Skip to content

Commit

Permalink
Merge branch 'master' of github.com:ludwig-ai/ludwig into check_promp…
Browse files Browse the repository at this point in the history
…t_token_lengths
  • Loading branch information
justinxzhao committed Oct 18, 2023
2 parents de93a9d + fb6d866 commit 4657e8f
Show file tree
Hide file tree
Showing 41 changed files with 1,020 additions and 176 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
torchscript-version: 1.10.2
ray-version: 2.2.0
- python-version: "3.9"
pytorch-version: 2.0.0
pytorch-version: 2.1.0
torchscript-version: 1.10.2
ray-version: 2.3.0
- python-version: "3.10"
Expand Down Expand Up @@ -208,6 +208,7 @@ jobs:
- "integration_tests_c"
- "integration_tests_d"
- "integration_tests_e"
- "integration_tests_f"

env:
AWS_ACCESS_KEY_ID: ${{ secrets.LUDWIG_TESTS_AWS_ACCESS_KEY_ID }}
Expand Down
13 changes: 9 additions & 4 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@
120
],
"editor.formatOnSave": true,
"python.formatting.provider": "black",
"python.linting.enabled": true,
"python.linting.flake8Enabled": true,
"python.linting.flake8Args": [
"[python]": {
"editor.defaultFormatter": "ms-python.black-formatter",
"editor.formatOnSave": true
},
"black-formatter.args": [
"--line-length",
"120"
],
"flake8.args": [
"--config=setup.cfg"
],
"python.testing.unittestEnabled": false,
Expand Down
11 changes: 8 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Ludwig is a **low-code** framework for building **custom** AI models like **LLMs
Key features:

- 🛠 **Build custom models with ease:** a declarative YAML configuration file is all you need to train a state-of-the-art LLM on your data. Support for multi-task and multi-modality learning. Comprehensive config validation detects invalid parameter combinations and prevents runtime failures.
-**Optimized for scale and efficiency:** automatic batch size selection, distributed training ([DDP](https://pytorch.org/tutorials/beginner/ddp_series_theory.html), [DeepSpeed](https://github.com/microsoft/DeepSpeed)), parameter efficient fine-tuning ([PEFT](https://github.com/huggingface/peft)), 4-bit quantization (QLoRA), and larger-than-memory datasets.
-**Optimized for scale and efficiency:** automatic batch size selection, distributed training ([DDP](https://pytorch.org/tutorials/beginner/ddp_series_theory.html), [DeepSpeed](https://github.com/microsoft/DeepSpeed)), parameter efficient fine-tuning ([PEFT](https://github.com/huggingface/peft)), 4-bit quantization (QLoRA), paged and 8-bit optimizers, and larger-than-memory datasets.
- 📐 **Expert level control:** retain full control of your models down to the activation functions. Support for hyperparameter optimization, explainability, and rich metric visualizations.
- 🧱 **Modular and extensible:** experiment with different model architectures, tasks, features, and modalities with just a few parameter changes in the config. Think building blocks for deep learning.
- 🚢 **Engineered for production:** prebuilt [Docker](https://hub.docker.com/u/ludwigai) containers, native support for running with [Ray](https://www.ray.io/) on [Kubernetes](https://github.com/ray-project/kuberay), export models to [Torchscript](https://pytorch.org/docs/stable/jit.html) and [Triton](https://developer.nvidia.com/triton-inference-server), upload to [HuggingFace](https://huggingface.co/models) with one command.
Expand Down Expand Up @@ -52,8 +52,13 @@ pip install ludwig[full]

Want to take a quick peak at some of the Ludwig 0.8 features? Check out this Colab Notebook 🚀 [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1lB4ALmEyvcMycE3Mlnsd7I3bc0zxvk39)

For a full tutorial, check out the official [getting started guide](https://ludwig-ai.github.io/ludwig-docs/latest/getting_started/),
or take a look at end-to-end [Examples](https://ludwig-ai.github.io/ludwig-docs/latest/examples).
Looking to fine-tune Llama-2 or Mistral? Check out these notebooks:

1. Fine-Tune Llama-2-7b: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1r4oSEwRJpYKBPM0M0RSh0pBEYK_gBKbe)
1. Fine-Tune Llama-2-13b: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1zmSEzqZ7v4twBrXagj1TE_C--RNyVAyu)
1. Fine-Tune Mistral-7b: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1i_8A1n__b7ljRWHzIsAdhO7u7r49vUm4)

For a full tutorial, check out the official [getting started guide](https://ludwig-ai.github.io/ludwig-docs/latest/getting_started/), or take a look at end-to-end [Examples](https://ludwig-ai.github.io/ludwig-docs/latest/examples).

## Large Language Model Fine-Tuning

Expand Down
57 changes: 37 additions & 20 deletions ludwig/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from typing import Any, Callable, TYPE_CHECKING
from typing import Any, Callable, Generator, TYPE_CHECKING

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -89,7 +89,7 @@ def initialize_pytorch(self, *args, **kwargs):

@contextmanager
@abstractmethod
def create_trainer(self, **kwargs) -> BaseTrainer:
def create_trainer(self, config: BaseTrainerConfig, model: BaseModel, **kwargs) -> Generator:
raise NotImplementedError()

@abstractmethod
Expand Down Expand Up @@ -146,7 +146,9 @@ def tune_batch_size(self, evaluator_cls: type[BatchSizeEvaluator], dataset_len:
raise NotImplementedError()

@abstractmethod
def batch_transform(self, df: DataFrame, batch_size: int, transform_fn: Callable, name: str = None) -> DataFrame:
def batch_transform(
self, df: DataFrame, batch_size: int, transform_fn: Callable, name: str | None = None
) -> DataFrame:
"""Applies `transform_fn` to every `batch_size` length batch of `df` and returns the result."""
raise NotImplementedError()

Expand All @@ -171,7 +173,9 @@ def read_binary_files(column: pd.Series, map_fn: Callable | None = None, file_si
with ThreadPoolExecutor() as executor: # number of threads is inferred
if isinstance(sample_fname, str):
if map_fn is read_audio_from_path: # bypass torchaudio issue that no longer takes in file-like objects
result = executor.map(lambda path: map_fn(path) if path is not None else path, column.values)
result = executor.map( # type: ignore[misc]
lambda path: map_fn(path) if path is not None else path, column.values
)
else:
result = executor.map(
lambda path: get_bytes_obj_from_path(path) if path is not None else path, column.values
Expand All @@ -186,7 +190,7 @@ def read_binary_files(column: pd.Series, map_fn: Callable | None = None, file_si
return pd.Series(result, index=column.index, name=column.name)

@staticmethod
def batch_transform(df: DataFrame, batch_size: int, transform_fn: Callable, name: str = None) -> DataFrame:
def batch_transform(df: DataFrame, batch_size: int, transform_fn: Callable, name: str | None = None) -> DataFrame:
name = name or "Batch Transform"
batches = to_batches(df, batch_size)
transform = transform_fn()
Expand All @@ -204,21 +208,11 @@ def initialize():
def initialize_pytorch(*args, **kwargs):
initialize_pytorch(*args, **kwargs)

def create_trainer(self, config: BaseTrainerConfig, model: BaseModel, **kwargs) -> BaseTrainer:
from ludwig.trainers.registry import get_llm_trainers_registry, get_trainers_registry

if model.type() == MODEL_LLM:
trainer_cls = get_from_registry(config.type, get_llm_trainers_registry())
else:
trainer_cls = get_from_registry(model.type(), get_trainers_registry())

return trainer_cls(config=config, model=model, **kwargs)

@staticmethod
def create_predictor(model: BaseModel, **kwargs):
from ludwig.models.predictor import get_predictor_cls

return get_predictor_cls(model.type())(model, **kwargs)
return get_predictor_cls(model.type())(model, **kwargs) # type: ignore[call-arg]

def sync_model(self, model):
pass
Expand Down Expand Up @@ -254,14 +248,16 @@ def is_coordinator() -> bool:
class LocalBackend(LocalPreprocessingMixin, LocalTrainingMixin, Backend):
BACKEND_TYPE = "local"

_shared_instance: LocalBackend

@classmethod
def shared_instance(cls):
def shared_instance(cls) -> LocalBackend:
"""Returns a shared singleton LocalBackend instance."""
if not hasattr(cls, "_shared_instance"):
cls._shared_instance = cls()
return cls._shared_instance

def __init__(self, **kwargs):
def __init__(self, **kwargs) -> None:
super().__init__(dataset_manager=PandasDatasetManager(self), **kwargs)

@property
Expand All @@ -280,6 +276,22 @@ def max_concurrent_trials(self, hyperopt_config: HyperoptConfigDict) -> int | No
# trial resources it wants, because there is no Ray Datasets process to compete with it for CPUs.
return None

def create_trainer(
self,
config: BaseTrainerConfig,
model: BaseModel,
**kwargs,
) -> BaseTrainer: # type: ignore[override]
from ludwig.trainers.registry import get_llm_trainers_registry, get_trainers_registry

trainer_cls: type
if model.type() == MODEL_LLM:
trainer_cls = get_from_registry(config.type, get_llm_trainers_registry())
else:
trainer_cls = get_from_registry(model.type(), get_trainers_registry())

return trainer_cls(config=config, model=model, **kwargs)


@DeveloperAPI
class DataParallelBackend(LocalPreprocessingMixin, Backend, ABC):
Expand All @@ -298,15 +310,20 @@ def initialize_pytorch(self, *args, **kwargs):
*args, local_rank=self._distributed.local_rank(), local_size=self._distributed.local_size(), **kwargs
)

def create_trainer(self, **kwargs) -> BaseTrainer:
def create_trainer(
self,
config: BaseTrainerConfig,
model: BaseModel,
**kwargs,
) -> BaseTrainer: # type: ignore[override]
from ludwig.trainers.trainer import Trainer

return Trainer(distributed=self._distributed, **kwargs)

def create_predictor(self, model: BaseModel, **kwargs):
from ludwig.models.predictor import get_predictor_cls

return get_predictor_cls(model.type())(model, distributed=self._distributed, **kwargs)
return get_predictor_cls(model.type())(model, distributed=self._distributed, **kwargs) # type: ignore[call-arg]

def sync_model(self, model):
# Model weights are only saved on the coordinator, so broadcast
Expand Down
10 changes: 9 additions & 1 deletion ludwig/config_validation/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@ def check_llm_finetuning_adaption_prompt_parameters(config: "ModelConfig"):
if config.adapter.type != "adaption_prompt":
return

from peft.tuners.adaption_prompt import TRANSFORMERS_MODEL_CONFIG
from peft.tuners.adaption_prompt.config import TRANSFORMERS_MODEL_CONFIG

# Adaption Config is currently only supported for Llama model types
model_config = _get_llm_model_config(config.base_model)
Expand Down Expand Up @@ -718,3 +718,11 @@ def check_prompt_requirements(config: "ModelConfig") -> None: # noqa: F821
"A template must contain at least one reference to a column or the sample keyword {__sample__} for "
"a JSON-serialized representation of non-output feature columns."
)


@register_config_check
def check_sample_ratio_and_size_compatible(config: "ModelConfig") -> None:
sample_ratio = config.preprocessing.sample_ratio
sample_size = config.preprocessing.sample_size
if sample_size is not None and sample_ratio < 1.0:
raise ConfigValidationError("sample_size cannot be used when sample_ratio < 1.0")
2 changes: 2 additions & 0 deletions ludwig/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@
EPOCHS = "epochs"
BATCH_SIZE = "batch_size"
EVAL_BATCH_SIZE = "eval_batch_size"
EFFECTIVE_BATCH_SIZE = "effective_batch_size"
MAX_BATCH_SIZE = "max_batch_size"
DEFAULT_BATCH_SIZE = "auto"
FALLBACK_BATCH_SIZE = 128
# The smallest batch size that is supported on Ludwig.
Expand Down
3 changes: 3 additions & 0 deletions ludwig/contribs/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,6 @@ def on_visualize_figure(self, fig):
logger.info("wandb.on_visualize_figure() called...")
if wandb.run:
wandb.log({"figure": fig})

def on_train_end(self, output_directory):
wandb.finish()
113 changes: 113 additions & 0 deletions ludwig/data/batcher/test_batcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import logging

import pandas as pd
import yaml

from ludwig.api import LudwigModel
from ludwig.data.dataset.pandas import PandasDataset


def test_pandas_size():
df = pd.DataFrame(
{"name": ["joe", "janice", "sara"], "mask": ["green", "black", "pink"], "weapon": ["stick", "gun", "gun"]}
)
config = yaml.safe_load(
"""
model_type: llm
base_model: HuggingFaceH4/tiny-random-LlamaForCausalLM
input_features:
- name: name
type: text
preprocessing:
max_sequence_length: 256
column: name
output_features:
- name: weapon
type: text
preprocessing:
max_sequence_length: 256
column: weapon
preprocessing:
split:
type: random
probabilities:
- 1
- 0
- 0
"""
)
model = LudwigModel(config=config, logging_level=logging.INFO)
data = model.preprocess(df, skip_save_processed_input=False)
training_set = data[0]
assert training_set.size == len(df)

# Check if string loading works as well
# data[0].data_hdf5_fp is the string filepath to the cached data from preprocessing
data_from_str = PandasDataset(data[0].data_hdf5_fp, data[0].features, None)
assert data_from_str.size == len(df)


def test_pandas_batcher_use_all_samples():
df = pd.DataFrame(
{"name": ["joe", "janice", "sara"], "mask": ["green", "black", "pink"], "weapon": ["stick", "gun", "gun"]}
)
config = yaml.safe_load(
"""
model_type: llm
base_model: HuggingFaceH4/tiny-random-LlamaForCausalLM
input_features:
- name: name
type: text
preprocessing:
max_sequence_length: 256
column: name
output_features:
- name: weapon
type: text
preprocessing:
max_sequence_length: 256
column: weapon
preprocessing:
split:
type: random
probabilities:
- 1
- 0
- 0
"""
)
model = LudwigModel(config=config, logging_level=logging.INFO)
data = model.preprocess(df, skip_save_processed_input=False)
training_set = data[0]
features = training_set.dataset.keys()

batches = []
with training_set.initialize_batcher(batch_size=1) as batcher:
while not batcher.last_batch():
batch = batcher.next_batch()
batches.append(batch)
assert (len(batches)) == training_set.size

# Check to see if all items are used exactly once
for feature in features:
for i in range(len(training_set.dataset[feature])):
# Each of the arrays in the line below should contain the vector representation of a feature of sample i
assert (batches[i][feature].squeeze() == training_set.dataset[feature][i].squeeze()).all()

# Check if string loading works as well
batches = []
# data[0].data_hdf5_fp is the string filepath to the cached data from preprocessing
data_from_str = PandasDataset(data[0].data_hdf5_fp, data[0].features, None)
features = data_from_str.dataset.keys()

with data_from_str.initialize_batcher(batch_size=1) as batcher:
while not batcher.last_batch():
batch = batcher.next_batch()
batches.append(batch)
assert (len(batches)) == data_from_str.size

# Check to see if all items are used exactly once
for feature in features:
for i in range(len(data_from_str.dataset[feature])):
# Each of the arrays in the line below should contain the vector representation of a feature of sample i
assert (batches[i][feature].squeeze() == data_from_str.dataset[feature][i].squeeze()).all()
Loading

0 comments on commit 4657e8f

Please sign in to comment.