Skip to content

Commit

Permalink
docs: 📝 Add example training call.
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar committed Mar 21, 2024
1 parent 861dddd commit 3246e1d
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 9 deletions.
41 changes: 41 additions & 0 deletions examples/simple_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# %%
import torch
from cellmap_data import CellMapDataset, CellMapDataLoader
from torchvision.models import resnet18

# %%
# Define the dataset files to use
dataset_dict = {
"train": {"raw": "train_data.zarr/raw", "gt": "train_data.zarr/gt", "weight": 1.0},
"val": {"raw": "val_data.zarr/raw", "gt": "val_data.zarr/gt"},
"test": {"raw": "test_data.zarr/raw", "gt": "test_data.zarr/gt"},
}

# %%
# Create the dataset and dataloader
dataset = CellMapDataset(dataset_dict)
dataloader = CellMapDataLoader(dataset, batch_size=4, shuffle=True, num_workers=0)

# %%
# Create the network
model = resnet18(num_classes=2)

# %%
# Define the loss function and optimizer
loss = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# %%
# Train the network
for epoch in range(10):
for i, data in enumerate(dataloader):
inputs, targets = data
optimizer.zero_grad()
outputs = model(inputs)
loss_value = loss(outputs, targets)
loss_value.backward()
optimizer.step()
print(f"Epoch {epoch}, Batch {i}, Loss {loss_value.item()}")
# %%
# Save the trained model
torch.save(model.state_dict(), "model.pth")
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ test = ["pytest>=6.0", "pytest-cov"]
dev = [
"black",
"ipython",
"jupyter",
"mypy",
"pdbpp",
"pre-commit",
Expand Down
3 changes: 3 additions & 0 deletions src/cellmap_data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,6 @@

__author__ = "Jeff Rhoades"
__email__ = "[email protected]"

from .dataset import CellMapDataset
from .dataloader import CellMapDataLoader
6 changes: 6 additions & 0 deletions src/cellmap_data/dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from torch.utils.data import DataLoader
from cellmap_data.load import transforms


class CellMapDataLoader(DataLoader):
def __init__(self, dataset, batch_size=1, shuffle=False, num_workers=0): ...
25 changes: 16 additions & 9 deletions src/cellmap_data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,27 @@
# %%
import csv
from typing import Dict, Optional
from torch.utils.data import Dataset
import tensorstore as tswift
from fibsem_tools import read, read_xarray


# %%
class CellMapDataset(Dataset):
def __init__(self, ...):
...
def __init__(
self, dataset_dict: Optional[Dict[str, Dict[str, str | float]]] = None
):
self.dataset_dict = dataset_dict
self.construct()

def __len__(self):
...
def __len__(self): ...

def __getitem__(self, idx):
...
def __getitem__(self, idx): ...

def from_csv(self, csv_path):
# Load file data from csv file
dataset_dict = {}
with open(csv_path, 'r') as f:
with open(csv_path, "r") as f:
reader = csv.reader(f)
for row in reader:
if row[0] not in dataset_dict:
Expand All @@ -29,7 +32,11 @@ def from_csv(self, csv_path):
dataset_dict[row[0]]["weight"] = row[3]
else:
dataset_dict[row[0]]["weight"] = 1.0

self.dataset_dict = dataset_dict

self.construct()

def construct(self): ...


# %%

0 comments on commit 3246e1d

Please sign in to comment.