Skip to content

Commit

Permalink
refactor: ✨ Add empty datasets to allow for no validation cases.
Browse files Browse the repository at this point in the history
Also switch to PyPi xarray-tensorstore for dependency.
  • Loading branch information
rhoadesScholar committed Oct 10, 2024
1 parent fcdb26c commit 0e8a7ff
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 15 deletions.
3 changes: 1 addition & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@

extensions = [
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
"sphinx.ext.autosectionlabel",
"sphinx.ext.napoleon",
"sphinx.ext.githubpages",
"sphinx.ext.viewcode",
# "sphinx.ext.linkcode",
"sphinx.ext.autosectionlabel",
"sphinx.ext.coverage",
"sphinx.ext.autosummary",
]

autodoc_default_options = {
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ dependencies = [
"pydantic_ome_ngff",
"xarray_ome_ngff",
"tensorstore",
"xarray-tensorstore @ git+https://github.com/google/xarray-tensorstore.git",
# "xarray-tensorstore @ git+https://github.com/google/xarray-tensorstore.git",
"xarray-tensorstore",
"universal_pathlib>=0.2.0",
"fsspec[s3,http]",
"cellpose",
Expand Down
11 changes: 11 additions & 0 deletions src/cellmap_data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,3 +710,14 @@ def reset_arrays(self, type: str = "target") -> None:
self.target_sources[array_name] = self.get_target_array(array_info)
else:
raise ValueError(f"Unknown dataset array type: {type}")

@staticmethod
def empty() -> "CellMapDataset":
"""Creates an empty dataset."""
empty_dataset = CellMapDataset("", "", [], {}, {})
empty_dataset.classes = []
empty_dataset._class_counts = {}
empty_dataset._class_weights = {}
empty_dataset._validation_indices = []

return empty_dataset
25 changes: 14 additions & 11 deletions src/cellmap_data/datasplit.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,20 +226,23 @@ def train_datasets_combined(self) -> CellMapMultiDataset:
@property
def validation_datasets_combined(self) -> CellMapMultiDataset:
"""A multi-dataset from the combination of all validation datasets."""
assert len(self.validation_datasets) > 0, "Validation datasets not loaded."
try:
return self._validation_datasets_combined
except AttributeError:
self._validation_datasets_combined = CellMapMultiDataset(
self.classes,
self.input_arrays,
self.target_arrays,
[
ds
for ds in self.validation_datasets
if self.force_has_data or ds.has_data
],
)
if len(self.validation_datasets) == 0:
UserWarning("Validation datasets not loaded.")
self._validation_datasets_combined = CellMapMultiDataset.empty()
else:
self._validation_datasets_combined = CellMapMultiDataset(
self.classes,
self.input_arrays,
self.target_arrays,
[
ds
for ds in self.validation_datasets
if self.force_has_data or ds.has_data
],
)
return self._validation_datasets_combined

@property
Expand Down
11 changes: 10 additions & 1 deletion src/cellmap_data/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,16 @@ def array(self) -> xarray.DataArray:
array_future = tensorstore.open(
spec, read=True, write=False, context=self.context
)
array = array_future.result()
try:
array = array_future.result()
except ValueError as e:
Warning(e)
UserWarning("Falling back to zarr3 driver")
spec["driver"] = "zarr3"
array_future = tensorstore.open(
spec, read=True, write=False, context=self.context
)
array = array_future.result()
data = xt._TensorStoreAdapter(array)
self._array = xarray.DataArray(data=data, coords=self.full_coords)
return self._array
Expand Down
11 changes: 11 additions & 0 deletions src/cellmap_data/multidataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,3 +249,14 @@ def set_spatial_transforms(
"""Sets the raw value transforms for each dataset in the training multi-dataset."""
for dataset in self.datasets:
dataset.spatial_transforms = spatial_transforms

@staticmethod
def empty() -> "CellMapMultiDataset":
"""Creates an empty dataset."""
empty_dataset = CellMapMultiDataset([], {}, {}, [CellMapDataset.empty()])
empty_dataset.classes = []
empty_dataset._class_counts = {}
empty_dataset._class_weights = {}
empty_dataset._validation_indices = []

return empty_dataset

0 comments on commit 0e8a7ff

Please sign in to comment.