ClassificationTask example #789
Replies: 3 comments 10 replies
-
What happens when you change |
Beta Was this translation helpful? Give feedback.
-
I was able to train and test fine for 50 epochs with the following code and installing the latest torchgeo v0.3.1 via import pytorch_lightning as pl
from torchgeo.datasets import UCMerced
from torchgeo.datamodules import UCMercedDataModule
from torchgeo.trainers import ClassificationTask
# Parameters
data_dir = "./data"
num_classes = 21
channels = 3
batch_size = 4
num_workers = 2
backbone = "resnet18"
weights = "imagenet"
lr = 0.01
lr_schedule_patience = 5
epochs = 50
# Download dataset and dataset splits
dataset = UCMerced(data_dir, download=True, checksum=True)
# Instantiate datamodule, classifier task, and callbacks
datamodule = UCMercedDataModule(
root_dir=data_dir,
batch_size=batch_size,
num_workers=num_workers,
)
task = ClassificationTask(
classification_model=backbone,
weights=weights,
num_classes=num_classes,
in_channels=channels,
loss="ce",
learning_rate=lr,
learning_rate_schedule_patience=lr_schedule_patience
)
checkpoint_callback = pl.callbacks.ModelCheckpoint(
monitor="val_loss",
save_top_k=1,
save_last=True,
)
early_stopping_callback = pl.callbacks.EarlyStopping(
monitor="val_loss",
min_delta=0.00,
patience=10,
)
# Train
trainer = pl.Trainer(
callbacks=[checkpoint_callback, early_stopping_callback],
max_epochs=epochs,
gpus=1
)
trainer.fit(model=task, datamodule=datamodule)
# Test
test_metrics = trainer.test(model=task, datamodule=datamodule)
"""
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Test metric DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
test_AverageAccuracy 0.7456777691841125
test_F1Score 0.7523809671401978
test_JaccardIndex 0.6187251210212708
test_OverallAccuracy 0.7523809671401978
test_loss 0.9027961492538452
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
""" Here are the tensorboard train/val loss and overall acc plots over steps: |
Beta Was this translation helpful? Give feedback.
-
Aw man, @isaaccorley beat me to it by a few minutes, but I crushed his score :)... Here's mine (lower LR is better) -- https://gist.github.com/calebrob6/ebfbf202977b2f8f27409ddeaac479d2
|
Beta Was this translation helpful? Give feedback.
-
Hi
I wish to create an example notebook training a ClassificationTask on the Merced dataset using pytorch lightning, and adapted the approach in https://torchgeo.readthedocs.io/en/latest/tutorials/trainers.html without success.
I have:
However training proceeds for only 1 epoch then exits. Can anyone steer me in the right direction?
Beta Was this translation helpful? Give feedback.
All reactions