Skip to content

Commit

Permalink
Remove duplicate code and slightly tweak assembly logic
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgensd committed May 15, 2024
1 parent d4dd39a commit 954eb8e
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 100 deletions.
92 changes: 5 additions & 87 deletions smart/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1314,93 +1314,10 @@ def get_block_system(self, Fsum, u):
I0.ufl_operands[0] == Ib0.ufl_operands[0](1) -> True
"""

# blocks/partitions are by compartment, not species
Fblock = d.extract_blocks(Fsum)

# =====================================================================
# doflin.fem.problem.MixedNonlinearVariationalProblem()
# =====================================================================
# basically is a wrapper around the cpp class that finalizes preparing
# F and J into the right format
# TODO: add dirichlet BCs

# Add in placeholders for empty blocks of F
if len(Fblock) != len(u):
Ftemp = [None for i in range(len(u))]
for Fi in Fblock:
Ftemp[Fi.arguments()[0].part()] = Fi
Fblock = Ftemp

# debug attempt
J = []
for Fi in Fblock:
for uj in u:
if Fi is None:
# pass
J.append(None)
else:
dFdu = expand_derivatives(d.derivative(Fi, uj))
J.append(dFdu)

# Check number of blocks in the residual and solution are coherent
assert len(J) == len(u) * len(u)
assert len(Fblock) == len(u)

# Decompose F blocks into subforms based on domain of integration
# Fblock = [F0, F1, ... , Fn] where the index is the compartment index
# Flist = [[F0(Omega_0), F0(Omega_1)], ..., [Fn(Omega_n)]]
# If a form has integrals on multiple domains, they are split into a list
Flist = list()
for idx, Fi in enumerate(Fblock):
if Fi is None or Fi.empty():
logger.warning(
f"F{idx} = F[{self.cc.get_index(idx).name}]) is empty",
extra=dict(format_type="warning"),
)
Flist.append([d.cpp.fem.Form(1, 0)])
else:
Fs = []
for Fsub in sub_forms_by_domain(Fi):
if Fsub is None or Fsub.empty():
domain = self.get_mesh_by_id(Fsub.mesh().id()).name
logger.warning(
f"F{idx} = F[{self.cc.get_index(idx).name}] "
"is empty on integration domain {domain}",
extra=dict(format_type="logred"),
)
Fs.append(d.cpp.fem.Form(1, 0))
else:
Fs.append(d.Form(Fsub))
Flist.append(Fs)

# Decompose J blocks into subforms based on domain of integration
Jlist = list()
for idx, Ji in enumerate(J):
idx_i, idx_j = divmod(idx, len(u))
if Ji is None or Ji.empty():
logger.warning(
f"J{idx_i}{idx_j} = dF[{self.cc.get_index(idx_i).name}])"
f"/du[{self.cc.get_index(idx_j).name}] is empty",
extra=dict(format_type="logred"),
)
Jlist.append([d.cpp.fem.Form(2, 0)])
else:
Js = []
for Jsub in sub_forms_by_domain(Ji):
if Jsub is None or Jsub.empty():
domain = self.get_mesh_by_id(Jsub.mesh().id()).name
logger.warning(
f"J{idx_i}{idx_j} = dF[{self.cc.get_index(idx_i).name}])"
f"/du[{self.cc.get_index(idx_j).name}]"
f"is empty on integration domain {domain}",
extra=dict(format_type="logred"),
)
Js.append(d.Form(Jsub))
Jlist.append(Js)

Flist = get_block_F(self, Fsum, u)
Jlist = get_block_J(self, Fsum, u)
global_sizes = [uj.function_space().dim() for uj in u]

# return Flist, Jlist
return Flist, Jlist, global_sizes

def get_global_sizes(self, u):
Expand All @@ -1426,14 +1343,15 @@ def get_block_F(self, Fsum, u):
# Fblock = [F0, F1, ... , Fn] where the index is the compartment index
# Flist = [[F0(Omega_0), F0(Omega_1)], ..., [Fn(Omega_n)]]
# If a form has integrals on multiple domains, they are split into a list
# If Fi(Omega_j) is not defined, None is inserted
Flist = list()
for idx, Fi in enumerate(Fblock):
if Fi is None or Fi.empty():
logger.warning(
f"F{idx} = F[{self.cc.get_index(idx).name}]) is empty",
extra=dict(format_type="warning"),
)
Flist.append([d.cpp.fem.Form(1, 0)])
Flist.append([None])
else:
Fs = []
for Fsub in sub_forms_by_domain(Fi):
Expand All @@ -1444,7 +1362,7 @@ def get_block_F(self, Fsum, u):
f"on integration domain {domain}",
extra=dict(format_type="logred"),
)
Fs.append(d.cpp.fem.Form(1, 0))
Fs.append(None)
else:
Fs.append(d.Form(Fsub))
Flist.append(Fs)
Expand Down
30 changes: 17 additions & 13 deletions smart/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,22 +209,22 @@ def init_petsc_vecnest(self):
Fpetsc = []
for j in range(dim):
Fsum = None
for k in range(len(self.Fforms[j])):
if self.Fforms[j][k].function_space(0) is None:
for k, form in enumerate(self.Fforms[j]):
if form is None:
logger.warning(
f"{self.Fjk_name(j,k)}] has no function space",
f"{self.Fjk_name(j,k)}] is not defined",
extra=dict(format_type="log"),
)
continue

# Note: This could be simplified once add_values in assemble mixed is fixed
tensor = d.PETScVector(self.comm)

if Fsum is None:
Fsum = d.assemble_mixed(self.Fforms[j][k], tensor=tensor)
Fsum = d.assemble_mixed(form, tensor=tensor)
else:
# Fsum.axpy(1, d.assemble_mixed(self.Fforms[j][k], tensor=tensor).vec(),
# structure=Fsum.Structure.DIFFERENT_NONZERO_PATTERN)
Fsum += d.assemble_mixed(self.Fforms[j][k], tensor=tensor)
Fsum += d.assemble_mixed(form, tensor=tensor)

if Fsum is None:
logger.debug(
Expand Down Expand Up @@ -348,14 +348,18 @@ def assemble_Fnest(self, Fnest):

for j in range(dim):
Fvecs.append([])
for k in range(len(self.Fforms[j])):
# , tensor=d.PETScVector(Fvecs[idx]))
Fvecs[j].append(d.as_backend_type(d.assemble_mixed(self.Fforms[j][k])))
# TODO: could probably speed this up by not using axpy if there is only one subform
# sum the vectors
# NOTE: This can be simplified once add_values is fixed
for k, form in enumerate(self.Fforms[j]):
if form is None:
logger.warning(
f"{self.Fjk_name(j,k)}] is not defined",
extra=dict(format_type="log"),
)
continue
Fvecs[j].append(d.as_backend_type(d.assemble_mixed(form)))
Fj_petsc[j].zeroEntries()
for k in range(len(self.Fforms[j])):
Fj_petsc[j].axpy(1, Fvecs[j][k].vec())
for k, vec in enumerate(Fvecs[j]):
Fj_petsc[j].axpy(1, vec.vec())

Fnest.assemble()
self.stopwatches["snes residual assemble"].pause()
Expand Down

0 comments on commit 954eb8e

Please sign in to comment.