Skip to content

Commit

Permalink
Allow backbone to be any functional, preprocessor any callable
Browse files Browse the repository at this point in the history
Consolidate saving into a saving class to match loading.
Plenty more cleanup to do there probably, but this will at
least bring our saving and loading routines into a common
flow.
  • Loading branch information
mattdangerw committed Oct 1, 2024
1 parent a77595e commit 0b005a3
Show file tree
Hide file tree
Showing 9 changed files with 166 additions and 127 deletions.
10 changes: 3 additions & 7 deletions keras_hub/src/layers/preprocessing/audio_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
from keras_hub.src.layers.preprocessing.preprocessing_layer import (
PreprocessingLayer,
)
from keras_hub.src.utils.preset_utils import AUDIO_CONVERTER_CONFIG_FILE
from keras_hub.src.utils.preset_utils import builtin_presets
from keras_hub.src.utils.preset_utils import find_subclass
from keras_hub.src.utils.preset_utils import get_preset_loader
from keras_hub.src.utils.preset_utils import save_serialized_object
from keras_hub.src.utils.preset_utils import get_preset_saver
from keras_hub.src.utils.python_utils import classproperty


Expand Down Expand Up @@ -101,8 +100,5 @@ def save_to_preset(self, preset_dir):
Args:
preset_dir: The path to the local model preset directory.
"""
save_serialized_object(
self,
preset_dir,
config_file=AUDIO_CONVERTER_CONFIG_FILE,
)
saver = get_preset_saver(preset_dir)
saver.save_audio_converter(self)
10 changes: 3 additions & 7 deletions keras_hub/src/layers/preprocessing/image_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
from keras_hub.src.layers.preprocessing.preprocessing_layer import (
PreprocessingLayer,
)
from keras_hub.src.utils.preset_utils import IMAGE_CONVERTER_CONFIG_FILE
from keras_hub.src.utils.preset_utils import builtin_presets
from keras_hub.src.utils.preset_utils import find_subclass
from keras_hub.src.utils.preset_utils import get_preset_loader
from keras_hub.src.utils.preset_utils import save_serialized_object
from keras_hub.src.utils.preset_utils import get_preset_saver
from keras_hub.src.utils.python_utils import classproperty


Expand Down Expand Up @@ -110,8 +109,5 @@ def save_to_preset(self, preset_dir):
Args:
preset_dir: The path to the local model preset directory.
"""
save_serialized_object(
self,
preset_dir,
config_file=IMAGE_CONVERTER_CONFIG_FILE,
)
saver = get_preset_saver(preset_dir)
saver.save_image_converter(self)
12 changes: 3 additions & 9 deletions keras_hub/src/models/backbone.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
import os

import keras

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.utils.keras_utils import assert_quantization_support
from keras_hub.src.utils.preset_utils import CONFIG_FILE
from keras_hub.src.utils.preset_utils import MODEL_WEIGHTS_FILE
from keras_hub.src.utils.preset_utils import builtin_presets
from keras_hub.src.utils.preset_utils import get_preset_loader
from keras_hub.src.utils.preset_utils import save_metadata
from keras_hub.src.utils.preset_utils import save_serialized_object
from keras_hub.src.utils.preset_utils import get_preset_saver
from keras_hub.src.utils.python_utils import classproperty


Expand Down Expand Up @@ -193,9 +188,8 @@ def save_to_preset(self, preset_dir):
Args:
preset_dir: The path to the local model preset directory.
"""
save_serialized_object(self, preset_dir, config_file=CONFIG_FILE)
self.save_weights(os.path.join(preset_dir, MODEL_WEIGHTS_FILE))
save_metadata(self, preset_dir)
saver = get_preset_saver(preset_dir)
saver.save_backbone(self)

def enable_lora(self, rank):
"""Enable Lora on the backbone.
Expand Down
8 changes: 3 additions & 5 deletions keras_hub/src/models/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from keras_hub.src.utils.preset_utils import builtin_presets
from keras_hub.src.utils.preset_utils import find_subclass
from keras_hub.src.utils.preset_utils import get_preset_loader
from keras_hub.src.utils.preset_utils import save_serialized_object
from keras_hub.src.utils.preset_utils import get_preset_saver
from keras_hub.src.utils.python_utils import classproperty


Expand Down Expand Up @@ -209,7 +209,5 @@ def save_to_preset(self, preset_dir):
Args:
preset_dir: The path to the local model preset directory.
"""
save_serialized_object(self, preset_dir, config_file=self.config_name)
for layer in self._flatten_layers(include_self=False):
if hasattr(layer, "save_to_preset"):
layer.save_to_preset(preset_dir)
saver = get_preset_saver(preset_dir)
saver.save_preprocessor(self)
5 changes: 3 additions & 2 deletions keras_hub/src/models/resnet/resnet_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,17 +96,18 @@ def __init__(
**kwargs,
):
head_dtype = head_dtype or backbone.dtype_policy
data_format = getattr(backbone, "data_format", None)

# === Layers ===
self.backbone = backbone
self.preprocessor = preprocessor
if pooling == "avg":
self.pooler = keras.layers.GlobalAveragePooling2D(
data_format=backbone.data_format, dtype=head_dtype
data_format=data_format, dtype=head_dtype
)
elif pooling == "max":
self.pooler = keras.layers.GlobalAveragePooling2D(
data_format=backbone.data_format, dtype=head_dtype
data_format=data_format, dtype=head_dtype
)
else:
raise ValueError(
Expand Down
46 changes: 22 additions & 24 deletions keras_hub/src/models/task.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
import os

import keras
from rich import console as rich_console
from rich import markup
from rich import table as rich_table

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.backbone import Backbone
from keras_hub.src.models.preprocessor import Preprocessor
from keras_hub.src.utils.keras_utils import print_msg
from keras_hub.src.utils.pipeline_model import PipelineModel
from keras_hub.src.utils.preset_utils import TASK_CONFIG_FILE
from keras_hub.src.utils.preset_utils import TASK_WEIGHTS_FILE
from keras_hub.src.utils.preset_utils import builtin_presets
from keras_hub.src.utils.preset_utils import find_subclass
from keras_hub.src.utils.preset_utils import get_preset_loader
from keras_hub.src.utils.preset_utils import save_serialized_object
from keras_hub.src.utils.preset_utils import get_preset_saver
from keras_hub.src.utils.python_utils import classproperty


Expand Down Expand Up @@ -58,10 +56,15 @@ def __init__(self, *args, compile=True, **kwargs):
self.compile()

def preprocess_samples(self, x, y=None, sample_weight=None):
if self.preprocessor is not None:
# If `preprocessor` is `None`, return inputs unaltered.
if self.preprocessor is None:
return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
# If `preprocessor` is `Preprocessor` subclass, pass labels as a kwarg.
if isinstance(self.preprocessor, Preprocessor):
return self.preprocessor(x, y=y, sample_weight=sample_weight)
else:
return super().preprocess_samples(x, y, sample_weight)
# For other layers and callable, do not pass the label.
x = self.preprocessor(x)
return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)

def __setattr__(self, name, value):
# Work around setattr issues for Keras 2 and Keras 3 torch backend.
Expand Down Expand Up @@ -178,7 +181,10 @@ def from_preset(
loader = get_preset_loader(preset)
backbone_cls = loader.check_backbone_class()
# Detect the correct subclass if we need to.
if cls.backbone_cls != backbone_cls:
if (
issubclass(backbone_cls, Backbone)
and cls.backbone_cls != backbone_cls
):
cls = find_subclass(preset, cls, backbone_cls)
# Specifically for classifiers, we never load task weights if
# num_classes is supplied. We handle this in the task base class because
Expand Down Expand Up @@ -232,17 +238,8 @@ def save_to_preset(self, preset_dir):
Args:
preset_dir: The path to the local model preset directory.
"""
if self.preprocessor is None:
raise ValueError(
"Cannot save `task` to preset: `Preprocessor` is not initialized."
)

save_serialized_object(self, preset_dir, config_file=TASK_CONFIG_FILE)
if self.has_task_weights():
self.save_task_weights(os.path.join(preset_dir, TASK_WEIGHTS_FILE))

self.preprocessor.save_to_preset(preset_dir)
self.backbone.save_to_preset(preset_dir)
saver = get_preset_saver(preset_dir)
saver.save_task(self)

@property
def layers(self):
Expand Down Expand Up @@ -327,24 +324,25 @@ def add_layer(layer, info):
info,
)

tokenizer = self.preprocessor.tokenizer
preprocessor = self.preprocessor
tokenizer = getattr(preprocessor, "tokenizer", None)
if tokenizer:
info = "Vocab size: "
info += highlight_number(tokenizer.vocabulary_size())
add_layer(tokenizer, info)
image_converter = self.preprocessor.image_converter
image_converter = getattr(preprocessor, "image_converter", None)
if image_converter:
info = "Image size: "
info += highlight_shape(image_converter.image_size())
add_layer(image_converter, info)
audio_converter = self.preprocessor.audio_converter
audio_converter = getattr(preprocessor, "audio_converter", None)
if audio_converter:
info = "Audio shape: "
info += highlight_shape(audio_converter.audio_shape())
add_layer(audio_converter, info)

# Print the to the console.
preprocessor_name = markup.escape(self.preprocessor.name)
preprocessor_name = markup.escape(preprocessor.name)
console.print(bold_text(f'Preprocessor: "{preprocessor_name}"'))
console.print(table)

Expand Down
29 changes: 23 additions & 6 deletions keras_hub/src/models/task_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@
import pathlib

import keras
import numpy as np
import pytest

from keras_hub.src.models.bert.bert_text_classifier import BertTextClassifier
from keras_hub.src.models.causal_lm import CausalLM
from keras_hub.src.models.gpt2.gpt2_causal_lm import GPT2CausalLM
from keras_hub.src.models.preprocessor import Preprocessor
from keras_hub.src.models.resnet.resnet_image_classifier import (
ResNetImageClassifier,
)
from keras_hub.src.models.task import Task
from keras_hub.src.models.text_classifier import TextClassifier
from keras_hub.src.tests.test_case import TestCase
Expand Down Expand Up @@ -146,10 +150,23 @@ def test_save_to_preset(self):
self.assertAllClose(ref_out, new_out)

@pytest.mark.large
def test_none_preprocessor(self):
model = TextClassifier.from_preset(
"bert_tiny_en_uncased",
preprocessor=None,
num_classes=2,
def test_save_to_preset_custom_backbone_and_preprocessor(self):
preprocessor = keras.layers.Rescaling(1 / 255.0)
inputs = keras.Input(shape=(None, None, 3))
outputs = keras.layers.Dense(8)(inputs)
backbone = keras.Model(inputs, outputs)
# TODO: update to ImageClassifier after other PR.
task = ResNetImageClassifier(
backbone=backbone,
preprocessor=preprocessor,
num_classes=10,
)
self.assertEqual(model.preprocessor, None)

save_dir = self.get_temp_dir()
task.save_to_preset(save_dir)
batch = np.random.randint(0, 256, size=(2, 224, 224, 3))
expected = task.predict(batch)

restored_task = ResNetImageClassifier.from_preset(save_dir)
actual = restored_task.predict(batch)
self.assertAllClose(expected, actual)
9 changes: 3 additions & 6 deletions keras_hub/src/tokenizers/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from keras_hub.src.utils.preset_utils import find_subclass
from keras_hub.src.utils.preset_utils import get_file
from keras_hub.src.utils.preset_utils import get_preset_loader
from keras_hub.src.utils.preset_utils import save_serialized_object
from keras_hub.src.utils.preset_utils import get_preset_saver
from keras_hub.src.utils.python_utils import classproperty
from keras_hub.src.utils.tensor_utils import preprocessing_function

Expand Down Expand Up @@ -189,11 +189,8 @@ def save_to_preset(self, preset_dir):
Args:
preset_dir: The path to the local model preset directory.
"""
save_serialized_object(self, preset_dir, config_file=self.config_name)
subdir = self.config_name.split(".")[0]
asset_dir = os.path.join(preset_dir, ASSET_DIR, subdir)
os.makedirs(asset_dir, exist_ok=True)
self.save_assets(asset_dir)
saver = get_preset_saver(preset_dir)
saver.save_tokenizer(self)

@preprocessing_function
def call(self, inputs, *args, training=None, **kwargs):
Expand Down
Loading

0 comments on commit 0b005a3

Please sign in to comment.