Skip to content

Commit

Permalink
feat(model): Add optional memoization to datasets during model training.
Browse files Browse the repository at this point in the history
By specifying the optional parameter `memoized_dataset_cache_size > 0`, the corresponding number of datasets will be kept in memory to avoid repeated conversion from `vaex` to `pandas` in settings where we perform repeated fitting using the same datasets e.g. hyperparamter tuning. Use with caution and always call `clear_load_dataset_cache` once completed to clear the cache.
  • Loading branch information
Erik Båvenstrand committed Apr 12, 2024
1 parent 1b26ddc commit 6a955dc
Show file tree
Hide file tree
Showing 9 changed files with 264 additions and 30 deletions.
27 changes: 27 additions & 0 deletions mleko/dataset/data_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import copy
from typing import Literal

from mleko.cache.fingerprinters.dict_fingerprinter import DictFingerprinter
from mleko.utils.custom_logger import CustomLogger


Expand Down Expand Up @@ -53,6 +54,32 @@ def __init__(
"timedelta": sorted(list(timedelta)),
}

def __eq__(self, other: DataSchema) -> bool:
"""Check if two DataSchema objects are equal.
Args:
other: DataSchema object to compare with.
Returns:
True if the two DataSchema objects are equal, False otherwise.
"""
return isinstance(other, DataSchema) and DictFingerprinter().fingerprint(
self.to_dict()
) == DictFingerprinter().fingerprint(other.to_dict())

def __hash__(self) -> int:
"""Get the hash of the DataSchema.
Warning:
This method is not intended to be used for stable hashing across runs. Please
refer to the `DictFingerprinter` class in the `mleko.utils.fingerprinter` module
for stable hashing of the DataSchema.
Returns:
Hash of the DataSchema.
"""
return hash(DictFingerprinter().fingerprint(self.to_dict()))

def __repr__(self) -> str:
"""Get the string representation of DataSchema.
Expand Down
4 changes: 2 additions & 2 deletions mleko/dataset/transform/label_encoder_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

from __future__ import annotations

import json
import re
from pathlib import Path
from typing import Hashable

import vaex
import vaex.array_types

from mleko.cache.fingerprinters import DictFingerprinter
from mleko.dataset.data_schema import DataSchema
from mleko.utils.custom_logger import CustomLogger
from mleko.utils.decorators import auto_repr
Expand Down Expand Up @@ -195,7 +195,7 @@ def _fingerprint(self) -> Hashable:
super()._fingerprint(),
self._allow_unseen,
self._encode_null,
json.dumps(self._label_dict, sort_keys=True) if self._label_dict is not None else None,
DictFingerprinter().fingerprint(self._label_dict),
)

def _fit_using_label_dict(self, feature: str, observed_labels: list[str]) -> bool:
Expand Down
77 changes: 77 additions & 0 deletions mleko/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,19 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Hashable, Union

import pandas as pd
import vaex

from mleko.cache.fingerprinters import DictFingerprinter, VaexFingerprinter
from mleko.cache.handlers import JOBLIB_CACHE_HANDLER, VAEX_DATAFRAME_CACHE_HANDLER
from mleko.cache.lru_cache_mixin import LRUCacheMixin
from mleko.dataset.data_schema import DataSchema
from mleko.utils.custom_logger import CustomLogger
from mleko.utils.vaex_helpers import HashableVaexDataFrame, get_columns


logger = CustomLogger()
Expand Down Expand Up @@ -45,6 +48,7 @@ def __init__(
self,
features: list[str] | tuple[str, ...] | None,
ignore_features: list[str] | tuple[str, ...] | None,
memoized_dataset_cache_size: int | None,
cache_directory: str | Path,
cache_size: int,
) -> None:
Expand All @@ -54,18 +58,31 @@ def __init__(
The `features` and `ignore_features` arguments are mutually exclusive. If both are specified, a
`ValueError` is raised.
Warning:
The `memoized_dataset_cache_size` parameter is experimental and should be used with caution. It refers to
the number of datasets to keep in memory for speeding up repeated training. This can be useful when
hyperparameter tuning or cross-validation is performed, as the dataset does not need to be loaded from disk
every time. However, this can lead to memory issues if the dataset is too large. Specify 0 to disable the
cache. When finished with the fitting and transforming, please call the `_clear_dataset_cache` method to
clear the cache and free up memory.
Args:
features: List of feature names to be used by the model. If None, the default is all features
applicable to the model.
ignore_features: List of feature names to be ignored by the model. If None, the default is to
ignore no features.
memoized_dataset_cache_size: The number of datasets to keep in memory for speeding up repeated training.
When finished with the fitting and transforming, please call the `_clear_dataset_cache` method to clear
the cache and free up memory. Specify 0 to disable the cache.
cache_directory: Directory where the cache will be stored locally.
cache_size: The maximum number of entries to keep in the cache.
Raises:
ValueError: If both `features` and `ignore_features` are specified.
"""
super().__init__(cache_directory, cache_size)
self._memoized_dataset_cache_size = memoized_dataset_cache_size

if features is not None and ignore_features is not None:
msg = "Both `features` and `ignore_features` have been specified. The arguments are mutually exclusive."
logger.error(msg)
Expand All @@ -76,6 +93,24 @@ def __init__(
self._features: tuple[str, ...] | None = tuple(features) if features is not None else None
self._ignore_features: tuple[str, ...] = tuple(ignore_features) if ignore_features is not None else tuple()

self._memoized_load_dataset = lru_cache(maxsize=self._memoized_dataset_cache_size)(self._memoized_load_dataset)
"""Load the dataset into memory and memoize the result.
Warning:
This method should be used with caution, as it loads the entire dataset into memory as a pandas DataFrame.
The returned DataFrame will be memoized using the `functools.lru_cache` to avoid reloading the
dataset multiple times. The cache size is set to the `memoized_dataset_cache_size` attribute.
Args:
data_schema: The data schema of the dataframe.
dataframe: The dataframe to load, wrapped in a `HashableVaexDataFrame` object.
additional_features: Additional features to load, such as the target feature.
name: Name of the dataset to be used in the log message.
Returns:
A pandas DataFrame with the loaded data.
"""

def fit(
self,
data_schema: DataSchema,
Expand Down Expand Up @@ -307,6 +342,48 @@ def _feature_set(self, data_schema: DataSchema) -> list[str]:
else self._features
)

def _load_dataset(
self,
data_schema: DataSchema,
dataframe: vaex.DataFrame,
additional_features: list[str] | None = None,
) -> pd.DataFrame:
"""Load the dataset into memory.
Warning:
This method should be used with caution, as it loads the entire dataset into memory as a pandas DataFrame.
Args:
data_schema: The data schema of the dataframe.
dataframe: The dataframe to load.
additional_features: Additional features to load, such as the target feature.
Returns:
A pandas DataFrame with the loaded data.
"""
feature_names = self._feature_set(data_schema)
df = get_columns(dataframe, feature_names + (additional_features if additional_features else [])).to_pandas_df()
return df # type: ignore

def _memoized_load_dataset(
self,
data_schema: DataSchema,
dataframe: HashableVaexDataFrame,
additional_features: tuple[str, ...] | None = None,
name: str | None = None,
) -> pd.DataFrame:
if name is not None:
name = f"{name.strip()} "
else:
name = ""

logger.info(f"Loading the {name}dataset into memory.")
return self._load_dataset(data_schema, dataframe.df, list(additional_features) if additional_features else None)

def clear_load_dataset_cache(self) -> None:
"""Clears the cache for the `_memoized_load_dataset` method."""
self._memoized_load_dataset.cache_clear()

@abstractmethod
def _fit(
self,
Expand Down
49 changes: 26 additions & 23 deletions mleko/model/lgbm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from mleko.dataset.data_schema import DataSchema
from mleko.utils.custom_logger import CustomLogger
from mleko.utils.decorators import auto_repr
from mleko.utils.vaex_helpers import get_columns
from mleko.utils.vaex_helpers import HashableVaexDataFrame

from .base_model import BaseModel, HyperparametersType

Expand Down Expand Up @@ -63,6 +63,7 @@ def __init__(
ignore_features: list[str] | tuple[str, ...] | None = None,
random_state: int | None = 42,
verbosity: int = logging.INFO,
memoized_dataset_cache_size: int | None = 0,
cache_directory: str | Path = "data/lgbm-model",
cache_size: int = 1,
) -> None:
Expand All @@ -73,6 +74,13 @@ def __init__(
By default, all features are used. If ignore_features is provided, all features except the ones in
ignore_features will be used. If features is provided, only the features in features will be used.
Warning:
The `memoized_dataset_cache_size` parameter is experimental and should be used with caution. It refers to
the number of datasets to keep in memory for speeding up repeated training. This can be useful when
hyperparameter tuning or cross-validation is performed, as the dataset does not need to be loaded from disk
every time. However, this can lead to memory issues if the dataset is too large. Specify 0 to disable the
cache. When finished with the fitting and transforming, please call the `_clear_dataset_cache` method to
clear the cache and free up memory.
Args:
target: The name of the target feature.
Expand All @@ -85,6 +93,9 @@ def __init__(
ignore_features: The names of the features to be ignored.
random_state: The random state to be used for reproducibility.
verbosity: The verbosity level of the logger, will be passed to the LightGBM model.
memoized_dataset_cache_size: The number of datasets to keep in memory for speeding up repeated training.
When finished with the fitting and transforming, please call the `_clear_dataset_cache` method to clear
the cache and free up memory. Specify 0 to disable the cache.
cache_directory: The target directory where the model will be saved.
cache_size: The maximum number of entries to keep in the cache.
Expand All @@ -105,7 +116,7 @@ def __init__(
... )
>>> booster, df_train_pred, df_test_pred = model.fit_transform(data_schema, df_train, df_test, {})
"""
super().__init__(features, ignore_features, cache_directory, cache_size)
super().__init__(features, ignore_features, memoized_dataset_cache_size, cache_directory, cache_size)
lgb.register_logger(logger)

self._target = target
Expand Down Expand Up @@ -151,15 +162,23 @@ def _fit(
raise ValueError(msg)

validation_datasets: list[tuple[str, pd.DataFrame, pd.Series]] = []
logger.info("Loading the training dataset into memory.")
train_df = self._load_dataset(data_schema, train_dataframe)
train_df = self._memoized_load_dataset(
data_schema,
HashableVaexDataFrame(train_dataframe),
(self._target,),
name="training",
)
X_train = train_df[self._feature_set(data_schema)]
y_train = train_df[self._target]
validation_datasets.append(("train", X_train, y_train))

if validation_dataframe is not None:
logger.info("Loading the validation dataset into memory.")
validation_df = self._load_dataset(data_schema, validation_dataframe)
validation_df = self._memoized_load_dataset(
data_schema,
HashableVaexDataFrame(validation_dataframe),
(self._target,),
name="validation",
)
X_validation = validation_df[self._feature_set(data_schema)]
y_validation = validation_df[self._target]
validation_datasets.append(("validation", X_validation, y_validation))
Expand Down Expand Up @@ -204,9 +223,7 @@ def _transform(self, data_schema: DataSchema, dataframe: vaex.DataFrame) -> vaex
Returns:
The transformed dataframe.
"""
logger.info("Loading the dataset into memory.")
feature_names = self._feature_set(data_schema)
dataset = get_columns(dataframe, feature_names).to_pandas_df()
dataset = self._memoized_load_dataset(data_schema, HashableVaexDataFrame(dataframe))
df = dataframe.copy()

logger.info("Transforming the dataset.")
Expand Down Expand Up @@ -242,17 +259,3 @@ def _default_features(self, data_schema: DataSchema) -> tuple[str, ...]:
"""
features = data_schema.get_features(["numerical", "boolean", "categorical"])
return tuple(str(feature) for feature in features)

def _load_dataset(self, data_schema: DataSchema, dataframe: vaex.DataFrame) -> pd.DataFrame:
"""Load the dataset into memory.
Args:
data_schema: The data schema of the dataframe.
dataframe: The dataframe to load.
Returns:
A pandas DataFrame with the loaded data.
"""
feature_names = self._feature_set(data_schema)
df: pd.DataFrame = get_columns(dataframe, feature_names + [self._target]).to_pandas_df() # type: ignore
return df
28 changes: 28 additions & 0 deletions mleko/utils/vaex_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from __future__ import annotations

from dataclasses import dataclass

import vaex


Expand Down Expand Up @@ -88,3 +90,29 @@ def get_indices(df: vaex.DataFrame, indices: list[int]) -> vaex.DataFrame:
selection = get_filtered_df(df, index.isin(indices))
selection.delete_virtual_column(idx_name)
return selection.extract()


@dataclass(frozen=True)
class HashableVaexDataFrame:
"""An immutable hashable wrapper around a `vaex.DataFrame`."""

df: vaex.DataFrame

def __eq__(self, other) -> bool:
"""Check if two `HashableVaexDataFrame` objects are equal.
Args:
other: `HashableVaexDataFrame` object to compare with.
Returns:
True if the two `HashableVaexDataFrame` objects are equal, False otherwise.
"""
return isinstance(other, HashableVaexDataFrame) and self.df.fingerprint() == other.df.fingerprint()

def __hash__(self) -> int:
"""Get the hash of the `HashableVaexDataFrame`.
Returns:
Hash of the `HashableVaexDataFrame`.
"""
return hash(self.df.fingerprint())
44 changes: 44 additions & 0 deletions tests/dataset/test_data_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,47 @@ def test_get_feature_type(self):

with pytest.raises(ValueError):
data_schema.get_type("non_existent_feature")

def test_hash_match(self):
"""Should return the same hash for the same schema."""
data_schema1 = DataSchema(
numerical=["numerical1", "numerical2"],
categorical=["categorical1", "categorical2"],
boolean=["boolean1", "boolean2"],
datetime=["datetime1", "datetime2"],
timedelta=["timedelta1", "timedelta2"],
)
data_schema2 = DataSchema(
numerical=["numerical1", "numerical2"],
categorical=["categorical1", "categorical2"],
boolean=["boolean1", "boolean2"],
datetime=["datetime1", "datetime2"],
timedelta=["timedelta1", "timedelta2"],
)
assert hash(data_schema1) == hash(data_schema2)

def test_data_schema_equality_and_inequality(self):
"""Should test equality and inequality of DataSchema instances."""
data_schema1 = DataSchema(
numerical=["numerical1", "numerical2"],
categorical=["categorical1", "categorical2"],
boolean=["boolean1", "boolean2"],
datetime=["datetime1", "datetime2"],
timedelta=["timedelta1", "timedelta2"],
)
data_schema2 = DataSchema(
numerical=["numerical1", "numerical2"],
categorical=["categorical1", "categorical2"],
boolean=["boolean1", "boolean2"],
datetime=["datetime1", "datetime2"],
timedelta=["timedelta1", "timedelta2"],
)
data_schema3 = DataSchema(
numerical=["numerical1", "numerical2"],
categorical=["categorical1", "categorical2"],
boolean=["boolean1", "boolean2"],
datetime=["datetime1", "datetime2"],
timedelta=["timedelta1"],
)
assert data_schema1 == data_schema2
assert data_schema1 != data_schema3
Loading

0 comments on commit 6a955dc

Please sign in to comment.