Created
October 6, 2022 15:03
-
-
Save GMoncrieff/624fc44be9adfb06fa4b6b90938ebbe8 to your computer and use it in GitHub Desktop.
Utility functions for data extraction in multidimensional arrays with non-rectilinear grids
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import xoak | |
import xarray as xr | |
import numpy as np | |
import pandas as pd | |
def select_points(data: xr.Dataset, | |
xc: str, yc: str, | |
xdat: np.array, ydat: np.array, | |
name: str = "points") -> xr.Dataset: | |
"""Select points from an xarray dataset. | |
Args: | |
data (xr.Dataset): xarray dataset from which to select points | |
xc (str): name of the x coordinate | |
yc (str): name of the y coordinate | |
xdat (np.array): x coordinates of the points | |
ydat (np.array): y coordinates of the points | |
name (str, optional): name of the new dimension. Defaults to "points". | |
Returns: | |
xr.Dataset: xarray dataset with only the selected points | |
""" | |
# set index | |
data.xoak.set_index([xc, yc], 'sklearn_geo_balltree') | |
# create query xr | |
ds_sel = xr.Dataset({ | |
xc: (name, xdat), | |
yc: (name, ydat) | |
}) | |
# select points | |
data = data.xoak.sel({ | |
xc: ds_sel[xc], | |
yc: ds_sel[yc] | |
}) | |
return data | |
def add_neighbour_pixels(data: xr.Dataset, | |
xc: str, yc: str, | |
xlen:int = 3, ylen:int = 3, | |
zdim:str = 'z')-> xr.Dataset: | |
"""Add neighbour pixels to an xarray dataset. | |
Args: | |
data (xr.Dataset): xarray input dataset | |
xc (str): name of the x coordinate | |
yc (str): name of the y coordinate | |
xlen (int, optional): number of pixels to add on the left and right. Defaults to 3. | |
ylen (int, optional): number of pixels to add on the top and bottom. Defaults to 3. | |
zdim (str, optional): name of the z dimension. Defaults to 'z'. | |
Returns: | |
xr.Dataset: xarray dataset with pixles added in x and y dim | |
""" | |
# pad data | |
# the roll up and down x and y dims | |
# add padding? #.pad({xc: 1,yc:1}, mode='edge')\ | |
data = data\ | |
.rolling({xc:xlen,yc:ylen}, min_periods=1,center=True)\ | |
.construct({xc: 'x_roll',yc: 'y_roll'})\ | |
.stack({zdim:('x_roll','y_roll')},create_index=False) | |
return data | |
def extract_and_label(data: xr.Dataset, | |
query:pd.DataFrame) -> xr.Dataset: | |
"""Extract and label points from an xarray dataset. | |
Args: | |
data (xr.Dataset): xarray dataset from which to extract points | |
query (pd.DataFrame): dataframe with the points to extract | |
Returns: | |
xr.Dataset: xarray dataset with the extracted points | |
""" | |
#convert to xarray | |
query = query.to_xarray() | |
#get lat lon | |
xq, yq = query.longitude.values, query.latitude.values | |
#expand dims by adding neighbour pixels | |
data = add_neighbour_pixels(data, 'x', 'y', 3, 3, 'z') | |
#extract at points | |
data = select_points(data, 'latitude', 'longitude', xq, yq, 'index') | |
#merge with query labels | |
data = data.merge(query['lab']) | |
return data | |
def test(): | |
#create df with locations and labels | |
df = pd.DataFrame({'longitude':[-122.666,-122.669,-122.721],'latitude':[21.1519,21.258,21.139],'lab':[1,1,7]}) | |
#create xr with data | |
lats = np.array([[21.138 , 21.14499, 21.15197, 21.15894, 21.16591], | |
[21.16287, 21.16986, 21.17684, 21.18382, 21.19079], | |
[21.18775, 21.19474, 21.20172, 21.2087 , 21.21568], | |
[21.21262, 21.21962, 21.22661, 21.23359, 21.24056], | |
[21.2375 , 21.2445 , 21.25149, 21.25848, 21.26545]]) | |
lons = np.array([[-122.72 , -122.69333, -122.66666, -122.63999, -122.61331], | |
[-122.7275 , -122.70082, -122.67415, -122.64746, -122.62078], | |
[-122.735 , -122.70832, -122.68163, -122.65494, -122.62825], | |
[-122.7425 , -122.71582, -122.68912, -122.66243, -122.63573], | |
[-122.75001, -122.72332, -122.69662, -122.66992, -122.64321]]) | |
band = np.array([1,2]) | |
speed = np.array([[[1, 2, 3, 4, 5], | |
[6 , 7, 8, 9, 10], | |
[11, 12, 13, 14, 15], | |
[16, 17, 18, 19, 20], | |
[21, 22, 23, 24, 25]], | |
[[100, 200, 300, 400, 500], | |
[600 , 700, 800, 900, 1000], | |
[1100, 1200, 1300, 1400, 1500], | |
[1600, 1700, 1800, 1900, 2000], | |
[2100, 2200, 2300, 2400, 2500]]]) | |
ds = xr.Dataset({'SPEED':(('band','x', 'y'),speed)}, | |
coords = {'latitude': (('x', 'y'), lats), | |
'longitude': (('x', 'y'), lons), | |
'band': ('band', band)}, | |
attrs={'variable':'Wind Speed'}) | |
ds = extract_and_label(data=ds,query=df) | |
assert list(dict(ds.dims).keys()) == ['index', 'band','z'] , "dim names do not match expectation" | |
assert list(dict(ds.dims).values()) == [3,2,9], "dim lengths do not match expectation" | |
assert ds['lab'].values.tolist() == [1,1,7], "labels do not match expectation" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment