Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cellpose unet model #257

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dacapo/experiments/architectures/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
DummyArchitecture,
) # noqa
from .cnnectome_unet_config import CNNectomeUNetConfig, CNNectomeUNet # noqa
from .cellpose_unet_config import CellposUNetConfig, CellposeUnet # noqa
75 changes: 75 additions & 0 deletions dacapo/experiments/architectures/cellpose_unet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from cellpose.resnet_torch import CPnet
from .architecture import Architecture
from funlib.geometry import Coordinate


# example
# nout = 4
# sz = 3
# self.net = CPnet(
# nbase, nout, sz, mkldnn=False, conv_3D=True, max_pool=True, diam_mean=30.0
# )
# currently the input channels are embedded in nbdase, but they should be passed as a separate parameternbase = [in_chan, 32, 64, 128, 256]
class CellposeUnet(Architecture):
def __init__(self, architecture_config):
super().__init__()
self._input_shape = Coordinate(architecture_config.input_shape)
self._nbase = architecture_config.nbase
self._sz = self._input_shape.dims
self._eval_shape_increase = Coordinate((0,) * self._sz)
self._nout = architecture_config.nout
print("conv_3D:", architecture_config.conv_3D)
self.unet = CPnet(
architecture_config.nbase,
architecture_config.nout,
self._sz,
architecture_config.mkldnn,
architecture_config.conv_3D,
architecture_config.max_pool,
architecture_config.diam_mean,
)
print(self.unet)

def forward(self, data):
"""
Forward pass of the CPnet model.

Args:
data (torch.Tensor): Input data.

Returns:
tuple: A tuple containing the output tensor, style tensor, and downsampled tensors.
"""
if self.unet.mkldnn:
data = data.to_mkldnn()
T0 = self.unet.downsample(data)
if self.unet.mkldnn:
style = self.unet.make_style(T0[-1].to_dense())
else:
style = self.unet.make_style(T0[-1])
# style0 = style
if not self.unet.style_on:
style = style * 0
T1 = self.unet.upsample(style, T0, self.unet.mkldnn)
# head layer
# T1 = self.unet.output(T1)
if self.unet.mkldnn:
T0 = [t0.to_dense() for t0 in T0]
T1 = T1.to_dense()
return T1

@property
def input_shape(self):
return self._input_shape

@property
def num_in_channels(self) -> int:
return self._nbase[0]

@property
def num_out_channels(self) -> int:
return self._nout

@property
def eval_shape_increase(self):
return self._eval_shape_increase
41 changes: 41 additions & 0 deletions dacapo/experiments/architectures/cellpose_unet_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import attr

from .architecture_config import ArchitectureConfig
from .cellpose_unet import CellposeUnet

from funlib.geometry import Coordinate

from typing import List, Optional


@attr.s
class CellposUNetConfig(ArchitectureConfig):
"""This class configures the CellPose based on
https://github.com/MouseLand/cellpose/blob/main/cellpose/resnet_torch.py
"""

architecture_type = CellposeUnet

input_shape: Coordinate = attr.ib(
metadata={
"help_text": "The shape of the data passed into the network during training."
}
)
nbase: List[int] = attr.ib(
metadata={
"help_text": "List of integers representing the number of channels in each layer of the downsample path."
}
)
nout: int = attr.ib(metadata={"help_text": "Number of output channels."})
mkldnn: Optional[bool] = attr.ib(
default=False, metadata={"help_text": "Whether to use MKL-DNN acceleration."}
)
conv_3D: bool = attr.ib(
default=False, metadata={"help_text": "Whether to use 3D convolution."}
)
max_pool: Optional[bool] = attr.ib(
default=True, metadata={"help_text": "Whether to use max pooling."}
)
diam_mean: Optional[float] = attr.ib(
default=30.0, metadata={"help_text": "Mean diameter of the cells."}
)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ dependencies = [
"scipy",
"upath",
"boto3",
"cellpose",
]

# extras
Expand Down
4 changes: 2 additions & 2 deletions tests/fixtures/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from .db import options
from .architectures import dummy_architecture
from .architectures import dummy_architecture, cellpose_architecture
from .arrays import dummy_array, zarr_array, cellmap_array
from .datasplits import dummy_datasplit, twelve_class_datasplit, six_class_datasplit
from .evaluators import binary_3_channel_evaluator
from .losses import dummy_loss
from .post_processors import argmax, threshold
from .predictors import distance_predictor, onehot_predictor
from .runs import dummy_run, distance_run, onehot_run
from .runs import dummy_run, distance_run, onehot_run, cellpose_run
from .tasks import dummy_task, distance_task, onehot_task
from .trainers import dummy_trainer, gunpowder_trainer
14 changes: 13 additions & 1 deletion tests/fixtures/architectures.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dacapo.experiments.architectures import DummyArchitectureConfig
from dacapo.experiments.architectures import DummyArchitectureConfig, CellposUNetConfig

import pytest

Expand All @@ -8,3 +8,15 @@ def dummy_architecture():
yield DummyArchitectureConfig(
name="dummy_architecture", num_in_channels=1, num_out_channels=12
)


@pytest.fixture
def cellpose_architecture():
yield CellposUNetConfig(
name="cellpose_architecture",
input_shape=(216, 216, 216),
nbase=[1, 12, 24, 48, 96],
nout=12,
conv_3D=True
# nbase=[1, 32, 64, 128, 256], nout = 32, conv_3D = True
)
18 changes: 18 additions & 0 deletions tests/fixtures/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,21 @@ def onehot_run(
repetition=0,
num_iterations=100,
)


@pytest.fixture()
def cellpose_run(
dummy_datasplit,
cellpose_architecture,
dummy_task,
dummy_trainer,
):
yield RunConfig(
name="cellpose_run",
task_config=dummy_task,
architecture_config=cellpose_architecture,
trainer_config=dummy_trainer,
datasplit_config=dummy_datasplit,
repetition=0,
num_iterations=100,
)
61 changes: 61 additions & 0 deletions tests/operations/test_cellpose.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import numpy as np
from dacapo.store.create_store import create_stats_store
from ..fixtures import *

from dacapo.experiments import Run
from dacapo.store.create_store import create_config_store, create_weights_store
from dacapo.train import train_run
from pytest_lazy_fixtures import lf
import pytest

import logging

logging.basicConfig(level=logging.INFO)


# skip the test for the Apple Paravirtual device
# that does not support Metal 2.0
@pytest.mark.filterwarnings("ignore:.*Metal 2.0.*:UserWarning")
@pytest.mark.parametrize(
"run_config",
[
lf("cellpose_run"),
],
)
def test_train(
run_config,
):
print("Test train")
# create a store

store = create_config_store()
stats_store = create_stats_store()
weights_store = create_weights_store()

# store the configs

store.store_run_config(run_config)
run = Run(run_config)
print("Run created ")
print(run.model)

# # -------------------------------------

# # train

# weights_store.store_weights(run, 0)
# print("Weights stored")
# train_run(run)

# init_weights = weights_store.retrieve_weights(run.name, 0)
# final_weights = weights_store.retrieve_weights(run.name, run.train_until)

# for name, weight in init_weights.model.items():
# weight_diff = (weight - final_weights.model[name]).sum()
# assert abs(weight_diff) > np.finfo(weight_diff.numpy().dtype).eps, weight_diff

# # assert train_stats and validation_scores are available

# training_stats = stats_store.retrieve_training_stats(run_config.name)

# assert training_stats.trained_until() == run_config.num_iterations
Loading