tensorflow 2.0 vs pytorch

TensorFlow 2.0开源了,相较于TensoforFlow 1,TF2更专注于简单性和易用性,具有热切执行(Eager Execution),直观的API,融合Keras等更新。

Tensorflow 2www.tensorflow.org

随着这些更新,TensorFlow 2.0也变得越来越像Pytorch, 我们先来看一段手写字体识别(MNIST)代码,你能知道是什么深度学习框架写的吗?

# define model
class MyModel(torch.Model):
  def __init__(self):
    super(MyModel, self).__init__()
    self.conv1 = Conv2D(32, 3, activation='relu')
    self.flatten = Flatten()
    self.d1 = Dense(128, activation='relu')
    self.d2 = Dense(10, activation='softmax')

  def call(self, x):
    x = self.conv1(x)
    x = self.flatten(x)
    x = self.d1(x)
    return self.d2(x)
  
# Create an instance of the model
model = MyModel()

# define optimizer 
loss_object = torch.losses.SparseCategoricalCrossentropy()
optimizer = torch.optimizers.Adam()

# define metric
train_loss = torch.metrics.Mean(name='train_loss')
train_accuracy = torch.metrics.SparseCategoricalAccuracy(name='train_accuracy')

# load data loader
mnist = torch.datasets.mnist
train_ds, test_ds = load_data(mnist, batch_size)

# Train and Test
EPOCHS = 5
for epoch in range(EPOCHS):
  for images, labels in train_ds:
      train_step(images, labels)

  for test_images, test_labels in test_ds:
    test_step(test_images, test_labels)

整个代码风格和Pytorh几乎一致,不过,这不是Pytorch,而是TensorFlow 2. 只不过,我加了这样的一行: import tensorflow.keras as torch

完整的import代码如下:

%tensorflow_version 2.x
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Conv2D
import tensorflow.keras as torch

TF2 的热切执行机制让代码风格变得更 pythonic, 不需要再一步步地placeholder以及build graph。就本人而言,我是个Pytorch拥护者,还没有正式使用TF2。不过对于Pytorch用户而言,转到TF2简直不能更简单。

总而言之,Have fun~

    原文作者:登高居士
    原文地址: https://zhuanlan.zhihu.com/p/80064752
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞