diff --git a/docs/source/tutorials/ar.ipynb b/docs/source/tutorials/ar.ipynb index b07b8e44..ae76f8ef 100644 --- a/docs/source/tutorials/ar.ipynb +++ b/docs/source/tutorials/ar.ipynb @@ -231,7 +231,8 @@ ], "source": [ "# calculate baseline absolute error\n", - "actuals = torch.cat([y[0] for x, y in iter(val_dataloader)])\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "actuals = torch.cat([y[0] for x, y in iter(val_dataloader)]).to(device)\n", "baseline_predictions = Baseline().predict(val_dataloader)\n", "SMAPE()(baseline_predictions, actuals)" ] @@ -519,7 +520,7 @@ } ], "source": [ - "actuals = torch.cat([y[0] for x, y in iter(val_dataloader)])\n", + "actuals = torch.cat([y[0] for x, y in iter(val_dataloader)]).to(device)\n", "predictions = best_model.predict(val_dataloader)\n", "(actuals - predictions).abs().mean()" ] @@ -538,7 +539,7 @@ "metadata": {}, "outputs": [], "source": [ - "raw_predictions, x = best_model.predict(val_dataloader, mode=\"raw\", return_x=True)" + "raw_predictions = best_model.predict(val_dataloader, mode=\"raw\", return_x=True)" ] }, { @@ -649,7 +650,7 @@ ], "source": [ "for idx in range(10): # plot 10 examples\n", - " best_model.plot_prediction(x, raw_predictions, idx=idx, add_loss_to_title=True)" + " best_model.plot_prediction(raw_predictions.x, raw_predictions.output, idx=idx, add_loss_to_title=True)" ] }, { @@ -797,7 +798,7 @@ ], "source": [ "for idx in range(10): # plot 10 examples\n", - " best_model.plot_interpretation(x, raw_predictions, idx=idx)" + " best_model.plot_interpretation(raw_predictions.x, raw_predictions.output, idx=idx)" ] }, {