diff --git a/FIAT/crouzeix_raviart.py b/FIAT/crouzeix_raviart.py index 0979c9d4..8447a126 100644 --- a/FIAT/crouzeix_raviart.py +++ b/FIAT/crouzeix_raviart.py @@ -14,7 +14,6 @@ from FIAT.check_format_variant import check_format_variant from FIAT.quadrature_schemes import create_quadrature from FIAT.quadrature import FacetQuadratureRule -from FIAT.reference_element import make_lattice class CrouzeixRaviartDualSet(dual_set.DualSet): @@ -24,6 +23,9 @@ def __init__(self, ref_el, degree, variant, interpolant_deg): sd = ref_el.get_spatial_dimension() top = ref_el.get_topology() + if degree > 1 and sd != 2: + raise NotImplementedError("High-order Crouzeix-Raviart is only implemented on triangles.") + # Initialize empty nodes and entity_ids entity_ids = {dim: {entity: [] for entity in top[dim]} for dim in top} nodes = [] @@ -61,11 +63,9 @@ def __init__(self, ref_el, degree, variant, interpolant_deg): for i in sorted(top[dim]): cur = len(nodes) if dim == sd-1 and dim != 0: - verts = ref_el.get_vertices_of_subcomplex(top[dim][i]) - pts = make_lattice(verts, degree-1, variant="gl") + pts = ref_el.make_points(dim, i, degree-1, variant="gl", interior=0) else: pts = ref_el.make_points(dim, i, degree, variant="gll") - nodes.extend(functional.PointEvaluation(ref_el, x) for x in pts) entity_ids[dim][i].extend(range(cur, len(nodes))) diff --git a/FIAT/reference_element.py b/FIAT/reference_element.py index 7d9ff296..b4cd2c63 100644 --- a/FIAT/reference_element.py +++ b/FIAT/reference_element.py @@ -510,7 +510,7 @@ def compute_face_edge_tangents(self, dim, entity_id): v1.append(dest) return vert_coords[v1] - vert_coords[v0] - def make_points(self, dim, entity_id, order, variant=None): + def make_points(self, dim, entity_id, order, variant=None, interior=1): """Constructs a lattice of points on the entity_id:th facet of dimension dim. Order indicates how many points to include in each direction.""" @@ -520,7 +520,7 @@ def make_points(self, dim, entity_id, order, variant=None): entity_verts = \ self.get_vertices_of_subcomplex( self.get_topology()[dim][entity_id]) - return make_lattice(entity_verts, order, 1, variant=variant) + return make_lattice(entity_verts, order, interior=interior, variant=variant) else: raise ValueError("illegal dimension")