TensorFlow教程(一)MNIST手写体数字识别

本文改编自TensorFLow官方教程中文版,力求更加简洁、清晰。

一、介绍

TensorFlow是当前最流行的机器学习框架,有了它,开发人工智能程序就像Java编程一样简单。今天,就让我们从手写体数字识别入手,看看如何用机器学习的方法解决这个问题。

二、编程环境

Python2.7+TensorFlow0.5.0下测试通过,Python3.5下未测试。请参考《TensorFLow下载与安装》配置环境。

三、思路

没有接触过图像处理的人可能会很纳闷,从一张图片识别出里面的内容似乎是件相当神奇的事情。其实,当你把图片当成一枚枚像素来看的话,就没那么神秘了。下图为手写体数字1的图片,它在计算机中的存储其实是一个二维矩阵,每个元素都是0~1之间的数字,0代表白色,1代表黑色,小数代表某种程度的灰色。

数字1的像素表示

现在,对于MNIST数据集中的图片来说,我们只要把它当成长度为784的向量就可以了(忽略它的二维结构,28×28=784)。我们的任务就是让这个向量经过一个函数后输出一个类别,呐,就是下边这个函数,称为Softmax分类器。

线性分类器示意图

这个式子里的图片向量的长度只有3,用x表示。乘上一个系数矩阵W,再加上一个列向量b,然后输入softmax函数,输出就是分类结果y。W是一个权重矩阵,W的每一行与整个图片像素相乘的结果是一个分数score,分数越高表示图片越接近该行代表的类别。因此,W x + b 的结果其实是一个列向量,每一行代表图片属于该类的评分。熟悉图像分类的同学应该了解,通常分类的结果并非评分,而是概率,表示有多大的概率属于此类别。因此,Softmax函数的作用就是把评分转换成概率,并使总的概率为1。

有了这个模型,如何训练它呢?

对于机器学习算法来说,训练就是不断调整模型参数使误差达到最小的过程。这里的模型参数就是W和b。接下来我们需要定义误差。误差当然是把预测的结果y和正确结果相比较得到的,但是由于正确结果是one_hot向量(即只有一个元素是1,其它元素都是0),而预测结果是个概率向量,用什么方法比较其实是个需要深入考虑的事情。事实上,我们使用的是交叉熵损失(cross-entropy loss),为什么用这个,其实我现在也不太清楚,所以姑且先用着吧,以后见得多了自然就明白了。

好了,到这里思路大体上就讲完了,还有不清楚的地方让我们看看代码就能理解了。

四、TensorFlow实现

说实话,这个代码比想象中还要简练,只有33行,所以我把它直接贴出来。

# coding=utf-8

import tensorflow as tf
import input_data

# 下载MNIST数据集到'MNIST_data'文件夹并解压
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

# 设置权重weights和偏置biases作为优化变量,初始值设为0
weights = tf.Variable(tf.zeros([784, 10]))
biases = tf.Variable(tf.zeros([10]))

# 构建模型
x = tf.placeholder("float", [None, 784])
y = tf.nn.softmax(tf.matmul(x, weights) + biases)                                   # 模型的预测值
y_real = tf.placeholder("float", [None, 10])                                        # 真实值

cross_entropy = -tf.reduce_sum(y_real * tf.log(y))                                  # 预测值与真实值的交叉熵
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)        # 使用梯度下降优化器最小化交叉熵

# 开始训练
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for i in range(1000):
    batch_xs, batch_ys = mnist.train.next_batch(100)                                # 每次随机选取100个数据进行训练,即所谓的“随机梯度下降(Stochastic Gradient Descent,SGD)”
    sess.run(train_step, feed_dict={x: batch_xs, y_real:batch_ys})                  # 正式执行train_step,用feed_dict的数据取代placeholder

    if i % 100 == 0:
        # 每训练100次后评估模型
        correct_prediction = tf.equal(tf.argmax(y, 1), tf.arg_max(y_real, 1))       # 比较预测值和真实值是否一致
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))             # 统计预测正确的个数,取均值得到准确率
        print sess.run(accuracy, feed_dict={x: mnist.test.images, y_real: mnist.test.labels})

这里用到了官方给的一个代码文件input_data,我已经放到工程里了。导入input_data,就可以用它来读取MNIST数据集,非常方便。

整体来说,使用TensorFLow编程主要分为两个阶段,第一个阶段是构建模型,把网络模型用代码搭建起来。TensorFlow的本质是数据流图,因此这一阶段其实是在规定数据的流动方向。第二个阶段是开始训练,把数据输入到模型中,并通过梯度下降等方法优化变量的值。

首先,我们需要把权重weights和偏置biases设置成优化变量,只有优化变量才可以在后面被Optimizer优化。并且需要为它们赋初值,这里将weights设为784×10的zero矩阵,把biases设为1×10的zero矩阵。

然后构建模型。模型的输入一般设置为placeholder,译为占位符。在训练的过程中只有placeholder可以允许数据输入。第一维的长度为None表示允许输入任意长度,也就是说输入可以是任意张图像。

使用tf.log计算y中每个元素的对数,并逐个与y_real相乘,再求和并取反,就得到了交叉熵。使用梯度下降优化器最小化交叉熵作为训练步骤train_step

接下来开始训练。首先要调用tf.initialize_all_variables()方法初始化所有变量。再创建一个tf.Session对象来控制整个训练流程。循环训练1000次,每次从训练集中随机取100个数据进行训练。

在训练的过程中,每隔100次对模型进行一次评估。评估使用测试集数据,统计正确预测的个数的百分比并输出。结果如下:

$ /usr/bin/python2.7 /home/wjg/projects/MNISTRecognition/main.py
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
0.4075
0.894
0.8989
0.9012
0.904
0.9105
0.9086
0.9137
0.9105
0.9174
Process finished with exit code 0

可见预测准确率逐渐上升,最后达到91%。

五、总结

这是我第一次使用TensorFlow,它给我的感觉是非常方便,很贴合程序员的开发习惯。相比之下,之前用Caffe的时候就总是摸不着头脑。当然也可能是因为TensorFlow的官方文档更友好的缘故。

本文在很多地方都语焉不详,因为作者水平有限,有关深奥的数学原理都一带而过。所以如果想要深入了解,还是推荐大家看官方教程。文末的参考资料一栏列出了一些有帮助的文章和视频。

最后,可以从我的GitHub上下载完整代码:https://github.com/jingedawang/MNISTRecognition

另外,熟悉多维矩阵操作(NumPy中的切片和广播)可以更好的地理解代码,建议阅读参考资料最后一条:P

六、参考资料

MNIST机器学习入门 TensorFlow中文社区
莫烦 Tensorflow 16 Classification 分类学习 莫烦
Classification 分类学习 莫烦
Softmax 函数的特点和作用是什么? 知乎
CS231n课程笔记翻译:线性分类笔记(下) 杜客译
CS231n课程笔记翻译:Python Numpy教程 杜客译

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 199,711评论 5 468
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 83,932评论 2 376
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 146,770评论 0 330
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 53,799评论 1 271
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 62,697评论 5 359
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,069评论 1 276
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 37,535评论 3 390
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,200评论 0 254
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,353评论 1 294
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,290评论 2 317
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,331评论 1 329
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,020评论 3 315
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,610评论 3 303
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,694评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 30,927评论 1 255
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 42,330评论 2 346
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 41,904评论 2 341

推荐阅读更多精彩内容