You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm trying to setup a DQN agent with a graph attention layer. The agent can take one of 3 actions. For some reason, when I run the training function, I see the following error:
File "main.py", line 151, in <module>
main()
File "main.py", line 41, in main
run_trial(args, args.tr)
File "main.py", line 146, in run_trial
agent.observe(obs, rew, done, info)
File "/Users/behradkoohy/sumo-scratchpad/RESCO/agents/agent.py", line 33, in observe
self.agents[agent_id].observe(observation[agent_id], reward[agent_id], done, info)
File "/Users/behradkoohy/sumo-scratchpad/RESCO/agents/colight.py", line 189, in observe
self.agent.observe(observation, reward, done, False)
File "/Users/behradkoohy/sumo-scratchpad/RESCO/venv/lib/python3.8/site-packages/pfrl/agent.py", line 164, in observe
self.batch_observe([obs], [reward], [done], [reset])
File "/Users/behradkoohy/sumo-scratchpad/RESCO/venv/lib/python3.8/site-packages/pfrl/agents/dqn.py", line 586, in batch_observe
return self._batch_observe_train(
File "/Users/behradkoohy/sumo-scratchpad/RESCO/venv/lib/python3.8/site-packages/pfrl/agents/dqn.py", line 548, in _batch_observe_train
self.replay_updater.update_if_necessary(self.t)
File "/Users/behradkoohy/sumo-scratchpad/RESCO/venv/lib/python3.8/site-packages/pfrl/replay_buffer.py", line 356, in update_if_necessary
self.update_func(transitions)
File "/Users/behradkoohy/sumo-scratchpad/RESCO/venv/lib/python3.8/site-packages/pfrl/agents/dqn.py", line 351, in update
loss = self._compute_loss(exp_batch, errors_out=errors_out)
File "/Users/behradkoohy/sumo-scratchpad/RESCO/venv/lib/python3.8/site-packages/pfrl/agents/dqn.py", line 441, in _compute_loss
y, t = self._compute_y_and_t(exp_batch)
File "/Users/behradkoohy/sumo-scratchpad/RESCO/venv/lib/python3.8/site-packages/pfrl/agents/dqn.py", line 421, in _compute_y_and_t
batch_q = torch.reshape(qout.evaluate_actions(batch_actions), (batch_size, 1))
File "/Users/behradkoohy/sumo-scratchpad/RESCO/venv/lib/python3.8/site-packages/pfrl/action_value.py", line 70, in evaluate_actions
return self.q_values.gather(dim=1, index=index).flatten()
RuntimeError: Size does not match at dimension 0 expected index [32, 1] to be smaller than self [1, 3] apart from dimension 1
I'm somewhat lost as to where to go with this. The neural network works absolutely fine when a Conv2D is used, but doesn't like the graph attention layer despite the same output dimensions etc.
Thanks
The text was updated successfully, but these errors were encountered:
I guess from the error message it is due to the shape issue of your network's output.
File "/Users/behradkoohy/sumo-scratchpad/RESCO/venv/lib/python3.8/site-packages/pfrl/action_value.py", line 70, in evaluate_actions
return self.q_values.gather(dim=1, index=index).flatten()
RuntimeError: Size does not match at dimension 0 expected index [32, 1] to be smaller than self [1, 3] apart from dimension 1
This suggests that q_values, which is the output of your network, has a shape of (1, 3), which is probably unexpected since q_values must have a shape of (batch_size, num_actions) for DiscreteActionValue. It would help to check why the output has such a shape.
Hi,
I'm trying to setup a DQN agent with a graph attention layer. The agent can take one of 3 actions. For some reason, when I run the training function, I see the following error:
I'm somewhat lost as to where to go with this. The neural network works absolutely fine when a Conv2D is used, but doesn't like the graph attention layer despite the same output dimensions etc.
Thanks
The text was updated successfully, but these errors were encountered: