Skip to content

Commit

Permalink
Improve restrict (#1124)
Browse files Browse the repository at this point in the history
  • Loading branch information
kinnala authored May 3, 2024
1 parent 976f1cc commit 19749db
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 15 deletions.
58 changes: 43 additions & 15 deletions skfem/mesh/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,37 +1079,65 @@ def trace(self, facets, mtype=None, project=None):
t
), facets

def restrict(self, elements):
def restrict(self,
elements,
return_mapping=False,
skip_boundaries=False,
skip_subdomains=False):
"""Restrict the mesh to a subset of elements.
Parameters
----------
elements
Criteria of which elements to include. This input is normalized
using ``self.normalize_elements``.
return_mapping
Optionally, return the index mapping for vertices.
skip_boundaries
Optionally, skip retagging boundaries.
skip_subdomains
Optionally, skip retagging subdomains.
"""
elements = self.normalize_elements(elements)
p, t, ix = self._reix(self.t[:, elements])
newt = np.zeros(self.t.shape[1], dtype=np.int64) - 1
newt[elements] = np.arange(len(elements), dtype=np.int64)
newf = np.zeros(self.facets.shape[1], dtype=np.int64) - 1
facets = np.unique(self.t2f[:, elements])
newf[facets] = np.arange(len(facets), dtype=np.int64)
return replace(

new_subdomains = None
if not skip_subdomains and self.subdomains is not None:
# map from old to new element index
newt = np.zeros(self.t.shape[1], dtype=np.int64) - 1
newt[elements] = np.arange(len(elements), dtype=np.int64)
# remove 'elements' from each subdomain and remap
new_subdomains = {
k: newt[np.intersect1d(self.subdomains[k],
elements).astype(np.int64)]
for k in self.subdomains
}

new_boundaries = None
if not skip_boundaries and self.boundaries is not None:
# map from old to new facet index
newf = np.zeros(self.facets.shape[1], dtype=np.int64) - 1
facets = np.unique(self.t2f[:, elements])
newf[facets] = np.arange(len(facets), dtype=np.int64)
new_boundaries = {k: newf[self.boundaries[k]]
for k in self.boundaries}
# filter facets not existing in the new mesh, value is -1
new_boundaries = {k: v[v >= 0]
for k, v in new_boundaries.items()}

out = replace(
self,
doflocs=p,
t=t,
_boundaries=({k: np.extract(newf[self.boundaries[k]] >= 0,
newf[self.boundaries[k]])
for k in self.boundaries}
if self.boundaries is not None else None),
_subdomains=({k: newt[np.intersect1d(self.subdomains[k],
elements).astype(np.int64)]
for k in self.subdomains}
if self.subdomains is not None else None),
_boundaries=new_boundaries,
_subdomains=new_subdomains,
)

if return_mapping:
return out, ix
return out

def remove_elements(self, elements):
"""Construct a new mesh by removing elements.
Expand Down
63 changes: 63 additions & 0 deletions tests/test_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,3 +792,66 @@ def test_incidence(mesh):
p2f = mesh.p2f
for itr in range(0, 50, 3):
assert np.sum((mesh.facets == itr).any(axis=0)) == len(p2f[:, itr].data)


def test_restrict_tags_boundary():

m = MeshTri().refined(3)
m = m.with_subdomains({
'left': lambda x: x[0] <= 0.5,
'bottom': lambda x: x[1] <= 0.5,
})

mr = m.restrict('left')

# test boundary retag
topleftp = m.p[0, np.unique(m.facets[:, m.boundaries['top']].flatten())]
topleftp = np.sort(topleftp[topleftp <= 0.5])
topleftpr = np.sort(mr.p[0, np.unique(mr.facets[:, mr.boundaries['top']].flatten())])

assert_array_equal(topleftp, topleftpr)


def test_restrict_tags_subdomain():

m = MeshTri().refined(3)
m = m.with_subdomains({
'left': lambda x: x[0] <= 0.5,
'bottom': lambda x: x[1] <= 0.5,
})

mr = m.restrict('left')

# test subdomain retag
bottomleftp = m.p[:, np.unique(m.t[:, m.subdomains['bottom']].flatten())]
bottomleftp = bottomleftp[:, bottomleftp[0] <= 0.5]
ix = np.argsort(bottomleftp[0] + 0.1 * bottomleftp[1])
bottomleftp = bottomleftp[:, ix]

bottomleftpr = mr.p[:, np.unique(mr.t[:, mr.subdomains['bottom']].flatten())]
ix = np.argsort(bottomleftpr[0] + 0.1 * bottomleftpr[1])
bottomleftpr = bottomleftpr[:, ix]

assert_array_equal(bottomleftp, bottomleftpr)


def test_restrict_reverse_map():

m = MeshTri().refined(3)
m = m.with_subdomains({
'left': lambda x: x[0] <= 0.5,
'bottom': lambda x: x[1] <= 0.5,
})

mr, ix = m.restrict('left', return_mapping=True)


p1 = mr.p
I = np.argsort(p1[0] + 0.1 * p1[1])
p1 = p1[:, I]

p2 = m.p[:, ix]
I = np.argsort(p2[0] + 0.1 * p2[1])
p2 = p2[:, I]

assert_array_equal(p1, p2)

0 comments on commit 19749db

Please sign in to comment.