Skip to content

Commit

Permalink
[Core] Make NPCStepper.generate_local() sync in local mode (#9)
Browse files Browse the repository at this point in the history
* feat: use chat template for gguf model (llama.cpp model)

* change: variable name ('model' to 'llm') in NPCStepper.generate_local()

* change: NPCStepper.get_action is now synchronous and README is up to date

* change: update tests/test_step.py

* change: update README.md

* Update gigax/step.py

Co-authored-by: tdeborde <[email protected]>

* Update gigax/step.py

Co-authored-by: tdeborde <[email protected]>

* Update README.md

Co-authored-by: tdeborde <[email protected]>

* Update gigax/step.py

Co-authored-by: tdeborde <[email protected]>

* Update tests/test_step.py

Co-authored-by: tdeborde <[email protected]>

* update: README.md

---------

Co-authored-by: A-Mahla <>
Co-authored-by: tdeborde <[email protected]>
  • Loading branch information
A-Mahla and tristandeborde authored Jun 19, 2024
1 parent 6fa19e3 commit c3c209d
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 24 deletions.
32 changes: 25 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ ______________________________________________________________________
pip install gigax
```


## Features

- [x] 🕹️ NPCs that `<speak>`, `<jump>`, `<attack>` and perform any other action you've defined
Expand All @@ -36,13 +37,22 @@ pip install gigax

Gigax has new releases and features on the way. Make sure to ⭐ star and 👀 watch this repository!


## Usage

### Model instantiation


* We provide various models on the [🤗 Huggingface hub](https://huggingface.co/Gigax):
* [NPC-LLM-7B](https://huggingface.co/Gigax/NPC-LLM-7B) (our Mistral-7B fine-tune)
* [NPC-LLM-3_8B](https://huggingface.co/Gigax/NPC-LLM-3_8B) (our Phi-3 fine-tune)
* [NPC-LLM-3_8B-128k](https://huggingface.co/Gigax/NPC-LLM-3_8B-128k) (our Phi-3 128k context length fine-tune)

* All these models are also available in [gguf](https://huggingface.co/docs/hub/en/gguf) format to run them on CPU using [llama_cpp](https://llama-cpp-python.readthedocs.io/en/latest/)
* [NPC-LLM-7B-GGUF](https://huggingface.co/Gigax/NPC-LLM-7B-GGUF)
* [NPC-LLM-3_8B-GGUF](https://huggingface.co/Gigax/NPC-LLM-3_8B-GGUF)
* [NPC-LLM-3_8B-128k-GGUF](https://huggingface.co/Gigax/NPC-LLM-3_8B-128k-GGUF)


* Start by instantiating one of them using outlines:
```py
Expand All @@ -51,22 +61,26 @@ from gigax.step import NPCStepper
from transformers import AutoTokenizer, AutoModelForCausalLM

# Download model from the Hub
model_name = "Gigax/NPC-LLM-7B"
llm = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
llm = Llama.from_pretrained(
repo_id="Gigax/NPC-LLM-3_8B-GGUF",
filename="npc-llm-3_8B.gguf"
# n_gpu_layers=-1, # Uncomment to use GPU acceleration
# n_ctx=2048, # Uncomment to increase the context window
)

# Our stepper takes in a Outlines model to enable guided generation
# This forces the model to follow our output format
model = models.Transformers(llm, tokenizer)
model = models.LlamaCpp(llm)

# Instantiate a stepper: handles prompting + output parsing
stepper = NPCStepper(model=model)
```


### Stepping an NPC


* From there, stepping an NPC is a one-liner:
```py
action = stepper.get_action(
action = await stepper.get_action(
context=context,
locations=locations,
NPCs=NPCs,
Expand All @@ -76,6 +90,7 @@ action = stepper.get_action(
)
```


* We provide classes to instantiate `Locations`, `NPCs`, etc. :
```py
from gigax.parse import CharacterAction
Expand All @@ -88,7 +103,9 @@ from gigax.scene import (
ParameterType,
)
# Use sample data
context = "Medieval world"
current_location = Location(name="Old Town", description="A quiet and peaceful town.")
locations = [current_location] # you can add more locations to the scene
NPCs = [
Character(
name="John the Brave",
Expand Down Expand Up @@ -121,6 +138,7 @@ events = [
]
```


## API

Contact us to [give our NPC API a try](https://tally.so/r/w7d2Rz) - we'll take care of model serving, NPC memory, and more!
Expand Down
5 changes: 2 additions & 3 deletions gigax/step.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import time
import logging
import traceback

from openai import AsyncOpenAI
from gigax.prompt import NPCPrompt, llama_chat_template
from gigax.scene import (
Expand Down Expand Up @@ -71,7 +70,7 @@ async def generate_api(
# Return the NPC's response
return response.choices[0].message.content # type: ignore

async def generate_local(
def generate_local(
self,
prompt: str,
llm: models.LogitsGenerator,
Expand Down Expand Up @@ -144,7 +143,7 @@ async def get_action(

# Generate the response
if isinstance(self.model, models.LogitsGenerator):
res = await self.generate_local(
res = self.generate_local(
prompt,
self.model,
guided_regex.pattern,
Expand Down
26 changes: 12 additions & 14 deletions tests/test_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,28 @@
load_dotenv()


@pytest.mark.asyncio
async def test_stepper_local_llamacpp(
def test_stepper_local_llamacpp(
context: str,
locations: list[Location],
NPCs: list[Character],
protagonist: ProtagonistCharacter,
items: list[Item],
events: list[CharacterAction],
):
llm = Llama(
model_path="./models/Phi-3-mini-4k-instruct-q4.gguf", # path to GGUF file
n_ctx=4096, # The max sequence length to use - note that longer sequence lengths require much more resources
n_threads=8, # The number of CPU threads to use, tailor to your system and the resulting performance
n_gpu_layers=35, # The number of layers to offload to GPU, if you have GPU acceleration available. Set to 0 if no GPU acceleration is available on your system.
llm = Llama.from_pretrained(
repo_id="Gigax/NPC-LLM-3_8B-GGUF",
filename="npc-llm-3_8B.gguf"
# n_gpu_layers=-1, # Uncomment to use GPU acceleration
# seed=1337, # Uncomment to set a specific seed
# n_ctx=2048, # Uncomment to increase the context window
)

model = models.LlamaCpp(llm) # type: ignore
model = models.LlamaCpp(llm)

stepper = NPCStepper(model=model)

start = time.time()
action = await stepper.get_action(
action = stepper.get_action(
context=context,
locations=locations,
NPCs=NPCs,
Expand All @@ -46,8 +46,7 @@ async def test_stepper_local_llamacpp(
assert str(action) == "Aldren: Attack John the Brave"


@pytest.mark.asyncio
async def test_stepper_local_transformers(
def test_stepper_local_transformers(
context: str,
locations: list[Location],
NPCs: list[Character],
Expand All @@ -63,7 +62,7 @@ async def test_stepper_local_transformers(
# Get the NPC's input
stepper = NPCStepper(model=model)

action = await stepper.get_action(
action = stepper.get_action(
context=context,
locations=locations,
NPCs=NPCs,
Expand All @@ -75,8 +74,7 @@ async def test_stepper_local_transformers(
assert str(action) == "Aldren: Attack John the Brave"


@pytest.mark.asyncio
async def test_stepper_api(
def test_stepper_api(
context: str,
locations: list[Location],
NPCs: list[Character],
Expand Down

0 comments on commit c3c209d

Please sign in to comment.