Skip to content

Commit

Permalink
fix api bugs and different devices calculation bug
Browse files Browse the repository at this point in the history
  • Loading branch information
chenr86 authored and jdb78 committed Sep 10, 2023
1 parent c4b1349 commit 25e0a28
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions docs/source/tutorials/ar.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,8 @@
],
"source": [
"# calculate baseline absolute error\n",
"actuals = torch.cat([y[0] for x, y in iter(val_dataloader)])\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"actuals = torch.cat([y[0] for x, y in iter(val_dataloader)]).to(device)\n",
"baseline_predictions = Baseline().predict(val_dataloader)\n",
"SMAPE()(baseline_predictions, actuals)"
]
Expand Down Expand Up @@ -519,7 +520,7 @@
}
],
"source": [
"actuals = torch.cat([y[0] for x, y in iter(val_dataloader)])\n",
"actuals = torch.cat([y[0] for x, y in iter(val_dataloader)]).to(device)\n",
"predictions = best_model.predict(val_dataloader)\n",
"(actuals - predictions).abs().mean()"
]
Expand All @@ -538,7 +539,7 @@
"metadata": {},
"outputs": [],
"source": [
"raw_predictions, x = best_model.predict(val_dataloader, mode=\"raw\", return_x=True)"
"raw_predictions = best_model.predict(val_dataloader, mode=\"raw\", return_x=True)"
]
},
{
Expand Down Expand Up @@ -649,7 +650,7 @@
],
"source": [
"for idx in range(10): # plot 10 examples\n",
" best_model.plot_prediction(x, raw_predictions, idx=idx, add_loss_to_title=True)"
" best_model.plot_prediction(raw_predictions.x, raw_predictions.output, idx=idx, add_loss_to_title=True)"
]
},
{
Expand Down Expand Up @@ -797,7 +798,7 @@
],
"source": [
"for idx in range(10): # plot 10 examples\n",
" best_model.plot_interpretation(x, raw_predictions, idx=idx)"
" best_model.plot_interpretation(raw_predictions.x, raw_predictions.output, idx=idx)"
]
},
{
Expand Down

0 comments on commit 25e0a28

Please sign in to comment.