Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Jan 31, 2025
1 parent 27d4efb commit eaf25fd
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions optimum_benchmark/backends/pytorch/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def process_quantization_config(self) -> None:

self.logger.info("\t+ Processing AutoQuantization config")
self.quantization_config = AutoQuantizationConfig.from_dict(
dict(**getattr(self.pretrained_config, "quantization_config", {}), **self.config.quantization_config)
dict(getattr(self.pretrained_config, "quantization_config", {}), **self.config.quantization_config)
)

@property
Expand Down Expand Up @@ -404,9 +404,9 @@ def forward(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict

@torch.inference_mode()
def prefill(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict:
assert (
kwargs.get("max_new_tokens") == kwargs.get("min_new_tokens") == 1
), "For prefilling, max_new_tokens and min_new_tokens must be equal to 1"
assert kwargs.get("max_new_tokens") == kwargs.get("min_new_tokens") == 1, (
"For prefilling, max_new_tokens and min_new_tokens must be equal to 1"
)
return self.pretrained_model.generate(**inputs, **kwargs)

@torch.inference_mode()
Expand Down

0 comments on commit eaf25fd

Please sign in to comment.