PyTorch对于变长文本的异步多进程读取

异步多进程读取数据对于在大规模的数据集上训练模型有重要的意义。

基于PyTorch的异步多进程IO的tutorial往往基于图像做demo进行讲解,对做CV的同学相对比较友好,torchvision支持的也比较好,相比之下torchtext文档写得很乱,缺乏教程,也欠缺灵活性。

这篇文章主要写给用PyTorch做NLP的同学,主要贡献在于提供了一种解决变长样本的方法。做CV的同学可以直接移步https://zhuanlan.zhihu.com/p/30934236,写的真的很好很清晰,本文也是借鉴了这篇文章的部分思路。

因为部分代码是我正在研究的工作,不便直接展示,所以本文少量代码会用伪代码代替。

本文主要包含三部分,DatasetDataloadercollate_fn

Prerequisite

为了实现异步多进程的IO,首先需要把每个example写成一个小文件,放到指定文件夹里。这里我自己用json格式存储example,大家随意。

(这里我暂时没有想到更好的办法)

例如:

{
"token":["x","x"],
"token_idx":[0,0], 
"tag":["y","y"], 
"tag_idx":[1,1]
}

Dataset

import glob
import os
import ujson as json
from torch.utils import data
class Dataset(data.Dataset):
    def __init__(self, dir):
        super(Dataset, self).__init__()
        self.data = glob.glob(dir + "/*.json")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        json_path = self.data[idx]
        with open(json_path, "r") as f:
            data = json.load(f)
        return data

基于torch.utils.data.Dataset,我们自己定义了一个Dataset类,初始化需要出入存放数据的文件夹地址。

Dataloader

Dataloader本质上就是一个iterator。没什么好讲的,自己看文档。

trainset = Dataset("./data/train")
trainloader = data.Dataloader(trainset, batch_size=24, num_workers=8, collate_fn=collate_fn)

collate_fn

我们在异步多进程得到一个batch之后和将batch输入模型之前,还需要对batch进一步处理,主要是padding。

import torch
from keras_preprocessing.sequence import pad_sequences
def collate_fn(batch):
    token_idx = list(map(lambda x: x["token_idx"], batch))
    tag_idx = list(map(lambda x: x["tag_idx"], batch))
    token_idx_pad = torch.LongTensor(pad_sequences(token_idx, padding="post"))
    tag_idx_pad = torch.LongTensor(pad_sequences(tag_idx, padding="post"))
    return token_idx_pad, tag_idx_pad

上面这个例子展示了在序列标注任务上如何使用collate_fn。我用了keras_preprocessing.sequence做padding,我个人觉得torch自带的torch.nn.utils.rnn.pad_sequence等几个相关的function用起来不大顺手。

Conclusion

综合在一起,如下:

trainset = Dataset("./data/train")
trainloader = data.Dataloader(trainset, batch_size=24, num_workers=8, collate_fn=collate_fn)
for token, tag in trainloader:
    ......

然后我们就可以愉快的异步多进程读取数据了。

妈妈再也不用担心我的内存不够用了。

以上方法主要的缺点就是需要生成大量的小文件。

如果哪位大神有更好的方法,欢迎指正,我也只不过是抛砖引玉而已。

感激不尽。

    原文作者:吴明昊
    原文地址: https://zhuanlan.zhihu.com/p/72014659
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞