Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[chore] Some refactoring in SentenceTransformer's encode method #2809

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 75 additions & 48 deletions sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,46 +537,25 @@ def encode(
device = self.device

self.to(device)

all_embeddings = []

length_sorted_idx = np.argsort([-self._text_length(sen) for sen in sentences])
sentences_sorted = [sentences[idx] for idx in length_sorted_idx]

for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar):
sentences_batch = sentences_sorted[start_index : start_index + batch_size]
features = self.tokenize(sentences_batch)

if self.device.type == "hpu":
if "input_ids" in features:
curr_tokenize_len = features["input_ids"].shape
additional_pad_len = 2 ** math.ceil(math.log2(curr_tokenize_len[1])) - curr_tokenize_len[1]
features["input_ids"] = torch.cat(
(
features["input_ids"],
torch.ones((curr_tokenize_len[0], additional_pad_len), dtype=torch.int8),
),
-1,
)
features["attention_mask"] = torch.cat(
(
features["attention_mask"],
torch.zeros((curr_tokenize_len[0], additional_pad_len), dtype=torch.int8),
),
-1,
)
if "token_type_ids" in features:
features["token_type_ids"] = torch.cat(
(
features["token_type_ids"],
torch.zeros((curr_tokenize_len[0], additional_pad_len), dtype=torch.int8),
),
-1,
)
self._pad_features(features)

features = batch_to_device(features, device)
features.update(extra_features)

with torch.no_grad():
out_features = self.forward(features)

if self.device.type == "hpu":
out_features = copy.deepcopy(out_features)

Expand All @@ -585,29 +564,17 @@ def encode(
)

if output_value == "token_embeddings":
embeddings = []
for token_emb, attention in zip(out_features[output_value], out_features["attention_mask"]):
last_mask_id = len(attention) - 1
while last_mask_id > 0 and attention[last_mask_id].item() == 0:
last_mask_id -= 1

embeddings.append(token_emb[0 : last_mask_id + 1])
elif output_value is None: # Return all outputs
embeddings = []
for sent_idx in range(len(out_features["sentence_embedding"])):
row = {name: out_features[name][sent_idx] for name in out_features}
embeddings.append(row)
else: # Sentence embeddings
embeddings = out_features[output_value]
embeddings = embeddings.detach()
if normalize_embeddings:
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)

# fixes for #522 and #487 to avoid oom problems on gpu with large datasets
if convert_to_numpy:
embeddings = embeddings.cpu()

all_embeddings.extend(embeddings)
all_embeddings.extend(self._process_token_embeddings(out_features))
elif output_value == "sentence_embedding":
all_embeddings.extend(
self._process_sentence_embeddings(out_features, normalize_embeddings, convert_to_numpy)
)
elif not output_value:
all_embeddings.extend(self._process_all_outputs(out_features))
else:
raise ValueError(
f"Got unexpected value for 'output_value' : {output_value}. Valid values are 'token_embeddings', 'sentence_embedding' or None."
)

all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)]

Expand Down Expand Up @@ -636,6 +603,66 @@ def encode(

return all_embeddings

@staticmethod
def _pad_features(features: Dict[str, torch.Tensor]) -> None:
"""
Pads the input features to the next power of 2 for compatibility with certain hardware accelerators.
"""
curr_tokenize_len = features["input_ids"].shape
additional_pad_len = 2 ** math.ceil(math.log2(curr_tokenize_len[1])) - curr_tokenize_len[1]
features["input_ids"] = torch.cat(
(
features["input_ids"],
torch.ones((curr_tokenize_len[0], additional_pad_len), dtype=torch.int8),
),
-1,
)
features["attention_mask"] = torch.cat(
(
features["attention_mask"],
torch.zeros((curr_tokenize_len[0], additional_pad_len), dtype=torch.int8),
),
-1,
)
if "token_type_ids" in features:
features["token_type_ids"] = torch.cat(
(
features["token_type_ids"],
torch.zeros((curr_tokenize_len[0], additional_pad_len), dtype=torch.int8),
),
-1,
)

@staticmethod
def _process_token_embeddings(out_features: Dict[str, torch.Tensor]) -> List[torch.Tensor]:
embeddings = []
for token_emb, attention in zip(out_features["token_embeddings"], out_features["attention_mask"]):
last_mask_id = len(attention) - 1
while last_mask_id > 0 and attention[last_mask_id].item() == 0:
last_mask_id -= 1
embeddings.append(token_emb[0 : last_mask_id + 1])
return embeddings

@staticmethod
def _process_all_outputs(out_features: Dict[str, torch.Tensor]) -> List[Dict[str, torch.Tensor]]:
embeddings = []
for sent_idx in range(len(out_features["sentence_embedding"])):
row = {name: out_features[name][sent_idx] for name in out_features}
embeddings.append(row)
return embeddings

@staticmethod
def _process_sentence_embeddings(
out_features: Dict[str, torch.Tensor], normalize_embeddings: bool, convert_to_numpy: bool
) -> Union[List[np.ndarray], List[torch.Tensor]]:
embeddings = out_features["sentence_embedding"]
embeddings = embeddings.detach()
if normalize_embeddings:
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
if convert_to_numpy:
embeddings = embeddings.cpu()
return embeddings

@property
def similarity_fn_name(self) -> Optional[str]:
"""Return the name of the similarity function used by :meth:`SentenceTransformer.similarity` and :meth:`SentenceTransformer.similarity_pairwise`.
Expand Down
Loading