深度学习发展到现在,各路大神都发展出了各种模型。在深度学习实现过程中最重要的最花时间的应该是数据预处理与后处理,会极大影响最后效果,至于模型,感觉像是拼乐高积木,一个模块一个模块地叠加,拼成最适合自己的模型。
1 数据预处理
1.1 图像切割
一般而言,训练集会是一整张大图,所以需要自己切割成小图训练,可以做切割,也可以在训练时划窗读取,最好先做切割,可以检查数据。切割的图片大小根据服务器性能来看,12G的GPU切为256或512的比较合适一些。
切割的时候最好有重叠的切割,至于重叠率可以根据实际情况自己做一些尝试,这样可以尽量避免将要识别的物体切割,导致模型训练时不能很好地识别该类物体。同理,在模型预测时,在边缘部分的预测结果也不准确,需要重叠切割,并且取中间部分的结果,舍弃边缘部分。
1.2 数据平衡
待识别的物体如果占比不平衡,则会很影响模型,比如草地占比99%,喷泉占比1%的数据,只要模型将所有物体都分类为草地,那么模型的准确率也会达到99。
所以我们希望模型的数据占比尽量达到均衡。为此有很多策略可以使用,先进行数据统计,数据分析后再制定策略。
1.2.1 数据统计
拿到数据后一般需要分析标签中各个类别的占比。
import gdal
import numpy as np
data_path = r'/home/fsl/image_2_label.png'
src = gdal.Open(data_path).ReadAsArray()
n_class0 = np.sum(np.where(src==0))
n_class1 = np.sum(np.where(src==1))
n_class2 = np.sum(np.where(src==2))
n_class3 = np.sum(np.where(src==3))
n_class4 = np.sum(np.where(src==4))
sum = src.shape[0]*src.shape[1]
print("背景:{},第一类:{},第二类:{},第三类:{},第四类:{}".format(n_class0/sum ,n_class1/sum ,n_class2/sum ,n_class3/sum ,n_class4/sum ))
1.2.2 策略制定
最简单的处理方法就是在数据切图的时候处理或在计算loss时处理。
比如背景占比0.7,玉米占比0.02,草地占比0.2,薏仁米占比0.08,这种情况下背景占比过高,在切图时判断这张小图背景占比是否高于7/8,若高于这个阈值,则丢掉这张图片,若背景占比低,则不作处理或增加这张图与上一张图的采样重叠率,这样可以增大非背景的像素数量。
其它三类占比大致差不多,但是玉米与薏仁米相对草地来说少了一个量级,所以可以对玉米与薏仁米占比大于7/8的小图做图像增强(反转,旋转等)。图像增强Pytorch与Tensorflow都有提供相应的库,可以直接调用。
其次,在计算loss时,可以增加小类别的权重,比如玉米与薏仁米的权重应该要比草地与背景的权重大。这篇博客可以参考。
1.3 图像增强(数据扩充)
一般什么时候会用到数据增强呢?当数据集较少,以及数据没有实际场景那么丰富的时候,比如实际场景中图片色彩可能偏红可能偏蓝,但是拿到的训练数据都是偏红的,那就需要对图片做图像增强,将色彩调整为偏蓝加入训练集。
一般而言,训练模型为了增加模型的适应性,都需要做图像增强,扩充图像的多样性。
pytorch做图像增强,tensorflow做图像增强。
2 后处理
2.1 模型预测
2.1.1 膨胀预测
由于在图像边界部分模型预测效果不好,所以直接将影像切成512*512的图来预测再拼接会导致每个512*512影像之间交接的部分存在明显接缝,所以需要有重叠的切图,然后预测结果只取每个512*512影像的中间部分的结果。
2.1.2 多模型预测
在一个训练集上训练多个模型,用多个模型的输出取平均可以很好的提升模型效果。
同样的,预测时,对图像做几个数据增强,分别输入模型进行训练,将输出取平均也会有一定效果。
2.2.2 图像增强
除了在训练时需要图像增强以外,在测试时也需要做图像增强,通过测试结果来取平均,可以一定程度上避免训练集缺乏多样性的问题,但是也会成倍增加测试时间。
def pred_aug(img, model):
img90 = torch.rot90(img, 1, dims=(2,3))
img_hori = torch.flip(img, [2])
img_vert = torch.flip(img, [3])
# 预测结果
pred = model(img)
pred90 = model(img90)
pred_hori = model(img_hori)
pred_vert = model(img_vert)
# 将影像预测结果逆操作回原图
pred90 = torch.rot90(pred90, 3, dims=(2,3))
pred_hori = torch.flip(pred_hori, [2])
pred_vert = torch.flip(pred_vert, [3])
# 做softmax
pred = torch.nn.functional.softmax(pred, dim=1)
pred90 = torch.nn.functional.softmax(pred90, dim=1)
pred_hori = torch.nn.functional.softmax(pred_hori, dim=1)
pred_vert = torch.nn.functional.softmax(pred_vert, dim=1)
pred = pred + pred90 + pred_hori + pred_vert
pred = torch.argmax(pred, dim=1)
pred.squeeze_()
return pred
2.2 模型结果赋予颜色
模型的预测结果一般是每个类别的概率值,需要先用argmax转换为类别值0,1,2… 然后再将这个类别值转换为rgb三个通道的值。将类别值转为rgb的代码如下所示:
def write_img(pred_images, filename, ori_image):
pred = pred_images[0] # pred的shape为(高,宽,类别数量)
COLORMAP = [[255,255,255],[0, 255, 0], [0, 0, 255], [0, 0, 255]] # 分别为0-3类对应的颜色
cm = np.array(COLORMAP).astype(np.uint8)
pred = np.argmax(np.array(pred), axis=2) # 此时pred的shape为(高,宽)
pred_val = cm[pred] # pred_val的shape为(高,宽,3)
# 将模型预测结果叠加在原图上
overlap = cv2.addWeighted(ori_image, 0.7, pred_val, 0.3, 0)
cv2.imwrite(os.path.join("gdrive", "My Drive", "data", "deeplab", filename.split("/")[-1]), overlap)
2.3 模型压缩
2.3.1 知识蒸馏
2.4 优化结果
2.4.1 区域连通,去除噪点
def area_connection(result, n_class,area_threshold,)
"""
result:预测影像
area_threshold:最小连通尺寸,小于该尺寸的都删掉
"""
result = to_categorical(result, num_classes=n_class, dtype='uint8') # 转为one-hot
for i in tqdm(range(n_class)):
# 去除小物体
result[:, :, i] = skimage.morphology.remove_small_objects(result[:, :, i] == 1, min_size=area_threshold, connectivity=1, in_place=True)
# 去除孔洞
result[:, :, i] = skimage.morphology.remove_small_holes(result[:, :, i] == 1, area_threshold=area_threshold, connectivity=1, in_place=True)
# 获取最终label
result = np.argmax(result, axis=2).astype(np.uint8)
return result
2.4.2 CRF条件随机场
3 其它
3.1 计算IOU
def cal_iou(target, pred,n_class=25):
"""
target是真实标签,shape为(h,w),像素值为0,1,2...
pred是预测结果,shape为(h,w),像素值为0,1,2...
n_class:为预测类别数量
"""
h,w = target.shape
# 转为one-hot,shape变为(h,w,n_class)
target_one_hot = np.eye(n_class)[target]
pred_one_hot = np.eye(n_class)[pred]
target_one_hot[target_one_hot!=0]=1
pred_one_hot[pred_one_hot!=0] = 1
join_result = target_one_hot*pred_one_hot
join_sum = np.sum(np.where(join_result==1)) # 计算相交的像素数量
pred_sum =np.sum(np.where(pred_one_hot==1)) # 计算预测结果非0得像素数
target_sum = np.sum(np.where(target_one_hot==1)) # 计算真实标签的非0得像素数
iou = join_sum/(pred_sum + target_sum - join_sum + 1e-6)
return iou