Skip to content

Commit

Permalink
Small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
oandreeva-nv committed Sep 4, 2024
1 parent f9148c9 commit c3ff24b
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 9 deletions.
4 changes: 2 additions & 2 deletions AI_Agents_Guide/Constrained_Decoding/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ assistant
### Example 2

Optionally, we can also restrict an output to a specific schema. For example,
in [`client.py`](./artifacts/client.py) we use a `pydentic` library to define the
in [`client.py`](./artifacts/client.py) we use a `pydantic` library to define the
following answer format:

```python
Expand Down Expand Up @@ -279,7 +279,7 @@ self.executor = trtllm.Executor(model_path=...,
...
```

Additionally, if you want to enable logits pos-processor for every request
Additionally, if you want to enable logits post-processor for every request
individually, you can do so via an additional `input` parameter.
For example, in this tutorial we will add `logits_post_processor_name` in
`inflight_batcher_llm/tensorrt_llm/config.pbtxt`:
Expand Down
11 changes: 4 additions & 7 deletions AI_Agents_Guide/Constrained_Decoding/artifacts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,8 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import json
import typing
from collections import defaultdict
from typing import DefaultDict, Dict
from typing import DefaultDict, Dict, List

import torch
from lmformatenforcer import JsonSchemaParser, TokenEnforcer
Expand Down Expand Up @@ -102,7 +101,7 @@ def __call__(
self,
req_id: int,
logits: torch.Tensor,
ids: typing.List[typing.List[int]],
ids: List[List[int]],
stream_ptr: int,
):
# Create a mask with negative infinity to block all tokens initially.
Expand Down Expand Up @@ -140,12 +139,10 @@ def __call__(
self,
req_id: int,
logits: torch.Tensor,
ids: typing.List[typing.List[int]],
ids: List[List[int]],
stream_ptr: int,
):
# Initialize the FSM state dictionary if the input_ids are empty,
# as this means that the input_ids are the first tokens of the sequence.
seq_id = hash(tuple(ids[0]))
seq_id = None
# If the prefix token IDs have changed we assume that we are dealing
# with a new sample and reset the FSM state
if (
Expand Down

0 comments on commit c3ff24b

Please sign in to comment.