Skip to content

Commit

Permalink
Fix for dtype and None dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurCamara committed Oct 3, 2024
1 parent 858a195 commit 1d808d9
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
2 changes: 1 addition & 1 deletion sentence_transformers/data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __call__(self, features: list[dict[str, Any]]) -> dict[str, torch.Tensor]:
for key, value in tokenized.items():
batch[f"{column_name}_{key}"] = value
if prompt_len is not None:
batch[f"{column_name}_prompt_length"] = torch.Tensor(
batch[f"{column_name}_prompt_length"] = torch.tensor(
[prompt_len] * len(values), device=batch[f"{column_name}_input_ids"].device, dtype=torch.long
)
return batch
Expand Down
6 changes: 5 additions & 1 deletion sentence_transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,10 +878,14 @@ def _load_from_checkpoint(self, checkpoint_path: str) -> None:
loaded_model = SentenceTransformer(checkpoint_path, trust_remote_code=self.model.trust_remote_code)
self.model.load_state_dict(loaded_model.state_dict())

def maybe_add_dataset_name_column(self, dataset_dict: DatasetDict | Dataset) -> DatasetDict | Dataset:
def maybe_add_dataset_name_column(
self, dataset_dict: DatasetDict | Dataset | None
) -> DatasetDict | Dataset | None:
"""
Check if the the dataset_name should be added to the dataset. True if the dataset and loss are Dict or the prompts are mapping to dataset names.
"""
if dataset_dict is None:
return None
loss_is_dict = isinstance(self.loss, dict)
dataset_is_dict = isinstance(dataset_dict, DatasetDict)

Expand Down

0 comments on commit 1d808d9

Please sign in to comment.