diff --git a/examples/simple_train.py b/examples/simple_train.py new file mode 100644 index 0000000..a4bf502 --- /dev/null +++ b/examples/simple_train.py @@ -0,0 +1,41 @@ +# %% +import torch +from cellmap_data import CellMapDataset, CellMapDataLoader +from torchvision.models import resnet18 + +# %% +# Define the dataset files to use +dataset_dict = { + "train": {"raw": "train_data.zarr/raw", "gt": "train_data.zarr/gt", "weight": 1.0}, + "val": {"raw": "val_data.zarr/raw", "gt": "val_data.zarr/gt"}, + "test": {"raw": "test_data.zarr/raw", "gt": "test_data.zarr/gt"}, +} + +# %% +# Create the dataset and dataloader +dataset = CellMapDataset(dataset_dict) +dataloader = CellMapDataLoader(dataset, batch_size=4, shuffle=True, num_workers=0) + +# %% +# Create the network +model = resnet18(num_classes=2) + +# %% +# Define the loss function and optimizer +loss = torch.nn.MSELoss() +optimizer = torch.optim.Adam(model.parameters(), lr=0.001) + +# %% +# Train the network +for epoch in range(10): + for i, data in enumerate(dataloader): + inputs, targets = data + optimizer.zero_grad() + outputs = model(inputs) + loss_value = loss(outputs, targets) + loss_value.backward() + optimizer.step() + print(f"Epoch {epoch}, Batch {i}, Loss {loss_value.item()}") +# %% +# Save the trained model +torch.save(model.state_dict(), "model.pth") diff --git a/pyproject.toml b/pyproject.toml index 0e7f05d..9259fd0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ test = ["pytest>=6.0", "pytest-cov"] dev = [ "black", "ipython", + "jupyter", "mypy", "pdbpp", "pre-commit", diff --git a/src/cellmap_data/__init__.py b/src/cellmap_data/__init__.py index 7419624..ec9266f 100644 --- a/src/cellmap_data/__init__.py +++ b/src/cellmap_data/__init__.py @@ -14,3 +14,6 @@ __author__ = "Jeff Rhoades" __email__ = "rhoadesj@hhmi.org" + +from .dataset import CellMapDataset +from .dataloader import CellMapDataLoader diff --git a/src/cellmap_data/dataloader.py b/src/cellmap_data/dataloader.py new file mode 100644 index 0000000..194a2c7 --- /dev/null +++ b/src/cellmap_data/dataloader.py @@ -0,0 +1,6 @@ +from torch.utils.data import DataLoader +from cellmap_data.load import transforms + + +class CellMapDataLoader(DataLoader): + def __init__(self, dataset, batch_size=1, shuffle=False, num_workers=0): ... diff --git a/src/cellmap_data/dataset.py b/src/cellmap_data/dataset.py index f75404e..432683f 100644 --- a/src/cellmap_data/dataset.py +++ b/src/cellmap_data/dataset.py @@ -1,24 +1,27 @@ # %% import csv +from typing import Dict, Optional from torch.utils.data import Dataset import tensorstore as tswift +from fibsem_tools import read, read_xarray # %% class CellMapDataset(Dataset): - def __init__(self, ...): - ... + def __init__( + self, dataset_dict: Optional[Dict[str, Dict[str, str | float]]] = None + ): + self.dataset_dict = dataset_dict + self.construct() - def __len__(self): - ... + def __len__(self): ... - def __getitem__(self, idx): - ... + def __getitem__(self, idx): ... def from_csv(self, csv_path): # Load file data from csv file dataset_dict = {} - with open(csv_path, 'r') as f: + with open(csv_path, "r") as f: reader = csv.reader(f) for row in reader: if row[0] not in dataset_dict: @@ -29,7 +32,11 @@ def from_csv(self, csv_path): dataset_dict[row[0]]["weight"] = row[3] else: dataset_dict[row[0]]["weight"] = 1.0 - + self.dataset_dict = dataset_dict - + self.construct() + + def construct(self): ... + + # %%