Skip to content

Commit

Permalink
Merge branch 'master' of github.com:ludwig-ai/ludwig into release-0.8
Browse files Browse the repository at this point in the history
  • Loading branch information
justinxzhao committed Oct 13, 2023
2 parents b1f5ead + 062f958 commit a1dea34
Show file tree
Hide file tree
Showing 26 changed files with 824 additions and 118 deletions.
13 changes: 9 additions & 4 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@
120
],
"editor.formatOnSave": true,
"python.formatting.provider": "black",
"python.linting.enabled": true,
"python.linting.flake8Enabled": true,
"python.linting.flake8Args": [
"[python]": {
"editor.defaultFormatter": "ms-python.black-formatter",
"editor.formatOnSave": true
},
"black-formatter.args": [
"--line-length",
"120"
],
"flake8.args": [
"--config=setup.cfg"
],
"python.testing.unittestEnabled": false,
Expand Down
11 changes: 8 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Ludwig is a **low-code** framework for building **custom** AI models like **LLMs
Key features:

- 🛠 **Build custom models with ease:** a declarative YAML configuration file is all you need to train a state-of-the-art LLM on your data. Support for multi-task and multi-modality learning. Comprehensive config validation detects invalid parameter combinations and prevents runtime failures.
-**Optimized for scale and efficiency:** automatic batch size selection, distributed training ([DDP](https://pytorch.org/tutorials/beginner/ddp_series_theory.html), [DeepSpeed](https://github.com/microsoft/DeepSpeed)), parameter efficient fine-tuning ([PEFT](https://github.com/huggingface/peft)), 4-bit quantization (QLoRA), and larger-than-memory datasets.
-**Optimized for scale and efficiency:** automatic batch size selection, distributed training ([DDP](https://pytorch.org/tutorials/beginner/ddp_series_theory.html), [DeepSpeed](https://github.com/microsoft/DeepSpeed)), parameter efficient fine-tuning ([PEFT](https://github.com/huggingface/peft)), 4-bit quantization (QLoRA), paged and 8-bit optimizers, and larger-than-memory datasets.
- 📐 **Expert level control:** retain full control of your models down to the activation functions. Support for hyperparameter optimization, explainability, and rich metric visualizations.
- 🧱 **Modular and extensible:** experiment with different model architectures, tasks, features, and modalities with just a few parameter changes in the config. Think building blocks for deep learning.
- 🚢 **Engineered for production:** prebuilt [Docker](https://hub.docker.com/u/ludwigai) containers, native support for running with [Ray](https://www.ray.io/) on [Kubernetes](https://github.com/ray-project/kuberay), export models to [Torchscript](https://pytorch.org/docs/stable/jit.html) and [Triton](https://developer.nvidia.com/triton-inference-server), upload to [HuggingFace](https://huggingface.co/models) with one command.
Expand Down Expand Up @@ -52,8 +52,13 @@ pip install ludwig[full]

Want to take a quick peak at some of the Ludwig 0.8 features? Check out this Colab Notebook 🚀 [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1lB4ALmEyvcMycE3Mlnsd7I3bc0zxvk39)

For a full tutorial, check out the official [getting started guide](https://ludwig-ai.github.io/ludwig-docs/latest/getting_started/),
or take a look at end-to-end [Examples](https://ludwig-ai.github.io/ludwig-docs/latest/examples).
Looking to fine-tune Llama-2 or Mistral? Check out these notebooks:

1. Fine-Tune Llama-2-7b: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1r4oSEwRJpYKBPM0M0RSh0pBEYK_gBKbe)
1. Fine-Tune Llama-2-13b: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1zmSEzqZ7v4twBrXagj1TE_C--RNyVAyu)
1. Fine-Tune Mistral-7b: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1i_8A1n__b7ljRWHzIsAdhO7u7r49vUm4)

For a full tutorial, check out the official [getting started guide](https://ludwig-ai.github.io/ludwig-docs/latest/getting_started/), or take a look at end-to-end [Examples](https://ludwig-ai.github.io/ludwig-docs/latest/examples).

## Large Language Model Fine-Tuning

Expand Down
46 changes: 36 additions & 10 deletions ludwig/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,16 @@ def train(
`(training_set, validation_set, test_set)`.
`output_directory` filepath to where training results are stored.
"""
# Only reset the metadata if the model has not been trained before
if self.training_set_metadata:
logger.warning(
"This model has been trained before. Its architecture has been defined by the original training set "
"(for example, the number of possible categorical outputs). The current training data will be mapped "
"to this architecture. If you want to change the architecture of the model, please concatenate your "
"new training data with the original and train a new model from scratch."
)
training_set_metadata = self.training_set_metadata

if self._user_config.get(HYPEROPT):
print_boxed("WARNING")
logger.warning(HYPEROPT_WARNING)
Expand Down Expand Up @@ -2131,6 +2141,31 @@ def kfold_cross_validate(
return kfold_cv_stats, kfold_split_indices


def _get_compute_description(backend) -> Dict:
"""Returns the compute description for the backend."""
compute_description = {"num_nodes": backend.num_nodes}

if torch.cuda.is_available():
# Assumption: All nodes are of the same instance type.
# TODO: fix for Ray where workers may be of different skus
compute_description.update(
{
"gpus_per_node": torch.cuda.device_count(),
"arch_list": torch.cuda.get_arch_list(),
"gencode_flags": torch.cuda.get_gencode_flags(),
"devices": {},
}
)
for i in range(torch.cuda.device_count()):
compute_description["devices"][i] = {
"gpu_type": torch.cuda.get_device_name(i),
"device_capability": torch.cuda.get_device_capability(i),
"device_properties": str(torch.cuda.get_device_properties(i)),
}

return compute_description


@PublicAPI
def get_experiment_description(
config,
Expand Down Expand Up @@ -2174,15 +2209,6 @@ def get_experiment_description(

description["config"] = config
description["torch_version"] = torch.__version__

gpu_info = {}
if torch.cuda.is_available():
# Assumption: All nodes are of the same instance type.
# TODO: fix for Ray where workers may be of different skus
gpu_info = {"gpu_type": torch.cuda.get_device_name(0), "gpus_per_node": torch.cuda.device_count()}

compute_description = {"num_nodes": backend.num_nodes, **gpu_info}

description["compute"] = compute_description
description["compute"] = _get_compute_description(backend)

return description
57 changes: 37 additions & 20 deletions ludwig/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from typing import Any, Callable, TYPE_CHECKING
from typing import Any, Callable, Generator, TYPE_CHECKING

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -89,7 +89,7 @@ def initialize_pytorch(self, *args, **kwargs):

@contextmanager
@abstractmethod
def create_trainer(self, **kwargs) -> BaseTrainer:
def create_trainer(self, config: BaseTrainerConfig, model: BaseModel, **kwargs) -> Generator:
raise NotImplementedError()

@abstractmethod
Expand Down Expand Up @@ -146,7 +146,9 @@ def tune_batch_size(self, evaluator_cls: type[BatchSizeEvaluator], dataset_len:
raise NotImplementedError()

@abstractmethod
def batch_transform(self, df: DataFrame, batch_size: int, transform_fn: Callable, name: str = None) -> DataFrame:
def batch_transform(
self, df: DataFrame, batch_size: int, transform_fn: Callable, name: str | None = None
) -> DataFrame:
"""Applies `transform_fn` to every `batch_size` length batch of `df` and returns the result."""
raise NotImplementedError()

Expand All @@ -171,7 +173,9 @@ def read_binary_files(column: pd.Series, map_fn: Callable | None = None, file_si
with ThreadPoolExecutor() as executor: # number of threads is inferred
if isinstance(sample_fname, str):
if map_fn is read_audio_from_path: # bypass torchaudio issue that no longer takes in file-like objects
result = executor.map(lambda path: map_fn(path) if path is not None else path, column.values)
result = executor.map( # type: ignore[misc]
lambda path: map_fn(path) if path is not None else path, column.values
)
else:
result = executor.map(
lambda path: get_bytes_obj_from_path(path) if path is not None else path, column.values
Expand All @@ -186,7 +190,7 @@ def read_binary_files(column: pd.Series, map_fn: Callable | None = None, file_si
return pd.Series(result, index=column.index, name=column.name)

@staticmethod
def batch_transform(df: DataFrame, batch_size: int, transform_fn: Callable, name: str = None) -> DataFrame:
def batch_transform(df: DataFrame, batch_size: int, transform_fn: Callable, name: str | None = None) -> DataFrame:
name = name or "Batch Transform"
batches = to_batches(df, batch_size)
transform = transform_fn()
Expand All @@ -204,21 +208,11 @@ def initialize():
def initialize_pytorch(*args, **kwargs):
initialize_pytorch(*args, **kwargs)

def create_trainer(self, config: BaseTrainerConfig, model: BaseModel, **kwargs) -> BaseTrainer:
from ludwig.trainers.registry import get_llm_trainers_registry, get_trainers_registry

if model.type() == MODEL_LLM:
trainer_cls = get_from_registry(config.type, get_llm_trainers_registry())
else:
trainer_cls = get_from_registry(model.type(), get_trainers_registry())

return trainer_cls(config=config, model=model, **kwargs)

@staticmethod
def create_predictor(model: BaseModel, **kwargs):
from ludwig.models.predictor import get_predictor_cls

return get_predictor_cls(model.type())(model, **kwargs)
return get_predictor_cls(model.type())(model, **kwargs) # type: ignore[call-arg]

def sync_model(self, model):
pass
Expand Down Expand Up @@ -254,14 +248,16 @@ def is_coordinator() -> bool:
class LocalBackend(LocalPreprocessingMixin, LocalTrainingMixin, Backend):
BACKEND_TYPE = "local"

_shared_instance: LocalBackend

@classmethod
def shared_instance(cls):
def shared_instance(cls) -> LocalBackend:
"""Returns a shared singleton LocalBackend instance."""
if not hasattr(cls, "_shared_instance"):
cls._shared_instance = cls()
return cls._shared_instance

def __init__(self, **kwargs):
def __init__(self, **kwargs) -> None:
super().__init__(dataset_manager=PandasDatasetManager(self), **kwargs)

@property
Expand All @@ -280,6 +276,22 @@ def max_concurrent_trials(self, hyperopt_config: HyperoptConfigDict) -> int | No
# trial resources it wants, because there is no Ray Datasets process to compete with it for CPUs.
return None

def create_trainer(
self,
config: BaseTrainerConfig,
model: BaseModel,
**kwargs,
) -> BaseTrainer: # type: ignore[override]
from ludwig.trainers.registry import get_llm_trainers_registry, get_trainers_registry

trainer_cls: type
if model.type() == MODEL_LLM:
trainer_cls = get_from_registry(config.type, get_llm_trainers_registry())
else:
trainer_cls = get_from_registry(model.type(), get_trainers_registry())

return trainer_cls(config=config, model=model, **kwargs)


@DeveloperAPI
class DataParallelBackend(LocalPreprocessingMixin, Backend, ABC):
Expand All @@ -298,15 +310,20 @@ def initialize_pytorch(self, *args, **kwargs):
*args, local_rank=self._distributed.local_rank(), local_size=self._distributed.local_size(), **kwargs
)

def create_trainer(self, **kwargs) -> BaseTrainer:
def create_trainer(
self,
config: BaseTrainerConfig,
model: BaseModel,
**kwargs,
) -> BaseTrainer: # type: ignore[override]
from ludwig.trainers.trainer import Trainer

return Trainer(distributed=self._distributed, **kwargs)

def create_predictor(self, model: BaseModel, **kwargs):
from ludwig.models.predictor import get_predictor_cls

return get_predictor_cls(model.type())(model, distributed=self._distributed, **kwargs)
return get_predictor_cls(model.type())(model, distributed=self._distributed, **kwargs) # type: ignore[call-arg]

def sync_model(self, model):
# Model weights are only saved on the coordinator, so broadcast
Expand Down
2 changes: 1 addition & 1 deletion ludwig/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(self):
init_config Initialize a user config from a dataset and targets
render_config Renders the fully populated config with all defaults set
check_install Runs a quick training run on synthetic data to verify installation status
upload Push trained model artifacts to a registry (e.g., HuggingFace Hub)
upload Push trained model artifacts to a registry (e.g., Predibase, HuggingFace Hub)
""",
)
parser.add_argument("command", help="Subcommand to run")
Expand Down
8 changes: 8 additions & 0 deletions ludwig/config_validation/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,3 +718,11 @@ def check_prompt_requirements(config: "ModelConfig") -> None: # noqa: F821
"A template must contain at least one reference to a column or the sample keyword {__sample__} for "
"a JSON-serialized representation of non-output feature columns."
)


@register_config_check
def check_sample_ratio_and_size_compatible(config: "ModelConfig") -> None:
sample_ratio = config.preprocessing.sample_ratio
sample_size = config.preprocessing.sample_size
if sample_size is not None and sample_ratio < 1.0:
raise ConfigValidationError("sample_size cannot be used when sample_ratio < 1.0")
23 changes: 13 additions & 10 deletions ludwig/data/dataset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@
# limitations under the License.
# ==============================================================================

from __future__ import annotations

import contextlib
from abc import ABC, abstractmethod
from typing import Iterable, Optional
from typing import Iterable

from ludwig.data.batcher.base import Batcher
from ludwig.distributed import DistributedStrategy
from ludwig.features.base_feature import BaseFeature
from ludwig.utils.defaults import default_random_seed
Expand All @@ -26,7 +29,7 @@

class Dataset(ABC):
@abstractmethod
def __len__(self):
def __len__(self) -> int:
raise NotImplementedError()

@contextlib.contextmanager
Expand All @@ -38,36 +41,36 @@ def initialize_batcher(
random_seed: int = default_random_seed,
ignore_last: bool = False,
distributed: DistributedStrategy = None,
):
) -> Batcher:
raise NotImplementedError()

@abstractmethod
def to_df(self, features: Optional[Iterable[BaseFeature]] = None) -> DataFrame:
def to_df(self, features: Iterable[BaseFeature] | None = None) -> DataFrame:
raise NotImplementedError()

@abstractmethod
def to_scalar_df(self, features: Optional[Iterable[BaseFeature]] = None) -> DataFrame:
def to_scalar_df(self, features: Iterable[BaseFeature] | None = None) -> DataFrame:
raise NotImplementedError()

@property
def in_memory_size_bytes(self):
def in_memory_size_bytes(self) -> int:
raise NotImplementedError()


class DatasetManager(ABC):
@abstractmethod
def create(self, dataset, config, training_set_metadata):
def create(self, dataset, config, training_set_metadata) -> Dataset:
raise NotImplementedError()

@abstractmethod
def save(self, cache_path, dataset, config, training_set_metadata, tag):
def save(self, cache_path, dataset, config, training_set_metadata, tag) -> Dataset:
raise NotImplementedError()

@abstractmethod
def can_cache(self, skip_save_processed_input):
def can_cache(self, skip_save_processed_input) -> bool:
raise NotImplementedError()

@property
@abstractmethod
def data_format(self):
def data_format(self) -> str:
raise NotImplementedError()
Loading

0 comments on commit a1dea34

Please sign in to comment.