Last active
January 8, 2019 17:34
-
-
Save jswhit/27e7bb07f1567d389b46199473a96df4 to your computer and use it in GitHub Desktop.
MPI parallel nearest neighbor search on a sphere
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import print_function # Python 3 compatible print function | |
"""find nearest neighbors for points on a sphere when locations to be compared | |
are distributed across MPI tasks""" | |
# to run: mpirun -np 4 python parallel_nn_sphere.py | |
# requires mpi4py ('conda install mpi4py' in anaconda python) | |
from mpi4py import MPI | |
import numpy as np | |
import sys | |
# generate lat,lon values for random points on the surface of a unit sphere | |
def random_sphere(npts): | |
# generate random points on a sphere, | |
# so that every small area on the sphere is expected | |
# to have the same number of points. | |
# http://mathworld.wolfram.com/SpherePointPicking.html | |
u = np.random.uniform(0.,1.,size=npts) | |
v = np.random.uniform(0.,1.,size=npts) | |
lons = 2.*np.pi*u | |
lats = np.arccos(2*v-1) - np.pi/2. | |
return lons, lats | |
# great circle distance | |
def gcdist(lon1,lat1,lon2,lat2): | |
# compute great circle distance in radians between (lon1,lat1) and | |
# (lon2,lat2). | |
# lon,lat pairs given in radians - returned distance is in radians. | |
# uses Haversine formula | |
dlon = lon2 - lon1 | |
dlat = lat2 - lat1 | |
a = (np.sin(dlat/2))**2 + np.cos(lat1) * np.cos(lat2) * (np.sin(dlon/2))**2 | |
# this can happen due to roundoff error, resulting in dist = NaN. | |
a = a.clip(0.,1.) | |
return 2.0 * np.arctan2( np.sqrt(a), np.sqrt(1-a) ) | |
# function to find nearest neighbors | |
# this could be replaced by more efficient kd tree | |
def find_neighbors(lon1,lat1,lon2,lat2,radius): | |
# lon1,lat1 is a single point | |
# lon2,lat2 are vectors of lons/lats | |
# return indices of lon2,lat2 within radius. | |
r = gcdist(lon1,lat1,lon2,lat2) | |
return np.where(r <= radius)[0] | |
# function to perform distributed nearest neighbor search | |
def find_distributed_neigbors(lon,lat,oblons,oblats,radius,comm=MPI.COMM_WORLD): | |
# find neighbors within radius of lon,lat in oblons,oblats | |
# oblons, oblats are distributed across MPI tasks. | |
# returns oblons_close_all, oblats_close_all which contain | |
# all neighbors. Must be called by all MPI tasks. | |
# create arrays needed for MPI | |
rank = comm.rank | |
nprocs = comm.Get_size() | |
lonstmp = np.empty(nprocs, np.float64) | |
latstmp = np.empty(nprocs, np.float64) | |
recvcounts = np.empty(nprocs,np.int) | |
# broadcast this state variable location to all other tasks | |
# NOTE: allgather only works if same number of state locations | |
# on each task. | |
comm.Allgather(lon,lonstmp) | |
comm.Allgather(lat,latstmp) | |
# lonstmp now contains a vector length nprocs with the i'th value | |
# of the state variable longitude location for each task. | |
# now loop over the locations in lonstmp,latstmp | |
# j is the task number that this location belongs to | |
for j in range(nprocs): | |
# check for missing value, if found set neighbors arrays empty | |
# and continue loop. | |
# if there are not the same number of points assigned to each task, | |
# the arrays can be padded with nans. | |
if np.isnan(lonstmp[j]) or np.isnan(latstmp[j]): | |
oblons_close_all_tmp = np.array([],np.float64) | |
oblats_close_all_tmp = np.array([],np.float64) | |
if rank == j: | |
oblons_close_all=oblons_close_all_tmp | |
oblats_close_all=oblats_close_all_tmp | |
continue | |
# find the nearest neighbor ob locations for lonstmp[j],latstmp[j] | |
# on this task | |
indices = find_neighbors(lonstmp[j],latstmp[j],oblons,oblats,radius) | |
oblons_close = oblons[indices] | |
oblats_close = oblats[indices] | |
# recvcounts is the number of nearest neighbors found on each task | |
# ncount is the number of nearest neighbors on this task | |
ncount = np.asarray(indices.size) | |
# send recvcounts to all tasks. | |
comm.Allgather(ncount,recvcounts) | |
# obclose_all_tmp is an array to hold all the nearest | |
# neighbors found on all tasks. Only needs to be | |
# allocated on task responsible for this state variable. | |
if rank==j: | |
ncount_all = recvcounts.sum() | |
oblons_close_all_tmp=np.empty(ncount_all,np.float64) | |
oblats_close_all_tmp=np.empty(ncount_all,np.float64) | |
else: | |
oblons_close_all_tmp=None | |
oblats_close_all_tmp=None | |
# displs is the 'displacement index vector' for Gatherv | |
displs = np.zeros(nprocs,np.int) | |
for nrank in range(1,nprocs): | |
displs[nrank]=displs[nrank-1]+recvcounts[nrank-1] | |
# gather all nearest neighbors on task responsible for this state | |
# variable (rank=j). | |
comm.Gatherv([oblons_close,recvcounts[rank],MPI.DOUBLE],[oblons_close_all_tmp,tuple(recvcounts),tuple(displs),MPI.DOUBLE],root=j) | |
comm.Gatherv([oblats_close,recvcounts[rank],MPI.DOUBLE],[oblats_close_all_tmp,tuple(recvcounts),tuple(displs),MPI.DOUBLE],root=j) | |
# save result on rank j | |
if rank==j: | |
oblons_close_all=oblons_close_all_tmp | |
oblats_close_all=oblats_close_all_tmp | |
return oblons_close_all, oblats_close_all | |
# get MPI task info | |
comm = MPI.COMM_WORLD | |
rank = comm.rank # The process ID (integer 0-3 for 4-process run) | |
nprocs = comm.Get_size() # total number of MPI tasks | |
# total number of state variable locations, distributed evenly over tasks | |
npts = 1000 | |
if npts % nprocs: | |
if rank==0: sys.stdout.write('npts must be divisible by nprocs, exiting ...') | |
raise SystemExit | |
npts_pertask = npts // nprocs | |
xlons, xlats = random_sphere(npts_pertask) | |
# add a missing value lon/lat pair on root task | |
if rank==0: xlons[-1]=np.nan; xlats[-1]=np.nan | |
# total number of observation locations, distributed evenly over tasks. | |
nobs = 1000 | |
if nobs % nprocs: | |
if rank==0: sys.stdout.write('nobs must be divisible by nprocs, exiting ...') | |
raise SystemExit | |
nobs_pertask = nobs // nprocs | |
oblons, oblats = random_sphere(nobs_pertask) | |
# Allgather to get all ob locations on all tasks (for debugging) | |
check_result = True | |
if check_result: | |
oblons_all = np.empty(nobs,np.float64) | |
oblats_all = np.empty(nobs,np.float64) | |
comm.Allgather(oblons,oblons_all) | |
comm.Allgather(oblats,oblats_all) | |
# nearest neighbor search radius (radians) | |
radius = 0.25 | |
# measure walltime in this loop | |
t1 = MPI.Wtime() | |
# loop over state variables on each task | |
for i in range(npts_pertask): | |
# find all neighbors for xlons[i],xlats[i] on this task, considering oblons,oblats across all MPI tasks. | |
oblons_close_all,oblats_close_all = find_distributed_neigbors(xlons[i],xlats[i],oblons,oblats,radius) | |
# check result | |
if check_result and oblons_close_all.size > 0: # non-empty neighbors array | |
# find correct answer by searching all ob locations on each task | |
# (this is just for checking the answer, the whole point of this approach | |
# is to avoid having a global array of ob locations on each task) | |
indices = find_neighbors(xlons[i],xlats[i],oblons_all,oblats_all,radius) | |
oblons_close_all_check = oblons_all[indices] | |
oblats_close_all_check = oblats_all[indices] | |
difflons = np.abs(np.sort(oblons_close_all_check)-np.sort(oblons_close_all)) | |
difflats = np.abs(np.sort(oblats_close_all_check)-np.sort(oblats_close_all)) | |
if difflons.max() > 1.e-10 or difflats.max() > 1.e-10: | |
print('incorrect result on rank',rank) | |
# print out mean wall clock time spent in above loop. | |
# should be nearly constant with number of MPI tasks. | |
# So, this approach doesn't speed up the search, but it does reduce the memory overhead | |
# by eliminating the need for global arrays. | |
t = MPI.Wtime() - t1 | |
tmean = np.array(0.,np.float64) | |
comm.Reduce(np.array(t,np.float64),tmean,op=MPI.SUM,root=0) | |
if rank==0: print('total time=',tmean/nprocs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment