Inside TF-Slim(12) 使用 slim & tf.data在VOC2012数据集中训练VGG

0. 前言

  • 个人笔记:Inside TF-Slim(6) learning – part1,介绍tf.slim.learning的基本使用。
  • 个人笔记:TensorFlow API(2) tf.data,介绍tf.data的使用。
  • 个人笔记:TensorFlow API(3) tf.app.flags,介绍tf的命令行工具。
  • 源码地址
  • 目标:
    • 使用tf.data.Dataset来构建数据集。
    • 使用tensorflow models中的slim.nets搭建模型。
    • 使用slim.learning相关函数进行训练。
  • 其他功能:
    • 使用了tf.slim的 pre-trained model。
    • 每个epoch后,在验证集上评价模型(计算损失函数平均数以及准确率平均数)。保存验证集上准确率最高的模型。
    • 对 pre-trained model 进行 finetune:
      • 首先,固定VGG中的部分权重进行训练,即仅训练fc8层的权重。
      • 然后,对VGG中所有权重进行训练。
    • 图像增强:
      • 训练阶段:随机获取图片边长(范围为[256, 512])并resize,之后进行随机切片获取224*224的图片。
      • 预测阶段:图片resize为384*384。

1. 准备工作

  • 导入依赖包。
    • 除了使用tensorflow外,还需要下载tensorflow models,并导入/path/to/models/research/slim。
    • 设置日志信息。
    • 定义各种参数(包括命令行参数)。
# 导入依赖包
import sys
# 把下载好的tensorflow models中的slim路径,添加到系统路径下
sys.path.append("/home/ubuntu/models/research/slim")  
import nets.vgg as vgg  # 这个就是tensorflow models中的内容,定义了vgg网络结构
import tensorflow as tf
import tensorflow.contrib.slim as slim
import math
import os

# 设置日志信息
import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
logger.addHandler(ch)
# 各种参数定义
# VOC2012 分类信息,共有20类
CLASSES = ['aeroplane', 'bicycle', 'bird', 'boat',
           'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable',
           'dog', 'horse', 'motorbike', 'person', 'pottedplant',
           'sheep', 'sofa', 'train', 'tvmonitor']

# 命令行基本参数
tf.flags.DEFINE_integer('EPOCH_STAGE_1', 5, 'epochs to train only fc8')
tf.flags.DEFINE_integer('EPOCH_STAGE_2', 30, 'epochs to train all variables')
tf.flags.DEFINE_integer('BATCH_SIZE', 64, 'batch size')
tf.flags.DEFINE_float('WEIGHT_DECAY', 0.00005, 'l2 loss')
tf.flags.DEFINE_float('KEEP_PROB', 0.5, 'dropout layer')
tf.flags.DEFINE_float('MOMENTUM', 0.9, 'optimizer momentum')
tf.flags.DEFINE_string('VOC2012_ROOT', "/home/ubuntu/data/VOC2012", 'where to store logs')
tf.flags.DEFINE_integer('TRAIN_IMAGE_SIZE', 224, 'image size during training') 
tf.flags.DEFINE_integer('VAL_IMAGE_SIZE', 384, 'image size during evaluation')
tf.flags.DEFINE_integer('NUM_CLASSES', 20, 'how many classes in this dataset')
tf.flags.DEFINE_integer('LOGGING_EVERY_N_STEPS', 10, 'logging in console every n steps')
tf.flags.DEFINE_string('LOGS_DIR', './logs/', 'where to store logs')

# 学习率相关参数
tf.flags.DEFINE_float('LEARNING_RATE_START', 0.001, 'learning rate at epoch 0')
tf.flags.DEFINE_integer('DECAY_STEPS', 500, 'learning rate decay var')
tf.flags.DEFINE_float('DECAY_RATE', 0.5, 'learning rate decay var')

# pre-trained model相关
tf.flags.DEFINE_boolean('USE_PRE_TRAINED_MODEL', True, 'whether or not use slim pre-trained model.')
tf.flags.DEFINE_string('PRE_TRAINED_MODEL_PATH', '/home/ubuntu/data/slim/vgg_19.ckpt', 'pre-trained model path.')

FLAGS = tf.app.flags.FLAGS

2. 获取tf.data.Dataset实例

  • 实现数据增强:包括resize、切片、水平镜像三类操作。
def get_dataset(mode='train', resize_image_min=256, resize_image_max=512, image_size=224):
    """  获取数据集  :param image_size: 输出图片的尺寸  :param resize_image_max: 在切片前,将图片resize的最小尺寸  :param resize_image_min: 在切片前,将图片resize的最大尺寸  :param mode: 可以是 train val trainval 三者之一,对应于VOC数据集中的预先设定好的训练集、验证集  :return: 返回元组,第一个参数是 tf.data.Dataset实例,第二个是数据集中元素数量  """
    def get_image_paths_and_labels():
        # 从本地文件系统中,获取所有图片的绝对路径以及对应的标签
        if mode not in ['train', 'val', 'trainval']:
            raise ValueError('Unknown mode: {}'.format(mode))
        result_dict = {}
        for i, class_name in enumerate(CLASSES):
            file_name = class_name + "_" + mode + '.txt'
            for line in open(os.path.join(FLAGS.VOC2012_ROOT, 'ImageSets', 'Main', file_name), 'r'):
                line = line.replace(' ', ' ').replace('\n', '')
                parts = line.split(' ')
                if int(parts[1]) == 1:
                    result_dict[os.path.join(FLAGS.VOC2012_ROOT, 'JPEGImages', parts[0] + '.jpg')] = i
        keys, values = [], []
        for key, value in result_dict.items():
            keys.append(key)
            values.append(value)
        return keys, values

    def norm_imagenet(image):
        # 对每张图片减去 ImageNet RGB平均数
        means = [103.939, 116.779, 123.68]
        channels = tf.split(axis=2, num_or_size_splits=3, value=image)
        for i in range(3):
            channels[i] -= means[i]
        return tf.concat(axis=2, values=channels)

    def random_crop(images, cur_image_size):
        # 随机切片
        image_height = tf.shape(images)[-3]
        image_width = tf.shape(images)[-2]
        offset_height = tf.random_uniform([], 0, (image_height - cur_image_size + 1), dtype=tf.int32)
        offset_width = tf.random_uniform([], 0, (image_width - cur_image_size + 1), dtype=tf.int32)
        return tf.image.crop_to_bounding_box(images, offset_height, offset_width, cur_image_size, cur_image_size)

    def parse_image_by_path_fn(image_path):
        # 通过文件路径读取图片,并进行数据增强
        img_file = tf.read_file(image_path)
        cur_image = tf.image.decode_jpeg(img_file)
        cur_image = tf.image.resize_images(cur_image, [random_image_size, random_image_size])
        cur_image = tf.image.random_flip_left_right(cur_image)
        cur_image = random_crop(cur_image, image_size)
        cur_image = norm_imagenet(cur_image)
        return cur_image

    # 获取所有图片路径以及对应标签
    random_image_size = tf.random_uniform([], resize_image_min, resize_image_max, tf.int32)
    paths, labels = get_image_paths_and_labels()

    # 建立tf.data.Dataset实例
    images_dataset = tf.data.Dataset.from_tensor_slices(paths).map(parse_image_by_path_fn)
    labels_dataset = tf.data.Dataset.from_tensor_slices(labels)
    dataset = tf.data.Dataset.zip((images_dataset, labels_dataset))
    if mode == 'train':
        dataset = dataset.shuffle(buffer_size=len(paths))
    return dataset.batch(batch_size=FLAGS.BATCH_SIZE), len(paths)

3. 通过tf.slim建立VGG模型

  • 调用 tensorflow models 中已有模型,并设置l2 loss参数。
def get_vgg_model(x, is_training):
    with slim.arg_scope([slim.conv2d, slim.fully_connected],
                        activation_fn=tf.nn.relu,
                        weights_regularizer=slim.l2_regularizer(FLAGS.WEIGHT_DECAY),
                        biases_initializer=tf.zeros_initializer()):
        with slim.arg_scope([slim.conv2d], padding='SAME'):
            logits, _ = vgg.vgg_19(x,
                                   num_classes=FLAGS.NUM_CLASSES,
                                   is_training=is_training,
                                   dropout_keep_prob=FLAGS.KEEP_PROB,
                                   spatial_squeeze=True,
                                   scope='vgg_19',
                                   fc_conv_padding='VALID',
                                   global_pool=True)
    return logits

4. 构建slim.learning中最重要的 train_step 函数

  • 基本功能:进行一次梯度下降计算。
  • 其他功能:
    • 每个epoch训练完成后,对验证集进行评价。
    • 获取每个epoch的平均准确率。
  • 请注意:
    • 配合slim.learning.train中的train_step_kwargs参数,train_step函数可以实现各种扩展功能。
    • 但这就导致一个问题,代码写得很难看……
    • 这里是强行使用 slim.learning,我本人并不是特别喜欢。
  • 关于finetune分阶段训练的实现:
    • 这里用的是笨办法:train_step函数本身拥有一个train_op,在train_step_kwargs中传入另一个train_op,在函数中分别处理。
    • 后来发现,有更简单的办法能够实现这个功能:通过 tf.cond 将两个train_op合并为一个。
# train_step函数,必须有以下四个参数,且必须返回两个参数(total_loss, should_stop)
# 其中 train_step_kwargs 中可以自定义各种参数
def train_step_vgg(sess, train_op, global_step, train_step_kwargs):
    """  slim.learning中要调用的函数,代表一次梯度下降  :param sess:  :param train_op:  :param global_step:  :param train_step_kwargs:  :return:  """
    cur_global_step = sess.run(global_step)
    # 获取训练集数据,tf.data相关
    train_next_element = train_step_kwargs['train_next_element']
    try:
        cur_images, cur_labels = sess.run(train_next_element)
    except:
        # 如果获取失败,要么说明刚开始训练,要么说明一个epoch遍历结束

        if int(cur_global_step) != 0:
            # 如果是一个epoch遍历结束,则需要对验证集进行评估
            val_set_test(sess, global_step, train_step_kwargs)

        # 开始新一轮epoch训练
        train_iter_initializer = train_step_kwargs['train_iter_initializer']
        sess.run(train_iter_initializer)
        cur_images, cur_labels = sess.run(train_next_element)

        # 由于是计算每个epoch的性能指标,所以在epoch结束需要对各种性能指标清零
        sess.run(train_step_kwargs['reset_metrics_ops'])
    feed_dict = {train_step_kwargs['ph_x']: cur_images,
                 train_step_kwargs['ph_y']: cur_labels,
                 train_step_kwargs['ph_is_training']: True,
                 train_step_kwargs['ph_image_size']: FLAGS.TRAIN_IMAGE_SIZE}

    # 根据不同的train_op来处理finetune的两个阶段
    train_op = train_op if cur_global_step < train_step_kwargs['stage_1_steps'] else train_step_kwargs['train_op2']

    # 进行一次梯度下降训练
    cur_total_loss, cur_accuracy = sess.run(
        [train_op, train_step_kwargs['train_accuracy']],
        feed_dict=feed_dict)

    # 输出训练日志信息
    if cur_global_step % train_step_kwargs['logging_every_n_steps'] == 0:
        logger.info('step %d: loss is %.4f, accuracy is %.3f.' % (cur_global_step, cur_total_loss, cur_accuracy))

    # 结束训练时,对验证集进行操作
    if cur_global_step + 1 >= train_step_kwargs['max_steps']:
        val_set_test(sess, global_step, train_step_kwargs)

    # 必须输出两个值
    return cur_total_loss, cur_global_step + 1 >= train_step_kwargs['max_steps']


# 评估验证集
# 在 train_step 函数中,每个epoch结束时调用
# 获取平均loss和平均accuracy,并保存val accuray最好的模型
def val_set_test(sess, global_step, train_step_kwargs):
    # 调用tf.data相关接口,初始化数据集
    sess.run(train_step_kwargs['val_iter_initializer'])
    while True:
        try:
            cur_images, cur_labels = sess.run(train_step_kwargs['val_next_element'])
            feed_dict = {train_step_kwargs['ph_x']: cur_images,
                         train_step_kwargs['ph_y']: cur_labels,
                         train_step_kwargs['ph_is_training']: False,
                         train_step_kwargs['ph_image_size']: FLAGS.VAL_IMAGE_SIZE}
            # 使用tf.metrics获取的update_op,更新相关评价指标
            sess.run([train_step_kwargs['mean_val_loss_update_op'], train_step_kwargs['val_accuracy_update_op']],
                     feed_dict=feed_dict)
        except:
            break

    # 获取 loss 和 accuracy 在验证集上的平均数
    val_loss, val_accuracy = sess.run([train_step_kwargs['mean_val_loss'], train_step_kwargs['val_accuracy']])
    logger.info('epoch val loss is %.4f, val accuracy is %.2f' % (val_loss, val_accuracy))

    # 保存验证集上accuracy最高的模型
    if val_accuracy > train_step_kwargs['best_val_accuracy']:
        train_step_kwargs['best_val_accuracy'] = val_accuracy
        train_step_kwargs['saver'].save(sess, os.path.join(FLAGS.LOGS_DIR, 'model.ckpt'), global_step=global_step)
    return val_loss, val_accuracy

5. 主程序

  • 第一步:获取训练集与验证集(调用之前的get_dataset方法)以及对应的Iterator实例。
  • 第二步构建模型:
    • 定义placeholder。
    • 创建模型。
    • 定义损失函数与优化器。
    • metrics相关操作。
  • 第三步:导入pre-trained model。
  • 第四步:调用slim.learning.create_train_op创建train_op实例。
    • 注意,这里要创建两个train_op实例,分别代表finetune的两个阶段:阶段一,只训练VGG中的fc8;阶段二,训练VGG中的所有参数。
  • 第五步:调用slim.learning.train开始训练。
    • 这一步主要获取 train_step 中所需要的参数字典。
def main(_):
    # 1. 获取训练集与验证集
    train_set, train_set_size = get_dataset('train')
    train_iter = train_set.make_initializable_iterator()
    logger.info('train set created successfully with {} items.'.format(train_set_size))
    val_set, val_set_size = get_dataset('val',
                                        image_size=FLAGS.VAL_IMAGE_SIZE,
                                        resize_image_min=FLAGS.VAL_IMAGE_SIZE,
                                        resize_image_max=FLAGS.VAL_IMAGE_SIZE+1)
    val_iter = val_set.make_initializable_iterator()
    logger.info('val set created successfully with {} items.'.format(val_set_size))

    # 2. 构建模型

    # 2.1. 定义tf.placeholder
    ph_image_size = tf.placeholder(tf.int32)
    ph_x = tf.placeholder(tf.float32)
    ph_y = tf.placeholder(tf.int32, [None])
    ph_is_training = tf.placeholder(tf.bool)

    # 2.2. 创建模型
    logits = get_vgg_model(tf.reshape(ph_x, [-1, ph_image_size, ph_image_size, 3]), ph_is_training)

    # 2.3. 定义损失函数与优化器
    global_step = tf.train.get_or_create_global_step()
    learning_rate = tf.train.exponential_decay(FLAGS.LEARNING_RATE_START, global_step,
                                               decay_rate=FLAGS.DECAY_RATE, decay_steps=FLAGS.DECAY_STEPS) # 学习率衰减
    tf.losses.sparse_softmax_cross_entropy(logits=logits, labels=ph_y)
    total_loss = tf.losses.get_total_loss()
    optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=FLAGS.MOMENTUM)

    # 2.4. metrics相关操作
    accuracy, _ = tf.metrics.accuracy(ph_y, tf.argmax(tf.nn.softmax(logits), axis=1),
                                      updates_collections=tf.GraphKeys.UPDATE_OPS)
    val_accuracy, val_accuracy_update_op = tf.metrics.accuracy(ph_y, tf.argmax(tf.nn.softmax(logits), axis=1))
    mean_val_loss, mean_val_loss_update_op = tf.metrics.mean(total_loss, name='mean_val_loss')

    # 3. 导入pre-trained model
    # 需要先下载模型,并指定模型文件位置
    if FLAGS.USE_PRE_TRAINED_MODEL:
        variables_to_restore = slim.get_variables_to_restore(include=['vgg_19'], exclude=['vgg_19/fc8'])
        init_fn = slim.assign_from_checkpoint_fn(FLAGS.PRE_TRAINED_MODEL_PATH, variables_to_restore, True, True)
        logger.info('use pre-trained model with %d variables' % len(variables_to_restore))
    else:
        init_fn = None

    # 4. 创建 train_op
    # 根据 slim.learning.create_train_op 中的 variables_to_train 参数来指定需要训练的参数
    train_op1 = slim.learning.create_train_op(total_loss, optimizer,
                                              variables_to_train=slim.get_variables_to_restore(include=['vgg_19/fc8']),
                                              global_step=global_step)
    train_op2 = slim.learning.create_train_op(total_loss, optimizer, global_step=global_step)

    # 5. 通过 slim.learning.train 开始训练
    # 构练 train_step 函数所需参数
    def get_train_step_kwargs():
        """  获取 train_step 函数所需要的参数  """
        metrics = tf.get_collection(tf.GraphKeys.METRIC_VARIABLES)
        reset_metrics_ops = []
        for metric in metrics:
            reset_metrics_ops.append(tf.assign(metric, 0))  # 对所有metrics清零
        saver = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES))
        train_step_kwargs = {'max_steps': int(math.ceil(train_set_size / FLAGS.BATCH_SIZE)) * FLAGS.EPOCH_STAGE_2,
                             'logging_every_n_steps': FLAGS.LOGGING_EVERY_N_STEPS,
                             'train_next_element': train_iter.get_next(),
                             'train_iter_initializer': train_iter.initializer,
                             'train_accuracy': accuracy,
                             'ph_x': ph_x,
                             'ph_y': ph_y,
                             'ph_is_training': ph_is_training,
                             'ph_image_size': ph_image_size,
                             'reset_metrics_ops': reset_metrics_ops,
                             'val_next_element': val_iter.get_next(),
                             'val_iter_initializer': val_iter.initializer,
                             'mean_val_loss': mean_val_loss,
                             'mean_val_loss_update_op': mean_val_loss_update_op,
                             'val_accuracy': val_accuracy,
                             'val_accuracy_update_op': val_accuracy_update_op,
                             'saver': saver,
                             'best_val_accuracy': .0,
                             'train_op2': train_op2,
                             'stage_1_steps': int(math.ceil(train_set_size / FLAGS.BATCH_SIZE)) * FLAGS.EPOCH_STAGE_1  # finetune fc8的step数量
                             }
        return train_step_kwargs
    # 开始训练
    slim.learning.train(train_op1,
                        logdir=FLAGS.LOGS_DIR,
                        train_step_fn=train_step_vgg,
                        train_step_kwargs=get_train_step_kwargs(),
                        init_fn=init_fn,
                        global_step=global_step,
                        save_interval_secs=None,
                        save_summaries_secs=None)

6. 后记

  • 研究了源码,如果不写个实例之后都忘光了。
  • 写了实例之后发现,其实以后都不会来用这个东西。
  • 网上tf.slim.learning的实例代码基本都是照抄官方文档,使用的都是自带的train_step函数。
  • 在研究完源码后,花一天写了一个实例,自定义train_step函数,扩展部分功能。
  • 我太菜了,现阶段只能实现功能,对性能(比如tf.data)没有太多感受,希望进一步提升。
    原文作者:清欢守护者
    原文地址: https://zhuanlan.zhihu.com/p/36198988
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞