Created
January 18, 2024 15:42
-
-
Save ekreutz/16c716aa90b74386637406fc948e9bbe to your computer and use it in GitHub Desktop.
Fast O(n) running maximum using numba
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from numba import njit | |
from numpy.typing import NDArray as array | |
@njit | |
def running_max(values: array, w: int) -> array: | |
"""Fast O(n) running maximum. | |
For large values of `w` this solution is 100x faster or more, than the naive version. | |
""" | |
n: int = len(values) | |
# We'll fake a queue using an array, since numba doesn't have queues or linked | |
# lists with efficient pops/pushes | |
queue: array = np.zeros(w, np.int64) | |
bl: int = 0 # index of oldest element inserted | |
br: int = -1 # index of newest element inserted | |
bn: int = 0 | |
max_vals: array = np.zeros(n, values.dtype) | |
for i in range(n): | |
# remove (pop left) elements that fell out of the window | |
while bn > 0 and queue[bl] < i - w + 1: | |
bl = (bl + 1) % w | |
bn -= 1 | |
# remove (pop right) elements whose values are less than the current value | |
# found at values[i] | |
while bn > 0 and values[queue[br]] <= values[i]: | |
br = (w + br - 1) % w | |
bn -= 1 | |
# Add current index (on the right) | |
br = (br + 1) % w | |
bn += 1 | |
queue[br] = i | |
# The max element is always at the left (bl) of the queue | |
max_vals[i] = values[queue[bl]] | |
return max_vals |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment