From e0dfa316c65e9966bc7e82e33f2b7b57be9f755d Mon Sep 17 00:00:00 2001 From: Jett Wang Date: Thu, 23 Nov 2023 12:21:25 +0800 Subject: [PATCH] mindmap update --- main.py | 142 ++++++++++++++++++++++++++++++-------- tests/test.http | 2 +- tools/generate_mindmap.py | 4 +- 3 files changed, 118 insertions(+), 30 deletions(-) diff --git a/main.py b/main.py index c109caa..d632097 100644 --- a/main.py +++ b/main.py @@ -23,6 +23,9 @@ from fastapi import File, UploadFile from fastapi.responses import JSONResponse from graphviz import Digraph +from concurrent.futures import ProcessPoolExecutor +from fastapi import FastAPI, BackgroundTasks +import asyncio load_dotenv() @@ -72,6 +75,13 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) - app.add_middleware(LimitUploadSize, max_upload_size=1024 * 1024 * 10) +executor = ProcessPoolExecutor() + + +async def run_in_process(fn, *args): + loop = asyncio.get_running_loop() + return await loop.run_in_executor(executor, fn, *args) + class TokenData(BaseModel): api_key: str @@ -272,52 +282,130 @@ async def create_image_ocr(file: UploadFile = File(...), td: TokenData = Depends os.unlink(tmp_path) -@app.post("/knowledge/mindmap/create", summary="Create a knowledge base mindmap from params", - description="Generating mind maps from given structured data", include_in_schema=False) -async def create_mindmap(item: MindmapItem, td: bool = Depends(verify_api_key)): +# @app.post("/knowledge/mindmap/create", summary="Create a knowledge base mindmap from params", +# description="Generating mind maps from given structured data", include_in_schema=False) +# async def create_mindmap(item: MindmapItem, td: bool = Depends(verify_api_key)): +# try: +# log.info(f"create_mindmap: {item}") +# # 创建并构建思维导图 +# graph = Digraph(comment=item.title, engine="sfdp") +# graph.attr(splines='curved', overlap='false', margin='0.4') # 设置图的大小为A4纸尺寸 +# build_mind_map(graph, item.title, None, structure=item.structure) +# fileuuid = str(uuid.uuid4()) +# graph.render(os.path.join(DATA_DIR, fileuuid), format='png', cleanup=True) +# server_url = os.environ.get("GPTS_API_SERVER") +# if server_url.endswith("/"): +# server_url = server_url[:-1] +# return RestResult(code=0, msg="success", result=dict(data=f"{server_url}/assets/{fileuuid}.png")) +# except Exception as e: +# log.error(f"create_mindmap error: {e}") +# raise HTTPException(status_code=500, detail=str(e)) + + +def create_mindmap_task(task): try: - log.info(f"create_mindmap: {item}") + log.info(f"generate_mindmap params: {task['content']}") + airesp = create_mindma_data_by_openai(task['content']) + log.info(f"generate_mindmap result: {airesp}") # 创建并构建思维导图 + # 将 JSON 字符串转换为 Python 字典 + data = json.loads(airesp) + # 使用 model_validate 方法创建 MindmapItem 实例 + item = MindmapItem.model_validate(data) graph = Digraph(comment=item.title, engine="sfdp") - graph.attr(splines='curved', overlap='false', margin='0.4') # 设置图的大小为A4纸尺寸 + graph.attr(splines='curved') build_mind_map(graph, item.title, None, structure=item.structure) - fileuuid = str(uuid.uuid4()) - graph.render(os.path.join(DATA_DIR, fileuuid), format='png', cleanup=True) + fileuuid = task["id"] + output_path = os.path.join(DATA_DIR, fileuuid) + graph.render(output_path, format='png', cleanup=True) + # 生成 DOT 文件 + dot_path = output_path + ".dot" + with open(dot_path, "w") as dot_file: + dot_file.write(graph.source) + server_url = os.environ.get("GPTS_API_SERVER") if server_url.endswith("/"): server_url = server_url[:-1] - return RestResult(code=0, msg="success", result=dict(data=f"{server_url}/assets/{fileuuid}.png")) + task["status"] = "done" + task["remark"] = "The task is completed, please access the URL information!" + task["image_url"] = f"{server_url}/assets/{fileuuid}.png" + task["dot_url"] = f"{server_url}/assets/{fileuuid}.dot" + with open(os.path.join(DATA_DIR, f"{fileuuid}.json"), "w") as f: + data = json.dumps(task) + f.write(data) except Exception as e: - log.error(f"create_mindmap error: {e}") - raise HTTPException(status_code=500, detail=str(e)) + import traceback + traceback.print_exc() -@app.get("/knowledge/mindmap/generate", summary="Create a knowledge base mindmap from query content", - description="Generating mind maps from given content") -async def generate_mindmap(content: str = Query(...), td: bool = Depends(verify_api_key)): +# @app.get("/knowledge/mindmap/generate", summary="Create a knowledge base mindmap from query content", +# description="Generating mind maps from given content") +# async def generate_mindmap(content: str = Query(...), td: bool = Depends(verify_api_key)): +# try: +# log.info(f"generate_mindmap params: {content}") +# airesp = await create_mindma_data_by_openai(content) +# log.info(f"generate_mindmap result: {airesp}") +# # 创建并构建思维导图 +# # 将 JSON 字符串转换为 Python 字典 +# data = json.loads(airesp) +# # 使用 model_validate 方法创建 MindmapItem 实例 +# item = MindmapItem.model_validate(data) +# graph = Digraph(comment=item.title, engine="sfdp") +# graph.attr(splines='curved') +# build_mind_map(graph, item.title, None, structure=item.structure) +# fileuuid = str(uuid.uuid4()) +# graph.render(os.path.join(DATA_DIR, fileuuid), format='png', cleanup=True) +# server_url = os.environ.get("GPTS_API_SERVER") +# if server_url.endswith("/"): +# server_url = server_url[:-1] +# return RestResult(code=0, msg="success", result=dict(data=f"{server_url}/assets/{fileuuid}.png")) +# except Exception as e: +# log.error(f"generate_mindmap error: {e}") +# raise HTTPException(status_code=500, detail=str(e)) + + +@app.get("/knowledge/mindmap/task/add", summary="Create a mindmap generate task from content", + description="Create a mindmap generate task from content") +async def generate_mindmap_task_add(background_tasks: BackgroundTasks, content: str = Query(...), + td: bool = Depends(verify_api_key)): try: - log.info(f"generate_mindmap params: {content}") - airesp = create_mindma_data_by_openai(content) - log.info(f"generate_mindmap result: {airesp}") - # 创建并构建思维导图 - # 将 JSON 字符串转换为 Python 字典 - data = json.loads(airesp) - # 使用 model_validate 方法创建 MindmapItem 实例 - item = MindmapItem.model_validate(data) - graph = Digraph(comment=item.title, engine="sfdp") - graph.attr(splines='curved') - build_mind_map(graph, item.title, None, structure=item.structure) - fileuuid = str(uuid.uuid4()) - graph.render(os.path.join(DATA_DIR, fileuuid), format='png', cleanup=True) + log.info(f"generate_mindmap_task_add params: {content}") server_url = os.environ.get("GPTS_API_SERVER") if server_url.endswith("/"): server_url = server_url[:-1] - return RestResult(code=0, msg="success", result=dict(data=f"{server_url}/assets/{fileuuid}.png")) + taskid = str(uuid.uuid4()) + task = dict( + id=taskid, + content=content, + status="pending", + image_url=f"{server_url}/assets/{taskid}.png", + dot_url=f"{server_url}/assets/{taskid}.dot", + status_url=f"{server_url}/knowledge/mindmap/task/result/{taskid}", + remark="The task is being processed, please remember to save the URL information and revisit it later!", + ) + with open(os.path.join(DATA_DIR, f"{taskid}.json"), "w") as f: + data = json.dumps(task) + f.write(data) + background_tasks.add_task(create_mindmap_task, task) + log.info(f"generate_mindmap_task_add result: {task}") + return RestResult(code=0, msg="task add success", result=dict(data=task)) except Exception as e: log.error(f"generate_mindmap error: {e}") raise HTTPException(status_code=500, detail=str(e)) +@app.get("/knowledge/mindmap/task/result/{taskid}", summary="Get the mindmap generate task result", + description="Get the mindmap generate task result") +async def generate_mindmap_task_result(taskid: str): + if not re.match(r'^[\w-]+$', taskid): + raise HTTPException(status_code=400, detail="Invalid task ID format") + + file_path = os.path.join(DATA_DIR, f"{taskid}.json") + if not os.path.isfile(file_path): + raise HTTPException(status_code=404, detail="File not found") + return FileResponse(file_path) + + if __name__ == "__main__": import uvicorn diff --git a/tests/test.http b/tests/test.http index 1005bfa..2582b1c 100644 --- a/tests/test.http +++ b/tests/test.http @@ -66,7 +66,7 @@ Authorization: Bearer a99e05501a0405531caf783eef419b56a5a32f57b64ae3b89587b3a0d5 ### -GET http://127.0.0.1:8700/knowledge/mindmap/generate?content=根据Python知识创建一个思维导图,最多60节点 +GET http://127.0.0.1:8700/knowledge/mindmap/task/add?content=根据C++基础创建一个思维导图 Accept: application/json Content-Type: application/json Authorization: Bearer a99e05501a0405531caf783eef419b56a5a32f57b64ae3b89587b3a0d5202ee167d80d727a1b8181 diff --git a/tools/generate_mindmap.py b/tools/generate_mindmap.py index 411f7ae..f12eae6 100644 --- a/tools/generate_mindmap.py +++ b/tools/generate_mindmap.py @@ -6,9 +6,9 @@ from main import MindmapItem -def generate_mindmap(): +async def generate_mindmap(): try: - airesp = create_mindma_data_by_openai("根据微积分基础整理一个学习计划思维导图") + airesp = await create_mindma_data_by_openai("根据微积分基础整理一个学习计划思维导图") # 创建并构建思维导图 data = json.loads(airesp) item = MindmapItem.model_validate(data)