diff --git a/AI_Agents_Guide/Constrained_Decoding/README.md b/AI_Agents_Guide/Constrained_Decoding/README.md index 0fbf37d0..28e07417 100644 --- a/AI_Agents_Guide/Constrained_Decoding/README.md +++ b/AI_Agents_Guide/Constrained_Decoding/README.md @@ -162,7 +162,7 @@ assistant ### Example 2 Optionally, we can also restrict an output to a specific schema. For example, -in [`client.py`](./artifacts/client.py) we use a `pydentic` library to define the +in [`client.py`](./artifacts/client.py) we use a `pydantic` library to define the following answer format: ```python @@ -279,7 +279,7 @@ self.executor = trtllm.Executor(model_path=..., ... ``` -Additionally, if you want to enable logits pos-processor for every request +Additionally, if you want to enable logits post-processor for every request individually, you can do so via an additional `input` parameter. For example, in this tutorial we will add `logits_post_processor_name` in `inflight_batcher_llm/tensorrt_llm/config.pbtxt`: diff --git a/AI_Agents_Guide/Constrained_Decoding/artifacts/utils.py b/AI_Agents_Guide/Constrained_Decoding/artifacts/utils.py index fb476d49..70ef4237 100644 --- a/AI_Agents_Guide/Constrained_Decoding/artifacts/utils.py +++ b/AI_Agents_Guide/Constrained_Decoding/artifacts/utils.py @@ -25,9 +25,8 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import json -import typing from collections import defaultdict -from typing import DefaultDict, Dict +from typing import DefaultDict, Dict, List import torch from lmformatenforcer import JsonSchemaParser, TokenEnforcer @@ -102,7 +101,7 @@ def __call__( self, req_id: int, logits: torch.Tensor, - ids: typing.List[typing.List[int]], + ids: List[List[int]], stream_ptr: int, ): # Create a mask with negative infinity to block all tokens initially. @@ -140,12 +139,10 @@ def __call__( self, req_id: int, logits: torch.Tensor, - ids: typing.List[typing.List[int]], + ids: List[List[int]], stream_ptr: int, ): - # Initialize the FSM state dictionary if the input_ids are empty, - # as this means that the input_ids are the first tokens of the sequence. - seq_id = hash(tuple(ids[0])) + seq_id = None # If the prefix token IDs have changed we assume that we are dealing # with a new sample and reset the FSM state if (