pytorch保存和加载cpu,GPU,以及多GPU模型

https://www.jianshu.com/p/4905bf8e06e5

上面这个链接主要给出了PyTorch如何保存和加载模型

今天遇到了单GPU保存模型,然后多GPU加载模型出现错误的情况。在此记录。

from collections import OrderedDict
def load_pretrainedmodel(modelname, model_):
    pre_model = torch.load(modelname, map_location=lambda storage, loc: storage)["model"]
    #print(pre_model)
    if cuda:
        state_dict = OrderedDict()
        for k in pre_model.state_dict():
            name = k
            if name[:7] != 'module' and torch.cuda.device_count() > 1: # loaded model is single GPU but we will train it in multiple GPUS!
                name = 'module.' + name #add 'module'
            elif name[:7] == 'module' and torch.cuda.device_count() == 1: # loaded model is multiple GPUs but we will train it in single GPU!
                name = k[7:]# remove `module.` 
            state_dict[name] = pre_model.state_dict()[k]
            #print(name)
        model_.load_state_dict(state_dict)
        #model_.load_state_dict(torch.load(modelname)['model'].state_dict())
    else:
        model_ = torch.load(modelname, map_location=lambda storage, loc: storage)["model"]
    return model_

由于多GPU的模型参数会多出‘module.’这个前缀,所以有时要加上有时要去掉。

    原文作者:pytorch
    原文地址: https://www.cnblogs.com/amarr/p/10553354.html
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞