From b06697678c625337a327baf0516a03685cf73452 Mon Sep 17 00:00:00 2001 From: Matt Watson Date: Tue, 10 Sep 2024 10:23:28 -0700 Subject: [PATCH] Finish up resnet - Add presets. - Add converter script. - Add preprocessing with auto resizing. --- keras_nlp/api/layers/__init__.py | 3 + keras_nlp/api/models/__init__.py | 6 + keras_nlp/src/models/image_classifier.py | 5 - .../models/image_classifier_preprocessor.py | 83 ++++++++++++ keras_nlp/src/models/resnet/__init__.py | 6 + .../models/resnet/resnet_image_classifier.py | 10 +- .../resnet_image_classifier_preprocessor.py | 28 +++++ .../models/resnet/resnet_image_converter.py | 23 ++++ keras_nlp/src/models/resnet/resnet_presets.py | 95 ++++++++++++++ keras_nlp/src/models/text_classifier.py | 3 - keras_nlp/src/utils/preset_utils.py | 26 +--- keras_nlp/src/utils/timm/convert_resnet.py | 15 +++ keras_nlp/src/utils/timm/preset_loader.py | 22 +++- .../convert_resnet_checkpoints.py | 119 ++++++++++++++++++ 14 files changed, 411 insertions(+), 33 deletions(-) create mode 100644 keras_nlp/src/models/image_classifier_preprocessor.py create mode 100644 keras_nlp/src/models/resnet/resnet_image_classifier_preprocessor.py create mode 100644 keras_nlp/src/models/resnet/resnet_image_converter.py create mode 100644 keras_nlp/src/models/resnet/resnet_presets.py create mode 100644 tools/checkpoint_conversion/convert_resnet_checkpoints.py diff --git a/keras_nlp/api/layers/__init__.py b/keras_nlp/api/layers/__init__.py index 8b92cc11b0..7def279b19 100644 --- a/keras_nlp/api/layers/__init__.py +++ b/keras_nlp/api/layers/__init__.py @@ -53,6 +53,9 @@ from keras_nlp.src.models.pali_gemma.pali_gemma_image_converter import ( PaliGemmaImageConverter, ) +from keras_nlp.src.models.resnet.resnet_image_converter import ( + ResNetImageConverter, +) from keras_nlp.src.models.whisper.whisper_audio_converter import ( WhisperAudioConverter, ) diff --git a/keras_nlp/api/models/__init__.py b/keras_nlp/api/models/__init__.py index 1ebe476424..2b3ffbb30b 100644 --- a/keras_nlp/api/models/__init__.py +++ b/keras_nlp/api/models/__init__.py @@ -172,6 +172,9 @@ ) from keras_nlp.src.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer from keras_nlp.src.models.image_classifier import ImageClassifier +from keras_nlp.src.models.image_classifier_preprocessor import ( + ImageClassifierPreprocessor, +) from keras_nlp.src.models.llama3.llama3_backbone import Llama3Backbone from keras_nlp.src.models.llama3.llama3_causal_lm import Llama3CausalLM from keras_nlp.src.models.llama3.llama3_causal_lm_preprocessor import ( @@ -231,6 +234,9 @@ from keras_nlp.src.models.resnet.resnet_image_classifier import ( ResNetImageClassifier, ) +from keras_nlp.src.models.resnet.resnet_image_classifier_preprocessor import ( + ResNetImageClassifierPreprocessor, +) from keras_nlp.src.models.roberta.roberta_backbone import RobertaBackbone from keras_nlp.src.models.roberta.roberta_masked_lm import RobertaMaskedLM from keras_nlp.src.models.roberta.roberta_masked_lm_preprocessor import ( diff --git a/keras_nlp/src/models/image_classifier.py b/keras_nlp/src/models/image_classifier.py index f0cc031dbc..0606a29821 100644 --- a/keras_nlp/src/models/image_classifier.py +++ b/keras_nlp/src/models/image_classifier.py @@ -33,11 +33,6 @@ class ImageClassifier(Task): used to load a pre-trained config and weights. """ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - # Default compilation. - self.compile() - def compile( self, optimizer="auto", diff --git a/keras_nlp/src/models/image_classifier_preprocessor.py b/keras_nlp/src/models/image_classifier_preprocessor.py new file mode 100644 index 0000000000..c354169893 --- /dev/null +++ b/keras_nlp/src/models/image_classifier_preprocessor.py @@ -0,0 +1,83 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.preprocessor import Preprocessor +from keras_nlp.src.utils.tensor_utils import preprocessing_function + + +@keras_nlp_export("keras_nlp.models.ImageClassifierPreprocessor") +class ImageClassifierPreprocessor(Preprocessor): + """Base class for image classification preprocessing layers. + + `ImageClassifierPreprocessor` tasks wraps a + `keras_nlp.layers.ImageConverter` to create a preprocessing layer for + image classification tasks. It is intended to be paired with a + `keras_nlp.models.ImageClassifier` task. + + All `ImageClassifierPreprocessor` take inputs three inputs, `x`, `y`, and + `sample_weight`. `x`, the first input, should always be included. It can + be a image or batch of images. See examples below. `y` and `sample_weight` + are optional inputs that will be passed through unaltered. Usually, `y` will + be the classification label, and `sample_weight` will not be provided. + + The layer will output either `x`, an `(x, y)` tuple if labels were provided, + or an `(x, y, sample_weight)` tuple if labels and sample weight were + provided. `x` will be the input images after all model preprocessing has + been applied. + + All `ImageClassifierPreprocessor` tasks include a `from_preset()` + constructor which can be used to load a pre-trained config and vocabularies. + You can call the `from_preset()` constructor directly on this base class, in + which case the correct class for your model will be automatically + instantiated. + + Examples. + ```python + preprocessor = keras_nlp.models.ImageClassifierPreprocessor.from_preset( + "resnet_50", + ) + + # Resize a single image for resnet 50. + x = np.ones((512, 512, 3)) + x = preprocessor(x) + + # Resize a labeled image. + x, y = np.ones((512, 512, 3)), 1 + x, y = preprocessor(x, y) + + # Resize a batch of labeled images. + x, y = [np.ones((512, 512, 3)), np.zeros((512, 512, 3))], [1, 0] + x, y = preprocessor(x, y) + + # Use a `tf.data.Dataset`. + ds = tf.data.Dataset.from_tensor_slices((x, y)).batch(2) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + ``` + """ + + def __init__( + self, + image_converter=None, + **kwargs, + ): + super().__init__(**kwargs) + self.image_converter = image_converter + + @preprocessing_function + def call(self, x, y=None, sample_weight=None): + if self.image_converter: + x = self.image_converter(x) + return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) diff --git a/keras_nlp/src/models/resnet/__init__.py b/keras_nlp/src/models/resnet/__init__.py index 3364a6bd16..a09d7a80bb 100644 --- a/keras_nlp/src/models/resnet/__init__.py +++ b/keras_nlp/src/models/resnet/__init__.py @@ -11,3 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from keras_nlp.src.models.resnet.resnet_backbone import ResNetBackbone +from keras_nlp.src.models.resnet.resnet_presets import backbone_presets +from keras_nlp.src.utils.preset_utils import register_presets + +register_presets(backbone_presets, ResNetBackbone) diff --git a/keras_nlp/src/models/resnet/resnet_image_classifier.py b/keras_nlp/src/models/resnet/resnet_image_classifier.py index 815dc7fcca..5dc2f25828 100644 --- a/keras_nlp/src/models/resnet/resnet_image_classifier.py +++ b/keras_nlp/src/models/resnet/resnet_image_classifier.py @@ -16,6 +16,9 @@ from keras_nlp.src.api_export import keras_nlp_export from keras_nlp.src.models.image_classifier import ImageClassifier from keras_nlp.src.models.resnet.resnet_backbone import ResNetBackbone +from keras_nlp.src.models.resnet.resnet_image_classifier_preprocessor import ( + ResNetImageClassifierPreprocessor, +) @keras_nlp_export("keras_nlp.models.ResNetImageClassifier") @@ -88,21 +91,22 @@ class ResNetImageClassifier(ImageClassifier): """ backbone_cls = ResNetBackbone + preprocessor_cls = ResNetImageClassifierPreprocessor def __init__( self, backbone, num_classes, - activation="softmax", + preprocessor=None, + activation=None, head_dtype=None, - preprocessor=None, # adding this dummy arg for saved model test - # TODO: once preprocessor flow is figured out, this needs to be updated **kwargs, ): head_dtype = head_dtype or backbone.dtype_policy # === Layers === self.backbone = backbone + self.preprocessor = preprocessor self.output_dense = keras.layers.Dense( num_classes, activation=activation, diff --git a/keras_nlp/src/models/resnet/resnet_image_classifier_preprocessor.py b/keras_nlp/src/models/resnet/resnet_image_classifier_preprocessor.py new file mode 100644 index 0000000000..0da2c67deb --- /dev/null +++ b/keras_nlp/src/models/resnet/resnet_image_classifier_preprocessor.py @@ -0,0 +1,28 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.image_classifier_preprocessor import ( + ImageClassifierPreprocessor, +) +from keras_nlp.src.models.resnet.resnet_backbone import ResNetBackbone +from keras_nlp.src.models.resnet.resnet_image_converter import ( + ResNetImageConverter, +) + + +@keras_nlp_export("keras_nlp.models.ResNetImageClassifierPreprocessor") +class ResNetImageClassifierPreprocessor(ImageClassifierPreprocessor): + backbone_cls = ResNetBackbone + image_converter_cls = ResNetImageConverter diff --git a/keras_nlp/src/models/resnet/resnet_image_converter.py b/keras_nlp/src/models/resnet/resnet_image_converter.py new file mode 100644 index 0000000000..876dffc96d --- /dev/null +++ b/keras_nlp/src/models/resnet/resnet_image_converter.py @@ -0,0 +1,23 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.layers.preprocessing.resizing_image_converter import ( + ResizingImageConverter, +) +from keras_nlp.src.models.resnet.resnet_backbone import ResNetBackbone + + +@keras_nlp_export("keras_nlp.layers.ResNetImageConverter") +class ResNetImageConverter(ResizingImageConverter): + backbone_cls = ResNetBackbone diff --git a/keras_nlp/src/models/resnet/resnet_presets.py b/keras_nlp/src/models/resnet/resnet_presets.py new file mode 100644 index 0000000000..ff092a1338 --- /dev/null +++ b/keras_nlp/src/models/resnet/resnet_presets.py @@ -0,0 +1,95 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ResNet preset configurations.""" + +backbone_presets = { + "resnet_18_imagenet": { + "metadata": { + "description": ( + "18-layer ResNet model pre-trained on the ImageNet 1k dataset " + "at a 224x224 resolution." + ), + "params": 11186112, + "official_name": "ResNet", + "path": "resnet", + "model_card": "https://arxiv.org/abs/2110.00476", + }, + "kaggle_handle": "kaggle://kerashub/resnetv1/keras/resnet_18_imagenet/1", + }, + "resnet_50_imagenet": { + "metadata": { + "description": ( + "50-layer ResNet model pre-trained on the ImageNet 1k dataset " + "at a 224x224 resolution." + ), + "params": 23561152, + "official_name": "ResNet", + "path": "resnet", + "model_card": "https://arxiv.org/abs/2110.00476", + }, + "kaggle_handle": "kaggle://kerashub/resnetv1/keras/resnet_50_imagenet/1", + }, + "resnet_101_imagenet": { + "metadata": { + "description": ( + "101-layer ResNet model pre-trained on the ImageNet 1k dataset " + "at a 224x224 resolution." + ), + "params": 42605504, + "official_name": "ResNet", + "path": "resnet", + "model_card": "https://arxiv.org/abs/2110.00476", + }, + "kaggle_handle": "kaggle://kerashub/resnetv1/keras/resnet_101_imagenet/1", + }, + "resnet_152_imagenet": { + "metadata": { + "description": ( + "152-layer ResNet model pre-trained on the ImageNet 1k dataset " + "at a 224x224 resolution." + ), + "params": 58295232, + "official_name": "ResNet", + "path": "resnet", + "model_card": "https://arxiv.org/abs/2110.00476", + }, + "kaggle_handle": "kaggle://kerashub/resnetv1/keras/resnet_152_imagenet/1", + }, + "resnet_v2_50_imagenet": { + "metadata": { + "description": ( + "50-layer ResNetV2 model pre-trained on the ImageNet 1k " + "dataset at a 224x224 resolution." + ), + "params": 23561152, + "official_name": "ResNet", + "path": "resnet", + "model_card": "https://arxiv.org/abs/2110.00476", + }, + "kaggle_handle": "kaggle://kerashub/resnetv2/keras/resnet_v2_50_imagenet/1", + }, + "resnet_v2_101_imagenet": { + "metadata": { + "description": ( + "101-layer ResNetV2 model pre-trained on the ImageNet 1k " + "dataset at a 224x224 resolution." + ), + "params": 42605504, + "official_name": "ResNet", + "path": "resnet", + "model_card": "https://arxiv.org/abs/2110.00476", + }, + "kaggle_handle": "kaggle://kerashub/resnetv2/keras/resnet_v2_101_imagenet/1", + }, +} diff --git a/keras_nlp/src/models/text_classifier.py b/keras_nlp/src/models/text_classifier.py index d28985f67f..fd84cac3a3 100644 --- a/keras_nlp/src/models/text_classifier.py +++ b/keras_nlp/src/models/text_classifier.py @@ -63,9 +63,6 @@ class TextClassifier(Task): ``` """ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - def compile( self, optimizer="auto", diff --git a/keras_nlp/src/utils/preset_utils.py b/keras_nlp/src/utils/preset_utils.py index 5215ac7e59..a935fd160b 100644 --- a/keras_nlp/src/utils/preset_utils.py +++ b/keras_nlp/src/utils/preset_utils.py @@ -227,7 +227,7 @@ def get_file(preset, path): else: raise ValueError(message) elif os.path.exists(preset): - # Assume a local filepath. + # Assume a local filepath.pyth local_path = os.path.join(preset, path) if not os.path.exists(local_path): raise FileNotFoundError( @@ -345,19 +345,9 @@ def save_metadata(layer, preset): metadata_file.write(json.dumps(metadata, indent=4)) -def _validate_tokenizer(preset, allow_incomplete=False): +def _validate_tokenizer(preset): if not check_file_exists(preset, TOKENIZER_CONFIG_FILE): - if allow_incomplete: - logging.warning( - f"`{TOKENIZER_CONFIG_FILE}` is missing from the preset directory `{preset}`." - ) - return - else: - raise FileNotFoundError( - f"`{TOKENIZER_CONFIG_FILE}` is missing from the preset directory `{preset}`. " - "To upload the model without a tokenizer, " - "set `allow_incomplete=True`." - ) + return config_path = get_file(preset, TOKENIZER_CONFIG_FILE) try: with open(config_path, encoding="utf-8") as config_file: @@ -485,7 +475,6 @@ def delete_model_card(preset): def upload_preset( uri, preset, - allow_incomplete=False, ): """Upload a preset directory to a model hub. @@ -497,9 +486,6 @@ def upload_preset( `hf://[/]` will be uploaded to the Hugging Face Hub. preset: The path to the local model preset directory. - allow_incomplete: If True, allows the upload of presets without - a tokenizer configuration. Otherwise, a tokenizer - is required. """ # Check if preset directory exists. @@ -507,7 +493,7 @@ def upload_preset( raise FileNotFoundError(f"The preset directory {preset} doesn't exist.") _validate_backbone(preset) - _validate_tokenizer(preset, allow_incomplete) + _validate_tokenizer(preset) if uri.startswith(KAGGLE_PREFIX): if kagglehub is None: @@ -695,7 +681,7 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs): kwargs["backbone"] = self.load_backbone( backbone_class, load_weights, **backbone_kwargs ) - if "preprocessor" not in kwargs: + if "preprocessor" not in kwargs and cls.preprocessor_cls: kwargs["preprocessor"] = self.load_preprocessor( cls.preprocessor_cls, ) @@ -760,7 +746,7 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs): ) # We found a `task.json` with a complete config for our class. task = load_serialized_object(task_config, **kwargs) - if task.preprocessor is not None: + if task.preprocessor and task.preprocessor.tokenizer: task.preprocessor.tokenizer.load_preset_assets(self.preset) if load_weights: has_task_weights = check_file_exists(self.preset, TASK_WEIGHTS_FILE) diff --git a/keras_nlp/src/utils/timm/convert_resnet.py b/keras_nlp/src/utils/timm/convert_resnet.py index 8671a8f6b3..81456c4225 100644 --- a/keras_nlp/src/utils/timm/convert_resnet.py +++ b/keras_nlp/src/utils/timm/convert_resnet.py @@ -158,3 +158,18 @@ def port_batch_normalization(keras_layer_name, hf_weight_prefix): normalization_layer.input_mean = mean normalization_layer.input_variance = [s**2 for s in std] normalization_layer.build(normalization_layer._build_input_shape) + + +def convert_head(task, loader, timm_config): + v2 = "resnetv2_" in timm_config["architecture"] + prefix = "head.fc." if v2 else "fc." + loader.port_weight( + task.output_dense.kernel, + hf_weight_key=prefix + "weight", + hook_fn=lambda x, _: np.transpose(np.squeeze(x)), + ) + loader.port_weight( + task.output_dense.bias, + hf_weight_key=prefix + "bias", + ) + return task diff --git a/keras_nlp/src/utils/timm/preset_loader.py b/keras_nlp/src/utils/timm/preset_loader.py index 0f29007e1c..a853d5761f 100644 --- a/keras_nlp/src/utils/timm/preset_loader.py +++ b/keras_nlp/src/utils/timm/preset_loader.py @@ -13,6 +13,7 @@ # limitations under the License. """Convert timm models to KerasNLP.""" +from keras_nlp.src.models.image_classifier import ImageClassifier from keras_nlp.src.utils.preset_utils import PresetLoader from keras_nlp.src.utils.preset_utils import jax_memory_cleanup from keras_nlp.src.utils.timm import convert_resnet @@ -44,6 +45,23 @@ def load_backbone(self, cls, load_weights, **kwargs): self.converter.convert_weights(backbone, loader, self.config) return backbone + def load_task(self, cls, load_weights, load_task_weights, **kwargs): + if not load_task_weights or not issubclass(cls, ImageClassifier): + return super().load_task( + cls, load_weights, load_task_weights, **kwargs + ) + # Support loading the classification head for classifier models. + if "num_classes" not in kwargs: + kwargs["num_classes"] = self.config["num_classes"] + task = super().load_task(cls, load_weights, load_task_weights, **kwargs) + if load_weights: + with SafetensorLoader(self.preset, prefix="") as loader: + self.converter.convert_head(task, loader, self.config) + return task + def load_image_converter(self, cls, **kwargs): - # TODO. - return None + pretrained_cfg = self.config.get("pretrained_cfg", None) + if not pretrained_cfg or "input_size" not in pretrained_cfg: + return None + input_size = pretrained_cfg["input_size"] + return cls(width=input_size[1], height=input_size[2]) diff --git a/tools/checkpoint_conversion/convert_resnet_checkpoints.py b/tools/checkpoint_conversion/convert_resnet_checkpoints.py new file mode 100644 index 0000000000..a2eb95eb85 --- /dev/null +++ b/tools/checkpoint_conversion/convert_resnet_checkpoints.py @@ -0,0 +1,119 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert resnet checkpoints. + +python tools/checkpoint_conversion/convert_resnet_checkpoints.py \ + --preset resnet_18_imagenet --upload_uri kaggle://kerashub/resnetv1/keras/resnet_18_imagenet +python tools/checkpoint_conversion/convert_resnet_checkpoints.py \ + --preset resnet_50_imagenet --upload_uri kaggle://kerashub/resnetv1/keras/resnet_50_imagenet +python tools/checkpoint_conversion/convert_resnet_checkpoints.py \ + --preset resnet_101_imagenet --upload_uri kaggle://kerashub/resnetv1/keras/resnet_101_imagenet +python tools/checkpoint_conversion/convert_resnet_checkpoints.py \ + --preset resnet_152_imagenet --upload_uri kaggle://kerashub/resnetv1/keras/resnet_152_imagenet +python tools/checkpoint_conversion/convert_resnet_checkpoints.py \ + --preset resnet_v2_50_imagenet --upload_uri kaggle://kerashub/resnetv2/keras/resnet_v2_50_imagenet +python tools/checkpoint_conversion/convert_resnet_checkpoints.py \ + --preset resnet_v2_101_imagenet --upload_uri kaggle://kerashub/resnetv2/keras/resnet_v2_101_imagenet +""" + +import os +import shutil + +import keras +import numpy as np +import PIL +import timm +from absl import app +from absl import flags + +import keras_nlp + +PRESET_MAP = { + "resnet_18_imagenet": "timm/resnet18.a1_in1k", + "resnet_50_imagenet": "timm/resnet50.a1_in1k", + "resnet_101_imagenet": "timm/resnet101.a1_in1k", + "resnet_152_imagenet": "timm/resnet152.a1_in1k", + "resnet_v2_50_imagenet": "timm/resnetv2_50.a1h_in1k", + "resnet_v2_101_imagenet": "timm/resnetv2_101.a1h_in1k", +} +FLAGS = flags.FLAGS + + +flags.DEFINE_string( + "preset", + None, + "Must be a valid `CausalLM` preset from KerasNLP", + required=True, +) +flags.DEFINE_string( + "upload_uri", + None, + 'Could be "kaggle://keras/{variant}/keras/{preset}_int8"', + required=False, +) + + +def validate_output(keras_nlp_model, timm_model): + file = keras.utils.get_file( + origin=( + "https://storage.googleapis.com/keras-cv/" + "models/paligemma/cow_beach_1.png" + ) + ) + image = PIL.Image.open(file) + + # Call with Timm. + data_config = timm.data.resolve_model_data_config(timm_model) + transforms = timm.data.create_transform(**data_config, is_training=False) + timm_batch = transforms(image).unsqueeze(0) + timm_outputs = timm_model(timm_batch).detach().numpy() + # Call with Keras. + keras_outputs = keras_nlp_model.predict(np.array([image])) + + print("🔶 KerasNLP output:", keras_outputs[0, :10]) + print("🔶 TIMM output:", timm_outputs[0, :10]) + print("🔶 Difference:", np.mean(np.abs(keras_outputs - timm_outputs))) + + +def main(_): + preset = FLAGS.preset + if os.path.exists(preset): + shutil.rmtree(preset) + os.makedirs(preset) + + timm_name = PRESET_MAP[preset] + + print("✅ Loaded TIMM model.") + timm_model = timm.create_model(timm_name, pretrained=True) + timm_model = timm_model.eval() + + print("✅ Loaded KerasNLP model.") + keras_nlp_model = keras_nlp.models.ImageClassifier.from_preset( + "hf://" + timm_name, + ) + + keras_nlp_model.save_to_preset(f"./{preset}") + print(f"🏁 Preset saved to ./{preset}") + + validate_output(keras_nlp_model, timm_model) + + upload_uri = FLAGS.upload_uri + if upload_uri: + keras_nlp.upload_preset(uri=upload_uri, preset=f"./{preset}") + print(f"🏁 Preset uploaded to {upload_uri}") + + +if __name__ == "__main__": + flags.mark_flag_as_required("preset") + app.run(main)