Skip to content

Commit

Permalink
Fix torch.compile on the Llama3.2 vision model
Browse files Browse the repository at this point in the history
  • Loading branch information
iseeyuan committed Oct 9, 2024
1 parent 286527c commit 254978d
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions torchchat/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@
from torchchat.model import Model, ModelType
from torchchat.utils.build_utils import device_sync, set_precision
from torchchat.utils.device_info import get_device_info

# torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.capture_dynamic_output_shape_ops = True
# torch._dynamo.config.suppress_errors = True

class _ChatFormatter(ABC):
def __init__(self, tokenizer):
Expand Down Expand Up @@ -415,7 +417,9 @@ def decode_one_token(
x = x.view(1, -1)
if model.config.model_type == ModelType.Flamingo:
assert batch is not None, "Flamingo requires batch"
mask = batch["causal_mask"][None, input_pos.item(), None, :]
# breakpoint()
# start_pos = input_pos.item()
mask = batch["causal_mask"][None, input_pos, None, :].view(1, 1, -1)
encoder_mask = batch["encoder_mask"] if "encoder_mask" in batch else None
logits = model(
x, encoder_mask=encoder_mask, mask=mask, input_pos=input_pos
Expand Down

0 comments on commit 254978d

Please sign in to comment.