Skip to content

Commit

Permalink
ENH: Support MultiDiscrete & MultiBinary (#253)
Browse files Browse the repository at this point in the history
  • Loading branch information
younik authored Oct 20, 2024
1 parent 480d34c commit 5042059
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 18 deletions.
9 changes: 6 additions & 3 deletions minari/dataset/_storages/arrow_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
from minari.dataset.minari_storage import MinariStorage


_FIXEDLIST_SPACES = (gym.spaces.Box, gym.spaces.MultiDiscrete, gym.spaces.MultiBinary)


class ArrowStorage(MinariStorage):
FORMAT = "arrow"

Expand Down Expand Up @@ -173,7 +176,7 @@ def _encode_space(space: gym.Space, values: Any, pad: int = 0):
names.append(str(i))
arrays.append(_encode_space(space[i], value, pad=pad))
return pa.StructArray.from_arrays(arrays, names=names)
elif isinstance(space, gym.spaces.Box):
elif isinstance(space, _FIXEDLIST_SPACES):
values = np.asarray(values)
assert values.shape[1:] == space.shape
values = values.reshape(values.shape[0], -1)
Expand All @@ -183,7 +186,7 @@ def _encode_space(space: gym.Space, values: Any, pad: int = 0):
elif isinstance(space, gym.spaces.Discrete):
values = np.asarray(values).reshape(-1, 1)
values = np.pad(values, ((0, pad), (0, 0)))
return pa.array(values.squeeze(-1), type=pa.int32())
return pa.array(values.squeeze(-1), type=pa.from_numpy_dtype(space.dtype))
else:
if not isinstance(values, list):
values = list(values)
Expand All @@ -203,7 +206,7 @@ def _decode_space(space, values: pa.Array):
for i, subspace in enumerate(space.spaces)
]
)
elif isinstance(space, gym.spaces.Box):
elif isinstance(space, _FIXEDLIST_SPACES):
data = np.stack(values.to_numpy(zero_copy_only=False))
return data.reshape(-1, *space.shape)
elif isinstance(space, gym.spaces.Discrete):
Expand Down
48 changes: 46 additions & 2 deletions minari/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,36 @@ def _serialize_discrete(space: spaces.Discrete, to_string=True) -> Union[Dict, s
return result


@serialize_space.register(spaces.MultiDiscrete)
def _serialize_multi_discrete(
space: spaces.MultiDiscrete, to_string=True
) -> Union[Dict, str]:
result = {}
result["type"] = "MultiDiscrete"
result["dtype"] = str(space.dtype)
result["nvec"] = space.nvec.tolist()
result["start"] = space.start.tolist()

if to_string:
result = json.dumps(result)
return result


@serialize_space.register(spaces.MultiBinary)
def _serialize_multi_binary(
space: spaces.MultiBinary, to_string=True
) -> Union[Dict, str]:
result = {"type": "MultiBinary", "n": space.n}

if to_string:
result = json.dumps(result)
return result


@serialize_space.register(spaces.Dict)
def _serialize_dict(space: spaces.Dict, to_string=True) -> Union[Dict, str]:
result = {"type": "Dict", "subspaces": {}}

for key in space.spaces.keys():
result["subspaces"][key] = serialize_space(space.spaces[key], to_string=False)

Expand All @@ -57,6 +84,7 @@ def _serialize_dict(space: spaces.Dict, to_string=True) -> Union[Dict, str]:
@serialize_space.register(spaces.Tuple)
def _serialize_tuple(space: spaces.Tuple, to_string=True) -> Union[Dict, str]:
result = {"type": "Tuple", "subspaces": []}

for subspace in space.spaces:
result["subspaces"].append(serialize_space(subspace, to_string=False))

Expand Down Expand Up @@ -130,10 +158,10 @@ def _deserialize_dict(space_dict: Dict) -> spaces.Dict:
def _deserialize_box(space_dict: Dict) -> spaces.Box:
assert space_dict["type"] == "Box"
shape = tuple(space_dict["shape"])
dtype = np.dtype(space_dict["dtype"])
dtype = space_dict["dtype"]
low = np.array(space_dict["low"], dtype=dtype)
high = np.array(space_dict["high"], dtype=dtype)
return spaces.Box(low=low, high=high, shape=shape, dtype=dtype) # type: ignore
return spaces.Box(low=low, high=high, shape=shape, dtype=dtype)


@deserialize_space.register("Discrete")
Expand All @@ -144,6 +172,22 @@ def _deserialize_discrete(space_dict: Dict) -> spaces.Discrete:
return spaces.Discrete(n=n, start=start)


@deserialize_space.register("MultiDiscrete")
def _deserialize_multi_discrete(space_dict: Dict) -> spaces.MultiDiscrete:
assert space_dict["type"] == "MultiDiscrete"
return spaces.MultiDiscrete(
nvec=space_dict["nvec"],
dtype=space_dict["dtype"],
start=space_dict["start"],
)


@deserialize_space.register("MultiBinary")
def _deserialize_multi_binary(space_dict: Dict) -> spaces.MultiBinary:
assert space_dict["type"] == "MultiBinary"
return spaces.MultiBinary(n=space_dict["n"])


@deserialize_space.register("Text")
def _deserialize_text(space_dict: Dict) -> spaces.Text:
assert space_dict["type"] == "Text"
Expand Down
31 changes: 18 additions & 13 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,18 @@
)

cartpole_test_dataset = [("cartpole/test-v0", "CartPole-v1")]
dummy_box_dataset = [("dummy-box/test-v0", "DummyBoxEnv-v0")]
dummy_box_dataset = [
("dummy-box/test-v0", "DummyBoxEnv-v0"),
]
dummy_text_dataset = [("dummy-text/test-v0", "DummyTextEnv-v0")]

# Note: Doesn't include the text dataset, since this is often handled separately
dummy_test_datasets = [
("dummy-dict/test-v0", "DummyDictEnv-v0"),
("dummy-tuple/test-v0", "DummyTupleEnv-v0"),
("dummy-combo/test-v0", "DummyComboEnv-v0"),
("dummy-multi-dim-box/test-v0", "DummyMultiDimensionalBoxEnv-v0"),
("dummy-multidim-space/test-v0", "DummyMultiSpaceEnv-v0"),
("dummy-tuple-discrete-box/test-v0", "DummyTupleDiscreteBoxEnv-v0"),
("nested/namespace/dummy-dict/test-v0", "DummyDictEnv-v0"),
("dummy-single-step/test-v0", "DummySingleStepEnv-v0"),
Expand Down Expand Up @@ -95,25 +99,25 @@ def _get_info(self):
return super()._get_info() if self.timestep % 2 == 0 else {}


class DummyMultiDimensionalBoxEnv(gym.Env):
class DummyMultiDimensionalBoxEnv(DummyEnv):
def __init__(self):
super().__init__()
self.action_space = spaces.Box(
low=-1, high=4, shape=(2, 2, 2), dtype=np.float32
)
self.observation_space = spaces.Box(
low=-1, high=4, shape=(3, 3, 3), dtype=np.float32
)

def step(self, action):
terminated = self.timestep > 5
self.timestep += 1

return self.observation_space.sample(), 0, terminated, False, {}
class DummyMultiSpaceEnv(DummyEnv):
def __init__(self):
super().__init__()
self.action_space = spaces.MultiBinary(10)
self.observation_space = spaces.MultiDiscrete([10, 2, 4], dtype=np.int32)

def reset(self, seed=None, options=None):
self.timestep = 0
self.observation_space.seed(seed)
return self.observation_space.sample(), {}
def _get_info(self):
return {"timestep": np.array([self.timestep])}


class DummyTupleDiscreteBoxEnv(DummyEnv):
Expand Down Expand Up @@ -256,12 +260,13 @@ def __init__(self):

test_spaces = [
gym.spaces.Box(low=-1, high=4, shape=(2,), dtype=np.float32),
gym.spaces.Box(low=-1, high=4, shape=(3,), dtype=np.float32),
gym.spaces.Box(low=-1, high=4, shape=(2, 2, 2), dtype=np.float32),
gym.spaces.Box(low=-1, high=4, shape=(3, 3, 3), dtype=np.float32),
gym.spaces.Text(max_length=10, min_length=10),
gym.spaces.Text(max_length=10, min_length=10, seed=42),
gym.spaces.Text(max_length=20, charset=unicode_charset),
gym.spaces.Text(max_length=10, charset="01"),
gym.spaces.Discrete(10, seed=7, start=2),
gym.spaces.MultiDiscrete([5, 2, 3], dtype=np.int32),
gym.spaces.MultiBinary(10, seed=1),
gym.spaces.Tuple(
(
gym.spaces.Discrete(1),
Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def register_dummy_envs():
"DummySingleStepEnv",
"DummyInconsistentInfoEnv",
"DummyMultiDimensionalBoxEnv",
"DummyMultiSpaceEnv",
"DummyTupleDiscreteBoxEnv",
"DummyDictEnv",
"DummyTupleEnv",
Expand Down

0 comments on commit 5042059

Please sign in to comment.