的Pytorch的数据读取非常方便, 可以很容易地实现多线程数据预读. 我个人认为编程难度比TF小很多,而且灵活性也更高. (TF需要把文件名封装成list, 传入string_input_producer
, 这样可以得到一个queue
; 然后把这个queue
给一个WholeFileReader.read()
; 再把read()
回来的value用decode_jpeg()
解码; 然后再用一系列处理去clip, flip等等…)
Pytorch的数据读取主要包含三个类:
- Dataset
- DataLoader
- DataLoaderIter
这三者大致是一个依次封装的关系: 1.被装进2., 2.被装进3.
一. torch.utils.data.Dataset
是一个抽象类, 自定义的Dataset需要继承它并且实现两个成员方法:
__getitem__()
__len__()
第一个最为重要, 即每次怎么读数据. 以图片为例:
def __getitem__(self, index):
img_path, label = self.data[index].img_path, self.data[index].label
img = Image.open(img_path)
return img, label
值得一提的是, pytorch还提供了很多常用的transform, 在torchvision.transforms
里面, 本文中不多介绍, 我常用的有Resize
, RandomCrop
, Normalize
, ToTensor
(这个极为重要, 可以把一个PIL或numpy图片转为torch.Tensor
, 但是好像对numpy数组的转换比较受限, 所以这里建议在__getitem__()
里面用PIL来读图片, 而不是用skimage.io).
第二个比较简单, 就是返回整个数据集的长度:
def __len__(self):
return len(self.data)
二. torch.utils.data.DataLoader
类定义为:
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)
可以看到, 主要参数有这么几个:
dataset
: 即上面自定义的dataset.collate_fn
: 这个函数用来打包batch, 后面详细讲.num_worker
: 非常简单的多线程方法, 只要设置为>=1, 就可以多线程预读数据啦.
这个类其实就是下面将要讲的DataLoaderIter
的一个框架, 一共干了两件事: 1.定义了一堆成员变量, 到时候赋给DataLoaderIter
, 2.然后有一个__iter__()
函数, 把自己 “装进” DataLoaderIter
里面.
def __iter__(self):
return DataLoaderIter(self)
三. torch.utils.data.dataloader.DataLoaderIter
上面提到, DataLoaderIter
就是DataLoaderIter
的一个框架, 用来传给DataLoaderIter
一堆参数, 并把自己装进DataLoaderIter
里.
其实到这里就可以满足大多数训练的需求了, 比如
class CustomDataset(Dataset):
# 自定义自己的dataset
dataset = CustomDataset()
dataloader = Dataloader(dataset, ...)
for data in dataloader:
# training...
在for 循环里, 总共有三点操作:
- 调用了
dataloader
的__iter__()
方法, 产生了一个DataLoaderIter
- 反复调用
DataLoaderIter
的__next__()
来得到batch, 具体操作就是, 多次调用dataset的__getitem__()
方法 (如果num_worker
>0就多线程调用), 然后用collate_fn
来把它们打包成batch. 中间还会涉及到shuffle
, 以及sample
的方法等, 这里就不多说了. - 当数据读完后,
__next__()
抛出一个StopIteration
异常,for
循环结束,dataloader
失效.
四. 又一层封装…
其实上面三个类已经可以搞定了, 但是我觉得这还不太符合我的需求, 就又写了一个类, 仅供参考
class DataProvider:
def __init__(self, batch_size, is_cuda):
self.batch_size = batch_size
self.dataset = Dataset_triple(self.batch_size,
transform_=transforms.Compose(
[transforms.Scale([224, 224]),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])]),
)
self.is_cuda = is_cuda # 是否将batch放到gpu上
self.dataiter = None
self.iteration = 0 # 当前epoch的batch数
self.epoch = 0 # 统计训练了多少个epoch
def build(self):
dataloader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True, num_workers=0, drop_last=True)
self.dataiter = DataLoaderIter(dataloader)
def next(self):
if self.dataiter is None:
self.build()
try:
batch = self.dataiter.next()
self.iteration += 1
if self.is_cuda:
batch = [batch[0].cuda(), batch[1].cuda(), batch[2].cuda()]
return batch
except StopIteration: # 一个epoch结束后reload
self.epoch += 1
self.build()
self.iteration = 1 # reset and return the 1st batch
batch = self.dataiter.next()
if self.is_cuda:
batch = [batch[0].cuda(), batch[1].cuda(), batch[2].cuda()]
return batch