Skip to content

Commit

Permalink
Merge pull request #2855 from InfinityPacer/feature/event
Browse files Browse the repository at this point in the history
  • Loading branch information
jxxghp authored Oct 15, 2024
2 parents e99913f + 9548409 commit 541a3d6
Showing 1 changed file with 59 additions and 25 deletions.
84 changes: 59 additions & 25 deletions app/core/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
import time
import traceback
import uuid
from queue import PriorityQueue, Empty
from typing import Callable, Dict, List, Union, Optional
from functools import lru_cache
from queue import Empty, PriorityQueue
from typing import Callable, Dict, List, Optional, Union

from app.helper.message import MessageHelper
from app.helper.thread import ThreadHelper
from app.log import logger
from app.schemas.types import EventType, ChainEventType
from app.schemas.types import ChainEventType, EventType
from app.utils.limit import ExponentialBackoffRateLimiter
from app.utils.singleton import Singleton

Expand Down Expand Up @@ -254,24 +255,46 @@ def visualize_handlers(self) -> List[Dict]:
handler_info.append(handler_dict)
return handler_info

@staticmethod
def __get_handler_identifier(target: Union[Callable, type]) -> str:
@classmethod
@lru_cache(maxsize=1000)
def __get_handler_identifier(cls, target: Union[Callable, type]) -> Optional[str]:
"""
获取处理器或处理器类的唯一标识符,包括模块名和类名
获取处理器或处理器类的唯一标识符,包括模块名和类名/方法名
:param target: 处理器函数或类
:return: 唯一标识符
"""
if isinstance(target, type):
# 如果是类,使用模块名和类名
module_name = target.__module__
class_name = target.__qualname__
return f"{module_name}.{class_name}"
else:
# 如果是函数或方法,使用 inspect.getmodule 来获取模块名
module = inspect.getmodule(target)
# 统一使用 inspect.getmodule 来获取模块名
module = inspect.getmodule(target)
module_name = module.__name__ if module else "unknown_module"

# 使用 __qualname__ 获取目标的限定名
qualname = target.__qualname__
return f"{module_name}.{qualname}"

@classmethod
@lru_cache(maxsize=1000)
def __get_class_from_callable(cls, handler: Callable) -> Optional[str]:
"""
获取可调用对象所属类的唯一标识符
:param handler: 可调用对象(函数、方法等)
:return: 类的唯一标识符
"""
# 对于绑定方法,通过 __self__.__class__ 获取类
if inspect.ismethod(handler) and hasattr(handler, "__self__"):
return cls.__get_handler_identifier(handler.__self__.__class__)

# 对于类实例(实现了 __call__ 方法)
if not inspect.isfunction(handler) and hasattr(handler, "__call__"):
handler_cls = handler.__class__
return cls.__get_handler_identifier(handler_cls)

# 对于未绑定方法、静态方法、类方法,使用 __qualname__ 提取类信息
qualname_parts = handler.__qualname__.split(".")
if len(qualname_parts) > 1:
class_name = ".".join(qualname_parts[:-1])
module = inspect.getmodule(handler)
module_name = module.__name__ if module else "unknown_module"
qualname = target.__qualname__
return f"{module_name}.{qualname}"
return f"{module_name}.{class_name}"

def __is_handler_enabled(self, handler: Callable) -> bool:
"""
Expand All @@ -283,7 +306,7 @@ def __is_handler_enabled(self, handler: Callable) -> bool:
handler_id = self.__get_handler_identifier(handler)

# 获取处理器所属类的唯一标识符
class_id = self.__get_handler_identifier(handler.__self__.__class__) if hasattr(handler, '__self__') else None
class_id = self.__get_class_from_callable(handler)

# 检查处理器或类是否被禁用,只要其中之一被禁用则返回 False
if handler_id in self.__disabled_handlers or (class_id is not None and class_id in self.__disabled_classes):
Expand Down Expand Up @@ -386,18 +409,29 @@ def __get_class_instance(class_name: str):
"""
# 检查类是否在全局变量中
if class_name in globals():
class_obj = globals()[class_name]()
else:
# 如果类不在全局变量中,尝试动态导入模块并创建实例
# 导入模块,除了插件和Command,只有chain能响应事件
try:
module = importlib.import_module(f"app.chain.{class_name[:-5].lower()}")
class_obj = getattr(module, class_name)()
class_obj = globals()[class_name]()
return class_obj
except Exception as e:
logger.error(f"事件处理出错:{str(e)} - {traceback.format_exc()}")
logger.error(f"事件处理出错:创建全局类实例出错:{str(e)} - {traceback.format_exc()}")
return None

return class_obj
# 如果类不在全局变量中,尝试动态导入模块并创建实例
try:
# 导入模块,除了插件和Command,只有chain能响应事件
if not class_name.endswith("Chain"):
logger.debug(f"事件处理出错:无效的 Chain 类名: {class_name},类名必须以 'Chain' 结尾")
return None
module_name = f"app.chain.{class_name[:-5].lower()}"
module = importlib.import_module(module_name)
if hasattr(module, class_name):
class_obj = getattr(module, class_name)()
return class_obj
else:
logger.debug(f"事件处理出错:模块 {module_name} 中没有找到类 {class_name}")
except Exception as e:
logger.error(f"事件处理出错:{str(e)} - {traceback.format_exc()}")
return None

def __broadcast_consumer_loop(self):
"""
Expand Down

0 comments on commit 541a3d6

Please sign in to comment.