对比TensorFlow和Pytorch的动静态图构建上的差异
静态图框架设计好了不能够修改,且定义静态图时需要使用新的特殊语法,这也意味着图设定时无法使用if、while、for-loop等结构,而是需要特殊的由框架专门设计的语法,在构建图时,我们需要考虑到所有的情况(即各个if分支图结构必须全部在图中,即使不一定会在每一次运行时使用到),使得静态图异常庞大占用过多显存。
以动态图没有这个顾虑,它兼容python的各种逻辑控制语法,最终创建的图取决于每次运行时的条件分支选择,下面我们对比一下TensorFlow和Pytorch的if条件分支构建图的实现:
# Author : Hellcat # Time : 2018/2/9 def tf_graph_if(): import numpy as np import tensorflow as tf x = tf.placeholder(tf.float32, shape=(3, 4)) z = tf.placeholder(tf.float32, shape=None) w1 = tf.placeholder(tf.float32, shape=(4, 5)) w2 = tf.placeholder(tf.float32, shape=(4, 5)) def f1(): return tf.matmul(x, w1) def f2(): return tf.matmul(x, w2) y = tf.cond(tf.less(z, 0), f1, f2) with tf.Session() as sess: y_out = sess.run(y, feed_dict={ x: np.random.randn(3, 4), z: 10, w1: np.random.randn(4, 5), w2: np.random.randn(4, 5)}) return y_out def t_graph_if(): import torch as t from torch.autograd import Variable x = Variable(t.randn(3, 4)) w1 = Variable(t.randn(4, 5)) w2 = Variable(t.randn(4, 5)) z = 10 if z > 0: y = x.mm(w1) else: y = x.mm(w2) return y if __name__ == "__main__": print(tf_graph_if()) print(t_graph_if())
计算输出如下:
[[ 4.0871315 0.90317607 -4.65211582 0.71610922 -2.70281982]
[ 3.67874336 -0.58160967 -3.43737102 1.9781189 -2.18779659]
[ 2.6638422 -0.81783319 -0.30386463 -0.61386991 -3.80232286]]
Variable containing:
-0.2474 0.1269 0.0830 3.4642 0.2255
0.7555 -0.8057 -2.8159 3.7416 0.6230
0.9010 -0.9469 -2.5086 -0.8848 0.2499
[torch.FloatTensor of size 3×5]
个人感觉上面的对比不太完美,如果使用TensorFlow的变量来对比,上面函数应该改写如下,
# Author : Hellcat # Time : 2018/2/9 def tf_graph_if(): import tensorflow as tf x = tf.Variable(dtype=tf.float32, initial_value=tf.random_uniform(shape=[3, 4])) z = tf.constant(dtype=tf.float32, value=10) w1 = tf.Variable(dtype=tf.float32, initial_value=tf.random_uniform(shape=[4, 5])) w2 = tf.Variable(dtype=tf.float32, initial_value=tf.random_uniform(shape=[4, 5])) def f1(): return tf.matmul(x, w1) def f2(): return tf.matmul(x, w2) y = tf.cond(tf.less(z, 0), f1, f2) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) y_out = sess.run(y) return y_out
输出没什么变化,
[[ 1.89582038 1.12734962 0.59730953 0.99833554 0.86517167]
[ 1.2659111 0.77320379 0.63649696 0.5804953 0.82271856]
[ 1.92151642 1.64715886 1.19869363 1.31581473 1.5636673 ]]
可以看到,TensorFlow的if条件分支使用函数tf.cond(tf.less(z, 0), f1, f2)来实现,这和Pytorch直接使用if的逻辑很不同,而且,动态图不需要feed,直接运行便可。简单对比,可以看到Pytorch的逻辑更为简洁,让人很感兴趣。