diff --git a/gusto/__init__.py b/gusto/__init__.py index d7dc21faa..37f870157 100644 --- a/gusto/__init__.py +++ b/gusto/__init__.py @@ -11,7 +11,6 @@ from gusto.diagnostics import * # noqa from gusto.diffusion_methods import * # noqa from gusto.equations import * # noqa -from gusto.fml import * # noqa from gusto.forcing import * # noqa from gusto.initialisation_tools import * # noqa from gusto.io import * # noqa diff --git a/gusto/equations.py b/gusto/equations.py index 1c50bd4e6..c49d28ccf 100644 --- a/gusto/equations.py +++ b/gusto/equations.py @@ -1,22 +1,27 @@ """Objects describing geophysical fluid equations to be solved in weak form.""" from abc import ABCMeta -from firedrake import (TestFunction, Function, sin, pi, inner, dx, div, cross, - FunctionSpace, MixedFunctionSpace, TestFunctions, - TrialFunction, FacetNormal, jump, avg, dS_v, dS, - DirichletBC, conditional, SpatialCoordinate, - split, Constant, action) +from firedrake import ( + TestFunction, Function, sin, pi, inner, dx, div, cross, + FunctionSpace, MixedFunctionSpace, TestFunctions, TrialFunction, + FacetNormal, jump, avg, dS_v, dS, DirichletBC, conditional, + SpatialCoordinate, split, Constant, action +) +from firedrake.fml import ( + Term, all_terms, keep, drop, Label, subject, name_label, + replace_subject, replace_trial_function +) from gusto.fields import PrescribedFields -from gusto.fml import (Term, all_terms, keep, drop, Label, subject, name, - replace_subject, replace_trial_function) -from gusto.labels import (time_derivative, transport, prognostic, hydrostatic, - linearisation, pressure_gradient, coriolis) +from gusto.labels import ( + time_derivative, transport, prognostic, hydrostatic, linearisation, + pressure_gradient, coriolis +) from gusto.thermodynamics import exner_pressure -from gusto.common_forms import (advection_form, continuity_form, - vector_invariant_form, kinetic_energy_form, - advection_equation_circulation_form, - diffusion_form, linear_continuity_form, - linear_advection_form) +from gusto.common_forms import ( + advection_form, continuity_form, vector_invariant_form, + kinetic_energy_form, advection_equation_circulation_form, + diffusion_form, linear_continuity_form, linear_advection_form +) from gusto.active_tracers import ActiveTracer, Phases, TracerVariableType from gusto.configuration import TransportEquationType import ufl @@ -972,7 +977,7 @@ def __init__(self, domain, parameters, Omega=None, sponge=None, raise NotImplementedError('Only mixing ratio tracers are implemented') theta_v = theta / (Constant(1.0) + tracer_mr_total) - pressure_gradient_form = name(subject(prognostic( + pressure_gradient_form = name_label(subject(prognostic( cp*(-div(theta_v*w)*exner*dx + jump(theta_v*w, n)*avg(exner)*dS_v), 'u'), self.X), "pressure_gradient") @@ -1039,7 +1044,7 @@ def __init__(self, domain, parameters, Omega=None, sponge=None, mubar*sin((pi/2.)*(z-zc)/(H-zc))**2) self.mu = self.prescribed_fields("sponge", W_DG).interpolate(muexpr) - residual += name(subject(prognostic( + residual += name_label(subject(prognostic( self.mu*inner(w, domain.k)*inner(u, domain.k)*dx, 'u'), self.X), "sponge") if diffusion_options is not None: @@ -1144,7 +1149,7 @@ def __init__(self, domain, parameters, Omega=None, sponge=None, k = self.domain.k u = split(self.X)[0] - self.residual += name( + self.residual += name_label( subject( prognostic( -inner(k, self.tests[0]) * inner(k, u) * dx, "u"), @@ -1309,8 +1314,10 @@ def __init__(self, domain, parameters, Omega=None, # The p features here so that the div(u) evaluated in the "forcing" step # replaces the whole pressure field, rather than merely providing an # increment to it. - divergence_form = name(subject(prognostic(phi*(p-div(u))*dx, 'p'), self.X), - "incompressibility") + divergence_form = name_label( + subject(prognostic(phi*(p-div(u))*dx, 'p'), self.X), + "incompressibility" + ) residual = (mass_form + adv_form + divergence_form + pressure_gradient_form + gravity_form) diff --git a/gusto/fml/__init__.py b/gusto/fml/__init__.py deleted file mode 100644 index 5a424f231..000000000 --- a/gusto/fml/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from gusto.fml.form_manipulation_language import * # noqa -from gusto.fml.replacement import * # noqa diff --git a/gusto/fml/form_manipulation_language.py b/gusto/fml/form_manipulation_language.py deleted file mode 100644 index 062fddb04..000000000 --- a/gusto/fml/form_manipulation_language.py +++ /dev/null @@ -1,430 +0,0 @@ -"""A language for manipulating forms using labels.""" - -import ufl -import functools -import operator -from firedrake import Constant, Function - - -__all__ = ["Label", "Term", "LabelledForm", "identity", "drop", "all_terms", - "keep", "subject", "name"] - -# ---------------------------------------------------------------------------- # -# Core routines for filtering terms -# ---------------------------------------------------------------------------- # -identity = lambda t: t -drop = lambda t: None -all_terms = lambda t: True -keep = identity - - -# ---------------------------------------------------------------------------- # -# Term class -# ---------------------------------------------------------------------------- # -class Term(object): - """A Term object contains a form and its labels.""" - - __slots__ = ["form", "labels"] - - def __init__(self, form, label_dict=None): - """ - Args: - form (:class:`ufl.Form`): the form for this terms. - label_dict (dict, optional): dictionary of key-value pairs - corresponding to current form labels. Defaults to None. - """ - self.form = form - self.labels = label_dict or {} - - def get(self, label): - """ - Returns the value of a label. - - Args: - label (:class:`Label`): the label to return the value of. - - Returns: - The value of a label. - """ - return self.labels.get(label.label) - - def has_label(self, *labels, return_tuple=False): - """ - Whether the term has the specified labels attached to it. - - Args: - *labels (:class:`Label`): a label or series of labels. A tuple is - automatically returned if multiple labels are provided as - arguments. - return_tuple (bool, optional): if True, forces a tuple to be - returned even if only one label is provided as an argument. - Defaults to False. - - Returns: - bool or tuple: Booleans corresponding to whether the term has the - specified labels. - """ - if len(labels) == 1 and not return_tuple: - return labels[0].label in self.labels - else: - return tuple(self.has_label(l) for l in labels) - - def __add__(self, other): - """ - Adds a term or labelled form to this term. - - Args: - other (:class:`Term` or :class:`LabelledForm`): the term or labelled - form to add to this term. - - Returns: - :class:`LabelledForm`: a labelled form containing the terms. - """ - if self is NullTerm: - return other - if other is None or other is NullTerm: - return self - elif isinstance(other, Term): - return LabelledForm(self, other) - elif isinstance(other, LabelledForm): - return LabelledForm(self, *other.terms) - else: - return NotImplemented - - __radd__ = __add__ - - def __sub__(self, other): - """ - Subtracts a term or labelled form from this term. - - Args: - other (:class:`Term` or :class:`LabelledForm`): the term or labelled - form to subtract from this term. - - Returns: - :class:`LabelledForm`: a labelled form containing the terms. - """ - other = other * Constant(-1.0) - return self + other - - def __mul__(self, other): - """ - Multiplies this term by another quantity. - - Args: - other (float, :class:`Constant` or :class:`ufl.algebra.Product`): - the quantity to multiply this term by. If it is a float or int - then it is converted to a :class:`Constant` before the - multiplication. - - Returns: - :class:`Term`: the product of the term with the quantity. - """ - return Term(other*self.form, self.labels) - - __rmul__ = __mul__ - - def __truediv__(self, other): - """ - Divides this term by another quantity. - - Args: - other (float, :class:`Constant` or :class:`ufl.algebra.Product`): - the quantity to divide this term by. If it is a float or int - then it is converted to a :class:`Constant` before the - division. - - Returns: - :class:`Term`: the quotient of the term divided by the quantity. - """ - return self * (Constant(1.0) / other) - - -# This is necessary to be the initialiser for functools.reduce -NullTerm = Term(None) - - -# ---------------------------------------------------------------------------- # -# Labelled form class -# ---------------------------------------------------------------------------- # -class LabelledForm(object): - """ - A form, broken down into terms that pair individual forms with labels. - - The `LabelledForm` object holds a list of terms, which pair :class:`Form` - objects with :class:`Label`s. The `label_map` routine allows the terms to be - manipulated or selected based on particular filters. - """ - __slots__ = ["terms"] - - def __init__(self, *terms): - """ - Args: - *terms (:class:`Term`): terms to combine to make the `LabelledForm`. - - Raises: - TypeError: _description_ - """ - if len(terms) == 1 and isinstance(terms[0], LabelledForm): - self.terms = terms[0].terms - else: - if any([type(term) is not Term for term in list(terms)]): - raise TypeError('Can only pass terms or a LabelledForm to LabelledForm') - self.terms = list(terms) - - def __add__(self, other): - """ - Adds a form, term or labelled form to this labelled form. - - Args: - other (:class:`ufl.Form`, :class:`Term` or :class:`LabelledForm`): - the form, term or labelled form to add to this labelled form. - - Returns: - :class:`LabelledForm`: a labelled form containing the terms. - """ - if isinstance(other, ufl.Form): - return LabelledForm(*self, Term(other)) - elif type(other) is Term: - return LabelledForm(*self, other) - elif type(other) is LabelledForm: - return LabelledForm(*self, *other) - elif other is None: - return self - else: - return NotImplemented - - __radd__ = __add__ - - def __sub__(self, other): - """ - Subtracts a form, term or labelled form from this labelled form. - - Args: - other (:class:`ufl.Form`, :class:`Term` or :class:`LabelledForm`): - the form, term or labelled form to subtract from this labelled - form. - - Returns: - :class:`LabelledForm`: a labelled form containing the terms. - """ - if type(other) is Term: - return LabelledForm(*self, Constant(-1.)*other) - elif type(other) is LabelledForm: - return LabelledForm(*self, *[Constant(-1.)*t for t in other]) - elif other is None: - return self - else: - # Make new Term for other and subtract it - return LabelledForm(*self, Term(Constant(-1.)*other)) - - def __mul__(self, other): - """ - Multiplies this labelled form by another quantity. - - Args: - other (float, :class:`Constant` or :class:`ufl.algebra.Product`): - the quantity to multiply this labelled form by. All terms in - the form are multiplied. - - Returns: - :class:`LabelledForm`: the product of all terms with the quantity. - """ - return self.label_map(all_terms, lambda t: Term(other*t.form, t.labels)) - - def __truediv__(self, other): - """ - Divides this labelled form by another quantity. - - Args: - other (float, :class:`Constant` or :class:`ufl.algebra.Product`): - the quantity to divide this labelled form by. All terms in the - form are divided. - - Returns: - :class:`LabelledForm`: the quotient of all terms with the quantity. - """ - return self * (Constant(1.0) / other) - - __rmul__ = __mul__ - - def __iter__(self): - """Returns an iterable of the terms in the labelled form.""" - return iter(self.terms) - - def __len__(self): - """Returns the number of terms in the labelled form.""" - return len(self.terms) - - def label_map(self, term_filter, map_if_true=identity, - map_if_false=identity): - """ - Maps selected terms in the labelled form, returning a new labelled form. - - Args: - term_filter (func): a function to filter the labelled form's terms. - map_if_true (func, optional): how to map the terms for which the - term_filter returns True. Defaults to identity. - map_if_false (func, optional): how to map the terms for which the - term_filter returns False. Defaults to identity. - - Returns: - :class:`LabelledForm`: a new labelled form with the terms mapped. - """ - - new_labelled_form = LabelledForm( - functools.reduce(operator.add, - filter(lambda t: t is not None, - (map_if_true(t) if term_filter(t) else - map_if_false(t) for t in self.terms)), - # Need to set an initialiser, otherwise the label_map - # won't work if the term_filter is False for everything - # None does not work, as then we add Terms to None - # and the addition operation is defined from None - # rather than the Term. NullTerm solves this. - NullTerm)) - - # Drop the NullTerm - new_labelled_form.terms = list(filter(lambda t: t is not NullTerm, - new_labelled_form.terms)) - - return new_labelled_form - - @property - def form(self): - """ - Provides the whole form from the labelled form. - - Raises: - TypeError: if the labelled form has no terms. - - Returns: - :class:`ufl.Form`: the whole form corresponding to all the terms. - """ - # Throw an error if there is no form - if len(self.terms) == 0: - raise TypeError('The labelled form cannot return a form as it has no terms') - else: - return functools.reduce(operator.add, (t.form for t in self.terms)) - - -class Label(object): - """Object for tagging forms, allowing them to be manipulated.""" - - __slots__ = ["label", "default_value", "value", "validator"] - - def __init__(self, label, *, value=True, validator=None): - """ - Args: - label (str): the name of the label. - value (..., optional): the value for the label to take. Can be any - type (subject to the validator). Defaults to True. - validator (func, optional): function to check the validity of any - value later passed to the label. Defaults to None. - """ - self.label = label - self.default_value = value - self.validator = validator - - def __call__(self, target, value=None): - """ - Applies the label to a form or term. - - Args: - target (:class:`ufl.Form`, :class:`Term` or :class:`LabelledForm`): - the form, term or labelled form to be labelled. - value (..., optional): the value to attach to this label. Defaults - to None. - - Raises: - ValueError: if the `target` is not a :class:`ufl.Form`, - :class:`Term` or :class:`LabelledForm`. - - Returns: - :class:`Term` or :class:`LabelledForm`: a :class:`Term` is returned - if the target is a :class:`Term`, otherwise a - :class:`LabelledForm` is returned. - """ - # if value is provided, check that we have a validator function - # and validate the value, otherwise use default value - if value is not None: - assert self.validator, f'Label {self.label} requires a validator' - assert self.validator(value), f'Value {value} for label {self.label} does not satisfy validator' - self.value = value - else: - self.value = self.default_value - if isinstance(target, LabelledForm): - return LabelledForm(*(self(t, value) for t in target.terms)) - elif isinstance(target, ufl.Form): - return LabelledForm(Term(target, {self.label: self.value})) - elif isinstance(target, Term): - new_labels = target.labels.copy() - new_labels.update({self.label: self.value}) - return Term(target.form, new_labels) - else: - raise ValueError("Unable to label %s" % target) - - def remove(self, target): - """ - Removes a label from a term or labelled form. - - This removes any :class:`Label` with this `label` from - `target`. If called on an :class:`LabelledForm`, it acts termwise. - - Args: - target (:class:`Term` or :class:`LabelledForm`): term or labelled - form to have this label removed from. - - Raises: - ValueError: if the `target` is not a :class:`Term` or a - :class:`LabelledForm`. - """ - - if isinstance(target, LabelledForm): - return LabelledForm(*(self.remove(t) for t in target.terms)) - elif isinstance(target, Term): - try: - d = target.labels.copy() - d.pop(self.label) - return Term(target.form, d) - except KeyError: - return target - else: - raise ValueError("Unable to unlabel %s" % target) - - def update_value(self, target, new): - """ - Updates the label of a term or labelled form. - - This updates the value of any :class:`Label` with this `label` from - `target`. If called on an :class:`LabelledForm`, it acts termwise. - - Args: - target (:class:`Term` or :class:`LabelledForm`): term or labelled - form to have this label updated. - new (...): the new value for this label to take. - - Raises: - ValueError: if the `target` is not a :class:`Term` or a - :class:`LabelledForm`. - """ - - if isinstance(target, LabelledForm): - return LabelledForm(*(self.update_value(t, new) for t in target.terms)) - elif isinstance(target, Term): - try: - d = target.labels.copy() - d[self.label] = new - return Term(target.form, d) - except KeyError: - return target - else: - raise ValueError("Unable to relabel %s" % target) - - -# ---------------------------------------------------------------------------- # -# Some common labels -# ---------------------------------------------------------------------------- # - -subject = Label("subject", validator=lambda value: type(value) == Function) -name = Label("name", validator=lambda value: type(value) == str) diff --git a/gusto/fml/replacement.py b/gusto/fml/replacement.py deleted file mode 100644 index 9916204c0..000000000 --- a/gusto/fml/replacement.py +++ /dev/null @@ -1,221 +0,0 @@ -""" -Generic routines for replacing functions using FML. -""" - -import ufl -from .form_manipulation_language import Term, subject -from firedrake import split, MixedElement - -__all__ = ["replace_test_function", "replace_trial_function", - "replace_subject"] - - -# ---------------------------------------------------------------------------- # -# A general routine for building the replacement dictionary -# ---------------------------------------------------------------------------- # -def _replace_dict(old, new, old_idx, new_idx, replace_type): - """ - Build a dictionary to pass to the ufl.replace routine - The dictionary matches variables in the old term with those in the new - - Does not check types unless indexing is required (leave type-checking to ufl.replace) - """ - - mixed_old = type(old.ufl_element()) is MixedElement - mixed_new = hasattr(new, "ufl_element") and type(new.ufl_element()) is MixedElement - - indexable_old = mixed_old - indexable_new = mixed_new or type(new) is tuple - - if mixed_old: - split_old = split(old) - if indexable_new: - split_new = new if type(new) is tuple else split(new) - - # check indices arguments are valid - if not indexable_old and old_idx is not None: - raise ValueError(f"old_idx should not be specified to replace_{replace_type}" - + f" when replaced {replace_type} of type {old} is not mixed.") - - if not indexable_new and new_idx is not None: - raise ValueError(f"new_idx should not be specified to replace_{replace_type} when" - + f" new {replace_type} of type {new} is not mixed or indexable.") - - if indexable_old and not indexable_new: - if old_idx is None: - raise ValueError(f"old_idx must be specified to replace_{replace_type} when replaced" - + f" {replace_type} of type {old} is mixed and new {replace_type}" - + f" of type {new} is not mixed or indexable.") - - if indexable_new and not indexable_old: - if new_idx is None: - raise ValueError(f"new_idx must be specified to replace_{replace_type} when new" - + f" {replace_type} of type {new} is mixed or indexable and" - + f" old {replace_type} of type {old} is not mixed.") - - if indexable_old and indexable_new: - # must be both True or both False - if (old_idx is None) ^ (new_idx is None): - raise ValueError("both or neither old_idx and new_idx must be specified to" - + f" replace_{replace_type} when old {replace_type} of type" - + f" {old} is mixed and new {replace_type} of type {new} is" - + " mixed or indexable.") - if old_idx is None: # both indexes are none - if len(split_old) != len(split_new): - raise ValueError(f"if neither index is specified to replace_{replace_type}" - + f" and both old {replace_type} of type {old} and new" - + f" {replace_type} of type {new} are mixed or indexable" - + f" then old of length {len(split_old)} and new of length {len(split_new)}" - + " must be the same length.") - - # make the replace_dict - - replace_dict = {} - - if not indexable_old and not indexable_new: - replace_dict[old] = new - - elif not indexable_old and indexable_new: - replace_dict[old] = split_new[new_idx] - - elif indexable_old and not indexable_new: - replace_dict[split_old[old_idx]] = new - - elif indexable_old and indexable_new: - if old_idx is None: # replace everything - for k, v in zip(split_old, split_new): - replace_dict[k] = v - else: # idxs are given - replace_dict[split_old[old_idx]] = split_new[new_idx] - - return replace_dict - - -# ---------------------------------------------------------------------------- # -# Replacement routines -# ---------------------------------------------------------------------------- # -def replace_test_function(new_test, old_idx=None, new_idx=None): - """ - A routine to replace the test function in a term with a new test function. - - Args: - new_test (:class:`TestFunction`): the new test function. - - Returns: - a function that takes in t, a :class:`Term`, and returns a new - :class:`Term` with form containing the new_test and labels=t.labels - """ - - def repl(t): - """ - Replaces the test function in a term with a new expression. This is - built around the ufl replace routine. - - Args: - t (:class:`Term`): the original term. - - Returns: - :class:`Term`: the new term. - """ - old_test = t.form.arguments()[0] - replace_dict = _replace_dict(old_test, new_test, - old_idx=old_idx, new_idx=new_idx, - replace_type='test') - - try: - new_form = ufl.replace(t.form, replace_dict) - except Exception as err: - error_message = f"{type(err)} raised by ufl.replace when trying to" \ - + f" replace_test_function with {new_test}" - raise type(err)(error_message) from err - - return Term(new_form, t.labels) - - return repl - - -def replace_trial_function(new_trial, old_idx=None, new_idx=None): - """ - A routine to replace the trial function in a term with a new expression. - - Args: - new (:class:`TrialFunction` or :class:`Function`): the new function. - - Returns: - a function that takes in t, a :class:`Term`, and returns a new - :class:`Term` with form containing the new_test and labels=t.labels - """ - - def repl(t): - """ - Replaces the trial function in a term with a new expression. This is - built around the ufl replace routine. - - Args: - t (:class:`Term`): the original term. - - Raises: - TypeError: if the form is linear. - - Returns: - :class:`Term`: the new term. - """ - if len(t.form.arguments()) != 2: - raise TypeError('Trying to replace trial function of a form that is not linear') - old_trial = t.form.arguments()[1] - replace_dict = _replace_dict(old_trial, new_trial, - old_idx=old_idx, new_idx=new_idx, - replace_type='trial') - - try: - new_form = ufl.replace(t.form, replace_dict) - except Exception as err: - error_message = f"{type(err)} raised by ufl.replace when trying to" \ - + f" replace_trial_function with {new_trial}" - raise type(err)(error_message) from err - - return Term(new_form, t.labels) - - return repl - - -def replace_subject(new_subj, old_idx=None, new_idx=None): - """ - A routine to replace the subject in a term with a new variable. - - Args: - new (:class:`ufl.Expr`): the new expression to replace the subject. - idx (int, optional): index of the subject in the equation's - :class:`MixedFunctionSpace`. Defaults to None. - """ - def repl(t): - """ - Replaces the subject in a term with a new expression. This is built - around the ufl replace routine. - - Args: - t (:class:`Term`): the original term. - - Raises: - ValueError: when the new expression and subject are not of - compatible sizes (e.g. a mixed function vs a non-mixed function) - - Returns: - :class:`Term`: the new term. - """ - - old_subj = t.get(subject) - replace_dict = _replace_dict(old_subj, new_subj, - old_idx=old_idx, new_idx=new_idx, - replace_type='subject') - - try: - new_form = ufl.replace(t.form, replace_dict) - except Exception as err: - error_message = f"{type(err)} raised by ufl.replace when trying to" \ - + f" replace_subject with {new_subj}" - raise type(err)(error_message) from err - - return Term(new_form, t.labels) - - return repl diff --git a/gusto/forcing.py b/gusto/forcing.py index 71ab8bfd0..27f460db9 100644 --- a/gusto/forcing.py +++ b/gusto/forcing.py @@ -4,7 +4,7 @@ Function, TrialFunctions, DirichletBC, LinearVariationalProblem, LinearVariationalSolver ) -from gusto.fml import drop, replace_subject, name +from firedrake.fml import drop, replace_subject, name_label from gusto.labels import ( transport, diffusion, time_derivative, hydrostatic, physics_label ) @@ -58,20 +58,26 @@ def __init__(self, equation, alpha): # the explicit forms are multiplied by (1-alpha) and moved to the rhs L_explicit = -(1-alpha)*dt*residual.label_map( - lambda t: t.has_label(time_derivative) or t.get(name) in implicit_terms or t.get(name) == "hydrostatic_form", + lambda t: + t.has_label(time_derivative) + or t.get(name_label) in implicit_terms + or t.get(name_label) == "hydrostatic_form", drop, replace_subject(self.x0)) # the implicit forms are multiplied by alpha and moved to the rhs L_implicit = -alpha*dt*residual.label_map( - lambda t: t.has_label(time_derivative) or t.get(name) in implicit_terms or t.get(name) == "hydrostatic_form", + lambda t: + t.has_label(time_derivative) + or t.get(name_label) in implicit_terms + or t.get(name_label) == "hydrostatic_form", drop, replace_subject(self.x0)) # now add the terms that are always fully implicit - if any(t.get(name) in implicit_terms for t in residual): + if any(t.get(name_label) in implicit_terms for t in residual): L_implicit -= dt*residual.label_map( - lambda t: t.get(name) in implicit_terms, + lambda t: t.get(name_label) in implicit_terms, replace_subject(self.x0), drop) @@ -79,12 +85,12 @@ def __init__(self, equation, alpha): if any([t.has_label(hydrostatic) for t in residual]): L_explicit += residual.label_map( - lambda t: t.get(name) == "hydrostatic_form", + lambda t: t.get(name_label) == "hydrostatic_form", replace_subject(self.x0), drop) L_implicit -= residual.label_map( - lambda t: t.get(name) == "hydrostatic_form", + lambda t: t.get(name_label) == "hydrostatic_form", replace_subject(self.x0), drop) diff --git a/gusto/labels.py b/gusto/labels.py index 7538de5a2..5aa658d85 100644 --- a/gusto/labels.py +++ b/gusto/labels.py @@ -2,8 +2,8 @@ import ufl from firedrake import Function +from firedrake.fml import Term, Label, LabelledForm from gusto.configuration import IntegrateByParts, TransportEquationType -from gusto.fml.form_manipulation_language import Term, Label, LabelledForm from types import MethodType dynamics_label = Label("dynamics", validator=lambda value: type(value) is str) diff --git a/gusto/linear_solvers.py b/gusto/linear_solvers.py index f991706a6..55bee0616 100644 --- a/gusto/linear_solvers.py +++ b/gusto/linear_solvers.py @@ -5,12 +5,14 @@ finite element spaces. """ -from firedrake import (split, LinearVariationalProblem, Constant, - LinearVariationalSolver, TestFunctions, TrialFunctions, - TestFunction, TrialFunction, lhs, rhs, FacetNormal, - div, dx, jump, avg, dS_v, dS_h, ds_v, ds_t, ds_b, ds_tb, inner, action, - dot, grad, Function, VectorSpaceBasis, BrokenElement, - FunctionSpace, MixedFunctionSpace, DirichletBC) +from firedrake import ( + split, LinearVariationalProblem, Constant, LinearVariationalSolver, + TestFunctions, TrialFunctions, TestFunction, TrialFunction, lhs, + rhs, FacetNormal, div, dx, jump, avg, dS_v, dS_h, ds_v, ds_t, ds_b, + ds_tb, inner, action, dot, grad, Function, VectorSpaceBasis, + BrokenElement, FunctionSpace, MixedFunctionSpace, DirichletBC +) +from firedrake.fml import Term, drop from firedrake.petsc import flatten_parameters from pyop2.profiling import timed_function, timed_region @@ -18,7 +20,6 @@ from gusto.logging import logger, DEBUG, logging_ksp_monitor_true_residual from gusto.labels import linearisation, time_derivative, hydrostatic from gusto import thermodynamics -from gusto.fml.form_manipulation_language import Term, drop from gusto.recovery.recovery_kernels import AverageWeightings, AverageKernel from abc import ABCMeta, abstractmethod, abstractproperty diff --git a/gusto/physics.py b/gusto/physics.py index 76c8ee4c3..7bebece17 100644 --- a/gusto/physics.py +++ b/gusto/physics.py @@ -8,18 +8,20 @@ """ from abc import ABCMeta, abstractmethod +from firedrake import ( + Interpolator, conditional, Function, dx, sqrt, dot, min_value, + max_value, Constant, pi, Projector, grad, TestFunctions, split, + inner, TestFunction, exp, avg, outer, FacetNormal, + SpatialCoordinate, dS_v, NonlinearVariationalProblem, + NonlinearVariationalSolver +) +from firedrake.fml import identity, Term, subject from gusto.active_tracers import Phases, TracerVariableType from gusto.configuration import BoundaryLayerParameters from gusto.recovery import Recoverer, BoundaryMethod from gusto.equations import CompressibleEulerEquations -from gusto.fml import identity, Term, subject from gusto.labels import PhysicsLabel, transporting_velocity, transport, prognostic from gusto.logging import logger -from firedrake import (Interpolator, conditional, Function, dx, sqrt, dot, - min_value, max_value, Constant, pi, Projector, grad, - TestFunctions, split, inner, TestFunction, exp, avg, - outer, FacetNormal, SpatialCoordinate, dS_v, - NonlinearVariationalProblem, NonlinearVariationalSolver) from gusto import thermodynamics import ufl import math diff --git a/gusto/spatial_methods.py b/gusto/spatial_methods.py index bb050f3c5..2b1751545 100644 --- a/gusto/spatial_methods.py +++ b/gusto/spatial_methods.py @@ -4,7 +4,7 @@ """ from firedrake import split -from gusto.fml import Term, keep, drop +from firedrake.fml import Term, keep, drop from gusto.labels import prognostic __all__ = ['SpatialMethod'] diff --git a/gusto/time_discretisation.py b/gusto/time_discretisation.py index 4a0dcd4d6..c8e237428 100644 --- a/gusto/time_discretisation.py +++ b/gusto/time_discretisation.py @@ -6,20 +6,23 @@ """ from abc import ABCMeta, abstractmethod, abstractproperty -from firedrake import (Function, TestFunction, NonlinearVariationalProblem, - NonlinearVariationalSolver, DirichletBC, split, Constant) +import math +import numpy as np + +from firedrake import ( + Function, TestFunction, NonlinearVariationalProblem, + NonlinearVariationalSolver, DirichletBC, split, Constant +) +from firedrake.fml import ( + replace_subject, replace_test_function, Term, all_terms, drop +) from firedrake.formmanipulation import split_form from firedrake.utils import cached_property from gusto.configuration import EmbeddedDGOptions, RecoveryOptions -from gusto.fml import ( - replace_subject, replace_test_function, Term, all_terms, drop -) from gusto.labels import time_derivative, prognostic, physics_label from gusto.logging import logger, DEBUG, logging_ksp_monitor_true_residual from gusto.wrappers import * -import math -import numpy as np __all__ = ["ForwardEuler", "BackwardEuler", "ExplicitMultistage", "ImplicitMultistage", diff --git a/gusto/timeloop.py b/gusto/timeloop.py index a7fc661b3..7c1c2a441 100644 --- a/gusto/timeloop.py +++ b/gusto/timeloop.py @@ -2,9 +2,9 @@ from abc import ABCMeta, abstractmethod, abstractproperty from firedrake import Function, Projector, split, Constant +from firedrake.fml import drop, Label, Term from pyop2.profiling import timed_stage from gusto.equations import PrognosticEquationSet -from gusto.fml import drop, Label, Term from gusto.fields import TimeLevelFields, StateFields from gusto.forcing import Forcing from gusto.labels import ( diff --git a/gusto/transport_methods.py b/gusto/transport_methods.py index 30e6fa974..4fc951b6d 100644 --- a/gusto/transport_methods.py +++ b/gusto/transport_methods.py @@ -2,11 +2,12 @@ Defines TransportMethod objects, which are used to solve a transport problem. """ -from firedrake import (dx, dS, dS_v, dS_h, ds_t, ds_b, ds_v, dot, inner, outer, - jump, grad, div, FacetNormal, Function, sign, avg, cross, - curl) +from firedrake import ( + dx, dS, dS_v, dS_h, ds_t, ds_b, ds_v, dot, inner, outer, jump, + grad, div, FacetNormal, Function, sign, avg, cross, curl +) +from firedrake.fml import Term, keep, drop from gusto.configuration import IntegrateByParts, TransportEquationType -from gusto.fml import Term, keep, drop from gusto.labels import prognostic, transport, transporting_velocity, ibp_label from gusto.logging import logger from gusto.spatial_methods import SpatialMethod diff --git a/gusto/wrappers.py b/gusto/wrappers.py index 79b650200..10c463ca7 100644 --- a/gusto/wrappers.py +++ b/gusto/wrappers.py @@ -5,11 +5,12 @@ """ from abc import ABCMeta, abstractmethod -from firedrake import (FunctionSpace, Function, BrokenElement, Projector, - Interpolator, VectorElement, Constant, as_ufl, dot, grad, - TestFunction) +from firedrake import ( + FunctionSpace, Function, BrokenElement, Projector, Interpolator, + VectorElement, Constant, as_ufl, dot, grad, TestFunction +) +from firedrake.fml import Term from gusto.configuration import EmbeddedDGOptions, RecoveryOptions, SUPGOptions -from gusto.fml import Term from gusto.recovery import Recoverer, ReversibleRecoverer from gusto.labels import transporting_velocity import ufl diff --git a/integration-tests/physics/test_boundary_layer_mixing.py b/integration-tests/physics/test_boundary_layer_mixing.py index 1b0499e70..a72d0b3c1 100644 --- a/integration-tests/physics/test_boundary_layer_mixing.py +++ b/integration-tests/physics/test_boundary_layer_mixing.py @@ -6,6 +6,7 @@ from gusto.labels import physics_label from firedrake import (VectorFunctionSpace, PeriodicIntervalMesh, as_vector, exp, SpatialCoordinate, ExtrudedMesh, Function) +from firedrake.fml import identity import pytest diff --git a/integration-tests/physics/test_saturation_adjustment.py b/integration-tests/physics/test_saturation_adjustment.py index 767e0c787..3671114ea 100644 --- a/integration-tests/physics/test_saturation_adjustment.py +++ b/integration-tests/physics/test_saturation_adjustment.py @@ -10,6 +10,7 @@ from firedrake import (norm, Constant, PeriodicIntervalMesh, SpatialCoordinate, ExtrudedMesh, Function, sqrt, conditional) +from firedrake.fml import identity from netCDF4 import Dataset import pytest diff --git a/integration-tests/physics/test_static_adjustment.py b/integration-tests/physics/test_static_adjustment.py index 18ee13cea..2264d6675 100644 --- a/integration-tests/physics/test_static_adjustment.py +++ b/integration-tests/physics/test_static_adjustment.py @@ -8,6 +8,7 @@ from gusto.labels import physics_label from firedrake import (Constant, PeriodicIntervalMesh, SpatialCoordinate, ExtrudedMesh, Function) +from firedrake.fml import identity import pytest diff --git a/integration-tests/physics/test_suppress_vertical_wind.py b/integration-tests/physics/test_suppress_vertical_wind.py index 250916227..44f1e05ae 100644 --- a/integration-tests/physics/test_suppress_vertical_wind.py +++ b/integration-tests/physics/test_suppress_vertical_wind.py @@ -7,6 +7,7 @@ from gusto.labels import physics_label from firedrake import (Constant, PeriodicIntervalMesh, as_vector, sin, norm, SpatialCoordinate, ExtrudedMesh, Function, dot) +from firedrake.fml import identity def run_suppress_vertical_wind(dirname): diff --git a/integration-tests/physics/test_surface_fluxes.py b/integration-tests/physics/test_surface_fluxes.py index 09d424af4..fd4f92838 100644 --- a/integration-tests/physics/test_surface_fluxes.py +++ b/integration-tests/physics/test_surface_fluxes.py @@ -9,6 +9,7 @@ from gusto.labels import physics_label from firedrake import (norm, Constant, PeriodicIntervalMesh, as_vector, SpatialCoordinate, ExtrudedMesh, Function, conditional) +from firedrake.fml import identity import pytest diff --git a/integration-tests/physics/test_wind_drag.py b/integration-tests/physics/test_wind_drag.py index a00818089..6bf107a9d 100644 --- a/integration-tests/physics/test_wind_drag.py +++ b/integration-tests/physics/test_wind_drag.py @@ -7,6 +7,7 @@ from gusto.labels import physics_label from firedrake import (norm, Constant, PeriodicIntervalMesh, as_vector, dot, SpatialCoordinate, ExtrudedMesh, Function, conditional) +from firedrake.fml import identity import pytest diff --git a/integration-tests/test_fml.py b/integration-tests/test_fml.py deleted file mode 100644 index b827de8c6..000000000 --- a/integration-tests/test_fml.py +++ /dev/null @@ -1,120 +0,0 @@ -""" -Tests a full workflow for the Form Manipulation Language (FML). - -This uses an IMEX discretisation of the linear shallow-water equations on a -mixed function space. -""" -from gusto.fml import subject, replace_subject, keep, drop, Label -from firedrake import (PeriodicUnitSquareMesh, FunctionSpace, Constant, - MixedFunctionSpace, TestFunctions, Function, split, - inner, dx, SpatialCoordinate, as_vector, pi, sin, div, - NonlinearVariationalProblem, NonlinearVariationalSolver) - - -def test_fml(): - - # Define labels for shallow-water - time_derivative = Label("time_derivative") - transport = Label("transport") - pressure_gradient = Label("pressure_gradient") - explicit = Label("explicit") - implicit = Label("implicit") - - # ------------------------------------------------------------------------ # - # Set up finite element objects - # ------------------------------------------------------------------------ # - - # Two shallow-water constants - H = Constant(10000.) - g = Constant(10.) - - # Set up mesh and function spaces - dt = Constant(0.01) - Nx = 5 - mesh = PeriodicUnitSquareMesh(Nx, Nx) - spaces = [FunctionSpace(mesh, "BDM", 1), FunctionSpace(mesh, "DG", 1)] - W = MixedFunctionSpace(spaces) - - # Set up fields on a mixed function space - w, phi = TestFunctions(W) - X = Function(W) - u0, h0 = split(X) - - # Set up time derivatives - mass_form = time_derivative(subject(inner(u0, w)*dx + subject(inner(h0, phi)*dx), X)) - - # Height field transport form - transport_form = transport(subject(H*phi*div(u0)*dx, X)) - - # Pressure gradient term -- integrate by parts once - pressure_gradient_form = pressure_gradient(subject(-g*div(w)*h0*dx, X)) - - # Define IMEX scheme. Transport term explicit and pressure gradient implict. - # This is not necessarily a sensible scheme -- it's just a simple demo for - # how FML can be used. - transport_form = explicit(transport_form) - pressure_gradient_form = implicit(pressure_gradient_form) - - # Add terms together to give whole residual - residual = mass_form + transport_form + pressure_gradient_form - - # ------------------------------------------------------------------------ # - # Initial condition - # ------------------------------------------------------------------------ # - - # Constant flow but sinusoidal height field - x, _ = SpatialCoordinate(mesh) - u0, h0 = X.subfunctions - u0.interpolate(as_vector([1.0, 0.0])) - h0.interpolate(H + 0.01*H*sin(2*pi*x)) - - # ------------------------------------------------------------------------ # - # Set up time discretisation - # ------------------------------------------------------------------------ # - - X_np1 = Function(W) - - # Here we would normally set up routines for the explicit and implicit parts - # but as this is just a test, we'll do just a single explicit/implicit step - - # Explicit: just forward euler - explicit_lhs = residual.label_map(lambda t: t.has_label(time_derivative), - map_if_true=replace_subject(X_np1), - map_if_false=drop) - - explicit_rhs = residual.label_map(lambda t: t.has_label(time_derivative) - or t.has_label(explicit), - map_if_true=keep, map_if_false=drop) - explicit_rhs = explicit_rhs.label_map(lambda t: t.has_label(time_derivative), - map_if_false=lambda t: -dt*t) - - # Implicit: just backward euler - implicit_lhs = residual.label_map(lambda t: t.has_label(time_derivative) - or t.has_label(implicit), - map_if_true=replace_subject(X_np1), - map_if_false=drop) - implicit_lhs = implicit_lhs.label_map(lambda t: t.has_label(time_derivative), - map_if_false=lambda t: dt*t) - - implicit_rhs = residual.label_map(lambda t: t.has_label(time_derivative), - map_if_false=drop) - - # ------------------------------------------------------------------------ # - # Set up and solve problems - # ------------------------------------------------------------------------ # - - explicit_residual = explicit_lhs - explicit_rhs - implicit_residual = implicit_lhs - implicit_rhs - - explicit_problem = NonlinearVariationalProblem(explicit_residual.form, X_np1) - explicit_solver = NonlinearVariationalSolver(explicit_problem) - - implicit_problem = NonlinearVariationalProblem(implicit_residual.form, X_np1) - implicit_solver = NonlinearVariationalSolver(implicit_problem) - - # Solve problems and update X_np1 - # In reality this would be within a time stepping loop! - explicit_solver.solve() - X.assign(X_np1) - implicit_solver.solve() - X.assign(X_np1) diff --git a/unit-tests/fml_tests/test_label.py b/unit-tests/fml_tests/test_label.py deleted file mode 100644 index 8b7fc7215..000000000 --- a/unit-tests/fml_tests/test_label.py +++ /dev/null @@ -1,190 +0,0 @@ -""" -Tests FML's Label objects. -""" - -from firedrake import IntervalMesh, FunctionSpace, Function, TestFunction, dx -from gusto.configuration import TransportEquationType -from gusto.fml import Label, LabelledForm, Term -from ufl import Form -import pytest - - -@pytest.fixture -def label_and_values(label_type): - # Returns labels with different value validation - - bad_value = "bar" - - if label_type == "boolean": - # A label that is simply a string, whose value is Boolean - this_label = Label("foo") - good_value = True - new_value = False - - elif label_type == "integer": - # A label whose value is an integer - this_label = Label("foo", validator=lambda value: (type(value) == int and value < 9)) - good_value = 5 - bad_value = 10 - new_value = 7 - - elif label_type == "other": - # A label whose value is some other type - this_label = Label("foo", validator=lambda value: type(value) == TransportEquationType) - good_value = TransportEquationType.advective - new_value = TransportEquationType.conservative - - elif label_type == "function": - # A label whose value is an Function - this_label = Label("foo", validator=lambda value: type(value) == Function) - good_value, _ = setup_form() - new_value = Function(good_value.function_space()) - - return this_label, good_value, bad_value, new_value - - -def setup_form(): - # Create mesh and function space - L = 3.0 - n = 3 - mesh = IntervalMesh(n, L) - V = FunctionSpace(mesh, "DG", 0) - f = Function(V) - g = TestFunction(V) - form = f*g*dx - - return f, form - - -@pytest.fixture -def object_to_label(object_type): - # A series of different objects to be labelled - - if object_type == int: - return 10 - - else: - _, form = setup_form() - term = Term(form) - - if object_type == Form: - return form - - elif object_type == Term: - return term - - elif object_type == LabelledForm: - return LabelledForm(term) - - else: - raise ValueError(f'object_type {object_type} not implemented') - - -@pytest.mark.parametrize("label_type", ["boolean", "integer", - "other", "function"]) -@pytest.mark.parametrize("object_type", [LabelledForm, Term, Form, int]) -def test_label(label_type, object_type, label_and_values, object_to_label): - - label, good_value, bad_value, new_value = label_and_values - - # ------------------------------------------------------------------------ # - # Check label has correct name - # ------------------------------------------------------------------------ # - - assert label.label == "foo", "Label has incorrect name" - - # ------------------------------------------------------------------------ # - # Check we can't label unsupported objects - # ------------------------------------------------------------------------ # - - if object_type == int: - # Can't label integers, so check this fails and force end - try: - labelled_object = label(object_to_label) - except ValueError: - # Appropriate error has been returned so end the test - return - - # If we get here there has been an error - assert False, "Labelling an integer should throw an error" - - # ------------------------------------------------------------------------ # - # Test application of labels - # ------------------------------------------------------------------------ # - - if label_type == "boolean": - labelled_object = label(object_to_label) - - else: - # Check that passing an inappropriate label gives the correct error - try: - labelled_object = label(object_to_label, bad_value) - # If we get here the validator has not worked - assert False, 'The labelling validator has not worked for ' \ - + f'label_type {label_type} and object_type {object_type}' - - except AssertionError: - # Now label object properly - labelled_object = label(object_to_label, good_value) - - # ------------------------------------------------------------------------ # - # Check labelled form or term has been returned - # ------------------------------------------------------------------------ # - - if object_type == Term: - assert type(labelled_object) == Term, 'Labelled Term should be a ' \ - + f'be a Term and not type {type(labelled_object)}' - else: - assert type(labelled_object) == LabelledForm, 'Labelled Form should ' \ - + f'be a Labelled Form and not type {type(labelled_object)}' - - # ------------------------------------------------------------------------ # - # Test that the values are correct - # ------------------------------------------------------------------------ # - - if object_type == Term: - assert labelled_object.get(label) == good_value, 'Value of label ' \ - + f'should be {good_value} and not {labelled_object.get(label)}' - else: - assert labelled_object.terms[0].get(label) == good_value, 'Value of ' \ - + f'label should be {good_value} and not ' \ - + f'{labelled_object.terms[0].get(label)}' - - # ------------------------------------------------------------------------ # - # Test updating of values - # ------------------------------------------------------------------------ # - - # Check that passing an inappropriate label gives the correct error - try: - labelled_object = label.update_value(labelled_object, bad_value) - # If we get here the validator has not worked - assert False, 'The validator has not worked for updating label of ' \ - + f'label_type {label_type} and object_type {object_type}' - except AssertionError: - # Update new value - labelled_object = label.update_value(labelled_object, new_value) - - # Check that new value is correct - if object_type == Term: - assert labelled_object.get(label) == new_value, 'Updated value of ' \ - + f'label should be {new_value} and not {labelled_object.get(label)}' - else: - assert labelled_object.terms[0].get(label) == new_value, 'Updated ' \ - + f'value of label should be {new_value} and not ' \ - + f'{labelled_object.terms[0].get(label)}' - - # ------------------------------------------------------------------------ # - # Test removal of values - # ------------------------------------------------------------------------ # - - labelled_object = label.remove(labelled_object) - - # Try to see if object still has that label - if object_type == Term: - label_value = labelled_object.get(label) - else: - label_value = labelled_object.terms[0].get(label) - - # If we get here then the label has been extracted but it shouldn't have - assert label_value is None, f'The label {label_type} appears has not to ' \ - + f'have been removed for object_type {object_type}' diff --git a/unit-tests/fml_tests/test_label_map.py b/unit-tests/fml_tests/test_label_map.py deleted file mode 100644 index 0fb31563f..000000000 --- a/unit-tests/fml_tests/test_label_map.py +++ /dev/null @@ -1,74 +0,0 @@ -""" -Tests FML's LabelledForm label_map routine. -""" - -from firedrake import IntervalMesh, FunctionSpace, Function, TestFunction, dx -from gusto.fml import Label, Term, identity, drop, all_terms - - -def test_label_map(): - - # ------------------------------------------------------------------------ # - # Set up labelled forms - # ------------------------------------------------------------------------ # - - # Some basic labels - foo_label = Label("foo") - bar_label = Label("bar", validator=lambda value: type(value) == int) - - # Create mesh, function space and forms - L = 3.0 - n = 3 - mesh = IntervalMesh(n, L) - V = FunctionSpace(mesh, "DG", 0) - f = Function(V) - g = Function(V) - test = TestFunction(V) - form_1 = f*test*dx - form_2 = g*test*dx - term_1 = foo_label(Term(form_1)) - term_2 = bar_label(Term(form_2), 5) - - labelled_form = term_1 + term_2 - - # ------------------------------------------------------------------------ # - # Test all_terms - # ------------------------------------------------------------------------ # - - # Passing all_terms should return the same labelled form - new_labelled_form = labelled_form.label_map(all_terms) - assert len(new_labelled_form) == len(labelled_form), \ - 'new_labelled_form should be the same as labelled_form' - for new_term, term in zip(new_labelled_form.terms, labelled_form.terms): - assert new_term == term, 'terms in new_labelled_form should be the ' + \ - 'same as those in labelled_form' - - # ------------------------------------------------------------------------ # - # Test identity and drop - # ------------------------------------------------------------------------ # - - # Get just the first term, which has the foo label - new_labelled_form = labelled_form.label_map( - lambda t: t.has_label(foo_label), map_if_true=identity, map_if_false=drop - ) - assert len(new_labelled_form) == 1, 'new_labelled_form should be length 1' - for new_term in new_labelled_form.terms: - assert new_term.has_label(foo_label), 'All terms in ' + \ - 'new_labelled_form should have foo_label' - - # Give term_1 the bar label - new_labelled_form = labelled_form.label_map( - lambda t: t.has_label(bar_label), map_if_true=identity, - map_if_false=lambda t: bar_label(t, 0) - ) - assert len(new_labelled_form) == 2, 'new_labelled_form should be length 2' - for new_term in new_labelled_form.terms: - assert new_term.has_label(bar_label), 'All terms in ' + \ - 'new_labelled_form should have bar_label' - - # Test with a more complex filter, which should give an empty labelled_form - new_labelled_form = labelled_form.label_map( - lambda t: (t.has_label(bar_label) and t.get(bar_label) > 10), - map_if_true=identity, map_if_false=drop - ) - assert len(new_labelled_form) == 0, 'new_labelled_form should be length 0' diff --git a/unit-tests/fml_tests/test_labelled_form.py b/unit-tests/fml_tests/test_labelled_form.py deleted file mode 100644 index a0176adb3..000000000 --- a/unit-tests/fml_tests/test_labelled_form.py +++ /dev/null @@ -1,134 +0,0 @@ -""" -Tests FML's LabelledForm objects. -""" - -from firedrake import (IntervalMesh, FunctionSpace, Function, - TestFunction, dx, Constant) -from gusto.fml import Label, Term, LabelledForm -from ufl import Form - - -def test_labelled_form(): - - # ------------------------------------------------------------------------ # - # Set up labelled forms - # ------------------------------------------------------------------------ # - - # Some basic labels - lorem_label = Label("lorem", validator=lambda value: type(value) == str) - ipsum_label = Label("ipsum", validator=lambda value: type(value) == int) - - # Create mesh, function space and forms - L = 3.0 - n = 3 - mesh = IntervalMesh(n, L) - V = FunctionSpace(mesh, "DG", 0) - f = Function(V) - g = Function(V) - test = TestFunction(V) - form_1 = f*test*dx - form_2 = g*test*dx - term_1 = lorem_label(Term(form_1), 'i_have_lorem') - term_2 = ipsum_label(Term(form_2), 5) - - # ------------------------------------------------------------------------ # - # Test labelled forms have the correct number of terms - # ------------------------------------------------------------------------ # - - # Create from a single term - labelled_form_1 = LabelledForm(term_1) - assert len(labelled_form_1) == 1, 'LabelledForm should have 1 term' - - # Create from multiple terms - labelled_form_2 = LabelledForm(*[term_1, term_2]) - assert len(labelled_form_2) == 2, 'LabelledForm should have 2 terms' - - # Trying to create from two LabelledForms should give an error - try: - labelled_form_3 = LabelledForm(labelled_form_1, labelled_form_2) - # If we get here something has gone wrong - assert False, 'We should not be able to create LabelledForm ' + \ - 'from two LabelledForms' - except TypeError: - pass - - # Create from a single LabelledForm - labelled_form_3 = LabelledForm(labelled_form_1) - assert len(labelled_form_3) == 1, 'LabelledForm should have 1 term' - - # ------------------------------------------------------------------------ # - # Test getting form - # ------------------------------------------------------------------------ # - - assert type(labelled_form_1.form) is Form, 'The form belonging to the ' + \ - f'LabelledForm must be a Form, and not {type(labelled_form_1.form)}' - - assert type(labelled_form_2.form) is Form, 'The form belonging to the ' + \ - f'LabelledForm must be a Form, and not {type(labelled_form_2.form)}' - - assert type(labelled_form_3.form) is Form, 'The form belonging to the ' + \ - f'LabelledForm must be a Form, and not {type(labelled_form_3.form)}' - - # ------------------------------------------------------------------------ # - # Test addition and subtraction of labelled forms - # ------------------------------------------------------------------------ # - - # Add a Form to a LabelledForm - new_labelled_form = labelled_form_1 + form_2 - assert len(new_labelled_form) == 2, 'LabelledForm should have 2 terms' - - # Add a Term to a LabelledForm - new_labelled_form = labelled_form_1 + term_2 - assert len(new_labelled_form) == 2, 'LabelledForm should have 2 terms' - - # Add a LabelledForm to a LabelledForm - new_labelled_form = labelled_form_1 + labelled_form_2 - assert len(new_labelled_form) == 3, 'LabelledForm should have 3 terms' - - # Adding None to a LabelledForm should give the same LabelledForm - new_labelled_form = labelled_form_1 + None - assert new_labelled_form == labelled_form_1, 'Two LabelledForms should be equal' - - # Subtract a Form from a LabelledForm - new_labelled_form = labelled_form_1 - form_2 - assert len(new_labelled_form) == 2, 'LabelledForm should have 2 terms' - - # Subtract a Term from a LabelledForm - new_labelled_form = labelled_form_1 - term_2 - assert len(new_labelled_form) == 2, 'LabelledForm should have 2 terms' - - # Subtract a LabelledForm from a LabelledForm - new_labelled_form = labelled_form_1 - labelled_form_2 - assert len(new_labelled_form) == 3, 'LabelledForm should have 3 terms' - - # Subtracting None from a LabelledForm should give the same LabelledForm - new_labelled_form = labelled_form_1 - None - assert new_labelled_form == labelled_form_1, 'Two LabelledForms should be equal' - - # ------------------------------------------------------------------------ # - # Test multiplication and division of labelled forms - # ------------------------------------------------------------------------ # - - # Multiply by integer - new_labelled_form = labelled_form_1 * -4 - assert len(new_labelled_form) == 1, 'LabelledForm should have 1 term' - - # Multiply by float - new_labelled_form = labelled_form_1 * 12.4 - assert len(new_labelled_form) == 1, 'LabelledForm should have 1 term' - - # Multiply by Constant - new_labelled_form = labelled_form_1 * Constant(5.0) - assert len(new_labelled_form) == 1, 'LabelledForm should have 1 term' - - # Divide by integer - new_labelled_form = labelled_form_1 / (-8) - assert len(new_labelled_form) == 1, 'LabelledForm should have 1 term' - - # Divide by float - new_labelled_form = labelled_form_1 / (-6.2) - assert len(new_labelled_form) == 1, 'LabelledForm should have 1 term' - - # Divide by Constant - new_labelled_form = labelled_form_1 / Constant(0.01) - assert len(new_labelled_form) == 1, 'LabelledForm should have 1 term' diff --git a/unit-tests/fml_tests/test_replace_perp.py b/unit-tests/fml_tests/test_replace_perp.py deleted file mode 100644 index f095df545..000000000 --- a/unit-tests/fml_tests/test_replace_perp.py +++ /dev/null @@ -1,46 +0,0 @@ -# The perp routine should come from UFL when it is fully implemented there -from gusto import perp -from gusto.fml import subject, replace_subject, all_terms -from firedrake import (UnitSquareMesh, FunctionSpace, MixedFunctionSpace, - TestFunctions, Function, split, inner, dx, errornorm, - SpatialCoordinate, as_vector, TrialFunctions, solve) - - -def test_replace_perp(): - - # The test checks that if the perp operator is applied to the - # subject of a labelled form, the perp of the subject is found and - # replaced by the replace_subject function. This gave particular problems - # before the perp operator was defined - - # set up mesh and function spaces - the subject is defined on a - # mixed function space because the problem didn't occur otherwise - Nx = 5 - mesh = UnitSquareMesh(Nx, Nx) - spaces = [FunctionSpace(mesh, "BDM", 1), FunctionSpace(mesh, "DG", 1)] - W = MixedFunctionSpace(spaces) - - # set up labelled form with subject u - w, p = TestFunctions(W) - U0 = Function(W) - u0, _ = split(U0) - form = subject(inner(perp(u0), w)*dx, U0) - - # make a function to replace the subject with and give it some values - U1 = Function(W) - u1, _ = U1.subfunctions - x, y = SpatialCoordinate(mesh) - u1.interpolate(as_vector([1, 2])) - - u, D = TrialFunctions(W) - a = inner(u, w)*dx + D*p*dx - L = form.label_map(all_terms, replace_subject(U1, old_idx=0, new_idx=0)) - U2 = Function(W) - solve(a == L.form, U2) - - u2, _ = U2.subfunctions - U3 = Function(W) - u3, _ = U3.subfunctions - u3.interpolate(as_vector([-2, 1])) - - assert errornorm(u2, u3) < 1e-14 diff --git a/unit-tests/fml_tests/test_replacement.py b/unit-tests/fml_tests/test_replacement.py deleted file mode 100644 index 47eceb2c3..000000000 --- a/unit-tests/fml_tests/test_replacement.py +++ /dev/null @@ -1,374 +0,0 @@ -""" -Tests the different replacement routines from replacement.py -""" - -from firedrake import (UnitSquareMesh, FunctionSpace, Function, TestFunction, - TestFunctions, TrialFunction, TrialFunctions, - Argument, - VectorFunctionSpace, dx, inner, split, grad) -from gusto.fml import (Label, subject, replace_subject, - replace_test_function, replace_trial_function, - drop, all_terms) -import pytest - -from collections import namedtuple - -ReplaceSubjArgs = namedtuple("ReplaceSubjArgs", "new_subj idxs error") -ReplaceArgsArgs = namedtuple("ReplaceArgsArgs", "new_arg idxs error replace_function arg_idx") - - -def ReplaceTestArgs(*args): - return ReplaceArgsArgs(*args, replace_test_function, 0) - - -def ReplaceTrialArgs(*args): - return ReplaceArgsArgs(*args, replace_trial_function, 1) - - -# some dummy labels -foo_label = Label("foo") -bar_label = Label("bar") - -nx = 2 -mesh = UnitSquareMesh(nx, nx) -V0 = FunctionSpace(mesh, 'CG', 1) -V1 = FunctionSpace(mesh, 'DG', 1) -W = V0*V1 -Vv = VectorFunctionSpace(mesh, 'CG', 1) -Wv = Vv*V1 - - -@pytest.fixture() -def primal_form(): - primal_subj = Function(V0) - primal_test = TestFunction(V0) - - primal_term1 = foo_label(subject(primal_subj*primal_test*dx, primal_subj)) - primal_term2 = bar_label(inner(grad(primal_subj), grad(primal_test))*dx) - - return primal_term1 + primal_term2 - - -def primal_subj_argsets(): - argsets = [ - ReplaceSubjArgs(Function(V0), {}, None), - ReplaceSubjArgs(Function(V0), {'new_idx': 0}, ValueError), - ReplaceSubjArgs(Function(V0), {'old_idx': 0}, ValueError), - ReplaceSubjArgs(Function(W), {'new_idx': 0}, None), - ReplaceSubjArgs(Function(W), {'new_idx': 1}, None), - ReplaceSubjArgs(split(Function(W)), {'new_idx': 1}, None), - ReplaceSubjArgs(Function(W), {'old_idx': 0}, ValueError), - ReplaceSubjArgs(Function(W), {'new_idx': 7}, IndexError) - ] - return argsets - - -def primal_test_argsets(): - argsets = [ - ReplaceTestArgs(TestFunction(V0), {}, None), - ReplaceTestArgs(TestFunction(V0), {'new_idx': 0}, ValueError), - ReplaceTestArgs(TestFunction(W), {'new_idx': 0}, None), - ReplaceTestArgs(TestFunction(W), {'new_idx': 1}, None), - ReplaceTestArgs(TestFunctions(W), {'new_idx': 1}, None), - ReplaceTestArgs(TestFunction(W), {'new_idx': 7}, IndexError) - ] - return argsets - - -def primal_trial_argsets(): - argsets = [ - ReplaceTrialArgs(TrialFunction(V0), {}, None), - ReplaceTrialArgs(TrialFunction(V0), {'new_idx': 0}, ValueError), - ReplaceTrialArgs(TrialFunction(W), {'new_idx': 0}, None), - ReplaceTrialArgs(TrialFunction(W), {'new_idx': 1}, None), - ReplaceTrialArgs(TrialFunctions(W), {'new_idx': 1}, None), - ReplaceTrialArgs(TrialFunction(W), {'new_idx': 7}, IndexError), - ReplaceTrialArgs(Function(V0), {}, None), - ReplaceTrialArgs(Function(V0), {'new_idx': 0}, ValueError), - ReplaceTrialArgs(Function(W), {'new_idx': 0}, None), - ReplaceTrialArgs(Function(W), {'new_idx': 1}, None), - ReplaceTrialArgs(split(Function(W)), {'new_idx': 1}, None), - ReplaceTrialArgs(Function(W), {'new_idx': 7}, IndexError), - ] - return argsets - - -@pytest.fixture -def mixed_form(): - mixed_subj = Function(W) - mixed_test = TestFunction(W) - - mixed_subj0, mixed_subj1 = split(mixed_subj) - mixed_test0, mixed_test1 = split(mixed_test) - - mixed_term1 = foo_label(subject(mixed_subj0*mixed_test0*dx, mixed_subj)) - mixed_term2 = bar_label(inner(grad(mixed_subj1), grad(mixed_test1))*dx) - - return mixed_term1 + mixed_term2 - - -def mixed_subj_argsets(): - argsets = [ - ReplaceSubjArgs(Function(W), {}, None), - ReplaceSubjArgs(Function(W), {'new_idx': 0, 'old_idx': 0}, None), - ReplaceSubjArgs(Function(W), {'old_idx': 0}, ValueError), - ReplaceSubjArgs(Function(W), {'new_idx': 0}, ValueError), - ReplaceSubjArgs(Function(V0), {'old_idx': 0}, None), - ReplaceSubjArgs(Function(V0), {'new_idx': 0}, ValueError), - ReplaceSubjArgs(split(Function(W)), {'new_idx': 0, 'old_idx': 0}, None), - ] - return argsets - - -def mixed_test_argsets(): - argsets = [ - ReplaceTestArgs(TestFunction(W), {}, None), - ReplaceTestArgs(TestFunctions(W), {}, None), - ReplaceTestArgs(TestFunction(W), {'old_idx': 0, 'new_idx': 0}, None), - ReplaceTestArgs(TestFunctions(W), {'old_idx': 0}, ValueError), - ReplaceTestArgs(TestFunction(W), {'new_idx': 0}, ValueError), - ReplaceTestArgs(TestFunction(V0), {'old_idx': 0}, None), - ReplaceTestArgs(TestFunctions(V0), {'new_idx': 1}, ValueError), - ReplaceTestArgs(TestFunction(W), {'old_idx': 7, 'new_idx': 7}, IndexError) - ] - return argsets - - -def mixed_trial_argsets(): - argsets = [ - ReplaceTrialArgs(TrialFunction(W), {}, None), - ReplaceTrialArgs(TrialFunctions(W), {}, None), - ReplaceTrialArgs(TrialFunction(W), {'old_idx': 0, 'new_idx': 0}, None), - ReplaceTrialArgs(TrialFunction(V0), {'old_idx': 0}, None), - ReplaceTrialArgs(TrialFunctions(V0), {'new_idx': 1}, ValueError), - ReplaceTrialArgs(TrialFunction(W), {'old_idx': 7, 'new_idx': 7}, IndexError), - ReplaceTrialArgs(Function(W), {}, None), - ReplaceTrialArgs(split(Function(W)), {}, None), - ReplaceTrialArgs(Function(W), {'old_idx': 0, 'new_idx': 0}, None), - ReplaceTrialArgs(Function(V0), {'old_idx': 0}, None), - ReplaceTrialArgs(Function(V0), {'new_idx': 0}, ValueError), - ReplaceTrialArgs(Function(W), {'old_idx': 7, 'new_idx': 7}, IndexError), - ] - return argsets - - -@pytest.fixture -def vector_form(): - vector_subj = Function(Vv) - vector_test = TestFunction(Vv) - - vector_term1 = foo_label(subject(inner(vector_subj, vector_test)*dx, vector_subj)) - vector_term2 = bar_label(inner(grad(vector_subj), grad(vector_test))*dx) - - return vector_term1 + vector_term2 - - -def vector_subj_argsets(): - argsets = [ - ReplaceSubjArgs(Function(Vv), {}, None), - ReplaceSubjArgs(Function(V0), {}, ValueError), - ReplaceSubjArgs(Function(Vv), {'new_idx': 0}, ValueError), - ReplaceSubjArgs(Function(Vv), {'old_idx': 0}, ValueError), - ReplaceSubjArgs(Function(Wv), {'new_idx': 0}, None), - ReplaceSubjArgs(Function(Wv), {'new_idx': 1}, ValueError), - ReplaceSubjArgs(split(Function(Wv)), {'new_idx': 0}, None), - ReplaceSubjArgs(Function(W), {'old_idx': 0}, ValueError), - ReplaceSubjArgs(Function(W), {'new_idx': 7}, IndexError), - ] - return argsets - - -def vector_test_argsets(): - argsets = [ - ReplaceTestArgs(TestFunction(Vv), {}, None), - ReplaceTestArgs(TestFunction(V0), {}, ValueError), - ReplaceTestArgs(TestFunction(Vv), {'new_idx': 0}, ValueError), - ReplaceTestArgs(TestFunction(Wv), {'new_idx': 0}, None), - ReplaceTestArgs(TestFunction(Wv), {'new_idx': 1}, ValueError), - ReplaceTestArgs(TestFunctions(Wv), {'new_idx': 0}, None), - ReplaceTestArgs(TestFunction(W), {'new_idx': 7}, IndexError), - ] - return argsets - - -@pytest.mark.parametrize('argset', primal_subj_argsets()) -def test_replace_subject_primal(primal_form, argset): - new_subj = argset.new_subj - idxs = argset.idxs - error = argset.error - - if error is None: - old_subj = primal_form.form.coefficients()[0] - - new_form = primal_form.label_map( - lambda t: t.has_label(foo_label), - map_if_true=replace_subject(new_subj, **idxs), - map_if_false=drop) - - # what if we only replace part of the subject? - if 'new_idx' in idxs: - split_new = new_subj if type(new_subj) is tuple else split(new_subj) - new_subj = split_new[idxs['new_idx']].ufl_operands[0] - - assert new_subj in new_form.form.coefficients() - assert old_subj not in new_form.form.coefficients() - - else: - with pytest.raises(error): - new_form = primal_form.label_map( - lambda t: t.has_label(foo_label), - map_if_true=replace_subject(new_subj, **idxs)) - - -@pytest.mark.parametrize('argset', mixed_subj_argsets()) -def test_replace_subject_mixed(mixed_form, argset): - new_subj = argset.new_subj - idxs = argset.idxs - error = argset.error - - if error is None: - old_subj = mixed_form.form.coefficients()[0] - - new_form = mixed_form.label_map( - lambda t: t.has_label(foo_label), - map_if_true=replace_subject(new_subj, **idxs), - map_if_false=drop) - - # what if we only replace part of the subject? - if 'new_idx' in idxs: - split_new = new_subj if type(new_subj) is tuple else split(new_subj) - new_subj = split_new[idxs['new_idx']].ufl_operands[0] - - assert new_subj in new_form.form.coefficients() - assert old_subj not in new_form.form.coefficients() - - else: - with pytest.raises(error): - new_form = mixed_form.label_map( - lambda t: t.has_label(foo_label), - map_if_true=replace_subject(new_subj, **idxs)) - - -@pytest.mark.parametrize('argset', vector_subj_argsets()) -def test_replace_subject_vector(vector_form, argset): - new_subj = argset.new_subj - idxs = argset.idxs - error = argset.error - - if error is None: - old_subj = vector_form.form.coefficients()[0] - - new_form = vector_form.label_map( - lambda t: t.has_label(foo_label), - map_if_true=replace_subject(new_subj, **idxs), - map_if_false=drop) - - # what if we only replace part of the subject? - if 'new_idx' in idxs: - split_new = new_subj if type(new_subj) is tuple else split(new_subj) - new_subj = split_new[idxs['new_idx']].ufl_operands[0].ufl_operands[0] - - assert new_subj in new_form.form.coefficients() - assert old_subj not in new_form.form.coefficients() - - else: - with pytest.raises(error): - new_form = vector_form.label_map( - lambda t: t.has_label(foo_label), - map_if_true=replace_subject(new_subj, **idxs)) - - -@pytest.mark.parametrize('argset', primal_test_argsets() + primal_trial_argsets()) -def test_replace_arg_primal(primal_form, argset): - new_arg = argset.new_arg - idxs = argset.idxs - error = argset.error - replace_function = argset.replace_function - arg_idx = argset.arg_idx - primal_form = primal_form.label_map(lambda t: t.has_label(subject), - replace_subject(TrialFunction(V0)), - drop) - - if error is None: - new_form = primal_form.label_map( - all_terms, - map_if_true=replace_function(new_arg, **idxs)) - - if 'new_idx' in idxs: - split_arg = new_arg if type(new_arg) is tuple else split(new_arg) - new_arg = split_arg[idxs['new_idx']].ufl_operands[0] - - if isinstance(new_arg, Argument): - assert new_form.form.arguments()[arg_idx] is new_arg - elif type(new_arg) is Function: - assert new_form.form.coefficients()[0] is new_arg - - else: - with pytest.raises(error): - new_form = primal_form.label_map( - all_terms, - map_if_true=replace_function(new_arg, **idxs)) - - -@pytest.mark.parametrize('argset', mixed_test_argsets() + mixed_trial_argsets()) -def test_replace_arg_mixed(mixed_form, argset): - new_arg = argset.new_arg - idxs = argset.idxs - error = argset.error - replace_function = argset.replace_function - arg_idx = argset.arg_idx - mixed_form = mixed_form.label_map(lambda t: t.has_label(subject), - replace_subject(TrialFunction(W)), - drop) - - if error is None: - new_form = mixed_form.label_map( - all_terms, - map_if_true=replace_function(new_arg, **idxs)) - - if 'new_idx' in idxs: - split_arg = new_arg if type(new_arg) is tuple else split(new_arg) - new_arg = split_arg[idxs['new_idx']].ufl_operands[0] - - if isinstance(new_arg, Argument): - assert new_form.form.arguments()[arg_idx] is new_arg - elif type(new_arg) is Function: - assert new_form.form.coefficients()[0] is new_arg - - else: - with pytest.raises(error): - new_form = mixed_form.label_map( - all_terms, - map_if_true=replace_function(new_arg, **idxs)) - - -@pytest.mark.parametrize('argset', vector_test_argsets()) -def test_replace_arg_vector(vector_form, argset): - new_arg = argset.new_arg - idxs = argset.idxs - error = argset.error - replace_function = argset.replace_function - arg_idx = argset.arg_idx - vector_form = vector_form.label_map(lambda t: t.has_label(subject), - replace_subject(TrialFunction(Vv)), - drop) - - if error is None: - new_form = vector_form.label_map( - all_terms, - map_if_true=replace_function(new_arg, **idxs)) - - if 'new_idx' in idxs: - split_arg = new_arg if type(new_arg) is tuple else split(new_arg) - new_arg = split_arg[idxs['new_idx']].ufl_operands[0] - - if isinstance(new_arg, Argument): - assert new_form.form.arguments()[arg_idx] is new_arg - elif type(new_arg) is Function: - assert new_form.form.coefficients()[0] is new_arg - - else: - with pytest.raises(error): - new_form = vector_form.label_map( - all_terms, - map_if_true=replace_function(new_arg, **idxs)) diff --git a/unit-tests/fml_tests/test_term.py b/unit-tests/fml_tests/test_term.py deleted file mode 100644 index 403a7096a..000000000 --- a/unit-tests/fml_tests/test_term.py +++ /dev/null @@ -1,170 +0,0 @@ -""" -Tests FML's Term objects. A term contains a form and labels. -""" - -from firedrake import (IntervalMesh, FunctionSpace, Function, - TestFunction, dx, Constant) -from gusto.fml import Label, Term, LabelledForm -import pytest - - -# Two methods of making a Term with Labels. Either pass them as a dict -# at the initialisation of the Term, or apply them afterwards -@pytest.mark.parametrize("initialise", ["from_dicts", "apply_labels"]) -def test_term(initialise): - - # ------------------------------------------------------------------------ # - # Set up terms - # ------------------------------------------------------------------------ # - - # Some basic labels - foo_label = Label("foo", validator=lambda value: type(value) == bool) - lorem_label = Label("lorem", validator=lambda value: type(value) == str) - ipsum_label = Label("ipsum", validator=lambda value: type(value) == int) - - # Dict for matching the label names to the label objects - all_labels = [foo_label, lorem_label, ipsum_label] - all_label_dict = {label.label: label for label in all_labels} - - # Create mesh, function space and forms - L = 3.0 - n = 3 - mesh = IntervalMesh(n, L) - V = FunctionSpace(mesh, "DG", 0) - f = Function(V) - g = Function(V) - h = Function(V) - test = TestFunction(V) - form = f*test*dx - - # Declare what the labels will be - label_dict = {'foo': True, 'lorem': 'etc', 'ipsum': 1} - - # Make terms - if initialise == "from_dicts": - term = Term(form, label_dict) - else: - term = Term(form) - - # Apply labels - for label_name, value in label_dict.items(): - term = all_label_dict[label_name](term, value) - - # ------------------------------------------------------------------------ # - # Test Term.get routine - # ------------------------------------------------------------------------ # - - for label in all_labels: - if label.label in label_dict.keys(): - # Check if label is attached to Term and it has correct value - assert term.get(label) == label_dict[label.label], \ - f'term should have label {label.label} with value equal ' + \ - f'to {label_dict[label.label]} and not {term.get(label)}' - else: - # Labelled shouldn't be attached to Term so this should return None - assert term.get(label) is None, 'term should not have ' + \ - f'label {label.label} but term.get(label) returns ' + \ - f'{term.get(label)}' - - # ------------------------------------------------------------------------ # - # Test Term.has_label routine - # ------------------------------------------------------------------------ # - - # Test has_label for each label one by one - for label in all_labels: - assert term.has_label(label) == (label.label in label_dict.keys()), \ - f'term.has_label giving incorrect value for {label.label}' - - # Test has_labels by passing all labels at once - has_labels = term.has_label(*all_labels, return_tuple=True) - for i, label in enumerate(all_labels): - assert has_labels[i] == (label.label in label_dict.keys()), \ - f'has_label for label {label.label} returning wrong value' - - # Check the return_tuple option is correct when only one label is passed - has_labels = term.has_label(*[foo_label], return_tuple=True) - assert len(has_labels) == 1, 'Length returned by has_label is ' + \ - f'incorrect, it is {len(has_labels)} but should be 1' - assert has_labels[0] == (label.label in label_dict.keys()), \ - f'has_label for label {label.label} returning wrong value' - - # ------------------------------------------------------------------------ # - # Test Term addition and subtraction - # ------------------------------------------------------------------------ # - - form_2 = g*test*dx - term_2 = ipsum_label(Term(form_2), 2) - - labelled_form_1 = term_2 + term - labelled_form_2 = term + term_2 - - # Adding two Terms should return a LabelledForm containing the Terms - assert type(labelled_form_1) is LabelledForm, 'The sum of two Terms ' + \ - f'should be a LabelledForm, not {type(labelled_form_1)}' - assert type(labelled_form_2) is LabelledForm, 'The sum of two Terms ' + \ - f'should be a LabelledForm, not {type(labelled_form_1)}' - - # Adding a LabelledForm to a Term should return a LabelledForm - labelled_form_3 = term + labelled_form_2 - assert type(labelled_form_3) is LabelledForm, 'The sum of a Term and ' + \ - f'Labelled Form should be a LabelledForm, not {type(labelled_form_3)}' - - labelled_form_1 = term_2 - term - labelled_form_2 = term - term_2 - - # Subtracting two Terms should return a LabelledForm containing the Terms - assert type(labelled_form_1) is LabelledForm, 'The difference of two ' + \ - f'Terms should be a LabelledForm, not {type(labelled_form_1)}' - assert type(labelled_form_2) is LabelledForm, 'The difference of two ' + \ - f'Terms should be a LabelledForm, not {type(labelled_form_1)}' - - # Subtracting a LabelledForm from a Term should return a LabelledForm - labelled_form_3 = term - labelled_form_2 - assert type(labelled_form_3) is LabelledForm, 'The differnce of a Term ' + \ - f'and a Labelled Form should be a LabelledForm, not {type(labelled_form_3)}' - - # Adding None to a Term should return the Term - new_term = term + None - assert term == new_term, 'Adding None to a Term should give the same Term' - - # ------------------------------------------------------------------------ # - # Test Term multiplication and division - # ------------------------------------------------------------------------ # - - # Multiplying a term by an integer should give a Term - new_term = term*3 - assert type(new_term) is Term, 'Multiplying a Term by an integer ' + \ - f'give a Term, not a {type(new_term)}' - - # Multiplying a term by a float should give a Term - new_term = term*19.0 - assert type(new_term) is Term, 'Multiplying a Term by a float ' + \ - f'give a Term, not a {type(new_term)}' - - # Multiplying a term by a Constant should give a Term - new_term = term*Constant(-4.0) - assert type(new_term) is Term, 'Multiplying a Term by a Constant ' + \ - f'give a Term, not a {type(new_term)}' - - # Dividing a term by an integer should give a Term - new_term = term/3 - assert type(new_term) is Term, 'Dividing a Term by an integer ' + \ - f'give a Term, not a {type(new_term)}' - - # Dividing a term by a float should give a Term - new_term = term/19.0 - assert type(new_term) is Term, 'Dividing a Term by a float ' + \ - f'give a Term, not a {type(new_term)}' - - # Dividing a term by a Constant should give a Term - new_term = term/Constant(-4.0) - assert type(new_term) is Term, 'Dividing a Term by a Constant ' + \ - f'give a Term, not a {type(new_term)}' - - # Multiplying a term by a Function should fail - try: - new_term = term*h - # If we get here we have failed - assert False, 'Multiplying a Term by a Function should fail' - except TypeError: - pass