diff --git a/sentence_transformers/data_collator.py b/sentence_transformers/data_collator.py index 2633822ac..bfa54f40a 100644 --- a/sentence_transformers/data_collator.py +++ b/sentence_transformers/data_collator.py @@ -59,7 +59,7 @@ def __call__(self, features: list[dict[str, Any]]) -> dict[str, torch.Tensor]: for key, value in tokenized.items(): batch[f"{column_name}_{key}"] = value if prompt_len is not None: - batch[f"{column_name}_prompt_length"] = torch.Tensor( + batch[f"{column_name}_prompt_length"] = torch.tensor( [prompt_len] * len(values), device=batch[f"{column_name}_input_ids"].device, dtype=torch.long ) return batch diff --git a/sentence_transformers/trainer.py b/sentence_transformers/trainer.py index 766f8c2d1..96c1c7bdf 100644 --- a/sentence_transformers/trainer.py +++ b/sentence_transformers/trainer.py @@ -878,10 +878,14 @@ def _load_from_checkpoint(self, checkpoint_path: str) -> None: loaded_model = SentenceTransformer(checkpoint_path, trust_remote_code=self.model.trust_remote_code) self.model.load_state_dict(loaded_model.state_dict()) - def maybe_add_dataset_name_column(self, dataset_dict: DatasetDict | Dataset) -> DatasetDict | Dataset: + def maybe_add_dataset_name_column( + self, dataset_dict: DatasetDict | Dataset | None + ) -> DatasetDict | Dataset | None: """ Check if the the dataset_name should be added to the dataset. True if the dataset and loss are Dict or the prompts are mapping to dataset names. """ + if dataset_dict is None: + return None loss_is_dict = isinstance(self.loss, dict) dataset_is_dict = isinstance(dataset_dict, DatasetDict)