前言
昨晚睡了12小时,早上起来神清气爽,索性把之前提的一个Issue:Is there any plan to port TensorframeOnSpark(From yahoo) 给尝试着集成进来。 前两天已经添加了一个 TFTextEstimator:为Spark Deep Learning 添加NLP处理实现,不过只能做hyper parameter tuning,做不了真正的分布式训练,所以正好把这个特性加到了这个Estimator里。
使用方法
建议看这篇文章之前,先看为Spark Deep Learning 添加NLP处理实现。 我给TFTextFileEstimator 添加了一个新的参数叫做 runningMode。目前只有两个值: Normal 和 TFoS。
# create a estimator to training where map_fun contains tensorflow's code
estimator = TFTextFileEstimator(inputCol="sentence_matrix", outputCol="sentence_matrix", labelCol="preds",
fitParam=[{"epochs": 1, "cluster_size": 2, "batch_size": 1, "model": "/tmp/model"}],
runningMode="TFoS",
mapFnParam=map_fun)
如果使用TFoS model参数是必须的。并且 map_fun方法也需要做些改造。主要是tensorflow 分布式training 和 单机多device 还是有区别的。
原理
在TFTextEstimator里,通过参数runningMode控制:
if self.getRunningMode() == "TFoS":
return self._fitInCluster(dataset, paramMaps)
else:
return self._fitInParallel(dataset, paramMaps)
如果是,则走集群模式,否则走并行训练。 我们来看看_fitInCluster:
def _fitInCluster(self, dataset, paramMaps):
sc = JVMAPI._curr_sc()
temp_item = dataset.take(1)[0]
vocab_s = temp_item["vocab_size"]
embedding_size = temp_item["embedding_size"]
baseParamMap = self.extractParamMap()
baseParamDict = dict([(param.name, val) for param, val in baseParamMap.items()])
args = self._clusterModelDefaultValue(sc, paramMaps[0])
args["feature"] = self.getInputCol()
args["label"] = self.getLabelCol()
args["vacab_size"] = vocab_s
args["embedding_size"] = embedding_size
args["params"] = baseParamDict
cluster = TFCluster.run(sc, self.getMapFnParam(), args, args['cluster_size'], args['num_ps'],
args['tensorboard'],
TFCluster.InputMode.SPARK)
cluster.train(dataset.rdd, args["epochs"])
cluster.shutdown()
很简单,创建 TFCluster对象,并且调用其train方法。 最核心的还是 map_fun函数,这里实现了所有的tf逻辑(除了数据以外)。我后面会单独一个篇幅来讲。在做实现的过程,发现两个问题:
- TFoS 最好一个批次的数据会丢失 ,我对应提了一个IssueWhen training, the data of last batch will not be trained
- TFoS 没有办法跑在Local模式,所以调试麻烦些,需要跑在spark standalone模式下。
可运行的实例代码在: TFoSTest.py
mapfun函数解析
TFoSTest.py 里的代码兼容单机和集群模式运行。
def map_fun(args={}, ctx=None, _read_data=None):
如果ctx为None,则是单机模式,否则为集群模式。如果是集群模式则直接使用
TFNode.DataFeed(ctx.mgr, True)
获取数据,否则使用 _read_data 获取数据。具体详细参看示例中代码。
结束语
这个只是Demo性质的,单机和集群模式的融合度还不够好,map_fun编写难度还有些。