昨天搞了一天用Go语言部署TensorFlow模型,把整个过程记录一下,以备大家参考(现在没有题图,以后在搞一个图)。
首先我们要有一个已经保存好的TensorFlow模型,也就是.pb文件。这个文件固化了计算图和权重,Go语言只需要根据这个代码跑相应的Session就行了。关于如何产生.pb文件,如果大家有兴趣的话可以私信我,我可以根据大家的需求情况写一份文档。具体部分可以参见tf.saved_model。
然后编译TF的源代码得到libtensorflow.so和libtensorflow_framework.so。也可以用官网上的下载链接(我没用过,大家可以尝试一下)。需要注意的是请务必保证保存模型的TF版本和这个动态链接库的TF版本一致,不然的话后面的Go代码可能会挂(大坑)。如果需要编译的话可以参考TF的官方文档,如果有兴趣的话同上请私信我。
有了这个东西,为了让ld能够找到这两个文件,Linux上需要设置$LIBRARY_PATH和$LD_LIBRARY_PATH这两个环境变量。
export LIBRARY_PATH=[.so文件所在的目录]
export LD_LIBRARY_PATH=[.so文件所在的目录]
然后是下载我们的依赖包。可以使用下边的命令。第一个是下载依赖,第二个是测试下载的依赖有没有问题。如果第二个出错,就证明前面的步骤有问题。
go get github.com/tensorflow/tensorflow/tensorflow/go
go get github.com/tensorflow/tensorflow/tensorflow/go
接下来就可以愉快的载入模型开始玩了。下面是载入模型的示例代码。载入模型的时候需要给模型所在的文件夹和模型的名字(模型的名字可以用saved_model_cli这个工具来查看)。后面的一段是我自己家的,意思是打印出当前模型图里面所有的Operator。这个代码返回一个tf.SavedModel的struct,这个struct有两个成员,第一个是Session,第二个是Graph。如果大家对于TF的python API很熟应该知道这两个是什么东西。
func LoadModel(modelPath string, modelNames []string) *tf.SavedModel {
model, err := tf.LoadSavedModel(modelPath, modelNames, nil) // 载入模型 if err != nil {
log.Fatal("LoadSavedModel(): %v", err)
}
log.Println("List possible ops in graphs") // 打印出所有的Operator for _, op := range model.Graph.Operations() {
//log.Printf("Op name: %v, on device: %v", op.Name(), op.Device()) log.Printf("Op name: %v", op.Name())
}
return model
}
有了Session和Graph之后,我们就能跑这个模型了。我这边用的是gin这个web框架,直接把输入的JSON编码成TensorFlow接受的输入,然后调用Session.Run方法来跑整个计算图。这个方法传三个参数,第一个参数是一个map,把每个tf.Output类型映射成一个tf.Tensor。前面一个在知道输入的Operator的情况下,可以通过Operator.Output(0)方法拿到,后面一个,可以使用tf.NewTensor这个函数,传入输入的Go数组来生成。如果大家熟悉TensorFlow的Python API的话,我们会发现,第一个类似与feed_dict这个参数。第二个参数是输出的张量的列表。我们同样可以在拿到Operator以后,通过Operator.Output(0)方法拿到,注意要把他们包装成一个[]Output类型,即使里面只有一个元素。第三个是不执行的Operator的列表,这里我们设置成nil。
func main () {
m := LoadModel("../freeze_model", []string{"serve"})
s := m.Session
// ... ServeJSON := func (c *gin.Context) {
var json map[string] int64
if c.BindJSON(&json) == nil {
log.Println(json)
}
ret, err:= s.Run(MapGraphInputs(CreateMapFromJSON(json), m),
GetGraphOutputs([]string{"prob"}, m), nil)
if err != nil {
log.Fatal("Error in executing graph...", err)
}
// ... }
// ... }
然后我们编译一下源代码,跑一下,发现gin框架起来了,我们就可以用这个可执行文件做web服务了~这个可执行文件和.so文件,以及模型文件一起,完全可以一起copy到docker的container里面,这样就可以用k8s愉快的和这个模型玩耍了。