sess.run()中的feed_dict
参考文献:
【1】 tensorflow学习笔记(十):sess.run()
我们都知道feed_dict的作用是给使用placeholder创建出来的tensor赋值。其实,他的作用更加广泛:feed使用一个值临时替换一个op的输出结果。你可以提供feed数据作为run()调用的参数。feed只在调用它的方法内有效,方法结束,feed就会消失。
sess.run()
参考文献:
【2】tensorflow学习笔记(十):sess.run()
当我们构建完图后,需要在一个会话中启动图,启动的第一步是创建一个Session对象。
为了取回(Fetch)操作的输出内容,可以在使用Session对象的run()调用执行图时,传入一些tensor,这些tensor会帮助你取回结果。
在python语言中,返回的tensor是numpy ndarray对象。
在执行sess.run()时,tensorflow并不是计算了整个图,只是计算了与想要fetch的值相关的部分。
使用feed_dict字典填充
参考文献:
【3】简单的Tensorflow(3):使用feed_dict字典填充
tensorflow还提供字典填充函数,使输入和输出更为简单:feed_dict = {}。
例如:需要吧8和2填充到字典中,就需要占位符tensorflow.placeholder()而非变量,input1 = tf.placeholder(tf.float32),因为是一个元素不需要矩阵相乘,只要简单的乘法即可:tensorflow.multiply()。
import tensorflow as tf
#设置两个乘数,用占位符表示
input1 = tf.placeholder(tf.float32)
input2 = tf.placeholder(tf.float32)
#设置乘积
output = tf.multiply(input1, input2)
with tf.Session() as sess:
#用feed_dict以字典的方式填充占位
print(sess.run(output, feed_dict={input1:[8.],input2:[2.]}))
结果是:
[ 16.]
占位符和feed_dict
参考文献:
【4】 Tensorflow学习笔记——占位符和feed_dict(二)
import tensorflow as tf
import numpy as np
list_of_points1_ = [[1, 2], [3, 4], [5, 6], [7, 8]]
list_of_points2_ = [[15, 16], [13, 14], [11, 12], [9, 10]]
list_of_points1 = np.array([np.array(elem).reshape(1, 2) for elem in list_of_points1_])
list_of_points2 = np.array([np.array(elem).reshape(1, 2) for elem in list_of_points2_])
graph = tf.Graph()
with graph.as_default():
#我们使用tf.placeholder()创建占位符 ,在session.run()过程中再投递数据
point1 = tf.placeholder(tf.float32, shape=(1, 2))
point2 = tf.placeholder(tf.float32, shape=(1, 2))
def calculate_eucledian_distance(point1, point2):
difference = tf.subtract(point1, point2)
power2 = tf.pow(difference, tf.constant(2.0, shape=(1, 2)))
add = tf.reduce_sum(power2)
eucledian_distance = tf.sqrt(add)
return eucledian_distance
dist = calculate_eucledian_distance(point1, point2)
with tf.Session(graph=graph) as session:
tf.global_variables_initializer().run()
for ii in range(len(list_of_points1)):
point1_ = list_of_points1[ii]
point2_ = list_of_points2[ii]
#使用feed_dict将数据投入到[dist]中
feed_dict = {point1: point1_, point2: point2_}
distance = session.run([dist], feed_dict=feed_dict)
print("the distance between {} and {} -> {}".format(point1_, point2_, distance))
输出:
the distance between [[1 2]] and [[15 16]] -> [19.79899]
the distance between [[3 4]] and [[13 14]] -> [14.142136]
the distance between [[5 6]] and [[11 12]] -> [8.485281]
the distance between [[7 8]] and [[ 9 10]] -> [2.828427]
请大家批评指正,谢谢😄~