在ChiikaCode插件中,前端的算法设计主要集中在以下几个方面:
在ChiikaCode插件的前端开发中,项目生成和文件创建是核心功能之一。这个部分的算法设计关注于如何根据后端返回的数据(通常是一个包含目录和文件结构的JSON对象),递归地创建文件夹和文件。
为了生成整个项目结构,我们通常需要递归地遍历每一个节点。在这种情况下,项目结构中可能包含文件夹和文件,我们采用深度优先遍历(DFS)方法来逐一访问每个节点。 每个节点会携带该文件或文件夹的相关信息,比如文件名、类型(文件或文件夹)、内容(对于文件来说是其内容,对于文件夹来说是其子节点)。 DFS遍历确保了文件夹结构和内容能够按正确的顺序创建。首先创建目录,接着在目录下创建文件。如果目录中还有子目录或文件,则继续递归创建。
假如有以下项目结构
Project
│
├── folder1
│ ├── file1.js
│ └── folder2
│ └── file2.js
└── file3.js
- DFS遍历过程
- 首先,遍历到
Project
文件夹。 - 然后进入
folder1
文件夹,继续遍历。 - 在
folder1
内,首先遇到file1.js
,创建该文件。 - 接着进入
folder2
文件夹,再遍历其中的file2.js
。 - 最后,回到
Project
文件夹,创建file3.js
。
- 路径创建
使用 vscode.Uri
创建路径,并调用 createDirectory()
或 createFile()
。
+------------------------+
| Project |
+------------------------+
/ \
+-----------------+ +-----------------+
| folder1 | | file3.js |
+-----------------+ +-----------------+
/ \
+-------------+ +---------------+
| file1.js | | folder2 |
+-------------+ +---------------+
|
+-----------+
| file2.js |
+-----------+
为了确保每个文件和文件夹的路径都是准确无误的,我们使用了vscode.Uri
API来处理路径的拼接和管理。
vscode.Uri.file(path)
方法可以生成一个URI对象,用于表示文件路径,确保路径格式与文件系统兼容。
例如,在创建文件夹时,我们先生成文件夹的URI,然后通过vscode.workspace.fs.createDirectory()
方法创建目录。如果是文件,则通过vscode.workspace.fs.createFile()
来创建文件并写入内容。
在遍历项目结构时,我们根据每个节点的类型(文件夹或文件)来调用不同的创建方法。对于文件夹,我们只需要调用createDirectory()
;对于文件,则需要通过createFile()
和后续的WorkspaceEdit
操作来创建文件并填充内容。
当遇到一个文件夹节点时,系统会首先检查该文件夹是否已存在,如果不存在则创建它。创建文件夹后,系统会继续递归处理该文件夹中的子节点。如果节点是文件,则会创建该文件并写入内容,文件创建的路径也是递归生成的。
前端通过onDidReceiveMessage
和window.addEventListener('message')
监听来自后端的消息,并通过解析不同的消息命令来选择合适的处理逻辑。算法设计中需要考虑消息的高效分发与处理,保证用户交互的流畅性。
在处理完后端请求并收到响应后,前端需要通过事件驱动机制更新UI。我们使用了原生的JavaScript事件机制,并结合前端框架(如HTML的DOM事件),来确保UI在收到数据后及时渲染。例如,生成的代码会通过responseContainer
动态展示,用户交互后会即时反映到UI上。
在这个模块中,我们定义了多种文件类型加载器,并通过继承基类 BaseLoader
来实现文件的读取和处理。每个加载器会根据文件后缀或类型,加载文件内容。
我们为不同的文件类型定义了加载器,例如处理 .txt
, .csv
, .xlsx
等不同文件格式的加载逻辑。
import csv
import os
from openpyxl import load_workbook
from docx import Document as DocxDocument
class BaseLoader:
def __init__(self, file_path):
self.file_path = file_path
def load(self):
raise NotImplementedError("This method should be overridden by subclasses")
# 针对不同文件格式的加载器
class PythonLoader(BaseLoader):
def load(self):
with open(self.file_path, 'r', encoding='utf-8') as file:
return file.read()
class CSVLoader(BaseLoader):
def load(self):
with open(self.file_path, newline='', encoding='utf-8') as csvfile:
reader = csv.reader(csvfile)
return list(reader)
class XlsxLoader(BaseLoader):
def load(self):
workbook = load_workbook(filename=self.file_path)
sheet = workbook.active
data = []
for row in sheet.iter_rows(values_only=True):
data.append(row)
return data
class DocxLoader(BaseLoader):
def load(self):
doc = DocxDocument(self.file_path)
full_text = []
for para in doc.paragraphs:
full_text.append(para.text)
return '\n'.join(full_text)
根据文件后缀,选择合适的加载器来处理不同格式的文件:
def load_file(file_path):
if not os.path.exists(file_path):
raise HTTPException(status_code=400, detail="文件路径不存在")
# 根据文件类型选择不同的加载器
if file_path.endswith('.py'):
loader = PythonLoader(file_path)
elif file_path.endswith('.csv'):
loader = CSVLoader(file_path)
elif file_path.endswith('.xlsx'):
loader = XlsxLoader(file_path)
elif file_path.endswith('.docx'):
loader = DocxLoader(file_path)
else:
raise HTTPException(status_code=400, detail="不支持的文件类型")
return loader.load()
加载完文档后,我们需要对文档进行分割,以便后续的向量化处理:
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document
def split_documents(documents):
if isinstance(documents, str):
documents = [Document(page_content=documents)]
elif isinstance(documents, list) and all(isinstance(item, str) for item in documents):
documents = [Document(page_content=item) for item in documents]
text_splitter = RecursiveCharacterTextSplitter(chunk_size=65536, chunk_overlap=10)
return text_splitter.split_documents(documents)
split_documents
函数利用RecursiveCharacterTextSplitter
将文档分割成适合向量化的块,每个块最多65536
个字符,重叠部分为10
个字符。
在这一步中,我们将文档向量化,存储到向量数据库中,确保后续可以快速检索相关文档。
我们使用 Hugging Face 提供的嵌入模型将文档转换为向量:
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
def get_embedding():
model_name = 'moka-ai/m3e-base'
embedding = HuggingFaceBgeEmbeddings(
model_name=model_name,
model_kwargs={'device': 'cpu'},
encode_kwargs={'normalize_embeddings': True}
)
return embedding
我们使用 Chroma
向量数据库将文档存储为向量:
from langchain_community.vectorstores import Chroma
def build_vector_db(documents, embedding):
persist_directory = 'db'
db = Chroma.from_documents(
documents,
embedding,
persist_directory=persist_directory
)
return db
build_vector_db
将文档向量存储到Chroma
数据库中,并在本地保存数据。
检索器用于从数据库中根据查询获取相关文档:
def build_retriever(db):
return db.as_retriever()
build_retriever
创建了一个基于Chroma
数据库的检索器。
RAG(Retrieval-Augmented Generation)是结合检索与生成的架构,用于生成更为准确的答案。
RAG链会先从数据库中检索相关文档,再将问题和上下文一起传递给大语言模型,生成最终的回答:
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain_community.chat_models import ChatOllama
from langchain.schema.runnable import RunnablePassthrough
def build_rag_chain(retriever):
template = """
根据context详细解释有关question的内容,并给出答案。
Question: {question} Context: {context} Answer: """
prompt = ChatPromptTemplate.from_template(template)
llm = ChatOllama(model='llama3.2:latest')
rag_chain = (
{"context": retriever, "question": RunnablePassthrough()} |
prompt |
llm |
StrOutputParser()
)
return rag_chain
build_rag_chain
方法通过ChatPromptTemplate
和ChatOllama
来构建一个基于检索的生成链条,确保问题和检索的上下文信息能够传递给 LLM。
用户输入问题后,RAG链会根据问题检索数据库中的相关文档,然后生成最终的回答:
@app.post("/ask")
async def ask_question(request: QuestionRequest):
if not rag_chain:
raise HTTPException(status_code=400, detail="RAG链尚未初始化,请先上传文件路径")
question = request.question
answer = ""
async for chunk in rag_chain.astream(question):
answer += chunk
return {"answer": answer}
ask_question
路由使用rag_chain
进行检索和生成回答。
这个模块是整个应用的核心,涉及到动态代码生成、错误检查和代码执行的过程。
基于函数名称、参数和文档字符串生成代码:
from ollama import generate
def getCompleteBody(prefix: str, suffix: str) -> str:
return generate(
model='starcoder2:3b',
prompt=prefix,
suffix=suffix,
options={'num_predict': 256, 'temperature': 1, 'top_p': 0.9}
)['response']
getCompleteBody
函数通过ollama.generate
来生成代码块,基于传入的提示和选项生成补全的代码。
为了确保代码没有语法错误,我们通过 exec
尝试执行生成的代码:
def getNoErrorCode(functionName: str, arguments: List[str], docString: str) -> str:
executable = False
while not executable:
functionString = getOriginalCode(functionName, arguments, docString)
try:
exec(functionString, globals())
executable = True
except Exception as e:
pass
return functionString
getNoErrorCode
会生成代码并尝试执行,如果生成的代码有错误,它会重新生成,直到代码正确为止。
这个模块的主要功能是基于用户的需求和指定的编程语言,动态生成一个项目代码框架,并使用大模型生成代码。它的关键算法流程包括文档加载、代码结构生成、代码生成、项目存储与返回等部分。以下是该模块的算法总结:
当用户请求生成代码时,首先需要根据用户的需求和编程语言来构建代码的结构。这一过程会涉及到对项目结构的获取和解析。文档加载器会通过多个子模块获取项目结构流,并生成相应的节点。核心步骤包括:
- 获取项目结构流:通过
getRawStructureStream
函数,依据用户的需求和所选语言,获取代码项目的基本框架。 - 解析项目结构:将获取到的项目结构字符串解析为
Node
对象,结构化表示文件和目录的层次。
structure_content = "".join(getRawStructureStream(question, language))
node = parseStructureString(structure_content, language)
- 如果解析失败,将抛出异常
500
错误。
一旦项目结构解析完成,接下来是基于结构为每个文件节点生成代码。这一步通过 getRawCodeStream
函数实现,涉及到以下步骤:
- 获取代码生成流:通过用户的需求、项目结构和文件名等信息,构建一个原始字符流,并交给大模型进行生成。模型通过
llm_chain
来处理输入并输出相应的代码块。 - 填充代码框架:在生成的代码框架中填充代码内容。根据项目结构,针对每个文件节点,通过调用
getRawCodeStream
获取代码。
for f_node in node.getFileNodes(lang_exts[language]):
content = "".join(getRawCodeStream(None, node.getStrucureString(), f_node.getPath(), language))
f_node.content = getLongestCodeBlock(content)
- 提取最长的代码块:在所有生成的代码块中,选取最合适的代码段作为文件内容。
getLongestCodeBlock
用来提取最长的有效代码块,确保代码完整性。
代码生成的核心算法是通过调用外部模型(如 llm
)来根据项目结构和需求生成代码。这里的关键点在于如何将用户需求和项目结构拼接到一起,并通过模型生成符合需求的代码:
llm_chain = (
RunnablePassthrough() |
prompt |
llm |
StrOutputParser()
)
RunnablePassthrough
:保证用户输入数据通过管道流畅传递。prompt
:生成用于大模型的提示模板,包含用户需求、代码结构、语言等信息。llm
:大语言模型,根据输入的提示生成代码。StrOutputParser
:解析模型的输出,并将其格式化为最终的代码字符串。
生成的代码框架和内容需要以一定格式返回给用户。最终,代码结构会被转换为 JSON 格式,并存储到文件中。这个过程确保用户能够获得一个标准的项目结构,同时也便于后续查看和使用。
result_json = node.getJsonDict()
with open('result.json', 'w', encoding='utf-8') as f:
json.dump(result_json, f, indent=4)
getJsonDict
:将项目结构(包括文件内容和节点信息)转换为 JSON 格式。json.dump
:将结构化的 JSON 数据写入本地文件,便于后续使用和查看。