I have the following Numba function:
@numba.njit def count_in_range(arr, min_value, max_value): count = 0 for a in arr: if min_value < a < max_value: count += 1 return count
It counts how many values are in the range in the array.
However, I realized that I only needed to determine if they existed. So I modified it as follows:
@numba.njit def count_in_range2(arr, min_value, max_value): count = 0 for a in arr: if min_value < a < max_value: count += 1 break # <---- break here return count
Then, this function becomes slower than before the change. Under certain conditions, it can be surprisingly more than 10 times slower.
Benchmark code:
from timeit import timeit rng = np.random.default_rng(0) arr = rng.random(10 * 1000 * 1000) # To compare on even conditions, choose the condition that does not terminate early. min_value = 0.5 max_value = min_value - 1e-10 assert not np.any(np.logical_and(min_value <= arr, arr <= max_value)) n = 100 for f in (count_in_range, count_in_range2): f(arr, min_value, max_value) elapsed = timeit(lambda: f(arr, min_value, max_value), number=n) / n print(f"{f.__name__}: {elapsed * 1000:.3f} ms")
Result:
count_in_range: 3.351 ms count_in_range2: 42.312 ms
Further experimenting, I found that the speed varies greatly depending on the search range (i.e. min_value and max_value).
min_value
max_value
At various search ranges:
count_in_range2: 5.802 ms, range: (0.0, -1e-10) count_in_range2: 15.408 ms, range: (0.1, 0.09999999990000001) count_in_range2: 29.571 ms, range: (0.25, 0.2499999999) count_in_range2: 42.514 ms, range: (0.5, 0.4999999999) count_in_range2: 24.427 ms, range: (0.75, 0.7499999999) count_in_range2: 12.547 ms, range: (0.9, 0.8999999999) count_in_range2: 5.747 ms, range: (1.0, 0.9999999999)
Can someone explain to me what is going on?
I am using Numba 0.58.1 under Python 3.10.11. Confirmed on both Windows 10 and Ubuntu 22.04.
The unexpected behavior you’re observing in the Numba-compiled function count_in_range2 is likely due to the way the JIT compiler optimizes the loop. When you introduce the break statement, the loop is no longer guaranteed to iterate over the entire array. This can lead to less efficient machine code being generated by the JIT compiler.
count_in_range2
break
Numba optimizes loops based on assumptions about the loop behavior. When you use count_in_range2 and break early, the compiler might make assumptions that the loop is short-circuited and optimize accordingly, leading to potentially suboptimal performance.
To address this issue, you can try a different approach that doesn’t involve breaking the loop. For example, you can use boolean array operations to check the condition for the entire array and then sum the resulting boolean array:
@numba.njit def count_in_range3(arr, min_value, max_value): return np.sum((min_value < arr) & (arr < max_value))
This way, you retain the benefits of Numba’s optimized loop without introducing the potential issues associated with breaking early. It’s worth noting that Numba performs well with boolean array operations.
Here’s how you can modify your benchmark code to include count_in_range3:
count_in_range3
for f in (count_in_range, count_in_range2, count_in_range3): f(arr, min_value, max_value) elapsed = timeit(lambda: f(arr, min_value, max_value), number=n) / n print(f"{f.__name__}: {elapsed * 1000:.3f} ms")
This modification should provide consistent and efficient performance across different search ranges.