diff --git a/README.md b/README.md index 55184f3..76844b2 100755 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ This is a repository for the following paper: -Yonetani, Taniai, Barekatain, Nishimura, Kanezaki, "Path Planning using Neural A\* Search", ICML, 2021 [[paper]](https://arxiv.org/abs/2009.07476) [[project page]](https://omron-sinicx.github.io/neural-astar/) +Ryo Yonetani*, Tatsunori Taniai*, Mohammadamin Barekatain, Mai Nishimura, Asako Kanezaki, "Path Planning using Neural A\* Search", ICML, 2021 [[paper]](https://arxiv.org/abs/2009.07476) [[project page]](https://omron-sinicx.github.io/neural-astar/) ## TL;DR diff --git a/example.ipynb b/example.ipynb index 9c806bd..438493f 100644 --- a/example.ipynb +++ b/example.ipynb @@ -5,7 +5,9 @@ "id": "5c38c3d8-b2af-41fd-ba2b-8d482b540535", "metadata": {}, "source": [ - "### (optional) Install Neural A* on Colab" + "### (Optional) Install Neural A* on Colab\n", + "\n", + "We highly recommend the use of GPUs for faster training/planning." ] }, { @@ -44,7 +46,6 @@ "metadata": {}, "outputs": [], "source": [ - "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import torch\n", "from tqdm import tqdm\n", @@ -64,7 +65,7 @@ "outputs": [], "source": [ "neural_astar = NeuralAstar(encoder_arch='CNN').to(device)\n", - "neural_astar.load_state_dict(torch.load(\"data/cnn_mazes.pt\"))\n", + "neural_astar.load_state_dict(torch.load(\"data/cnn_mazes.pt\", map_location=torch.device(device)))\n", "\n", "vanilla_astar = VanillaAstar().to(device)" ] @@ -87,7 +88,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 5/5 [00:04<00:00, 1.24it/s]\n" + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:17<00:00, 3.55s/it]\n" ] }, { @@ -442,7 +443,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -456,7 +457,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.10" + "version": "3.9.0" } }, "nbformat": 4,