0. 前言
- 没找到什么研究相关源码的资料。
- 主要内容:
- 学习构建tfrecords。
- 学习如何通过tfrecords文件集构建数据集。
- 如果想通过 TensorFlow Object Detection API 来训练自己的数据集,个人建议:
- 主要了解 tfrecords 的构建方法,目标是将自己的数据构建为API可以识别的 tfrecords 格式。
- 对于
train.py
等脚本,对于数据的具体预处理过程了解即可。 - 对于配置文件内容需要非常熟悉。
1. 构建tfrecords
1.1. 综述
dataset_tools
包括两类功能:- 创建某类数据集对应 tfrecords 文件的脚本。
- 创建 tfrecords 文件的工具类。
- 提供对应脚本的数据集包括:
- 工具类介绍:
1.2. PASCAL 脚本解析
- 源码地址:create_pascal_tf_record.py
- 脚本使用实例:
data_dir
:数据集根目录。set
:可以是 train
, val
, trainval
, test
其中之一。annotations_dir
:标签文件夹名称,默认为Annotations
,相对于data_dir
的相对路径。year
:可以是 VOC2007
, VOC2012
, merged
其中之一。output_path
:生成tfrecords文件的路径。label_map_path
:用到的标签map文件。ignore_difficult_instances
:是否忽略有 difficult
标记的数据。
# 在 /path/to/models/research 路径下运行
python object_detection/dataset_tools/create_pascal_tf_record.py \
--data_dir=/home/user/VOCdevkit \
--set=train \
--annotations_dir=/path/to/annotations \
--year=VOC2007 \
--output_path=/home/user/pascal.record \
--label_map_path=data/pascal_label_map.pbtxt \
--ignore_difficult_instances=Fasle
def main(_):
# 输入参数合法性判断
if FLAGS.set not in SETS:
raise ValueError('set must be in : {}'.format(SETS))
if FLAGS.year not in YEARS:
raise ValueError('year must be in : {}'.format(YEARS))
data_dir = FLAGS.data_dir
years = ['VOC2007', 'VOC2012']
if FLAGS.year != 'merged':
years = [FLAGS.year]
# 创建 TFRecordWriter 对象,准备写入本地文件
writer = tf.python_io.TFRecordWriter(FLAGS.output_path)
# 解析 label_map_path 文件
label_map_dict = label_map_util.get_label_map_dict(FLAGS.label_map_path)
# 读取输入数据
for year in years:
logging.info('Reading from PASCAL %s dataset.', year)
# 读取输入图片列表
examples_path = os.path.join(data_dir, year, 'ImageSets', 'Main',
'aeroplane_' + FLAGS.set + '.txt')
examples_list = dataset_util.read_examples_list(examples_path)
# 获取标签绝对路径
annotations_dir = os.path.join(data_dir, year, FLAGS.annotations_dir)
# 依次处理所有文件
for idx, example in enumerate(examples_list):
if idx % 100 == 0:
logging.info('On image %d of %d', idx, len(examples_list))
# 解析 xml 标签文件,转换为字典
path = os.path.join(annotations_dir, example + '.xml')
with tf.gfile.GFile(path, 'r') as fid:
xml_str = fid.read()
xml = etree.fromstring(xml_str)
data = dataset_util.recursive_parse_xml_to_dict(xml)['annotation']
# 根据解析得到的字典,获取 tf.train.Example 对象,并将该对象写入本地文件
tf_example = dict_to_tf_example(data, FLAGS.data_dir, label_map_dict,
FLAGS.ignore_difficult_instances)
writer.write(tf_example.SerializeToString())
writer.close()
dict_to_tf_example
函数解析- 主要功能:读取图片数据,读取标签数据,构建
tf.train.Example
对象。
def dict_to_tf_example(data,
dataset_directory,
label_map_dict,
ignore_difficult_instances=False,
image_subdirectory='JPEGImages'):
# 获取图片绝对路径
img_path = os.path.join(data['folder'], image_subdirectory, data['filename'])
full_path = os.path.join(dataset_directory, img_path)
# 读取读片数据,并判断图片格式
with tf.gfile.GFile(full_path, 'rb') as fid:
encoded_jpg = fid.read()
encoded_jpg_io = io.BytesIO(encoded_jpg)
image = PIL.Image.open(encoded_jpg_io)
if image.format != 'JPEG':
raise ValueError('Image format not JPEG')
# 根据图片数据
key = hashlib.sha256(encoded_jpg).hexdigest()
# 获取 tfrecord 文件所需数据
width = int(data['size']['width'])
height = int(data['size']['height'])
xmin = []
ymin = []
xmax = []
ymax = []
classes = []
classes_text = []
truncated = []
poses = []
difficult_obj = []
if 'object' in data:
for obj in data['object']:
difficult = bool(int(obj['difficult']))
if ignore_difficult_instances and difficult:
continue
difficult_obj.append(int(difficult))
xmin.append(float(obj['bndbox']['xmin']) / width)
ymin.append(float(obj['bndbox']['ymin']) / height)
xmax.append(float(obj['bndbox']['xmax']) / width)
ymax.append(float(obj['bndbox']['ymax']) / height)
classes_text.append(obj['name'].encode('utf8'))
classes.append(label_map_dict[obj['name']])
truncated.append(int(obj['truncated']))
poses.append(obj['pose'].encode('utf8'))
# 构建 tf.train.Example 对象
example = tf.train.Example(features=tf.train.Features(feature={
'image/height': dataset_util.int64_feature(height),
'image/width': dataset_util.int64_feature(width),
'image/filename': dataset_util.bytes_feature(
data['filename'].encode('utf8')),
'image/source_id': dataset_util.bytes_feature(
data['filename'].encode('utf8')),
'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
'image/encoded': dataset_util.bytes_feature(encoded_jpg),
'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
'image/object/bbox/xmin': dataset_util.float_list_feature(xmin),
'image/object/bbox/xmax': dataset_util.float_list_feature(xmax),
'image/object/bbox/ymin': dataset_util.float_list_feature(ymin),
'image/object/bbox/ymax': dataset_util.float_list_feature(ymax),
'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
'image/object/class/label': dataset_util.int64_list_feature(classes),
'image/object/difficult': dataset_util.int64_list_feature(difficult_obj),
'image/object/truncated': dataset_util.int64_list_feature(truncated),
'image/object/view': dataset_util.bytes_list_feature(poses),
}))
return example
1.3. 构建自己的数据集
- 总体思路:
- 所有 tfrecords 文件中,主要的属性名称必须相同。
- 如果自己数据集的标签格式与 COCO, Pascal 等相同,则可以稍微修改一下相应脚本。
- 如果标签信息格式独特,那肯定需要自己写脚本来生成 tfrecords 文件,要求重点属性名相同。
- 重点属性名(查看 coco, pascal, kitti 等脚本中相同的属性名):
- 图片基本信息:
image/height
image/width
image/filename
image/source_id
image/key/sha256
:根据图片encoded生成的sha256编码。image/encoded
:图片数据本身。image/format
:图片格式。
- bbox信息:
image/object/bbox/xmin
image/object/bbox/xmax
image/object/bbox/ymin
image/object/bbox/ymax
image/object/class/text
:分类标签文本。image/object/class/label
:分类标签对应数字。image/object/difficult
- 完整版属性名可以参考:standard_fields.py 中的
TfExampleFields
。
2. 通过 tfrecords 文件构建数据集对象
2.1. 综述
- 在看了
train.py
以及 eval.py
的源码后: - 构建数据集主要通过
dataset_builder
构建。对应的配置文件是 input_reader.proto
,返回参数是一个 tf.data.Dataset
对象。 - 之后,以
tf.data.Dataset
为输入,通过 dataset_util.make_initializable_iterator
创建 make_initializable_iterator
结果,并将对应的 iterator.initializer
添加到 tf.GraphKeys.TABLE_INITIALIZERS
中,以iterator.get_next()
作为后续操作的输入。 train.py
中额外包含了数据增强部分,使用 preprocessor_builder
,对应的配置文件是 preprocessor.proto
。
2.2. 通过 dataset_builder 构建 tf.data.Dataset 对象
def build(input_reader_config,
transform_input_data_fn=None,
batch_size=None,
max_num_boxes=None,
num_classes=None,
spatial_image_shape=None,
num_additional_channels=0):
# 判断配置文件合法性
if not isinstance(input_reader_config, input_reader_pb2.InputReader):
raise ValueError('input_reader_config not of type '
'input_reader_pb2.InputReader.')
# 以 tfrecords 文件作为输入
if input_reader_config.WhichOneof('input_reader') == 'tf_record_input_reader':
# 读取配置文件信息
config = input_reader_config.tf_record_input_reader
if not config.input_path:
raise ValueError('At least one input path must be specified in '
'`input_reader_config`.')
label_map_proto_file = None
if input_reader_config.HasField('label_map_path'):
label_map_proto_file = input_reader_config.label_map_path
# 根据配置文件内容,构建 TfExampleDecoder 对象
decoder = tf_example_decoder.TfExampleDecoder(
load_instance_masks=input_reader_config.load_instance_masks,
instance_mask_type=input_reader_config.mask_type,
label_map_proto_file=label_map_proto_file,
use_display_name=input_reader_config.use_display_name,
num_additional_channels=num_additional_channels)
# 数据预处理函数
# 功能:解析 tfrecords,并调用自定义的 transform_input_data_fn 函数处理原始数据
def process_fn(value):
processed = decoder.decode(value)
if transform_input_data_fn is not None:
return transform_input_data_fn(processed)
return processed
# 根据配置生成 tf.data.Dataset 对象
dataset = dataset_util.read_dataset(
functools.partial(tf.data.TFRecordDataset, buffer_size=8 * 1000 * 1000),
process_fn, config.input_path[:], input_reader_config)
# 如果需要 batch 数据,则需要将所有图片都保持在同一尺寸
# 注意,并不是进行 resize,而是通过 padding
# 具体可以参考 tf.contrib.data.padded_batch_and_drop_remainder 或 tf.data.Dataset.padded_batch
if batch_size:
padding_shapes = _get_padding_shapes(dataset, max_num_boxes, num_classes,
spatial_image_shape)
dataset = dataset.apply(
tf.contrib.data.padded_batch_and_drop_remainder(batch_size,
padding_shapes))
return dataset
raise ValueError('Unsupported input_reader_config.')
dataset_util.read_dataset
方法:- 功能:实际生成
tf.data.Dataset
实例。
def read_dataset(file_read_func, decode_func, input_files, config):
# 获取所有输入文件名称
filenames = tf.gfile.Glob(input_files)
num_readers = config.num_readers
if num_readers > len(filenames):
num_readers = len(filenames)
tf.logging.warning('num_readers has been reduced to %d to match input file '
'shards.' % num_readers)
# 以文件名队列作为输入,构建 tf.data.Dataset 对象
# 主要包括 from_tensor_slices shuffle repeat parallel_interleave map prefetch 操作
filename_dataset = tf.data.Dataset.from_tensor_slices(tf.unstack(filenames))
if config.shuffle:
filename_dataset = filename_dataset.shuffle(
config.filenames_shuffle_buffer_size)
elif num_readers > 1:
tf.logging.warning('`shuffle` is false, but the input data stream is '
'still slightly shuffled since `num_readers` > 1.')
filename_dataset = filename_dataset.repeat(config.num_epochs or None)
records_dataset = filename_dataset.apply(
tf.contrib.data.parallel_interleave(
file_read_func,
cycle_length=num_readers,
block_length=config.read_block_length,
sloppy=config.shuffle))
if config.shuffle:
records_dataset = records_dataset.shuffle(config.shuffle_buffer_size)
# 该方法包括解析 tfrecords 文件
# 并进行用户自定义数据转换(不是 model.preprocess,也不是数据增强)
tensor_dataset = records_dataset.map(
decode_func, num_parallel_calls=config.num_parallel_map_calls)
return tensor_dataset.prefetch(config.prefetch_size)
2.3. 训练过程数据的后续处理
- 源码地址:trainer.py
- 流程:
- 根据配置文件,进行数据增强。
- 构建
batcher.BatchQueue
实例用于多GPU训练。
- 数据增强细节请参考
preprocessor_builder
,preprocessor.proto
,preprocessor.py
。 - 三者的关系大概是:主程序通过
preprocessor.proto
配置文件信息,通过 preprocessor_builder
返回 preprocessor.py
中对应的方法。
trainer.py
中create_input_queue
函数解析。
def create_input_queue(batch_size_per_clone, create_tensor_dict_fn,
batch_queue_capacity, num_batch_queue_threads,
prefetch_queue_capacity, data_augmentation_options):
# create_tensor_dict_fn 的本质就是 tf.data.Dataset().make_initializable_iterator.get_next()
tensor_dict = create_tensor_dict_fn()
# 获取图片数据并简单转换
tensor_dict[fields.InputDataFields.image] = tf.expand_dims(
tensor_dict[fields.InputDataFields.image], 0)
images = tensor_dict[fields.InputDataFields.image]
float_images = tf.to_float(images)
tensor_dict[fields.InputDataFields.image] = float_images
# 获取各种标志位
include_instance_masks = (fields.InputDataFields.groundtruth_instance_masks
in tensor_dict)
include_keypoints = (fields.InputDataFields.groundtruth_keypoints
in tensor_dict)
include_multiclass_scores = (fields.InputDataFields.multiclass_scores
in tensor_dict)
# 进行数据增强
if data_augmentation_options:
tensor_dict = preprocessor.preprocess(
tensor_dict, data_augmentation_options,
func_arg_map=preprocessor.get_default_func_arg_map(
include_multiclass_scores=include_multiclass_scores,
include_instance_masks=include_instance_masks,
include_keypoints=include_keypoints))
# 生成 batcher.BatchQueue 方法用于分布式训练
input_queue = batcher.BatchQueue(
tensor_dict,
batch_size=batch_size_per_clone,
batch_queue_capacity=batch_queue_capacity,
num_batch_queue_threads=num_batch_queue_threads,
prefetch_queue_capacity=prefetch_queue_capacity)
return input_queue
3. 流水账
3.1. 为什么 dataset.make_initializable_iterator().get_next()
的返回值是一个字典?
# 项目中源码
def get_next(config):
return dataset_util.make_initializable_iterator(
dataset_builder.build(config)).get_next()
create_input_dict_fn = functools.partial(get_next, input_config)
- 猜测与
TfExampleDecoder.decode
有关: - 在创建
tf.data.Dataset
对象时,调用了该方法。 - 该方法的返回值是一个字典。
3.2. batcher.BatchQueue 的使用
- 源码地址:batcher.py
- 官方实例(Example input pipeline with batching):
key, string_tensor = slim.parallel_reader.parallel_read(...)
tensor_dict = decoder.decode(string_tensor)
tensor_dict = preprocessor.preprocess(tensor_dict, ...)
batch_queue = batcher.BatchQueue(tensor_dict,
batch_size=32,
batch_queue_capacity=2000,
num_batch_queue_threads=8,
prefetch_queue_capacity=20)
tensor_dict = batch_queue.dequeue()
outputs = Model(tensor_dict)
- 原理:
- 初始化实例:通过
tf.train.batch
进行batch操作,并通过 prefetcher.prefetch
构建 tf.PaddingFIFOQueue
对象。 dequeue
方法:获取 a batch of tensor_dict 实例,本质就是调用了 tf.PaddingFIFOQueue
的 dequeue
方法。
3.3. 数据读取流程
- 获取图片数据以及bbox ground truth信息。
- 如果是公开数据集,其实就是把数据集下载好。
- 如果是自己的数据集,则需要把所有bbox打好。
- 生成 tfrecords 文件。
- 如果图片的 bbox 标签与 Pascal,COCO 等相同,则可以套用现有的脚本进行转换(当然,一般也需要少量修改)。
- 如果 bbox 标签格式比较特殊,则需要自己写脚本转换,需要注意 tfrecords 中属性名与默认的相同。
- 设置数据集相关参数。
- 包括
tf.data.Dataset
相关操作,主要参考 input_reader.proto
和 dataset_builder
。 - 包括数据增强相关配置,请参考
preprocessor_builder
,preprocessor.proto
,preprocessor.py
。
- 根据分布式训练需要,对数据集进行进一步处理。
- 主要就是通过
batcher.BatchQueue
实现。