-
Notifications
You must be signed in to change notification settings - Fork 51
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
valhassan
committed
Oct 22, 2024
1 parent
30d2ec3
commit d78bd01
Showing
1 changed file
with
112 additions
and
0 deletions.
There are no files selected for viewing
112 changes: 112 additions & 0 deletions
112
geo_deep_learning/tasks_with_models/segmentation_dofa.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
import numpy as np | ||
import torch | ||
from pathlib import Path | ||
import matplotlib.pyplot as plt | ||
from torch import Tensor | ||
from typing import Any, Callable, Dict, List | ||
from lightning.pytorch import LightningModule, LightningDataModule | ||
from torchmetrics.classification import MulticlassJaccardIndex | ||
from torchmetrics.wrappers import ClasswiseWrapper | ||
from models.dofa.dofa_seg import DOFASeg | ||
from tools.script_model import script_model | ||
|
||
class SegmentationDOFA(LightningModule): | ||
def __init__(self, | ||
encoder: str, | ||
pretrained: bool, | ||
image_size: tuple[int, int], | ||
in_channels: int, | ||
num_classes: int, | ||
loss: Callable, | ||
class_labels: List[str] = None, | ||
**kwargs: Any): | ||
super().__init__() | ||
self.save_hyperparameters() | ||
self.num_classes = num_classes | ||
self.model = DOFASeg(encoder, pretrained, image_size, self.num_classes) | ||
self.loss = loss | ||
self.metric= MulticlassJaccardIndex(num_classes=num_classes, average=None, zero_division=np.nan) | ||
self.labels = [str(i) for i in range(num_classes)] if class_labels is None else class_labels | ||
self.classwise_metric = ClasswiseWrapper(self.metric, labels=self.labels) | ||
|
||
def forward(self, image: Tensor) -> Tensor: | ||
return self.model(image) | ||
|
||
def training_step(self, batch: Dict[str, Any], batch_idx: int): | ||
if batch_idx == 0: # Only check on first batch | ||
def check_unused_parameters(): | ||
unused = [] | ||
for name, param in self.model.named_parameters(): | ||
if param.grad is None and param.requires_grad: | ||
unused.append(name) | ||
if unused: | ||
print("Unused parameters:", unused) | ||
return len(unused) | ||
x = batch["image"] | ||
y = batch["label"] | ||
y = y.squeeze(1).long() | ||
y_hat = self(x) | ||
loss = self.loss(y_hat, y) | ||
# Check gradients after backward pass | ||
if batch_idx == 0: | ||
loss.backward(retain_graph=True) | ||
num_unused = check_unused_parameters() | ||
print(f"Total unused parameters: {num_unused}") | ||
|
||
y_hat = y_hat.argmax(dim=1) | ||
self.log('train_loss', loss, | ||
prog_bar=True, logger=True, | ||
on_step=False, on_epoch=True, sync_dist=True, rank_zero_only=True) | ||
return loss | ||
|
||
def validation_step(self, batch, batch_idx): | ||
x = batch["image"] | ||
y = batch["label"] | ||
y = y.squeeze(1).long() | ||
y_hat = self(x) | ||
loss = self.loss(y_hat, y) | ||
y_hat = y_hat.softmax(dim=1).argmax(dim=1) | ||
self.log('val_loss', loss, | ||
prog_bar=True, logger=True, | ||
on_step=False, on_epoch=True, sync_dist=True, rank_zero_only=True) | ||
return y_hat | ||
|
||
def test_step(self, batch, batch_idx): | ||
x = batch["image"] | ||
y = batch["label"] | ||
y = y.squeeze(1).long() | ||
y_hat = self(x) | ||
loss = self.loss(y_hat, y) | ||
y_hat = y_hat.softmax(dim=1).argmax(dim=1) | ||
test_metrics = self.classwise_metric(y_hat, y) | ||
test_metrics["loss"] = loss | ||
self.log_dict(test_metrics, | ||
prog_bar=True, logger=True, | ||
on_step=False, on_epoch=True, sync_dist=True, rank_zero_only=True) | ||
|
||
def on_train_end(self): | ||
if self.trainer.is_global_zero and self.trainer.checkpoint_callback is not None: | ||
best_model_path = self.trainer.checkpoint_callback.best_model_path | ||
if best_model_path: | ||
print(f"Best model path: {best_model_path}") | ||
best_model_dir = Path(best_model_path).parent | ||
best_model_name = Path(best_model_path).stem | ||
best_model_export_path = str(best_model_dir / f"{best_model_name}_scripted.pt") | ||
self.export_model(best_model_path, best_model_export_path, self.trainer.datamodule) | ||
|
||
def export_model(self, checkpoint_path: str, export_path: str, datamodule: LightningDataModule): | ||
input_channels = self.hparams["init_args"]["in_channels"] | ||
map_location = "cuda" | ||
if self.device.type == "cpu": | ||
map_location = "cpu" | ||
best_model = self.__class__.load_from_checkpoint(checkpoint_path, map_location=map_location) | ||
best_model.eval() | ||
|
||
scrpted_model = script_model(best_model.model, datamodule, self.num_classes, from_logits=True) | ||
patch_size = datamodule.patch_size | ||
dummy_input = torch.rand(1, input_channels, *patch_size, device=torch.device(map_location)) | ||
traced_model = torch.jit.trace(scrpted_model, dummy_input) | ||
torch.jit.save(traced_model, export_path) | ||
print(f"Model exported to TorchScript") | ||
|
||
|