diff --git a/src/models/world_model.py b/src/models/world_model.py index 0691034..0aa715e 100644 --- a/src/models/world_model.py +++ b/src/models/world_model.py @@ -97,7 +97,7 @@ def forward(self, tokens: torch.LongTensor, past_keys_values: Optional[KeysValue def compute_loss(self, batch: Batch, tokenizer: Tokenizer, **kwargs: Any) -> LossWithIntermediateLosses: with torch.no_grad(): - obs_tokens = tokenizer.encode(batch['observations'], should_preprocess=True).tokens # (BL, K) + obs_tokens = tokenizer.encode(batch['observations'], should_preprocess=True).tokens # (B, L, K) act_tokens = rearrange(batch['actions'], 'b l -> b l 1') tokens = rearrange(torch.cat((obs_tokens, act_tokens), dim=2), 'b l k1 -> b (l k1)') # (B, L(K+1))