onnx是Facebook打造的AI中间件,但是Tensorflow官方不支持onnx,所以只能用onnx自己提供的方式从tensorflow尝试转换
Tensorflow模型转onnx
Tensorflow转onnx, onnx官方github上有提供转换的方式,地址为https://github.com/onnx/tutorials/blob/master/tutorials/OnnxTensorflowExport.ipynb 。按链接中的步骤一步一步就能完成mnist的模型转换,我也成功转换出了mnist.onnx模型。但是在上面步骤中model = onnx.load('mnist.onnx')
之后执行tf_rep = prepare(model)
一直不成功。但是换成网上别人用pytorch转的mnist.onnx执行tf_rep = prepare(model)
又完全是OK的,这个暂时还没找到原因在哪里。
onnx模型转换为Tensorflow模型
上面提到按官网的教程从tensorflow转换生成的onnx模型执行tf_rep = prepare(model)
有问题。所以这里我从网上下载的一个pytorch转换的mnist onnx模型为实验对象,实验用的onnx下载地址:https://download.csdn.net/download/computerme/10448754
onnx模型转换为Tensorflow模型的代码如下:
import onnx
import numpy as np
from onnx_tf.backend import prepare
model = onnx.load('./assets/mnist_model.onnx')
tf_rep = prepare(model)
img = np.load("./assets/image.npz")
output = tf_rep.run(img.reshape([1, 1,28,28]))
print("outpu mat: \n",output)
print("The digit is classified as ", np.argmax(output))
import tensorflow as tf
with tf.Session() as persisted_sess:
print("load graph")
persisted_sess.graph.as_default()
tf.import_graph_def(tf_rep.predict_net.graph.as_graph_def(), name='')
inp = persisted_sess.graph.get_tensor_by_name(
tf_rep.predict_net.tensor_dict[tf_rep.predict_net.external_input[0]].name
)
out = persisted_sess.graph.get_tensor_by_name(
tf_rep.predict_net.tensor_dict[tf_rep.predict_net.external_output[0]].name
)
res = persisted_sess.run(out, {inp: img.reshape([1, 1,28,28])})
print(res)
print("The digit is classified as ",np.argmax(res))
tf_rep.export_graph('tf.pb')
转换完成后,需要对转换出的tf.pb
模型进行验证,验证方式如下:
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import gfile
name = "tf.pb"
with tf.Session() as persisted_sess:
print("load graph")
with gfile.FastGFile(name, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
persisted_sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
inp = persisted_sess.graph.get_tensor_by_name('0:0')
out = persisted_sess.graph.get_tensor_by_name('LogSoftmax:0')
#test = np.random.rand(1, 1, 28, 28).astype(np.float32)
#feed_dict = {inp: test}
img = np.load("./assets/image.npz")
feed_dict = {inp: img.reshape([1, 1,28,28])}
classification = persisted_sess.run(out, feed_dict)
print(out)
print(classification)
Reference:
https://github.com/onnx/onnx-tensorflow/issues/167