Skip to content

Commit

Permalink
Added rerank pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Aug 28, 2024
1 parent fb0a248 commit e12e58f
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
## 1.0.2 (unreleased)

- Added `rerank` pipeline
- Added support for `nomic-ai/nomic-embed-text-v1`
- Added support for `intfloat/e5-base-v2` to `Model`
- Added support for `BAAI/bge-base-en-v1.5` to `Model`
Expand Down
19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,18 @@ model = Informers.pipeline("feature-extraction", "BAAI/bge-base-en-v1.5", quanti
embeddings = model.(input, pooling: "mean", normalize: true)
```

### mixedbread-ai/mxbai-rerank-base-v1

[Docs](https://huggingface.co/mixedbread-ai/mxbai-rerank-base-v1) [unreleased]

```ruby
query = "How many people live in London?"
docs = ["Around 9 Million people live in London", "London is known for its financial district"]

model = Informers.pipeline("rerank", "mixedbread-ai/mxbai-rerank-base-v1", quantized: false)
result = model.(query, docs)
```

### Other

You can use the feature extraction pipeline directly.
Expand Down Expand Up @@ -171,6 +183,13 @@ extractor = Informers.pipeline("feature-extraction")
extractor.("We are very happy to show you the 🤗 Transformers library.")
```

Reranking [unreleased]

```ruby
ranker = Informers.pipeline("rerank")
ranker.("Who created Ruby?", ["Matz created Ruby", "Another doc"])
```

## Credits

This library was ported from [Transformers.js](https://github.com/xenova/transformers.js) and is available under the same license.
Expand Down
7 changes: 7 additions & 0 deletions lib/informers/models.rb
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,12 @@ class NomicBertPreTrainedModel < PreTrainedModel
class NomicBertModel < NomicBertPreTrainedModel
end

class DebertaV2PreTrainedModel < PreTrainedModel
end

class DebertaV2Model < DebertaV2PreTrainedModel
end

class DistilBertPreTrainedModel < PreTrainedModel
end

Expand All @@ -226,6 +232,7 @@ def call(model_inputs)
MODEL_MAPPING_NAMES_ENCODER_ONLY = {
"bert" => ["BertModel", BertModel],
"nomic_bert" => ["NomicBertModel", NomicBertModel],
"deberta-v2" => ["DebertaV2Model", DebertaV2Model],
"distilbert" => ["DistilBertModel", DistilBertModel]
}

Expand Down
43 changes: 43 additions & 0 deletions lib/informers/pipelines.rb
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,40 @@ def call(
end
end

class RerankPipeline < Pipeline
def initialize(**options)
super(**options)
end

def call(
query,
documents,
return_documents: false,
top_k: nil
)
model_inputs = @tokenizer.([query] * documents.size,
text_pair: documents,
padding: true,
truncation: true
)

outputs = @model.(model_inputs)

result =
Utils.sigmoid(outputs[0].map(&:first))
.map.with_index { |s, i| {doc_id: i, score: s} }
.sort_by { |v| -v[:score] }

if return_documents
result.each do |v|
v[:text] = documents[v[:doc_id]]
end
end

top_k ? result.first(top_k) : result
end
end

SUPPORTED_TASKS = {
"text-classification" => {
tokenizer: AutoTokenizer,
Expand Down Expand Up @@ -344,6 +378,15 @@ def call(
model: "Xenova/all-MiniLM-L6-v2"
},
type: "text"
},
"rerank" => {
tokenizer: AutoTokenizer,
pipeline: RerankPipeline,
model: AutoModel,
default: {
model: "mixedbread-ai/mxbai-rerank-base-v1"
},
type: "text"
}
}

Expand Down
6 changes: 6 additions & 0 deletions lib/informers/tokenizers.rb
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,18 @@ class BertTokenizer < PreTrainedTokenizer
# self.return_token_type_ids = true
end

class DebertaV2Tokenizer < PreTrainedTokenizer
# TODO
# self.return_token_type_ids = true
end

class DistilBertTokenizer < PreTrainedTokenizer
end

class AutoTokenizer
TOKENIZER_CLASS_MAPPING = {
"BertTokenizer" => BertTokenizer,
"DebertaV2Tokenizer" => DebertaV2Tokenizer,
"DistilBertTokenizer" => DistilBertTokenizer
}

Expand Down
17 changes: 17 additions & 0 deletions test/model_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -120,4 +120,21 @@ def test_bge_base
assert_elements_in_delta [0.00029264, -0.0619305, -0.06199387], embeddings[0][..2]
assert_elements_in_delta [-0.07482512, -0.0770234, 0.03398684], embeddings[-1][..2]
end

# https://huggingface.co/mixedbread-ai/mxbai-rerank-base-v1
def test_mxbai_rerank
query = "How many people live in London?"
docs = ["Around 9 Million people live in London", "London is known for its financial district"]

model = Informers.pipeline("rerank", "mixedbread-ai/mxbai-rerank-base-v1", quantized: false)
result = model.(query, docs, return_documents: true)

assert_equal 0, result[0][:doc_id]
assert_in_delta 0.984, result[0][:score]
assert_equal docs[0], result[0][:text]

assert_equal 1, result[1][:doc_id]
assert_in_delta 0.139, result[1][:score]
assert_equal docs[1], result[1][:text]
end
end
8 changes: 8 additions & 0 deletions test/pipeline_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,14 @@ def test_feature_extraction
assert_in_delta (-0.3130), output[-1][-1][-1]
end

def test_rerank
ranker = Informers.pipeline("rerank")
result = ranker.("Who created Ruby?", ["Matz created Ruby", "Another doc"])
assert_equal 2, result.size
assert_equal 0, result[0][:doc_id]
assert_equal 1, result[1][:doc_id]
end

def test_progress_callback
msgs = []
extractor = Informers.pipeline("feature-extraction", progress_callback: ->(msg) { msgs << msg })
Expand Down

0 comments on commit e12e58f

Please sign in to comment.