Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Reranker: setence-sf and infinity, and some bugs fixed #306

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions lazyllm/components/auto/autodeploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,18 @@ def __new__(cls, base_model, source=lazyllm.config['model_source'], trust_remote
launcher=launchers.remote(ngpus=1), stream=False, type=None, **kw):
base_model = ModelManager(source).download(base_model)
model_name = get_model_name(base_model)
if type == 'embed' or ModelManager.get_model_type(model_name) == 'embed':
if not type:
type = ModelManager.get_model_type(model_name)
if type in ('embed', 'reranker'):
if lazyllm.config['default_embedding_engine'] == 'transformers' or not check_requirements('infinity_emb'):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check_requirements可以加个@functools.lru_cache

return EmbeddingDeploy(launcher)
return EmbeddingDeploy(launcher, model_type=type)
else:
return deploy.Infinity(launcher)
elif type == 'sd' or ModelManager.get_model_type(model_name) == 'sd':
return deploy.Infinity(launcher, model_type=type)
elif type == 'sd':
return StableDiffusionDeploy(launcher)
elif type == 'stt' or ModelManager.get_model_type(model_name) == 'stt':
elif type == 'stt':
return SenseVoiceDeploy(launcher)
elif type == 'tts' or ModelManager.get_model_type(model_name) == 'tts':
elif type == 'tts':
return TTSDeploy(model_name, launcher=launcher)
map_name = model_map(model_name)
candidates = get_configer().query_deploy(lazyllm.config['gpu_type'], launcher.ngpus,
Expand Down
39 changes: 33 additions & 6 deletions lazyllm/components/deploy/infinity.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
lazyllm.config.add("default_embedding_engine", str, "", "DEFAULT_EMBEDDING_ENGINE")

class Infinity(LazyLLMDeployBase):
keys_name_handle = keys_name_handle = {
keys_name_handle = {
'inputs': 'input',
}
message_format = {
Expand All @@ -19,6 +19,7 @@ class Infinity(LazyLLMDeployBase):

def __init__(self,
launcher=launchers.remote(ngpus=1),
model_type='embed',
**kw,
):
super().__init__(launcher=launcher)
Expand All @@ -27,8 +28,25 @@ def __init__(self,
'port': None,
'batch-size': 256,
})
self._model_type = model_type
self.kw.check_and_update(kw)
self.random_port = False if 'port' in kw and kw['port'] else True
if self._model_type == "reranker":
self._update_reranker_message()

def _update_reranker_message(self):
self.keys_name_handle = {
'inputs': 'query',
}
self.message_format = {
'query': 'who are you ?',
'documents': ['string'],
'return_documents': False,
'raw_scores': False,
'top_n': 1,
'model': 'default/not-specified',
}
self.default_headers = {'Content-Type': 'application/json'}

def cmd(self, finetuned_model=None, base_model=None):
if not os.path.exists(finetuned_model) or \
Expand All @@ -51,19 +69,28 @@ def impl():
def geturl(self, job=None):
if job is None:
job = self.job
if self._model_type == "reranker":
target_name = 'rerank'
else:
target_name = 'embeddings'
if lazyllm.config['mode'] == lazyllm.Mode.Display:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以找个时间,把Display模式完整的删掉,因为作为应用开发框架,Display模式没啥意义

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added TODO list.

return 'http://{ip}:{port}/embeddings'
return f'http://<ip>:<port>/{target_name}'
else:
return f'http://{job.get_jobip()}:{self.kw["port"]}/embeddings'
return f'http://{job.get_jobip()}:{self.kw["port"]}/{target_name}'

@staticmethod
def extract_result(x, inputs):
try:
res_object = json.loads(x)
except Exception as e:
LOG.warning(f'JSONDecodeError on load {x}')
raise e
assert 'object' in res_object
object_type = res_object['object']
if object_type == 'embedding':
res_list = [item['embedding'] for item in res_object['data']]
if len(res_list) == 1 and type(inputs['input']) is str:
res_list = res_list[0]
return json.dumps(res_list)
except Exception as e:
LOG.warning(f'JSONDecodeError on load {x}')
raise e
elif object_type == 'rerank':
return [x['index'] for x in res_object['results']]
47 changes: 44 additions & 3 deletions lazyllm/components/embedding/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from lazyllm.thirdparty import transformers as tf
from lazyllm.thirdparty import torch

from sentence_transformers import CrossEncoder
import numpy as np


class LazyHuggingFaceEmbedding(object):
def __init__(self, base_embed, source=None, init=False):
Expand Down Expand Up @@ -45,13 +48,45 @@ def __reduce__(self):
init = bool(os.getenv('LAZYLLM_ON_CLOUDPICKLE', None) == 'ON' or self.init_flag)
return LazyHuggingFaceEmbedding.rebuild, (self.base_embed, init)

class LazyHuggingFaceRerank(object):
def __init__(self, base_rerank, source=None, init=False):
from ..utils.downloader import ModelManager
source = lazyllm.config['model_source'] if not source else source
self.base_rerank = ModelManager(source).download(base_rerank)
self.reranker = None
self.init_flag = lazyllm.once_flag()
if init:
lazyllm.call_once(self.init_flag, self.load_reranker)

def load_reranker(self):
self.reranker = CrossEncoder(self.base_rerank)

def __call__(self, inps):
lazyllm.call_once(self.init_flag, self.load_reranker)
query, documents, top_n = inps['query'], inps['documents'], inps['top_n']
query_pairs = [(query, doc) for doc in documents]
scores = self.reranker.predict(query_pairs)
sorted_indices = np.argsort(scores)[::-1]
if top_n > 0:
sorted_indices = sorted_indices[:top_n]
return sorted_indices.tolist()

@classmethod
def rebuild(cls, base_rerank, init):
return cls(base_rerank, init)

def __reduce__(self):
init = bool(os.getenv('LAZYLLM_ON_CLOUDPICKLE', None) == 'ON' or self.init_flag)
return LazyHuggingFaceRerank.rebuild, (self.base_rerank, init)

class EmbeddingDeploy():
message_format = None
keys_name_handle = None
default_headers = {'Content-Type': 'application/json'}

def __init__(self, launcher=None):
def __init__(self, launcher=None, model_type='embed'):
self.launcher = launcher
self._model_type = model_type

def __call__(self, finetuned_model=None, base_model=None):
if not os.path.exists(finetuned_model) or \
Expand All @@ -61,5 +96,11 @@ def __call__(self, finetuned_model=None, base_model=None):
LOG.warning(f"Note! That finetuned_model({finetuned_model}) is an invalid path, "
f"base_model({base_model}) will be used")
finetuned_model = base_model
return lazyllm.deploy.RelayServer(func=LazyHuggingFaceEmbedding(
finetuned_model), launcher=self.launcher)()
if self._model_type == 'embed':
return lazyllm.deploy.RelayServer(func=LazyHuggingFaceEmbedding(
finetuned_model), launcher=self.launcher)()
if self._model_type == 'reranker':
return lazyllm.deploy.RelayServer(func=LazyHuggingFaceRerank(
finetuned_model), launcher=self.launcher)()
else:
raise RuntimeError(f'Not support model type: {self._model_type}.')
2 changes: 1 addition & 1 deletion lazyllm/components/utils/downloader/model_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@
"huggingface": "BAAI/bge-reranker-large",
"modelscope": "Xorbits/bge-reranker-large"
},
"type": "embed"
"type": "reranker"
},
"chattts": {
"source": {
Expand Down
7 changes: 5 additions & 2 deletions lazyllm/module/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@
from ..components.formatter import FormatterBase, EmptyFormatter
from ..components.utils import ModelManager
from ..flow import FlowBase, Pipeline, Parallel
from ..common.bind import _MetaBind
import uuid
from ..client import get_redis, redis_client


class ModuleBase(object):
# use _MetaBind:
# if bind a ModuleBase: x, then hope: isinstance(x, ModuleBase)==True,
# example: ActionModule.submodules:: isinstance(x, ModuleBase) will add submodule.
class ModuleBase(metaclass=_MetaBind):
builder_keys = [] # keys in builder support Option by default

def __new__(cls, *args, **kw):
Expand Down
5 changes: 4 additions & 1 deletion lazyllm/tools/rag/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,7 @@ def forward(self, *args, **kw) -> List[DocNode]:
return self._impl.retrieve(*args, **kw)

def __repr__(self):
return lazyllm.make_repr("Module", "Document", manager=bool(self._manager))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return lazyllm.make_repr("Module", "Document", manager=hasattr(self, '_manager'))

if hasattr(self, '_manager'):
return lazyllm.make_repr("Module", "Document", manager=bool(self._manager))
else:
return lazyllm.make_repr("Module", "Document")
47 changes: 16 additions & 31 deletions lazyllm/tools/rag/rerank.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from functools import lru_cache
from typing import Callable, List, Optional, Union
from lazyllm import ModuleBase, config, LOG

import lazyllm
from lazyllm import ModuleBase, LOG
from lazyllm.tools.rag.store import DocNode, MetadataMode
from lazyllm.components.utils.downloader import ModelManager
from .retriever import _PostProcess
import numpy as np


class Reranker(ModuleBase, _PostProcess):
Expand All @@ -15,10 +15,22 @@ def __init__(self, name: str = "ModuleReranker", target: Optional[str] = None,
super().__init__()
self._name = name
self._kwargs = kwargs
if self._name == "ModuleReranker":
self._reranker = lazyllm.TrainableModule(self._kwargs['model'])
_PostProcess.__init__(self, target, output_format, join)

def forward(self, nodes: List[DocNode], query: str = "") -> List[DocNode]:
results = self.registered_reranker[self._name](nodes, query=query, **self._kwargs)
if self._name == "ModuleReranker":
docs = [node.get_text(metadata_mode=MetadataMode.EMBED) for node in nodes]
top_n = self._kwargs['topk'] if 'topk' in self._kwargs else len(docs)
if self._reranker._deploy_type == lazyllm.deploy.Infinity:
sorted_indices = self._reranker(query, documents=docs, top_n=top_n)
else:
inps = {'query': query, 'documents': docs, 'top_n': top_n}
sorted_indices = self._reranker(inps)
results = [nodes[i] for i in sorted_indices]
else:
results = self.registered_reranker[self._name](nodes, query=query, **self._kwargs)
LOG.debug(f"Rerank use `{self._name}` and get nodes: {results}")
return self._post_process(results)

Expand Down Expand Up @@ -73,33 +85,6 @@ def KeywordFilter(
return None
return node


@lru_cache(maxsize=None)
def get_cross_encoder_model(model_name: str):
from sentence_transformers import CrossEncoder

model = ModelManager(config["model_source"]).download(model_name)
return CrossEncoder(model)


@Reranker.register_reranker(batch=True)
def ModuleReranker(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为什么从注册列表里面去掉了,而是变成了if self._name == "ModuleReranker":。我个人感觉可能是想在Rerank进行update的时候,把TrainableModule也Update了。

我在想Reranker可以支持对类进行注册。在__new__的时候,如果发现name: str = "ModuleReranker"在注册表中且是个class,则直接返回注册的类的实例化对象,看看能不能满足需求

@classmethod
    def register_reranker(
        cls: "Reranker", func: Optional[Callable] = None, batch: bool = False
    ):
        def decorator(f):
            if isinstance(f, type):
                if batch: raise NotImplementedError('...')
                cls.registered_reranker[f.__name__] = f
                return f 
            else:
                def wrapper(nodes, **kwargs):
                    if batch:
                        return f(nodes, **kwargs)
                    else:
                        results = [f(node, **kwargs) for node in nodes]
                        return [result for result in results if result]

                cls.registered_reranker[f.__name__] = wrapper
                return wrapper

        return decorator(func) if func else decorator

class Reranker(ModuleBase, _PostProcess):
    pass

@Reranker.register_reranker(batch=True)
class ModuleReranker(Reranker):
    def __init__(self, ...):
        pass

    def forward():
        pass

nodes: List[DocNode], model: str, query: str, topk: int = -1, **kwargs
) -> List[DocNode]:
if not nodes:
return []
cross_encoder = get_cross_encoder_model(model)
query_pairs = [
(query, node.get_text(metadata_mode=MetadataMode.EMBED)) for node in nodes
]
scores = cross_encoder.predict(query_pairs)
sorted_indices = np.argsort(scores)[::-1] # Descending order
if topk > 0:
sorted_indices = sorted_indices[:topk]

return [nodes[i] for i in sorted_indices]


# User-defined similarity decorator
def register_reranker(func=None, batch=False):
return Reranker.register_reranker(func, batch)
38 changes: 20 additions & 18 deletions tests/advanced_tests/standard_test/test_reranker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import unittest
from unittest.mock import patch, MagicMock
import os
import lazyllm
from lazyllm.tools.rag.store import DocNode
from lazyllm.tools.rag.rerank import Reranker, register_reranker

Expand Down Expand Up @@ -35,24 +36,25 @@ def test_keyword_filter_with_exclude_keys(self):
self.assertEqual(len(results), 2)
self.assertNotIn(self.doc2, results)

@patch("lazyllm.components.utils.downloader.ModelManager.download")
@patch("sentence_transformers.CrossEncoder")
def test_module_reranker(self, MockCrossEncoder, mock_download):
mock_model = MagicMock()
mock_download.return_value = "mock_model_path"
MockCrossEncoder.return_value = mock_model
mock_model.predict.return_value = [0.8, 0.6, 0.9]
def test_module_reranker(self):
env_key = 'LAZYLLM_DEFAULT_EMBEDDING_ENGINE'
test_cases = ['', 'transformers']
original_value = os.getenv(env_key, None)
for value in test_cases:
with self.subTest(value=value):
os.environ[env_key] = value
lazyllm.config.refresh(env_key)
reranker = Reranker(name="ModuleReranker", model="bge-reranker-large", topk=2)
reranker.start()
results = reranker.forward(self.nodes, query='cherry')

reranker = Reranker(name="ModuleReranker", model="dummy-model", topk=2)
results = reranker.forward(self.nodes, query=self.query)

self.assertEqual(len(results), 2)
self.assertEqual(
results[0].get_text(), self.doc3.get_text()
) # highest score
self.assertEqual(
results[1].get_text(), self.doc1.get_text()
) # second highest score
self.assertEqual(len(results), 2)
self.assertEqual(
results[0].get_text(), self.doc3.get_text()
) # highest score
if original_value:
os.environ[env_key] = original_value
lazyllm.config.refresh(env_key)

def test_register_reranker_decorator(self):
@register_reranker
Expand Down