python – numpy多维索引和函数’take’

在一周中的奇数天,我几乎理解numpy中的多维索引. Numpy有一个功能
‘take’似乎做我想要的但是有额外的奖励我可以控制如果索引超出范围会发生什么

具体来说,我有一个三维数组要求作为查找表

lut = np.ones([13,13,13],np.bool)

以及一个2×2的3长向量数组,作为表中的索引

arr = np.arange(12).reshape([2,2,3]) % 13 

IIUC,如果我要写lut [arr]那么arr被视为2x2x3数字数组,当它们被用作lut的索引时,它们每个都返回一个13×13数组.这解释了为什么lut [arr] .shape是(2,2,3,13,13).

我可以通过写作让它做我想做的事

lut[ arr[:,:,0],arr[:,:,1],arr[:,:,2] ] #(is there a better way to write this?)

现在这三个术语就好像它们已被压缩以生成一个2×2的元组数组而lut [< tuple>]从lut生成一个元素.最终结果是来自lut的2×2条目数组,正如我想要的那样.

我已经阅读了’take’功能的文档……

This function does the same thing as “fancy” indexing
(indexing arrays using arrays); however, it can be easier
to use if you need elements along a given axis.

axis : int, optional
The axis over which to select values.

也许是天真的,我认为设置axis = 2我会得到三个值用作3元组来执行查找但实际上

np.take(lut,arr).shape =  (2, 2, 3)
np.take(lut,arr,axis=0).shape =  (2, 2, 3, 13, 13)
np.take(lut,arr,axis=1).shape =  (13, 2, 2, 3, 13)
np.take(lut,arr,axis=2).shape =  (13, 13, 2, 2, 3)

所以我很清楚我不明白发生了什么.任何人都可以告诉我如何实现我想要的东西吗?

最佳答案 我们可以计算线性指数,然后使用
np.take

np.take(lut, np.ravel_multi_index(arr.T, lut.shape)).T

如果您对替代方案持开放态度,我们可以将indices数组重新整形为2D,转换为元组,使用它索引到数据数组中,为我们提供1D,可以将其重新形成为2D –

lut[tuple(arr.reshape(-1,arr.shape[-1]).T)].reshape(arr.shape[:2])

样品运行 –

In [49]: lut = np.random.randint(11,99,(13,13,13))

In [50]: arr = np.arange(12).reshape([2,2,3])

In [51]: lut[ arr[:,:,0],arr[:,:,1],arr[:,:,2] ] # Original approach
Out[51]: 
array([[41, 21],
       [94, 22]])

In [52]: np.take(lut, np.ravel_multi_index(arr.T, lut.shape)).T
Out[52]: 
array([[41, 21],
       [94, 22]])

In [53]: lut[tuple(arr.reshape(-1,arr.shape[-1]).T)].reshape(arr.shape[:2])
Out[53]: 
array([[41, 21],
       [94, 22]])

我们可以避免np.take方法的双重转置,就像这样 –

In [55]: np.take(lut, np.ravel_multi_index(arr.transpose(2,0,1), lut.shape))
Out[55]: 
array([[41, 21],
       [94, 22]])

推广到通用维度的多维数组

这可以推广到通用号的ndarray.昏暗的,像这样 –

np.take(lut, np.ravel_multi_index(np.rollaxis(arr,-1,0), lut.shape))

基于元组的方法应该没有任何改变.

这是一个相同的样本运行 –

In [95]: lut = np.random.randint(11,99,(13,13,13,13))

In [96]: arr = np.random.randint(0,13,(2,3,4,4))

In [97]: lut[ arr[:,:,:,0] , arr[:,:,:,1],arr[:,:,:,2],arr[:,:,:,3] ]
Out[97]: 
array([[[95, 11, 40, 75],
        [38, 82, 11, 38],
        [30, 53, 69, 21]],

       [[61, 74, 33, 94],
        [90, 35, 89, 72],
        [52, 64, 85, 22]]])

In [98]: np.take(lut, np.ravel_multi_index(np.rollaxis(arr,-1,0), lut.shape))
Out[98]: 
array([[[95, 11, 40, 75],
        [38, 82, 11, 38],
        [30, 53, 69, 21]],

       [[61, 74, 33, 94],
        [90, 35, 89, 72],
        [52, 64, 85, 22]]])
点赞