TensorFlow Object Detection API 源码(4) 数据集

0. 前言

  • 没找到什么研究相关源码的资料。
  • 主要内容:
    • 学习构建tfrecords。
    • 学习如何通过tfrecords文件集构建数据集。
  • 如果想通过 TensorFlow Object Detection API 来训练自己的数据集,个人建议:
    • 主要了解 tfrecords 的构建方法,目标是将自己的数据构建为API可以识别的 tfrecords 格式。
    • 对于 train.py 等脚本,对于数据的具体预处理过程了解即可。
    • 对于配置文件内容需要非常熟悉。

1. 构建tfrecords

1.1. 综述

1.2. PASCAL 脚本解析

  • 源码地址:create_pascal_tf_record.py
  • 脚本使用实例:
    • data_dir:数据集根目录。
    • set:可以是 train, val, trainval, test 其中之一。
    • annotations_dir:标签文件夹名称,默认为Annotations,相对于data_dir的相对路径。
    • year:可以是 VOC2007, VOC2012, merged 其中之一。
    • output_path:生成tfrecords文件的路径。
    • label_map_path:用到的标签map文件。
    • ignore_difficult_instances:是否忽略有 difficult 标记的数据。
# 在 /path/to/models/research 路径下运行
python object_detection/dataset_tools/create_pascal_tf_record.py \
    --data_dir=/home/user/VOCdevkit \
    --set=train \
    --annotations_dir=/path/to/annotations \
    --year=VOC2007 \
    --output_path=/home/user/pascal.record \
    --label_map_path=data/pascal_label_map.pbtxt \
    --ignore_difficult_instances=Fasle
  • main 函数解析
def main(_):
  # 输入参数合法性判断
  if FLAGS.set not in SETS:
    raise ValueError('set must be in : {}'.format(SETS))
  if FLAGS.year not in YEARS:
    raise ValueError('year must be in : {}'.format(YEARS))
  data_dir = FLAGS.data_dir
  years = ['VOC2007', 'VOC2012']
  if FLAGS.year != 'merged':
    years = [FLAGS.year]

  # 创建 TFRecordWriter 对象,准备写入本地文件
  writer = tf.python_io.TFRecordWriter(FLAGS.output_path)

  # 解析 label_map_path 文件
  label_map_dict = label_map_util.get_label_map_dict(FLAGS.label_map_path)

  # 读取输入数据
  for year in years:
    logging.info('Reading from PASCAL %s dataset.', year)

    # 读取输入图片列表
    examples_path = os.path.join(data_dir, year, 'ImageSets', 'Main',
                                 'aeroplane_' + FLAGS.set + '.txt')
    examples_list = dataset_util.read_examples_list(examples_path)

    # 获取标签绝对路径
    annotations_dir = os.path.join(data_dir, year, FLAGS.annotations_dir)

    # 依次处理所有文件
    for idx, example in enumerate(examples_list):
      if idx % 100 == 0:
        logging.info('On image %d of %d', idx, len(examples_list))

      # 解析 xml 标签文件,转换为字典
      path = os.path.join(annotations_dir, example + '.xml')
      with tf.gfile.GFile(path, 'r') as fid:
        xml_str = fid.read()
      xml = etree.fromstring(xml_str)
      data = dataset_util.recursive_parse_xml_to_dict(xml)['annotation']

      # 根据解析得到的字典,获取 tf.train.Example 对象,并将该对象写入本地文件
      tf_example = dict_to_tf_example(data, FLAGS.data_dir, label_map_dict,
                                      FLAGS.ignore_difficult_instances)
      writer.write(tf_example.SerializeToString())

  writer.close()
  • dict_to_tf_example 函数解析
    • 主要功能:读取图片数据,读取标签数据,构建tf.train.Example对象。
def dict_to_tf_example(data,
                       dataset_directory,
                       label_map_dict,
                       ignore_difficult_instances=False,
                       image_subdirectory='JPEGImages'):
  # 获取图片绝对路径
  img_path = os.path.join(data['folder'], image_subdirectory, data['filename'])
  full_path = os.path.join(dataset_directory, img_path)

  # 读取读片数据,并判断图片格式
  with tf.gfile.GFile(full_path, 'rb') as fid:
    encoded_jpg = fid.read()
  encoded_jpg_io = io.BytesIO(encoded_jpg)
  image = PIL.Image.open(encoded_jpg_io)
  if image.format != 'JPEG':
    raise ValueError('Image format not JPEG')

  # 根据图片数据
  key = hashlib.sha256(encoded_jpg).hexdigest()

  # 获取 tfrecord 文件所需数据
  width = int(data['size']['width'])
  height = int(data['size']['height'])
  xmin = []
  ymin = []
  xmax = []
  ymax = []
  classes = []
  classes_text = []
  truncated = []
  poses = []
  difficult_obj = []
  if 'object' in data:
    for obj in data['object']:
      difficult = bool(int(obj['difficult']))
      if ignore_difficult_instances and difficult:
        continue
      difficult_obj.append(int(difficult))
      xmin.append(float(obj['bndbox']['xmin']) / width)
      ymin.append(float(obj['bndbox']['ymin']) / height)
      xmax.append(float(obj['bndbox']['xmax']) / width)
      ymax.append(float(obj['bndbox']['ymax']) / height)
      classes_text.append(obj['name'].encode('utf8'))
      classes.append(label_map_dict[obj['name']])
      truncated.append(int(obj['truncated']))
      poses.append(obj['pose'].encode('utf8'))

  # 构建 tf.train.Example 对象
  example = tf.train.Example(features=tf.train.Features(feature={
      'image/height': dataset_util.int64_feature(height),
      'image/width': dataset_util.int64_feature(width),
      'image/filename': dataset_util.bytes_feature(
          data['filename'].encode('utf8')),
      'image/source_id': dataset_util.bytes_feature(
          data['filename'].encode('utf8')),
      'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
      'image/encoded': dataset_util.bytes_feature(encoded_jpg),
      'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
      'image/object/bbox/xmin': dataset_util.float_list_feature(xmin),
      'image/object/bbox/xmax': dataset_util.float_list_feature(xmax),
      'image/object/bbox/ymin': dataset_util.float_list_feature(ymin),
      'image/object/bbox/ymax': dataset_util.float_list_feature(ymax),
      'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
      'image/object/class/label': dataset_util.int64_list_feature(classes),
      'image/object/difficult': dataset_util.int64_list_feature(difficult_obj),
      'image/object/truncated': dataset_util.int64_list_feature(truncated),
      'image/object/view': dataset_util.bytes_list_feature(poses),
  }))
  return example
  • label_map介绍:
    • 输入配置文件:本质是一个 string_int_label_map.proto 配置文件。当然,对于常用数据集,配置文件都已经提供(位于 data 文件夹)。
    • 通过调用 label_map_util.pyget_label_map_dict 方法,可以将配置文件转换为一个字典,key为标签名称,value为标签对应的数字。

1.3. 构建自己的数据集

  • 总体思路:
    • 所有 tfrecords 文件中,主要的属性名称必须相同。
    • 如果自己数据集的标签格式与 COCO, Pascal 等相同,则可以稍微修改一下相应脚本。
    • 如果标签信息格式独特,那肯定需要自己写脚本来生成 tfrecords 文件,要求重点属性名相同。
  • 重点属性名(查看 coco, pascal, kitti 等脚本中相同的属性名):
    • 图片基本信息:
      • image/height
      • image/width
      • image/filename
      • image/source_id
      • image/key/sha256:根据图片encoded生成的sha256编码。
      • image/encoded:图片数据本身。
      • image/format:图片格式。
    • bbox信息:
      • image/object/bbox/xmin
      • image/object/bbox/xmax
      • image/object/bbox/ymin
      • image/object/bbox/ymax
      • image/object/class/text:分类标签文本。
      • image/object/class/label:分类标签对应数字。
      • image/object/difficult
    • 完整版属性名可以参考:standard_fields.py 中的 TfExampleFields

2. 通过 tfrecords 文件构建数据集对象

2.1. 综述

  • 在看了 train.py 以及 eval.py 的源码后:
    • 构建数据集主要通过 dataset_builder 构建。对应的配置文件是 input_reader.proto,返回参数是一个 tf.data.Dataset 对象。
    • 之后,以tf.data.Dataset为输入,通过 dataset_util.make_initializable_iterator 创建 make_initializable_iterator 结果,并将对应的 iterator.initializer 添加到 tf.GraphKeys.TABLE_INITIALIZERS 中,以iterator.get_next()作为后续操作的输入。
    • train.py 中额外包含了数据增强部分,使用 preprocessor_builder,对应的配置文件是 preprocessor.proto

2.2. 通过 dataset_builder 构建 tf.data.Dataset 对象

  • 源码地址:dataset_builderinput_reader.proto
  • dataset_builder.build 函数解析
    • 作用:通过配置文件,调用 dataset_util.read_dataset 方法生成 tf.data.Dataset 对象。
def build(input_reader_config,
          transform_input_data_fn=None,
          batch_size=None,
          max_num_boxes=None,
          num_classes=None,
          spatial_image_shape=None,
          num_additional_channels=0):
  # 判断配置文件合法性
  if not isinstance(input_reader_config, input_reader_pb2.InputReader):
    raise ValueError('input_reader_config not of type '
                     'input_reader_pb2.InputReader.')

  # 以 tfrecords 文件作为输入
  if input_reader_config.WhichOneof('input_reader') == 'tf_record_input_reader':
    # 读取配置文件信息
    config = input_reader_config.tf_record_input_reader
    if not config.input_path:
      raise ValueError('At least one input path must be specified in '
                       '`input_reader_config`.')
    label_map_proto_file = None
    if input_reader_config.HasField('label_map_path'):
      label_map_proto_file = input_reader_config.label_map_path

    # 根据配置文件内容,构建 TfExampleDecoder 对象
    decoder = tf_example_decoder.TfExampleDecoder(
        load_instance_masks=input_reader_config.load_instance_masks,
        instance_mask_type=input_reader_config.mask_type,
        label_map_proto_file=label_map_proto_file,
        use_display_name=input_reader_config.use_display_name,
        num_additional_channels=num_additional_channels)

    # 数据预处理函数
    # 功能:解析 tfrecords,并调用自定义的 transform_input_data_fn 函数处理原始数据
    def process_fn(value):
      processed = decoder.decode(value)
      if transform_input_data_fn is not None:
        return transform_input_data_fn(processed)
      return processed

    # 根据配置生成 tf.data.Dataset 对象
    dataset = dataset_util.read_dataset(
        functools.partial(tf.data.TFRecordDataset, buffer_size=8 * 1000 * 1000),
        process_fn, config.input_path[:], input_reader_config)

    # 如果需要 batch 数据,则需要将所有图片都保持在同一尺寸
    # 注意,并不是进行 resize,而是通过 padding 
    # 具体可以参考 tf.contrib.data.padded_batch_and_drop_remainder 或 tf.data.Dataset.padded_batch
    if batch_size:
      padding_shapes = _get_padding_shapes(dataset, max_num_boxes, num_classes,
                                           spatial_image_shape)
      dataset = dataset.apply(
          tf.contrib.data.padded_batch_and_drop_remainder(batch_size,
                                                          padding_shapes))
    return dataset

  raise ValueError('Unsupported input_reader_config.')
  • dataset_util.read_dataset 方法:
    • 功能:实际生成 tf.data.Dataset 实例。
def read_dataset(file_read_func, decode_func, input_files, config):
  # 获取所有输入文件名称
  filenames = tf.gfile.Glob(input_files)
  num_readers = config.num_readers
  if num_readers > len(filenames):
    num_readers = len(filenames)
    tf.logging.warning('num_readers has been reduced to %d to match input file '
                       'shards.' % num_readers)

  # 以文件名队列作为输入,构建 tf.data.Dataset 对象
  # 主要包括 from_tensor_slices shuffle repeat parallel_interleave map prefetch 操作
  filename_dataset = tf.data.Dataset.from_tensor_slices(tf.unstack(filenames))
  if config.shuffle:
    filename_dataset = filename_dataset.shuffle(
        config.filenames_shuffle_buffer_size)
  elif num_readers > 1:
    tf.logging.warning('`shuffle` is false, but the input data stream is '
                       'still slightly shuffled since `num_readers` > 1.')
  filename_dataset = filename_dataset.repeat(config.num_epochs or None)
  records_dataset = filename_dataset.apply(
      tf.contrib.data.parallel_interleave(
          file_read_func,
          cycle_length=num_readers,
          block_length=config.read_block_length,
          sloppy=config.shuffle))
  if config.shuffle:
    records_dataset = records_dataset.shuffle(config.shuffle_buffer_size)

  # 该方法包括解析 tfrecords 文件
  # 并进行用户自定义数据转换(不是 model.preprocess,也不是数据增强)
  tensor_dataset = records_dataset.map(
      decode_func, num_parallel_calls=config.num_parallel_map_calls)
  return tensor_dataset.prefetch(config.prefetch_size)

2.3. 训练过程数据的后续处理

  • 源码地址:trainer.py
  • 流程:
    • 根据配置文件,进行数据增强。
    • 构建 batcher.BatchQueue 实例用于多GPU训练。
  • 数据增强细节请参考 preprocessor_builderpreprocessor.protopreprocessor.py
    • 三者的关系大概是:主程序通过 preprocessor.proto 配置文件信息,通过 preprocessor_builder 返回 preprocessor.py 中对应的方法。
  • trainer.pycreate_input_queue函数解析。
def create_input_queue(batch_size_per_clone, create_tensor_dict_fn,
                       batch_queue_capacity, num_batch_queue_threads,
                       prefetch_queue_capacity, data_augmentation_options):
  # create_tensor_dict_fn 的本质就是 tf.data.Dataset().make_initializable_iterator.get_next()
  tensor_dict = create_tensor_dict_fn()

  # 获取图片数据并简单转换
  tensor_dict[fields.InputDataFields.image] = tf.expand_dims(
      tensor_dict[fields.InputDataFields.image], 0)
  images = tensor_dict[fields.InputDataFields.image]
  float_images = tf.to_float(images)
  tensor_dict[fields.InputDataFields.image] = float_images

  # 获取各种标志位
  include_instance_masks = (fields.InputDataFields.groundtruth_instance_masks
                            in tensor_dict)
  include_keypoints = (fields.InputDataFields.groundtruth_keypoints
                       in tensor_dict)
  include_multiclass_scores = (fields.InputDataFields.multiclass_scores
                               in tensor_dict)

  # 进行数据增强
  if data_augmentation_options:
    tensor_dict = preprocessor.preprocess(
        tensor_dict, data_augmentation_options,
        func_arg_map=preprocessor.get_default_func_arg_map(
            include_multiclass_scores=include_multiclass_scores,
            include_instance_masks=include_instance_masks,
            include_keypoints=include_keypoints))

  # 生成 batcher.BatchQueue 方法用于分布式训练
  input_queue = batcher.BatchQueue(
      tensor_dict,
      batch_size=batch_size_per_clone,
      batch_queue_capacity=batch_queue_capacity,
      num_batch_queue_threads=num_batch_queue_threads,
      prefetch_queue_capacity=prefetch_queue_capacity)
  return input_queue

3. 流水账

3.1. 为什么 dataset.make_initializable_iterator().get_next() 的返回值是一个字典?

  • 相关源码如下:
# 项目中源码
def get_next(config):
    return dataset_util.make_initializable_iterator(
        dataset_builder.build(config)).get_next()
create_input_dict_fn = functools.partial(get_next, input_config)
  • 猜测与 TfExampleDecoder.decode 有关:
    • 在创建tf.data.Dataset对象时,调用了该方法。
    • 该方法的返回值是一个字典。

3.2. batcher.BatchQueue 的使用

  • 源码地址:batcher.py
  • 官方实例(Example input pipeline with batching):
key, string_tensor = slim.parallel_reader.parallel_read(...)
tensor_dict = decoder.decode(string_tensor)
tensor_dict = preprocessor.preprocess(tensor_dict, ...)
batch_queue = batcher.BatchQueue(tensor_dict,
                                batch_size=32,
                                batch_queue_capacity=2000,
                                num_batch_queue_threads=8,
                                prefetch_queue_capacity=20)
tensor_dict = batch_queue.dequeue()
outputs = Model(tensor_dict)
  • 原理:
    • 初始化实例:通过 tf.train.batch 进行batch操作,并通过 prefetcher.prefetch 构建 tf.PaddingFIFOQueue 对象。
    • dequeue 方法:获取 a batch of tensor_dict 实例,本质就是调用了 tf.PaddingFIFOQueuedequeue 方法。

3.3. 数据读取流程

  • 获取图片数据以及bbox ground truth信息。
    • 如果是公开数据集,其实就是把数据集下载好。
    • 如果是自己的数据集,则需要把所有bbox打好。
  • 生成 tfrecords 文件。
    • 如果图片的 bbox 标签与 Pascal,COCO 等相同,则可以套用现有的脚本进行转换(当然,一般也需要少量修改)。
    • 如果 bbox 标签格式比较特殊,则需要自己写脚本转换,需要注意 tfrecords 中属性名与默认的相同。
  • 设置数据集相关参数。
    • 包括 tf.data.Dataset 相关操作,主要参考 input_reader.protodataset_builder
    • 包括数据增强相关配置,请参考 preprocessor_builderpreprocessor.protopreprocessor.py
  • 根据分布式训练需要,对数据集进行进一步处理。
    • 主要就是通过 batcher.BatchQueue 实现。
    原文作者:清欢守护者
    原文地址: https://zhuanlan.zhihu.com/p/38464485
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞