C++ 如何调用Pytorch模型

这个过程大体就两步:

第一:在python中导出pytorch模型

第二:c++ load 上述模型并使用。

官方文档在此。

https://pytorch.org/tutorials/advanced/cpp_export.htmlpytorch.org

以下是个人理解,欢迎交流。

第一步:convert Pytorch func/model to Torch Script

有2中方法可以完成该任务。

一、 torch.jit.trace

  1. function
def foo(x):
    return torch.sigmoid(x)
# trace func
tmp_in = torch.rand(5)
script_func = torch.jit.trace(foo, tmp_in)

# use traced func
x = torch.rand(3)
out = script_func(x)
  1. net module
class Mymodel(torch.nn.Module):
    def __init__(self):
        super(Mymodel, self).__init__()
        self.conv=torch.nn.Conv2d(3,2,2)
    
    def forward(self, x):
        out = self.conv(x)
        return x

tmp_in = torch.rand(1,3,5,5)
m = Mymodel()
script_model = torch.jit.trace(m, tmp_in)

# use traced model
out = script_model(some_inputs)

二、annotation

某些与输入有关的控制流,不适合用jit.trace的方法,如:

class MyModule(torch.nn.Module):
    def __init__(self, N, M):
        super(MyModule, self).__init__()
        self.weight = torch.nn.Parameter(torch.rand(N, M))

    def forward(self, input):
        if input.sum() > 0:
          output = self.weight.mv(input)
        else:
          output = self.weight + input
        return output

虽然jit.trace依然可以正常运行,但是它只‘记录’了tmp_input所进行的操作流,当其他输入时仍然只会走‘记录’的操作,不能根据当前输入选择正确的操作。

同样的问题还有‘遍历列表’(列表的长度会变化)

这时就要使用annotation。

  1. func
@torch.jit.trace
def foo(x):
    return torch.sigmoid(x)

# 这样foo直接就是traced func
  1. net module
class MyModel(torch.jit.ScriptModule):  # 父类为torch.jit.ScriptModule
    def __init__(self, N, M):
        super(MyModel self).__init__()
        self.weight = torch.nn.Parameter(torch.rand(N, M))
    
    @torch.jit.script_method
    def forward(self, input):
        if input.sum() > 0:
          output = self.weight.mv(input)
        else:
          output = self.weight + input
        return output

script_model = MyModel()
# script_model 直接为traced model

第二步: 在c++中使用script model

  1. 在python中保存traced script model

script_model.save('model.pt')

  1. c++ load .pt and execute script model
#include <torch/script.h> // One-stop header. 
#include <iostream> #include <memory> 
int main(int argc, const char* argv[]) {
  if (argc != 2) {
    std::cerr << "usage: example-app <path-to-exported-script-module>\n";
    return -1;
  }

  // Deserialize the ScriptModule from a file using torch::jit::load().   std::shared_ptr<torch::jit::script::Module> module = torch::jit::load(argv[1]);
  assert(module != nullptr);
  std::cout << "ok\n";
    
  // Create a vector of inputs.   std::vector<torch::jit::IValue> inputs; 
  inputs.push_back(torch::ones({1, 3, 224, 224}));

  // Execute the model and turn its output into a tensor.   at::Tensor output = module->forward(inputs).toTensor();

  std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';
    
  return 0;
}

同步csdn页:https://blog.csdn.net/qq_14975217/article/details/90475779

    原文作者:leaf
    原文地址: https://zhuanlan.zhihu.com/p/66707105
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞