Skip to content

Instantly share code, notes, and snippets.

@d-v-b
Created December 17, 2024 11:06
Show Gist options
  • Save d-v-b/a3ffaace36f5b5f1f56671e05f21b06c to your computer and use it in GitHub Desktop.
Save d-v-b/a3ffaace36f5b5f1f56671e05f21b06c to your computer and use it in GitHub Desktop.
A simple bounding-box-based array with caching
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "numpy",
# ]
# ///
import numpy as np
from typing import Iterable
BBox = tuple[tuple[int, int], ...]
def slice_to_interval(slice, length: int) -> tuple[int, int]:
"""
Convert a slice to an interval tuple
"""
if slice.step != 1:
raise ValueError("slice.step must be 1")
if slice.stop is None:
raise ValueError('slice.stop must not be None')
return slice.start, slice.stop
def slices_to_bbox(slices: Iterable[slice], shape: tuple[int, ...]) -> BBox:
return tuple(
slice_to_interval(_slice, _shape) for _slice, _shape in zip(slices, shape, strict=True))
def intersect_interval(interval_a: tuple[int, int], interval_b: tuple[int, int]) -> tuple[int, int]:
"""
Get the intersection of two half-open intervals as a tuple of ints. Might be empty if the
intervals are disjoint.
"""
if interval_a[0] >= interval_b[1] or interval_b[0] >= interval_a[1]:
return ()
return max(interval_a[0], interval_b[0]), min(interval_a[1], interval_b[1])
def intersect_bbox(bbox_a: BBox, bbox_b: BBox) -> BBox:
"""
Get the intersection between two bounding boxes as a tuple of tuples of integers.
Might be empty if the intervals are disjoint.
"""
tuple(intersect_interval(a, b) for a,b in zip(bbox_a, bbox_b, strict=True))
def contains_interval(interval_a: tuple[int, int], interval_b: tuple[int, int]) -> bool:
"""
Return True if interval_a contains interval_b, False otherwise
"""
return interval_a[0] <= interval_b[0] and interval_a[1] >= interval_b[1]
def contains_bbox(bbox_a, bbox_b) -> bool:
"""
Return True if bbox_a contains bbox_b, False otherwise
"""
return all(contains_interval(a, b) for a,b in zip(bbox_a, bbox_b, strict=True))
class CachedArray:
data: np.ndarray
_cache_size: tuple[int, ...]
_cache: tuple[tuple[tuple[int, int],...], np.ndarray] # (bounding box, data) tuple
def __init__(self, data, _cache_size, _cache=()):
self.data = data
self._cache_size = _cache_size
self._cache = _cache
def __getitem__(self, query: tuple[slice, ...]) -> np.ndarray:
bbox_query = slices_to_bbox(query, self.shape)
if len(self._cache) == 0:
print('cache is empty, fetching data')
result = self.data[query]
self._cache = (bbox_query, result)
return result
else:
bbox_cache, data_cache = self._cache
if contains_bbox(bbox_cache, bbox_query):
print('using cache')
# translate the query bbox into the array coordinate system
bbox_trans = tuple((a[0] - b[0], a[1] - b[0]) for a,b in zip(bbox_query, bbox_cache, strict=True))
slices_trans = tuple(slice(start, stop, 1) for start, stop in bbox_trans)
return data_cache[slices_trans]
else:
print('cache is invalid, fetching data')
result = self.data[query]
self._cache = (bbox_query, result)
return result
@property
def shape(self) -> tuple[int, ...]:
return self.data.shape
data = np.arange(4 ** 3).reshape(4,4,4)
cached_data = CachedArray(data, _cache_size=[1,4,4])
query = slice(0, 1, 1), slice(0, 4, 1), slice(0, 4, 1)
res = cached_data[query]
# uses the cache
res2 = cached_data[query]
assert np.array_equal(res, res2)
query2 = slice(1, 2, 1), slice(0, 4, 1), slice(0, 4, 1)
# will not use the cache
res3 = cached_data[query2]
# will use the cache
res4 = cached_data[query2]
assert np.array_equal(res3, res4)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment