-
Notifications
You must be signed in to change notification settings - Fork 64
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support Reranker: setence-sf and infinity, and some bugs fixed #306
base: main
Are you sure you want to change the base?
Changes from 2 commits
beca67e
3a4008c
b01a2df
f82a1f6
ac78e2f
98ada27
d0cd35f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,7 +9,7 @@ | |
lazyllm.config.add("default_embedding_engine", str, "", "DEFAULT_EMBEDDING_ENGINE") | ||
|
||
class Infinity(LazyLLMDeployBase): | ||
keys_name_handle = keys_name_handle = { | ||
keys_name_handle = { | ||
'inputs': 'input', | ||
} | ||
message_format = { | ||
|
@@ -19,6 +19,7 @@ class Infinity(LazyLLMDeployBase): | |
|
||
def __init__(self, | ||
launcher=launchers.remote(ngpus=1), | ||
model_type='embed', | ||
**kw, | ||
): | ||
super().__init__(launcher=launcher) | ||
|
@@ -27,8 +28,25 @@ def __init__(self, | |
'port': None, | ||
'batch-size': 256, | ||
}) | ||
self._model_type = model_type | ||
self.kw.check_and_update(kw) | ||
self.random_port = False if 'port' in kw and kw['port'] else True | ||
if self._model_type == "reranker": | ||
self._update_reranker_message() | ||
|
||
def _update_reranker_message(self): | ||
self.keys_name_handle = { | ||
'inputs': 'query', | ||
} | ||
self.message_format = { | ||
'query': 'who are you ?', | ||
'documents': ['string'], | ||
'return_documents': False, | ||
'raw_scores': False, | ||
'top_n': 1, | ||
'model': 'default/not-specified', | ||
} | ||
self.default_headers = {'Content-Type': 'application/json'} | ||
|
||
def cmd(self, finetuned_model=None, base_model=None): | ||
if not os.path.exists(finetuned_model) or \ | ||
|
@@ -51,19 +69,28 @@ def impl(): | |
def geturl(self, job=None): | ||
if job is None: | ||
job = self.job | ||
if self._model_type == "reranker": | ||
target_name = 'rerank' | ||
else: | ||
target_name = 'embeddings' | ||
if lazyllm.config['mode'] == lazyllm.Mode.Display: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 可以找个时间,把Display模式完整的删掉,因为作为应用开发框架,Display模式没啥意义 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added TODO list. |
||
return 'http://{ip}:{port}/embeddings' | ||
return f'http://<ip>:<port>/{target_name}' | ||
else: | ||
return f'http://{job.get_jobip()}:{self.kw["port"]}/embeddings' | ||
return f'http://{job.get_jobip()}:{self.kw["port"]}/{target_name}' | ||
|
||
@staticmethod | ||
def extract_result(x, inputs): | ||
try: | ||
res_object = json.loads(x) | ||
except Exception as e: | ||
LOG.warning(f'JSONDecodeError on load {x}') | ||
raise e | ||
assert 'object' in res_object | ||
object_type = res_object['object'] | ||
if object_type == 'embedding': | ||
res_list = [item['embedding'] for item in res_object['data']] | ||
if len(res_list) == 1 and type(inputs['input']) is str: | ||
res_list = res_list[0] | ||
return json.dumps(res_list) | ||
except Exception as e: | ||
LOG.warning(f'JSONDecodeError on load {x}') | ||
raise e | ||
elif object_type == 'rerank': | ||
return [x['index'] for x in res_object['results']] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -80,4 +80,7 @@ def forward(self, *args, **kw) -> List[DocNode]: | |
return self._impl.retrieve(*args, **kw) | ||
|
||
def __repr__(self): | ||
return lazyllm.make_repr("Module", "Document", manager=bool(self._manager)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. return lazyllm.make_repr("Module", "Document", manager=hasattr(self, '_manager')) |
||
if hasattr(self, '_manager'): | ||
return lazyllm.make_repr("Module", "Document", manager=bool(self._manager)) | ||
else: | ||
return lazyllm.make_repr("Module", "Document") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,10 @@ | ||
from functools import lru_cache | ||
from typing import Callable, List, Optional, Union | ||
from lazyllm import ModuleBase, config, LOG | ||
|
||
import lazyllm | ||
from lazyllm import ModuleBase, LOG | ||
from lazyllm.tools.rag.store import DocNode, MetadataMode | ||
from lazyllm.components.utils.downloader import ModelManager | ||
from .retriever import _PostProcess | ||
import numpy as np | ||
|
||
|
||
class Reranker(ModuleBase, _PostProcess): | ||
|
@@ -15,10 +15,22 @@ def __init__(self, name: str = "ModuleReranker", target: Optional[str] = None, | |
super().__init__() | ||
self._name = name | ||
self._kwargs = kwargs | ||
if self._name == "ModuleReranker": | ||
self._reranker = lazyllm.TrainableModule(self._kwargs['model']) | ||
_PostProcess.__init__(self, target, output_format, join) | ||
|
||
def forward(self, nodes: List[DocNode], query: str = "") -> List[DocNode]: | ||
results = self.registered_reranker[self._name](nodes, query=query, **self._kwargs) | ||
if self._name == "ModuleReranker": | ||
docs = [node.get_text(metadata_mode=MetadataMode.EMBED) for node in nodes] | ||
top_n = self._kwargs['topk'] if 'topk' in self._kwargs else len(docs) | ||
if self._reranker._deploy_type == lazyllm.deploy.Infinity: | ||
sorted_indices = self._reranker(query, documents=docs, top_n=top_n) | ||
else: | ||
inps = {'query': query, 'documents': docs, 'top_n': top_n} | ||
sorted_indices = self._reranker(inps) | ||
results = [nodes[i] for i in sorted_indices] | ||
else: | ||
results = self.registered_reranker[self._name](nodes, query=query, **self._kwargs) | ||
LOG.debug(f"Rerank use `{self._name}` and get nodes: {results}") | ||
return self._post_process(results) | ||
|
||
|
@@ -73,33 +85,6 @@ def KeywordFilter( | |
return None | ||
return node | ||
|
||
|
||
@lru_cache(maxsize=None) | ||
def get_cross_encoder_model(model_name: str): | ||
from sentence_transformers import CrossEncoder | ||
|
||
model = ModelManager(config["model_source"]).download(model_name) | ||
return CrossEncoder(model) | ||
|
||
|
||
@Reranker.register_reranker(batch=True) | ||
def ModuleReranker( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里为什么从注册列表里面去掉了,而是变成了if self._name == "ModuleReranker":。我个人感觉可能是想在Rerank进行update的时候,把TrainableModule也Update了。 我在想Reranker可以支持对类进行注册。在__new__的时候,如果发现name: str = "ModuleReranker"在注册表中且是个class,则直接返回注册的类的实例化对象,看看能不能满足需求
|
||
nodes: List[DocNode], model: str, query: str, topk: int = -1, **kwargs | ||
) -> List[DocNode]: | ||
if not nodes: | ||
return [] | ||
cross_encoder = get_cross_encoder_model(model) | ||
query_pairs = [ | ||
(query, node.get_text(metadata_mode=MetadataMode.EMBED)) for node in nodes | ||
] | ||
scores = cross_encoder.predict(query_pairs) | ||
sorted_indices = np.argsort(scores)[::-1] # Descending order | ||
if topk > 0: | ||
sorted_indices = sorted_indices[:topk] | ||
|
||
return [nodes[i] for i in sorted_indices] | ||
|
||
|
||
# User-defined similarity decorator | ||
def register_reranker(func=None, batch=False): | ||
return Reranker.register_reranker(func, batch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check_requirements可以加个@functools.lru_cache