Skip to content

Commit

Permalink
知识库检索 API更新
Browse files Browse the repository at this point in the history
  • Loading branch information
yinjiaqi authored and yinjiaqi committed Jan 20, 2025
1 parent ca54970 commit 73ef777
Show file tree
Hide file tree
Showing 8 changed files with 170 additions and 30 deletions.
6 changes: 6 additions & 0 deletions go/appbuilder/knowledge_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,12 @@ func (t *KnowledgeBase) DescribeChunks(req DescribeChunksRequest) (DescribeChunk
}

func (t *KnowledgeBase) QueryKnowledgeBase(req QueryKnowledgeBaseRequest) (QueryKnowledgeBaseResponse, error) {
// 检查 RankScoreThreshold 是否为 nil,如果是,则设置默认值
if req.RankScoreThreshold == nil {
defaultThreshold := 0.4
req.RankScoreThreshold = &defaultThreshold
}

request := http.Request{}
header := t.sdkConfig.AuthHeaderV2()
serviceURL, err := t.sdkConfig.ServiceURLV2("/knowledgebases/query")
Expand Down
64 changes: 44 additions & 20 deletions go/appbuilder/knowledge_base_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@

package appbuilder

type QueryType string

const (
Fulltext QueryType = "fulltext"
Semantic QueryType = "semantic"
Hybrid QueryType = "hybrid"
)

const (
ContentTypeRawText = "raw_text"
ContentTypeQA = "qa"
Expand Down Expand Up @@ -277,6 +285,18 @@ type ElasticSearchRetrieveConfig struct {
Top int `json:"top"`
}

type VectorDBRetrieveConfig struct {
Name string `json:"name"`
Type string `json:"type"`
Threshold float64 `json:"threshold"`
Top int `json:"top"`
}

type SmallToBigConfig struct {
Name string `json:"name"`
Type string `json:"type"`
}

type RankingConfig struct {
Name string `json:"name"`
Type string `json:"type"`
Expand All @@ -291,13 +311,14 @@ type QueryPipelineConfig struct {
}

type QueryKnowledgeBaseRequest struct {
Query string `json:"query"`
KnowledgebaseIDs []string `json:"knowledgebase_ids"`
Type *string `json:"type,omitempty"`
Top int `json:"top,omitempty"`
Skip int `json:"skip,omitempty"`
MetadataFileters MetadataFilters `json:"metadata_fileters,omitempty"`
PipelineConfig QueryPipelineConfig `json:"pipeline_config,omitempty"`
Query string `json:"query"`
KnowledgebaseIDs []string `json:"knowledgebase_ids"`
Type *QueryType `json:"type,omitempty"`
Top int `json:"top,omitempty"`
Skip int `json:"skip,omitempty"`
RankScoreThreshold *float64 `json:"rank_score_threshold,omitempty"`
MetadataFileters MetadataFilters `json:"metadata_fileters,omitempty"`
PipelineConfig QueryPipelineConfig `json:"pipeline_config,omitempty"`
}

type RowLine struct {
Expand All @@ -314,19 +335,22 @@ type ChunkLocation struct {
}

type Chunk struct {
ChunkID string `json:"chunk_id"`
KnowledgebaseID string `json:"knowledgebase_id"`
DocumentID string `json:"document_id"`
DocumentName string `json:"document_name"`
Meta map[string]any `json:"meta"`
Type string `json:"type"`
Content string `json:"content"`
CreateTime string `json:"create_time"`
UpdateTime string `json:"update_time"`
RetrievalScore float64 `json:"retrieval_score"`
RankScore float64 `json:"rank_score"`
Locations ChunkLocation `json:"locations"`
Children []Chunk `json:"children"`
ChunkID string `json:"chunk_id"`
KnowledgebaseID string `json:"knowledgebase_id"`
DocumentID string `json:"document_id"`
DocumentName string `json:"document_name"`
Meta map[string]any `json:"meta"`
Type string `json:"type"`
Content string `json:"content"`
CreateTime string `json:"create_time"`
UpdateTime string `json:"update_time"`
RetrievalScore float64 `json:"retrieval_score"`
RankScore float64 `json:"rank_score"`
Locations ChunkLocation `json:"locations"`
Children []Chunk `json:"children"`
NeighbourChunks []Chunk `json:"neighbour_chunks"`
OriginalChunkId string `json:"original_chunk_id"`
OriginalChunkOffset int `json:"original_chunk_offset"`
}

type QueryKnowledgeBaseResponse struct {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -734,12 +734,12 @@ public QueryKnowledgeBaseResponse queryKnowledgeBase(QueryKnowledgeBaseRequest r
return respBody;
}

public QueryKnowledgeBaseResponse queryKnowledgeBase(String query, String type, Integer top, Integer skip,
public QueryKnowledgeBaseResponse queryKnowledgeBase(String query, String type, float rank_score_threshold, Integer top, Integer skip,
String[] knowledgebaseIDs, QueryKnowledgeBaseRequest.MetadataFilters filters,
QueryKnowledgeBaseRequest.QueryPipelineConfig pipelineConfig)
throws IOException, AppBuilderServerException {
String url = AppBuilderConfig.QUERY_KNOWLEDGEBASE_URL;
QueryKnowledgeBaseRequest request = new QueryKnowledgeBaseRequest(query, type, top, skip, knowledgebaseIDs, filters, pipelineConfig);
QueryKnowledgeBaseRequest request = new QueryKnowledgeBaseRequest(query, type, rank_score_threshold,top, skip, knowledgebaseIDs, filters, pipelineConfig);
String jsonBody = JsonUtils.serialize(request);
ClassicHttpRequest postRequest = httpClient.createPostRequestV2(url,
new StringEntity(jsonBody, StandardCharsets.UTF_8));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,19 @@
public class QueryKnowledgeBaseRequest {
private String query;
private String type;
private float rank_score_threshold = 0.4f;
private Integer top;
private Integer skip;
private String[] knowledgebase_ids;
private MetadataFilters metadata_filters;
private QueryPipelineConfig pipeline_config;

public QueryKnowledgeBaseRequest(String query, String type, Integer top, Integer skip,
public QueryKnowledgeBaseRequest(String query, String type, float rank_score_threshold, Integer top, Integer skip,
String[] knowledgebase_ids, MetadataFilters metadata_filters,
QueryPipelineConfig pipeline_config) {
this.query = query;
this.type = type;
this.rank_score_threshold = rank_score_threshold;
this.top = top;
this.skip = skip;
this.knowledgebase_ids = knowledgebase_ids;
Expand All @@ -39,6 +41,14 @@ public void setType(String type) {
this.type = type;
}

public float getRank_score_threshold() {
return rank_score_threshold;
}

public void setRank_score_threshold(float rank_score_threshold) {
this.rank_score_threshold = rank_score_threshold;
}

public Integer getTop() {
return top;
}
Expand Down Expand Up @@ -217,6 +227,46 @@ public void setTop(Integer top) {
}
}

public static class VectorDBRetrieveConfig {
private String name;
private String type;
private Double threshold;
private Integer top;

public String getName() {
return name;
}

public void setName(String name) {
this.name = name;
}

public String getType() {
return type;
}

public void setType(String type) {
this.type = type;
}

public Double getThreshold() {
return threshold;
}

public void setThreshold(Double threshold) {
this.threshold = threshold;
}

public Integer getTop() {
return top;
}

public void setTop(Integer top) {
this.top = top;
}
}


public static class RankingConfig {
private String name;
private String type;
Expand Down Expand Up @@ -265,6 +315,28 @@ public void setTop(Integer top) {
}
}

public static class SmallToBigConfig {
private String name;
private String type;

public String getName() {
return name;
}

public void setName(String name) {
this.name = name;
}

public String getType() {
return type;
}

public void setType(String type) {
this.type = type;
}

}

public static class QueryPipelineConfig {
private String id;
private List<Object> pipeline;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ public static class Chunk {
private float rank_score;
private ChunkLocation locations;
private List<Chunk> children;
private List<Chunk> neighbour_chunks;
private String original_chunk_id;
private Integer original_chunk_offset;

public String getChunk_id() { return chunk_id; }

Expand Down Expand Up @@ -96,6 +99,18 @@ public static class Chunk {
public List<Chunk> getChildren() { return children; }

public void setChildren(List<Chunk> children) { this.children = children; }

public List<Chunk> getNeighbour_chunks() { return neighbour_chunks; }

public void setNeighbour_chunks(List<Chunk> neighbour_chunks) { this.neighbour_chunks = neighbour_chunks; }

public String getOriginal_chunk_id() { return original_chunk_id; }

public void setOriginal_chunk_id(String original_chunk_id) { this.original_chunk_id = original_chunk_id; }

public Integer getOriginal_chunk_offset() { return original_chunk_offset; }

public void setOriginal_chunk_offset(Integer original_chunk_offset) { this.original_chunk_offset = original_chunk_offset; }
}

public static class ChunkLocation {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ public void testQueryKnowledgeBaseV2() throws IOException, AppBuilderServerExcep
Files.readAllBytes(Paths.get("src/test/java/com/baidubce/appbuilder/files/query_knowledgebase.json")));
QueryKnowledgeBaseRequest request = gson.fromJson(requestJson, QueryKnowledgeBaseRequest.class);
QueryKnowledgeBaseResponse response = knowledgebase.queryKnowledgeBase(request.getQuery(),
request.getType(), request.getTop(), request.getSkip(),
request.getType(), request.getRank_score_threshold(), request.getTop(), request.getSkip(),
request.getKnowledgebase_ids(), request.getMetadata_filters(), request.getPipeline_config());
assertNotNull(response.getChunks().get(0).getChunk_id());
}
Expand Down
29 changes: 25 additions & 4 deletions python/core/console/knowledge_base/data_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from __future__ import annotations
from datetime import datetime
from pydantic import BaseModel, Field
from enum import Enum
from typing import Union, Optional, List


Expand Down Expand Up @@ -323,13 +324,28 @@ class PreRankingConfig(BaseModel):
None, description="得分归一化参数,不建议修改,默认50"
)

class QueryType(str, Enum):
FULLTEXT = "fulltext" # 全文检索
SEMANTIC = "semantic" # 语义检索
HYBRID = "hybrid" # 混合检索

class ElasticSearchRetrieveConfig(BaseModel):
class ElasticSearchRetrieveConfig(BaseModel): # 托管资源为共享资源 或 BES资源时使用该配置
name: str = Field(..., description="配置名称")
type: str = Field(None, description="elastic_search标志,该节点为es全文检索")
threshold: float = Field(None, description="得分阈值,默认0.1")
top: int = Field(None, description="召回数量,默认400")

class VectorDBRetrieveConfig(BaseModel):
name: str = Field(..., description="该节点的自定义名称。")
type: str = Field("vector_db", description="该节点的类型,默认为vector_db。")
threshold: Optional[float] = Field(0.1, description="得分阈值。取值范围:[0, 1]", ge=0.0, le=1.0)
top: Optional[int] = Field(400, description="召回数量。取值范围:[0, 800]", ge=0, le=800)
pre_ranking: Optional[PreRankingConfig] = Field(None, description="粗排配置")

class SmallToBigConfig(BaseModel):
name: str = Field(..., description="配置名称")
type: str = Field("small_to_big", description="small_to_big标志,该节点为small_to_big节点")


class RankingConfig(BaseModel):
name: str = Field(..., description="配置名称")
Expand All @@ -341,24 +357,29 @@ class RankingConfig(BaseModel):
model_name: str = Field(None, description="ranking模型名(当前仅一种,暂不生效)")
top: int = Field(None, description="取切片top进行排序,默认20,最大400")


class QueryPipelineConfig(BaseModel):
id: str = Field(
None, description="配置唯一标识,如果用这个id,则引用已经配置好的QueryPipeline"
)
pipeline: list[Union[ElasticSearchRetrieveConfig, RankingConfig]] = Field(
pipeline: list[Union[ElasticSearchRetrieveConfig, RankingConfig, VectorDBRetrieveConfig, SmallToBigConfig]] = Field(
None, description="配置的Pipeline,如果没有用id,可以用这个对象指定一个新的配置"
)


class QueryKnowledgeBaseRequest(BaseModel):
query: str = Field(..., description="检索query")
type: str = Field(None, description="检索策略的枚举, fulltext:全文检索")
type: Optional[QueryType] = Field(None, description="检索策略的枚举, fulltext:全文检索, semantic:语义检索, hybrid:混合检索")
top: int = Field(None, description="返回结果数量")
skip: int = Field(
None,
description="跳过多少条记录, 通过top和skip可以实现类似分页的效果,比如top 10 skip 0,取第一页的10个,top 10 skip 10,取第二页的10个",
)
rank_score_threshold: float = Field(
0.4,
description="重排序匹配分阈值,只有rank_score大于等于该分值的切片重排序时才会被筛选出来。当且仅当,pipeline_config中配置了ranking节点时,该过滤条件生效。取值范围: [0, 1]。",
ge=0.0,
le=1.0,
)
knowledgebase_ids: list[str] = Field(..., description="知识库ID列表")
metadata_filters: MetadataFilters = Field(None, description="元数据过滤条件")
pipeline_config: QueryPipelineConfig = Field(None, description="检索配置")
Expand Down
6 changes: 4 additions & 2 deletions python/core/console/knowledge_base/knowledge_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,10 +909,11 @@ def query_knowledge_base(
self,
query: str,
knowledgebase_ids: list[str],
type: str = None,
type: Optional[data_class.QueryType] = None,
metadata_filters: data_class.MetadataFilter = None,
pipeline_config: data_class.QueryPipelineConfig = None,
top: int = None,
rank_score_threshold: Optional[float] = 0.4,
top: int = 6,
skip: int = None,
) -> data_class.QueryKnowledgeBaseResponse:
"""
Expand All @@ -934,6 +935,7 @@ def query_knowledge_base(
type=type,
metadata_filters=metadata_filters,
pipeline_config=pipeline_config,
rank_score_threshold=rank_score_threshold,
top=top,
skip=skip,
)
Expand Down

0 comments on commit 73ef777

Please sign in to comment.