Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/ludwig-ai/ludwig
Browse files Browse the repository at this point in the history
  • Loading branch information
connor-mccorm committed Oct 12, 2023
2 parents 379f6e1 + fd91478 commit e6b6bb9
Show file tree
Hide file tree
Showing 14 changed files with 453 additions and 138 deletions.
46 changes: 23 additions & 23 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -163,31 +163,31 @@ jobs:
run: |
RUN_PRIVATE=$IS_NOT_FORK LUDWIG_TEST_SUITE_TIMEOUT_S=5400 pytest -v --timeout 300 --durations 100 -m "$MARKERS and not slow and not combinatorial and not horovod or benchmark and not llm" --junitxml pytest.xml tests/regression_tests
# Skip Horovod installation for torch nightly.
# Skip Horovod and replace with DDP.
# https://github.com/ludwig-ai/ludwig/issues/3468
- name: Install Horovod if necessary
if: matrix.test-markers == 'distributed' && matrix.pytorch-version != 'nightly'
env:
HOROVOD_WITH_PYTORCH: 1
HOROVOD_WITHOUT_MPI: 1
HOROVOD_WITHOUT_TENSORFLOW: 1
HOROVOD_WITHOUT_MXNET: 1
run: |
pip install -r requirements_extra.txt
HOROVOD_BUILT=$(python -c "import horovod.torch; horovod.torch.nccl_built(); print('SUCCESS')" || true)
if [[ $HOROVOD_BUILT != "SUCCESS" ]]; then
pip uninstall -y horovod
pip install --no-cache-dir git+https://github.com/horovod/horovod.git@master
fi
horovodrun --check-build
shell: bash

# Skip Horovod tests for torch nightly.
# - name: Install Horovod if necessary
# if: matrix.test-markers == 'distributed' && matrix.pytorch-version != 'nightly'
# env:
# HOROVOD_WITH_PYTORCH: 1
# HOROVOD_WITHOUT_MPI: 1
# HOROVOD_WITHOUT_TENSORFLOW: 1
# HOROVOD_WITHOUT_MXNET: 1
# run: |
# pip install -r requirements_extra.txt
# HOROVOD_BUILT=$(python -c "import horovod.torch; horovod.torch.nccl_built(); print('SUCCESS')" || true)
# if [[ $HOROVOD_BUILT != "SUCCESS" ]]; then
# pip uninstall -y horovod
# pip install --no-cache-dir git+https://github.com/horovod/horovod.git@master
# fi
# horovodrun --check-build
# shell: bash

# Skip Horovod tests and replace with DDP.
# https://github.com/ludwig-ai/ludwig/issues/3468
- name: Horovod Tests
if: matrix.test-markers == 'distributed' && matrix.pytorch-version != 'nightly'
run: |
RUN_PRIVATE=$IS_NOT_FORK LUDWIG_TEST_SUITE_TIMEOUT_S=5400 pytest -v --timeout 300 --durations 100 -m "$MARKERS and horovod and not slow and not combinatorial and not llm" --junitxml pytest.xml tests/
# - name: Horovod Tests
# if: matrix.test-markers == 'distributed' && matrix.pytorch-version != 'nightly'
# run: |
# RUN_PRIVATE=$IS_NOT_FORK LUDWIG_TEST_SUITE_TIMEOUT_S=5400 pytest -v --timeout 300 --durations 100 -m "$MARKERS and horovod and not slow and not combinatorial and not llm" --junitxml pytest.xml tests/

- name: Upload Unit Test Results
if: ${{ always() && !env.ACT }}
Expand Down
51 changes: 36 additions & 15 deletions ludwig/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,16 @@ def train(
`(training_set, validation_set, test_set)`.
`output_directory` filepath to where training results are stored.
"""
# Only reset the metadata if the model has not been trained before
if self.training_set_metadata:
logger.warning(
"This model has been trained before. Its architecture has been defined by the original training set "
"(for example, the number of possible categorical outputs). The current training data will be mapped "
"to this architecture. If you want to change the architecture of the model, please concatenate your "
"new training data with the original and train a new model from scratch."
)
training_set_metadata = self.training_set_metadata

if self._user_config.get(HYPEROPT):
print_boxed("WARNING")
logger.warning(HYPEROPT_WARNING)
Expand Down Expand Up @@ -1215,7 +1225,6 @@ def experiment(
data_format: Optional[str] = None,
experiment_name: str = "experiment",
model_name: str = "run",
model_load_path: Optional[str] = None,
model_resume_path: Optional[str] = None,
eval_split: str = TEST,
skip_save_training_description: bool = False,
Expand Down Expand Up @@ -1264,9 +1273,6 @@ def experiment(
the experiment.
:param model_name: (str, default: `'run'`) name of the model that is
being used.
:param model_load_path: (str, default: `None`) if this is specified the
loaded model will be used as initialization
(useful for transfer learning).
:param model_resume_path: (str, default: `None`) resumes training of
the model from the path specified. The config is restored.
In addition to config, training statistics and loss for
Expand Down Expand Up @@ -1347,7 +1353,6 @@ def experiment(
data_format=data_format,
experiment_name=experiment_name,
model_name=model_name,
model_load_path=model_load_path,
model_resume_path=model_resume_path,
skip_save_training_description=skip_save_training_description,
skip_save_training_statistics=skip_save_training_statistics,
Expand Down Expand Up @@ -2136,6 +2141,31 @@ def kfold_cross_validate(
return kfold_cv_stats, kfold_split_indices


def _get_compute_description(backend) -> Dict:
"""Returns the compute description for the backend."""
compute_description = {"num_nodes": backend.num_nodes}

if torch.cuda.is_available():
# Assumption: All nodes are of the same instance type.
# TODO: fix for Ray where workers may be of different skus
compute_description.update(
{
"gpus_per_node": torch.cuda.device_count(),
"arch_list": torch.cuda.get_arch_list(),
"gencode_flags": torch.cuda.get_gencode_flags(),
"devices": {},
}
)
for i in range(torch.cuda.device_count()):
compute_description["devices"][i] = {
"gpu_type": torch.cuda.get_device_name(i),
"device_capability": torch.cuda.get_device_capability(i),
"device_properties": str(torch.cuda.get_device_properties(i)),
}

return compute_description


@PublicAPI
def get_experiment_description(
config,
Expand Down Expand Up @@ -2179,15 +2209,6 @@ def get_experiment_description(

description["config"] = config
description["torch_version"] = torch.__version__

gpu_info = {}
if torch.cuda.is_available():
# Assumption: All nodes are of the same instance type.
# TODO: fix for Ray where workers may be of different skus
gpu_info = {"gpu_type": torch.cuda.get_device_name(0), "gpus_per_node": torch.cuda.device_count()}

compute_description = {"num_nodes": backend.num_nodes, **gpu_info}

description["compute"] = compute_description
description["compute"] = _get_compute_description(backend)

return description
86 changes: 53 additions & 33 deletions ludwig/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
# limitations under the License.
# ==============================================================================

from __future__ import annotations

import time
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from typing import Any, Callable, Dict, Optional, Type, Union
from typing import Any, Callable, Generator, TYPE_CHECKING

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -48,14 +50,17 @@
from ludwig.utils.torch_utils import initialize_pytorch
from ludwig.utils.types import DataFrame, Series

if TYPE_CHECKING:
from ludwig.trainers.base import BaseTrainer


@DeveloperAPI
class Backend(ABC):
def __init__(
self,
dataset_manager: DatasetManager,
cache_dir: Optional[str] = None,
credentials: Optional[Dict[str, Dict[str, Any]]] = None,
cache_dir: str | None = None,
credentials: dict[str, dict[str, Any]] | None = None,
):
credentials = credentials or {}
self._dataset_manager = dataset_manager
Expand Down Expand Up @@ -84,7 +89,7 @@ def initialize_pytorch(self, *args, **kwargs):

@contextmanager
@abstractmethod
def create_trainer(self, **kwargs) -> "BaseTrainer": # noqa: F821
def create_trainer(self, config: BaseTrainerConfig, model: BaseModel, **kwargs) -> Generator:
raise NotImplementedError()

@abstractmethod
Expand All @@ -110,7 +115,7 @@ def supports_multiprocessing(self):
raise NotImplementedError()

@abstractmethod
def read_binary_files(self, column: Series, map_fn: Optional[Callable] = None) -> Series:
def read_binary_files(self, column: Series, map_fn: Callable | None = None) -> Series:
raise NotImplementedError()

@property
Expand All @@ -128,11 +133,11 @@ def get_available_resources(self) -> Resources:
raise NotImplementedError()

@abstractmethod
def max_concurrent_trials(self, hyperopt_config: HyperoptConfigDict) -> Union[int, None]:
def max_concurrent_trials(self, hyperopt_config: HyperoptConfigDict) -> int | None:
raise NotImplementedError()

@abstractmethod
def tune_batch_size(self, evaluator_cls: Type[BatchSizeEvaluator], dataset_len: int) -> int:
def tune_batch_size(self, evaluator_cls: type[BatchSizeEvaluator], dataset_len: int) -> int:
"""Returns best batch size (measured in samples / s) on the given evaluator.
The evaluator class will need to be instantiated on each worker in the backend cluster, then call
Expand All @@ -141,7 +146,9 @@ def tune_batch_size(self, evaluator_cls: Type[BatchSizeEvaluator], dataset_len:
raise NotImplementedError()

@abstractmethod
def batch_transform(self, df: DataFrame, batch_size: int, transform_fn: Callable, name: str = None) -> DataFrame:
def batch_transform(
self, df: DataFrame, batch_size: int, transform_fn: Callable, name: str | None = None
) -> DataFrame:
"""Applies `transform_fn` to every `batch_size` length batch of `df` and returns the result."""
raise NotImplementedError()

Expand All @@ -159,16 +166,16 @@ def supports_multiprocessing(self):
return True

@staticmethod
def read_binary_files(
column: pd.Series, map_fn: Optional[Callable] = None, file_size: Optional[int] = None
) -> pd.Series:
def read_binary_files(column: pd.Series, map_fn: Callable | None = None, file_size: int | None = None) -> pd.Series:
column = column.fillna(np.nan).replace([np.nan], [None]) # normalize NaNs to None

sample_fname = column.head(1).values[0]
with ThreadPoolExecutor() as executor: # number of threads is inferred
if isinstance(sample_fname, str):
if map_fn is read_audio_from_path: # bypass torchaudio issue that no longer takes in file-like objects
result = executor.map(lambda path: map_fn(path) if path is not None else path, column.values)
result = executor.map( # type: ignore[misc]
lambda path: map_fn(path) if path is not None else path, column.values
)
else:
result = executor.map(
lambda path: get_bytes_obj_from_path(path) if path is not None else path, column.values
Expand All @@ -183,7 +190,7 @@ def read_binary_files(
return pd.Series(result, index=column.index, name=column.name)

@staticmethod
def batch_transform(df: DataFrame, batch_size: int, transform_fn: Callable, name: str = None) -> DataFrame:
def batch_transform(df: DataFrame, batch_size: int, transform_fn: Callable, name: str | None = None) -> DataFrame:
name = name or "Batch Transform"
batches = to_batches(df, batch_size)
transform = transform_fn()
Expand All @@ -201,21 +208,11 @@ def initialize():
def initialize_pytorch(*args, **kwargs):
initialize_pytorch(*args, **kwargs)

def create_trainer(self, config: BaseTrainerConfig, model: BaseModel, **kwargs) -> "BaseTrainer": # noqa: F821
from ludwig.trainers.registry import get_llm_trainers_registry, get_trainers_registry

if model.type() == MODEL_LLM:
trainer_cls = get_from_registry(config.type, get_llm_trainers_registry())
else:
trainer_cls = get_from_registry(model.type(), get_trainers_registry())

return trainer_cls(config=config, model=model, **kwargs)

@staticmethod
def create_predictor(model: BaseModel, **kwargs):
from ludwig.models.predictor import get_predictor_cls

return get_predictor_cls(model.type())(model, **kwargs)
return get_predictor_cls(model.type())(model, **kwargs) # type: ignore[call-arg]

def sync_model(self, model):
pass
Expand All @@ -229,7 +226,7 @@ def is_coordinator() -> bool:
return True

@staticmethod
def tune_batch_size(evaluator_cls: Type[BatchSizeEvaluator], dataset_len: int) -> int:
def tune_batch_size(evaluator_cls: type[BatchSizeEvaluator], dataset_len: int) -> int:
evaluator = evaluator_cls()
return evaluator.select_best_batch_size(dataset_len)

Expand All @@ -251,14 +248,16 @@ def is_coordinator() -> bool:
class LocalBackend(LocalPreprocessingMixin, LocalTrainingMixin, Backend):
BACKEND_TYPE = "local"

_shared_instance: LocalBackend

@classmethod
def shared_instance(cls):
def shared_instance(cls) -> LocalBackend:
"""Returns a shared singleton LocalBackend instance."""
if not hasattr(cls, "_shared_instance"):
cls._shared_instance = cls()
return cls._shared_instance

def __init__(self, **kwargs):
def __init__(self, **kwargs) -> None:
super().__init__(dataset_manager=PandasDatasetManager(self), **kwargs)

@property
Expand All @@ -272,19 +271,35 @@ def num_training_workers(self) -> int:
def get_available_resources(self) -> Resources:
return Resources(cpus=psutil.cpu_count(), gpus=torch.cuda.device_count())

def max_concurrent_trials(self, hyperopt_config: HyperoptConfigDict) -> Union[int, None]:
def max_concurrent_trials(self, hyperopt_config: HyperoptConfigDict) -> int | None:
# Every trial will be run with Pandas and NO Ray Datasets. Allow Ray Tune to use all the
# trial resources it wants, because there is no Ray Datasets process to compete with it for CPUs.
return None

def create_trainer(
self,
config: BaseTrainerConfig,
model: BaseModel,
**kwargs,
) -> BaseTrainer: # type: ignore[override]
from ludwig.trainers.registry import get_llm_trainers_registry, get_trainers_registry

trainer_cls: type
if model.type() == MODEL_LLM:
trainer_cls = get_from_registry(config.type, get_llm_trainers_registry())
else:
trainer_cls = get_from_registry(model.type(), get_trainers_registry())

return trainer_cls(config=config, model=model, **kwargs)


@DeveloperAPI
class DataParallelBackend(LocalPreprocessingMixin, Backend, ABC):
BACKEND_TYPE = "deepspeed"

def __init__(self, **kwargs):
super().__init__(dataset_manager=PandasDatasetManager(self), **kwargs)
self._distributed: Optional[DistributedStrategy] = None
self._distributed: DistributedStrategy | None = None

@abstractmethod
def initialize(self):
Expand All @@ -295,15 +310,20 @@ def initialize_pytorch(self, *args, **kwargs):
*args, local_rank=self._distributed.local_rank(), local_size=self._distributed.local_size(), **kwargs
)

def create_trainer(self, **kwargs) -> "BaseTrainer": # noqa: F821
def create_trainer(
self,
config: BaseTrainerConfig,
model: BaseModel,
**kwargs,
) -> BaseTrainer: # type: ignore[override]
from ludwig.trainers.trainer import Trainer

return Trainer(distributed=self._distributed, **kwargs)

def create_predictor(self, model: BaseModel, **kwargs):
from ludwig.models.predictor import get_predictor_cls

return get_predictor_cls(model.type())(model, distributed=self._distributed, **kwargs)
return get_predictor_cls(model.type())(model, distributed=self._distributed, **kwargs) # type: ignore[call-arg]

def sync_model(self, model):
# Model weights are only saved on the coordinator, so broadcast
Expand Down Expand Up @@ -343,10 +363,10 @@ def get_available_resources(self) -> Resources:

return Resources(cpus=cpus, gpus=gpus)

def max_concurrent_trials(self, hyperopt_config: HyperoptConfigDict) -> Union[int, None]:
def max_concurrent_trials(self, hyperopt_config: HyperoptConfigDict) -> int | None:
# Return None since there is no Ray component
return None

def tune_batch_size(self, evaluator_cls: Type[BatchSizeEvaluator], dataset_len: int) -> int:
def tune_batch_size(self, evaluator_cls: type[BatchSizeEvaluator], dataset_len: int) -> int:
evaluator = evaluator_cls()
return evaluator.select_best_batch_size(dataset_len)
2 changes: 1 addition & 1 deletion ludwig/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(self):
init_config Initialize a user config from a dataset and targets
render_config Renders the fully populated config with all defaults set
check_install Runs a quick training run on synthetic data to verify installation status
upload Push trained model artifacts to a registry (e.g., HuggingFace Hub)
upload Push trained model artifacts to a registry (e.g., Predibase, HuggingFace Hub)
""",
)
parser.add_argument("command", help="Subcommand to run")
Expand Down
Loading

0 comments on commit e6b6bb9

Please sign in to comment.