博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
tensorflow学习3---mnist
阅读量:4327 次
发布时间:2019-06-06

本文共 3752 字,大约阅读时间需要 12 分钟。

1 import tensorflow as tf  2 from tensorflow.examples.tutorials.mnist import input_data 3  4 '''数据下载''' 5 mnist=input_data.read_data_sets('Mnist_data',one_hot=True) 6 #one_hot标签 7        8 '''生成层 函数''' 9 def add_layer(input,in_size,out_size,n_layer='layer',activation_function=None):10     layer_name='layer %s' % n_layer11     '''补充知识'''12     #tf.name_scope:Wrapper for Graph.name_scope() using the default graph.13     #scope名字的作用域14     #sprase:A string (not ending with '/') will create a new name scope, in which name is appended to the prefix of all operations created in the context. 15     #If name has been used before, it will be made unique by calling self.unique_name(name).16     with tf.name_scope('weights'):17         Weights=tf.Variable(tf.random_normal([in_size,out_size]),name='w')18         tf.summary.histogram(layer_name+'/wights',Weights)19         #tf.summary.histogram:output summary with histogram直方图20         #tf,random_normal正太分布21     with tf.name_scope('biases'):22         biases=tf.Variable(tf.zeros([1,out_size])+0.1)23         tf.summary.histogram(layer_name+'/biases',biases)24         #tf.summary.histogram:k25     with tf.name_scope('Wx_plus_b'):26         Wx_plus_b=tf.matmul(input,Weights)+biases27     if activation_function==None:28         outputs=Wx_plus_b29     else:30         outputs=activation_function(Wx_plus_b)31     tf.summary.histogram(layer_name+'/output',outputs)32     return outputs33 '''准确率'''34 def compute_accuracy(v_xs,v_ys):35     global prediction36     y_pre=sess.run(prediction,feed_dict={xs:v_xs})#<37     #tf.equal()对比预测值的索引和实际label的索引是否一样,一样返回True,否则返回false38     correct_prediction=tf.equal(tf.argmax(y_pre,1),tf.argmax(v_ys,1))39     #correct_prediction-->[ True False  True ...,  True  True  True]40     '''补充知识-tf.argmax'''41     #tf.argmax:Returns the index with the largest value across dimensions of a tensor.42     #tf.argmax()----->43     accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))44     #正确cast为1,错误cast为045     '''补充知识 tf.cast'''46     #tf.cast:   Casts a tensor to a new type.47     ## tensor `a` is [1.8, 2.2], dtype=tf.float48     #tf.cast(a, tf.int32) ==> [1, 2]  # dtype=tf.int3249     result=sess.run(accuracy,feed_dict={xs:v_xs,ys:v_ys})50     #print(sess.run(correct_prediction,feed_dict={xs:v_xs,ys:v_ys}))51     #ckc=tf.cast(correct_prediction,tf.float32)52     #print(sess.run(ckc,feed_dict={xs:v_xs,ys:v_ys}))53     return result54 55 56 '''占位符'''57 xs=tf.placeholder(tf.float32,[None,784])58 ys=tf.placeholder(tf.float32,[None,10])59 60 '''添加层'''61 62 prediction=add_layer(xs,784,10,activation_function=tf.nn.softmax)63 #sotmax激活函数,用于分类函数64 65 '''计算'''66 #交叉熵cross_entropy损失函数,参数分别为实际的预测值和实际的label值y,re67 '''补充知识'''68 #reduce_mean()69 # 'x' is [[1., 1. ]]70 #         [2., 2.]]71 #tf.reduce_mean(x) ==> 1.572 #tf.reduce_mean(x, 0) ==> [1.5, 1.5]73 #tf.reduce_mean(x, 1) ==> [1.,  2.]74 cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys*tf.log(prediction),reduction_indices=[1]))75 '''补充知识'''76 #reduce_sum77 # 'x' is [[1, 1, 1]]78 #         [1, 1, 1]]79 #tf.reduce_sum(x) ==> 680 #tf.reduce_sum(x, 0) ==> [2, 2, 2]81 #tf.reduce_sum(x, 1) ==> [3, 3]82 #tf.reduce_sum(x, 1, keep_dims=True) ==> [[3], [3]]83 #tf.reduce_sum(x, [0, 1]) ==> 684 85 train_step=tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)86 87 '''Session_begin'''88 with tf.Session() as sess:89     sess.run(tf.global_variables_initializer())90     for i in range(1000):91         batch_xs,batch_ys=mnist.train.next_batch(100) #逐个batch去取数据92         sess.run(train_step,feed_dict={xs:batch_xs,ys:batch_ys})93         if(i%50==0):94             print(compute_accuracy(mnist.test.images,mnist.test.labels))95

 

转载于:https://www.cnblogs.com/ChenKe-cheng/p/8889229.html

你可能感兴趣的文章
openssl 升级
查看>>
ASP.NET MVC:通过 FileResult 向 浏览器 发送文件
查看>>
CVE-2010-2883Adobe Reader和Acrobat CoolType.dll栈缓冲区溢出漏洞分析
查看>>
使用正确的姿势跨域
查看>>
AccountManager教程
查看>>
Android学习笔记(十一)——从意图返回结果
查看>>
算法导论笔记(四)算法分析常用符号
查看>>
ultraedit激活
查看>>
总结(6)--- python基础知识点小结(细全)
查看>>
亿级曝光品牌视频的幕后设定
查看>>
ARPA
查看>>
JSP开发模式
查看>>
我的Android进阶之旅------&gt;Android嵌入图像InsetDrawable的使用方法
查看>>
Detours信息泄漏漏洞
查看>>
win32使用拖放文件
查看>>
Android 动态显示和隐藏软键盘
查看>>
raid5什么意思?怎样做raid5?raid5 几块硬盘?
查看>>
【转】how can i build fast
查看>>
null?对象?异常?到底应该如何返回错误信息
查看>>
django登录验证码操作
查看>>