Skip to content

Commit

Permalink
Implemented logging token usage (solves AntonOsika#322) (AntonOsika#438)
Browse files Browse the repository at this point in the history
* Implemented logging token usage

Token usage is now tracked and logged into memory/logs/token_usage

* Step names are now inferred from function name

* Incorporated Anton's feedback

- Made LogUsage a dataclass
- For token logging, step name is now inferred via inspect module

* Formatted (black/ruff)

* Update gpt_engineer/ai.py

Co-authored-by: Anton Osika <[email protected]>

* formatting

---------

Co-authored-by: Anton Osika <[email protected]>
  • Loading branch information
UmerHA and AntonOsika authored Jul 3, 2023
1 parent 2b8e056 commit 8fd315d
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 14 deletions.
92 changes: 89 additions & 3 deletions gpt_engineer/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,54 @@

import logging

from dataclasses import dataclass
from typing import Dict, List

import openai
import tiktoken

logger = logging.getLogger(__name__)


@dataclass
class TokenUsage:
step_name: str
in_step_prompt_tokens: int
in_step_completion_tokens: int
in_step_total_tokens: int
total_prompt_tokens: int
total_completion_tokens: int
total_tokens: int


class AI:
def __init__(self, model="gpt-4", temperature=0.1):
self.temperature = temperature
self.model = model

def start(self, system, user):
# initialize token usage log
self.cumulative_prompt_tokens = 0
self.cumulative_completion_tokens = 0
self.cumulative_total_tokens = 0
self.token_usage_log = []

try:
self.tokenizer = tiktoken.encoding_for_model(model)
except KeyError:
logger.debug(
f"Tiktoken encoder for model {model} not found. Using "
"cl100k_base encoder instead. The results may therefore be "
"inaccurate and should only be used as estimate."
)
self.tokenizer = tiktoken.get_encoding("cl100k_base")

def start(self, system, user, step_name):
messages = [
{"role": "system", "content": system},
{"role": "user", "content": user},
]

return self.next(messages)
return self.next(messages, step_name=step_name)

def fsystem(self, msg):
return {"role": "system", "content": msg}
Expand All @@ -31,7 +60,7 @@ def fuser(self, msg):
def fassistant(self, msg):
return {"role": "assistant", "content": msg}

def next(self, messages: List[Dict[str, str]], prompt=None):
def next(self, messages: List[Dict[str, str]], prompt=None, *, step_name=None):
if prompt:
messages += [{"role": "user", "content": prompt}]

Expand All @@ -52,8 +81,65 @@ def next(self, messages: List[Dict[str, str]], prompt=None):
print()
messages += [{"role": "assistant", "content": "".join(chat)}]
logger.debug(f"Chat completion finished: {messages}")

self.update_token_usage_log(
messages=messages, answer="".join(chat), step_name=step_name
)

return messages

def update_token_usage_log(self, messages, answer, step_name):
prompt_tokens = self.num_tokens_from_messages(messages)
completion_tokens = self.num_tokens(answer)
total_tokens = prompt_tokens + completion_tokens

self.cumulative_prompt_tokens += prompt_tokens
self.cumulative_completion_tokens += completion_tokens
self.cumulative_total_tokens += total_tokens

self.token_usage_log.append(
TokenUsage(
step_name=step_name,
in_step_prompt_tokens=prompt_tokens,
in_step_completion_tokens=completion_tokens,
in_step_total_tokens=total_tokens,
total_prompt_tokens=self.cumulative_prompt_tokens,
total_completion_tokens=self.cumulative_completion_tokens,
total_tokens=self.cumulative_total_tokens,
)
)

def format_token_usage_log(self):
result = "step_name,"
result += "prompt_tokens_in_step,completion_tokens_in_step,total_tokens_in_step"
result += ",total_prompt_tokens,total_completion_tokens,total_tokens\n"
for log in self.token_usage_log:
result += log.step_name + ","
result += str(log.in_step_prompt_tokens) + ","
result += str(log.in_step_completion_tokens) + ","
result += str(log.in_step_total_tokens) + ","
result += str(log.total_prompt_tokens) + ","
result += str(log.total_completion_tokens) + ","
result += str(log.total_tokens) + "\n"
return result

def num_tokens(self, txt):
return len(self.tokenizer.encode(txt))

def num_tokens_from_messages(self, messages):
"""Returns the number of tokens used by a list of messages."""
n_tokens = 0
for message in messages:
n_tokens += (
4 # every message follows <im_start>{role/name}\n{content}<im_end>\n
)
for key, value in message.items():
n_tokens += self.num_tokens(value)
if key == "name": # if there's a name, the role is omitted
n_tokens += -1 # role is always required and always 1 token
n_tokens += 2 # every reply is primed with <im_start>assistant
return n_tokens


def fallback_model(model: str) -> str:
try:
Expand Down
2 changes: 2 additions & 0 deletions gpt_engineer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def main(
if collect_consent():
collect_learnings(model, temperature, steps, dbs)

dbs.logs["token_usage"] = ai.format_token_usage_log()


if __name__ == "__main__":
app()
31 changes: 20 additions & 11 deletions gpt_engineer/steps.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import json
import re
import subprocess
Expand Down Expand Up @@ -35,12 +36,17 @@ def get_prompt(dbs: DBs) -> str:
return dbs.input["prompt"]


def curr_fn() -> str:
"""Get the name of the current function"""
return inspect.stack()[1].function


# All steps below have the signature Step


def simple_gen(ai: AI, dbs: DBs) -> List[dict]:
"""Run the AI on the main prompt and save the results"""
messages = ai.start(setup_sys_prompt(dbs), get_prompt(dbs))
messages = ai.start(setup_sys_prompt(dbs), get_prompt(dbs), step_name=curr_fn())
to_files(messages[-1]["content"], dbs.workspace)
return messages

Expand All @@ -52,7 +58,7 @@ def clarify(ai: AI, dbs: DBs) -> List[dict]:
messages = [ai.fsystem(dbs.preprompts["qa"])]
user_input = get_prompt(dbs)
while True:
messages = ai.next(messages, user_input)
messages = ai.next(messages, user_input, step_name=curr_fn())

if messages[-1]["content"].strip() == "Nothing more to clarify.":
break
Expand All @@ -71,6 +77,7 @@ def clarify(ai: AI, dbs: DBs) -> List[dict]:
messages = ai.next(
messages,
"Make your own assumptions and state them explicitly before starting",
step_name=curr_fn(),
)
print()
return messages
Expand All @@ -97,7 +104,7 @@ def gen_spec(ai: AI, dbs: DBs) -> List[dict]:
ai.fsystem(f"Instructions: {dbs.input['prompt']}"),
]

messages = ai.next(messages, dbs.preprompts["spec"])
messages = ai.next(messages, dbs.preprompts["spec"], step_name=curr_fn())

dbs.memory["specification"] = messages[-1]["content"]

Expand All @@ -108,7 +115,7 @@ def respec(ai: AI, dbs: DBs) -> List[dict]:
messages = json.loads(dbs.logs[gen_spec.__name__])
messages += [ai.fsystem(dbs.preprompts["respec"])]

messages = ai.next(messages)
messages = ai.next(messages, step_name=curr_fn())
messages = ai.next(
messages,
(
Expand All @@ -119,6 +126,7 @@ def respec(ai: AI, dbs: DBs) -> List[dict]:
"If you are satisfied with the specification, just write out the "
"specification word by word again."
),
step_name=curr_fn(),
)

dbs.memory["specification"] = messages[-1]["content"]
Expand All @@ -135,7 +143,7 @@ def gen_unit_tests(ai: AI, dbs: DBs) -> List[dict]:
ai.fuser(f"Specification:\n\n{dbs.memory['specification']}"),
]

messages = ai.next(messages, dbs.preprompts["unit_tests"])
messages = ai.next(messages, dbs.preprompts["unit_tests"], step_name=curr_fn())

dbs.memory["unit_tests"] = messages[-1]["content"]
to_files(dbs.memory["unit_tests"], dbs.workspace)
Expand All @@ -145,28 +153,26 @@ def gen_unit_tests(ai: AI, dbs: DBs) -> List[dict]:

def gen_clarified_code(ai: AI, dbs: DBs) -> List[dict]:
"""Takes clarification and generates code"""

messages = json.loads(dbs.logs[clarify.__name__])

messages = [
ai.fsystem(setup_sys_prompt(dbs)),
] + messages[1:]
messages = ai.next(messages, dbs.preprompts["use_qa"])
messages = ai.next(messages, dbs.preprompts["use_qa"], step_name=curr_fn())

to_files(messages[-1]["content"], dbs.workspace)
return messages


def gen_code(ai: AI, dbs: DBs) -> List[dict]:
# get the messages from previous step

messages = [
ai.fsystem(setup_sys_prompt(dbs)),
ai.fuser(f"Instructions: {dbs.input['prompt']}"),
ai.fuser(f"Specification:\n\n{dbs.memory['specification']}"),
ai.fuser(f"Unit tests:\n\n{dbs.memory['unit_tests']}"),
]
messages = ai.next(messages, dbs.preprompts["use_qa"])
messages = ai.next(messages, dbs.preprompts["use_qa"], step_name=curr_fn())
to_files(messages[-1]["content"], dbs.workspace)
return messages

Expand Down Expand Up @@ -224,6 +230,7 @@ def gen_entrypoint(ai: AI, dbs: DBs) -> List[dict]:
"if necessary.\n"
),
user="Information about the codebase:\n\n" + dbs.workspace["all_output.txt"],
step_name=curr_fn(),
)
print()

Expand All @@ -240,7 +247,7 @@ def use_feedback(ai: AI, dbs: DBs):
ai.fassistant(dbs.workspace["all_output.txt"]),
ai.fsystem(dbs.preprompts["use_feedback"]),
]
messages = ai.next(messages, dbs.input["feedback"])
messages = ai.next(messages, dbs.input["feedback"], step_name=curr_fn())
to_files(messages[-1]["content"], dbs.workspace)
return messages

Expand All @@ -253,7 +260,9 @@ def fix_code(ai: AI, dbs: DBs):
ai.fuser(code_output),
ai.fsystem(dbs.preprompts["fix_code"]),
]
messages = ai.next(messages, "Please fix any errors in the code above.")
messages = ai.next(
messages, "Please fix any errors in the code above.", step_name=curr_fn()
)
to_files(messages[-1]["content"], dbs.workspace)
return messages

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependencies = [
'typer >= 0.3.2',
'rudder-sdk-python == 2.0.2',
'dataclasses-json == 0.5.7',
'tiktoken',
'tabulate == 0.9.0',
]

Expand Down

0 comments on commit 8fd315d

Please sign in to comment.