Pytorch笔记03-torch.nn.Modules及拓展新的Modules

0. 本章内容

在之前已经介绍了torch.nn这个包,其主要负责快速构建常用的深度学习网络,如卷积层等。但很多时候,我们需要自己定义一个torch.nn中并未实现的网络层,以使得代码更加模块化。这时候就需要我们自己拓展torch.nn.Modules了。本章主要介绍如何通过继承torch.nn.Modules类来定义一个新的网络层。

1. torch.nn.Modules回顾

  • torch.nn.Modules相当于是对网络某种层的封装,包括网络结构以及网络参数,和其他有用的操作如输出参数
  • torch.nn包中的各个类实际上就是由torch.nn.Modules继承而拓展
  • pytorch中有两种拓展pytorch方式,一种是拓展Function(后面将介绍),一种是就是拓展Modules

2. 如何利用torch.nn.Modules进行拓展

  • 进行拓展,需要继承Modules类,并实现__init__()方法,以及forward()方法
  • __init__()方法,用于定义一些新的属性,这些属性可以包括Modules的实例,如一个torch.nn.Conv2d。即创建该网络中的子网络,在创建这些子网络时,这些网络的参数也被初始化
# 在module初始化时,执行__init__()
def __init__(self):
    # 调用Module的初始化
    super(MyBlock, self).__init__()
        
    # 创建将要调用的子层(Module),注意:此时还并未实现MyBlock网络的结构(即forward运算),只是初始化了其子层(结构+参数)
    self.conv1 = nn.Conv2d(1, 3, 3)
    self.conv2 = nn.Conv2d(3, 3, 3)

  • forward()方法,用于定义该Module进行forward时的运算,forward()方法接受一个输入,然后通过其他modules或者其他Function运算,来进行forward,返回一个输出结果
# 在定义好forward后,该module调用forward,将按forward进行前向传播,并构建网络
def forward(self, x):
        
    # 这里relu与pool层选择用Function来实现(后续将介绍Function,可以认为Function是一种运算),而不使用Module,用Module也可以
    x = F.relu(self.conv1(x))
    x = F.relu(self.conv2(x))
    x = F.max_pool2d(x, 2)
    return x

3. 作用

  • Module的作用就是可以结构化定义网络的层,并提供对该层的封装,包括该层的结构,参数以及一些其他操作
  • Module中的forward可以使用其他Module,其在调用forward时,其内部其他Module将按顺序进行forward(具体见例子)

4. 例子

构建一个简单的类vggblock

import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F

class MyBlock(nn.Module):
    def __init__(self):
        # 调用Module的初始化
        super(MyBlock, self).__init__()
        
        # 创建将要调用的子层(Module),注意:此时还并未实现MyBlock网络的结构,只是初始化了其子层(结构+参数)
        self.conv1 = nn.Conv2d(1, 3, 3)
        self.conv2 = nn.Conv2d(3, 3, 3)
        
    def forward(self, x):
        
        # 这里relu与pool层选择用Function来实现,而不使用Module,用Module也可以
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        return x

# 实例化一个新建的网络 
test_MyBlock = MyBlock()

# 可以看Module中的子Module
print test_MyBlock

# MyBlock (
# (conv1): Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1))
# (conv2): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
# )

# 可以通过Module中的多种方法,实现输出参数等功能
print(test_MyBlock.state_dict().keys())

# ['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias']

# 可以直接对Module中的子Module进行修改
test_MyBlock.conv1 = nn.Conv2d(1, 3, 5)
print test_MyBlock.conv1
# ['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias']

# 随机生成一个输入
x = torch.randn(1, 1, 10, 10)

# 进行forward操作,建立网络,此处可以直接使用test_MyBlock(x)
test_MyBlock.forward(x)

利用上面构建的Block构建一个简单的vgg网络

class SimpleVgg(nn.Module):
    def __init__(self):
        super(SimpleVgg, self).__init__()
        
        # 利用刚才构建的block
        self.block1 = MyBlock()
        self.block2 = MyBlock()
        self.block2.conv1 = nn.Conv2d(3, 3, 3)
        self.fc = nn.Linear(75, 10)
    
    def forward(self, x):

        # forward时,数据线经过两个block,而后通过fc层
        x = self.block1(x)
        x = self.block2(x)
        
        # 将3维张量化为1维向量
        x = x.view(-1, self.num_flat_features(x))
        x = self.fc(x)
        return x
    
    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

vgg = SimpleVgg()

print vgg
    原文作者:林梅林
    原文地址: https://zhuanlan.zhihu.com/p/27545732
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞