Skip to content

Latest commit

 

History

History
389 lines (281 loc) · 14.9 KB

关键算法设计报告.md

File metadata and controls

389 lines (281 loc) · 14.9 KB

ChiikaCode 关键算法设计报告

一、前端算法设计

1. 前端算法设计

在ChiikaCode插件中,前端的算法设计主要集中在以下几个方面:

1.1. 项目生成与文件创建

在ChiikaCode插件的前端开发中,项目生成和文件创建是核心功能之一。这个部分的算法设计关注于如何根据后端返回的数据(通常是一个包含目录和文件结构的JSON对象),递归地创建文件夹和文件。

深度优先遍历(DFS)

为了生成整个项目结构,我们通常需要递归地遍历每一个节点。在这种情况下,项目结构中可能包含文件夹和文件,我们采用深度优先遍历(DFS)方法来逐一访问每个节点。 每个节点会携带该文件或文件夹的相关信息,比如文件名、类型(文件或文件夹)、内容(对于文件来说是其内容,对于文件夹来说是其子节点)。 DFS遍历确保了文件夹结构和内容能够按正确的顺序创建。首先创建目录,接着在目录下创建文件。如果目录中还有子目录或文件,则继续递归创建。

假如有以下项目结构

Project
│
├── folder1
│   ├── file1.js
│   └── folder2
│       └── file2.js
└── file3.js
  • DFS遍历过程
  1. 首先,遍历到 Project 文件夹。
  2. 然后进入 folder1 文件夹,继续遍历。
  3. folder1 内,首先遇到 file1.js,创建该文件。
  4. 接着进入 folder2 文件夹,再遍历其中的 file2.js
  5. 最后,回到 Project 文件夹,创建 file3.js
  • 路径创建

使用 vscode.Uri 创建路径,并调用 createDirectory()createFile()

              +------------------------+
              |        Project          |
              +------------------------+
                     /        \
        +-----------------+    +-----------------+
        |    folder1      |    |     file3.js    |
        +-----------------+    +-----------------+
               /     \
      +-------------+   +---------------+
      |   file1.js  |   |    folder2     |
      +-------------+   +---------------+
                           |
                      +-----------+
                      |  file2.js |
                      +-----------+
使用vscode.Uri进行路径管理

为了确保每个文件和文件夹的路径都是准确无误的,我们使用了vscode.Uri API来处理路径的拼接和管理。 vscode.Uri.file(path)方法可以生成一个URI对象,用于表示文件路径,确保路径格式与文件系统兼容。 例如,在创建文件夹时,我们先生成文件夹的URI,然后通过vscode.workspace.fs.createDirectory()方法创建目录。如果是文件,则通过vscode.workspace.fs.createFile()来创建文件并写入内容。

根据节点类型选择创建方法

在遍历项目结构时,我们根据每个节点的类型(文件夹或文件)来调用不同的创建方法。对于文件夹,我们只需要调用createDirectory();对于文件,则需要通过createFile()和后续的WorkspaceEdit操作来创建文件并填充内容。

递归创建目录和文件

当遇到一个文件夹节点时,系统会首先检查该文件夹是否已存在,如果不存在则创建它。创建文件夹后,系统会继续递归处理该文件夹中的子节点。如果节点是文件,则会创建该文件并写入内容,文件创建的路径也是递归生成的。

1.2. 消息处理机制

前端通过onDidReceiveMessagewindow.addEventListener('message')监听来自后端的消息,并通过解析不同的消息命令来选择合适的处理逻辑。算法设计中需要考虑消息的高效分发与处理,保证用户交互的流畅性。

1.3. UI更新与数据绑定

在处理完后端请求并收到响应后,前端需要通过事件驱动机制更新UI。我们使用了原生的JavaScript事件机制,并结合前端框架(如HTML的DOM事件),来确保UI在收到数据后及时渲染。例如,生成的代码会通过responseContainer动态展示,用户交互后会即时反映到UI上。

二、后端关键算法设计

1. 文档加载与处理

在这个模块中,我们定义了多种文件类型加载器,并通过继承基类 BaseLoader 来实现文件的读取和处理。每个加载器会根据文件后缀或类型,加载文件内容。

1.1 定义文件加载器

我们为不同的文件类型定义了加载器,例如处理 .txt, .csv, .xlsx 等不同文件格式的加载逻辑。

import csv
import os
from openpyxl import load_workbook
from docx import Document as DocxDocument

class BaseLoader:
    def __init__(self, file_path):
        self.file_path = file_path

    def load(self):
        raise NotImplementedError("This method should be overridden by subclasses")

# 针对不同文件格式的加载器
class PythonLoader(BaseLoader):
    def load(self):
        with open(self.file_path, 'r', encoding='utf-8') as file:
            return file.read()

class CSVLoader(BaseLoader):
    def load(self):
        with open(self.file_path, newline='', encoding='utf-8') as csvfile:
            reader = csv.reader(csvfile)
            return list(reader)

class XlsxLoader(BaseLoader):
    def load(self):
        workbook = load_workbook(filename=self.file_path)
        sheet = workbook.active
        data = []
        for row in sheet.iter_rows(values_only=True):
            data.append(row)
        return data

class DocxLoader(BaseLoader):
    def load(self):
        doc = DocxDocument(self.file_path)
        full_text = []
        for para in doc.paragraphs:
            full_text.append(para.text)
        return '\n'.join(full_text)

1.2 加载文件函数

根据文件后缀,选择合适的加载器来处理不同格式的文件:

def load_file(file_path):
    if not os.path.exists(file_path):
        raise HTTPException(status_code=400, detail="文件路径不存在")

    # 根据文件类型选择不同的加载器
    if file_path.endswith('.py'):
        loader = PythonLoader(file_path)
    elif file_path.endswith('.csv'):
        loader = CSVLoader(file_path)
    elif file_path.endswith('.xlsx'):
        loader = XlsxLoader(file_path)
    elif file_path.endswith('.docx'):
        loader = DocxLoader(file_path)
    else:
        raise HTTPException(status_code=400, detail="不支持的文件类型")
    
    return loader.load()

1.3 文档分割

加载完文档后,我们需要对文档进行分割,以便后续的向量化处理:

from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document

def split_documents(documents):
    if isinstance(documents, str):
        documents = [Document(page_content=documents)]
    elif isinstance(documents, list) and all(isinstance(item, str) for item in documents):
        documents = [Document(page_content=item) for item in documents]

    text_splitter = RecursiveCharacterTextSplitter(chunk_size=65536, chunk_overlap=10)
    return text_splitter.split_documents(documents)
  • split_documents 函数利用 RecursiveCharacterTextSplitter 将文档分割成适合向量化的块,每个块最多 65536 个字符,重叠部分为 10 个字符。

2. 向量化与数据库存储

在这一步中,我们将文档向量化,存储到向量数据库中,确保后续可以快速检索相关文档。

2.1 嵌入向量化

我们使用 Hugging Face 提供的嵌入模型将文档转换为向量:

from langchain_community.embeddings import HuggingFaceBgeEmbeddings

def get_embedding():
    model_name = 'moka-ai/m3e-base'
    embedding = HuggingFaceBgeEmbeddings(
        model_name=model_name, 
        model_kwargs={'device': 'cpu'}, 
        encode_kwargs={'normalize_embeddings': True}
    )
    return embedding

2.2 向量数据库的构建

我们使用 Chroma 向量数据库将文档存储为向量:

from langchain_community.vectorstores import Chroma

def build_vector_db(documents, embedding):
    persist_directory = 'db'
    db = Chroma.from_documents(
        documents, 
        embedding, 
        persist_directory=persist_directory
    )
    return db
  • build_vector_db 将文档向量存储到 Chroma 数据库中,并在本地保存数据。

2.3 构建检索器

检索器用于从数据库中根据查询获取相关文档:

def build_retriever(db):
    return db.as_retriever()
  • build_retriever 创建了一个基于 Chroma 数据库的检索器。

3. RAG链生成

RAG(Retrieval-Augmented Generation)是结合检索与生成的架构,用于生成更为准确的答案。

3.1 定义RAG链

RAG链会先从数据库中检索相关文档,再将问题和上下文一起传递给大语言模型,生成最终的回答:

from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain_community.chat_models import ChatOllama
from langchain.schema.runnable import RunnablePassthrough

def build_rag_chain(retriever):
    template = """
        根据context详细解释有关question的内容,并给出答案。
        Question: {question} Context: {context} Answer: """
    prompt = ChatPromptTemplate.from_template(template)
    llm = ChatOllama(model='llama3.2:latest')
    rag_chain = (
        {"context": retriever, "question": RunnablePassthrough()} |
        prompt |
        llm |
        StrOutputParser()
    )
    return rag_chain
  • build_rag_chain 方法通过 ChatPromptTemplateChatOllama 来构建一个基于检索的生成链条,确保问题和检索的上下文信息能够传递给 LLM。

3.2 使用RAG链进行问答

用户输入问题后,RAG链会根据问题检索数据库中的相关文档,然后生成最终的回答:

@app.post("/ask")
async def ask_question(request: QuestionRequest):
    if not rag_chain:
        raise HTTPException(status_code=400, detail="RAG链尚未初始化,请先上传文件路径")

    question = request.question
    answer = ""
    async for chunk in rag_chain.astream(question):
        answer += chunk
    return {"answer": answer}
  • ask_question 路由使用 rag_chain 进行检索和生成回答。

4. 代码生成与执行

这个模块是整个应用的核心,涉及到动态代码生成、错误检查和代码执行的过程。

4.1 生成代码

基于函数名称、参数和文档字符串生成代码:

from ollama import generate

def getCompleteBody(prefix: str, suffix: str) -> str:
    return generate(
        model='starcoder2:3b',
        prompt=prefix,
        suffix=suffix,
        options={'num_predict': 256, 'temperature': 1, 'top_p': 0.9}
    )['response']
  • getCompleteBody 函数通过 ollama.generate 来生成代码块,基于传入的提示和选项生成补全的代码。

4.2 保证代码无错误

为了确保代码没有语法错误,我们通过 exec 尝试执行生成的代码:

def getNoErrorCode(functionName: str, arguments: List[str], docString: str) -> str:
    executable = False
    while not executable:
        functionString = getOriginalCode(functionName, arguments, docString)
        try:
            exec(functionString, globals())
            executable = True
        except Exception as e:
            pass
    return functionString
  • getNoErrorCode 会生成代码并尝试执行,如果生成的代码有错误,它会重新生成,直到代码正确为止。

这个模块的主要功能是基于用户的需求和指定的编程语言,动态生成一个项目代码框架,并使用大模型生成代码。它的关键算法流程包括文档加载、代码结构生成、代码生成、项目存储与返回等部分。以下是该模块的算法总结:

5. 项目级代码生成

5.1 文档加载与处理

当用户请求生成代码时,首先需要根据用户的需求和编程语言来构建代码的结构。这一过程会涉及到对项目结构的获取和解析。文档加载器会通过多个子模块获取项目结构流,并生成相应的节点。核心步骤包括:

  • 获取项目结构流:通过 getRawStructureStream 函数,依据用户的需求和所选语言,获取代码项目的基本框架。
  • 解析项目结构:将获取到的项目结构字符串解析为 Node 对象,结构化表示文件和目录的层次。
structure_content = "".join(getRawStructureStream(question, language))
node = parseStructureString(structure_content, language)
  • 如果解析失败,将抛出异常 500 错误。

5.2 代码结构生成

一旦项目结构解析完成,接下来是基于结构为每个文件节点生成代码。这一步通过 getRawCodeStream 函数实现,涉及到以下步骤:

  • 获取代码生成流:通过用户的需求、项目结构和文件名等信息,构建一个原始字符流,并交给大模型进行生成。模型通过 llm_chain 来处理输入并输出相应的代码块。
  • 填充代码框架:在生成的代码框架中填充代码内容。根据项目结构,针对每个文件节点,通过调用 getRawCodeStream 获取代码。
for f_node in node.getFileNodes(lang_exts[language]):
    content = "".join(getRawCodeStream(None, node.getStrucureString(), f_node.getPath(), language))
    f_node.content = getLongestCodeBlock(content)
  • 提取最长的代码块:在所有生成的代码块中,选取最合适的代码段作为文件内容。getLongestCodeBlock 用来提取最长的有效代码块,确保代码完整性。

5.3 代码生成

代码生成的核心算法是通过调用外部模型(如 llm)来根据项目结构和需求生成代码。这里的关键点在于如何将用户需求和项目结构拼接到一起,并通过模型生成符合需求的代码:

llm_chain = (
    RunnablePassthrough() |
    prompt |
    llm |
    StrOutputParser()
)
  • RunnablePassthrough:保证用户输入数据通过管道流畅传递。
  • prompt:生成用于大模型的提示模板,包含用户需求、代码结构、语言等信息。
  • llm:大语言模型,根据输入的提示生成代码。
  • StrOutputParser:解析模型的输出,并将其格式化为最终的代码字符串。

5.4 项目结果输出

生成的代码框架和内容需要以一定格式返回给用户。最终,代码结构会被转换为 JSON 格式,并存储到文件中。这个过程确保用户能够获得一个标准的项目结构,同时也便于后续查看和使用。

result_json = node.getJsonDict()
with open('result.json', 'w', encoding='utf-8') as f:
    json.dump(result_json, f, indent=4)
  • getJsonDict:将项目结构(包括文件内容和节点信息)转换为 JSON 格式。
  • json.dump:将结构化的 JSON 数据写入本地文件,便于后续使用和查看。