Pytorch Distributed 初始化方法
参考文献
https://pytorch.org/docs/master/distributed.html
代码
https://github.com/overfitover/pytorch-distributed
欢迎来star me.
初始化
torch.distributed.init_process_group(backend, init_method='env://', **kwargs)
参数说明
- backend(str): 后端选择,包括 tcp mpi gloo
- init_method(str, optional): 用来初始化包的URL, 用来做并发控制的共享方式
- world_size(int, optional): 参与工作的进程数
- rank(int, optional): 当前进程的rank
- group_name(str, optional): 用来标记这组进程。
init_method()
有三种方法:
- file:// 共享文件系统
- tcp:// IP组播
- env:// 环境变量 (默认是这个)
env
#!/usr/bin/env python
import os
import torch
import torch.distributed as dist
from torch.multiprocessing import Process
import time
def run(rank, size):
pass
def init_processes(rank, size, fn, backend='gloo'):
""" Initialize the distributed environment. """
os.environ['MASTER_ADDR'] = '162.128.0.22'
os.environ['MASTER_PORT'] = '29555'
dist.init_process_group(backend, rank=rank, world_size=size)
torch.cuda.manual_seed(1)
fn(rank, size)
print("MM")
print(dist.get_rank())
print(dist.get_world_size())
print(dist.is_available())
def main():
size = 2
processes=[]
for i in range(size):
p = Process(target=init_processes, args=(i, size, run))
p.start()
processes.append(p)
for p in processes:
p.join()
if __name__ == "__main__":
start_time = time.time()
main()
end_time = time.time()
print("耗时:", end_time-start_time)
注意
将162.128.0.22换成自己的IP地址。
tcp
import torch
import torch.distributed as dist
import argparse
from time import sleep
from random import randint
from torch.multiprocessing import Process
def initialize(rank, world_size, ip, port):
dist.init_process_group(backend='tcp', init_method='tcp://{}:{}'.format(ip, port), rank=rank, world_size=world_size)
print("MM")
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--ip', type=str, default='162.128.0.22')
parser.add_argument('--port', type=str, default='20000')
parser.add_argument('--rank', '-r', type=int)
parser.add_argument('--world-size', '-s', type=int)
args = parser.parse_args()
print(args)
# initialize(args.rank, args.world_size, args.ip, args.port)
size = 2
processes = []
for i in range(size):
p = Process(target=initialize, args=(i, size, args.ip, args.port))
p.start()
processes.append(p)
for p in processes:
p.join()
if __name__ == '__main__':
main()
注意
将162.128.0.22换成自己的IP地址。
共享文件
import argparse
from time import sleep
from random import randint
from torch.multiprocessing import Process
def initialize(rank, world_size):
dist.init_process_group(backend='gloo', init_method='file:///home/yxk/Documents/Deeplearningoflidar139/overfitover/share', rank=rank, world_size=world_size)
print("MM")
def main():
size = 2
processes = []
for i in range(size):
p = Process(target=initialize, args=(i, size))
p.start()
processes.append(p)
for p in processes:
p.join()
if __name__ == '__main__':
main()
注意
init_method: 需要以file://开头,包含共享文件系统上不存在的文件(在现有目录中)的路径。如果文件不存在, 文件系统初始化将自动创建该文件,但不会删除该文件。你要在下一个init_process_group调用之前清楚该文件。