diff --git a/smart/model.py b/smart/model.py index 809abf5..cfd01a2 100644 --- a/smart/model.py +++ b/smart/model.py @@ -49,7 +49,7 @@ ) from .solvers import smartSNESProblem -from .units import unit +from . import unit logger = logging.getLogger(__name__) @@ -632,7 +632,8 @@ def _init_2_4_check_for_unused_parameters_species_compartments(self): raise ValueError(print_str) def _init_2_5_link_compartments_to_species(self): - """Linking compartments and compartment dimensionality to species""" + """Linking compartments and compartment dimensionality to species, + check for consistency of diffusion coefficient definition""" logger.debug( "Linking compartments and compartment dimensionality to species", extra=dict(format_type="log"), @@ -640,6 +641,12 @@ def _init_2_5_link_compartments_to_species(self): for species in self.sc: species.compartment = self.cc[species.compartment_name] species.dimensionality = self.cc[species.compartment_name].dimensionality + # convert diffusion coeff to consistent with mesh + diffusion_conversion = species.diffusion_.to( + species.compartment.compartment_**2 / unit.s + ) + species.diffusion_ = species.compartment.compartment_**2 / unit.s + species.D *= diffusion_conversion.magnitude def _init_2_6_link_species_to_compartments(self): """Links species to compartments - a species is considered to be @@ -1077,7 +1084,7 @@ def _init_5_2_create_variational_forms(self): for flux in self.fc: # -1 factor in flux.form means this is a lhs term form_type = "boundary_reaction" if flux.is_boundary_condition else "domain_reaction" - flux_form_units = flux.equation_units * flux.measure_units + flux_form_ = flux.equation_ * flux.measure_ # Determine if flux is linear w.r.t. compartment functions # Use flux.is_linear_wrt_comp and combine with linear_wrt_comp # (prioritizing former). If compartment is not relevant to flux then it is linear @@ -1090,7 +1097,7 @@ def _init_5_2_create_variational_forms(self): flux.form, flux.destination_species, form_type, - flux_form_units, + flux_form_, True, linearity_dict, ) @@ -1105,15 +1112,15 @@ def _init_5_2_create_variational_forms(self): v = species.v D = species.D dx = species.compartment.mesh.dx - Dform_units = ( - species.diffusion_units - * species.concentration_units - * species.compartment.compartment_units ** (species.compartment.dimensionality - 2) + Dform_ = ( + species.diffusion_ + * species.concentration_ + * species.compartment.compartment_ ** (species.compartment.dimensionality - 2) ) - mass_form_units = ( - species.concentration_units + mass_form_ = ( + species.concentration_ / unit.s - * species.compartment.compartment_units**species.compartment.dimensionality + * species.compartment.compartment_**species.compartment.dimensionality ) # diffusion term if species.D == 0: @@ -1123,11 +1130,6 @@ def _init_5_2_create_variational_forms(self): extra=dict(format_type="log"), ) else: - if Dform_units != mass_form_units: # unit conversion for consistency - diffusion_conversion = species.diffusion_units.to( - species.compartment.compartment_units**2 / unit.s - ) - D *= diffusion_conversion.magnitude D_constant = d.Constant(D, name=f"D_{species.name}") if self.config.flags["axisymmetric_model"]: Dform = x[0] * D_constant * d.inner(d.grad(u), d.grad(v)) * dx @@ -1141,7 +1143,7 @@ def _init_5_2_create_variational_forms(self): Dform, species, "diffusion", - Dform_units, + Dform_, True, linear_wrt_comp, ) @@ -1158,7 +1160,7 @@ def _init_5_2_create_variational_forms(self): Muform, species, "mass_u", - mass_form_units, + mass_form_, True, linear_wrt_comp, ) @@ -1173,7 +1175,7 @@ def _init_5_2_create_variational_forms(self): Munform, species, "mass_un", - mass_form_units, + mass_form_, True, linear_wrt_comp, )