Skip to content

Commit

Permalink
Update naming and refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
rudolphpienaar committed Apr 19, 2024
1 parent ab9364b commit dd0ecda
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions spleenseg/plotting/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,28 @@ def plot_imageAndLabel(
plt.savefig(str(savefile))


def plot_trainingMetrics(training: data.TrainingLog, savefile: Path) -> None:
def plot_trainingMetrics(
log: data.TrainingLog, training: data.TrainingParams, savefile: Path
) -> None:
plt.figure("train", (12, 6))
plt.subplot(1, 2, 1)
plt.title("Epoch Average Loss")
x = [i + 1 for i in range(len(training.loss_per_epoch))]
y = training.loss_per_epoch
x = [i + 1 for i in range(len(log.loss_per_epoch))]
y = log.loss_per_epoch
plt.xlabel("epoch")
plt.plot(x, y)
plt.subplot(1, 2, 2)
plt.title("Val Mean Dice")
x = [training.val_interval * (i + 1) for i in range(len(training.metric_values))]
y = training.metric_values
x = [training.val_interval * (i + 1) for i in range(len(log.metric_per_epoch))]
y = log.metric_per_epoch
plt.xlabel("epoch")
plt.plot(x, y)
plt.savefig(str(savefile))


def plot_IODo(input: dict[str, torch.Tensor], output: torch.Tensor, title: str) -> None:
def plot_bestModelOnValidate(
input: dict[str, torch.Tensor], output: torch.Tensor, title: str, savefile: Path
) -> None:
plt.figure("check", (18, 6))
plt.subplot(1, 3, 1)
plt.title(f"image {title}")
Expand All @@ -48,4 +52,4 @@ def plot_IODo(input: dict[str, torch.Tensor], output: torch.Tensor, title: str)
plt.subplot(1, 3, 3)
plt.title(f"output {title}")
plt.imshow(torch.argmax(output, dim=1).detach().cpu()[0, :, :, 80])
plt.show()
plt.savefig(str(savefile))

0 comments on commit dd0ecda

Please sign in to comment.