这个过程大体就两步:
第一:在python中导出pytorch模型
第二:c++ load 上述模型并使用。
官方文档在此。
https://pytorch.org/tutorials/advanced/cpp_export.html pytorch.org
以下是个人理解,欢迎交流。
第一步:convert Pytorch func/model to Torch Script
有2中方法可以完成该任务。
一、 torch.jit.trace
- function
def foo(x):
return torch.sigmoid(x)
# trace func
tmp_in = torch.rand(5)
script_func = torch.jit.trace(foo, tmp_in)
# use traced func
x = torch.rand(3)
out = script_func(x)
- net module
class Mymodel(torch.nn.Module):
def __init__(self):
super(Mymodel, self).__init__()
self.conv=torch.nn.Conv2d(3,2,2)
def forward(self, x):
out = self.conv(x)
return x
tmp_in = torch.rand(1,3,5,5)
m = Mymodel()
script_model = torch.jit.trace(m, tmp_in)
# use traced model
out = script_model(some_inputs)
二、annotation
某些与输入有关的控制流,不适合用jit.trace的方法,如:
class MyModule(torch.nn.Module):
def __init__(self, N, M):
super(MyModule, self).__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))
def forward(self, input):
if input.sum() > 0:
output = self.weight.mv(input)
else:
output = self.weight + input
return output
虽然jit.trace依然可以正常运行,但是它只‘记录’了tmp_input所进行的操作流,当其他输入时仍然只会走‘记录’的操作,不能根据当前输入选择正确的操作。
同样的问题还有‘遍历列表’(列表的长度会变化)
这时就要使用annotation。
- func
@torch.jit.trace
def foo(x):
return torch.sigmoid(x)
# 这样foo直接就是traced func
- net module
class MyModel(torch.jit.ScriptModule): # 父类为torch.jit.ScriptModule
def __init__(self, N, M):
super(MyModel self).__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))
@torch.jit.script_method
def forward(self, input):
if input.sum() > 0:
output = self.weight.mv(input)
else:
output = self.weight + input
return output
script_model = MyModel()
# script_model 直接为traced model
第二步: 在c++中使用script model
- 在python中保存traced script model
script_model.save('model.pt')
- c++ load .pt and execute script model
#include <torch/script.h> // One-stop header.
#include <iostream> #include <memory>
int main(int argc, const char* argv[]) {
if (argc != 2) {
std::cerr << "usage: example-app <path-to-exported-script-module>\n";
return -1;
}
// Deserialize the ScriptModule from a file using torch::jit::load(). std::shared_ptr<torch::jit::script::Module> module = torch::jit::load(argv[1]);
assert(module != nullptr);
std::cout << "ok\n";
// Create a vector of inputs. std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::ones({1, 3, 224, 224}));
// Execute the model and turn its output into a tensor. at::Tensor output = module->forward(inputs).toTensor();
std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';
return 0;
}
同步csdn页:https://blog.csdn.net/qq_14975217/article/details/90475779