python – 将tensorflow检查点保存到.pb protobuf文件

我已经在tensorflow上训练了一个
pix2pix模型,模型已经以检查点的形式保存,并带有以下文件:

model-15000.meta,model-15000.index,model-15000.data-00000-of-00001,graph.pbtxt,checkpoint.

现在,我想将其转换为protobuf文件(.pb)以进行部署.我遇到了freeze_graph.py脚本这样做,但是我遇到了其中一个参数的问题,它是output_node_names.

我尝试了几个图层名称,但是我收到以下错误:

AssertionError: generator/decoder_2/batchnorm/scale/gradients is not in graph

不确定如何找到output_node_names

最佳答案 尝试以下代码将meta转换为pb文件:

import tensorflow as tf
#Step 1 
#import the model metagraph
saver = tf.train.import_meta_graph('./model.meta', clear_devices=True)
#make that as the default graph
graph = tf.get_default_graph()
input_graph_def = graph.as_graph_def()
sess = tf.Session()
#now restore the variables
saver.restore(sess, "./model")

#Step 2
# Find the output name
graph = tf.get_default_graph()
for op in graph.get_operations(): 
  print (op.name)

#Step 3
from tensorflow.python.platform import gfile
from tensorflow.python.framework import graph_util

output_node_names="predictions_mod/Sigmoid"
output_graph_def = graph_util.convert_variables_to_constants(
        sess, # The session
        input_graph_def, # input_graph_def is useful for retrieving the nodes 
        output_node_names.split(",")  )    

#Step 4
#output folder
output_fld ='./'
#output pb file name
output_model_file = 'model.pb'
from tensorflow.python.framework import graph_io
#write the graph
graph_io.write_graph(output_graph_def, output_fld, output_model_file, as_text=False)

希望这个有用!!!

点赞