diff --git a/app/core/config.py b/app/core/config.py index cf2082bd..9c70740d 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -1,20 +1,26 @@ +import os import secrets import sys import threading from pathlib import Path -from typing import Optional, List +from typing import Optional, List, Any, Type, Tuple from dotenv import set_key -from pydantic import BaseSettings, validator +from pydantic import BaseSettings, validator, BaseModel from app.log import logger from app.utils.system import SystemUtils +from app.utils.url import UrlUtils -class Settings(BaseSettings): +class ConfigModel(BaseModel): """ - 系统配置类 + Pydantic 配置模型,描述所有配置项及其类型和默认值 """ + + class Config: + extra = "ignore" # 忽略未定义的配置项 + # 项目名称 PROJECT_NAME = "MoviePilot" # 域名 格式;https://movie-pilot.org @@ -187,39 +193,109 @@ class Settings(BaseSettings): # 全局图片缓存,将媒体图片缓存到本地 GLOBAL_IMAGE_CACHE: bool = False - @validator("SUBSCRIBE_RSS_INTERVAL", - "COOKIECLOUD_INTERVAL", - "META_CACHE_EXPIRE", - pre=True, always=True) - def convert_int(cls, value): - if not value: - return 0 - try: - return int(value) - except (ValueError, TypeError): - raise ValueError(f"{value} 格式错误,不是有效数字!") - - @validator("COOKIECLOUD_ENABLE_LOCAL", - "SUBSCRIBE_SEARCH", - pre=True, always=True) - def convert_boolean(cls, value): - if not value: - return False - if str(value).upper() == "FALSE": - return False - return True + +class Settings(BaseSettings, ConfigModel): + """ + 系统配置类 + """ + + class Config: + case_sensitive = True + env_file = SystemUtils.get_env_path() + env_file_encoding = "utf-8" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + # 初始化配置目录及子目录 + for path in [self.CONFIG_PATH, self.TEMP_PATH, self.LOG_PATH, self.COOKIE_PATH]: + if not path.exists(): + path.mkdir(parents=True, exist_ok=True) + # 如果是二进制程序,确保配置文件存在 + if SystemUtils.is_frozen(): + app_env_path = self.CONFIG_PATH / "app.env" + if not app_env_path.exists(): + SystemUtils.copy(self.INNER_CONFIG_PATH / "app.env", app_env_path) @validator("API_TOKEN", pre=True, always=True) def validate_api_token(cls, v): if not v: new_token = secrets.token_urlsafe(16) - logger.info(f"【API_TOKEN】未设置,已随机生成新的 API_TOKEN:{new_token}") + logger.info(f"'API_TOKEN' 未设置,已随机生成新的 API_TOKEN:{new_token}") set_key(str(SystemUtils.get_env_path()), "API_TOKEN", new_token) return new_token elif len(v) < 16: - logger.warning("API_TOKEN 长度不足 16 个字符,存在安全隐患,建议尽快更换为更复杂的密钥!") + logger.warning("'API_TOKEN' 长度不足 16 个字符,存在安全隐患,建议尽快更换为更复杂的密钥!") return v + @staticmethod + def generic_type_converter(value: Any, expected_type: Type, default: Any, field_name: str) -> Tuple[Any, bool]: + """ + 通用类型转换函数,根据预期类型转换值。如果转换失败,返回默认值 """ + if value is None: + return default, False + + original_value = value + if isinstance(value, str): + value = value.strip() + + try: + if expected_type is bool: + if isinstance(value, bool): + return value, False + if isinstance(value, str): + value_clean = value.lower() + bool_map = { + "false": False, "no": False, "0": False, "off": False, + "true": True, "yes": True, "1": True, "on": True + } + if value_clean in bool_map: + return bool_map[value_clean], value_clean != original_value.lower() + elif isinstance(value, (int, float)): + return bool(value), False + return default, False + elif expected_type is int: + if isinstance(value, int): + return value, False + if isinstance(value, str): + converted = int(value) + return converted, value != original_value + elif expected_type is float: + if isinstance(value, float): + return value, False + if isinstance(value, str): + converted = float(value) + return converted, value != original_value + elif expected_type is str: + return value, value != original_value + # # 后续考虑支持 list 类型的处理 + # elif expected_type is list: + # if isinstance(value, list): + # return value, False + # if isinstance(value, str): + # items = [item.strip() for item in value.split(",") if item.strip()] + # return items, items != original_value.split(",") + # 可根据需要添加更多类型处理 + else: + return value, False + except (ValueError, TypeError): + return default, True + + @validator('*', pre=True, always=True) + def generic_type_validator(cls, value: Any, field): + """ + 通用校验器 + """ + converted_value, needs_update = cls.generic_type_converter(value, field.type_, field.default, field.name) + if needs_update: + logger.error(f"字段 '{field.name}' 的值 '{value}' 无效,已使用 '{converted_value}' 进行替换") + if field.name in os.environ: + logger.warning(f"字段 '{field.name}' 已存在于环境变量中,请手动修改") + else: + set_key(SystemUtils.get_env_path(), field.name, + str(converted_value) if converted_value is not None else "") + logger.info(f"字段 '{field.name}' 已由应用修改并写入到 app.env 中") + return converted_value + @property def VERSION_FLAG(self) -> str: """ @@ -354,35 +430,7 @@ def VAPID(self): def MP_DOMAIN(self, url: str = None): if not self.APP_DOMAIN: return None - domain = self.APP_DOMAIN.rstrip("/") - if not domain.startswith("http"): - domain = "http://" + domain - if not url: - return domain - return domain + "/" + url.lstrip("/") - - def __init__(self, **kwargs): - super().__init__(**kwargs) - with self.CONFIG_PATH as p: - if not p.exists(): - p.mkdir(parents=True, exist_ok=True) - if SystemUtils.is_frozen(): - if not (p / "app.env").exists(): - SystemUtils.copy(self.INNER_CONFIG_PATH / "app.env", p / "app.env") - with self.TEMP_PATH as p: - if not p.exists(): - p.mkdir(parents=True, exist_ok=True) - with self.LOG_PATH as p: - if not p.exists(): - p.mkdir(parents=True, exist_ok=True) - with self.COOKIE_PATH as p: - if not p.exists(): - p.mkdir(parents=True, exist_ok=True) - - class Config: - case_sensitive = True - env_file = SystemUtils.get_env_path() - env_file_encoding = "utf-8" + return UrlUtils.combine_url(host=self.APP_DOMAIN, path=url) class GlobalVar(object):