Skip to content

Commit

Permalink
Allow backbone to be any model, preprocessor to be any callable
Browse files Browse the repository at this point in the history
Consolidate saving into a saving class to match loading.
Plenty more cleanup to do there probably, but this will at
least bring our saving and loading routines into a common
flow.
  • Loading branch information
mattdangerw committed Oct 1, 2024
1 parent a77595e commit 543e5db
Show file tree
Hide file tree
Showing 8 changed files with 126 additions and 124 deletions.
10 changes: 3 additions & 7 deletions keras_hub/src/layers/preprocessing/audio_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
from keras_hub.src.layers.preprocessing.preprocessing_layer import (
PreprocessingLayer,
)
from keras_hub.src.utils.preset_utils import AUDIO_CONVERTER_CONFIG_FILE
from keras_hub.src.utils.preset_utils import builtin_presets
from keras_hub.src.utils.preset_utils import find_subclass
from keras_hub.src.utils.preset_utils import get_preset_loader
from keras_hub.src.utils.preset_utils import save_serialized_object
from keras_hub.src.utils.preset_utils import get_preset_saver
from keras_hub.src.utils.python_utils import classproperty


Expand Down Expand Up @@ -101,8 +100,5 @@ def save_to_preset(self, preset_dir):
Args:
preset_dir: The path to the local model preset directory.
"""
save_serialized_object(
self,
preset_dir,
config_file=AUDIO_CONVERTER_CONFIG_FILE,
)
saver = get_preset_saver(preset_dir)
saver.save_audio_converter(self)
10 changes: 3 additions & 7 deletions keras_hub/src/layers/preprocessing/image_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
from keras_hub.src.layers.preprocessing.preprocessing_layer import (
PreprocessingLayer,
)
from keras_hub.src.utils.preset_utils import IMAGE_CONVERTER_CONFIG_FILE
from keras_hub.src.utils.preset_utils import builtin_presets
from keras_hub.src.utils.preset_utils import find_subclass
from keras_hub.src.utils.preset_utils import get_preset_loader
from keras_hub.src.utils.preset_utils import save_serialized_object
from keras_hub.src.utils.preset_utils import get_preset_saver
from keras_hub.src.utils.python_utils import classproperty


Expand Down Expand Up @@ -110,8 +109,5 @@ def save_to_preset(self, preset_dir):
Args:
preset_dir: The path to the local model preset directory.
"""
save_serialized_object(
self,
preset_dir,
config_file=IMAGE_CONVERTER_CONFIG_FILE,
)
saver = get_preset_saver(preset_dir)
saver.save_image_converter(self)
12 changes: 3 additions & 9 deletions keras_hub/src/models/backbone.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
import os

import keras

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.utils.keras_utils import assert_quantization_support
from keras_hub.src.utils.preset_utils import CONFIG_FILE
from keras_hub.src.utils.preset_utils import MODEL_WEIGHTS_FILE
from keras_hub.src.utils.preset_utils import builtin_presets
from keras_hub.src.utils.preset_utils import get_preset_loader
from keras_hub.src.utils.preset_utils import save_metadata
from keras_hub.src.utils.preset_utils import save_serialized_object
from keras_hub.src.utils.preset_utils import get_preset_saver
from keras_hub.src.utils.python_utils import classproperty


Expand Down Expand Up @@ -193,9 +188,8 @@ def save_to_preset(self, preset_dir):
Args:
preset_dir: The path to the local model preset directory.
"""
save_serialized_object(self, preset_dir, config_file=CONFIG_FILE)
self.save_weights(os.path.join(preset_dir, MODEL_WEIGHTS_FILE))
save_metadata(self, preset_dir)
saver = get_preset_saver(preset_dir)
saver.save_backbone(self)

def enable_lora(self, rank):
"""Enable Lora on the backbone.
Expand Down
8 changes: 3 additions & 5 deletions keras_hub/src/models/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from keras_hub.src.utils.preset_utils import builtin_presets
from keras_hub.src.utils.preset_utils import find_subclass
from keras_hub.src.utils.preset_utils import get_preset_loader
from keras_hub.src.utils.preset_utils import save_serialized_object
from keras_hub.src.utils.preset_utils import get_preset_saver
from keras_hub.src.utils.python_utils import classproperty


Expand Down Expand Up @@ -209,7 +209,5 @@ def save_to_preset(self, preset_dir):
Args:
preset_dir: The path to the local model preset directory.
"""
save_serialized_object(self, preset_dir, config_file=self.config_name)
for layer in self._flatten_layers(include_self=False):
if hasattr(layer, "save_to_preset"):
layer.save_to_preset(preset_dir)
saver = get_preset_saver(preset_dir)
saver.save_preprocessor(self)
40 changes: 17 additions & 23 deletions keras_hub/src/models/task.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
import os

import keras
from rich import console as rich_console
from rich import markup
from rich import table as rich_table

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.preprocessor import Preprocessor
from keras_hub.src.utils.keras_utils import print_msg
from keras_hub.src.utils.pipeline_model import PipelineModel
from keras_hub.src.utils.preset_utils import TASK_CONFIG_FILE
from keras_hub.src.utils.preset_utils import TASK_WEIGHTS_FILE
from keras_hub.src.utils.preset_utils import builtin_presets
from keras_hub.src.utils.preset_utils import find_subclass
from keras_hub.src.utils.preset_utils import get_preset_loader
from keras_hub.src.utils.preset_utils import save_serialized_object
from keras_hub.src.utils.preset_utils import get_preset_saver
from keras_hub.src.utils.python_utils import classproperty


Expand Down Expand Up @@ -58,10 +55,15 @@ def __init__(self, *args, compile=True, **kwargs):
self.compile()

def preprocess_samples(self, x, y=None, sample_weight=None):
if self.preprocessor is not None:
# If `preprocessor` is `None`, return inputs unaltered.
if self.preprocessor is None:
return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
# If `preprocessor` is `Preprocessor` subclass, pass labels as a kwarg.
if isinstance(self.preprocessor, Preprocessor):
return self.preprocessor(x, y=y, sample_weight=sample_weight)
else:
return super().preprocess_samples(x, y, sample_weight)
# For other layers and callable, do not pass the label.
x = self.preprocessor(x)
return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)

def __setattr__(self, name, value):
# Work around setattr issues for Keras 2 and Keras 3 torch backend.
Expand Down Expand Up @@ -232,17 +234,8 @@ def save_to_preset(self, preset_dir):
Args:
preset_dir: The path to the local model preset directory.
"""
if self.preprocessor is None:
raise ValueError(
"Cannot save `task` to preset: `Preprocessor` is not initialized."
)

save_serialized_object(self, preset_dir, config_file=TASK_CONFIG_FILE)
if self.has_task_weights():
self.save_task_weights(os.path.join(preset_dir, TASK_WEIGHTS_FILE))

self.preprocessor.save_to_preset(preset_dir)
self.backbone.save_to_preset(preset_dir)
saver = get_preset_saver(preset_dir)
saver.save_task(self)

@property
def layers(self):
Expand Down Expand Up @@ -327,24 +320,25 @@ def add_layer(layer, info):
info,
)

tokenizer = self.preprocessor.tokenizer
preprocessor = self.preprocessor
tokenizer = getattr(preprocessor, "tokenizer", None)
if tokenizer:
info = "Vocab size: "
info += highlight_number(tokenizer.vocabulary_size())
add_layer(tokenizer, info)
image_converter = self.preprocessor.image_converter
image_converter = getattr(preprocessor, "image_converter", None)
if image_converter:
info = "Image size: "
info += highlight_shape(image_converter.image_size())
add_layer(image_converter, info)
audio_converter = self.preprocessor.audio_converter
audio_converter = getattr(preprocessor, "audio_converter", None)
if audio_converter:
info = "Audio shape: "
info += highlight_shape(audio_converter.audio_shape())
add_layer(audio_converter, info)

# Print the to the console.
preprocessor_name = markup.escape(self.preprocessor.name)
preprocessor_name = markup.escape(preprocessor.name)
console.print(bold_text(f'Preprocessor: "{preprocessor_name}"'))
console.print(table)

Expand Down
9 changes: 0 additions & 9 deletions keras_hub/src/models/task_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,3 @@ def test_save_to_preset(self):
ref_out = task.backbone.predict(data)
new_out = restored_task.backbone.predict(data)
self.assertAllClose(ref_out, new_out)

@pytest.mark.large
def test_none_preprocessor(self):
model = TextClassifier.from_preset(
"bert_tiny_en_uncased",
preprocessor=None,
num_classes=2,
)
self.assertEqual(model.preprocessor, None)
9 changes: 3 additions & 6 deletions keras_hub/src/tokenizers/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from keras_hub.src.utils.preset_utils import find_subclass
from keras_hub.src.utils.preset_utils import get_file
from keras_hub.src.utils.preset_utils import get_preset_loader
from keras_hub.src.utils.preset_utils import save_serialized_object
from keras_hub.src.utils.preset_utils import get_preset_saver
from keras_hub.src.utils.python_utils import classproperty
from keras_hub.src.utils.tensor_utils import preprocessing_function

Expand Down Expand Up @@ -189,11 +189,8 @@ def save_to_preset(self, preset_dir):
Args:
preset_dir: The path to the local model preset directory.
"""
save_serialized_object(self, preset_dir, config_file=self.config_name)
subdir = self.config_name.split(".")[0]
asset_dir = os.path.join(preset_dir, ASSET_DIR, subdir)
os.makedirs(asset_dir, exist_ok=True)
self.save_assets(asset_dir)
saver = get_preset_saver(preset_dir)
saver.save_tokenizer(self)

@preprocessing_function
def call(self, inputs, *args, training=None, **kwargs):
Expand Down
152 changes: 94 additions & 58 deletions keras_hub/src/utils/preset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,64 +267,6 @@ def check_file_exists(preset, path):
return True


def get_tokenizer(layer):
"""Get the tokenizer from any KerasHub model or layer."""
# Avoid circular import.
from keras_hub.src.tokenizers.tokenizer import Tokenizer

if isinstance(layer, Tokenizer):
return layer
if hasattr(layer, "tokenizer"):
return layer.tokenizer
if hasattr(layer, "preprocessor"):
return getattr(layer.preprocessor, "tokenizer", None)
return None


def recursive_pop(config, key):
"""Remove a key from a nested config object"""
config.pop(key, None)
for value in config.values():
if isinstance(value, dict):
recursive_pop(value, key)


# TODO: refactor saving routines into a PresetSaver class?
def make_preset_dir(preset):
os.makedirs(preset, exist_ok=True)


def save_serialized_object(
layer,
preset,
config_file=CONFIG_FILE,
config_to_skip=[],
):
make_preset_dir(preset)
config_path = os.path.join(preset, config_file)
config = keras.saving.serialize_keras_object(layer)
config_to_skip += ["compile_config", "build_config"]
for c in config_to_skip:
recursive_pop(config, c)
with open(config_path, "w") as config_file:
config_file.write(json.dumps(config, indent=4))


def save_metadata(layer, preset):
from keras_hub.src.version_utils import __version__ as keras_hub_version

keras_version = keras.version() if hasattr(keras, "version") else None
metadata = {
"keras_version": keras_version,
"keras_hub_version": keras_hub_version,
"parameter_count": layer.count_params(),
"date_saved": datetime.datetime.now().strftime("%Y-%m-%d@%H:%M:%S"),
}
metadata_path = os.path.join(preset, METADATA_FILE)
with open(metadata_path, "w") as metadata_file:
metadata_file.write(json.dumps(metadata, indent=4))


def _validate_backbone(preset):
config_path = os.path.join(preset, CONFIG_FILE)
if not os.path.exists(config_path):
Expand Down Expand Up @@ -600,6 +542,13 @@ def get_preset_loader(preset):
)


def get_preset_saver(preset):
# Unlike loading, we only support one form of saving, the Keras format.
# We keep the rough API structure as loading for simplicity and to allow
# extensibility down the road if we need it.
return KerasPresetSaver(preset)


class PresetLoader:
def __init__(self, preset, config):
self.config = config
Expand Down Expand Up @@ -737,3 +686,90 @@ def load_preprocessor(
preprocessor = load_serialized_object(preprocessor_json, **kwargs)
preprocessor.load_preset_assets(self.preset)
return preprocessor


class KerasPresetSaver:
def __init__(self, preset_dir):
os.makedirs(preset_dir, exist_ok=True)
self.preset_dir = preset_dir

def save_backbone(self, backbone):
self._save_serialized_object(backbone, config_file=CONFIG_FILE)
backbone_weight_path = os.path.join(self.preset_dir, MODEL_WEIGHTS_FILE)
backbone.save_weights(backbone_weight_path)
self._save_metadata(backbone)

def save_tokenizer(self, tokenizer):
config_file = TOKENIZER_CONFIG_FILE
if hasattr(tokenizer, "config_file"):
config_file = tokenizer.config_file
self._save_serialized_object(tokenizer, config_file)
# Save assets.
subdir = config_file.split(".")[0]
asset_dir = os.path.join(self.preset_dir, ASSET_DIR, subdir)
os.makedirs(asset_dir, exist_ok=True)
tokenizer.save_assets(asset_dir)

def save_audio_converter(self, converter):
self._save_serialized_object(converter, AUDIO_CONVERTER_CONFIG_FILE)

def save_image_converter(self, converter):
self._save_serialized_object(converter, IMAGE_CONVERTER_CONFIG_FILE)

def save_task(self, task):
# Save task specific config and weights.
self._save_serialized_object(task, TASK_CONFIG_FILE)
if task.has_task_weights():
task_weight_path = os.path.join(self.preset_dir, TASK_WEIGHTS_FILE)
task.save_task_weights(task_weight_path)
# Save backbone.
if hasattr(task.backbone, "save_to_preset"):
task.backbone.save_to_preset(self.preset_dir)
else:
# Allow saving a `keras.Model` that is not a backbone subclass.
self.save_backbone(task.backbone)
# Save preprocessor.
if task.preprocessor and hasattr(task.preprocessor, "save_to_preset"):
task.preprocessor.save_to_preset(self.preset_dir)
else:
# Allow saving a `keras.Layer` that is not a preprocessor subclass.
self.save_preprocessor(task.preprocessor)

def save_preprocessor(self, preprocessor):
config_file = PREPROCESSOR_CONFIG_FILE
if hasattr(preprocessor, "config_file"):
config_file = preprocessor.config_file
self._save_serialized_object(preprocessor, config_file)
for layer in preprocessor._flatten_layers(include_self=False):
if hasattr(layer, "save_to_preset"):
layer.save_to_preset(self.preset_dir)

def _recursive_pop(self, config, key):
"""Remove a key from a nested config object"""
config.pop(key, None)
for value in config.values():
if isinstance(value, dict):
self._recursive_pop(value, key)

def _save_serialized_object(self, layer, config_file):
config_path = os.path.join(self.preset_dir, config_file)
config = keras.saving.serialize_keras_object(layer)
config_to_skip = ["compile_config", "build_config"]
for key in config_to_skip:
self._recursive_pop(config, key)
with open(config_path, "w") as config_file:
config_file.write(json.dumps(config, indent=4))

def _save_metadata(self, layer):
from keras_hub.src.version_utils import __version__ as keras_hub_version

keras_version = keras.version() if hasattr(keras, "version") else None
metadata = {
"keras_version": keras_version,
"keras_hub_version": keras_hub_version,
"parameter_count": layer.count_params(),
"date_saved": datetime.datetime.now().strftime("%Y-%m-%d@%H:%M:%S"),
}
metadata_path = os.path.join(self.preset_dir, METADATA_FILE)
with open(metadata_path, "w") as metadata_file:
metadata_file.write(json.dumps(metadata, indent=4))

0 comments on commit 543e5db

Please sign in to comment.