小能豆

为什么 Numba 没有改进这个递归函数

py

我有一个结构非常简单的真/假值数组:

# the real array has hundreds of thousands of items
positions = np.array([True, False, False, False, True, True, True, True, False, False, False], dtype=np.bool)

我想遍历这个数组并输出发生变化的位置(真变为假或相反)。为此,我整理了两种不同的方法:

  • 递归二分查找(查看所有值是否相同,如果不相同,则分成两部分,然后递归)
  • 纯迭代搜索(循环遍历所有元素并与前一个/下一个元素进行比较)

两个版本都给出了我想要的结果,但是 Numba 对其中一个版本的影响比另一个版本更大。使用 300k 个值的虚拟数组,性能结果如下:

300k 个元素的数组的性能结果

  • 纯 Python 二分查找运行时间为 11 毫秒
  • 纯 Python 迭代搜索运行时间为 1.1 秒(比二分搜索慢 100 倍)
  • Numba 二进制搜索运行时间为 5 毫秒(比纯 Python 快 2 倍)
  • Numba 迭代搜索运行时间为 900 µs (比纯 Python 快 1,200 倍)

因此,当使用 Numba 时,binary_search 比 iterative_search 慢 5 倍,而理论上它应该快 100 倍(如果适当加速,预计应该在 9 µs 内运行)。

如何才能让 Numba 加速二进制搜索,就像加速迭代搜索一样?

两种方法的代码(以及示例position数组)都可以在这个公共 gist 中找到:https ://gist.github.com/JivanRoquet/d58989aa0a4598e060ec2c705b9f3d8f

注意:Numba 未binary_search()在对象模式下运行,因为当提及时nopython=True,它不会抱怨并愉快地编译该函数。


阅读 19

收藏
2024-12-10

共1个答案

小能豆

在你描述的情况中,使用 Numba 加速二分查找算法时,性能不如预期,尤其是在与迭代搜索方法的比较中。对于这种情况,Numba的加速效果通常是受限的,因为它的优化更多适用于线性(逐项)处理,而非递归或分治类型的算法,如二分查找。

可能的原因:

  1. 递归开销:二分查找的递归特性可能导致堆栈操作,这对于 JIT 编译的优化来说,通常不如线性循环那样高效。尤其是当 nopython=True 时,Numba 更倾向于优化简单的线性迭代。
  2. 数据访问模式:Numba 在处理大规模数据时,尤其是数组的顺序访问时,会大幅提升性能。然而,二分查找的方式不适合顺序访问,因为它每次都要跳跃性地访问数组,这可能导致性能瓶颈。

解决方法:

1. 转换为迭代版本的二分查找

为了利用 Numba 的高效迭代优化,你可以将递归的二分查找转换为迭代版本。这样,算法仍然会采用二分查找的思路,但会避免递归的开销。

2. 优化 Numba 的使用

可以尝试使用 numbaprange 或显式并行化来优化 iterative_search,同时确保将函数 nopython=True

示例代码:

我将提供一种迭代版本的二分查找实现,使用 Numba 提升其性能。

import numpy as np
import numba
from numba import jit

# 迭代版二分查找 (避免递归)
@jit(nopython=True)
def binary_search_iterative(arr):
    changes = []
    start, end = 0, len(arr) - 1
    prev = arr[start]

    while start <= end:
        mid = (start + end) // 2
        current = arr[mid]

        # 检查当前和前一个元素是否发生变化
        if current != prev:
            changes.append(mid)

        prev = current

        # 继续向左或向右查找
        if current < arr[mid]:
            end = mid - 1
        else:
            start = mid + 1

    return changes


# 测试数据
positions = np.array([True, False, False, False, True, True, True, True, False, False, False], dtype=np.bool)

# 使用 Numba 加速的二分查找
changes = binary_search_iterative(positions)

print("发生变化的索引:", changes)

总结:

  1. 递归的二分查找不适合 Numba 的优化:递归在 Numba 中的性能开销较大,特别是如果每次调用都涉及堆栈的创建和销毁。
  2. 迭代版二分查找可以得到更好的加速:Numba 通过对线性访问优化的迭代版本提供了更好的加速效果。
  3. 避免过多的分支和递归:对于这种类型的性能敏感任务,尽量简化算法并使用迭代式方法,而不是递归。

通过这些方法,你可以显著提升二分查找的执行速度,特别是对于大规模数据。

2024-12-10