Created
January 31, 2020 16:26
-
-
Save sfinkens/dd55c4a8792292c847cdfe8a74f862e1 to your computer and use it in GitHub Desktop.
polygon resampling
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pyresample import create_area_def | |
import shapely.geometry | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import geopandas | |
import pandas as pd | |
def show_area_defs(source_area_def, target_area_def): | |
for area_def in (source_area_def, target_area_def): | |
crs = area_def.to_cartopy_crs() | |
fix, ax = plt.subplots(subplot_kw={'projection': crs}) | |
ax.coastlines() | |
ax.imshow(np.zeros(area_def.shape), transform=crs, extent=crs.bounds) | |
plt.show() | |
def get_boxes(area_def): | |
x_coords, y_coords = area_def.get_proj_coords() | |
dx, dy = area_def.resolution | |
return [shapely.geometry.box(x - dx / 2, y - dy / 2, x + dx / 2, y + dy / 2) | |
for x, y in zip(x_coords.ravel(), y_coords.ravel())] | |
def resample_polygon(data, source_area, target_area): | |
# Compute polygons representing source grid cells and combine them together with the data | |
# in a data frame | |
source_poly = geopandas.GeoSeries(get_boxes(source_area), crs=source_area.proj_dict) | |
valid_data = ~np.isnan(data) | |
source_data = pd.DataFrame({'source_data': data.ravel(), | |
'source_count': valid_data.ravel().astype(int)}) | |
df_source = geopandas.GeoDataFrame(geometry=source_poly, data=source_data) | |
# Compute polygons representing target grid cells and combine them together with the grid index | |
# in a data frame | |
target_idx = pd.DataFrame({'target_idx': np.arange(target_area.size, dtype=int)}) | |
target_poly = geopandas.GeoSeries(get_boxes(target_area), crs=target_area.proj_dict) | |
df_target = geopandas.GeoDataFrame(geometry=target_poly, data=target_idx) | |
# Transform source polygons to target projection; remove polygons invalidated by the | |
# transformation | |
df_source_trans = df_source.to_crs(target_area.proj_dict) | |
area = df_source_trans.area | |
df_source_trans = df_source_trans[~(area.isna() | np.isinf(area))] | |
# Compute mean of all source polygons intersecting a target gridcell. The value associated with | |
# a source polygon can be used multiple times if that polygon intersects with multiple target | |
# grid cells. | |
# TODO: Use area of intersection as weights for a weighted mean | |
intersect = geopandas.overlay(df_source_trans, df_target, how='intersection') | |
intersect_mean = intersect.dissolve(by='target_idx', aggfunc='mean') | |
# Keep track of number of intersections per grid cell | |
intersect_sum = intersect.dissolve(by='target_idx', aggfunc='sum') | |
# Broadcast results to target grid | |
res = np.full(target_area.size, np.nan, dtype=float) | |
res[intersect_mean.index.to_numpy()] = intersect_mean['source_data'].to_numpy() | |
res = res.reshape(target_area.shape) | |
nobs = np.zeros(target_area.size, dtype=int) | |
nobs[intersect_sum.index.to_numpy()] = intersect_sum['source_count'].to_numpy() | |
nobs = nobs.reshape(target_area.shape) | |
return res, nobs, intersect, intersect_mean | |
if __name__ == '__main__': | |
a = 6378169.0 | |
h = 35785831.0 | |
width = 36 | |
source_area = create_area_def(area_id='geos', | |
projection={'proj': 'geos', 'lon_0': 0, 'a': a, | |
'b': 6356583.8, 'h': h}, | |
width=width, height=width, | |
area_extent=[-5570248.686685662, -5567248.28340708, 5567248.28340708, 5570248.686685662]) | |
target_area = create_area_def(area_id='regular', projection={'proj': 'latlon'}, | |
width=36*2, height=18*2, | |
area_extent=(-180.0, -90.0, 180.0, 90.0)) | |
data = np.arange(source_area.size, dtype=float).reshape(source_area.shape) | |
data[10:15, 10:15] = np.nan | |
resampled, nobs, intersect, intersect_mean = resample_polygon(data, | |
source_area=source_area, | |
target_area=target_area) | |
plt.figure() | |
plt.imshow(nobs) | |
plt.colorbar() | |
plt.title('# of observations') | |
plt.figure() | |
plt.imshow(resampled) | |
plt.colorbar() | |
plt.title('Result on target grid') | |
intersect.plot(column='source_data', legend=True) | |
plt.title('Intersection') | |
intersect_mean.plot(column='source_data', legend=True) | |
plt.title('Mean of intersections') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment