-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #173 from nulib/4313-deployment-override
Refactors chat handler for easier configuration handling.
- Loading branch information
Showing
24 changed files
with
855 additions
and
300 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,7 @@ lerna-debug.log* | |
|
||
### Python ### | ||
.coverage | ||
htmlcov | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,58 +1,62 @@ | ||
ifndef VERBOSE | ||
.SILENT: | ||
endif | ||
ENV=dev | ||
|
||
help: | ||
echo "make build | build the SAM project" | ||
echo "make serve | run the SAM server locally" | ||
echo "make clean | remove all installed dependencies and build artifacts" | ||
echo "make deps | install all dependencies" | ||
echo "make link | create hard links to allow for hot reloading of a built project" | ||
echo "make secrets | symlink secrets files from ../tfvars" | ||
echo "make style | run all style checks" | ||
echo "make test | run all tests" | ||
echo "make cover | run all tests with coverage" | ||
echo "make env ENV=[env] | activate env.\$$ENV.json file (default: dev)" | ||
echo "make deps-node | install node dependencies" | ||
echo "make deps-python | install python dependencies" | ||
echo "make style-node | run node code style check" | ||
echo "make style-python | run python code style check" | ||
echo "make test-node | run node tests" | ||
echo "make test-python | run python tests" | ||
echo "make cover-node | run node tests with coverage" | ||
echo "make cover-python | run python tests with coverage" | ||
.aws-sam/build.toml: ./template.yaml node/package-lock.json node/src/package-lock.json python/requirements.txt python/src/requirements.txt | ||
sam build --cached --parallel | ||
deps-node: | ||
cd node && npm ci | ||
cover-node: | ||
cd node && npm run test:coverage | ||
style-node: | ||
cd node && npm run prettier | ||
test-node: | ||
cd node && npm run test | ||
deps-python: | ||
cd chat/src && pip install -r requirements.txt | ||
cover-python: | ||
cd chat/src && coverage run --include='src/**/*' -m unittest -v && coverage report | ||
style-python: | ||
cd chat && ruff check . | ||
test-python: | ||
cd chat && python -m unittest -v | ||
build: .aws-sam/build.toml | ||
link: build | ||
cd chat/src && for src in *.py **/*.py; do for target in $$(find ../../.aws-sam/build -maxdepth 1 -type d); do if [[ -f $$target/$$src ]]; then ln -f $$src $$target/$$src; fi; done; done | ||
cd node/src && for src in *.js *.json **/*.js **/*.json; do for target in $$(find ../../.aws-sam/build -maxdepth 1 -type d); do if [[ -f $$target/$$src ]]; then ln -f $$src $$target/$$src; fi; done; done | ||
serve: link | ||
sam local start-api --host 0.0.0.0 --log-file dc-api.log | ||
deps: deps-node deps-python | ||
style: style-node style-python | ||
test: test-node test-python | ||
cover: cover-node cover-python | ||
env: | ||
ln -fs ./env.${ENV}.json ./env.json | ||
secrets: | ||
ln -s ../tfvars/dc-api/* . | ||
clean: | ||
rm -rf .aws-sam node/node_modules node/src/node_modules python/**/__pycache__ python/.coverage python/.ruff_cache | ||
ifndef VERBOSE | ||
.SILENT: | ||
endif | ||
ENV=dev | ||
|
||
help: | ||
echo "make build | build the SAM project" | ||
echo "make serve | run the SAM server locally" | ||
echo "make clean | remove all installed dependencies and build artifacts" | ||
echo "make deps | install all dependencies" | ||
echo "make link | create hard links to allow for hot reloading of a built project" | ||
echo "make secrets | symlink secrets files from ../tfvars" | ||
echo "make style | run all style checks" | ||
echo "make test | run all tests" | ||
echo "make cover | run all tests with coverage" | ||
echo "make env ENV=[env] | activate env.\$$ENV.json file (default: dev)" | ||
echo "make deps-node | install node dependencies" | ||
echo "make deps-python | install python dependencies" | ||
echo "make style-node | run node code style check" | ||
echo "make style-python | run python code style check" | ||
echo "make test-node | run node tests" | ||
echo "make test-python | run python tests" | ||
echo "make cover-node | run node tests with coverage" | ||
echo "make cover-python | run python tests with coverage" | ||
.aws-sam/build.toml: ./template.yaml node/package-lock.json node/src/package-lock.json chat/dependencies/requirements.txt chat/src/requirements.txt | ||
sam build --cached --parallel | ||
deps-node: | ||
cd node && npm ci | ||
cover-node: | ||
cd node && npm run test:coverage | ||
style-node: | ||
cd node && npm run prettier | ||
test-node: | ||
cd node && npm run test | ||
deps-python: | ||
cd chat/src && pip install -r requirements.txt | ||
cover-python: deps-python | ||
cd chat && export SKIP_WEAVIATE_SETUP=True && coverage run --source=src -m unittest -v && coverage report --skip-empty | ||
cover-html-python: deps-python | ||
cd chat && export SKIP_WEAVIATE_SETUP=True && coverage run --source=src -m unittest -v && coverage html --skip-empty | ||
style-python: deps-python | ||
cd chat && ruff check . | ||
test-python: deps-python | ||
cd chat && export SKIP_WEAVIATE_SETUP=True && PYTHONPATH=src:test && python -m unittest discover -v | ||
python-version: | ||
cd chat && python --version | ||
build: .aws-sam/build.toml | ||
link: build | ||
cd chat/src && for src in *.py **/*.py; do for target in $$(find ../../.aws-sam/build -maxdepth 1 -type d); do if [[ -f $$target/$$src ]]; then ln -f $$src $$target/$$src; fi; done; done | ||
cd node/src && for src in *.js *.json **/*.js **/*.json; do for target in $$(find ../../.aws-sam/build -maxdepth 1 -type d); do if [[ -f $$target/$$src ]]; then ln -f $$src $$target/$$src; fi; done; done | ||
serve: link | ||
sam local start-api --host 0.0.0.0 --log-file dc-api.log | ||
deps: deps-node deps-python | ||
style: style-node style-python | ||
test: test-node test-python | ||
cover: cover-node cover-python | ||
env: | ||
ln -fs ./env.${ENV}.json ./env.json | ||
secrets: | ||
ln -s ../tfvars/dc-api/* . | ||
clean: | ||
rm -rf .aws-sam node/node_modules node/src/node_modules python/**/__pycache__ python/.coverage python/.ruff_cache |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,216 @@ | ||
import os | ||
import json | ||
|
||
from dataclasses import dataclass, field | ||
from langchain.chains.qa_with_sources import load_qa_with_sources_chain | ||
from langchain.prompts import PromptTemplate | ||
from setup import ( | ||
weaviate_client, | ||
weaviate_vector_store, | ||
openai_chat_client, | ||
) | ||
from typing import List | ||
from handlers.streaming_socket_callback_handler import StreamingSocketCallbackHandler | ||
from helpers.apitoken import ApiToken | ||
from helpers.prompts import document_template, prompt_template | ||
from websocket import Websocket | ||
|
||
|
||
CHAIN_TYPE = "stuff" | ||
DOCUMENT_VARIABLE_NAME = "context" | ||
INDEX_NAME = "DCWork" | ||
K_VALUE = 10 | ||
MAX_K = 100 | ||
TEMPERATURE = 0.2 | ||
TEXT_KEY = "title" | ||
VERSION = "2023-07-01-preview" | ||
|
||
|
||
@dataclass | ||
class EventConfig: | ||
""" | ||
The EventConfig class represents the configuration for an event. | ||
Default values are set for the following properties which can be overridden in the payload message. | ||
""" | ||
|
||
api_token: ApiToken = field(init=False) | ||
attributes: List[str] = field(init=False) | ||
azure_endpoint: str = field(init=False) | ||
azure_resource_name: str = field(init=False) | ||
debug_mode: bool = field(init=False) | ||
deployment_name: str = field(init=False) | ||
document_prompt: PromptTemplate = field(init=False) | ||
event: dict = field(default_factory=dict) | ||
index_name: str = field(init=False) | ||
is_logged_in: bool = field(init=False) | ||
k: int = field(init=False) | ||
openai_api_version: str = field(init=False) | ||
payload: dict = field(default_factory=dict) | ||
prompt_text: str = field(init=False) | ||
prompt: PromptTemplate = field(init=False) | ||
question: str = field(init=False) | ||
ref: str = field(init=False) | ||
request_context: dict = field(init=False) | ||
temperature: float = field(init=False) | ||
socket: Websocket = field(init=False, default=None) | ||
text_key: str = field(init=False) | ||
|
||
def __post_init__(self): | ||
self.payload = json.loads(self.event.get("body", "{}")) | ||
self.api_token = ApiToken(signed_token=self.payload.get("auth")) | ||
self.attributes = self._get_attributes() | ||
self.azure_endpoint = self._get_azure_endpoint() | ||
self.azure_resource_name = self._get_azure_resource_name() | ||
self.azure_endpoint = self._get_azure_endpoint() | ||
self.debug_mode = self._is_debug_mode_enabled() | ||
self.deployment_name = self._get_deployment_name() | ||
self.index_name = self._get_index_name() | ||
self.is_logged_in = self.api_token.is_logged_in() | ||
self.k = self._get_k() | ||
self.openai_api_version = self._get_openai_api_version() | ||
self.prompt_text = self._get_prompt_text() | ||
self.request_context = self.event.get("requestContext", {}) | ||
self.question = self.payload.get("question") | ||
self.ref = self.payload.get("ref") | ||
self.temperature = self._get_temperature() | ||
self.text_key = self._get_text_key() | ||
self.attributes = self._get_attributes() | ||
self.document_prompt = self._get_document_prompt() | ||
self.prompt = PromptTemplate(template=self.prompt_text, input_variables=["question", "context"]) | ||
|
||
def _get_payload_value_with_superuser_check(self, key, default): | ||
if self.api_token.is_superuser(): | ||
return self.payload.get(key, default) | ||
else: | ||
return default | ||
|
||
def _get_azure_endpoint(self): | ||
default = f"https://{self._get_azure_resource_name()}.openai.azure.com/" | ||
return self._get_payload_value_with_superuser_check("azure_endpoint", default) | ||
|
||
def _get_azure_resource_name(self): | ||
azure_resource_name = self._get_payload_value_with_superuser_check("azure_resource_name", os.environ.get("AZURE_OPENAI_RESOURCE_NAME")) | ||
if not azure_resource_name: | ||
raise EnvironmentError( | ||
"Either payload must contain 'azure_resource_name' or environment variable 'AZURE_OPENAI_RESOURCE_NAME' must be set" | ||
) | ||
return azure_resource_name | ||
|
||
def _get_deployment_name(self): | ||
return self._get_payload_value_with_superuser_check("deployment_name", os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT_ID")) | ||
|
||
def _get_index_name(self): | ||
return self._get_payload_value_with_superuser_check("index", INDEX_NAME) | ||
|
||
def _get_k(self): | ||
value = self._get_payload_value_with_superuser_check("k", K_VALUE) | ||
return min(value, MAX_K) | ||
|
||
def _get_openai_api_version(self): | ||
return self._get_payload_value_with_superuser_check("openai_api_version", VERSION) | ||
|
||
def _get_prompt_text(self): | ||
return self._get_payload_value_with_superuser_check("prompt", prompt_template()) | ||
|
||
def _get_temperature(self): | ||
return self._get_payload_value_with_superuser_check("temperature", TEMPERATURE) | ||
|
||
def _get_text_key(self): | ||
return self._get_payload_value_with_superuser_check("text_key", TEXT_KEY) | ||
|
||
def _get_attributes(self): | ||
attributes = [ | ||
item | ||
for item in self._get_request_attributes() | ||
if item not in [self._get_text_key(), "source"] | ||
] | ||
return attributes | ||
|
||
def _get_request_attributes(self): | ||
if os.getenv("SKIP_WEAVIATE_SETUP"): | ||
return [] | ||
|
||
attributes = self._get_payload_value_with_superuser_check("attributes", []) | ||
if attributes: | ||
return attributes | ||
else: | ||
client = weaviate_client() | ||
schema = client.schema.get(self._get_index_name()) | ||
names = [prop["name"] for prop in schema.get("properties")] | ||
return names | ||
|
||
def _get_document_prompt(self): | ||
return PromptTemplate( | ||
template=document_template(self.attributes), | ||
input_variables=["page_content", "source"] + self.attributes, | ||
) | ||
|
||
def debug_message(self): | ||
return { | ||
"type": "debug", | ||
"message": { | ||
"attributes": self.attributes, | ||
"azure_endpoint": self.azure_endpoint, | ||
"deployment_name": self.deployment_name, | ||
"index": self.index_name, | ||
"k": self.k, | ||
"openai_api_version": self.openai_api_version, | ||
"prompt": self.prompt_text, | ||
"question": self.question, | ||
"ref": self.ref, | ||
"temperature": self.temperature, | ||
"text_key": self.text_key, | ||
}, | ||
} | ||
|
||
def setup_websocket(self, socket=None): | ||
if socket is None: | ||
connection_id = self.request_context.get("connectionId") | ||
endpoint_url = f'https://{self.request_context.get("domainName")}/{self.request_context.get("stage")}' | ||
self.socket = Websocket(endpoint_url=endpoint_url, connection_id=connection_id, ref=self.ref) | ||
else: | ||
self.socket = socket | ||
return self.socket | ||
|
||
def setup_llm_request(self): | ||
self._setup_vector_store() | ||
self._setup_chat_client() | ||
self._setup_chain() | ||
|
||
def _setup_vector_store(self): | ||
self.weaviate = weaviate_vector_store( | ||
index_name=self.index_name, | ||
text_key=self.text_key, | ||
attributes=self.attributes + ["source"], | ||
) | ||
|
||
def _setup_chat_client(self): | ||
self.client = openai_chat_client( | ||
deployment_name=self.deployment_name, | ||
openai_api_base=self.azure_endpoint, | ||
openai_api_version=self.openai_api_version, | ||
callbacks=[StreamingSocketCallbackHandler(self.socket, self.debug_mode)], | ||
streaming=True, | ||
) | ||
|
||
def _setup_chain(self): | ||
self.chain = load_qa_with_sources_chain( | ||
self.client, | ||
chain_type=CHAIN_TYPE, | ||
prompt=self.prompt, | ||
document_prompt=self.document_prompt, | ||
document_variable_name=DOCUMENT_VARIABLE_NAME, | ||
verbose=self._to_bool(os.getenv("VERBOSE")), | ||
) | ||
|
||
def _is_debug_mode_enabled(self): | ||
debug = self.payload.get("debug", False) | ||
return debug and self.api_token.is_superuser() | ||
|
||
def _to_bool(self, val): | ||
"""Converts a value to boolean. If the value is a string, it considers | ||
"", "no", "false", "0" as False. Otherwise, it returns the boolean of the value. | ||
""" | ||
if isinstance(val, str): | ||
return val.lower() not in ["", "no", "false", "0"] | ||
return bool(val) |
Empty file.
Oops, something went wrong.