tfrecord 是 TensorFlow官方建议的输入文件格式。
小规模的文件生成可以用如下:
https://github.com/godkillok/tensorflow_template/blob/master/data/text/print_csv_tfrecords.py
查看tfrecord里面的内容:
https://github.com/godkillok/tensorflow_template/blob/master/data/text/print_csv_tfrecords.py
大规模
安装和配置spark-tensorflow
1.先下载这个项目 git clone https://github.com/tensorflow/ecosystem/tree/master/spark/spark-tensorflow-connector
2.安装moven
2.1下载压缩包: http://maven.apache.org/download.cgi
2.2unzip apache-maven-3.6.0-bin.zip
2.3Check environment variable value:
echo $JAVA_HOME
2.4 新增maven的环境变量
export PATH=/解压的那个地址/apache-maven-3.6.0/bin:$PATH
3. cd的第1步的下载的项目的hadoop 目录下:
cd ../ecosystem/hadoop
mvn clean install
4.cd 到spark-tensorflow-connector目录下:
cd ../spark/spark-tensorflow-connector
mvn clean install
5.pyspark –jars target(在那个项目下和hadoop同级别 有新增一个文件夹叫target)/spark-tensorflow-connector_2.11-1.10.0.jar
安装直接使用
直接用就是 pyspark –jars /ecosystem/spark/spark-tensorflow-connector/target/spark-tensorflow-connector_2.11-1.10.0.jar
如下是demo的
from pyspark.sql.types import *
#
path = "test-output.tfrecord"
fields = [StructField("id", IntegerType()), StructField("IntegerCol", IntegerType()),
StructField("LongCol", LongType()), StructField("FloatCol", FloatType()),
StructField("DoubleCol", DoubleType()), StructField("VectorCol", ArrayType(DoubleType(), True)),
StructField("StringCol", StringType())]
schema = StructType(fields)
test_rows = [[11, 1, 23, 10.0, 14.0, [1.0, 2.0], "r1"], [21, 2, 24, 12.0, 15.0, [2.0, 2.0], "r2"]]
rdd = spark.sparkContext.parallelize(test_rows)
df = spark.createDataFrame(rdd, schema)
df.write.format("tfrecords").option("recordType", "Example").save(path)
df = spark.read.format("tfrecords").option("recordType", "Example").load(path)
df.show()