Last active
October 17, 2018 01:20
-
-
Save zkytony/bbcc70be83bc094947d9f9b89d18dc27 to your computer and use it in GitHub Desktop.
Use SymPy to plot series functions
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
plots series functions | |
""" | |
import math | |
import sympy | |
from pprint import pprint | |
import matplotlib.pyplot as plt | |
# constants | |
EVEN=1 | |
ODD=2 | |
def get_series(f, x, n_limit=7, only=None, **variables): | |
if only is None: | |
y = f(1, x, **variables) | |
for n in range(2, n_limit+1): | |
y += f(n, x, **variables) | |
elif only == EVEN: # start from 2 | |
y = f(2, x, **variables) | |
for n in range(4, n_limit+1, 2): | |
y += f(n, x, **variables) | |
elif only == ODD: # start from 1 | |
y = f(1, x, **variables) | |
for n in range(3, n_limit+1, 2): | |
y += f(n, x, **variables) | |
return y | |
def plot_function(y, x, low=-2, up=2, num_points=30, save_file=None, label="f(x)", style="--"): | |
"""plot y = f(x)""" | |
step_size = (up - low) / num_points | |
xvals = [] | |
yvals = [] | |
xval = low | |
while xval < up: | |
yval = float(y.subs([(x, xval)])) | |
xvals.append(xval) | |
yvals.append(yval) | |
xval += step_size | |
plt.plot(xvals, yvals, style, label=label) | |
plt.axhline(y=0, color='k') | |
plt.axvline(x=0, color='k') | |
plt.xlabel("x") | |
plt.ylabel("f(x)") | |
plt.legend(loc="lower right") | |
if save_file is not None: | |
plt.savefig("figures/%s.png" % save_file) | |
plt.close() | |
def plot_series(f, n_limit=7, low=-2, up=2, num_points=30, save=False): | |
"""Plot series written in closed form by `f`""" | |
x = sympy.symbols('x') | |
y = get_series(f, x, n_limit=n_limit) | |
plot_function(y, low=low, up=up, num_points=num_points, save_file=f.__name__) | |
### Problem specific ### | |
def f1a(n, x): | |
return 2*(math.pi*n-sympy.sin(math.pi*n))*sympy.sin(n*x) / (math.pi*n**2) | |
def f1b(n, x): | |
# 1/pi*((integral of -pi/2*sin(nx)dx from -pi to -pi/2)+(integral of x*sin(nx)dx from -pi/2 to pi/2)+(integral of pi/2*sin(nx)dx from pi/2 to pi)) | |
return (2*sympy.sin(math.pi*n/2)/(math.pi*n**2)-sympy.cos(math.pi*n)/n)*sympy.sin(n*x) | |
def f2a(n, x): | |
return 1 / n * sympy.sin(n*x) | |
def prob_2a(n_limit=6, save=False): | |
x = sympy.symbols('x') | |
y_odd = 2*get_series(f2a, x, n_limit=n_limit, only=ODD) | |
y_even = 2*get_series(f2a, x, n_limit=n_limit, only=EVEN) | |
y = y_odd - y_even | |
print("y = " + str(y)) | |
plot_function(y, x, save_file=None if not save else "prob_2a_%d" % n_limit, | |
label="f(x) (%d terms)" % n_limit) | |
def f2b(n, x): | |
return 1 / n**2 * sympy.cos(n*x) | |
def prob_2b(n_limit=5, save=False): | |
x = sympy.symbols('x') | |
y = 0.5 + (4 / math.pi**2) * get_series(f2b, x, n_limit=n_limit, only=ODD) | |
print("y = " + str(y)) | |
plt.ylabel("f(x) (%d terms)" % n_limit) | |
plot_function(y, x, save_file=None if not save else "prob_2b_%d" % n_limit, | |
label="f(x) (%d terms)" % n_limit) | |
def f4a_r(n, x): | |
return 1/n * sympy.sin(n*x) | |
def f4a_y(n, x, c): | |
D_n = -n**4+(2-c**2)*n**2-1 | |
A_n = 4*c / (math.pi*D_n) | |
B_n = (n**2-1)*4*c / (n*c*math.pi*D_n) | |
return A_n * sympy.cos(n*x) + B_n * sympy.sin(n*x) | |
def f4b_r(n, x): | |
return (-1)**(n+1)*12/n**3 * sympy.sin(n*x) | |
def f4b_y(n, x, c): | |
D_n = -n**4+(2-c**2)*n**2-1 | |
A_n = (-1)**(n+1)*12*c / (n**2*D_n) | |
B_n = (-1)**(n+1)*12*(n**2-1)*c / (n**3*c*D_n) | |
return A_n * sympy.cos(n*x) + B_n * sympy.sin(n*x) | |
def prob_4a(n_limit=5, save=False, c_vals=[]): | |
x = sympy.symbols('x') | |
r = 4/math.pi * get_series(f4a_r, x, n_limit=n_limit, only=ODD) | |
print("r = " + str(r)) | |
plot_function(r, x, save_file=None if not save else "prob_4a_r_%d" % n_limit, | |
label="r(t) (n $\leq$ %d)" % n_limit) | |
c = sympy.symbols('c') | |
y = get_series(f4a_y, x, n_limit=n_limit, c=c) | |
print("y(c,x) = " + str(y)) | |
for c_val in c_vals: | |
y1 = y.subs([(c, c_val)]) | |
print("y(x) = " + str(y1)) | |
plot_function(y1, x, | |
label="y(t), c=%f (n $\leq$ %d)" % (c_val, n_limit)) | |
# Save it afterwards | |
save_file = None if not save else "prob_4a_y_%d" % n_limit | |
if save_file is not None: | |
plt.savefig("figures/%s.png" % save_file) | |
else: | |
plt.show() | |
def prob_4b(n_limit=5, save=False, c_vals=[]): | |
x = sympy.symbols('x') | |
r = 12 * get_series(f4b_r, x, n_limit=n_limit, only=ODD) | |
print("r = " + str(r)) | |
plot_function(r, x, save_file=None if not save else "prob_4b_r_%d" % n_limit, | |
label="r(t) (n $\leq$ %d)" % n_limit) | |
c = sympy.symbols('c') | |
y = get_series(f4b_y, x, n_limit=n_limit, c=c) | |
print("y(c,x) = " + str(y)) | |
for c_val in c_vals: | |
y1 = y.subs([(c, c_val)]) | |
print("y(x) = " + str(y1)) | |
plot_function(y1, x, | |
label="y(t), c=%f (n $\leq$ %d)" % (c_val, n_limit)) | |
# Save it afterwards | |
save_file = None if not save else "prob_4b_y_%d" % n_limit | |
if save_file is not None: | |
plt.savefig("figures/%s.png" % save_file) | |
else: | |
plt.show() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment