Created
June 25, 2019 21:43
-
-
Save egafni/4625c32bfddee88c701518fbf7814573 to your computer and use it in GitHub Desktop.
method to save/load xarray multi-index
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import xarray as xr | |
import pandas as pd | |
def encode_multiindices(dataset): | |
""" | |
Provides a way to encode multiindices for saving to a netcdf file | |
Adapted from https://github.com/pydata/xarray/issues/1077#issuecomment-436015893 | |
""" | |
for idx_name, index in list(dataset.indexes.items()): | |
if isinstance(index, pd.MultiIndex): | |
temp_name = '__' + idx_name | |
new_coords = {'{}__{}'.format(temp_name, level_name): level_values.rename(None) | |
for level_name, level_values in zip(index.names, index.levels)} | |
new_coords[temp_name] = xr.DataArray(index.labels, | |
dims=('{}__names__'.format(temp_name), | |
'{}__num__'.format(temp_name)), | |
coords={'{}__names__'.format(temp_name): index.names, | |
'{}__num__'.format(temp_name): list(range(len(index)))}, | |
attrs={'__is_multiindex': 1}) | |
dataset = dataset.drop(idx_name).assign_coords(**new_coords) | |
return dataset | |
def decode_multiindices(dataset): | |
""" | |
Provides a way to load a dataset encoded by `encode_multiindices` | |
""" | |
for temp_name, da in list(dataset.coords.items()): | |
if temp_name.startswith('__') and da.attrs.get('__is_multiindex', False): | |
name = temp_name[2:] | |
level_names = da.coords['{}__names__'.format(temp_name)].values | |
levels = [dataset.coords['{}__{}'.format(temp_name, level_name)].values for level_name in level_names] | |
labels = da.values | |
dataset = dataset.assign_coords(**{name: pd.MultiIndex(levels=levels, labels=labels, names=level_names)}) | |
dataset = dataset.drop(['{}__{}'.format(temp_name, level_name) for level_name in level_names] + | |
list(da.dims) + [temp_name]) | |
return dataset | |
def to_netcdf_with_multiindex(dataset, engine='netcdf4', *args, **kwargs): | |
encode_multiindices(dataset).to_netcdf(*args, engine=engine, **kwargs) | |
def open_with_multiindex(*args, **kwargs): | |
return decode_multiindices(xr.open_dataset(*args, **kwargs)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment