From b4334f03dd47279958404be16fb45176152c62b8 Mon Sep 17 00:00:00 2001 From: Arianna Sole Date: Tue, 21 Jan 2025 14:44:59 +0100 Subject: [PATCH] Update instance_segmentation.py --- torchgeo/trainers/instance_segmentation.py | 50 ---------------------- 1 file changed, 50 deletions(-) diff --git a/torchgeo/trainers/instance_segmentation.py b/torchgeo/trainers/instance_segmentation.py index d28f47f610d..820c61e8a9d 100644 --- a/torchgeo/trainers/instance_segmentation.py +++ b/torchgeo/trainers/instance_segmentation.py @@ -194,53 +194,3 @@ def predict_step(self, batch: Any, batch_idx: int) -> Tensor: y_hat: Tensor = self.model(images) return y_hat - - - - - - - - - - - - - - -#================================================================= -# TESTING -#================================================================= - -def collate_fn(batch): - return tuple(zip(*batch)) - -train_dataset = VHR10(root="data", split="positive", transforms=None, download=True) -val_dataset = VHR10(root="data", split="positive", transforms=None) - -train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn) -val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn) - -task = InstanceSegmentationTask( - model="mask_rcnn", - backbone="resnet50", - weights=True, - num_classes=11, - lr=1e-3, - freeze_backbone=False -) - -trainer = pl.Trainer( - max_epochs=10, - accelerator="gpu" if torch.cuda.is_available() else "cpu", - devices=1 -) - -trainer.fit(task, train_dataloaders=train_loader, val_dataloaders=val_loader) - -trainer.test(task, dataloaders=val_loader) - -test_sample = train_dataset[0] -test_image = test_sample["image"].unsqueeze(0) -predictions = task.predict_step({"image": test_image}, batch_idx=0) -print(predictions)