Forked from finsberg/monodomain_external_operator_vs_splitting.py
Last active
March 10, 2025 09:48
-
-
Save jorgensd/29cf779486ddb0fa295ce2b324a73ce5 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dataclasses import dataclass | |
import shutil | |
from mpi4py import MPI | |
from petsc4py import PETSc | |
import dolfinx.fem.petsc | |
from dolfinx_external_operator import ( | |
FEMExternalOperator, | |
evaluate_external_operators, | |
evaluate_operands, | |
replace_external_operators, | |
) | |
import numpy.typing as npt | |
import numpy as np | |
import basix.ufl | |
import ufl | |
def v_exact_func(x, t): | |
return ufl.cos(2 * ufl.pi * x[0]) * ufl.cos(2 * ufl.pi * x[1]) * ufl.sin(t) | |
def s_exact_func(x, t): | |
return -ufl.cos(2 * ufl.pi * x[0]) * ufl.cos(2 * ufl.pi * x[1]) * ufl.cos(t) | |
def ode_exact_func(x, t): | |
return ufl.as_vector((v_exact_func(x, t), s_exact_func(x, t))) | |
def ac_func(x, t): | |
return ( | |
8 | |
* ufl.pi**2 | |
* ufl.cos(2 * ufl.pi * x[0]) | |
* ufl.cos(2 * ufl.pi * x[1]) | |
* ufl.sin(t) | |
) | |
@dataclass | |
class ODE: | |
time: dolfinx.fem.Constant | |
dt: float | |
parameters: npt.NDArray | |
def __call__(self, v: npt.NDArray, ode_states: npt.NDArray): | |
s_ode = ode_states[:, 1::2] # Extract s | |
states = np.vstack([v.flatten(), s_ode.flatten()]) | |
new_values = simple_ode_forward_euler( | |
states, self.time.value, self.dt, parameters=None | |
) | |
return new_values.T.flatten() | |
# def __call__(self, v_pde: npt.NDArray, s: npt.NDArray): | |
# s = s.flatten() | |
# v = v_pde.flatten() | |
# I_ion = -sf | |
# values = np.zeros_like(v_pde.shape[0],) | |
# values[0] = v - I_ion * self.dt | |
# values[1] = s + v * self.dt | |
# return I_ion | |
def simple_ode_forward_euler(states, t, dt, parameters): | |
v, s = states | |
values = np.zeros_like(states) | |
values[0] = v - s * dt | |
values[1] = s + v * dt | |
return values | |
def VTXWriter(mesh, path, functions, engine="BP5", write: bool = False): | |
if write: | |
return dolfinx.io.VTXWriter(mesh.comm, path, functions, engine=engine) | |
else: | |
class DummyVTXWriter: | |
def write(self, *args, **kwargs): | |
pass | |
def close(self): | |
pass | |
return DummyVTXWriter() | |
def splitting_scheme(N=50, M=1.0, dt=0.01, T=1.0, quad_degree=4, save_vtx=False): | |
comm = MPI.COMM_WORLD | |
mesh = dolfinx.mesh.create_unit_square( | |
comm, N, N, dolfinx.cpp.mesh.CellType.triangle | |
) | |
time = dolfinx.fem.Constant(mesh, dolfinx.default_scalar_type(0.0)) | |
x = ufl.SpatialCoordinate(mesh) | |
I_s = ac_func(x, time) | |
el = basix.ufl.quadrature_element( | |
scheme="default", degree=quad_degree, cell=mesh.ufl_cell().cellname() | |
) | |
V_ode = dolfinx.fem.functionspace(mesh, el) | |
v_ode = dolfinx.fem.Function(V_ode) | |
s = dolfinx.fem.Function(V_ode) | |
s.interpolate( | |
dolfinx.fem.Expression( | |
s_exact_func(x, time), V_ode.element.interpolation_points() | |
) | |
) | |
# This is just zero | |
# v_init = ufl.replace(v_exact_func(x, t_var), {t_var: 0.0}) | |
# v_ode.interpolate(dolfinx.fem.Expression(v_init, V_ode.element.interpolation_points())) | |
states = np.zeros((2, s.x.array.size)) | |
states[1, :] = s.x.array | |
states[0, :] = v_ode.x.array | |
V_pde = dolfinx.fem.functionspace(mesh, ("P", 1)) | |
v_pde = dolfinx.fem.Function(V_pde, name="v_pde") | |
C_m = 1.0 | |
dx = ufl.Measure("dx", domain=mesh, metadata={"quadrature_degree": quad_degree}) | |
# Define variational formulation | |
v = ufl.TrialFunction(V_pde) | |
w = ufl.TestFunction(V_pde) | |
# # Set-up variational problem | |
# Dt_v_dt = v - v_ode | |
# G = (C_m * Dt_v_dt * w + dt * (ufl.inner(M * ufl.grad(v), ufl.grad(w)) - I_s * w)) * dx | |
# a, L = ufl.system(G) | |
# I = I_s | |
a = C_m * v * w * dx + dt * ufl.inner(M * ufl.grad(v), ufl.grad(w)) * dx | |
L = C_m * v_ode * w * dx + dt * I_s * w * dx | |
solver = dolfinx.fem.petsc.LinearProblem(a, L, u=v_pde) | |
dolfinx.fem.petsc.assemble_matrix(solver.A, solver.a) | |
solver.A.assemble() | |
v_expr = dolfinx.fem.Expression(v_pde, V_ode.element.interpolation_points()) | |
v_exact_expr = dolfinx.fem.Expression( | |
v_exact_func(x, time), V_pde.element.interpolation_points() | |
) | |
v_exact = dolfinx.fem.Function(V_pde, name="v_exact") | |
path = "splitting_scheme.bp" | |
shutil.rmtree(path, ignore_errors=True) | |
vtx = VTXWriter(mesh.comm, path, [v_exact, v_pde], engine="BP5", write=save_vtx) | |
while time.value < T: | |
states[0, :] = v_ode.x.array | |
states[:] = simple_ode_forward_euler(states, time.value, dt, parameters=None) | |
v_ode.x.array[:] = states[0, :] | |
with solver.b.localForm() as b_loc: | |
b_loc.set(0) | |
dolfinx.fem.petsc.assemble_vector(solver.b, solver.L) | |
solver.b.ghostUpdate( | |
addv=PETSc.InsertMode.ADD, | |
mode=PETSc.ScatterMode.REVERSE, | |
) | |
solver.solver.solve(solver.b, v_pde.x.petsc_vec) | |
v_pde.x.scatter_forward() | |
v_ode.interpolate(v_expr) | |
time.value += dt | |
v_exact.interpolate(v_exact_expr) | |
vtx.write(time.value) | |
error = dolfinx.fem.form((v_pde - v_exact) ** 2 * dx) | |
E = np.sqrt(comm.allreduce(dolfinx.fem.assemble_scalar(error), MPI.SUM)) | |
vtx.close() | |
return E | |
def external_operator(N=50, M=1.0, dt=0.01, T=1.0, quad_degree=4, save_vtx=False): | |
comm = MPI.COMM_WORLD | |
mesh = dolfinx.mesh.create_unit_square( | |
comm, N, N, dolfinx.cpp.mesh.CellType.triangle | |
) | |
time = dolfinx.fem.Constant(mesh, dolfinx.default_scalar_type(0.0)) | |
x = ufl.SpatialCoordinate(mesh) | |
I_s = ac_func(x, time) | |
el = basix.ufl.quadrature_element( | |
scheme="default", | |
degree=quad_degree, | |
cell=mesh.ufl_cell().cellname(), | |
value_shape=(2,), | |
) | |
V_ode = dolfinx.fem.functionspace(mesh, el) | |
ode_states_old = dolfinx.fem.Function(V_ode) | |
ode_states_old.interpolate( | |
dolfinx.fem.Expression( | |
ode_exact_func(x, time), V_ode.element.interpolation_points() | |
) | |
) | |
# states = np.zeros((2, s.x.array.size)) | |
# states[1, :] = s.x.array | |
# states[0, :] = v_old.x.array | |
def f_external(derivatives: tuple[int, ...]): | |
if derivatives == (0, 0): # no derivation, the function itself | |
return ODE(time, dt, parameters=None) | |
elif derivatives == (1, 0): # the derivative with respect to the operand `uh` | |
return NotImplementedError | |
elif derivatives == (0, 1): | |
return NotImplementedError | |
else: | |
return NotImplementedError | |
V_pde = dolfinx.fem.functionspace(mesh, ("P", 1)) | |
v_pde = dolfinx.fem.Function(V_pde, name="v_pde") | |
ode_states = FEMExternalOperator( | |
v_pde, ode_states_old, function_space=V_ode, external_function=f_external | |
) | |
v_ode, s_ode = ufl.split(ode_states) | |
# This is just zero | |
# v_init = ufl.replace(v_exact_func(x, t_var), {t_var: 0.0}) | |
# v_ode.interpolate(dolfinx.fem.Expression(v_init, V_ode.element.interpolation_points())) | |
C_m = 1.0 | |
dx = ufl.Measure("dx", domain=mesh, metadata={"quadrature_degree": quad_degree}) | |
# Define variational formulation | |
v = ufl.TrialFunction(V_pde) | |
w = ufl.TestFunction(V_pde) | |
# # Set-up variational problem | |
# Pre solve PDE, an ODE has to be solved | |
a = C_m * v * w * dx + dt * ufl.inner(M * ufl.grad(v), ufl.grad(w)) * dx | |
L = C_m * v_ode * w * dx + dt * (I_s) * w * dx | |
L_updated, operators = replace_external_operators(L) | |
L_compiled = dolfinx.fem.form(L_updated) | |
solver = dolfinx.fem.petsc.LinearProblem(a, L_compiled, u=v_pde) | |
dolfinx.fem.petsc.assemble_matrix(solver.A, solver.a) | |
solver.A.assemble() | |
v_expr = dolfinx.fem.Expression(v_pde, V_ode.element.interpolation_points()) | |
v_exact_expr = dolfinx.fem.Expression( | |
v_exact_func(x, time), V_pde.element.interpolation_points() | |
) | |
v_exact = dolfinx.fem.Function(V_pde, name="v_exact") | |
path = "external_operator.bp" | |
shutil.rmtree(path, ignore_errors=True) | |
vtx = VTXWriter(mesh.comm, path, [v_exact, v_pde], engine="BP5", write=save_vtx) | |
while time.value < T: | |
# Evaluate external operators | |
coefficients = evaluate_operands(operators) | |
# Associate coefficients with external operators | |
evaluate_external_operators(operators, coefficients) | |
# Hold solution of ODE post evaluation | |
# v_ode is the values below | |
# print(operators[0].ref_coefficient.x.array) | |
with solver.b.localForm() as b_loc: | |
b_loc.set(0) | |
dolfinx.fem.petsc.assemble_vector(solver.b, solver.L) | |
solver.b.ghostUpdate( | |
addv=PETSc.InsertMode.ADD, | |
mode=PETSc.ScatterMode.REVERSE, | |
) | |
solver.solver.solve(solver.b, v_pde.x.petsc_vec) | |
v_pde.x.scatter_forward() | |
ode_states_old.interpolate(ode_states.ref_coefficient) | |
# v_ode.interpolate(v_expr) | |
time.value += dt | |
v_exact.interpolate(v_exact_expr) | |
vtx.write(time.value) | |
error = dolfinx.fem.form((v_pde - v_exact) ** 2 * dx) | |
E = np.sqrt(comm.allreduce(dolfinx.fem.assemble_scalar(error), MPI.SUM)) | |
vtx.close() | |
return E | |
def main(): | |
print("\nSplitting scheme (spatial)") | |
err_splitting = [] | |
for N in [4, 8, 16, 32, 64]: | |
err = splitting_scheme(N=N, T=1.0, dt=0.0005) | |
print(f"N={N}, error={err}") | |
err_splitting.append(err) | |
if len(err_splitting) > 1: | |
oder = np.log2(err_splitting[-2] / err_splitting[-1]) | |
print(f"Order of convergence: {oder}") | |
print("\nExternal operator (spatial)") | |
err_external = [] | |
for N in [4, 8, 16, 32, 64]: | |
err = external_operator(N=N, T=1.0, dt=0.0005) | |
print(f"N={N}, error={err}") | |
err_external.append(err) | |
if len(err_external) > 1: | |
oder = np.log2(err_external[-2] / err_external[-1]) | |
print(f"Order of convergence: {oder}") | |
print("\nSplitting scheme (temporal)") | |
err_splitting = [] | |
for dt in [0.1 / 2**i for i in range(4)]: | |
err = splitting_scheme(dt=dt) | |
print(f"dt={N}, error={err}") | |
err_splitting.append(err) | |
if len(err_splitting) > 1: | |
oder = np.log2(err_splitting[-2] / err_splitting[-1]) | |
print(f"Order of convergence: {oder}") | |
print("\nExternal operator (temporal)") | |
err_external = [] | |
for dt in [0.1 / 2**i for i in range(4)]: | |
err = external_operator(dt=dt) | |
print(f"dt={N}, error={err}") | |
err_external.append(err) | |
if len(err_external) > 1: | |
oder = np.log2(err_external[-2] / err_external[-1]) | |
print(f"Order of convergence: {oder}") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment