未经许可,严禁任何形式转载!
能力有限,欢迎批评指正。
引入
训练模型的时候,第一步需要将数据集进行加载和预处理,然后送入数据队列,通过数据队列送入网络中。分别通过 Dataset
和 Dataloader
实现。
创建数据集 Dataset
使用自带数据集
对于常用数据集,可以使用 torchvision.datasets
直接进行读取。torchvision.dataset
是 torch.utils.data.Dataset
的实现。它提供了如下的数据集:
1. MNIST
2. COCO (Captioning and Detection)
3. LSUN Classification
4. ImageFolder
5. Imagenet-12
6. CIFAR10 and CIFAR100
7. STL10
下面以 MNIST
为例,进行调用:
import torchvision
cifarSet = torchvision.datasets.CIFAR10(root = "../data/cifar/", train= True, download = True)
其中,root
表示数据集存在或下载的路径,train
表示是训练集还是测试集, download=True
表示若不存在,则下载。
自定义 Dataset
对于自建数据集,pytorch
中数据集处理有两种方式,一种是通过集成 torch.utils.data.DataSet
类,另一种是使用 torchvision.datasets.ImageFolder
类。前者自由度更高,适合所有数据处理,后者有一定使用条件。
torchvision.datasets.ImageFolder
应用场合
主要用于分类问题,当不同类别的图片分别处于同一目录下的不同子目录时(子目录名对应类别名,如下所示),可以使用该方式。支持的数据格式有:jpg, .jpeg, .png, .ppm, .bmp, .pgm, .tif
。
data/dog/1.jpg
data/dog/23.jpg
...
data/cat/134.jpg
data/cat/243.jpg
函数原型
ImageFolder
的函数原型为:
ImageFolder(root, transform=None, target_transform=None, loader=default_loader)
# root : 根路径,即前面的 data
# transform : 对 PIL image 的操作(预处理)
# target_transform : 对 label 的预处理
# loader : 指定加载图片的函数,默认读取 PIL image 对象
实例
from torchvision.datasets import ImageFolder
dataset=ImageFolder('data', transform=transform)
#对应文件夹的label
print(dataset.class_to_idx)
#所有图片的路径和对应的label
print(dataset.imgs)
#输出第0张图片的大小
print(dataset[0][0].size())
输出如下所示:
{'cat': 0, 'dog': 1}
[('data/cat/deconvolution-unroll 2.jpg', 0), ('data/cat/deconvolution-unroll 3.jpg', 0), ('data/cat/deconvolution-unroll.jpg', 0), ('data/dog/deconvolution-unroll 2.jpg', 1), ('data/dog/deconvolution-unroll 3.jpg', 1), ('data/dog/deconvolution-unroll.jpg', 1)]
torch.Size([3, 224, 224])
torch.utils.data.DataSet
创建方式
使用该方法,需要继承 Dataset 类,并实现以下接口:
1. __init__() # 进行数据及相关参数的初始化
2. __getitem__() # 用于定义单个数据样本获取的规则方式,以及数据的预处理方式,最后返回数据
3. __len__() # 用于返回数据集的长度
应用场景
可以使用该方式来建立任何数据集的 Dataset
类。
有如下几点建议:
- 高负载的操作放在 __getitem__ 中
dataset
中尽量包含只读对象,避免修改任何可变对象
实例
示例代码如下所示:
import os
import torch
import torch.utils.data as data
from PIL import Image
def default_loader(path):
return Image.open(path).convert('RGB')
class myImageFloder(data.Dataset):
def __init__(self, root, label, transform = None, target_transform=None, loader=default_loader):
fh = open(label)
c=0
imgs=[]
class_names=[]
for line in fh.readlines():
if c==0:
# 解析属性名
class_names=[n.strip() for n in line.rstrip().split(' ')]
else:
cls = line.split()
fn = cls.pop(0)
if os.path.isfile(os.path.join(root, fn)):
# 解析文件名,及对应的属性
imgs.append((fn, tuple([float(v) for v in cls])))
c=c+1
self.root = root
self.imgs = imgs
self.classes = class_names
self.transform = transform
self.target_transform = target_transform
self.loader = loader
def __getitem__(self, index):
fn, label = self.imgs[index]
# 代开文件,获取数据
img = self.loader(os.path.join(self.root, fn))
if self.transform is not None:
# 数据预处理
img = self.transform(img)
# 返回数据,必须为 Tensor 格式
return img, torch.Tensor(label)
def __len__(self):
# 返回数据集的尺寸
return len(self.imgs)
def getName(self):
# 返回 class 名
return self.classes
数据预处理 transforms
pytorch
使用 torchvision.transforms
实现数据的预处理,使用 transforms.Compose
类进行封装,打包多个 transforms
,最后传递给 Dataset
对象。
详细参考笔记《图像预处理 transforms
》。
示例代码如下所示:
from torchvision import transforms
mytransform = transforms.Compose([
transforms.Resize(256),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(), #将图片转换为Tensor,归一化至[0,1]
])
# torch.utils.data.DataLoader
cifarSet = torchvision.datasets.CIFAR10(root = "../data/cifar/", train= True, download = True, transform = mytransform )
cifarLoader = torch.utils.data.DataLoader(cifarSet, batch_size= 10, shuffle= False, num_workers= 2)
创建数据队列 Dataloader
数据集创建完毕后,需要创建数据集的加载器。通过初始化 torch.utils.data.Dataloader
类即可创建 Dataloader
。
实际上,该 loader
是一个可迭代对象,其内部实现为一个队列,将数据集按照指定的规则进行输出。在训练过程中,通过外部循环来控制数据集重复的次数。
类原型
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)
参数定义如下所示:
* dataset (Dataset)
加载数据的数据集
* batch_size (int, optional)
每批加载多少个样本
* shuffle (bool, optional)
设置为“真”时,在每个epoch对数据打乱.(默认:False)
* sampler (Sampler, optional)
定义从数据集中提取样本的策略,返回一个样本
* batch_sampler (Sampler, optional)
like sampler, 返回一批样本. 与atch_size, shuffle, sampler和 drop_last互斥.
* num_workers (int, optional)
用于加载数据的子进程数。0表示数据将在主进程中加载。(默认:0)
线程多于必要的时候,数据读取线程返回到主线程反而会因为线程间通信减慢数据。
因此大了不好小了也不好。建议把模型,loss,优化器全注释了只跑一下数据流速度,确定最优值
* collate_fn (callable, optional)
合并样本列表以形成一个 mini-batch. # callable可调用对象
* pin_memory (bool, optional)
如果为 True, 将会提前申请 CUDA 内存,张量会被复制到 CUDA 固定内存中,然后再返回它们。
* drop_last (bool, optional)
设定为 True 如果数据集大小不能被批量大小整除的时候, 将丢掉最后一个不完整的batch,(默认:False).
* timeout (numeric, optional)
如果为正值,则为从工作人员收集批次的超时值。应始终是非负的。(默认:0)
* worker_init_fn (callable, optional)
If not None, this will be called on each worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as input, after seeding and before data loading. (default: None)
备注
pin_memory
就是锁页内存,创建DataLoader
时,设置pin_memory=True
,则意味着生成的Tensor
数据最开始是属于内存中的锁页内存,这样将内存的Tensor
转义到GPU
的显存就会更快一些。- 主机中的内存,有两种存在方式,一是锁页,二是不锁页,锁页内存存放的内容在任何情况下都不会与主机的虚拟内存进行交换(注:虚拟内存就是硬盘),而不锁页内存在主机内存不足时,数据会存放在虚拟内存中
- 显卡中的显存全部是锁页内存,当计算机的内存充足的时候,可以设置
pin_memory=True
。当系统卡住,或者交换内存使用过多的时候,设置pin_memory=False
。因为pin_memory
与电脑硬件性能有关,pytorch
开发者不能确保每一个炼丹玩家都有高端设备,因此pin_memory
默认为False
。 - 如果机子的内存比较大,建议开启 pin_memory=Ture,如果开启后发现有卡顿现象或者内存占用过高,此时建议关闭。
示例代码
# 创建 loader 队列,用于生成数据 batch
train_data = MyDataSet(os.path.join(root_dir, 'train.txt'), transform=transforms.ToTensor())
train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True)
sampler
说明
当 shuffle 参数为 True 时,Dataloader
会调用随机采样器 RandomSampler 进行样本采样。默认的采样器为 SequentialSampler,它会按顺序逐一采样。
此外,Pytorch
还提供另一个采样器,即 WeightedRandomSampler,它会根据样本的权重选取数据,在样本不均衡时,可以用它进行重采样。
WeightedRandomSampler
构建 WeightedRandomSampler 需要提供两个参数:每个样本的权重,一共选取的样本总数 num_samples
,以及一个可选参数 replacement
。
权重越大的样本被选中的概率越大,待选取的样本数目一般小于总样本数。replacement 用于指定是否为可放回采样,默认为 True。若为 False 时,如果要采样的样本数大于总样本数,则可能会异常。
代码如下所示:
from torch.utils.data.sampler import WeightedRandomSampler
dataset = DogCat('data', transforms=transform)
# 创建 weights
weights = [2 if label == 1 else 1 for data, label in dataset]
# 创建 sampler
sampler = WeightedRandomSampler(weights, num_samples=9, replacement=True)
# 使用
dataloader = Dataloader(dataset, batch_size=3, sampler=sampler)