官方文档
torch.gather(input, dim, index, out=None) → Tensor
Gathers values along an axis specified by dim.
For a 3-D tensor the output is specified by:
out[i][j][k] = input[index[i][j][k]][j][k] # dim=0
out[i][j][k] = input[i][index[i][j][k]][k] # dim=1
out[i][j][k] = input[i][j][index[i][j][k]] # dim=2
Parameters:
input (Tensor) – The source tensor
dim (int) – The axis along which to index
index (LongTensor) – The indices of elements to gather
out (Tensor, optional) – Destination tensor
Example:
>>> t = torch.Tensor([[1,2],[3,4]])
>>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))
1 1
4 3
[torch.FloatTensor of size 2x2]
torch.gather 函数用于从参数 t 选择性输出特定 index 的矩阵,输出矩阵的大小跟 index 的大小是一样的,torch.gather 的 dim 参数用来选择 index 作用的 axis。
构建 2×18×2×2 的矩阵 a,
# a[:,i,:,:] = i
a = torch.arange(18).unsqueeze(0).unsqueeze(-1).unsqueeze(-1).expand(2,18,2,2)
a[:,2,:,:]
Out[24]:
tensor([[[ 2., 2.],
[ 2., 2.]],
[[ 2., 2.],
[ 2., 2.]]])
a[:,16,:,:]
Out[25]:
tensor([[[ 16., 16.],
[ 16., 16.]],
[[ 16., 16.],
[ 16., 16.]]])
现在要通过 torch.gather 函数把 a 变成 offset[:,:9,:,:] = [0,2,…16],offset[:,9:,:,:] = [1,3,..,17]
N = 9
offsets_index = Variable(torch.cat([torch.arange(0, 2*N, 2), torch.arange(1, 2*N+1, 2)]), requires_grad=False).long()
offsets_index
Out[29]:
tensor([ 0, 2, 4, 6, 8, 10, 12, 14, 16, 1, 3, 5,
7, 9, 11, 13, 15, 17])
offsets_index = offsets_index.unsqueeze(dim=0).unsqueeze(dim=-1).unsqueeze(dim=-1).expand(*a.size())
offset = torch.gather(a, dim=1, index=offsets_index)
offset[:,0,:,:]
Out[34]:
tensor([[[ 0., 0.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 0.]]])
offset[:,1,:,:]
Out[35]:
tensor([[[ 2., 2.],
[ 2., 2.]],
[[ 2., 2.],
[ 2., 2.]]])
offset[:,8,:,:]
Out[39]:
tensor([[[ 16., 16.],
[ 16., 16.]],
[[ 16., 16.],
[ 16., 16.]]])
offset[:,9,:,:]
Out[40]:
tensor([[[ 1., 1.],
[ 1., 1.]],
[[ 1., 1.],
[ 1., 1.]]])
offset[:,17,:,:]
Out[41]:
tensor([[[ 17., 17.],
[ 17., 17.]],
[[ 17., 17.],
[ 17., 17.]]])
代码帖的有点多,主要是为了验证效果。
offset 的输出规则如下:
offset[i][j][k][s] = input[i][offsets_index[i][j][k][s]][k][s] # dim=1
因为 dim = 1,offsets_index 影响 axis = 1 的维度,offset[i][j][k][s] 由 input 根据 offsets_index 在 axis=1 维度用 offsets_index[i][j][k][s] 作为索引,其他的位置不变,同理其他维度改变就用 index[i][j][k][s] 作为对应 axis 的索引 。
最终输出 offset 的时候,offset[:][1][:][:] 的数据只是选择了 input 在 axis=1 上 input[:][2][:][:] 的所有数据,在第 axis=0,2,3 维度 input 的索引和 offset 是对应的,所以 offset 在相应位置上的数据和 input 一样。