Skip to content

Commit

Permalink
Update instance_segmentation.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ariannasole23 authored Jan 21, 2025
1 parent 0fa7b07 commit b4334f0
Showing 1 changed file with 0 additions and 50 deletions.
50 changes: 0 additions & 50 deletions torchgeo/trainers/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,53 +194,3 @@ def predict_step(self, batch: Any, batch_idx: int) -> Tensor:
y_hat: Tensor = self.model(images)
return y_hat















#=================================================================
# TESTING
#=================================================================

def collate_fn(batch):
return tuple(zip(*batch))

train_dataset = VHR10(root="data", split="positive", transforms=None, download=True)
val_dataset = VHR10(root="data", split="positive", transforms=None)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)

task = InstanceSegmentationTask(
model="mask_rcnn",
backbone="resnet50",
weights=True,
num_classes=11,
lr=1e-3,
freeze_backbone=False
)

trainer = pl.Trainer(
max_epochs=10,
accelerator="gpu" if torch.cuda.is_available() else "cpu",
devices=1
)

trainer.fit(task, train_dataloaders=train_loader, val_dataloaders=val_loader)

trainer.test(task, dataloaders=val_loader)

test_sample = train_dataset[0]
test_image = test_sample["image"].unsqueeze(0)
predictions = task.predict_step({"image": test_image}, batch_idx=0)
print(predictions)

0 comments on commit b4334f0

Please sign in to comment.