TensorFlow 2.0开源了,相较于TensoforFlow 1,TF2更专注于简单性和易用性,具有热切执行(Eager Execution),直观的API,融合Keras等更新。
Tensorflow 2 www.tensorflow.org
随着这些更新,TensorFlow 2.0也变得越来越像Pytorch, 我们先来看一段手写字体识别(MNIST)代码,你能知道是什么深度学习框架写的吗?
# define model
class MyModel(torch.Model):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = Conv2D(32, 3, activation='relu')
self.flatten = Flatten()
self.d1 = Dense(128, activation='relu')
self.d2 = Dense(10, activation='softmax')
def call(self, x):
x = self.conv1(x)
x = self.flatten(x)
x = self.d1(x)
return self.d2(x)
# Create an instance of the model
model = MyModel()
# define optimizer
loss_object = torch.losses.SparseCategoricalCrossentropy()
optimizer = torch.optimizers.Adam()
# define metric
train_loss = torch.metrics.Mean(name='train_loss')
train_accuracy = torch.metrics.SparseCategoricalAccuracy(name='train_accuracy')
# load data loader
mnist = torch.datasets.mnist
train_ds, test_ds = load_data(mnist, batch_size)
# Train and Test
EPOCHS = 5
for epoch in range(EPOCHS):
for images, labels in train_ds:
train_step(images, labels)
for test_images, test_labels in test_ds:
test_step(test_images, test_labels)
整个代码风格和Pytorh几乎一致,不过,这不是Pytorch,而是TensorFlow 2. 只不过,我加了这样的一行: import tensorflow.keras as torch
完整的import代码如下:
%tensorflow_version 2.x
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Conv2D
import tensorflow.keras as torch
TF2 的热切执行机制让代码风格变得更 pythonic, 不需要再一步步地placeholder以及build graph。就本人而言,我是个Pytorch拥护者,还没有正式使用TF2。不过对于Pytorch用户而言,转到TF2简直不能更简单。
总而言之,Have fun~