一尘不染

用Cython实现Numba的演奏

python

通常,使用Cython时,我能够与Numba媲美。但是,在此示例中,我没有这样做-Numba比我的Cython版本快4倍。

以下是Cython版本:

%%cython -c=-march=native -c=-O3
cimport numpy as np
import numpy as np
cimport cython

@cython.boundscheck(False)
@cython.wraparound(False)
def cy_where(double[::1] df):
    cdef int i
    cdef int n = len(df)
    cdef np.ndarray[dtype=double] output = np.empty(n, dtype=np.float64)
    for i in range(n):
        if df[i]>0.5:
            output[i] = 2.0*df[i]
        else:
            output[i] = df[i]
    return output

这是Numba版本:

import numba as nb
@nb.njit
def nb_where(df):
    n = len(df)
    output = np.empty(n, dtype=np.float64)
    for i in range(n):
        if df[i]>0.5:
            output[i] = 2.0*df[i]
        else:
            output[i] = df[i]
    return output

经过测试,Cython版本与numpy的版本相当where,但明显不及Numba:

#Python3.6 + Cython 0.28.3 + gcc-7.2
import numpy
np.random.seed(0)
n = 10000000
data = np.random.random(n)

assert (cy_where(data)==nb_where(data)).all()
assert (np.where(data>0.5,2*data, data)==nb_where(data)).all()

%timeit cy_where(data)       # 179ms
%timeit nb_where(data)       # 49ms (!!)
%timeit np.where(data>0.5,2*data, data)  # 278 ms

Numba性能的原因是什么?在使用Cython时如何匹配?


正如@ max9111所建议的那样,通过使用连续的内存视图消除了跨步,但这并不能显着提高性能:

@cython.boundscheck(False)
@cython.wraparound(False)
def cy_where_cont(double[::1] df):
    cdef int i
    cdef int n = len(df)
    cdef np.ndarray[dtype=double] output = np.empty(n, dtype=np.float64)
    cdef double[::1] view = output  # view as continuous!
    for i in range(n):
        if df[i]>0.5:
            view[i] = 2.0*df[i]
        else:
            view[i] = df[i]
    return output

%timeit cy_where_cont(data)   #  165 ms

阅读 307

收藏
2021-01-20

共1个答案

一尘不染

这似乎完全由LLVM能够进行的优化驱动。如果我用clang编译cython示例,则两个示例之间的性能是相同的。就其价值而言,Windows上的MSVC表现出与numba类似的性能差异。

$ CC=clang ipython
<... setup code>

In [7]: %timeit cy_where(data)       # 179ms
   ...: %timeit nb_where(data)       # 49ms (!!)

30.8 ms ± 309 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
30.2 ms ± 498 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
2021-01-20