给pytorch 读取数据加速

在用tensorflow的时候,可以将数据转化成tfrecord的数据格式,增加数据读取效率。这时候你看nvidia-smi 的时候,gpu的利用效率基本接近100%,那感觉真的是爽,强迫症的福音。

而在pytorch上,一般用的是dataloder

而你还在用

training_data_loader = DataLoader(
    dataset=train_dataset,
    num_workers=opts.threads,
    batch_size=opts.batchSize,
    pin_memory=True,
    shuffle=True,
)
for iteration, batch in enumerate(training_data_loader, 1):
    # 训练代码

这个在训练时就不怎么友好了,特别是你gpu还是租的时候,时间就是金钱啊,同学!

这个情况nVidia给出了解决方案,

nvidia 加速io方案github.com

里面重点是增加了dataloader的装饰

class data_prefetcher():
    def __init__(self, loader):
        self.loader = iter(loader)
        self.stream = torch.cuda.Stream()
        self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)
        self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1)
        # With Amp, it isn't necessary to manually convert data to half.
        # if args.fp16:
        # self.mean = self.mean.half()
        # self.std = self.std.half()
        self.preload()

    def preload(self):
        try:
            self.next_input, self.next_target = next(self.loader)
        except StopIteration:
            self.next_input = None
            self.next_target = None
            return
        with torch.cuda.stream(self.stream):
            self.next_input = self.next_input.cuda(non_blocking=True)
            self.next_target = self.next_target.cuda(non_blocking=True)
            # With Amp, it isn't necessary to manually convert data to half.
            # if args.fp16:
            # self.next_input = self.next_input.half()
            # else:
            self.next_input = self.next_input.float()
            self.next_input = self.next_input.sub_(self.mean).div_(self.std)

升级后,改成这样

prefetcher = data_prefetcher(train_loader)
data, label = prefetcher.next()
iteration = 0
while data is not None:
    iteration += 1
    # 训练代码
    data, label = prefetcher.next()

大功告成,gpu利用率蹭蹭上来,当然最好还是把机械硬盘换成ssd

    原文作者:体hi
    原文地址: https://zhuanlan.zhihu.com/p/72956595
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞