diff --git a/keras_nlp/api/layers/__init__.py b/keras_nlp/api/layers/__init__.py index 8b92cc11b0..7def279b19 100644 --- a/keras_nlp/api/layers/__init__.py +++ b/keras_nlp/api/layers/__init__.py @@ -53,6 +53,9 @@ from keras_nlp.src.models.pali_gemma.pali_gemma_image_converter import ( PaliGemmaImageConverter, ) +from keras_nlp.src.models.resnet.resnet_image_converter import ( + ResNetImageConverter, +) from keras_nlp.src.models.whisper.whisper_audio_converter import ( WhisperAudioConverter, ) diff --git a/keras_nlp/api/models/__init__.py b/keras_nlp/api/models/__init__.py index 1ebe476424..2b3ffbb30b 100644 --- a/keras_nlp/api/models/__init__.py +++ b/keras_nlp/api/models/__init__.py @@ -172,6 +172,9 @@ ) from keras_nlp.src.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer from keras_nlp.src.models.image_classifier import ImageClassifier +from keras_nlp.src.models.image_classifier_preprocessor import ( + ImageClassifierPreprocessor, +) from keras_nlp.src.models.llama3.llama3_backbone import Llama3Backbone from keras_nlp.src.models.llama3.llama3_causal_lm import Llama3CausalLM from keras_nlp.src.models.llama3.llama3_causal_lm_preprocessor import ( @@ -231,6 +234,9 @@ from keras_nlp.src.models.resnet.resnet_image_classifier import ( ResNetImageClassifier, ) +from keras_nlp.src.models.resnet.resnet_image_classifier_preprocessor import ( + ResNetImageClassifierPreprocessor, +) from keras_nlp.src.models.roberta.roberta_backbone import RobertaBackbone from keras_nlp.src.models.roberta.roberta_masked_lm import RobertaMaskedLM from keras_nlp.src.models.roberta.roberta_masked_lm_preprocessor import ( diff --git a/keras_nlp/src/models/image_classifier.py b/keras_nlp/src/models/image_classifier.py index f0cc031dbc..0606a29821 100644 --- a/keras_nlp/src/models/image_classifier.py +++ b/keras_nlp/src/models/image_classifier.py @@ -33,11 +33,6 @@ class ImageClassifier(Task): used to load a pre-trained config and weights. """ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - # Default compilation. - self.compile() - def compile( self, optimizer="auto", diff --git a/keras_nlp/src/models/image_classifier_preprocessor.py b/keras_nlp/src/models/image_classifier_preprocessor.py new file mode 100644 index 0000000000..c354169893 --- /dev/null +++ b/keras_nlp/src/models/image_classifier_preprocessor.py @@ -0,0 +1,83 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.preprocessor import Preprocessor +from keras_nlp.src.utils.tensor_utils import preprocessing_function + + +@keras_nlp_export("keras_nlp.models.ImageClassifierPreprocessor") +class ImageClassifierPreprocessor(Preprocessor): + """Base class for image classification preprocessing layers. + + `ImageClassifierPreprocessor` tasks wraps a + `keras_nlp.layers.ImageConverter` to create a preprocessing layer for + image classification tasks. It is intended to be paired with a + `keras_nlp.models.ImageClassifier` task. + + All `ImageClassifierPreprocessor` take inputs three inputs, `x`, `y`, and + `sample_weight`. `x`, the first input, should always be included. It can + be a image or batch of images. See examples below. `y` and `sample_weight` + are optional inputs that will be passed through unaltered. Usually, `y` will + be the classification label, and `sample_weight` will not be provided. + + The layer will output either `x`, an `(x, y)` tuple if labels were provided, + or an `(x, y, sample_weight)` tuple if labels and sample weight were + provided. `x` will be the input images after all model preprocessing has + been applied. + + All `ImageClassifierPreprocessor` tasks include a `from_preset()` + constructor which can be used to load a pre-trained config and vocabularies. + You can call the `from_preset()` constructor directly on this base class, in + which case the correct class for your model will be automatically + instantiated. + + Examples. + ```python + preprocessor = keras_nlp.models.ImageClassifierPreprocessor.from_preset( + "resnet_50", + ) + + # Resize a single image for resnet 50. + x = np.ones((512, 512, 3)) + x = preprocessor(x) + + # Resize a labeled image. + x, y = np.ones((512, 512, 3)), 1 + x, y = preprocessor(x, y) + + # Resize a batch of labeled images. + x, y = [np.ones((512, 512, 3)), np.zeros((512, 512, 3))], [1, 0] + x, y = preprocessor(x, y) + + # Use a `tf.data.Dataset`. + ds = tf.data.Dataset.from_tensor_slices((x, y)).batch(2) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + ``` + """ + + def __init__( + self, + image_converter=None, + **kwargs, + ): + super().__init__(**kwargs) + self.image_converter = image_converter + + @preprocessing_function + def call(self, x, y=None, sample_weight=None): + if self.image_converter: + x = self.image_converter(x) + return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) diff --git a/keras_nlp/src/models/resnet/__init__.py b/keras_nlp/src/models/resnet/__init__.py index 3364a6bd16..a09d7a80bb 100644 --- a/keras_nlp/src/models/resnet/__init__.py +++ b/keras_nlp/src/models/resnet/__init__.py @@ -11,3 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from keras_nlp.src.models.resnet.resnet_backbone import ResNetBackbone +from keras_nlp.src.models.resnet.resnet_presets import backbone_presets +from keras_nlp.src.utils.preset_utils import register_presets + +register_presets(backbone_presets, ResNetBackbone) diff --git a/keras_nlp/src/models/resnet/resnet_backbone.py b/keras_nlp/src/models/resnet/resnet_backbone.py index ca1de9b090..c9a794b3c8 100644 --- a/keras_nlp/src/models/resnet/resnet_backbone.py +++ b/keras_nlp/src/models/resnet/resnet_backbone.py @@ -57,7 +57,7 @@ class ResNetBackbone(FeaturePyramidBackbone): stackwise_num_blocks: list of ints. The number of blocks for each stack. stackwise_num_strides: list of ints. The number of strides for each stack. - block_type: str. The block type to stack. One of `"basic_block"` or + block_type: str. The block type to stack. One of `"basic_block"`, `"bottleneck_block"`, `"basic_block_vd"` or `"bottleneck_block_vd"`. Use `"basic_block"` for ResNet18 and ResNet34. Use `"bottleneck_block"` for ResNet50, ResNet101 and @@ -126,7 +126,6 @@ def __init__( use_pre_activation=False, include_rescaling=True, image_shape=(None, None, 3), - pooling="avg", data_format=None, dtype=None, **kwargs, @@ -285,20 +284,9 @@ def __init__( )(x) x = layers.Activation("relu", dtype=dtype, name="post_relu")(x) - if pooling == "avg": - feature_map_output = layers.GlobalAveragePooling2D( - data_format=data_format, dtype=dtype - )(x) - elif pooling == "max": - feature_map_output = layers.GlobalMaxPooling2D( - data_format=data_format, dtype=dtype - )(x) - else: - feature_map_output = x - super().__init__( inputs=image_input, - outputs=feature_map_output, + outputs=x, dtype=dtype, **kwargs, ) @@ -313,8 +301,8 @@ def __init__( self.use_pre_activation = use_pre_activation self.include_rescaling = include_rescaling self.image_shape = image_shape - self.pooling = pooling self.pyramid_outputs = pyramid_outputs + self.data_format = data_format def get_config(self): config = super().get_config() @@ -329,7 +317,6 @@ def get_config(self): "use_pre_activation": self.use_pre_activation, "include_rescaling": self.include_rescaling, "image_shape": self.image_shape, - "pooling": self.pooling, } ) return config diff --git a/keras_nlp/src/models/resnet/resnet_backbone_test.py b/keras_nlp/src/models/resnet/resnet_backbone_test.py index 33d4debac1..d6882823c7 100644 --- a/keras_nlp/src/models/resnet/resnet_backbone_test.py +++ b/keras_nlp/src/models/resnet/resnet_backbone_test.py @@ -29,74 +29,93 @@ def setUp(self): "stackwise_num_blocks": [2, 2, 2], "stackwise_num_strides": [1, 2, 2], "image_shape": (None, None, 3), - "pooling": "avg", + "block_type": "bottleneck_block", + "use_pre_activation": False, } self.input_size = 64 self.input_data = ops.ones((2, self.input_size, self.input_size, 3)) @parameterized.named_parameters( - ("v1_basic", False, "basic_block"), - ("v1_bottleneck", False, "bottleneck_block"), - ("v2_basic", True, "basic_block"), - ("v2_bottleneck", True, "bottleneck_block"), - ("vd_basic", False, "basic_block_vd"), - ("vd_bottleneck", False, "bottleneck_block_vd"), + ("basic", "basic_block"), + ("bottleneck", "bottleneck_block"), ) - def test_backbone_basics(self, use_pre_activation, block_type): - init_kwargs = self.init_kwargs.copy() - init_kwargs.update( - { - "block_type": block_type, - "use_pre_activation": use_pre_activation, - } + def test_backbone_basics(self, block_type): + feature_size = 64 if block_type == "basic_block" else 256 + self.run_vision_backbone_test( + cls=ResNetBackbone, + init_kwargs={**self.init_kwargs, "block_type": block_type}, + input_data=self.input_data, + expected_output_shape=(2, 4, 4, feature_size), + expected_pyramid_output_keys=["P2", "P3", "P4"], + expected_pyramid_image_sizes=[(16, 16), (8, 8), (4, 4)], ) - if block_type in ("basic_block_vd", "bottleneck_block_vd"): - init_kwargs.update( - { - "input_conv_filters": [32, 32, 64], - "input_conv_kernel_sizes": [3, 3, 3], - } - ) + + @parameterized.named_parameters( + ("basic", "basic_block"), + ("bottleneck", "bottleneck_block"), + ) + def test_backbone_v2(self, block_type): + feature_size = 64 if block_type == "basic_block" else 256 self.run_vision_backbone_test( cls=ResNetBackbone, - init_kwargs=init_kwargs, + init_kwargs={ + **self.init_kwargs, + "block_type": block_type, + "use_pre_activation": True, + }, input_data=self.input_data, - expected_output_shape=( - (2, 64) - if block_type in ("basic_block", "basic_block_vd") - else (2, 256) - ), + expected_output_shape=(2, 4, 4, feature_size), expected_pyramid_output_keys=["P2", "P3", "P4"], expected_pyramid_image_sizes=[(16, 16), (8, 8), (4, 4)], ) @parameterized.named_parameters( - ("v1_basic", False, "basic_block"), - ("v1_bottleneck", False, "bottleneck_block"), - ("v2_basic", True, "basic_block"), - ("v2_bottleneck", True, "bottleneck_block"), - ("vd_basic", False, "basic_block_vd"), - ("vd_bottleneck", False, "bottleneck_block_vd"), + ("basic", "basic_block_vd"), + ("bottleneck", "bottleneck_block_vd"), ) - @pytest.mark.large - def test_saved_model(self, use_pre_activation, block_type): - init_kwargs = self.init_kwargs.copy() - init_kwargs.update( - { + def test_backbone_vd(self, block_type): + feature_size = 64 if block_type == "basic_block_vd" else 256 + self.run_vision_backbone_test( + cls=ResNetBackbone, + init_kwargs={ + **self.init_kwargs, "block_type": block_type, - "use_pre_activation": use_pre_activation, - "image_shape": (None, None, 3), - } + "input_conv_filters": [32, 32, 64], + "input_conv_kernel_sizes": [3, 3, 3], + }, + input_data=self.input_data, + expected_output_shape=(2, 4, 4, feature_size), + expected_pyramid_output_keys=["P2", "P3", "P4"], + expected_pyramid_image_sizes=[(16, 16), (8, 8), (4, 4)], ) - if block_type in ("basic_block_vd", "bottleneck_block_vd"): - init_kwargs.update( - { - "input_conv_filters": [32, 32, 64], - "input_conv_kernel_sizes": [3, 3, 3], - } - ) + + @pytest.mark.large + def test_saved_model(self): self.run_model_saving_test( cls=ResNetBackbone, - init_kwargs=init_kwargs, + init_kwargs=self.init_kwargs, input_data=self.input_data, ) + + @pytest.mark.large + def test_smallest_preset(self): + image_batch = self.load_test_image((224, 224))[None, ...] + self.run_preset_test( + cls=ResNetBackbone, + preset="resnet_18_imagenet", + input_data=image_batch, + expected_output_shape=(1, 7, 7, 512), + # The forward pass from a preset should be stable! + expected_partial_output=ops.array( + [0.008969, 0.015136, 0.028074, 0.594599, 0.002846] + ), + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in ResNetBackbone.presets: + self.run_preset_test( + cls=ResNetBackbone, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/src/models/resnet/resnet_image_classifier.py b/keras_nlp/src/models/resnet/resnet_image_classifier.py index 815dc7fcca..c9844b12b0 100644 --- a/keras_nlp/src/models/resnet/resnet_image_classifier.py +++ b/keras_nlp/src/models/resnet/resnet_image_classifier.py @@ -16,6 +16,9 @@ from keras_nlp.src.api_export import keras_nlp_export from keras_nlp.src.models.image_classifier import ImageClassifier from keras_nlp.src.models.resnet.resnet_backbone import ResNetBackbone +from keras_nlp.src.models.resnet.resnet_image_classifier_preprocessor import ( + ResNetImageClassifierPreprocessor, +) @keras_nlp_export("keras_nlp.models.ResNetImageClassifier") @@ -88,21 +91,36 @@ class ResNetImageClassifier(ImageClassifier): """ backbone_cls = ResNetBackbone + preprocessor_cls = ResNetImageClassifierPreprocessor def __init__( self, backbone, num_classes, - activation="softmax", + preprocessor=None, + pooling="avg", + activation=None, head_dtype=None, - preprocessor=None, # adding this dummy arg for saved model test - # TODO: once preprocessor flow is figured out, this needs to be updated **kwargs, ): head_dtype = head_dtype or backbone.dtype_policy # === Layers === self.backbone = backbone + self.preprocessor = preprocessor + if pooling == "avg": + self.pooler = keras.layers.GlobalAveragePooling2D( + data_format=backbone.data_format, dtype=head_dtype + ) + elif pooling == "max": + self.pooler = keras.layers.GlobalAveragePooling2D( + data_format=backbone.data_format, dtype=head_dtype + ) + else: + raise ValueError( + "Unknown `pooling` type. Polling should be either `'avg'` or " + f"`'max'`. Received: pooling={pooling}." + ) self.output_dense = keras.layers.Dense( num_classes, activation=activation, @@ -113,6 +131,7 @@ def __init__( # === Functional Model === inputs = self.backbone.input x = self.backbone(inputs) + x = self.pooler(x) outputs = self.output_dense(x) super().__init__( inputs=inputs, @@ -123,6 +142,7 @@ def __init__( # === Config === self.num_classes = num_classes self.activation = activation + self.pooling = pooling def get_config(self): # Backbone serialized in `super` @@ -130,6 +150,7 @@ def get_config(self): config.update( { "num_classes": self.num_classes, + "pooling": self.pooling, "activation": self.activation, } ) diff --git a/keras_nlp/src/models/resnet/resnet_image_classifier_preprocessor.py b/keras_nlp/src/models/resnet/resnet_image_classifier_preprocessor.py new file mode 100644 index 0000000000..0da2c67deb --- /dev/null +++ b/keras_nlp/src/models/resnet/resnet_image_classifier_preprocessor.py @@ -0,0 +1,28 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.image_classifier_preprocessor import ( + ImageClassifierPreprocessor, +) +from keras_nlp.src.models.resnet.resnet_backbone import ResNetBackbone +from keras_nlp.src.models.resnet.resnet_image_converter import ( + ResNetImageConverter, +) + + +@keras_nlp_export("keras_nlp.models.ResNetImageClassifierPreprocessor") +class ResNetImageClassifierPreprocessor(ImageClassifierPreprocessor): + backbone_cls = ResNetBackbone + image_converter_cls = ResNetImageConverter diff --git a/keras_nlp/src/models/resnet/resnet_image_classifier_test.py b/keras_nlp/src/models/resnet/resnet_image_classifier_test.py index da06c80320..d9f689c6d6 100644 --- a/keras_nlp/src/models/resnet/resnet_image_classifier_test.py +++ b/keras_nlp/src/models/resnet/resnet_image_classifier_test.py @@ -35,11 +35,11 @@ def setUp(self): use_pre_activation=True, image_shape=(16, 16, 3), include_rescaling=False, - pooling="avg", ) self.init_kwargs = { "backbone": self.backbone, "num_classes": 2, + "pooling": "avg", "activation": "softmax", } self.train_data = (self.images, self.labels) @@ -66,3 +66,14 @@ def test_saved_model(self): init_kwargs=self.init_kwargs, input_data=self.images, ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in ResNetImageClassifier.presets: + self.run_preset_test( + cls=ResNetImageClassifier, + preset=preset, + init_kwargs={"num_classes": 2}, + input_data=self.images, + expected_output_shape=(2, 2), + ) diff --git a/keras_nlp/src/models/resnet/resnet_image_converter.py b/keras_nlp/src/models/resnet/resnet_image_converter.py new file mode 100644 index 0000000000..876dffc96d --- /dev/null +++ b/keras_nlp/src/models/resnet/resnet_image_converter.py @@ -0,0 +1,23 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.layers.preprocessing.resizing_image_converter import ( + ResizingImageConverter, +) +from keras_nlp.src.models.resnet.resnet_backbone import ResNetBackbone + + +@keras_nlp_export("keras_nlp.layers.ResNetImageConverter") +class ResNetImageConverter(ResizingImageConverter): + backbone_cls = ResNetBackbone diff --git a/keras_nlp/src/models/resnet/resnet_presets.py b/keras_nlp/src/models/resnet/resnet_presets.py new file mode 100644 index 0000000000..8030e7257f --- /dev/null +++ b/keras_nlp/src/models/resnet/resnet_presets.py @@ -0,0 +1,95 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ResNet preset configurations.""" + +backbone_presets = { + "resnet_18_imagenet": { + "metadata": { + "description": ( + "18-layer ResNet model pre-trained on the ImageNet 1k dataset " + "at a 224x224 resolution." + ), + "params": 11186112, + "official_name": "ResNet", + "path": "resnet", + "model_card": "https://arxiv.org/abs/2110.00476", + }, + "kaggle_handle": "kaggle://kerashub/resnetv1/keras/resnet_18_imagenet/2", + }, + "resnet_50_imagenet": { + "metadata": { + "description": ( + "50-layer ResNet model pre-trained on the ImageNet 1k dataset " + "at a 224x224 resolution." + ), + "params": 23561152, + "official_name": "ResNet", + "path": "resnet", + "model_card": "https://arxiv.org/abs/2110.00476", + }, + "kaggle_handle": "kaggle://kerashub/resnetv1/keras/resnet_50_imagenet/2", + }, + "resnet_101_imagenet": { + "metadata": { + "description": ( + "101-layer ResNet model pre-trained on the ImageNet 1k dataset " + "at a 224x224 resolution." + ), + "params": 42605504, + "official_name": "ResNet", + "path": "resnet", + "model_card": "https://arxiv.org/abs/2110.00476", + }, + "kaggle_handle": "kaggle://kerashub/resnetv1/keras/resnet_101_imagenet/2", + }, + "resnet_152_imagenet": { + "metadata": { + "description": ( + "152-layer ResNet model pre-trained on the ImageNet 1k dataset " + "at a 224x224 resolution." + ), + "params": 58295232, + "official_name": "ResNet", + "path": "resnet", + "model_card": "https://arxiv.org/abs/2110.00476", + }, + "kaggle_handle": "kaggle://kerashub/resnetv1/keras/resnet_152_imagenet/2", + }, + "resnet_v2_50_imagenet": { + "metadata": { + "description": ( + "50-layer ResNetV2 model pre-trained on the ImageNet 1k " + "dataset at a 224x224 resolution." + ), + "params": 23561152, + "official_name": "ResNet", + "path": "resnet", + "model_card": "https://arxiv.org/abs/2110.00476", + }, + "kaggle_handle": "kaggle://kerashub/resnetv2/keras/resnet_v2_50_imagenet/2", + }, + "resnet_v2_101_imagenet": { + "metadata": { + "description": ( + "101-layer ResNetV2 model pre-trained on the ImageNet 1k " + "dataset at a 224x224 resolution." + ), + "params": 42605504, + "official_name": "ResNet", + "path": "resnet", + "model_card": "https://arxiv.org/abs/2110.00476", + }, + "kaggle_handle": "kaggle://kerashub/resnetv2/keras/resnet_v2_101_imagenet/2", + }, +} diff --git a/keras_nlp/src/models/text_classifier.py b/keras_nlp/src/models/text_classifier.py index d28985f67f..fd84cac3a3 100644 --- a/keras_nlp/src/models/text_classifier.py +++ b/keras_nlp/src/models/text_classifier.py @@ -63,9 +63,6 @@ class TextClassifier(Task): ``` """ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - def compile( self, optimizer="auto", diff --git a/keras_nlp/src/tests/test_case.py b/keras_nlp/src/tests/test_case.py index 7e2d0660b5..b661e387c4 100644 --- a/keras_nlp/src/tests/test_case.py +++ b/keras_nlp/src/tests/test_case.py @@ -18,6 +18,7 @@ import re import keras +import numpy as np import tensorflow as tf from absl.testing import parameterized from keras import ops @@ -493,6 +494,7 @@ def run_vision_backbone_test( run_mixed_precision_check=run_mixed_precision_check, run_quantization_check=run_quantization_check, ) + if expected_pyramid_output_keys: backbone = cls(**init_kwargs) model = keras.models.Model( @@ -522,6 +524,12 @@ def run_vision_backbone_test( input_data = ops.transpose(input_data, axes=(2, 0, 1)) elif len(input_data_shape) == 4: input_data = ops.transpose(input_data, axes=(0, 3, 1, 2)) + if len(expected_output_shape) == 3: + x = expected_output_shape + expected_output_shape = (x[0], x[2], x[1]) + elif len(expected_output_shape) == 4: + x = expected_output_shape + expected_output_shape = (x[0], x[3], x[1], x[2]) if "image_shape" in init_kwargs: init_kwargs = init_kwargs.copy() init_kwargs["image_shape"] = tuple( @@ -631,3 +639,8 @@ def compare(actual, expected): def get_test_data_dir(self): return str(pathlib.Path(__file__).parent / "test_data") + + def load_test_image(self, size): + path = os.path.join(self.get_test_data_dir(), "test_image.png") + img = keras.utils.load_img(path, target_size=size) + return np.array(img) diff --git a/keras_nlp/src/tests/test_data/test_image.png b/keras_nlp/src/tests/test_data/test_image.png new file mode 100644 index 0000000000..f83575262e Binary files /dev/null and b/keras_nlp/src/tests/test_data/test_image.png differ diff --git a/keras_nlp/src/utils/preset_utils.py b/keras_nlp/src/utils/preset_utils.py index 5215ac7e59..a935fd160b 100644 --- a/keras_nlp/src/utils/preset_utils.py +++ b/keras_nlp/src/utils/preset_utils.py @@ -227,7 +227,7 @@ def get_file(preset, path): else: raise ValueError(message) elif os.path.exists(preset): - # Assume a local filepath. + # Assume a local filepath.pyth local_path = os.path.join(preset, path) if not os.path.exists(local_path): raise FileNotFoundError( @@ -345,19 +345,9 @@ def save_metadata(layer, preset): metadata_file.write(json.dumps(metadata, indent=4)) -def _validate_tokenizer(preset, allow_incomplete=False): +def _validate_tokenizer(preset): if not check_file_exists(preset, TOKENIZER_CONFIG_FILE): - if allow_incomplete: - logging.warning( - f"`{TOKENIZER_CONFIG_FILE}` is missing from the preset directory `{preset}`." - ) - return - else: - raise FileNotFoundError( - f"`{TOKENIZER_CONFIG_FILE}` is missing from the preset directory `{preset}`. " - "To upload the model without a tokenizer, " - "set `allow_incomplete=True`." - ) + return config_path = get_file(preset, TOKENIZER_CONFIG_FILE) try: with open(config_path, encoding="utf-8") as config_file: @@ -485,7 +475,6 @@ def delete_model_card(preset): def upload_preset( uri, preset, - allow_incomplete=False, ): """Upload a preset directory to a model hub. @@ -497,9 +486,6 @@ def upload_preset( `hf://[/]` will be uploaded to the Hugging Face Hub. preset: The path to the local model preset directory. - allow_incomplete: If True, allows the upload of presets without - a tokenizer configuration. Otherwise, a tokenizer - is required. """ # Check if preset directory exists. @@ -507,7 +493,7 @@ def upload_preset( raise FileNotFoundError(f"The preset directory {preset} doesn't exist.") _validate_backbone(preset) - _validate_tokenizer(preset, allow_incomplete) + _validate_tokenizer(preset) if uri.startswith(KAGGLE_PREFIX): if kagglehub is None: @@ -695,7 +681,7 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs): kwargs["backbone"] = self.load_backbone( backbone_class, load_weights, **backbone_kwargs ) - if "preprocessor" not in kwargs: + if "preprocessor" not in kwargs and cls.preprocessor_cls: kwargs["preprocessor"] = self.load_preprocessor( cls.preprocessor_cls, ) @@ -760,7 +746,7 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs): ) # We found a `task.json` with a complete config for our class. task = load_serialized_object(task_config, **kwargs) - if task.preprocessor is not None: + if task.preprocessor and task.preprocessor.tokenizer: task.preprocessor.tokenizer.load_preset_assets(self.preset) if load_weights: has_task_weights = check_file_exists(self.preset, TASK_WEIGHTS_FILE) diff --git a/keras_nlp/src/utils/timm/convert_resnet.py b/keras_nlp/src/utils/timm/convert_resnet.py index 8671a8f6b3..9ede5e8b73 100644 --- a/keras_nlp/src/utils/timm/convert_resnet.py +++ b/keras_nlp/src/utils/timm/convert_resnet.py @@ -158,3 +158,17 @@ def port_batch_normalization(keras_layer_name, hf_weight_prefix): normalization_layer.input_mean = mean normalization_layer.input_variance = [s**2 for s in std] normalization_layer.build(normalization_layer._build_input_shape) + + +def convert_head(task, loader, timm_config): + v2 = "resnetv2_" in timm_config["architecture"] + prefix = "head.fc." if v2 else "fc." + loader.port_weight( + task.output_dense.kernel, + hf_weight_key=prefix + "weight", + hook_fn=lambda x, _: np.transpose(np.squeeze(x)), + ) + loader.port_weight( + task.output_dense.bias, + hf_weight_key=prefix + "bias", + ) diff --git a/keras_nlp/src/utils/timm/preset_loader.py b/keras_nlp/src/utils/timm/preset_loader.py index 0f29007e1c..69476f378c 100644 --- a/keras_nlp/src/utils/timm/preset_loader.py +++ b/keras_nlp/src/utils/timm/preset_loader.py @@ -13,6 +13,7 @@ # limitations under the License. """Convert timm models to KerasNLP.""" +from keras_nlp.src.models.image_classifier import ImageClassifier from keras_nlp.src.utils.preset_utils import PresetLoader from keras_nlp.src.utils.preset_utils import jax_memory_cleanup from keras_nlp.src.utils.timm import convert_resnet @@ -44,6 +45,22 @@ def load_backbone(self, cls, load_weights, **kwargs): self.converter.convert_weights(backbone, loader, self.config) return backbone + def load_task(self, cls, load_weights, load_task_weights, **kwargs): + if not load_task_weights or not issubclass(cls, ImageClassifier): + return super().load_task( + cls, load_weights, load_task_weights, **kwargs + ) + # Support loading the classification head for classifier models. + kwargs["num_classes"] = self.config["num_classes"] + task = super().load_task(cls, load_weights, load_task_weights, **kwargs) + if load_task_weights: + with SafetensorLoader(self.preset, prefix="") as loader: + self.converter.convert_head(task, loader, self.config) + return task + def load_image_converter(self, cls, **kwargs): - # TODO. - return None + pretrained_cfg = self.config.get("pretrained_cfg", None) + if not pretrained_cfg or "input_size" not in pretrained_cfg: + return None + input_size = pretrained_cfg["input_size"] + return cls(width=input_size[1], height=input_size[2]) diff --git a/requirements-common.txt b/requirements-common.txt index 4e90ca9fab..2bdc4a5720 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -18,3 +18,4 @@ rouge-score sentencepiece tensorflow-datasets safetensors +pillow diff --git a/tools/checkpoint_conversion/convert_resnet_checkpoints.py b/tools/checkpoint_conversion/convert_resnet_checkpoints.py new file mode 100644 index 0000000000..5ac72874e8 --- /dev/null +++ b/tools/checkpoint_conversion/convert_resnet_checkpoints.py @@ -0,0 +1,125 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert resnet checkpoints. + +python tools/checkpoint_conversion/convert_resnet_checkpoints.py \ + --preset resnet_18_imagenet --upload_uri kaggle://kerashub/resnetv1/keras/resnet_18_imagenet +python tools/checkpoint_conversion/convert_resnet_checkpoints.py \ + --preset resnet_50_imagenet --upload_uri kaggle://kerashub/resnetv1/keras/resnet_50_imagenet +python tools/checkpoint_conversion/convert_resnet_checkpoints.py \ + --preset resnet_101_imagenet --upload_uri kaggle://kerashub/resnetv1/keras/resnet_101_imagenet +python tools/checkpoint_conversion/convert_resnet_checkpoints.py \ + --preset resnet_152_imagenet --upload_uri kaggle://kerashub/resnetv1/keras/resnet_152_imagenet +python tools/checkpoint_conversion/convert_resnet_checkpoints.py \ + --preset resnet_v2_50_imagenet --upload_uri kaggle://kerashub/resnetv2/keras/resnet_v2_50_imagenet +python tools/checkpoint_conversion/convert_resnet_checkpoints.py \ + --preset resnet_v2_101_imagenet --upload_uri kaggle://kerashub/resnetv2/keras/resnet_v2_101_imagenet +""" + +import os +import shutil + +import keras +import numpy as np +import PIL +import timm +import torch +from absl import app +from absl import flags + +import keras_nlp + +PRESET_MAP = { + "resnet_18_imagenet": "timm/resnet18.a1_in1k", + "resnet_50_imagenet": "timm/resnet50.a1_in1k", + "resnet_101_imagenet": "timm/resnet101.a1_in1k", + "resnet_152_imagenet": "timm/resnet152.a1_in1k", + "resnet_v2_50_imagenet": "timm/resnetv2_50.a1h_in1k", + "resnet_v2_101_imagenet": "timm/resnetv2_101.a1h_in1k", +} +FLAGS = flags.FLAGS + + +flags.DEFINE_string( + "preset", + None, + "Must be a valid `ResNet` preset from KerasNLP", + required=True, +) +flags.DEFINE_string( + "upload_uri", + None, + 'Could be "kaggle://keras/{variant}/keras/{preset}_int8"', + required=False, +) + + +def validate_output(keras_model, timm_model): + file = keras.utils.get_file( + origin=( + "https://storage.googleapis.com/keras-cv/" + "models/paligemma/cow_beach_1.png" + ) + ) + image = PIL.Image.open(file) + batch = np.array([image]) + + # Call with Timm. + timm_batch = keras_model.preprocessor(batch) + timm_batch = keras.ops.transpose(timm_batch, axes=(0, 3, 1, 2)) / 255.0 + timm_batch = torch.from_numpy(np.array(timm_batch)) + timm_outputs = timm_model(timm_batch).detach().numpy() + timm_label = np.argmax(timm_outputs[0]) + # Call with Keras. + keras_outputs = keras_model.predict(batch) + keras_label = np.argmax(keras_outputs[0]) + + print("🔶 Keras output:", keras_outputs[0, :10]) + print("🔶 TIMM output:", timm_outputs[0, :10]) + print("🔶 Difference:", np.mean(np.abs(keras_outputs - timm_outputs))) + print("🔶 Keras label:", keras_label) + print("🔶 TIMM label:", timm_label) + + +def main(_): + preset = FLAGS.preset + if os.path.exists(preset): + shutil.rmtree(preset) + os.makedirs(preset) + + timm_name = PRESET_MAP[preset] + + print("✅ Loaded TIMM model.") + timm_model = timm.create_model(timm_name, pretrained=True) + timm_model = timm_model.eval() + + print("✅ Loaded KerasNLP model.") + keras_model = keras_nlp.models.ImageClassifier.from_preset( + "hf://" + timm_name, + ) + + keras_model.save_to_preset(f"./{preset}") + print(f"🏁 Preset saved to ./{preset}") + + validate_output(keras_model, timm_model) + + upload_uri = FLAGS.upload_uri + if upload_uri: + keras_nlp.upload_preset(uri=upload_uri, preset=f"./{preset}") + print(f"🏁 Preset uploaded to {upload_uri}") + + +if __name__ == "__main__": + flags.mark_flag_as_required("preset") + app.run(main)