Pytorch多GPU训练:DataParallel

devices_ids=[1,2,3] #使用GPU 1,2,3

net=net.to(devices_ids[0]) #首先把模型放在多GPU中的第一个GPU上

net=torch.nn.DataParallel(net,device_ids=devices_ids,output_device=devices_ids[0]) #使用多GPU训练模型,并把模型输出到output_devices上。

x.to(devices_ids[0])
y.to(devices_ids[0])

preds=net(x).to(devices_ids[0]) #所有的张量都必须放在devices_ids[0]上

    原文作者:shayxu
    原文地址: https://zhuanlan.zhihu.com/p/39450949
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞