spark 生成TensorFlow的tfrecord 文件

tfrecord 是 TensorFlow官方建议的输入文件格式。

小规模的文件生成可以用如下:

https://github.com/godkillok/tensorflow_template/blob/master/data/text/print_csv_tfrecords.py

查看tfrecord里面的内容:

https://github.com/godkillok/tensorflow_template/blob/master/data/text/print_csv_tfrecords.py

大规模

安装和配置spark-tensorflow

1.先下载这个项目 git clone https://github.com/tensorflow/ecosystem/tree/master/spark/spark-tensorflow-connector

2.安装moven

2.1下载压缩包: http://maven.apache.org/download.cgi

2.2unzip apache-maven-3.6.0-bin.zip

2.3Check environment variable value:

echo $JAVA_HOME

2.4 新增maven的环境变量

export PATH=/解压的那个地址/apache-maven-3.6.0/bin:$PATH

3. cd的第1步的下载的项目的hadoop 目录下:

cd ../ecosystem/hadoop

mvn clean install

4.cd 到spark-tensorflow-connector目录下:

cd ../spark/spark-tensorflow-connector

mvn clean install

5.pyspark –jars target(在那个项目下和hadoop同级别 有新增一个文件夹叫target)/spark-tensorflow-connector_2.11-1.10.0.jar

安装直接使用

直接用就是 pyspark –jars /ecosystem/spark/spark-tensorflow-connector/target/spark-tensorflow-connector_2.11-1.10.0.jar

如下是demo的

from pyspark.sql.types import *
#
path = "test-output.tfrecord"

fields = [StructField("id", IntegerType()), StructField("IntegerCol", IntegerType()),
          StructField("LongCol", LongType()), StructField("FloatCol", FloatType()),
          StructField("DoubleCol", DoubleType()), StructField("VectorCol", ArrayType(DoubleType(), True)),
          StructField("StringCol", StringType())]
schema = StructType(fields)
test_rows = [[11, 1, 23, 10.0, 14.0, [1.0, 2.0], "r1"], [21, 2, 24, 12.0, 15.0, [2.0, 2.0], "r2"]]
rdd = spark.sparkContext.parallelize(test_rows)
df = spark.createDataFrame(rdd, schema)
df.write.format("tfrecords").option("recordType", "Example").save(path)

df = spark.read.format("tfrecords").option("recordType", "Example").load(path)
df.show()

    原文作者:汤go
    原文地址: https://zhuanlan.zhihu.com/p/51819048
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞