Skip to content

Commit

Permalink
Feat/evaluations (GoogleCloudPlatform#217)
Browse files Browse the repository at this point in the history
* chore: cleanup reqs

* feat: add map inputs for faster processing

* feat: add playbook processing support; add support for direct intent trigger

* fix: cleanup creds code and remove outdated code

* feat: new eval tool

* feat: add script for making code context file

* feat: add retry decorators

* fix: expand testing folder

* feat: add evaluations support

* fix: linting

* chore: update reqs

* fix: dataframe setup and reporting

* chore: bump version to 1.12.0
  • Loading branch information
kmaphoenix authored Aug 20, 2024
1 parent 2b5f675 commit c838940
Show file tree
Hide file tree
Showing 14 changed files with 2,791 additions and 54 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
*.py[cod]
*.sw[op]

# specific files
code-context.txt

# cloud shell
.theia

Expand Down
7 changes: 5 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ pfreeze:

test:
@if [ -n "$(v)" ]; then \
pytest tests/dfcx_scrapi/core/$(f) -vv; \
pytest tests/dfcx_scrapi/$(f) -vv; \
else \
pytest tests/dfcx_scrapi/core/$(f); \
pytest tests/dfcx_scrapi/$(f); \
fi

lint:
Expand All @@ -38,3 +38,6 @@ fix:
build:
python3 -m build
pip uninstall dfcx-scrapi -y

context-file:
find . -name "*.py" -print0 | xargs -0 -I {} sh -c 'echo "=== {} ==="; cat {}' > code-context.txt
7 changes: 4 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
google-cloud-aiplatform>=1.39.0
google-cloud-dialogflow-cx>=1.34.0
google-cloud-discoveryengine>=0.11.10
google-auth>=2.27.0
google-oauth
oauth2client
pyparsing==2.4.7
pandas
tabulate
gspread==5.10.0
Expand All @@ -14,7 +14,8 @@ pylint==2.8.3
pytest==6.0.2
pytest-cov==2.11.1
pytest-xdist==2.1.0
pyyaml==5.4
pyyaml
rouge-score
torch
transformers
sentencepiece
tqdm
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

setup(
name='dfcx-scrapi',
version='1.11.2',
version='1.12.0',
description='A high level scripting API for bot builders, developers, and\
maintainers.',
long_description=long_description,
Expand All @@ -45,5 +45,6 @@
package_dir={'':'src'},
packages=find_packages(where='src'),
python_requires='>=3.6, <4',
install_requires=['google-cloud-dialogflow-cx']

install_requires=['google-cloud-dialogflow-cx', 'rouge-score']
)
3 changes: 3 additions & 0 deletions src/dfcx_scrapi/core/playbooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
creds_dict: Dict = None,
creds=None,
scope=False,
playbooks_map: Dict[str, str] = None
):
super().__init__(
creds_path=creds_path,
Expand All @@ -59,6 +60,8 @@ def __init__(
credentials=self.creds, client_options=client_options
)

self.playbooks_map = playbooks_map

@staticmethod
def build_instructions_from_list(
instructions: List[str]) -> List[types.Playbook.Step]:
Expand Down
160 changes: 158 additions & 2 deletions src/dfcx_scrapi/core/scrapi_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,56 @@
import logging
import json
import re
import time
import functools
import threading
import vertexai
from collections import defaultdict
from typing import Dict, Any
from typing import Dict, Any, Iterable

from google.api_core import exceptions
from google.cloud.dialogflowcx_v3beta1 import types
from google.oauth2 import service_account
from google.auth.transport.requests import Request
from google.protobuf import json_format # type: ignore
from google.protobuf import json_format
from google.protobuf import field_mask_pb2, struct_pb2

from vertexai.generative_models import GenerativeModel
from vertexai.language_models import TextEmbeddingModel, TextGenerationModel

from proto.marshal.collections import repeated
from proto.marshal.collections import maps

_INTERVAL_SENTINEL = object()

# The following models are supported for Metrics and Evaluations, either for
# Text Embeddings or used to provide Generations / Predictions.
SYS_INSTRUCT_MODELS = [
"gemini-1.0-pro-002",
"gemini-1.5-pro-001",
"gemini-1.5-flash-001"
]

NON_SYS_INSTRUCT_GEM_MODELS = [
"gemini-1.0-pro-001"
]

ALL_GEMINI_MODELS = SYS_INSTRUCT_MODELS + NON_SYS_INSTRUCT_GEM_MODELS

TEXT_GENERATION_MODELS = [
"text-bison@002",
"text-unicorn@001"
]
EMBEDDING_MODELS_NO_DIMENSIONALITY = [
"textembedding-gecko@001",
"textembedding-gecko@003",
"textembedding-gecko-multilingual@001"
]
ALL_EMBEDDING_MODELS = EMBEDDING_MODELS_NO_DIMENSIONALITY + [
"text-embedding-004"
]

ALL_GENERATIVE_MODELS = ALL_GEMINI_MODELS + TEXT_GENERATION_MODELS

class ScrapiBase:
"""Core Class for managing Auth and other shared functions."""
Expand Down Expand Up @@ -354,6 +391,56 @@ def _get_solution_type(solution_type: str) -> int:

return solution_map[solution_type]

@staticmethod
def is_valid_sys_instruct_model(llm_model: str) -> bool:
valid_sys_instruct = True
"""Validate if model allows system instructions."""
if llm_model in NON_SYS_INSTRUCT_GEM_MODELS:
valid_sys_instruct = False

return valid_sys_instruct

def build_generative_model(
self,
llm_model: str,
system_instructions: str = None
) -> GenerativeModel:
"""Build the GenertiveModel object and sys instructions as required."""
valid_sys_intruct = self.is_valid_sys_instruct_model(llm_model)

if valid_sys_intruct and system_instructions:
return GenerativeModel(
llm_model, system_instruction=system_instructions)

elif not valid_sys_intruct and system_instructions:
raise ValueError(
f"Model `{llm_model}` does not support System Instructions"
)
else:
return GenerativeModel(llm_model)

def model_setup(self, llm_model: str, system_instructions: str = None):
"""Create a new LLM instance from user inputs."""
if llm_model in ALL_EMBEDDING_MODELS:
return TextEmbeddingModel.from_pretrained(llm_model)

elif llm_model in ALL_GEMINI_MODELS:
return self.build_generative_model(llm_model, system_instructions)

elif llm_model in TEXT_GENERATION_MODELS:
return TextGenerationModel.from_pretrained(llm_model)

else:
raise ValueError(f"LLM Model `{llm_model}` not supported.")

def init_vertex(self, agent_id: str):
"""Use the Agent ID to parse out relevant fields and init Vertex API."""
parts = self._parse_resource_path("agent", agent_id)
project_id = parts.get("project")
location = parts.get("location")

vertexai.init(project=project_id, location=location)

def _build_data_store_parent(self, location: str) -> str:
"""Build the Parent ID needed for Discovery Engine API calls."""
return (f"projects/{self.project_id}/locations/{location}/collections/"
Expand Down Expand Up @@ -473,3 +560,72 @@ def wrapper(self, *args, **kwargs):
wrapper.calls_api = True

return wrapper

def should_retry(err: exceptions.GoogleAPICallError) -> bool:
"""Helper function for deciding whether we should retry the error or not."""
return isinstance(err, (exceptions.TooManyRequests, exceptions.ServerError))

def ratelimit(rate: float):
"""Decorator that controls the frequency of function calls."""
seconds_per_event = 1.0 / rate
lock = threading.Lock()
bucket = 0
last = 0

def decorate(func):
def rate_limited_function(*args, **kwargs):
nonlocal last, bucket
while True:
with lock:
now = time.time()
bucket += now - last
last = now

# capping the bucket in order to avoid accumulating too many
bucket = min(bucket, seconds_per_event)

# if bucket is less than `seconds_per_event` then we have to wait
# `seconds_per_event` - `bucket` seconds until a new "token" is
# refilled
delay = max(seconds_per_event - bucket, 0)

if delay == 0:
# consuming a token and breaking out of the delay loop to perform
# the function call
bucket -= seconds_per_event
break
time.sleep(delay)
return func(*args, **kwargs)
return rate_limited_function
return decorate

def retry_api_call(retry_intervals: Iterable[float]):
"""Decorator for retrying certain GoogleAPICallError exception types."""
def decorate(func):
def retried_api_call_func(*args, **kwargs):
interval_iterator = iter(retry_intervals)
while True:
try:
return func(*args, **kwargs)
except exceptions.GoogleAPICallError as err:
print(f"retrying api call: {err}")
if not should_retry(err):
raise

interval = next(interval_iterator, _INTERVAL_SENTINEL)
if interval is _INTERVAL_SENTINEL:
raise
time.sleep(interval)
return retried_api_call_func
return decorate

def handle_api_error(func):
"""Decorator that chatches GoogleAPICallError exception and returns None."""
def handled_api_error_func(*args, **kwargs):
try:
return func(*args, **kwargs)
except exceptions.GoogleAPICallError as err:
print(f"failed api call: {err}")
return None

return handled_api_error_func
Loading

0 comments on commit c838940

Please sign in to comment.