From 4d98be3c29f4e499bfe7e9064186e5325d3be39a Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Mon, 14 Oct 2024 17:29:47 -0700 Subject: [PATCH] Fixed error with sentence-transformers/all-MiniLM-L6-v2 - fixes #9 --- CHANGELOG.md | 1 + lib/informers/pipelines.rb | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f084db9..8da0e96 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ ## 1.1.1 (unreleased) - Added `audio-classification` pipeline +- Fixed error with `sentence-transformers/all-MiniLM-L6-v2` ## 1.1.0 (2024-09-17) diff --git a/lib/informers/pipelines.rb b/lib/informers/pipelines.rb index 03976fd..131ab92 100644 --- a/lib/informers/pipelines.rb +++ b/lib/informers/pipelines.rb @@ -837,7 +837,7 @@ def call( if !model_output.nil? model_options[:output_names] = Array(model_output) elsif @model.instance_variable_get(:@output_names) == ["token_embeddings"] && pooling == "mean" && normalize - # optimization for sentence-transformers/all-MiniLM-L6-v2 + # optimization for previous revision of sentence-transformers/all-MiniLM-L6-v2 model_options[:output_names] = ["sentence_embedding"] pooling = "none" normalize = false @@ -1402,7 +1402,8 @@ def pipeline( results = load_items(classes, model, pretrained_options) results[:task] = task - if model == "sentence-transformers/all-MiniLM-L6-v2" + # for previous revision of sentence-transformers/all-MiniLM-L6-v2 + if model == "sentence-transformers/all-MiniLM-L6-v2" && results[:model].instance_variable_get(:@session).outputs.any? { |v| v[:name] == "token_embeddings" } results[:model].instance_variable_set(:@output_names, ["token_embeddings"]) end