我有一个结构非常简单的真/假值数组:
# the real array has hundreds of thousands of items positions = np.array([True, False, False, False, True, True, True, True, False, False, False], dtype=np.bool)
我想遍历这个数组并输出发生变化的位置(真变为假或相反)。为此,我整理了两种不同的方法:
两个版本都给出了我想要的结果,但是 Numba 对其中一个版本的影响比另一个版本更大。使用 300k 个值的虚拟数组,性能结果如下:
300k 个元素的数组的性能结果 纯 Python 二分查找运行时间为 11 毫秒 纯 Python 迭代搜索运行时间为 1.1 秒(比二分搜索慢 100 倍) Numba 二进制搜索运行时间为 5 毫秒(比纯 Python 快 2 倍) Numba 迭代搜索运行时间为 900 µs (比纯 Python 快 1,200 倍)
因此,当使用 Numba 时,binary_search 比 iterative_search 慢 5 倍,而理论上它应该快 100 倍(如果适当加速,预计应该在 9 µs 内运行)。
如何才能让 Numba 加速二进制搜索,就像加速迭代搜索一样?
两种方法的代码(以及示例position数组)都可以在这个公共 gist 中找到:https ://gist.github.com/JivanRoquet/d58989aa0a4598e060ec2c705b9f3d8f
position
注意:Numba 未binary_search()在对象模式下运行,因为当提及时nopython=True,它不会抱怨并愉快地编译该函数。
binary_search()
nopython=True
在你描述的情况中,使用 Numba 加速二分查找算法时,性能不如预期,尤其是在与迭代搜索方法的比较中。对于这种情况,Numba的加速效果通常是受限的,因为它的优化更多适用于线性(逐项)处理,而非递归或分治类型的算法,如二分查找。
为了利用 Numba 的高效迭代优化,你可以将递归的二分查找转换为迭代版本。这样,算法仍然会采用二分查找的思路,但会避免递归的开销。
可以尝试使用 numba 的 prange 或显式并行化来优化 iterative_search,同时确保将函数 nopython=True。
numba
prange
iterative_search
我将提供一种迭代版本的二分查找实现,使用 Numba 提升其性能。
import numpy as np import numba from numba import jit # 迭代版二分查找 (避免递归) @jit(nopython=True) def binary_search_iterative(arr): changes = [] start, end = 0, len(arr) - 1 prev = arr[start] while start <= end: mid = (start + end) // 2 current = arr[mid] # 检查当前和前一个元素是否发生变化 if current != prev: changes.append(mid) prev = current # 继续向左或向右查找 if current < arr[mid]: end = mid - 1 else: start = mid + 1 return changes # 测试数据 positions = np.array([True, False, False, False, True, True, True, True, False, False, False], dtype=np.bool) # 使用 Numba 加速的二分查找 changes = binary_search_iterative(positions) print("发生变化的索引:", changes)
通过这些方法,你可以显著提升二分查找的执行速度,特别是对于大规模数据。