Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add support for Ray 2.8.1 #3817

Draft
wants to merge 15 commits into
base: master
Choose a base branch
from
12 changes: 6 additions & 6 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ jobs:
- python-version: "3.8"
pytorch-version: 2.0.0
torchscript-version: 1.10.2
ray-version: 2.2.0
ray-version: 2.8.1
- python-version: "3.9"
pytorch-version: 2.1.1
torchscript-version: 1.10.2
ray-version: 2.3.1
ray-version: 2.8.1
- python-version: "3.10"
pytorch-version: nightly
torchscript-version: 1.10.2
ray-version: 2.3.1
ray-version: 2.8.1
env:
PYTORCH: ${{ matrix.pytorch-version }}
MARKERS: ${{ matrix.test-markers }}
Expand Down Expand Up @@ -257,7 +257,7 @@ jobs:
cat requirements.txt | sed '/^torch[>=<\b]/d' | sed '/^torchtext/d' | sed '/^torchvision/d' | sed '/^torchaudio/d' > requirements-temp && mv requirements-temp requirements.txt
cat requirements_distributed.txt | sed '/^ray[\[]/d'
pip install torch==2.0.0 torchtext torchvision torchaudio
pip install ray==2.3.0
pip install ray==2.8.1
pip install '.[test]'
pip list
shell: bash
Expand Down Expand Up @@ -298,7 +298,7 @@ jobs:
cat requirements.txt | sed '/^torch[>=<\b]/d' | sed '/^torchtext/d' | sed '/^torchvision/d' | sed '/^torchaudio/d' > requirements-temp && mv requirements-temp requirements.txt
cat requirements_distributed.txt | sed '/^ray[\[]/d'
pip install torch==2.0.0 torchtext torchvision torchaudio
pip install ray==2.3.0
pip install ray==2.8.1
pip install '.[test]'
pip list
shell: bash
Expand Down Expand Up @@ -374,7 +374,7 @@ jobs:
pip --version
python -m pip install -U pip
pip install torch==2.0.0 torchtext
pip install ray==2.3.0
pip install ray==2.8.1
pip install '.'
pip list
shell: bash
Expand Down
100 changes: 46 additions & 54 deletions ludwig/backend/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import contextlib
import copy
import logging
import os
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

Expand All @@ -29,11 +30,11 @@
from packaging import version
from ray import ObjectRef
from ray.air import session
from ray.air.checkpoint import Checkpoint
from ray.air.config import DatasetConfig, RunConfig, ScalingConfig
from ray.air.config import RunConfig, ScalingConfig
from ray.air.result import Result
from ray.data import ActorPoolStrategy
from ray.train._checkpoint import Checkpoint
from ray.train.base_trainer import TrainingFailedError
from ray.train.torch import TorchCheckpoint
from ray.train.trainer import BaseTrainer as RayBaseTrainer
from ray.tune.tuner import Tuner
from ray.util.dask import ray_dask_get
Expand All @@ -52,6 +53,7 @@
init_dist_strategy,
LocalStrategy,
)
from ludwig.globals import MODEL_WEIGHTS_FILE_NAME
from ludwig.models.base import BaseModel
from ludwig.models.predictor import BasePredictor, get_output_columns, get_predictor_cls
from ludwig.schema.trainer import ECDTrainerConfig, FineTuneTrainerConfig
Expand All @@ -66,10 +68,10 @@
from ludwig.types import HyperoptConfigDict, ModelConfigDict, TrainerConfigDict, TrainingSetMetadataDict
from ludwig.utils.batch_size_tuner import BatchSizeEvaluator
from ludwig.utils.dataframe_utils import is_dask_series_or_df, set_index_name
from ludwig.utils.fs_utils import get_fs_and_path
from ludwig.utils.fs_utils import get_fs_and_path, open_file
from ludwig.utils.misc_utils import get_from_registry
from ludwig.utils.system_utils import Resources
from ludwig.utils.torch_utils import initialize_pytorch
from ludwig.utils.torch_utils import get_torch_device, initialize_pytorch
from ludwig.utils.types import DataFrame, Series

_ray220 = version.parse(ray.__version__) >= version.parse("2.2.0")
Expand Down Expand Up @@ -212,14 +214,26 @@ def train_fn(
report_tqdm_to_ray=True,
**executable_kwargs,
)
results = trainer.train(train_shard, val_shard, test_shard, return_state_dict=True, **kwargs)
# Results is a tuple object of length 4 that has:
# 1. The model state dict
# 2. The training statistics
# 3. The validation statistics
# 4. The test statistics
results: tuple = trainer.train(train_shard, val_shard, test_shard, return_state_dict=True, **kwargs)
torch.cuda.empty_cache()

# Create a local directory to store checkpoint related data
ckpt_dir = os.path.join(kwargs.get("save_path"), "checkpoint")
os.makedirs(ckpt_dir, exist_ok=True)

# Save the state dict to disk and load it back on the main process
ckpt_path = os.path.join(ckpt_dir, MODEL_WEIGHTS_FILE_NAME)
torch.save(results[0], ckpt_path)

# Passing objects containing Torch tensors as metrics is not supported as it will throw an
# exception on deserialization, so create a checkpoint and return via session.report() along
# with the path of the checkpoint
ckpt = Checkpoint.from_dict({"state_dict": results})
torch_ckpt = TorchCheckpoint.from_checkpoint(ckpt)
# with the path of the checkpoint on disk.
ckpt: Checkpoint = Checkpoint.from_directory(ckpt_dir)

# The checkpoint is put in the object store and then retrieved by the Trainable actor to be reported to Tune.
# It is also persisted on disk by the Trainable (and synced to cloud, if configured to do so)
Expand All @@ -229,8 +243,11 @@ def train_fn(
metrics={
"validation_field": trainer.validation_field,
"validation_metric": trainer.validation_metric,
"train_results": results[1],
"val_results": results[2],
"test_results": results[3],
},
checkpoint=torch_ckpt,
checkpoint=ckpt,
)

except Exception:
Expand Down Expand Up @@ -371,40 +388,6 @@ def __init__(self, trainer_kwargs: Dict[str, Any]) -> None:
**trainer_kwargs,
)

def _get_dataset_configs(
self,
datasets: Dict[str, Any],
stream_window_size: Dict[str, Union[None, float]],
data_loader_kwargs: Dict[str, Any],
) -> Dict[str, DatasetConfig]:
"""Generates DatasetConfigs for each dataset passed into the trainer."""
dataset_configs = {}
for dataset_name, _ in datasets.items():
if _ray230:
# DatasetConfig.use_stream_api and DatasetConfig.stream_window_size have been removed as of Ray 2.3.
# We need to use DatasetConfig.max_object_store_memory_fraction instead -> default to 20% when windowing
# is enabled unless the end user specifies a different fraction.
# https://docs.ray.io/en/master/ray-air/check-ingest.html?highlight=max_object_store_memory_fraction#enabling-streaming-ingest # noqa
dataset_conf = DatasetConfig(
split=True,
max_object_store_memory_fraction=stream_window_size.get(dataset_name),
)
else:
dataset_conf = DatasetConfig(
split=True,
use_stream_api=True,
stream_window_size=stream_window_size.get(dataset_name),
)

if dataset_name == "train":
# Mark train dataset as always required
dataset_conf.required = True
# Check data loader kwargs to see if shuffle should be enabled for the
# train dataset. global_shuffle is False by default for all other datasets.
dataset_conf.global_shuffle = data_loader_kwargs.get("shuffle", True)
dataset_configs[dataset_name] = dataset_conf
return dataset_configs

def run(
self,
train_loop_per_worker: Callable,
Expand All @@ -417,9 +400,13 @@ def run(
) -> Result:
dataset_config = None
if dataset is not None:
data_loader_kwargs = data_loader_kwargs or {}
stream_window_size = stream_window_size or {}
dataset_config = self._get_dataset_configs(dataset, stream_window_size, data_loader_kwargs)
dataset_config = ray.train.DataConfig(
datasets_to_split="all",
execution_options=ray.data.ExecutionOptions(
preserve_order=data_loader_kwargs.get("shuffle", True),
verbose_progress=True,
),
)

callbacks = callbacks or []

Expand Down Expand Up @@ -523,17 +510,22 @@ def train(
self._validation_metric = trainer_results.metrics["validation_metric"]

# Load model from checkpoint
ckpt = TorchCheckpoint.from_checkpoint(trainer_results.checkpoint)
results = ckpt.to_dict()["state_dict"]
ckpt = trainer_results.checkpoint

with open_file(os.path.join(ckpt.path, MODEL_WEIGHTS_FILE_NAME), "rb") as f:
state_dict = torch.load(f, map_location=torch.device(get_torch_device()))

# load state dict back into the model
# use `strict=False` to account for PEFT training, where the saved state in the checkpoint
# might only contain the PEFT layers that were modified during training
state_dict, *args = results
self.model.load_state_dict(state_dict, strict=False)
results = (self.model, *args)

return results
return (
self.model,
trainer_results.metrics["train_results"],
trainer_results.metrics["val_results"],
trainer_results.metrics["test_results"],
)

def train_online(self, *args, **kwargs):
# TODO: When this is implemented we also need to update the
Expand Down Expand Up @@ -752,7 +744,7 @@ def batch_predict(
predictions = dataset.ds.map_batches(
batch_predictor,
batch_size=self.batch_size,
compute="actors",
compute=ActorPoolStrategy(),
batch_format="pandas",
num_cpus=num_cpus,
num_gpus=num_gpus,
Expand Down Expand Up @@ -1135,7 +1127,7 @@ def batch_transform(self, df: DataFrame, batch_size: int, transform_fn: Callable
ds = ds.map_batches(
transform_fn,
batch_size=batch_size,
compute="actors",
compute=ActorPoolStrategy(),
batch_format="pandas",
**self._get_transform_kwargs(),
)
Expand Down
4 changes: 2 additions & 2 deletions ludwig/data/dataframe/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from dask.diagnostics import ProgressBar
from packaging import version
from pyarrow.fs import FSSpecHandler, PyFileSystem
from ray.data import Dataset, read_parquet
from ray.data import ActorPoolStrategy, Dataset, read_parquet

from ludwig.api_annotations import DeveloperAPI
from ludwig.data.dataframe.base import DataFrameEngine
Expand Down Expand Up @@ -167,7 +167,7 @@ def map_batches(self, series, map_fn, enable_tensor_extension_casting=True):

with tensor_extension_casting(enable_tensor_extension_casting):
ds = ray.data.from_dask(series)
ds = ds.map_batches(map_fn, batch_format="pandas")
ds = ds.map_batches(map_fn, batch_format="pandas", compute=ActorPoolStrategy())
return ds.to_dask()

def apply_objects(self, df, apply_fn, meta=None):
Expand Down
33 changes: 24 additions & 9 deletions ludwig/data/dataset/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import queue
import threading
from functools import lru_cache
from typing import Dict, Iterable, Iterator, Literal, Optional, Union
from typing import Dict, Iterable, Literal, Optional, Union

import numpy as np
import pandas as pd
Expand All @@ -29,7 +29,9 @@
from pyarrow.fs import FSSpecHandler, PyFileSystem
from pyarrow.lib import ArrowInvalid
from ray.data import read_parquet
from ray.data.dataset import Dataset as _Dataset
from ray.data.dataset_pipeline import DatasetPipeline
from ray.data.iterator import DataIterator

from ludwig.api_annotations import DeveloperAPI
from ludwig.backend.base import Backend
Expand All @@ -49,6 +51,7 @@

logger = logging.getLogger(__name__)

_ray_240 = version.parse(ray.__version__) >= version.parse("2.4.0")
_ray_230 = version.parse(ray.__version__) >= version.parse("2.3.0")


Expand Down Expand Up @@ -140,7 +143,8 @@ def initialize_batcher(
augmentation_pipeline=None,
):
yield RayDatasetBatcher(
self.ds.repeat().iter_datasets(),
# self.ds is a MaterializedDataset object - the iterator call returns a DataIterator object
self.ds.iterator(),
self.features,
self.training_set_metadata,
batch_size,
Expand Down Expand Up @@ -234,7 +238,7 @@ def data_format(self):
class RayDatasetShard(Dataset):
def __init__(
self,
dataset_shard: DatasetPipeline,
dataset_shard: _Dataset,
features: Dict[str, FeatureConfigDict],
training_set_metadata: TrainingSetMetadataDict,
):
Expand All @@ -244,6 +248,10 @@ def __init__(
self.create_epoch_iter()

def create_epoch_iter(self) -> None:
if _ray_240:
self.epoch_iter = self.dataset_shard
return

if _ray_230:
# In Ray >= 2.3, session.get_dataset_shard() returns a DatasetIterator object.
if isinstance(self.dataset_shard, ray.data.DatasetIterator):
Expand Down Expand Up @@ -289,7 +297,14 @@ def initialize_batcher(

@lru_cache(1)
def __len__(self):
return next(self.epoch_iter).count()
if isinstance(self.epoch_iter, DataIterator):
num_rows = 0
for block, meta in self.epoch_iter._to_block_iterator()[0]:
num_rows += meta.num_rows
return num_rows
else:
# self.epoch_iter is a ray.data.Dataset object
return self.epoch_iter.count()

@property
def size(self):
Expand All @@ -306,7 +321,7 @@ def to_scalar_df(self, features: Optional[Iterable[BaseFeature]] = None) -> Data
class RayDatasetBatcher(Batcher):
def __init__(
self,
dataset_epoch_iterator: Iterator[DatasetPipeline],
dataset_epoch_iterator: _Dataset,
features: Dict[str, Dict],
training_set_metadata: TrainingSetMetadataDict,
batch_size: int,
Expand Down Expand Up @@ -364,7 +379,7 @@ def steps_per_epoch(self):
return math.ceil(self.samples_per_epoch / self.batch_size)

def _fetch_next_epoch(self):
pipeline = next(self.dataset_epoch_iterator)
pipeline = self.dataset_epoch_iterator

read_parallelism = 1
if read_parallelism == 1:
Expand Down Expand Up @@ -431,14 +446,14 @@ def augment_batch(df: pd.DataFrame) -> pd.DataFrame:

return augment_batch

def _create_sync_reader(self, pipeline: DatasetPipeline):
def _create_sync_reader(self, pipeline: _Dataset):
def sync_read():
for batch in pipeline.iter_batches(prefetch_blocks=0, batch_size=self.batch_size, batch_format="pandas"):
yield self._prepare_batch(batch)

return sync_read()

def _create_async_reader(self, pipeline: DatasetPipeline):
def _create_async_reader(self, pipeline: _Dataset):
q = queue.Queue(maxsize=100)
batch_size = self.batch_size
augment_batch = self._augment_batch_fn()
Expand Down Expand Up @@ -474,7 +489,7 @@ def async_read():

return async_read()

def _create_async_parallel_reader(self, pipeline: DatasetPipeline, num_threads: int):
def _create_async_parallel_reader(self, pipeline: _Dataset, num_threads: int):
q = queue.Queue(maxsize=100)

batch_size = self.batch_size
Expand Down
Loading
Loading