Skip to content

Commit

Permalink
Merging interleave_data bugfix into release #patch
Browse files Browse the repository at this point in the history
  • Loading branch information
jgostick authored Jul 14, 2020
2 parents 299becb + 1a9ad7d commit 720157f
Show file tree
Hide file tree
Showing 4 changed files with 255 additions and 107 deletions.
61 changes: 45 additions & 16 deletions openpnm/core/Base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1165,10 +1165,7 @@ def interleave_data(self, prop):
>>> print(g1['pore.label']) # 'pore.label' is defined on pn, not g1
[False False False False]
"""
element = self._parse_element(prop.split('.')[0], single=True)
N = self.project.network._count(element)

# Fetch sources list depending on object type?
# Fetch sources list depending on type of self
proj = self.project
if self._isa() in ['network', 'geometry']:
sources = list(proj.geometries().values())
Expand All @@ -1179,13 +1176,45 @@ def interleave_data(self, prop):
else:
raise Exception('Unrecognized object type, cannot find dependents')

# Get generalized element and array length
element = self._parse_element(prop.split('.')[0], single=True)
N = self.project.network._count(element)

# Attempt to fetch the requested array from each object
arrs = [item.get(prop, None) for item in sources]
arrs = [obj.get(prop, None) for obj in sources]

# Check for missing sources, and add None to arrs if necessary
if N > sum([obj._count(element) for obj in sources]):
arrs.append(None)

# Obtain list of locations for inserting values
locs = [self._get_indices(element, item.name) for item in sources]
sizes = [np.size(a) for a in arrs]

if np.all([item is None for item in arrs]): # prop not found anywhere
raise KeyError(prop)

# --------------------------------------------------------------------
# Let's start by handling the easy cases first
if not any([a is None for a in arrs]):
# All objs present and array found on all objs
shape = list(arrs[0].shape)
shape[0] = N
types = [a.dtype for a in arrs]
if len(set(types)) == 1:
# All types are the same
temp_arr = np.ones(shape, dtype=types[0])
for vals, inds in zip(arrs, locs):
temp_arr[inds] = vals
return temp_arr # Return early because it's just easier
elif all([a.dtype in [float, int, bool] for a in arrs]):
# All types are numeric, make float
temp_arr = np.ones(shape, dtype=float)
for vals, inds in zip(arrs, locs):
temp_arr[inds] = vals
return temp_arr # Return early because it's just easier

# ---------------------------------------------------------------------
# Now handle the complicated cases
# Check the general type of each array
atype = []
for a in arrs:
Expand All @@ -1211,6 +1240,7 @@ def interleave_data(self, prop):
temp_arr = np.zeros((N, item.shape[1]), dtype=item.dtype)
temp_arr.fill(dummy_val[atype[0]])

sizes = [np.size(a) for a in arrs]
# Convert int arrays to float IF NaNs are expected
if temp_arr.dtype.name.startswith('int') and \
(np.any([i is None for i in arrs]) or np.sum(sizes) != N):
Expand All @@ -1228,16 +1258,15 @@ def interleave_data(self, prop):
# Importing unyt significantly adds to our import time, we also
# currently don't use this package extensively, so we're not going
# to support it for now.
"""
if any([hasattr(a, 'units') for a in arrs]):
[a.convert_to_mks() for a in arrs if hasattr(a, 'units')]
units = [a.units.__str__() for a in arrs if hasattr(a, 'units')]
if len(units) > 0:
if len(set(units)) == 1:
temp_arr *= np.array([1]) * getattr(unyt, units[0])
else:
raise Exception('Units on the interleaved array are not equal')
"""

# if any([hasattr(a, 'units') for a in arrs]):
# [a.convert_to_mks() for a in arrs if hasattr(a, 'units')]
# units = [a.units.__str__() for a in arrs if hasattr(a, 'units')]
# if len(units) > 0:
# if len(set(units)) == 1:
# temp_arr *= np.array([1]) * getattr(unyt, units[0])
# else:
# raise Exception('Units on the interleaved array are not equal')

return temp_arr

Expand Down
89 changes: 0 additions & 89 deletions tests/unit/core/BaseTest.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,95 +712,6 @@ def test_map_pores_missing(self):
b = self.geo22.map_pores(pores=Ps, origin=self.net2)
assert len(b) == 0

def test_interleave_data_bool(self):
net = op.network.Cubic(shape=[2, 2, 2])
Ps = net.pores('top')
geom1 = op.geometry.GenericGeometry(network=net, pores=Ps)
Ps = net.pores('bottom')
geom2 = op.geometry.GenericGeometry(network=net, pores=Ps)
# Ensure Falses return in missing places
geom1['pore.blah'] = True
assert np.all(~geom2['pore.blah'])
assert np.sum(net['pore.blah']) == 4
# Ensure all Trues returned now
geom2['pore.blah'] = True
assert np.all(geom2['pore.blah'])
assert np.sum(net['pore.blah']) == 8

def test_interleave_data_int(self):
net = op.network.Cubic(shape=[2, 2, 2])
Ps = net.pores('top')
geom1 = op.geometry.GenericGeometry(network=net, pores=Ps)
Ps = net.pores('bottom')
geom2 = op.geometry.GenericGeometry(network=net, pores=Ps)
geom1['pore.blah'] = 1
# Ensure ints are returned geom1
assert 'int' in geom1['pore.blah'].dtype.name
# Ensure nans are returned on geom2
assert np.all(np.isnan(geom2['pore.blah']))
# Ensure interleaved array is float with nans
assert 'float' in net['pore.blah'].dtype.name
# Ensure missing values are floats
assert np.sum(np.isnan(net['pore.blah'])) == 4

def test_interleave_data_float(self):
net = op.network.Cubic(shape=[2, 2, 2])
Ps = net.pores('top')
geom1 = op.geometry.GenericGeometry(network=net, pores=Ps)
Ps = net.pores('bottom')
geom2 = op.geometry.GenericGeometry(network=net, pores=Ps)
geom1['pore.blah'] = 1.0
# Ensure flaots are returned geom1
assert 'float' in geom1['pore.blah'].dtype.name
# Ensure nans are returned on geom2
assert np.all(np.isnan(geom2['pore.blah']))
# Ensure interleaved array is float with nans
assert 'float' in net['pore.blah'].dtype.name
# Ensure missing values are floats
assert np.sum(np.isnan(net['pore.blah'])) == 4

def test_interleave_data_object(self):
net = op.network.Cubic(shape=[2, 2, 2])
Ps = net.pores('top')
geom1 = op.geometry.GenericGeometry(network=net, pores=Ps)
Ps = net.pores('bottom')
_ = op.geometry.GenericGeometry(network=net, pores=Ps)
geom1['pore.blah'] = [[1, 2], [1, 2, 3], [1, 2, 3, 4], [1]]
assert 'object' in net['pore.blah'].dtype.name
# Ensure missing elements are None
assert np.sum([item is None for item in net['pore.blah']]) == 4

def test_interleave_data_key_error(self):
net = op.network.Cubic(shape=[2, 2, 2])
Ps = net.pores('top')
geom1 = op.geometry.GenericGeometry(network=net, pores=Ps)
Ps = net.pores('bottom')
geom2 = op.geometry.GenericGeometry(network=net, pores=Ps)
with pytest.raises(KeyError):
net['pore.blah']
with pytest.raises(KeyError):
geom1['pore.blah']
with pytest.raises(KeyError):
geom2['pore.blah']

def test_interleave_data_float_missing_geometry(self):
net = op.network.Cubic(shape=[2, 2, 2])
geom = op.geometry.GenericGeometry(network=net, pores=[0, 1, 2])
geom['pore.blah'] = 1.0
assert np.any(np.isnan(net['pore.blah']))

def test_interleave_data_int_missing_geometry(self):
net = op.network.Cubic(shape=[2, 2, 2])
geom = op.geometry.GenericGeometry(network=net, pores=[0, 1, 2])
geom['pore.blah'] = 1
assert np.any(np.isnan(net['pore.blah']))

def test_interleave_data_bool_missing_geometry(self):
net = op.network.Cubic(shape=[2, 2, 2])
geom = op.geometry.GenericGeometry(network=net, pores=[0, 1, 2])
geom['pore.blah'] = True
assert np.sum(net['pore.blah']) == geom.Np

def test_getitem_with_no_matches(self):
self.geo.pop('pore.blah', None)
with pytest.raises(KeyError):
Expand Down
Loading

0 comments on commit 720157f

Please sign in to comment.