1、torch.nn.state_dict():
返回一个字典,保存着module的所有状态(state)。
parameters和persistent_buffers都会包含在字典中,字典的key就是parameter和buffer的names。
例子:
import torch from torch.autograd import Variable import torch.nn as nn class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.conv2 = nn.Linear(1, 2) self.vari = Variable(torch.rand([1])) self.par = nn.Parameter(torch.rand([1])) self.register_buffer("buffer", torch.randn([2,3])) model = Model() print(model.state_dict().keys())
odict_keys(['par', 'buffer', 'conv2.weight', 'conv2.bias'])
字典迭代形式{<class ‘str’>:<class ‘torch.Tensor’>, … }