yolo系列是目标识别的重头戏,为了更好的理解掌握它,我们必须从源码出发深刻理解代码。下面我们来讲解pytorch实现的yolov3源码。在讲解之前,大家应该具备相应的原理知识yolov1,yolov2,yolov3。
大部分同学在看论文时并不能把所有的知识全部掌握。我们必须结合代码(代码将理论变成实践),它是百分百还原理论的,也只有在掌握代码以及理论后,我们才能推陈出新有所收获,所以大家平时一定多接触代码,这里我们会结合yolov3的理论知识让大家真正在代码中理解思想。
下面我就train过程的代码进行讲解。在理解train过程之前,建议大家先了解inference的代码讲解。
PyTorch实现yolov3代码详细解密
数据读取:
Pytorch读取图片,主要通过Dataset类,Dataset类作为所有的datasets的基类存在,所有的datasets都需要继承它。
class Dataset(object):
"""An abstract class representing a Dataset. All other datasets should subclass it. All subclasses should override ``__len__``, that provides the size of the dataset, and ``__getitem__``, supporting integer indexing in range from 0 to len(self) exclusive. """
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
这里重点看getitem函数,getitem接收一个index,返回图片数据和labels。我们看yolov3的dataset。
dataset = LoadImagesAndLabels(train_path, img_size=img_size, augment=True)
class LoadImagesAndLabels(Dataset): # for training/testing
def __init__(self, path, img_size=416, augment=False):
with open(path, 'r') as file:
img_files = file.read().splitlines()
self.img_files = list(filter(lambda x: len(x) > 0, img_files))
n = len(self.img_files)
assert n > 0, 'No images found in %s' % path
self.img_size = img_size
self.augment = augment
self.label_files = [
x.replace('images', 'labels').replace('.bmp', '.txt').replace('.jpg', '.txt').replace('.png', '.txt')
for x in self.img_files]
# if n < 200: # preload all images into memory if possible
# self.imgs = [cv2.imread(img_files[i]) for i in range(n)]
def __len__(self):
return len(self.img_files)
def __getitem__(self, index):
img_path = self.img_files[index]
label_path = self.label_files[index]
# if hasattr(self, 'imgs'):
# img = self.imgs[index] # BGR
img = cv2.imread(img_path) # BGR
assert img is not None, 'File Not Found ' + img_path
h, w, _ = img.shape
img, ratio, padw, padh = letterbox(img, height=self.img_size)
#将每幅图resize成418*418
# Load labels
labels = []
if os.path.isfile(label_path):
with open(label_path, 'r') as file:
lines = file.read().splitlines()
x = np.array([x.split() for x in lines], dtype=np.float32)
if x.size > 0:
# Normalized xywh to pixel xyxy format
labels = x.copy()
labels[:, 1] = ratio * w * (x[:, 1] - x[:, 3] / 2) + padw
labels[:, 2] = ratio * h * (x[:, 2] - x[:, 4] / 2) + padh
labels[:, 3] = ratio * w * (x[:, 1] + x[:, 3] / 2) + padw
labels[:, 4] = ratio * h * (x[:, 2] + x[:, 4] / 2) + padh
print(labels)
# Augment image and labels
if self.augment:
img, labels = random_affine(img, labels, degrees=(-5, 5), translate=(0.10, 0.10), scale=(0.90, 1.10))
nL = len(labels) # number of labels
if nL:
# convert xyxy to xywh
labels[:, 1:5] = xyxy2xywh(labels[:, 1:5]) / self.img_size
if self.augment:
# random left-right flip
lr_flip = True
if lr_flip and random.random() > 0.5:
img = np.fliplr(img)
if nL:
labels[:, 1] = 1 - labels[:, 1]
# random up-down flip
ud_flip = False
if ud_flip and random.random() > 0.5:
img = np.flipud(img)
if nL:
labels[:, 2] = 1 - labels[:, 2]
labels_out = torch.zeros((nL, 6))
if nL:
labels_out[:, 1:] = torch.from_numpy(labels)
# Normalize
img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
img = np.ascontiguousarray(img, dtype=np.float32) # uint8 to float32
img /= 255.0 # 0 - 255 to 0.0 - 1.0
return torch.from_numpy(img), labels_out, img_path, (h, w)
dataset = LoadImagesAndLabels(train_path, img_size=img_size, augment=True),可以看到其中LoadImagesAndLabels类是Dataset的子类,init函数是正常的读取数据,我们主要看getitem,getitem接收一个index,就是img_files的索引,通过letterbox函数进行数据预处理将每幅图resize成418*418,labels里面存放的是ground truth的类别和坐标信息,因为图像resize了,所以labels中的坐标信息也要相对变化。最后返回处理后的img,labels,地址和宽高。
那么读取自己数据的基本流程就是:
1:制作存储了图像的路径和标签信息的txt
2:将这些信息转化为list,该list每一个元素对应一个样本
3:通过getitem函数,读取数据标签,并返回。
在训练代码里是感觉不到这些操作的,只会看到通过DataLoader就可以获取一个batch的数据,其实触发去读取图片这些操作的是DataLoader里的__iter__(self),流程详细描述如下:
1.从dataset类中初始化txt,txt中有图片路径和标签
2.初始化DataLoder时,将dataset传入,从而使DataLoader拥有图片路径
3.在for i, (imgs, targets, _, _) in enumerate(dataloader):中,一个iteration进行时,读取一个batch的数据,enumerate将数据返回到imgs,targets中,imgs就是数据增强后的图像,labels就是处理后的标签。
4.读取过程中需要在class DataLoader()类中调用_DataLoderIter()
5.在 _DataLoderiter()类中跳到 next(self)函数,在该函数中通过indices = next(self.sample_iter)获取一个batch的indices,再通过batch=self.collate_fn()获取一个batch数据。
6.self.collate_fn中调用LoadImagesAndLabels类中的 getitem()函数,再函数中获取图片。
如此,我们第一步数据预处理就完成了,后面我们就可以把数据imgs放到模型里跑了。大家不要忽视这些代码,想真正弄懂,我们就要一步一步刨根问底。
下面一章,我们会根据程序复现训练过程的算法原理,讲解yolov3的loss是如何计算的。