0. 本章内容
在之前已经介绍了torch.nn这个包,其主要负责快速构建常用的深度学习网络,如卷积层等。但很多时候,我们需要自己定义一个torch.nn中并未实现的网络层,以使得代码更加模块化。这时候就需要我们自己拓展torch.nn.Modules了。本章主要介绍如何通过继承torch.nn.Modules类来定义一个新的网络层。
1. torch.nn.Modules回顾
- torch.nn.Modules相当于是对网络某种层的封装,包括网络结构以及网络参数,和其他有用的操作如输出参数
- torch.nn包中的各个类实际上就是由torch.nn.Modules继承而拓展
- pytorch中有两种拓展pytorch方式,一种是拓展Function(后面将介绍),一种是就是拓展Modules
2. 如何利用torch.nn.Modules进行拓展
- 进行拓展,需要继承Modules类,并实现__init__()方法,以及forward()方法
- __init__()方法,用于定义一些新的属性,这些属性可以包括Modules的实例,如一个torch.nn.Conv2d。即创建该网络中的子网络,在创建这些子网络时,这些网络的参数也被初始化
# 在module初始化时,执行__init__()
def __init__(self):
# 调用Module的初始化
super(MyBlock, self).__init__()
# 创建将要调用的子层(Module),注意:此时还并未实现MyBlock网络的结构(即forward运算),只是初始化了其子层(结构+参数)
self.conv1 = nn.Conv2d(1, 3, 3)
self.conv2 = nn.Conv2d(3, 3, 3)
- forward()方法,用于定义该Module进行forward时的运算,forward()方法接受一个输入,然后通过其他modules或者其他Function运算,来进行forward,返回一个输出结果
# 在定义好forward后,该module调用forward,将按forward进行前向传播,并构建网络
def forward(self, x):
# 这里relu与pool层选择用Function来实现(后续将介绍Function,可以认为Function是一种运算),而不使用Module,用Module也可以
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2)
return x
3. 作用
- Module的作用就是可以结构化定义网络的层,并提供对该层的封装,包括该层的结构,参数以及一些其他操作
- Module中的forward可以使用其他Module,其在调用forward时,其内部其他Module将按顺序进行forward(具体见例子)
4. 例子
构建一个简单的类vggblock
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
class MyBlock(nn.Module):
def __init__(self):
# 调用Module的初始化
super(MyBlock, self).__init__()
# 创建将要调用的子层(Module),注意:此时还并未实现MyBlock网络的结构,只是初始化了其子层(结构+参数)
self.conv1 = nn.Conv2d(1, 3, 3)
self.conv2 = nn.Conv2d(3, 3, 3)
def forward(self, x):
# 这里relu与pool层选择用Function来实现,而不使用Module,用Module也可以
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2)
return x
# 实例化一个新建的网络
test_MyBlock = MyBlock()
# 可以看Module中的子Module
print test_MyBlock
# MyBlock (
# (conv1): Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1))
# (conv2): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
# )
# 可以通过Module中的多种方法,实现输出参数等功能
print(test_MyBlock.state_dict().keys())
# ['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias']
# 可以直接对Module中的子Module进行修改
test_MyBlock.conv1 = nn.Conv2d(1, 3, 5)
print test_MyBlock.conv1
# ['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias']
# 随机生成一个输入
x = torch.randn(1, 1, 10, 10)
# 进行forward操作,建立网络,此处可以直接使用test_MyBlock(x)
test_MyBlock.forward(x)
利用上面构建的Block构建一个简单的vgg网络
class SimpleVgg(nn.Module):
def __init__(self):
super(SimpleVgg, self).__init__()
# 利用刚才构建的block
self.block1 = MyBlock()
self.block2 = MyBlock()
self.block2.conv1 = nn.Conv2d(3, 3, 3)
self.fc = nn.Linear(75, 10)
def forward(self, x):
# forward时,数据线经过两个block,而后通过fc层
x = self.block1(x)
x = self.block2(x)
# 将3维张量化为1维向量
x = x.view(-1, self.num_flat_features(x))
x = self.fc(x)
return x
def num_flat_features(self, x):
size = x.size()[1:] # all dimensions except the batch dimension
num_features = 1
for s in size:
num_features *= s
return num_features
vgg = SimpleVgg()
print vgg