Skip to content

Commit

Permalink
Merge pull request #615 from Abirdcfly/fixhistoryinqa
Browse files Browse the repository at this point in the history
fix: qachain set history will get error
  • Loading branch information
bjwswang authored Jan 23, 2024
2 parents c51099f + 8230951 commit cad071c
Show file tree
Hide file tree
Showing 13 changed files with 28 additions and 29 deletions.
5 changes: 2 additions & 3 deletions api/app-node/chain/v1alpha1/llmchain_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ type CommonChainConfig struct {
// Usually this value is just empty
Model string `json:"model,omitempty"`
// MaxTokens is the maximum number of tokens to generate to use in a llm call.
// +kubebuilder:default=1024
// +kubebuilder:default=2048
MaxTokens int `json:"maxTokens,omitempty"`
// Temperature is the temperature for sampling to use in a llm call, between 0 and 1.
//+kubebuilder:validation:Minimum=0
Expand All @@ -57,8 +57,7 @@ type CommonChainConfig struct {
MinLength int `json:"minLength,omitempty"`
// MaxLength is the maximum length of the generated text in a llm call.
// +kubebuilder:validation:Minimum=10
// +kubebuilder:validation:Maximum=4096
// +kubebuilder:default=1024
// +kubebuilder:default=2048
MaxLength int `json:"maxLength,omitempty"`
// RepetitionPenalty is the repetition penalty for sampling in a llm call.
RepetitionPenalty float64 `json:"repetitionPenalty,omitempty"`
Expand Down
4 changes: 2 additions & 2 deletions apiserver/pkg/application/application.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ func addDefaultValue(gApp *generated.Application, app *v1alpha1.Application) {
gApp.NumDocuments = pointer.Int(5)
gApp.ScoreThreshold = pointer.Float64(0.3)
gApp.Temperature = pointer.Float64(0.7)
gApp.MaxLength = pointer.Int(1024)
gApp.MaxTokens = pointer.Int(1024)
gApp.MaxLength = pointer.Int(2048)
gApp.MaxTokens = pointer.Int(2048)
gApp.ConversionWindowSize = pointer.Int(5)
}

Expand Down
5 changes: 2 additions & 3 deletions apiserver/pkg/chat/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/util/uuid"
"k8s.io/klog/v2"

"github.com/kubeagi/arcadia/api/base/v1alpha1"
Expand All @@ -43,7 +42,7 @@ var (
Conversations = map[string]Conversation{}
)

func AppRun(ctx context.Context, req ChatReqBody, respStream chan string) (*ChatRespBody, error) {
func AppRun(ctx context.Context, req ChatReqBody, respStream chan string, messageID string) (*ChatRespBody, error) {
token := auth.ForOIDCToken(ctx)
c, err := client.GetClient(token)
if err != nil {
Expand Down Expand Up @@ -92,7 +91,7 @@ func AppRun(ctx context.Context, req ChatReqBody, respStream chan string) (*Chat
Debug: req.Debug,
}
}
messageID := string(uuid.NewUUID())

conversation.Messages = append(conversation.Messages, Message{
ID: messageID,
Query: req.Query,
Expand Down
7 changes: 5 additions & 2 deletions apiserver/service/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ func chatHandler() gin.HandlerFunc {
if req.NewChat {
req.ConversationID = string(uuid.NewUUID())
}
messageID := string(uuid.NewUUID())
var response *chat.ChatRespBody
var err error

Expand All @@ -84,9 +85,10 @@ func chatHandler() gin.HandlerFunc {
}
}
}()
response, err = chat.AppRun(c.Request.Context(), req, respStream)
response, err = chat.AppRun(c.Request.Context(), req, respStream, messageID)
if err != nil {
c.SSEvent("error", chat.ChatRespBody{
MessageID: messageID,
ConversationID: req.ConversationID,
Message: err.Error(),
CreatedAt: time.Now(),
Expand Down Expand Up @@ -132,6 +134,7 @@ func chatHandler() gin.HandlerFunc {
clientDisconnected := c.Stream(func(w io.Writer) bool {
if msg, ok := <-respStream; ok {
c.SSEvent("", chat.ChatRespBody{
MessageID: messageID,
ConversationID: req.ConversationID,
Message: msg,
CreatedAt: time.Now(),
Expand All @@ -148,7 +151,7 @@ func chatHandler() gin.HandlerFunc {
klog.FromContext(c.Request.Context()).Info("end to receive messages")
} else {
// handle chat blocking mode
response, err = chat.AppRun(c.Request.Context(), req, nil)
response, err = chat.AppRun(c.Request.Context(), req, nil, messageID)
if err != nil {
c.JSON(http.StatusInternalServerError, chat.ErrorResp{Err: err.Error()})
klog.FromContext(c.Request.Context()).Error(err, "error resp")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,13 @@ spec:
description: DisplayName defines datasource display name
type: string
maxLength:
default: 1024
default: 2048
description: MaxLength is the maximum length of the generated text
in a llm call.
maximum: 4096
minimum: 10
type: integer
maxTokens:
default: 1024
default: 2048
description: MaxTokens is the maximum number of tokens to generate
to use in a llm call.
type: integer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,13 @@ spec:
description: DisplayName defines datasource display name
type: string
maxLength:
default: 1024
default: 2048
description: MaxLength is the maximum length of the generated text
in a llm call.
maximum: 4096
minimum: 10
type: integer
maxTokens:
default: 1024
default: 2048
description: MaxTokens is the maximum number of tokens to generate
to use in a llm call.
type: integer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,13 @@ spec:
description: DisplayName defines datasource display name
type: string
maxLength:
default: 1024
default: 2048
description: MaxLength is the maximum length of the generated text
in a llm call.
maximum: 4096
minimum: 10
type: integer
maxTokens:
default: 1024
default: 2048
description: MaxTokens is the maximum number of tokens to generate
to use in a llm call.
type: integer
Expand Down
2 changes: 1 addition & 1 deletion deploy/charts/arcadia/Chart.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ apiVersion: v2
name: arcadia
description: A Helm chart(KubeBB Component) for KubeAGI Arcadia
type: application
version: 0.2.17
version: 0.2.18
appVersion: "0.1.0"

keywords:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,13 @@ spec:
description: DisplayName defines datasource display name
type: string
maxLength:
default: 1024
default: 2048
description: MaxLength is the maximum length of the generated text
in a llm call.
maximum: 4096
minimum: 10
type: integer
maxTokens:
default: 1024
default: 2048
description: MaxTokens is the maximum number of tokens to generate
to use in a llm call.
type: integer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,13 @@ spec:
description: DisplayName defines datasource display name
type: string
maxLength:
default: 1024
default: 2048
description: MaxLength is the maximum length of the generated text
in a llm call.
maximum: 4096
minimum: 10
type: integer
maxTokens:
default: 1024
default: 2048
description: MaxTokens is the maximum number of tokens to generate
to use in a llm call.
type: integer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,13 @@ spec:
description: DisplayName defines datasource display name
type: string
maxLength:
default: 1024
default: 2048
description: MaxLength is the maximum length of the generated text
in a llm call.
maximum: 4096
minimum: 10
type: integer
maxTokens:
default: 1024
default: 2048
description: MaxTokens is the maximum number of tokens to generate
to use in a llm call.
type: integer
Expand Down
1 change: 1 addition & 0 deletions pkg/appruntime/app_runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ func (a *Application) Run(ctx context.Context, cli dynamic.Interface, respStream
"question": input.Question,
"_answer_stream": respStream,
"_history": input.History,
"context": "",
}
visited := make(map[string]bool)
waitRunningNodes := list.New()
Expand Down
3 changes: 3 additions & 0 deletions pkg/appruntime/chain/retrievalqachain.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ func (l *RetrievalQAChain) Run(ctx context.Context, cli dynamic.Interface, args
options := getChainOptions(instance.Spec.CommonChainConfig)

llmChain := chains.NewLLMChain(llm, prompt)
if history != nil {
llmChain.Memory = getMemory(llm, instance.Spec.Memory, history, "", "")
}
var baseChain chains.Chain
var stuffDocuments *appretriever.KnowledgeBaseStuffDocuments
if knowledgeBaseRetriever, ok := v3.(*appretriever.KnowledgeBaseRetriever); ok {
Expand Down

0 comments on commit cad071c

Please sign in to comment.