我有一个...X n x m数组,比如a
,其中...代表任意数量的附加维度.为简单起见,我们称n维为"行",称m维为"列",即使该数组是高维的.
我还有一个长度为n的向量v
,它包含最后一个维度的索引(从0到m-1).我想创建一个数组b
,它使用这个向量来提取每一行的指示列.
使用循环可以很容易地做到这一点.以下是一个最小的工作示例:
import numpy as np
a = np.round(np.random.rand(2,3,4)*10)
v = [0, 2, 1]
print(a)
"""[[[ 1. 6. 9. 9.]
[ 1. 8. 4. 10.]
[ 0. 0. 5. 3.]]
[[ 7. 8. 1. 10.]
[ 7. 9. 7. 8.]
[ 3. 4. 8. 7.]]]
"""
b = []
for i in range(len(v)):
b.append(a.take(i, axis=-2).take(v[i], axis=-1))
b = np.asarray(b)
print(b)
"""
[[1. 7.]
[4. 7.]
[0. 4.]]
"""
有没有更聪明的方法在不循环的情况下进行这种索引?