Pytorch collate_fn用法

By default, Dataloader use collate_fn method to pack a series of images and target as tensors (first dimension of tensor is batch size). The default collate_fn expects all the images in a batch to have the same size because it uses torch.stack() to pack the images. If the images provided by Dataset have variable size, you have to provide your custom collate_fn. A simple example is shown below:

 1 # a simple custom collate function, just to show the idea
 2 
 3 # `batch` is a list of tuple where first element is image tensor and
 4 
 5 # second element is corresponding label
 6 
 7 def my_collate(batch):  8     data = [item[0] for item in batch]  # just form a list of tensor
 9 
10     target = [item[1] for item in batch] 11     target = torch.LongTensor(target) 12     return [data, target]

Reference:   Writing Your Own Custom Dataset for Classification in PyTorch

 

 

By default, torch stacks the input image to from a tensor of size N*C*H*W, so every image in the batch must have the same height and width. In order to load a batch with variable size input image, we have to use our own collate_fn which is used to pack a batch of images.

For image classification, the input to collate_fn is a list of with size batch_size. Each element is a tuple where the first element is the input image(a torch.FloatTensor) and the second element is the image label which is simply an int. Because the samples in a batch have different size, we can store these samples in a list ans store the corresponding labels in torch.LongTensor. Then we put the image list and the label tensor into a list and return the result.

here is a very simple snippet to demonstrate how to write a custom collate_fn:

 1 import torch  2 from torch.utils.data import DataLoader  3 from torchvision import transforms  4 import torchvision.datasets as datasets  5 import matplotlib.pyplot as plt  6 
 7 # a simple custom collate function, just to show the idea
 8 def my_collate(batch):  9     data = [item[0] for item in batch] 10     target = [item[1] for item in batch] 11     target = torch.LongTensor(target) 12     return [data, target] 13 
14 
15 def show_image_batch(img_list, title=None): 16     num = len(img_list) 17     fig = plt.figure() 18     for i in range(num): 19         ax = fig.add_subplot(1, num, i+1) 20         ax.imshow(img_list[i].numpy().transpose([1,2,0])) 21  ax.set_title(title[i]) 22 
23  plt.show() 24 
25 # do not do randomCrop to show that the custom collate_fn can handle images of different size
26 train_transforms = transforms.Compose([transforms.Scale(size = 224), 27  transforms.ToTensor(), 28  ]) 29 
30 # change root to valid dir in your system, see ImageFolder documentation for more info
31 train_dataset = datasets.ImageFolder(root="/hd1/jdhao/toyset", 32                                      transform=train_transforms) 33 
34 trainset = DataLoader(dataset=train_dataset, 35                       batch_size=4, 36                       shuffle=True, 37                       collate_fn=my_collate, # use custom collate function here
38                       pin_memory=True) 39 
40 trainiter = iter(trainset) 41 imgs, labels = trainiter.next() 42 
43 # print(type(imgs), type(labels))
44 show_image_batch(imgs, title=[train_dataset.classes[x] for x in labels])

Reference:    How to create a dataloader with variable-size input

 

 

Dataloader的测试用例:

 1 import torch  2 import torch.utils.data as Data  3 import numpy as np  4 
 5 test = np.array([0,1,2,3,4,5,6,7,8,9,10,11])  6 
 7 inputing = torch.tensor(np.array([test[i:i + 3] for i in range(10)]))  8 target = torch.tensor(np.array([test[i:i + 1] for i in range(10)]))  9 
10 torch_dataset = Data.TensorDataset(inputing,target) 11 batch = 3
12 
13 loader = Data.DataLoader( 14     dataset=torch_dataset, 15     batch_size=batch, # 批大小
16     # 若dataset中的样本数不能被batch_size整除的话,最后剩余多少就使用多少
17     collate_fn=lambda x:( 18  torch.cat( 19             [x[i][j].unsqueeze(0) for i in range(len(x))], 0 20             ).unsqueeze(0) for j in range(len(x[0])) 21  ) 22  ) 23 
24 for (i,j) in loader: 25     print(i) 26     print(j) 

Reference: DataLoader的collate_fn参数

    原文作者:pytorch
    原文地址: https://www.cnblogs.com/king-lps/p/10990304.html
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞