Skip to content

Commit

Permalink
Deploying to gh-pages from @ 26fb98e 🚀
Browse files Browse the repository at this point in the history
  • Loading branch information
younik committed Feb 1, 2024
1 parent ac6d2b2 commit 60543bf
Show file tree
Hide file tree
Showing 47 changed files with 77 additions and 148 deletions.
2 changes: 1 addition & 1 deletion main/.buildinfo
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Sphinx build info version 1
# This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done.
config: 1978296481ddd170ce69a8416903ec8e
config: b9d761ee5ebfc4251fffc3fe599f90db
tags: d77d1c0d9ca2f4c8421862c7c5a0d620
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# %%%
# We present here how to perform behavioral cloning on a Minari dataset using `PyTorch <https://pytorch.org/>`_.
# We will start generating the dataset of the expert policy for the `CartPole-v1 <https://gymnasium.farama.org/environments/classic_control/cart_pole/>`_ environment, which is a classic control problem.
# The objective is to balance the pole on the cart, and we receive a reward of +1 for each successful timestep.
# The objective is to balance the pole on the cart, and we receive a reward of +1 for each successful step.

# %%
# Imports
Expand Down Expand Up @@ -108,7 +108,7 @@ def collate_fn(batch):
return {
"id": torch.Tensor([x.id for x in batch]),
"seed": torch.Tensor([x.seed for x in batch]),
"total_timesteps": torch.Tensor([x.total_timesteps for x in batch]),
"total_steps": torch.Tensor([x.total_steps for x in batch]),
"observations": torch.nn.utils.rnn.pad_sequence(
[torch.as_tensor(x.observations) for x in batch],
batch_first=True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"We present here how to perform behavioral cloning on a Minari dataset using [PyTorch](https://pytorch.org/).\nWe will start generating the dataset of the expert policy for the [CartPole-v1](https://gymnasium.farama.org/environments/classic_control/cart_pole/) environment, which is a classic control problem.\nThe objective is to balance the pole on the cart, and we receive a reward of +1 for each successful timestep.\n\n"
"We present here how to perform behavioral cloning on a Minari dataset using [PyTorch](https://pytorch.org/).\nWe will start generating the dataset of the expert policy for the [CartPole-v1](https://gymnasium.farama.org/environments/classic_control/cart_pole/) environment, which is a classic control problem.\nThe objective is to balance the pole on the cart, and we receive a reward of +1 for each successful step.\n\n"
]
},
{
Expand Down Expand Up @@ -126,7 +126,7 @@
},
"outputs": [],
"source": [
"def collate_fn(batch):\n return {\n \"id\": torch.Tensor([x.id for x in batch]),\n \"seed\": torch.Tensor([x.seed for x in batch]),\n \"total_timesteps\": torch.Tensor([x.total_timesteps for x in batch]),\n \"observations\": torch.nn.utils.rnn.pad_sequence(\n [torch.as_tensor(x.observations) for x in batch],\n batch_first=True\n ),\n \"actions\": torch.nn.utils.rnn.pad_sequence(\n [torch.as_tensor(x.actions) for x in batch],\n batch_first=True\n ),\n \"rewards\": torch.nn.utils.rnn.pad_sequence(\n [torch.as_tensor(x.rewards) for x in batch],\n batch_first=True\n ),\n \"terminations\": torch.nn.utils.rnn.pad_sequence(\n [torch.as_tensor(x.terminations) for x in batch],\n batch_first=True\n ),\n \"truncations\": torch.nn.utils.rnn.pad_sequence(\n [torch.as_tensor(x.truncations) for x in batch],\n batch_first=True\n )\n }"
"def collate_fn(batch):\n return {\n \"id\": torch.Tensor([x.id for x in batch]),\n \"seed\": torch.Tensor([x.seed for x in batch]),\n \"total_steps\": torch.Tensor([x.total_steps for x in batch]),\n \"observations\": torch.nn.utils.rnn.pad_sequence(\n [torch.as_tensor(x.observations) for x in batch],\n batch_first=True\n ),\n \"actions\": torch.nn.utils.rnn.pad_sequence(\n [torch.as_tensor(x.actions) for x in batch],\n batch_first=True\n ),\n \"rewards\": torch.nn.utils.rnn.pad_sequence(\n [torch.as_tensor(x.rewards) for x in batch],\n batch_first=True\n ),\n \"terminations\": torch.nn.utils.rnn.pad_sequence(\n [torch.as_tensor(x.terminations) for x in batch],\n batch_first=True\n ),\n \"truncations\": torch.nn.utils.rnn.pad_sequence(\n [torch.as_tensor(x.truncations) for x in batch],\n batch_first=True\n )\n }"
]
},
{
Expand Down
Binary file not shown.
21 changes: 3 additions & 18 deletions main/_modules/minari/data_collector/data_collector/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -361,12 +361,10 @@ <h1>Source code for minari.data_collector.data_collector</h1><div class="highlig
<span></span><span class="kn">from</span> <span class="nn">__future__</span> <span class="kn">import</span> <span class="n">annotations</span>

<span class="kn">import</span> <span class="nn">copy</span>
<span class="kn">import</span> <span class="nn">inspect</span>
<span class="kn">import</span> <span class="nn">os</span>
<span class="kn">import</span> <span class="nn">secrets</span>
<span class="kn">import</span> <span class="nn">shutil</span>
<span class="kn">import</span> <span class="nn">tempfile</span>
<span class="kn">import</span> <span class="nn">warnings</span>
<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span><span class="p">,</span> <span class="n">Callable</span><span class="p">,</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">SupportsFloat</span><span class="p">,</span> <span class="n">Type</span><span class="p">,</span> <span class="n">Union</span>

<span class="kn">import</span> <span class="nn">gymnasium</span> <span class="k">as</span> <span class="nn">gym</span>
Expand All @@ -382,6 +380,7 @@ <h1>Source code for minari.data_collector.data_collector</h1><div class="highlig
<span class="p">)</span>
<span class="kn">from</span> <span class="nn">minari.dataset.minari_dataset</span> <span class="kn">import</span> <span class="n">MinariDataset</span>
<span class="kn">from</span> <span class="nn">minari.dataset.minari_storage</span> <span class="kn">import</span> <span class="n">MinariStorage</span>
<span class="kn">from</span> <span class="nn">minari.utils</span> <span class="kn">import</span> <span class="n">_generate_dataset_metadata</span><span class="p">,</span> <span class="n">_generate_dataset_path</span>


<span class="c1"># H5Py supports ints up to uint64</span>
Expand All @@ -390,17 +389,6 @@ <h1>Source code for minari.data_collector.data_collector</h1><div class="highlig
<span class="n">EpisodeBuffer</span> <span class="o">=</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]</span> <span class="c1"># TODO: narrow this down</span>


<span class="k">def</span> <span class="fm">__getattr__</span><span class="p">(</span><span class="n">name</span><span class="p">):</span>
<span class="k">if</span> <span class="n">name</span> <span class="o">==</span> <span class="s2">&quot;DataCollectorV0&quot;</span><span class="p">:</span>
<span class="n">stacklevel</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">inspect</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="mi">0</span><span class="p">))</span>
<span class="n">warnings</span><span class="o">.</span><span class="n">warn</span><span class="p">(</span><span class="s2">&quot;DataCollectorV0 is deprecated and will be removed. Use DataCollector instead.&quot;</span><span class="p">,</span> <span class="ne">DeprecationWarning</span><span class="p">,</span> <span class="n">stacklevel</span><span class="o">=</span><span class="n">stacklevel</span><span class="p">)</span>
<span class="k">return</span> <span class="n">DataCollector</span>
<span class="k">elif</span> <span class="n">name</span> <span class="o">==</span> <span class="s2">&quot;__path__&quot;</span><span class="p">:</span>
<span class="k">return</span> <span class="kc">False</span> <span class="c1"># see https://stackoverflow.com/a/60803436</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ImportError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;cannot import name &#39;</span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2">&#39; from &#39;</span><span class="si">{</span><span class="vm">__name__</span><span class="si">}</span><span class="s2">&#39; (</span><span class="si">{</span><span class="vm">__file__</span><span class="si">}</span><span class="s2">)&quot;</span><span class="p">)</span>


<div class="viewcode-block" id="DataCollector">
<a class="viewcode-back" href="../../../../api/data_collector/#minari.DataCollector">[docs]</a>
<span class="k">class</span> <span class="nc">DataCollector</span><span class="p">(</span><span class="n">gym</span><span class="o">.</span><span class="n">Wrapper</span><span class="p">):</span>
Expand Down Expand Up @@ -719,8 +707,6 @@ <h1>Source code for minari.data_collector.data_collector</h1><div class="highlig
<span class="sd"> Returns:</span>
<span class="sd"> MinariDataset</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="c1"># TODO: move the import to top of the file after removing minari.create_dataset_from_collector_env() in 0.5.0</span>
<span class="kn">from</span> <span class="nn">minari.utils</span> <span class="kn">import</span> <span class="n">_generate_dataset_metadata</span><span class="p">,</span> <span class="n">_generate_dataset_path</span>
<span class="n">dataset_path</span> <span class="o">=</span> <span class="n">_generate_dataset_path</span><span class="p">(</span><span class="n">dataset_id</span><span class="p">)</span>
<span class="n">metadata</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]</span> <span class="o">=</span> <span class="n">_generate_dataset_metadata</span><span class="p">(</span>
<span class="n">dataset_id</span><span class="p">,</span>
Expand All @@ -737,7 +723,7 @@ <h1>Source code for minari.data_collector.data_collector</h1><div class="highlig
<span class="n">minari_version</span><span class="p">,</span>
<span class="p">)</span>

<span class="bp">self</span><span class="o">.</span><span class="n">save_to_disk</span><span class="p">(</span><span class="n">dataset_path</span><span class="p">,</span> <span class="n">metadata</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_save_to_disk</span><span class="p">(</span><span class="n">dataset_path</span><span class="p">,</span> <span class="n">metadata</span><span class="p">)</span>

<span class="c1"># will be able to calculate dataset size only after saving the disk, so updating the dataset metadata post `save_to_disk` method</span>

Expand All @@ -746,7 +732,7 @@ <h1>Source code for minari.data_collector.data_collector</h1><div class="highlig
<span class="n">dataset</span><span class="o">.</span><span class="n">storage</span><span class="o">.</span><span class="n">update_metadata</span><span class="p">(</span><span class="n">metadata</span><span class="p">)</span>
<span class="k">return</span> <span class="n">dataset</span>

<span class="k">def</span> <span class="nf">save_to_disk</span><span class="p">(</span>
<span class="k">def</span> <span class="nf">_save_to_disk</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span> <span class="n">path</span><span class="p">:</span> <span class="nb">str</span> <span class="o">|</span> <span class="n">os</span><span class="o">.</span><span class="n">PathLike</span><span class="p">,</span> <span class="n">dataset_metadata</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]</span> <span class="o">=</span> <span class="p">{}</span>
<span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Save all in-memory buffer data and move temporary files to a permanent location in disk.</span>
Expand All @@ -755,7 +741,6 @@ <h1>Source code for minari.data_collector.data_collector</h1><div class="highlig
<span class="sd"> path (str): path to store the dataset, e.g.: &#39;/home/foo/datasets/data&#39;</span>
<span class="sd"> dataset_metadata (Dict, optional): additional metadata to add to the dataset file. Defaults to {}.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">warnings</span><span class="o">.</span><span class="n">warn</span><span class="p">(</span><span class="s2">&quot;This method is deprecated and will become private in v0.5.0.&quot;</span><span class="p">,</span> <span class="ne">DeprecationWarning</span><span class="p">,</span> <span class="n">stacklevel</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_validate_buffer</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_storage</span><span class="o">.</span><span class="n">update_episodes</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_buffer</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_buffer</span><span class="o">.</span><span class="n">clear</span><span class="p">()</span>
Expand Down
4 changes: 2 additions & 2 deletions main/_modules/minari/dataset/episode_data/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ <h1>Source code for minari.dataset.episode_data</h1><div class="highlight"><pre>

<span class="nb">id</span><span class="p">:</span> <span class="nb">int</span>
<span class="n">seed</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span>
<span class="n">total_timesteps</span><span class="p">:</span> <span class="nb">int</span>
<span class="n">total_steps</span><span class="p">:</span> <span class="nb">int</span>
<span class="n">observations</span><span class="p">:</span> <span class="n">Any</span>
<span class="n">actions</span><span class="p">:</span> <span class="n">Any</span>
<span class="n">rewards</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span>
Expand All @@ -388,7 +388,7 @@ <h1>Source code for minari.dataset.episode_data</h1><div class="highlight"><pre>
<span class="s2">&quot;EpisodeData(&quot;</span>
<span class="sa">f</span><span class="s2">&quot;id=</span><span class="si">{</span><span class="nb">repr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">id</span><span class="p">)</span><span class="si">}</span><span class="s2">, &quot;</span>
<span class="sa">f</span><span class="s2">&quot;seed=</span><span class="si">{</span><span class="nb">repr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">seed</span><span class="p">)</span><span class="si">}</span><span class="s2">, &quot;</span>
<span class="sa">f</span><span class="s2">&quot;total_timesteps=</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">total_timesteps</span><span class="si">}</span><span class="s2">, &quot;</span>
<span class="sa">f</span><span class="s2">&quot;total_steps=</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">total_steps</span><span class="si">}</span><span class="s2">, &quot;</span>
<span class="sa">f</span><span class="s2">&quot;observations=</span><span class="si">{</span><span class="n">EpisodeData</span><span class="o">.</span><span class="n">_repr_space_values</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">observations</span><span class="p">)</span><span class="si">}</span><span class="s2">, &quot;</span>
<span class="sa">f</span><span class="s2">&quot;actions=</span><span class="si">{</span><span class="n">EpisodeData</span><span class="o">.</span><span class="n">_repr_space_values</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">actions</span><span class="p">)</span><span class="si">}</span><span class="s2">, &quot;</span>
<span class="sa">f</span><span class="s2">&quot;rewards=ndarray of </span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">rewards</span><span class="p">)</span><span class="si">}</span><span class="s2"> floats, &quot;</span>
Expand Down
2 changes: 1 addition & 1 deletion main/_modules/minari/dataset/minari_dataset/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,7 @@ <h1>Source code for minari.dataset.minari_dataset</h1><div class="highlight"><pr
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_total_steps</span> <span class="o">=</span> <span class="nb">sum</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">storage</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span>
<span class="k">lambda</span> <span class="n">episode</span><span class="p">:</span> <span class="n">episode</span><span class="p">[</span><span class="s2">&quot;total_timesteps&quot;</span><span class="p">],</span>
<span class="k">lambda</span> <span class="n">episode</span><span class="p">:</span> <span class="n">episode</span><span class="p">[</span><span class="s2">&quot;total_steps&quot;</span><span class="p">],</span>
<span class="n">episode_indices</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">episode_indices</span><span class="p">,</span>
<span class="p">)</span>
<span class="p">)</span>
Expand Down
Loading

0 comments on commit 60543bf

Please sign in to comment.