Skip to content

Commit

Permalink
Merge pull request #260 from patraxo/develop
Browse files Browse the repository at this point in the history
added fix for claude 3, included messageAPI

Tested the functionality by deploying and asking a question in the Generative AI Query. It requires updating the Lambda function's configuration (BEDROCK_MODEL_ID	= anthropic.claude-3-sonnet-20240229-v1:0). Will need to make this a configurable value via CloudFormation parameter (will be addressed as part of updating models for summarization (#259).
  • Loading branch information
kishd authored May 14, 2024
2 parents 6709d5d + 006aa4e commit 5355062
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 9 deletions.
2 changes: 1 addition & 1 deletion pca-boto3-bedrock/template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Description: >
Parameters:
Boto3Version:
Type: String
Default: "1.34.40"
Default: "1.34.101"

Resources:

Expand Down
2 changes: 1 addition & 1 deletion pca-server/src/pca/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
boto3==1.15.7
boto3==1.34.101
27 changes: 20 additions & 7 deletions pca-ui/src/genai/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,18 @@ def get_bedrock_request_body(modelId, parameters, prompt):
provider = modelId.split(".")[0]
request_body = None
if provider == "anthropic":
request_body = {
"prompt": prompt,
"max_tokens_to_sample": MAX_TOKENS
}
print(modelId)
if 'claude-3' in modelId:
request_body = {
"max_tokens": MAX_TOKENS,
"messages": [{"role": "user", "content": prompt}],
"anthropic_version": "bedrock-2023-05-31"
}
else:
request_body = {
"prompt": prompt,
"max_tokens_to_sample": MAX_TOKENS
}
request_body.update(parameters)
elif provider == "ai21":
request_body = {
Expand All @@ -80,8 +88,13 @@ def get_bedrock_generate_text(modelId, response):
provider = modelId.split(".")[0]
generated_text = None
if provider == "anthropic":
response_body = json.loads(response.get("body").read().decode())
generated_text = response_body.get("completion")
if 'claude-3' in modelId:
response_raw = json.loads(response.get("body").read().decode())
generated_text = response_raw.get('content')[0].get('text')

else:
response_body = json.loads(response.get("body").read().decode())
generated_text = response_body.get("completion")
elif provider == "ai21":
response_body = json.loads(response.get("body").read())
generated_text = response_body.get("completions")[0].get("data").get("text")
Expand Down Expand Up @@ -242,4 +255,4 @@ def lambda_handler(event, context):
"response":query_response
})
}
return response
return response

0 comments on commit 5355062

Please sign in to comment.