Tensorflow——sess.run()

sess.run()中的feed_dict

参考文献:

【1】 tensorflow学习笔记(十):sess.run()

我们都知道feed_dict的作用是给使用placeholder创建出来的tensor赋值。其实,他的作用更加广泛:feed使用一个值临时替换一个op的输出结果。你可以提供feed数据作为run()调用的参数。feed只在调用它的方法内有效,方法结束,feed就会消失。

sess.run()

参考文献:

【2】tensorflow学习笔记(十):sess.run()

当我们构建完图后,需要在一个会话中启动图,启动的第一步是创建一个Session对象。

为了取回(Fetch)操作的输出内容,可以在使用Session对象的run()调用执行图时,传入一些tensor,这些tensor会帮助你取回结果。

在python语言中,返回的tensor是numpy ndarray对象。

在执行sess.run()时,tensorflow并不是计算了整个图,只是计算了与想要fetch的值相关的部分。

使用feed_dict字典填充

参考文献:

【3】简单的Tensorflow(3):使用feed_dict字典填充

tensorflow还提供字典填充函数,使输入和输出更为简单:feed_dict = {}。

例如:需要吧8和2填充到字典中,就需要占位符tensorflow.placeholder()而非变量,input1 = tf.placeholder(tf.float32),因为是一个元素不需要矩阵相乘,只要简单的乘法即可:tensorflow.multiply()。

import tensorflow as tf
#设置两个乘数,用占位符表示
input1 = tf.placeholder(tf.float32)
input2 = tf.placeholder(tf.float32)
#设置乘积
output = tf.multiply(input1, input2)
with tf.Session() as sess:
  #用feed_dict以字典的方式填充占位
 print(sess.run(output, feed_dict={input1:[8.],input2:[2.]}))

结果是:

[ 16.]

占位符和feed_dict

参考文献:

【4】 Tensorflow学习笔记——占位符和feed_dict(二)

import tensorflow as tf
import numpy as np

list_of_points1_ = [[1, 2], [3, 4], [5, 6], [7, 8]]
list_of_points2_ = [[15, 16], [13, 14], [11, 12], [9, 10]] 

list_of_points1 = np.array([np.array(elem).reshape(1, 2) for elem in list_of_points1_])
list_of_points2 = np.array([np.array(elem).reshape(1, 2) for elem in list_of_points2_])

graph = tf.Graph()

with graph.as_default():
#我们使用tf.placeholder()创建占位符 ,在session.run()过程中再投递数据
    point1 = tf.placeholder(tf.float32, shape=(1, 2))
    point2 = tf.placeholder(tf.float32, shape=(1, 2))

def calculate_eucledian_distance(point1, point2):
    difference = tf.subtract(point1, point2)
    power2 = tf.pow(difference, tf.constant(2.0, shape=(1, 2)))
    add = tf.reduce_sum(power2)
    eucledian_distance = tf.sqrt(add)
    return eucledian_distance

dist = calculate_eucledian_distance(point1, point2)

with tf.Session(graph=graph) as session:
    tf.global_variables_initializer().run()
    for ii in range(len(list_of_points1)):
        point1_ = list_of_points1[ii]
        point2_ = list_of_points2[ii]
        #使用feed_dict将数据投入到[dist]中
        feed_dict = {point1: point1_, point2: point2_}
        distance = session.run([dist], feed_dict=feed_dict)
        print("the distance between {} and {} -> {}".format(point1_, point2_, distance))

输出:

the distance between [[1 2]] and [[15 16]] -> [19.79899]
the distance between [[3 4]] and [[13 14]] -> [14.142136]
the distance between [[5 6]] and [[11 12]] -> [8.485281]
the distance between [[7 8]] and [[ 9 10]] -> [2.828427]

请大家批评指正,谢谢😄~

    原文作者:搬砖的旺财
    原文地址: https://zhuanlan.zhihu.com/p/51165622
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞