Skip to content

Commit

Permalink
Merge branch '0.6_dev' into v0.6_solver
Browse files Browse the repository at this point in the history
  • Loading branch information
royzhao committed Dec 24, 2024
2 parents 7796b61 + f281b57 commit adf8e01
Show file tree
Hide file tree
Showing 197 changed files with 34,512 additions and 64 deletions.
5 changes: 5 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,8 @@ repos:
kag/solver/logic/core_modules/parser/logic_node_parser.py
)$
- repo: https://github.com/pycqa/flake8
rev: 4.0.1
hooks:
- id: flake8
files: ^kag/.*\.py$
4 changes: 4 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
recursive-include kag *
recursive-exclude kag/examples *
global-exclude *.pyc
global-exclude *.pyo
global-exclude *.pyd
global-exclude __pycache__
Empty file added kag/bridge/__init__.py
Empty file.
38 changes: 38 additions & 0 deletions kag/bridge/spg_server_bridge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import json
import kag.interface as interface


class SPGServerBridge:
def __init__(self):
pass

def run_reader(self, config, input_data):
if isinstance(config, str):
config = json.loads(config)
scanner_config = config["scanner"]
reader_config = config["reader"]
scanner = interface.ScannerABC.from_config(scanner_config)
reader = interface.ReaderABC.from_config(reader_config)
chunks = []
for data in scanner.generate(input_data):
chunks += reader.invoke(data)
return chunks

def run_component(self, component_name, component_config, input_data):
if isinstance(component_config, str):
component_config = json.loads(component_config)

cls = getattr(interface, component_name)
instance = cls.from_config(component_config)
return instance.invoke(input_data)
7 changes: 4 additions & 3 deletions kag/builder/component/reader/markdown_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def process_text_with_links(element):
current_content = []

# Traverse all elements
for element in soup.find_all(
all_elements = soup.find_all(
[
"h1",
"h2",
Expand All @@ -147,7 +147,8 @@ def process_text_with_links(element):
"pre",
"code",
]
):
)
for element in all_elements:
if element.name.startswith("h") and not is_in_code_block(element):
# Only process headers that are not in code blocks
# Handle title logic
Expand All @@ -166,7 +167,7 @@ def process_text_with_links(element):
stack[-1].children.append(new_node)
stack.append(new_node)

elif element.name in ["pre", "code"]:
elif element.name in ["code"]:
# Preserve code blocks as is
text = element.get_text()
if text:
Expand Down
47 changes: 19 additions & 28 deletions kag/builder/component/scanner/csv_scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,32 +19,18 @@

@ScannerABC.register("csv")
class CSVScanner(ScannerABC):
"""
A class for reading CSV files and converting them into a list of dictionaries.
This class inherits from `ScannerABC` and provides functionality to read CSV files.
It can either return the entire row as a dictionary or split the row into multiple dictionaries
based on specified columns.
Attributes:
cols (List[str]): A list of column names to be processed. If None, the entire row is returned as a dictionary.
rank (int): The rank of the current process (used for distributed processing).
world_size (int): The total number of processes (used for distributed processing).
"""

def __init__(self, cols: List[str] = None, rank: int = 0, world_size: int = 1):
"""
Initializes the CSVScanner with optional columns, rank, and world size.
Args:
cols (List[str], optional): A list of column names to be processed. Defaults to None.
- If not specified, each row of the CSV file will be returned as a single dictionary.
- If specified, each row will be split into multiple dictionaries, one for each specified column.
rank (int, optional): The rank of the current process. Defaults to None.
world_size (int, optional): The total number of processes. Defaults to None.
"""
def __init__(
self,
header: bool = True,
col_names: List[str] = None,
col_ids: List[int] = None,
rank: int = 0,
world_size: int = 1,
):
super().__init__(rank=rank, world_size=world_size)
self.cols = cols
self.header = header
self.col_names = col_names
self.col_ids = col_ids

@property
def input_types(self) -> Input:
Expand All @@ -65,14 +51,19 @@ def load_data(self, input: Input, **kwargs) -> List[Output]:
Returns:
List[Output]: A list of dictionaries containing the processed data.
"""
data = pd.read_csv(input, dtype=str)
if self.cols is None:
input = self.download_data(input)
if self.header:
data = pd.read_csv(input, dtype=str)
else:
data = pd.read_csv(input, dtype=str, header=None)
col_keys = self.col_names if self.col_names else self.col_ids
if col_keys is None:
return data.to_dict(orient="records")

contents = []
for _, row in data.iterrows():
for k, v in row.items():
if k in self.cols:
if k in col_keys:
v = str(v)
name = v[:5] + "..." + v[-5:]
contents.append(
Expand Down
1 change: 1 addition & 0 deletions kag/builder/component/scanner/json_scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def load_data(self, input: Input, **kwargs) -> List[Output]:
Raises:
ValueError: If there is an error reading the JSON data or if the input is not a valid JSON array or object.
"""
input = self.download_data(input)
try:
if os.path.exists(input):
corpus = self._read_from_file(input)
Expand Down
23 changes: 13 additions & 10 deletions kag/builder/component/scanner/yuque_scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(self, token: str):
@property
def input_types(self) -> Type[Input]:
"""The type of input this Runnable object accepts specified as a type annotation."""
return str
return Union[str, List[str]]

@property
def output_types(self) -> Type[Output]:
Expand Down Expand Up @@ -90,12 +90,15 @@ def load_data(self, input: Input, **kwargs) -> List[Output]:
List[Output]: A list of strings, where each string contains the token and the URL of each document.
"""
url = input
data = self.get_yuque_api_data(url)
if isinstance(data, dict):
# for single yuque doc
return [f"{self.token}@{url}"]
output = []
for item in data:
slug = item["slug"]
output.append(os.path.join(url, slug))
return [f"{self.token}@{url}" for url in output]
if isinstance(url, str):
data = self.get_yuque_api_data(url)
if isinstance(data, dict):
# for single yuque doc
return [f"{self.token}@{url}"]
output = []
for item in data:
slug = item["slug"]
output.append(os.path.join(url, slug))
return [f"{self.token}@{url}" for url in output]
else:
return [f"{self.token}@{x}" for x in url]
6 changes: 3 additions & 3 deletions kag/builder/default_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ def __init__(
self,
reader: ReaderABC,
splitter: SplitterABC,
extractor: ExtractorABC,
vectorizer: VectorizerABC,
writer: SinkWriterABC,
extractor: ExtractorABC = None,
vectorizer: VectorizerABC = None,
writer: SinkWriterABC = None,
post_processor: PostProcessorABC = None,
):
"""
Expand Down
1 change: 1 addition & 0 deletions kag/builder/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def process(data, data_id, data_abstract):

futures = []
print(f"Processing {input}")
success = 0
try:
with ThreadPoolExecutor(self.num_chains) as executor:
for item in self.scanner.generate(input):
Expand Down
12 changes: 12 additions & 0 deletions kag/common/checkpointer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,18 @@ def size(self):

raise NotImplementedError("size not implemented yet.")

def __contains__(self, key):
"""
Defines the behavior of the `in` operator for the object.
Args:
key (str): The key to check for existence in the checkpoint.
Returns:
bool: True if the key exists in the checkpoint, False otherwise.
"""

return self.exists(key)


class CheckpointerManager:
"""
Expand Down
4 changes: 3 additions & 1 deletion kag/examples/2wiki/builder/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import os
import logging
from kag.common.registry import import_modules_from_path

Expand All @@ -27,6 +28,7 @@ def buildKB(file_path):

if __name__ == "__main__":
import_modules_from_path(".")
file_path = "./data/2wiki_sub_corpus.json"
dir_path = os.path.dirname(__file__)
file_path = os.path.join(dir_path, "data/2wiki_sub_corpus.json")

buildKB(file_path)
15 changes: 13 additions & 2 deletions kag/examples/2wiki/solver/evaFor2wiki.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
from kag.common.conf import KAG_CONFIG
from kag.common.registry import import_modules_from_path

from kag.common.checkpointer import CheckpointerManager

logger = logging.getLogger(__name__)


class EvaFor2wiki:

"""
init for kag client
"""
Expand Down Expand Up @@ -43,13 +44,22 @@ def qa(self, query):
def parallelQaAndEvaluate(
self, qaFilePath, resFilePath, threadNum=1, upperLimit=10
):
ckpt = CheckpointerManager.get_checkpointer(
{"type": "zodb", "ckpt_dir": "ckpt"}
)

def process_sample(data):
try:
sample_idx, sample = data
sample_id = sample["_id"]
question = sample["question"]
gold = sample["answer"]
prediction, traceLog = self.qa(question)
if question in ckpt:
print(f"found existing answer to question: {question}")
prediction, traceLog = ckpt.read_from_ckpt(question)
else:
prediction, traceLog = self.qa(question)
ckpt.write_to_ckpt(question, (prediction, traceLog))

evalObj = Evaluate()
metrics = evalObj.getBenchMark([prediction], [gold])
Expand Down Expand Up @@ -107,6 +117,7 @@ def process_sample(data):
res_metrics[item_key] = item_value / total_metrics["processNum"]
else:
res_metrics[item_key] = total_metrics["processNum"]
CheckpointerManager.close()
return res_metrics


Expand Down
16 changes: 12 additions & 4 deletions kag/examples/hotpotqa/solver/evaForHotpotqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from kag.common.conf import KAG_CONFIG
from kag.common.registry import import_modules_from_path

from kag.common.checkpointer import CheckpointerManager

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -39,17 +41,22 @@ def qa(self, query):
def parallelQaAndEvaluate(
self, qaFilePath, resFilePath, threadNum=1, upperLimit=10, run_failed=False
):
ckpt = CheckpointerManager.get_checkpointer(
{"type": "zodb", "ckpt_dir": "ckpt"}
)

def process_sample(data):
try:
sample_idx, sample = data
sample_id = sample["_id"]
question = sample["question"]
gold = sample["answer"]
if "prediction" not in sample.keys():
prediction, traceLog = self.qa(question)
if question in ckpt:
print(f"found existing answer to question: {question}")
prediction, traceLog = ckpt.read_from_ckpt(question)
else:
prediction = sample["prediction"]
traceLog = sample["traceLog"]
prediction, traceLog = self.qa(question)
ckpt.write_to_ckpt(question, (prediction, traceLog))

evaObj = Evaluate()
metrics = evaObj.getBenchMark([prediction], [gold])
Expand Down Expand Up @@ -107,6 +114,7 @@ def process_sample(data):
res_metrics[item_key] = item_value / total_metrics["processNum"]
else:
res_metrics[item_key] = total_metrics["processNum"]
CheckpointerManager.close()
return res_metrics


Expand Down
14 changes: 13 additions & 1 deletion kag/examples/musique/solver/evaForMusique.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from kag.examples.utils import delay_run
from kag.solver.logic.solver_pipeline import SolverPipeline

from kag.common.checkpointer import CheckpointerManager

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -43,13 +45,22 @@ def qaWithoutLogicForm(self, query):
def parallelQaAndEvaluate(
self, qaFilePath, resFilePath, threadNum=1, upperLimit=10
):
ckpt = CheckpointerManager.get_checkpointer(
{"type": "zodb", "ckpt_dir": "ckpt"}
)

def process_sample(data):
try:
sample_idx, sample = data
sample_id = sample["id"]
question = sample["question"]
gold = sample["answer"]
prediction, traceLog = self.qaWithoutLogicForm(question)
if question in ckpt:
print(f"found existing answer to question: {question}")
prediction, traceLog = ckpt.read_from_ckpt(question)
else:
prediction, traceLog = self.qa(question)
ckpt.write_to_ckpt(question, (prediction, traceLog))

evaObj = Evaluate()
metrics = evaObj.getBenchMark([prediction], [gold])
Expand Down Expand Up @@ -107,6 +118,7 @@ def process_sample(data):
res_metrics[item_key] = item_value / total_metrics["processNum"]
else:
res_metrics[item_key] = total_metrics["processNum"]
CheckpointerManager.close()
return res_metrics


Expand Down
Loading

0 comments on commit adf8e01

Please sign in to comment.