Skip to content

Commit

Permalink
(fix)[solver]: language (OpenSPG#19)
Browse files Browse the repository at this point in the history
* fix builder init

* fix language

* fix req
  • Loading branch information
northmachine authored Nov 4, 2024
1 parent 60dc618 commit 7b2cc22
Show file tree
Hide file tree
Showing 18 changed files with 46 additions and 43 deletions.
2 changes: 1 addition & 1 deletion KAG_VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.5-beta1
0.5.2-beta1
2 changes: 1 addition & 1 deletion kag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@


__package_name__ = "openspg-kag"
__version__ = "0.5-beta1"
__version__ = "0.5.2-beta1"

from kag.common.env import init_env

Expand Down
4 changes: 2 additions & 2 deletions kag/solver/common/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from string import Template

from knext.project.client import ProjectClient

from kag.common.llm.client import LLMClient
import logging
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -108,10 +107,11 @@ def __init__(
"""
self.host_addr = kwargs.get("KAG_PROJECT_HOST_ADDR") or os.getenv("KAG_PROJECT_HOST_ADDR")
self.project_id = kwargs.get("KAG_PROJECT_ID") or os.getenv("KAG_PROJECT_ID")
self.config = ProjectClient().get_config(self.project_id)

self._init_llm()
self.biz_scene = kwargs.get("KAG_PROMPT_BIZ_SCENE") or os.getenv("KAG_PROMPT_BIZ_SCENE", "default")
self.language = kwargs.get("KAG_PROMPT_LANGUAGE") or os.getenv("KAG_PROMPT_LANGUAGE", "en")
self.language = self.config.get("prompt").get("language") or os.getenv("KAG_PROMPT_LANGUAGE", "en")


def _init_llm(self):
Expand Down
10 changes: 5 additions & 5 deletions kag/solver/logic/core_modules/lf_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,13 @@ def __init__(self, query: str, project_id: str,
# Initialize executors for different operations.
self.retrieval_executor = RetrievalExecutor(query, self.kg_graph, self.schema, self.kg_retriever,
self.el,
self.dsl_runner, self.debug_info, text_similarity)
self.deduce_executor = DeduceExecutor(query, self.kg_graph, self.schema, self.op_runner, self.debug_info)
self.sort_executor = SortExecutor(query, self.kg_graph, self.schema, self.debug_info)
self.math_executor = MathExecutor(query, self.kg_graph, self.schema, self.debug_info)
self.dsl_runner, self.debug_info, text_similarity,KAG_PROJECT_ID = self.project_id)
self.deduce_executor = DeduceExecutor(query, self.kg_graph, self.schema, self.op_runner, self.debug_info, KAG_PROJECT_ID = self.project_id)
self.sort_executor = SortExecutor(query, self.kg_graph, self.schema, self.debug_info, KAG_PROJECT_ID = self.project_id)
self.math_executor = MathExecutor(query, self.kg_graph, self.schema, self.debug_info, KAG_PROJECT_ID = self.project_id)
self.output_executor = OutputExecutor(query, self.kg_graph, self.schema, self.el,
self.dsl_runner,
self.retrieval_executor.query_one_graph_cache, self.debug_info)
self.retrieval_executor.query_one_graph_cache, self.debug_info, KAG_PROJECT_ID = self.project_id)

self.with_sub_answer = os.getenv("KAG_QA_WITH_SUB_ANSWER", True)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@


class DeduceExecutor(OpExecutor):
def __init__(self, nl_query: str, kg_graph: KgGraph, schema: SchemaUtils, rule_runner: OpRunner, debug_info: dict):
super().__init__(nl_query, kg_graph, schema, debug_info)
def __init__(self, nl_query: str, kg_graph: KgGraph, schema: SchemaUtils, rule_runner: OpRunner, debug_info: dict,
**kwargs):
super().__init__(nl_query, kg_graph, schema, debug_info, **kwargs)
self.KAG_PROJECT_ID = kwargs.get('KAG_PROJECT_ID')
self.rule_runner = rule_runner
self.op_register_map = {
'verify': self.rule_runner.run_verify_op,
Expand All @@ -25,10 +27,10 @@ def __init__(self, nl_query: str, kg_graph: KgGraph, schema: SchemaUtils, rule_r

def _deduce_call(self, node: DeduceNode, req_id: str, param: dict) -> list:
op_mapping = {
'choice': ChoiceOp(self.nl_query, self.kg_graph, self.schema, self.debug_info),
'multiChoice': MultiChoiceOp(self.nl_query, self.kg_graph, self.schema, self.debug_info),
'entailment': EntailmentOp(self.nl_query, self.kg_graph, self.schema, self.debug_info),
'judgement': JudgementOp(self.nl_query, self.kg_graph, self.schema, self.debug_info)
'choice': ChoiceOp(self.nl_query, self.kg_graph, self.schema, self.debug_info,KAG_PROJECT_ID = self.KAG_PROJECT_ID),
'multiChoice': MultiChoiceOp(self.nl_query, self.kg_graph, self.schema, self.debug_info,KAG_PROJECT_ID = self.KAG_PROJECT_ID),
'entailment': EntailmentOp(self.nl_query, self.kg_graph, self.schema, self.debug_info,KAG_PROJECT_ID = self.KAG_PROJECT_ID),
'judgement': JudgementOp(self.nl_query, self.kg_graph, self.schema, self.debug_info,KAG_PROJECT_ID = self.KAG_PROJECT_ID)
}
result = []
for op in node.deduce_ops:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@


class ChoiceOp(OpExecutor):
def __init__(self, nl_query: str, kg_graph: KgGraph, schema: SchemaUtils, debug_info: dict):
super().__init__(nl_query, kg_graph, schema, debug_info)
def __init__(self, nl_query: str, kg_graph: KgGraph, schema: SchemaUtils, debug_info: dict, **kwargs):
super().__init__(nl_query, kg_graph, schema, debug_info, **kwargs)
self.prompt = PromptOp.load(self.biz_scene, "deduce_choice")(
language=self.language
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@


class EntailmentOp(OpExecutor):
def __init__(self, nl_query: str, kg_graph: KgGraph, schema: SchemaUtils, debug_info: dict):
super().__init__(nl_query, kg_graph, schema, debug_info)
def __init__(self, nl_query: str, kg_graph: KgGraph, schema: SchemaUtils, debug_info: dict, **kwargs):
super().__init__(nl_query, kg_graph, schema, debug_info, **kwargs)
self.prompt = PromptOp.load(self.biz_scene, "deduce_entail")(
language=self.language
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@


class JudgementOp(OpExecutor):
def __init__(self, nl_query: str, kg_graph: KgGraph, schema: SchemaUtils, debug_info: dict):
super().__init__(nl_query, kg_graph, schema, debug_info)
def __init__(self, nl_query: str, kg_graph: KgGraph, schema: SchemaUtils, debug_info: dict, **kwargs):
super().__init__(nl_query, kg_graph, schema, debug_info, **kwargs)
self.prompt = PromptOp.load(self.biz_scene, "deduce_judge")(
language=self.language
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@


class MultiChoiceOp(OpExecutor):
def __init__(self, nl_query: str, kg_graph: KgGraph, schema: SchemaUtils, debug_info: dict):
super().__init__(nl_query, kg_graph, schema, debug_info)
def __init__(self, nl_query: str, kg_graph: KgGraph, schema: SchemaUtils, debug_info: dict, **kwargs):
super().__init__(nl_query, kg_graph, schema, debug_info, **kwargs)
self.prompt = PromptOp.load(self.biz_scene, "deduce_multi_choice")(
language=self.language
)
Expand Down
4 changes: 2 additions & 2 deletions kag/solver/logic/core_modules/op_executor/op_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class OpExecutor(KagBaseModule, ABC):
Each subclass must implement the execution and judgment functions.
"""
def __init__(self, nl_query: str, kg_graph: KgGraph, schema: SchemaUtils, debug_info: dict):
def __init__(self, nl_query: str, kg_graph: KgGraph, schema: SchemaUtils, debug_info: dict, **kwargs):
"""
Initializes the operator executor with necessary components.
Expand All @@ -23,7 +23,7 @@ def __init__(self, nl_query: str, kg_graph: KgGraph, schema: SchemaUtils, debug_
schema (SchemaUtils): Semantic structure definition to assist in the parsing process.
debug_info (dict): Debug information dictionary to record debugging information during parsing.
"""
super().__init__()
super().__init__(**kwargs)
self.kg_graph = kg_graph
self.schema = schema
self.nl_query = nl_query
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@


class MathExecutor(OpExecutor):
def __init__(self, nl_query: str, kg_graph: KgGraph, schema: SchemaUtils, debug_info: dict):
super().__init__(nl_query, kg_graph, schema, debug_info)
def __init__(self, nl_query: str, kg_graph: KgGraph, schema: SchemaUtils, debug_info: dict, **kwargs):
super().__init__(nl_query, kg_graph, schema, debug_info, **kwargs)

def is_this_op(self, logic_node: LogicNode) -> bool:
return isinstance(logic_node, (CountNode, SumNode))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

class GetExecutor(OpExecutor):
def __init__(self, nl_query: str, kg_graph: KgGraph, schema: SchemaUtils, el: EntityLinkerBase,
dsl_runner: DslRunner, cached_map: dict, debug_info: dict):
super().__init__(nl_query, kg_graph, schema, debug_info)
dsl_runner: DslRunner, cached_map: dict, debug_info: dict, **kwargs):
super().__init__(nl_query, kg_graph, schema, debug_info, **kwargs)

self.el = el
self.dsl_runner = dsl_runner
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@


class OutputExecutor(OpExecutor):
def __init__(self, nl_query: str, kg_graph: KgGraph, schema: SchemaUtils, el: EntityLinkerBase, dsl_runner: DslRunner, cached_map: dict, debug_info: dict):
super().__init__(nl_query, kg_graph, schema, debug_info)
def __init__(self, nl_query: str, kg_graph: KgGraph, schema: SchemaUtils, el: EntityLinkerBase, dsl_runner: DslRunner, cached_map: dict, debug_info: dict, **kwargs):
super().__init__(nl_query, kg_graph, schema, debug_info, **kwargs)
self.KAG_PROJECT_ID = kwargs.get('KAG_PROJECT_ID')
self.op_register_map = {
'get': GetExecutor(nl_query, kg_graph, schema, el, dsl_runner, cached_map, self.debug_info)
'get': GetExecutor(nl_query, kg_graph, schema, el, dsl_runner, cached_map, self.debug_info,KAG_PROJECT_ID = kwargs.get('KAG_PROJECT_ID'))
}

def is_this_op(self, logic_node: LogicNode) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class GetSPOExecutor(OpExecutor):
"""
def __init__(self, nl_query: str, kg_graph: KgGraph, schema: SchemaUtils, retrieval_spo: KGRetrieverABC,
el: EntityLinkerBase,
dsl_runner: DslRunner, query_one_graph_cache: dict, debug_info: dict, text_similarity: TextSimilarity=None):
dsl_runner: DslRunner, query_one_graph_cache: dict, debug_info: dict, text_similarity: TextSimilarity=None,**kwargs):
"""
Initializes the GetSPOExecutor with necessary components.
Expand All @@ -39,7 +39,7 @@ def __init__(self, nl_query: str, kg_graph: KgGraph, schema: SchemaUtils, retrie
query_one_graph_cache (dict): Cache for storing results of one-hop graph queries.
debug_info (dict): Debug information dictionary to record debugging information during parsing.
"""
super().__init__(nl_query, kg_graph, schema, debug_info)
super().__init__(nl_query, kg_graph, schema, debug_info, **kwargs)
self.retrieval_spo = retrieval_spo
self.dsl_runner = dsl_runner
self.query_one_graph_cache = query_one_graph_cache
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@


class SearchS(OpExecutor):
def __init__(self, nl_query: str, kg_graph: KgGraph, schema: SchemaUtils, debug_info: dict):
super().__init__(nl_query, kg_graph, schema, debug_info)
def __init__(self, nl_query: str, kg_graph: KgGraph, schema: SchemaUtils, debug_info: dict, **kwargs):
super().__init__(nl_query, kg_graph, schema, debug_info, **kwargs)

def executor(self, logic_node: LogicNode, req_id: str, param: dict) -> KgGraph:
raise NotImplementedError("search s not impl")
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@

class RetrievalExecutor(OpExecutor):
def __init__(self, nl_query: str, kg_graph: KgGraph, schema: SchemaUtils, retrieval_spo: KGRetrieverABC, el: EntityLinkerBase,
dsl_runner: DslRunner, debug_info: dict, text_similarity: TextSimilarity=None):
super().__init__(nl_query, kg_graph, schema, debug_info)
dsl_runner: DslRunner, debug_info: dict, text_similarity: TextSimilarity=None,**kwargs):
super().__init__(nl_query, kg_graph, schema, debug_info, **kwargs)
self.query_one_graph_cache = {}
self.op_register_map = {
'get_spo': GetSPOExecutor(nl_query, kg_graph, schema, retrieval_spo, el, dsl_runner, self.query_one_graph_cache, self.debug_info, text_similarity),
'search_s': SearchS(nl_query, kg_graph, schema, self.debug_info)
'get_spo': GetSPOExecutor(nl_query, kg_graph, schema, retrieval_spo, el, dsl_runner, self.query_one_graph_cache, self.debug_info, text_similarity,KAG_PROJECT_ID = kwargs.get('KAG_PROJECT_ID')),
'search_s': SearchS(nl_query, kg_graph, schema, self.debug_info,KAG_PROJECT_ID = kwargs.get('KAG_PROJECT_ID'))
}

def is_this_op(self, logic_node: LogicNode) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@


class SortExecutor(OpExecutor):
def __init__(self, nl_query: str, kg_graph: KgGraph, schema: SchemaUtils, debug_info: dict):
super().__init__(nl_query, kg_graph, schema, debug_info)
def __init__(self, nl_query: str, kg_graph: KgGraph, schema: SchemaUtils, debug_info: dict, **kwargs):
super().__init__(nl_query, kg_graph, schema, debug_info, **kwargs)

def is_this_op(self, logic_node: LogicNode) -> bool:
return isinstance(logic_node, SortNode)
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,5 @@ openai
python-docx
charset_normalizer==3.3.2
pdfminer.six==20231228
openspg-knext==0.5b1
openspg-knext==0.5.2b1
ollama

0 comments on commit 7b2cc22

Please sign in to comment.