一般的,默认的collate_fn函数是要求一个batch中的图片都具有相同size(因为要做stack操作),当一个batch中的图片大小都不同时,可以使用自定义的collate_fn函数,则一个batch中的图片不再被stack操作,可以全部存储在一个list中,当然还有对应的label,如下面这个例子:
import torch from torch.utils.data import DataLoader from torchvision import transforms import torchvision.datasets as datasets import matplotlib.pyplot as plt # a simple custom collate function, just to show the idea def my_collate(batch): data = [item[0] for item in batch] target = [item[1] for item in batch] target = torch.LongTensor(target) return [data, target] def show_image_batch(img_list, title=None): num = len(img_list) fig = plt.figure() for i in range(num): ax = fig.add_subplot(1, num, i+1) ax.imshow(img_list[i].numpy().transpose([1,2,0])) ax.set_title(title[i]) plt.show() # do not do randomCrop to show that the custom collate_fn can handle images of different size train_transforms = transforms.Compose([transforms.Scale(size = 224), transforms.ToTensor(), ]) # change root to valid dir in your system, see ImageFolder documentation for more info train_dataset = datasets.ImageFolder(root="/hd1/jdhao/toyset", transform=train_transforms) trainset = DataLoader(dataset=train_dataset, batch_size=4, shuffle=True, collate_fn=my_collate, # use custom collate function here pin_memory=True) trainiter = iter(trainset) imgs, labels = trainiter.next() # print(type(imgs), type(labels)) show_image_batch(imgs, title=[train_dataset.classes[x] for x in labels])