Tensorflow 模型预测inference步骤

前言

在使用tensorflow跑完深度学习等模型后,生成的pb模型,在做预测时,有固定的步骤。记录一下过程。

主要步骤

以tensorflow的deeplab模型为例,在训练完成后,生成了pb格式的模型。在做验证预测的时候,固定格式如下:

MODEL_NAME="./"
PATH_TO_CKPT = MODEL_NAME + 'frozen_inference_graph.pb'  


# In[20]:

detection_graph = tf.Graph()
with detection_graph.as_default():
  od_graph_def = tf.GraphDef()
  with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:      #加载模型
    serialized_graph = fid.read()
    od_graph_def.ParseFromString(serialized_graph)
    tf.import_graph_def(od_graph_def, name='')


# In[21]:

test_img_base_path="./Sample"
imgs_files=os.path.join(test_img_base_path,"*","*.png")   #测试验证图片的路径
imgs_list=glob.glob(imgs_files)
num_imgs=len(imgs_list)
print("Images num:"+str(num_imgs))
inference_path="./inference_result"
new_files=[]
if not os.path.exists(inference_path):
    os.mkdir(inference_path)
total_time = 0


# In[22]:

with detection_graph.as_default():  
  with tf.Session(graph=detection_graph) as sess: 
    image_tensor = detection_graph.get_tensor_by_name('ImageTensor:0')     #获取输入图片的tensor
    prediction = detection_graph.get_tensor_by_name('SemanticPredictions:0')  #输出prediction的tensor
    start_time=datetime.datetime.now()
    print("STARTING ...")
    for image_path in imgs_list:
        image_np = Image.open(image_path)
        image_np_expanded = np.expand_dims(image_np, axis=0)    
          #图片的处理,每次预测一张图,batch_size=1,当然也可以一次预测多张图片
        # Definite input and output Tensors for detection_graph 
        out_name=os.path.join(inference_path,image_path.split("/")[-2],image_path.split("/")[-1])
        time1 = time.time()
        prediction_out= sess.run(  
          prediction,feed_dict={image_tensor: image_np_expanded})   #运行一次模型
        time2 = time.time()
        total_time += float(time2-time1)
        result=Image.fromarray(np.array(prediction_out[0]*200).astype(np.uint8))
         #由于本例是图像分割模型,输出也是图片,将prediction直接转为array格式保存即可
        if  not os.path.exists(os.path.join(inference_path,out_name.split("/")[-2])):
            os.mkdir(os.path.join(inference_path,out_name.split("/")[-2]))
        result.save(out_name)
    end_time=datetime.datetime.now()
    
    print("START TIME :"+str(start_time))
    print("END TIME :"+str(end_time))
    print("THE TOTAL TIME COST IS:"+str(total_time))
    print("THE average TIME COST IS:"+str(float(total_time)/float(num_imgs)))

对于预测过程中,前面In[20],In[21]部分的作用构建图、加载模型,该部分对于tensorflow生成的pb模型文件来讲基本上都是固定的,因此套用即可。

总结

对于tensorflow的验证部分,比较难以构建,对于不同的任务验证过程有差异,但是基本的步骤相差不大。

完整代码附上。

    原文作者:陈远
    原文地址: https://zhuanlan.zhihu.com/p/50506840
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞