pytorch 读数据接口 制作数据集 data.dataset

【吐槽】

啊,代码,你这个大猪蹄子

自己写了cifar10的数据接口,跟官方接口load的数据一样,

沾沾自喜,以为自己会写数据接口了

几天之后,突然想,自己的代码为啥有点慢呢,这数据集不大啊

用了官方接口,真快啊。。。

啊啊啊啊啊啊啊啊

 

但这是好事,至少我明白了一点知识对吧

 

【lesson】

看了cifar10的接口,发现自己在数据集初始化的地方写的太少了,应该在初始化的时候就把所有数据读进来,这样的话在__getitem__的时候才能快。

 

 人家的初始化:

 if self.train:
            self.train_data = []
            self.train_labels = []
            for fentry in self.train_list:
                f = fentry[0]
                file = os.path.join(self.root, self.base_folder, f)
                fo = open(file, 'rb')
                if sys.version_info[0] == 2:
                    entry = pickle.load(fo)
                else:
                    entry = pickle.load(fo, encoding='latin1')
                self.train_data.append(entry['data'])
                if 'labels' in entry:
                    self.train_labels += entry['labels']
                else:
                    self.train_labels += entry['fine_labels']
                fo.close()

            self.train_data = np.concatenate(self.train_data)
            self.train_data = self.train_data.reshape((50000, 3, 32, 32))
            self.train_data = self.train_data.transpose((0, 2, 3, 1))  # convert to HWC

 

人家的getitem

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        if self.train:
            img, target = self.train_data[index], self.train_labels[index]
        else:
            img, target = self.test_data[index], self.test_labels[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

 

 

自己:(都写到getitem里面了)

 def __init__(self, root, transforms=transform(), train=True, test=False): self.root = root self.transform = transforms self.train = train self.test = test if self.test: self.train = False def __getitem__(self, item): x = math.floor(item / 10000) + 1 y = item % 10000 if not self.train and not self.test: x = 5 y = 5000+item imgpath = os.path.join(self.root, "data_batch_"+str(x)) with open(imgpath, 'rb') as fo: dict = pickle.load(fo, encoding='bytes') d_decode = {} for k,v in dict.items(): d_decode[k.decode('utf8')] = v dict = d_decode data = dict['data'][y] # 3*32*32==3072 data = np.reshape(data,(3,32,32)) data = data.transpose(1,2,0) data = self.transform(data) label = dict['labels'][y] # label = torch.from_numpy(label) return data, label

 

 

 

附自己的代码和人家的代码全部

人家:

 1 base_folder = 'cifar-10-batches-py'
 2     url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
 3     filename = "cifar-10-python.tar.gz"
 4     tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
 5     train_list = [
 6         ['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
 7         ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
 8         ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
 9         ['data_batch_4', '634d18415352ddfa80567beed471001a'],
10         ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
11     ]
12 
13     test_list = [
14         ['test_batch', '40351d587109b95175f43aff81a1287e'],
15     ]
16 
17     def __init__(self, root, train=True,
18                  transform=None, target_transform=None,
19                  download=False):
20         self.root = os.path.expanduser(root)
21         self.transform = transform
22         self.target_transform = target_transform
23         self.train = train  # training set or test set
24 
25         if download:
26             self.download()
27 
28         if not self._check_integrity():
29             raise RuntimeError('Dataset not found or corrupted.' +
30                                ' You can use download=True to download it')
31 
32         # now load the picked numpy arrays
33         if self.train:
34             self.train_data = []
35             self.train_labels = []
36             for fentry in self.train_list:
37                 f = fentry[0]
38                 file = os.path.join(self.root, self.base_folder, f)
39                 fo = open(file, 'rb')
40                 if sys.version_info[0] == 2:
41                     entry = pickle.load(fo)
42                 else:
43                     entry = pickle.load(fo, encoding='latin1')
44                 self.train_data.append(entry['data'])
45                 if 'labels' in entry:
46                     self.train_labels += entry['labels']
47                 else:
48                     self.train_labels += entry['fine_labels']
49                 fo.close()
50 
51             self.train_data = np.concatenate(self.train_data)
52             self.train_data = self.train_data.reshape((50000, 3, 32, 32))
53             self.train_data = self.train_data.transpose((0, 2, 3, 1))  # convert to HWC
54         else:
55             f = self.test_list[0][0]
56             file = os.path.join(self.root, self.base_folder, f)
57             fo = open(file, 'rb')
58             if sys.version_info[0] == 2:
59                 entry = pickle.load(fo)
60             else:
61                 entry = pickle.load(fo, encoding='latin1')
62             self.test_data = entry['data']
63             if 'labels' in entry:
64                 self.test_labels = entry['labels']
65             else:
66                 self.test_labels = entry['fine_labels']
67             fo.close()
68             self.test_data = self.test_data.reshape((10000, 3, 32, 32))
69             self.test_data = self.test_data.transpose((0, 2, 3, 1))  # convert to HWC
70 
71     def __getitem__(self, index):
72         """
73         Args:
74             index (int): Index
75 
76         Returns:
77             tuple: (image, target) where target is index of the target class.
78         """
79         if self.train:
80             img, target = self.train_data[index], self.train_labels[index]
81         else:
82             img, target = self.test_data[index], self.test_labels[index]
83 
84         # doing this so that it is consistent with all other datasets
85         # to return a PIL Image
86         img = Image.fromarray(img)
87 
88         if self.transform is not None:
89             img = self.transform(img)
90 
91         if self.target_transform is not None:
92             target = self.target_transform(target)
93 
94         return img, target

 

    原文作者:pytorch
    原文地址: https://www.cnblogs.com/yexiaoqi/p/10510960.html
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞