PyTorch修改模型参数

在PyTorch中非常容易修改模型的参数,因此,很容易在基础网络模型的基础上进行,重新进行自己的模型训练,请参考下列将一个模型参数复制到另外一个模型的代码:

trained_dict = train_model.state_dict()

my_trained_dict = my_.state_dict()
for k, v in trained_dict.items():
for _k, _v in my_trained_dict.items():
if k.find(_k) > 0:
size_ = v.size()
if len(size_) == 1:
for k0 in range(size_[0]):
_v[k0] = v[k0]
else:
for k0 in range(size_[0]):
for k1 in range(size_[1]):
for k2 in range(size_[2]):
for k3 in range(size_[3]):
_v[k0, k1, k2, k3] = v[k0, k1, k2, k3]

my_model.load_state_dict(my_trained_dict)

    原文作者:AI贾书军
    原文地址: https://zhuanlan.zhihu.com/p/27389873
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞