https://www.jianshu.com/p/4905bf8e06e5
上面这个链接主要给出了PyTorch如何保存和加载模型
今天遇到了单GPU保存模型,然后多GPU加载模型出现错误的情况。在此记录。
from collections import OrderedDict def load_pretrainedmodel(modelname, model_): pre_model = torch.load(modelname, map_location=lambda storage, loc: storage)["model"] #print(pre_model) if cuda: state_dict = OrderedDict() for k in pre_model.state_dict(): name = k if name[:7] != 'module' and torch.cuda.device_count() > 1: # loaded model is single GPU but we will train it in multiple GPUS! name = 'module.' + name #add 'module' elif name[:7] == 'module' and torch.cuda.device_count() == 1: # loaded model is multiple GPUs but we will train it in single GPU! name = k[7:]# remove `module.` state_dict[name] = pre_model.state_dict()[k] #print(name) model_.load_state_dict(state_dict) #model_.load_state_dict(torch.load(modelname)['model'].state_dict()) else: model_ = torch.load(modelname, map_location=lambda storage, loc: storage)["model"] return model_
由于多GPU的模型参数会多出‘module.’这个前缀,所以有时要加上有时要去掉。