From 0b005a3e14b3575349f81b0a23c4ed8838fb9c68 Mon Sep 17 00:00:00 2001 From: Matt Watson Date: Tue, 1 Oct 2024 02:18:45 -0700 Subject: [PATCH] Allow backbone to be any functional, preprocessor any callable Consolidate saving into a saving class to match loading. Plenty more cleanup to do there probably, but this will at least bring our saving and loading routines into a common flow. --- .../layers/preprocessing/audio_converter.py | 10 +- .../layers/preprocessing/image_converter.py | 10 +- keras_hub/src/models/backbone.py | 12 +- keras_hub/src/models/preprocessor.py | 8 +- .../models/resnet/resnet_image_classifier.py | 5 +- keras_hub/src/models/task.py | 46 +++-- keras_hub/src/models/task_test.py | 29 +++- keras_hub/src/tokenizers/tokenizer.py | 9 +- keras_hub/src/utils/preset_utils.py | 164 +++++++++++------- 9 files changed, 166 insertions(+), 127 deletions(-) diff --git a/keras_hub/src/layers/preprocessing/audio_converter.py b/keras_hub/src/layers/preprocessing/audio_converter.py index 3ddf805abc..8e655702c4 100644 --- a/keras_hub/src/layers/preprocessing/audio_converter.py +++ b/keras_hub/src/layers/preprocessing/audio_converter.py @@ -2,11 +2,10 @@ from keras_hub.src.layers.preprocessing.preprocessing_layer import ( PreprocessingLayer, ) -from keras_hub.src.utils.preset_utils import AUDIO_CONVERTER_CONFIG_FILE from keras_hub.src.utils.preset_utils import builtin_presets from keras_hub.src.utils.preset_utils import find_subclass from keras_hub.src.utils.preset_utils import get_preset_loader -from keras_hub.src.utils.preset_utils import save_serialized_object +from keras_hub.src.utils.preset_utils import get_preset_saver from keras_hub.src.utils.python_utils import classproperty @@ -101,8 +100,5 @@ def save_to_preset(self, preset_dir): Args: preset_dir: The path to the local model preset directory. """ - save_serialized_object( - self, - preset_dir, - config_file=AUDIO_CONVERTER_CONFIG_FILE, - ) + saver = get_preset_saver(preset_dir) + saver.save_audio_converter(self) diff --git a/keras_hub/src/layers/preprocessing/image_converter.py b/keras_hub/src/layers/preprocessing/image_converter.py index b93e36e069..3ff6fc09ef 100644 --- a/keras_hub/src/layers/preprocessing/image_converter.py +++ b/keras_hub/src/layers/preprocessing/image_converter.py @@ -2,11 +2,10 @@ from keras_hub.src.layers.preprocessing.preprocessing_layer import ( PreprocessingLayer, ) -from keras_hub.src.utils.preset_utils import IMAGE_CONVERTER_CONFIG_FILE from keras_hub.src.utils.preset_utils import builtin_presets from keras_hub.src.utils.preset_utils import find_subclass from keras_hub.src.utils.preset_utils import get_preset_loader -from keras_hub.src.utils.preset_utils import save_serialized_object +from keras_hub.src.utils.preset_utils import get_preset_saver from keras_hub.src.utils.python_utils import classproperty @@ -110,8 +109,5 @@ def save_to_preset(self, preset_dir): Args: preset_dir: The path to the local model preset directory. """ - save_serialized_object( - self, - preset_dir, - config_file=IMAGE_CONVERTER_CONFIG_FILE, - ) + saver = get_preset_saver(preset_dir) + saver.save_image_converter(self) diff --git a/keras_hub/src/models/backbone.py b/keras_hub/src/models/backbone.py index fa065a3df5..dfe4b31b31 100644 --- a/keras_hub/src/models/backbone.py +++ b/keras_hub/src/models/backbone.py @@ -1,15 +1,10 @@ -import os - import keras from keras_hub.src.api_export import keras_hub_export from keras_hub.src.utils.keras_utils import assert_quantization_support -from keras_hub.src.utils.preset_utils import CONFIG_FILE -from keras_hub.src.utils.preset_utils import MODEL_WEIGHTS_FILE from keras_hub.src.utils.preset_utils import builtin_presets from keras_hub.src.utils.preset_utils import get_preset_loader -from keras_hub.src.utils.preset_utils import save_metadata -from keras_hub.src.utils.preset_utils import save_serialized_object +from keras_hub.src.utils.preset_utils import get_preset_saver from keras_hub.src.utils.python_utils import classproperty @@ -193,9 +188,8 @@ def save_to_preset(self, preset_dir): Args: preset_dir: The path to the local model preset directory. """ - save_serialized_object(self, preset_dir, config_file=CONFIG_FILE) - self.save_weights(os.path.join(preset_dir, MODEL_WEIGHTS_FILE)) - save_metadata(self, preset_dir) + saver = get_preset_saver(preset_dir) + saver.save_backbone(self) def enable_lora(self, rank): """Enable Lora on the backbone. diff --git a/keras_hub/src/models/preprocessor.py b/keras_hub/src/models/preprocessor.py index c12b0481f0..f0569a36f8 100644 --- a/keras_hub/src/models/preprocessor.py +++ b/keras_hub/src/models/preprocessor.py @@ -8,7 +8,7 @@ from keras_hub.src.utils.preset_utils import builtin_presets from keras_hub.src.utils.preset_utils import find_subclass from keras_hub.src.utils.preset_utils import get_preset_loader -from keras_hub.src.utils.preset_utils import save_serialized_object +from keras_hub.src.utils.preset_utils import get_preset_saver from keras_hub.src.utils.python_utils import classproperty @@ -209,7 +209,5 @@ def save_to_preset(self, preset_dir): Args: preset_dir: The path to the local model preset directory. """ - save_serialized_object(self, preset_dir, config_file=self.config_name) - for layer in self._flatten_layers(include_self=False): - if hasattr(layer, "save_to_preset"): - layer.save_to_preset(preset_dir) + saver = get_preset_saver(preset_dir) + saver.save_preprocessor(self) diff --git a/keras_hub/src/models/resnet/resnet_image_classifier.py b/keras_hub/src/models/resnet/resnet_image_classifier.py index 50c34df37b..364b52b738 100644 --- a/keras_hub/src/models/resnet/resnet_image_classifier.py +++ b/keras_hub/src/models/resnet/resnet_image_classifier.py @@ -96,17 +96,18 @@ def __init__( **kwargs, ): head_dtype = head_dtype or backbone.dtype_policy + data_format = getattr(backbone, "data_format", None) # === Layers === self.backbone = backbone self.preprocessor = preprocessor if pooling == "avg": self.pooler = keras.layers.GlobalAveragePooling2D( - data_format=backbone.data_format, dtype=head_dtype + data_format=data_format, dtype=head_dtype ) elif pooling == "max": self.pooler = keras.layers.GlobalAveragePooling2D( - data_format=backbone.data_format, dtype=head_dtype + data_format=data_format, dtype=head_dtype ) else: raise ValueError( diff --git a/keras_hub/src/models/task.py b/keras_hub/src/models/task.py index 080c67c221..2e10ab2d2a 100644 --- a/keras_hub/src/models/task.py +++ b/keras_hub/src/models/task.py @@ -1,19 +1,17 @@ -import os - import keras from rich import console as rich_console from rich import markup from rich import table as rich_table from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.preprocessor import Preprocessor from keras_hub.src.utils.keras_utils import print_msg from keras_hub.src.utils.pipeline_model import PipelineModel -from keras_hub.src.utils.preset_utils import TASK_CONFIG_FILE -from keras_hub.src.utils.preset_utils import TASK_WEIGHTS_FILE from keras_hub.src.utils.preset_utils import builtin_presets from keras_hub.src.utils.preset_utils import find_subclass from keras_hub.src.utils.preset_utils import get_preset_loader -from keras_hub.src.utils.preset_utils import save_serialized_object +from keras_hub.src.utils.preset_utils import get_preset_saver from keras_hub.src.utils.python_utils import classproperty @@ -58,10 +56,15 @@ def __init__(self, *args, compile=True, **kwargs): self.compile() def preprocess_samples(self, x, y=None, sample_weight=None): - if self.preprocessor is not None: + # If `preprocessor` is `None`, return inputs unaltered. + if self.preprocessor is None: + return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) + # If `preprocessor` is `Preprocessor` subclass, pass labels as a kwarg. + if isinstance(self.preprocessor, Preprocessor): return self.preprocessor(x, y=y, sample_weight=sample_weight) - else: - return super().preprocess_samples(x, y, sample_weight) + # For other layers and callable, do not pass the label. + x = self.preprocessor(x) + return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) def __setattr__(self, name, value): # Work around setattr issues for Keras 2 and Keras 3 torch backend. @@ -178,7 +181,10 @@ def from_preset( loader = get_preset_loader(preset) backbone_cls = loader.check_backbone_class() # Detect the correct subclass if we need to. - if cls.backbone_cls != backbone_cls: + if ( + issubclass(backbone_cls, Backbone) + and cls.backbone_cls != backbone_cls + ): cls = find_subclass(preset, cls, backbone_cls) # Specifically for classifiers, we never load task weights if # num_classes is supplied. We handle this in the task base class because @@ -232,17 +238,8 @@ def save_to_preset(self, preset_dir): Args: preset_dir: The path to the local model preset directory. """ - if self.preprocessor is None: - raise ValueError( - "Cannot save `task` to preset: `Preprocessor` is not initialized." - ) - - save_serialized_object(self, preset_dir, config_file=TASK_CONFIG_FILE) - if self.has_task_weights(): - self.save_task_weights(os.path.join(preset_dir, TASK_WEIGHTS_FILE)) - - self.preprocessor.save_to_preset(preset_dir) - self.backbone.save_to_preset(preset_dir) + saver = get_preset_saver(preset_dir) + saver.save_task(self) @property def layers(self): @@ -327,24 +324,25 @@ def add_layer(layer, info): info, ) - tokenizer = self.preprocessor.tokenizer + preprocessor = self.preprocessor + tokenizer = getattr(preprocessor, "tokenizer", None) if tokenizer: info = "Vocab size: " info += highlight_number(tokenizer.vocabulary_size()) add_layer(tokenizer, info) - image_converter = self.preprocessor.image_converter + image_converter = getattr(preprocessor, "image_converter", None) if image_converter: info = "Image size: " info += highlight_shape(image_converter.image_size()) add_layer(image_converter, info) - audio_converter = self.preprocessor.audio_converter + audio_converter = getattr(preprocessor, "audio_converter", None) if audio_converter: info = "Audio shape: " info += highlight_shape(audio_converter.audio_shape()) add_layer(audio_converter, info) # Print the to the console. - preprocessor_name = markup.escape(self.preprocessor.name) + preprocessor_name = markup.escape(preprocessor.name) console.print(bold_text(f'Preprocessor: "{preprocessor_name}"')) console.print(table) diff --git a/keras_hub/src/models/task_test.py b/keras_hub/src/models/task_test.py index bcc0b791f4..250d1d15b3 100644 --- a/keras_hub/src/models/task_test.py +++ b/keras_hub/src/models/task_test.py @@ -2,12 +2,16 @@ import pathlib import keras +import numpy as np import pytest from keras_hub.src.models.bert.bert_text_classifier import BertTextClassifier from keras_hub.src.models.causal_lm import CausalLM from keras_hub.src.models.gpt2.gpt2_causal_lm import GPT2CausalLM from keras_hub.src.models.preprocessor import Preprocessor +from keras_hub.src.models.resnet.resnet_image_classifier import ( + ResNetImageClassifier, +) from keras_hub.src.models.task import Task from keras_hub.src.models.text_classifier import TextClassifier from keras_hub.src.tests.test_case import TestCase @@ -146,10 +150,23 @@ def test_save_to_preset(self): self.assertAllClose(ref_out, new_out) @pytest.mark.large - def test_none_preprocessor(self): - model = TextClassifier.from_preset( - "bert_tiny_en_uncased", - preprocessor=None, - num_classes=2, + def test_save_to_preset_custom_backbone_and_preprocessor(self): + preprocessor = keras.layers.Rescaling(1 / 255.0) + inputs = keras.Input(shape=(None, None, 3)) + outputs = keras.layers.Dense(8)(inputs) + backbone = keras.Model(inputs, outputs) + # TODO: update to ImageClassifier after other PR. + task = ResNetImageClassifier( + backbone=backbone, + preprocessor=preprocessor, + num_classes=10, ) - self.assertEqual(model.preprocessor, None) + + save_dir = self.get_temp_dir() + task.save_to_preset(save_dir) + batch = np.random.randint(0, 256, size=(2, 224, 224, 3)) + expected = task.predict(batch) + + restored_task = ResNetImageClassifier.from_preset(save_dir) + actual = restored_task.predict(batch) + self.assertAllClose(expected, actual) diff --git a/keras_hub/src/tokenizers/tokenizer.py b/keras_hub/src/tokenizers/tokenizer.py index 914e41a21e..b97efae444 100644 --- a/keras_hub/src/tokenizers/tokenizer.py +++ b/keras_hub/src/tokenizers/tokenizer.py @@ -10,7 +10,7 @@ from keras_hub.src.utils.preset_utils import find_subclass from keras_hub.src.utils.preset_utils import get_file from keras_hub.src.utils.preset_utils import get_preset_loader -from keras_hub.src.utils.preset_utils import save_serialized_object +from keras_hub.src.utils.preset_utils import get_preset_saver from keras_hub.src.utils.python_utils import classproperty from keras_hub.src.utils.tensor_utils import preprocessing_function @@ -189,11 +189,8 @@ def save_to_preset(self, preset_dir): Args: preset_dir: The path to the local model preset directory. """ - save_serialized_object(self, preset_dir, config_file=self.config_name) - subdir = self.config_name.split(".")[0] - asset_dir = os.path.join(preset_dir, ASSET_DIR, subdir) - os.makedirs(asset_dir, exist_ok=True) - self.save_assets(asset_dir) + saver = get_preset_saver(preset_dir) + saver.save_tokenizer(self) @preprocessing_function def call(self, inputs, *args, training=None, **kwargs): diff --git a/keras_hub/src/utils/preset_utils.py b/keras_hub/src/utils/preset_utils.py index c628c2d012..65af19df7f 100644 --- a/keras_hub/src/utils/preset_utils.py +++ b/keras_hub/src/utils/preset_utils.py @@ -267,64 +267,6 @@ def check_file_exists(preset, path): return True -def get_tokenizer(layer): - """Get the tokenizer from any KerasHub model or layer.""" - # Avoid circular import. - from keras_hub.src.tokenizers.tokenizer import Tokenizer - - if isinstance(layer, Tokenizer): - return layer - if hasattr(layer, "tokenizer"): - return layer.tokenizer - if hasattr(layer, "preprocessor"): - return getattr(layer.preprocessor, "tokenizer", None) - return None - - -def recursive_pop(config, key): - """Remove a key from a nested config object""" - config.pop(key, None) - for value in config.values(): - if isinstance(value, dict): - recursive_pop(value, key) - - -# TODO: refactor saving routines into a PresetSaver class? -def make_preset_dir(preset): - os.makedirs(preset, exist_ok=True) - - -def save_serialized_object( - layer, - preset, - config_file=CONFIG_FILE, - config_to_skip=[], -): - make_preset_dir(preset) - config_path = os.path.join(preset, config_file) - config = keras.saving.serialize_keras_object(layer) - config_to_skip += ["compile_config", "build_config"] - for c in config_to_skip: - recursive_pop(config, c) - with open(config_path, "w") as config_file: - config_file.write(json.dumps(config, indent=4)) - - -def save_metadata(layer, preset): - from keras_hub.src.version_utils import __version__ as keras_hub_version - - keras_version = keras.version() if hasattr(keras, "version") else None - metadata = { - "keras_version": keras_version, - "keras_hub_version": keras_hub_version, - "parameter_count": layer.count_params(), - "date_saved": datetime.datetime.now().strftime("%Y-%m-%d@%H:%M:%S"), - } - metadata_path = os.path.join(preset, METADATA_FILE) - with open(metadata_path, "w") as metadata_file: - metadata_file.write(json.dumps(metadata, indent=4)) - - def _validate_backbone(preset): config_path = os.path.join(preset, CONFIG_FILE) if not os.path.exists(config_path): @@ -518,6 +460,8 @@ def load_serialized_object(config, **kwargs): def check_config_class(config): """Validate a preset is being loaded on the correct class.""" registered_name = config["registered_name"] + if registered_name in ("Functional", "Sequential"): + return keras.Model cls = keras.saving.get_registered_object(registered_name) if cls is None: raise ValueError( @@ -600,6 +544,13 @@ def get_preset_loader(preset): ) +def get_preset_saver(preset): + # Unlike loading, we only support one form of saving; Keras serialized + # configs and saved weights. We keep the rough API structure as loading + # just for simplicity. + return KerasPresetSaver(preset) + + class PresetLoader: def __init__(self, preset, config): self.config = config @@ -684,7 +635,8 @@ def load_backbone(self, cls, load_weights, **kwargs): def load_tokenizer(self, cls, config_name=TOKENIZER_CONFIG_FILE, **kwargs): tokenizer_config = load_json(self.preset, config_name) tokenizer = load_serialized_object(tokenizer_config, **kwargs) - tokenizer.load_preset_assets(self.preset) + if hasattr(tokenizer, "load_preset_assets"): + tokenizer.load_preset_assets(self.preset) return tokenizer def load_audio_converter(self, cls, **kwargs): @@ -709,7 +661,9 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs): ) # We found a `task.json` with a complete config for our class. task = load_serialized_object(task_config, **kwargs) - if task.preprocessor: + if task.preprocessor and hasattr( + task.preprocessor, "load_preset_assets" + ): task.preprocessor.load_preset_assets(self.preset) if load_weights: has_task_weights = check_file_exists(self.preset, TASK_WEIGHTS_FILE) @@ -735,5 +689,93 @@ def load_preprocessor( return super().load_preprocessor(cls, **kwargs) # We found a `preprocessing.json` with a complete config for our class. preprocessor = load_serialized_object(preprocessor_json, **kwargs) - preprocessor.load_preset_assets(self.preset) + if hasattr(preprocessor, "load_preset_assets"): + preprocessor.load_preset_assets(self.preset) return preprocessor + + +class KerasPresetSaver: + def __init__(self, preset_dir): + os.makedirs(preset_dir, exist_ok=True) + self.preset_dir = preset_dir + + def save_backbone(self, backbone): + self._save_serialized_object(backbone, config_file=CONFIG_FILE) + backbone_weight_path = os.path.join(self.preset_dir, MODEL_WEIGHTS_FILE) + backbone.save_weights(backbone_weight_path) + self._save_metadata(backbone) + + def save_tokenizer(self, tokenizer): + config_file = TOKENIZER_CONFIG_FILE + if hasattr(tokenizer, "config_file"): + config_file = tokenizer.config_file + self._save_serialized_object(tokenizer, config_file) + # Save assets. + subdir = config_file.split(".")[0] + asset_dir = os.path.join(self.preset_dir, ASSET_DIR, subdir) + os.makedirs(asset_dir, exist_ok=True) + tokenizer.save_assets(asset_dir) + + def save_audio_converter(self, converter): + self._save_serialized_object(converter, AUDIO_CONVERTER_CONFIG_FILE) + + def save_image_converter(self, converter): + self._save_serialized_object(converter, IMAGE_CONVERTER_CONFIG_FILE) + + def save_task(self, task): + # Save task specific config and weights. + self._save_serialized_object(task, TASK_CONFIG_FILE) + if task.has_task_weights(): + task_weight_path = os.path.join(self.preset_dir, TASK_WEIGHTS_FILE) + task.save_task_weights(task_weight_path) + # Save backbone. + if hasattr(task.backbone, "save_to_preset"): + task.backbone.save_to_preset(self.preset_dir) + else: + # Allow saving a `keras.Model` that is not a backbone subclass. + self.save_backbone(task.backbone) + # Save preprocessor. + if task.preprocessor and hasattr(task.preprocessor, "save_to_preset"): + task.preprocessor.save_to_preset(self.preset_dir) + else: + # Allow saving a `keras.Layer` that is not a preprocessor subclass. + self.save_preprocessor(task.preprocessor) + + def save_preprocessor(self, preprocessor): + config_file = PREPROCESSOR_CONFIG_FILE + if hasattr(preprocessor, "config_file"): + config_file = preprocessor.config_file + self._save_serialized_object(preprocessor, config_file) + for layer in preprocessor._flatten_layers(include_self=False): + if hasattr(layer, "save_to_preset"): + layer.save_to_preset(self.preset_dir) + + def _recursive_pop(self, config, key): + """Remove a key from a nested config object""" + config.pop(key, None) + for value in config.values(): + if isinstance(value, dict): + self._recursive_pop(value, key) + + def _save_serialized_object(self, layer, config_file): + config_path = os.path.join(self.preset_dir, config_file) + config = keras.saving.serialize_keras_object(layer) + config_to_skip = ["compile_config", "build_config"] + for key in config_to_skip: + self._recursive_pop(config, key) + with open(config_path, "w") as config_file: + config_file.write(json.dumps(config, indent=4)) + + def _save_metadata(self, layer): + from keras_hub.src.version_utils import __version__ as keras_hub_version + + keras_version = keras.version() if hasattr(keras, "version") else None + metadata = { + "keras_version": keras_version, + "keras_hub_version": keras_hub_version, + "parameter_count": layer.count_params(), + "date_saved": datetime.datetime.now().strftime("%Y-%m-%d@%H:%M:%S"), + } + metadata_path = os.path.join(self.preset_dir, METADATA_FILE) + with open(metadata_path, "w") as metadata_file: + metadata_file.write(json.dumps(metadata, indent=4))