之前训练的几个神经网络使用的数据集比较小,一个服务器节点自带64G内存 + 虚拟内存基本上可以保证训练过程正常运行。最近终于用到了将近两百万个样例的2维训练数据,一次读入几乎把内存占满,神经网络的训练无以为继。这个时候,发现 Tensorflow 提供一种所谓的流水线式数据读入方式,貌似可以解决在大数据集上训练时内存不够的难题。一切都是看上去很美好,尝试使用时才发现踩坑无数。为了节省自己,还有 DeepLearner 未来在使用到同样功能时不跌入同样的深坑,写此笔记。
既然有 tensorflow 的官方文档,https://www.tensorflow.org/programmers_guide/reading_data
以及非常完整的中文笔记, https://saicoco.github.io/tf3/
这篇笔记只记录遇到的几个大坑以及填坑方法。在下面的这段正确实现的代码里(例子程序里经常看到),隐藏着好几个踩平了的坑:
57 example_batch, label_batch = tf.train.shuffle_batch(
58 [example, label], batch_size=batch_size, capacity=capacity,
59 min_after_dequeue=min_after_dequeue)
60
61 with tf.Session() as sess:
62 init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
63 sess.run(init_op)
64 coord = tf.train.Coordinator()
65 threads = tf.train.start_queue_runners(coord=coord)
第一坑:Tensorflow 自定义的几个队列操作很多时候用到了局域变量,所以在初始化的时候不仅要初始化全局变量,同时必须初始化局域变量,如上文例子代码第 62, 63 行所示;官方文档里的相应代码如下,只初始化了全局变量,完全没管局部变量,跟着官方文档走被坑死。
# Create the graph, etc.
init_op = tf.global_variables_initializer()
# Create a session for running operations in the Graph.
sess = tf.Session()
第二坑:有几段代码必须满足特定的执行顺序,否则万劫不复。这几段代码就是上述正确示例中的 57, 64, 65 行. 在美国,你经常会看到道路施工时有一个指挥员,协调交通以及施工操作。 tf.train.Coordinator() 就是这样的协调员,他负责告诉不同的工人(线程 threads)如何将物品并行的送入队列,而 Tensorflow 在训练时要从队列的另一端批量读出数据,做 mini-batch 训练。必须要注意的是 第 57 行 代码必须放在 64, 65 行代码的前面,因为 57 行代码里面隐性的声明了队列,而 tf.train.start_queue_runners(coord=coord) 必须要在队列声明之后调用。否则,train_shuffle_batch 函数会永久等待 start_queue_runner() 函数。
第三坑:上面例子中第 57 行的 example_batch 是 tf.Tensor 类型,不能用作 place_holder 的feed dictionary。事实上,有了example_batch 后,可以省略掉定义输入数据 place_holder的步骤而直接在计算图中使用 example_batch。
第四坑: 使用 tf.TextLineReader(skip_header_lines=1) 从 csv 文件读入数据太慢,希望能够将数据存入二进制格式,使用 TFRecordReader()。这里的一个坑就是代表2维图像的 float array 必须要转化成 bytes,没有任何文档或例子告诉你应该怎么做。还好,最后从网上搜到了一个python 库,
import struct
image_bytes = struct.pack('%sf' % len(image), *image)
其中 image 是一个二维float数组展平之后的一维数组。
下面是完整的实现代码,convert() 负责把数据转化成 TFRecords 格式,read_and_decode() 读取一个样本,input_pipeline() 负责将输入样本顺序打乱 (shuffle=True),并生成 mini_batch的训练样本。read_test() 是一个从文件中批读出的例子程序。
现在的 tensorflow documents 看起来真的很应一句话,“满纸荒唐言,一把辛酸泪”。希望大家在使用到大数据输入,能够从下面例子出发,不再像我一样心塞。
#/usr/bin/env python
#author: lgpang
#email: lgpang@qq.com
#createTime: Thu 10 Aug 2017 07:42:13 AM CST
from tqdm import tqdm
import numpy as np
import tensorflow as tf
import os
import struct
def convert():
fname = "input.csv"
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _float_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
fname_out = "output.tfrecords"
writer = tf.python_io.TFRecordWriter(fname_out)
with open(fname, 'r') as fin:
fin.readline()
#for idx in xrange(8000 * 197):
for idx in xrange(8000 * 197):
event_id = idx // 8000
oversamp_id = idx % 8000
print(event_id, oversamp_id)
label = idx
image = list(map(np.float32, fin.readline().split(',')[2:]))
image_bytes = struct.pack('%sf' % len(image), *image)
example = tf.train.Example(features=tf.train.Features(feature={
'height': _int64_feature(64),
'width': _int64_feature(64),
'label': _int64_feature(label),
'image': _bytes_feature(tf.compat.as_bytes(image_bytes))
}))
writer.write(example.SerializeToString())
writer.close()
def read_and_decode(filename_queue, pixels_in_image=64*64):
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'label': tf.FixedLenFeature([], tf.int64),
'image': tf.FixedLenFeature([], tf.string),
})
image = tf.decode_raw(features['image'], tf.float32, little_endian=True)
image.set_shape([pixels_in_image])
label = features['label']
return image, label
def input_pipeline(filenames, batch_size, num_epochs=None, num_features=None):
'''num_features := width * height for 2D image'''
filename_queue = tf.train.string_input_producer(
filenames, num_epochs=None, shuffle=True)
example, label = read_and_decode(filename_queue, num_features)
# min_after_dequeue defines how big a buffer we will randomly sample
# from -- bigger means better shuffling but slower start up and more
# memory used.
# capacity must be larger than min_after_dequeue and the amount larger
# determines the maximum we will prefetch. Recommendation:
# min_after_dequeue + (num_threads + a small safety margin) * batch_size
min_after_dequeue = 10000
capacity = min_after_dequeue + 3 * batch_size
example_batch, label_batch = tf.train.shuffle_batch(
[example, label], batch_size=batch_size, capacity=capacity,
min_after_dequeue=min_after_dequeue)
return example_batch, label_batch
def read_test():
'''example usage to read batches of records from TFRcords files'''
filenames = ['output.tfrecords']
example_batch, label_batch = input_pipeline(filenames, batch_size=100,
num_epochs=1, num_features=64*64)
with tf.Session() as sess:
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
try:
idx = 0
while not coord.should_stop():
# Run training steps or whatever
images, labels = sess.run([example_batch, label_batch])
print(images, labels)
idx = idx + 1
if idx > 2: break
except tf.errors.OutOfRangeError:
print('Done training -- epoch limit reached')
finally:
# When done, ask the threads to stop.
coord.request_stop()
coord.join(threads)
if __name__ == '__main__':
convert()
print("convert finishes")
#read_test()