keras, tensorflow模型部署通过jar包部署到spark环境攻略

这是个我想干很久的事情了。之前研究tensorflow on spark, DL4j 都没有成功。所以这里首先讲一下我做这件事情的流程。模型的部署,首先你得有一个模型。这里假设你有了一个keras模型,假设你保存了一个keras 的.h5模型

python 准备阶段

你需要通过以下代码将keras h5的模型转化为pb文件

def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
"""
	:param session: 需要转换的tensorflow的session
	:param keep_var_names:需要保留的variable,默认全部转换constant
	:param output_names:output的名字
	:param clear_devices:是否移除设备指令以获得更好的可移植性
	:return:
	"""
	from tensorflow.python.framework.graph_util import convert_variables_to_constants
	graph = session.graph
	with graph.as_default():
	freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
	output_names = output_names or []
	# 如果指定了output名字,则复制一个新的Tensor,并且以指定的名字命名
	if len(output_names) > 0:
	for i in range(len(output_names)):
	# 当前graph中复制一个新的Tensor,指定名字
	tf.identity(model.outputs[i], name=output_names[i])
	output_names += [v.op.name for v in tf.global_variables()]
	input_graph_def = graph.as_graph_def()
	if clear_devices:
	for node in input_graph_def.node:
	node.device = ""
	frozen_graph = convert_variables_to_constants(session, input_graph_def,
	output_names, freeze_var_names)
	return frozen_graph

from keras import backend as K
import tensorflow as tf
from keras.models import load_model
model = load_model("models/model.h5")
print(model.input.op.name)
print(model.output.op.name)
print(model)
# 自定义output_names
frozen_graph = freeze_session(K.get_session(), output_names=["output"])
tf.train.write_graph(frozen_graph, "./", "model.pb", as_text=False)

如果你用的是tensorflow模型可以从这里开始看

这两行代码打印的是tensorflow模型的输入和输出,这个输入和输出的名字在后面java读入模型的时候有用到

print(model.input.op.name)
print(model.output.op.name)

java 部分才是细节多的地方

java部分,用java的好处在于可以把程序和资源打成jar包。不依赖于集群的资源。

类必须实现Serializable虚接口,否则不能将其写成UDF函数。

其实这一部分也可以用scala来写。这样类中的某些不能序列化的成员变量,比如Graph对象,可以通过 with Serializable来实现

public class Predictor implements Serializable 

java中资源文件的读取

资源在jar包当中时,需要在<build>标签中添加如下的内容

<resources>
<resource>
<directory>src/main/resources</directory>
<targetPath>resource</targetPath>
</resource>
</resources>

因为jar包是一个单独的文件,在打成jar包了以后,不能像在IDEA中那样运行,需要通过createTempFile和FileOutputStream的方式将内容读取出来

File labelEncoderFile = null;
String resource = "/resource/label_encoder.json";
URL res = getClass().getResource(resource);
if (res.getProtocol().equals("jar")) {
try {
    InputStream input = getClass().getResourceAsStream(resource);
    labelEncoderFile = File.createTempFile("tempfile", ".tmp");
    OutputStream out = new FileOutputStream(labelEncoderFile);
int read;
byte[] bytes = new byte[1024];

while ((read = input.read(bytes)) != -1) {
    out.write(bytes, 0, read);
}
out.close();
labelEncoderFile.deleteOnExit();
}catch (IOException ex) {
    ex.printStackTrace();
}
}else {
//this will probably work in your IDE, but not from a JAR     labelEncoderFile = new File(res.getFile());
}

if (labelEncoderFile != null && !labelEncoderFile.exists()) {
    throw new RuntimeException("Error: File " + labelEncoderFile + " not found!");
}

python 预处理文件的保存和java对预处理文件的恢复

一个重要的问题是,在python中的模型预处理阶段也需要放到java当中,以实现端到端的

pipeline ml flow.对于文本处理来说,有两个,一个是tokenizer,一个是LabelEncoder

我个人比较建议使用json进行读写,因为dict转json比较容易,java也比较容易读取这个内容

with open('label_encoder.json', 'w') as f:
    json.dump(le_dict, f)

with open('word_index.json', 'w') as f:
    json.dump(tokenizer.word_index, f)

下边是java从json的自己构造的labelEncoder的代码。LabelEncoder 和tokenizer 实际上就是一个java 的Map<String, String>。所以转换成了map也就实现了我自己的LabelEncoder

String labelEncoderContent = null;
try {
    labelEncoderContent = FileUtils.readFileToString(labelEncoderFile, "UTF-8");
}catch (IOException e) {
    e.printStackTrace();
}

Map<String, String> outputToLabel = new HashMap<String, String>();
ObjectMapper labelEncoderMapper = new ObjectMapper();
try {
    outputToLabel = labelEncoderMapper.readValue(labelEncoderContent, new TypeReference<HashMap<String, String>>() {
});
}catch (Exception e) {
    e.printStackTrace();
}

Predictor预测类中的两个关键成员变量

Predictor是我的预测java类,该类有两个很关键的成员变量,一个是this.graph,另外一个是this.sess 这两个类是用来读取pb文件,并预测的主要工具

而graph没有实现Serializable接口,所以不能在构造函数当中对其初始化。否则就不能广播,也就没有办法写成UDF。所以我在预测阶段判断了成员变量是否为空。

//this.graph是否为空,从而导入图和session
if (this.graph == null) {
    this.graph = new Graph();
    this.graph.importGraphDef(this.pbBytes);
}
if (this.sess == null) {
    this.sess = new Session(this.graph);
}

当我们的sess就位了以后就是使用tensorflow 的java接口来进行预测了。注意要在pom当中添加tensorflow的依赖

终于到预测部分了

float[][] index_seqs = new float[1][MAX_LEN];
try (Tensor x = Tensor.create(index_seqs);
// input是输入的name,output是输出的name

Tensor y = sess.runner()
.feed("input_1_3", x)
.feed("dropout_1/keras_learning_phase", Tensor.create(false))
.fetch("dense_1_3/Softmax").run().get(0)) {

float[][] result = new float[1][2033];
y.copyTo(result);

result 就是我们想要的模型输出向量

scala part

而在scala部分 最主要的是org.apache.spark.sql.functions.udf的使用

import org.apache.spark.sql.functions.udf
val predictor:Predictor = Predictor.getInstance()
def predictWithProbability= { goods_name:String => predictor.predictWithProbability(name)}
val predictWithProbabilityUDF = udf(predictWithProbability)
val predictedDataSet:DataFrame = predictDataSet.withColumn("result", predictWithProbabilityUDF(predictDataSet.col("name")))

这样就能够得到结果了

    原文作者:量子侠
    原文地址: https://zhuanlan.zhihu.com/p/62903077
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞