TensorFlow学习笔记(14)使用SignatureDef保存和恢复RNN模型

环境:
Python 3.5.2
tensorflow : 1.11.0
ubuntu : 16.04

保存模型,github代码

  saved_model_dir='./model'                                                                                                                                                                                                                                     
  builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir)                                                                                                                                                                                           
                                                                                                                                                                                                                                                                
  # input_x, keep_prob                                                                                                                                                                                                                                          
  inputs = {'input_x': tf.saved_model.utils.build_tensor_info(xs),                                                                                                                                                                                              
           'input_y': tf.saved_model.utils.build_tensor_info(ys),                                                                                                                                                                                              
            'keep_prob': tf.saved_model.utils.build_tensor_info(keep_prob)}                                                                                                                                                                                     
                                                                                                                                                                                                                                                                
  # prediction 为预测函数,恢复的时候要通过该函数来预测                                                                                                                                                                                                         
  outputs = {'prediction' : tf.saved_model.utils.build_tensor_info(prediction)}                                                                                                                                                                                 
                                                                                                                                                                                                                                                                
  signature = tf.saved_model.signature_def_utils.build_signature_def(inputs, outputs, 'test_sig_name')                                                                                                                                                          
                                                                                                                                                                                                                                                                
  with tf.Session() as sess:                                                                                                                                                                                                                                    
      sess.run(init)                                                                                                                                                                                                                                            
                                                                                                                                                                                                                                                                
      for i in range(1000):                                                                                                                                                                                                                                     
          batch_xs, batch_ys = mnist.train.next_batch(100)                                                                                                                                                                                                      
          sess.run(train_step, feed_dict={xs: batch_xs, ys: batch_ys, keep_prob: 0.5})                                                                                                                                                                          
          if i % 50 == 0:                                                                                                                                                                                                                                       
              print(compute_accuracy(sess, prediction,                                                                                                                                                                                                          
                  mnist.test.images[:1000], mnist.test.labels[:1000]))                                                                                                                                                                                          
                                                                                                                                                                                                                                                                
      builder.add_meta_graph_and_variables(sess, ['model_final'], {'test_signature':signature})                                                                                                                                                                 
      builder.save()   

恢复模型github代码

  saved_model_dir='./model'                                                                                                                                                                                                                                     
                                                                                                                                                                                                                                                                
  signature_key = 'test_signature'                                                                                                                                                                                                                              
  input_key_x = 'input_x'                                                                                                                                                                                                                                       
  input_key_y = 'input_y'                                                                                                                                                                                                                                       
  input_key_keep_prob = 'keep_prob'                                                                                                                                                                                                                             
  output_key_prediction = 'prediction'                                                                                                                                                                                                                          

  with tf.Session() as sess:                                                                                                                                                                                                                                    
      meta_graph_def = tf.saved_model.loader.load(sess, ['model_final'], saved_model_dir)                                                                                                                                                                       
                                                                                                                                                                                                                                                                
      # 从meta_graph_def中取出SignatureDef对象                                                                                                                                                                                                                  
      signature = meta_graph_def.signature_def                                                                                                                                                                                                                  
                                                                                                                                                                                                                                                                
      # 从signature中找出具体输入输出的tensor name.                                                                                                                                                                                                             
      x_tensor_name = signature[signature_key].inputs[input_key_x].name                                                                                                                                                                                         
      y_tensor_name = signature[signature_key].inputs[input_key_y].name                                                                                                                                                                                         
      keep_prob_tensor_name = signature[signature_key].inputs[input_key_keep_prob].name                                                                                                                                                                         
      prediction_tensor_name = signature[signature_key].outputs[output_key_prediction].name                                                                                                                                                                     
                                                                                                                                                                                                                                                                
      # 获取tensor 并inference                                                                                                                                                                                                                                  
      input_x = sess.graph.get_tensor_by_name(x_tensor_name)                                                                                                                                                                                                    
      input_y = sess.graph.get_tensor_by_name(y_tensor_name)                                                                                                                                                                                                    
      keep_prob = sess.graph.get_tensor_by_name(keep_prob_tensor_name)                                                                                                                                                                                          
      prediction = sess.graph.get_tensor_by_name(prediction_tensor_name)                                                                                                                                                                                        

通过恢复的模型,来预测结果

                                                                                                                                                                                                                                                                
      # 测试单个数据                                                                                                                                                                                                                                            
      x = mnist.test.images[index].reshape(1, 784)                                                                                                                                                                                                              
      y = mnist.test.labels[index].reshape(1, 10)  # 转为one-hot形式                                                                                                                                                                                            
      print (y)                                                                                                                                                                                                                                                 
                                                                                                                                                                                                                                                                
      pred_y = sess.run(prediction, feed_dict={input_x: x, keep_prob : 1 })                                                                                                                                                                                     
      print (pred_y)                                                                                                                                                                                                                                            
                                                                                                                                                                                                                                                                
      print("Actual class: ", str(sess.run(tf.argmax(y, 1))), \                                                                                                                                                                                                 
            ", predict class ",str(sess.run(tf.argmax(pred_y, 1))), \                                                                                                                                                                                           
            ", predict ", str(sess.run(tf.equal(tf.argmax(y, 1), tf.argmax(pred_y, 1))))                                                                                                                                                                        
            )                                                                                                                                                                                                                                                   
                                                                                                                                                                                                                                                                
      # 测试数据集                                                                                                                                                                                                                                              
      print(compute_accuracy(sess, prediction, input_x, keep_prob,                                                                                                                                                                                              
        mnist.test.images[:1000], mnist.test.labels[:1000]))     
    原文作者:谢昆明
    原文地址: https://www.jianshu.com/p/be6bde3a93bb
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞