Skip to content

Instantly share code, notes, and snippets.

@egafni
Created June 25, 2019 21:43
Show Gist options
  • Save egafni/4625c32bfddee88c701518fbf7814573 to your computer and use it in GitHub Desktop.
Save egafni/4625c32bfddee88c701518fbf7814573 to your computer and use it in GitHub Desktop.
method to save/load xarray multi-index
import xarray as xr
import pandas as pd
def encode_multiindices(dataset):
"""
Provides a way to encode multiindices for saving to a netcdf file
Adapted from https://github.com/pydata/xarray/issues/1077#issuecomment-436015893
"""
for idx_name, index in list(dataset.indexes.items()):
if isinstance(index, pd.MultiIndex):
temp_name = '__' + idx_name
new_coords = {'{}__{}'.format(temp_name, level_name): level_values.rename(None)
for level_name, level_values in zip(index.names, index.levels)}
new_coords[temp_name] = xr.DataArray(index.labels,
dims=('{}__names__'.format(temp_name),
'{}__num__'.format(temp_name)),
coords={'{}__names__'.format(temp_name): index.names,
'{}__num__'.format(temp_name): list(range(len(index)))},
attrs={'__is_multiindex': 1})
dataset = dataset.drop(idx_name).assign_coords(**new_coords)
return dataset
def decode_multiindices(dataset):
"""
Provides a way to load a dataset encoded by `encode_multiindices`
"""
for temp_name, da in list(dataset.coords.items()):
if temp_name.startswith('__') and da.attrs.get('__is_multiindex', False):
name = temp_name[2:]
level_names = da.coords['{}__names__'.format(temp_name)].values
levels = [dataset.coords['{}__{}'.format(temp_name, level_name)].values for level_name in level_names]
labels = da.values
dataset = dataset.assign_coords(**{name: pd.MultiIndex(levels=levels, labels=labels, names=level_names)})
dataset = dataset.drop(['{}__{}'.format(temp_name, level_name) for level_name in level_names] +
list(da.dims) + [temp_name])
return dataset
def to_netcdf_with_multiindex(dataset, engine='netcdf4', *args, **kwargs):
encode_multiindices(dataset).to_netcdf(*args, engine=engine, **kwargs)
def open_with_multiindex(*args, **kwargs):
return decode_multiindices(xr.open_dataset(*args, **kwargs))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment