Skip to content

Commit

Permalink
fix: 🐛 Bugs and docs
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar committed Jul 26, 2024
1 parent 2c32ed6 commit 68c11d0
Show file tree
Hide file tree
Showing 8 changed files with 39 additions and 22 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
[![PyPI](https://img.shields.io/pypi/v/cellmap-data.svg?color=green)](https://pypi.org/project/cellmap-data)
[![Python Version](https://img.shields.io/pypi/pyversions/cellmap-data.svg?color=green)](https://python.org) -->

[![Build Docs](https://github.com/janelia-cellmap/cellmap-data/actions/workflows/docs.yml/badge.svg?branch=main)](https://janelia-cellmap.github.io/cellmap-data/)

Utility for loading CellMap data for machine learning training, utilizing PyTorch, Xarray, TensorStore, and PyDantic.

You can select classes to load to construct targets separately from the labels you want to predict. This allows you to train a model to predict a subset of labels, while still using all labels to construct the target from true negatives as well as true positives.
Expand Down
7 changes: 4 additions & 3 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
html_title = "CellMap-Data"
html_logo = "https://raw.githubusercontent.com/janelia-cellmap/cellmap-data/main/docs/source/_static/CellMapLogo.png"
html_favicon = "https://raw.githubusercontent.com/janelia-cellmap/cellmap-data/main/docs/source/_static/favicon.ico"
# html_theme_options = {
# # "show_navbar_depth": 3,
# }
html_theme_options = {
# "show_navbar_depth": 3,
"home_page_in_toc": True,
}
7 changes: 7 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@ CellMap-Data: the Docs

This library provides a collection of classes and functions for working with cellmap data, specifically for the CellMap project team training machine learning models. The capabilities include loading data from the CellMap ZarrDataset, transforming data, and splitting data into training and validation sets. Functionality is not provided for writing data.


Contents
==============
.. autosummary::

:recursive:
:toctree:

Expand All @@ -25,6 +29,9 @@ This library provides a collection of classes and functions for working with cel
cellmap_data.utils


.. include:: ../../README.md
:parser: recommonmark


Links
==================
Expand Down
10 changes: 6 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,16 @@ dependencies = [
"xarray_ome_ngff",
"tensorstore",
"xarray-tensorstore @ git+https://github.com/google/xarray-tensorstore.git",
"universal_pathlib>=0.2.0",
"fsspec[s3,http]",
"cellpose",
"py_distance_transforms",
# "py_distance_transforms",
]

# extras
# https://peps.python.org/pep-0621/#dependencies-optional-dependencies
[project.optional-dependencies]
test = ["pytest>=6.0", "pytest-cov"]
test = ["pytest>=6.0", "pytest-cov", "mypy"]
dev = [
"black",
"ipython",
Expand All @@ -56,7 +58,6 @@ dev = [
"snakeviz",
"sphinx",
"sphinx-book-theme",

]
all = [
"cellmap-data[dev,test]",
Expand Down Expand Up @@ -135,12 +136,13 @@ filterwarnings = ["error"]

# https://mypy.readthedocs.io/en/stable/config_file.html
[tool.mypy]
files = "src/**/"
files = "src/cellmap_data"
strict = true
disallow_any_generics = false
disallow_subclassing_any = false
show_error_codes = true
pretty = true
ignore_missing_imports = true

# # module specific overrides
# [[tool.mypy.overrides]]
Expand Down
12 changes: 6 additions & 6 deletions src/cellmap_data/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,12 +411,12 @@ def apply_spatial_transforms(self, coords) -> torch.Tensor:
# Apply and spatial transformations that require the image array (e.g. transpose)
if self._current_spatial_transforms is not None:
for transform, params in self._current_spatial_transforms.items():
if transform in self.post_image_transforms:
if transform == "transpose":
new_order = [params[c] for c in self.axes]
data = np.transpose(data, new_order)
else:
raise ValueError(f"Unknown spatial transform: {transform}")
# if transform in self.post_image_transforms:
if transform == "transpose":
new_order = [params[c] for c in self.axes]
data = np.transpose(data, new_order)
# else:
# raise ValueError(f"Unknown spatial transform: {transform}")

return torch.tensor(data)

Expand Down
1 change: 1 addition & 0 deletions src/cellmap_data/transforms/targets/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .distance import DistanceTransform, SignedDistanceTransform
from .cellpose import CellposeFlow
6 changes: 5 additions & 1 deletion src/cellmap_data/transforms/targets/cellpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,8 @@ def __init__(self, ndim: int, device: str | None = None):
self.device = _device

def __call__(self, masks):
return self.flows_func(masks)
flows, centers = self.flows_func((masks > 0).squeeze().numpy())
flows = flows[None, ...]
if self.ndim == 2:
flows = flows[None, ...]
return torch.tensor(flows)
16 changes: 8 additions & 8 deletions src/cellmap_data/transforms/targets/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ def _transform(self, x: torch.Tensor):
def forward(self, x: torch.Tensor):
"""Forward pass."""
# TODO: Need to figure out how to prevent having inaccurate distance values at the edges --> precompute
# distance = self._transform(x[b, class_ind].nan_to_num(0))
distance = self._transform(x[b, class_ind])
distance[x[b, class_ind].isnan()] = torch.nan
x[b, class_ind] = distance
# distance = self._transform(x.nan_to_num(0))
distance = self._transform(x)
distance[x.isnan()] = torch.nan
x = distance
return x


Expand Down Expand Up @@ -99,8 +99,8 @@ def _transform(self, x: torch.Tensor):
def forward(self, x: torch.Tensor):
"""Forward pass."""
# TODO: Need to figure out how to prevent having inaccurate distance values at the edges --> precompute
# distance = self._transform(x[b, class_ind].nan_to_num(0))
distance = self._transform(x[b, class_ind])
distance[x[b, class_ind].isnan()] = torch.nan
x[b, class_ind] = distance
# distance = self._transform(x.nan_to_num(0))
distance = self._transform(x)
distance[x.isnan()] = torch.nan
x = distance
return x

0 comments on commit 68c11d0

Please sign in to comment.