Created
December 17, 2024 11:06
-
-
Save d-v-b/a3ffaace36f5b5f1f56671e05f21b06c to your computer and use it in GitHub Desktop.
A simple bounding-box-based array with caching
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# /// script | |
# requires-python = ">=3.10" | |
# dependencies = [ | |
# "numpy", | |
# ] | |
# /// | |
import numpy as np | |
from typing import Iterable | |
BBox = tuple[tuple[int, int], ...] | |
def slice_to_interval(slice, length: int) -> tuple[int, int]: | |
""" | |
Convert a slice to an interval tuple | |
""" | |
if slice.step != 1: | |
raise ValueError("slice.step must be 1") | |
if slice.stop is None: | |
raise ValueError('slice.stop must not be None') | |
return slice.start, slice.stop | |
def slices_to_bbox(slices: Iterable[slice], shape: tuple[int, ...]) -> BBox: | |
return tuple( | |
slice_to_interval(_slice, _shape) for _slice, _shape in zip(slices, shape, strict=True)) | |
def intersect_interval(interval_a: tuple[int, int], interval_b: tuple[int, int]) -> tuple[int, int]: | |
""" | |
Get the intersection of two half-open intervals as a tuple of ints. Might be empty if the | |
intervals are disjoint. | |
""" | |
if interval_a[0] >= interval_b[1] or interval_b[0] >= interval_a[1]: | |
return () | |
return max(interval_a[0], interval_b[0]), min(interval_a[1], interval_b[1]) | |
def intersect_bbox(bbox_a: BBox, bbox_b: BBox) -> BBox: | |
""" | |
Get the intersection between two bounding boxes as a tuple of tuples of integers. | |
Might be empty if the intervals are disjoint. | |
""" | |
tuple(intersect_interval(a, b) for a,b in zip(bbox_a, bbox_b, strict=True)) | |
def contains_interval(interval_a: tuple[int, int], interval_b: tuple[int, int]) -> bool: | |
""" | |
Return True if interval_a contains interval_b, False otherwise | |
""" | |
return interval_a[0] <= interval_b[0] and interval_a[1] >= interval_b[1] | |
def contains_bbox(bbox_a, bbox_b) -> bool: | |
""" | |
Return True if bbox_a contains bbox_b, False otherwise | |
""" | |
return all(contains_interval(a, b) for a,b in zip(bbox_a, bbox_b, strict=True)) | |
class CachedArray: | |
data: np.ndarray | |
_cache_size: tuple[int, ...] | |
_cache: tuple[tuple[tuple[int, int],...], np.ndarray] # (bounding box, data) tuple | |
def __init__(self, data, _cache_size, _cache=()): | |
self.data = data | |
self._cache_size = _cache_size | |
self._cache = _cache | |
def __getitem__(self, query: tuple[slice, ...]) -> np.ndarray: | |
bbox_query = slices_to_bbox(query, self.shape) | |
if len(self._cache) == 0: | |
print('cache is empty, fetching data') | |
result = self.data[query] | |
self._cache = (bbox_query, result) | |
return result | |
else: | |
bbox_cache, data_cache = self._cache | |
if contains_bbox(bbox_cache, bbox_query): | |
print('using cache') | |
# translate the query bbox into the array coordinate system | |
bbox_trans = tuple((a[0] - b[0], a[1] - b[0]) for a,b in zip(bbox_query, bbox_cache, strict=True)) | |
slices_trans = tuple(slice(start, stop, 1) for start, stop in bbox_trans) | |
return data_cache[slices_trans] | |
else: | |
print('cache is invalid, fetching data') | |
result = self.data[query] | |
self._cache = (bbox_query, result) | |
return result | |
@property | |
def shape(self) -> tuple[int, ...]: | |
return self.data.shape | |
data = np.arange(4 ** 3).reshape(4,4,4) | |
cached_data = CachedArray(data, _cache_size=[1,4,4]) | |
query = slice(0, 1, 1), slice(0, 4, 1), slice(0, 4, 1) | |
res = cached_data[query] | |
# uses the cache | |
res2 = cached_data[query] | |
assert np.array_equal(res, res2) | |
query2 = slice(1, 2, 1), slice(0, 4, 1), slice(0, 4, 1) | |
# will not use the cache | |
res3 = cached_data[query2] | |
# will use the cache | |
res4 = cached_data[query2] | |
assert np.array_equal(res3, res4) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment