IIUC,您想要搜索列中arr2
中1
都是1
的位置上的arr1
:
arr1 = np.array(
[
[0.1, 0.02, 0.2, 0.3, 0.013, 0.7, 0.7, 0.11, 0.18, 0.6],
[0.23, 0.02, 0.1, 0.1, 0.011, 0.3, 0.4, 0.4, 0.4, 0.5],
]
)
arr2 = np.array([[0, 1, 0, 0, 0, 0, 1, 1, 0, 1], [1, 0, 1, 0, 0, 0, 0, 1, 1, 0]])
out = arr1[:, np.all(arr2, axis=0)]
print(out)
打印:
[[0.11]
[0.4 ]]
如果要查找所有组合:
unique = np.unique(arr2.T, axis=0)
for row in unique:
print("Combination:")
print(row)
print()
print(arr1[:, np.all(arr2 == row.reshape(arr2.shape[0], -1), axis=0)])
print("-" * 80)
打印:
Combination:
[0 0]
[[0.3 0.013 0.7 ]
[0.1 0.011 0.3 ]]
--------------------------------------------------------------------------------
Combination:
[0 1]
[[0.1 0.2 0.18]
[0.23 0.1 0.4 ]]
--------------------------------------------------------------------------------
Combination:
[1 0]
[[0.02 0.7 0.6 ]
[0.02 0.4 0.5 ]]
--------------------------------------------------------------------------------
Combination:
[1 1]
[[0.11]
[0.4 ]]
--------------------------------------------------------------------------------