Pytorch 中的 torch.gather 函数

官方文档

torch.gather(input, dim, index, out=None) → Tensor
Gathers values along an axis specified by dim.
For a 3-D tensor the output is specified by: 
    out[i][j][k] = input[index[i][j][k]][j][k] # dim=0 
    out[i][j][k] = input[i][index[i][j][k]][k] # dim=1 
    out[i][j][k] = input[i][j][index[i][j][k]] # dim=2 
Parameters: 
    input (Tensor) – The source tensor 
    dim (int) – The axis along which to index 
    index (LongTensor) – The indices of elements to gather 
    out (Tensor, optional) – Destination tensor 
 
Example: 
   >>> t = torch.Tensor([[1,2],[3,4]]) 
   >>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]])) 
     1 1
     4 3 
[torch.FloatTensor of size 2x2]

torch.gather 函数用于从参数 t 选择性输出特定 index 的矩阵,输出矩阵的大小跟 index 的大小是一样的,torch.gather 的 dim 参数用来选择 index 作用的 axis。

构建 2×18×2×2 的矩阵 a,

# a[:,i,:,:] = i
a = torch.arange(18).unsqueeze(0).unsqueeze(-1).unsqueeze(-1).expand(2,18,2,2)
a[:,2,:,:]
Out[24]: 
tensor([[[ 2.,  2.],
         [ 2.,  2.]],
        [[ 2.,  2.],
         [ 2.,  2.]]])
a[:,16,:,:]
Out[25]: 
tensor([[[ 16.,  16.],
         [ 16.,  16.]],
        [[ 16.,  16.],
         [ 16.,  16.]]])

现在要通过 torch.gather 函数把 a 变成 offset[:,:9,:,:] = [0,2,…16],offset[:,9:,:,:] = [1,3,..,17]

N = 9
offsets_index = Variable(torch.cat([torch.arange(0, 2*N, 2), torch.arange(1, 2*N+1, 2)]), requires_grad=False).long()
offsets_index
Out[29]: 
tensor([  0,   2,   4,   6,   8,  10,  12,  14,  16,   1,   3,   5,
          7,   9,  11,  13,  15,  17])
offsets_index = offsets_index.unsqueeze(dim=0).unsqueeze(dim=-1).unsqueeze(dim=-1).expand(*a.size())
offset = torch.gather(a, dim=1, index=offsets_index)
offset[:,0,:,:]
Out[34]: 
tensor([[[ 0.,  0.],
         [ 0.,  0.]],
        [[ 0.,  0.],
         [ 0.,  0.]]])
offset[:,1,:,:]
Out[35]: 
tensor([[[ 2.,  2.],
         [ 2.,  2.]],
        [[ 2.,  2.],
         [ 2.,  2.]]])
offset[:,8,:,:]
Out[39]: 
tensor([[[ 16.,  16.],
         [ 16.,  16.]],
        [[ 16.,  16.],
         [ 16.,  16.]]])
offset[:,9,:,:]
Out[40]: 
tensor([[[ 1.,  1.],
         [ 1.,  1.]],
        [[ 1.,  1.],
         [ 1.,  1.]]])
offset[:,17,:,:]
Out[41]: 
tensor([[[ 17.,  17.],
         [ 17.,  17.]],
        [[ 17.,  17.],
         [ 17.,  17.]]])

代码帖的有点多,主要是为了验证效果。

offset 的输出规则如下:

offset[i][j][k][s] = input[i][offsets_index[i][j][k][s]][k][s] # dim=1 

因为 dim = 1,offsets_index 影响 axis = 1 的维度,offset[i][j][k][s] 由 input 根据 offsets_index 在 axis=1 维度用 offsets_index[i][j][k][s] 作为索引,其他的位置不变,同理其他维度改变就用 index[i][j][k][s] 作为对应 axis 的索引 。

最终输出 offset 的时候,offset[:][1][:][:] 的数据只是选择了 input 在 axis=1 上 input[:][2][:][:] 的所有数据,在第 axis=0,2,3 维度 input 的索引和 offset 是对应的,所以 offset 在相应位置上的数据和 input 一样。

    原文作者:雪花飘满地
    原文地址: https://zhuanlan.zhihu.com/p/52464651
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞