Created
July 28, 2019 17:42
-
-
Save TimSchmeier/c38dadaa2e53a7bb719a82664f1cebc0 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import scipy.special | |
import numpy as np | |
import matplotlib.pyplot as plt | |
x = np.linspace(-10, 10, 200) | |
y = 4*(x**3) - 2*(x**2) + x | |
y += np.random.normal(size=y.shape[0])*500 | |
plt.plot(x, y, 'ro') | |
plt.show() | |
class Bezier(object): | |
def __init__(self, numt=50): | |
self.numt = numt | |
self.t = np.linspace(0, 1, numt) | |
def decasteljau(self, x): | |
curve = [] | |
for t in self.t: | |
curve.append( | |
self.recurrence(x, t) | |
) | |
return curve | |
def recurrence(self, x, t): | |
# if final interpolation is done | |
if x.shape[0] == 1: | |
return x | |
# continue interpolation | |
else: | |
x_out = np.zeros(x.shape[0]-1) | |
for i in range(x.shape[0]-1): | |
x_out[i] = (1 - t)*x[i] + t*x[i + 1] | |
return self.recurrence(x_out, t) | |
def bernstein(self, x): | |
# Bernstein polynomial | |
# C(t) = SUMover(i points) b[i, n](t)*P[i] | |
degree = len(x) - 1 | |
s = 0 | |
for i in range(len(x)): | |
s += scipy.special.binom(degree, i) * (self.t**i) * ((1-self.t)**(degree-i))*x[i] | |
return s | |
b = Bezier() | |
fitted_bezier_x = b.decasteljau(x) | |
fitted_bezier_y = b.decasteljau(y) | |
plt.plot(x, y, 'ro') | |
plt.plot(fitted_bezier_x, fitted_bezier_y) | |
plt.show() | |
fitted_bernstein_x = b.bernstein(x) | |
fitted_bernstein_y = b.bernstein(y) | |
plt.plot(x, y, 'ro') | |
plt.plot(fitted_bernstein_x, fitted_bernstein_y) | |
plt.show() | |
class BSpline(object): | |
def __init__(self, x, y, degree, numt = 1000): | |
self.degree = degree | |
self.x = x | |
self.y = y | |
nknots = self.degree + len(self.x) + 1 | |
self.xknots = self.get_knots(self.x, nknots) | |
self.yknots = self.get_knots(self.y, nknots) | |
self.xt = self.get_t(self.x, numt) | |
self.yt = self.get_t(self.y, numt) | |
def get_knots(self, v, nknots): | |
# Linearly spaced knots | |
return np.linspace(np.min(v), np.max(v), nknots) | |
def get_t(self, v, numt): | |
# Get t parameter | |
return np.linspace(np.min(v), np.max(v), numt, endpoint=False) | |
def isin(self, t, i, knots): | |
# turn on/of control parameter if it is in between/outside of knots | |
return 1 if (t >= knots[i]) & (t < knots[i + 1]) else 0 | |
def spline_basis(self, t, i, j, knots): | |
if j == 0: | |
return self.isin(t, i, knots) | |
else: | |
a = (t - knots[i]) / (knots[i + j] - knots[i]) | |
b = (knots[i + j + 1] - t) / (knots[i + j + 1] - knots[i + 1]) | |
return a*self.spline_basis(t, i, j - 1, knots) + b*self.spline_basis(t, i + 1, j - 1, knots) | |
def _run_one(self, v, t, knots): | |
m = np.zeros((len(v), len(t))) | |
for i in range(len(v)): | |
for it, ti in enumerate(t): | |
m[i, it] = self.spline_basis(ti, i, self.degree, knots) | |
return m | |
def run(self): | |
my = self._run_one(self.y, self.yt, self.yknots) | |
mx = self._run_one(self.x, self.xt, self.xknots) | |
return mx, my | |
def project(self, m, v): | |
return np.dot(v, m) | |
def plot(self, mx, my): | |
plt.plot(self.x, self.y, 'ro') | |
plt.plot(self.project(mx, self.x), self.project(my, self.y)) | |
plt.show() | |
bs = BSpline(x, y, 4) | |
mx, my = bs.run() | |
# basis | |
plt.plot(my) | |
plt.show() | |
# fit | |
bs.plot(mx, my) | |
plt.show() | |
class NURBS(BSpline): | |
def __init__(self, x, y, degree, numt = 1000, w=None): | |
# Do we want to control :param: w to ensure fitted function curvature is negative? | |
if w is not None: | |
assert w.shape == x.shape, "w and x must have same shape" | |
self.w = w | |
else: | |
self.w = np.ones(x.shape[0]) | |
super(NURBS, self).__init__(x, y, degree, numt) | |
def _run_one(self, v, t, knots): | |
# V are points | |
# C(u) = SUMover(i points) spline_basis()*w*p | |
m = np.zeros((len(v), len(t))) | |
for i in range(len(v)): | |
for it, ti in enumerate(t): | |
m[i, it] = self.spline_basis(ti, i, self.degree, knots) | |
# Multiply basis by weighting | |
m *= self.w.reshape([-1, 1]) | |
# Normalize at each :param: t value | |
return m / m.sum(axis=0) | |
nurb = NURBS(x, y, 3, w=np.random.uniform(size=x.shape[0])*10) | |
mx, my = nurb.run() | |
# basis | |
plt.plot(my) | |
plt.show() | |
#fit | |
nurb.plot(mx, my) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment