2.2 MobileNet V1思考

本文来源于学习知乎文章 CNN模型之MobileNet

         小而高效的CNN有两个方向:一是对训练好的复杂模型进行压缩得到小模型;二是直接设计小模型并进行训练。本文要介绍的MobileNet是后者。

原则:保持模型性能(accuracy)的前提下降低模型大小(parameters size),同时提升模型速度(speed, low latency)。

我的文章的SqueezeNet和ShuffleNet也是基于这个原则。这个方向非常有前景,尤其在工业界非常有用。

1 深度级可分离卷积(Depthwise separable convolution)

         MobileNet的基本单元是深度级可分离卷积(depthwise separable convolution)(类似于SqueezeNet中的fire模块)。深度级可分离卷积其实是一种可分解卷积操作(factorized convolutions),其可以分解为两个更小的操作:

  • depthwise convolution
  • pointwise convolution,

《2.2 MobileNet V1思考》 图1 Depthwise separable convolution

如图1所示。

  • Depthwise convolution(DC)和标准卷积不同。标准卷积的卷积核是用在所有的输入通道上(input channels)。depthwise convolution针对每个输入通道采用不同的卷积核,就是说一个卷积核对应一个输入通道,所以说depthwise convolution是depth级别的操作。
  • pointwise convolution(PC)其实就是普通的卷积,只不过其采用《2.2 MobileNet V1思考》的卷积核。

《2.2 MobileNet V1思考》 图2 Depthwise convolution和pointwise convolution

         图2中更清晰地展示了两种操作。

  • depthwise separable convolution,首先是采用depthwise convolution对不同输入通道分别进行卷积
  • pointwise convolution将上面的输出再进行结合,这样其实整体效果和一个标准卷积是差不多的,但是会大大减少计算量和模型参数量。
             奇技淫巧的出发点是我们在文首提到的原则。为了达到这个目的,使用了各种奇技淫巧。
             下面分析下Depthwise separable convolution和标准卷积的区别。
             假设输入特征图大小为《2.2 MobileNet V1思考》,输出特征图大小为《2.2 MobileNet V1思考》。这里假设输入特征图和输出特征图的height和width相等。
             对于标准卷积,其计算量为
    《2.2 MobileNet V1思考》
             对于depthwise separable convolution而言,其计算量为
    《2.2 MobileNet V1思考》
             上面两个计算量相除得到
    《2.2 MobileNet V1思考》
             由上式可知,若卷积核大小为《2.2 MobileNet V1思考》,则运算量下降9倍。而且这个下降倍数随着《2.2 MobileNet V1思考》的增加而变大。
             标准卷积的参数数量为
    《2.2 MobileNet V1思考》
             depthwise separable convolution的参数数量为
    《2.2 MobileNet V1思考》
    上述两式相除得到
    《2.2 MobileNet V1思考》

2 MobileNet网络结构

         MobileNet在使用depthwise separable convolution构建网络的时候会大量使用batchnorm,这不利于硬件进行并行运算。使用BN和ReLU的depthwise separable convolution的基本结构如下图所示。

《2.2 MobileNet V1思考》 图3 加入BN和ReLU的depthwise separable convolution
《2.2 MobileNet V1思考》 图4 MobileNet的网络结构

         MobileNet的网络结构如图4所示。首先是一个3×3的标准卷积,然后后面就是堆积depthwise separable convolution,并且可以看到其中的部分depthwise convolution会通过strides=2进行down sampling。然后采用average pooling将feature变成《2.2 MobileNet V1思考》,根据预测类别大小加上全连接层,最后是一个softmax层。

《2.2 MobileNet V1思考》 图5 MobileNet网络的计算与参数分布

         如果单独计算depthwise convolution和pointwise convolution,整个网络有28层(这里Avg Pool和Softmax不计算在内)。我们还可以分析整个网络的参数和计算量分布,如图5所示。可以看到整个计算量基本集中在《2.2 MobileNet V1思考》卷积上,如果你熟悉卷积底层实现的话,卷积一般通过一种im2col方式实现,其需要内存重组,但是当卷积核为《2.2 MobileNet V1思考》时,其实就不需要这种操作了,底层可以有更快的实现。对于参数也主要集中在《2.2 MobileNet V1思考》卷积,除此之外还有就是全连接层占了一部分参数。

MobileNet改进
目的:在MobileNet基准模型的基础上再想得到更小的模型。
主要方法是引入了两个超参数:

  • width multiplier
  • resolution multiplier

width multiplier为《2.2 MobileNet V1思考》,且《2.2 MobileNet V1思考》,于是depthwise separable convolution的计算量为:
《2.2 MobileNet V1思考》
resolution multiplier为《2.2 MobileNet V1思考》,且《2.2 MobileNet V1思考》,于是depthwise separable convolution的计算量为:
《2.2 MobileNet V1思考》

3 MobileNet的tensorflow实现

         tensorflow中内置了depthwise convolution算子tf.nn.depthwise_conv2d。
先实现depthwise_separable_convolution子模块

def _depthwise_separable_conv2d(inputs, num_filters, width_multiplier, scope, downsample=False):
  num_filters = round(num_filters * width_multiplier)
  strides = 2 if downsample else 0
  
  with tf.variable_scope(scope):
    dw_conv = depthwise_conv2d(inputs, "depthwise_conv", strides = strides)
    bn = batchnorm(dw_conv, "dw_bn", is_training = True)
    relu = tf.nn.relu(bn)
    pw_conv = conv2d(relu, "pointwise_conv", num_filters)
    bn = batchnorm(pw_conv, "pw_bn", is_training=True)
    return tf.nn.relu(bn)

利用上面的_depthwise_separable_conv2d来构建MobileNet。

def MobileNet(inputs, num_classes, width_multiplier, scope = "MobileNet"):
  with tf.variable_scope(scope):
    net = conv2d(inputs, "conv_1", round(32 * width_multiplier), filter_size=3, strides=2)               # ->[N, 112, 112, 32]
    net = tf.nn.relu(bacthnorm(net, "conv_1/bn", is_training=self.is_training))
    net = self._depthwise_separable_conv2d(net, 64, width_multiplier, "ds_conv_2")                       # ->[N, 112, 112, 64]
    net = self._depthwise_separable_conv2d(net, 128, width_multiplier, "ds_conv_3", downsample=True)     # ->[N, 56, 56, 128]
    net = self._depthwise_separable_conv2d(net, 128, width_multiplier, "ds_conv_4")                      # ->[N, 56, 56, 128]
    net = self._depthwise_separable_conv2d(net, 256, width_multiplier, "ds_conv_5", downsample=True)     # ->[N, 28, 28, 256]
    net = self._depthwise_separable_conv2d(net, 256, width_multiplier, "ds_conv_6")                      # ->[N, 28, 28, 256]
    net = self._depthwise_separable_conv2d(net, 512, width_multiplier, "ds_conv_7", downsample=True)     # ->[N, 14, 14, 512]
    net = self._depthwise_separable_conv2d(net, 512, width_multiplier, "ds_conv_8")                      # ->[N, 14, 14, 512]
    net = self._depthwise_separable_conv2d(net, 512, width_multiplier, "ds_conv_9")                      # ->[N, 14, 14, 512]
    net = self._depthwise_separable_conv2d(net, 512, self.width_multiplier, "ds_conv_10")                # ->[N, 14, 14, 512]
    net = self._depthwise_separable_conv2d(net, 512, width_multiplier, "ds_conv_11")                     # ->[N, 14, 14, 512]
    net = self._depthwise_separable_conv2d(net, 512, width_multiplier, "ds_conv_12")                     # ->[N, 14, 14, 512]
    net = self._depthwise_separable_conv2d(net, 1024, width_multiplier, "ds_conv_13", downsample=True)   # ->[N, 7, 7, 1024]
    net = self._depthwise_separable_conv2d(net, 1024, width_multiplier, "ds_conv_14")                    # ->[N, 7, 7, 1024]
    net = avg_pool(net, 7, "avg_pool_15")
    net = tf.squeeze(net, [1, 2], name="SpatialSqueeze")
    self.logits = fully_connected(net, num_classes, "fc_16")
    self.predictions = tf.nn.softmax(self.logits)

4 MobileNet的缺点分析

         一点愚见:

  • 依然使用了batchnorm,导致在FPGA中难以实现并行处理
  • 使用了平均池化,对增加计算量

5 总结

         本文简单介绍了Google提出的移动端模型MobileNet,其核心是采用了可分解的depthwise separable convolution,其不仅可以降低模型计算复杂度,而且可以大大降低模型大小。在真实的移动端应用场景,像MobileNet这样类似的网络将是持续研究的重点。后面我们会介绍其他的移动端CNN模型。

    原文作者:深度学习模型优化
    原文地址: https://www.jianshu.com/p/708f4518bdc7
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞