我从一维数组生成有效的成对组合之后。如果n> 1000,Itertools效率太低
E.g. [1, 2, 3, 4] magic code... Out[2]: array([[1, 2], [1, 3], [1, 4], [2, 3], [2, 4], [3, 4]])
最近的事情在这里。
一种方法是numba获取内存,从而提高性能-
numba
from numba import njit @njit def pairwise_combs_numba(a): n = len(a) L = n*(n-1)//2 out = np.empty((L,2),dtype=a.dtype) iterID = 0 for i in range(n): for j in range(i+1,n): out[iterID,0] = a[i] out[iterID,1] = a[j] iterID += 1 return out
另一个基于NumPy的np.broadcast_to控件将用于获取网格视图,然后进行遮罩-
np.broadcast_to
def pairwise_combs_mask(a): n = len(a) L = n*(n-1)//2 out = np.empty((L,2),dtype=a.dtype) m = ~np.tri(len(a),dtype=bool) out[:,0] = np.broadcast_to(a[:,None],(n,n))[m] out[:,1] = np.broadcast_to(a,(n,n))[m] return out
我们将扩展相同的方法,以使自己成为三元组合-
@njit def triplet_combs_numba(a): n = len(a) L = n*(n-1)*(n-2)//6 out = np.empty((L,3),dtype=a.dtype) iterID = 0 for i in range(n): for j in range(i+1,n): for k in range(j+1,n): out[iterID,0] = a[i] out[iterID,1] = a[j] out[iterID,2] = a[k] iterID += 1 return out def triplet_combs_mask(a): n = len(a) L = n*(n-1)*(n-2)//6 out = np.empty((L,3),dtype=a.dtype) r = np.arange(n) m = (r[:,None,None]<r[:,None]) & (r[:,None]<r) out[:,0] = np.broadcast_to(a[:,None,None],(n,n,n))[m] out[:,1] = np.broadcast_to(a[None,:,None],(n,n,n))[m] out[:,2] = np.broadcast_to(a[None,None,:],(n,n,n))[m] return out
高阶组合将同样扩展。
样品运行-
In [54]: a = np.array([3,9,4,1,7]) In [55]: pairwise_combs_numba(a) Out[55]: array([[3, 9], [3, 4], [3, 1], [3, 7], [9, 4], [9, 1], [9, 7], [4, 1], [4, 7], [1, 7]]) In [56]: triplet_combs_numba(a) Out[56]: array([[3, 9, 4], [3, 9, 1], [3, 9, 7], [3, 4, 1], [3, 4, 7], [3, 1, 7], [9, 4, 1], [9, 4, 7], [9, 1, 7], [4, 1, 7]])
时间(包括Python的内置- itertools.combinations)-
itertools.combinations
In [68]: a = np.random.rand(4000) In [69]: %timeit pairwise_combs_numba(a) ...: %timeit pairwise_combs_mask(a) ...: %timeit list(itertools.combinations(a, 2)) 10 loops, best of 3: 52.2 ms per loop 10 loops, best of 3: 146 ms per loop 1 loop, best of 3: 597 ms per loop In [70]: a = np.random.rand(400) In [71]: %timeit triplet_combs_numba(a) ...: %timeit triplet_combs_mask(a) ...: %timeit list(itertools.combinations(a, 3)) 10 loops, best of 3: 98.5 ms per loop 1 loop, best of 3: 352 ms per loop 1 loop, best of 3: 795 ms per loop