diff --git a/circinus/agent.py b/circinus/agent.py index 73e9835..a837067 100644 --- a/circinus/agent.py +++ b/circinus/agent.py @@ -1,12 +1,60 @@ from pathlib import Path +from attrs import define + from circinus.fuzzer import UserInput, candidate_prompt, fuzzing_loop +from circinus.knowledge_base import ChunkDTO, init_knowledge_base from circinus.llm import GPT +@define(hash=True) +class Source: + file: str + page: str + text: str + + @staticmethod + def curate_sources(sources: list[ChunkDTO]) -> set['Source']: + curated_sources = set() + + for chunk in sources: + doc_metadata = chunk.document.doc_metadata + + file_name = doc_metadata.get('file_name', '-') if doc_metadata else '-' + page_label = doc_metadata.get('page_label', '-') if doc_metadata else '-' + + source = Source(file=file_name, page=page_label, text=chunk.text) + curated_sources.add(source) + + return curated_sources + + def __str__(self): + return '\n'.join([ + f'{self.file=}', + f'{self.page=}', + f'{self.text=}' + ]) + + class Agent: - def __init__(self): + def __init__(self, config): self._llm = GPT() + self.knowledge_base = init_knowledge_base(config) + + def search_in_docs(self): + pass + + def add_docs(self, file_name: str, file_data: str | bytes) -> None: + self.knowledge_base.add(file_name, file_data) + + def query_docs(self, message: str): + return Source.curate_sources( + sources=self.knowledge_base.retrieve( + text=message, + limit=4, + prev_next_chunks=0, + ), + ) def fuzz(self, documentation: str, specification: str, code: str): prompts = candidate_prompt( diff --git a/circinus/knowledge_base.py b/circinus/knowledge_base.py index db0a2e3..b938606 100644 --- a/circinus/knowledge_base.py +++ b/circinus/knowledge_base.py @@ -19,6 +19,12 @@ from llama_index.storage.docstore import BaseDocumentStore from llama_index.storage.index_store.types import BaseIndexStore +from circinus.components import ( + doc_store_component, + embedding_component, + index_store_component, + llm_component, +) from circinus.vector_store import ContextFilter, VectorStore @@ -216,32 +222,11 @@ def retrieve( return retrieved_nodes -def main(): - from circinus.components import ( - doc_store_component, - embedding_component, - index_store_component, - llm_component, - ) - from circinus.settings import load_config - - config = load_config('config.toml') - knowledge_base = KnowledgeBase( +def init_knowledge_base(config) -> KnowledgeBase: + return KnowledgeBase( llm=llm_component(config=config), vector_store_component=VectorStore(), embedding=embedding_component(config=config), document_store=doc_store_component(), index_store=index_store_component(), ) - - docs = knowledge_base.add( - file_name='blob_documentation.txt', - file_data=(Path(__file__).parent / 'blob_documentation.txt').read_text(encoding='utf-8') - ) - - print(docs) - print(knowledge_base.retrieve('Blob.prototype.size')) - - -if __name__ == '__main__': - main()