一尘不染

搜索具有多个值的Numpy数组

python

我有具有重复值的numpy 2d数组。

我正在搜索这样的数组。

In [104]: import numpy as np

In [105]: array = np.array

In [106]: a = array([[1, 2, 3],
     ...:            [1, 2, 3],
     ...:            [2, 5, 6],
     ...:            [3, 8, 9],
     ...:            [4, 8, 9],
     ...:            [4, 2, 3],
     ...:            [5, 2, 3])

In [107]: num_list = [1, 4, 5]

In [108]: for i in num_list :
     ...:     print(a[np.where(a[:,0] == num_list)])
     ...:
 [[1 2 3]
 [1 2 3]]
[[4 8 9]
 [4 2 3]]
[[5 2 3]]

输入是列表,其编号类似于列0的值。我想要的最终结果是任何形式的结果行,例如数组,列表或元组

array([[1, 2, 3],
       [1, 2, 3],
       [4, 8, 9],
       [4, 2, 3],
       [5, 2, 3]])

我的代码工作正常,但似乎不是pythonic。有没有更好的多值搜索策略?

就像a[np.where(a[:,0] == l)]只进行一次查找即可获取所有值的地方。

我的真实数组很大


阅读 145

收藏
2021-01-20

共1个答案

一尘不染

方法1:
使用np.in1d-

a[np.in1d(a[:,0], num_list)]

方法2:
使用np.searchsorted-

num_arr = np.sort(num_list) # Sort num_list and get as array

# Get indices of occurrences of first column in num_list
idx = np.searchsorted(num_arr, a[:,0])

# Take care of out of bounds cases
idx[idx==len(num_arr)] = 0

out = a[a[:,0] == num_arr[idx]]
2021-01-20