使用 Tensorflow 读取大数据时的踩坑之路

之前训练的几个神经网络使用的数据集比较小,一个服务器节点自带64G内存 + 虚拟内存基本上可以保证训练过程正常运行。最近终于用到了将近两百万个样例的2维训练数据,一次读入几乎把内存占满,神经网络的训练无以为继。这个时候,发现 Tensorflow 提供一种所谓的流水线式数据读入方式,貌似可以解决在大数据集上训练时内存不够的难题。一切都是看上去很美好,尝试使用时才发现踩坑无数。为了节省自己,还有 DeepLearner 未来在使用到同样功能时不跌入同样的深坑,写此笔记。

既然有 tensorflow 的官方文档,https://www.tensorflow.org/programmers_guide/reading_data

以及非常完整的中文笔记, https://saicoco.github.io/tf3/

这篇笔记只记录遇到的几个大坑以及填坑方法。在下面的这段正确实现的代码里(例子程序里经常看到),隐藏着好几个踩平了的坑:

57  example_batch, label_batch = tf.train.shuffle_batch(
58           [example, label], batch_size=batch_size, capacity=capacity,
59           min_after_dequeue=min_after_dequeue)
60 
61  with tf.Session() as sess:
62      init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
63      sess.run(init_op)
64      coord = tf.train.Coordinator()
65      threads = tf.train.start_queue_runners(coord=coord)

第一坑:Tensorflow 自定义的几个队列操作很多时候用到了局域变量,所以在初始化的时候不仅要初始化全局变量,同时必须初始化局域变量,如上文例子代码第 62, 63 行所示;官方文档里的相应代码如下,只初始化了全局变量,完全没管局部变量,跟着官方文档走被坑死。

# Create the graph, etc.
init_op = tf.global_variables_initializer()
# Create a session for running operations in the Graph.
sess = tf.Session()

第二坑:有几段代码必须满足特定的执行顺序,否则万劫不复。这几段代码就是上述正确示例中的 57, 64, 65 行. 在美国,你经常会看到道路施工时有一个指挥员,协调交通以及施工操作。 tf.train.Coordinator() 就是这样的协调员,他负责告诉不同的工人(线程 threads)如何将物品并行的送入队列,而 Tensorflow 在训练时要从队列的另一端批量读出数据,做 mini-batch 训练。必须要注意的是 第 57 行 代码必须放在 64, 65 行代码的前面,因为 57 行代码里面隐性的声明了队列,而 tf.train.start_queue_runners(coord=coord) 必须要在队列声明之后调用。否则,train_shuffle_batch 函数会永久等待 start_queue_runner() 函数。

第三坑:上面例子中第 57 行的 example_batch 是 tf.Tensor 类型,不能用作 place_holder 的feed dictionary。事实上,有了example_batch 后,可以省略掉定义输入数据 place_holder的步骤而直接在计算图中使用 example_batch。

第四坑: 使用 tf.TextLineReader(skip_header_lines=1) 从 csv 文件读入数据太慢,希望能够将数据存入二进制格式,使用 TFRecordReader()。这里的一个坑就是代表2维图像的 float array 必须要转化成 bytes,没有任何文档或例子告诉你应该怎么做。还好,最后从网上搜到了一个python 库,

import struct
image_bytes = struct.pack('%sf' % len(image), *image)

其中 image 是一个二维float数组展平之后的一维数组。

下面是完整的实现代码,convert() 负责把数据转化成 TFRecords 格式,read_and_decode() 读取一个样本,input_pipeline() 负责将输入样本顺序打乱 (shuffle=True),并生成 mini_batch的训练样本。read_test() 是一个从文件中批读出的例子程序。

现在的 tensorflow documents 看起来真的很应一句话,“满纸荒唐言,一把辛酸泪”。希望大家在使用到大数据输入,能够从下面例子出发,不再像我一样心塞。

#/usr/bin/env python
#author: lgpang
#email: lgpang@qq.com
#createTime: Thu 10 Aug 2017 07:42:13 AM CST

from tqdm import tqdm
import numpy as np
import tensorflow as tf
import os
import struct

def convert():
    fname = "input.csv"
    def _int64_feature(value):
        return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

    def _float_feature(value):
        return tf.train.Feature(float_list=tf.train.FloatList(value=value))

    def _bytes_feature(value):
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

    fname_out = "output.tfrecords"
    writer = tf.python_io.TFRecordWriter(fname_out)

    with open(fname, 'r') as fin:
        fin.readline()
        #for idx in xrange(8000 * 197):
        for idx in xrange(8000 * 197):
            event_id = idx // 8000
            oversamp_id = idx % 8000
            print(event_id, oversamp_id)
            label = idx
            image = list(map(np.float32, fin.readline().split(',')[2:]))
            image_bytes = struct.pack('%sf' % len(image), *image)
            example = tf.train.Example(features=tf.train.Features(feature={
                'height': _int64_feature(64),
                'width': _int64_feature(64),
                'label': _int64_feature(label),
                'image': _bytes_feature(tf.compat.as_bytes(image_bytes))
                }))
            writer.write(example.SerializeToString())

    writer.close()


def read_and_decode(filename_queue, pixels_in_image=64*64):
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(
            serialized_example,
            features={
                'height': tf.FixedLenFeature([], tf.int64),
                'width': tf.FixedLenFeature([], tf.int64),
                'label': tf.FixedLenFeature([], tf.int64),
                'image': tf.FixedLenFeature([], tf.string),
                })

    image = tf.decode_raw(features['image'], tf.float32, little_endian=True)
    image.set_shape([pixels_in_image])
    label = features['label']
    return image, label

def input_pipeline(filenames, batch_size, num_epochs=None, num_features=None):
    '''num_features := width * height for 2D image'''
    filename_queue = tf.train.string_input_producer(
            filenames, num_epochs=None, shuffle=True)
    example, label = read_and_decode(filename_queue, num_features)
    # min_after_dequeue defines how big a buffer we will randomly sample
    # from -- bigger means better shuffling but slower start up and more
    # memory used.
    # capacity must be larger than min_after_dequeue and the amount larger
    # determines the maximum we will prefetch. Recommendation:
    # min_after_dequeue + (num_threads + a small safety margin) * batch_size
    min_after_dequeue = 10000
    capacity = min_after_dequeue + 3 * batch_size
    example_batch, label_batch = tf.train.shuffle_batch(
            [example, label], batch_size=batch_size, capacity=capacity,
            min_after_dequeue=min_after_dequeue)
    return example_batch, label_batch


def read_test():
    '''example usage to read batches of records from TFRcords files'''
    filenames = ['output.tfrecords']
    example_batch, label_batch = input_pipeline(filenames, batch_size=100,
        num_epochs=1, num_features=64*64)

    with tf.Session() as sess:
        init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
        sess.run(init_op)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        try:
            idx = 0
            while not coord.should_stop():
                # Run training steps or whatever
                images, labels = sess.run([example_batch, label_batch])
                print(images, labels)
                idx = idx + 1
                if idx > 2: break
        except tf.errors.OutOfRangeError:
            print('Done training -- epoch limit reached')
        finally:
            # When done, ask the threads to stop.
            coord.request_stop()

        coord.join(threads)

if __name__ == '__main__':
    convert()
    print("convert finishes")
    #read_test()

    原文作者:hahakity
    原文地址: https://zhuanlan.zhihu.com/p/28450111
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞