From 60543bf577c8d89321745231b0e8c1723a888971 Mon Sep 17 00:00:00 2001 From: younik <42100908+younik@users.noreply.github.com> Date: Thu, 1 Feb 2024 20:29:29 +0000 Subject: [PATCH] =?UTF-8?q?Deploying=20to=20gh-pages=20from=20@=20Farama-F?= =?UTF-8?q?oundation/Minari@26fb98edec72aba9de4dea6811240c690bfadf7d=20?= =?UTF-8?q?=F0=9F=9A=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main/.buildinfo | 2 +- .../tutorials_python.zip | Bin 57251 -> 57239 bytes .../behavioral_cloning.py | 4 +- .../behavioral_cloning.ipynb | 4 +- .../tutorials_jupyter.zip | Bin 77453 -> 77441 bytes .../data_collector/data_collector/index.html | 21 +----- .../minari/dataset/episode_data/index.html | 4 +- .../minari/dataset/minari_dataset/index.html | 2 +- main/_modules/minari/utils/index.html | 60 +----------------- main/api/data_collector/index.html | 16 ++--- .../minari_dataset/episode_data/index.html | 14 ++-- main/content/dataset_standards/index.html | 18 +++--- .../datasets/antmaze/large-diverse/index.html | 2 +- main/datasets/antmaze/large-play/index.html | 2 +- .../antmaze/medium-diverse/index.html | 2 +- main/datasets/antmaze/medium-play/index.html | 2 +- .../datasets/antmaze/umaze-diverse/index.html | 2 +- main/datasets/antmaze/umaze/index.html | 2 +- main/datasets/door/cloned/index.html | 2 +- main/datasets/door/expert/index.html | 2 +- main/datasets/door/human/index.html | 2 +- main/datasets/hammer/cloned/index.html | 2 +- main/datasets/hammer/expert/index.html | 2 +- main/datasets/hammer/human/index.html | 2 +- main/datasets/kitchen/complete/index.html | 2 +- main/datasets/kitchen/mixed/index.html | 2 +- main/datasets/kitchen/partial/index.html | 2 +- .../minigrid/fourrooms-random/index.html | 4 +- main/datasets/minigrid/fourrooms/index.html | 4 +- main/datasets/pen/cloned/index.html | 2 +- main/datasets/pen/expert/index.html | 2 +- main/datasets/pen/human/index.html | 2 +- .../datasets/pointmaze/large-dense/index.html | 2 +- main/datasets/pointmaze/large/index.html | 2 +- .../pointmaze/medium-dense/index.html | 2 +- main/datasets/pointmaze/medium/index.html | 2 +- main/datasets/pointmaze/open-dense/index.html | 2 +- main/datasets/pointmaze/open/index.html | 2 +- .../datasets/pointmaze/umaze-dense/index.html | 2 +- main/datasets/pointmaze/umaze/index.html | 2 +- main/datasets/relocate/cloned/index.html | 2 +- main/datasets/relocate/expert/index.html | 2 +- main/datasets/relocate/human/index.html | 2 +- main/genindex/index.html | 8 ++- main/objects.inv | Bin 1827 -> 1824 bytes main/searchindex.js | 2 +- .../behavioral_cloning/index.html | 4 +- 47 files changed, 77 insertions(+), 148 deletions(-) diff --git a/main/.buildinfo b/main/.buildinfo index 9e8f7cbb..eda4c395 100644 --- a/main/.buildinfo +++ b/main/.buildinfo @@ -1,4 +1,4 @@ # Sphinx build info version 1 # This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done. -config: 1978296481ddd170ce69a8416903ec8e +config: b9d761ee5ebfc4251fffc3fe599f90db tags: d77d1c0d9ca2f4c8421862c7c5a0d620 diff --git a/main/_downloads/315c4c52fb68082a731b192d944e2ede/tutorials_python.zip b/main/_downloads/315c4c52fb68082a731b192d944e2ede/tutorials_python.zip index 2a419d70f34727d75f9d1a03b9e7ece327da4c0e..47636b5f4914641808891a45b3b4f7aaadb608b4 100644 GIT binary patch delta 236 zcmZ3ypLzOzX5IjAW)=|!5XfKbxVci*&zl*Nu6vWub>*p;5q$5_HjMRyPcah@S-?x^51)|@*t6=g=f}V2Qx5$uq=ws+{uae4cH#v1qlKGv?x@i diff --git a/main/_downloads/433fbd4ad5e11d67afb6f95e0ee37d2b/behavioral_cloning.py b/main/_downloads/433fbd4ad5e11d67afb6f95e0ee37d2b/behavioral_cloning.py index 219d2f73..403e6040 100644 --- a/main/_downloads/433fbd4ad5e11d67afb6f95e0ee37d2b/behavioral_cloning.py +++ b/main/_downloads/433fbd4ad5e11d67afb6f95e0ee37d2b/behavioral_cloning.py @@ -5,7 +5,7 @@ # %%% # We present here how to perform behavioral cloning on a Minari dataset using `PyTorch `_. # We will start generating the dataset of the expert policy for the `CartPole-v1 `_ environment, which is a classic control problem. -# The objective is to balance the pole on the cart, and we receive a reward of +1 for each successful timestep. +# The objective is to balance the pole on the cart, and we receive a reward of +1 for each successful step. # %% # Imports @@ -108,7 +108,7 @@ def collate_fn(batch): return { "id": torch.Tensor([x.id for x in batch]), "seed": torch.Tensor([x.seed for x in batch]), - "total_timesteps": torch.Tensor([x.total_timesteps for x in batch]), + "total_steps": torch.Tensor([x.total_steps for x in batch]), "observations": torch.nn.utils.rnn.pad_sequence( [torch.as_tensor(x.observations) for x in batch], batch_first=True diff --git a/main/_downloads/92aa7a29195623bda8ca31d2e959f2a6/behavioral_cloning.ipynb b/main/_downloads/92aa7a29195623bda8ca31d2e959f2a6/behavioral_cloning.ipynb index ccdaba37..25481d12 100644 --- a/main/_downloads/92aa7a29195623bda8ca31d2e959f2a6/behavioral_cloning.ipynb +++ b/main/_downloads/92aa7a29195623bda8ca31d2e959f2a6/behavioral_cloning.ipynb @@ -22,7 +22,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We present here how to perform behavioral cloning on a Minari dataset using [PyTorch](https://pytorch.org/).\nWe will start generating the dataset of the expert policy for the [CartPole-v1](https://gymnasium.farama.org/environments/classic_control/cart_pole/) environment, which is a classic control problem.\nThe objective is to balance the pole on the cart, and we receive a reward of +1 for each successful timestep.\n\n" + "We present here how to perform behavioral cloning on a Minari dataset using [PyTorch](https://pytorch.org/).\nWe will start generating the dataset of the expert policy for the [CartPole-v1](https://gymnasium.farama.org/environments/classic_control/cart_pole/) environment, which is a classic control problem.\nThe objective is to balance the pole on the cart, and we receive a reward of +1 for each successful step.\n\n" ] }, { @@ -126,7 +126,7 @@ }, "outputs": [], "source": [ - "def collate_fn(batch):\n return {\n \"id\": torch.Tensor([x.id for x in batch]),\n \"seed\": torch.Tensor([x.seed for x in batch]),\n \"total_timesteps\": torch.Tensor([x.total_timesteps for x in batch]),\n \"observations\": torch.nn.utils.rnn.pad_sequence(\n [torch.as_tensor(x.observations) for x in batch],\n batch_first=True\n ),\n \"actions\": torch.nn.utils.rnn.pad_sequence(\n [torch.as_tensor(x.actions) for x in batch],\n batch_first=True\n ),\n \"rewards\": torch.nn.utils.rnn.pad_sequence(\n [torch.as_tensor(x.rewards) for x in batch],\n batch_first=True\n ),\n \"terminations\": torch.nn.utils.rnn.pad_sequence(\n [torch.as_tensor(x.terminations) for x in batch],\n batch_first=True\n ),\n \"truncations\": torch.nn.utils.rnn.pad_sequence(\n [torch.as_tensor(x.truncations) for x in batch],\n batch_first=True\n )\n }" + "def collate_fn(batch):\n return {\n \"id\": torch.Tensor([x.id for x in batch]),\n \"seed\": torch.Tensor([x.seed for x in batch]),\n \"total_steps\": torch.Tensor([x.total_steps for x in batch]),\n \"observations\": torch.nn.utils.rnn.pad_sequence(\n [torch.as_tensor(x.observations) for x in batch],\n batch_first=True\n ),\n \"actions\": torch.nn.utils.rnn.pad_sequence(\n [torch.as_tensor(x.actions) for x in batch],\n batch_first=True\n ),\n \"rewards\": torch.nn.utils.rnn.pad_sequence(\n [torch.as_tensor(x.rewards) for x in batch],\n batch_first=True\n ),\n \"terminations\": torch.nn.utils.rnn.pad_sequence(\n [torch.as_tensor(x.terminations) for x in batch],\n batch_first=True\n ),\n \"truncations\": torch.nn.utils.rnn.pad_sequence(\n [torch.as_tensor(x.truncations) for x in batch],\n batch_first=True\n )\n }" ] }, { diff --git a/main/_downloads/a5659940aa3f8f568547d47752a43172/tutorials_jupyter.zip b/main/_downloads/a5659940aa3f8f568547d47752a43172/tutorials_jupyter.zip index 73ff6cc811d0d6b4d321cf88406c62de84ca54a8..47bd9d86205c5fcfed0ebd566cf1df1c00358629 100644 GIT binary patch delta 221 zcmeCZ%hGt4g*U*PnMH&F1Xe6|+$?4p62c5(>=utu1Ti-9t`-*p(h+Lf8!w+$V_*Pb zb%tqrjGCL5_J%QU{y)K7c(T0Y{w_F#dtIzWL*xX1)NGmDvqfoowjQGZ+ddseposwdSyMv* diff --git a/main/_modules/minari/data_collector/data_collector/index.html b/main/_modules/minari/data_collector/data_collector/index.html index 542e1123..b7dafb25 100644 --- a/main/_modules/minari/data_collector/data_collector/index.html +++ b/main/_modules/minari/data_collector/data_collector/index.html @@ -361,12 +361,10 @@

Source code for minari.data_collector.data_collector

from __future__ import annotations import copy -import inspect import os import secrets import shutil import tempfile -import warnings from typing import Any, Callable, Dict, List, Optional, SupportsFloat, Type, Union import gymnasium as gym @@ -382,6 +380,7 @@

Source code for minari.data_collector.data_collector

) from minari.dataset.minari_dataset import MinariDataset from minari.dataset.minari_storage import MinariStorage +from minari.utils import _generate_dataset_metadata, _generate_dataset_path # H5Py supports ints up to uint64 @@ -390,17 +389,6 @@

Source code for minari.data_collector.data_collector

EpisodeBuffer = Dict[str, Any] # TODO: narrow this down -def __getattr__(name): - if name == "DataCollectorV0": - stacklevel = len(inspect.stack(0)) - warnings.warn("DataCollectorV0 is deprecated and will be removed. Use DataCollector instead.", DeprecationWarning, stacklevel=stacklevel) - return DataCollector - elif name == "__path__": - return False # see https://stackoverflow.com/a/60803436 - else: - raise ImportError(f"cannot import name '{name}' from '{__name__}' ({__file__})") - -
[docs] class DataCollector(gym.Wrapper): @@ -719,8 +707,6 @@

Source code for minari.data_collector.data_collector

Returns: MinariDataset """ - # TODO: move the import to top of the file after removing minari.create_dataset_from_collector_env() in 0.5.0 - from minari.utils import _generate_dataset_metadata, _generate_dataset_path dataset_path = _generate_dataset_path(dataset_id) metadata: Dict[str, Any] = _generate_dataset_metadata( dataset_id, @@ -737,7 +723,7 @@

Source code for minari.data_collector.data_collector

minari_version, ) - self.save_to_disk(dataset_path, metadata) + self._save_to_disk(dataset_path, metadata) # will be able to calculate dataset size only after saving the disk, so updating the dataset metadata post `save_to_disk` method @@ -746,7 +732,7 @@

Source code for minari.data_collector.data_collector

dataset.storage.update_metadata(metadata) return dataset - def save_to_disk( + def _save_to_disk( self, path: str | os.PathLike, dataset_metadata: Dict[str, Any] = {} ): """Save all in-memory buffer data and move temporary files to a permanent location in disk. @@ -755,7 +741,6 @@

Source code for minari.data_collector.data_collector

path (str): path to store the dataset, e.g.: '/home/foo/datasets/data' dataset_metadata (Dict, optional): additional metadata to add to the dataset file. Defaults to {}. """ - warnings.warn("This method is deprecated and will become private in v0.5.0.", DeprecationWarning, stacklevel=2) self._validate_buffer() self._storage.update_episodes(self._buffer) self._buffer.clear() diff --git a/main/_modules/minari/dataset/episode_data/index.html b/main/_modules/minari/dataset/episode_data/index.html index 65b93ab9..391e4c76 100644 --- a/main/_modules/minari/dataset/episode_data/index.html +++ b/main/_modules/minari/dataset/episode_data/index.html @@ -375,7 +375,7 @@

Source code for minari.dataset.episode_data

 
     id: int
     seed: Optional[int]
-    total_timesteps: int
+    total_steps: int
     observations: Any
     actions: Any
     rewards: np.ndarray
@@ -388,7 +388,7 @@ 

Source code for minari.dataset.episode_data

             "EpisodeData("
             f"id={repr(self.id)}, "
             f"seed={repr(self.seed)}, "
-            f"total_timesteps={self.total_timesteps}, "
+            f"total_steps={self.total_steps}, "
             f"observations={EpisodeData._repr_space_values(self.observations)}, "
             f"actions={EpisodeData._repr_space_values(self.actions)}, "
             f"rewards=ndarray of {len(self.rewards)} floats, "
diff --git a/main/_modules/minari/dataset/minari_dataset/index.html b/main/_modules/minari/dataset/minari_dataset/index.html
index 65523ef3..eaf75c72 100644
--- a/main/_modules/minari/dataset/minari_dataset/index.html
+++ b/main/_modules/minari/dataset/minari_dataset/index.html
@@ -627,7 +627,7 @@ 

Source code for minari.dataset.minari_dataset

else: self._total_steps = sum( self.storage.apply( - lambda episode: episode["total_timesteps"], + lambda episode: episode["total_steps"], episode_indices=self.episode_indices, ) ) diff --git a/main/_modules/minari/utils/index.html b/main/_modules/minari/utils/index.html index 9a0b9061..8aa75324 100644 --- a/main/_modules/minari/utils/index.html +++ b/main/_modules/minari/utils/index.html @@ -376,7 +376,6 @@

Source code for minari.utils

 from packaging.specifiers import InvalidSpecifier, SpecifierSet
 from packaging.version import Version
 
-from minari import DataCollector
 from minari.dataset.minari_dataset import MinariDataset
 from minari.dataset.minari_storage import MinariStorage
 from minari.storage.datasets_root_dir import get_dataset_path
@@ -935,63 +934,6 @@ 

Source code for minari.utils

 
 
 
-def create_dataset_from_collector_env(
-    dataset_id: str,
-    collector_env: DataCollector,
-    eval_env: Optional[str | gym.Env | EnvSpec] = None,
-    algorithm_name: Optional[str] = None,
-    author: Optional[str] = None,
-    author_email: Optional[str] = None,
-    code_permalink: Optional[str] = None,
-    ref_min_score: Optional[float] = None,
-    ref_max_score: Optional[float] = None,
-    expert_policy: Optional[Callable[[ObsType], ActType]] = None,
-    num_episodes_average_score: int = 100,
-    minari_version: Optional[str] = None,
-):
-    """Create a Minari dataset using the data collected from stepping with a Gymnasium environment wrapped with a `DataCollector` Minari wrapper.
-
-    The ``dataset_id`` parameter corresponds to the name of the dataset, with the syntax as follows:
-    ``(env_name-)(dataset_name)(-v(version))`` where ``env_name`` identifies the name of the environment used to generate the dataset ``dataset_name``.
-    This ``dataset_id`` is used to load the Minari datasets with :meth:`minari.load_dataset`.
-
-    Args:
-        dataset_id (str): name id to identify Minari dataset
-        collector_env (DataCollector): Gymnasium environment used to collect the buffer data
-        buffer (list[Dict[str, Union[list, Dict]]]): list of episode dictionaries with data
-        eval_env (Optional[str|gym.Env|EnvSpec]): Gymnasium environment(gym.Env)/environment id(str)/environment spec(EnvSpec) to use for evaluation with the dataset. After loading the dataset, the environment can be recovered as follows: `MinariDataset.recover_environment(eval_env=True).
-                                                If None the `env` used to collect the buffer data should be used for evaluation.
-        algorithm_name (Optional[str], optional): name of the algorithm used to collect the data. Defaults to None.
-        author (Optional[str], optional): author that generated the dataset. Defaults to None.
-        author_email (Optional[str], optional): email of the author that generated the dataset. Defaults to None.
-        code_permalink (Optional[str], optional): link to relevant code used to generate the dataset. Defaults to None.
-        ref_min_score( Optional[float], optional): minimum reference score from the average returns of a random policy. This value is later used to normalize a score with :meth:`minari.get_normalized_score`. If default None the value will be estimated with a default random policy.
-        ref_max_score (Optional[float], optional: maximum reference score from the average returns of a hypothetical expert policy. This value is used in :meth:`minari.get_normalized_score`. Default None.
-        expert_policy (Optional[Callable[[ObsType], ActType], optional): policy to compute `ref_max_score` by averaging the returns over a number of episodes equal to  `num_episodes_average_score`.
-                                                                        `ref_max_score` and `expert_policy` can't be passed at the same time. Default to None
-        num_episodes_average_score (int): number of episodes to average over the returns to compute `ref_min_score` and `ref_max_score`. Default to 100.
-        minari_version (Optional[str], optional): Minari version specifier compatible with the dataset. If None (default) use the installed Minari version.
-
-    Returns:
-        MinariDataset
-    """
-    warnings.warn("This function is deprecated and will be removed in v0.5.0. Please use DataCollector.create_dataset() instead.", DeprecationWarning, stacklevel=2)
-    dataset = collector_env.create_dataset(
-        dataset_id=dataset_id,
-        eval_env=eval_env,
-        algorithm_name=algorithm_name,
-        author=author,
-        author_email=author_email,
-        code_permalink=code_permalink,
-        ref_min_score=ref_min_score,
-        ref_max_score=ref_max_score,
-        expert_policy=expert_policy,
-        num_episodes_average_score=num_episodes_average_score,
-        minari_version=minari_version,
-    )
-    return dataset
-
-
 
[docs] def get_normalized_score(dataset: MinariDataset, returns: np.ndarray) -> np.ndarray: @@ -1071,7 +1013,7 @@

Source code for minari.utils

         version += f" ({__version__} installed)"
 
     md_dict = {
-        "Total Timesteps": dataset_spec["total_steps"],
+        "Total steps": dataset_spec["total_steps"],
         "Total Episodes": dataset_spec["total_episodes"],
         "Dataset Observation Space": f"`{dataset_observation_space}`",
         "Dataset Action Space": f"`{dataset_action_space}`",
diff --git a/main/api/data_collector/index.html b/main/api/data_collector/index.html
index d003bf30..576eea16 100644
--- a/main/api/data_collector/index.html
+++ b/main/api/data_collector/index.html
@@ -460,13 +460,6 @@ 

Methods#< -
-
-minari.DataCollector.close(self)#
-

Close the DataCollector.

-

Clear buffer and close temporary directory.

-
-
minari.DataCollector.create_dataset(self, dataset_id: str, eval_env: str | gym.Env | EnvSpec | None = None, algorithm_name: str | None = None, author: str | None = None, author_email: str | None = None, code_permalink: str | None = None, ref_min_score: float | None = None, ref_max_score: float | None = None, expert_policy: Callable[[ObsType], ActType] | None = None, num_episodes_average_score: int = 100, minari_version: str | None = None)#
@@ -500,6 +493,13 @@

Methods#<

+
+
+minari.DataCollector.close(self)#
+

Close the DataCollector.

+

Clear buffer and close temporary directory.

+
+ @@ -579,8 +579,8 @@

Methods#<
  • Methods
  • diff --git a/main/api/minari_dataset/episode_data/index.html b/main/api/minari_dataset/episode_data/index.html index e71635b8..b97bead8 100644 --- a/main/api/minari_dataset/episode_data/index.html +++ b/main/api/minari_dataset/episode_data/index.html @@ -373,7 +373,7 @@

    EpisodeData#

    -class minari.EpisodeData(id: int, seed: int | None, total_timesteps: int, observations: Any, actions: Any, rewards: ndarray, terminations: ndarray, truncations: ndarray, infos: dict)[source]#
    +class minari.EpisodeData(id: int, seed: int | None, total_steps: int, observations: Any, actions: Any, rewards: ndarray, terminations: ndarray, truncations: ndarray, infos: dict)[source]#

    Contains the datasets data for a single episode.

    This is the object returned by minari.MinariDataset.sample_episodes.

    @@ -393,22 +393,22 @@

    Attributes -
    -EpisodeData.total_timesteps: int#
    -

    The number of timesteps contained in this episode.

    +
    +EpisodeData.total_steps: int#
    +

    The number of steps contained in this episode.

    EpisodeData.observations: Any#

    The observations of the environment. The initial and final observations are included meaning that the number -of observations will be increased by one compared to the number of timesteps

    +of observations will be increased by one compared to the number of steps.

    EpisodeData.actions: Any#
    -

    The actions taken in each episode timestep.

    +

    The actions taken in each episode step.

    @@ -502,7 +502,7 @@

    AttributesAttributes diff --git a/main/objects.inv b/main/objects.inv index 9bbcb152b81a41a746d23dced3b20f057177e82b..1259972b3902a9f49e871c5ba8cfb6b5248c696a 100644 GIT binary patch delta 1536 zcmV+b2LJh^4xkROGXa0jiOWIODzh%;p|%dRkE2b_Hk*KnG(YN7`&T7X^@Gs-SC z>S?-KDVs;UHIMzk1c7>*G`%ppZfIKOG$S$d9u0k@OeJ30jvy8$ z{;0h!A8F0Yf|T-t=mHlNdlkRcU)oW|r64PkT{0gmV~k_JEOLJ=7e&K1RX*B#(__tx z^w1oSq<+yXe^P$#&mkrqVCv_GDehTRv@%A)!Sg9>@!L#-jO>&F9gi={f$w+nW{;gyj=jYl#s2VDHU0_DPlROm}ugO*nMukin zSDABmRU3$!6kMLq3@@1mQ86kv8^8VeEf6w)m-9@gf@&iD10cA=4+Q+fCHR9kWuDxB zc)!2>FxIs_2-u_ms<+W61#@;y<^?h~T__*m)J<|86@`Dth$1uKVXi#Lxn=Pq430+pHS*HuOtEX;Fi5(?DO zn1|Ym1Hn2P^Ke_SKj5*#Eab0LW+$?2*vb(2+rG@Kkh*eI`ivDc@03*R?b{rvpiiFZ zb2Kz{J&k9tlUBHv6tf^gtAo@L!n_P9iW})eO_v?$goZZhpXzOHh>pCeYt9Q1hUnFh zKE8k0*AN=470L4=613GGVZ(dVGYnCAmb%i*d^cWV-4n>_K|D`sEequHAf6|)miaTe zm(J3-;xw>-^->k1ex2-LgV<~%D-cGf-DSKOFejymLb96uKByZ@sCL&_0klfpuHuJ@ z@@}j;?0!D~xHq#L6$JjRNa+KJ@@r7EV19p}nuvs0-a#kwD8ELU{SNASRug@d^RQ(n zMBu)Cn;V7-a-uV@%3dK*niHLQefA3eB6U@|N_`+DEpzf@Oc*{8rj5E@1g`ZX8B|(S z@N3C|q>_=}5@V$%_^61ld`RU<-ILUXfw_>fQTG_d?gBMJtMz_ef4oVRC_B{*y>WlB zpu&O{Yu|PO)dfsfqJ9@#ubyEiP}RZnhLJXiktASp_bgB966{#2sUynSv{aXMfo7B8 zigljoroIdlkqClSZ0>-zPu<(BOkbT7QJtRB2Ue&3_l-q*)^*FH%wA zjHH5SE&APvcGa_Q-lr(YkAmK4cc6cA2X^kj9tgHTumge?2pS;h072yr+!ovDAA_B1 zjp~wW@wWsQfY}(`2&oIPLpmcoLrkfrq5f0bhN$M&8-puA& zvs80hcwHvNKDQ0~Z_GPS1su#_Yrffg`XB4)Cf;p_qc-@f27|5L#8H2SMCX6P;D~>{ z`_B?rZAXj;KayU~W|dEEhu#w~u)4tLMuqr8S>}9lX}@sVAgu)Oqk7M&`Tm9LXB#y5 zM&)V#AuwhtUXhVD6Op+F=vx5TzrDZd4lhe#xD^^%$4j2MB@4djW>+?1edO1p!;LKX zd*2J%UYiql&!05|MMQEzYK(Yg#OtDBG@}xHgg5FiGTfx^Wj+p z7CC5u0Ym&1K%AXD1G<}H4G>x6a3LyuF^_E`ph$-on9IjM14>{4z~}BYiK&)lkU({9 mDV7OLs8kkJ(Ub=_P*~J33(U95U|z7mqD(ZWVf_zUjYapk+TNfeh|?g=_B|X~t8!8OjY| z*N3p1vg!JU0MF;=l|5(O+kV0Wc>&K@G+(m$Jeo8;M_{;o?pr!Dx+9JEn(?Bz=R+8= z(=VJaXCt`)zMX#)Mpf?O?mLV~7zv)y!#T{qm5k&3T>A%ALnW^Z%;fPc6IfAFTv zllu?v_qQL$y0!-an-oCxHu|Js&aTP4K*pvEyV;fKOdlkJ!4LF}egbNc+;by17m|NmI@0Swmz`}^7o@Sv@}*c@6C;g^)zNJ6=8@*yrAtSkGE@D!$|z&J zRQDdQjg=#CWu6mb@>G#p`i-qz5r=tBj87`0(rGV)tutkeV_74el1sXVEBlDaT37g9Fr9;4V@phjr5-mmMAH>nb3r<#AE zH!c=bSkPka+b*EGfayxq?}F>qGwcMaI(XhN(grb-1WfLpas4- zY%*N2&J*3#mti6jL9mL=9nkivdz+Q%t8*f%(^LAu>a_p9(a%`_>Cb!#RJWcqYlP%Q zDk_|jR1mF2zZ=o6diKrx6b1QF&>Mg44pi>I&K=kT!4?R1K+pm~0|Xr)sN8|uVjKN) zuyd_ZT~aOnmf!*~8>1T`bs=_0XM|^nDb+O8e`?zh)!cexFa`ixN^E7TF}MYr9>@tl zQeI=0YEBEU%cR)nwqgH`dFQEsgE?%?H+xV2V;$YZyX|n)27lFHu(g{w>d$|W=v){a z@vnFPS>megi1FY@(#zSb@~Q36djbYl7Z}~B5PvAkoKG(87fu_bl>mNJ?>RN!zi|C* zg9hKIJk37@#!SU4GSX%uGS>io3jq7K_cz_)Who4|LL=*V$uqZP!57`^%0{e@{Cae_ zkp+M6dqLZ4bK>s#vu2=(NG^Lwjj@b)T~v%_RDzH2MjeI*uZ_?5H%TD#=Ko??RYr3@ zJd4002MsV_h`$1ev$JPFcT=nZB8wa@M1?Qru}uUN>F@$``S@o*2`m8k+`T3-)v^o{ psID!=GJy$|%AzWo^56yviyCHu`BoXs3l>

    Behavioral cloning with PyTorch#

    We present here how to perform behavioral cloning on a Minari dataset using PyTorch. We will start generating the dataset of the expert policy for the CartPole-v1 environment, which is a classic control problem. -The objective is to balance the pole on the cart, and we receive a reward of +1 for each successful timestep.

    +The objective is to balance the pole on the cart, and we receive a reward of +1 for each successful step.

    Imports#

    For this tutorial you will need the RL Baselines3 Zoo library, which you can install with pip install rl_zoo3. @@ -463,7 +463,7 @@

    Behavioral cloning with PyTorchreturn { "id": torch.Tensor([x.id for x in batch]), "seed": torch.Tensor([x.seed for x in batch]), - "total_timesteps": torch.Tensor([x.total_timesteps for x in batch]), + "total_steps": torch.Tensor([x.total_steps for x in batch]), "observations": torch.nn.utils.rnn.pad_sequence( [torch.as_tensor(x.observations) for x in batch], batch_first=True