Skip to content

Commit

Permalink
工作流自定义组件支持SDK调用 (#688)
Browse files Browse the repository at this point in the history
* 增加工作流Agent cookbook的链接

* 组件增加sdk调用
  • Loading branch information
userpj authored Jan 2, 2025
1 parent c2e7a38 commit f16d544
Show file tree
Hide file tree
Showing 21 changed files with 1,157 additions and 49 deletions.
92 changes: 92 additions & 0 deletions go/appbuilder/component_client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// Copyright (c) 2024 Baidu, Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package appbuilder

import (
"bufio"
"bytes"
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"time"
)

type ComponentClient struct {
sdkConfig *SDKConfig
client HTTPClient
}

func NewComponentClient(config *SDKConfig) (*ComponentClient, error) {
if config == nil {
return nil, errors.New("config is nil")
}
client := config.HTTPClient
if client == nil {
client = &http.Client{Timeout: 300 * time.Second}
}
return &ComponentClient{sdkConfig: config, client: client}, nil
}

func (t *ComponentClient) Run(component, version, action string, stream bool, parameters map[string]any) (ComponentClientIterator, error) {
request := http.Request{}

urlSuffix := fmt.Sprintf("/components/%s", component)
if version != "" {
urlSuffix += fmt.Sprintf("/version/%s", version)
}
if action != "" {
if strings.Contains(urlSuffix, "?") {
urlSuffix += fmt.Sprintf("&action=%s", action)
} else {
urlSuffix += fmt.Sprintf("?action=%s", action)
}
}

serviceURL, err := t.sdkConfig.ServiceURLV2(urlSuffix)
if err != nil {
return nil, err
}

header := t.sdkConfig.AuthHeaderV2()
request.URL = serviceURL
request.Method = "POST"
header.Set("Content-Type", "application/json")
request.Header = header

req := ComponentRunRequest{
Stream: stream,
Parameters: parameters,
}
data, _ := json.Marshal(req)
request.Body = NopCloser(bytes.NewReader(data))
request.ContentLength = int64(len(data)) // 手动设置长度

t.sdkConfig.BuildCurlCommand(&request)
resp, err := t.client.Do(&request)
if err != nil {
return nil, err
}
requestID, err := checkHTTPResponse(resp)
if err != nil {
return nil, fmt.Errorf("requestID=%s, err=%v", requestID, err)
}
r := NewSSEReader(1024*1024, bufio.NewReader(resp.Body))
if req.Stream {
return &ComponentClientStreamIterator{requestID: requestID, r: r, body: resp.Body}, nil
}
return &ComponentClientOnceIterator{body: resp.Body}, nil
}
199 changes: 199 additions & 0 deletions go/appbuilder/component_client_data.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
// Copyright (c) 2024 Baidu, Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package appbuilder

import (
"encoding/json"
"fmt"
"io"
"strings"
)

const (
SysOriginQuery = "_sys_origin_query"
SysFileUrls = "_sys_file_urls"
SysConversationID = "_sys_conversation_id"
SysEndUserID = "_sys_end_user_id"
SysChatHistory = "_sys_chat_history"
)

type ComponentRunRequest struct {
Stream bool `json:"stream"`
Parameters map[string]any `json:"parameters"`
}

type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}

type ComponentRunResponse struct {
RequestID string `json:"request_id"`
Code string `json:"code"`
Message string `json:"message"`
Data ComponentRunResponseData `json:"data"`
}

type ComponentRunResponseData struct {
ConversationID string `json:"conversation_id"`
MessageID string `json:"message_id"`
TraceID string `json:"trace_id"`
UserID string `json:"user_id"`
EndUserID string `json:"end_user_id"`
IsCompletion bool `json:"is_completion"`
Role string `json:"role"`
Content []Content `json:"content"`
}

type Content struct {
Name string `json:"name"`
VisibleScope string `json:"visible_scope"`
RawData map[string]any `json:"raw_data"`
Usage map[string]any `json:"usage"`
Metrics map[string]any `json:"metrics"`
Type string `json:"type"`
Text map[string]any `json:"text"`
Event ComponentEvent `json:"event"`
}

type ComponentEvent struct {
ID string `json:"id"`
Status string `json:"status"`
Name string `json:"name"`
CreatedTime string `json:"created_time"`
ErrorCode string `json:"error_code"`
ErrorMessage string `json:"error_message"`
}

type Text struct {
Info string `json:"info"`
}

type Code struct {
Code string `json:"code"`
}

type Files struct {
Filename string `json:"filename"`
Url string `json:"url"`
}

type Urls struct {
Url string `json:"url"`
}

type OralText struct {
Info string `json:"info"`
}

type References struct {
Type string `json:"type"`
Source string `json:"source"`
DocID string `json:"doc_id"`
Title string `json:"title"`
Content string `json:"content"`
Extra map[string]any `json:"extra"`
}

type Image struct {
Filename string `json:"filename"`
Url string `json:"url"`
Byte []byte `json:"byte"`
}

type Chart struct {
Type string `json:"type"`
Data string `json:"data"`
}

type Audio struct {
Filename string `json:"filename"`
Url string `json:"url"`
Byte []byte `json:"byte"`
}

type PlanStep struct {
Name string `json:"name"`
Arguments map[string]any `json:"arguments"`
Thought string `json:"thought"`
}

type Plan struct {
Detail string `json:"detail"`
Steps []PlanStep `json:"steps"`
}

type FunctionCall struct {
Thought string `json:"thought"`
Name string `json:"name"`
Arguments map[string]any `json:"arguments"`
}

type Json struct {
Data string `json:"data"`
}

type ComponentClientIterator interface {
// Next 获取处理结果,如果返回error不为空,迭代器自动失效,不允许再调用此方法
Next() (*ComponentRunResponseData, error)
}

type ComponentClientStreamIterator struct {
requestID string
r *sseReader
body io.ReadCloser
}

func (t *ComponentClientStreamIterator) Next() (*ComponentRunResponseData, error) {
data, err := t.r.ReadMessageLine()
if err != nil && !(err == io.EOF) {
t.body.Close()
return nil, fmt.Errorf("requestID=%s, err=%v", t.requestID, err)
}
if err != nil && err == io.EOF {
t.body.Close()
return nil, err
}
if strings.HasPrefix(string(data), "data:") {
var resp ComponentRunResponse
if err := json.Unmarshal(data[5:], &resp); err != nil {
t.body.Close()
return nil, fmt.Errorf("requestID=%s, err=%v", t.requestID, err)
}
return &resp.Data, nil
}
// 非SSE格式关闭连接,并返回数据
t.body.Close()
return nil, fmt.Errorf("requestID=%s, body=%s", t.requestID, string(data))
}

// ComponentClientOnceIterator 非流式返回时对应的迭代器,只可迭代一次
type ComponentClientOnceIterator struct {
body io.ReadCloser
requestID string
}

func (t *ComponentClientOnceIterator) Next() (*ComponentRunResponseData, error) {
data, err := io.ReadAll(t.body)
if err != nil {
return nil, fmt.Errorf("requestID=%s, err=%v", t.requestID, err)
}
defer t.body.Close()
var resp ComponentRunResponse
if err := json.Unmarshal(data, &resp); err != nil {
return nil, fmt.Errorf("requestID=%s, err=%v", t.requestID, err)
}
return &resp.Data, nil
}
89 changes: 89 additions & 0 deletions go/appbuilder/component_client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Copyright (c) 2024 Baidu, Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package appbuilder

import (
"bytes"
"fmt"
"os"
"testing"
)

func TestComponentClient(t *testing.T) {
var logBuffer bytes.Buffer

// 设置环境变量
os.Setenv("APPBUILDER_LOGLEVEL", "DEBUG")
os.Setenv("APPBUILDER_LOGFILE", "")

// 测试逻辑
config, err := NewSDKConfig("", "")
if err != nil {
t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m")
t.Fatalf("new http client config failed: %v", err)
}

componentID := "44205c67-3980-41f7-aad4-37357b577fd0"
client, err := NewComponentClient(config)
if err != nil {
t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m")
t.Fatalf("new ComponentClient instance failed")
}

parameters := map[string]any{
SysOriginQuery: "北京景点推荐",
}
i, err := client.Run(componentID, "latest", "", false, parameters)
if err != nil {
t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m")
t.Fatalf("run component failed: %v", err)
}

// test result
for answer, err := i.Next(); err == nil; answer, err = i.Next() {
data := answer.Content[0].Text
if data == nil {
t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m")
t.Fatalf("run component failed: data is nil")
}
}

i2, err := client.Run(componentID, "latest", "", true, parameters)
if err != nil {
t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m")
t.Fatalf("run component failed: %v", err)
}

// test stream result
var answerText any
for answer, err := i2.Next(); err == nil; answer, err = i2.Next() {
if len(answer.Content) == 0 {
continue
}
answerText = answer.Content[0].Text
}
if answerText == nil {
t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m")
t.Fatalf("run component failed: data is nil")
}

// 如果测试失败,则输出缓冲区中的日志
if t.Failed() {
fmt.Println(logBuffer.String())
} else { // else 紧跟在右大括号后面
// 测试通过,打印文件名和测试函数名
t.Logf("%s========== OK: %s ==========%s", "\033[32m", t.Name(), "\033[0m")
}
}
Loading

0 comments on commit f16d544

Please sign in to comment.