Pytorch中Dataloader保存文件名

转载自:https://gist.github.com/andrewjong/6b02ff237533b3b2c554701fb53d5c4d,本文只做个人记录学习使用,版权归原作者所有。

import torch
from torchvision import datasets

class ImageFolderWithPaths(datasets.ImageFolder):
    """Custom dataset that includes image file paths. Extends
    torchvision.datasets.ImageFolder
    """

    # override the __getitem__ method. this is the method that dataloader calls
    def __getitem__(self, index):
        # this is what ImageFolder normally returns 
        original_tuple = super(ImageFolderWithPaths, self).__getitem__(index)
        # the image file path
        path = self.imgs[index][0]
        # make a new tuple that includes original and the path
        tuple_with_path = (original_tuple + (path,))
        return tuple_with_path

# EXAMPLE USAGE:
# instantiate the dataset and dataloader
data_dir = "your/data_dir/here"
dataset = ImageFolderWithPaths(data_dir) # our custom dataset
dataloader = torch.utils.DataLoader(dataset)

# iterate over data
for inputs, labels, paths in dataloader:
    # use the above variables freely
    print(inputs, labels, paths)

 

    原文作者:开飞机的小毛驴儿
    原文地址: https://blog.csdn.net/jzwong/article/details/108867297
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞