-
Notifications
You must be signed in to change notification settings - Fork 64
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support Reranker: setence-sf and infinity, and some bugs fixed #306
base: main
Are you sure you want to change the base?
Conversation
|
||
|
||
@Reranker.register_reranker(batch=True) | ||
def ModuleReranker( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里为什么从注册列表里面去掉了,而是变成了if self._name == "ModuleReranker":。我个人感觉可能是想在Rerank进行update的时候,把TrainableModule也Update了。
我在想Reranker可以支持对类进行注册。在__new__的时候,如果发现name: str = "ModuleReranker"在注册表中且是个class,则直接返回注册的类的实例化对象,看看能不能满足需求
@classmethod
def register_reranker(
cls: "Reranker", func: Optional[Callable] = None, batch: bool = False
):
def decorator(f):
if isinstance(f, type):
if batch: raise NotImplementedError('...')
cls.registered_reranker[f.__name__] = f
return f
else:
def wrapper(nodes, **kwargs):
if batch:
return f(nodes, **kwargs)
else:
results = [f(node, **kwargs) for node in nodes]
return [result for result in results if result]
cls.registered_reranker[f.__name__] = wrapper
return wrapper
return decorator(func) if func else decorator
class Reranker(ModuleBase, _PostProcess):
pass
@Reranker.register_reranker(batch=True)
class ModuleReranker(Reranker):
def __init__(self, ...):
pass
def forward():
pass
if type == 'embed' or ModelManager.get_model_type(model_name) == 'embed': | ||
if not type: | ||
type = ModelManager.get_model_type(model_name) | ||
if type in ('embed', 'reranker'): | ||
if lazyllm.config['default_embedding_engine'] == 'transformers' or not check_requirements('infinity_emb'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check_requirements可以加个@functools.lru_cache
if self._model_type == "reranker": | ||
target_name = 'rerank' | ||
else: | ||
target_name = 'embeddings' | ||
if lazyllm.config['mode'] == lazyllm.Mode.Display: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以找个时间,把Display模式完整的删掉,因为作为应用开发框架,Display模式没啥意义
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added TODO list.
lazyllm/tools/rag/document.py
Outdated
@@ -80,4 +80,7 @@ def forward(self, *args, **kw) -> List[DocNode]: | |||
return self._impl.retrieve(*args, **kw) | |||
|
|||
def __repr__(self): | |||
return lazyllm.make_repr("Module", "Document", manager=bool(self._manager)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return lazyllm.make_repr("Module", "Document", manager=hasattr(self, '_manager'))
lazyllm/tools/rag/rerank.py
Outdated
|
||
|
||
class Reranker(ModuleBase, _PostProcess): | ||
registered_reranker = dict() | ||
|
||
def __new__(cls, name: str = "ModuleReranker", *args, **kwargs): | ||
if name == "ModuleReranker": | ||
return super(Reranker, cls).__new__(cls.registered_reranker["ModuleReranker"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里要判断name在注册表且为class,然后断言这个class要是Reranker的subclass
Support Reranker:
Bugs fixed:
Document
no attr:_manager
ModuleBase
object,isinstance(xx, ModuleBase)
= False, this will cause cannot add into submodules. example: