Skip to content

Commit

Permalink
Add euler update prior to entering Newton iterations
Browse files Browse the repository at this point in the history
  • Loading branch information
emmetfrancis committed Dec 14, 2023
1 parent 305e502 commit b25560e
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 11 deletions.
10 changes: 9 additions & 1 deletion examples/example2/example2_conservation_test.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
")\n",
"from matplotlib import pyplot as plt\n",
"import matplotlib.image as mpimg\n",
"import time\n",
"\n",
"logger = logging.getLogger(\"smart\")\n",
"logger.setLevel(logging.INFO)"
Expand Down Expand Up @@ -116,6 +117,7 @@
"# =============================================================================================\n",
"# name, initial concentration, concentration units, diffusion, diffusion units, compartment\n",
"A = Species(\"A\", 1.0, surf_unit, 0.01, D_unit, \"Cyto\")\n",
"# Aedge = Species(\"Aedge\", 1.0, edge_unit, 0.0, D_unit, \"PM\")\n",
"X = Species(\"X\", 1.0, edge_unit, 1.0, D_unit, \"PM\")\n",
"B = Species(\"B\", 0.0, edge_unit, 1.0, D_unit, \"PM\")\n",
"sc = SpeciesContainer()\n",
Expand Down Expand Up @@ -210,6 +212,7 @@
"source": [
"enforce_mass_conservation = [True, False]\n",
"for i in range(len(enforce_mass_conservation)):\n",
" start_time = time.time()\n",
" config_cur = config.Config()\n",
" config_cur.flags.update({\n",
" \"allow_unused_components\": True, \n",
Expand Down Expand Up @@ -239,14 +242,18 @@
" dx_cyto = d.Measure(\"dx\", domain=cytoMesh)\n",
" # ds_cyto = d.Measure(\"ds\", domain=cytoMesh)\n",
" volume = d.assemble(1.0*dx_cyto)\n",
" dx_pm = d.Measure(\"dx\", domain=model_cur.cc[\"PM\"].dolfin_mesh)\n",
" pmMesh = model_cur.cc[\"PM\"].dolfin_mesh\n",
" dx_pm = d.Measure(\"dx\", domain=pmMesh)\n",
" sa = d.assemble(1.0*dx_pm)\n",
"\n",
" Atot_vec = [sc[\"A\"].initial_condition * volume + sc[\"B\"].initial_condition * sa]\n",
" # Xtot_vec = [sc[\"X\"].initial_condition * sa + sc[\"B\"].initial_condition * sa]\n",
" # Set loglevel to warning in order not to pollute notebook output\n",
" logger.setLevel(logging.WARNING)\n",
"\n",
" store_map_pm = pmMesh.topology().mapping()[parent_mesh.dolfin_mesh.id()].vertex_map()\n",
" store_map_cyto = cytoMesh.topology().mapping()[parent_mesh.dolfin_mesh.id()].vertex_map()\n",
"\n",
" while True:\n",
" # Solve the system\n",
" model_cur.monolithic_solve()\n",
Expand All @@ -265,6 +272,7 @@
" break\n",
" plt.plot(tvec, 100*np.array(Atot_vec)/Atot_vec[0], \n",
" label=f\"Mass conservation = {enforce_mass_conservation[i]}\")\n",
" print(f\"Processing time = {time.time()-start_time} for mass conservation = {enforce_mass_conservation[i]}\")\n",
"plt.xlabel('Time (s)')\n",
"plt.ylabel('Normalized molecule count (%)')\n",
"plt.legend()"
Expand Down
2 changes: 1 addition & 1 deletion smart/mesh_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,7 +904,7 @@ def compute_curvature(

# map kappab to mesh function
if half_mesh_data == "":
store_map_b = bmesh.topology().mapping()[0].vertex_map()
store_map_b = bmesh.topology().mapping()[ref_mesh.id()].vertex_map()
for j in range(len(store_map_b)):
cur_sub_idx = d.vertex_to_dof_map(Vb)[j]
kappa_mf.set_value(store_map_b[j], kappab.vector().get_local()[cur_sub_idx])
Expand Down
90 changes: 85 additions & 5 deletions smart/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,11 +792,13 @@ def _init_4_2_define_dolfin_function_spaces(self):
self.child_meshes[compartment.name].dolfin_mesh, "P", 1
)

if self.parent_mesh.curvature != "":
scalarFunctionSpace = d.FunctionSpace(
self.child_meshes[compartment.name].dolfin_mesh, "P", 1
)
compartment.curv_func = self.mf0_to_fun(self.parent_mesh.curvature, scalarFunctionSpace)
if self.parent_mesh.curvature != "":
scalarFunctionSpace = d.FunctionSpace(
self.child_meshes[compartment.name].dolfin_mesh, "P", 1
)
compartment.curv_func = self.mf0_to_fun(
self.parent_mesh.curvature, scalarFunctionSpace
)

self.V = [compartment.V for compartment in self._active_compartments]
# Make the MixedFunctionSpace
Expand Down Expand Up @@ -1594,6 +1596,8 @@ def monolithic_solve(self):
are true.
"""
if self.config.flags["enforce_mass_conservation"]:
self.euler_update()
self.idx += 1
# start a timer for the total time step
self.stopwatches["Total time step"].start()
Expand Down Expand Up @@ -2138,3 +2142,79 @@ def mf0_to_fun(self, mf0, V):
dfunc.vector().set_local(values)
dfunc.vector().apply("insert")
return dfunc

def euler_update(self):
"""
Update estimate for volume variables projected onto the boundary
at the next time step (should improve convergence of mass conservation
adjusted simulations)
"""
ref_mesh = self.parent_mesh.dolfin_mesh
valsList = []
namesList = []
for f in self.fc:
# if f.topology in ["surface_to_volume", "volume_to_surface"]:
cur_sp = f.destination_species.name
cur_change = f.equation_lambda_eval("value")
if f.topology in ["surface", "volume"]:
continue
# cur_change = d.project(cur_change, self.sc[cur_sp].V)
# vals_vec = d.project(self.sc[cur_sp].u["u"], self.sc[cur_sp].V)
# vals_new = (vals_vec.vector().get_local() +
# float(self.dt)*cur_change.vector().get_local())
# vals = self.sc[cur_sp].u["u"].vector().get_local()
# vals[self.sc[cur_sp].dof_map] = vals_new
else:
surf_dim = f.surface.dimensionality
if f.destination_species.dimensionality != surf_dim: # only the 'volume' species
surf_sp = list(f.surface.species.keys())[0]
cur_change = d.project(cur_change, self.sc[surf_sp].V)
vals_vec = d.project(self.sc[cur_sp].u["u"], self.sc[cur_sp].V)
surf_map = (
f.surface.dolfin_mesh.topology().mapping()[ref_mesh.id()].vertex_map()
)
domain_map = (
f.destination_compartment.dolfin_mesh.topology()
.mapping()[ref_mesh.id()]
.vertex_map()
)
vals_new = vals_vec.vector().get_local()
for j in range(len(surf_map)):
j_domain = np.nonzero(np.equal(domain_map, surf_map[j]))
j_domain = j_domain[0][0]
cur_domain_idx = d.vertex_to_dof_map(self.sc[cur_sp].V)[j_domain]
cur_surf_idx = d.vertex_to_dof_map(self.sc[surf_sp].V)[j]
vals_new[cur_domain_idx] += (
float(self.dt) * cur_change.vector().get_local()[cur_surf_idx]
)
vals = self.sc[cur_sp].u["u"].vector().get_local()
vals[self.sc[cur_sp].dof_map] = vals_new
else:
continue
if cur_sp in namesList:
for j in range(len(namesList)):
if cur_sp == namesList[j]:
valsList[j] = (
valsList[j] + vals - self.sc[cur_sp].u["u"].vector().get_local()
)
else:
valsList.append(vals)
namesList.append(cur_sp)

# each compartment has an associated vector of unknowns
for c in self.cc:
vals_cur = c.u["u"].vector().get_local()
for i in range(len(valsList)):
if namesList[i] in c.species:
cur_dofmap = self.sc[namesList[i]].dof_map
vals_cur[cur_dofmap] = valsList[i][cur_dofmap]
c.u["u"].vector().set_local(vals_cur)
c.u["u"].vector().apply("insert")

if len(self.problem.global_sizes) == 1:
self._ubackend = self.u["u"]._functions[0].vector().vec().copy()
else:
self._ubackend = PETSc.Vec().createNest(
[usub.vector().vec().copy() for usub in self.u["u"]._functions],
comm=self.mpi_comm_world,
)
6 changes: 2 additions & 4 deletions smart/model_assembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -1471,7 +1471,8 @@ def _post_init_get_flux_topology(self):
if not s.compartment.mesh.is_surface:
surf_space = d.FunctionSpace(self.surface.dolfin_mesh, "CG", 1)
# u or usplit?
self.proj_var.update({sname: d.interpolate(s.u["u"], surf_space)})
cur_interp = d.interpolate(s.u["u"], surf_space)
self.proj_var.update({sname: cur_interp})

def _post_init_get_flux_units(self):
"""
Expand Down Expand Up @@ -1624,9 +1625,6 @@ def form(self):
"""-1 factor because terms are defined as if they were on the
lhs of the equation :math:`F(u;v)=0`"""
x = d.SpatialCoordinate(self.destination_compartment.dolfin_mesh)
# if self.name == "r1 [B (f)]":
# self.equation_variables["A"] = d.interpolate(
# self.species["A"].u["u"], self.surface.V)*self.species["A"].concentration_units
if self.axisymm:
return (
d.Constant(-1)
Expand Down

0 comments on commit b25560e

Please sign in to comment.