前言
在使用tensorflow跑完深度学习等模型后,生成的pb模型,在做预测时,有固定的步骤。记录一下过程。
主要步骤
以tensorflow的deeplab模型为例,在训练完成后,生成了pb格式的模型。在做验证预测的时候,固定格式如下:
MODEL_NAME="./"
PATH_TO_CKPT = MODEL_NAME + 'frozen_inference_graph.pb'
# In[20]:
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid: #加载模型
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
# In[21]:
test_img_base_path="./Sample"
imgs_files=os.path.join(test_img_base_path,"*","*.png") #测试验证图片的路径
imgs_list=glob.glob(imgs_files)
num_imgs=len(imgs_list)
print("Images num:"+str(num_imgs))
inference_path="./inference_result"
new_files=[]
if not os.path.exists(inference_path):
os.mkdir(inference_path)
total_time = 0
# In[22]:
with detection_graph.as_default():
with tf.Session(graph=detection_graph) as sess:
image_tensor = detection_graph.get_tensor_by_name('ImageTensor:0') #获取输入图片的tensor
prediction = detection_graph.get_tensor_by_name('SemanticPredictions:0') #输出prediction的tensor
start_time=datetime.datetime.now()
print("STARTING ...")
for image_path in imgs_list:
image_np = Image.open(image_path)
image_np_expanded = np.expand_dims(image_np, axis=0)
#图片的处理,每次预测一张图,batch_size=1,当然也可以一次预测多张图片
# Definite input and output Tensors for detection_graph
out_name=os.path.join(inference_path,image_path.split("/")[-2],image_path.split("/")[-1])
time1 = time.time()
prediction_out= sess.run(
prediction,feed_dict={image_tensor: image_np_expanded}) #运行一次模型
time2 = time.time()
total_time += float(time2-time1)
result=Image.fromarray(np.array(prediction_out[0]*200).astype(np.uint8))
#由于本例是图像分割模型,输出也是图片,将prediction直接转为array格式保存即可
if not os.path.exists(os.path.join(inference_path,out_name.split("/")[-2])):
os.mkdir(os.path.join(inference_path,out_name.split("/")[-2]))
result.save(out_name)
end_time=datetime.datetime.now()
print("START TIME :"+str(start_time))
print("END TIME :"+str(end_time))
print("THE TOTAL TIME COST IS:"+str(total_time))
print("THE average TIME COST IS:"+str(float(total_time)/float(num_imgs)))
对于预测过程中,前面In[20],In[21]部分的作用构建图、加载模型,该部分对于tensorflow生成的pb模型文件来讲基本上都是固定的,因此套用即可。
总结
对于tensorflow的验证部分,比较难以构建,对于不同的任务验证过程有差异,但是基本的步骤相差不大。