Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test(llms/bedrock): update testcase and update aws go sdk version for add claude-3-5 #915

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ jobs:
# https://github.com/actions/setup-go#caching-dependency-files-and-build-outputs
cache: false
- name: golangci-lint
uses: golangci/golangci-lint-action@v4
uses: golangci/golangci-lint-action@v6.0.1
with:
args: --timeout=4m
version: v1.57.2
version: v1.59.1
build-examples:
runs-on: ubuntu-latest
steps:
Expand Down
6 changes: 5 additions & 1 deletion .golangci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ linters:
- varnamelen
- nlreturn
- gomnd
- goerr113
- err113
- wrapcheck # TODO: we should probably enable this one (at least for new code).
- testpackage
- nolintlint # see https://github.com/golangci/golangci-lint/issues/3228.
Expand All @@ -31,6 +31,10 @@ linters:
- perfsprint
- musttag
- tagalign # Impractical for schema-defined output parser, which relies heavily on struct tagging.
- mnd
- canonicalheader
- intrange
- testifylint

linters-settings:
cyclop:
Expand Down
8 changes: 4 additions & 4 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ require (
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect
github.com/aws/aws-sdk-go-v2/credentials v1.17.12 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.1 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.12 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.12 // indirect
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.7 // indirect
Expand Down Expand Up @@ -190,9 +190,9 @@ require (
github.com/Masterminds/sprig/v3 v3.2.3
github.com/PuerkitoBio/goquery v1.8.1
github.com/amikos-tech/chroma-go v0.1.2
github.com/aws/aws-sdk-go-v2 v1.26.1
github.com/aws/aws-sdk-go-v2 v1.30.0
github.com/aws/aws-sdk-go-v2/config v1.27.12
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.1
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.11.0
github.com/cohere-ai/tokenizer v1.1.2
github.com/fatih/color v1.17.0
github.com/gage-technologies/mistral-go v1.0.0
Expand Down
16 changes: 8 additions & 8 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ github.com/asaskevich/govalidator v0.0.0-20200907205600-7a23bdc65eef/go.mod h1:W
github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3dyBCFEj5IhUbnKptjxatkF07cF2ak3yi77so=
github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw=
github.com/aws/aws-sdk-go v1.42.27/go.mod h1:OGr6lGMAKGlG9CVrYnWYDKIyb829c6EVBRjxqjmPepc=
github.com/aws/aws-sdk-go-v2 v1.26.1 h1:5554eUqIYVWpU0YmeeYZ0wU64H2VLBs8TlhRB2L+EkA=
github.com/aws/aws-sdk-go-v2 v1.26.1/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM=
github.com/aws/aws-sdk-go-v2 v1.30.0 h1:6qAwtzlfcTtcL8NHtbDQAqgM5s6NDipQTkPxyH/6kAA=
github.com/aws/aws-sdk-go-v2 v1.30.0/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 h1:x6xsQXGSmW6frevwDA+vi/wqhp1ct18mVXYN08/93to=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2/go.mod h1:lPprDr1e6cJdyYeGXnRaJoP4Md+cDBvi2eOj00BlGmg=
github.com/aws/aws-sdk-go-v2/config v1.27.12 h1:vq88mBaZI4NGLXk8ierArwSILmYHDJZGJOeAc/pzEVQ=
Expand All @@ -85,14 +85,14 @@ github.com/aws/aws-sdk-go-v2/credentials v1.17.12 h1:PVbKQ0KjDosI5+nEdRMU8ygEQDm
github.com/aws/aws-sdk-go-v2/credentials v1.17.12/go.mod h1:jlWtGFRtKsqc5zqerHZYmKmRkUXo3KPM14YJ13ZEjwE=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.1 h1:FVJ0r5XTHSmIHJV6KuDmdYhEpvlHpiSd38RQWhut5J4=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.1/go.mod h1:zusuAeqezXzAB24LGuzuekqMAEgWkVYukBec3kr3jUg=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 h1:aw39xVGeRWlWx9EzGVnhOR4yOjQDHPQ6o6NmBlscyQg=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5/go.mod h1:FSaRudD0dXiMPK2UjknVwwTYyZMRsHv3TtkabsZih5I=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 h1:PG1F3OD1szkuQPzDw3CIQsRIrtTlUC3lP84taWzHlq0=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5/go.mod h1:jU1li6RFryMz+so64PpKtudI+QzbKoIEivqdf6LNpOc=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.12 h1:SJ04WXGTwnHlWIODtC5kJzKbeuHt+OUNOgKg7nfnUGw=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.12/go.mod h1:FkpvXhA92gb3GE9LD6Og0pHHycTxW7xGpnEh5E7Opwo=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.12 h1:hb5KgeYfObi5MHkSSZMEudnIvX30iB+E21evI4r6BnQ=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.12/go.mod h1:CroKe/eWJdyfy9Vx4rljP5wTUjNJfb+fPz1uMYUhEGM=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 h1:hT8rVHwugYE2lEfdFE0QWVo81lF7jMrYJVDWI+f+VxU=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0/go.mod h1:8tu/lYfQfFe6IGnaOdrpVgEL2IrrDOf6/m9RQum4NkY=
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.1 h1:vTHgBjsGhgKWWIgioxd7MkBH5Ekr8C6Cb+/8iWf1dpc=
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.1/go.mod h1:nZspkhg+9p8iApLFoyAqfyuMP0F38acy2Hm3r5r95Cg=
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.11.0 h1:wHTY1k+myd0QIZevhf2XiKF4rLs37vlLguJV6LFjUQ0=
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.11.0/go.mod h1:vHk9LI9clsbT8DYUmHtBxinKBlnp4XvxqyaCXA7J2bY=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2 h1:Ji0DY1xUsUr3I8cHps0G+XM3WWU16lP6yG8qu1GAZAs=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2/go.mod h1:5CsjAbs3NlGQyZNFACh+zztPDI7fU6eW9QsxjfnuBKg=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.7 h1:ogRAwT1/gxJBcSWDMZlgyFUM962F51A5CRhDLbxLdmo=
Expand Down
5 changes: 5 additions & 0 deletions llms/bedrock/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
## Resource

- [Amazon Bedrock Runtime - Amazon Bedrock](https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Operations_Amazon_Bedrock_Runtime.html)
- [Use the Converse API - Amazon Bedrock](https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html)
- [aws-sdk-go-v2/service/bedrockruntime/CHANGELOG.md at main · aws/aws-sdk-go-v2](https://github.com/aws/aws-sdk-go-v2/blob/main/service/bedrockruntime/CHANGELOG.md)
38 changes: 27 additions & 11 deletions llms/bedrock/bedrockllm.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package bedrock
import (
"context"
"errors"
"strings"

"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
Expand Down Expand Up @@ -59,7 +60,11 @@ func (l *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOptio
}

// GenerateContent implements llms.Model.
func (l *LLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) {
func (l *LLM) GenerateContent(
ctx context.Context,
messages []llms.MessageContent,
options ...llms.CallOption,
) (*llms.ContentResponse, error) {
if l.CallbacksHandler != nil {
l.CallbacksHandler.HandleLLMGenerateContentStart(ctx, messages)
}
Expand All @@ -71,17 +76,28 @@ func (l *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten
opt(&opts)
}

m, err := processMessages(messages)
if err != nil {
return nil, err
}

res, err := l.client.CreateCompletion(ctx, opts.Model, m, opts)
if err != nil {
if l.CallbacksHandler != nil {
l.CallbacksHandler.HandleLLMError(ctx, err)
var res *llms.ContentResponse
if strings.HasPrefix(opts.Model, "anthropic") { // nolint: nestif
var err error
res, err = l.client.Converse(ctx, opts.Model, messages, &opts)
if err != nil {
if l.CallbacksHandler != nil {
l.CallbacksHandler.HandleLLMError(ctx, err)
}
return nil, err
}
} else {
m, err := processMessages(messages)
if err != nil {
return nil, err
}
res, err = l.client.CreateCompletion(ctx, opts.Model, m, opts)
if err != nil {
if l.CallbacksHandler != nil {
l.CallbacksHandler.HandleLLMError(ctx, err)
}
return nil, err
}
return nil, err
}

if l.CallbacksHandler != nil {
Expand Down
146 changes: 119 additions & 27 deletions llms/bedrock/bedrockllm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ import (
"context"
"os"
"testing"
"unicode/utf8"

"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/stretchr/testify/assert"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/llms/bedrock"
)
Expand All @@ -22,11 +24,9 @@ func setUpTest() (*bedrockruntime.Client, error) {

func TestAmazonOutput(t *testing.T) {
t.Parallel()

if os.Getenv("TEST_AWS") != "true" {
t.Skip("Skipping test, requires AWS access")
}

client, err := setUpTest()
if err != nil {
t.Fatal(err)
Expand All @@ -51,36 +51,128 @@ func TestAmazonOutput(t *testing.T) {
},
}

type testcase struct {
model string
}

// All the test models.
models := []string{
bedrock.ModelAi21J2MidV1,
bedrock.ModelAi21J2UltraV1,
bedrock.ModelAmazonTitanTextLiteV1,
bedrock.ModelAmazonTitanTextExpressV1,
bedrock.ModelAnthropicClaudeV3Sonnet,
bedrock.ModelAnthropicClaudeV3Haiku,
bedrock.ModelAnthropicClaudeV21,
bedrock.ModelAnthropicClaudeV2,
bedrock.ModelAnthropicClaudeInstantV1,
bedrock.ModelCohereCommandTextV14,
bedrock.ModelCohereCommandLightTextV14,
bedrock.ModelMetaLlama213bChatV1,
bedrock.ModelMetaLlama270bChatV1,
bedrock.ModelMetaLlama38bInstructV1,
bedrock.ModelMetaLlama370bInstructV1,
tests := []testcase{
{model: bedrock.ModelAi21J2MidV1},
{model: bedrock.ModelAi21J2UltraV1},
{model: bedrock.ModelAmazonTitanTextLiteV1},
{model: bedrock.ModelAmazonTitanTextExpressV1},
{model: bedrock.ModelAnthropicClaudeV3Sonnet},
{model: bedrock.ModelAnthropicClaudeV3Haiku},
{model: bedrock.ModelAnthropicClaudeV21},
{model: bedrock.ModelAnthropicClaudeV2},
{model: bedrock.ModelAnthropicClaudeV35Sonnet},
{model: bedrock.ModelAnthropicClaudeInstantV1},
{model: bedrock.ModelCohereCommandTextV14},
{model: bedrock.ModelCohereCommandLightTextV14},
{model: bedrock.ModelMetaLlama213bChatV1},
{model: bedrock.ModelMetaLlama270bChatV1},
{model: bedrock.ModelMetaLlama38bInstructV1},
{model: bedrock.ModelMetaLlama370bInstructV1},
}

ctx := context.Background()

for _, model := range models {
t.Logf("Model output for %s:-", model)
for _, tt := range tests {
t.Run(tt.model, func(t *testing.T) {
t.Parallel()
t.Logf("Model output for %s:-", tt.model)
resp, err := llm.GenerateContent(ctx, msgs, llms.WithModel(tt.model), llms.WithMaxTokens(512))
if err != nil {
t.Fatal(err)
}
assert.NotEmpty(t, resp.Choices)
for i, choice := range resp.Choices {
t.Logf("Choice %d: %s", i, choice.Content)
}
assert.Greater(t, utf8.RuneCountInString(resp.Choices[0].Content), 5)
})
}
}

func TestBedrockConverseStream(t *testing.T) {
t.Parallel()
if os.Getenv("TEST_AWS") != "true" {
t.Skip("Skipping test, requires AWS access")
}

client, err := setUpTest()
if err != nil {
t.Fatal(err)
}
llm, err := bedrock.New(bedrock.WithClient(client))
if err != nil {
t.Fatalf("%v", err)
}

msgs := []llms.MessageContent{
{
Role: llms.ChatMessageTypeSystem,
Parts: []llms.ContentPart{
llms.TextPart("You know all about AI."),
},
},
{
Role: llms.ChatMessageTypeHuman,
Parts: []llms.ContentPart{
llms.TextPart("Explain AI in 10 words or less."),
},
},
}

type testcase struct {
model string
}

// All the test models.
tests := []testcase{
{model: bedrock.ModelAi21J2MidV1},
{model: bedrock.ModelAi21J2UltraV1},
{model: bedrock.ModelAmazonTitanTextLiteV1},
{model: bedrock.ModelAmazonTitanTextExpressV1},
{model: bedrock.ModelAnthropicClaudeV3Sonnet},
{model: bedrock.ModelAnthropicClaudeV3Haiku},
{model: bedrock.ModelAnthropicClaudeV21},
{model: bedrock.ModelAnthropicClaudeV2},
{model: bedrock.ModelAnthropicClaudeV35Sonnet},
{model: bedrock.ModelAnthropicClaudeInstantV1},
{model: bedrock.ModelCohereCommandTextV14},
{model: bedrock.ModelCohereCommandLightTextV14},
{model: bedrock.ModelMetaLlama213bChatV1},
{model: bedrock.ModelMetaLlama270bChatV1},
{model: bedrock.ModelMetaLlama38bInstructV1},
{model: bedrock.ModelMetaLlama370bInstructV1},
}

ctx := context.Background()
streamFunc := func(_ context.Context, _ []byte) error {
// t.Logf("Stream chunk: %s", string(chunk))
return nil
}

resp, err := llm.GenerateContent(ctx, msgs, llms.WithModel(model), llms.WithMaxTokens(512))
if err != nil {
t.Fatal(err)
}
for i, choice := range resp.Choices {
t.Logf("Choice %d: %s", i, choice.Content)
}
for _, tt := range tests {
t.Run(tt.model, func(t *testing.T) {
t.Parallel()
t.Logf("Model output for %s:-", tt.model)
resp, err := llm.GenerateContent(
ctx,
msgs,
llms.WithModel(tt.model),
llms.WithMaxTokens(512),
llms.WithStreamingFunc(streamFunc),
)
if err != nil {
t.Fatal(err)
}
assert.NotEmpty(t, resp.Choices)
for i, choice := range resp.Choices {
t.Logf("Choice %d: %s", i, choice.Content)
}
assert.Greater(t, utf8.RuneCountInString(resp.Choices[0].Content), 5)
})
}
}
Loading
Loading