Skip to content

Commit

Permalink
Merge pull request #173 from nulib/4313-deployment-override
Browse files Browse the repository at this point in the history
Refactors chat handler for easier configuration handling.
  • Loading branch information
bmquinn authored Feb 5, 2024
2 parents 0dd7f7f + e667a7a commit 1d7908a
Show file tree
Hide file tree
Showing 24 changed files with 855 additions and 300 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test-python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ jobs:
env:
AWS_ACCESS_KEY_ID: ci
AWS_SECRET_ACCESS_KEY: ci
SKIP_WEAVIATE_SETUP: 'True'
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ lerna-debug.log*

### Python ###
.coverage
htmlcov
__pycache__/
*.py[cod]
*$py.class
Expand Down
120 changes: 62 additions & 58 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,58 +1,62 @@
ifndef VERBOSE
.SILENT:
endif
ENV=dev

help:
echo "make build | build the SAM project"
echo "make serve | run the SAM server locally"
echo "make clean | remove all installed dependencies and build artifacts"
echo "make deps | install all dependencies"
echo "make link | create hard links to allow for hot reloading of a built project"
echo "make secrets | symlink secrets files from ../tfvars"
echo "make style | run all style checks"
echo "make test | run all tests"
echo "make cover | run all tests with coverage"
echo "make env ENV=[env] | activate env.\$$ENV.json file (default: dev)"
echo "make deps-node | install node dependencies"
echo "make deps-python | install python dependencies"
echo "make style-node | run node code style check"
echo "make style-python | run python code style check"
echo "make test-node | run node tests"
echo "make test-python | run python tests"
echo "make cover-node | run node tests with coverage"
echo "make cover-python | run python tests with coverage"
.aws-sam/build.toml: ./template.yaml node/package-lock.json node/src/package-lock.json python/requirements.txt python/src/requirements.txt
sam build --cached --parallel
deps-node:
cd node && npm ci
cover-node:
cd node && npm run test:coverage
style-node:
cd node && npm run prettier
test-node:
cd node && npm run test
deps-python:
cd chat/src && pip install -r requirements.txt
cover-python:
cd chat/src && coverage run --include='src/**/*' -m unittest -v && coverage report
style-python:
cd chat && ruff check .
test-python:
cd chat && python -m unittest -v
build: .aws-sam/build.toml
link: build
cd chat/src && for src in *.py **/*.py; do for target in $$(find ../../.aws-sam/build -maxdepth 1 -type d); do if [[ -f $$target/$$src ]]; then ln -f $$src $$target/$$src; fi; done; done
cd node/src && for src in *.js *.json **/*.js **/*.json; do for target in $$(find ../../.aws-sam/build -maxdepth 1 -type d); do if [[ -f $$target/$$src ]]; then ln -f $$src $$target/$$src; fi; done; done
serve: link
sam local start-api --host 0.0.0.0 --log-file dc-api.log
deps: deps-node deps-python
style: style-node style-python
test: test-node test-python
cover: cover-node cover-python
env:
ln -fs ./env.${ENV}.json ./env.json
secrets:
ln -s ../tfvars/dc-api/* .
clean:
rm -rf .aws-sam node/node_modules node/src/node_modules python/**/__pycache__ python/.coverage python/.ruff_cache
ifndef VERBOSE
.SILENT:
endif
ENV=dev

help:
echo "make build | build the SAM project"
echo "make serve | run the SAM server locally"
echo "make clean | remove all installed dependencies and build artifacts"
echo "make deps | install all dependencies"
echo "make link | create hard links to allow for hot reloading of a built project"
echo "make secrets | symlink secrets files from ../tfvars"
echo "make style | run all style checks"
echo "make test | run all tests"
echo "make cover | run all tests with coverage"
echo "make env ENV=[env] | activate env.\$$ENV.json file (default: dev)"
echo "make deps-node | install node dependencies"
echo "make deps-python | install python dependencies"
echo "make style-node | run node code style check"
echo "make style-python | run python code style check"
echo "make test-node | run node tests"
echo "make test-python | run python tests"
echo "make cover-node | run node tests with coverage"
echo "make cover-python | run python tests with coverage"
.aws-sam/build.toml: ./template.yaml node/package-lock.json node/src/package-lock.json chat/dependencies/requirements.txt chat/src/requirements.txt
sam build --cached --parallel
deps-node:
cd node && npm ci
cover-node:
cd node && npm run test:coverage
style-node:
cd node && npm run prettier
test-node:
cd node && npm run test
deps-python:
cd chat/src && pip install -r requirements.txt
cover-python: deps-python
cd chat && export SKIP_WEAVIATE_SETUP=True && coverage run --source=src -m unittest -v && coverage report --skip-empty
cover-html-python: deps-python
cd chat && export SKIP_WEAVIATE_SETUP=True && coverage run --source=src -m unittest -v && coverage html --skip-empty
style-python: deps-python
cd chat && ruff check .
test-python: deps-python
cd chat && export SKIP_WEAVIATE_SETUP=True && PYTHONPATH=src:test && python -m unittest discover -v
python-version:
cd chat && python --version
build: .aws-sam/build.toml
link: build
cd chat/src && for src in *.py **/*.py; do for target in $$(find ../../.aws-sam/build -maxdepth 1 -type d); do if [[ -f $$target/$$src ]]; then ln -f $$src $$target/$$src; fi; done; done
cd node/src && for src in *.js *.json **/*.js **/*.json; do for target in $$(find ../../.aws-sam/build -maxdepth 1 -type d); do if [[ -f $$target/$$src ]]; then ln -f $$src $$target/$$src; fi; done; done
serve: link
sam local start-api --host 0.0.0.0 --log-file dc-api.log
deps: deps-node deps-python
style: style-node style-python
test: test-node test-python
cover: cover-node cover-python
env:
ln -fs ./env.${ENV}.json ./env.json
secrets:
ln -s ../tfvars/dc-api/* .
clean:
rm -rf .aws-sam node/node_modules node/src/node_modules python/**/__pycache__ python/.coverage python/.ruff_cache
3 changes: 1 addition & 2 deletions chat/dependencies/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
boto3~=1.34.13
langchain~=0.0.208
nbformat~=5.9.0
openai~=0.27.8
pandas~=2.0.2
pyjwt~=2.6.0
python-dotenv~=1.0.0
tiktoken~=0.4.0
Expand Down
216 changes: 216 additions & 0 deletions chat/src/event_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
import os
import json

from dataclasses import dataclass, field
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
from langchain.prompts import PromptTemplate
from setup import (
weaviate_client,
weaviate_vector_store,
openai_chat_client,
)
from typing import List
from handlers.streaming_socket_callback_handler import StreamingSocketCallbackHandler
from helpers.apitoken import ApiToken
from helpers.prompts import document_template, prompt_template
from websocket import Websocket


CHAIN_TYPE = "stuff"
DOCUMENT_VARIABLE_NAME = "context"
INDEX_NAME = "DCWork"
K_VALUE = 10
MAX_K = 100
TEMPERATURE = 0.2
TEXT_KEY = "title"
VERSION = "2023-07-01-preview"


@dataclass
class EventConfig:
"""
The EventConfig class represents the configuration for an event.
Default values are set for the following properties which can be overridden in the payload message.
"""

api_token: ApiToken = field(init=False)
attributes: List[str] = field(init=False)
azure_endpoint: str = field(init=False)
azure_resource_name: str = field(init=False)
debug_mode: bool = field(init=False)
deployment_name: str = field(init=False)
document_prompt: PromptTemplate = field(init=False)
event: dict = field(default_factory=dict)
index_name: str = field(init=False)
is_logged_in: bool = field(init=False)
k: int = field(init=False)
openai_api_version: str = field(init=False)
payload: dict = field(default_factory=dict)
prompt_text: str = field(init=False)
prompt: PromptTemplate = field(init=False)
question: str = field(init=False)
ref: str = field(init=False)
request_context: dict = field(init=False)
temperature: float = field(init=False)
socket: Websocket = field(init=False, default=None)
text_key: str = field(init=False)

def __post_init__(self):
self.payload = json.loads(self.event.get("body", "{}"))
self.api_token = ApiToken(signed_token=self.payload.get("auth"))
self.attributes = self._get_attributes()
self.azure_endpoint = self._get_azure_endpoint()
self.azure_resource_name = self._get_azure_resource_name()
self.azure_endpoint = self._get_azure_endpoint()
self.debug_mode = self._is_debug_mode_enabled()
self.deployment_name = self._get_deployment_name()
self.index_name = self._get_index_name()
self.is_logged_in = self.api_token.is_logged_in()
self.k = self._get_k()
self.openai_api_version = self._get_openai_api_version()
self.prompt_text = self._get_prompt_text()
self.request_context = self.event.get("requestContext", {})
self.question = self.payload.get("question")
self.ref = self.payload.get("ref")
self.temperature = self._get_temperature()
self.text_key = self._get_text_key()
self.attributes = self._get_attributes()
self.document_prompt = self._get_document_prompt()
self.prompt = PromptTemplate(template=self.prompt_text, input_variables=["question", "context"])

def _get_payload_value_with_superuser_check(self, key, default):
if self.api_token.is_superuser():
return self.payload.get(key, default)
else:
return default

def _get_azure_endpoint(self):
default = f"https://{self._get_azure_resource_name()}.openai.azure.com/"
return self._get_payload_value_with_superuser_check("azure_endpoint", default)

def _get_azure_resource_name(self):
azure_resource_name = self._get_payload_value_with_superuser_check("azure_resource_name", os.environ.get("AZURE_OPENAI_RESOURCE_NAME"))
if not azure_resource_name:
raise EnvironmentError(
"Either payload must contain 'azure_resource_name' or environment variable 'AZURE_OPENAI_RESOURCE_NAME' must be set"
)
return azure_resource_name

def _get_deployment_name(self):
return self._get_payload_value_with_superuser_check("deployment_name", os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT_ID"))

def _get_index_name(self):
return self._get_payload_value_with_superuser_check("index", INDEX_NAME)

def _get_k(self):
value = self._get_payload_value_with_superuser_check("k", K_VALUE)
return min(value, MAX_K)

def _get_openai_api_version(self):
return self._get_payload_value_with_superuser_check("openai_api_version", VERSION)

def _get_prompt_text(self):
return self._get_payload_value_with_superuser_check("prompt", prompt_template())

def _get_temperature(self):
return self._get_payload_value_with_superuser_check("temperature", TEMPERATURE)

def _get_text_key(self):
return self._get_payload_value_with_superuser_check("text_key", TEXT_KEY)

def _get_attributes(self):
attributes = [
item
for item in self._get_request_attributes()
if item not in [self._get_text_key(), "source"]
]
return attributes

def _get_request_attributes(self):
if os.getenv("SKIP_WEAVIATE_SETUP"):
return []

attributes = self._get_payload_value_with_superuser_check("attributes", [])
if attributes:
return attributes
else:
client = weaviate_client()
schema = client.schema.get(self._get_index_name())
names = [prop["name"] for prop in schema.get("properties")]
return names

def _get_document_prompt(self):
return PromptTemplate(
template=document_template(self.attributes),
input_variables=["page_content", "source"] + self.attributes,
)

def debug_message(self):
return {
"type": "debug",
"message": {
"attributes": self.attributes,
"azure_endpoint": self.azure_endpoint,
"deployment_name": self.deployment_name,
"index": self.index_name,
"k": self.k,
"openai_api_version": self.openai_api_version,
"prompt": self.prompt_text,
"question": self.question,
"ref": self.ref,
"temperature": self.temperature,
"text_key": self.text_key,
},
}

def setup_websocket(self, socket=None):
if socket is None:
connection_id = self.request_context.get("connectionId")
endpoint_url = f'https://{self.request_context.get("domainName")}/{self.request_context.get("stage")}'
self.socket = Websocket(endpoint_url=endpoint_url, connection_id=connection_id, ref=self.ref)
else:
self.socket = socket
return self.socket

def setup_llm_request(self):
self._setup_vector_store()
self._setup_chat_client()
self._setup_chain()

def _setup_vector_store(self):
self.weaviate = weaviate_vector_store(
index_name=self.index_name,
text_key=self.text_key,
attributes=self.attributes + ["source"],
)

def _setup_chat_client(self):
self.client = openai_chat_client(
deployment_name=self.deployment_name,
openai_api_base=self.azure_endpoint,
openai_api_version=self.openai_api_version,
callbacks=[StreamingSocketCallbackHandler(self.socket, self.debug_mode)],
streaming=True,
)

def _setup_chain(self):
self.chain = load_qa_with_sources_chain(
self.client,
chain_type=CHAIN_TYPE,
prompt=self.prompt,
document_prompt=self.document_prompt,
document_variable_name=DOCUMENT_VARIABLE_NAME,
verbose=self._to_bool(os.getenv("VERBOSE")),
)

def _is_debug_mode_enabled(self):
debug = self.payload.get("debug", False)
return debug and self.api_token.is_superuser()

def _to_bool(self, val):
"""Converts a value to boolean. If the value is a string, it considers
"", "no", "false", "0" as False. Otherwise, it returns the boolean of the value.
"""
if isinstance(val, str):
return val.lower() not in ["", "no", "false", "0"]
return bool(val)
Empty file added chat/src/handlers/__init__.py
Empty file.
Loading

0 comments on commit 1d7908a

Please sign in to comment.