From 006ad44d2983aade26d399bca6750b6026281d69 Mon Sep 17 00:00:00 2001 From: Dongyanmio Date: Sun, 27 Oct 2024 10:05:20 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E6=9C=AC=E5=9C=B0?= =?UTF-8?q?=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit refactor: 使用内存存放文件列表 --- .gitignore | 1 + core/__init__.py | 91 ++++++++++++++++++++-------- core/filesdb.py | 120 ------------------------------------- core/routes/openbmclapi.py | 23 +++---- core/routes/services.py | 16 +++-- core/types.py | 56 +++++++++++++++-- core/utils.py | 34 ++++++++++- test.py | 50 ++++++++++++---- 8 files changed, 209 insertions(+), 182 deletions(-) delete mode 100644 core/filesdb.py diff --git a/.gitignore b/.gitignore index 7f45a70..ca5e2ac 100644 --- a/.gitignore +++ b/.gitignore @@ -163,6 +163,7 @@ cython_debug/ # iodine-at-home config.yml +logs/ data/ files/ plugins/ \ No newline at end of file diff --git a/core/__init__.py b/core/__init__.py index a4f624d..266511f 100644 --- a/core/__init__.py +++ b/core/__init__.py @@ -5,6 +5,7 @@ import asyncio import uvicorn import importlib +from pathlib import Path from pluginbase import PluginBase from fastapi import FastAPI, Response from datetime import datetime, timezone @@ -23,8 +24,7 @@ import core.utils as utils from core.logger import logger from core.config import config -from core.types import Cluster, oclm -from core.filesdb import FilesDB +from core.types import Cluster, oclm, filesdb from core.dns.cloudflare import cf_client # 路由库 @@ -33,17 +33,17 @@ from core.routes.services import router as services_router from core.routes.api.v0 import router as api_v0_router + # 网页部分 @asynccontextmanager async def lifespan(app: FastAPI): - logger.info( - f"正在 {config.get('host')}:{config.get('port')} 上监听服务器..." - ) + init_filelist() + logger.info(filesdb.url_list) + logger.info(f"正在 {config.get('host')}:{config.get('port')} 上监听服务器...") yield - async with FilesDB() as db: - await db.close() logger.success("主控退出成功。") + app = FastAPI( title="iodine@home", summary="开源的文件分发主控,并尝试兼容 OpenBMCLAPI 客户端", @@ -52,7 +52,7 @@ async def lifespan(app: FastAPI): "name": "The MIT License", "url": "https://raw.githubusercontent.com/ZeroNexis/iodine-at-home/main/LICENSE", }, - lifespan=lifespan + lifespan=lifespan, ) app.include_router(agent_router, prefix="/openbmclapi-agent") @@ -69,6 +69,7 @@ async def lifespan(app: FastAPI): allow_headers=["*"], ) + # 插件部分 async def load_plugins(): global app @@ -81,7 +82,9 @@ async def load_plugins(): if hasattr(plugin, "__API__") and plugin.__API__: if hasattr(plugin, "router"): app.include_router(plugin.router, prefix=f"/{plugin.__NAMESPACE__}") - logger.success(f"已注册插件 API 路由:{plugin.__NAMESPACE__}, {plugin.router.routes}") + logger.success( + f"已注册插件 API 路由:{plugin.__NAMESPACE__}, {plugin.router.routes}" + ) else: logger.warning( f"插件「{plugin.__NAME__}」未定义 Router ,无法加载该插件的路径!" @@ -93,21 +96,23 @@ async def load_plugins(): sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*") socket = ASGIApp(sio) + # 核心功能 @app.middleware("http") -async def _(request,call_next): +async def _(request, call_next): start_time = datetime.now() response = await call_next(request) process_time = (datetime.now() - start_time).total_seconds() - response_size = len(response.body) if hasattr(response, 'body') else 0 - referer = request.headers.get('Referer') - user_agent = request.headers.get('user-agent', '-') + response_size = len(response.body) if hasattr(response, "body") else 0 + referer = request.headers.get("Referer") + user_agent = request.headers.get("user-agent", "-") logger.info( f"Serve {response.status_code} | {process_time:.2f}s | {response_size}B | " - f"{request.client.host} | {request.method} | {request.url.path} | \"{user_agent}\" | \"{referer}\"" + f'{request.client.host} | {request.method} | {request.url.path} | "{user_agent}" | "{referer}"' ) return response + ## 节点端连接时 @sio.on("connect") async def on_connect(sid, *args): @@ -162,36 +167,51 @@ async def on_cluster_request_cert(sid, *args): if cluster_is_exist == False: return [{"message": "错误: 节点似乎并不存在,请检查配置文件"}] logger.debug(f"节点 {cluster.id} 请求证书") - if cluster.cert_fullchain != "" and cluster.cert_privkey != "" and cluster.cert_expiry != "" and cluster.cert_expiry > datetime.now(pytz.utc).strftime('%Y-%m-%dT%H:%M:%S+00:00'): + if ( + cluster.cert_fullchain != "" + and cluster.cert_privkey != "" + and cluster.cert_expiry != "" + and cluster.cert_expiry + > datetime.now(pytz.utc).strftime("%Y-%m-%dT%H:%M:%S+00:00") + ): return [ - None, { + None, + { "_id": cluster.id, "clusterId": cluster.id, "cert": cluster.cert_fullchain, "key": cluster.cert_privkey, "expires": cluster.cert_expiry, - "__v": 0 - } + "__v": 0, + }, ] else: - cert, key = await cf_client.get_certificate(f"{cluster.id}.{config.get('cluster-certificate.domain')}") + cert, key = await cf_client.get_certificate( + f"{cluster.id}.{config.get('cluster-certificate.domain')}" + ) if cert == None or key == None: return [{"message": "错误: 证书获取失败,请重新尝试。"}] current_time = datetime.now(pytz.utc) future_time = current_time + relativedelta(months=3) - formatted_time = future_time.astimezone(pytz.utc).strftime('%Y-%m-%dT%H:%M:%S+00:00') - await cluster.edit(cert_fullchain=cert, cert_privkey=key, cert_expiry=formatted_time) + formatted_time = future_time.astimezone(pytz.utc).strftime( + "%Y-%m-%dT%H:%M:%S+00:00" + ) + await cluster.edit( + cert_fullchain=cert, cert_privkey=key, cert_expiry=formatted_time + ) return [ - None, { + None, + { "_id": cluster.id, "clusterId": cluster.id, "cert": cluster.cert_fullchain, "key": cluster.cert_privkey, "expires": cluster.cert_expiry, - "__v": 0 - } + "__v": 0, + }, ] + ## 节点启动时 @sio.on("enable") async def on_cluster_enable(sid, data: dict, *args): @@ -250,7 +270,8 @@ async def on_cluster_enable(sid, data: dict, *args): else: logger.debug(f"{cluster.id} 测速失败: {bandwidth[1]}") return [{"message": f"错误: {bandwidth[1]}"}] - + + ## 节点保活时 @sio.on("keep-alive") async def on_cluster_keep_alive(sid, data, *args): @@ -265,6 +286,7 @@ async def on_cluster_keep_alive(sid, data, *args): ) return [None, datetime.now(timezone.utc).isoformat()] + @sio.on("disable") ## 节点禁用时 async def on_cluster_disable(sid, *args): session = await sio.get_session(sid) @@ -280,12 +302,29 @@ async def on_cluster_disable(sid, *args): logger.debug(f"节点 {cluster.id} 尝试禁用集群失败: 节点没有启用") return [None, True] + +def init_filelist(): + filelist = utils.scan_files(Path("./files/")) + for file in filelist: + hash = utils.get_file_hash(f"./{file}") + size = utils.get_file_size(f"./{file}") + mtime = utils.get_file_mtime(f"./{file}") + filesdb.append(hash=hash, url=f"{file}", size=size, mtime=mtime) + + def init(): + Path("./files/").mkdir(exist_ok=True) logger.clear() logger.info("加载中……") try: asyncio.run(load_plugins()) app.mount("/", socket) - uvicorn.run(app, host=config.get('host'), port=config.get(path='port'), log_level='warning', access_log=False) + uvicorn.run( + app, + host=config.get("host"), + port=config.get(path="port"), + log_level="warning", + access_log=False, + ) except Exception as e: logger.error(e) diff --git a/core/filesdb.py b/core/filesdb.py deleted file mode 100644 index ee6a825..0000000 --- a/core/filesdb.py +++ /dev/null @@ -1,120 +0,0 @@ -import os -import atexit -import asyncio -import aiosqlite - - -class FilesDB: - def __init__(self): - self.conn = None - self.cursor = None - - async def __aenter__(self): - await self.connect() - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - await self.close() - - async def connect(self): - if not os.path.exists("./data/database.db"): - raise FileNotFoundError("数据库文件不存在") - self.conn = await aiosqlite.connect("./data/database.db") - self.cursor = await self.conn.cursor() - - async def close(self): - if self.conn: - await self.conn.close() - self.conn = None - self.cursor = None - - async def create_table(self): - await self.conn.execute( - """ - CREATE TABLE IF NOT EXISTS FILELIST - ( - HASH TEXT PRIMARY KEY, - PATH TEXT, - URL TEXT, - SIZE INTEGER, - MTIME INTEGER, - SOURCE TEXT - ) - """ - ) - await self.conn.commit() - await self.close() - - async def new_file( - self, - hash: str, - path: str = "", - url: str = "", - size: int = 0, - mtime: int = 0, - source: str = "local", - ): - await self.connect() - await self.conn.execute( - """ - INSERT INTO FILELIST ( - HASH, - PATH, - URL, - SIZE, - MTIME, - SOURCE - ) VALUES (?, ?, ?, ?, ?, ?) - """, - ( - hash, - path, - url, - size, - mtime, - source, - ), - ) - await self.conn.commit() - return True - - async def delete_file(self, hash: str): - await self.conn.execute( - """ - DELETE FROM FILELIST WHERE HASH = ? - """, - (hash,), - ) - await self.conn.commit() - return True - - async def delete_all(self): - await self.conn.execute( - """ - DELETE FROM FILELIST - """, - ) - await self.conn.commit() - return True - - async def find_one(self, key: str, value: str): - async with self.conn.execute( - f""" - SELECT * FROM FILELIST WHERE {key} = ? - """, - (value,), - ) as cursor: - row = await cursor.fetchone() - if row: - columns = [desc[0] for desc in cursor.description] - result = dict(zip(columns, row)) - else: - result = False - return result - - async def get_all(self): - async with self.conn.execute("SELECT * FROM FILELIST") as cursor: - rows = await cursor.fetchall() - columns = [desc[0] for desc in cursor.description] - result = [dict(zip(columns, row)) for row in rows] - return result \ No newline at end of file diff --git a/core/routes/openbmclapi.py b/core/routes/openbmclapi.py index 2d25e49..df418b3 100644 --- a/core/routes/openbmclapi.py +++ b/core/routes/openbmclapi.py @@ -5,9 +5,8 @@ from fastapi.responses import FileResponse, HTMLResponse, PlainTextResponse # 本地库 -from core.types import Avro from core.logger import logger -from core.filesdb import FilesDB +from core.types import Avro, filesdb router = APIRouter() @@ -21,15 +20,13 @@ def get_configuration(response: Response): @router.get("/files", summary="文件列表", tags=["nodes"]) async def get_filesList(): - async with FilesDB() as filesdb: - filelist = await filesdb.get_all() avro = Avro() - avro.writeVarInt(len(filelist)) # 写入文件数量 - for file in filelist: - avro.writeString(f"/{file['SOURCE']}/{file['HASH']}") # 路径 - avro.writeString(file['HASH']) # 哈希 - avro.writeVarInt(file['SIZE']) # 文件大小 - avro.writeVarInt(file['MTIME']) # 修改时间 + avro.writeVarInt(len(filesdb.hash_list)) # 写入文件数量 + for i in range(len(filesdb.hash_list)): + avro.writeString(f"/{filesdb.url_list[i]}") # 路径 + avro.writeString(filesdb.hash_list[i]) # 哈希 + avro.writeVarInt(filesdb.size_list[i]) # 文件大小 + avro.writeVarInt(filesdb.mtime_list[i]) # 修改时间 avro.write(b"\x00") result = pyzstd.compress(avro.io.getvalue()) avro.io.close() @@ -38,7 +35,11 @@ async def get_filesList(): @router.get("/download/{hash}", summary="应急同步", tags=["nodes"]) async def download_file_from_ctrl(hash: str): - raise HTTPException(404, detail="未找到该文件") + filedata = await filesdb.find(hash) + if filedata: + return FileResponse(f"./{filedata['PATH']}") + else: + raise HTTPException(404, detail="未找到该文件") @router.post("/report", summary="上报异常", tags=["nodes"]) diff --git a/core/routes/services.py b/core/routes/services.py index 562b9f8..2a16146 100644 --- a/core/routes/services.py +++ b/core/routes/services.py @@ -6,27 +6,25 @@ # 本地库 import core.utils as utils -from core.types import oclm, Cluster +from core.types import oclm, Cluster, filesdb from core.logger import logger -from core.filesdb import FilesDB router = APIRouter() -@router.get("/files/{path}", summary="通过 PATH 下载普通文件", tags=["public"]) -async def download_path_file(hash: str): - async with FilesDB() as filesdb: - filedata = await filesdb.find_one("PATH", hash) +@router.get("/files/{path:path}", summary="通过 PATH 下载普通文件", tags=["public"]) +async def download_path_file(path: str): + filedata = filesdb.find(None, f"files/{path}") if filedata: if len(oclm) == 0: - return RedirectResponse(filedata["URL"], 302) + return FileResponse(Path(f"./{filedata['url']}")) else: cluster = Cluster(oclm.random()) await cluster.initialize() - sign = utils.get_sign(filedata["HASH"], cluster.secret) + sign = utils.get_sign(filedata['hash'], cluster.secret) url = utils.get_url( - cluster.host, cluster.port, f"/download/{filedata['HASH']}", sign + cluster.host, cluster.port, f"/download/{filedata['hash']}", sign ) return RedirectResponse(url, 302) else: diff --git a/core/types.py b/core/types.py index 7fea504..7f30e00 100644 --- a/core/types.py +++ b/core/types.py @@ -50,7 +50,7 @@ async def edit( runtime: str = None, cert_fullchain: str = None, cert_privkey: str = None, - cert_expiry: str = None + cert_expiry: str = None, ): result = await cdb.edit_cluster( self.id, @@ -67,7 +67,7 @@ async def edit( runtime, cert_fullchain, cert_privkey, - cert_expiry + cert_expiry, ) if result: await self.initialize() @@ -114,15 +114,63 @@ def update(self, cluster_id: str, weight: float): def include(self, cluster_id: str): return cluster_id in self.id_list - + def random(self) -> str: return choices(self.id_list, self.weight_list)[0] - oclm = OCLManager() +class FilesDB: + def __init__(self): + self.hash_list = [] + self.size_list = [] + self.mtime_list = [] + self.url_list = [] + + def append(self, hash: str, size: int, mtime: int, url: str): + if hash not in self.hash_list: + self.hash_list.append(hash) + self.size_list.append(size) + self.mtime_list.append(mtime) + self.url_list.append(url) + + def remove(self, hash: str): + if hash in self.hash_list: + self.size_list.remove(self.size_list[self.hash_list.index(hash)]) + self.mtime_list.remove(self.mtime_list[self.hash_list.index(hash)]) + self.url_list.remove(self.url_list[self.hash_list.index(hash)]) + self.hash_list.remove(hash) + + def find(self, hash: str | None = None, url: str | None = None): + if hash is not None: + if hash in self.hash_list: + return { + "hash": hash, + "size": self.size_list[self.hash_list.index(hash)], + "mtime": self.mtime_list[self.hash_list.index(hash)], + "url": self.url_list[self.hash_list.index(hash)], + } + else: + return None + elif url is not None: + if url in self.url_list: + return { + "hash": self.hash_list[self.url_list.index(url)], + "size": self.size_list[self.url_list.index(url)], + "mtime": self.mtime_list[self.url_list.index(url)], + "url": url, + } + else: + return None + else: + return None + + +filesdb = FilesDB() + + # 本段修改自 TTB-Network/python-openbmclapi 中部分代码 # 仓库链接: https://github.com/TTB-Network/python-openbmclapi # 源代码使用 MIT License 协议开源 | Copyright (c) 2024 TTB-Network diff --git a/core/utils.py b/core/utils.py index a790090..e53d716 100644 --- a/core/utils.py +++ b/core/utils.py @@ -1,4 +1,5 @@ # 第三方库 +import os import jwt import time import httpx @@ -83,7 +84,9 @@ async def measure_cluster(size: int, cluster: Cluster): try: start_time = time.time() async with httpx.AsyncClient() as client: - response = await client.get(url, headers={"User-Agent": const.user_agent}, timeout=10) + response = await client.get( + url, headers={"User-Agent": const.user_agent}, timeout=10 + ) end_time = time.time() elapsed_time = end_time - start_time # 计算测速时间 @@ -91,3 +94,32 @@ async def measure_cluster(size: int, cluster: Cluster): return [True, bandwidth] except Exception as e: return [False, e] + + +# 遍历指定目录及其子目录中的所有文件 +def scan_files(directory): + result = [] + for root, dirs, files in os.walk(directory): + dirs[:] = [d for d in dirs if not d.startswith('.')] + for file in files: + if not file.startswith('.'): + # 打印文件的完整路径 + path = str(os.path.join(root, file)) + result.append(path.replace("\\", "/")) + return result + + +def get_file_mtime(file_path): + return int(os.path.getmtime(file_path)) + + +def get_file_size(file_path): + return os.path.getsize(file_path) + + +def get_file_hash(file_path): + sha1_hash = hashlib.sha1() + with open(file_path, "rb") as f: + for byte_block in iter(lambda: f.read(4096), b""): + sha1_hash.update(byte_block) + return sha1_hash.hexdigest() diff --git a/test.py b/test.py index f915e47..bb60bee 100644 --- a/test.py +++ b/test.py @@ -1,15 +1,43 @@ -import pytz -from datetime import datetime, timedelta -from dateutil.relativedelta import relativedelta +import tqdm -# 获取当前时间 -current_time = datetime.now(pytz.utc) +class FilesDB: + def __init__(self): + self.hash_list = [] + self.size_list = [] + self.mtime_list = [] + self.url_list = [] -# 加上三个月的时间 -future_time = current_time + relativedelta(months=3) + def append(self, hash: str, size: int, mtime: int, url: str): + if hash not in self.hash_list: + self.hash_list.append(hash) + self.size_list.append(size) + self.mtime_list.append(mtime) + self.url_list.append(url) -# 将时间转换为指定格式的字符串 -formatted_time = future_time.astimezone(pytz.utc).strftime('%Y-%m-%dT%H:%M:%S+00:00') + def remove(self, hash: str): + if hash in self.hash_list: + self.size_list.remove(self.size_list[self.hash_list.index(hash)]) + self.mtime_list.remove(self.mtime_list[self.hash_list.index(hash)]) + self.url_list.remove(self.url_list[self.hash_list.index(hash)]) + self.hash_list.remove(hash) -# 输出结果 -print(formatted_time) \ No newline at end of file + def find(self, hash: str | None = None, url: str | None = None): + if hash is not None: + if hash in self.hash_list: + return hash, self.size_list[self.hash_list.index(hash)], self.mtime_list[self.hash_list.index(hash)], self.url_list[self.hash_list.index(hash)], + else: + return None, None, None, None + elif url is not None: + if url in self.url_list: + return self.size_list[self.url_list.index(url)], self.mtime_list[self.url_list.index(url)], self.hash_list[self.url_list.index(url)], url + else: + return None, None, None, None + else: + return None, None, None, None + +filesdb = FilesDB() + +for i in tqdm.tqdm(range(1000)): + filesdb.append(f"hash{i}", i, i, f"url{i}") + +print(filesdb.find(url="url500")) \ No newline at end of file