一、所使用的函数介绍
1. find_classes
def find_classes(dir):
# 得到指定目录下的所有文件,并将其名字和指定目录的路径合并
# 以数组的形式存在classes中
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
# 使用sort()进行简单的排序
classes.sort()
# 将其保存的路径排序后简单地映射到 0 ~ [ len(classes)-1] 的数字上
class_to_idx = {classes[i]: i for i in range(len(classes))}
# 返回存放路径的数组和存放其映射后的序号的数组
return classes, class_to_idx
2. has_file_allowed_extension
def has_file_allowed_extension(filename, extensions):
"""Checks if a file is an allowed extension.
Args:
filename (string): path to a file
Returns:
bool: True if the filename ends with a known image extension
"""
# 将文件的名变成小写
filename_lower = filename.lower()
# endswith() 方法用于判断字符串是否以指定后缀结尾
# 如果以指定后缀结尾返回True,否则返回False
return any(filename_lower.endswith(ext) for ext in extensions)
3. make_dataset
def make_dataset(dir, class_to_idx, extensions):
images = []
# expanduser把path中包含的"~"和"~user"转换成用户目录
# 主要还是在Linux之类的系统中使用,在不包含"~"和"~user"时
# dir不变
dir = os.path.expanduser(dir)
# 排序后按顺序通过for循环dir路径下的所有文件名
for target in sorted(os.listdir(dir)):
# 将路径拼合
d = os.path.join(dir, target)
# 如果拼接后不是文件目录,则跳出这次循环
if not os.path.isdir(d):
continue
# os.walk(d) 返回的fnames是当前d目录下所有的文件名
# 注意:第一个for其实就只循环一次,返回的fnames 是一个数组
for root, _, fnames in sorted(os.walk(d)):
# 循环每一个文件名
for fname in sorted(fnames):
# 文件的后缀名是否符合给定
if has_file_allowed_extension(fname, extensions):
# 组合路径
path = os.path.join(root, fname)
# 将组合后的路径和该文件位于哪一个序号的文件夹下的序号
# 组成元祖
item = (path, class_to_idx[target])
# 将其存入数组中
images.append(item)
return images
注意:
下面三个函数都是加载图像的函数,用于ImageFolder类中
4. pil_loader
def pil_loader(path):
# open path as file to avoid ResourceWarning
# (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
5. accimage_loader
def accimage_loader(path):
import accimage
try:
return accimage.Image(path)
except IOError:
# Potentially a decoding problem, fall back to PIL.Image
return pil_loader(path)
6. default_loader
def default_loader(path):
from torchvision import get_image_backend
if get_image_backend() == 'accimage':
return accimage_loader(path)
else:
return pil_loader(path)
用于定义读入文件的格式
IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']
二、关键类
1. DatasetFolder
class DatasetFolder(data.Dataset):
"""A generic data loader where the samples are arranged in this way: ::
root/class_x/xxx.ext
root/class_x/xxy.ext
root/class_x/xxz.ext
root/class_y/123.ext
root/class_y/nsdf3.ext
root/class_y/asd932_.ext
Args:
root (string): Root directory path.
loader (callable): A function to load a sample given its path.
extensions (list[string]): A list of allowed extensions.
transform (callable, optional): A function/transform that takes in
a sample and returns a transformed version.
E.g, ``transforms.RandomCrop`` for images.
target_transform (callable, optional): A function/transform that takes
in the target and transforms it.
Attributes:
classes (list): List of the class names.
class_to_idx (dict): Dict with items (class_name, class_index).
samples (list): List of (sample path, class_index) tuples
"""
def __init__(self, root, loader, extensions, transform=None, target_transform=None):
# 得到root下的文件路径数组和文件映射后的序号数组
classes, class_to_idx = find_classes(root)
# 得到所有文件的路径和其所在文件夹的序号所组成的集合的数组
samples = make_dataset(root, class_to_idx, extensions)
# 如果在指定的路径上没有得到文件那么便抛出一个异常
if len(samples) == 0:
raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n"
"Supported extensions are: " + ",".join(extensions)))
self.root = root
self.loader = loader
self.extensions = extensions
self.classes = classes
self.class_to_idx = class_to_idx
self.samples = samples
self.transform = transform
self.target_transform = target_transform
# 如果在类中定义了__getitem__()方法,那么其实例对象(假设为P)
# 就可以这样P[key]取值。
# 当实例对象做P[key]运算时,就会调用类中的__getitem__(key)方法
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
# 得到文件路径和其所属的文件夹的序号
path, target = self.samples[index]
# 加载数据
sample = self.loader(path)
# 是否对读入的数据进行处理
# 主要包括转化张量和一些数据增强的方法
if self.transform is not None:
sample = self.transform(sample)
# 是否对所属的文件夹的序号进行处理
# 由于torchvision是按文件夹来得到数据的标签值
# 所以这里的序号其实就是分类的标签
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
# 数据数量
def __len__(self):
return len(self.samples)
# 生成报告,报告一些必要信息
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
2. ImageFolder
# 继承自DatasetFolder,只是在DatasetFolder基础上将加载的文件格式
# 数据加载函数给定义了
class ImageFolder(DatasetFolder):
"""A generic data loader where the images are arranged in this way: ::
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
Args:
root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
Attributes:
classes (list): List of the class names.
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples
"""
# default_loader 数据加载函数
def __init__(self, root, transform=None, target_transform=None,
loader=default_loader):
# root 路径
# IMG_EXTENSIONS 定义了读取的文件类型
# transform和target_transform 数据和标签处理
super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
transform=transform,
target_transform=target_transform)
self.imgs = self.samples