Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save jorgensd/29cf779486ddb0fa295ce2b324a73ce5 to your computer and use it in GitHub Desktop.
Save jorgensd/29cf779486ddb0fa295ce2b324a73ce5 to your computer and use it in GitHub Desktop.
from dataclasses import dataclass
import shutil
from mpi4py import MPI
from petsc4py import PETSc
import dolfinx.fem.petsc
from dolfinx_external_operator import (
FEMExternalOperator,
evaluate_external_operators,
evaluate_operands,
replace_external_operators,
)
import numpy.typing as npt
import numpy as np
import basix.ufl
import ufl
def v_exact_func(x, t):
return ufl.cos(2 * ufl.pi * x[0]) * ufl.cos(2 * ufl.pi * x[1]) * ufl.sin(t)
def s_exact_func(x, t):
return -ufl.cos(2 * ufl.pi * x[0]) * ufl.cos(2 * ufl.pi * x[1]) * ufl.cos(t)
def ode_exact_func(x, t):
return ufl.as_vector((v_exact_func(x, t), s_exact_func(x, t)))
def ac_func(x, t):
return (
8
* ufl.pi**2
* ufl.cos(2 * ufl.pi * x[0])
* ufl.cos(2 * ufl.pi * x[1])
* ufl.sin(t)
)
@dataclass
class ODE:
time: dolfinx.fem.Constant
dt: float
parameters: npt.NDArray
def __call__(self, v: npt.NDArray, ode_states: npt.NDArray):
s_ode = ode_states[:, 1::2] # Extract s
states = np.vstack([v.flatten(), s_ode.flatten()])
new_values = simple_ode_forward_euler(
states, self.time.value, self.dt, parameters=None
)
return new_values.T.flatten()
# def __call__(self, v_pde: npt.NDArray, s: npt.NDArray):
# s = s.flatten()
# v = v_pde.flatten()
# I_ion = -sf
# values = np.zeros_like(v_pde.shape[0],)
# values[0] = v - I_ion * self.dt
# values[1] = s + v * self.dt
# return I_ion
def simple_ode_forward_euler(states, t, dt, parameters):
v, s = states
values = np.zeros_like(states)
values[0] = v - s * dt
values[1] = s + v * dt
return values
def VTXWriter(mesh, path, functions, engine="BP5", write: bool = False):
if write:
return dolfinx.io.VTXWriter(mesh.comm, path, functions, engine=engine)
else:
class DummyVTXWriter:
def write(self, *args, **kwargs):
pass
def close(self):
pass
return DummyVTXWriter()
def splitting_scheme(N=50, M=1.0, dt=0.01, T=1.0, quad_degree=4, save_vtx=False):
comm = MPI.COMM_WORLD
mesh = dolfinx.mesh.create_unit_square(
comm, N, N, dolfinx.cpp.mesh.CellType.triangle
)
time = dolfinx.fem.Constant(mesh, dolfinx.default_scalar_type(0.0))
x = ufl.SpatialCoordinate(mesh)
I_s = ac_func(x, time)
el = basix.ufl.quadrature_element(
scheme="default", degree=quad_degree, cell=mesh.ufl_cell().cellname()
)
V_ode = dolfinx.fem.functionspace(mesh, el)
v_ode = dolfinx.fem.Function(V_ode)
s = dolfinx.fem.Function(V_ode)
s.interpolate(
dolfinx.fem.Expression(
s_exact_func(x, time), V_ode.element.interpolation_points()
)
)
# This is just zero
# v_init = ufl.replace(v_exact_func(x, t_var), {t_var: 0.0})
# v_ode.interpolate(dolfinx.fem.Expression(v_init, V_ode.element.interpolation_points()))
states = np.zeros((2, s.x.array.size))
states[1, :] = s.x.array
states[0, :] = v_ode.x.array
V_pde = dolfinx.fem.functionspace(mesh, ("P", 1))
v_pde = dolfinx.fem.Function(V_pde, name="v_pde")
C_m = 1.0
dx = ufl.Measure("dx", domain=mesh, metadata={"quadrature_degree": quad_degree})
# Define variational formulation
v = ufl.TrialFunction(V_pde)
w = ufl.TestFunction(V_pde)
# # Set-up variational problem
# Dt_v_dt = v - v_ode
# G = (C_m * Dt_v_dt * w + dt * (ufl.inner(M * ufl.grad(v), ufl.grad(w)) - I_s * w)) * dx
# a, L = ufl.system(G)
# I = I_s
a = C_m * v * w * dx + dt * ufl.inner(M * ufl.grad(v), ufl.grad(w)) * dx
L = C_m * v_ode * w * dx + dt * I_s * w * dx
solver = dolfinx.fem.petsc.LinearProblem(a, L, u=v_pde)
dolfinx.fem.petsc.assemble_matrix(solver.A, solver.a)
solver.A.assemble()
v_expr = dolfinx.fem.Expression(v_pde, V_ode.element.interpolation_points())
v_exact_expr = dolfinx.fem.Expression(
v_exact_func(x, time), V_pde.element.interpolation_points()
)
v_exact = dolfinx.fem.Function(V_pde, name="v_exact")
path = "splitting_scheme.bp"
shutil.rmtree(path, ignore_errors=True)
vtx = VTXWriter(mesh.comm, path, [v_exact, v_pde], engine="BP5", write=save_vtx)
while time.value < T:
states[0, :] = v_ode.x.array
states[:] = simple_ode_forward_euler(states, time.value, dt, parameters=None)
v_ode.x.array[:] = states[0, :]
with solver.b.localForm() as b_loc:
b_loc.set(0)
dolfinx.fem.petsc.assemble_vector(solver.b, solver.L)
solver.b.ghostUpdate(
addv=PETSc.InsertMode.ADD,
mode=PETSc.ScatterMode.REVERSE,
)
solver.solver.solve(solver.b, v_pde.x.petsc_vec)
v_pde.x.scatter_forward()
v_ode.interpolate(v_expr)
time.value += dt
v_exact.interpolate(v_exact_expr)
vtx.write(time.value)
error = dolfinx.fem.form((v_pde - v_exact) ** 2 * dx)
E = np.sqrt(comm.allreduce(dolfinx.fem.assemble_scalar(error), MPI.SUM))
vtx.close()
return E
def external_operator(N=50, M=1.0, dt=0.01, T=1.0, quad_degree=4, save_vtx=False):
comm = MPI.COMM_WORLD
mesh = dolfinx.mesh.create_unit_square(
comm, N, N, dolfinx.cpp.mesh.CellType.triangle
)
time = dolfinx.fem.Constant(mesh, dolfinx.default_scalar_type(0.0))
x = ufl.SpatialCoordinate(mesh)
I_s = ac_func(x, time)
el = basix.ufl.quadrature_element(
scheme="default",
degree=quad_degree,
cell=mesh.ufl_cell().cellname(),
value_shape=(2,),
)
V_ode = dolfinx.fem.functionspace(mesh, el)
ode_states_old = dolfinx.fem.Function(V_ode)
ode_states_old.interpolate(
dolfinx.fem.Expression(
ode_exact_func(x, time), V_ode.element.interpolation_points()
)
)
# states = np.zeros((2, s.x.array.size))
# states[1, :] = s.x.array
# states[0, :] = v_old.x.array
def f_external(derivatives: tuple[int, ...]):
if derivatives == (0, 0): # no derivation, the function itself
return ODE(time, dt, parameters=None)
elif derivatives == (1, 0): # the derivative with respect to the operand `uh`
return NotImplementedError
elif derivatives == (0, 1):
return NotImplementedError
else:
return NotImplementedError
V_pde = dolfinx.fem.functionspace(mesh, ("P", 1))
v_pde = dolfinx.fem.Function(V_pde, name="v_pde")
ode_states = FEMExternalOperator(
v_pde, ode_states_old, function_space=V_ode, external_function=f_external
)
v_ode, s_ode = ufl.split(ode_states)
# This is just zero
# v_init = ufl.replace(v_exact_func(x, t_var), {t_var: 0.0})
# v_ode.interpolate(dolfinx.fem.Expression(v_init, V_ode.element.interpolation_points()))
C_m = 1.0
dx = ufl.Measure("dx", domain=mesh, metadata={"quadrature_degree": quad_degree})
# Define variational formulation
v = ufl.TrialFunction(V_pde)
w = ufl.TestFunction(V_pde)
# # Set-up variational problem
# Pre solve PDE, an ODE has to be solved
a = C_m * v * w * dx + dt * ufl.inner(M * ufl.grad(v), ufl.grad(w)) * dx
L = C_m * v_ode * w * dx + dt * (I_s) * w * dx
L_updated, operators = replace_external_operators(L)
L_compiled = dolfinx.fem.form(L_updated)
solver = dolfinx.fem.petsc.LinearProblem(a, L_compiled, u=v_pde)
dolfinx.fem.petsc.assemble_matrix(solver.A, solver.a)
solver.A.assemble()
v_expr = dolfinx.fem.Expression(v_pde, V_ode.element.interpolation_points())
v_exact_expr = dolfinx.fem.Expression(
v_exact_func(x, time), V_pde.element.interpolation_points()
)
v_exact = dolfinx.fem.Function(V_pde, name="v_exact")
path = "external_operator.bp"
shutil.rmtree(path, ignore_errors=True)
vtx = VTXWriter(mesh.comm, path, [v_exact, v_pde], engine="BP5", write=save_vtx)
while time.value < T:
# Evaluate external operators
coefficients = evaluate_operands(operators)
# Associate coefficients with external operators
evaluate_external_operators(operators, coefficients)
# Hold solution of ODE post evaluation
# v_ode is the values below
# print(operators[0].ref_coefficient.x.array)
with solver.b.localForm() as b_loc:
b_loc.set(0)
dolfinx.fem.petsc.assemble_vector(solver.b, solver.L)
solver.b.ghostUpdate(
addv=PETSc.InsertMode.ADD,
mode=PETSc.ScatterMode.REVERSE,
)
solver.solver.solve(solver.b, v_pde.x.petsc_vec)
v_pde.x.scatter_forward()
ode_states_old.interpolate(ode_states.ref_coefficient)
# v_ode.interpolate(v_expr)
time.value += dt
v_exact.interpolate(v_exact_expr)
vtx.write(time.value)
error = dolfinx.fem.form((v_pde - v_exact) ** 2 * dx)
E = np.sqrt(comm.allreduce(dolfinx.fem.assemble_scalar(error), MPI.SUM))
vtx.close()
return E
def main():
print("\nSplitting scheme (spatial)")
err_splitting = []
for N in [4, 8, 16, 32, 64]:
err = splitting_scheme(N=N, T=1.0, dt=0.0005)
print(f"N={N}, error={err}")
err_splitting.append(err)
if len(err_splitting) > 1:
oder = np.log2(err_splitting[-2] / err_splitting[-1])
print(f"Order of convergence: {oder}")
print("\nExternal operator (spatial)")
err_external = []
for N in [4, 8, 16, 32, 64]:
err = external_operator(N=N, T=1.0, dt=0.0005)
print(f"N={N}, error={err}")
err_external.append(err)
if len(err_external) > 1:
oder = np.log2(err_external[-2] / err_external[-1])
print(f"Order of convergence: {oder}")
print("\nSplitting scheme (temporal)")
err_splitting = []
for dt in [0.1 / 2**i for i in range(4)]:
err = splitting_scheme(dt=dt)
print(f"dt={N}, error={err}")
err_splitting.append(err)
if len(err_splitting) > 1:
oder = np.log2(err_splitting[-2] / err_splitting[-1])
print(f"Order of convergence: {oder}")
print("\nExternal operator (temporal)")
err_external = []
for dt in [0.1 / 2**i for i in range(4)]:
err = external_operator(dt=dt)
print(f"dt={N}, error={err}")
err_external.append(err)
if len(err_external) > 1:
oder = np.log2(err_external[-2] / err_external[-1])
print(f"Order of convergence: {oder}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment