tensorflow python创建模型,训练模型,得到.pb模型文件后,用c++ api进行预测
1 #include <iostream> 2 #include <map> 3 4 #include "tensorflow/cc/ops/image_ops.h" 5 #include "tensorflow/cc/ops/standard_ops.h" 6 #include "tensorflow/core/framework/graph.pb.h" 7 #include "tensorflow/core/framework/tensor.h" 8 #include "tensorflow/core/graph/default_device.h" 9 #include "tensorflow/core/graph/graph_def_builder.h" 10 #include "tensorflow/core/platform/logging.h" 11 #include "tensorflow/core/platform/types.h" 12 #include "tensorflow/core/public/session.h" 13 14 using namespace std ; 15 using namespace tensorflow; 16 using tensorflow::Tensor; 17 using tensorflow::Status; 18 using tensorflow::string; 19 using tensorflow::int32; 20 21 22 //从文件名中读取数据 23 Status ReadTensorFromImageFile(string file_name, const int input_height, 24 const int input_width, 25 vector<Tensor>* out_tensors) { 26 auto root = Scope::NewRootScope(); 27 using namespace ops; 28 29 auto file_reader = ops::ReadFile(root.WithOpName("file_reader"),file_name); 30 const int wanted_channels = 1; 31 Output image_reader; 32 std::size_t found = file_name.find(".png"); 33 //判断文件格式 34 if (found!=std::string::npos) { 35 image_reader = DecodePng(root.WithOpName("png_reader"), file_reader,DecodePng::Channels(wanted_channels)); 36 } 37 else { 38 image_reader = DecodeJpeg(root.WithOpName("jpeg_reader"), file_reader,DecodeJpeg::Channels(wanted_channels)); 39 } 40 // 下面几步是读取图片并处理 41 auto float_caster =Cast(root.WithOpName("float_caster"), image_reader, DT_FLOAT); 42 auto dims_expander = ExpandDims(root, float_caster, 0); 43 auto resized = ResizeBilinear(root, dims_expander,Const(root.WithOpName("resize"), {input_height, input_width})); 44 // Div(root.WithOpName(output_name), Sub(root, resized, {input_mean}),{input_std}); 45 Transpose(root.WithOpName("transpose"),resized,{0,2,1,3}); 46 47 GraphDef graph; 48 root.ToGraphDef(&graph); 49 50 unique_ptr<Session> session(NewSession(SessionOptions())); 51 session->Create(graph); 52 session->Run({}, {"transpose"}, {}, out_tensors);//Run,获取图片数据保存到Tensor中 53 54 return Status::OK(); 55 } 56 57 int main(int argc, char* argv[]) { 58 59 string graph_path = "aov_crnn.pb"; 60 GraphDef graph_def; 61 //读取模型文件 62 if (!ReadBinaryProto(Env::Default(), graph_path, &graph_def).ok()) { 63 cout << "Read model .pb failed"<<endl; 64 return -1; 65 } 66 67 //新建session 68 unique_ptr<Session> session; 69 SessionOptions sess_opt; 70 sess_opt.config.mutable_gpu_options()->set_allow_growth(true); 71 (&session)->reset(NewSession(sess_opt)); 72 if (!session->Create(graph_def).ok()) { 73 cout<<"Create graph failed"<<endl; 74 return -1; 75 } 76 77 //读取图像到inputs中 78 int input_height = 40; 79 int input_width = 240; 80 vector<Tensor> inputs; 81 // string image_path(argv[1]); 82 string image_path("test.jpg"); 83 if (!ReadTensorFromImageFile(image_path, input_height, input_width,&inputs).ok()) { 84 cout<<"Read image file failed"<<endl; 85 return -1; 86 } 87 88 vector<Tensor> outputs; 89 string input = "inputs_sq"; 90 string output = "results_sq";//graph中的输入节点和输出节点,需要预先知道 91 92 pair<string,Tensor>img(input,inputs[0]); 93 Status status = session->Run({img},{output}, {}, &outputs);//Run,得到运行结果,存到outputs中 94 if (!status.ok()) { 95 cout<<"Running model failed"<<endl; 96 cout<<status.ToString()<<endl; 97 return -1; 98 } 99 100 101 //得到模型运行结果 102 Tensor t = outputs[0]; 103 auto tmap = t.tensor<int64, 2>(); 104 int output_dim = t.shape().dim_size(1); 105 106 107 return 0; 108 }
g++ -g tf_predict.cpp -o tf_predict -I /usr/include/eigen3 -I /usr/local/include/tf -L/usr/local/lib/ `pkg-config --cflags --libs protobuf` -ltensorflow_cc -ltensorflow_framework
也可以用opencv c++库读取图片Mat复制到Tensor中
1 tensorflow::Tensor readTensor(string filename){ 2 tensorflow::Tensor input_tensor(DT_FLOAT,TensorShape({1,240,40,1})); 3 Mat src=imread(filename,0); 4 Mat dst; 5 resize(src,dst,Size(240,40));//resize 6 Mat dst_transpose=dst.t();//transpose 7 8 auto tmap=input_tensor.tensor<float,4>(); 9 10 for(int i=0;i<240;i++){//Mat复制到Tensor 11 for(int j=0;j<40;j++){ 12 tmap(0,i,j,0)=dst_transpose.at<uchar>(i,j); 13 } 14 } 15 16 return input_tensor; 17 }
也可用指针引用的方式转换
1 tensorflow::Tensor input_tensor(DT_FLOAT,TensorShape({1,height,width,3})); 2 float *tensor_data_ptr = input_tensor.flat<float>().data(); 3 cv::Mat fake_mat(dst.rows, dst.cols, CV_32FC(src.channels()), tensor_data_ptr); 4 dst.convertTo(fake_mat, CV_32FC3);