pytorch模型参数

1、torch.nn.state_dict():

返回一个字典,保存着module的所有状态(state)。

parameters和persistent_buffers都会包含在字典中,字典的key就是parameter和buffer的names。

例子:

import torch
from torch.autograd import Variable
import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv2 = nn.Linear(1, 2)
        self.vari = Variable(torch.rand([1]))
        self.par = nn.Parameter(torch.rand([1]))
        self.register_buffer("buffer", torch.randn([2,3]))

model = Model()
print(model.state_dict().keys())
odict_keys(['par', 'buffer', 'conv2.weight', 'conv2.bias'])

 

字典迭代形式{<class ‘str’>:<class ‘torch.Tensor’>, … }

    原文作者:pytorch
    原文地址: https://www.cnblogs.com/lucifer1997/p/11305150.html
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞