Last active
April 27, 2025 18:44
-
-
Save paraseba/8410f92b39da7b4b33505179f83161f0 to your computer and use it in GitHub Desktop.
Code for the blog post "Icechunk: Efficient storage of versioned array data"
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pathlib import Path | |
import numpy as np | |
import icechunk | |
import icechunk.xarray | |
import xarray | |
import zarr | |
ROOT = "/dev/shm/versioned-storage" | |
X_CHUNK_SIZE = 200 | |
Y_CHUNK_SIZE = 200 | |
T_CHUNK_SIZE = 1 | |
X_SIZE = 10 * X_CHUNK_SIZE | |
Y_SIZE = 10 * Y_CHUNK_SIZE | |
T_SIZE = 5 * T_CHUNK_SIZE | |
def print_stats(path: str) -> None: | |
root = Path(path) | |
size = 0 | |
chunks = 0 | |
for f in root.glob("**/*"): | |
if f.is_file(): | |
if "chunks" in f.parts: | |
chunks += 1 | |
size += f.stat().st_size | |
size /= 1024.0 * 1024 | |
print(f"Number of chunks: {chunks} - Total size: {size:.2f} MB") | |
def create_repo(path: str) -> icechunk.Repository: | |
storage = icechunk.local_filesystem_storage(path) | |
config = icechunk.RepositoryConfig(inline_chunk_threshold_bytes=0) | |
repo = icechunk.Repository.create(storage=storage, config=config) | |
data = np.random.rand(X_SIZE, Y_SIZE, T_SIZE) | |
dataset = xarray.Dataset( | |
data_vars=dict( | |
array=(("x", "y", "t"), data), | |
), | |
coords=dict( | |
x=("x", np.arange(X_SIZE)), | |
y=("y", np.arange(Y_SIZE)), | |
t=("t", np.arange(T_SIZE)), | |
), | |
) | |
session = repo.writable_session("main") | |
icechunk.xarray.to_icechunk( | |
dataset, | |
session, | |
encoding={ | |
"array": {"chunks": (X_CHUNK_SIZE, Y_CHUNK_SIZE, T_CHUNK_SIZE)}, | |
"x": {"chunks": (X_SIZE,)}, | |
"y": {"chunks": (Y_SIZE,)}, | |
"t": {"chunks": (1_000,)}, | |
}, | |
) | |
session.commit("Array created") | |
return repo | |
def append(repo: icechunk.Repository) -> None: | |
session = repo.writable_session("main") | |
new_data = np.random.rand(X_SIZE, Y_SIZE, 1) | |
new_pancake = xarray.Dataset( | |
data_vars=dict( | |
array=(("x", "y", "t"), new_data), | |
), | |
) | |
icechunk.xarray.to_icechunk(new_pancake, session, append_dim="t") | |
session.commit("Array extended") | |
def update_slice(repo: icechunk.Repository) -> None: | |
session = repo.writable_session("main") | |
new_data = np.random.rand(X_SIZE, Y_SIZE, 1) * 42 | |
new_pancake = xarray.Dataset( | |
data_vars=dict( | |
array=(("x", "y", "t"), new_data), | |
), | |
) | |
region = dict(t=slice(T_SIZE // 2, T_SIZE // 2 + 1)) | |
icechunk.xarray.to_icechunk(new_pancake, session, region=region) | |
session.commit("Array updated") | |
def update_metadata(repo: icechunk.Repository) -> None: | |
session = repo.writable_session("main") | |
group = zarr.open_group(store=session.store, mode="r") | |
array = group["array"] | |
array.attrs["foo"] = "bar" | |
session.commit("Array metadata updated") | |
session = repo.readonly_session("main") | |
group = zarr.open_group(store=session.store, mode="r") | |
array = group["array"] | |
assert array.attrs["foo"] == "bar" | |
def compare_versions(repo: icechunk.Repository) -> None: | |
session = repo.readonly_session(branch="main") | |
ds = xarray.open_zarr(session.store, consolidated=False) | |
new = ds.isel(t=2).max()["array"].values | |
parent = list(repo.ancestry(branch="main"))[-2] | |
session = repo.readonly_session(snapshot_id=parent.id) | |
ds = xarray.open_zarr(session.store, consolidated=False) | |
old = ds.isel(t=2).max()["array"].values | |
print(f"Old max: {old} - New max: {new}") | |
def main() -> None: | |
print("Writing 5 pancakes") | |
repo = create_repo(ROOT) | |
print_stats(ROOT) | |
print("Updating 1 pancake") | |
update_slice(repo) | |
print_stats(ROOT) | |
compare_versions(repo) | |
print("Updating metadata") | |
update_metadata(repo) | |
print_stats(ROOT) | |
print("Appending 1 pancake") | |
append(repo) | |
print_stats(ROOT) | |
last_snap_time = next(repo.ancestry(branch="main")).written_at | |
repo.expire_snapshots(older_than=last_snap_time) | |
gc_res = repo.garbage_collect(delete_object_older_than=last_snap_time) | |
print(gc_res) | |
print_stats(ROOT) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment