Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 16, 2024
1 parent 67f659c commit b54f73d
Show file tree
Hide file tree
Showing 2 changed files with 182 additions and 141 deletions.
38 changes: 22 additions & 16 deletions test/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,24 +858,15 @@ def _get_mock_input_td(
@pytest.mark.parametrize("n_agents", [1, 3])
@pytest.mark.parametrize("share_params", [True, False])
@pytest.mark.parametrize("centralised", [True, False])
@pytest.mark.parametrize(
"batch",
[
(10,),
(
10,
3,
),
(),
],
)
def test_mlp(
@pytest.mark.parametrize("n_agent_inputs", [6, None])
@pytest.mark.parametrize("batch", [(10,), (10, 3), ()])
def test_multiagent_mlp(
self,
n_agents,
centralised,
share_params,
batch,
n_agent_inputs=6,
n_agent_inputs,
n_agent_outputs=2,
):
torch.manual_seed(0)
Expand All @@ -887,6 +878,8 @@ def test_mlp(
share_params=share_params,
depth=2,
)
if n_agent_inputs is None:
n_agent_inputs = 6
td = self._get_mock_input_td(n_agents, n_agent_inputs, batch=batch)
obs = td.get(("agents", "observation"))

Expand Down Expand Up @@ -924,14 +917,27 @@ def test_mlp(
@pytest.mark.parametrize("n_agents", [1, 3])
@pytest.mark.parametrize("share_params", [True, False])
@pytest.mark.parametrize("centralised", [True, False])
@pytest.mark.parametrize("channels", [3, None])
@pytest.mark.parametrize("batch", [(10,), (10, 3), ()])
def test_cnn(
self, n_agents, centralised, share_params, batch, x=50, y=50, channels=3
def test_multiagent_cnn(
self,
n_agents,
centralised,
share_params,
batch,
channels,
x=50,
y=50,
):
torch.manual_seed(0)
cnn = MultiAgentConvNet(
n_agents=n_agents, centralised=centralised, share_params=share_params
n_agents=n_agents,
centralised=centralised,
share_params=share_params,
in_features=channels,
)
if channels is None:
channels = 3
td = TensorDict(
{
"agents": TensorDict(
Expand Down
Loading

0 comments on commit b54f73d

Please sign in to comment.