From f5ddbc6592b3f639e95e9120172dba568e9f8502 Mon Sep 17 00:00:00 2001 From: Agus Date: Tue, 28 Jan 2025 20:05:54 +0100 Subject: [PATCH] Add `Checkpointer` step (#1114) --- .../how_to_guides/advanced/checkpointing.md | 59 ++++++++ mkdocs.yml | 1 + src/distilabel/distiset.py | 12 +- src/distilabel/pipeline/base.py | 7 + src/distilabel/pipeline/write_buffer.py | 7 +- src/distilabel/steps/__init__.py | 2 + src/distilabel/steps/checkpointer.py | 131 ++++++++++++++++++ .../utils/mkdocs/components_gallery.py | 2 + 8 files changed, 216 insertions(+), 5 deletions(-) create mode 100644 docs/sections/how_to_guides/advanced/checkpointing.md create mode 100644 src/distilabel/steps/checkpointer.py diff --git a/docs/sections/how_to_guides/advanced/checkpointing.md b/docs/sections/how_to_guides/advanced/checkpointing.md new file mode 100644 index 0000000000..9607177da6 --- /dev/null +++ b/docs/sections/how_to_guides/advanced/checkpointing.md @@ -0,0 +1,59 @@ +# Push data to the hub while the pipeline is running + +Long-running pipelines can be resource-intensive, and ensuring everything is functioning as expected is crucial. To make this process seamless, the [HuggingFaceHubCheckpointer][distilabel.steps.checkpointer.HuggingFaceHubCheckpointer] step has been designed to integrate directly into the pipeline workflow. + +The [`HuggingFaceHubCheckpointer`](https://distilabel.argilla.io/dev/sections/getting_started/quickstart/) allows you to periodically save your generated data as a Hugging Face Dataset at configurable intervals (every `input_batch_size` examples generated). + +Just add the [`HuggingFaceHubCheckpointer`](https://distilabel.argilla.io/dev/sections/getting_started/quickstart/) as any other step in your pipeline. + +## Sample pipeline with dummy data to see the checkpoint strategy in action + +The following pipeline starts from a fake dataset with dummy data, passes that through a fake `DoNothing` step (any other step/s work here, but this can be useful to explore the behavior), and makes use of the [`HuggingFaceHubCheckpointer`](https://distilabel.argilla.io/dev/sections/getting_started/quickstart/) step to push the data to the hub. + +```python +from datasets import Dataset + +from distilabel.pipeline import Pipeline +from distilabel.steps import HuggingFaceHubCheckpointer +from distilabel.steps.base import Step, StepInput +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from distilabel.typing import StepOutput + +dataset = Dataset.from_dict({"a": [1, 2, 3, 4] * 50, "b": [5, 6, 7, 8] * 50}) + +class DoNothing(Step): + def process(self, *inputs: StepInput) -> "StepOutput": + for input in inputs: + yield input + +with Pipeline(name="pipeline-with-checkpoints") as pipeline: + text_generation = DoNothing( + input_batch_size=60 + ) + checkpoint = HuggingFaceHubCheckpointer( + repo_id="username/streaming_test_1", # (1) + private=True, + input_batch_size=50 # (2) + ) + text_generation >> checkpoint + + +if __name__ == "__main__": + distiset = pipeline.run( + dataset=dataset, + use_cache=False + ) + distiset.push_to_hub(repo_id="username/streaming_test") +``` + +1. The name of the dataset for the checkpoints, can be different to the final distiset. This dataset +will contain less information than the final distiset to make it faster while the pipeline is running. +2. The `input_batch_size` determines how often the data is pushed to the Hugging Face Hub. If the process is really slow, say for a big model, a value like 100 may be on point, for smaller models or pipelines that generate data faster, 10.000 maybe more relevant. It's better to explore the value for a given use case. + +The final datasets can be found in the following links: + +- Checkpoint dataset: [distilabel-internal-testing/streaming_test_1](https://huggingface.co/datasets/distilabel-internal-testing/streaming_test_1) + +- Final distiset: [distilabel-internal-testing/streaming_test](https://huggingface.co/datasets/distilabel-internal-testing/streaming_test) diff --git a/mkdocs.yml b/mkdocs.yml index 24e5ca9b74..40435102de 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -193,6 +193,7 @@ nav: - Exporting data to Argilla: "sections/how_to_guides/advanced/argilla.md" - Structured data generation: "sections/how_to_guides/advanced/structured_generation.md" - Offline Batch Generation: "sections/how_to_guides/advanced/offline_batch_generation.md" + - Push data to the hub while the pipeline is running: "sections/how_to_guides/advanced/checkpointing.md" - Specifying requirements for pipelines and steps: "sections/how_to_guides/advanced/pipeline_requirements.md" - Load groups and execution stages: "sections/how_to_guides/advanced/load_groups_and_execution_stages.md" - Using CLI to explore and re-run existing Pipelines: "sections/how_to_guides/advanced/cli/index.md" diff --git a/src/distilabel/distiset.py b/src/distilabel/distiset.py index f44d20a3ab..1a43737ae7 100644 --- a/src/distilabel/distiset.py +++ b/src/distilabel/distiset.py @@ -286,9 +286,15 @@ def _extract_readme_metadata( Returns: The metadata extracted from the README.md file of the dataset repository as a dict. """ - readme_path = Path( - hf_hub_download(repo_id, "README.md", repo_type="dataset", token=token) - ) + import requests + + try: + readme_path = Path( + hf_hub_download(repo_id, "README.md", repo_type="dataset", token=token) + ) + except requests.exceptions.HTTPError: + # This can fail when using the checkpoint step + return {} # Remove the '---' from the metadata metadata = re.findall(r"---\n(.*?)\n---", readme_path.read_text(), re.DOTALL)[0] metadata = yaml.safe_load(metadata) diff --git a/src/distilabel/pipeline/base.py b/src/distilabel/pipeline/base.py index 2a0c89abd5..cf1e315246 100644 --- a/src/distilabel/pipeline/base.py +++ b/src/distilabel/pipeline/base.py @@ -175,6 +175,7 @@ def __init__( cache_dir: Optional[Union[str, "PathLike"]] = None, enable_metadata: bool = False, requirements: Optional[List[str]] = None, + dump_batch_size: int = 50, ) -> None: """Initialize the `BasePipeline` instance. @@ -189,6 +190,9 @@ def __init__( requirements: List of requirements that must be installed to run the pipeline. Defaults to `None`, but can be helpful to inform in a pipeline to be shared that this requirements must be installed. + dump_batch_size: Determines the frequency of writing the buffer to the file, + as it will determine when the buffer is full and we should write to the file. + Defaults to 50 (every 50 elements in the buffer we can check for writes). """ self.name = name or _PIPELINE_DEFAULT_NAME self.description = description @@ -234,6 +238,8 @@ def __init__( self._log_queue: Union["Queue[Any]", None] = None + self._dump_batch_size = dump_batch_size + def __enter__(self) -> Self: """Set the global pipeline instance when entering a pipeline context.""" _GlobalPipelineManager.set_pipeline(self) @@ -1022,6 +1028,7 @@ def _setup_write_buffer(self, use_cache: bool = True) -> None: ].use_cache for step_name in self.dag }, + dump_batch_size=self._dump_batch_size, ) def _print_load_stages_info(self) -> None: diff --git a/src/distilabel/pipeline/write_buffer.py b/src/distilabel/pipeline/write_buffer.py index 3fdb037e14..8b7df7d3f9 100644 --- a/src/distilabel/pipeline/write_buffer.py +++ b/src/distilabel/pipeline/write_buffer.py @@ -38,6 +38,7 @@ def __init__( path: "PathLike", leaf_steps: Set[str], steps_cached: Optional[Dict[str, bool]] = None, + dump_batch_size: int = 50, ) -> None: """ Args: @@ -48,6 +49,9 @@ def __init__( use_cache. We will use this to determine whether we have to read a previous parquet table to concatenate before saving the cached datasets. + dump_batch_size: Determines the frequency of writing the buffer to the file, + as it will determine when the buffer is full and we should write to the file. + Defaults to 50 (every 50 elements in the buffer we can check for writes). Raises: ValueError: If the path is not a directory. @@ -64,9 +68,8 @@ def __init__( self._buffers: Dict[str, List[Dict[str, Any]]] = { step: [] for step in leaf_steps } - # TODO: make this configurable self._buffers_dump_batch_size: Dict[str, int] = { - step: 50 for step in leaf_steps + step: dump_batch_size for step in leaf_steps } self._buffer_last_schema = {} self._buffers_last_file: Dict[str, int] = {step: 1 for step in leaf_steps} diff --git a/src/distilabel/steps/__init__.py b/src/distilabel/steps/__init__.py index 661704c7d8..772501638f 100644 --- a/src/distilabel/steps/__init__.py +++ b/src/distilabel/steps/__init__.py @@ -21,6 +21,7 @@ StepInput, StepResources, ) +from distilabel.steps.checkpointer import HuggingFaceHubCheckpointer from distilabel.steps.clustering.dbscan import DBSCAN from distilabel.steps.clustering.text_clustering import TextClustering from distilabel.steps.clustering.umap import UMAP @@ -76,6 +77,7 @@ "GeneratorStepOutput", "GlobalStep", "GroupColumns", + "HuggingFaceHubCheckpointer", "KeepColumns", "LoadDataFromDicts", "LoadDataFromDisk", diff --git a/src/distilabel/steps/checkpointer.py b/src/distilabel/steps/checkpointer.py new file mode 100644 index 0000000000..d9d48b1cb5 --- /dev/null +++ b/src/distilabel/steps/checkpointer.py @@ -0,0 +1,131 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import tempfile +from typing import TYPE_CHECKING, Optional + +from pydantic import PrivateAttr + +from distilabel.steps.base import Step, StepInput + +if TYPE_CHECKING: + from distilabel.typing import StepOutput + +from huggingface_hub import HfApi + + +class HuggingFaceHubCheckpointer(Step): + """Special type of step that uploads the data to a Hugging Face Hub dataset. + + A `Step` that uploads the data to a Hugging Face Hub dataset. The data is uploaded in JSONL format + in a specific Hugging Face Dataset, which can be different to the one where the main distiset + pipeline is saved. The data is checked every `input_batch_size` inputs, and a new file is created + in the `repo_id` repository. There will be different config files depending on the leaf steps + as in the pipeline, and each file will be numbered sequentially. As there will be writes every + `input_batch_size` inputs, it's advisable not to set a small number on this step, as that + will slow down the process. + + Attributes: + repo_id: + The ID of the repository to push to in the following format: `/` or + `/`. Also accepts ``, which will default to the namespace + of the logged-in user. + private: + Whether the dataset repository should be set to private or not. Only affects repository creation: + a repository that already exists will not be affected by that parameter. + token: + An optional authentication token for the Hugging Face Hub. If no token is passed, will default + to the token saved locally when logging in with `huggingface-cli login`. Will raise an error + if no token is passed and the user is not logged-in. + + Categories: + - helper + + Examples: + Do checkpoints of the data generated in a Hugging Face Hub dataset: + + ```python + from typing import TYPE_CHECKING + from datasets import Dataset + + from distilabel.pipeline import Pipeline + from distilabel.steps import HuggingFaceHubCheckpointer + from distilabel.steps.base import Step, StepInput + + if TYPE_CHECKING: + from distilabel.typing import StepOutput + + # Create a dummy dataset + dataset = Dataset.from_dict({"instruction": ["tell me lies"] * 100}) + + with Pipeline(name="pipeline-with-checkpoints") as pipeline: + text_generation = TextGeneration( + llm=InferenceEndpointsLLM( + model_id="meta-llama/Meta-Llama-3.1-8B-Instruct", + tokenizer_id="meta-llama/Meta-Llama-3.1-8B-Instruct", + ), + template="Follow the following instruction: {{ instruction }}" + ) + checkpoint = HuggingFaceHubCheckpointer( + repo_id="username/streaming_checkpoint", + private=True, + input_batch_size=50 # Will save write the data to the dataset every 50 inputs + ) + text_generation >> checkpoint + ``` + + """ + + repo_id: str + private: bool = True + token: Optional[str] = None + + _counter: int = PrivateAttr(0) + + def load(self) -> None: + super().load() + if self.token is None: + from distilabel.utils.huggingface import get_hf_token + + self.token = get_hf_token(self.__class__.__name__, "token") + + self._api = HfApi(token=self.token) + # Create the repo if it doesn't exist + if not self._api.repo_exists(repo_id=self.repo_id, repo_type="dataset"): + self._logger.info(f"Creating repo {self.repo_id}") + self._api.create_repo( + repo_id=self.repo_id, repo_type="dataset", private=self.private + ) + + def process(self, *inputs: StepInput) -> "StepOutput": + for i, input in enumerate(inputs): + # Each section of *inputs corresponds to a different configuration of the pipeline + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl") as temp_file: + for item in input: + json_line = json.dumps(item, ensure_ascii=False) + temp_file.write(json_line + "\n") + try: + self._api.upload_file( + path_or_fileobj=temp_file.name, + path_in_repo=f"config-{i}/train-{str(self._counter).zfill(5)}.jsonl", + repo_id=self.repo_id, + repo_type="dataset", + commit_message=f"Checkpoint {i}-{self._counter}", + ) + self._logger.info(f"⬆️ Uploaded checkpoint {i}-{self._counter}") + finally: + self._counter += 1 + + yield from inputs diff --git a/src/distilabel/utils/mkdocs/components_gallery.py b/src/distilabel/utils/mkdocs/components_gallery.py index 77225b9baa..9168faeeb5 100644 --- a/src/distilabel/utils/mkdocs/components_gallery.py +++ b/src/distilabel/utils/mkdocs/components_gallery.py @@ -95,6 +95,7 @@ "save": ":material-content-save:", "image-generation": ":material-image:", "labelling": ":label:", + "helper": ":fontawesome-solid-kit-medical:", } _STEP_CATEGORY_TO_DESCRIPTION = { @@ -116,6 +117,7 @@ "save": "Save steps are used to save the data.", "image-generation": "Image generation steps are used to generate images based on a given prompt.", "labelling": "Labelling steps are used to label the data.", + "helper": "Helper steps are used to do extra tasks during the pipeline execution.", }