Skip to content

Commit

Permalink
SDK支持 rag检索API二期内容
Browse files Browse the repository at this point in the history
  • Loading branch information
yinjiaqi authored and yinjiaqi committed Jan 20, 2025
1 parent ca54970 commit d7880d9
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 27 deletions.
64 changes: 44 additions & 20 deletions go/appbuilder/knowledge_base_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@

package appbuilder

type QueryType string

const (
Fulltext QueryType = "fulltext"
Semantic QueryType = "semantic"
Hybrid QueryType = "hybrid"
)

const (
ContentTypeRawText = "raw_text"
ContentTypeQA = "qa"
Expand Down Expand Up @@ -277,6 +285,18 @@ type ElasticSearchRetrieveConfig struct {
Top int `json:"top"`
}

type VectorDBRetrieveConfig struct {
Name string `json:"name"`
Type string `json:"type"`
Threshold float64 `json:"threshold"`
Top int `json:"top"`
}

type SmallToBigConfig status {
Name string `json:"name"`
Type string `json:"type"`
}

type RankingConfig struct {
Name string `json:"name"`
Type string `json:"type"`
Expand All @@ -291,13 +311,14 @@ type QueryPipelineConfig struct {
}

type QueryKnowledgeBaseRequest struct {
Query string `json:"query"`
KnowledgebaseIDs []string `json:"knowledgebase_ids"`
Type *string `json:"type,omitempty"`
Top int `json:"top,omitempty"`
Skip int `json:"skip,omitempty"`
MetadataFileters MetadataFilters `json:"metadata_fileters,omitempty"`
PipelineConfig QueryPipelineConfig `json:"pipeline_config,omitempty"`
Query string `json:"query"`
KnowledgebaseIDs []string `json:"knowledgebase_ids"`
Type *QueryType `json:"type,omitempty"`
Top int `json:"top,omitempty"`
Skip int `json:"skip,omitempty"`
rank_score_threshold float64 `json:"rank_score_threshold,omitempty"`
MetadataFileters MetadataFilters `json:"metadata_fileters,omitempty"`
PipelineConfig QueryPipelineConfig `json:"pipeline_config,omitempty"`
}

type RowLine struct {
Expand All @@ -314,19 +335,22 @@ type ChunkLocation struct {
}

type Chunk struct {
ChunkID string `json:"chunk_id"`
KnowledgebaseID string `json:"knowledgebase_id"`
DocumentID string `json:"document_id"`
DocumentName string `json:"document_name"`
Meta map[string]any `json:"meta"`
Type string `json:"type"`
Content string `json:"content"`
CreateTime string `json:"create_time"`
UpdateTime string `json:"update_time"`
RetrievalScore float64 `json:"retrieval_score"`
RankScore float64 `json:"rank_score"`
Locations ChunkLocation `json:"locations"`
Children []Chunk `json:"children"`
ChunkID string `json:"chunk_id"`
KnowledgebaseID string `json:"knowledgebase_id"`
DocumentID string `json:"document_id"`
DocumentName string `json:"document_name"`
Meta map[string]any `json:"meta"`
Type string `json:"type"`
Content string `json:"content"`
CreateTime string `json:"create_time"`
UpdateTime string `json:"update_time"`
RetrievalScore float64 `json:"retrieval_score"`
RankScore float64 `json:"rank_score"`
Locations ChunkLocation `json:"locations"`
Children []Chunk `json:"children"`
NeighbourChunks []Chunk `json:"neighbour_chunks"`
OriginalChunkId string `json:"original_chunk_id"`
OriginalChunkOffset int `json:"original_chunk_offset"`
}

type QueryKnowledgeBaseResponse struct {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,24 @@

import java.util.List;

import com.baidubce.appbuilder.model.knowledgebase.QueryKnowledgeBaseRequest.PostRankingConfig.QueryPipelineConfig;

public class QueryKnowledgeBaseRequest {
private String query;
private String type;
private float rank_score_threshold;
private Integer top;
private Integer skip;
private String[] knowledgebase_ids;
private MetadataFilters metadata_filters;
private QueryPipelineConfig pipeline_config;

public QueryKnowledgeBaseRequest(String query, String type, Integer top, Integer skip,
public QueryKnowledgeBaseRequest(String query, String type, float rank_score_threshold, Integer top, Integer skip,
String[] knowledgebase_ids, MetadataFilters metadata_filters,
QueryPipelineConfig pipeline_config) {
this.query = query;
this.type = type;
this.rank_score_threshold = rank_score_threshold;
this.top = top;
this.skip = skip;
this.knowledgebase_ids = knowledgebase_ids;
Expand All @@ -39,6 +43,14 @@ public void setType(String type) {
this.type = type;
}

public float getRank_score_threshold() {
return rank_score_threshold;
}

public void setRank_score_threshold(float rank_score_threshold) {
this.rank_score_threshold = rank_score_threshold;
}

public Integer getTop() {
return top;
}
Expand Down Expand Up @@ -217,6 +229,46 @@ public void setTop(Integer top) {
}
}

public static class VectorDBRetrieveConfig {
private String name;
private String type;
private Double threshold;
private Integer top;

public String getName() {
return name;
}

public void setName(String name) {
this.name = name;
}

public String getType() {
return type;
}

public void setType(String type) {
this.type = type;
}

public Double getThreshold() {
return threshold;
}

public void setThreshold(Double threshold) {
this.threshold = threshold;
}

public Integer getTop() {
return top;
}

public void setTop(Integer top) {
this.top = top;
}
}


public static class RankingConfig {
private String name;
private String type;
Expand Down Expand Up @@ -265,6 +317,28 @@ public void setTop(Integer top) {
}
}

public static class SmallToBigConfig {
private String name;
private String type;

public String getName() {
return name;
}

public void setName(String name) {
this.name = name;
}

public String getType() {
return type;
}

public void setType(String type) {
this.type = type;
}

}

public static class QueryPipelineConfig {
private String id;
private List<Object> pipeline;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ public static class Chunk {
private float rank_score;
private ChunkLocation locations;
private List<Chunk> children;
private List<Chunk> neighbour_chunks;
private String original_chunk_id;
private Integer original_chunk_offset;

public String getChunk_id() { return chunk_id; }

Expand Down Expand Up @@ -96,6 +99,18 @@ public static class Chunk {
public List<Chunk> getChildren() { return children; }

public void setChildren(List<Chunk> children) { this.children = children; }

public List<Chunk> getNeighbour_chunks() { return neighbour_chunks; }

public void setNeighbour_chunks(List<Chunk> neighbour_chunks) { this.neighbour_chunks = neighbour_chunks; }

public String getOriginal_chunk_id() { return original_chunk_id; }

public void setOriginal_chunk_id(String original_chunk_id) { this.original_chunk_id = original_chunk_id; }

public Integer getOriginal_chunk_offset() { return original_chunk_offset; }

public void setOriginal_chunk_offset(Integer original_chunk_offset) { this.original_chunk_offset = original_chunk_offset; }
}

public static class ChunkLocation {
Expand Down
29 changes: 25 additions & 4 deletions python/core/console/knowledge_base/data_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from __future__ import annotations
from datetime import datetime
from pydantic import BaseModel, Field
from enum import Enum
from typing import Union, Optional, List


Expand Down Expand Up @@ -323,13 +324,28 @@ class PreRankingConfig(BaseModel):
None, description="得分归一化参数,不建议修改,默认50"
)

class QueryType(str, Enum):
FULLTEXT = "fulltext" # 全文检索
SEMANTIC = "semantic" # 语义检索
HYBRID = "hybrid" # 混合检索

class ElasticSearchRetrieveConfig(BaseModel):
class ElasticSearchRetrieveConfig(BaseModel): # 托管资源为共享资源 或 BES资源时使用该配置
name: str = Field(..., description="配置名称")
type: str = Field(None, description="elastic_search标志,该节点为es全文检索")
threshold: float = Field(None, description="得分阈值,默认0.1")
top: int = Field(None, description="召回数量,默认400")

class VectorDBRetrieveConfig(BaseModel):
name: str = Field(..., description="该节点的自定义名称。")
type: str = Field("vector_db", description="该节点的类型,默认为vector_db。")
threshold: Optional[float] = Field(0.1, description="得分阈值。取值范围:[0, 1]", ge=0.0, le=1.0)
top: Optional[int] = Field(400, description="召回数量。取值范围:[0, 800]", ge=0, le=800)
pre_ranking: Optional[PreRankingConfig] = Field(None, description="粗排配置")

class SmallToBigConfig(BaseModel):
name: str = Field(..., description="配置名称")
type: str = Field("small_to_big", description="small_to_big标志,该节点为small_to_big节点")


class RankingConfig(BaseModel):
name: str = Field(..., description="配置名称")
Expand All @@ -341,24 +357,29 @@ class RankingConfig(BaseModel):
model_name: str = Field(None, description="ranking模型名(当前仅一种,暂不生效)")
top: int = Field(None, description="取切片top进行排序,默认20,最大400")


class QueryPipelineConfig(BaseModel):
id: str = Field(
None, description="配置唯一标识,如果用这个id,则引用已经配置好的QueryPipeline"
)
pipeline: list[Union[ElasticSearchRetrieveConfig, RankingConfig]] = Field(
pipeline: list[Union[ElasticSearchRetrieveConfig, RankingConfig, VectorDBRetrieveConfig, SmallToBigConfig]] = Field(
None, description="配置的Pipeline,如果没有用id,可以用这个对象指定一个新的配置"
)


class QueryKnowledgeBaseRequest(BaseModel):
query: str = Field(..., description="检索query")
type: str = Field(None, description="检索策略的枚举, fulltext:全文检索")
type: Optional[QueryType] = Field(None, description="检索策略的枚举, fulltext:全文检索, semantic:语义检索, hybrid:混合检索")
top: int = Field(None, description="返回结果数量")
skip: int = Field(
None,
description="跳过多少条记录, 通过top和skip可以实现类似分页的效果,比如top 10 skip 0,取第一页的10个,top 10 skip 10,取第二页的10个",
)
rank_score_threshold: float = Field(
0.4,
description="重排序匹配分阈值,只有rank_score大于等于该分值的切片重排序时才会被筛选出来。当且仅当,pipeline_config中配置了ranking节点时,该过滤条件生效。取值范围: [0, 1]。",
ge=0.0,
le=1.0,
)
knowledgebase_ids: list[str] = Field(..., description="知识库ID列表")
metadata_filters: MetadataFilters = Field(None, description="元数据过滤条件")
pipeline_config: QueryPipelineConfig = Field(None, description="检索配置")
Expand Down
6 changes: 4 additions & 2 deletions python/core/console/knowledge_base/knowledge_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,10 +909,11 @@ def query_knowledge_base(
self,
query: str,
knowledgebase_ids: list[str],
type: str = None,
type: Optional[data_class.QueryType] = None,
metadata_filters: data_class.MetadataFilter = None,
pipeline_config: data_class.QueryPipelineConfig = None,
top: int = None,
rank_score_threshold: float = None,
top: int = 6,
skip: int = None,
) -> data_class.QueryKnowledgeBaseResponse:
"""
Expand All @@ -934,6 +935,7 @@ def query_knowledge_base(
type=type,
metadata_filters=metadata_filters,
pipeline_config=pipeline_config,
rank_score_threshold=rank_score_threshold,
top=top,
skip=skip,
)
Expand Down

0 comments on commit d7880d9

Please sign in to comment.