Skip to content

Commit

Permalink
address constants
Browse files Browse the repository at this point in the history
  • Loading branch information
ddkohler committed Jan 25, 2024
1 parent ce53625 commit a151998
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
8 changes: 5 additions & 3 deletions WrightTools/data/_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def squeeze(self, name=None, parent=None):
i.e. if the joint shape of the axes has an array dimension with length 1, this
array dimension is squeezed.
channels and variables that span beyond the axes are omitted.
channels and variables that span beyond the axes are removed.
Parameters
----------
Expand Down Expand Up @@ -408,8 +408,6 @@ def squeeze(self, name=None, parent=None):
}
new.attrs.update(attrs)

# TODO: deal with constants? establish new constants?

joint_shape = wt_kit.joint_shape(*[ai[:] for ai in self.axes])
cull_dims = [j == 1 for j in joint_shape]
sl = [0 if cull else slice(None) for cull in cull_dims]
Expand All @@ -427,6 +425,10 @@ def squeeze(self, name=None, parent=None):
kwargs["values"] = c[sl]
new.create_channel(**kwargs)

# inherit constants
for c in self.constants:
new.create_constant(c.expression)

new.transform(*self.axis_expressions)
return new

Expand Down
13 changes: 13 additions & 0 deletions tests/data/squeeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,19 @@ def test_squeeze():
assert d.ndim == 2
assert d.shape == (5, 4)

def test_constants():
d = wt.Data(name="test")
d.create_variable("x", values=np.array([1]).reshape(1,1))
d.create_constant("x")
d.create_variable("y", values=np.linspace(3,5,4).reshape(-1,1))
d.create_variable("z", values=np.linspace(0,1,6).reshape(1,-1))
d.transform("y")
ds = d.squeeze()
assert "x" in ds.constant_expressions
d.print_tree()
ds.print_tree()


if __name__ == "__main__":
test_squeeze()
test_constants()

0 comments on commit a151998

Please sign in to comment.