Last active
April 24, 2023 15:05
-
-
Save jswhit/99b1c7b12a8b3eabcf521987424c98c3 to your computer and use it in GitHub Desktop.
test multiscale LETKF solver
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""non-cycled 1d test of LETKF with multiscale localization""" | |
import numpy as np | |
from scipy.linalg import eigh, lapack, solve | |
from scipy.fft import rfft, irfft, rfftfreq | |
from argparse import ArgumentParser | |
# set random seed for reproducibility | |
#np.random.seed(42) | |
def syminv(C): | |
# inverse of a square symmetric positive definite matrix | |
# using eigenanalysis | |
#evals, eigs = eigh(C) | |
#evals = evals.clip(min=np.finfo(evals.dtype).eps) | |
#C_inv = (eigs * (1.0 / evals)).dot(eigs.T) | |
# using Cholesky decomp | |
zz, info = lapack.dpotrf(C) | |
C_inv, info = lapack.dpotri(zz) | |
# lapack only returns the upper or lower triangular part | |
C_inv = np.triu(C_inv) + np.triu(C_inv, k=1).T | |
return C_inv | |
# using linear solver | |
#return solve(C, np.identity(C.shape[0]), sym_pos=True) | |
def gausscov(r,l,w): | |
return w*np.exp(-(r/l)**2) | |
def expcov(r,l): | |
return np.exp(-2.*r/l) | |
def getdist(i,j): | |
"""find distances between point i and other points j in 1d periodic domain""" | |
ndim = len(j) | |
return np.abs(np.remainder(i-j + ndim/2.,ndim)-ndim/2.) | |
def gasp_cohn(r): | |
"""Gaspari-Cohn localization function (goes to zero at r=1)""" | |
eps = np.finfo(r.dtype).eps | |
r = (np.abs(2*r)).clip(min=eps) | |
loc = np.zeros(r.shape, r.dtype) | |
loc = np.where(r<=1, -0.25*r**5+0.5*r**4+0.625*r**3-5./3.*r**2+1, loc) | |
loc = np.where(np.logical_and(r > 1.,r <= 2.), | |
1./12.*r**5-0.5*r**4+0.625*r**3+5./3.*r**2-5.*r+4.-2./3./r, loc) | |
return loc | |
# CL args. | |
parser = ArgumentParser(description='test multiscale LETKF localization') | |
parser.add_argument('--lscales', type=float, nargs='+',required=True, help='localization scales in grid points') | |
parser.add_argument('--band_cutoffs', type=float, nargs='+',required=True, help='wavenumber cutoff for each lscale') | |
parser.add_argument('--cov_param', type=float, default=70, help='true covariance parameter') | |
parser.add_argument('--verbose', action='store_true', help='verbose output') | |
parser.add_argument('--l1norm', action='store_true', help='L1 instead of L2 error norm') | |
parser.add_argument('--nsamples', type=int, default=8, help='ensemble members') | |
parser.add_argument('--ntrials', type=int, default=100, help='number of trials') | |
parser.add_argument('--ndim', type=int, default=500, help='number of grid points') | |
parser.add_argument('--random_seed', type=int, default=0, help='random seed (default is to not set)') | |
parser.add_argument('--oberrvar', type=float, default=1.0, help='observation error variance') | |
args = parser.parse_args() | |
# update local namespace with CL args and values | |
locals().update(args.__dict__) | |
if verbose: | |
print(args) | |
if random_seed > 0: # set random seed for reproducibility | |
np.random.seed(random_seed) | |
# specify covariance and localization matrices | |
nlscales = len(lscales) | |
nband_cutoffs = len(band_cutoffs) | |
if nlscales > 1 and nband_cutoffs != nlscales-1: | |
print('number of lscales not the same as len(band_cutoffs)+1') | |
raise SystemExit | |
local = np.zeros((nlscales,ndim,ndim),np.float64) | |
cov = np.zeros((ndim,ndim),np.float64) # B variance = 1 | |
# fourier wavenumbers | |
wavenums = ndim*rfftfreq(ndim)[0 : (ndim // 2) + 1] | |
# define true cov as sum of gaussians with gaussian weighting | |
clscales = np.arange(1,ndim//2) | |
wts = np.exp(-(clscales/cov_param)**2) | |
wts = wts/wts.sum() | |
for wt,clscale in zip(wts,clscales): | |
for i in range(ndim): | |
dist = getdist(i,np.arange(ndim)) | |
cov[:,i] += gausscov(dist,clscale,wt) | |
# use exponential approximation | |
#for i in range(ndim): | |
# dist = getdist(i,np.arange(ndim)) | |
# cov[:,i] = expcov(dist,cov_param) | |
#if verbose: | |
# import matplotlib.pyplot as plt | |
# plt.plot(np.arange(ndim), cov[ndim//2],color='k') | |
# plt.title('cov') | |
# plt.show() | |
# raise SystemExit | |
for n,lscale in enumerate(lscales): | |
for i in range(ndim): | |
dist = getdist(i,np.arange(ndim)) | |
local[n,:,i] = gasp_cohn(dist/lscale) # gaspari-cohn polynomial (compact support) | |
local = local.clip(min=np.finfo(local.dtype).eps) | |
# eigenanalysis of true cov, compute optimal gain matrix | |
evals, evecs = eigh(cov) | |
evals = evals.clip(min=np.finfo(evals.dtype).eps) | |
scaled_evecs = np.dot(evecs, np.diag(np.sqrt(evals)))/np.sqrt(nsamples-1) | |
kfopt = np.dot(cov, syminv(cov + oberrvar*np.eye(ndim))) | |
paopt = np.dot((np.eye(ndim) - kfopt), cov) | |
#import matplotlib.pyplot as plt | |
#plt.plot(np.arange(ndim), kfopt[ndim//2],color='b') | |
#plt.plot(np.arange(ndim), pa[ndim//2],color='r') | |
#plt.show() | |
#raise SystemExit | |
if verbose: | |
print('tr(paopt)/tr(pb) = ',np.trace(paopt)/np.trace(cov)) | |
l1norm = False | |
if l1norm: | |
kfopt_frobnorm = np.abs(kfopt).sum() | |
else: | |
kfopt_frobnorm = (kfopt**2).sum() | |
# create square root of localization matrices for each local volume | |
dist = np.zeros((ndim,ndim),np.float64) | |
indx = np.zeros((ndim,ndim),bool) # based on longest length scale | |
for i in range(ndim): | |
dist[i] = getdist(i,np.arange(ndim)) | |
indx[i] = dist[i] < np.abs(lscales[0]) | |
sqrtlocalloc_lst=[]; neig_lst=[] | |
for n,lscale in enumerate(lscales): | |
nlocal = np.zeros(ndim,int) | |
for i in range(ndim): | |
localloc = local[n][np.ix_(indx[i],indx[i])] | |
nlocal = localloc.shape[0] | |
# symmetric square root of localization (truncated eigenvector expansion) | |
evalsl, evecsl = eigh(localloc) | |
for ne in range(1,nlocal): | |
percentvar = evalsl[-ne:].sum()/evalsl.sum() | |
if percentvar > 0.99: | |
neig = ne | |
break | |
evecs_norml = (evecsl*np.sqrt(evalsl/percentvar)).T | |
if not i: | |
neig_lst.append(neig) | |
sqrtlocalloc = np.zeros((ndim,neig,nlocal),np.float64) | |
sqrtlocalloc[i,...] = evecs_norml[nlocal-neig:nlocal,:] | |
sqrtlocalloc_lst.append(sqrtlocalloc) | |
nsamples_tot=0 | |
for n in range(nlscales): | |
nsamples_tot += neig_lst[n]*nsamples | |
# run trials with different ensembles | |
meankferr_rloc = 0; meankferr_bloc = 0; meankferr_blocg = 0 | |
kfmean_rloc = np.zeros((ndim,ndim),np.float64) | |
kfmean_bloc = np.zeros((ndim,ndim),np.float64) | |
kfmean_blocg = np.zeros((ndim,ndim),np.float64) | |
bandvar_mean = np.zeros(nlscales, np.float64) | |
totvar1=0; totvar2=0; totvar3=0 | |
for ntrial in range(ntrials): | |
# generate ensemble (x is an array of unit normal random numbers) | |
x = np.random.normal(size=(ndim,nsamples)) | |
x = x - x.mean(axis=1)[:,np.newaxis] # zero mean | |
# full ensemble | |
y = np.dot(scaled_evecs,x) | |
# spectral bandpass filtering (boxcar window). | |
if nlscales == 1: | |
yyl=[y] | |
else: | |
yyl=[] | |
yfilt_save = np.zeros_like(y) | |
yspec = rfft(y,axis=0) | |
for n,sigma in enumerate(band_cutoffs): | |
yfiltspec = np.where(wavenums[:,np.newaxis] < sigma, yspec, 0.+0.j) | |
yfilt = irfft(yfiltspec,axis=0) | |
yyl.append(yfilt-yfilt_save) | |
yfilt_save=yfilt | |
ysum = np.zeros_like(y) | |
for n in range(nband_cutoffs): | |
ysum += yyl[n] | |
yyl.append(y-ysum) | |
yy = np.asarray(yyl) | |
bandvar = np.zeros(nlscales, np.float64) | |
for n in range(nlscales): | |
bandvar[n] = ((yy[n]**2).sum(axis=-1)/(nsamples-1)).mean() | |
bandvar_mean += bandvar/nsamples | |
yyall = yy.sum(axis=0) | |
totvar1 += ((y**2).sum(axis=-1)/(nsamples-1)).mean()/nsamples | |
totvar2 += ((yyall**2).sum(axis=-1)/(nsamples-1)).mean()/nsamples | |
totvar3 += bandvar.sum()/nsamples | |
#diff = y-yyall | |
#print(diff.min(), diff.max(),totvar1,totvar2,bandvar.sum()) | |
#continue | |
kfens_rloc = np.zeros((ndim,ndim),np.float64) | |
kfens_bloc = np.zeros((ndim,ndim),np.float64) | |
kfens_blocg = np.zeros((ndim,ndim),np.float64) | |
# global solve | |
cov_local = np.zeros((nlscales,ndim,ndim),np.float64) | |
# no cross covariance (by construction) | |
for n in range(nlscales): | |
cov_local[n] = local[n]*np.dot(yy[n], yy[n].T) | |
hpbhtinv = syminv(cov_local.sum(axis=0) + oberrvar*np.eye(ndim)) | |
for n in range(nlscales): | |
kfens_blocg += np.dot(cov_local[n], hpbhtinv) | |
# local solve | |
for i in range(ndim): | |
# find local grid points (obs since H=I) | |
if not ntrial: | |
# use largest (first) localization scale to define local volume | |
indx[i] = dist[i] < np.abs(lscales[0]) | |
ylocal_full = y[np.ix_(indx[i],np.ones(nsamples,bool))].T | |
ylocal = np.zeros((nlscales,)+ylocal_full.shape,ylocal_full.dtype) | |
for n in range(nlscales): | |
ylocal[n] = yy[n][np.ix_(indx[i],np.ones(nsamples,bool))].T | |
# R localization | |
Yb_sqrtRinv_lst=[]; Yb_Rinv_lst=[]; ylocal_lst=[] | |
for n in range(nlscales): | |
taper = local[n,indx[i],i] | |
Yb_sqrtRinv_lst.append(np.sqrt(taper/oberrvar)*ylocal[n]) | |
Yb_Rinv_lst.append((taper/oberrvar)*ylocal[n]) | |
ylocal_lst.append(yy[n,i,:]) | |
Yb_sqrtRinv = np.vstack(Yb_sqrtRinv_lst) | |
Yb_Rinv = np.vstack(Yb_Rinv_lst) | |
ytmp = np.concatenate(ylocal_lst) | |
pa = np.eye(nsamples*nlscales) + np.dot(Yb_sqrtRinv, Yb_sqrtRinv.T) | |
painv = syminv(pa); painv_YbRinv = np.dot(painv, Yb_Rinv) | |
kfens_rloc[indx[i],i] = np.dot(ytmp, painv_YbRinv) | |
# B loc with modulate ensemble with eigenvectors of 'local' localization matrix. | |
Yb_sqrtRinv_lst=[]; ylocal_lst=[] | |
for n,lscale in enumerate(lscales): | |
neig = neig_lst[n] | |
sqrtlocalloc = sqrtlocalloc_lst[n] | |
nlocal = sqrtlocalloc.shape[-1] | |
nsamples2 = neig*nsamples; nsamp2 = 0 | |
ylocal2 = np.zeros((nsamples2,nlocal),ylocal.dtype) | |
ylocal = yy[n][np.ix_(indx[i],np.ones(nsamples,bool))].T | |
for j in range(neig): | |
for nsamp in range(nsamples): | |
ylocal2[nsamp2,:] = ylocal[nsamp,:]*sqrtlocalloc[i,neig-j-1,:] | |
nsamp2 += 1 | |
Yb_sqrtRinv = ylocal2/np.sqrt(oberrvar) | |
Yb_sqrtRinv_lst.append(Yb_sqrtRinv) | |
ylocal_lst.append(ylocal2[:,np.argmin(dist[i][indx[i]])]) | |
Yb_sqrtRinv = np.vstack(Yb_sqrtRinv_lst) | |
ytmp = np.concatenate(ylocal_lst) | |
painv = syminv(np.eye(nsamples_tot) + np.dot(Yb_sqrtRinv, Yb_sqrtRinv.T)) | |
kfens_bloc[indx[i],i] = np.dot(ytmp,np.dot(painv,Yb_sqrtRinv/np.sqrt(oberrvar))) | |
# normalized Frobenius norm | |
diff_rloc = kfens_rloc-kfopt | |
diff_bloc = kfens_bloc-kfopt | |
diff_blocg = kfens_blocg-kfopt | |
if l1norm: | |
kferr_rloc = np.abs(diff_rloc).sum() | |
meankferr_rloc += (kferr_rloc/kfopt_frobnorm)/ntrials | |
kferr_bloc = np.abs(diff_bloc).sum() | |
meankferr_bloc += (kferr_bloc/kfopt_frobnorm)/ntrials | |
kferr_blocg = np.abs(diff_blocg).sum() | |
meankferr_blocg += (kferr_blocg/kfopt_frobnorm)/ntrials | |
else: | |
kferr_rloc = (diff_rloc**2).sum() | |
meankferr_rloc += np.sqrt(kferr_rloc/kfopt_frobnorm)/ntrials | |
kferr_bloc = (diff_bloc**2).sum() | |
meankferr_bloc += np.sqrt(kferr_bloc/kfopt_frobnorm)/ntrials | |
kferr_blocg = (diff_blocg**2).sum() | |
meankferr_blocg += np.sqrt(kferr_blocg/kfopt_frobnorm)/ntrials | |
kfmean_rloc += kfens_rloc/ntrials | |
kfmean_bloc += kfens_bloc/ntrials | |
kfmean_blocg += kfens_blocg/ntrials | |
#print(totvar1, totvar2, totvar3) | |
#print(bandvar_mean) | |
if verbose: | |
import matplotlib.pyplot as plt | |
x = np.linspace(-ndim//2,ndim//2-1,ndim) | |
plt.plot(x,kfmean_rloc[ndim//2],'r',label='mean est K (Rloc)') | |
plt.plot(x,kfmean_bloc[ndim//2],'b',label='mean est K (Bloc)') | |
plt.plot(x,kfopt[ndim//2],'k',label='K true') | |
plt.xlim(-lscale,lscale) | |
plt.legend() | |
plt.savefig('meangain.png') | |
plt.show() | |
# print out mean error | |
print("lscale = %s Kerr_localRloc = %s Kerr_localBloc = %s Kerr_globalBloc = %s" %\ | |
(lscales[0],meankferr_rloc,meankferr_bloc,meankferr_blocg)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment