diff --git a/WrightTools/data/_data.py b/WrightTools/data/_data.py index f15d12e5..5a307e53 100644 --- a/WrightTools/data/_data.py +++ b/WrightTools/data/_data.py @@ -1935,7 +1935,7 @@ def split( # axis ------------------------------------------------------------------------------------ old_expr = self.axis_expressions old_units = self.units - out = wt_collection.Collection(name=f"{self.name}_split", parent=parent) + out = wt_collection.Collection(name=f"{self.natural_name}_split", parent=parent) if isinstance(expression, int): if units is None: units = self._axes[expression].units @@ -1969,11 +1969,8 @@ def split( omasks.append(None) cuts.append(None) for i in range(len(positions) - 1): - out.create_data(f"{self.name}_{i:0>3}") + out.create_data(f"{self.natural_name}_{i:0>3}") - if inherit_attrs: - for d in out.values(): - {d.attrs[k]: self.attrs[k] for k in self.attrs.keys() if k not in d.attrs.keys()} for var in self.variables: for i, (imask, omask, cut) in enumerate(zip(masks, omasks, cuts)): if omask is None: @@ -2047,6 +2044,10 @@ def split( for ax, u in zip(self.axes, old_units): ax.convert(u) + if inherit_attrs: + for d in out.values(): + {d.attrs[k]: self.attrs[k] for k in self.attrs.keys() if k not in d.attrs.keys()} + return out def transform(self, *axes, verbose=True): diff --git a/tests/data/split.py b/tests/data/split.py index 5b2e8b61..fd4fbb44 100755 --- a/tests/data/split.py +++ b/tests/data/split.py @@ -122,7 +122,7 @@ def test_split_parent(): a = wt.data.from_PyCMDS(p) parent = wt.Collection() split = a.split(1, [1500], parent=parent) - assert "split" in parent + assert f"{a.natural_name}_split" in parent assert split.filepath == parent.filepath assert len(split) == 2 a.close()