From 25e0a28bb85cd6ec8fc6e66c73df29018a9730c5 Mon Sep 17 00:00:00 2001 From: dev <2384178109@qq.com> Date: Wed, 26 Jul 2023 11:46:15 +0800 Subject: [PATCH] fix api bugs and different devices calculation bug --- docs/source/tutorials/ar.ipynb | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/docs/source/tutorials/ar.ipynb b/docs/source/tutorials/ar.ipynb index b07b8e44..ae76f8ef 100644 --- a/docs/source/tutorials/ar.ipynb +++ b/docs/source/tutorials/ar.ipynb @@ -231,7 +231,8 @@ ], "source": [ "# calculate baseline absolute error\n", - "actuals = torch.cat([y[0] for x, y in iter(val_dataloader)])\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "actuals = torch.cat([y[0] for x, y in iter(val_dataloader)]).to(device)\n", "baseline_predictions = Baseline().predict(val_dataloader)\n", "SMAPE()(baseline_predictions, actuals)" ] @@ -519,7 +520,7 @@ } ], "source": [ - "actuals = torch.cat([y[0] for x, y in iter(val_dataloader)])\n", + "actuals = torch.cat([y[0] for x, y in iter(val_dataloader)]).to(device)\n", "predictions = best_model.predict(val_dataloader)\n", "(actuals - predictions).abs().mean()" ] @@ -538,7 +539,7 @@ "metadata": {}, "outputs": [], "source": [ - "raw_predictions, x = best_model.predict(val_dataloader, mode=\"raw\", return_x=True)" + "raw_predictions = best_model.predict(val_dataloader, mode=\"raw\", return_x=True)" ] }, { @@ -649,7 +650,7 @@ ], "source": [ "for idx in range(10): # plot 10 examples\n", - " best_model.plot_prediction(x, raw_predictions, idx=idx, add_loss_to_title=True)" + " best_model.plot_prediction(raw_predictions.x, raw_predictions.output, idx=idx, add_loss_to_title=True)" ] }, { @@ -797,7 +798,7 @@ ], "source": [ "for idx in range(10): # plot 10 examples\n", - " best_model.plot_interpretation(x, raw_predictions, idx=idx)" + " best_model.plot_interpretation(raw_predictions.x, raw_predictions.output, idx=idx)" ] }, {