PyTorch 中的 ModuleList 和 Sequential: 区别和使用场景

PyTorch 中有一些基础概念在构建网络的时候很重要,比如 nn.Module, nn.ModuleList, nn.Sequential,这些类我们称之为容器 (containers),因为我们可以添加模块 (module) 到它们之中。这些容器之间很容易混淆,本文中我们主要学习一下 nn.ModuleList 和 nn.Sequential,并判断在什么时候用哪一个比较合适。

本文中的例子使用的是 PyTorch 1.0 版本。第一次在知乎写有关编程的文章,如果文章中有错误或者没有说明白的地方,欢迎在评论区指正和讨论。

nn.ModuleList

首先说说 nn.ModuleList 这个类,你可以把任意 nn.Module 的子类 (比如 nn.Conv2d, nn.Linear 之类的) 加到这个 list 里面,方法和 Python 自带的 list 一样,无非是 extend,append 等操作。但不同于一般的 list,加入到 nn.ModuleList 里面的 module 是会自动注册到整个网络上的,同时 module 的 parameters 也会自动添加到整个网络中。描述看起来很枯燥,我们来看几个例子。

第一个网络,我们先来看看使用 nn.ModuleList 来构建一个小型网络,包括3个全连接层:

class net1(nn.Module):
    def __init__(self):
        super(net1, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(10,10) for i in range(2)])
    def forward(self, x):
        for m in self.linears:
            x = m(x)
        return x

net = net1()
print(net)
# net1(
# (modules): ModuleList(
# (0): Linear(in_features=10, out_features=10, bias=True)
# (1): Linear(in_features=10, out_features=10, bias=True)
# )
# )

for param in net.parameters():
    print(type(param.data), param.size())
# <class 'torch.Tensor'> torch.Size([10, 10])
# <class 'torch.Tensor'> torch.Size([10])
# <class 'torch.Tensor'> torch.Size([10, 10])
# <class 'torch.Tensor'> torch.Size([10])

我们可以看到,这个网络包含两个全连接层,他们的权重 (weithgs) 和偏置 (bias) 都在这个网络之内。接下来我们看看第二个网络,它使用 Python 自带的 list:

class net2(nn.Module):
    def __init__(self):
        super(net2, self).__init__()
        self.linears = [nn.Linear(10,10) for i in range(2)]
    def forward(self, x):
        for m in self.linears:
            x = m(x)
        return x

net = net2()
print(net)
# net2()
print(list(net.parameters()))
# []

显然,使用 Python 的 list 添加的全连接层和它们的 parameters 并没有自动注册到我们的网络中。当然,我们还是可以使用 forward 来计算输出结果。但是如果用 net2 实例化的网络进行训练的时候,因为这些层的 parameters 不在整个网络之中,所以其网络参数也不会被更新,也就是无法训练。

好,看到这里,我们大致明白了 nn.ModuleList 是干什么的了:它是一个储存不同 module,并自动将每个 module 的 parameters 添加到网络之中的容器。但是,我们需要注意到,nn.ModuleList 并没有定义一个网络,它只是将不同的模块储存在一起,这些模块之间并没有什么先后顺序可言,比如:

class net3(nn.Module):
    def __init__(self):
        super(net3, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(10,20), nn.Linear(20,30), nn.Linear(5,10)])
    def forward(self, x):
        x = self.linears[2](x)
        x = self.linears[0](x)
        x = self.linears[1](x) 
        return x

net = net3()
print(net)
# net3(
# (linears): ModuleList(
# (0): Linear(in_features=10, out_features=20, bias=True)
# (1): Linear(in_features=20, out_features=30, bias=True)
# (2): Linear(in_features=5, out_features=10, bias=True)
# )
# )
input = torch.randn(32, 5)
print(net(input).shape)
# torch.Size([32, 30])

根据 net3 的结果,我们可以看出来这个 ModuleList 里面的顺序并不能决定什么,网络的执行顺序是根据 forward 函数来决定的。如果你非要 ModuleList 和 forward 中的顺序不一样, PyTorch 表示它无所谓,但以后 review 你代码的人可能会意见比较大。

我们再考虑另外一种情况,既然这个 ModuleList 可以根据序号来调用,那么一个模块是否可以在 forward 函数中被调用多次呢?答案当然是可以的,但是,被调用多次的模块,是使用同一组 parameters 的,也就是它们的参数是共享的,无论之后怎么更新。例子如下,虽然在 forward 中我们用了 nn.Linear(10,10) 两次,但是它们只有一组参数。这么做有什么用处呢,我目前没有想到…

class net4(nn.Module):
    def __init__(self):
        super(net4, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(5, 10), nn.Linear(10, 10)])
    def forward(self, x):
        x = self.linears[0](x)
        x = self.linears[1](x)
        x = self.linears[1](x)
        return x

net = net4()
print(net)
# net4(
# (linears): ModuleList(
# (0): Linear(in_features=5, out_features=10, bias=True)
# (1): Linear(in_features=10, out_features=10, bias=True)
# )
# )
for name, param in net.named_parameters():
    print(name, param.size())
# linears.0.weight torch.Size([10, 5])
# linears.0.bias torch.Size([10])
# linears.1.weight torch.Size([10, 10])
# linears.1.bias torch.Size([10])

nn.Sequential

现在我们来研究一下 nn.Sequential,不同于 nn.ModuleList,它已经实现了内部的 forward 函数,而且里面的模块必须是按照顺序进行排列的,所以我们必须确保前一个模块的输出大小和下一个模块的输入大小是一致的,如下面的例子所示:

class net5(nn.Module):
    def __init__(self):
        super(net5, self).__init__()
        self.block = nn.Sequential(nn.Conv2d(1,20,5),
                                    nn.ReLU(),
                                    nn.Conv2d(20,64,5),
                                    nn.ReLU())
    def forward(self, x):
        x = self.block(x)
        return x

net = net5()
print(net)
# net5(
# (block): Sequential(
# (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
# (1): ReLU()
# (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
# (3): ReLU()
# )
# )

下面给出了两个 nn.Sequential 初始化的例子,来自于 官网教程。在第二个初始化中我们用到了 OrderedDict 来指定每个 module 的名字,而不是采用默认的命名方式 (按序号 0,1,2,3…) 。

# Example of using Sequential
model1 = nn.Sequential(
          nn.Conv2d(1,20,5),
          nn.ReLU(),
          nn.Conv2d(20,64,5),
          nn.ReLU()
        )
print(model1)
# Sequential(
# (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
# (1): ReLU()
# (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
# (3): ReLU()
# )

# Example of using Sequential with OrderedDict
import collections
model2 = nn.Sequential(collections.OrderedDict([
          ('conv1', nn.Conv2d(1,20,5)),
          ('relu1', nn.ReLU()),
          ('conv2', nn.Conv2d(20,64,5)),
          ('relu2', nn.ReLU())
        ]))
print(model2)
# Sequential(
# (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
# (relu1): ReLU()
# (conv2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
# (relu2): ReLU()
# )

有同学可能发现了,诶,你这个 model1 和 从类 net5 实例化来的 net 有什么区别吗?是没有的。这两个网络是相同的,因为 nn.Sequential 就是一个 nn.Module 的子类,也就是 nn.Module 所有的方法 (method) 它都有。并且直接使用 nn.Sequential 不用写 forward 函数,因为它内部已经帮你写好了。

这时候有同学该说了,既然 nn.Sequential 这么好,我以后都直接用它了。如果你确定 nn.Sequential 里面的顺序是你想要的,而且不需要再添加一些其他处理的函数 (比如 nn.functional 里面的函数,nn 与 nn.functional 有什么区别? ),那么完全可以直接用 nn.Sequential。这么做的代价就是失去了部分灵活性,毕竟不能自己去定制 forward 函数里面的内容了。

一般情况下 nn.Sequential 的用法是来组成卷积块 (block),然后像拼积木一样把不同的 block 拼成整个网络,让代码更简洁,更加结构化。

nn.ModuleList 和 nn.Sequential: 到底该用哪个

前边我们已经简单介绍了这两个类,现在我们来讨论一下在两个不同的场景中,选择哪一个比较合适。

场景一,有的时候网络中有很多相似或者重复的层,我们一般会考虑用 for 循环来创建它们,而不是一行一行地写,比如:

layers = [nn.Linear(10, 10) for i in range(5)]

这个时候,很自然而然地,我们会想到使用 ModuleList,像这样:

class net6(nn.Module):
    def __init__(self):
        super(net6, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(3)])

    def forward(self, x):
        for layer in self.linears:
            x = layer(x)
        return x

net = net6()
print(net)
# net6(
# (linears): ModuleList(
# (0): Linear(in_features=10, out_features=10, bias=True)
# (1): Linear(in_features=10, out_features=10, bias=True)
# (2): Linear(in_features=10, out_features=10, bias=True)
# )
# )

这个是比较一般的方法,但如果不想这么麻烦,我们也可以用 Sequential 来实现,如 net7 所示!注意 * 这个操作符,它可以把一个 list 拆开成一个个独立的元素。但是,请注意这个 list 里面的模块必须是按照想要的顺序来进行排列的。在 场景一 中,我个人觉得使用 net7 这种方法比较方便和整洁。

class net7(nn.Module):
    def __init__(self):
        super(net7, self).__init__()
        self.linear_list = [nn.Linear(10, 10) for i in range(3)]
        self.linears = nn.Sequential(*self.linears_list)

    def forward(self, x):
        self.x = self.linears(x)
        return x

net = net7()
print(net)
# net7(
# (linears): Sequential(
# (0): Linear(in_features=10, out_features=10, bias=True)
# (1): Linear(in_features=10, out_features=10, bias=True)
# (2): Linear(in_features=10, out_features=10, bias=True)
# )
# )

下面我们考虑 场景二,当我们需要之前层的信息的时候,比如 ResNets 中的 shortcut 结构,或者是像 FCN 中用到的 skip architecture 之类的,当前层的结果需要和之前层中的结果进行融合,一般使用 ModuleList 比较方便,一个非常简单的例子如下:

class net8(nn.Module):
    def __init__(self):
        super(net8, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(10, 20), nn.Linear(20, 30), nn.Linear(30, 50)])
        self.trace = []

    def forward(self, x):
        for layer in self.linears:
            x = layer(x)
            self.trace.append(x)
        return x

net = net8()
input  = torch.randn(32, 10) # input batch size: 32
output = net(input)
for each in net.trace:
    print(each.shape)
# torch.Size([32, 20])
# torch.Size([32, 30])
# torch.Size([32, 50])

我们使用了一个 trace 的列表来储存网络每层的输出结果,这样如果以后的层要用的话,就可以很方便地调用了。

总结

本文中我们通过一些实例学习了 ModuleList 和 Sequential 这两种 nn containers,ModuleList 就是一个储存各种模块的 list,这些模块之间没有联系,没有实现 forward 功能,但相比于普通的 Python list,ModuleList 可以把添加到其中的模块和参数自动注册到网络上。而Sequential 内的模块需要按照顺序排列,要保证相邻层的输入输出大小相匹配,内部 forward 功能已经实现,可以使代码更加整洁。在不同场景中,如果二者都适用,那就看个人偏好了。非常推荐大家看一下 PyTorch 官方的 TorchVision 下面模型实现的代码,能学到很多构建网络的技巧。

Reference:

  1. nn.ModuleList 和 Sequential 由来、用法和实例 —— 写网络模型
  2. The difference in usage between nn.ModuleList and python list
  3. When should I use nn.ModuleList and when should I use nn.Sequential?
  4. Pytorch: how and when to use Module, Sequential, ModuleList and ModuleDict
    原文作者:xiaopl
    原文地址: https://zhuanlan.zhihu.com/p/64990232
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞