我正在从深度学习书(第7章,CNN)中读到
implementation of im2col,其目的是将4维数组转换为2维.我不知道为什么在实现中有一个6维数组.我对作者使用的算法背后的想法非常感兴趣.
我试图搜索很多关于im2col实现的论文,但是没有一个像这样使用高维数组.我发现目前用于可视化im2col过程的材料是this paper – HAL Id: inria-00112631的图片
def im2col(input_data, filter_h, filter_w, stride=1, pad=0):
"""
Parameters
----------
input_data : (batch size, channel, height, width), or (N,C,H,W) at below
filter_h : kernel height
filter_w : kernel width
stride : size of stride
pad : size of padding
Returns
-------
col : two dimensional array
"""
N, C, H, W = input_data.shape
out_h = (H + 2*pad - filter_h)//stride + 1
out_w = (W + 2*pad - filter_w)//stride + 1
img = np.pad(input_data, [(0,0), (0,0), (pad, pad), (pad, pad)], 'constant')
col = np.zeros((N, C, filter_h, filter_w, out_h, out_w))
for y in range(filter_h):
y_max = y + stride*out_h
for x in range(filter_w):
x_max = x + stride*out_w
col[:, :, y, x, :, :] = img[:, :, y:y_max:stride, x:x_max:stride]
col = col.transpose(0, 4, 5, 1, 2, 3).reshape(N*out_h*out_w, -1)
return col
最佳答案 让我们试着看看im2col的作用.它的输入是一堆彩色图像,堆栈具有尺寸图像id,颜色通道,垂直位置,水平位置.让我们简单地假设我们只有一个图像:
它首先做的是填充:
接下来,它将其切割成窗口.窗口的大小由filter_h / w控制,由strides重叠.
这是六个维度的来源:图像ID(示例中缺少,因为我们只有一个图像),网格高度/宽度,颜色通道.窗户高度/宽度.
现在的算法有点笨拙,它以错误的维度顺序组装输出,然后必须使用转置来纠正它.
最好先把它弄好:
def im2col_better(input_data, filter_h, filter_w, stride=1, pad=0):
img = np.pad(input_data, [(0,0), (0,0), (pad, pad), (pad, pad)], 'constant')
N, C, H, W = img.shape
out_h = (H - filter_h)//stride + 1
out_w = (W - filter_w)//stride + 1
col = np.zeros((N, out_h, out_w, C, filter_h, filter_w))
for y in range(out_h):
for x in range(out_w):
col[:, y, x] = img[
..., y*stride:y*stride+filter_h, x*stride:x*stride+filter_w]
return col.reshape(np.multiply.reduceat(col.shape, (0, 3)))
正如旁注:我们可以使用stride_tricks做得更好,并避免嵌套for循环:
def im2col_best(input_data, filter_h, filter_w, stride=1, pad=0):
img = np.pad(input_data, [(0,0), (0,0), (pad, pad), (pad, pad)], 'constant')
N, C, H, W = img.shape
NN, CC, HH, WW = img.strides
out_h = (H - filter_h)//stride + 1
out_w = (W - filter_w)//stride + 1
col = np.lib.stride_tricks.as_strided(img, (N, out_h, out_w, C, filter_h, filter_w), (NN, stride * HH, stride * WW, CC, HH, WW)).astype(float)
return col.reshape(np.multiply.reduceat(col.shape, (0, 3)))
算法做的最后一件事是重新整形,合并前三个维度(在我们的例子中只有两个,因为只有一个图像).红色箭头显示各个窗口如何排列成第一个新维度:
最后三个维度颜色通道,窗口中的y坐标,窗口中的x坐标合并到第二个输出维度.单个像素排列如黄色箭头所示: