在PyTorch中非常容易修改模型的参数,因此,很容易在基础网络模型的基础上进行,重新进行自己的模型训练,请参考下列将一个模型参数复制到另外一个模型的代码:
trained_dict = train_model.state_dict()
my_trained_dict = my_.state_dict()
for k, v in trained_dict.items():
for _k, _v in my_trained_dict.items():
if k.find(_k) > 0:
size_ = v.size()
if len(size_) == 1:
for k0 in range(size_[0]):
_v[k0] = v[k0]
else:
for k0 in range(size_[0]):
for k1 in range(size_[1]):
for k2 in range(size_[2]):
for k3 in range(size_[3]):
_v[k0, k1, k2, k3] = v[k0, k1, k2, k3]
my_model.load_state_dict(my_trained_dict)