From 543e5dba7ee44b9677d14209aa49508d35322f25 Mon Sep 17 00:00:00 2001 From: Matt Watson Date: Tue, 1 Oct 2024 02:18:45 -0700 Subject: [PATCH] Allow backbone to be any model, preprocessor to be any callable Consolidate saving into a saving class to match loading. Plenty more cleanup to do there probably, but this will at least bring our saving and loading routines into a common flow. --- .../layers/preprocessing/audio_converter.py | 10 +- .../layers/preprocessing/image_converter.py | 10 +- keras_hub/src/models/backbone.py | 12 +- keras_hub/src/models/preprocessor.py | 8 +- keras_hub/src/models/task.py | 40 ++--- keras_hub/src/models/task_test.py | 9 -- keras_hub/src/tokenizers/tokenizer.py | 9 +- keras_hub/src/utils/preset_utils.py | 152 +++++++++++------- 8 files changed, 126 insertions(+), 124 deletions(-) diff --git a/keras_hub/src/layers/preprocessing/audio_converter.py b/keras_hub/src/layers/preprocessing/audio_converter.py index 3ddf805abc..8e655702c4 100644 --- a/keras_hub/src/layers/preprocessing/audio_converter.py +++ b/keras_hub/src/layers/preprocessing/audio_converter.py @@ -2,11 +2,10 @@ from keras_hub.src.layers.preprocessing.preprocessing_layer import ( PreprocessingLayer, ) -from keras_hub.src.utils.preset_utils import AUDIO_CONVERTER_CONFIG_FILE from keras_hub.src.utils.preset_utils import builtin_presets from keras_hub.src.utils.preset_utils import find_subclass from keras_hub.src.utils.preset_utils import get_preset_loader -from keras_hub.src.utils.preset_utils import save_serialized_object +from keras_hub.src.utils.preset_utils import get_preset_saver from keras_hub.src.utils.python_utils import classproperty @@ -101,8 +100,5 @@ def save_to_preset(self, preset_dir): Args: preset_dir: The path to the local model preset directory. """ - save_serialized_object( - self, - preset_dir, - config_file=AUDIO_CONVERTER_CONFIG_FILE, - ) + saver = get_preset_saver(preset_dir) + saver.save_audio_converter(self) diff --git a/keras_hub/src/layers/preprocessing/image_converter.py b/keras_hub/src/layers/preprocessing/image_converter.py index b93e36e069..3ff6fc09ef 100644 --- a/keras_hub/src/layers/preprocessing/image_converter.py +++ b/keras_hub/src/layers/preprocessing/image_converter.py @@ -2,11 +2,10 @@ from keras_hub.src.layers.preprocessing.preprocessing_layer import ( PreprocessingLayer, ) -from keras_hub.src.utils.preset_utils import IMAGE_CONVERTER_CONFIG_FILE from keras_hub.src.utils.preset_utils import builtin_presets from keras_hub.src.utils.preset_utils import find_subclass from keras_hub.src.utils.preset_utils import get_preset_loader -from keras_hub.src.utils.preset_utils import save_serialized_object +from keras_hub.src.utils.preset_utils import get_preset_saver from keras_hub.src.utils.python_utils import classproperty @@ -110,8 +109,5 @@ def save_to_preset(self, preset_dir): Args: preset_dir: The path to the local model preset directory. """ - save_serialized_object( - self, - preset_dir, - config_file=IMAGE_CONVERTER_CONFIG_FILE, - ) + saver = get_preset_saver(preset_dir) + saver.save_image_converter(self) diff --git a/keras_hub/src/models/backbone.py b/keras_hub/src/models/backbone.py index fa065a3df5..dfe4b31b31 100644 --- a/keras_hub/src/models/backbone.py +++ b/keras_hub/src/models/backbone.py @@ -1,15 +1,10 @@ -import os - import keras from keras_hub.src.api_export import keras_hub_export from keras_hub.src.utils.keras_utils import assert_quantization_support -from keras_hub.src.utils.preset_utils import CONFIG_FILE -from keras_hub.src.utils.preset_utils import MODEL_WEIGHTS_FILE from keras_hub.src.utils.preset_utils import builtin_presets from keras_hub.src.utils.preset_utils import get_preset_loader -from keras_hub.src.utils.preset_utils import save_metadata -from keras_hub.src.utils.preset_utils import save_serialized_object +from keras_hub.src.utils.preset_utils import get_preset_saver from keras_hub.src.utils.python_utils import classproperty @@ -193,9 +188,8 @@ def save_to_preset(self, preset_dir): Args: preset_dir: The path to the local model preset directory. """ - save_serialized_object(self, preset_dir, config_file=CONFIG_FILE) - self.save_weights(os.path.join(preset_dir, MODEL_WEIGHTS_FILE)) - save_metadata(self, preset_dir) + saver = get_preset_saver(preset_dir) + saver.save_backbone(self) def enable_lora(self, rank): """Enable Lora on the backbone. diff --git a/keras_hub/src/models/preprocessor.py b/keras_hub/src/models/preprocessor.py index c12b0481f0..f0569a36f8 100644 --- a/keras_hub/src/models/preprocessor.py +++ b/keras_hub/src/models/preprocessor.py @@ -8,7 +8,7 @@ from keras_hub.src.utils.preset_utils import builtin_presets from keras_hub.src.utils.preset_utils import find_subclass from keras_hub.src.utils.preset_utils import get_preset_loader -from keras_hub.src.utils.preset_utils import save_serialized_object +from keras_hub.src.utils.preset_utils import get_preset_saver from keras_hub.src.utils.python_utils import classproperty @@ -209,7 +209,5 @@ def save_to_preset(self, preset_dir): Args: preset_dir: The path to the local model preset directory. """ - save_serialized_object(self, preset_dir, config_file=self.config_name) - for layer in self._flatten_layers(include_self=False): - if hasattr(layer, "save_to_preset"): - layer.save_to_preset(preset_dir) + saver = get_preset_saver(preset_dir) + saver.save_preprocessor(self) diff --git a/keras_hub/src/models/task.py b/keras_hub/src/models/task.py index 080c67c221..fd603d517e 100644 --- a/keras_hub/src/models/task.py +++ b/keras_hub/src/models/task.py @@ -1,19 +1,16 @@ -import os - import keras from rich import console as rich_console from rich import markup from rich import table as rich_table from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.preprocessor import Preprocessor from keras_hub.src.utils.keras_utils import print_msg from keras_hub.src.utils.pipeline_model import PipelineModel -from keras_hub.src.utils.preset_utils import TASK_CONFIG_FILE -from keras_hub.src.utils.preset_utils import TASK_WEIGHTS_FILE from keras_hub.src.utils.preset_utils import builtin_presets from keras_hub.src.utils.preset_utils import find_subclass from keras_hub.src.utils.preset_utils import get_preset_loader -from keras_hub.src.utils.preset_utils import save_serialized_object +from keras_hub.src.utils.preset_utils import get_preset_saver from keras_hub.src.utils.python_utils import classproperty @@ -58,10 +55,15 @@ def __init__(self, *args, compile=True, **kwargs): self.compile() def preprocess_samples(self, x, y=None, sample_weight=None): - if self.preprocessor is not None: + # If `preprocessor` is `None`, return inputs unaltered. + if self.preprocessor is None: + return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) + # If `preprocessor` is `Preprocessor` subclass, pass labels as a kwarg. + if isinstance(self.preprocessor, Preprocessor): return self.preprocessor(x, y=y, sample_weight=sample_weight) - else: - return super().preprocess_samples(x, y, sample_weight) + # For other layers and callable, do not pass the label. + x = self.preprocessor(x) + return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) def __setattr__(self, name, value): # Work around setattr issues for Keras 2 and Keras 3 torch backend. @@ -232,17 +234,8 @@ def save_to_preset(self, preset_dir): Args: preset_dir: The path to the local model preset directory. """ - if self.preprocessor is None: - raise ValueError( - "Cannot save `task` to preset: `Preprocessor` is not initialized." - ) - - save_serialized_object(self, preset_dir, config_file=TASK_CONFIG_FILE) - if self.has_task_weights(): - self.save_task_weights(os.path.join(preset_dir, TASK_WEIGHTS_FILE)) - - self.preprocessor.save_to_preset(preset_dir) - self.backbone.save_to_preset(preset_dir) + saver = get_preset_saver(preset_dir) + saver.save_task(self) @property def layers(self): @@ -327,24 +320,25 @@ def add_layer(layer, info): info, ) - tokenizer = self.preprocessor.tokenizer + preprocessor = self.preprocessor + tokenizer = getattr(preprocessor, "tokenizer", None) if tokenizer: info = "Vocab size: " info += highlight_number(tokenizer.vocabulary_size()) add_layer(tokenizer, info) - image_converter = self.preprocessor.image_converter + image_converter = getattr(preprocessor, "image_converter", None) if image_converter: info = "Image size: " info += highlight_shape(image_converter.image_size()) add_layer(image_converter, info) - audio_converter = self.preprocessor.audio_converter + audio_converter = getattr(preprocessor, "audio_converter", None) if audio_converter: info = "Audio shape: " info += highlight_shape(audio_converter.audio_shape()) add_layer(audio_converter, info) # Print the to the console. - preprocessor_name = markup.escape(self.preprocessor.name) + preprocessor_name = markup.escape(preprocessor.name) console.print(bold_text(f'Preprocessor: "{preprocessor_name}"')) console.print(table) diff --git a/keras_hub/src/models/task_test.py b/keras_hub/src/models/task_test.py index bcc0b791f4..90d641a609 100644 --- a/keras_hub/src/models/task_test.py +++ b/keras_hub/src/models/task_test.py @@ -144,12 +144,3 @@ def test_save_to_preset(self): ref_out = task.backbone.predict(data) new_out = restored_task.backbone.predict(data) self.assertAllClose(ref_out, new_out) - - @pytest.mark.large - def test_none_preprocessor(self): - model = TextClassifier.from_preset( - "bert_tiny_en_uncased", - preprocessor=None, - num_classes=2, - ) - self.assertEqual(model.preprocessor, None) diff --git a/keras_hub/src/tokenizers/tokenizer.py b/keras_hub/src/tokenizers/tokenizer.py index 914e41a21e..b97efae444 100644 --- a/keras_hub/src/tokenizers/tokenizer.py +++ b/keras_hub/src/tokenizers/tokenizer.py @@ -10,7 +10,7 @@ from keras_hub.src.utils.preset_utils import find_subclass from keras_hub.src.utils.preset_utils import get_file from keras_hub.src.utils.preset_utils import get_preset_loader -from keras_hub.src.utils.preset_utils import save_serialized_object +from keras_hub.src.utils.preset_utils import get_preset_saver from keras_hub.src.utils.python_utils import classproperty from keras_hub.src.utils.tensor_utils import preprocessing_function @@ -189,11 +189,8 @@ def save_to_preset(self, preset_dir): Args: preset_dir: The path to the local model preset directory. """ - save_serialized_object(self, preset_dir, config_file=self.config_name) - subdir = self.config_name.split(".")[0] - asset_dir = os.path.join(preset_dir, ASSET_DIR, subdir) - os.makedirs(asset_dir, exist_ok=True) - self.save_assets(asset_dir) + saver = get_preset_saver(preset_dir) + saver.save_tokenizer(self) @preprocessing_function def call(self, inputs, *args, training=None, **kwargs): diff --git a/keras_hub/src/utils/preset_utils.py b/keras_hub/src/utils/preset_utils.py index c628c2d012..186477d0e2 100644 --- a/keras_hub/src/utils/preset_utils.py +++ b/keras_hub/src/utils/preset_utils.py @@ -267,64 +267,6 @@ def check_file_exists(preset, path): return True -def get_tokenizer(layer): - """Get the tokenizer from any KerasHub model or layer.""" - # Avoid circular import. - from keras_hub.src.tokenizers.tokenizer import Tokenizer - - if isinstance(layer, Tokenizer): - return layer - if hasattr(layer, "tokenizer"): - return layer.tokenizer - if hasattr(layer, "preprocessor"): - return getattr(layer.preprocessor, "tokenizer", None) - return None - - -def recursive_pop(config, key): - """Remove a key from a nested config object""" - config.pop(key, None) - for value in config.values(): - if isinstance(value, dict): - recursive_pop(value, key) - - -# TODO: refactor saving routines into a PresetSaver class? -def make_preset_dir(preset): - os.makedirs(preset, exist_ok=True) - - -def save_serialized_object( - layer, - preset, - config_file=CONFIG_FILE, - config_to_skip=[], -): - make_preset_dir(preset) - config_path = os.path.join(preset, config_file) - config = keras.saving.serialize_keras_object(layer) - config_to_skip += ["compile_config", "build_config"] - for c in config_to_skip: - recursive_pop(config, c) - with open(config_path, "w") as config_file: - config_file.write(json.dumps(config, indent=4)) - - -def save_metadata(layer, preset): - from keras_hub.src.version_utils import __version__ as keras_hub_version - - keras_version = keras.version() if hasattr(keras, "version") else None - metadata = { - "keras_version": keras_version, - "keras_hub_version": keras_hub_version, - "parameter_count": layer.count_params(), - "date_saved": datetime.datetime.now().strftime("%Y-%m-%d@%H:%M:%S"), - } - metadata_path = os.path.join(preset, METADATA_FILE) - with open(metadata_path, "w") as metadata_file: - metadata_file.write(json.dumps(metadata, indent=4)) - - def _validate_backbone(preset): config_path = os.path.join(preset, CONFIG_FILE) if not os.path.exists(config_path): @@ -600,6 +542,13 @@ def get_preset_loader(preset): ) +def get_preset_saver(preset): + # Unlike loading, we only support one form of saving, the Keras format. + # We keep the rough API structure as loading for simplicity and to allow + # extensibility down the road if we need it. + return KerasPresetSaver(preset) + + class PresetLoader: def __init__(self, preset, config): self.config = config @@ -737,3 +686,90 @@ def load_preprocessor( preprocessor = load_serialized_object(preprocessor_json, **kwargs) preprocessor.load_preset_assets(self.preset) return preprocessor + + +class KerasPresetSaver: + def __init__(self, preset_dir): + os.makedirs(preset_dir, exist_ok=True) + self.preset_dir = preset_dir + + def save_backbone(self, backbone): + self._save_serialized_object(backbone, config_file=CONFIG_FILE) + backbone_weight_path = os.path.join(self.preset_dir, MODEL_WEIGHTS_FILE) + backbone.save_weights(backbone_weight_path) + self._save_metadata(backbone) + + def save_tokenizer(self, tokenizer): + config_file = TOKENIZER_CONFIG_FILE + if hasattr(tokenizer, "config_file"): + config_file = tokenizer.config_file + self._save_serialized_object(tokenizer, config_file) + # Save assets. + subdir = config_file.split(".")[0] + asset_dir = os.path.join(self.preset_dir, ASSET_DIR, subdir) + os.makedirs(asset_dir, exist_ok=True) + tokenizer.save_assets(asset_dir) + + def save_audio_converter(self, converter): + self._save_serialized_object(converter, AUDIO_CONVERTER_CONFIG_FILE) + + def save_image_converter(self, converter): + self._save_serialized_object(converter, IMAGE_CONVERTER_CONFIG_FILE) + + def save_task(self, task): + # Save task specific config and weights. + self._save_serialized_object(task, TASK_CONFIG_FILE) + if task.has_task_weights(): + task_weight_path = os.path.join(self.preset_dir, TASK_WEIGHTS_FILE) + task.save_task_weights(task_weight_path) + # Save backbone. + if hasattr(task.backbone, "save_to_preset"): + task.backbone.save_to_preset(self.preset_dir) + else: + # Allow saving a `keras.Model` that is not a backbone subclass. + self.save_backbone(task.backbone) + # Save preprocessor. + if task.preprocessor and hasattr(task.preprocessor, "save_to_preset"): + task.preprocessor.save_to_preset(self.preset_dir) + else: + # Allow saving a `keras.Layer` that is not a preprocessor subclass. + self.save_preprocessor(task.preprocessor) + + def save_preprocessor(self, preprocessor): + config_file = PREPROCESSOR_CONFIG_FILE + if hasattr(preprocessor, "config_file"): + config_file = preprocessor.config_file + self._save_serialized_object(preprocessor, config_file) + for layer in preprocessor._flatten_layers(include_self=False): + if hasattr(layer, "save_to_preset"): + layer.save_to_preset(self.preset_dir) + + def _recursive_pop(self, config, key): + """Remove a key from a nested config object""" + config.pop(key, None) + for value in config.values(): + if isinstance(value, dict): + self._recursive_pop(value, key) + + def _save_serialized_object(self, layer, config_file): + config_path = os.path.join(self.preset_dir, config_file) + config = keras.saving.serialize_keras_object(layer) + config_to_skip = ["compile_config", "build_config"] + for key in config_to_skip: + self._recursive_pop(config, key) + with open(config_path, "w") as config_file: + config_file.write(json.dumps(config, indent=4)) + + def _save_metadata(self, layer): + from keras_hub.src.version_utils import __version__ as keras_hub_version + + keras_version = keras.version() if hasattr(keras, "version") else None + metadata = { + "keras_version": keras_version, + "keras_hub_version": keras_hub_version, + "parameter_count": layer.count_params(), + "date_saved": datetime.datetime.now().strftime("%Y-%m-%d@%H:%M:%S"), + } + metadata_path = os.path.join(self.preset_dir, METADATA_FILE) + with open(metadata_path, "w") as metadata_file: + metadata_file.write(json.dumps(metadata, indent=4))