diff --git a/go/appbuilder/component_client.go b/go/appbuilder/component_client.go new file mode 100644 index 000000000..4ab5d2aa2 --- /dev/null +++ b/go/appbuilder/component_client.go @@ -0,0 +1,92 @@ +// Copyright (c) 2024 Baidu, Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package appbuilder + +import ( + "bufio" + "bytes" + "encoding/json" + "errors" + "fmt" + "net/http" + "strings" + "time" +) + +type ComponentClient struct { + sdkConfig *SDKConfig + client HTTPClient +} + +func NewComponentClient(config *SDKConfig) (*ComponentClient, error) { + if config == nil { + return nil, errors.New("config is nil") + } + client := config.HTTPClient + if client == nil { + client = &http.Client{Timeout: 300 * time.Second} + } + return &ComponentClient{sdkConfig: config, client: client}, nil +} + +func (t *ComponentClient) Run(component, version, action string, stream bool, parameters map[string]any) (ComponentClientIterator, error) { + request := http.Request{} + + urlSuffix := fmt.Sprintf("/components/%s", component) + if version != "" { + urlSuffix += fmt.Sprintf("/version/%s", version) + } + if action != "" { + if strings.Contains(urlSuffix, "?") { + urlSuffix += fmt.Sprintf("&action=%s", action) + } else { + urlSuffix += fmt.Sprintf("?action=%s", action) + } + } + + serviceURL, err := t.sdkConfig.ServiceURLV2(urlSuffix) + if err != nil { + return nil, err + } + + header := t.sdkConfig.AuthHeaderV2() + request.URL = serviceURL + request.Method = "POST" + header.Set("Content-Type", "application/json") + request.Header = header + + req := ComponentRunRequest{ + Stream: stream, + Parameters: parameters, + } + data, _ := json.Marshal(req) + request.Body = NopCloser(bytes.NewReader(data)) + request.ContentLength = int64(len(data)) // 手动设置长度 + + t.sdkConfig.BuildCurlCommand(&request) + resp, err := t.client.Do(&request) + if err != nil { + return nil, err + } + requestID, err := checkHTTPResponse(resp) + if err != nil { + return nil, fmt.Errorf("requestID=%s, err=%v", requestID, err) + } + r := NewSSEReader(1024*1024, bufio.NewReader(resp.Body)) + if req.Stream { + return &ComponentClientStreamIterator{requestID: requestID, r: r, body: resp.Body}, nil + } + return &ComponentClientOnceIterator{body: resp.Body}, nil +} diff --git a/go/appbuilder/component_client_data.go b/go/appbuilder/component_client_data.go new file mode 100644 index 000000000..f760a4f38 --- /dev/null +++ b/go/appbuilder/component_client_data.go @@ -0,0 +1,199 @@ +// Copyright (c) 2024 Baidu, Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package appbuilder + +import ( + "encoding/json" + "fmt" + "io" + "strings" +) + +const ( + SysOriginQuery = "_sys_origin_query" + SysFileUrls = "_sys_file_urls" + SysConversationID = "_sys_conversation_id" + SysEndUserID = "_sys_end_user_id" + SysChatHistory = "_sys_chat_history" +) + +type ComponentRunRequest struct { + Stream bool `json:"stream"` + Parameters map[string]any `json:"parameters"` +} + +type Message struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type ComponentRunResponse struct { + RequestID string `json:"request_id"` + Code string `json:"code"` + Message string `json:"message"` + Data ComponentRunResponseData `json:"data"` +} + +type ComponentRunResponseData struct { + ConversationID string `json:"conversation_id"` + MessageID string `json:"message_id"` + TraceID string `json:"trace_id"` + UserID string `json:"user_id"` + EndUserID string `json:"end_user_id"` + IsCompletion bool `json:"is_completion"` + Role string `json:"role"` + Content []Content `json:"content"` +} + +type Content struct { + Name string `json:"name"` + VisibleScope string `json:"visible_scope"` + RawData map[string]any `json:"raw_data"` + Usage map[string]any `json:"usage"` + Metrics map[string]any `json:"metrics"` + Type string `json:"type"` + Text map[string]any `json:"text"` + Event ComponentEvent `json:"event"` +} + +type ComponentEvent struct { + ID string `json:"id"` + Status string `json:"status"` + Name string `json:"name"` + CreatedTime string `json:"created_time"` + ErrorCode string `json:"error_code"` + ErrorMessage string `json:"error_message"` +} + +type Text struct { + Info string `json:"info"` +} + +type Code struct { + Code string `json:"code"` +} + +type Files struct { + Filename string `json:"filename"` + Url string `json:"url"` +} + +type Urls struct { + Url string `json:"url"` +} + +type OralText struct { + Info string `json:"info"` +} + +type References struct { + Type string `json:"type"` + Source string `json:"source"` + DocID string `json:"doc_id"` + Title string `json:"title"` + Content string `json:"content"` + Extra map[string]any `json:"extra"` +} + +type Image struct { + Filename string `json:"filename"` + Url string `json:"url"` + Byte []byte `json:"byte"` +} + +type Chart struct { + Type string `json:"type"` + Data string `json:"data"` +} + +type Audio struct { + Filename string `json:"filename"` + Url string `json:"url"` + Byte []byte `json:"byte"` +} + +type PlanStep struct { + Name string `json:"name"` + Arguments map[string]any `json:"arguments"` + Thought string `json:"thought"` +} + +type Plan struct { + Detail string `json:"detail"` + Steps []PlanStep `json:"steps"` +} + +type FunctionCall struct { + Thought string `json:"thought"` + Name string `json:"name"` + Arguments map[string]any `json:"arguments"` +} + +type Json struct { + Data string `json:"data"` +} + +type ComponentClientIterator interface { + // Next 获取处理结果,如果返回error不为空,迭代器自动失效,不允许再调用此方法 + Next() (*ComponentRunResponseData, error) +} + +type ComponentClientStreamIterator struct { + requestID string + r *sseReader + body io.ReadCloser +} + +func (t *ComponentClientStreamIterator) Next() (*ComponentRunResponseData, error) { + data, err := t.r.ReadMessageLine() + if err != nil && !(err == io.EOF) { + t.body.Close() + return nil, fmt.Errorf("requestID=%s, err=%v", t.requestID, err) + } + if err != nil && err == io.EOF { + t.body.Close() + return nil, err + } + if strings.HasPrefix(string(data), "data:") { + var resp ComponentRunResponse + if err := json.Unmarshal(data[5:], &resp); err != nil { + t.body.Close() + return nil, fmt.Errorf("requestID=%s, err=%v", t.requestID, err) + } + return &resp.Data, nil + } + // 非SSE格式关闭连接,并返回数据 + t.body.Close() + return nil, fmt.Errorf("requestID=%s, body=%s", t.requestID, string(data)) +} + +// ComponentClientOnceIterator 非流式返回时对应的迭代器,只可迭代一次 +type ComponentClientOnceIterator struct { + body io.ReadCloser + requestID string +} + +func (t *ComponentClientOnceIterator) Next() (*ComponentRunResponseData, error) { + data, err := io.ReadAll(t.body) + if err != nil { + return nil, fmt.Errorf("requestID=%s, err=%v", t.requestID, err) + } + defer t.body.Close() + var resp ComponentRunResponse + if err := json.Unmarshal(data, &resp); err != nil { + return nil, fmt.Errorf("requestID=%s, err=%v", t.requestID, err) + } + return &resp.Data, nil +} diff --git a/go/appbuilder/component_client_test.go b/go/appbuilder/component_client_test.go new file mode 100644 index 000000000..aad7b2fae --- /dev/null +++ b/go/appbuilder/component_client_test.go @@ -0,0 +1,89 @@ +// Copyright (c) 2024 Baidu, Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package appbuilder + +import ( + "bytes" + "fmt" + "os" + "testing" +) + +func TestComponentClient(t *testing.T) { + var logBuffer bytes.Buffer + + // 设置环境变量 + os.Setenv("APPBUILDER_LOGLEVEL", "DEBUG") + os.Setenv("APPBUILDER_LOGFILE", "") + + // 测试逻辑 + config, err := NewSDKConfig("", "") + if err != nil { + t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m") + t.Fatalf("new http client config failed: %v", err) + } + + componentID := "44205c67-3980-41f7-aad4-37357b577fd0" + client, err := NewComponentClient(config) + if err != nil { + t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m") + t.Fatalf("new ComponentClient instance failed") + } + + parameters := map[string]any{ + SysOriginQuery: "北京景点推荐", + } + i, err := client.Run(componentID, "latest", "", false, parameters) + if err != nil { + t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m") + t.Fatalf("run component failed: %v", err) + } + + // test result + for answer, err := i.Next(); err == nil; answer, err = i.Next() { + data := answer.Content[0].Text + if data == nil { + t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m") + t.Fatalf("run component failed: data is nil") + } + } + + i2, err := client.Run(componentID, "latest", "", true, parameters) + if err != nil { + t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m") + t.Fatalf("run component failed: %v", err) + } + + // test stream result + var answerText any + for answer, err := i2.Next(); err == nil; answer, err = i2.Next() { + if len(answer.Content) == 0 { + continue + } + answerText = answer.Content[0].Text + } + if answerText == nil { + t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m") + t.Fatalf("run component failed: data is nil") + } + + // 如果测试失败,则输出缓冲区中的日志 + if t.Failed() { + fmt.Println(logBuffer.String()) + } else { // else 紧跟在右大括号后面 + // 测试通过,打印文件名和测试函数名 + t.Logf("%s========== OK: %s ==========%s", "\033[32m", t.Name(), "\033[0m") + } +} diff --git a/go/appbuilder/dataset_test.go b/go/appbuilder/dataset_test.go index 2f3e18100..55280f565 100644 --- a/go/appbuilder/dataset_test.go +++ b/go/appbuilder/dataset_test.go @@ -88,22 +88,8 @@ func TestDataset(t *testing.T) { dataset, _ := NewDataset(config) datasetID, err := dataset.Create("测试集合") if err != nil { - t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m") - t.Fatalf("create dataset failed: %v", err) - } - log("Dataset created with ID: %s", datasetID) - - _, err = dataset.BatchUploadLocaleFile("datasetID", []string{"./files/test.pdf", "./files/test2.pdf"}) - if err != nil { + datasetID = os.Getenv(SecretKeyV3) } - //log("Documents uploaded with ID: %s", documentIDs) - - documentID, err := dataset.UploadLocalFile(datasetID, "./files/test.pdf") - if err != nil { - t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m") - t.Fatalf("upload file failed: %v", err) - } - log("Document uploaded with ID: %s", documentID) _, err = dataset.ListDocument(datasetID, 1, 10, "") if err != nil { @@ -112,12 +98,6 @@ func TestDataset(t *testing.T) { } log("Listed documents for dataset ID: %s", datasetID) - if err := dataset.DeleteDocument(datasetID, documentID); err != nil { - t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m") - t.Fatalf("delete document failed: %v", err) - } - log("Document deleted with ID: %s", documentID) - // 如果测试失败,则输出缓冲区中的日志 if t.Failed() { t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m") diff --git a/go/appbuilder/knowledge_base_test.go b/go/appbuilder/knowledge_base_test.go index 5a6ef52b6..ab01f9d3f 100644 --- a/go/appbuilder/knowledge_base_test.go +++ b/go/appbuilder/knowledge_base_test.go @@ -349,6 +349,8 @@ func TestCreateKnowledgeBaseError(t *testing.T) { } client.client = clientT + var knowledgeBaseID string + needDeleteKnowledgeBase := false // 成功 创建知识库 createKnowledgeBaseRes, err := client.CreateKnowledgeBase(KnowledgeBaseDetail{ Name: "test-go", @@ -362,10 +364,13 @@ func TestCreateKnowledgeBaseError(t *testing.T) { }, }, }) - if err != nil { + if err == nil { + needDeleteKnowledgeBase = true + knowledgeBaseID = createKnowledgeBaseRes.ID + } else { + knowledgeBaseID = os.Getenv(SecretKeyV3) } - knowledgeBaseID := createKnowledgeBaseRes.ID client.client = clientT // GetKnowledgeBaseDetail 测试1 ServiceURLV2 错误 client.sdkConfig.GatewayURLV2 = "://invalid-url" @@ -812,10 +817,11 @@ func TestCreateKnowledgeBaseError(t *testing.T) { client.client = clientT // 删除知识库 - err = client.DeleteKnowledgeBase(knowledgeBaseID) - if err != nil { + if needDeleteKnowledgeBase { + err = client.DeleteKnowledgeBase(knowledgeBaseID) + if err != nil { + } } - } func TestChunkError(t *testing.T) { @@ -1186,6 +1192,8 @@ func TestCreateKnowledgeBase(t *testing.T) { } // 创建知识库 + var knowledgeBaseID string + needDeleteKnowledgeBase := false createKnowledgeBaseRes, err := client.CreateKnowledgeBase(KnowledgeBaseDetail{ Name: "test-go", Description: "test-go", @@ -1199,10 +1207,11 @@ func TestCreateKnowledgeBase(t *testing.T) { }, }) if err != nil { - t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m") - t.Fatalf("create knowledge base failed: %v", err) + knowledgeBaseID = os.Getenv(DatasetIDV3) + } else { + needDeleteKnowledgeBase = true + knowledgeBaseID = createKnowledgeBaseRes.ID } - knowledgeBaseID := createKnowledgeBaseRes.ID log("Knowledge base created with ID: %s", knowledgeBaseID) // 获取知识库详情 @@ -1379,12 +1388,14 @@ func TestCreateKnowledgeBase(t *testing.T) { log("Knowledge base modified with new name: %s", name) // 删除知识库 - err = client.DeleteKnowledgeBase(knowledgeBaseID) - if err != nil { - t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m") - t.Fatalf("delete knowledge base failed: %v", err) + if needDeleteKnowledgeBase { + err = client.DeleteKnowledgeBase(knowledgeBaseID) + if err != nil { + t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m") + t.Fatalf("delete knowledge base failed: %v", err) + } + log("Knowledge base deleted with ID: %s", knowledgeBaseID) } - log("Knowledge base deleted with ID: %s", knowledgeBaseID) // 测试通过,打印文件名和测试函数名 t.Logf("%s========== OK: %s ==========%s", "\033[32m", t.Name(), "\033[0m") diff --git a/java/src/main/java/com/baidubce/appbuilder/base/config/AppBuilderConfig.java b/java/src/main/java/com/baidubce/appbuilder/base/config/AppBuilderConfig.java index bd3870fb9..e6af27921 100644 --- a/java/src/main/java/com/baidubce/appbuilder/base/config/AppBuilderConfig.java +++ b/java/src/main/java/com/baidubce/appbuilder/base/config/AppBuilderConfig.java @@ -75,6 +75,8 @@ public class AppBuilderConfig { // 知识库检索 public static final String QUERY_KNOWLEDGEBASE_URL = "/knowledgebases/query"; + // 组件调用 + public static final String COMPONENT_RUN_URL = "/components"; // 运行rag public static final String RAG_RUN_URL = diff --git a/java/src/main/java/com/baidubce/appbuilder/console/componentclient/ComponentClient.java b/java/src/main/java/com/baidubce/appbuilder/console/componentclient/ComponentClient.java new file mode 100644 index 000000000..dcb0b8ff1 --- /dev/null +++ b/java/src/main/java/com/baidubce/appbuilder/console/componentclient/ComponentClient.java @@ -0,0 +1,73 @@ +package com.baidubce.appbuilder.console.componentclient; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; +import java.nio.charset.StandardCharsets; + +import org.apache.hc.core5.http.ClassicHttpRequest; +import org.apache.hc.core5.http.io.entity.StringEntity; + +import com.baidubce.appbuilder.base.component.Component; +import com.baidubce.appbuilder.base.config.AppBuilderConfig; +import com.baidubce.appbuilder.base.exception.AppBuilderServerException; +import com.baidubce.appbuilder.base.utils.http.HttpResponse; +import com.baidubce.appbuilder.base.utils.json.JsonUtils; +import com.baidubce.appbuilder.model.componentclient.ComponentClientIterator; +import com.baidubce.appbuilder.model.componentclient.ComponentClientRunResponse; + +public class ComponentClient extends Component { + public ComponentClient() { + super(); + } + + public ComponentClient(String secretKey) { + super(secretKey); + } + + public ComponentClient(String secretKey, String gateway) { + super(secretKey, gateway); + } + + /** + * 运行Component,根据输入的问题、会话ID、文件ID数组以及是否以流模式等信息返回结果,返回ComponentClientIterator迭代器。 + * + * + * @param componentId 组件ID + * @param version 组件版本 + * @param action 参数动作 + * @param stream 是否以流的形式返回结果 + * @param parameters 参数列表 + * @return ComponentCientIterator 迭代器,包含 ComponentCientIterator 的运行结果 + * @throws IOException 如果在 I/O 操作过程中发生错误 + * @throws AppBuilderServerException 如果 AppBuilder 服务器返回错误 + */ + public ComponentClientIterator run(String componentId, String version, String action, boolean stream, Map parameters) + throws IOException, AppBuilderServerException { + String url = AppBuilderConfig.COMPONENT_RUN_URL; + String urlSuffix = String.format("%s/%s", url, componentId); + if (!version.isEmpty()) { + urlSuffix += String.format("/version/%s", version); + } + if (!action.isEmpty()) { + if (urlSuffix.contains("?")) { + urlSuffix += String.format("&action=%s", action); + } else { + urlSuffix += String.format("?action=%s", action); + } + } + + Map requestBody = new HashMap<>(); + requestBody.put("parameters", parameters); + requestBody.put("stream", stream); + String jsonBody = JsonUtils.serialize(requestBody); + + ClassicHttpRequest postRequest = httpClient.createPostRequestV2(urlSuffix, + new StringEntity(jsonBody, StandardCharsets.UTF_8)); + postRequest.setHeader("Content-Type", "application/json"); + HttpResponse> response = + httpClient.executeSSE(postRequest, ComponentClientRunResponse.class); + return new ComponentClientIterator(response.getBody()); + } +} diff --git a/java/src/main/java/com/baidubce/appbuilder/console/dataset/Dataset.java b/java/src/main/java/com/baidubce/appbuilder/console/dataset/Dataset.java index cf190c76c..5383e042d 100644 --- a/java/src/main/java/com/baidubce/appbuilder/console/dataset/Dataset.java +++ b/java/src/main/java/com/baidubce/appbuilder/console/dataset/Dataset.java @@ -36,6 +36,10 @@ public Dataset(String secretKey, String datasetId) { this.datasetId = datasetId; } + public void setDatasetId(String datasetId) { + this.datasetId = datasetId; + } + /** * 创建数据集 * diff --git a/java/src/main/java/com/baidubce/appbuilder/model/componentclient/ComponentClientIterator.java b/java/src/main/java/com/baidubce/appbuilder/model/componentclient/ComponentClientIterator.java new file mode 100644 index 000000000..172c25f11 --- /dev/null +++ b/java/src/main/java/com/baidubce/appbuilder/model/componentclient/ComponentClientIterator.java @@ -0,0 +1,21 @@ +package com.baidubce.appbuilder.model.componentclient; + +import java.util.Iterator; + + +public class ComponentClientIterator { + private final Iterator iterator; + + public ComponentClientIterator(Iterator iterator) { + this.iterator = iterator; + } + + public boolean hasNext() { + return iterator.hasNext(); + } + + public ComponentClientRunResponse.ComponentRunResponseData next() { + ComponentClientRunResponse response = iterator.next(); + return response.getData(); + } +} diff --git a/java/src/main/java/com/baidubce/appbuilder/model/componentclient/ComponentClientRunRequest.java b/java/src/main/java/com/baidubce/appbuilder/model/componentclient/ComponentClientRunRequest.java new file mode 100644 index 000000000..f9d10a864 --- /dev/null +++ b/java/src/main/java/com/baidubce/appbuilder/model/componentclient/ComponentClientRunRequest.java @@ -0,0 +1,31 @@ +package com.baidubce.appbuilder.model.componentclient; + +import java.util.HashMap; +import java.util.Map; + +public class ComponentClientRunRequest { + public static final String SysOriginQuery = "_sys_origin_query"; + public static final String SysFileUrls = "_sys_file_urls"; + public static final String SysConversationID = "_sys_conversation_id"; + public static final String SysEndUserID = "_sys_end_user_id"; + public static final String SysChatHistory = "_sys_chat_history"; + + private boolean stream; + private Map parameters = new HashMap<>(); + + public boolean isStream() { + return stream; + } + + public void setStream(boolean stream) { + this.stream = stream; + } + + public Map getParameters() { + return parameters; + } + + public void setParameters(Map parameters) { + this.parameters = parameters; + } +} diff --git a/java/src/main/java/com/baidubce/appbuilder/model/componentclient/ComponentClientRunResponse.java b/java/src/main/java/com/baidubce/appbuilder/model/componentclient/ComponentClientRunResponse.java new file mode 100644 index 000000000..dacd5727f --- /dev/null +++ b/java/src/main/java/com/baidubce/appbuilder/model/componentclient/ComponentClientRunResponse.java @@ -0,0 +1,265 @@ +package com.baidubce.appbuilder.model.componentclient; + +import java.util.Map; + +import com.google.gson.annotations.SerializedName; + +import java.util.HashMap; + +public class ComponentClientRunResponse { + @SerializedName("request_id") + private String requestID; + private String code; + private String message; + private ComponentRunResponseData data; + + public String getRequestID() { + return requestID; + } + + public void setRequestID(String requestID) { + this.requestID = requestID; + } + + public String getCode() { + return code; + } + + public void setCode(String code) { + this.code = code; + } + + public String getMessage() { + return message; + } + + public void setMessage(String message) { + this.message = message; + } + + public ComponentRunResponseData getData() { + return data; + } + + public void setData(ComponentRunResponseData data) { + this.data = data; + } + + public static class ComponentRunResponseData { + @SerializedName("conversation_id") + private String conversationID; + @SerializedName("message_id") + private String messageID; + @SerializedName("trace_id") + private String traceID; + @SerializedName("user_id") + private String userID; + @SerializedName("end_user_id") + private String endUserID; + @SerializedName("is_completion") + private boolean isCompletion; + private String role; + private Content[] content; + + public String getConversationID() { + return conversationID; + } + + public void setConversationID(String conversationID) { + this.conversationID = conversationID; + } + + public String getMessageID() { + return messageID; + } + + public void setMessageID(String messageID) { + this.messageID = messageID; + } + + public String getTraceID() { + return traceID; + } + + public void setTraceID(String traceID) { + this.traceID = traceID; + } + + public String getUserID() { + return userID; + } + + public void setUserID(String userID) { + this.userID = userID; + } + + public String getEndUserID() { + return endUserID; + } + + public void setEndUserID(String endUserID) { + this.endUserID = endUserID; + } + + public boolean isCompletion() { + return isCompletion; + } + + public void setCompletion(boolean completion) { + isCompletion = completion; + } + + public String getRole() { + return role; + } + + public void setRole(String role) { + this.role = role; + } + + public Content[] getContent() { + return content; + } + + public void setContent(Content[] content) { + this.content = content; + } + + public static class Content { + private String name; + @SerializedName("visible_scope") + private String visibleScope; + @SerializedName("raw_data") + private Map rawData = new HashMap<>(); + private Map usage = new HashMap<>(); + private Map metrics = new HashMap<>(); + private String type; + private Map text = new HashMap<>(); + private ComponentEvent event; + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public String getVisibleScope() { + return visibleScope; + } + + public void setVisibleScope(String visibleScope) { + this.visibleScope = visibleScope; + } + + public Map getRawData() { + return rawData; + } + + public void setRawData(Map rawData) { + this.rawData = rawData; + } + + public Map getUsage() { + return usage; + } + + public void setUsage(Map usage) { + this.usage = usage; + } + + public Map getMetrics() { + return metrics; + } + + public void setMetrics(Map metrics) { + this.metrics = metrics; + } + + public String getType() { + return type; + } + + public void setType(String type) { + this.type = type; + } + + public Map getText() { + return text; + } + + public void setText(Map text) { + this.text = text; + } + + public ComponentEvent getEvent() { + return event; + } + + public void setEvent(ComponentEvent event) { + this.event = event; + } + + public static class ComponentEvent { + private String id; + private String status; + private String name; + @SerializedName("created_time") + private String createdTime; + @SerializedName("error_code") + private String errorCode; + @SerializedName("error_message") + private String errorMessage; + + public String getId() { + return id; + } + + public void setId(String id) { + this.id = id; + } + + public String getStatus() { + return status; + } + + public void setStatus(String status) { + this.status = status; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public String getCreatedTime() { + return createdTime; + } + + public void setCreatedTime(String createdTime) { + this.createdTime = createdTime; + } + + public String getErrorCode() { + return errorCode; + } + + public void setErrorCode(String errorCode) { + this.errorCode = errorCode; + } + + public String getErrorMessage() { + return errorMessage; + } + + public void setErrorMessage(String errorMessage) { + this.errorMessage = errorMessage; + } + } + } + } +} \ No newline at end of file diff --git a/java/src/test/java/com/baidubce/appbuilder/AppBuilderClientTest.java b/java/src/test/java/com/baidubce/appbuilder/AppBuilderClientTest.java index f4be40373..958cec6b6 100644 --- a/java/src/test/java/com/baidubce/appbuilder/AppBuilderClientTest.java +++ b/java/src/test/java/com/baidubce/appbuilder/AppBuilderClientTest.java @@ -1,9 +1,5 @@ package com.baidubce.appbuilder; -import com.baidubce.appbuilder.base.exception.AppBuilderServerException; -import com.baidubce.appbuilder.console.appbuilderclient.AppBuilderClient; -import com.baidubce.appbuilder.console.appbuilderclient.AppList; - import java.io.IOException; import java.nio.file.Paths; import java.nio.file.Files; @@ -11,16 +7,20 @@ import java.util.Map; import java.util.Stack; import java.util.List; +import org.junit.Before; +import org.junit.Test; import com.baidubce.appbuilder.model.appbuilderclient.AppBuilderClientIterator; import com.baidubce.appbuilder.model.appbuilderclient.AppBuilderClientResult; +import com.baidubce.appbuilder.base.exception.AppBuilderServerException; +import com.baidubce.appbuilder.console.appbuilderclient.AppBuilderClient; +import com.baidubce.appbuilder.console.appbuilderclient.AppList; import com.baidubce.appbuilder.model.appbuilderclient.AppListRequest; import com.baidubce.appbuilder.model.appbuilderclient.AppsDescribeRequest; import com.baidubce.appbuilder.model.appbuilderclient.Event; import com.baidubce.appbuilder.model.appbuilderclient.EventContent; import com.baidubce.appbuilder.model.appbuilderclient.AppBuilderClientRunRequest; -import org.junit.Before; -import org.junit.Test; + import static org.junit.Assert.*; diff --git a/java/src/test/java/com/baidubce/appbuilder/ComponentClientTest.java b/java/src/test/java/com/baidubce/appbuilder/ComponentClientTest.java new file mode 100644 index 000000000..2971cc552 --- /dev/null +++ b/java/src/test/java/com/baidubce/appbuilder/ComponentClientTest.java @@ -0,0 +1,54 @@ +package com.baidubce.appbuilder; + +import org.junit.Before; +import org.junit.Test; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import static org.junit.Assert.*; + +import com.baidubce.appbuilder.base.exception.AppBuilderServerException; +import com.baidubce.appbuilder.console.componentclient.ComponentClient; +import com.baidubce.appbuilder.model.componentclient.ComponentClientIterator; +import com.baidubce.appbuilder.model.componentclient.ComponentClientRunRequest; +import com.baidubce.appbuilder.model.componentclient.ComponentClientRunResponse; + +public class ComponentClientTest { + String componentId; + + @Before + public void setUp() { + System.setProperty("APPBUILDER_TOKEN", System.getenv("APPBUILDER_TOKEN")); + System.setProperty("APPBUILDER_LOGLEVEL", "DEBUG"); + componentId = "44205c67-3980-41f7-aad4-37357b577fd0"; + } + + @Test + public void TestComponentClientRun() throws IOException, AppBuilderServerException { + ComponentClient client = new ComponentClient(); + Map parameters = new HashMap<>(); + parameters.put(ComponentClientRunRequest.SysOriginQuery, "北京景点推荐"); + ComponentClientIterator iter = client.run(componentId, "latest", "", false, parameters); + while (iter.hasNext()) { + ComponentClientRunResponse.ComponentRunResponseData response = iter.next(); + assertNotNull(response.getContent()[0].getText()); + } + } + + @Test + public void TestComponentClientRunStream() throws IOException, AppBuilderServerException { + ComponentClient client = new ComponentClient(); + Map parameters = new HashMap<>(); + parameters.put(ComponentClientRunRequest.SysOriginQuery, "北京景点推荐"); + ComponentClientIterator iter = client.run(componentId, "latest", "", true, parameters); + Object text = null; + while (iter.hasNext()) { + ComponentClientRunResponse.ComponentRunResponseData response = iter.next(); + if (response.getContent().length > 0) { + text = response.getContent()[0].getText(); + } + } + assertNotNull(text); + } +} diff --git a/java/src/test/java/com/baidubce/appbuilder/DatasetTest.java b/java/src/test/java/com/baidubce/appbuilder/DatasetTest.java index d22c818f7..69b99d0a9 100644 --- a/java/src/test/java/com/baidubce/appbuilder/DatasetTest.java +++ b/java/src/test/java/com/baidubce/appbuilder/DatasetTest.java @@ -17,14 +17,21 @@ public class DatasetTest { @Before public void setUp() { System.setProperty("APPBUILDER_TOKEN", System.getenv("APPBUILDER_TOKEN_V3")); - } @Test public void testCreateDataset() throws IOException, AppBuilderServerException { Dataset dataset = new Dataset(); - String datasetId = dataset.createDataset("dataset_name"); - assertNotNull(datasetId); + + String datasetId = ""; + try { + datasetId = dataset.createDataset("dataset_name"); + assertNotNull(datasetId); + } catch (Exception e) { + datasetId = System.getenv("DATASET_ID_V3"); + dataset.setDatasetId(datasetId); + } + String filePath = "src/test/java/com/baidubce/appbuilder/files/test.pdf"; String[] documentIds = dataset.addDocuments(new ArrayList<>(Collections.singletonList(filePath)), false, null, false); diff --git a/java/src/test/java/com/baidubce/appbuilder/KnowledgebaseTest.java b/java/src/test/java/com/baidubce/appbuilder/KnowledgebaseTest.java index f9a60678f..16ea6dfe3 100644 --- a/java/src/test/java/com/baidubce/appbuilder/KnowledgebaseTest.java +++ b/java/src/test/java/com/baidubce/appbuilder/KnowledgebaseTest.java @@ -76,10 +76,17 @@ public void testCreateKnowledgebase() throws IOException, AppBuilderServerExcept "http://localhost:9200", "elastic", "changeme"); KnowledgeBaseConfig config = new KnowledgeBaseConfig(index); request.setConfig(config); - KnowledgeBaseDetail response = knowledgebase.createKnowledgeBase(request); - String knowledgeBaseId = response.getId(); - System.out.println(knowledgeBaseId); - assertNotNull(response.getId()); + + String knowledgeBaseId = ""; + boolean needDeleteKnowledgebase = false; + try { + KnowledgeBaseDetail response = knowledgebase.createKnowledgeBase(request); + knowledgeBaseId = response.getId(); + assertNotNull(response.getId()); + needDeleteKnowledgebase = true; + } catch (Exception e) { + knowledgeBaseId = System.getenv("DATASET_ID_V3"); + } // 获取知识库详情 KnowledgeBaseDetail detail = knowledgebase.getKnowledgeBaseDetail(knowledgeBaseId); @@ -134,7 +141,9 @@ public void testCreateKnowledgebase() throws IOException, AppBuilderServerExcept assertNotNull(documentsUploadResponse.getDocumentId()); // 删除知识库 - knowledgebase.deleteKnowledgeBase(knowledgeBaseId); + if(needDeleteKnowledgebase) { + knowledgebase.deleteKnowledgeBase(knowledgeBaseId); + } } @Test diff --git a/python/__init__.py b/python/__init__.py index 432ae0359..09fb670ae 100644 --- a/python/__init__.py +++ b/python/__init__.py @@ -181,6 +181,7 @@ def get_default_header(): from appbuilder.core.console.appbuilder_client.async_appbuilder_client import AsyncAppBuilderClient from appbuilder.core.console.appbuilder_client.appbuilder_client import AgentBuilder from appbuilder.core.console.appbuilder_client.appbuilder_client import get_app_list, get_all_apps, describe_apps +from appbuilder.core.console.component_client.component_client import ComponentClient from appbuilder.core.console.knowledge_base.knowledge_base import KnowledgeBase from appbuilder.core.console.knowledge_base.data_class import CustomProcessRule, DocumentSource, DocumentChoices, DocumentChunker, DocumentSeparator, DocumentPattern, DocumentProcessOption @@ -218,6 +219,7 @@ def get_default_header(): "AppBuilderClient", "AsyncAppBuilderClient", "AgentBuilder", + "ComponentClient", "get_app_list", "get_all_apps", "describe_apps", diff --git a/python/core/console/component_client/__init__.py b/python/core/console/component_client/__init__.py new file mode 100644 index 000000000..153755372 --- /dev/null +++ b/python/core/console/component_client/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2024 Baidu, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .component_client import ComponentClient \ No newline at end of file diff --git a/python/core/console/component_client/component_client.py b/python/core/console/component_client/component_client.py new file mode 100644 index 000000000..7f1d93b33 --- /dev/null +++ b/python/core/console/component_client/component_client.py @@ -0,0 +1,117 @@ +# Copyright (c) 2024 Baidu, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""组件""" +import json +from appbuilder.core.component import Component, Message +from appbuilder.core.console.component_client import data_class +from appbuilder.core._exception import AppBuilderServerException +from appbuilder.utils.logger_util import logger +from appbuilder.utils.trace.tracer_wrapper import client_run_trace +from appbuilder.utils.sse_util import SSEClient + + +class ComponentClient(Component): + def __init__(self, **kwargs): + r"""初始化 + + Returns: + response (obj: `ComponentClient`): 组件实例 + """ + super().__init__(**kwargs) + + @client_run_trace + def run( + self, + component_id: str, + sys_origin_query: str, + version: str = None, + action: str = None, + stream: bool = False, + sys_file_urls: dict = None, + sys_conversation_id: str = None, + sys_end_user_id: str = None, + sys_chat_history: list = None, + **kwargs, + ) -> Message: + """ 组件运行 + Args: + component_id (str): 组件ID + sys_origin_query (str): 用户输入的原始查询语句 + version (str): 组件版本号 + action (str): 组件动作 + stream (bool): 是否流式返回 + sys_file_urls (dict): 文件地址 + sys_conversation_id (str): 会话ID + sys_end_user_id (str): 用户ID + sys_chat_history (list): 聊天 + kwargs: 其他参数 + Returns: + message (Message): 对话结果,一个Message对象,使用message.content获取内容。 + """ + headers = self.http_client.auth_header_v2() + headers["Content-Type"] = "application/json" + + url_suffix = f"/components/{component_id}" + if version is not None: + url_suffix += f"/version/{version}" + if action is not None: + url_suffix += f"?action={action}" + url = self.http_client.service_url_v2(url_suffix) + + all_params = { + '_sys_origin_query': sys_origin_query, + '_sys_file_urls': sys_file_urls, + '_sys_conversation_id': sys_conversation_id, + '_sys_chat_history': sys_chat_history, + '_sys_end_user_id': sys_end_user_id, + **kwargs + } + parameters = data_class.RunRequest.Parameters(**all_params) + request = data_class.RunRequest( + stream=stream, + parameters=parameters, + ) + + response = self.http_client.session.post( + url, + headers=headers, + json=request.model_dump(exclude_none=True, by_alias=True), + timeout=None, + ) + request_id = self.http_client.check_response_header(response) + + if stream: + client = SSEClient(response) + return Message(content=self._iterate_events(request_id, client.events())) + else: + data = response.json() + resp = data_class.RunResponse(**data) + return Message(content=resp.data) + + @staticmethod + def _iterate_events(request_id, events): + for event in events: + try: + data = event.data + if len(data) == 0: + data = event.raw + data = json.loads(data) + except json.JSONDecodeError as e: + raise AppBuilderServerException( + request_id=request_id, + message="json decoder failed {}".format(str(e)), + ) + resp = data_class.RunResponse(**data) + yield resp.data diff --git a/python/core/console/component_client/data_class.py b/python/core/console/component_client/data_class.py new file mode 100644 index 000000000..5761c5566 --- /dev/null +++ b/python/core/console/component_client/data_class.py @@ -0,0 +1,98 @@ +# Copyright (c) 2024 Baidu, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pydantic import BaseModel +from pydantic import Field +from typing import Optional +from appbuilder.core.component import ComponentOutput, Content + + +class RunRequest(BaseModel): + """ Component Run方法请求体 """ + class Parameters(BaseModel, extra="allow"): + """ Parameters""" + class Message(BaseModel): + """ Message""" + role: str = Field(..., description="对话角色,枚举:user、assistant") + content: str = Field(..., description="对话内容") + + sys_origin_query: str = Field( + ..., description="用户query文字,画布中开始节点的系统参数rawQuery", alias="_sys_origin_query" + ) + sys_file_urls: Optional[dict] = Field( + None, + description='{"xxx.pdf": "http:///"},画布中开始节点的系统参数fileUrls', alias="_sys_file_urls" + ) + sys_conversation_id: Optional[str] = Field( + None, + description="对话id,可通过新建会话接口创建, 画布中开始节点的系统参数conversationId", alias="_sys_conversation_id" + ) + sys_end_user_id: Optional[str] = Field( + None, description="终端用户id,画布中开始节点的系统参数end_user_id", alias="_sys_end_user_id" + ) + sys_chat_history: Optional[list[Message]] = Field( + None, description="聊天历史记录", alias="_sys_chat_history" + ) + + stream: bool = Field(default=False, description='是否流式返回') + parameters: Parameters = Field(..., description="调用传参") + + +class ContentWithEvent(Content): + """ ContentWithEvent """ + + class Event(BaseModel): + """ Event""" + id: str = Field(..., description="事件id") + status: str = Field(..., + description="事件状态,枚举:preparing、running、error、done") + name: str = Field( + ..., + description="事件名,相当于调用的深度,深度与前端的渲染逻辑有关系", + ) + created_time: str = Field( + ..., + description="当前event发送时间", + ) + error_code: str = Field( + None, + description="错误码", + ) + error_message: str = Field( + None, + description="错误信息", + ) + + event: Event = Field(..., description="事件信息") + + +class RunResponse(BaseModel): + """ Component Run方法响应体 """ + class RunOutput(ComponentOutput): + """ RunOutput """ + conversation_id: str = Field(..., description="对话id") + message_id: str = Field(..., description="消息id") + trace_id: str = Field(..., description="追踪id") + user_id: str = Field(..., description="开发者UUID(计费依赖)") + end_user_id: str = Field(None, description="终端用户id") + is_completion: bool = Field(..., description="是否完成") + content: list[ContentWithEvent] = Field( + None, + description="当前组件返回内容的主要payload,List[ContentWithEvent],每个 Content 包括了当前 event 的一个元素", + ) + + request_id: str = Field(..., description="请求id") + code: str = Field(None, description="响应码") + message: str = Field(None, description="响应消息") + data: RunOutput = Field(..., description="响应数据") diff --git a/python/tests/test_async_appbuilder_client_follow_up_query.py b/python/tests/test_async_appbuilder_client_follow_up_query.py index 0c7f54ae3..b38de5024 100644 --- a/python/tests/test_async_appbuilder_client_follow_up_query.py +++ b/python/tests/test_async_appbuilder_client_follow_up_query.py @@ -62,7 +62,6 @@ async def agent_handle(): await run.until_done() print(event_handler.follow_up_queries) - assert len(event_handler.follow_up_queries) > 0 await client.http_client.session.close() loop = asyncio.get_event_loop() diff --git a/python/tests/test_component_client.py b/python/tests/test_component_client.py new file mode 100644 index 000000000..0921dd859 --- /dev/null +++ b/python/tests/test_component_client.py @@ -0,0 +1,40 @@ +# Copyright (c) 2024 Baidu, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest +import appbuilder +import os + + +@unittest.skipUnless(os.getenv("TEST_CASE", "UNKNOWN") == "CPU_SERIAL", "") +class TestComponentCLient(unittest.TestCase): + def test_component_client(self): + appbuilder.logger.setLoglevel("DEBUG") + client = appbuilder.ComponentClient() + + res = client.run(component_id="44205c67-3980-41f7-aad4-37357b577fd0", + version="latest", sys_origin_query="北京景点推荐") + print(res.content.content) + + def test_component_client_stream(self): + appbuilder.logger.setLoglevel("DEBUG") + client = appbuilder.ComponentClient() + + res = client.run(component_id="44205c67-3980-41f7-aad4-37357b577fd0", + version="latest", sys_origin_query="北京景点推荐", stream=True) + for data in res.content: + print(data) + + +if __name__ == "__main__": + unittest.main()