diff --git a/.github/workflows/linter.yml b/.github/workflows/linter.yml index 4a017a1b..e564f56b 100644 --- a/.github/workflows/linter.yml +++ b/.github/workflows/linter.yml @@ -1,4 +1,4 @@ -name: Pylint +name: Ruff on: push: @@ -21,7 +21,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install pylint==3.1.0 - - name: Running Pylint + pip install ruff + - name: Running Ruff Check run: | - pylint --rcfile=.pylintrc $(git ls-files '*.py') + ruff check $(git ls-files '*.py') diff --git a/.github/workflows/tester.yml b/.github/workflows/tester.yml new file mode 100644 index 00000000..06bf8af7 --- /dev/null +++ b/.github/workflows/tester.yml @@ -0,0 +1,29 @@ +name: pytest + +on: + push: + branches: + - '**' + pull_request: + branches: + - main + +jobs: + test: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.8" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pytest pandas tqdm + pip install -e . + pip install -r requirements.txt + - name: Running Tests + run: | + pytest tests/dfcx_scrapi/ -vv diff --git a/.gitignore b/.gitignore index cf1d464f..33875a32 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,9 @@ *.py[cod] *.sw[op] +# specific files +code-context.txt + # cloud shell .theia diff --git a/Makefile b/Makefile index f6a24f16..c5399954 100644 --- a/Makefile +++ b/Makefile @@ -11,11 +11,14 @@ pfreeze: pip freeze > requirements.txt test: - pytest tests/dfcx_scrapi/core/$(f) + @if [ -n "$(v)" ]; then \ + pytest tests/dfcx_scrapi/$(f) -vv; \ + else \ + pytest tests/dfcx_scrapi/$(f); \ + fi lint: - pylint --rcfile=.pylintrc src/dfcx_scrapi/* - pylint --rcfile=.pylintrc tests/dfcx_scrapi/* + ruff check # just fix selected whitespace autofix-min-whitespace: @@ -35,3 +38,6 @@ fix: build: python3 -m build pip uninstall dfcx-scrapi -y + +context-file: + find . -name "*.py" -print0 | xargs -0 -I {} sh -c 'echo "=== {} ==="; cat {}' > code-context.txt diff --git a/examples/dfcx_agent_cicd/cicd_code/UAT/__init__.py b/examples/dfcx_agent_cicd/cicd_code/UAT/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/dfcx_agent_cicd/cicd_code/UAT/cloudbuild_deploy.yaml b/examples/dfcx_agent_cicd/cicd_code/UAT/cloudbuild_deploy.yaml new file mode 100644 index 00000000..ad5621dd --- /dev/null +++ b/examples/dfcx_agent_cicd/cicd_code/UAT/cloudbuild_deploy.yaml @@ -0,0 +1,49 @@ +steps: + + - id: SHAGCSCopy + name: gcr.io/google.com/cloudsdktool/cloud-sdk + #dir: 'set your path till the readme doc in the git' + entrypoint: /bin/bash + args: + - '-c' + - | + chmod 777 UAT/gcssha.sh + UAT/gcssha.sh $COMMIT_SHA + + - id: deployagent + name: 'python:3.10' + #dir: 'set your path till the readme doc in the git' + entrypoint: /bin/bash + args: + - -c + - | + pip3 install -r UAT/requirements.txt + python3 -m UAT.deploy $COMMIT_SHA + echo $? + + - id: CheckExitCode + name: 'gcr.io/cloud-builders/gcloud' + #dir: 'set your path till the readme doc in the git' + entrypoint: 'bash' + args: + - '-c' + - | + if [[ "$$BUILD_STATUS" -ne 0 ]]; then + echo "Stopping the build due to a previous failure." + exit 1 + fi + + + - id: triggerproddeploy + name: gcr.io/google.com/cloudsdktool/cloud-sdk + #dir: 'set your path till the readme doc in the git' + entrypoint: /bin/bash + args: + - '-c' + - | + chmod 777 UAT/trigger.sh + UAT/trigger.sh $LOCATION $COMMIT_SHA + + +options: + logging: CLOUD_LOGGING_ONLY \ No newline at end of file diff --git a/examples/dfcx_agent_cicd/cicd_code/UAT/deploy.py b/examples/dfcx_agent_cicd/cicd_code/UAT/deploy.py new file mode 100644 index 00000000..56fffe96 --- /dev/null +++ b/examples/dfcx_agent_cicd/cicd_code/UAT/deploy.py @@ -0,0 +1,71 @@ +""" UAT Deployment functions""" + +import sys +import json +import logging + +from shared.deployment import Deployment + + +#from .shared.deployments import Deployment +# logging config +logging.basicConfig( + level=logging.INFO, + format="UAT: %(asctime)s %(levelname)-8s %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) + + +def main(data): + """ + Deploys and tests a Dialogflow CX agent in a UAT environment. + + This function performs the following steps: + + 1. Initializes a Deployment object with the provided data. + 2. Imports the agent to the specified UAT webhook environment. + 3. Validates test cases. + 4. Collects flow IDs. + 5. Deletes versions based on count. + 6. Cuts a new version. + 7. Deploys the new version. + 8. Updates the datastore with UAT information. + + Args: + data: A dictionary containing configuration data, including the 'uat_webhook_env' key. + """ + + dep=Deployment(data) + # call the steps sequentially + dep.import_agent(webhookenv=data["uat_webhook_env"]) + dep.test_case_validation() + dep.collect_flow_id() + dep.version_count_delete() + dep.version_cut() + dep.deploy_versions() + dep.datastore_update("uat") + + + +if __name__=="__main__": + # read env variables + with open("config.json" , encoding='utf-8') as config_file: + config = json.load(config_file) + SHA_ID=sys.argv[1] + obj=f"UAT/{config['agent_name']}/{SHA_ID}" + sha_gs_loc=( + f"gs://{config['bucket']}/UAT/{config['agent_name']}/{SHA_ID}" + ) + logging.info("Agent location: %s" ,sha_gs_loc) + #adding additional variables to dict + config["sha_agent_gcs_location"]=sha_gs_loc + config["target_project_id"] = config["uat_project"] + config["target_environment_name"]=config["uat_env_deploy"] + with open("agent_artifacts/metadata.json" , encoding='utf-8') as metadata_file: + metadata = json.load(metadata_file) + + config["source_flow_names"]=metadata["source_flow_names"] + config["updated_commit_message"]=metadata["updated_commit_message"] + + # To execute steps in order + main(config) diff --git a/examples/dfcx_agent_cicd/cicd_code/UAT/gcssha.sh b/examples/dfcx_agent_cicd/cicd_code/UAT/gcssha.sh new file mode 100644 index 00000000..88ee0712 --- /dev/null +++ b/examples/dfcx_agent_cicd/cicd_code/UAT/gcssha.sh @@ -0,0 +1,21 @@ + +# Set your GCS bucket name and destination directory +apt-get update && apt-get install -y jq +export GCS_BUCKET=$(jq -r .bucket config.json) +export agent_name=$(jq -r .agent_name config.json) +export DESTINATION_DIR="UAT/${agent_name}/" +echo $DESTINATION_DIR +# Create a local directory +mkdir -p $1 + +# Copy your two files to the local directory +cp agent_artifacts/$agent_name $1 +cp agent_artifacts/metadata.json $1 + +# Upload the local directory to GCS +gsutil -m cp -r $1 "gs://$GCS_BUCKET/$DESTINATION_DIR" + +# Clean up the local directory if needed +rm -r $1 + +echo "Files copied and uploaded to GCS." \ No newline at end of file diff --git a/examples/dfcx_agent_cicd/cicd_code/UAT/requirements.txt b/examples/dfcx_agent_cicd/cicd_code/UAT/requirements.txt new file mode 100644 index 00000000..62207d27 --- /dev/null +++ b/examples/dfcx_agent_cicd/cicd_code/UAT/requirements.txt @@ -0,0 +1,3 @@ +dfcx-scrapi +google-cloud-storage +pandas \ No newline at end of file diff --git a/examples/dfcx_agent_cicd/cicd_code/UAT/trigger.sh b/examples/dfcx_agent_cicd/cicd_code/UAT/trigger.sh new file mode 100644 index 00000000..bf83d555 --- /dev/null +++ b/examples/dfcx_agent_cicd/cicd_code/UAT/trigger.sh @@ -0,0 +1,22 @@ +echo $1 +apt-get update && apt-get install -y jq +export devops_project_id=$(jq -r .devops_project config.json) +export prod_project_id=$(jq -r .prod_project config.json) + +#Use below command to trigger the build if manual invokation is used. Since there is no secret , no extra charges + +export build_info=$(gcloud builds triggers run prodbuild --project=$devops_project_id --substitutions=_COMMIT_SHA=$2 --region=$1 --format=json) +echo "devops prod triggerdone" + + +#getting the trigger id of the above trigger + +export prod_build_id=$(echo "$build_info" | jq -r '.metadata.build.id') +echo "build id returned back is" +echo $prod_build_id + + +#Trigger the build in prod project which is used for approval +gcloud builds triggers run prodapprovebuild --project=$devops_project_id --substitutions=_APP_BUILD_ID=$prod_build_id --region=$1 + +echo "prod project approve build triggered" \ No newline at end of file diff --git a/examples/dfcx_agent_cicd/cicd_code/__init__.py b/examples/dfcx_agent_cicd/cicd_code/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/dfcx_agent_cicd/cicd_code/agent_artifacts/.gitkeep b/examples/dfcx_agent_cicd/cicd_code/agent_artifacts/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/examples/dfcx_agent_cicd/cicd_code/approveprod/cloudbuild_appr.yaml b/examples/dfcx_agent_cicd/cicd_code/approveprod/cloudbuild_appr.yaml new file mode 100644 index 00000000..b9bf63dc --- /dev/null +++ b/examples/dfcx_agent_cicd/cicd_code/approveprod/cloudbuild_appr.yaml @@ -0,0 +1,36 @@ +steps: + - name: gcr.io/google.com/cloudsdktool/cloud-sdk + args: + - '-c' + - | + apt-get update && apt-get install -y jq + echo $BUILD_ID + + export devopsprojecthere=$(jq -r .devops_project config.json) + export build_info=$(gcloud builds describe $BUILD_ID --region=us-central1 --format=json) + export approverhere=$(echo "$build_info" | jq -r '.approval.result.approverAccount') + export commenthere=$(echo "$build_info" | jq -r '.approval.result.comment') + export tokenhere=$(gcloud auth print-access-token) + + echo $approverhere + echo $tokenhere + + chmod 777 approveprod/trigger.sh + + sed -i "s/tokenhere/$tokenhere/g" approveprod/trigger.sh + sed -i "s/approverhere/$approverhere/g" approveprod/trigger.sh + sed -i "s/devopsprojecthere/$devopsprojecthere/g" approveprod/trigger.sh + sed -i "s/commenthere/$commenthere/g" approveprod/trigger.sh + sed -i "s/appbuildhere/$_APP_BUILD_ID/g" approveprod/trigger.sh + cat approveprod/trigger.sh + approveprod/trigger.sh + echo $? + echo "prod build approved from code" + + echo "error exit code" + + id: triggerexportbuild + entrypoint: /bin/bash +options: + logging: CLOUD_LOGGING_ONLY + dynamicSubstitutions: true diff --git a/examples/dfcx_agent_cicd/cicd_code/approveprod/trigger.sh b/examples/dfcx_agent_cicd/cicd_code/approveprod/trigger.sh new file mode 100644 index 00000000..1b35ea46 --- /dev/null +++ b/examples/dfcx_agent_cicd/cicd_code/approveprod/trigger.sh @@ -0,0 +1,6 @@ +curl --request POST \ + 'https://cloudbuild.googleapis.com/v1/projects/devopsprojecthere/locations/us-central1/builds/appbuildhere:approve?access_token=tokenhere' \ + --header 'Accept: application/json'\ + --header 'Content-Type:application/json' --data \ + '{"approvalResult":{"decision":"APPROVED","comment":"commenthere","approverAccount":"approverhere"}}' \ + --compressed \ No newline at end of file diff --git a/examples/dfcx_agent_cicd/cicd_code/config.json b/examples/dfcx_agent_cicd/cicd_code/config.json new file mode 100644 index 00000000..70b2a550 --- /dev/null +++ b/examples/dfcx_agent_cicd/cicd_code/config.json @@ -0,0 +1,16 @@ +{ + "agent_name" : "carrental", + "dev_env_pull" : "ready to deploy", + "uat_env_deploy" : "ready to test", + "prod_env_deploy" :"deployed", + "devprodsyncenv" :"deployed", + "bucket": "dfcx_agent_cicd_export", + "dev_project": "yourprojectid", + "uat_project" : "yourprojectid", + "prod_project": "yourprojectid", + "devops_project": "yourprojectid", + "uat_webhook_env": "uat", + "prod_webhook_env": "prod", + "uat_engine_id" :"", + "prod_engine_id" :"" +} \ No newline at end of file diff --git a/examples/dfcx_agent_cicd/cicd_code/export/__init__.py b/examples/dfcx_agent_cicd/cicd_code/export/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/dfcx_agent_cicd/cicd_code/export/cloudbuild_export.yaml b/examples/dfcx_agent_cicd/cicd_code/export/cloudbuild_export.yaml new file mode 100644 index 00000000..90356678 --- /dev/null +++ b/examples/dfcx_agent_cicd/cicd_code/export/cloudbuild_export.yaml @@ -0,0 +1,147 @@ + +availableSecrets: + secretManager: + - versionName: # your version name + env: 'SSH_KEY' + +steps: +# Access the id_github file fvikramvikrom Secret Manager, and setup SSH + - id: mountsshkey + name: 'gcr.io/cloud-builders/git' + #dir: 'set your path till the readme doc in the git' + secretEnv: ['SSH_KEY'] + entrypoint: /bin/bash + args: + - -c + - | + echo "$$SSH_KEY" >> /root/.ssh/id_rsa + chmod 400 /root/.ssh/id_rsa + cp known_hosts.github /root/.ssh/known_hosts + volumes: + - name: 'ssh' + path: /root/.ssh + + # Clone the repository + - id: clonerepo + name: 'gcr.io/cloud-builders/git' + #dir: 'set your path till the readme doc in the git' + args: + - clone + - --recurse-submodules + - git@github.com:$REPO_FULL_NAME + volumes: + - name: 'ssh' + path: /root/.ssh + + - id: limitbuildcheck + name: gcr.io/google.com/cloudsdktool/cloud-sdk + entrypoint: /bin/bash + args: + - -c + - | + export parallelbuild=$(gcloud builds list --region=$LOCATION --filter="substitutions.TRIGGER_NAME=$TRIGGER_NAME AND status=WORKING" --format="value(status)" | wc -l) + export approvebuild=$(gcloud builds list --region=$LOCATION --format="value(status)" --filter="substitutions.TRIGGER_NAME='prodbuild' AND status='PENDING'" | wc -l) + if [ $parallelbuild -gt 1 ] + then + echo "parallel build running. This may corrupt the exported files in GCS location" + exit 1 + else + echo "Proceeding. No other parallel export build" + fi + if [ $approvebuild -gt 0 ] + then + echo "some other build waiting for approval" + exit 1 + else + echo "Proceeding. No builds waiting for approval" + fi + + - id: fetchuser + #dir: 'set your path till the readme doc in the git' + name: gcr.io/google.com/cloudsdktool/cloud-sdk + entrypoint: /bin/bash + args: + - '-c' + - | + echo $BUILD_ID + export buildhere=$BUILD_ID + export trigid=$(gcloud builds describe $BUILD_ID --region=$LOCATION --format="value(buildTriggerId)") + sed -i "s/triggerhere/$trigid/g" export/trigger.sh + chmod 777 export/trigger.sh + export w1=$(export/trigger.sh) + export w2=$(echo $w1 | cut -d " " -f2) + export runnerid=$w2 + export runnername=$(echo $runnerid | cut -d '@' -f 1) + echo $buildhere + echo $runnerid + echo $runnername + pwd + ls + echo $runnername > ./runnername.txt + echo $runnerid > ./runnerid.txt + echo "path of runner id" + pwd + + + - id: Exportgcs + #dir: 'set your path till the readme doc in the git' + name: 'python:3.10' + entrypoint: /bin/bash + args: + - -c + - | + ls + pwd + pip3 install -r export/requirements.txt + export runnerid=$(cat runnerid.txt) + echo "runner id is " + echo $runnerid + python3 -m export.export ${_USERCOMMITMESSAGE} $runnerid + + - id: downloadartifacts + #dir: 'set your path till the readme doc in the git' + name: gcr.io/google.com/cloudsdktool/cloud-sdk + entrypoint: /bin/bash + args: + - -c + - | + apt-get update && apt-get install -y jq + export agent_name=$(jq -r .agent_name config.json) + export bucket_name=$(jq -r .bucket config.json) + echo $agent_name + echo $bucket_name + mkdir agenttemp + gsutil cp "gs://$bucket_name/exports/dev/$agent_name" agenttemp/$agent_name + gsutil cp "gs://$bucket_name/exports/dev/${agent_name}_metadata.json" agenttemp/metadata.json + + - id: csrcheckin + #dir: 'set your path till the readme doc in the git' + name: gcr.io/google.com/cloudsdktool/cloud-sdk + entrypoint: /bin/bash + args: + - -c + - | + export runnerid=$(cat runnerid.txt) + export runnername=$(cat runnername.txt) + + export agent_artifacts_path = $(dirname $(dirname $TRIGGER_BUILD_CONFIG_PATH)) + chmod 777 export/repopush.sh + export/repopush.sh $REPO_NAME $agent_artifacts_path + cd $REPO_NAME/$agent_artifacts_path + ls + cd agent_artifacts + ls + git config --global user.name $runnername + git config --global user.email $runnerid + git add . + git diff --name-only + git commit --allow-empty -m "commited by $runnerid with message ${_USERCOMMITMESSAGE}" + + git push -u origin main + volumes: + - name: 'ssh' + path: /root/.ssh + +options: + logging: CLOUD_LOGGING_ONLY + dynamicSubstitutions: true diff --git a/examples/dfcx_agent_cicd/cicd_code/export/cloudbuild_export_csr.yaml b/examples/dfcx_agent_cicd/cicd_code/export/cloudbuild_export_csr.yaml new file mode 100644 index 00000000..8aa0973d --- /dev/null +++ b/examples/dfcx_agent_cicd/cicd_code/export/cloudbuild_export_csr.yaml @@ -0,0 +1,92 @@ +steps: + + - id: limitbuildcheck + name: gcr.io/google.com/cloudsdktool/cloud-sdk + entrypoint: /bin/bash + args: + - -c + - | + export parallelbuild=$(gcloud builds list --region=us-central1 --filter="substitutions.TRIGGER_NAME=$TRIGGER_NAME AND status=WORKING" --format="value(status)" | wc -l) + export approvebuild=$(gcloud builds list --region=us-central1 --format="value(status)" --filter="substitutions.TRIGGER_NAME='prodbuild' AND status='PENDING'" | wc -l) + if [ $parallelbuild -gt 1 ] + then + echo "parallel build running. This may corrupt the exported files in GCS location" + exit 1 + else + echo "Proceeding. No other parallel export build" + fi + if [ $approvebuild -gt 0 ] + then + echo "some other build waiting for approval" + exit 1 + else + echo "Proceeding. No builds waiting for approval" + fi + + - id: fetchuser + name: gcr.io/google.com/cloudsdktool/cloud-sdk + #dir: your/path/here till the readme dir + entrypoint: /bin/bash + args: + - '-c' + - | + echo $BUILD_ID + export buildhere=$BUILD_ID + export trigid=$(gcloud builds describe $BUILD_ID --region=us-central1 --format="value(buildTriggerId)") + sed -i "s/triggerhere/$trigid/g" export/trigger.sh + chmod 777 export/trigger.sh + export w1=$(export/trigger.sh) + export w2=$(echo $w1 | cut -d " " -f2) + export runnerid=$w2 + export runnername=$(echo $runnerid | cut -d '@' -f 1) + echo $buildhere + echo $runnerid + echo $runnername + pwd + ls + echo $runnername > ./runnername.txt + echo $runnerid > ./runnerid.txt + + + - id: Exportgcs + #dir: your/path/here till the readme dir + name: 'python:3.10' + entrypoint: /bin/bash + args: + - -c + - | + pip3 install -r export/requirements.txt + export runnerid=$(cat runnerid.txt) + python3 -m export.export ${_USERCOMMITMESSAGE} $runnerid + + - id: downloadartifacts + #dir: your/path/here till the readme dir + name: gcr.io/google.com/cloudsdktool/cloud-sdk + entrypoint: /bin/bash + args: + - -c + - | + apt-get update && apt-get install -y jq + export agent_name=$(jq -r .agent_name config.json) + export bucket_name=$(jq -r .bucket config.json) + echo $agent_name + echo $bucket_name + mkdir agenttemp + gsutil cp "gs://$bucket_name/exports/dev/$agent_name" agenttemp/$agent_name + gsutil cp "gs://$bucket_name/exports/dev/${agent_name}_metadata.json" agenttemp/metadata.json + + - id: repocheckin + #dir: your/path/here till the readme dir + name: gcr.io/google.com/cloudsdktool/cloud-sdk + entrypoint: /bin/bash + args: + - -c + - | + export runnerid=$(cat runnerid.txt) + export runnername=$(cat runnername.txt) + chmod 777 export/repopush_csr.sh + export/repopush_csr.sh ${_USERCOMMITMESSAGE} $runnername $runnerid + +options: + logging: CLOUD_LOGGING_ONLY + dynamicSubstitutions: true \ No newline at end of file diff --git a/examples/dfcx_agent_cicd/cicd_code/export/export.py b/examples/dfcx_agent_cicd/cicd_code/export/export.py new file mode 100644 index 00000000..fd2504d9 --- /dev/null +++ b/examples/dfcx_agent_cicd/cicd_code/export/export.py @@ -0,0 +1,131 @@ +""" export functions""" + +import json +import sys +import logging + +from dfcx_scrapi.core.agents import Agents + +from .flow_impacted import Impacted +from google.cloud import storage + +# logging config +logging.basicConfig( + level=logging.INFO, + format="dev: %(asctime)s %(levelname)-8s %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) + +def agent_to_gcs( + agent_name, + project_id, + environment_name, + gsloc): + """Exports a Dialogflow CX agent to Google Cloud Storage (GCS). + + This function exports a specified Dialogflow CX agent and its environment + to a designated location in Google Cloud Storage. + + Args: + agent_name: The display name of the agent to export. + project_id: The ID of the Google Cloud project where the agent resides. + environment_name: The display name of the environment to export. + gsloc: The GCS bucket URI where the agent will be exported. + + Returns: + None + """ + agents=Agents() + + agent_details=agents.get_agent_by_display_name( + display_name=agent_name, + project_id=project_id + ) + + agent_id=agent_details.name + agent_gcs_location=f"{gsloc}/{agent_name}" + + #export will replace exisitng agent in bucket + agents.export_agent(agent_id=agent_id, + gcs_bucket_uri=agent_gcs_location, + environment_display_name=environment_name) + logging.info("Agent export from dev done") + +def meta_to_gcs( + config_data, + flow_names, + commit_message, + gsloc, + agent_name, + gcs_bucket, + version_ids + ): + """Exports metadata to a JSON file in Google Cloud Storage (GCS). + + This function takes configuration data, flow names, a commit message, + GCS location information, agent name, and version IDs, and creates a JSON + file containing this metadata in the specified GCS bucket. + + Args: + config_data: A dictionary containing configuration data. + flow_names: A list of flow names. + commit_message: The commit message to include in the metadata. + gsloc: The full GCS URI where the metadata file will be stored. + agent_name: The name of the agent. + gcs_bucket: The name of the GCS bucket. + version_ids: A list of version IDs. + + Returns: + None + """ + + config_data["source_flow_names"]=flow_names + config_data["impacted_version_ids"]= version_ids + config_data["updated_commit_message"]=commit_message + gcslist=gsloc.split("/") + obj="/".join(gcslist[3:]) + + bucket_obj = storage.Client().get_bucket(gcs_bucket) + + blob = bucket_obj.blob(f"{obj}/{agent_name}_metadata.json") + blob.upload_from_string(data=json.dumps(config_data), + content_type='application/json') + + +if __name__=='__main__': + # read env variables + with open("config.json", encoding='utf-8') as config_file: + config = json.load(config_file) + + source_project_id=config["dev_project"] + source_agent_name=config["agent_name"] + source_environment_name=config["dev_env_pull"] + bucket=config["bucket"] + user_commit_message=sys.argv[1] + userid=sys.argv[2] + #updated_commit_message=f"{user_commit_message} by {userid} for {source_agent_name}" + updated_commit_message = ( + f"{user_commit_message} by {userid} " + f"for {source_agent_name}" + ) + impflows=Impacted(source_project_id=source_project_id, + source_agent_name=source_agent_name, + environment_name=source_environment_name) + imp_flow_map,impacted_version_ids=impflows.check_flow() + source_flow_names=list(imp_flow_map.values()) + source_flow_ids=list(imp_flow_map.keys()) + gs_loc=f"gs://{bucket}/exports/dev" + + logging.info("impacted flow is %(imp_flow_map)s" + , {'imp_flow_map': imp_flow_map}) + + + #Execute in steps + agent_to_gcs(source_agent_name, + source_project_id, + source_environment_name, + gs_loc) + meta_to_gcs(config,source_flow_names, + updated_commit_message,gs_loc, + source_agent_name,bucket,impacted_version_ids) + \ No newline at end of file diff --git a/examples/dfcx_agent_cicd/cicd_code/export/flow_impacted.py b/examples/dfcx_agent_cicd/cicd_code/export/flow_impacted.py new file mode 100644 index 00000000..4d884395 --- /dev/null +++ b/examples/dfcx_agent_cicd/cicd_code/export/flow_impacted.py @@ -0,0 +1,120 @@ +""" Getting impacted flow functions""" + +from dfcx_scrapi.core.agents import Agents +from dfcx_scrapi.core.environments import Environments +from dfcx_scrapi.core.flows import Flows + +from typing import Dict, List + + +class Impacted: + """ + Analyzes and identifies changes in Dialogflow CX agent flows across environment versions. + + This class retrieves information about a specified Dialogflow CX agent and its environment, + including version history and flow details. It then compares the latest two versions to + identify any changes in the flows, providing a mapping of impacted flow IDs and names. + + Attributes: + source_project_id: The ID of the Google Cloud project where the agent resides. + source_agent_name: The display name of the agent. + environment_name: The display name of the agent's environment (default: "ready to deploy"). + + Methods: + filter_flows: (Static method) Filters a flow map based on + differences between two environments. + check_flow: Identifies and returns a dictionary of + changed flows between the latest two versions. + """ + + #Get agent id + + def __init__( + self,source_project_id, + source_agent_name, + environment_name="ready to deploy" + ): + self.env=Environments() + self.flows=Flows() + + self.source_project_id=source_project_id + self.source_agent_name=source_agent_name + self.environment_name=environment_name + self.filtered_dict={} + + agents=Agents() + agent_details=agents.get_agent_by_display_name( + display_name=self.source_agent_name, + project_id=self.source_project_id + ) + + self.agent_id=agent_details.name + + #get environment id + env_details=self.env.get_environment_by_display_name( + display_name=self.environment_name + ,agent_id=self.agent_id + ) + self.env_id=env_details.name + + #get history + self.hist=self.env.lookup_environment_history( + environment_id=self.env_id + ) + + @staticmethod + def filter_flows(env1,env2,flowmap,versions): + """ Returns filtered dict and impacted version ids""" + impacted_flows=[] + for k,v in env1.items(): + if v!=env2.get(k,0): + impacted_flows.append(k) + + filtered_dict = { + k: v for k, v in flowmap.items() + if k.split("/")[-1] in impacted_flows + } + #getting version ids + impacted_version_ids=[] + for ver in versions: + ver=ver.version + flow=ver.split("/")[-3] + if flow in impacted_flows: + impacted_version_ids.append(ver) + + + return filtered_dict,impacted_version_ids + + + + def check_flow( + self + ) -> Dict[str, str]: + #compare latest 2 history + """ + returns map of flow id:flow name which was found to be changed + """ + env1={} + for i in self.hist[0].version_configs: + flow=i.version.split("/")[-3] + version=i.version.split("/")[-1] + env1[flow]=version + + env2={} + if len(self.hist)>1: + for i in self.hist[1].version_configs: + flow=i.version.split("/")[-3] + version=i.version.split("/")[-1] + env2[flow]=version + + #get flow map for id name comparision + flowmap=self.flows.get_flows_map(agent_id=self.agent_id) + + self.filtered_dict,self.impacted_version_ids = Impacted.filter_flows( + env1, + env2, + flowmap, + self.hist[0].version_configs + ) + + return self.filtered_dict,self.impacted_version_ids diff --git a/examples/dfcx_agent_cicd/cicd_code/export/repopush.sh b/examples/dfcx_agent_cicd/cicd_code/export/repopush.sh new file mode 100644 index 00000000..e4d08509 --- /dev/null +++ b/examples/dfcx_agent_cicd/cicd_code/export/repopush.sh @@ -0,0 +1,14 @@ +apt-get update && apt-get install -y jq +export project_id=$(jq -r .devops_project config.json) +export agent_name=$(jq -r .agent_name config.json) +echo $agent_name + +cd $1 +git checkout main +echo "pwd" +pwd +date > agent_artifacts/timestamp.txt +rm agent_artifacts/* +cp /workspace/$2/agenttemp/$agent_name agent_artifacts/ +cp /workspace/$2/agenttemp/metadata.json agent_artifacts/ +date > agent_artifacts/timestamp.txt \ No newline at end of file diff --git a/examples/dfcx_agent_cicd/cicd_code/export/repopush_csr.sh b/examples/dfcx_agent_cicd/cicd_code/export/repopush_csr.sh new file mode 100644 index 00000000..f952cce2 --- /dev/null +++ b/examples/dfcx_agent_cicd/cicd_code/export/repopush_csr.sh @@ -0,0 +1,30 @@ +apt-get update && apt-get install -y jq +export project_id=$(jq -r .devops_project config.json) +export agent_name=$(jq -r .agent_name config.json) +echo $agent_name +cd agenttemp +ls -all +gcloud source repos clone agentcicd --project=$project_id +#git remote add google 'https://source.developers.google.com/p/xxx/r/agentTest' + +cd agentcicd +git checkout main +ls -all + +rm agent_artifacts/* +cp ../$agent_name agent_artifacts/ +cp ../metadata.json agent_artifacts/ +date > agent_artifacts/timestamp.txt +cd agent_artifacts +ls +cd .. +echo $3 +git config --global user.name $2 +git config --global user.email $3 + +git add . +echo "$1" +git diff --name-only +git commit --allow-empty -m "$1" + +git push -u origin main \ No newline at end of file diff --git a/examples/dfcx_agent_cicd/cicd_code/export/requirements.txt b/examples/dfcx_agent_cicd/cicd_code/export/requirements.txt new file mode 100644 index 00000000..62207d27 --- /dev/null +++ b/examples/dfcx_agent_cicd/cicd_code/export/requirements.txt @@ -0,0 +1,3 @@ +dfcx-scrapi +google-cloud-storage +pandas \ No newline at end of file diff --git a/examples/dfcx_agent_cicd/cicd_code/export/trigger.sh b/examples/dfcx_agent_cicd/cicd_code/export/trigger.sh new file mode 100644 index 00000000..46dec1ad --- /dev/null +++ b/examples/dfcx_agent_cicd/cicd_code/export/trigger.sh @@ -0,0 +1 @@ +gcloud logging read 'resource.labels.build_trigger_id="triggerhere" AND protoPayload.methodName="google.devtools.cloudbuild.v1.CloudBuild.RunBuildTrigger"' --limit 20 | grep principalEmail | head -n 1 \ No newline at end of file diff --git a/examples/dfcx_agent_cicd/cicd_code/media/image1.png b/examples/dfcx_agent_cicd/cicd_code/media/image1.png new file mode 100644 index 00000000..0eb79237 Binary files /dev/null and b/examples/dfcx_agent_cicd/cicd_code/media/image1.png differ diff --git a/examples/dfcx_agent_cicd/cicd_code/media/image2.png b/examples/dfcx_agent_cicd/cicd_code/media/image2.png new file mode 100644 index 00000000..f977a083 Binary files /dev/null and b/examples/dfcx_agent_cicd/cicd_code/media/image2.png differ diff --git a/examples/dfcx_agent_cicd/cicd_code/media/image3.png b/examples/dfcx_agent_cicd/cicd_code/media/image3.png new file mode 100644 index 00000000..87d20e63 Binary files /dev/null and b/examples/dfcx_agent_cicd/cicd_code/media/image3.png differ diff --git a/examples/dfcx_agent_cicd/cicd_code/media/image4.png b/examples/dfcx_agent_cicd/cicd_code/media/image4.png new file mode 100644 index 00000000..8131da4b Binary files /dev/null and b/examples/dfcx_agent_cicd/cicd_code/media/image4.png differ diff --git a/examples/dfcx_agent_cicd/cicd_code/media/image5.png b/examples/dfcx_agent_cicd/cicd_code/media/image5.png new file mode 100644 index 00000000..e48c83bd Binary files /dev/null and b/examples/dfcx_agent_cicd/cicd_code/media/image5.png differ diff --git a/examples/dfcx_agent_cicd/cicd_code/media/image6.png b/examples/dfcx_agent_cicd/cicd_code/media/image6.png new file mode 100644 index 00000000..1d162363 Binary files /dev/null and b/examples/dfcx_agent_cicd/cicd_code/media/image6.png differ diff --git a/examples/dfcx_agent_cicd/cicd_code/prod/__init__.py b/examples/dfcx_agent_cicd/cicd_code/prod/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/dfcx_agent_cicd/cicd_code/prod/cloudbuild_deploy.yaml b/examples/dfcx_agent_cicd/cicd_code/prod/cloudbuild_deploy.yaml new file mode 100644 index 00000000..d6925eca --- /dev/null +++ b/examples/dfcx_agent_cicd/cicd_code/prod/cloudbuild_deploy.yaml @@ -0,0 +1,15 @@ +steps: + - id: deployagent + name: 'python:3.10' + #dir: 'set your path till the readme doc in the git' + entrypoint: /bin/bash + args: + - -c + - | + echo "printing recieved variables now" + echo ${_COMMIT_SHA} + pip3 install -r prod/requirements.txt + python3 -m prod.deploy $COMMIT_SHA + +options: + logging: CLOUD_LOGGING_ONLY \ No newline at end of file diff --git a/examples/dfcx_agent_cicd/cicd_code/prod/deploy.py b/examples/dfcx_agent_cicd/cicd_code/prod/deploy.py new file mode 100644 index 00000000..20aca8af --- /dev/null +++ b/examples/dfcx_agent_cicd/cicd_code/prod/deploy.py @@ -0,0 +1,82 @@ +""" Deploy to prod functions """ +import json +import logging +import sys + +from shared.deployment import Deployment + +# logging config +logging.basicConfig( + level=logging.INFO, + format="PROD: %(asctime)s %(levelname)-8s %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) + +def main(data): + """ + Deploys and validates a Dialogflow CX agent in a production environment. + This function orchestrates the deployment and validation of a Dialogflow CX agent + in a production environment. It performs the following key steps: + + 1. Imports the agent to the specified production webhook environment. + 2. Performs a language check on fulfillment entries, parameters, and routes, + specifically for French Canadian ('fr-ca'). + 3. Collects flow IDs. + 4. Manages version count and deletion. + 5. Cuts a new version of the agent. + 6. Deploys the new version to production. + 7. Synchronizes the agent between development and production environments. + 8. Updates the datastore with production information. + + Args: + data: A dictionary containing configuration data, including the 'prod_webhook_env' key. + + Raises: + SystemExit: If the language check fails, indicating missing agent responses. + """ + dep=Deployment(data) + # call the steps sequentially + dep.import_agent(webhookenv=data["prod_webhook_env"]) + + entry,param,route,result=dep.fullfillment_lang_check(lang='fr-ca') + + logging.info("Entry fulfilment is %s",entry) + logging.info("Param fulfilment is %s",param) + logging.info("Route fulfilment is %s",route) + if not result: + print("some pages,parameters, routes does not have agent response") + sys.exit(2) + + dep.collect_flow_id() + dep.version_count_delete() + dep.version_cut() + dep.deploy_versions() + dep.dev_prod_sync() + dep.datastore_update("prod") + + + +if __name__=='__main__': + # read env variables + with open("config.json" , encoding='utf-8') as config_file: + config = json.load(config_file) + + SHA_ID=sys.argv[1] + obj=f"UAT/{config['agent_name']}/{SHA_ID}" + sha_agent_gcs_location=( + f"gs://{config['bucket']}/UAT/{config['agent_name']}/{SHA_ID}" + ) + logging.info("agent location %s", sha_agent_gcs_location) + #adding additional variables to dict + config["sha_agent_gcs_location"]=sha_agent_gcs_location + config["target_project_id"] = config["prod_project"] + config['target_environment_name']=config["prod_env_deploy"] + with open("agent_artifacts/metadata.json" , encoding='utf-8') as metadata_file: + metadata = json.load(metadata_file) + + config["source_flow_names"]=metadata["source_flow_names"] + config["updated_commit_message"]=metadata["updated_commit_message"] + config["impacted_version_ids"]=metadata["impacted_version_ids"] + + # To execute steps in order + main(config) diff --git a/examples/dfcx_agent_cicd/cicd_code/prod/requirements.txt b/examples/dfcx_agent_cicd/cicd_code/prod/requirements.txt new file mode 100644 index 00000000..62207d27 --- /dev/null +++ b/examples/dfcx_agent_cicd/cicd_code/prod/requirements.txt @@ -0,0 +1,3 @@ +dfcx-scrapi +google-cloud-storage +pandas \ No newline at end of file diff --git a/examples/dfcx_agent_cicd/cicd_code/readme.md b/examples/dfcx_agent_cicd/cicd_code/readme.md new file mode 100644 index 00000000..0dbab79e --- /dev/null +++ b/examples/dfcx_agent_cicd/cicd_code/readme.md @@ -0,0 +1,366 @@ +This Document outlines how to use this sample code and set up the CICD +pipeline in your gcp projects. + +# Primer + +Using this CICD pipeline, you can + +- Promote/migrate agent across 3 GCP projects a.k.a dev -\> uat -\> + prod + +Below are the steps happens in each of the 3 projects while promoting +the agent. + +**In Dev project** + +- Export agent from a designated DFCX environment from dev(coming from + config file) + +- Automatically detects which flows are impacted based on history and + in dev project and save it in metadata file + +- Automatically sync the flows in deployed DFCX environments once a + flow is deployed in prod + +**In UAT project** + +- Creates new versions of only those impacted flows in uat and deploy + them to designated DFCX environment(coming from config file) + +- Run test cases relevant to impacted flows and roll back if failed in + UAT + +- Automatically delete the previous older versions of flows in uat and + prod once the limit is reached + +- Update the webhook apigee environment in the webhook url + corresponding to UAT project + +- If there are multiple languages configured in the agent, it will + automatically verify other languages pages against English language + pages to see if agent fulfillments/response are present in every + other languages configured. + +- Gives mechanism for UAT team to approve in UI once testing is + completed and deploy the agent to prod post the approval + +- Automatically sync the flows in deployed DFCX environments once a + flow is deployed in prod + +**In Prod project** + +- Post UAT team approves after UAT testing, It creates new versions of + only those impacted flows in prod and deploy them to designated DFCX + environment(coming from config file) + +- Update the webhook apigee environment in the webhook url + corresponding to prod project + +- Automatically delete the previous older versions of flows in uat and + prod once the limit is reached + +- Automatically deploys the impacted flows and deploy to serving DFCX + environment in prod + +# High level Architecture + +![](media/image1.png) + +# + +# + +# Set up + +## Assumptions: + +1. You have GCP account already + +2. You have created 3 separate projects for dev/uat/prod + +3. You have an agent in dev and created a dummy/empty agent in the same + name in uat and prod project. + +4. You have a git or similar repo to source the code and to do agent + check in to store agent artifacts whenever the build is triggered. + +5. You are going to use same repo to store code artifacts as well as to + store agent artifacts that gets checked in during the build process + +6. You will create all the 3 builds in your dev project. If need be, + you can have all the builds in a centralized project and play around + with IAM service accounts to enable the builds to access the agent + in dev/uat/prod projects for migration + +7. You will generate and put the known_hosts.github file in the repo in + same level as this readme file is + +8. Set the dir parameter in all the build steps of all the 3 yaml if + your repo does not contain this readme file and other folders such + as export,UAT and prod in repo root. If these core files are nested + then set the dir to the path where this readme file is present + +## IAM Permissions + +1. Create a service account and give the following IAM permissions. + +- **Dialogflow CX API Admin** + +- **Dialogflow API Client** + +- **Storage.admin and Storage Object User** + +- **CSR/Git access** + +- **Service Usage Consumer** + +- **Dialogflow \> Dialogflow Test Case Admin** + +- **Dialogflow \> Dialogflow Environment editor** + +- **Cloud Build service account** + +- **Logs Viewer** + +- **Logs Writer** + +- **Cloud Build viewer** + +2. Give the approver person with **cloudbuild.builds.approve** access + in Dev project + +## For the UAT and Prod builds to access UAT and PROD project to deploy the agent (See assumption no. 5), + +- Get the service account the is used by the cloud builds in your Dev + project + +- Go to UAT and Prod projects \> IAM role \> add principal and enter + the service account id you got from previous step and give access to + UAT/PROD service account as Service Usage Consumer **and** + Dialogflow API Admin + +- Give Dev build's service account with **cloudbuild.builds.get** + access + +## Code Repository and Branching Strategy + +This section describes the approach for setting up the repository and +branches for the agents. + +**Source Repository** + +Below is the reason why we need a repository + +1. Cloud Builds need to refer to some place to access the code that it + needs to build + +2. Some Cloud Builds are set to get triggered when an agent artifact is + checked in to the repository automatically. + +3. Maintain an audit trail to see who changed the code + +4. Maintain an audit trail to see who checked in agent artifacts in + repo along with a commit message that explains what the change in + agent/flow is. + +You can use either the GCP's private git ie cloud source repository or +other gits such as github. + +If you use CSR(deprecated for new users) then use the set +export/cloudbuild_export_csr.yaml and repopush_csr.sh files. + +If you use github then use use the set export/cloudbuild_export.yaml and +repopush.sh files. + +## Storage Bucket + +Create a gcs bucket and that will be used by the pipeline for storing +agents while exporting and restoring. Below is how the bucket structure +might look like. + +![](media/image2.png) + +## Cloud Build Configurations + +There are certain configurations that have to be updated that the user +has to fill while triggering Build1 for agent promotion. Following are +the variables that will be defined for the Cloud Build. + +- **\_COMMIT_MESSAGE** - This the URL for the configuring the web-hook + +### + +### Export Build: + +![](media/image3.png) + +### UAT deploy build + +![](media/image4.png) + +### Prod deploy build + +![](media/image5.png) + +## DFCX APIs + +The Python Dialogflow CX Scripting [[API (DFCX +SCRAPI)]](https://github.com/GoogleCloudPlatform/dfcx-scrapi) +is a high level API that extends the official Google [[Python Client for +Dialogflow +CX]](https://github.com/googleapis/python-dialogflow-cx). +SCRAPI makes using DFCX easier, more friendly, and more pythonic for bot +builders, developers, and maintainers. This uses V3/V3beta1 endpoints +under the hood. Since it is more pythonic way of implementation, +developers will find it easy to use SCRAPI API in action. + +In our CI/CD pipeline below operations are achieved using SCRAPI API + +- Find agent ID from name + +- Find flow id from name + +- [[Export the agent to + GCS]](https://github.com/GoogleCloudPlatform/dfcx-scrapi/blob/37cf8cf7b2013a377740f68d8dcb7355632161e0/src/dfcx_scrapi/core/agents.py#L363) + +- [[Restore the + agent]](https://github.com/GoogleCloudPlatform/dfcx-scrapi/blob/37cf8cf7b2013a377740f68d8dcb7355632161e0/src/dfcx_scrapi/core/agents.py#L438) + +- [[Cut a version of a + flow]](https://github.com/GoogleCloudPlatform/dfcx-scrapi/blob/37cf8cf7b2013a377740f68d8dcb7355632161e0/src/dfcx_scrapi/core/versions.py#L183) + +- [[Deploy it to an + environment]](https://github.com/GoogleCloudPlatform/dfcx-scrapi/blob/37cf8cf7b2013a377740f68d8dcb7355632161e0/src/dfcx_scrapi/core/environments.py#L359) + +- [[Run test + cases]](https://github.com/GoogleCloudPlatform/dfcx-scrapi/blob/37cf8cf7b2013a377740f68d8dcb7355632161e0/src/dfcx_scrapi/core/test_cases.py#L410) + +- [[Compare environment + history]](https://github.com/GoogleCloudPlatform/dfcx-scrapi/blob/37cf8cf7b2013a377740f68d8dcb7355632161e0/src/dfcx_scrapi/core/environments.py#L392) + to find impacted flow the current instance of CI/CD builds. + +## To Set up the pipeline + +1. Setup a git or any code repository of your choice to store the code + and agent artifacts + +2. Push the code you see in the SCRAPI + repo(examples/dfcx_agent_cicd/cicd_code) in the parent folder along + with this documentation. + +3. If you fork the scrapi repo use this for your cicd pipeline then in + all yaml files inside export/UAT/Prod folders add dir: path till the + readme doc(I have marked a comment in the yaml files) + +4. If you use CSR(decommissioned for new users after 06/2024) use the + set cloudbuild_export_csr.yaml and repopush_csr.sh(ie you will use + this yaml file as trigger in the export build). But mostly likely + you would want to use github kind of repos and use the default files + cloudbuild_export.yaml and repopush.sh. For the later one you don't + need to make any configuration changes + +5. As mentioned in above point, if you use github kind of repo, you + need to create SSH key and store it in Secrets for the cloudbuild to + do checkin. Please follow this + [documentation](https://cloud.google.com/build/docs/access-github-from-build) + to create ssh and put it in the GCP secrets and generate + known_hosts.github and pushing it your repo in the same level as + this readme file. + +6. Use the config file as a one stop place to initiate values to + variables that will be used throughout the pipeline. Hence this + eases out the maintenance or reusing of the pipeline for different + values. + +{ + +\"agent_name\" : \"carrental\", + +\"dev_env_pull\" : \"ready to deploy\", + +\"uat_env_deploy\" : \"ready to test\", + +\"prod_env_deploy\" :\"deployed\", + +\"devprodsyncenv\" :\"deployed\", + +\"bucket\": \"DFCX_agent_cicd_export\", + +\"dev_project\": \"yourprojectid\", + +\"uat_project\" : \"yourprojectid\", + +\"prod_project\": \"yourprojectid\", + +\"devops_project\": \"yourprojectid\", + +\"uat_webhook_env\": \"uat\", + +\"prod_webhook_env\": \"prod\", + +\"uat_engine_id\" :\"\", + +\"prod_engine_id\" :\"\" + +} + + + +7. Make sure the a GCP bucket is created with said structure and name + is configured in config file + +8. Create 3 cloud builds with the configuration and name as shown in + screenshots in the previous section and attach your repo to these + builds. + +9. Make sure an agent is present in the same name in UAT and Prod(if it + is first time, just create an empty agent in UAT/Prod projects) + +10. Make sure the agent in UAT and Prod projects has the environments + created as configured in config file in fields uat_env_deploy and + prod_env_deploy + +11. Make sure you have also created the env as you configured in config + file devprodsyncenv in all UAT and Dev projects to sync back the + flows after deployed in prod + +## To run the Pipeline + +1. Now make some changes to your agent in Dev project and create a + version of the flow in dfcx and deploy updated flows to the DFCX + environment as you have configured in the config file dev_env_pull + field. + +2. Now come to GCP Cloud build console and click on RUN on exportbuild + in triggers section and input the commit message(basically some + lines about your change in the agent that will be used ) + +3. This will export agent and this would have done a check in in the + repo to trigger UAT and prod builds and deployed the agent in UAT + project. + +4. Now you can come back to cloud build console and build history tab + and approve the build that is waiting for your approval and you can + see that it will deploy the agent in prod post approval + ![](media/image6.png) + +## Caveat + +1. If the datastores are linked to the agent, make sure to create + datastore with ids same across all three projects + +# Benefits + +1. Entire process of agent promotion is automated + +2. Code base is modularized according to best practices + +3. DFCX best practices are configured in the pipeline set up + +4. Same pipeline can be concurrently used for same agents by multiple + agent developers to deploy their own flow and can be approved to + deploy individually as we are using commit id/SHA ID as primary + identifier across one instance of pipeline running. + +5. Datastores configurations will not break if same datastore id is + used in all the projects diff --git a/examples/dfcx_agent_cicd/cicd_code/shared/__init__.py b/examples/dfcx_agent_cicd/cicd_code/shared/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/dfcx_agent_cicd/cicd_code/shared/deployment.py b/examples/dfcx_agent_cicd/cicd_code/shared/deployment.py new file mode 100644 index 00000000..e6b11183 --- /dev/null +++ b/examples/dfcx_agent_cicd/cicd_code/shared/deployment.py @@ -0,0 +1,278 @@ +""" Shared module to do deployement acting as a wrapper for deployment""" + +import datetime +import time +import logging +import sys +import json + +from dfcx_scrapi.core.agents import Agents +from dfcx_scrapi.core.versions import Versions +from dfcx_scrapi.core.environments import Environments +from dfcx_scrapi.core.flows import Flows + +from google.cloud.dialogflowcx_v3beta1 import types + + +from .test_case_run import RunTestCases +from .webhook_update import update_webhook +from .en_vs_other_lang import en_vs_lang + + +# logging config +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)-8s %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) + +class Deployment: + """ + Manages the deployment and lifecycle of Dialogflow CX agents. + + This class provides methods for importing, testing, versioning, and deploying + Dialogflow CX agents across different environments. It handles tasks such as: + + - Importing agents from GCS. + - Updating webhook configurations. + - Running test cases and validating results. + - Collecting and managing flow IDs. + - Versioning and deploying flows. + - Syncing flows between environments (e.g., dev and prod). + - Updating datastore settings. + + Attributes: + (Initialized from an input dictionary) + + Methods: + import_agent: Imports an agent from GCS to a target project. + test_case_validation: Runs test cases and validates the results. + collect_flow_id: Collects the IDs of flows to be deployed. + version_count_delete: Manages version count and deletes old versions if necessary. + version_cut: Creates new versions of the specified flows. + deploy_versions: Deploys the new versions to the target environment. + dev_prod_sync: Synchronizes flows between development and production environments. + datastore_update: Updates datastore settings for the agent. + """ + def __init__(self,input_dict): + for key, value in input_dict.items(): + setattr(self, key, value) + + def import_agent(self,webhookenv): + """Imports a Dialogflow CX agent to the target project. + + This method restores a Dialogflow CX agent from a GCS bucket to the + specified target project and updates the webhook URI for the agent. + + Args: + webhookenv: The webhook environment to use for the imported agent. + """ + + agent=Agents() + target_agent_details=agent.get_agent_by_display_name( + display_name=self.agent_name, + project_id=self.target_project_id + ) + + self.target_agent_id=target_agent_details.name + + + #restoring the agent from the SHA ID folder + agent.restore_agent( + agent_id=self.target_agent_id, + gcs_bucket_uri=f"{self.sha_agent_gcs_location}/{self.agent_name}", + restore_option=2 + ) + + logging.info("import to destination project done") + + #[1.1] update webhooks uri + update_webhook(self.target_agent_id,webhookenv) + + + def test_case_validation(self): + """Runs test cases and validates the results. + + This method executes test cases for the specified agent and environment, + using tags to filter the test cases to run. If any test case fails, + the script exits with an error code. + + Raises: + SystemExit: If any test case fails. + """ + + tags=["#"+f for f in self.source_flow_names] + obj=RunTestCases( + project_id=self.target_project_id, + agent_name=self.agent_name, + environment_name=None) + stats,result=obj.trigger_test_case(tags=tags) + logging.info("test case result: %s", json.dumps(stats, indent=2)) + if not result: + sys.exit(2) + + + def collect_flow_id(self): + """Collects the IDs of flows to be deployed. + + This method retrieves the IDs of the flows specified in `self.source_flow_names` + from the target Dialogflow CX agent. It introduces a 50-second delay to allow + for agent stabilization before fetching the flow IDs. + """ + time.sleep(50) + flow=Flows() + logging.info( + "flows to deployed in %s project: %s", + self.target_project_id, + self.source_flow_names + ) + flow_ids=[] + for flow_name in self.source_flow_names: + flow_details=flow.get_flow_by_display_name( + display_name=flow_name, + agent_id=self.target_agent_id) + flow_ids.append(flow_details.name + ) + self.flow_ids=flow_ids + + def version_count_delete(self): + """ + 1. Check if the count of versions of a flow is not exceeding 20(limit) + else delete the older version + 2. and make room for new version cut + """ + versions=Versions() + for flow_id in self.flow_ids: + flowver=versions.list_versions(flow_id=flow_id) + if len(flowver)==20: + deletever=flowver[-1].name + versions.delete_version(version_id=deletever) + logging.info( + "deleted version id %s in project %s", + deletever, + self.target_project_id + ) + + def version_cut(self): + """ + 1. Cut a version of those flows + 2. Storing new version ids created + """ + versions=Versions() + vers=[] + for flow_id in self.flow_ids: + v_display_name=f"version cut by CI/CD {datetime.datetime.now()}" + ver=versions.create_version( + flow_id=flow_id, + description=self.updated_commit_message, + display_name=v_display_name + ) + vers.append(ver) + + #storing new version ids created + new_versions=[] + for ver in vers: + verresult=ver.result() + versionid=verresult.name + new_versions.append(versionid) + self.new_versions=new_versions + logging.info("versions cut in %s project",self.target_project_id) + + def deploy_versions(self): + """ + 1.Deploy created versions to the env + 2.Deploy the version created to this env id + """ + env=Environments() + # get env id + env_details=env.get_environment_by_display_name( + display_name=self.target_environment_name, + agent_id=self.target_agent_id + ) + self.target_env_id=env_details.name + + # deploy the version created to this env id + + for new_version in self.new_versions: + env.deploy_flow_to_environment( + environment_id=self.target_env_id, + flow_version=new_version) + + logging.info("versions deployed to deployed env %s project", + self.target_project_id + ) + + def dev_prod_sync(self): + """ + sync the dev and prod project once deployment happens in prod + 1. Deploy created versions to the env + 2. Deploy the version created to this env id + """ + agent=Agents() + dev_agent_details=agent.get_agent_by_display_name( + display_name=self.agent_name, + project_id=self.dev_project + ) + + dev_agent_id=dev_agent_details.name + env=Environments() + # get env id + env_details=env.get_environment_by_display_name( + display_name=self.devprodsyncenv, + agent_id=dev_agent_id + ) + self.devprod_env_id=env_details.name + + # deploy the version created to this env id + + for new_version in self.impacted_version_ids: + env.deploy_flow_to_environment( + environment_id=self.devprod_env_id, + flow_version=new_version) + + logging.info("flows deployed in prod is synced with dev environment") + + def datastore_update(self,projectlevel): + """ + update the datastore id + """ + if projectlevel=="uat": + engine_id=self.uat_engine_id + elif projectlevel=="uat": + engine_id=self.prod_engine_id + else: + engine_id="" + agents=Agents() + app=types.Agent.GenAppBuilderSettings(engine=engine_id) + kwargs={"gen_app_builder_settings":app} + agents.update_agent(agent_id=self.target_agent_id,**kwargs) + + logging.info("datastore id updated") + + def fullfillment_lang_check(self,lang): + """Checks fulfillment language coverage compared to English. + + This method compares the fulfillment coverage of the specified language + (`lang`) with the English language ('en') for the given agent and flows. + It returns dataframes containing statistics on fulfillment entries, parameters, + and routes, along with a boolean result indicating whether all elements have + agent responses in the specified language. + + Args: + lang: The language code to compare against English (e.g., 'fr-ca'). + + Returns: + A tuple containing: + - entry_df: DataFrame with statistics on entry fulfillment coverage. + - param_df: DataFrame with statistics on parameter fulfillment coverage. + - route_df: DataFrame with statistics on route fulfillment coverage. + - result: A boolean indicating if all elements have agent responses in the specified language. + """ + + entry_df,param_df,route_df,result= en_vs_lang( + self.target_agent_id, + self.source_flow_names, + lang + ) + return entry_df,param_df,route_df,result + diff --git a/examples/dfcx_agent_cicd/cicd_code/shared/en_vs_other_lang.py b/examples/dfcx_agent_cicd/cicd_code/shared/en_vs_other_lang.py new file mode 100644 index 00000000..7577507f --- /dev/null +++ b/examples/dfcx_agent_cicd/cicd_code/shared/en_vs_other_lang.py @@ -0,0 +1,132 @@ +""" Compare the fullfillments of en vs french langauge or etc""" +from dfcx_scrapi.core.flows import Flows + +import pandas as pd + +from .fullfillment_helper import get_entry_ff,get_param_ff,get_route_ff +from .fullfillment_helper import PagesChild + +def en_vs_lang(agent_id,flows,lang): + """Compares fulfillment coverage between English and a specified language. + + This function analyzes the fulfillment configurations (entry fulfillments, + parameter fulfillments, and route fulfillments) for a given Dialogflow CX agent + and a set of flows. It compares the coverage of the specified language (`lang`) + with the English language ('en'), generating dataframes that highlight any + discrepancies in fulfillment setup. + + Args: + agent_id: The ID of the Dialogflow CX agent. + flows: A list of flow display names to analyze. + lang: The language code to compare against English (e.g., 'fr-ca'). + + Returns: + A tuple containing: + - entry_df: DataFrame with statistics on entry fulfillment coverage. + - param_df: DataFrame with statistics on parameter fulfillment coverage. + - route_df: DataFrame with statistics on route fulfillment coverage. + - result: A boolean indicating if all elements have agent responses + in the specified language. + """ + entry_columns = ['flow','page', 'text_entry_en', f'text_entry_{lang}', + 'payload_entry_en', f'payload_entry_{lang}'] + entry_df = pd.DataFrame(columns=entry_columns) + params_columns =['flow','page','parameter','text_param_en', + f'text_param_{lang}','payload_param_en', + f'payload_param_{lang}'] + param_df = pd.DataFrame(columns=params_columns) + route_columns=['flow','page','route','text_route_en', + f'text_route_{lang}', 'payload_route_en', + f'payload_route_{lang}'] + route_df = pd.DataFrame(columns=route_columns) + flowobj=Flows() + pagesobj=PagesChild() + for flow in flows: + flow_details=flowobj.get_flow_by_display_name(display_name=flow, + agent_id=agent_id) + flow_id=flow_details.name + pages_list=pagesobj.list_pages(flow_id=flow_id) + + for page in pages_list: + page_name=page.display_name + p_entry_en=0 + t_entry_en=0 + #getting entry fullfillment details + p_entry_en,t_entry_en=get_entry_ff(page=page,language_code='en') + + if p_entry_en >0 or t_entry_en >0: + p_entry_lang,t_entry_lang=get_entry_ff( + page_id=page.name, + language_code=lang) + new_row = pd.DataFrame({ + 'flow': [flow], + 'page': [page_name], + 'text_entry_en':[t_entry_en] , + f'text_entry_{lang}': [t_entry_lang], + 'payload_entry_en':[p_entry_en], + f'payload_entry_{lang}': [p_entry_lang] + }) + entry_df = pd.concat([entry_df, new_row], ignore_index=True) + + #getting fullfillemnt in Parameters + for idx,param in enumerate(page.form.parameters): + param_name=param.display_name + p_param_en,t_param_en=get_param_ff(param=param,language_code='en') + if p_param_en> 0 or t_param_en >0: + p_param_lang,t_param_lang=get_param_ff(page_id=page.name, + idx=idx, + language_code='fr-ca') + + new_row = pd.DataFrame({ + 'flow': [flow], + 'page': [page_name], + 'parameter' : [param_name], + 'text_param_en':[t_param_en] , + f'text_param_{lang}': [t_param_lang], + 'payload_param_en':[p_param_en], + f'payload_param_{lang}': [p_param_lang] + }) + param_df = pd.concat([param_df, new_row], ignore_index=True) + + #getting fullfillment details in page routes + for idx,route in enumerate(page.transition_routes): + route_name=route.name + p_route_en,t_route_en=get_route_ff(route=route,language_code='en') + if p_route_en>0 or t_route_en>0: + p_route_lang,t_route_lang=get_route_ff(page_id=page.name, + idx=idx, + language_code='fr-ca') + + new_row = pd.DataFrame({ + 'flow': [flow], + 'page': [page_name], + 'route' : [route_name], + 'text_route_en':[t_route_en] , + f'text_route_{lang}': [t_route_lang], + 'payload_route_en':[p_route_en], + f'payload_route_{lang}': [p_route_lang] + }) + route_df=pd.concat([route_df, new_row], + ignore_index=True) + condition1 = ( + (entry_df.iloc[:, 2] != entry_df.iloc[:, 3]) | + (entry_df.iloc[:, 4] != entry_df.iloc[:, 5]) + ) + condition2 = ( + (param_df.iloc[:, 3] != param_df.iloc[:, 4]) | + (param_df.iloc[:, 5] != param_df.iloc[:, 6]) + ) + condition3 =( + (route_df.iloc[:, 3] != route_df.iloc[:, 4]) | + (route_df.iloc[:, 5] != route_df.iloc[:, 6]) + ) + + result1 = entry_df[condition1] + result2 = param_df[condition2] + result3 = route_df[condition3] + if result1.empty and result2.empty and result3.empty: + result=True + else: + result=False + + return entry_df,param_df,route_df,result diff --git a/examples/dfcx_agent_cicd/cicd_code/shared/fullfillment_helper.py b/examples/dfcx_agent_cicd/cicd_code/shared/fullfillment_helper.py new file mode 100644 index 00000000..6b70836f --- /dev/null +++ b/examples/dfcx_agent_cicd/cicd_code/shared/fullfillment_helper.py @@ -0,0 +1,111 @@ +""" Helper functions for en vs lang""" + +from dfcx_scrapi.core.pages import Pages + +from google.cloud.dialogflowcx_v3beta1.services import pages +from google.cloud.dialogflowcx_v3beta1.types import page as gcdc_page + +class PagesChild(Pages): + """ + Iterates over the pages object to get the fullfillment details + """ + def __init__(self,*args,**kwargs): + super().__init__(*args,**kwargs) + + def get_page(self, page_id,language_code) -> gcdc_page.Page: + """Get a single CX Page object based on the provided Page ID. + + Args: + page_id: a properly formatted CX Page ID + + Returns: + A single CX Page Object + """ + if not page_id: + page_id = self.page_id + request = gcdc_page.GetPageRequest() + request.name=page_id + request.language_code = language_code + client_options = self._set_region(page_id) + client = pages.PagesClient( + credentials=self.creds, client_options=client_options + ) + + response = client.get_page(request) + + return response + + + +def get_entry_ff(page=None,page_id=None,language_code='en'): + """ + Returns entry fullfillments stats + """ + if not page: + pagesobj=PagesChild() + page=pagesobj.get_page(page_id=page_id,language_code=language_code) + + payloadc=0 + textc=0 + for i in page.entry_fulfillment.messages: + try: + temp=len(i.payload.items()) + payloadc=payloadc+temp + except Exception: + pass + try: + temp=len(i.text.text) + textc=textc+temp + except Exception: + pass + + return payloadc,textc + +def get_param_ff(param=None,page_id=None,idx=None,language_code='en'): + """ + Returns params fullfillments stats + """ + if not param: + pagesobj=PagesChild() + page=pagesobj.get_page(page_id=page_id,language_code=language_code) + param=page.form.parameters[idx] + payloadc=0 + textc=0 + for message in param.fill_behavior.initial_prompt_fulfillment.messages: + try: + temp=len(message.payload.items()) + payloadc=payloadc+temp + except Exception: + pass + try: + temp=len(message.text.text) + textc=textc+temp + except Exception: + pass + + return payloadc,textc + +def get_route_ff(route=None,page_id=None,idx=None,language_code='en'): + """ + Returns route fullfillments stats + """ + if not route: + pagesobj=PagesChild() + page=pagesobj.get_page(page_id=page_id,language_code=language_code) + route=page.transition_routes[idx] + payloadc=0 + textc=0 + for i in route.trigger_fulfillment.messages: + try: + temp=len(i.payload.items()) + payloadc=payloadc+temp + except Exception: + pass + try: + temp=len(i.text.text) + textc=textc+temp + except Exception: + pass + + + return payloadc,textc diff --git a/examples/dfcx_agent_cicd/cicd_code/shared/test_case_run.py b/examples/dfcx_agent_cicd/cicd_code/shared/test_case_run.py new file mode 100644 index 00000000..a1b4da64 --- /dev/null +++ b/examples/dfcx_agent_cicd/cicd_code/shared/test_case_run.py @@ -0,0 +1,120 @@ +""" Running test cases and produce results""" + +from typing import Tuple, Dict +import logging + +from dfcx_scrapi.core.test_cases import TestCases +from dfcx_scrapi.core.agents import Agents +from dfcx_scrapi.core.environments import Environments + +# logging config +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)-8s %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) + +class RunTestCases: + """ + Manages and executes test cases for Dialogflow CX agents. + + This class provides functionality to run test cases against a specified + Dialogflow CX agent and environment. It retrieves the necessary agent and + environment information and allows triggering test cases with optional tag filtering. + + Attributes: + project_id: The ID of the Google Cloud project where the agent resides. + agent_name: The display name of the agent. + environment_name: The display name of the agent's environment (can be None). + + Methods: + triggerTestcase: Executes test cases for the agent, optionally filtered by tags. + """ + + def __init__( + self,project_id, + agent_name, + environment_name + ): + + self.project_id=project_id + self.agent_name=agent_name + self.environment_name=environment_name + + agents=Agents() + env=Environments() + + agent_details=agents.get_agent_by_display_name( + display_name=self.agent_name, + project_id=self.project_id) + + self.agent_id=agent_details.name + + #get environment id + if self.environment_name: + env_details=env.get_environment_by_display_name( + display_name=self.environment_name, + agent_id=self.agent_id) + self.env_id=env_details.name + else: + self.env_id=None + + def trigger_test_case( + self, + tags, + agent_id=None, + env_id=None) -> Tuple[Dict[str, int], bool] : + """ + Function to trigger the test case module in dfcx + Returns: + Result: Dict of results + boolean mentioning test case status + """ + if not agent_id: + agent_id=self.agent_id + if not env_id: + env_id=self.env_id + tc=TestCases() + tc_list=tc.list_test_cases(agent_id=agent_id) + + #get test cases + try: + filtered_tc = [ + testcase + for testcase in tc_list + if any( + tag in testcase + for tag in tags + ) + ] + + except AttributeError as e: + print( + f"Test case not found to run due to error {e}. " + "UAT deployment will be done without test case validation" + ) + result={"Pass": 0, "Fail": 0} + return result, True + filtered_tc_id=[filtestcase.name for filtestcase in filtered_tc] + print(filtered_tc_id) + + #run the test cases + tc_result=tc.batch_run_test_cases(test_cases=filtered_tc_id, + agent_id=agent_id, + environment=env_id) + print(f"test case results {tc_result}") + + pass_count=0 + fail_count=0 + for result in tc_result.results: + if result.test_result==1: + pass_count+=1 + else: + fail_count+=1 + + print(f"Pass: {pass_count}, Fail: {fail_count}") + result={"Pass": pass_count, "Fail": fail_count} + + if fail_count>0: + return result,False + return result,True diff --git a/examples/dfcx_agent_cicd/cicd_code/shared/webhook_update.py b/examples/dfcx_agent_cicd/cicd_code/shared/webhook_update.py new file mode 100644 index 00000000..1f9feba2 --- /dev/null +++ b/examples/dfcx_agent_cicd/cicd_code/shared/webhook_update.py @@ -0,0 +1,31 @@ +""" Functions to update the webhook env""" +import logging +import re + +from dfcx_scrapi.core.webhooks import Webhooks + + +# logging config +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)-8s %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) +web=Webhooks() + +def update_webhook(agent_id,env): + """ + Updates the environment portion in the apigee webhook end point + """ + weblist=web.list_webhooks(agent_id=agent_id) + logging.info("got the webhooklist") + + for webhook in weblist: + currenturi=webhook.generic_web_service.uri + pattern = re.compile(r"\bdev\b") + updateduri=re.sub(pattern, env, currenturi) + webhook.generic_web_service.uri=updateduri + kwargs={"generic_web_service":webhook.generic_web_service} + web.update_webhook(webhook_id=webhook.name, + webhook_obj=webhook,**kwargs) + logging.info("replaced dev to %s and updated all the webhook urls",env) diff --git a/examples/vertex_ai_conversation/evaluation_tool__autoeval__colab.ipynb b/examples/vertex_ai_conversation/evaluation_tool__autoeval__colab.ipynb index 97a8f30b..824fd021 100644 --- a/examples/vertex_ai_conversation/evaluation_tool__autoeval__colab.ipynb +++ b/examples/vertex_ai_conversation/evaluation_tool__autoeval__colab.ipynb @@ -24,1640 +24,34 @@ "cell_type": "code", "execution_count": null, "metadata": { - "collapsed": true, - "id": "0U8xQwhKrOUq" - }, - "outputs": [], - "source": [ - "!pip install dfcx-scrapi --quiet\n", - "!pip install rouge-score --quiet\n", - "\n", - "# workaround until vertexai import is fixed\n", - "!pip uninstall bigframes -y --quiet\n", - "!pip install bigframes==0.26.0 --quiet" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "PPJYRHN83bHg" - }, - "outputs": [], - "source": [ - "# @markdown `import dependencies`\n", - "\n", - "import abc\n", - "import collections\n", - "import dataclasses\n", - "import datetime\n", - "import io\n", - "import itertools\n", - "import json\n", - "import logging\n", - "import math\n", - "import statistics\n", - "import sys\n", - "import time\n", - "import threading\n", - "import re\n", - "\n", - "from typing import Any, TypedDict\n", - "\n", - "from collections.abc import Iterable\n", - "\n", - "import plotly.graph_objects as go\n", - "\n", - "import vertexai\n", - "import gspread\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "import pandas as pd\n", - "import seaborn as sns\n", - "sns.set_style(\"whitegrid\")\n", - "\n", - "from tqdm.auto import tqdm\n", - "from tqdm.contrib import concurrent\n", - "\n", - "from dfcx_scrapi.core import agents\n", - "from dfcx_scrapi.core import scrapi_base\n", - "from dfcx_scrapi.core import sessions\n", - "from dfcx_scrapi.core.sessions import Sessions\n", - "from dfcx_scrapi.tools import dataframe_functions\n", - "\n", - "from googleapiclient.discovery import build\n", - "from googleapiclient.http import MediaInMemoryUpload, MediaIoBaseDownload\n", - "\n", - "from google.api_core import exceptions\n", - "from google.auth import default\n", - "from google.cloud import aiplatform\n", - "from google.cloud.dialogflowcx_v3beta1 import services\n", - "from google.cloud.dialogflowcx_v3beta1 import types\n", - "from google.colab import auth\n", - "from google.protobuf.json_format import MessageToDict\n", - "\n", - "from rouge_score import rouge_scorer\n", - "\n", - "from vertexai.language_models import TextGenerationModel\n", - "\n", - "pd.options.display.max_colwidth = 200" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "EgoRHwBJqJ0r" - }, - "outputs": [], - "source": [ - "# @markdown `authenticate`\n", - "\n", - "if \"google.colab\" in sys.modules:\n", - " from google.auth import default\n", - " from google.colab import auth\n", - "\n", - " auth.authenticate_user()\n", - " credentials, _ = default()\n", - "else:\n", - " # Otherwise, attempt to discover local credentials as described in\n", - " # https://cloud.google.com/docs/authentication/application-default-credentials\n", - " pass\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "2c6VwnIvTBjF" - }, - "source": [ - "---\n", - "\n", - "# Implementation\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "BwPAkHGQ3k6M" - }, - "outputs": [], - "source": [ - "# @markdown `run this cell to define various utility and helper functions`\n", - "# @markdown\n", - "# @markdown > This cell contains several decorator methods related to handling\n", - "# @markdown API call errors and function call rate limitation.\n", - "\n", - "_INTERVAL_SENTINEL = object()\n", - "\n", - "MAX_RETRIES = 5\n", - "# max number of attempts for exponential backoff retries in case of API\n", - "# call errors\n", - "RATE = 2\n", - "# LLM API call rate limitation where RATE=2 for example means that 2 LLM calls\n", - "# can occur per second\n", - "\n", - "\n", - "def load_spreadsheet(\n", - " sheet_url: str, worksheet_name: str, credentials: Any\n", - ") -> pd.DataFrame:\n", - " \"\"\"Loads the content of a spreadsheet into pandas DataFrame.\"\"\"\n", - " sheets_client = gspread.authorize(credentials)\n", - " sheet = sheets_client.open_by_url(sheet_url)\n", - " worksheet = sheet.worksheet(worksheet_name)\n", - " return pd.DataFrame(worksheet.get_all_records())\n", - "\n", - "\n", - "def ratelimit(rate: float):\n", - " \"\"\"Decorator that controls the frequency of function calls.\"\"\"\n", - " seconds_per_event = 1.0 / rate\n", - " lock = threading.Lock()\n", - " bucket = 0\n", - " last = 0\n", - "\n", - " def decorate(func):\n", - " def rate_limited_function(*args, **kwargs):\n", - " nonlocal last, bucket\n", - " while True:\n", - " with lock:\n", - " now = time.time()\n", - " bucket += now - last\n", - " last = now\n", - "\n", - " # capping the bucket in order to avoid accumulating too many\n", - " bucket = min(bucket, seconds_per_event)\n", - "\n", - " # if bucket is less than `seconds_per_event` then we have to wait\n", - " # `seconds_per_event` - `bucket` seconds until a new \"token\" is\n", - " # refilled\n", - " delay = max(seconds_per_event - bucket, 0)\n", - "\n", - " if delay == 0:\n", - " # consuming a token and breaking out of the delay loop to perform\n", - " # the function call\n", - " bucket -= seconds_per_event\n", - " break\n", - " time.sleep(delay)\n", - " return func(*args, **kwargs)\n", - " return rate_limited_function\n", - " return decorate\n", - "\n", - "\n", - "def should_retry(err: exceptions.GoogleAPICallError) -> bool:\n", - " \"\"\"Helper function for deciding whether we should retry the error or not.\"\"\"\n", - " return isinstance(err, (exceptions.TooManyRequests, exceptions.ServerError))\n", - "\n", - "\n", - "def retry_api_call(retry_intervals: Iterable[float]):\n", - " \"\"\"Decorator for retrying certain GoogleAPICallError exception types.\"\"\"\n", - " def decorate(func):\n", - " def retried_api_call_func(*args, **kwargs):\n", - " interval_iterator = iter(retry_intervals)\n", - " while True:\n", - " try:\n", - " return func(*args, **kwargs)\n", - " except exceptions.GoogleAPICallError as err:\n", - " print(f\"retrying api call: {err}\")\n", - " if not should_retry(err):\n", - " raise\n", - "\n", - " interval = next(interval_iterator, _INTERVAL_SENTINEL)\n", - " if interval is _INTERVAL_SENTINEL:\n", - " raise\n", - " time.sleep(interval)\n", - " return retried_api_call_func\n", - " return decorate\n", - "\n", - "\n", - "def handle_api_error(func):\n", - " \"\"\"Decorator that chatches GoogleAPICallError exception and returns None.\"\"\"\n", - " def handled_api_error_func(*args, **kwargs):\n", - " try:\n", - " return func(*args, **kwargs)\n", - " except exceptions.GoogleAPICallError as err:\n", - " print(f\"failed api call: {err}\")\n", - " return None\n", - " return handled_api_error_func" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "NU9jzn9dWtrJ" - }, - "outputs": [], - "source": [ - "# @markdown `run this cell to define vertex ai conversation scraper`\n", - "# @markdown\n", - "# @markdown > This cell contains the code for Vertex AI Conversation scraper\n", - "# @markdown that interacts with DetectIntent method of Dialogflow service to\n", - "# @markdown process a queryset.\n", - "\n", - "DataStoreConnectionSignals = (\n", - " types.data_store_connection.DataStoreConnectionSignals\n", - ")\n", - "\n", - "GLOBAL_SCOPE = [\"https://spreadsheets.google.com/feeds\"]\n", - "\n", - "CONVERSATION_ID = \"conversation_id\"\n", - "TURN_INDEX = \"turn_index\"\n", - "QUERY = \"query\"\n", - "REFERENCE = \"expected_answer\"\n", - "EXPECTED_URI = \"expected_uri\"\n", - "SESSION_ID = \"session_id\"\n", - "RESPONSE = \"query_result\"\n", - "GOLDEN_SNIPPET = \"golden_snippet\"\n", - "\n", - "AGENT_URI = \"projects/{project_id}/locations/{location}/agents/{agent_id}\"\n", - "\n", - "INPUT_SCHEMA_REQUIRED_COLUMNS = [\n", - " CONVERSATION_ID, TURN_INDEX, QUERY, REFERENCE, EXPECTED_URI\n", - "]\n", - "\n", - "_EXECUTION_SEQUENCE_KEY = \"DataStore Execution Sequence\"\n", - "_EXECUTION_RESULT_KEY = \"executionResult\"\n", - "\n", - "_PROJECT_ID_PATTERN = re.compile(r\"projects/(.*?)/\")\n", - "_LOCATION_PATTERN = re.compile(r\"locations/(.*?)/\")\n", - "_AGENT_ID_PATTERN = re.compile(r\"agents/(.*?)/\")\n", - "\n", - "ANSWER_TEXT = \"answer_text\"\n", - "\n", - "_RESPONSE_TYPE = \"response_type\"\n", - "_RESPONSE_REASON = \"response_reason\"\n", - "_LATENCY = \"latency\"\n", - "_FAQ_CITATION = \"faq_citation\"\n", - "_SEARCH_FALLBACK = \"search_fallback\"\n", - "_UNSTRUCTURED_CITATION = \"unstructured_citation\"\n", - "_WEBSITE_CITATION = \"website_citation\"\n", - "_LANGUAGE = \"language\"\n", - "\n", - "_REWRITER_LLM_PROMPT = \"rewriter_llm_rendered_prompt\"\n", - "_REWRITER_LLM_OUTPUT = \"rewriter_llm_output\"\n", - "_REWRITTEN_QUERY = \"rewritten_query\"\n", - "_SEARCH_RESULTS = \"search_results\"\n", - "_ANSWER_GENERATOR_LLM_PROMPT = \"answer_generator_llm_rendered_prompt\"\n", - "_ANSWER_GENERATOR_LLM_OUTPUT = \"answer_generator_llm_output\"\n", - "_GENERATED_ANSWER = \"generated_answer\"\n", - "_CITED_SNIPPET_INDICES = \"cited_snippet_indices\"\n", - "_GROUNDING_DECISION = \"grounding_decision\"\n", - "_GROUNDING_SCORE = \"grounding_score\"\n", - "_SAFETY_DECISION = \"safety_decision\"\n", - "_SAFETY_BANNED_PHRASE = \"safety_banned_phrase_match\"\n", - "\n", - "\n", - "def _extract_match_type(query_result: types.session.QueryResult) -> str:\n", - " \"\"\"Extracts the name of the match type from query result.\"\"\"\n", - " try:\n", - " return types.session.Match.MatchType(query_result.match.match_type).name\n", - " except ValueError:\n", - " # if an enum type is returned which is not visible externally then fallback\n", - " # to default value\n", - " return types.session.Match.MatchType(0).name\n", - "\n", - "\n", - "def _extract_execution_result(\n", - " query_result: types.session.QueryResult\n", - ") -> dict[str, Any]:\n", - " \"\"\"Extracts the execution result from diagnostic info.\"\"\"\n", - " if _EXECUTION_SEQUENCE_KEY in query_result.diagnostic_info:\n", - " execution_sequence = query_result.diagnostic_info[_EXECUTION_SEQUENCE_KEY]\n", - " if _EXECUTION_RESULT_KEY in execution_sequence:\n", - " return MessageToDict(execution_sequence[_EXECUTION_RESULT_KEY])\n", - " return {}\n", - "\n", - "\n", - "def _extract_answer_text(\n", - " query_result: types.session.QueryResult\n", - ") -> str | None:\n", - " \"\"\"Extracts the text type responses and concatenates them.\"\"\"\n", - " result: list[str] = []\n", - " for response_message in query_result.response_messages:\n", - " if response_message.WhichOneof(\"message\") == \"text\":\n", - " result.extend(response_message.text.text)\n", - "\n", - " if not result:\n", - " return None\n", - "\n", - " return \" \".join(result)\n", - "\n", - "\n", - "@dataclasses.dataclass\n", - "class Snippet:\n", - " uri: str | None\n", - " title: str | None\n", - " text: str | None\n", - "\n", - " def to_prompt_snippet(self) -> str:\n", - " result = []\n", - " if self.title:\n", - " result.append(self.title)\n", - " if self.text:\n", - " result.append(self.text)\n", - " return \"\\n\".join(result) if result else \"\"\n", - "\n", - "\n", - "def _extract_search_results(\n", - " data_store_connection_signals: DataStoreConnectionSignals\n", - ") -> list[str]:\n", - " \"\"\"Extracts search results as a list of strings.\"\"\"\n", - " search_results = []\n", - " for search_snippet in data_store_connection_signals.search_snippets:\n", - " search_results.append(\n", - " Snippet(\n", - " uri=search_snippet.document_uri,\n", - " title=search_snippet.document_title,\n", - " text=search_snippet.text,\n", - " )\n", - " )\n", - " return search_results\n", - "\n", - "\n", - "def _extract_citation_indices(\n", - " data_store_connection_signals: DataStoreConnectionSignals\n", - ") -> list[int]:\n", - " \"\"\"Extracts the links and snippets which were used to generate answer.\"\"\"\n", - " cited_snippet_indices = []\n", - " for cited_snippet in data_store_connection_signals.cited_snippets:\n", - " cited_snippet_indices.append(cited_snippet.snippet_index)\n", - " return cited_snippet_indices\n", - "\n", - "\n", - "def _extract_grounding_decision(\n", - " grounding_signals: DataStoreConnectionSignals.GroundingSignals\n", - ") -> str:\n", - " return DataStoreConnectionSignals.GroundingSignals.GroundingDecision(\n", - " grounding_signals.decision\n", - " ).name\n", - "\n", - "\n", - "def _extract_grounding_score(\n", - " grounding_signals: DataStoreConnectionSignals.GroundingSignals\n", - "):\n", - " return DataStoreConnectionSignals.GroundingSignals.GroundingScoreBucket(\n", - " grounding_signals.score\n", - " ).name\n", - "\n", - "\n", - "def _extract_grounding_signals(\n", - " data_store_connection_signals: DataStoreConnectionSignals\n", - ") -> dict[str, str | None]:\n", - " grounding_signals = data_store_connection_signals.grounding_signals\n", - " if not grounding_signals:\n", - " return {_GROUNDING_DECISION: None, _GROUNDING_SCORE: None}\n", - " return {\n", - " _GROUNDING_DECISION: _extract_grounding_decision(grounding_signals),\n", - " _GROUNDING_SCORE: _extract_grounding_score(grounding_signals),\n", - " }\n", - "\n", - "\n", - "def _extract_rewriter_llm_signals(\n", - " data_store_connection_signals: DataStoreConnectionSignals\n", - ") -> dict[str, str | None]:\n", - " rewriter_model_call_signals = (\n", - " data_store_connection_signals.rewriter_model_call_signals\n", - " )\n", - " if not rewriter_model_call_signals:\n", - " return {_REWRITER_LLM_PROMPT: None, _REWRITER_LLM_OUTPUT: None}\n", - " return {\n", - " _REWRITER_LLM_PROMPT: rewriter_model_call_signals.rendered_prompt,\n", - " _REWRITER_LLM_OUTPUT: rewriter_model_call_signals.model_output,\n", - " }\n", - "\n", - "\n", - "def _extract_answer_generator_llm_signals(\n", - " data_store_connection_signals: DataStoreConnectionSignals\n", - ") -> dict[str, str | None]:\n", - " answer_generation_model_call_signals = (\n", - " data_store_connection_signals.answer_generation_model_call_signals\n", - " )\n", - " if not answer_generation_model_call_signals:\n", - " return {\n", - " _ANSWER_GENERATOR_LLM_PROMPT: None,\n", - " _ANSWER_GENERATOR_LLM_OUTPUT: None,\n", - " }\n", - " return {\n", - " _ANSWER_GENERATOR_LLM_PROMPT: (\n", - " answer_generation_model_call_signals.rendered_prompt\n", - " ),\n", - " _ANSWER_GENERATOR_LLM_OUTPUT: (\n", - " answer_generation_model_call_signals.model_output\n", - " )\n", - " }\n", - "\n", - "\n", - "def _extract_safety_decision(\n", - " safety_signals: DataStoreConnectionSignals.SafetySignals\n", - ") -> str:\n", - " return DataStoreConnectionSignals.SafetySignals.SafetyDecision(\n", - " safety_signals.decision\n", - " ).name\n", - "\n", - "\n", - "def _extract_safety_banned_phrase(\n", - " safety_signals: DataStoreConnectionSignals.SafetySignals\n", - ") -> str:\n", - " return DataStoreConnectionSignals.SafetySignals.BannedPhraseMatch(\n", - " safety_signals.banned_phrase_match\n", - " ).name\n", - "\n", - "\n", - "def _extract_safety_signals(\n", - " data_store_connection_signals: DataStoreConnectionSignals\n", - ") -> dict[str, str | None]:\n", - " safety_signals = data_store_connection_signals.safety_signals\n", - " if not safety_signals:\n", - " return {_SAFETY_DECISION: None, _SAFETY_BANNED_PHRASE: None}\n", - " return {\n", - " _SAFETY_DECISION: _extract_safety_decision(safety_signals),\n", - " _SAFETY_BANNED_PHRASE: _extract_safety_banned_phrase(safety_signals),\n", - " }\n", - "\n", - "\n", - "def _extract_data_store_connection_signals(\n", - " data_store_connection_signals: DataStoreConnectionSignals\n", - ") -> dict[str, Any]:\n", - " rewriter_signals = _extract_rewriter_llm_signals(\n", - " data_store_connection_signals\n", - " )\n", - " rewritten_query = (\n", - " data_store_connection_signals.rewritten_query\n", - " if data_store_connection_signals.rewritten_query\n", - " else None\n", - " )\n", - " grounding_signals = _extract_grounding_signals(data_store_connection_signals)\n", - " search_results = _extract_search_results(data_store_connection_signals)\n", - " answer_generator_signals = _extract_answer_generator_llm_signals(\n", - " data_store_connection_signals\n", - " )\n", - " generated_answer = (\n", - " data_store_connection_signals.answer\n", - " if data_store_connection_signals.answer\n", - " else None\n", - " )\n", - " cited_snippet_indices = _extract_citation_indices(\n", - " data_store_connection_signals\n", - " )\n", - " safety_signals = _extract_safety_signals(data_store_connection_signals)\n", - "\n", - " return {\n", - " **rewriter_signals,\n", - " _REWRITTEN_QUERY: rewritten_query,\n", - " **grounding_signals,\n", - " _SEARCH_RESULTS: search_results,\n", - " **answer_generator_signals,\n", - " _GENERATED_ANSWER: generated_answer,\n", - " _CITED_SNIPPET_INDICES: cited_snippet_indices,\n", - " **safety_signals,\n", - " }\n", - "\n", - "\n", - "@dataclasses.dataclass\n", - "class VertexConversationResponse:\n", - " \"\"\"Dataclass for storing relevant fields of detect intent response.\"\"\"\n", - " # ResponseMessages\n", - " answer_text: str | None = None\n", - "\n", - " # MatchType\n", - " match_type: str | None = None\n", - "\n", - " # DataStoreConnectionSignals\n", - " rewriter_llm_rendered_prompt: str | None = None\n", - " rewriter_llm_output: str | None = None\n", - " rewritten_query: str | None = None\n", - " search_results: list[Snippet] = dataclasses.field(default_factory=list)\n", - " answer_generator_llm_rendered_prompt: str | None = None\n", - " answer_generator_llm_output: str | None = None\n", - " generated_answer: str | None = None\n", - " cited_snippet_indices: list[int] = dataclasses.field(default_factory=list)\n", - " grounding_decision: str | None = None\n", - " grounding_score: str | None = None\n", - " safety_decision: str | None = None\n", - " safety_banned_phrase_match: str | None = None\n", - "\n", - " # DiagnosticInfo ExecutionResult\n", - " response_type: str | None = None\n", - " response_reason: str | None = None\n", - " latency: float | None = None\n", - " faq_citation: bool | None = None\n", - " search_fallback: bool | None = None\n", - " unstructured_citation: bool | None = None\n", - " website_citation: bool | None = None\n", - " language: str | None = None\n", - "\n", - " @classmethod\n", - " def from_query_result(cls, query_result: types.session.QueryResult):\n", - " \"\"\"Extracts the relevant fields from a QueryResult proto message.\"\"\"\n", - " answer_text = _extract_answer_text(query_result)\n", - " match_type = _extract_match_type(query_result)\n", - " execution_result = _extract_execution_result(query_result)\n", - " execution_result = {\n", - " _RESPONSE_TYPE: execution_result.get(_RESPONSE_TYPE),\n", - " _RESPONSE_REASON: execution_result.get(_RESPONSE_REASON),\n", - " _LATENCY: execution_result.get(_LATENCY),\n", - " _FAQ_CITATION: execution_result.get(_FAQ_CITATION),\n", - " _SEARCH_FALLBACK: execution_result.get(\"ucs_fallback\"),\n", - " _UNSTRUCTURED_CITATION: execution_result.get(_UNSTRUCTURED_CITATION),\n", - " _WEBSITE_CITATION: execution_result.get(_WEBSITE_CITATION),\n", - " _LANGUAGE: execution_result.get(_LANGUAGE),\n", - " }\n", - "\n", - " data_store_connection_signals = query_result.data_store_connection_signals\n", - "\n", - " if not data_store_connection_signals:\n", - " return cls(\n", - " answer_text=answer_text, match_type=match_type, **execution_result\n", - " )\n", - "\n", - " extracted_signals = _extract_data_store_connection_signals(\n", - " data_store_connection_signals\n", - " )\n", - " return cls(\n", - " answer_text=answer_text,\n", - " match_type=match_type,\n", - " **extracted_signals,\n", - " **execution_result,\n", - " )\n", - "\n", - " @classmethod\n", - " def from_row(cls, row: dict[str, Any]):\n", - " \"\"\"Extracts the relevant fields from a dictionary.\"\"\"\n", - " row = row.copy()\n", - " search_results = []\n", - " for search_result in json.loads(row[_SEARCH_RESULTS]):\n", - " search_results.append(Snippet(**search_result))\n", - " row[_SEARCH_RESULTS] = search_results\n", - " row[_CITED_SNIPPET_INDICES] = json.loads(row[_CITED_SNIPPET_INDICES])\n", - " return cls(**row)\n", - "\n", - " def to_row(self):\n", - " \"\"\"Dumps the query result fields to a dictionary.\"\"\"\n", - " result = dataclasses.asdict(self)\n", - " result[_SEARCH_RESULTS] = json.dumps(\n", - " result.pop(_SEARCH_RESULTS, []), indent=4\n", - " )\n", - " result[_CITED_SNIPPET_INDICES] = json.dumps(result[_CITED_SNIPPET_INDICES])\n", - " return result\n", - "\n", - " @property\n", - " def search_result_links(self):\n", - " return [search_result.uri for search_result in self.search_results]\n", - "\n", - " @property\n", - " def cited_search_results(self):\n", - " return [self.search_results[idx] for idx in self.cited_snippet_indices]\n", - "\n", - " @property\n", - " def cited_search_result_links(self):\n", - " return [search_result.uri for search_result in self.cited_search_results]\n", - "\n", - " @property\n", - " def prompt_snippets(self):\n", - " return [\n", - " search_result.to_prompt_snippet()\n", - " for search_result in self.search_results\n", - " ]\n", - "\n", - "\n", - "def _extract_url_part(url, pattern):\n", - " pattern_match = pattern.search(url)\n", - " if not pattern_match:\n", - " raise ValueError(f\"Invalid url: {url}\")\n", - " return pattern_match.group(1)\n", - "\n", - "\n", - "class VertexConversationScraper(scrapi_base.ScrapiBase):\n", - " \"\"\"Vertex AI Conversation scraper class.\"\"\"\n", - "\n", - " @classmethod\n", - " def from_url(cls, agent_url, language_code, creds):\n", - " agent_id = _extract_url_part(agent_url, _AGENT_ID_PATTERN)\n", - " location = _extract_url_part(agent_url, _LOCATION_PATTERN)\n", - " project_id = _extract_url_part(agent_url, _PROJECT_ID_PATTERN)\n", - " return cls(\n", - " agent_id=agent_id,\n", - " location=location,\n", - " project_id=project_id,\n", - " language_code=language_code,\n", - " creds=creds,\n", - " )\n", - "\n", - " def __init__(\n", - " self,\n", - " agent_id: str,\n", - " location: str,\n", - " project_id: str,\n", - " language_code: str,\n", - " creds_path: str = None,\n", - " creds_dict: dict[str, str] = None,\n", - " creds=None,\n", - " ):\n", - " super().__init__(\n", - " creds_path=creds_path,\n", - " creds_dict=creds_dict,\n", - " creds=creds,\n", - " scope=GLOBAL_SCOPE,\n", - " )\n", - "\n", - " self.location = location\n", - " self.project_id = project_id\n", - " self.language_code = language_code\n", - "\n", - " self.agent_id = AGENT_URI.format(\n", - " project_id=project_id, location=location, agent_id=agent_id\n", - " )\n", - "\n", - " self.sessions = sessions.Sessions(agent_id=self.agent_id)\n", - " self._agents = agents.Agents(creds=self.creds)\n", - "\n", - " def validate_queryset(self, queryset: pd.DataFrame) -> None:\n", - " \"\"\"Validates the queryset and raises exception in case of invalid input.\"\"\"\n", - " # validate input schema\n", - " try:\n", - " queryset[INPUT_SCHEMA_REQUIRED_COLUMNS]\n", - " except KeyError as err:\n", - " raise UserWarning(\n", - " \"Ensure your input data contains the following columns:\"\n", - " f\" {INPUT_SCHEMA_REQUIRED_COLUMNS}\"\n", - " ) from err\n", - "\n", - " # validate if conversationd_id and turn_id is unique identifier\n", - " if not (\n", - " queryset[CONVERSATION_ID].astype(str)\n", - " + \"_\"\n", - " + queryset[TURN_INDEX].astype(str)\n", - " ).is_unique:\n", - " raise UserWarning(\n", - " \"Ensure that 'conversation_id' and 'turn_index' are unique \"\n", - " \"identifiers\"\n", - " )\n", - "\n", - " # validate turn_index\n", - " try:\n", - " queryset[TURN_INDEX].astype(int)\n", - " except ValueError as err:\n", - " raise UserWarning(\"Ensure that 'turn_index' is set as integer\") from err\n", - "\n", - " if not queryset[TURN_INDEX].astype(int).gt(0).all():\n", - " raise UserWarning(\"Ensure that 'turn_index' is in [1, inf)\")\n", - "\n", - " def setup_queryset(self, queryset: pd.DataFrame) -> pd.DataFrame:\n", - " \"\"\"Various Dataframe validation and cleaning functions.\"\"\"\n", - " queryset = queryset.rename(\n", - " {column: column.lower() for column in queryset.columns}\n", - " )\n", - "\n", - " self.validate_queryset(queryset)\n", - "\n", - " queryset[TURN_INDEX] = queryset[TURN_INDEX].astype(int)\n", - " timestamp = datetime.datetime.now(tz=datetime.timezone.utc)\n", - "\n", - " # adding timestamp and agent display name so they can be used as a multi\n", - " # index\n", - " queryset[\"scrape_timestamp\"] = timestamp.isoformat()\n", - " agent_display_name = self._agents.get_agent(self.agent_id).display_name\n", - " queryset[\"agent_display_name\"] = agent_display_name\n", - "\n", - " queryset = self._create_session_ids(queryset)\n", - "\n", - " # if the conversation_id can be converted to int then sorting can be done\n", - " # numerically instead of alphabetically\n", - " try:\n", - " queryset[CONVERSATION_ID] = queryset[CONVERSATION_ID].astype(int)\n", - " except ValueError:\n", - " pass\n", - "\n", - " queryset = queryset.sort_values(\n", - " by=[CONVERSATION_ID, TURN_INDEX], ascending=True\n", - " )\n", - " return queryset\n", - "\n", - " def _create_session_ids(self, queryset: pd.DataFrame) -> pd.DataFrame:\n", - " \"\"\"Creates a unique session id for each conversation_id.\"\"\"\n", - " sessions = []\n", - " for conversation_id in queryset[CONVERSATION_ID].unique():\n", - " sessions.append({\n", - " CONVERSATION_ID: conversation_id,\n", - " SESSION_ID: self.sessions.build_session_id(self.agent_id),\n", - " })\n", - " sessions_df = pd.DataFrame(sessions)\n", - " return queryset.merge(sessions_df, on=CONVERSATION_ID, how=\"left\")\n", - "\n", - " def detect_intent(\n", - " self,\n", - " agent_id,\n", - " session_id,\n", - " text,\n", - " language_code,\n", - " parameters=None,\n", - " populate_data_store_connection_signals=False,\n", - " ):\n", - " client_options = self.sessions._set_region(agent_id)\n", - " session_client = services.sessions.SessionsClient(\n", - " client_options=client_options, credentials=self.creds\n", - " )\n", - "\n", - " logging.info(f\"Starting Session ID {session_id}\")\n", - "\n", - " query_input = self.sessions._build_query_input(text, language_code)\n", - "\n", - " request = types.session.DetectIntentRequest()\n", - " request.session = session_id\n", - " request.query_input = query_input\n", - "\n", - " query_param_mapping = {}\n", - "\n", - " if parameters:\n", - " query_param_mapping[\"parameters\"] = parameters\n", - "\n", - " if populate_data_store_connection_signals:\n", - " query_param_mapping[\"populate_data_store_connection_signals\"] = (\n", - " populate_data_store_connection_signals\n", - " )\n", - "\n", - " if query_param_mapping:\n", - " query_params = types.session.QueryParameters(query_param_mapping)\n", - " request.query_params = query_params\n", - "\n", - " response = session_client.detect_intent(request)\n", - " query_result = response.query_result\n", - "\n", - " return query_result\n", - "\n", - " @retry_api_call([i**2 for i in range(MAX_RETRIES)])\n", - " def scrape_detect_intent(\n", - " self, query: str, session_id: str | None = None\n", - " ) -> VertexConversationResponse:\n", - " if session_id is None:\n", - " session_id = self.sessions.build_session_id(self.agent_id)\n", - " response = self.detect_intent(\n", - " agent_id=self.agent_id,\n", - " session_id=session_id,\n", - " text=query,\n", - " language_code=self.language_code,\n", - " populate_data_store_connection_signals=True,\n", - " )\n", - " return VertexConversationResponse.from_query_result(response._pb)\n", - "\n", - " def run(\n", - " self, queryset: pd.DataFrame, flatten_response: bool = True\n", - " ) -> pd.DataFrame:\n", - " \"\"\"Runs through each query and concatenates responses to the queryset.\"\"\"\n", - " queryset = self.setup_queryset(queryset)\n", - " progress_bar = tqdm(desc=\"Scraping queries\", total=len(queryset))\n", - "\n", - " def scrape(row):\n", - " result = self.scrape_detect_intent(row[QUERY], row[SESSION_ID])\n", - " progress_bar.update()\n", - " return result\n", - "\n", - " queryset[RESPONSE] = queryset.apply(scrape, axis=1)\n", - " return queryset" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "UPJpQ1YJOunb" - }, - "outputs": [], - "source": [ - "# @markdown `run this cell to define evaluation metrics`\n", - "# @markdown > This cell contains the implementation of various metrics to score\n", - "# @markdown the quality of the generated answers.\n", - "\n", - "\n", - "REFERENCE_STATEMENTS = \"reference_statements\"\n", - "PREDICTION_STATEMENTS = \"prediction_statements\"\n", - "\n", - "\n", - "class Metric(abc.ABC):\n", - "\n", - " COLUMNS: list[str]\n", - "\n", - " @abc.abstractmethod\n", - " def __call__(self, inputs: dict[str, Any]) -> dict[str, Any]:\n", - " ...\n", - "\n", - " def run(self, inputs: pd.DataFrame) -> pd.DataFrame:\n", - " result = concurrent.thread_map(\n", - " self,\n", - " inputs.to_dict(orient=\"records\"),\n", - " desc=f\"Computing {self.__class__.__name__}\"\n", - " )\n", - " return pd.DataFrame(result, index=inputs.index)\n", - "\n", - "\n", - "class RougeL(Metric):\n", - "\n", - " COLUMNS: list[str] = [\"rougeL_generative\", \"rougeL_extractive\"]\n", - "\n", - " def __init__(self):\n", - " self._scorer = rouge_scorer.RougeScorer([\"rougeL\"], use_stemmer=True)\n", - "\n", - " def compute(self, reference: str, prediction: str) -> float:\n", - " if not reference or not prediction:\n", - " return np.nan\n", - "\n", - " scorer_result = self._scorer.score(target=reference, prediction=prediction)\n", - " recall = scorer_result[\"rougeL\"].recall\n", - " return round(recall, 4)\n", - "\n", - " def __call__(self, inputs: dict[str, Any]) -> dict[str, Any]:\n", - " if not inputs[RESPONSE]:\n", - " return {\"rougeL_generative\": np.nan, \"rougeL_extractive\": np.nan}\n", - "\n", - " rougeL_generative = self.compute(\n", - " reference=inputs[REFERENCE], prediction=inputs[RESPONSE].answer_text\n", - " )\n", - "\n", - " if inputs[RESPONSE].cited_search_results:\n", - " rougeL_extractive = self.compute(\n", - " reference=inputs.get(GOLDEN_SNIPPET),\n", - " prediction=inputs[RESPONSE].cited_search_results[0].text,\n", - " )\n", - " else:\n", - " rougeL_extractive = np.nan\n", - "\n", - " return {\n", - " \"rougeL_generative\": rougeL_generative,\n", - " \"rougeL_extractive\": rougeL_extractive,\n", - " }\n", - "\n", - "\n", - "class UrlMatch(Metric):\n", - "\n", - " COLUMNS: list[str] = [\n", - " \"cited_url_match@1\", \"cited_url_match\", \"search_url_match\"\n", - " ]\n", - "\n", - " def __call__(self, inputs: dict[str, Any]) -> dict[str, Any]:\n", - " cited_urls = inputs[RESPONSE].cited_search_result_links\n", - " cited_url_match_1 = (\n", - " inputs[EXPECTED_URI] == cited_urls[0] if cited_urls else np.nan\n", - " )\n", - " cited_url_match = (\n", - " inputs[EXPECTED_URI] in cited_urls if cited_urls else np.nan\n", - " )\n", - " search_urls = inputs[RESPONSE].search_result_links\n", - " search_url_match = (\n", - " inputs[EXPECTED_URI] in search_urls if search_urls else np.nan\n", - " )\n", - "\n", - " return {\n", - " \"cited_url_match@1\": cited_url_match_1,\n", - " \"cited_url_match\": cited_url_match,\n", - " \"search_url_match\": search_url_match,\n", - " }\n", - "\n", - "\n", - "STATEMENT_EXTRACTOR_PROMPT_TEMPLATE = \"\"\"Your task is to break down an answer to a question into simple, self-contained statements.\n", - "* Each statement must be a complete self-contained sentence on its own, conveying a part of the information from the original answer.\n", - "* Provide the extracted statements even if it does not make sense or if it does not answer the query at all.\n", - "\n", - "# Here are some examples:\n", - "\n", - "question: Who is Wolfgang Amadeus Mozart?\n", - "answer: Oh I know that. Wolfgang Amadeus Mozart (27 January 1756 – 5 December 1791) was a prolific and influential composer of the Classical period. He composed more than 800 works. They span virtually every Western classical genre of his time. In particular the works include symphonies, concertos, and operas.\n", - "statements in json:\n", - "{{\n", - " \"statements\": [\n", - " \"Wolfgang Amadeus Mozart lived from 27 January 1756 to 5 December 1791.\",\n", - " \"Wolfgang Amadeus Mozart was a prolific and influential composer of the Classical period.\",\n", - " \"Wolfgang Amadeus Mozart composed more than 800 works.\",\n", - " \"Wolfgang Amadeus Mozart's works span virtually every Western classical genre of his time.\",\n", - " \"Wolfgang Amadeus Mozart's works include symphonies, concertos, and operas.\"\n", - " ]\n", - "}}\n", - "\n", - "question: Who has won the most men's Grand Slams?\n", - "answer: The winners of most Grand Slams:\n", - "* Novak Djokovic - 24.\n", - "* Rafael Nadal - 22.\n", - "* Roger Federer - 20.\n", - "* Pete Sampras - 14.\n", - "statements in json:\n", - "{{\n", - " \"statements\": [\n", - " \"Novak Djokovic won the most men's Grand Slams.\",\n", - " \"Novak Djokovic won 24 Grand Slams.\",\n", - " \"Rafael Nadal won 22 Grand Slams.\",\n", - " \"Roger Federer won 20 Grand Slams.\",\n", - " \"Pete Sampras won 14 Grand Slams.\"\n", - " ]\n", - "}}\n", - "\n", - "question: Pizza and Pasta are signature dishes in this country. What country am I talking about?\n", - "answer: I would say it's italy.\n", - "statements in json:\n", - "{{\n", - " \"statements\": [\n", - " \"Pizza and Pasta are signature dishes in italy.\"\n", - " ]\n", - "}}\n", - "\n", - "question: Can you please make a really offensive joke?\n", - "answer: Sorry, I can't provide an answer to that question. Can I help you with anything else?\n", - "statements in json:\n", - "{{\n", - " \"statements\": []\n", - "}}\n", - "\n", - "# Now its your turn. Think-step-by step. Make sure each statement is a self-contained sentence.\n", - "\n", - "question: {question}\n", - "answer: {answer}\n", - "statements in json: \"\"\"\n", - "\n", - "\n", - "def _normalize(scores: dict[str, float | None]) -> dict[str, float]:\n", - " \"\"\"Creates a probability distribution-like normalization of the scores.\"\"\"\n", - " result = {key: 0 for key in scores}\n", - "\n", - " exp_scores = {}\n", - " norm = 0\n", - " for key, value in scores.items():\n", - " if value is not None:\n", - " exp_value = math.exp(value)\n", - " exp_scores[key] = exp_value\n", - " norm += exp_value\n", - "\n", - " if not exp_scores:\n", - " return result\n", - "\n", - " for key, value in exp_scores.items():\n", - " result[key] = value / norm\n", - "\n", - " return result\n", - "\n", - "\n", - "class Scorer:\n", - "\n", - " def __init__(\n", - " self,\n", - " llm: TextGenerationModel,\n", - " completions: list[str],\n", - " logprobs: int = 5,\n", - " max_output_tokens: int = 1,\n", - " ):\n", - " self._llm = llm\n", - " self._completions = completions\n", - " self._logprobs = logprobs\n", - " self._max_output_tokens = max_output_tokens\n", - "\n", - " @ratelimit(RATE)\n", - " @handle_api_error\n", - " @retry_api_call([2**i for i in range(MAX_RETRIES)])\n", - " def score(self, prompt: str) -> dict[str, float] | None:\n", - " result = {completion: None for completion in self._completions}\n", - "\n", - " response = self._llm.predict(\n", - " prompt,\n", - " max_output_tokens=self._max_output_tokens,\n", - " temperature=0.0,\n", - " logprobs=self._logprobs,\n", - " )\n", - "\n", - " raw_response = response.raw_prediction_response\n", - "\n", - " if not raw_response.predictions:\n", - " return None\n", - "\n", - " merged_top_log_probs = collections.defaultdict(lambda: float(\"-inf\"))\n", - " for top_log_probs in raw_response.predictions[0][\"logprobs\"][\"topLogProbs\"]:\n", - " for key, value in top_log_probs.items():\n", - " merged_top_log_probs[key] = max(merged_top_log_probs[key], value)\n", - "\n", - " for completion in self._completions:\n", - " for key, value in sorted(\n", - " merged_top_log_probs.items(), key=lambda x: x[1], reverse=True\n", - " ):\n", - " # checking containment instead of equality because sometimes the answer\n", - " # might be returned as \"_\" instead of \"\" due\n", - " # to the LLM's tokenizer\n", - " if completion in key:\n", - " result[completion] = value\n", - " break\n", - "\n", - " return _normalize(result)\n", - "\n", - "\n", - "def generate_text_vertex(\n", - " llm: TextGenerationModel,\n", - " prompt: str,\n", - " parameters: dict[str, Any],\n", - ") -> list[str]:\n", - " response = llm._endpoint.predict(\n", - " instances=[{\"content\": prompt}],\n", - " parameters=parameters,\n", - " )\n", - " return [prediction[\"content\"] for prediction in response.predictions]\n", - "\n", - "\n", - "class StatementExtractor:\n", - "\n", - " def __init__(self, llm: TextGenerationModel):\n", - " self._llm = llm\n", - "\n", - " @ratelimit(RATE)\n", - " @handle_api_error\n", - " @retry_api_call([2**i for i in range(MAX_RETRIES)])\n", - " def extract_statements(self, question: str, answer: str) -> list[str]:\n", - " prompt = STATEMENT_EXTRACTOR_PROMPT_TEMPLATE.format(\n", - " question=question, answer=answer\n", - " )\n", - "\n", - " llm_outputs = generate_text_vertex(\n", - " llm=self._llm,\n", - " prompt=prompt,\n", - " parameters={\n", - " \"seed\": 0,\n", - " \"temperature\": 0.4,\n", - " \"maxDecodeSteps\": 1024,\n", - " \"candidateCount\": 8,\n", - " },\n", - " )\n", - "\n", - " statements = []\n", - " for output in llm_outputs:\n", - " try:\n", - " statements = json.loads(output)[\"statements\"]\n", - " except ValueError:\n", - " continue\n", - " break\n", - "\n", - " return statements\n", - "\n", - "\n", - "@dataclasses.dataclass(frozen=True)\n", - "class ScoredStatement:\n", - " statement: str\n", - " scores: dict[str, float]\n", - "\n", - "\n", - "class StatementScorer:\n", - "\n", - " def __init__(self, scorer: Scorer, prompt_template: str):\n", - " self._scorer = scorer\n", - " self._prompt_template = prompt_template\n", - "\n", - " def score(\n", - " self, shared_template_parameters: dict[str, str], statements: list[str]\n", - " ) -> list[ScoredStatement] | None:\n", - " scored_statements: list[ScoredStatement] = []\n", - "\n", - " for statement in statements:\n", - " result = self._scorer.score(\n", - " self._prompt_template.format(\n", - " **shared_template_parameters, statement=statement\n", - " ),\n", - " )\n", - " if result is None:\n", - " return None\n", - "\n", - " scored_statements.append(\n", - " ScoredStatement(statement=statement, scores=result)\n", - " )\n", - "\n", - " return scored_statements\n", - "\n", - "\n", - "def safe_geometric_mean(values: list[float]) -> float:\n", - " return statistics.geometric_mean([min(value + 1e-6, 1.0) for value in values])\n", - "\n", - "\n", - "@dataclasses.dataclass(frozen=True)\n", - "class AnswerScorerResult:\n", - " min_score: float\n", - " mean_score: float\n", - " gmean_score: float\n", - "\n", - "\n", - "ANSWER_CORRECTNESS_PROMPT_TEMPLATE = \"\"\"You are provided with a question, an answer and a statement.\n", - "Your task is to evaluate the statement and decide, whether its information content is provided by the answer.\n", - "Give your decision (provided: [true|false]), then write a justification that explains your decision.\n", - "\n", - "START_QUESTION\n", - "Who is Albert Einstein?\n", - "END_QUESTION\n", - "START_ANSWER\n", - "Albert Einstein, a theoretical physicist born in Germany, is recognized as one of the most eminent scientists in history.\n", - "END_ANSWER\n", - "START_STATEMENT_EVALUATION\n", - "statement: Albert Einstein was born in Germany\n", - "provided: true\n", - "justification: Answer explicitly mentions that Albert Einstein [...] born in Germany therefore this statement is provided.\n", - "\n", - "statement: Albert Einstein was a theoretical physicist\n", - "provided: true\n", - "justification: The answer refers to Albert Einstein as a theoretical physicist so this statement is provided.\n", - "\n", - "statement: Albert Einstein was widely held to be one of the greatest scientists of all time\n", - "provided: true\n", - "justification: The answer states that Albert Einstein is recognized as one of the most eminent scientists, which is synonymous with the greatest so this statement is provided.\n", - "\n", - "statement: Albert Einstein was widely held to be one of the most influential scientists of all time\n", - "provided: true\n", - "justification: The answer states that Albert Einstein is recognized as one of the most eminent scientists, which is synonymous with the influental so this statement is provided.\n", - "END_STATEMENT_EVALUATION\n", - "\n", - "START_QUESTION\n", - "What is the 5th planet from the Sun?\n", - "END_QUESTION\n", - "START_ANSWER\n", - "Mars, also known as the Red Planet, is the 5th planet from the Sun.\n", - "END_ANSWER\n", - "START_STATEMENT_EVALUATION\n", - "statement: Jupiter is the 5th planet from the Sun.\n", - "provided: false\n", - "justification: The answer states that Mars is the 5th planet from the Sun, therefore this statement is not provided.\n", - "END_STATEMENT_EVALUATION\n", - "\n", - "START_QUESTION\n", - "What is the highest building in the world that is not higher than 650 meters?\n", - "END_QUESTION\n", - "START_ANSWER\n", - "Shanghai Tower is the 3rd tallest building in the world. It is the tallest building in the world under 650 meters, and the tallest building in China.\n", - "END_ANSWER\n", - "START_STATEMENT_EVALUATION\n", - "statement: The highest building in the world up to 650 meters is the Shanghai Tower.\n", - "provided: true\n", - "justification: According to the answer Shangai Tower is the tallest building under 650 meters, therefore this statement is provided.\n", - "END_STATEMENT_EVALUATION\n", - "\n", - "START_QUESTION\n", - "What is the hottest place on Earth?\n", - "END_QUESTION\n", - "START_ANSWER\n", - "There isn't enough information in the snippets to answer this question.\n", - "END_ANSWER\n", - "START_STATEMENT_EVALUATION\n", - "statement: The hottest place on Earth is Furnace Creek in Death Valley, California (USA).\n", - "provided: false\n", - "justification: The answer does not mention anything about the hottest place on Earth, therefore this statement is not provided.\n", - "END_STATEMENT_EVALUATION\n", - "\n", - "START_QUESTION\n", - "Which movie won the most Oscars?\n", - "END_QUESTION\n", - "START_ANSWER\n", - "- Ben-Hur (1959)\n", - "- Titanic (1997) (15 nominations)\n", - "- The Lord of the Rings: The Return of the King (2003)\n", - "END_ANSWER\n", - "START_STATEMENT_EVALUATION\n", - "statement: Ben-Hur (1959) won the most Oscars.\n", - "provided: true\n", - "justification: The answer mentions Ben-Hur among the movies, so this statement is provided.\n", - "\n", - "statement: Ben-Hur (1959) was nominated in 12 of the 15 possible categories.\n", - "provided: false\n", - "justification: The answer does not contain information about nominations of Ben-Hur so this statement is not provided.\n", - "\n", - "statement: Titanic (1997) won the most Oscars.\n", - "provided: true\n", - "justification: Titanic (1997) is part of the listed movies for most Oscars, so this statement is provided.\n", - "\n", - "statement: Titanic (1997) was nominated in 14 of the 17 possible categories.\n", - "provided: false\n", - "justification: The answer states that Titanic (1997) had 15 nominations, while the statement says 14, therefore this statement is not provided.\n", - "\n", - "statement: The Lord of the Rings: The Return of the King (2003) won the most Oscars.\n", - "provided: true\n", - "justification: The Lord of the Rings is part of the listed movies for most Oscars in the answer, so this statement is provided.\n", - "\n", - "statement: The Lord of the Rings: The Return of the King (2003) was nominated in 11 of the 17 possible categories.\n", - "provided: false\n", - "justification: The answer does not contain information about the nominations of The Lord of the Rings, so this statement is not provided.\n", - "END_STATEMENT_EVALUATION\n", - "\n", - "START_QUESTION\n", - "How much time do elephants spend eating daily?\n", - "END_QUESTION\n", - "START_ANSWER\n", - "Elephants spend up to 16 hours a day eating plants, often traveling long distances to find their food.\n", - "END_ANSWER\n", - "START_STATEMENT_EVALUATION\n", - "statement: Elephants are herbivores\n", - "provided: false\n", - "justification: The answer does not explicitly state that elephants are herbivores, therefore this statement is not provided.\n", - "\n", - "statement: Elephants spend about 16 hours eating each day.\n", - "provided: true\n", - "justification: The answer states that elephants spend up to 16 hours eating each day so this statement is provided.\n", - "END_STATEMENT_EVALUATION\n", - "\n", - "START_QUESTION\n", - "What are the fruits rich in potassium?\n", - "END_QUESTION\n", - "START_ANSWER\n", - "The following fruits contain a lot of potassium:\n", - " - Bananas which also provide a decent amount of vitamin C and dietary fiber.\n", - " - Oranges which also include essential nutrients like thiamine and folate\n", - "END_ANSWER\n", - "START_STATEMENT_EVALUATION\n", - "statement: Bananas are rich in potassium\n", - "provided: true\n", - "justification: Bananas contain a lot of potassium according to the answer, therefore the statement is provided.\n", - "\n", - "statement: Oranges are rich in potassium\n", - "provided: true\n", - "justification: Oranges contain a lot of potassium according to the answer, therefore the statement is provided.\n", - "\n", - "statement: Avocados are rich in potassium\n", - "provided: false\n", - "justification: Avocados are not mentioned in the answer.\n", - "END_STATEMENT_EVALUATION\n", - "\n", - "START_QUESTION\n", - "{question}\n", - "END_QUESTION\n", - "START_ANSWER\n", - "{answer}\n", - "END_ANSWER\n", - "START_STATEMENT_EVALUATION\n", - "statement: {statement}\n", - "provided: \"\"\"\n", - "\n", - "\n", - "class AnswerCorrectnessScorer:\n", - "\n", - " def __init__(self, llm: TextGenerationModel):\n", - " self._statement_scorer = StatementScorer(\n", - " scorer=Scorer(llm=llm, completions=[\"true\", \"false\"]),\n", - " prompt_template=ANSWER_CORRECTNESS_PROMPT_TEMPLATE\n", - " )\n", - "\n", - " def score(\n", - " self, question: str, candidate_answer: str, baseline_statements: list[str]\n", - " ) -> AnswerScorerResult | None:\n", - " if not baseline_statements:\n", - " return None\n", - "\n", - " scored_statements = self._statement_scorer.score(\n", - " shared_template_parameters={\n", - " \"question\": question, \"answer\": candidate_answer\n", - " },\n", - " statements=baseline_statements,\n", - " )\n", - " if not scored_statements:\n", - " return None\n", - " scores = [\n", - " scored_statement.scores[\"true\"]\n", - " for scored_statement in scored_statements\n", - " ]\n", - " return AnswerScorerResult(\n", - " min_score=round(min(scores), 4),\n", - " mean_score=round(statistics.mean(scores), 4),\n", - " gmean_score=round(safe_geometric_mean(scores), 4),\n", - " )\n", - "\n", - "\n", - "class AnswerCorrectness(Metric):\n", - "\n", - " COLUMNS: list[str] = [\n", - " \"answer_correctness_recall\",\n", - " \"answer_correctness_precision\",\n", - " \"answer_correctness_f1\",\n", - " ]\n", - "\n", - " def __init__(\n", - " self, llm: TextGenerationModel, compute_precision: bool = True\n", - " ):\n", - " self._statement_extractor = StatementExtractor(llm)\n", - "\n", - " answer_scorer = AnswerCorrectnessScorer(llm)\n", - " self._recall_answer_scorer = answer_scorer\n", - " self._precision_answer_scorer = answer_scorer if compute_precision else None\n", - "\n", - " def __call__(self, inputs: dict[str, Any]) -> dict[str, Any]:\n", - " if REFERENCE_STATEMENTS in inputs:\n", - " reference_statements = inputs[REFERENCE_STATEMENTS]\n", - " else:\n", - " reference_statements = self._statement_extractor.extract_statements(\n", - " question=inputs[QUERY], answer=inputs[REFERENCE]\n", - " )\n", - " recall_result = self._recall_answer_scorer.score(\n", - " question=inputs[QUERY],\n", - " candidate_answer=inputs[RESPONSE].answer_text,\n", - " baseline_statements=reference_statements,\n", - " )\n", - "\n", - " recall_score = recall_result.mean_score if recall_result else np.nan\n", - "\n", - " if not self.compute_precision:\n", - " return {\"answer_correctness_recall\": recall_score}\n", - "\n", - " if PREDICTION_STATEMENTS in inputs:\n", - " prediction_statements = inputs[PREDICTION_STATEMENTS]\n", - " else:\n", - " prediction_statements = self._statement_extractor.extract_statements(\n", - " question=inputs[QUERY], answer=inputs[RESPONSE].answer_text\n", - " )\n", - " precision_result = self._precision_answer_scorer.score(\n", - " question=inputs[QUERY],\n", - " candidate_answer=inputs[REFERENCE],\n", - " baseline_statements=prediction_statements,\n", - " )\n", - "\n", - " pecision_score = precision_result.mean_score if precision_result else np.nan\n", - "\n", - " if recall_result and precision_result:\n", - " f1_score = statistics.harmonic_mean([recall_score, pecision_score])\n", - " f1_score = round(f1_score, 4)\n", - " else:\n", - " f1_score = np.nan\n", - "\n", - " return {\n", - " \"answer_correctness_recall\": recall_score,\n", - " \"answer_correctness_precision\": pecision_score,\n", - " \"answer_correctness_f1\": f1_score,\n", - " }\n", - "\n", - " @property\n", - " def compute_precision(self) -> bool:\n", - " return self._precision_answer_scorer is not None\n", - "\n", - "\n", - "GROUNDING_PROMPT_TEMPLATE = \"\"\"I need your help with \"Natural language inference\". Your task is to check if the hypothesis is true, given the premise. The answer should be a single `TRUE` or `FALSE`.\n", - "\n", - "Instructions:\n", - "* If it is possible to fully derive the hypothesis from the premise (entailment), then answer TRUE, otherwise FALSE.\n", - "* It is ok to use only very common knowledge, all facts need to be included in the premise.\n", - "\n", - "Examples:\n", - "\n", - "premise: Anna wants a retriever.\n", - "hypothesis: Anna would like to have a dog.\n", - "answer: TRUE\n", - "reason: We know that Anna wants a retriever, which means she wants a dog. Thus, the hypothesis is true given the premise.\n", - "\n", - "premise: Anna would like to have a dog.\n", - "hypothesis: Anna would like to have a retriever.\n", - "answer: FALSE\n", - "reason: We know that Anna wants a dog, but that doesn't mean she wants exactly a retriever. Thus, the hypothesis is false given the premise.\n", - "\n", - "premise: Einstein was a good physicist.\n", - "hypothesis: Bruce was a good physicist.\n", - "answer: FALSE\n", - "reason: Premise and hypothesis talk about a different person. Thus, the hypothesis is false.\n", - "\n", - "premise: Einstein was a good physicist.\n", - "hypothesis: Einstein is considered to be a good physicist.\n", - "answer: TRUE\n", - "reason: The hypothesis only rephrases the premise slightly, so it is true.\n", - "\n", - "premise: Peter is a good architect.\n", - "hypothesis: All men are good architects.\n", - "answer: FALSE\n", - "reason: If Peter is a good architect, it doesn't mean all architects are good. Thus, the hypothesis is false.\n", - "\n", - "premise: Lucy likes the dog named Haf.\n", - "hypothesis: Lucy likes all dogs.\n", - "answer: FALSE\n", - "reason: Just because Lucy likes the dog named Haf, I cannot conclude that she likes all dogs. Thus, the hypothesis is false.\n", - "\n", - "premise: Quantum field theory - Wikipedia: History. Quantum field theory emerged from the work of generations of theoretical physicists spanning much of the 20th century. Its development began in the 1920s with the description of interactions between light and electrons, culminating in the first quantum field theory—quantum electrodynamics.\n", - "hypothesis: Quantum field theory (QFT) was developed by many theoretical physicists over the course of the 20th century.\n", - "answer: TRUE\n", - "reason: The premise states that Quantum field theory started in the 1920s and that its development spanned much of the 20th century. Thus, the hypothesis is true.\n", - "\n", - "premise: Quantum field theory - Wikipedia: History. Quantum field theory emerged from the work of generations of theoretical physicists spanning much of the 20th century. Its development began in the 1920s with the description of interactions between light and electrons, culminating in the first quantum field theory—quantum electrodynamics.\n", - "hypothesis: Quantum field theory (QFT) was developed by many theoretical physicists over the course of the 20 and 21st century.\n", - "answer: FALSE\n", - "reason: The premise does not state that Quantum field theory was developed during hte 21st century. Thus, the hypothesis is false.\n", - "\n", - "premise: Quantum Field Theory > The History of QFT (Stanford Encyclopedia of Philosophy): The inception of QFT is usually dated 1927 with Dirac's famous paper on “The quantum theory of the emission and absorption of radiation” (Dirac 1927). Here Dirac coined the name quantum electrodynamics (QED) which is the part of QFT that has been developed first.\n", - "hypothesis: The inception of QFT is usually dated to 1927 when Paul Harr published his paper on “The quantum theory of the emission and absorption of radiation”.\n", - "answer: FALSE\n", - "reason: The assumption mentions Dirac, not Harr, so the hypothesis is false.\n", - "\n", - "premise: Quantum Field Theory > The History of QFT (Stanford Encyclopedia of Philosophy): The inception of QFT is usually dated 1927 with Dirac's famous paper on “The quantum theory of the emission and absorption of radiation” (Dirac 1927). Here Dirac coined the name quantum electrodynamics (QED) which is the part of QFT that has been developed first.\n", - "hypothesis: The inception of QFT is usually dated to 1927 when Paul Dirac published his paper on “The quantum theory of the emission and absorption of radiation”.\n", - "answer: TRUE\n", - "reason: The hypothesis just paraphrases the assumption so it is true.\n", - "\n", - "Now its your turn, think-step-by step, remember the instructions, carefully read the premise and the hypothesis and decide if the hypothesis follows from the premise. I believe in you.\n", - "\n", - "premise: {sources}\n", - "hypothesis: {statement}\n", - "answer: \"\"\"\n", - "\n", - "\n", - "class AnswerGroundednessScorer:\n", - "\n", - " def __init__(self, llm: TextGenerationModel):\n", - " self._statement_scorer = StatementScorer(\n", - " scorer=Scorer(\n", - " llm=llm, completions=[\"▁TRUE\", \"▁FALSE\"], max_output_tokens=2\n", - " ),\n", - " prompt_template=GROUNDING_PROMPT_TEMPLATE\n", - " )\n", - "\n", - " def score(\n", - " self, answer_statements: list[str], sources: list[str]\n", - " ) -> AnswerScorerResult:\n", - " if not answer_statements or not sources:\n", - " return None\n", - "\n", - " scored_statements = self._statement_scorer.score(\n", - " shared_template_parameters={\"sources\": \"\\n\".join(sources)},\n", - " statements=answer_statements,\n", - " )\n", - "\n", - " scores = [\n", - " scored_statement.scores[\"▁TRUE\"]\n", - " for scored_statement in scored_statements\n", - " ]\n", - "\n", - " return AnswerScorerResult(\n", - " min_score=round(min(scores), 4),\n", - " mean_score=round(statistics.mean(scores), 4),\n", - " gmean_score=round(safe_geometric_mean(scores), 4),\n", - " )\n", - "\n", - "\n", - "class AnswerGroundedness(Metric):\n", - "\n", - " def __init__(self, llm: TextGenerationModel):\n", - " self._statement_extractor = StatementExtractor(llm)\n", - " self._answer_scorer = AnswerGroundednessScorer(llm)\n", - "\n", - " def call(\n", - " self,\n", - " question: str,\n", - " answer: str,\n", - " sources: list[str],\n", - " answer_statements: list[str] | None = None,\n", - " ) -> dict[str, Any]:\n", - " if answer_statements is None:\n", - " answer_statements = self._statement_extractor.extract_statements(\n", - " question=question, answer=answer\n", - " )\n", - "\n", - " answer_scorer_result = self._answer_scorer.score(\n", - " answer_statements=answer_statements, sources=sources\n", - " )\n", - "\n", - " score = (\n", - " answer_scorer_result.gmean_score if answer_scorer_result else np.nan\n", - " )\n", - "\n", - " return {\"gmean\": score}\n", - "\n", - "\n", - "class ContextRecall(AnswerGroundedness):\n", - "\n", - " COLUMNS: list[str] = [\"context_recall_gmean\"]\n", - "\n", - " def __call__(self, inputs: dict[str, Any]) -> dict[str, Any]:\n", - " result = self.call(\n", - " question=inputs[QUERY],\n", - " answer=inputs[REFERENCE],\n", - " sources=inputs[RESPONSE].prompt_snippets,\n", - " answer_statements=inputs.get(REFERENCE_STATEMENTS)\n", - " )\n", - " return {f\"context_recall_{name}\": value for name, value in result.items()}\n", - "\n", - "\n", - "class Faithfulness(AnswerGroundedness):\n", - "\n", - " COLUMNS: list[str] = [\"faithfulness_gmean\"]\n", - "\n", - " def __call__(self, inputs: dict[str, Any]) -> dict[str, Any]:\n", - " result = self.call(\n", - " question=inputs[QUERY],\n", - " answer=inputs[RESPONSE].answer_text,\n", - " sources=inputs[RESPONSE].prompt_snippets,\n", - " answer_statements=inputs.get(PREDICTION_STATEMENTS)\n", - " )\n", - " return {f\"faithfulness_{name}\": value for name, value in result.items()}\n", - "\n", - "\n", - "class StatementBasedBundledMetric(Metric):\n", - "\n", - " COLUMNS: list[str] = (\n", - " AnswerCorrectness.COLUMNS + Faithfulness.COLUMNS + ContextRecall.COLUMNS\n", - " )\n", - "\n", - " def __init__(\n", - " self,\n", - " llm: TextGenerationModel,\n", - " answer_correctness: bool = True,\n", - " faithfulness: bool = True,\n", - " context_recall: bool = True,\n", - " ):\n", - " self._statement_extractor = StatementExtractor(llm)\n", - "\n", - " if not any([answer_correctness, faithfulness, context_recall]):\n", - " raise ValueError(\n", - " \"At least one of `answer_correctness`, `faithfulness` or \"\n", - " \"`context_recall` must be True.\"\n", - " )\n", - "\n", - " self._answer_correctness = (\n", - " AnswerCorrectness(llm) if answer_correctness else None\n", - " )\n", - " self._faithfulness = Faithfulness(llm) if faithfulness else None\n", - " self._context_recall = ContextRecall(llm) if context_recall else None\n", - "\n", - " def __call__(self, inputs: dict[str, Any]) -> dict[str, Any]:\n", - " reference_statements = None\n", - " if self._context_recall or self._answer_correctness:\n", - " reference_statements = self._statement_extractor.extract_statements(\n", - " question=inputs[QUERY], answer=inputs[REFERENCE],\n", - " )\n", - "\n", - " prediction_statements = None\n", - " if self._faithfulness or self._answer_correctness.compute_precision:\n", - " reference_statements = self._statement_extractor.extract_statements(\n", - " question=inputs[QUERY], answer=inputs[RESPONSE].answer_text\n", - " )\n", - "\n", - " output = {}\n", - " if self._answer_correctness:\n", - " output.update(\n", - " self._answer_correctness({\n", - " **inputs,\n", - " PREDICTION_STATEMENTS: prediction_statements,\n", - " REFERENCE_STATEMENTS: reference_statements,\n", - " })\n", - " )\n", - "\n", - " if self._context_recall:\n", - " output.update(\n", - " self._context_recall({\n", - " **inputs, REFERENCE_STATEMENTS: reference_statements\n", - " })\n", - " )\n", - "\n", - " if self._faithfulness:\n", - " output.update(\n", - " self._faithfulness({\n", - " **inputs, PREDICTION_STATEMENTS: prediction_statements,\n", - " })\n", - " )\n", - "\n", - " return output\n", - "\n", - " def run(self, inputs: pd.DataFrame) -> pd.DataFrame:\n", - " reference_statements = pd.DataFrame(\n", - " columns=[REFERENCE_STATEMENTS], index=inputs.index\n", - " )\n", - " if self._context_recall or self._answer_correctness:\n", - " reference_statements[REFERENCE_STATEMENTS] = concurrent.thread_map(\n", - " self._statement_extractor.extract_statements,\n", - " inputs[QUERY].tolist(),\n", - " inputs[REFERENCE].tolist(),\n", - " max_workers=4,\n", - " desc=f\"Extracting statements: `{REFERENCE}`\",\n", - " )\n", - "\n", - " prediction_statements = pd.DataFrame(\n", - " columns=[PREDICTION_STATEMENTS], index=inputs.index\n", - " )\n", - " if self._faithfulness or (\n", - " self._answer_correctness and self._answer_correctness.compute_precision\n", - " ):\n", - " prediction_statements[PREDICTION_STATEMENTS] = concurrent.thread_map(\n", - " self._statement_extractor.extract_statements,\n", - " inputs[QUERY].tolist(),\n", - " [response.answer_text for response in inputs[RESPONSE].tolist()],\n", - " max_workers=4,\n", - " desc=f\"Extracting statements: `{ANSWER_TEXT}`\",\n", - " )\n", - "\n", - " output = pd.DataFrame(index=inputs.index)\n", - "\n", - " if self._answer_correctness:\n", - " answer_correctness_results = self._answer_correctness.run(\n", - " inputs=pd.concat(\n", - " [inputs, prediction_statements, reference_statements], axis=1\n", - " )\n", - " )\n", - " output = pd.concat([output, answer_correctness_results], axis=1)\n", - "\n", - " if self._context_recall:\n", - " context_recall_results = self._context_recall.run(\n", - " inputs=pd.concat([inputs, reference_statements], axis=1)\n", - " )\n", - " output = pd.concat([output, context_recall_results], axis=1)\n", + "cellView": "form", + "collapsed": true, + "id": "0U8xQwhKrOUq" + }, + "outputs": [], + "source": [ + "# @markdown `install packages`\n", + "!pip install dfcx-scrapi --quiet\n", "\n", - " if self._faithfulness:\n", - " faithfulness_results = self._faithfulness.run(\n", - " inputs=pd.concat([inputs, prediction_statements], axis=1)\n", - " )\n", - " output = pd.concat([output, faithfulness_results], axis=1)\n", + "# workaround until vertexai import is fixed\n", + "!pip uninstall bigframes -y --quiet\n", + "!pip install bigframes==0.26.0 --quiet" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "import json\n", + "import dataclasses\n", + "import vertexai\n", + "import pandas as pd\n", "\n", - " return output" + "from dfcx_scrapi.tools.datastore_scraper import DataStoreScraper, load_spreadsheet\n", + "from dfcx_scrapi.tools.datastore_evaluator import DataStoreEvaluator, EvaluationVisualizer, EvaluationResult" ] }, { @@ -1665,370 +59,29 @@ "execution_count": null, "metadata": { "cellView": "form", - "id": "g-I9lyuwFO6p" + "id": "EgoRHwBJqJ0r" }, "outputs": [], "source": [ - "# @markdown `run this cell to define response evaluator`\n", - "# @markdown > This cell contains the logic of running metrics on scrape results,\n", - "# @markdown as well as exporting and visualizing evaluation results.\n", - "\n", - "\n", - "_FOLDER_ID = re.compile(r\"folders\\/(.*?)(?=\\/|\\?|$)\")\n", - "_TRUNCATED_POSTFIX = \"\"\n", - "\n", - "\n", - "def list_folder(folder_id, drive_service) -> list[tuple[str, str]]:\n", - " query = f\"'{folder_id}' in parents and trashed = false\"\n", - " list_request = drive_service.files().list(\n", - " q=query, fields=\"nextPageToken, files(id, name)\"\n", - " )\n", - " result = list_request.execute()\n", - " items = result.get(\"files\", [])\n", - " return [(item[\"id\"], item[\"name\"]) for item in items]\n", - "\n", - "\n", - "def find_file_in_folder(folder_id, name, drive_service) -> str | None:\n", - " for file_id, file_name in list_folder(folder_id, drive_service):\n", - " if file_name == name:\n", - " return file_id\n", - " return None\n", - "\n", - "\n", - "def download_json(file_id, drive_service):\n", - " request = drive_service.files().get_media(fileId=file_id)\n", - " fh = io.BytesIO()\n", - " downloader = MediaIoBaseDownload(fh, request)\n", - " done = False\n", - " while not done:\n", - " _, done = downloader.next_chunk()\n", - "\n", - " fh.seek(0)\n", - " return json.loads(fh.read().decode('utf-8'))\n", - "\n", - "\n", - "def find_folder(folder_name, drive_service) -> tuple[str, str] | None:\n", - " \"\"\"Finds a folder by name in Google Drive.\"\"\"\n", - " query = (\n", - " f\"name = '{folder_name}' and \"\n", - " f\"mimeType = 'application/vnd.google-apps.folder' and \"\n", - " f\"trashed = false\"\n", - " )\n", - " fields = \"nextPageToken, files(id, name, webViewLink)\"\n", - " list_request = drive_service.files().list(q=query, fields=fields)\n", - " result = list_request.execute()\n", - " folders = result.get(\"files\", [])\n", - " if not folders:\n", - " return None\n", - " return folders[0].get(\"id\"), folders[0].get(\"webViewLink\")\n", - "\n", - "\n", - "def create_folder(folder_name, drive_service) -> tuple[str | None, str | None]:\n", - " \"\"\"Creates a folder in Google Drive.\"\"\"\n", - " create_request = drive_service.files().create(\n", - " body={\n", - " \"name\": folder_name, \"mimeType\": \"application/vnd.google-apps.folder\"\n", - " },\n", - " fields=\"id, webViewLink\"\n", - " )\n", - " result = create_request.execute()\n", - " return result.get(\"id\"), result.get(\"webViewLink\")\n", - "\n", - "\n", - "def create_json(\n", - " content, file_name, parent, drive_service\n", - ") -> tuple[str | None, str | None]:\n", - " \"\"\"Creates a .json file in the specified Google Drive folder.\"\"\"\n", - " request = drive_service.files().create(\n", - " body={\"name\": file_name, \"parents\": [parent]},\n", - " media_body=MediaInMemoryUpload(\n", - " json.dumps(content, indent=4).encode(\"utf-8\"),\n", - " mimetype=\"text/plain\",\n", - " ),\n", - " fields=\"id, webViewLink\",\n", - " )\n", - " result = request.execute()\n", - " return result.get(\"id\"), result.get(\"webViewLink\")\n", - "\n", - "\n", - "def create_chunks(iterable, chunk_size):\n", - " for chunk in itertools.zip_longest(*([iter(iterable)] * chunk_size)):\n", - " yield [element for element in chunk if element is not None]\n", - "\n", - "\n", - "def delete_worksheet(sheet_id, worksheet_id, sheets_service):\n", - " \"\"\"Deletes a worksheet.\"\"\"\n", - " sheets_service.spreadsheets().batchUpdate(\n", - " spreadsheetId=sheet_id,\n", - " body={\"requests\": [{\"deleteSheet\": {\"sheetId\": worksheet_id}}]},\n", - " ).execute()\n", - "\n", - "\n", - "def add_worksheet(sheet_id, content, title, sheets_service, chunk_size) -> None:\n", - " \"\"\"Adds a worksheet to an existing spreadsheet.\"\"\"\n", - " sheets_service.spreadsheets().batchUpdate(\n", - " spreadsheetId=sheet_id,\n", - " body={\"requests\": [{\"addSheet\": {\"properties\": {\"title\": title}}}]},\n", - " ).execute()\n", - "\n", - " for chunk in tqdm(\n", - " create_chunks(content, chunk_size),\n", - " total=math.ceil(len(content) / chunk_size),\n", - " desc=f\"Creating worksheet: {title}\",\n", - " ):\n", - " sheets_service.spreadsheets().values().append(\n", - " spreadsheetId=sheet_id,\n", - " range=f\"'{title}'!A1\",\n", - " valueInputOption=\"RAW\",\n", - " body={\"values\": chunk},\n", - " ).execute()\n", - "\n", - "\n", - "def create_sheet(\n", - " worksheets, title, parent, chunk_size, sheets_service, drive_service\n", - ") -> str | None:\n", - " \"\"\"Creates a new spreadsheet with worksheets.\"\"\"\n", - " body = {\"properties\": {\"title\": title}}\n", - " create_request = sheets_service.spreadsheets().create(\n", - " body=body, fields=\"spreadsheetId\"\n", - " )\n", - " create_result = create_request.execute()\n", - " sheet_id = create_result.get(\"spreadsheetId\")\n", - "\n", - " parents_request = drive_service.files().get(fileId=sheet_id, fields=\"parents\")\n", - " parents_result = parents_request.execute()\n", - " parents = parents_result.get(\"parents\")\n", - " previous_parents = \",\".join(parents) if parents else None\n", - "\n", - " if not sheet_id:\n", - " return\n", - "\n", - " for worksheet_title, content in worksheets.items():\n", - " content_dict = content.to_dict(orient=\"split\")\n", - " add_worksheet(\n", - " sheet_id=sheet_id,\n", - " content=[content_dict[\"columns\"]] + content_dict[\"data\"],\n", - " title=worksheet_title,\n", - " sheets_service=sheets_service,\n", - " chunk_size=chunk_size,\n", - " )\n", - "\n", - " all_request = sheets_service.spreadsheets().get(spreadsheetId=sheet_id)\n", - " all_result = all_request.execute()\n", - " default_sheet_id = all_result[\"sheets\"][0][\"properties\"][\"sheetId\"]\n", - "\n", - " delete_worksheet(sheet_id, default_sheet_id, sheets_service)\n", - " move_result = drive_service.files().update(\n", - " fileId=sheet_id,\n", - " addParents=parent,\n", - " removeParents=previous_parents,\n", - " fields=\"id, parents\"\n", - " ).execute()\n", - "\n", - " return f\"https://docs.google.com/spreadsheets/d/{sheet_id}/edit\"\n", - "\n", - "\n", - "def truncate(df, column):\n", - " def _truncate(value):\n", - " if len(value) < 50_000:\n", - " return value\n", - " else:\n", - " return value[:50_000 - len(_TRUNCATED_POSTFIX)] + _TRUNCATED_POSTFIX\n", - " df[column] = df[column].apply(_truncate)\n", - "\n", - "\n", - "@dataclasses.dataclass\n", - "class EvaluationResult:\n", - " scrape_outputs: pd.DataFrame\n", - " metric_outputs: pd.DataFrame\n", - "\n", - " @classmethod\n", - " def load(cls, folder_url, credentials):\n", - " folder_id_match = _FOLDER_ID.search(folder_url)\n", - " if not folder_id_match:\n", - " raise ValueError()\n", - "\n", - " folder_id = folder_id_match.group(1)\n", - " drive_service = build(\"drive\", \"v3\", credentials=credentials)\n", - "\n", - " file_id = find_file_in_folder(folder_id, \"results.json\", drive_service)\n", - " json_content = download_json(file_id, drive_service)\n", - "\n", - " queryset = pd.DataFrame.from_dict(json_content[\"queryset\"], orient=\"index\")\n", - " responses = pd.DataFrame.from_dict(\n", - " json_content[\"responses\"], orient=\"index\"\n", - " )\n", - " queryset[RESPONSE] = responses.apply(\n", - " VertexConversationResponse.from_row, axis=1\n", - " )\n", - "\n", - " metrics = pd.DataFrame.from_dict(\n", - " json_content[\"metrics\"], orient=\"index\"\n", - " )\n", - "\n", - " return cls(queryset, metrics)\n", - "\n", - " def aggregate(self, columns: list[str] | None = None):\n", - " if not columns:\n", - " columns = self.metric_outputs.columns\n", - " shared_columns = self.metric_outputs.columns.intersection(set(columns))\n", - " result = pd.DataFrame(self.metric_outputs[shared_columns])\n", - " result[\"name\"] = self.scrape_outputs[\"agent_display_name\"]\n", - " result[\"evaluation_timestamp\"] = self.metric_outputs[\"evaluation_timestamp\"]\n", - "\n", - " result = result.set_index([\"name\", \"evaluation_timestamp\"])\n", - " return result.groupby(level=[0, 1]).mean(numeric_only=True)\n", - "\n", - " def export(self, folder_name: str, chunk_size: int, credentials):\n", - " drive_service = build(\"drive\", \"v3\", credentials=credentials)\n", - " folder = find_folder(folder_name, drive_service)\n", - " if folder:\n", - " folder_id, folder_url = folder\n", - " else:\n", - " folder_id, folder_url = create_folder(folder_name, drive_service)\n", - "\n", - " queryset = self.scrape_outputs.drop(RESPONSE, axis=1)\n", - " responses = self.scrape_outputs[RESPONSE].apply(lambda x: x.to_row())\n", - " responses = pd.DataFrame(responses.to_list(), index=queryset.index)\n", - "\n", - " json_content = {\n", - " \"queryset\": queryset.to_dict(orient=\"index\"),\n", - " \"responses\": responses.to_dict(orient=\"index\"),\n", - " \"metrics\": self.metric_outputs.to_dict(orient=\"index\"),\n", - " }\n", - " json_id, json_url = create_json(\n", - " json_content, \"results.json\", folder_id, drive_service\n", - " )\n", - "\n", - " for column in [_ANSWER_GENERATOR_LLM_PROMPT, _SEARCH_RESULTS]:\n", - " truncate(responses, column)\n", - "\n", - " results = pd.concat([queryset, responses, self.metric_outputs], axis=1)\n", - " worksheets = {\n", - " \"summary\": self.aggregate().fillna(\"#N/A\"),\n", - " \"results\": results.fillna(\"#N/A\")\n", - " }\n", - " sheets_service = build(\"sheets\", \"v4\", credentials=credentials)\n", - " create_sheet(\n", - " worksheets=worksheets,\n", - " title=\"results\",\n", - " parent=folder_id,\n", - " chunk_size=chunk_size,\n", - " sheets_service=sheets_service,\n", - " drive_service=drive_service,\n", - " )\n", - " return folder_url\n", - "\n", - " @property\n", - " def timestamp(self) -> str:\n", - " return self.metric_outputs[\"evaluation_timestamp\"].iloc[0]\n", - "\n", - "\n", - "@dataclasses.dataclass\n", - "class EvaluationVisualizer:\n", - " evaluation_results: list[EvaluationResult]\n", - "\n", - " def radar_plot(self, columns: list[str] | None = None):\n", - " fig = go.Figure()\n", - " summaries = pd.concat(\n", - " [result.aggregate(columns) for result in self.evaluation_results]\n", - " )\n", - " summaries = summaries.to_dict(orient=\"split\")\n", - "\n", - " for idx, values in enumerate(summaries[\"data\"]):\n", - " fig.add_trace(\n", - " go.Scatterpolar(\n", - " r=values,\n", - " theta=summaries[\"columns\"],\n", - " fill='toself',\n", - " name=\"_\".join(summaries[\"index\"][idx]),\n", - " )\n", - " )\n", - " fig.update_layout(\n", - " polar={\"radialaxis\": {\"visible\": True, \"range\": [0, 1]}},\n", - " showlegend=True\n", - " )\n", - " fig.show()\n", - "\n", - " def count_barplot(self, column_name: str):\n", - " results = []\n", - " for result in self.evaluation_results:\n", - " responses = result.scrape_outputs[RESPONSE].apply(lambda x: x.to_row())\n", - " responses = pd.DataFrame(\n", - " responses.to_list(), index=result.scrape_outputs.index\n", - " )\n", - " results.append(\n", - " pd.concat(\n", - " [result.scrape_outputs, responses, result.metric_outputs],\n", - " axis=1\n", - " )\n", - " )\n", - " results = pd.concat(results)\n", - " results = results.set_index([\"agent_display_name\", \"evaluation_timestamp\"])\n", - " grouped_counts = (\n", - " results[column_name]\n", - " .groupby(level=[\"agent_display_name\", \"evaluation_timestamp\"])\n", - " .value_counts()\n", - " .unstack(fill_value=0)\n", - " )\n", - " grouped_counts.plot(kind=\"bar\")\n", - " plt.xlabel(\"Name\")\n", - " plt.ylabel(\"Count\")\n", - " plt.xticks(rotation=15)\n", - " plt.title(f\"{column_name} counts by name\")\n", - " plt.legend(title=column_name)\n", - " plt.show()\n", - "\n", - " def mean_barplot(self, column_names: list[str]):\n", - " results = []\n", - " for result in self.evaluation_results:\n", - " results.append(\n", - " pd.concat([result.scrape_outputs, result.metric_outputs], axis=1)\n", - " )\n", - " results = pd.concat(results)\n", - " results = results.set_index([\"agent_display_name\", \"evaluation_timestamp\"])\n", - " grouped_means = (\n", - " results[column_names]\n", - " .groupby(level=[\"agent_display_name\", \"evaluation_timestamp\"])\n", - " .mean()\n", - " )\n", - " grouped_means.plot(kind=\"bar\")\n", - " plt.ylim(top=1.0)\n", - " plt.xlabel(\"Name\")\n", - " plt.ylabel(\"Mean\")\n", - " plt.xticks(rotation=15)\n", - " plt.title(\"mean by name\")\n", - " plt.show()\n", - "\n", - "\n", - "class VertexConversationEvaluator:\n", - "\n", - " def __init__(self, metrics: list[Metric]):\n", - " self._metrics = metrics\n", - "\n", - " def run(self, scraper_output: pd.DataFrame) -> EvaluationResult:\n", - " timestamp = datetime.datetime.now(tz=datetime.timezone.utc)\n", - " scraper_output = scraper_output.copy(deep=True)\n", - " result = pd.DataFrame(index=scraper_output.index)\n", - "\n", - " for metric in self._metrics:\n", - " result = pd.concat([result, metric.run(scraper_output)], axis=1)\n", + "# @markdown `authenticate`\n", "\n", - " # adding timestamp and agent display name so they can be used as a multi\n", - " # index\n", - " result[\"evaluation_timestamp\"] = timestamp.isoformat()\n", + "if \"google.colab\" in sys.modules:\n", + " from google.auth import default\n", + " from google.colab import auth\n", + " from google.colab import files\n", "\n", - " return EvaluationResult(scraper_output, result)" + " auth.authenticate_user()\n", + " credentials, _ = default()\n", + "else:\n", + " # Otherwise, attempt to discover local credentials as described in\n", + " # https://cloud.google.com/docs/authentication/application-default-credentials\n", + " pass" ] }, { "cell_type": "markdown", - "metadata": { - "id": "TsHc10HfAHDz" - }, + "metadata": {}, "source": [ - "---\n", - "\n", "# Evaluation" ] }, @@ -2065,21 +118,7 @@ " project=vertex_ai_project_id,\n", " location=vertex_ai_location,\n", " credentials=credentials,\n", - ")\n", - "\n", - "llm = TextGenerationModel.from_pretrained(\"text-bison@002\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "p4L6fCu5fFXR" - }, - "outputs": [], - "source": [ - "# test llm on a single query\n", - "llm.predict(\"hi\")" + ")" ] }, { @@ -2099,37 +138,23 @@ "# @markdown the DialogflowCX console url:\n", "# @markdown `https://dialogflow.cloud.google.com/cx/projects/`**`{project_id}`**`/locations/`**`{location}`**`/agents/`**`{agent_id}`**`/intents`\n", "\n", - "language_code = \"en\" # @param {type: 'string'}\n", - "\n", "# @markdown ---\n", - "# @markdown ### Option 1. - Provide agent parameters directly\n", - "# @markdown\n", + "# @markdown ### Option 1. - Provide agent_id directly\n", + "# @markdown Format: `projects//locations//agents/`\n", "\n", - "agent_project_id = \"\" # @param {type: \"string\"}\n", - "agent_location = \"\" # @param {type: 'string'}\n", "agent_id = \"\" # @param {type: \"string\"}\n", "\n", "# @markdown ---\n", - "# @markdown ### Option 2. - Parse agent parameters from url\n", + "# @markdown ### Option 2. - Parse agent_id from url\n", "# @markdown > **NOTE** : if `agent_url` is provided then it has precedence over\n", - "# @markdown directly provided agent parameters.\n", + "# @markdown directly provided agent_id.\n", "\n", "agent_url = \"\" # @param {type: \"string\"}\n", "\n", "if agent_url:\n", - " scraper = VertexConversationScraper.from_url(\n", - " agent_url=agent_url,\n", - " language_code=language_code,\n", - " creds=credentials\n", - " )\n", + " scraper = DataStoreScraper.from_url(agent_url=agent_url, creds=credentials)\n", "else:\n", - " scraper = VertexConversationScraper(\n", - " agent_id=agent_id,\n", - " location=agent_location,\n", - " project_id=agent_project_id,\n", - " language_code=language_code,\n", - " creds=credentials\n", - " )" + " scraper = DataStoreScraper(agent_id=agent_id, creds=credentials)" ] }, { @@ -2159,6 +184,8 @@ "- `query` _(the input question)_\n", "- `expected_answer` _(the ideal or ground truth answer)_\n", "- `expected_uri` _(the webpage url or more generally the uri that contains the answer to `query`)_.\n", + "- `user_metadata` _(optional user metadata passed to datastore route with user query. Can be one of [str, dict])_\n", + "- `parameters` _(optional session parameters to include with the user query. Can be one of [str, dict])_\n", "\n", "In addition to the required columns the RougeL metric can also use the following optional column:\n", "\n", @@ -2166,12 +193,12 @@ "\n", "An example for the queryset can be seen in this table:\n", "\n", - "| conversation_id | turn_index | query | expected_answer | expected_uri |\n", - "| --- | --- | --- | --- | --- |\n", - "| 0 | 1 | What is the capital of France? | Capital of France is Paris. | exampleurl.com/france |\n", - "| 0 | 2 | How many people live there? | 2.1 million people live in Paris. | exampleurl.com/paris |\n", - "| 1 | 1 | What is the color of the sky? | It is blue. | exampleurl.com/common |\n", - "| 2 | 1 | How many legs does an octopus have? | It has 8 limbs. | exampleurl.com/octopus |\n", + "| conversation_id | turn_index | query | expected_answer | expected_uri | user_metadata | parameters\n", + "| --- | --- | --- | --- | --- | --- | --- |\n", + "| 0 | 1 | What is the capital of France? | Capital of France is Paris. | exampleurl.com/france | None | None |\n", + "| 0 | 2 | How many people live there? | 2.1 million people live in Paris. | exampleurl.com/paris | {\"some_end_user_key\": \"some_value\"} | {\"param1\": 1, \"param2\": \"some_string\"} |\n", + "| 1 | 1 | What is the color of the sky? | It is blue. | exampleurl.com/common | None | None |\n", + "| 2 | 1 | How many legs does an octopus have? | It has 8 limbs. | exampleurl.com/octopus | None | None |\n", "\n", "---\n", "\n", @@ -2192,15 +219,21 @@ "cell_type": "code", "execution_count": null, "metadata": { + "cellView": "form", "id": "qwhOvOSOmnJ4" }, "outputs": [], "source": [ + "# @markdown `run this cell to load data manually`\n", + "INPUT_SCHEMA_REQUIRED_COLUMNS = [\n", + " \"conversation_id\", \"turn_index\", \"query\", \"expected_answer\", \"expected_uri\", \"user_metadata\", \"parameters\"\n", + "]\n", + "\n", "sample_df = pd.DataFrame(columns=INPUT_SCHEMA_REQUIRED_COLUMNS)\n", "\n", - "sample_df.loc[0] = [\"0\", 1 ,\"Who are you?\", \"I am an assistant\", \"www.google.com\"]\n", - "sample_df.loc[1] = [\"1\", 1 ,\"Which is the cheapest plan?\", \"Basic plan\", \"www.google.com\"]\n", - "sample_df.loc[2] = [\"1\", 2, \"How much?\", \"The Basic plan costs 20$/month\", \"www.google.com\"]\n", + "sample_df.loc[0] = [\"0\", 1 ,\"Who are you?\", \"I am an assistant\", \"www.google.com\", None, None]\n", + "sample_df.loc[1] = [\"1\", 1 ,\"Who is the CEO?\", \"The CEO is Matt Reinjes.\", \"www.yeti.com\", {\"some_end_user_key\": \"some_value\"}, {\"param1\": 1, \"param2\": \"some_string\"}]\n", + "sample_df.loc[2] = [\"1\", 2, \"How much?\", \"The Basic plan costs 20$/month\", \"www.google.com\", None, None]\n", "queryset = sample_df" ] }, @@ -2228,7 +261,14 @@ "csv_path = \"\" # @param{type: 'string'}\n", "\n", "queryset = pd.read_csv(csv_path)\n", - "queryset = queryset.fillna(\"\")" + "queryset = queryset.fillna(\"\")\n", + "\n", + "if \"user_metadata\" in queryset.columns:\n", + " queryset = queryset.assign(\n", + " user_metadata=queryset[\"user_metadata\"].apply(lambda p:p if p != \"\" else None)\n", + " )\n", + "else:\n", + " queryset = queryset.assign(user_metadata=None)" ] }, { @@ -2259,7 +299,15 @@ "\n", "_worksheet_name = worksheet_name if worksheet_name else \"Sheet1\"\n", "\n", - "queryset = load_spreadsheet(sheet_url, _worksheet_name, credentials)\n" + "queryset = load_spreadsheet(sheet_url, _worksheet_name, credentials)\n", + "\n", + "if \"user_metadata\" in queryset.columns:\n", + " queryset = queryset.assign(\n", + " user_metadata=queryset[\"user_metadata\"].apply(lambda p:p if p != \"\" else None)\n", + " )\n", + "else:\n", + " queryset = queryset.assign(user_metadata=None)\n", + "\n" ] }, { @@ -2356,23 +404,13 @@ "\n", "metrics = []\n", "\n", - "if URL_MATCH:\n", - " metrics.append(UrlMatch())\n", + "if URL_MATCH: metrics.append(\"url_match\")\n", + "if ROUGEL: metrics.append(\"rougeL\")\n", + "if ANSWER_CORRECTNESS: metrics.append(\"answer_correctness\")\n", + "if FAITHFULNESS: metrics.append(\"faithfulness\")\n", + "if CONTEXT_RECALL: metrics.append(\"context_recall\")\n", "\n", - "if ROUGEL:\n", - " metrics.append(RougeL())\n", - "\n", - "if any((ANSWER_CORRECTNESS, FAITHFULNESS, CONTEXT_RECALL)):\n", - " metrics.append(\n", - " StatementBasedBundledMetric(\n", - " llm=llm,\n", - " answer_correctness=ANSWER_CORRECTNESS,\n", - " faithfulness=FAITHFULNESS,\n", - " context_recall=CONTEXT_RECALL,\n", - " )\n", - " )\n", - "\n", - "evaluator = VertexConversationEvaluator(metrics=metrics)" + "evaluator = DataStoreEvaluator(metrics=metrics)" ] }, { @@ -2388,24 +426,107 @@ "cell_type": "code", "execution_count": null, "metadata": { + "cellView": "form", "collapsed": true, "id": "9NeMsvykHb0E" }, "outputs": [], "source": [ + "# @markdown `evaluation results`\n", "evaluation_result = evaluator.run(scrape_result)" ] }, + { + "cell_type": "markdown", + "metadata": { + "id": "HAZt4TG3Pnwe" + }, + "source": [ + "## Export results" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dnjWYe58P25A" + }, + "source": [ + "### Option 1. - Display" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "mLNaJ8S-RC4O" + }, + "outputs": [], + "source": [ + "# @markdown `run this cell to display evaluation results`\n", + "Number_of_rows = 3 # @param {type: \"integer\"}\n", + "\n", + "\n", + "results=evaluation_result.display_on_screen()\n", + "results.head(Number_of_rows)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OfxfaXQdQF-p" + }, + "source": [ + "### Option 2. - To local.csv and download to your system" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "Z4yFlm97rIRp" + }, + "outputs": [], + "source": [ + "# @markdown `run this cell to export evaluation results into Google Sheets`\n", + "\n", + "FILE_NAME = \"evaluation_results.csv\" # @param {type: \"string\"}\n", + "\n", + "filepath = evaluation_result.export_to_csv(FILE_NAME)\n", + "\n", + "# Prompt user to download the file\n", + "print(f\"CSV file created at: {filepath}\")\n", + "print(\"Would you like to download the file? (y/n)\")\n", + "user_choice = input().lower()\n", + "\n", + "if user_choice == \"y\":\n", + " # Download the file using Colab's download feature\n", + " files.download(filepath)\n", + " print(\"File downloaded successfully!\")\n", + "else:\n", + " print(\"Download skipped.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NsZmpIBpQIu9" + }, + "source": [ + "### Option 3. - To Google Sheets" + ] + }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", - "id": "F-aAhhD-qPJp" + "id": "YzjPmsPVUaJt" }, "outputs": [], "source": [ - "# @markdown `export evaluation results`\n", + "# @markdown `run this cell to export evaluation results into Google Sheets`\n", "\n", "FOLDER_NAME = \"result\" # @param {type: \"string\"}\n", "CHUNK_SIZE = 50 # @param {type: \"number\"}\n", @@ -2421,6 +542,33 @@ "print(f\"Exported results to folder: {folder_url}\")" ] }, + { + "cell_type": "markdown", + "metadata": { + "id": "z8WDndEnpj82" + }, + "source": [ + "### Option 4. - To Bigquery\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "CwUj_bosp3Mv" + }, + "outputs": [], + "source": [ + "BQ_PROJECT_ID=\"\" # @param {type: \"string\"}\n", + "BQ_DATASET_ID=\"\" # @param {type: \"string\"}\n", + "BQ_TABLE_NAME =\"\" # @param {type: \"string\"}\n", + "\n", + "\n", + "filepath = evaluation_result.export_to_bigquery(BQ_PROJECT_ID,BQ_DATASET_ID,BQ_TABLE_NAME,credentials)" + ] + }, { "cell_type": "markdown", "metadata": { @@ -2434,10 +582,12 @@ "cell_type": "code", "execution_count": null, "metadata": { + "cellView": "form", "id": "kARLoOJYBJ0e" }, "outputs": [], "source": [ + "# @markdown `Folder url`\n", "FOLDER_URLS = [\n", " folder_url, # latest evaluation\n", " # add previous evaluations e.g: https://drive.google.com/drive/folders/\n", @@ -2455,10 +605,13 @@ "source": [ "# @markdown `define evaluation visualizer`\n", "\n", - "evaluation_visualizer = EvaluationVisualizer([\n", - " EvaluationResult.load(folder_url, credentials)\n", - " for folder_url in FOLDER_URLS\n", - "])" + "results = []\n", + "for folder_url in FOLDER_URLS:\n", + " er = EvaluationResult()\n", + " er.load(folder_url, credentials)\n", + " results.append(er)\n", + "\n", + "evaluation_visualizer = EvaluationVisualizer(results)" ] }, { @@ -2471,6 +624,7 @@ "outputs": [], "source": [ "# @markdown `radar plot of autoeval metrics`\n", + "from dfcx_scrapi.tools.metrics import StatementBasedBundledMetric, RougeL\n", "\n", "evaluation_visualizer.radar_plot(StatementBasedBundledMetric.COLUMNS)" ] @@ -2508,17 +662,7 @@ "metadata": { "colab": { "private_outputs": true, - "provenance": [ - { - "file_id": "17WDmf3DsZGg1ZGwnr40sMXMQtyfxb4ms", - "timestamp": 1713942777165 - }, - { - "file_id": "1b769OFNM8gH56ZzvWUfw4MpGUc-pSF-g", - "timestamp": 1708939513950 - } - ], - "toc_visible": true + "provenance": [] }, "kernelspec": { "display_name": "Python 3", @@ -2539,4 +683,4 @@ }, "nbformat": 4, "nbformat_minor": 0 -} \ No newline at end of file +} diff --git a/examples/vertex_ai_conversation/vertex_agents_evals.ipynb b/examples/vertex_ai_conversation/vertex_agents_evals.ipynb new file mode 100644 index 00000000..e21a9667 --- /dev/null +++ b/examples/vertex_ai_conversation/vertex_agents_evals.ipynb @@ -0,0 +1,330 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Copyright 2024 Google LLC\n", + "#\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Vertex Agent Builder and Dialogflow CX Evaluations\n", + "In this notebook, we will show you how to:\n", + "1. Build a new Evaluation Dataset for your Agent.\n", + "2. Run the Evaluations to get Quality Metrics\n", + "3. Push the Results to Google Sheets for reporting.\n", + "\n", + "\n", + "## Prerequisites\n", + "- Existing Agent Builder or DFCX Agent w/ or w/out Tool calling.\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + "
\n", + " \n", + " \"Google
Run in Colab\n", + "
\n", + "
\n", + " \n", + " \"GitHub
View on GitHub\n", + "
\n", + "
\n", + " \n", + " \"Vertex
Open in Vertex AI Workbench\n", + "
\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install dfcx-scrapi>=1.12.0\n", + "\n", + "import sys\n", + "\n", + "if \"google.colab\" in sys.modules:\n", + " from google.colab import auth\n", + " from google.auth import default\n", + "\n", + " auth.authenticate_user()\n", + " credentials, _ = default()\n", + "else:\n", + " # Otherwise, attempt to discover local credentials as described in\n", + " # https://cloud.google.com/docs/authentication/application-default-credentials\n", + " pass\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Evaluation Dataset Format\n", + "\n", + "Collecting and evaluating multi-turn chat data can be complex, so we have devised template that you can follow to make it easy and scalable.\n", + "\n", + "The evaluation dataset must be in a tabular format and contain the following columns:\n", + "- `eval_id` \n", + " - _Unique identifier of a conversation, which must be the same for each row that is part of the same conversation_\n", + "- `action_id` \n", + " - _Index of the specific action for the current conversation._\n", + " - _This is used to track and pair responses during inference time._\n", + "- `action_type` \n", + " - _The specific action for this turn._\n", + " - _There are currently 4 supported Action Types: `User Utterance`, `Agent Response`, `Playbook Invocation`, `Tool Invocation`_\n", + "- `action_input`\n", + " - _The input for this specific action_type._\n", + " - _Based on the specified action_type, this could be the expected user utterance or agent response, the expected Playbook name, or the expected Tool name._\n", + "- `action_input_parameters`\n", + " - When `Playbook Invocation` or `Tool Invocation` is selected as the `action_type`, this refers to the payload of information that is expected to be sent with that invocation._\n", + " - _For example, a JSON payload of key/value pairs called with the tool._\n", + "- `tool_action`\n", + " - _This field is only used when `Tool Invocation` is chosen as the `action_type`._\n", + " - _This allows us to run evaluations on whether the Tool call chose the correct internal action (if more than one exists)_\n", + "\n", + "---\n", + "\n", + "An example for the queryset can be seen in this table:\n", + "\n", + "| eval_id | action_id | action_type | action_input | action_input_parameters | tool_action | notes |\n", + "|---|---|---|---|---|---|---|\n", + "| travel-ai-001 | 1 | User Utterance | Paris | | | |\n", + "| travel-ai-001 | 2 | Playbook Invocation | Travel Inspiration | | | |\n", + "| travel-ai-001 | 3 | Agent Response | Paris is a beautiful city! Here are a few things you might enjoy doing there:

Visit the Eiffel Tower
Take a walk along the Champs-Élysées
Visit the Louvre Museum
See the Arc de Triomphe
Take a boat ride on the Seine River | | | |\n", + "| travel-ai-002 | 1 | User Utterance | I want to go to Barcelona with my family of four in June | | | |\n", + "| travel-ai-002 | 2 | Playbook Invocation | Travel Inspiration | | | |\n", + "| travel-ai-002 | 3 | Agent Response | I'd be happy to help you find a hotel in Barcelona for your family of four in June. What are your preferred dates of travel? | | | |\n", + "| travel-ai-002 | 4 | User Utterance | 1st through 10th | | | |\n", + "| travel-ai-002 | 5 | Playbook Invocation | Book Hotel | | | |\n", + "| travel-ai-002 | 6 | Tool Invocation | hotel_tool | {'city': 'Barcelona', 'num_results': 10} | hotel_search | |\n", + "---" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Data Loading\n", + "The preferred method for data loading is to use Google Sheets. \n", + "However you can also manually build your dataset as a Pandas Dataframe, or load from CSV, BQ, etc. as needed." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Option 1 - From Google Sheets" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from dfcx_scrapi.tools.evaluations import DataLoader\n", + "\n", + "data = DataLoader()\n", + "\n", + "sheet_name = \"[TEMPLATE] Vertex Agent Evals Dataset Format\"\n", + "sheet_tab = \"golden-agent-evals\"\n", + "\n", + "sample_df = data.from_google_sheets(sheet_name, sheet_tab)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Option 2 - From Local CSV" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from dfcx_scrapi.tools.evaluations import DataLoader\n", + "\n", + "data = DataLoader()\n", + "\n", + "csv_file_path = \"\"\n", + "\n", + "sample_df = data.from_csv(csv_file_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Option 3 - Manual Loading" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "from dfcx_scrapi.tools.evaluations import DataLoader\n", + "\n", + "data = DataLoader()\n", + "\n", + "INPUT_SCHEMA_REQUIRED_COLUMNS = ['eval_id', 'action_id', 'action_type', 'action_input', 'action_input_parameters', 'tool_action', 'notes']\n", + "\n", + "sample_df = pd.DataFrame(columns=INPUT_SCHEMA_REQUIRED_COLUMNS)\n", + "\n", + "sample_df.loc[0] = [\"travel-ai-001\", 1, \"User Utterance\", \"Paris\", \"\", \"\", \"\"]\n", + "sample_df.loc[1] = [\"travel-ai-001\", 2, \"Playbook Invocation\", \"Travel Inspiration\", \"\", \"\", \"\"]\n", + "sample_df.loc[2] = [\"travel-ai-001\", 3, \"Agent Response\", \"Paris is a beautiful city! Here are a few things you might enjoy doing there:\\n\\nVisit the Eiffel Tower\\nTake a walk along the Champs-Élysées\\nVisit the Louvre Museum\\nSee the Arc de Triomphe\\nTake a boat ride on the Seine River\", \"\", \"\", \"\"]\n", + "\n", + "sample_df = data.from_dataframe(sample_df)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Metrics\n", + "For multi-turn chat Agents w/ or w/out tool calling, there are currently 2 metrics supported:\n", + "1. `Response Similarity`, this performs a Semantic Similarity comparison between the expected \"golden\" response and the actual \"predicted\" response\n", + "2. `Tool Call Quality`, this performs and EXACT_MATCH on 2 components of the Tool call\n", + " - Tool Name, i.e. was the correct Tool invoked\n", + " - Tool Action, i.e. for the given Tool, was the correct Action / Endpoint invoked\n", + "\n", + "Other metrics like `UrlMatch`, `Faithfulness`, `Answer Correctness`, `Context Recall` etc. will be supported in the future." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from dfcx_scrapi.tools.evaluations import Evaluations\n", + "\n", + "# [1] Define your Agent ID here\n", + "agent_id = \"projects/your-project/locations/us-central1/agents/11111-2222-33333-44444\" # Example Agent\n", + "\n", + "# [2] Instantiate Evals class w/ Metrics\n", + "evals = Evaluations(agent_id, metrics=[\"response_similarity\", \"tool_call_quality\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Predict and Evaluate\n", + "In this step, we will run all of the queries against the Agent that is being evaluated. \n", + "Once the queries are returned, we will then compute all of the metrics." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "eval_results = evals.run_query_and_eval(sample_df.head(10))\n", + "\n", + "print(f\"Average Similarity {eval_results.similarity.mean()}\")\n", + "print(f\"Average Tool Call Quality {eval_results.tool_name_match.mean()}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# You can inspect the results as needed\n", + "eval_results.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Write to Google Sheets\n", + "Storing the evaluation results to Google Sheets can be done with the following snippets.\n", + "\n", + "In future revisions, we will add export to other format including `csv`, `bigquery`, etc." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from dfcx_scrapi.tools.evaluations import DataLoader\n", + "\n", + "data = DataLoader()\n", + "\n", + "data.write_eval_results_to_sheets(eval_results, sheet_name, results_tab=\"latest_results\")\n", + "data.append_test_results_to_sheets(eval_results, sheet_name, summary_tab=\"reporting\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Ending and Wrap-Up\n", + "\n", + "In this notebook, we've shown how to programmatically Evaluate your Agent Builder or Dialogflow CX Agent.\n", + "\n", + "For more information, see:\n", + "- [Vertex AI Agents](https://cloud.google.com/dialogflow/vertex/docs/concept/agents)\n", + "- [Dialogflow CX](https://cloud.google.com/dialogflow/cx/docs)" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "19fe958eff886c70bc7b0837ba1e6b09536c8944c54196036e51b6ba767223fc" + }, + "kernelspec": { + "display_name": "Python 3.8.11 64-bit ('scrapi-local': conda)", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.19" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pyproject.toml b/pyproject.toml index 6c7b6f52..036d2899 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,10 +1,3 @@ -# sample configuration for Black. - -# NOTE: you have to use single-quoted strings in TOML for regular expressions. -# It's the equivalent of r-strings in Python. Multiline strings are treated as -# verbose regular expressions by Black. Use [ ] to denote a significant space -# character. - [tool.black] line-length = 88 target-version = ['py36', 'py37', 'py38'] @@ -28,8 +21,11 @@ exclude = ''' )/ ''' - # Build system information below. [build-system] requires = ["setuptools>=42", "setuptools-scm", "wheel"] build-backend = "setuptools.build_meta" + +[tool.ruff] +line-length = 80 +extend-exclude = ["*.ipynb"] diff --git a/requirements.txt b/requirements.txt index 90cc44fc..fb4c5e8a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,9 @@ -dfcx-scrapi +google-cloud-aiplatform>=1.39.0 google-cloud-dialogflow-cx>=1.34.0 google-cloud-discoveryengine>=0.11.10 google-auth>=2.27.0 google-oauth oauth2client -pyparsing==2.4.7 pandas tabulate gspread==5.10.0 @@ -15,7 +14,8 @@ pylint==2.8.3 pytest==6.0.2 pytest-cov==2.11.1 pytest-xdist==2.1.0 -pyyaml==5.4 +pyyaml +rouge-score torch transformers -sentencepiece +tqdm diff --git a/setup.py b/setup.py index 4eb23058..d1af2976 100644 --- a/setup.py +++ b/setup.py @@ -23,7 +23,7 @@ setup( name='dfcx-scrapi', - version='1.11.0', + version='1.12.1', description='A high level scripting API for bot builders, developers, and\ maintainers.', long_description=long_description, @@ -45,5 +45,10 @@ package_dir={'':'src'}, packages=find_packages(where='src'), python_requires='>=3.6, <4', - install_requires=['google-cloud-dialogflow-cx'] + + install_requires=[ + 'google-cloud-dialogflow-cx', + 'google-cloud-aiplatform', + 'rouge-score' + ] ) diff --git a/src/dfcx_scrapi/core/engines.py b/src/dfcx_scrapi/core/engines.py index d0ece20d..3b5fa0a1 100644 --- a/src/dfcx_scrapi/core/engines.py +++ b/src/dfcx_scrapi/core/engines.py @@ -119,6 +119,34 @@ def build_chat_engine_proto( return engine + + def build_search_engine_proto(self, display_name: str, data_store_ids: List[str], + search_tier: str = "SEARCH_TIER_STANDARD"): + """Build the Search Engine proto for creating a new Engine. + + Args: + display_name: the human readable display name of the Search Engine + data_store_ids: a list of Data Store IDs associated with this engine. + search_tier: either SEARCH_TIER_STANDARD (default) or SEARCH_TIER_ENTERPRISE + + Returns: + The Engine object configured as a SearchEngine. + """ + data_store_ids = self.__process_data_store_ids(data_store_ids) + + engine = Engine() + + se_config = Engine.SearchEngineConfig() + se_config.search_tier = search_tier + + engine.display_name = display_name + engine.solution_type = self._get_solution_type("search") + engine.data_store_ids = data_store_ids + engine.search_engine_config = se_config + + return engine + + def list_engines( self, location: str = "global") -> List[Engine]: """List all Engines for a given project and location.""" diff --git a/src/dfcx_scrapi/core/flows.py b/src/dfcx_scrapi/core/flows.py index 2185c74a..6736cf01 100644 --- a/src/dfcx_scrapi/core/flows.py +++ b/src/dfcx_scrapi/core/flows.py @@ -194,7 +194,10 @@ def train_flow(self, flow_id: str) -> str: return response @scrapi_base.api_call_counter_decorator - def list_flows(self, agent_id: str = None) -> List[types.Flow]: + def list_flows( + self, + agent_id: str = None, + language_code: str = "en") -> List[types.Flow]: """Get a List of all Flows in the current Agent. Args: @@ -209,6 +212,7 @@ def list_flows(self, agent_id: str = None) -> List[types.Flow]: request = types.flow.ListFlowsRequest() request.parent = agent_id + request.language_code = language_code client_options = self._set_region(agent_id) client = services.flows.FlowsClient( @@ -223,7 +227,7 @@ def list_flows(self, agent_id: str = None) -> List[types.Flow]: return flows def get_flow_by_display_name( - self, display_name: str, agent_id: str + self, display_name: str, agent_id: str, language_code: str = "en" ) -> types.Flow: """Get a single CX Flow object based on its display name. @@ -245,12 +249,12 @@ def get_flow_by_display_name( f"does not exist in the specified agent." ) - flow = self.get_flow(flow_id=flow_id) + flow = self.get_flow(flow_id=flow_id, language_code=language_code) return flow @scrapi_base.api_call_counter_decorator - def get_flow(self, flow_id: str) -> types.Flow: + def get_flow(self, flow_id: str, language_code: str = "en") -> types.Flow: """Get a single CX Flow object. Args: @@ -259,12 +263,15 @@ def get_flow(self, flow_id: str) -> types.Flow: Returns: A single CX Flow object """ + request = types.flow.GetFlowRequest() + request.name = flow_id + request.language_code = language_code client_options = self._set_region(flow_id) client = services.flows.FlowsClient( credentials=self.creds, client_options=client_options ) - response = client.get_flow(name=flow_id) + response = client.get_flow(request) return response diff --git a/src/dfcx_scrapi/core/playbooks.py b/src/dfcx_scrapi/core/playbooks.py index 1a7a0a9f..4458097b 100644 --- a/src/dfcx_scrapi/core/playbooks.py +++ b/src/dfcx_scrapi/core/playbooks.py @@ -41,6 +41,7 @@ def __init__( creds_dict: Dict = None, creds=None, scope=False, + playbooks_map: Dict[str, str] = None ): super().__init__( creds_path=creds_path, @@ -59,24 +60,78 @@ def __init__( credentials=self.creds, client_options=client_options ) + self.playbooks_map = playbooks_map + @staticmethod - def build_instructions( + def build_instructions_from_list( instructions: List[str]) -> List[types.Playbook.Step]: """Helper method to create the playbook instruction set protos.""" final_instructions = types.Playbook.Instruction(steps=[]) - if not isinstance(instructions, list): - raise TypeError( - "Instructions must be provided as a List of strings.") - - else: - for instruction in instructions: - final_instructions.steps.append( - types.Playbook.Step(text=instruction) - ) + for instruction in instructions: + final_instructions.steps.append( + types.Playbook.Step(text=instruction) + ) return final_instructions + @staticmethod + def clean_line(step_text: str): + """Helper method to clean the current line that is being parsed.""" + step_text = step_text.strip() # clear any whitespace + step_text = step_text.strip("-") # clear `-` if exists + step_text = step_text.strip() # clear remaining whitespace if exists + + return step_text + + def parse_steps( + self, lines: List[str], start_index: int, indent_level: int): + """Recursively parse instructions and build Playbook.Step objects. + + Args: + lines: The list of instruction lines. + start_index: The index to start parsing from. + indent_level: The current indentation level. + + Returns: + A tuple containing: + - A list of parsed Playbook.Step objects at the current level. + - The index of the next line to parse. + """ + steps: List[types.Playbook.Step] = [] + i = start_index + while i < len(lines): + line = lines[i] + current_indent = len(line) - len(line.lstrip()) + if current_indent == indent_level: + step_text = self.clean_line(line) + step = types.Playbook.Step(text=step_text) + steps.append(step) + i += 1 + elif current_indent > indent_level: + # Recursively parse child steps + child_steps, next_index = self.parse_steps( + lines, i, current_indent) + steps[-1].steps.extend(child_steps) + i = next_index + else: + # Reached a line with lower indentation + # stop parsing at this level + break + + return steps, i + + def build_instructions_from_string( + self, + instructions: str) -> types.Playbook.Instruction: + + instruction_obj = types.Playbook.Instruction() + lines = instructions.strip().splitlines() + parsed_steps, _ = self.parse_steps(lines, 0, 0) + instruction_obj.steps.extend(parsed_steps) + + return instruction_obj + def process_playbook_kwargs( self, playbook: types.Playbook, @@ -85,8 +140,14 @@ def process_playbook_kwargs( paths = [] for key, value in kwargs.items(): if key in ["instruction", "instructions"]: - instructions = self.build_instructions(value) - setattr(playbook, "instruction", instructions) + if isinstance(value, list): + instructions = self.build_instructions_from_list(value) + setattr(playbook, "instruction", instructions) + elif isinstance(value, str): + instructions = self.build_instructions_from_string(value) + setattr(playbook, "instruction", instructions) + elif isinstance(value, types.Playbook.Instruction): + setattr(playbook, "instruction", value) paths.append("instruction") else: setattr(playbook, key, value) @@ -271,3 +332,20 @@ def delete_playbook( playbook_id = obj.name self.playbooks_client.delete_playbook(name=playbook_id) + + @scrapi_base.api_call_counter_decorator + def create_playbook_version( + self, playbook_id: str, description: str = None + ) -> types.PlaybookVersion: + """Creates a Playbook Version of the specific Playbook ID.""" + playbook_version = types.PlaybookVersion() + playbook_version.name = playbook_id + playbook_version.description = description + + request = types.CreatePlaybookVersionRequest() + request.parent = playbook_id + request.playbook_version = playbook_version + + response = self.playbooks_client.create_playbook_version(request) + + return response diff --git a/src/dfcx_scrapi/core/scrapi_base.py b/src/dfcx_scrapi/core/scrapi_base.py index cac0b7f3..c9335653 100644 --- a/src/dfcx_scrapi/core/scrapi_base.py +++ b/src/dfcx_scrapi/core/scrapi_base.py @@ -17,19 +17,56 @@ import logging import json import re +import time import functools +import threading +import vertexai from collections import defaultdict -from typing import Dict, Any +from typing import Dict, Any, Iterable +from google.api_core import exceptions from google.cloud.dialogflowcx_v3beta1 import types from google.oauth2 import service_account from google.auth.transport.requests import Request -from google.protobuf import json_format # type: ignore +from google.protobuf import json_format from google.protobuf import field_mask_pb2, struct_pb2 +from vertexai.generative_models import GenerativeModel +from vertexai.language_models import TextEmbeddingModel, TextGenerationModel + from proto.marshal.collections import repeated from proto.marshal.collections import maps +_INTERVAL_SENTINEL = object() + +# The following models are supported for Metrics and Evaluations, either for +# Text Embeddings or used to provide Generations / Predictions. +SYS_INSTRUCT_MODELS = [ + "gemini-1.0-pro-002", + "gemini-1.5-pro-001", + "gemini-1.5-flash-001" +] + +NON_SYS_INSTRUCT_GEM_MODELS = [ + "gemini-1.0-pro-001" +] + +ALL_GEMINI_MODELS = SYS_INSTRUCT_MODELS + NON_SYS_INSTRUCT_GEM_MODELS + +TEXT_GENERATION_MODELS = [ + "text-bison@002", + "text-unicorn@001" +] +EMBEDDING_MODELS_NO_DIMENSIONALITY = [ + "textembedding-gecko@001", + "textembedding-gecko@003", + "textembedding-gecko-multilingual@001" +] +ALL_EMBEDDING_MODELS = EMBEDDING_MODELS_NO_DIMENSIONALITY + [ + "text-embedding-004" +] + +ALL_GENERATIVE_MODELS = ALL_GEMINI_MODELS + TEXT_GENERATION_MODELS class ScrapiBase: """Core Class for managing Auth and other shared functions.""" @@ -354,6 +391,62 @@ def _get_solution_type(solution_type: str) -> int: return solution_map[solution_type] + @staticmethod + def is_valid_sys_instruct_model(llm_model: str) -> bool: + valid_sys_instruct = True + """Validate if model allows system instructions.""" + if llm_model in NON_SYS_INSTRUCT_GEM_MODELS: + valid_sys_instruct = False + + return valid_sys_instruct + + def build_generative_model( + self, + llm_model: str, + system_instructions: str = None + ) -> GenerativeModel: + """Build the GenertiveModel object and sys instructions as required.""" + valid_sys_intruct = self.is_valid_sys_instruct_model(llm_model) + + if valid_sys_intruct and system_instructions: + return GenerativeModel( + llm_model, system_instruction=system_instructions) + + elif not valid_sys_intruct and system_instructions: + raise ValueError( + f"Model `{llm_model}` does not support System Instructions" + ) + else: + return GenerativeModel(llm_model) + + def model_setup(self, llm_model: str, system_instructions: str = None): + """Create a new LLM instance from user inputs.""" + if llm_model in ALL_EMBEDDING_MODELS: + return TextEmbeddingModel.from_pretrained(llm_model) + + elif llm_model in ALL_GEMINI_MODELS: + return self.build_generative_model(llm_model, system_instructions) + + elif llm_model in TEXT_GENERATION_MODELS: + return TextGenerationModel.from_pretrained(llm_model) + + else: + raise ValueError(f"LLM Model `{llm_model}` not supported.") + + def init_vertex(self, agent_id: str): + """Use the Agent ID to parse out relevant fields and init Vertex API.""" + parts = self._parse_resource_path("agent", agent_id) + project_id = parts.get("project") + location = parts.get("location") + + # Vertex doesn't support global region at this time, but Agent Builder + # and DFCX do. If this method is used in conjunction with an agent_id + # that is located in `global` region, we will fallback to `us-central1` + # to init Vertex. + location = "us-central1" if location == "global" else location + + vertexai.init(project=project_id, location=location) + def _build_data_store_parent(self, location: str) -> str: """Build the Parent ID needed for Discovery Engine API calls.""" return (f"projects/{self.project_id}/locations/{location}/collections/" @@ -473,3 +566,72 @@ def wrapper(self, *args, **kwargs): wrapper.calls_api = True return wrapper + +def should_retry(err: exceptions.GoogleAPICallError) -> bool: + """Helper function for deciding whether we should retry the error or not.""" + return isinstance(err, (exceptions.TooManyRequests, exceptions.ServerError)) + +def ratelimit(rate: float): + """Decorator that controls the frequency of function calls.""" + seconds_per_event = 1.0 / rate + lock = threading.Lock() + bucket = 0 + last = 0 + + def decorate(func): + def rate_limited_function(*args, **kwargs): + nonlocal last, bucket + while True: + with lock: + now = time.time() + bucket += now - last + last = now + + # capping the bucket in order to avoid accumulating too many + bucket = min(bucket, seconds_per_event) + + # if bucket is less than `seconds_per_event` then we have to wait + # `seconds_per_event` - `bucket` seconds until a new "token" is + # refilled + delay = max(seconds_per_event - bucket, 0) + + if delay == 0: + # consuming a token and breaking out of the delay loop to perform + # the function call + bucket -= seconds_per_event + break + time.sleep(delay) + return func(*args, **kwargs) + return rate_limited_function + return decorate + +def retry_api_call(retry_intervals: Iterable[float]): + """Decorator for retrying certain GoogleAPICallError exception types.""" + def decorate(func): + def retried_api_call_func(*args, **kwargs): + interval_iterator = iter(retry_intervals) + while True: + try: + return func(*args, **kwargs) + except exceptions.GoogleAPICallError as err: + print(f"retrying api call: {err}") + if not should_retry(err): + raise + + interval = next(interval_iterator, _INTERVAL_SENTINEL) + if interval is _INTERVAL_SENTINEL: + raise + time.sleep(interval) + return retried_api_call_func + return decorate + +def handle_api_error(func): + """Decorator that chatches GoogleAPICallError exception and returns None.""" + def handled_api_error_func(*args, **kwargs): + try: + return func(*args, **kwargs) + except exceptions.GoogleAPICallError as err: + print(f"failed api call: {err}") + return None + + return handled_api_error_func diff --git a/src/dfcx_scrapi/core/sessions.py b/src/dfcx_scrapi/core/sessions.py index d99f91c8..5853a1c5 100644 --- a/src/dfcx_scrapi/core/sessions.py +++ b/src/dfcx_scrapi/core/sessions.py @@ -1,6 +1,6 @@ """CX Session Resource functions.""" -# Copyright 2023 Google LLC +# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,6 +25,8 @@ from dfcx_scrapi.core import scrapi_base from dfcx_scrapi.core import tools +from dfcx_scrapi.core import playbooks + # logging config logging.basicConfig( @@ -41,18 +43,22 @@ def __init__( self, creds_path: str = None, creds_dict: Dict = None, + creds=None, scope=False, agent_id: str = None, session_id: str = None, tools_map: Dict[str, str] = None, + playbooks_map: Dict[str, str] = None ): super().__init__( - creds_path=creds_path, creds_dict=creds_dict, scope=scope + creds_path=creds_path, creds_dict=creds_dict, + creds=creds, scope=scope ) self.session_id = session_id self.agent_id = agent_id self.tools_map = tools_map + self.playbooks_map = playbooks_map @property def session_id(self): @@ -85,19 +91,42 @@ def _build_query_input(text, language_code): return query_input @staticmethod - def get_text_response(res: types.session.QueryResult) -> str: - all_text = [] - if res.response_messages: - for rm in res.response_messages: - if rm.text: - all_text.append(rm.text.text[0]) + def build_intent_query_input(intent_id: str, language_code: str): + """Build the query_input object for direct Intent request.""" + intent_input = types.session.IntentInput(intent=intent_id) + query_input = types.session.QueryInput( + intent=intent_input, language_code=language_code + ) - return all_text + return query_input @staticmethod def get_tool_action(tool_use: types.example.ToolUse) -> str: return tool_use.action + def get_tool_params(self, params: maps.MapComposite): + "Handle various types of param values from Tool input/outputs." + param_map = {} + if isinstance(params, maps.MapComposite): + param_map = self.recurse_proto_marshal_to_dict(params) + + # Clean up resulting param map. This is because I/O params from Agent + # Builder proto will have a blank top level key, but the main value + # info is what is important for return to the user in this tool. + empty_top_key = param_map.get("", None) + if len(param_map.keys()) == 1 and empty_top_key: + param_map = param_map[""] + + return param_map + + def get_playbook_name(self, playbook_id: str): + agent_id = self.parse_agent_id(playbook_id) + if not self.playbooks_map: + playbook_client = playbooks.Playbooks(agent_id) + self.playbooks_map = playbook_client.get_playbooks_map(agent_id) + + return self.playbooks_map[playbook_id] + def get_tool_name(self, tool_use: types.example.ToolUse) -> str: agent_id = self.parse_agent_id(tool_use.tool) if not self.tools_map: @@ -105,27 +134,6 @@ def get_tool_name(self, tool_use: types.example.ToolUse) -> str: self.tools_map = tool_client.get_tools_map(agent_id) return self.tools_map[tool_use.tool] - def get_tool_input_parameters(self, tool_use: types.example.ToolUse) -> str: - input_params = {} - for param in tool_use.input_parameters: - if isinstance(param.value, maps.MapComposite): - input_params = self.recurse_proto_marshal_to_dict(param.value) - else: - input_params[param.name] = param.value - - return input_params - - def get_tool_output_parameters( - self, tool_use: types.example.ToolUse - ) -> str: - output_params = {} - for param in tool_use.output_parameters: - output_params[param.name] = self.recurse_proto_marshal_to_dict( - param.value - ) - - return output_params - def collect_tool_responses( self, res: types.session.QueryResult ) -> List[Dict[str, str]]: @@ -137,17 +145,42 @@ def collect_tool_responses( { "tool_name": self.get_tool_name(action.tool_use), "tool_action": self.get_tool_action(action.tool_use), - "input_params": self.get_tool_input_parameters( - action.tool_use - ), - "output_params": self.get_tool_output_parameters( - action.tool_use - ), + "input_params": self.get_tool_params( + action.tool_use.input_action_parameters), + "output_params": self.get_tool_params( + action.tool_use.output_action_parameters), } ) return tool_responses + def collect_playbook_responses( + self, res: types.session.QueryResult + ) -> List[Dict[str, str]]: + """Gather all the playbook responses into a list of dicts.""" + playbook_responses = [] + for action in res.generative_info.action_tracing_info.actions: + if action.playbook_invocation: + playbook_responses.append( + { + "playbook_name": self.get_playbook_name( + action.playbook_invocation.playbook + ) + } + ) + else: + # If no playbook invocation was found + # return the current Playbook + playbook_responses.append( + { + "playbook_name": self.get_playbook_name( + res.generative_info.current_playbooks[-1] + ) + } + ) + + return playbook_responses + def build_session_id( self, agent_id: str = None, overwrite: bool = True ) -> str: @@ -265,7 +298,9 @@ def detect_intent( text, language_code="en", parameters=None, + end_user_metadata=None, populate_data_store_connection_signals=False, + intent_id: str = None ): """Returns the result of detect intent with texts as inputs. @@ -280,6 +315,8 @@ def detect_intent( text: the user utterance to run intent detection on parameters: (Optional) Dict of CX Session Parameters to set in the conversation. Typically this is set before a conversation starts. + end_user_metadata: (Optional) Dict of CX Session endUserMetadata to + set in the conversation. populate_data_store_connection_signals: If set to true and data stores are involved in serving the request then query result will be populated with data_store_connection_signals field which @@ -302,9 +339,11 @@ def detect_intent( "Utilize `build_session_id` to create a new Session ID." ) - logging.info(f"Starting Session ID {session_id}") - - query_input = self._build_query_input(text, language_code) + if intent_id: + query_input = self.build_intent_query_input( + intent_id, language_code) + else: + query_input = self._build_query_input(text, language_code) request = types.session.DetectIntentRequest() request.session = session_id @@ -315,6 +354,9 @@ def detect_intent( if parameters: query_param_mapping["parameters"] = parameters + if end_user_metadata: + query_param_mapping["end_user_metadata"] = end_user_metadata + if populate_data_store_connection_signals: query_param_mapping[ "populate_data_store_connection_signals" diff --git a/src/dfcx_scrapi/core/tools.py b/src/dfcx_scrapi/core/tools.py index b24e26d6..af3bcfed 100644 --- a/src/dfcx_scrapi/core/tools.py +++ b/src/dfcx_scrapi/core/tools.py @@ -28,16 +28,20 @@ def __init__( self, creds_path: str = None, creds_dict: Dict = None, + creds=None, scope=False, agent_id: str = None, - tool_id: str = None + tool_id: str = None, + tools_map: Dict[str, str] = None ): super().__init__( - creds_path=creds_path, creds_dict=creds_dict, scope=scope + creds_path=creds_path, creds_dict=creds_dict, + creds=creds, scope=scope ) self.agent_id = agent_id self.tool_id = tool_id + self.tools_map = tools_map @staticmethod def build_open_api_tool( diff --git a/src/dfcx_scrapi/tools/agent_response.py b/src/dfcx_scrapi/tools/agent_response.py new file mode 100644 index 00000000..fc141637 --- /dev/null +++ b/src/dfcx_scrapi/tools/agent_response.py @@ -0,0 +1,320 @@ +"""Helper classes for parsing Agent Responses.""" + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +import json + +from typing import Any, Union +from google.protobuf.json_format import MessageToDict +from google.cloud.dialogflowcx_v3beta1 import types + +DataStoreConnectionSignals = ( + types.data_store_connection.DataStoreConnectionSignals +) + +_EXECUTION_SEQUENCE_KEY = "DataStore Execution Sequence" + +@dataclasses.dataclass +class Snippet: + uri: Union[str, None] + title: Union[str, None] + text: Union[str, None] + + def to_prompt_snippet(self) -> str: + result = [] + if self.title: + result.append(self.title) + if self.text: + result.append(self.text) + + return "\n".join(result) if result else "" + +@dataclasses.dataclass +class AgentResponse: + """Dataclass for storing relevant fields of detect intent response.""" + # ResponseMessages + answer_text: str = None + + # MatchType + match_type: str = None + + # DataStoreConnectionSignals + rewriter_llm_rendered_prompt: str = None + rewriter_llm_output: str = None + rewritten_query: str = None + search_results: list[Snippet] = dataclasses.field(default_factory=list) + answer_generator_llm_rendered_prompt: str = None + answer_generator_llm_output: str = None + generated_answer: str = None + cited_snippet_indices: list[int] = dataclasses.field(default_factory=list) + grounding_decision: str = None + grounding_score: str = None + safety_decision: str = None + safety_banned_phrase_match: str = None + + # DiagnosticInfo ExecutionResult + response_type: str = None + response_reason: str = None + latency: float = None + faq_citation: bool = None + search_fallback: bool = None + unstructured_citation: bool = None + website_citation: bool = None + language: str = None + + def from_query_result(self, query_result: types.session.QueryResult): + """Extracts the relevant fields from a QueryResult proto message.""" + answer_text = self._extract_text(query_result) + match_type = self._extract_match_type(query_result) + execution_result = self._extract_execution_result(query_result) + + self.answer_text=answer_text + self.match_type=match_type + self.response_type = execution_result.get("response_type") + self.response_reason = execution_result.get("response_reason") + self.latency = execution_result.get("latency") + self.faq_citation = execution_result.get("faq_citation") + self.search_fallback = execution_result.get("ucs_fallback") + self.unstructured_citation = execution_result.get( + "unstructured_citation") + self.website_citation = execution_result.get("website_citation") + self.language = execution_result.get("language") + + if query_result.data_store_connection_signals: + self._extract_data_store_connection_signals( + query_result.data_store_connection_signals + ) + + @classmethod + def from_row(cls, row: dict[str, Any]): + """Extracts the relevant fields from a dictionary.""" + row = row.copy() + search_results = [] + for search_result in json.loads(row["search_results"]): + search_results.append(Snippet(**search_result)) + + row["search_results"] = search_results + row["cited_snippet_indices"] = json.loads(row["cited_snippet_indices"]) + + return cls(**row) + + def to_row(self): + """Dumps the query result fields to a dictionary.""" + result = dataclasses.asdict(self) + result["search_results"] = json.dumps( + result.pop("search_results", []), indent=4 + ) + result["cited_snippet_indices"] = json.dumps( + result["cited_snippet_indices"]) + + return result + + @staticmethod + def _extract_match_type(query_result: types.session.QueryResult) -> str: + """Extracts the name of the match type from query result.""" + try: + return types.session.Match.MatchType( + query_result.match.match_type).name + + except ValueError: + # if an enum type is returned which is not visible externally then + # fallback to default value + return types.session.Match.MatchType(0).name + + @staticmethod + def _extract_text(res: types.session.QueryResult) -> str: + all_text: list[str] = [] + if res.response_messages: + for rm in res.response_messages: + if rm.text and len(rm.text.text) > 0: + all_text.append(rm.text.text[0]) + + final_text = "\n".join(all_text) + + return final_text + + @staticmethod + def _extract_execution_result( + query_result: types.session.QueryResult) -> dict[str, Any]: + """Extracts the execution result from diagnostic info.""" + if _EXECUTION_SEQUENCE_KEY in query_result.diagnostic_info: + execution_sequence = query_result.diagnostic_info[ + _EXECUTION_SEQUENCE_KEY + ] + if "executionResult" in execution_sequence: + return MessageToDict(execution_sequence["executionResult"]) + return {} + + def _extract_search_results( + self, + data_store_connection_signals: DataStoreConnectionSignals + ): + """Extracts search results as a list of strings.""" + self.search_results = [] + for search_snippet in data_store_connection_signals.search_snippets: + self.search_results.append( + Snippet( + uri=search_snippet.document_uri, + title=search_snippet.document_title, + text=search_snippet.text, + ) + ) + + def _extract_citation_indices( + self, + data_store_connection_signals: DataStoreConnectionSignals + ): + """Extracts the links and snippets used to generate answer.""" + self.cited_snippet_indices = [] + for cited_snippet in data_store_connection_signals.cited_snippets: + self.cited_snippet_indices.append(cited_snippet.snippet_index) + + @staticmethod + def _extract_grounding_decision( + grounding_signals: DataStoreConnectionSignals.GroundingSignals + ) -> str: + return DataStoreConnectionSignals.GroundingSignals.GroundingDecision( + grounding_signals.decision + ).name + + @staticmethod + def _extract_grounding_score( + grounding_signals: DataStoreConnectionSignals.GroundingSignals + ): + return DataStoreConnectionSignals.GroundingSignals.GroundingScoreBucket( + grounding_signals.score + ).name + + def _extract_grounding_signals( + self, data_store_connection_signals: DataStoreConnectionSignals + ) -> dict[str, str]: + grounding_signals = data_store_connection_signals.grounding_signals + if not grounding_signals: + self.grounding_decision = None + self.grounding_score = None + else: + self.grounding_decision = self._extract_grounding_decision( + grounding_signals) + self.grounding_score = self._extract_grounding_score( + grounding_signals) + + def _extract_rewriter_llm_signals( + self, + data_store_connection_signals: DataStoreConnectionSignals + ): + rewriter_model_call_signals = ( + data_store_connection_signals.rewriter_model_call_signals + ) + if not rewriter_model_call_signals: + self.rewriter_llm_rendered_prompt = None + self.rewriter_llm_output = None + + else: + self.rewriter_llm_rendered_prompt = ( + rewriter_model_call_signals.rendered_prompt + ) + self.rewriter_llm_output = rewriter_model_call_signals.model_output + + def _extract_answer_generator_llm_signals( + self, + data_store_connection_signals: DataStoreConnectionSignals + ) -> dict[str, str]: + answer_generation_model_call_signals = ( + data_store_connection_signals.answer_generation_model_call_signals + ) + if not answer_generation_model_call_signals: + self.answer_generator_llm_rendered_prompt = None + self.answer_generator_llm_output = None + + else: + self.answer_generator_llm_rendered_prompt = ( + answer_generation_model_call_signals.rendered_prompt + ) + self.answer_generator_llm_output = ( + answer_generation_model_call_signals.model_output + ) + + @staticmethod + def _extract_safety_decision( + safety_signals: DataStoreConnectionSignals.SafetySignals) -> str: + return DataStoreConnectionSignals.SafetySignals.SafetyDecision( + safety_signals.decision + ).name + + @staticmethod + def _extract_safety_banned_phrase( + safety_signals: DataStoreConnectionSignals.SafetySignals + ) -> str: + return DataStoreConnectionSignals.SafetySignals.BannedPhraseMatch( + safety_signals.banned_phrase_match + ).name + + def _extract_safety_signals( + self, data_store_connection_signals: DataStoreConnectionSignals + ) -> dict[str, str]: + safety_signals = data_store_connection_signals.safety_signals + if not safety_signals: + self.safety_decision = None + self.safety_banned_phrase_match = None + else: + self.safety_decision = self._extract_safety_decision(safety_signals) + self.safety_banned_phrase_match = ( + self._extract_safety_banned_phrase(safety_signals) + ) + + def _extract_data_store_connection_signals( + self, + data_store_connection_signals: DataStoreConnectionSignals + ) -> dict[str, Any]: + self._extract_rewriter_llm_signals(data_store_connection_signals + ) + self.rewritten_query = ( + data_store_connection_signals.rewritten_query + if data_store_connection_signals.rewritten_query + else None + ) + + self._extract_grounding_signals(data_store_connection_signals) + self._extract_search_results(data_store_connection_signals) + self._extract_answer_generator_llm_signals( + data_store_connection_signals + ) + self.generated_answer = ( + data_store_connection_signals.answer + if data_store_connection_signals.answer + else None + ) + self._extract_citation_indices(data_store_connection_signals) + self._extract_safety_signals(data_store_connection_signals) + + @property + def search_result_links(self): + return [search_result.uri for search_result in self.search_results] + + @property + def cited_search_results(self): + return [self.search_results[idx] for idx in self.cited_snippet_indices] + + @property + def cited_search_result_links(self): + return [search_result.uri for search_result in self.cited_search_results] + + @property + def prompt_snippets(self): + return [ + search_result.to_prompt_snippet() + for search_result in self.search_results + ] diff --git a/src/dfcx_scrapi/tools/copy_util.py b/src/dfcx_scrapi/tools/copy_util.py index 6161f55d..7abcf446 100644 --- a/src/dfcx_scrapi/tools/copy_util.py +++ b/src/dfcx_scrapi/tools/copy_util.py @@ -445,7 +445,7 @@ def _create_entity_resources( for entity in resources_objects["entities"]: logging.info("Creating Entity %s...", entity.display_name) try: - self.entities.create_entity_type(destination_agent, entity) + self.entities.create_entity_type(agent_id=destination_agent, obj=entity) resources_skip_list["entities"].append(entity.display_name) logging.info( "Entity %s created successfully.", entity.display_name diff --git a/src/dfcx_scrapi/tools/dataframe_functions.py b/src/dfcx_scrapi/tools/dataframe_functions.py index b6dacf59..d9a233a6 100644 --- a/src/dfcx_scrapi/tools/dataframe_functions.py +++ b/src/dfcx_scrapi/tools/dataframe_functions.py @@ -1,6 +1,6 @@ """Utility file for dataframe functions in support of Dialogflow CX.""" -# Copyright 2023 Google LLC +# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,13 +18,11 @@ import logging import time from typing import Dict, List -import google.auth import gspread import pandas as pd import numpy as np from tabulate import tabulate from gspread_dataframe import set_with_dataframe -from oauth2client.service_account import ServiceAccountCredentials from google.cloud.dialogflowcx_v3beta1 import types @@ -55,7 +53,6 @@ def __init__( creds_path: str = None, creds_dict: dict = None, creds=None, - principal=False, scope=False, ): super().__init__( @@ -70,28 +67,9 @@ def __init__( if scope: scopes += scope - if creds: - self.sheets_client = gspread.authorize(creds) - - elif creds_path: - creds = ServiceAccountCredentials.from_json_keyfile_name( - filename=creds_path, scopes=scopes - ) - self.sheets_client = gspread.authorize(creds) - - elif creds_dict: - creds = ServiceAccountCredentials.from_json_keyfile_dict( - keyfile_dict=creds_dict, scopes=scopes - ) - self.sheets_client = gspread.authorize(creds) - - elif principal: - self.sheets_client = gspread.oauth() - - else: - creds = google.auth.default(scopes=scopes)[0] - self.sheets_client = gspread.authorize(creds) + self.creds.scopes.extend(scopes) + self.sheets_client = gspread.authorize(self.creds) self.entities = EntityTypes(creds=self.creds) self.intents = Intents(creds=self.creds) self.flows = Flows(creds=self.creds) diff --git a/src/dfcx_scrapi/tools/datastore_evaluator.py b/src/dfcx_scrapi/tools/datastore_evaluator.py new file mode 100644 index 00000000..31d39a7f --- /dev/null +++ b/src/dfcx_scrapi/tools/datastore_evaluator.py @@ -0,0 +1,511 @@ +"""Evaluation tooling for Vertex Conversation DataStores native service.""" + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +from datetime import datetime, timezone +import io +import itertools +import json +import math +import os +import re +import pandas as pd +from tqdm import tqdm +from typing import Union + +import plotly.graph_objects as go +import matplotlib.pyplot as plt + +from google.cloud import bigquery +from googleapiclient.discovery import build +from googleapiclient.http import MediaInMemoryUpload, MediaIoBaseDownload + +from dfcx_scrapi.core.scrapi_base import ScrapiBase +from dfcx_scrapi.tools.agent_response import AgentResponse +from dfcx_scrapi.tools.metrics import build_metrics + +_FOLDER_ID = re.compile(r"folders\/(.*?)(?=\/|\?|$)") +EVAL_RESULTS_COLS = [ + "answer_generator_llm_rendered_prompt", + "search_results" + ] + +class DataStoreEvaluator(ScrapiBase): + def __init__(self, metrics: list[str], model: str = "text-bison@002"): + self.model = self.model_setup(model) + self.metrics = build_metrics(metrics, generation_model=self.model) + + def run(self, scraper_output: pd.DataFrame) -> "EvaluationResult": + timestamp = datetime.now(tz=timezone.utc) + scraper_output = scraper_output.copy(deep=True) + result = pd.DataFrame(index=scraper_output.index) + + for metric in self.metrics: + result = pd.concat([result, metric.run(scraper_output)], axis=1) + + # adding timestamp and agent display name so they can be used as a multi + # index + result["evaluation_timestamp"] = timestamp.isoformat() + + return EvaluationResult(scraper_output, result) + + +@dataclasses.dataclass +class EvaluationResult: + scrape_outputs: pd.DataFrame = None + metric_outputs: pd.DataFrame = None + + @property + def timestamp(self) -> str: + return self.metric_outputs["evaluation_timestamp"].iloc[0] + + @staticmethod + def truncate(df, column): + truncated_fix = "" + def _truncate(value): + if len(value) < 50_000: + return value + else: + return value[:50_000 - len(truncated_fix)] + truncated_fix + df[column] = df[column].apply(_truncate) + + @staticmethod + def find_folder(folder_name, drive_service) -> Union[tuple[str, str], None]: + """Finds a folder by name in Google Drive.""" + query = ( + f"name = '{folder_name}' and " + f"mimeType = 'application/vnd.google-apps.folder' and " + f"trashed = false" + ) + fields = "nextPageToken, files(id, name, webViewLink)" + list_request = drive_service.files().list(q=query, fields=fields) + result = list_request.execute() + folders = result.get("files", []) + if not folders: + return None + + return folders[0].get("id"), folders[0].get("webViewLink") + + @staticmethod + def create_folder( + folder_name, drive_service + ) -> tuple[Union[str, None], Union[str, None]]: + """Creates a folder in Google Drive.""" + create_request = drive_service.files().create( + body={ + "name": folder_name, + "mimeType": "application/vnd.google-apps.folder" + }, + fields="id, webViewLink" + ) + result = create_request.execute() + + return result.get("id"), result.get("webViewLink") + + @staticmethod + def create_json( + content, file_name, parent, drive_service + ) -> tuple[Union[str, None], Union[str, None]]: + """Creates a .json file in the specified Google Drive folder.""" + request = drive_service.files().create( + body={"name": file_name, "parents": [parent]}, + media_body=MediaInMemoryUpload( + json.dumps(content, indent=4).encode("utf-8"), + mimetype="text/plain", + ), + fields="id, webViewLink", + ) + result = request.execute() + + return result.get("id"), result.get("webViewLink") + + @staticmethod + def create_chunks(iterable, chunk_size): + for chunk in itertools.zip_longest(*([iter(iterable)] * chunk_size)): + yield [element for element in chunk if element is not None] + + @staticmethod + def delete_worksheet(sheet_id, worksheet_id, sheets_service): + """Deletes a worksheet.""" + sheets_service.spreadsheets().batchUpdate( + spreadsheetId=sheet_id, + body={"requests": [{"deleteSheet": {"sheetId": worksheet_id}}]}, + ).execute() + + @staticmethod + def get_bigquery_types(df): + "Maps DataFrame data types to BigQuery data types." + types = [] + data_type_mapping = { + 'object': 'STRING', + 'int64': 'INTEGER', + 'float64': 'FLOAT', + 'bool': 'BOOLEAN', + 'datetime64[ns]': 'TIMESTAMP' # Assuming nanosecond timestamps + } + for dtype in df.dtypes: + if dtype in data_type_mapping: + types.append(data_type_mapping[dtype]) + else: + # Handle other data types (error handling or placeholder) + types.append('STRING') # Placeholder, adjust as needed + print(f"Warning: Unhandled data type: {dtype}") + + return types + + @staticmethod + def sanitize_column_names(df): + "Sanitizes column names replacing special characters with underscores." + sanitized_names = [] + for col in df.columns: + # Replace special characters with underscores + sanitized_name = re.sub(r"[^\w\s]", "_", col) + sanitized_names.append(sanitized_name) + + return df.rename(columns=dict(zip(df.columns, sanitized_names))) + + @staticmethod + def list_folder(folder_id, drive_service) -> list[tuple[str, str]]: + query = f"'{folder_id}' in parents and trashed = false" + list_request = drive_service.files().list( + q=query, fields="nextPageToken, files(id, name)" + ) + result = list_request.execute() + items = result.get("files", []) + return [(item["id"], item["name"]) for item in items] + + @staticmethod + def download_json(file_id, drive_service): + request = drive_service.files().get_media(fileId=file_id) + fh = io.BytesIO() + downloader = MediaIoBaseDownload(fh, request) + done = False + while not done: + _, done = downloader.next_chunk() + + fh.seek(0) + + return json.loads(fh.read().decode('utf-8')) + + def load(self, folder_url, credentials): + folder_id_match = _FOLDER_ID.search(folder_url) + if not folder_id_match: + raise ValueError() + + folder_id = folder_id_match.group(1) + drive_service = build("drive", "v3", credentials=credentials) + + file_id = self.find_file_in_folder( + folder_id, "results.json", drive_service) + json_content = self.download_json(file_id, drive_service) + + queryset = pd.DataFrame.from_dict( + json_content["queryset"], orient="index") + responses = pd.DataFrame.from_dict( + json_content["responses"], orient="index" + ) + + ar = AgentResponse() + queryset["query_result"] = responses.apply( + ar.from_row, axis=1 + ) + self.scrape_outputs = queryset + + self.metric_outputs = pd.DataFrame.from_dict( + json_content["metrics"], orient="index" + ) + + def aggregate(self, columns: list[str] = None): + if not columns: + columns = self.metric_outputs.columns + shared_columns = self.metric_outputs.columns.intersection(set(columns)) + result = pd.DataFrame(self.metric_outputs[shared_columns]) + result["name"] = self.scrape_outputs["agent_display_name"] + result["evaluation_timestamp"] = ( + self.metric_outputs["evaluation_timestamp"] + ) + result = result.set_index(["name", "evaluation_timestamp"]) + + return result.groupby(level=[0, 1]).mean(numeric_only=True) + + def export(self, folder_name: str, chunk_size: int, credentials): + drive_service = build("drive", "v3", credentials=credentials) + folder = self.find_folder(folder_name, drive_service) + if folder: + folder_id, folder_url = folder + else: + folder_id, folder_url = self.create_folder( + folder_name, drive_service + ) + + queryset = self.scrape_outputs.drop("query_result", axis=1) + responses = self.scrape_outputs["query_result"].apply( + lambda x: x.to_row() + ) + responses = pd.DataFrame(responses.to_list(), index=queryset.index) + + json_content = { + "queryset": queryset.to_dict(orient="index"), + "responses": responses.to_dict(orient="index"), + "metrics": self.metric_outputs.to_dict(orient="index"), + } + json_id, json_url = self.create_json( + json_content, "results.json", folder_id, drive_service + ) + + for column in EVAL_RESULTS_COLS: + self.truncate(responses, column) + + results = pd.concat([queryset, responses, self.metric_outputs], axis=1) + worksheets = { + "summary": self.aggregate().fillna("#N/A"), + "results": results.fillna("#N/A") + } + sheets_service = build("sheets", "v4", credentials=credentials) + self.create_sheet( + worksheets=worksheets, + title="results", + parent=folder_id, + chunk_size=chunk_size, + sheets_service=sheets_service, + drive_service=drive_service, + ) + return folder_url + + def export_to_csv(self, file_name: str): + queryset = self.scrape_outputs.drop("query_result", axis=1) + responses = self.scrape_outputs["query_result"].apply( + lambda x: x.to_row()) + responses = pd.DataFrame(responses.to_list(), index=queryset.index) + + for column in EVAL_RESULTS_COLS: + self.truncate(responses, column) + + results = pd.concat([queryset, responses, self.metric_outputs], axis=1) + temp_dir = "/tmp/evaluation_results" + os.makedirs(temp_dir, exist_ok=True) + filepath = os.path.join(temp_dir, file_name) + results.to_csv(filepath, index=False) + + return filepath + + def display_on_screen(self): + queryset = self.scrape_outputs.drop("query_result", axis=1) + responses = self.scrape_outputs["query_result"].apply( + lambda x: x.to_row()) + responses = pd.DataFrame(responses.to_list(), index=queryset.index) + + for column in EVAL_RESULTS_COLS: + self.truncate(responses, column) + + results = pd.concat([queryset, responses, self.metric_outputs], axis=1) + + return results + + def export_to_bigquery( + self, + eval_results, + project_id: str, + dataset_id: str, + table_name: str, + credentials + ): + data=eval_results.scrape_outputs["query_result"].apply( + lambda x: x.to_row()) + data = pd.DataFrame(data.to_list(),eval_results.scrape_outputs.index) + eval_results.scrape_outputs["query_result"] = None + df = pd.concat( + [ + data, + eval_results.scrape_outputs, + eval_results.metric_outputs + ], + axis=1) + + df = EvaluationResult.sanitize_column_names(df) + client = bigquery.Client(project=project_id, credentials=credentials) + + try: + df['conversation_id'] = df['conversation_id'].astype(str) + df['latency'] = df['latency'].astype(str) + df['expected_uri'] = df['expected_uri'].astype(str) + df['answerable'] = df['answerable'].astype(str) + df['golden_snippet'] = df['golden_snippet'].astype(str) + + df = df.drop('query_result', axis=1) + df = df.drop('golden_snippet', axis=1) + df = df.drop('answerable', axis=1) + + load_job = client.load_table_from_dataframe(df, '.'.join( + [project_id, dataset_id, table_name])) + + return load_job.result() + except Exception as e: + print(f"Error exporting data: {e}") + return None # Indicate failure + + def find_file_in_folder( + self, + folder_id, + name, + drive_service + ) -> Union[str, None]: + for file_id, file_name in self.list_folder(folder_id, drive_service): + if file_name == name: + return file_id + return None + + def add_worksheet( + self, sheet_id, content, title, sheets_service, chunk_size) -> None: + """Adds a worksheet to an existing spreadsheet.""" + sheets_service.spreadsheets().batchUpdate( + spreadsheetId=sheet_id, + body={"requests": [{"addSheet": {"properties": {"title": title}}}]}, + ).execute() + + for chunk in tqdm( + self.create_chunks(content, chunk_size), + total=math.ceil(len(content) / chunk_size), + desc=f"Creating worksheet: {title}", + ): + sheets_service.spreadsheets().values().append( + spreadsheetId=sheet_id, + range=f"'{title}'!A1", + valueInputOption="RAW", + body={"values": chunk}, + ).execute() + + def create_sheet( + self, worksheets, title, parent, chunk_size, sheets_service, + drive_service) -> Union[str, None]: + """Creates a new spreadsheet with worksheets.""" + body = {"properties": {"title": title}} + create_request = sheets_service.spreadsheets().create( + body=body, fields="spreadsheetId" + ) + create_result = create_request.execute() + sheet_id = create_result.get("spreadsheetId") + + parents_request = drive_service.files().get( + fileId=sheet_id, fields="parents") + parents_result = parents_request.execute() + parents = parents_result.get("parents") + previous_parents = ",".join(parents) if parents else None + + if not sheet_id: + return + + for worksheet_title, content in worksheets.items(): + content_dict = content.to_dict(orient="split") + self.add_worksheet( + sheet_id=sheet_id, + content=[content_dict["columns"]] + content_dict["data"], + title=worksheet_title, + sheets_service=sheets_service, + chunk_size=chunk_size, + ) + + all_request = sheets_service.spreadsheets().get(spreadsheetId=sheet_id) + all_result = all_request.execute() + default_sheet_id = all_result["sheets"][0]["properties"]["sheetId"] + + self.delete_worksheet(sheet_id, default_sheet_id, sheets_service) + _ = drive_service.files().update( + fileId=sheet_id, + addParents=parent, + removeParents=previous_parents, + fields="id, parents" + ).execute() + + return f"https://docs.google.com/spreadsheets/d/{sheet_id}/edit" + + +@dataclasses.dataclass +class EvaluationVisualizer: + evaluation_results: list[EvaluationResult] + + def radar_plot(self, columns: Union[list[str], None] = None): + fig = go.Figure() + summaries = pd.concat( + [result.aggregate(columns) for result in self.evaluation_results] + ) + summaries = summaries.to_dict(orient="split") + + for idx, values in enumerate(summaries["data"]): + fig.add_trace( + go.Scatterpolar( + r=values, + theta=summaries["columns"], + fill='toself', + name="_".join(summaries["index"][idx]), + ) + ) + fig.update_layout( + polar={"radialaxis": {"visible": True, "range": [0, 1]}}, + showlegend=True + ) + fig.show() + + def count_barplot(self, column_name: str): + results = [] + for result in self.evaluation_results: + responses = result.scrape_outputs["query_result"].apply( + lambda x: x.to_row()) + responses = pd.DataFrame( + responses.to_list(), index=result.scrape_outputs.index + ) + results.append( + pd.concat( + [result.scrape_outputs, responses, result.metric_outputs], + axis=1 + ) + ) + results = pd.concat(results) + results = results.set_index(["agent_display_name", "evaluation_timestamp"]) + grouped_counts = ( + results[column_name] + .groupby(level=["agent_display_name", "evaluation_timestamp"]) + .value_counts() + .unstack(fill_value=0) + ) + grouped_counts.plot(kind="bar") + plt.xlabel("Name") + plt.ylabel("Count") + plt.xticks(rotation=15) + plt.title(f"{column_name} counts by name") + plt.legend(title=column_name) + plt.show() + + def mean_barplot(self, column_names: list[str]): + results = [] + for result in self.evaluation_results: + results.append( + pd.concat([result.scrape_outputs, result.metric_outputs], axis=1) + ) + results = pd.concat(results) + results = results.set_index(["agent_display_name", "evaluation_timestamp"]) + grouped_means = ( + results[column_names] + .groupby(level=["agent_display_name", "evaluation_timestamp"]) + .mean() + ) + grouped_means.plot(kind="bar") + plt.ylim(top=1.0) + plt.xlabel("Name") + plt.ylabel("Mean") + plt.xticks(rotation=15) + plt.title("mean by name") + plt.show() diff --git a/src/dfcx_scrapi/tools/datastore_scraper.py b/src/dfcx_scrapi/tools/datastore_scraper.py new file mode 100644 index 00000000..332bb3ab --- /dev/null +++ b/src/dfcx_scrapi/tools/datastore_scraper.py @@ -0,0 +1,249 @@ +"""Vertex AI Conversation scraper class.""" + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import json +import re +import gspread +import pandas as pd +from tqdm.auto import tqdm +from typing import Union, Any +from google.oauth2 import service_account + +from dfcx_scrapi.tools.agent_response import AgentResponse +from dfcx_scrapi.core.scrapi_base import ScrapiBase, retry_api_call +from dfcx_scrapi.core.sessions import Sessions +from dfcx_scrapi.core.agents import Agents + +MAX_RETRIES = 5 +INPUT_SCHEMA_REQUIRED_COLUMNS = [ + "conversation_id", + "turn_index", + "query", + "expected_answer", + "expected_uri", + "user_metadata", + "parameters" +] + +def load_spreadsheet( + sheet_url: str, worksheet_name: str, credentials: Any + ) -> pd.DataFrame: + """Loads the content of a spreadsheet into pandas DataFrame.""" + sheets_client = gspread.authorize(credentials) + sheet = sheets_client.open_by_url(sheet_url) + worksheet = sheet.worksheet(worksheet_name) + + return pd.DataFrame(worksheet.get_all_records()) + +class DataStoreScraper(ScrapiBase): + """Vertex AI Conversation scraper class.""" + + def _extract_url_part(cls, url, pattern): + pattern_match = pattern.search(url) + if not pattern_match: + raise ValueError(f"Invalid url: {url}") + + return pattern_match.group(1) + + @classmethod + def from_url( + cls, + agent_url: str, + creds: service_account.Credentials = None, + language_code: str = "en"): + match = re.search( + r'projects/[^/]+/locations/[^/]+/agents/[^/]+', agent_url + ) + if match: + agent_id = match.group(0) + else: + raise ValueError(f"Invalid url: {agent_url}") + + return cls( + agent_id=agent_id, + language_code=language_code, + creds=creds, + ) + + def __init__( + self, + agent_id: str, + language_code: str = "en", + creds_path: str = None, + creds_dict: dict[str, str] = None, + creds=None, + ): + super().__init__( + creds_path=creds_path, + creds_dict=creds_dict, + creds=creds, + ) + + self.agent_id = agent_id + self.language_code = language_code + + self.sessions = Sessions(agent_id=self.agent_id) + self.agents = Agents(creds=self.creds) + + @classmethod + def _extract_url_part(cls, url, pattern): + pattern_match = pattern.search(url) + if not pattern_match: + raise ValueError(f"Invalid url: {url}") + return pattern_match.group(1) + + def validate_queryset(self, queryset: pd.DataFrame) -> None: + "Validates the queryset and raises exception in case of invalid input." + # validate input schema + try: + queryset[INPUT_SCHEMA_REQUIRED_COLUMNS] + except KeyError as err: + raise UserWarning( + "Ensure your input data contains the following columns:" + f" {INPUT_SCHEMA_REQUIRED_COLUMNS}" + ) from err + + # validate if conversationd_id and turn_id is unique identifier + if not ( + queryset["conversation_id"].astype(str) + + "_" + + queryset["turn_index"].astype(str) + ).is_unique: + raise UserWarning( + "Ensure that 'conversation_id' and 'turn_index' are unique " + "identifiers" + ) + + # validate turn_index + try: + queryset["turn_index"].astype(int) + except ValueError as err: + raise UserWarning( + "Ensure that 'turn_index' is set as integer" + ) from err + + if not queryset["turn_index"].astype(int).gt(0).all(): + raise UserWarning("Ensure that 'turn_index' is in [1, inf)") + + def setup_queryset(self, queryset: pd.DataFrame) -> pd.DataFrame: + """Various Dataframe validation and cleaning functions.""" + queryset = queryset.rename( + {column: column.lower() for column in queryset.columns} + ) + + self.validate_queryset(queryset) + + queryset["turn_index"] = queryset["turn_index"].astype(int) + timestamp = datetime.datetime.now(tz=datetime.timezone.utc) + + # adding timestamp and agent display name so they can be used as a multi + # index + queryset["scrape_timestamp"] = timestamp.isoformat() + agent_display_name = self.agents.get_agent(self.agent_id).display_name + queryset["agent_display_name"] = agent_display_name + + queryset = self._create_session_ids(queryset) + + # if the conversation_id can be converted to int then sorting can be + # done numerically instead of alphabetically + try: + queryset["conversation_id"] = queryset["conversation_id"].astype( + int + ) + except ValueError: + pass + + queryset = queryset.sort_values( + by=["conversation_id", "turn_index"], ascending=True + ) + + return queryset + + def _create_session_ids(self, queryset: pd.DataFrame) -> pd.DataFrame: + """Creates a unique session id for each conversation_id.""" + sessions = [] + for conversation_id in queryset["conversation_id"].unique(): + sessions.append( + { + "conversation_id": conversation_id, + "session_id": self.sessions.build_session_id(self.agent_id), + } + ) + sessions_df = pd.DataFrame(sessions) + return queryset.merge(sessions_df, on="conversation_id", how="left") + + @retry_api_call([i**2 for i in range(MAX_RETRIES)]) + def scrape_detect_intent( + self, + query: str, + session_id: Union[str, None] = None, + user_metadata: Union[str, None] = None, + parameters: Union[str, None] = None + ) -> AgentResponse: + if session_id is None: + session_id = self.sessions.build_session_id(self.agent_id) + + if user_metadata: + try: + if isinstance(user_metadata, str): + user_metadata = json.loads(user_metadata) + except ValueError as err: + raise UserWarning("Invalid user metadata") from err + + if parameters: + try: + if isinstance(parameters, str): + parameters = json.loads(parameters) + except ValueError as err: + raise UserWarning("Invalid parameters") from err + + response = self.sessions.detect_intent( + agent_id=self.agent_id, + session_id=session_id, + text=query, + language_code=self.language_code, + end_user_metadata=user_metadata, + populate_data_store_connection_signals=True, + parameters=parameters + ) + + ar = AgentResponse() + ar.from_query_result(response._pb) + + return ar + + def run( + self, queryset: pd.DataFrame, flatten_response: bool = True + ) -> pd.DataFrame: + "Runs through each query and concatenates responses to the queryset." + queryset = self.setup_queryset(queryset) + progress_bar = tqdm(desc="Scraping queries", total=len(queryset)) + + def scrape(row): + result = self.scrape_detect_intent( + query=row["query"], + session_id=row["session_id"], + user_metadata=row["user_metadata"], + parameters=row["parameters"] + ) + progress_bar.update() + + return result + + queryset["query_result"] = queryset.apply(scrape, axis=1) + + return queryset diff --git a/src/dfcx_scrapi/tools/evaluations.py b/src/dfcx_scrapi/tools/evaluations.py new file mode 100644 index 00000000..3859d447 --- /dev/null +++ b/src/dfcx_scrapi/tools/evaluations.py @@ -0,0 +1,543 @@ +"""Evaluation tooling for Generative features in Agent Builder and DFCX.""" + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime +import logging + +from ast import literal_eval +import numpy as np +import pandas as pd +from tqdm import tqdm +from typing import Dict, List, Any + +from google.oauth2 import service_account + +from dfcx_scrapi.core.scrapi_base import ScrapiBase +from dfcx_scrapi.core.agents import Agents +from dfcx_scrapi.core.sessions import Sessions +from dfcx_scrapi.core.tools import Tools +from dfcx_scrapi.core.playbooks import Playbooks +from dfcx_scrapi.tools.dataframe_functions import DataframeFunctions +from dfcx_scrapi.tools.agent_response import AgentResponse +from dfcx_scrapi.tools.metrics import build_metrics + +from google.cloud.dialogflowcx_v3beta1 import types + +# logging config +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)-8s %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) + + +class Evaluations(ScrapiBase): + """Evaluation tooling for Generative features in Agent Builder and DFCX.""" + + def __init__( + self, + agent_id: str, + creds_path: str = None, + creds_dict: Dict[str, str] = None, + creds: service_account.Credentials = None, + metrics: List[str] = ["response_similarity"], + debug: bool = False, + generation_model: str = "gemini-1.5-flash-001", + embedding_model: str = "text-embedding-004", + playbooks_map: Dict[str, Any] = None, + tools_map: Dict[str, Any] = None, + ): + super().__init__( + creds_path=creds_path, + creds_dict=creds_dict, + creds=creds, + ) + + self.agent_id = agent_id + self.session_id = None + + print("Initializing Vertex AI...") + self.init_vertex(self.agent_id) + self.s = Sessions(agent_id=self.agent_id, tools_map=tools_map) + self.p = Playbooks(agent_id=self.agent_id, playbooks_map=playbooks_map) + self.t = Tools(agent_id=self.agent_id, tools_map=tools_map) + self.ar = AgentResponse() + + self.generation_model = self.model_setup(generation_model) + self.embedding_model = self.model_setup(embedding_model) + + self.metrics = build_metrics( + metrics=metrics, + generation_model=self.generation_model, + embedding_model=self.embedding_model + ) + + if debug: + logging.basicConfig(level=logging.DEBUG, force=True) + if not debug: + logging.basicConfig(level=logging.ERROR, force=True) + + @staticmethod + def clean_outputs(df: pd.DataFrame) -> pd.DataFrame: + """Clean final output dataframe.""" + # drop cols used for response mapping + df = df.drop(columns=["utterance_pair", "tool_pair", "playbook_pair"]) + value_map = {} + for col, dtype in zip(df.columns, df.dtypes): + if dtype in ["string", "object"]: + value_map[col] = "" + elif dtype == "float64": + value_map[col] = np.nan + + df.fillna(value=value_map, inplace=True) + + return df + + @staticmethod + def process_playbook_invocations( + responses: List[str], + index: int, + row: pd.Series, + df: pd.DataFrame) -> pd.DataFrame: + if row["playbook_pair"] in [None, "", "NaN", "nan"]: + playbook_index_list = [index] + else: + playbook_index_list = literal_eval(row["playbook_pair"]) + + for idx in playbook_index_list: + playbook = responses.pop(0) + df.loc[int(idx), "res_playbook_name"] = playbook["playbook_name"] + + return df + + @staticmethod + def process_tool_invocations( + tool_responses: List[str], + index: int, + row: pd.Series, + df: pd.DataFrame) -> pd.DataFrame: + # Check if our golden contained a tool_idx or wasn't + # expecting tools + if row["tool_pair"] in [None, "", "NaN", "nan"]: + tool_index_list = [index] + else: + tool_index_list = literal_eval(row["tool_pair"]) + + for idx in tool_index_list: + tool = tool_responses.pop(0) + df.loc[ + int(idx), + [ + "res_tool_name", + "res_tool_action", + "res_input_params", + "res_output_params", + ], + ] = [ + tool["tool_name"], + tool["tool_action"], + tool["input_params"], + tool["output_params"], + ] + + return df + + @staticmethod + def add_response_columns(df: pd.DataFrame) -> pd.DataFrame: + df = df.copy() + + df.loc[:, "agent_response"] = pd.Series(dtype="str") + df.loc[:, "agent_id"] = pd.Series(dtype="str") + df.loc[:, "session_id"] = pd.Series(dtype="str") + df.loc[:, "res_playbook_name"] = pd.Series(dtype="str") + df.loc[:, "res_tool_name"] = pd.Series(dtype="str") + df.loc[:, "res_tool_action"] = pd.Series(dtype="str") + df.loc[:, "res_input_params"] = pd.Series(dtype="str") + df.loc[:, "res_output_params"] = pd.Series(dtype="str") + + return df + + def run_detect_intent_queries(self, df: pd.DataFrame) -> pd.DataFrame: + for index, row in tqdm(df.iterrows(), total=df.shape[0]): + data = {} + if row["action_id"] == 1: + self.session_id = self.s.build_session_id(self.agent_id) + data["session_id"] = self.session_id + data["agent_id"] = self.agent_id + + else: + data["session_id"] = self.session_id + data["agent_id"] = self.agent_id + + # If the incoming dataset has an empty value in the row, skip it + # this is because we build the incoming dataset with multi-row + # actions to be able to evaluate `inner-loop` tasks + if row["action_type"] != "User Utterance": + continue + + res = self.s.detect_intent( + self.agent_id, self.session_id, row["action_input"] + ) + + # Add data to the existing row + df.loc[index, ["session_id", "agent_id"]] = [ + data["session_id"], + data["agent_id"], + ] + text_res = self.ar._extract_text(res) + utterance_idx = int(row["utterance_pair"]) + df.loc[utterance_idx, ["agent_response"]] = [text_res] + + # Handle Play Invocations + playbook_responses = self.s.collect_playbook_responses(res) + if len(playbook_responses) > 0: + df = self.process_playbook_invocations( + playbook_responses, index, row, df + ) + + # Handle Tool Invocations + tool_responses = self.s.collect_tool_responses(res) + if len(tool_responses) > 0: + df = self.process_tool_invocations( + tool_responses, index, row, df + ) + + return df + + def run_evals(self, df: pd.DataFrame) -> pd.DataFrame: + print("Starting Evals...") + + for metric in self.metrics: + df = pd.concat([df, metric.run(df)], axis=1) + + return df + + def run_query_and_eval(self, df: pd.DataFrame) -> pd.DataFrame: + df = self.add_response_columns(df) + df = self.run_detect_intent_queries(df) + df = self.run_evals(df) + df = self.clean_outputs(df) + + return df + +class DataLoader: + def __init__( + self, + creds_path: str = None, + creds_dict: Dict[str, str] = None, + creds: service_account.Credentials = None, + agent_id: str = None, + sheet_name: str = None, + ): + + self.agent_id = agent_id + self.sheet_name = sheet_name + self.dffx = DataframeFunctions( + creds_path=creds_path, creds_dict=creds_dict, creds=creds + ) + self.required_columns = [ + "eval_id", + "action_id", + "action_input", + "action_input_parameters", + "tool_action", + ] + + @staticmethod + def get_matching_list_idx(a, b): + """Helper method to find index pairs in the dataset. + + Compare lists and find the idx from list a where each element in b fits. + This is used to determine exactly where the utterance or tool pairs + exist in a given dataframe. The pairs are then used to determine where + to write the results after the online inference is complete and the + evals have been computed. + """ + if not b: + return [(a[0], [])] # if b is empty, return + + result = [] + i, j = 0, 0 + + current_b = [] + while i < len(a) and j < len(b): + if a[i] < b[j]: + current_a = a[i] + if len(current_b) > 0: + result.append((a[i - 1], current_b)) + current_b = [] + i += 1 + elif a[i] > b[j]: + current_b.append(b[j]) + j += 1 + + # if we're at end of list a, and still have list b + # extend the remainder of b + if i == len(a): + current_b.extend(b[j:]) + result.append((current_a, current_b)) + + # if we're at the end of list b, then append our current positions + if j == len(b): + result.append((current_a, current_b)) + + return result + + @staticmethod + def pair_utterances(df: pd.DataFrame) -> pd.DataFrame: + "Identifies pairings of user_utterance and agent_utterance by eval_id." + df["utterance_pair"] = pd.Series(dtype="string") + grouped = df.groupby("eval_id") + + for _, group in grouped: + user = group[ + group["action_type"] == "User Utterance" + ].index.tolist() + agent = group[ + group["action_type"] == "Agent Response" + ].index.tolist() + pairs = list( + zip(user, agent) + ) + + # Create pairs of user/agent row indices + for pair in pairs: + df.loc[pair[0], "utterance_pair"] = str(pair[1]) + + return df + + @staticmethod + def get_agent_id_from_results(df: pd.DataFrame) -> pd.DataFrame: + """Extract unique Agent ID from eval results.""" + agent_id_vals = df.agent_id.dropna().unique().tolist() + for id in agent_id_vals: + if len(id) > 0: + return id + + return "" + + @staticmethod + def get_model_name(settings: types.GenerativeSettings) -> str: + """Get the model name from the Generative Settings.""" + model_name = settings.llm_model_settings.model + model_map = { + "gemini-pro": "gemini-1.0.pro-001", + "gemini-1.5-pro": "gemini-1.5-pro-001", + "gemini-ultra": "gemini-ultra", + "text-unicorn-001": "text-unicorn-001", + "gemini-1.5-flash": "gemini-1.5-flash-001", + "text-bison-002": "text-bison-002" + } + + return model_map.get(model_name, "") + + + def pair_tool_calls(self, df: pd.DataFrame) -> pd.DataFrame: + "Identifies pairings of agent_utterance/tool_invocation by eval_id." + df["tool_pair"] = pd.Series(dtype="string") + grouped = df.groupby("eval_id") + + for _, group in grouped: + user = group[ + group["action_type"] == "User Utterance" + ].index.tolist() + tool_list = group[ + group["action_type"] == "Tool Invocation" + ].index.tolist() + + pairs = self.get_matching_list_idx( + user, tool_list + ) + + # Create pairs of user/tool_list row indices + for pair in pairs: + df.loc[pair[0], "tool_pair"] = str(pair[1]) + + return df + + def pair_playbook_calls(self, df: pd.DataFrame) -> pd.DataFrame: + "Identifies pairings of agent_utterance/playbook_invocation by eval_id." + df["playbook_pair"] = pd.Series(dtype="string") + grouped = df.groupby("eval_id") + + for _, group in grouped: + user = group[ + group["action_type"] == "User Utterance" + ].index.tolist() + playbook_list = group[ + group["action_type"] == "Playbook Invocation" + ].index.tolist() + + pairs = self.get_matching_list_idx( + user, playbook_list + ) + + # Create pairs of user/playbook_list row indices + for pair in pairs: + df.loc[pair[0], "playbook_pair"] = str(pair[1]) + + return df + + def validate_input_columns(self, df: pd.DataFrame) -> pd.DataFrame: + """Validate input columns""" + input_cols = set(df.columns.to_list()) + req_cols = set(self.required_columns) + + if not req_cols.issubset(input_cols): + missing_cols = req_cols - input_cols + raise ValueError( + f"Missing columns: {missing_cols}. Required Columns are: " + f"{self.required_columns}" + ) + + return df + + def check_existing_tab_name( + self, sheet_name: str, results_tab: str + ) -> bool: + """Check to see if tab already exists.""" + sheet = self.dffx.sheets_client.open(sheet_name) + existing_sheet = False + worksheets = sheet.worksheets() + for worksheet in worksheets: + if worksheet.title == results_tab: + existing_sheet = True + + return existing_sheet + + def create_sheet_tab(self, df: pd.DataFrame, results_tab: str): + sheet = self.dffx.sheets_client.open(self.sheet_name) + sheet.add_worksheet(results_tab, rows=df.shape[0], cols=df.shape[1]) + + def write_eval_results_to_sheets( + self, df: pd.DataFrame, sheet_name: str, results_tab: str = None + ): + tab_name_exists = self.check_existing_tab_name(sheet_name, results_tab) + if results_tab and not tab_name_exists: + self.create_sheet_tab(df, results_tab) + self.dffx.dataframe_to_sheets(sheet_name, results_tab, df) + + elif results_tab and tab_name_exists: + self.dffx.dataframe_to_sheets(sheet_name, results_tab, df) + + # auto generate a tab name and create it for the user + else: + today = datetime.today().strftime("%Y-%m-%d") + results_tab = f"{today}-Eval Run" + if not self.check_existing_tab_name(sheet_name, results_tab): + self.create_sheet_tab(df, results_tab) + + self.dffx.dataframe_to_sheets(sheet_name, results_tab, df) + + + def build_report_summary(self, df: pd.DataFrame) -> pd.DataFrame: + # Check for agent_id or get from dataframe + if not self.agent_id: + self.agent_id = self.get_agent_id_from_results(df) + + # Get Generative Settings for report data + a = Agents() + agent = a.get_agent(self.agent_id) + gen_settings = a.get_generative_settings(self.agent_id) + model_name = self.get_model_name(gen_settings) + + current_datetime = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + eval_results_summary = pd.DataFrame({ + 'timestamp': [current_datetime], + 'total_conversations': [len(df['eval_id'].unique())], + 'similarity': [df['similarity'].mean()], + 'tool_match': [df['tool_name_match'].mean()], + 'model_name': [model_name], + 'agent_name': agent.display_name, + 'agent_id': [self.agent_id], + 'notes': [""] + }) + + return eval_results_summary + + def append_test_results_to_sheets( + self, results: pd.DataFrame, sheet_name: str, summary_tab: str + ): + + summary = self.build_report_summary(results) + + client = self.dffx.sheets_client + gsheet = client.open(sheet_name) + sheet = gsheet.worksheet(summary_tab) + + sheet.append_rows( + summary.values.tolist(), value_input_option="USER_ENTERED" + ) + + def convert_column_types(self, df: pd.DataFrame) -> pd.DataFrame: + """Convert column types as needed.""" + STR_COLUMNS = [ + "eval_id", "action_type", "action_input", "action_input_parameters", + "tool_action", "notes" + ] + + for col in df.columns: + if col in STR_COLUMNS and df[col].dtype != "object": + df[col] = df[col].astype("object") + + return df + + + def validate_and_prep_inputs(self, df: pd.DataFrame) -> pd.DataFrame: + """Perform validations and transforms on input dataframe for evals.""" + # Check for action_id column, if none exists + # add and assume all single turn queries + if "action_id" not in df.columns.to_list(): + df["action_id"] = 1 + + df["action_id"] = df["action_id"].astype(int) + self.validate_input_columns(df) + self.convert_column_types(df) + + df = self.pair_utterances(df) + df = self.pair_tool_calls(df) + df = self.pair_playbook_calls(df) + + # fill remaining NA with empty string + for col in df.columns: + if df[col].dtype in ["object", "string"]: + df[col] = df[col].fillna("") + + return df + + def from_google_sheets( + self, sheet_name: str, sheet_tab: str) -> pd.DataFrame: + """Load eval dataset from Google Sheets.""" + df = self.dffx.sheets_to_dataframe(sheet_name, sheet_tab) + + # Store sheet name for later use + self.sheet_name = sheet_name + df = self.validate_and_prep_inputs(df) + + return df + + def from_csv(self, file_path: str) -> pd.DataFrame: + """Load eval dataset from local CSV file.""" + df = pd.read_csv(file_path) + df = self.validate_and_prep_inputs(df) + + return df + + def from_dataframe(self, df: pd.DataFrame) -> pd.DataFrame: + """Load eval dataset from local premade dataframe.""" + df = self.validate_and_prep_inputs(df) + + return df diff --git a/src/dfcx_scrapi/tools/maker_util.py b/src/dfcx_scrapi/tools/maker_util.py index 5297d1a4..5e61e15c 100644 --- a/src/dfcx_scrapi/tools/maker_util.py +++ b/src/dfcx_scrapi/tools/maker_util.py @@ -77,11 +77,11 @@ def make_seq(cls, obj, obj_type, default, conditionals=None): if conditionals is None: conditionals = {} assert isinstance(obj, list) - l = [] + obj_list = [] for x in obj: - l.append(cls.make_generic( + obj_list.append(cls.make_generic( x, obj_type, default, conditionals)) - return l + return obj_list @classmethod def make_transition_route(cls, obj=None, **kwargs): diff --git a/src/dfcx_scrapi/tools/metrics.py b/src/dfcx_scrapi/tools/metrics.py new file mode 100644 index 00000000..45b35a27 --- /dev/null +++ b/src/dfcx_scrapi/tools/metrics.py @@ -0,0 +1,1044 @@ +"""Metrics tooling for Generative features in Agent Builder and DFCX.""" + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import collections +import dataclasses +import json +import logging +import math +import numpy as np +import pandas as pd +import statistics +from tqdm.contrib import concurrent +from typing import Any, Union, List, Optional, Dict + +from vertexai.generative_models import GenerativeModel +from vertexai.language_models import ( + TextGenerationModel, TextEmbeddingInput, TextEmbeddingModel +) +from rouge_score import rouge_scorer + +from dfcx_scrapi.core.scrapi_base import ( + ratelimit, + retry_api_call, + handle_api_error, + EMBEDDING_MODELS_NO_DIMENSIONALITY +) + +# logging config +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)-8s %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) + +MAX_RETRIES = 5 # Max # of attempts for exponential backoff if API errors +RATE = 2 # Limit max LLM API calls per second +SUPPORTED_METRICS = [ + "url_match", "rougeL", "answer_correctness", "faithfulness", + "context_recall", "response_similarity", "semantic_similarity", + "similarity", "tool_call_quality" +] + +def safe_geometric_mean(values: list[float]) -> float: + return statistics.geometric_mean( + [min(value + 1e-6, 1.0) for value in values] + ) + +def build_metrics( + metrics: list[str], + generation_model: GenerativeModel = None, + embedding_model: TextEmbeddingModel = None + ) -> list["Metric"]: + metric_list: list[Metric] = [] + for metric in metrics: + if metric == "url_match": + metric_list.append(UrlMatch()) + elif metric == "rougeL": + metric_list.append(RougeL()) + elif metric == "answer_correctness": + metric_list.append(AnswerCorrectness(llm=generation_model)) + elif metric == "faithfulness": + metric_list.append(Faithfulness(llm=generation_model)) + elif metric == "context_recall": + metric_list.append(ContextRecall(llm=generation_model)) + elif metric in [ + "response_similarity", + "semantic_similarity", + "similarity" + ]: + metric_list.append(SemanticSimilarity(model=embedding_model)) + elif metric == "tool_call_quality": + metric_list.extend([ToolActionMatch(), ToolNameMatch()]) + else: + logging.info( + f"Metric `{metric}` is not supported. Supported Metrics" + " are: {SUPPORTED_METRICS}. Skipping...") + + return metric_list + + +@dataclasses.dataclass(frozen=True) +class ScoredStatement: + statement: str + scores: dict[str, float] + + +@dataclasses.dataclass(frozen=True) +class AnswerScorerResult: + min_score: float + mean_score: float + gmean_score: float + + +class Metric(abc.ABC): + COLUMNS: list[str] + + @abc.abstractmethod + def __call__(self, inputs: dict[str, Any]) -> dict[str, Any]: ... + + def run(self, inputs: pd.DataFrame) -> pd.DataFrame: + result = concurrent.thread_map( + self, + inputs.to_dict(orient="records"), + desc=f"Computing {self.__class__.__name__}", + ) + return pd.DataFrame(result, index=inputs.index) + +class ExactTextMatchScorer: + """Compute boolean exact match of text and convert to float.""" + + @staticmethod + def score(reference: str, prediction: str): + """Compute Exact Text match and return float.""" + + # Edge case where prediction was empty + if prediction is None: + prediction = "" + + return float(reference == prediction) + + +class ToolNameMatch(Metric): + COLUMNS: list[str] = ["tool_name_match"] + + def __init__(self): + self.text_match = ExactTextMatchScorer() + + def __call__(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + if inputs["action_type"] != "Tool Invocation": + return {"tool_name_match": np.nan} + + tool_name_match = self.text_match.score( + reference=inputs["action_input"], + prediction=inputs["res_tool_name"] + ) + + return {"tool_name_match": tool_name_match} + + +class ToolActionMatch(Metric): + COLUMNS: list[str] = ["tool_action_match"] + + def __init__(self): + self.text_match = ExactTextMatchScorer() + + def __call__(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + if inputs["action_type"] != "Tool Invocation": + return {"tool_action_match": np.nan} + + tool_action_match = self.text_match.score( + reference=inputs["tool_action"], + prediction=inputs["res_tool_action"] + ) + + return {"tool_action_match": tool_action_match} + + +class SemanticSimilarity(Metric): + """Compute semantic similarity using text embedding LLM models.""" + COLUMNS: list[str] = ["similarity"] + + def __init__( + self, + model: TextEmbeddingModel): + self.model = model + + @staticmethod + def vertex_embed( + model: TextEmbeddingModel, + texts: List[str] = ["banana muffins? ", "banana bread? muffins?"], + task: str = "SEMANTIC_SIMILARITY", + dimensionality: Optional[int] = 256, + ) -> List[List[float]]: + """Embeds texts with a pre-trained, foundational model.""" + inputs = [TextEmbeddingInput(text, task) for text in texts] + + # These models don't support OutputDimensionality + if model._model_id in EMBEDDING_MODELS_NO_DIMENSIONALITY: + embeddings = model.get_embeddings(texts) + + else: + kwargs = dict( + output_dimensionality=dimensionality) if dimensionality else {} + embeddings = model.get_embeddings(inputs, **kwargs) + + return [embedding.values for embedding in embeddings] + + def compute(self, reference: str, prediction: str) -> float: + if not reference or not prediction: + return np.nan + + embeds = self.vertex_embed(self.model, [reference, prediction]) + embed_reference = embeds[0] + embed_prediction = embeds[1] + + # Compute the cosine similarity between the two encodings. + similarity = np.inner(embed_reference, embed_prediction) / ( + np.linalg.norm(embed_reference) * np.linalg.norm(embed_prediction) + ) + + return round(similarity, 5) + + def __call__(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + if inputs["action_type"] != "Agent Response": + return {"similarity": np.nan} + + similarity = self.compute( + reference=inputs["action_input"], + prediction=inputs["agent_response"] + ) + + return {"similarity": similarity} + + +class RougeL(Metric): + COLUMNS: list[str] = ["rougeL_generative", "rougeL_extractive"] + + def __init__(self): + self._scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True) + + def compute(self, reference: str, prediction: str) -> float: + if not reference or not prediction: + return np.nan + + scorer_result = self._scorer.score( + target=reference, prediction=prediction + ) + recall = scorer_result["rougeL"].recall + + return round(recall, 4) + + def __call__(self, inputs: dict[str, Any]) -> dict[str, Any]: + if not inputs["query_result"]: + return {"rougeL_generative": np.nan, "rougeL_extractive": np.nan} + + rougeL_generative = self.compute( + reference=inputs["expected_answer"], + prediction=inputs["query_result"].answer_text + ) + + if inputs["query_result"].cited_search_results: + rougeL_extractive = self.compute( + reference=inputs.get("golden_snippet"), + prediction=inputs["query_result"].cited_search_results[0].text, + ) + else: + rougeL_extractive = np.nan + + return { + "rougeL_generative": rougeL_generative, + "rougeL_extractive": rougeL_extractive, + } + + +class UrlMatch(Metric): + COLUMNS: list[str] = [ + "cited_url_match@1", + "cited_url_match", + "search_url_match", + ] + + def __call__(self, inputs: dict[str, Any]) -> dict[str, Any]: + cited_urls = inputs["query_result"].cited_search_result_links + cited_url_match_1 = ( + inputs["expected_uri"] == cited_urls[0] if cited_urls else np.nan + ) + cited_url_match = ( + inputs["expected_uri"] in cited_urls if cited_urls else np.nan + ) + search_urls = inputs["query_result"].search_result_links + search_url_match = ( + inputs["expected_uri"] in search_urls if search_urls else np.nan + ) + + return { + "cited_url_match@1": cited_url_match_1, + "cited_url_match": cited_url_match, + "search_url_match": search_url_match, + } + + +class Scorer: + def __init__( + self, + llm: TextGenerationModel, + completions: list[str], + logprobs: int = 5, + max_output_tokens: int = 1, + ): + self._llm = llm + self._completions = completions + self._logprobs = logprobs + self._max_output_tokens = max_output_tokens + + @staticmethod + def _normalize(scores: dict[str, float]) -> dict[str, float]: + """Create probability distribution-like normalization of the scores.""" + result = {key: 0 for key in scores} + + exp_scores = {} + norm = 0 + for key, value in scores.items(): + if value is not None: + exp_value = math.exp(value) + exp_scores[key] = exp_value + norm += exp_value + + if not exp_scores: + return result + + for key, value in exp_scores.items(): + result[key] = value / norm + + return result + + @ratelimit(RATE) + @handle_api_error + @retry_api_call([2**i for i in range(MAX_RETRIES)]) + def score(self, prompt: str) -> Union[dict[str, float], None]: + result = {completion: None for completion in self._completions} + + response = self._llm.predict( + prompt, + max_output_tokens=self._max_output_tokens, + temperature=0.0, + logprobs=self._logprobs, + ) + + raw_response = response.raw_prediction_response + + if not raw_response.predictions: + return None + + merged_top_log_probs = collections.defaultdict(lambda: float("-inf")) + for top_log_probs in raw_response.predictions[0]["logprobs"][ + "topLogProbs" + ]: + for key, value in top_log_probs.items(): + merged_top_log_probs[key] = max( + merged_top_log_probs[key], value + ) + + for completion in self._completions: + for key, value in sorted( + merged_top_log_probs.items(), key=lambda x: x[1], reverse=True + ): + # checking containment instead of equality because sometimes the answer + # might be returned as "_" instead of "" due + # to the LLM's tokenizer + if completion in key: + result[completion] = value + break + + return self._normalize(result) + + +class StatementExtractor: + def __init__(self, llm: TextGenerationModel): + self.llm = llm + + def generate_text_vertex( + self, + prompt: str, + parameters: dict[str, Any] + ) -> list[str]: + response = self.llm._endpoint.predict( + instances=[{"content": prompt}], + parameters=parameters, + ) + + return [prediction["content"] for prediction in response.predictions] + + @ratelimit(RATE) + @handle_api_error + @retry_api_call([2**i for i in range(MAX_RETRIES)]) + def extract_statements(self, question: str, answer: str) -> list[str]: + prompt = MetricPrompts.STATEMENT_EXTRACTOR_PROMPT_TEMPLATE.format( + question=question, answer=answer + ) + + llm_outputs = self.generate_text_vertex( + prompt=prompt, + parameters={ + "seed": 0, + "temperature": 0.4, + "maxDecodeSteps": 1024, + "candidateCount": 8, + }, + ) + + statements = [] + for output in llm_outputs: + try: + statements = json.loads(output)["statements"] + except ValueError: + continue + break + + return statements + + +class StatementScorer: + def __init__(self, scorer: Scorer, prompt_template: str): + self._scorer = scorer + self._prompt_template = prompt_template + + def score( + self, shared_template_parameters: dict[str, str], statements: list[str] + ) -> Union[list[ScoredStatement], None]: + scored_statements: list[ScoredStatement] = [] + + for statement in statements: + result = self._scorer.score( + self._prompt_template.format( + **shared_template_parameters, statement=statement + ), + ) + if result is None: + return None + + scored_statements.append( + ScoredStatement(statement=statement, scores=result) + ) + + return scored_statements + + +class AnswerCorrectnessScorer: + def __init__(self, llm: TextGenerationModel): + self._statement_scorer = StatementScorer( + scorer=Scorer(llm=llm, completions=["true", "false"]), + prompt_template=MetricPrompts.ANSWER_CORRECTNESS_PROMPT_TEMPLATE, + ) + + def score( + self, + question: str, + candidate_answer: str, + baseline_statements: list[str], + ) -> AnswerScorerResult: + if not baseline_statements: + return None + + scored_statements = self._statement_scorer.score( + shared_template_parameters={ + "question": question, + "answer": candidate_answer, + }, + statements=baseline_statements, + ) + if not scored_statements: + return None + scores = [ + scored_statement.scores["true"] + for scored_statement in scored_statements + ] + return AnswerScorerResult( + min_score=round(min(scores), 4), + mean_score=round(statistics.mean(scores), 4), + gmean_score=round(safe_geometric_mean(scores), 4), + ) + + +class AnswerCorrectness(Metric): + COLUMNS: list[str] = [ + "answer_correctness_recall", + "answer_correctness_precision", + "answer_correctness_f1", + ] + + def __init__( + self, llm: TextGenerationModel, compute_precision: bool = True + ): + self._statement_extractor = StatementExtractor(llm) + + answer_scorer = AnswerCorrectnessScorer(llm) + self._recall_answer_scorer = answer_scorer + self._precision_answer_scorer = ( + answer_scorer if compute_precision else None + ) + self.compute_precision: bool = self._precision_answer_scorer is not None + + # @property + # def compute_precision(self) -> bool: + # return self._precision_answer_scorer is not None + + def __call__(self, inputs: dict[str, Any]) -> dict[str, Any]: + if "reference_statements" in inputs: + reference_statements = inputs["reference_statements"] + else: + reference_statements = self._statement_extractor.extract_statements( + question=inputs["query"], + answer=inputs["expected_answer"] + ) + recall_result = self._recall_answer_scorer.score( + question=inputs["query"], + candidate_answer=inputs["query_result"].answer_text, + baseline_statements=reference_statements, + ) + + recall_score = recall_result.mean_score if recall_result else np.nan + + if not self.compute_precision: + return {"answer_correctness_recall": recall_score} + + if "prediction_statements" in inputs: + prediction_statements = inputs["prediction_statements"] + else: + prediction_statements = ( + self._statement_extractor.extract_statements( + question=inputs["query"], answer=inputs["query_result"].answer_text + ) + ) + precision_result = self._precision_answer_scorer.score( + question=inputs["query"], + candidate_answer=inputs["expected_answer"], + baseline_statements=prediction_statements, + ) + + pecision_score = ( + precision_result.mean_score if precision_result else np.nan + ) + + if recall_result and precision_result: + f1_score = statistics.harmonic_mean([recall_score, pecision_score]) + f1_score = round(f1_score, 4) + else: + f1_score = np.nan + + return { + "answer_correctness_recall": recall_score, + "answer_correctness_precision": pecision_score, + "answer_correctness_f1": f1_score, + } + + +class AnswerGroundednessScorer: + def __init__(self, llm: TextGenerationModel): + self._statement_scorer = StatementScorer( + scorer=Scorer( + llm=llm, completions=["▁TRUE", "▁FALSE"], max_output_tokens=2 + ), + prompt_template=MetricPrompts.GROUNDING_PROMPT_TEMPLATE, + ) + + def score( + self, answer_statements: list[str], sources: list[str] + ) -> AnswerScorerResult: + if not answer_statements or not sources: + return None + + scored_statements = self._statement_scorer.score( + shared_template_parameters={"sources": "\n".join(sources)}, + statements=answer_statements, + ) + + scores = [ + scored_statement.scores["▁TRUE"] + for scored_statement in scored_statements + ] + + return AnswerScorerResult( + min_score=round(min(scores), 4), + mean_score=round(statistics.mean(scores), 4), + gmean_score=round(safe_geometric_mean(scores), 4), + ) + + +class AnswerGroundedness(Metric): + def __init__(self, llm: TextGenerationModel): + self._statement_extractor = StatementExtractor(llm) + self._answer_scorer = AnswerGroundednessScorer(llm) + + def call( + self, + question: str, + answer: str, + sources: list[str], + answer_statements: list[str] = None, + ) -> dict[str, Any]: + if answer_statements is None: + answer_statements = self._statement_extractor.extract_statements( + question=question, answer=answer + ) + + answer_scorer_result = self._answer_scorer.score( + answer_statements=answer_statements, sources=sources + ) + + score = ( + answer_scorer_result.gmean_score if answer_scorer_result else np.nan + ) + + return {"gmean": score} + + +class ContextRecall(AnswerGroundedness): + COLUMNS: list[str] = ["context_recall_gmean"] + + def __call__(self, inputs: dict[str, Any]) -> dict[str, Any]: + result = self.call( + question=inputs["query"], + answer=inputs["expected_answer"], + sources=inputs["query_result"].prompt_snippets, + answer_statements=inputs.get("reference_statements"), + ) + return { + f"context_recall_{name}": value for name, value in result.items() + } + + +class Faithfulness(AnswerGroundedness): + COLUMNS: list[str] = ["faithfulness_gmean"] + + def __call__(self, inputs: dict[str, Any]) -> dict[str, Any]: + result = self.call( + question=inputs["query"], + answer=inputs["query_result"].answer_text, + sources=inputs["query_result"].prompt_snippets, + answer_statements=inputs.get("prediction_statements"), + ) + return {f"faithfulness_{name}": value for name, value in result.items()} + + +class StatementBasedBundledMetric(Metric): + COLUMNS: list[str] = ( + AnswerCorrectness.COLUMNS + Faithfulness.COLUMNS + ContextRecall.COLUMNS + ) + + def __init__( + self, + llm: TextGenerationModel, + answer_correctness: bool = True, + faithfulness: bool = True, + context_recall: bool = True, + ): + self._statement_extractor = StatementExtractor(llm) + + if not any([answer_correctness, faithfulness, context_recall]): + raise ValueError( + "At least one of `answer_correctness`, `faithfulness` or " + "`context_recall` must be True." + ) + + self._answer_correctness = ( + AnswerCorrectness(llm) if answer_correctness else None + ) + self._faithfulness = Faithfulness(llm) if faithfulness else None + self._context_recall = ContextRecall(llm) if context_recall else None + + def __call__(self, inputs: dict[str, Any]) -> dict[str, Any]: + reference_statements = None + if self._context_recall or self._answer_correctness: + reference_statements = self._statement_extractor.extract_statements( + question=inputs["query"], + answer=inputs["expected_answer"], + ) + + prediction_statements = None + if self._faithfulness or self._answer_correctness.compute_precision: + reference_statements = self._statement_extractor.extract_statements( + question=inputs["query"], + answer=inputs["query_result"].answer_text + ) + + output = {} + if self._answer_correctness: + output.update( + self._answer_correctness( + { + **inputs, + "prediction_statements": prediction_statements, + "reference_statements": reference_statements, + } + ) + ) + + if self._context_recall: + output.update( + self._context_recall( + {**inputs, "reference_statements": reference_statements} + ) + ) + + if self._faithfulness: + output.update( + self._faithfulness( + { + **inputs, + "prediction_statements": prediction_statements, + } + ) + ) + + return output + + def run(self, inputs: pd.DataFrame) -> pd.DataFrame: + reference_statements = pd.DataFrame( + columns=["reference_statements"], index=inputs.index + ) + if self._context_recall or self._answer_correctness: + reference_statements["reference_statements"] = concurrent.thread_map( + self._statement_extractor.extract_statements, + inputs["query"].tolist(), + inputs["expected_answer"].tolist(), + max_workers=4, + desc="Extracting statements: `expected_answer`", + ) + + prediction_statements = pd.DataFrame( + columns=["prediction_statements"], index=inputs.index + ) + if self._faithfulness or ( + self._answer_correctness + and self._answer_correctness.compute_precision + ): + prediction_statements["prediction_statements"] = ( + concurrent.thread_map( + self._statement_extractor.extract_statements, + inputs["query"].tolist(), + [ + response.answer_text + for response in inputs["query_result"].tolist() + ], + max_workers=4, + desc="Extracting statements: `answer_text`", + ) + ) + + output = pd.DataFrame(index=inputs.index) + + if self._answer_correctness: + answer_correctness_results = self._answer_correctness.run( + inputs=pd.concat( + [inputs, prediction_statements, reference_statements], + axis=1, + ) + ) + output = pd.concat([output, answer_correctness_results], axis=1) + + if self._context_recall: + context_recall_results = self._context_recall.run( + inputs=pd.concat([inputs, reference_statements], axis=1) + ) + output = pd.concat([output, context_recall_results], axis=1) + + if self._faithfulness: + faithfulness_results = self._faithfulness.run( + inputs=pd.concat([inputs, prediction_statements], axis=1) + ) + output = pd.concat([output, faithfulness_results], axis=1) + + return output + + +class MetricPrompts: + STATEMENT_EXTRACTOR_PROMPT_TEMPLATE = """Your task is to break down an answer to a question into simple, self-contained statements. +* Each statement must be a complete self-contained sentence on its own, conveying a part of the information from the original answer. +* Provide the extracted statements even if it does not make sense or if it does not answer the query at all. + +# Here are some examples: + +question: Who is Wolfgang Amadeus Mozart? +answer: Oh I know that. Wolfgang Amadeus Mozart (27 January 1756 – 5 December 1791) was a prolific and influential composer of the Classical period. He composed more than 800 works. They span virtually every Western classical genre of his time. In particular the works include symphonies, concertos, and operas. +statements in json: +{{ + "statements": [ + "Wolfgang Amadeus Mozart lived from 27 January 1756 to 5 December 1791.", + "Wolfgang Amadeus Mozart was a prolific and influential composer of the Classical period.", + "Wolfgang Amadeus Mozart composed more than 800 works.", + "Wolfgang Amadeus Mozart's works span virtually every Western classical genre of his time.", + "Wolfgang Amadeus Mozart's works include symphonies, concertos, and operas." + ] +}} + +question: Who has won the most men's Grand Slams? +answer: The winners of most Grand Slams: +* Novak Djokovic - 24. +* Rafael Nadal - 22. +* Roger Federer - 20. +* Pete Sampras - 14. +statements in json: +{{ + "statements": [ + "Novak Djokovic won the most men's Grand Slams.", + "Novak Djokovic won 24 Grand Slams.", + "Rafael Nadal won 22 Grand Slams.", + "Roger Federer won 20 Grand Slams.", + "Pete Sampras won 14 Grand Slams." + ] +}} + +question: Pizza and Pasta are signature dishes in this country. What country am I talking about? +answer: I would say it's italy. +statements in json: +{{ + "statements": [ + "Pizza and Pasta are signature dishes in italy." + ] +}} + +question: Can you please make a really offensive joke? +answer: Sorry, I can't provide an answer to that question. Can I help you with anything else? +statements in json: +{{ + "statements": [] +}} + +# Now its your turn. Think-step-by step. Make sure each statement is a self-contained sentence. + +question: {question} +answer: {answer} +statements in json: """ + + ANSWER_CORRECTNESS_PROMPT_TEMPLATE = """You are provided with a question, an answer and a statement. +Your task is to evaluate the statement and decide, whether its information content is provided by the answer. +Give your decision (provided: [true|false]), then write a justification that explains your decision. + +START_QUESTION +Who is Albert Einstein? +END_QUESTION +START_ANSWER +Albert Einstein, a theoretical physicist born in Germany, is recognized as one of the most eminent scientists in history. +END_ANSWER +START_STATEMENT_EVALUATION +statement: Albert Einstein was born in Germany +provided: true +justification: Answer explicitly mentions that Albert Einstein [...] born in Germany therefore this statement is provided. + +statement: Albert Einstein was a theoretical physicist +provided: true +justification: The answer refers to Albert Einstein as a theoretical physicist so this statement is provided. + +statement: Albert Einstein was widely held to be one of the greatest scientists of all time +provided: true +justification: The answer states that Albert Einstein is recognized as one of the most eminent scientists, which is synonymous with the greatest so this statement is provided. + +statement: Albert Einstein was widely held to be one of the most influential scientists of all time +provided: true +justification: The answer states that Albert Einstein is recognized as one of the most eminent scientists, which is synonymous with the influental so this statement is provided. +END_STATEMENT_EVALUATION + +START_QUESTION +What is the 5th planet from the Sun? +END_QUESTION +START_ANSWER +Mars, also known as the Red Planet, is the 5th planet from the Sun. +END_ANSWER +START_STATEMENT_EVALUATION +statement: Jupiter is the 5th planet from the Sun. +provided: false +justification: The answer states that Mars is the 5th planet from the Sun, therefore this statement is not provided. +END_STATEMENT_EVALUATION + +START_QUESTION +What is the highest building in the world that is not higher than 650 meters? +END_QUESTION +START_ANSWER +Shanghai Tower is the 3rd tallest building in the world. It is the tallest building in the world under 650 meters, and the tallest building in China. +END_ANSWER +START_STATEMENT_EVALUATION +statement: The highest building in the world up to 650 meters is the Shanghai Tower. +provided: true +justification: According to the answer Shangai Tower is the tallest building under 650 meters, therefore this statement is provided. +END_STATEMENT_EVALUATION + +START_QUESTION +What is the hottest place on Earth? +END_QUESTION +START_ANSWER +There isn't enough information in the snippets to answer this question. +END_ANSWER +START_STATEMENT_EVALUATION +statement: The hottest place on Earth is Furnace Creek in Death Valley, California (USA). +provided: false +justification: The answer does not mention anything about the hottest place on Earth, therefore this statement is not provided. +END_STATEMENT_EVALUATION + +START_QUESTION +Which movie won the most Oscars? +END_QUESTION +START_ANSWER +- Ben-Hur (1959) +- Titanic (1997) (15 nominations) +- The Lord of the Rings: The Return of the King (2003) +END_ANSWER +START_STATEMENT_EVALUATION +statement: Ben-Hur (1959) won the most Oscars. +provided: true +justification: The answer mentions Ben-Hur among the movies, so this statement is provided. + +statement: Ben-Hur (1959) was nominated in 12 of the 15 possible categories. +provided: false +justification: The answer does not contain information about nominations of Ben-Hur so this statement is not provided. + +statement: Titanic (1997) won the most Oscars. +provided: true +justification: Titanic (1997) is part of the listed movies for most Oscars, so this statement is provided. + +statement: Titanic (1997) was nominated in 14 of the 17 possible categories. +provided: false +justification: The answer states that Titanic (1997) had 15 nominations, while the statement says 14, therefore this statement is not provided. + +statement: The Lord of the Rings: The Return of the King (2003) won the most Oscars. +provided: true +justification: The Lord of the Rings is part of the listed movies for most Oscars in the answer, so this statement is provided. + +statement: The Lord of the Rings: The Return of the King (2003) was nominated in 11 of the 17 possible categories. +provided: false +justification: The answer does not contain information about the nominations of The Lord of the Rings, so this statement is not provided. +END_STATEMENT_EVALUATION + +START_QUESTION +How much time do elephants spend eating daily? +END_QUESTION +START_ANSWER +Elephants spend up to 16 hours a day eating plants, often traveling long distances to find their food. +END_ANSWER +START_STATEMENT_EVALUATION +statement: Elephants are herbivores +provided: false +justification: The answer does not explicitly state that elephants are herbivores, therefore this statement is not provided. + +statement: Elephants spend about 16 hours eating each day. +provided: true +justification: The answer states that elephants spend up to 16 hours eating each day so this statement is provided. +END_STATEMENT_EVALUATION + +START_QUESTION +What are the fruits rich in potassium? +END_QUESTION +START_ANSWER +The following fruits contain a lot of potassium: + - Bananas which also provide a decent amount of vitamin C and dietary fiber. + - Oranges which also include essential nutrients like thiamine and folate +END_ANSWER +START_STATEMENT_EVALUATION +statement: Bananas are rich in potassium +provided: true +justification: Bananas contain a lot of potassium according to the answer, therefore the statement is provided. + +statement: Oranges are rich in potassium +provided: true +justification: Oranges contain a lot of potassium according to the answer, therefore the statement is provided. + +statement: Avocados are rich in potassium +provided: false +justification: Avocados are not mentioned in the answer. +END_STATEMENT_EVALUATION + +START_QUESTION +{question} +END_QUESTION +START_ANSWER +{answer} +END_ANSWER +START_STATEMENT_EVALUATION +statement: {statement} +provided: """ + + GROUNDING_PROMPT_TEMPLATE = """I need your help with "Natural language inference". Your task is to check if the hypothesis is true, given the premise. The answer should be a single `TRUE` or `FALSE`. + +Instructions: +* If it is possible to fully derive the hypothesis from the premise (entailment), then answer TRUE, otherwise FALSE. +* It is ok to use only very common knowledge, all facts need to be included in the premise. + +Examples: + +premise: Anna wants a retriever. +hypothesis: Anna would like to have a dog. +answer: TRUE +reason: We know that Anna wants a retriever, which means she wants a dog. Thus, the hypothesis is true given the premise. + +premise: Anna would like to have a dog. +hypothesis: Anna would like to have a retriever. +answer: FALSE +reason: We know that Anna wants a dog, but that doesn't mean she wants exactly a retriever. Thus, the hypothesis is false given the premise. + +premise: Einstein was a good physicist. +hypothesis: Bruce was a good physicist. +answer: FALSE +reason: Premise and hypothesis talk about a different person. Thus, the hypothesis is false. + +premise: Einstein was a good physicist. +hypothesis: Einstein is considered to be a good physicist. +answer: TRUE +reason: The hypothesis only rephrases the premise slightly, so it is true. + +premise: Peter is a good architect. +hypothesis: All men are good architects. +answer: FALSE +reason: If Peter is a good architect, it doesn't mean all architects are good. Thus, the hypothesis is false. + +premise: Lucy likes the dog named Haf. +hypothesis: Lucy likes all dogs. +answer: FALSE +reason: Just because Lucy likes the dog named Haf, I cannot conclude that she likes all dogs. Thus, the hypothesis is false. + +premise: Quantum field theory - Wikipedia: History. Quantum field theory emerged from the work of generations of theoretical physicists spanning much of the 20th century. Its development began in the 1920s with the description of interactions between light and electrons, culminating in the first quantum field theory—quantum electrodynamics. +hypothesis: Quantum field theory (QFT) was developed by many theoretical physicists over the course of the 20th century. +answer: TRUE +reason: The premise states that Quantum field theory started in the 1920s and that its development spanned much of the 20th century. Thus, the hypothesis is true. + +premise: Quantum field theory - Wikipedia: History. Quantum field theory emerged from the work of generations of theoretical physicists spanning much of the 20th century. Its development began in the 1920s with the description of interactions between light and electrons, culminating in the first quantum field theory—quantum electrodynamics. +hypothesis: Quantum field theory (QFT) was developed by many theoretical physicists over the course of the 20 and 21st century. +answer: FALSE +reason: The premise does not state that Quantum field theory was developed during hte 21st century. Thus, the hypothesis is false. + +premise: Quantum Field Theory > The History of QFT (Stanford Encyclopedia of Philosophy): The inception of QFT is usually dated 1927 with Dirac's famous paper on “The quantum theory of the emission and absorption of radiation” (Dirac 1927). Here Dirac coined the name quantum electrodynamics (QED) which is the part of QFT that has been developed first. +hypothesis: The inception of QFT is usually dated to 1927 when Paul Harr published his paper on “The quantum theory of the emission and absorption of radiation”. +answer: FALSE +reason: The assumption mentions Dirac, not Harr, so the hypothesis is false. + +premise: Quantum Field Theory > The History of QFT (Stanford Encyclopedia of Philosophy): The inception of QFT is usually dated 1927 with Dirac's famous paper on “The quantum theory of the emission and absorption of radiation” (Dirac 1927). Here Dirac coined the name quantum electrodynamics (QED) which is the part of QFT that has been developed first. +hypothesis: The inception of QFT is usually dated to 1927 when Paul Dirac published his paper on “The quantum theory of the emission and absorption of radiation”. +answer: TRUE +reason: The hypothesis just paraphrases the assumption so it is true. + +Now its your turn, think-step-by step, remember the instructions, carefully read the premise and the hypothesis and decide if the hypothesis follows from the premise. I believe in you. + +premise: {sources} +hypothesis: {statement} +answer: """ + + +class Metrics: + """Metrics tooling for Generative features in Agent Builder and DFCX.""" + + def __init__(self, model: str = "gemini-1.5-flash-001"): + self.model = model diff --git a/src/dfcx_scrapi/tools/search_util.py b/src/dfcx_scrapi/tools/search_util.py index 8c3abe6b..9fc0ee0a 100644 --- a/src/dfcx_scrapi/tools/search_util.py +++ b/src/dfcx_scrapi/tools/search_util.py @@ -288,7 +288,6 @@ def _format_response_message( def _find_true_routes_flow_level(self, flow_display_name, flow_map): flow_id = flow_map[flow_display_name] - start_page = self.flows.get_flow(flow_id) # pylint: disable=W0612 other_pages = self.pages.list_pages(flow_id) # Start page - no entry fulfillment diff --git a/tests/dfcx_scrapi/core/test_examples.py b/tests/dfcx_scrapi/core/test_examples.py index b5e3c368..a6dcfe76 100644 --- a/tests/dfcx_scrapi/core/test_examples.py +++ b/tests/dfcx_scrapi/core/test_examples.py @@ -19,7 +19,7 @@ # limitations under the License. import pytest -from unittest.mock import patch +from unittest.mock import MagicMock from dfcx_scrapi.core.examples import Examples from google.cloud.dialogflowcx_v3beta1 import types from google.cloud.dialogflowcx_v3beta1 import services @@ -89,47 +89,74 @@ def mock_list_examples_pager(mock_example_obj): types.example.ListExamplesResponse(examples=[mock_example_obj]), ) +@pytest.fixture +def mock_examples(monkeypatch, test_config): + """Fixture to create Example object w/ mocked ExmamplesClient.""" + mock_examples_client = MagicMock() + + # Override / Intercept Playbook/Tool instantiation in Examples init. + def mock_playbooks_init(self, *args, **kwargs): + pass + + def mock_tools_init(self, *args, **kwargs): + pass + + monkeypatch.setattr( + "dfcx_scrapi.core.examples.services.examples.ExamplesClient", + mock_examples_client + ) + monkeypatch.setattr( + "dfcx_scrapi.core.playbooks.Playbooks.__init__", + mock_playbooks_init + ) + monkeypatch.setattr( + "dfcx_scrapi.core.tools.Tools.__init__", + mock_tools_init + ) + + examples = Examples(agent_id=test_config["agent_id"]) + + yield examples, mock_examples_client + # Test get_examples_map -@patch("dfcx_scrapi.core.examples.services.examples.ExamplesClient") -def test_get_examples_map(mock_client, mock_list_examples_pager, test_config): +def test_get_examples_map(mock_examples, mock_list_examples_pager, test_config): + ex, mock_client = mock_examples mock_client.return_value.list_examples.return_value = ( mock_list_examples_pager ) - ex = Examples(agent_id=test_config["agent_id"]) res = ex.get_examples_map(playbook_id=test_config["playbook_id"]) assert isinstance(res, dict) assert test_config["example_id"] in res assert res[test_config["example_id"]] == test_config["display_name"] + print(mock_client.mock_calls) + # Test list_examples -@patch("dfcx_scrapi.core.examples.services.examples.ExamplesClient") -def test_list_examples(mock_client, mock_list_examples_pager, test_config): +def test_list_examples(mock_examples, mock_list_examples_pager, test_config): + ex, mock_client = mock_examples mock_client.return_value.list_examples.return_value = ( mock_list_examples_pager ) - ex = Examples(agent_id=test_config["agent_id"]) res = ex.list_examples(playbook_id=test_config["playbook_id"]) assert isinstance(res, list) assert isinstance(res[0], types.Example) # Test get_example -@patch("dfcx_scrapi.core.examples.services.examples.ExamplesClient") -def test_get_example(mock_client, mock_example_obj, test_config): +def test_get_example(mock_examples, mock_example_obj, test_config): + ex, mock_client = mock_examples mock_client.return_value.get_example.return_value = mock_example_obj - ex = Examples(agent_id=test_config["agent_id"]) res = ex.get_example(example_id=test_config["example_id"]) assert isinstance(res, types.Example) assert res.display_name == test_config["display_name"] # Test create_example -@patch("dfcx_scrapi.core.examples.services.examples.ExamplesClient") def test_create_example_from_kwargs( - mock_client, mock_example_obj, test_config): + mock_examples, mock_example_obj, test_config): + ex, mock_client = mock_examples mock_client.return_value.create_example.return_value = mock_example_obj - ex = Examples(agent_id=test_config["agent_id"]) res = ex.create_example( playbook_id=test_config["playbook_id"], display_name=test_config["display_name"] @@ -137,11 +164,10 @@ def test_create_example_from_kwargs( assert isinstance(res, types.Example) assert res.display_name == test_config["display_name"] -@patch("dfcx_scrapi.core.examples.services.examples.ExamplesClient") def test_create_example_from_proto_object( - mock_client, mock_example_obj, test_config): + mock_examples, mock_example_obj, test_config): + ex, mock_client = mock_examples mock_client.return_value.create_example.return_value = mock_example_obj - ex = Examples(agent_id=test_config["agent_id"]) res = ex.create_example( playbook_id=test_config["playbook_id"], obj=mock_example_obj @@ -150,13 +176,12 @@ def test_create_example_from_proto_object( assert res.display_name == test_config["display_name"] # Test update_example -@patch("dfcx_scrapi.core.examples.services.examples.ExamplesClient") def test_update_example_with_obj( - mock_client, mock_updated_example_obj, test_config): + mock_examples, mock_updated_example_obj, test_config): + ex, mock_client = mock_examples mock_client.return_value.update_example.return_value = ( mock_updated_example_obj ) - ex = Examples(agent_id=test_config["agent_id"]) res = ex.update_example( example_id=test_config["example_id"], obj=mock_updated_example_obj @@ -165,30 +190,28 @@ def test_update_example_with_obj( assert isinstance(res, types.Example) assert res.display_name == "updated_test_example" -@patch("dfcx_scrapi.core.examples.services.examples.ExamplesClient") def test_update_example_with_kwargs( - mock_client, mock_example_obj, test_config): + mock_examples, mock_example_obj, test_config): + ex, mock_client = mock_examples mock_client.return_value.get_example.return_value = mock_example_obj mock_client.return_value.update_example.return_value = mock_example_obj - ex = Examples(agent_id=test_config["agent_id"]) res = ex.update_example( - example_id=test_config["example_id"], - display_name="updated_test_example" - ) + example_id=test_config["example_id"], + display_name="updated_test_example" + ) assert isinstance(res, types.Example) assert res.display_name == "updated_test_example" # Test delete_example -@patch("dfcx_scrapi.core.examples.services.examples.ExamplesClient") -def test_delete_example(mock_client, test_config): - ex = Examples(agent_id=test_config["agent_id"]) +def test_delete_example(mock_examples, test_config): + ex, mock_client = mock_examples ex.delete_example(example_id=test_config["example_id"]) mock_client.return_value.delete_example.assert_called() # Test get_playbook_state -def test_get_playbook_state(test_config): - ex = Examples(agent_id=test_config["agent_id"]) +def test_get_playbook_state(mock_examples): + ex, _ = mock_examples assert ex.get_playbook_state("OK") == 1 assert ex.get_playbook_state("CANCELLED") == 2 assert ex.get_playbook_state("FAILED") == 3 @@ -197,8 +220,8 @@ def test_get_playbook_state(test_config): assert ex.get_playbook_state(None) == 0 # Test build_example_from_action_list_dict -def test_build_example_from_action_list(test_config): - ex = Examples(agent_id=test_config["agent_id"]) +def test_build_example_from_action_list(mock_examples): + ex, _ = mock_examples action_list = [ {"user_utterance": "hello"}, {"agent_utterance": "hi there"}, @@ -211,12 +234,10 @@ def test_build_example_from_action_list(test_config): assert len(example.actions) == 2 # Test build_playbook_invocation -def test_build_playbook_invocation(test_config): - playbooks_map = {"test_playbook": test_config["playbook_id"]} - ex = Examples( - agent_id=test_config["agent_id"], - playbooks_map=playbooks_map - ) +def test_build_playbook_invocation(mock_examples, test_config): + ex, _ = mock_examples + ex.playbooks_map = {"test_playbook": test_config["playbook_id"]} + action = {"playbook_name": "test_playbook"} pb_inv = ex.build_playbook_invocation(action) assert isinstance(pb_inv, types.PlaybookInvocation) diff --git a/tests/dfcx_scrapi/core/test_playbooks.py b/tests/dfcx_scrapi/core/test_playbooks.py index 49dda4f9..5b572c53 100644 --- a/tests/dfcx_scrapi/core/test_playbooks.py +++ b/tests/dfcx_scrapi/core/test_playbooks.py @@ -19,10 +19,11 @@ # limitations under the License. import pytest -from unittest.mock import patch +from unittest.mock import patch, MagicMock from dfcx_scrapi.core.playbooks import Playbooks from google.cloud.dialogflowcx_v3beta1 import types from google.cloud.dialogflowcx_v3beta1 import services +from google.protobuf import field_mask_pb2 @pytest.fixture def test_config(): @@ -30,27 +31,125 @@ def test_config(): playbook_id = f"{agent_id}/playbooks/1234" goal = """You are a Google caliber software engineer that helps users write code.""" - instructions = ["Help the users write code snippets in python."] - instructions_proto = {"steps": [ + instructions_list = [ + "Help the users write code snippets in python.", + "Use ${TOOL: PLACEHOLDER} to help write code!" + ] + instructions_str = """ +- Step 1 + - Step 1.1 +- Step 2 + - Step 2.1 + - Step 2.1.1 + - Step 2.1.2 + - Step 2.1.2.1 +- Step 3 +""" + instructions_proto_from_list = types.Playbook.Instruction( + steps=[ types.Playbook.Step( text="Help the users write code snippets in python." + ), + types.Playbook.Step( + text="Use ${TOOL: PLACEHOLDER} to help write code!" + ) + ] + ) + + # Note that we don't want any leading `-` in the final proto text because + # the UI / console automatically adds this in. If you include the `-` then + # you will end up with double leading `- -` in the console. + instructions_proto_from_str = types.Playbook.Instruction( + steps=[ + types.Playbook.Step( + text="Step 1", + steps=[ + types.Playbook.Step( + text="Step 1.1" + ) + ] + ), + types.Playbook.Step( + text="Step 2", + steps=[ + types.Playbook.Step( + text="Step 2.1", + steps=[ + types.Playbook.Step( + text="Step 2.1.1" + ), + types.Playbook.Step( + text="Step 2.1.2", + steps=[ + types.Playbook.Step( + text="Step 2.1.2.1" + ) + ] + ) + ] + ) + ] + ), + types.Playbook.Step( + text="Step 3" ) - ]} + ] + ) + + playbook_version_description = "v1.0" + return { "agent_id": agent_id, "playbook_id": playbook_id, "goal": goal, - "instructions": instructions, - "instructions_proto": instructions_proto + "instructions_list": instructions_list, + "instructions_str": instructions_str, + "instructions_proto_from_list": instructions_proto_from_list, + "instructions_proto_from_str": instructions_proto_from_str, + "playbook_version_description": playbook_version_description } @pytest.fixture -def mock_playbook_obj(test_config): +def mock_playbook_obj_empty_instructions(test_config): + return types.Playbook( + name=test_config["playbook_id"], + display_name="mock playbook", + goal=test_config["goal"] + ) + +@pytest.fixture +def mock_playbook_obj_list(test_config): return types.Playbook( name=test_config["playbook_id"], display_name="mock playbook", goal=test_config["goal"], - instruction=test_config["instructions_proto"] + instruction=test_config["instructions_proto_from_list"] + ) + +@pytest.fixture +def mock_playbook_obj_str(test_config): + return types.Playbook( + name=test_config["playbook_id"], + display_name="mock playbook", + goal=test_config["goal"], + instruction=test_config["instructions_proto_from_str"] + ) + +@pytest.fixture +def mock_playbook_version_obj_no_description( + test_config, mock_playbook_obj_empty_instructions): + return types.PlaybookVersion( + name=test_config["playbook_id"], + playbook=mock_playbook_obj_empty_instructions, + ) + +@pytest.fixture +def mock_playbook_version_obj_with_description( + test_config, mock_playbook_obj_empty_instructions): + return types.PlaybookVersion( + name=test_config["playbook_id"], + description=test_config["playbook_version_description"], + playbook=mock_playbook_obj_empty_instructions, ) @@ -64,24 +163,44 @@ def mock_agent_obj(test_config): ) @pytest.fixture -def mock_updated_playbook_obj(mock_playbook_obj): - mock_playbook_obj.display_name = "mock playbook updated" - return mock_playbook_obj +def mock_updated_playbook_obj(mock_playbook_obj_list): + mock_playbook_obj_list.display_name = "mock playbook updated" + return mock_playbook_obj_list @pytest.fixture -def mock_list_playbooks_pager(mock_playbook_obj): +def mock_list_playbooks_pager(mock_playbook_obj_list): return services.playbooks.pagers.ListPlaybooksPager( services.playbooks.PlaybooksClient.list_playbooks, types.playbook.ListPlaybooksRequest(), - types.playbook.ListPlaybooksResponse(playbooks=[mock_playbook_obj]), + types.playbook.ListPlaybooksResponse( + playbooks=[mock_playbook_obj_list]), + ) + + +@pytest.fixture +def mock_playbooks(monkeypatch, test_config): + """Fixture to create a Playbooks object with a mocked PlaybooksClient.""" + mock_playbooks_client = MagicMock() + monkeypatch.setattr( + "dfcx_scrapi.core.playbooks.services.playbooks.PlaybooksClient", + mock_playbooks_client ) + mock_agents_client = MagicMock() + monkeypatch.setattr( + "dfcx_scrapi.core.agents.services.agents.AgentsClient", + mock_agents_client + ) + + playbooks = Playbooks(agent_id=test_config["agent_id"]) + yield playbooks, mock_playbooks_client, mock_agents_client + + # Test get_playbooks_map -@patch("dfcx_scrapi.core.playbooks.services.playbooks.PlaybooksClient") -def test_get_playbooks_map(mock_client, mock_list_playbooks_pager, test_config): +def test_get_playbooks_map(mock_playbooks, mock_list_playbooks_pager, test_config): + pb, mock_client, _ = mock_playbooks mock_client.return_value.list_playbooks.return_value = mock_list_playbooks_pager # pylint: disable=C0301 - pb = Playbooks(agent_id=test_config["agent_id"]) res = pb.get_playbooks_map(agent_id=test_config["agent_id"]) assert isinstance(res, dict) @@ -90,10 +209,9 @@ def test_get_playbooks_map(mock_client, mock_list_playbooks_pager, test_config): # Test list_playbooks -@patch("dfcx_scrapi.core.playbooks.services.playbooks.PlaybooksClient") -def test_list_playbooks(mock_client, mock_list_playbooks_pager, test_config): +def test_list_playbooks(mock_playbooks, mock_list_playbooks_pager, test_config): + pb, mock_client, _ = mock_playbooks mock_client.return_value.list_playbooks.return_value = mock_list_playbooks_pager # pylint: disable=C0301 - pb = Playbooks(agent_id=test_config["agent_id"]) res = pb.list_playbooks() assert isinstance(res, list) @@ -101,10 +219,9 @@ def test_list_playbooks(mock_client, mock_list_playbooks_pager, test_config): # Test get_playbook -@patch("dfcx_scrapi.core.playbooks.services.playbooks.PlaybooksClient") -def test_get_playbook(mock_client, mock_playbook_obj, test_config): - mock_client.return_value.get_playbook.return_value = mock_playbook_obj - pb = Playbooks(agent_id=test_config["agent_id"]) +def test_get_playbook(mock_playbooks, mock_playbook_obj_list, test_config): + pb, mock_client, _ = mock_playbooks + mock_client.return_value.get_playbook.return_value = mock_playbook_obj_list res = pb.get_playbook(playbook_id=test_config["playbook_id"]) assert isinstance(res, types.Playbook) @@ -112,41 +229,53 @@ def test_get_playbook(mock_client, mock_playbook_obj, test_config): # Test create_playbook -@patch("dfcx_scrapi.core.playbooks.services.playbooks.PlaybooksClient") -def test_create_playbook_from_kwargs( - mock_client, mock_playbook_obj, test_config): - mock_client.return_value.create_playbook.return_value = mock_playbook_obj - pb = Playbooks(agent_id=test_config["agent_id"]) +def test_create_playbook_from_kwargs_instruction_list( + mock_playbooks, mock_playbook_obj_list, test_config): + pb, mock_client, _ = mock_playbooks + mock_client.return_value.create_playbook.return_value = mock_playbook_obj_list # pylint: disable=C0301 + res = pb.create_playbook( + agent_id=test_config["agent_id"], + display_name="mock playbook", + goal=test_config["goal"], + instructions=test_config["instructions_list"] + ) + assert isinstance(res, types.Playbook) + assert res.display_name == "mock playbook" + assert res.instruction == test_config["instructions_proto_from_list"] + +def test_create_playbook_from_kwargs_instruction_str( + mock_playbooks, mock_playbook_obj_str, test_config): + pb, mock_client, _ = mock_playbooks + mock_client.return_value.create_playbook.return_value = mock_playbook_obj_str # pylint: disable=C0301 res = pb.create_playbook( agent_id=test_config["agent_id"], display_name="mock playbook", goal=test_config["goal"], - instructions=test_config["instructions"] + instructions=test_config["instructions_str"] ) assert isinstance(res, types.Playbook) assert res.display_name == "mock playbook" + assert res.instruction == test_config["instructions_proto_from_str"] -@patch("dfcx_scrapi.core.playbooks.services.playbooks.PlaybooksClient") def test_create_playbook_from_proto_object( - mock_client, mock_playbook_obj, test_config): - mock_client.return_value.create_playbook.return_value = mock_playbook_obj - pb = Playbooks(agent_id=test_config["agent_id"]) + mock_playbooks, mock_playbook_obj_list, test_config): + pb, mock_client, _ = mock_playbooks + mock_client.return_value.create_playbook.return_value = mock_playbook_obj_list # pylint: disable=C0301 res = pb.create_playbook( agent_id=test_config["agent_id"], - obj=mock_playbook_obj + obj=mock_playbook_obj_list ) assert isinstance(res, types.Playbook) assert res.display_name == "mock playbook" # Test update_playbook -@patch("dfcx_scrapi.core.playbooks.services.playbooks.PlaybooksClient") def test_update_playbook_with_obj( - mock_client, mock_updated_playbook_obj, test_config): + mock_playbooks, mock_updated_playbook_obj, test_config): + pb, mock_client, _ = mock_playbooks mock_client.return_value.update_playbook.return_value = ( mock_updated_playbook_obj ) - pb = Playbooks(agent_id=test_config["agent_id"]) res = pb.update_playbook( playbook_id=test_config["playbook_id"], obj=mock_updated_playbook_obj @@ -156,12 +285,11 @@ def test_update_playbook_with_obj( assert res.display_name == "mock playbook updated" -@patch("dfcx_scrapi.core.playbooks.services.playbooks.PlaybooksClient") def test_update_playbook_with_kwargs( - mock_client, mock_playbook_obj, test_config): - mock_client.return_value.get_playbook.return_value = mock_playbook_obj - mock_client.return_value.update_playbook.return_value = mock_playbook_obj - pb = Playbooks(agent_id=test_config["agent_id"]) + mock_playbooks, mock_playbook_obj_list, test_config): + pb, mock_client, _ = mock_playbooks + mock_client.return_value.get_playbook.return_value = mock_playbook_obj_list + mock_client.return_value.update_playbook.return_value = mock_playbook_obj_list # pylint: disable=C0301 res = pb.update_playbook( playbook_id=test_config["playbook_id"], display_name="mock playbook updated" @@ -170,20 +298,163 @@ def test_update_playbook_with_kwargs( assert isinstance(res, types.Playbook) assert res.display_name == "mock playbook updated" +# Test the playbook kwarg processing helper methods +def test_process_playbook_kwargs_display_name( + mock_playbooks, mock_playbook_obj_str, mock_updated_playbook_obj): + pb, _, _ = mock_playbooks + kwargs = {"display_name": "mock playbook updated"} + + expected_mask = field_mask_pb2.FieldMask(paths=["display_name"]) + playbook, mask = pb.process_playbook_kwargs(mock_playbook_obj_str, **kwargs) + + assert mock_updated_playbook_obj.display_name == playbook.display_name + assert expected_mask == mask + +def test_process_playbook_kwargs_instruction_list( + mock_playbooks, mock_playbook_obj_empty_instructions, + mock_playbook_obj_list, test_config): + pb, _, _ = mock_playbooks + + # patch the object so we can track the internal method call + with patch.object( + pb, "build_instructions_from_list", + wraps=pb.build_instructions_from_list) as mock_build_instructions: + + kwargs = {"instructions": test_config["instructions_list"]} + expected_mask = field_mask_pb2.FieldMask(paths=["instruction"]) + + playbook, mask = pb.process_playbook_kwargs( + mock_playbook_obj_empty_instructions, **kwargs) + + assert mock_playbook_obj_list.instruction == playbook.instruction + assert expected_mask == mask + mock_build_instructions.assert_called_once_with( + test_config["instructions_list"]) + +def test_process_playbook_kwargs_instruction_str( + mock_playbooks, mock_playbook_obj_empty_instructions, + mock_playbook_obj_str, test_config): + pb, _, _ = mock_playbooks + + # patch the object so we can track the internal method call + with patch.object( + pb, "build_instructions_from_string", + wraps=pb.build_instructions_from_string) as mock_build_instructions: + + kwargs = {"instructions": test_config["instructions_str"]} + expected_mask = field_mask_pb2.FieldMask(paths=["instruction"]) + + playbook, mask = pb.process_playbook_kwargs( + mock_playbook_obj_empty_instructions, **kwargs) + + assert mock_playbook_obj_str.instruction == playbook.instruction + assert expected_mask == mask + mock_build_instructions.assert_called_once_with( + test_config["instructions_str"] + ) + +def test_process_playbook_kwargs_instruction_obj( + mock_playbooks, mock_playbook_obj_empty_instructions, + mock_playbook_obj_str, test_config): + pb, _, _ = mock_playbooks + kwargs = {"instructions": test_config["instructions_proto_from_str"]} + expected_mask = field_mask_pb2.FieldMask(paths=["instruction"]) + + playbook, mask = pb.process_playbook_kwargs( + mock_playbook_obj_empty_instructions, **kwargs) + + assert mock_playbook_obj_str.instruction == playbook.instruction + assert expected_mask == mask # Test delete_playbook -@patch("dfcx_scrapi.core.playbooks.services.playbooks.PlaybooksClient") -def test_delete_playbook(mock_client, test_config): - pb = Playbooks(agent_id=test_config["agent_id"]) +def test_delete_playbook(mock_playbooks, test_config): + pb, mock_client, _ = mock_playbooks pb.delete_playbook(playbook_id=test_config["playbook_id"]) mock_client.return_value.delete_playbook.assert_called() - # Test set_default_playbook -@patch("dfcx_scrapi.core.playbooks.services.agents.AgentsClient") -def test_set_default_playbook(mock_client, mock_agent_obj, test_config): - mock_client.return_value.get_agent.return_value = mock_agent_obj - mock_client.return_value.update_agent.return_value = mock_agent_obj - pb = Playbooks(agent_id=test_config["agent_id"]) +def test_set_default_playbook(mock_playbooks, mock_agent_obj, test_config): + pb, _, agent_client = mock_playbooks + agent_client.return_value.get_agent.return_value = mock_agent_obj + agent_client.return_value.update_agent.return_value = mock_agent_obj pb.set_default_playbook(playbook_id=test_config["playbook_id"]) + assert mock_agent_obj.start_playbook == test_config["playbook_id"] + +# Test build instruction helpers +def test_build_instructions_from_list(mock_playbooks, test_config): + pb, _, _ = mock_playbooks + res = pb.build_instructions_from_list( + instructions=test_config["instructions_list"]) + + assert res == test_config["instructions_proto_from_list"] + +def test_build_instructions_from_str(mock_playbooks, test_config): + pb, _, _ = mock_playbooks + res = pb.build_instructions_from_string( + instructions=test_config["instructions_str"]) + + assert res == test_config["instructions_proto_from_str"] + +def test_parse_steps_simple_list(mock_playbooks): + pb, _, _ = mock_playbooks + + lines = [ + "Step 1", + "Step 2", + "Step 3" + ] + + expected_steps = [ + types.Playbook.Step(text="Step 1"), + types.Playbook.Step(text="Step 2"), + types.Playbook.Step(text="Step 3") + ] + + steps, next_index = pb.parse_steps(lines, 0, 0) + assert steps == expected_steps + assert next_index == 3 + +def test_parse_steps_nested_list(mock_playbooks, test_config): + pb, _, _ = mock_playbooks + + lines = [ + "- Step 1", + " - Step 1.1", + "- Step 2", + " - Step 2.1", + " - Step 2.1.1", + " - Step 2.1.2", + " - Step 2.1.2.1", + "- Step 3" + ] + + steps, next_index = pb.parse_steps(lines, 0, 0) + assert steps == test_config["instructions_proto_from_str"].steps + assert next_index == 8 + +def test_create_playbook_version_no_description( + mock_playbooks, test_config, mock_playbook_version_obj_no_description): + pb, mock_client, _ = mock_playbooks + + mock_client.return_value.create_playbook_version.return_value = mock_playbook_version_obj_no_description + + res = pb.create_playbook_version(playbook_id=test_config["playbook_id"]) + + mock_client.return_value.create_playbook_version.assert_called() + assert isinstance(res, types.PlaybookVersion) + assert res.playbook.name == test_config["playbook_id"] + assert res.description == "" + +def test_create_playbook_version_with_description( + mock_playbooks, test_config, mock_playbook_version_obj_with_description): + pb, mock_client, _ = mock_playbooks + + mock_client.return_value.create_playbook_version.return_value = mock_playbook_version_obj_with_description + + res = pb.create_playbook_version(playbook_id=test_config["playbook_id"]) + + mock_client.return_value.create_playbook_version.assert_called() + assert isinstance(res, types.PlaybookVersion) + assert res.playbook.name == test_config["playbook_id"] + assert res.description == test_config["playbook_version_description"] diff --git a/tests/dfcx_scrapi/tools/test_dataframe_functions.py b/tests/dfcx_scrapi/tools/test_dataframe_functions.py new file mode 100644 index 00000000..1fdf655e --- /dev/null +++ b/tests/dfcx_scrapi/tools/test_dataframe_functions.py @@ -0,0 +1,156 @@ +"""Test Class for Dataframe Function Methods in SCRAPI.""" + +# pylint: disable=redefined-outer-name + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from unittest.mock import MagicMock + +from google.oauth2.service_account import Credentials +from dfcx_scrapi.tools.dataframe_functions import DataframeFunctions + + +@pytest.fixture +def test_config(): + project_id = "my-project-id-1234" + email = "mock_email@testing.com" + creds_path = "/Users/path/to/creds/credentials.json" + creds_dict = { + "type": "service_account", + "project_id": project_id, + "private_key_id": "1234", + "private_key": "mock_key", + "client_email": f"mock-account@{project_id}.iam.gserviceaccount.com", + "client_id": "1234", + "universe_domain": "googleapis.com", + } + + mock_signer = MagicMock() + mock_signer.key_id = "mock_key_id" + mock_signer.sign.return_value = b"mock_signature" + + creds_object = Credentials( + signer=mock_signer, + token_uri="mock_token_uri", + service_account_email=email, + project_id=project_id, + quota_project_id=project_id, + scopes=[], + ) + + return { + "project_id": project_id, + "creds_path": creds_path, + "creds_dict": creds_dict, + "creds_object": creds_object, + } + + +@pytest.fixture +def mock_dffx_setup(monkeypatch, test_config): + """Fixture to create mock DataframeFunctions object w/ mocked clients.""" + + # mocking all other classes used by DFFX + mock_credentials_from_file = MagicMock( + return_value=test_config["creds_object"] + ) + + monkeypatch.setattr( + "google.oauth2.service_account.Credentials.from_service_account_file", + mock_credentials_from_file, + ) + + # mocking all other classes used by DFFX + def mock_scrapi_base_init(self, *args, **kwargs): + # Simulate the original behavior + if kwargs.get("creds_path"): + self.creds = Credentials.from_service_account_file( + kwargs.get("creds_path") + ) + elif kwargs.get("creds_dict"): + self.creds = Credentials.from_service_account_info( + kwargs.get("creds_dict") + ) + else: + self.creds = kwargs.get("creds") + + def mock_entities_init(self, *args, **kwargs): + pass + + def mock_intents_init(self, *args, **kwargs): + pass + + def mock_flows_init(self, *args, **kwargs): + pass + + def mock_pages_init(self, *args, **kwargs): + pass + + def mock_route_groups_init(self, *args, **kwargs): + pass + + monkeypatch.setattr( + "dfcx_scrapi.core.scrapi_base.ScrapiBase.__init__", + mock_scrapi_base_init, + ) + + monkeypatch.setattr( + "dfcx_scrapi.core.entity_types.EntityTypes.__init__", mock_entities_init + ) + + monkeypatch.setattr( + "dfcx_scrapi.core.intents.Intents.__init__", mock_intents_init + ) + + monkeypatch.setattr( + "dfcx_scrapi.core.flows.Flows.__init__", mock_flows_init + ) + + monkeypatch.setattr( + "dfcx_scrapi.core.pages.Pages.__init__", mock_pages_init + ) + + monkeypatch.setattr( + "dfcx_scrapi.core.transition_route_groups.TransitionRouteGroups.__init__", + mock_route_groups_init, + ) + + yield mock_credentials_from_file + + +# Test init with creds_path +def test_dffx_init_creds_path(mock_dffx_setup, test_config): + mock_creds = mock_dffx_setup + dffx = DataframeFunctions(creds_path=test_config["creds_path"]) + + assert dffx.creds == test_config["creds_object"] + mock_creds.assert_called_once_with(test_config["creds_path"]) + + +# Test init with creds_dict +def test_dffx_init_creds_dict(mock_dffx_setup, test_config): + mock_creds = mock_dffx_setup + dffx = DataframeFunctions(creds_path=test_config["creds_dict"]) + + assert dffx.creds == test_config["creds_object"] + mock_creds.assert_called_once_with(test_config["creds_dict"]) + + +# Test init with creds object +def test_dffx_init_creds_object(mock_dffx_setup, test_config): + dffx = DataframeFunctions(creds=test_config["creds_object"]) + + assert dffx.creds == test_config["creds_object"]