【吐槽】
啊,代码,你这个大猪蹄子
自己写了cifar10的数据接口,跟官方接口load的数据一样,
沾沾自喜,以为自己会写数据接口了
几天之后,突然想,自己的代码为啥有点慢呢,这数据集不大啊
用了官方接口,真快啊。。。
啊啊啊啊啊啊啊啊
但这是好事,至少我明白了一点知识对吧
【lesson】
看了cifar10的接口,发现自己在数据集初始化的地方写的太少了,应该在初始化的时候就把所有数据读进来,这样的话在__getitem__的时候才能快。
人家的初始化:
if self.train: self.train_data = [] self.train_labels = [] for fentry in self.train_list: f = fentry[0] file = os.path.join(self.root, self.base_folder, f) fo = open(file, 'rb') if sys.version_info[0] == 2: entry = pickle.load(fo) else: entry = pickle.load(fo, encoding='latin1') self.train_data.append(entry['data']) if 'labels' in entry: self.train_labels += entry['labels'] else: self.train_labels += entry['fine_labels'] fo.close() self.train_data = np.concatenate(self.train_data) self.train_data = self.train_data.reshape((50000, 3, 32, 32)) self.train_data = self.train_data.transpose((0, 2, 3, 1)) # convert to HWC
人家的getitem
def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (image, target) where target is index of the target class. """ if self.train: img, target = self.train_data[index], self.train_labels[index] else: img, target = self.test_data[index], self.test_labels[index] # doing this so that it is consistent with all other datasets # to return a PIL Image img = Image.fromarray(img) if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) return img, target
自己:(都写到getitem里面了)
def __init__(self, root, transforms=transform(), train=True, test=False): self.root = root self.transform = transforms self.train = train self.test = test if self.test: self.train = False def __getitem__(self, item): x = math.floor(item / 10000) + 1 y = item % 10000 if not self.train and not self.test: x = 5 y = 5000+item imgpath = os.path.join(self.root, "data_batch_"+str(x)) with open(imgpath, 'rb') as fo: dict = pickle.load(fo, encoding='bytes') d_decode = {} for k,v in dict.items(): d_decode[k.decode('utf8')] = v dict = d_decode data = dict['data'][y] # 3*32*32==3072 data = np.reshape(data,(3,32,32)) data = data.transpose(1,2,0) data = self.transform(data) label = dict['labels'][y] # label = torch.from_numpy(label) return data, label
附自己的代码和人家的代码全部
人家:
1 base_folder = 'cifar-10-batches-py' 2 url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" 3 filename = "cifar-10-python.tar.gz" 4 tgz_md5 = 'c58f30108f718f92721af3b95e74349a' 5 train_list = [ 6 ['data_batch_1', 'c99cafc152244af753f735de768cd75f'], 7 ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], 8 ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'], 9 ['data_batch_4', '634d18415352ddfa80567beed471001a'], 10 ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'], 11 ] 12 13 test_list = [ 14 ['test_batch', '40351d587109b95175f43aff81a1287e'], 15 ] 16 17 def __init__(self, root, train=True, 18 transform=None, target_transform=None, 19 download=False): 20 self.root = os.path.expanduser(root) 21 self.transform = transform 22 self.target_transform = target_transform 23 self.train = train # training set or test set 24 25 if download: 26 self.download() 27 28 if not self._check_integrity(): 29 raise RuntimeError('Dataset not found or corrupted.' + 30 ' You can use download=True to download it') 31 32 # now load the picked numpy arrays 33 if self.train: 34 self.train_data = [] 35 self.train_labels = [] 36 for fentry in self.train_list: 37 f = fentry[0] 38 file = os.path.join(self.root, self.base_folder, f) 39 fo = open(file, 'rb') 40 if sys.version_info[0] == 2: 41 entry = pickle.load(fo) 42 else: 43 entry = pickle.load(fo, encoding='latin1') 44 self.train_data.append(entry['data']) 45 if 'labels' in entry: 46 self.train_labels += entry['labels'] 47 else: 48 self.train_labels += entry['fine_labels'] 49 fo.close() 50 51 self.train_data = np.concatenate(self.train_data) 52 self.train_data = self.train_data.reshape((50000, 3, 32, 32)) 53 self.train_data = self.train_data.transpose((0, 2, 3, 1)) # convert to HWC 54 else: 55 f = self.test_list[0][0] 56 file = os.path.join(self.root, self.base_folder, f) 57 fo = open(file, 'rb') 58 if sys.version_info[0] == 2: 59 entry = pickle.load(fo) 60 else: 61 entry = pickle.load(fo, encoding='latin1') 62 self.test_data = entry['data'] 63 if 'labels' in entry: 64 self.test_labels = entry['labels'] 65 else: 66 self.test_labels = entry['fine_labels'] 67 fo.close() 68 self.test_data = self.test_data.reshape((10000, 3, 32, 32)) 69 self.test_data = self.test_data.transpose((0, 2, 3, 1)) # convert to HWC 70 71 def __getitem__(self, index): 72 """ 73 Args: 74 index (int): Index 75 76 Returns: 77 tuple: (image, target) where target is index of the target class. 78 """ 79 if self.train: 80 img, target = self.train_data[index], self.train_labels[index] 81 else: 82 img, target = self.test_data[index], self.test_labels[index] 83 84 # doing this so that it is consistent with all other datasets 85 # to return a PIL Image 86 img = Image.fromarray(img) 87 88 if self.transform is not None: 89 img = self.transform(img) 90 91 if self.target_transform is not None: 92 target = self.target_transform(target) 93 94 return img, target