Last active
December 30, 2021 13:32
-
-
Save ventusff/91628e5c98800b52f81b0cde55a71046 to your computer and use it in GitHub Desktop.
closest point on spline
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# NOTE: 给定一个点,找到其在一个样条函数上的最近点(或者说,投影点) | |
import numpy as np, matplotlib.pyplot as pp, mpl_toolkits.mplot3d as mp | |
import scipy.interpolate as si, scipy.optimize as so, scipy.spatial.distance as ssd | |
data = (1,2,3,4,5,6,7,8),(1,2.5,4,5.5,6.5,8.5,10,12.5),(1,2.5,4.5,6.5,8,8.5,10,12.5) | |
p = 6.5,9.5,9 | |
# Fit a spline to the data - s is the amount of smoothing, tck is the parameters of the resulting spline | |
(tck, uu) = si.splprep(data, s=0) | |
# Return distance from 3d point p to a point on the spline at spline parameter u | |
def distToP(u): | |
s = si.splev(u, tck) | |
return ssd.euclidean(p, s) | |
# Find the closest point on the spline to our 3d point p | |
# We do this by finding a value for the spline parameter u which | |
# gives the minimum distance in 3d to p | |
closestu = so.fmin(distToP, 0.5) | |
closest = si.splev(closestu, tck) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# NOTE: 给定一个点,找到其在一个样条函数上的最近点(或者说,投影点) | |
import numpy as np | |
import scipy | |
import scipy.interpolate | |
import scipy.optimize | |
import scipy.spatial.distance | |
data = (2,4,6,8,10,12,14,16),(1,2.5,4,5.5,6.5,8.5,10,12.5),(1,2.5,4.5,6.5,8,8.5,10,12.5) | |
p = np.array([1,9.5,13]).reshape(3,1) | |
# Fit a spline to the data | |
# s is the amount of smoothing | |
# k is the parameter of spline degree | |
# tck is the parameters of the resulting spline | |
(tck, uu) = scipy.interpolate.splprep(data, s=0, k=3) | |
tmp = 0 | |
# Return distance from 3d point p to a point on the spline at spline parameter u | |
def distToP(u): | |
global tmp | |
s = np.array(scipy.interpolate.splev(u, tck)) | |
tmp = s-p | |
return (tmp**2).sum() | |
def distToPPrime(u): | |
sp = np.array(scipy.interpolate.splev(u, tck, der=1)) | |
# NOTE: d( f(g(x)) )/dx = f'(g(x)) * g'(x) | |
return (2*tmp*sp).sum() | |
# Find the closest point on the spline to our 3d point p | |
# We do this by finding a value for the spline parameter u which | |
# gives the minimum distance in 3d to p | |
# from tqdm import tqdm | |
# for _ in tqdm(range(10000)): | |
closestu = scipy.optimize.fmin(distToP, 0.5, disp=False) | |
## NOTE: constrained multivariate methods | |
# closestu = scipy.optimize.fmin_l_bfgs_b(distToP, 0.5, fprime=distToPPrime, disp=False)[0] | |
# closestu = scipy.optimize.fmin_tnc(distToP, 0.5, fprime=distToPPrime, disp=False)[0] | |
# closestu = scipy.optimize.fmin_cobyla(distToP, 0.5, cons=[]) | |
# closestu = scipy.optimize.fmin_slsqp(distToP, 0.5, fprime=distToPPrime, disp=False) | |
## NOTE: univariate methods | |
# closestu = scipy.optimize.fminbound(distToP, 0, 1) | |
# closestu = scipy.optimize.golden(distToP) | |
# closestu = scipy.optimize.brent(distToP) | |
closest = np.array(scipy.interpolate.splev(closestu, tck)) | |
print(closestu) | |
print(closest) | |
# NOTE: 求解到这里就结束了,后面都是方便画图用的 | |
## 画图用 | |
import numpy as np, matplotlib.pyplot as pp, mpl_toolkits.mplot3d as mp | |
# Return the distance from u to v along the spline | |
def distAlong(tck, u = 0, v = 1, N = 1000): | |
spline = np.array(scipy.interpolate.splev(np.linspace(u, v, N), tck)) | |
lengths = np.sqrt(np.sum(np.diff(spline.T, axis=0)**2, axis=1)) | |
return np.sum(lengths) | |
s = "distance along spline to halfway point is %f" % (distAlong(tck)/2) | |
# Build a 3xN array of 3d points along the spline | |
N = 1001 | |
spline = np.array(scipy.interpolate.splev(np.linspace(0, 1, N), tck)) | |
# Plot things! | |
ax = mp.Axes3D(pp.figure(figsize=(8,8))) | |
ax.plot((p[0].item(),), (p[1].item(),), (p[2].item(),), 'o', label='input point') | |
ax.plot(closest[0], closest[1], closest[2], 'o', label='closest point on spline') | |
ax.plot((spline[0, N//2],), (spline[1, N//2],), (spline[2, N//2],), 'o', label='halfway along spline') | |
ax.plot(data[0], data[1], data[2], label='raw data points') | |
ax.plot(spline[0], spline[1], spline[2], label='spline fit to data') | |
pp.legend(loc='lower right') | |
pp.title(s) | |
pp.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment