diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py index a287bcca88..6b85148caf 100644 --- a/keras_hub/api/layers/__init__.py +++ b/keras_hub/api/layers/__init__.py @@ -33,9 +33,6 @@ ) from keras_hub.src.layers.preprocessing.random_deletion import RandomDeletion from keras_hub.src.layers.preprocessing.random_swap import RandomSwap -from keras_hub.src.layers.preprocessing.resizing_image_converter import ( - ResizingImageConverter, -) from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker from keras_hub.src.models.densenet.densenet_image_converter import ( DenseNetImageConverter, diff --git a/keras_hub/src/layers/preprocessing/image_converter.py b/keras_hub/src/layers/preprocessing/image_converter.py index b93e36e069..6c1c576b93 100644 --- a/keras_hub/src/layers/preprocessing/image_converter.py +++ b/keras_hub/src/layers/preprocessing/image_converter.py @@ -1,47 +1,146 @@ +import math + +import keras +from keras import ops + from keras_hub.src.api_export import keras_hub_export from keras_hub.src.layers.preprocessing.preprocessing_layer import ( PreprocessingLayer, ) +from keras_hub.src.utils.keras_utils import standardize_data_format from keras_hub.src.utils.preset_utils import IMAGE_CONVERTER_CONFIG_FILE from keras_hub.src.utils.preset_utils import builtin_presets from keras_hub.src.utils.preset_utils import find_subclass from keras_hub.src.utils.preset_utils import get_preset_loader from keras_hub.src.utils.preset_utils import save_serialized_object from keras_hub.src.utils.python_utils import classproperty +from keras_hub.src.utils.tensor_utils import preprocessing_function @keras_hub_export("keras_hub.layers.ImageConverter") class ImageConverter(PreprocessingLayer): - """Convert raw image for models that support image input. + """Preprocess raw images into model ready inputs. + + This class converts from raw images to model ready inputs. This conversion + proceeds in the following steps: - This class converts from raw images of any size, to preprocessed - images for pretrained model inputs. It is meant to be a convenient way to - write custom preprocessing code that is not model specific. This layer - should be instantiated via the `from_preset()` constructor, which will - create the correct subclass of this layer for the model preset. + 1. Resize the image using to `image_size`. If `image_size` is `None`, this + step will be skipped. + 2. Rescale the image by multiplying by `scale`, which can be either global + or per channel. If `scale` is `None`, this step will be skipped. + 3. Offset the image by adding `offset`, which can be either global + or per channel. If `offset` is `None`, this step will be skipped. The layer will take as input a raw image tensor in the channels last or channels first format, and output a preprocessed image input for modeling. - The exact structure of the output will vary per model, though in most cases - this layer will simply resize the image to the size needed by the model - input. + This tensor can be batched (rank 4), or unbatched (rank 3). + + This layer can be used with the `from_preset()` constructor to load a layer + that will rescale and resize an image for a specific pretrained model. + Using the layer this way allows writing preprocessing code that does not + need updating when switching between model checkpoints. Examples: ```python - # Resize images for `"pali_gemma_3b_224"`. - converter = keras_hub.layers.ImageConverter.from_preset("pali_gemma_3b_224") - converter(np.ones(2, 512, 512, 3)) # Output shape: (2, 224, 224, 3) - # Resize images for `"pali_gemma_3b_448"`. - converter = keras_hub.layers.ImageConverter.from_preset("pali_gemma_3b_448") - converter(np.ones(2, 512, 512, 3)) # Output shape: (2, 448, 448, 3) + # Resize raw images and scale them to [0, 1]. + converter = keras_hub.layers.ImageConverter( + image_size=(128, 128), + scale=1 / 255. + ) + converter(np.random.randint(0, 256, size=(2, 512, 512, 3))) + + # Resize images to the specific size needed for a PaliGemma preset. + converter = keras_hub.layers.ImageConverter.from_preset( + "pali_gemma_3b_224" + ) + converter(np.random.randint(0, 256, size=(2, 512, 512, 3))) ``` """ backbone_cls = None + def __init__( + self, + image_size=None, + scale=None, + offset=None, + crop_to_aspect_ratio=True, + interpolation="bilinear", + data_format=None, + **kwargs, + ): + # TODO: old arg names. Delete this block after resaving Kaggle assets. + if "height" in kwargs and "width" in kwargs: + image_size = (kwargs.pop("height"), kwargs.pop("width")) + if "variance" in kwargs and "mean" in kwargs: + std = [math.sqrt(v) for v in kwargs.pop("variance")] + scale = [scale / s for s in std] + offset = [- m / s for m, s in zip(kwargs.pop("mean"), std)] + + super().__init__(**kwargs) + + # Create the `Resizing` layer here even if it's not being used. That + # allows us to make `image_size` a settable property. + self.resizing = keras.layers.Resizing( + height=image_size[0] if image_size else None, + width=image_size[1] if image_size else None, + crop_to_aspect_ratio=crop_to_aspect_ratio, + interpolation=interpolation, + data_format=data_format, + dtype=self.dtype_policy, + name="resizing", + ) + self.scale = scale + self.offset = offset + self.crop_to_aspect_ratio = crop_to_aspect_ratio + self.interpolation = interpolation + self.data_format = standardize_data_format(data_format) + + @property def image_size(self): - """Returns the default size of a single image.""" - return (None, None) + """Settable tuple of `(height, width)` ints. The output image shape.""" + if self.resizing.height is None: + return None + return (self.resizing.height, self.resizing.width) + + @image_size.setter + def image_size(self, value): + if value is None: + value = (None, None) + self.resizing.height = value[0] + self.resizing.width = value[1] + + @preprocessing_function + def call(self, inputs): + if self.image_size is not None: + x = self.resizing(inputs) + if self.scale is not None: + x = x * self._expand_non_channel_dims(self.scale, x) + if self.offset is not None: + x = x + self._expand_non_channel_dims(self.offset, x) + return x + + def _expand_non_channel_dims(self, value, inputs): + unbatched = len(ops.shape(inputs)) == 3 + channels_first = self.data_format == "channels_first" + if unbatched: + broadcast_dims = (1, 2) if channels_first else (0, 1) + else: + broadcast_dims = (0, 2, 3) if channels_first else (0, 1, 2) + return ops.expand_dims(ops.array(value), broadcast_dims) + + def get_config(self): + config = super().get_config() + config.update( + { + "image_size": self.image_size, + "scale": self.scale, + "offset": self.offset, + "interpolation": self.interpolation, + "crop_to_aspect_ratio": self.crop_to_aspect_ratio, + } + ) + return config @classproperty def presets(cls): @@ -69,13 +168,6 @@ def from_preset( You can run `cls.presets.keys()` to list all built-in presets available on the class. - This constructor can be called in one of two ways. Either from the base - class like `keras_hub.models.ImageConverter.from_preset()`, or from a - model class like - `keras_hub.models.PaliGemmaImageConverter.from_preset()`. If calling - from the base class, the subclass of the returning object will be - inferred from the config in the preset directory. - Args: preset: string. A built-in preset identifier, a Kaggle Models handle, a Hugging Face handle, or a path to a local directory. @@ -85,17 +177,20 @@ class like `keras_hub.models.ImageConverter.from_preset()`, or from a Examples: ```python + batch = np.random.randint(0, 256, size=(2, 512, 512, 3)) + # Resize images for `"pali_gemma_3b_224"`. converter = keras_hub.layers.ImageConverter.from_preset( "pali_gemma_3b_224" ) - converter(np.ones(2, 512, 512, 3)) # Output shape: (2, 224, 224, 3) - # Override arguments on the base class. + converter(batch) # Output shape: (2, 224, 224, 3) + + # Resize images for `"pali_gemma_3b_448"` without cropping. converter = keras_hub.layers.ImageConverter.from_preset( "pali_gemma_3b_448", crop_to_aspect_ratio=False, ) - converter(np.ones(2, 512, 512, 3)) # (2, 448, 448, 3) + converter(batch) # Output shape: (2, 448, 448, 3) ``` """ loader = get_preset_loader(preset) diff --git a/keras_hub/src/layers/preprocessing/image_converter_test.py b/keras_hub/src/layers/preprocessing/image_converter_test.py index 3c0651ad72..ca6d79b3f4 100644 --- a/keras_hub/src/layers/preprocessing/image_converter_test.py +++ b/keras_hub/src/layers/preprocessing/image_converter_test.py @@ -3,6 +3,7 @@ import numpy as np import pytest +from keras import ops from keras_hub.src.layers.preprocessing.image_converter import ImageConverter from keras_hub.src.models.pali_gemma.pali_gemma_backbone import ( @@ -15,6 +16,50 @@ class ImageConverterTest(TestCase): + def test_resize_simple(self): + converter = ImageConverter(height=4, width=4, scale=1 / 255.0) + inputs = np.ones((10, 10, 3)) * 255.0 + outputs = converter(inputs) + self.assertAllClose(outputs, ops.ones((4, 4, 3))) + + def test_unbatched(self): + converter = ImageConverter( + image_size=(4, 4), + scale=(1.0 / 255.0, 0.8 / 255.0, 1.2 / 255.0), + offset=(0.2, -0.1, 0.25), + ) + inputs = np.ones((10, 10, 3)) * 128 + outputs = converter(inputs) + self.assertEqual(ops.shape(outputs), (4, 4, 3)) + self.assertAllClose(outputs[:, :, 0], np.ones((4, 4)) * 0.100392) + self.assertAllClose(outputs[:, :, 1], np.ones((4, 4)) * -0.040157) + self.assertAllClose(outputs[:, :, 2], np.ones((4, 4)) * 0.150588) + + def test_resize_batch(self): + converter = ImageConverter( + image_size=(4, 4), + scale=(1.0 / 255.0, 0.8 / 255.0, 1.2 / 255.0), + offset=(0.2, -0.1, 0.25), + ) + inputs = np.ones((2, 10, 10, 3)) * 128 + outputs = converter(inputs) + self.assertEqual(ops.shape(outputs), (2, 4, 4, 3)) + self.assertAllClose(outputs[:, :, :, 0], np.ones((2, 4, 4)) * 0.100392) + self.assertAllClose(outputs[:, :, :, 1], np.ones((2, 4, 4)) * -0.040157) + self.assertAllClose(outputs[:, :, :, 2], np.ones((2, 4, 4)) * 0.150588) + + def test_config(self): + converter = ImageConverter( + image_size=(12, 20), + scale=(0.25 / 255.0, 0.1 / 255.0, 0.5 / 255.0), + offset=(0.2, -0.1, 0.25), + crop_to_aspect_ratio=False, + interpolation="nearest", + ) + clone = ImageConverter.from_config(converter.get_config()) + test_batch = np.random.rand(4, 10, 20, 3) * 255 + self.assertAllClose(converter(test_batch), clone(test_batch)) + def test_preset_accessors(self): pali_gemma_presets = set(PaliGemmaImageConverter.presets.keys()) all_presets = set(ImageConverter.presets.keys()) diff --git a/keras_hub/src/layers/preprocessing/resizing_image_converter.py b/keras_hub/src/layers/preprocessing/resizing_image_converter.py deleted file mode 100644 index 199cbdceba..0000000000 --- a/keras_hub/src/layers/preprocessing/resizing_image_converter.py +++ /dev/null @@ -1,138 +0,0 @@ -import keras -from keras import ops - -from keras_hub.src.api_export import keras_hub_export -from keras_hub.src.layers.preprocessing.image_converter import ImageConverter -from keras_hub.src.utils.keras_utils import standardize_data_format -from keras_hub.src.utils.tensor_utils import preprocessing_function - - -@keras_hub_export("keras_hub.layers.ResizingImageConverter") -class ResizingImageConverter(ImageConverter): - """An `ImageConverter` that simply resizes the input image. - - The `ResizingImageConverter` is a subclass of `ImageConverter` for models - that need to resize (and optionally rescale) image tensors before using them - for modeling. The layer will take as input a raw image tensor (batched or - unbatched) in the channels last or channels first format, and output a - resize tensor. - - Args: - height: int, the height of the output shape. - width: int, the width of the output shape. - scale: float or `None`. If set, the image we be rescaled with a - `keras.layers.Rescaling` layer, multiplying the image by this - scale. - mean: tuples of floats per channel or `None`. If set, the image will be - normalized per channel by subtracting mean. - If set, also set `variance`. - variance: tuples of floats per channel or `None`. If set, the image will - be normalized per channel by dividing by `sqrt(variance)`. - If set, also set `mean`. - crop_to_aspect_ratio: If `True`, resize the images without aspect - ratio distortion. When the original aspect ratio differs - from the target aspect ratio, the output image will be - cropped so as to return the - largest possible window in the image (of size `(height, width)`) - that matches the target aspect ratio. By default - (`crop_to_aspect_ratio=False`), aspect ratio may not be preserved. - interpolation: String, the interpolation method. - Supports `"bilinear"`, `"nearest"`, `"bicubic"`, - `"lanczos3"`, `"lanczos5"`. Defaults to `"bilinear"`. - data_format: String, either `"channels_last"` or `"channels_first"`. - The ordering of the dimensions in the inputs. `"channels_last"` - corresponds to inputs with shape `(batch, height, width, channels)` - while `"channels_first"` corresponds to inputs with shape - `(batch, channels, height, width)`. It defaults to the - `image_data_format` value found in your Keras config file at - `~/.keras/keras.json`. If you never set it, then it will be - `"channels_last"`. - - Examples: - ```python - # Resize images for `"pali_gemma_3b_224"`. - converter = keras_hub.layers.ImageConverter.from_preset("pali_gemma_3b_224") - converter(np.ones(2, 512, 512, 3)) # Output shape: (2, 224, 224, 3) - # Resize images for `"pali_gemma_3b_224"`. - converter = keras_hub.layers.ImageConverter.from_preset("pali_gemma_3b_448") - converter(np.ones(2, 512, 512, 3)) # Output shape: (2, 448, 448, 3) - ``` - """ - - def __init__( - self, - height, - width, - scale=None, - mean=None, - variance=None, - crop_to_aspect_ratio=True, - interpolation="bilinear", - data_format=None, - **kwargs, - ): - super().__init__(**kwargs) - # By default, we just do a simple resize. Any model can subclass this - # layer for preprocessing of a raw image to a model image input. - self.resizing = keras.layers.Resizing( - height=height, - width=width, - crop_to_aspect_ratio=crop_to_aspect_ratio, - interpolation=interpolation, - data_format=data_format, - dtype=self.dtype_policy, - name="resizing", - ) - if scale is not None: - self.rescaling = keras.layers.Rescaling( - scale=scale, - dtype=self.dtype_policy, - name="rescaling", - ) - else: - self.rescaling = None - if (mean is not None) != (variance is not None): - raise ValueError( - "Both `mean` and `variance` should be set or `None`. Received " - f"`mean={mean}`, `variance={variance}`." - ) - self.scale = scale - self.mean = mean - self.variance = variance - self.data_format = standardize_data_format(data_format) - - def image_size(self): - """Returns the preprocessed size of a single image.""" - return (self.resizing.height, self.resizing.width) - - @preprocessing_function - def call(self, inputs): - x = self.resizing(inputs) - if self.rescaling: - x = self.rescaling(x) - if self.mean is not None: - # Avoid `layers.Normalization` so this works batched and unbatched. - channels_first = self.data_format == "channels_first" - if len(ops.shape(inputs)) == 3: - broadcast_dims = (1, 2) if channels_first else (0, 1) - else: - broadcast_dims = (0, 2, 3) if channels_first else (0, 1, 2) - mean = ops.expand_dims(ops.array(self.mean), broadcast_dims) - std = ops.expand_dims(ops.sqrt(self.variance), broadcast_dims) - x = (x - mean) / std - return x - - def get_config(self): - config = super().get_config() - config.update( - { - "height": self.resizing.height, - "width": self.resizing.width, - "interpolation": self.resizing.interpolation, - "crop_to_aspect_ratio": self.resizing.crop_to_aspect_ratio, - "scale": self.scale, - "mean": self.mean, - "variance": self.variance, - } - ) - return config diff --git a/keras_hub/src/layers/preprocessing/resizing_image_converter_test.py b/keras_hub/src/layers/preprocessing/resizing_image_converter_test.py deleted file mode 100644 index 052ae9d526..0000000000 --- a/keras_hub/src/layers/preprocessing/resizing_image_converter_test.py +++ /dev/null @@ -1,67 +0,0 @@ -import numpy as np -from keras import ops - -from keras_hub.src.layers.preprocessing.resizing_image_converter import ( - ResizingImageConverter, -) -from keras_hub.src.tests.test_case import TestCase - - -class ResizingImageConverterTest(TestCase): - def test_resize_simple(self): - converter = ResizingImageConverter(height=4, width=4) - inputs = np.ones((10, 10, 3)) - outputs = converter(inputs) - self.assertAllClose(outputs, ops.ones((4, 4, 3))) - - def test_resize_one(self): - converter = ResizingImageConverter( - height=4, - width=4, - mean=(0.5, 0.7, 0.3), - variance=(0.25, 0.1, 0.5), - scale=1 / 255.0, - ) - inputs = np.ones((10, 10, 3)) * 128 - outputs = converter(inputs) - self.assertEqual(ops.shape(outputs), (4, 4, 3)) - self.assertAllClose(outputs[:, :, 0], np.ones((4, 4)) * 0.003922) - self.assertAllClose(outputs[:, :, 1], np.ones((4, 4)) * -0.626255) - self.assertAllClose(outputs[:, :, 2], np.ones((4, 4)) * 0.285616) - - def test_resize_batch(self): - converter = ResizingImageConverter( - height=4, - width=4, - mean=(0.5, 0.7, 0.3), - variance=(0.25, 0.1, 0.5), - scale=1 / 255.0, - ) - inputs = np.ones((2, 10, 10, 3)) * 128 - outputs = converter(inputs) - self.assertEqual(ops.shape(outputs), (2, 4, 4, 3)) - self.assertAllClose(outputs[:, :, :, 0], np.ones((2, 4, 4)) * 0.003922) - self.assertAllClose(outputs[:, :, :, 1], np.ones((2, 4, 4)) * -0.626255) - self.assertAllClose(outputs[:, :, :, 2], np.ones((2, 4, 4)) * 0.285616) - - def test_errors(self): - with self.assertRaises(ValueError): - ResizingImageConverter( - height=4, - width=4, - mean=(0.5, 0.7, 0.3), - ) - - def test_config(self): - converter = ResizingImageConverter( - width=12, - height=20, - mean=(0.5, 0.7, 0.3), - variance=(0.25, 0.1, 0.5), - scale=1 / 255.0, - crop_to_aspect_ratio=False, - interpolation="nearest", - ) - clone = ResizingImageConverter.from_config(converter.get_config()) - test_batch = np.random.rand(4, 10, 20, 3) * 255 - self.assertAllClose(converter(test_batch), clone(test_batch)) diff --git a/keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py b/keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py index 7485204b7a..c0a5b1bb3e 100644 --- a/keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +++ b/keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py @@ -1,5 +1,3 @@ -import keras - from keras_hub.src.api_export import keras_hub_export from keras_hub.src.models.csp_darknet.csp_darknet_backbone import ( CSPDarkNetBackbone, @@ -9,111 +7,4 @@ @keras_hub_export("keras_hub.models.CSPDarkNetImageClassifier") class CSPDarkNetImageClassifier(ImageClassifier): - """CSPDarkNet image classifier task model. - - Args: - backbone: A `keras_hub.models.CSPDarkNetBackbone` instance. - num_classes: int. The number of classes to predict. - activation: `None`, str or callable. The activation function to use on - the `Dense` layer. Set `activation=None` to return the output - logits. Defaults to `"softmax"`. - - To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` - where `x` is a tensor and `y` is a integer from `[0, num_classes)`. - All `ImageClassifier` tasks include a `from_preset()` constructor which can - be used to load a pre-trained config and weights. - - Examples: - - Call `predict()` to run inference. - ```python - # Load preset and train - images = np.ones((2, 224, 224, 3), dtype="float32") - classifier = keras_hub.models.CSPDarkNetImageClassifier.from_preset( - "csp_darknet_tiny_imagenet") - classifier.predict(images) - ``` - - Call `fit()` on a single batch. - ```python - # Load preset and train - images = np.ones((2, 224, 224, 3), dtype="float32") - labels = [0, 3] - classifier = keras_hub.models.CSPDarkNetImageClassifier.from_preset( - "csp_darknet_tiny_imagenet") - classifier.fit(x=images, y=labels, batch_size=2) - ``` - - Call `fit()` with custom loss, optimizer and backbone. - ```python - classifier = keras_hub.models.CSPDarkNetImageClassifier.from_preset( - "csp_darknet_tiny_imagenet") - classifier.compile( - loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), - optimizer=keras.optimizers.Adam(5e-5), - ) - classifier.backbone.trainable = False - classifier.fit(x=images, y=labels, batch_size=2) - ``` - - Custom backbone. - ```python - images = np.ones((2, 224, 224, 3), dtype="float32") - labels = [0, 3] - backbone = keras_hub.models.CSPDarkNetBackbone( - stackwise_num_filters=[128, 256, 512, 1024], - stackwise_depth=[3, 9, 9, 3], - block_type="basic_block", - image_shape = (224, 224, 3), - ) - classifier = keras_hub.models.CSPDarkNetImageClassifier( - backbone=backbone, - num_classes=4, - ) - classifier.fit(x=images, y=labels, batch_size=2) - ``` - """ - backbone_cls = CSPDarkNetBackbone - - def __init__( - self, - backbone, - num_classes, - activation="softmax", - preprocessor=None, # adding this dummy arg for saved model test - # TODO: once preprocessor flow is figured out, this needs to be updated - **kwargs, - ): - # === Layers === - self.backbone = backbone - self.output_dense = keras.layers.Dense( - num_classes, - activation=activation, - name="predictions", - ) - - # === Functional Model === - inputs = self.backbone.input - x = self.backbone(inputs) - outputs = self.output_dense(x) - super().__init__( - inputs=inputs, - outputs=outputs, - **kwargs, - ) - - # === Config === - self.num_classes = num_classes - self.activation = activation - - def get_config(self): - # Backbone serialized in `super` - config = super().get_config() - config.update( - { - "num_classes": self.num_classes, - "activation": self.activation, - } - ) - return config diff --git a/keras_hub/src/models/densenet/densenet_image_classifier.py b/keras_hub/src/models/densenet/densenet_image_classifier.py index 64f28521a3..748ccfc0e6 100644 --- a/keras_hub/src/models/densenet/densenet_image_classifier.py +++ b/keras_hub/src/models/densenet/densenet_image_classifier.py @@ -1,5 +1,3 @@ -import keras - from keras_hub.src.api_export import keras_hub_export from keras_hub.src.models.densenet.densenet_backbone import DenseNetBackbone from keras_hub.src.models.densenet.densenet_image_classifier_preprocessor import ( @@ -10,131 +8,5 @@ @keras_hub_export("keras_hub.models.DenseNetImageClassifier") class DenseNetImageClassifier(ImageClassifier): - """DenseNet image classifier task model. - - To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` - where `x` is a tensor and `y` is a integer from `[0, num_classes)`. - All `ImageClassifier` tasks include a `from_preset()` constructor which can - be used to load a pre-trained config and weights. - - Args: - backbone: A `keras_hub.models.DenseNetBackbone` instance. - num_classes: int. The number of classes to predict. - activation: `None`, str or callable. The activation function to use on - the `Dense` layer. Set `activation=None` to return the output - logits. Defaults to `None`. - pooling: A pooling layer to use before the final classification layer, - must be one of "avg" or "max". Use "avg" for - `GlobalAveragePooling2D` and "max" for "GlobalMaxPooling2D. - preprocessor: A `keras_hub.models.DenseNetImageClassifierPreprocessor` - or `None`. If `None`, this model will not apply preprocessing, and - inputs should be preprocessed before calling the model. - - Examples: - - Call `predict()` to run inference. - ```python - # Load preset and train - images = np.ones((2, 224, 224, 3), dtype="float32") - classifier = keras_hub.models.DenseNetImageClassifier.from_preset( - "densenet121_imagenet") - classifier.predict(images) - ``` - - Call `fit()` on a single batch. - ```python - # Load preset and train - images = np.ones((2, 224, 224, 3), dtype="float32") - labels = [0, 3] - classifier = keras_hub.models.DenseNetImageClassifier.from_preset( - "densenet121_imagenet") - classifier.fit(x=images, y=labels, batch_size=2) - ``` - - Call `fit()` with custom loss, optimizer and backbone. - ```python - classifier = keras_hub.models.DenseNetImageClassifier.from_preset( - "densenet121_imagenet") - classifier.compile( - loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), - optimizer=keras.optimizers.Adam(5e-5), - ) - classifier.backbone.trainable = False - classifier.fit(x=images, y=labels, batch_size=2) - ``` - - Custom backbone. - ```python - images = np.ones((2, 224, 224, 3), dtype="float32") - labels = [0, 3] - backbone = keras_hub.models.DenseNetBackbone( - stackwise_num_filters=[128, 256, 512, 1024], - stackwise_depth=[3, 9, 9, 3], - block_type="basic_block", - image_shape = (224, 224, 3), - ) - classifier = keras_hub.models.DenseNetImageClassifier( - backbone=backbone, - num_classes=4, - ) - classifier.fit(x=images, y=labels, batch_size=2) - ``` - """ - backbone_cls = DenseNetBackbone preprocessor_cls = DenseNetImageClassifierPreprocessor - - def __init__( - self, - backbone, - num_classes, - activation=None, - pooling="avg", - preprocessor=None, - **kwargs, - ): - # === Layers === - self.backbone = backbone - self.preprocessor = preprocessor - if pooling == "avg": - self.pooler = keras.layers.GlobalAveragePooling2D() - elif pooling == "max": - self.pooler = keras.layers.GlobalMaxPooling2D() - else: - raise ValueError( - "Unknown `pooling` type. Polling should be either `'avg'` or " - f"`'max'`. Received: pooling={pooling}." - ) - self.output_dense = keras.layers.Dense( - num_classes, - activation=activation, - name="predictions", - ) - - # === Functional Model === - inputs = self.backbone.input - x = self.backbone(inputs) - x = self.pooler(x) - outputs = self.output_dense(x) - super().__init__( - inputs=inputs, - outputs=outputs, - **kwargs, - ) - - # === Config === - self.num_classes = num_classes - self.activation = activation - self.pooling = pooling - - def get_config(self): - # Backbone serialized in `super` - config = super().get_config() - config.update( - { - "num_classes": self.num_classes, - "activation": self.activation, - "pooling": self.pooling, - } - ) - return config diff --git a/keras_hub/src/models/densenet/densenet_image_converter.py b/keras_hub/src/models/densenet/densenet_image_converter.py index 1579cd8ba1..b15c8012f6 100644 --- a/keras_hub/src/models/densenet/densenet_image_converter.py +++ b/keras_hub/src/models/densenet/densenet_image_converter.py @@ -1,10 +1,8 @@ from keras_hub.src.api_export import keras_hub_export -from keras_hub.src.layers.preprocessing.resizing_image_converter import ( - ResizingImageConverter, -) +from keras_hub.src.layers.preprocessing.image_converter import ImageConverter from keras_hub.src.models.densenet.densenet_backbone import DenseNetBackbone @keras_hub_export("keras_hub.layers.DenseNetImageConverter") -class DenseNetImageConverter(ResizingImageConverter): +class DenseNetImageConverter(ImageConverter): backbone_cls = DenseNetBackbone diff --git a/keras_hub/src/models/feature_pyramid_backbone.py b/keras_hub/src/models/feature_pyramid_backbone.py index c2569a6c01..99549b3de5 100644 --- a/keras_hub/src/models/feature_pyramid_backbone.py +++ b/keras_hub/src/models/feature_pyramid_backbone.py @@ -15,7 +15,7 @@ class FeaturePyramidBackbone(Backbone): Example: ```python - input_data = np.random.uniform(0, 255, size=(2, 224, 224, 3)) + input_data = np.random.uniform(0, 256, size=(2, 224, 224, 3)) # Convert to feature pyramid output format using ResNet. backbone = ResNetBackbone.from_preset("resnet50") diff --git a/keras_hub/src/models/image_classifier.py b/keras_hub/src/models/image_classifier.py index 7c266dfe83..23945cf755 100644 --- a/keras_hub/src/models/image_classifier.py +++ b/keras_hub/src/models/image_classifier.py @@ -16,10 +16,145 @@ class ImageClassifier(Task): To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` labels where `x` is a string and `y` is a integer from `[0, num_classes)`. - All `ImageClassifier` tasks include a `from_preset()` constructor which can be - used to load a pre-trained config and weights. + Args: + backbone: A `keras_hub.models.Backbone` instance or a `keras.Model`. + num_classes: int. The number of classes to predict. + preprocessor: `None`, a `keras_hub.models.Preprocessor` instance, + a `keras.Layer` instance, or a callable. If `None` no preprocessing + will be applied to the inputs. + pooling: `"avg"` or `"max"`. The type of pooling to apply on backbone + output. Default to average pooling. + activation: `None`, str, or callable. The activation function to use on + the `Dense` layer. Set `activation=None` to return the output + logits. Defaults to `"softmax"`. + head_dtype: `None`, str, or `keras.mixed_precision.DTypePolicy`. The + dtype to use for the classification head's computations and weights. + + To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` + where `x` is a tensor and `y` is a integer from `[0, num_classes)`. + All `ImageClassifier` tasks include a `from_preset()` constructor which can + be used to load a pre-trained config and weights. + + Examples: + + Call `predict()` to run inference. + ```python + # Load preset and train + images = np.random.randint(0, 256, size=(2, 224, 224, 3)) + classifier = keras_hub.models.ImageClassifier.from_preset( + "resnet_50_imagenet" + ) + classifier.predict(images) + ``` + + Call `fit()` on a single batch. + ```python + # Load preset and train + images = np.random.randint(0, 256, size=(2, 224, 224, 3)) + labels = [0, 3] + classifier = keras_hub.models.ImageClassifier.from_preset( + "resnet_50_imagenet" + ) + classifier.fit(x=images, y=labels, batch_size=2) + ``` + + Call `fit()` with custom loss, optimizer and backbone. + ```python + classifier = keras_hub.models.ImageClassifier.from_preset( + "resnet_50_imagenet" + ) + classifier.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(5e-5), + ) + classifier.backbone.trainable = False + classifier.fit(x=images, y=labels, batch_size=2) + ``` + + Custom backbone. + ```python + images = np.random.randint(0, 256, size=(2, 224, 224, 3)) + labels = [0, 3] + backbone = keras_hub.models.ResNetBackbone( + stackwise_num_filters=[64, 64, 64], + stackwise_num_blocks=[2, 2, 2], + stackwise_num_strides=[1, 2, 2], + block_type="basic_block", + use_pre_activation=True, + pooling="avg", + ) + classifier = keras_hub.models.ImageClassifier( + backbone=backbone, + num_classes=4, + ) + classifier.fit(x=images, y=labels, batch_size=2) + ``` """ + def __init__( + self, + backbone, + num_classes, + preprocessor=None, + pooling="avg", + activation=None, + head_dtype=None, + **kwargs, + ): + head_dtype = head_dtype or backbone.dtype_policy + data_format = getattr(backbone, "data_format", None) + + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + if pooling == "avg": + self.pooler = keras.layers.GlobalAveragePooling2D( + data_format, dtype=head_dtype + ) + elif pooling == "max": + self.pooler = keras.layers.GlobalMaxPooling2D( + data_format, dtype=head_dtype + ) + else: + raise ValueError( + "Unknown `pooling` type. Polling should be either `'avg'` or " + f"`'max'`. Received: pooling={pooling}." + ) + self.output_dense = keras.layers.Dense( + num_classes, + activation=activation, + dtype=head_dtype, + name="predictions", + ) + + # === Functional Model === + inputs = self.backbone.input + x = self.backbone(inputs) + x = self.pooler(x) + outputs = self.output_dense(x) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + # === Config === + self.num_classes = num_classes + self.activation = activation + self.pooling = pooling + + def get_config(self): + # Backbone serialized in `super` + config = super().get_config() + config.update( + { + "num_classes": self.num_classes, + "pooling": self.pooling, + "activation": self.activation, + } + ) + return config + def compile( self, optimizer="auto", diff --git a/keras_hub/src/models/image_classifier_preprocessor.py b/keras_hub/src/models/image_classifier_preprocessor.py index cb62844a4e..ee905f3e18 100644 --- a/keras_hub/src/models/image_classifier_preprocessor.py +++ b/keras_hub/src/models/image_classifier_preprocessor.py @@ -38,15 +38,15 @@ class ImageClassifierPreprocessor(Preprocessor): ) # Resize a single image for resnet 50. - x = np.ones((512, 512, 3)) + x = np.random.randint(0, 256, (512, 512, 3)) x = preprocessor(x) # Resize a labeled image. - x, y = np.ones((512, 512, 3)), 1 + x, y = np.random.randint(0, 256, (512, 512, 3)), 1 x, y = preprocessor(x, y) # Resize a batch of labeled images. - x, y = [np.ones((512, 512, 3)), np.zeros((512, 512, 3))], [1, 0] + x, y = [np.random.randint(0, 256, (512, 512, 3)), np.zeros((512, 512, 3))], [1, 0] x, y = preprocessor(x, y) # Use a `tf.data.Dataset`. diff --git a/keras_hub/src/models/mix_transformer/mix_transformer_classifier.py b/keras_hub/src/models/mix_transformer/mix_transformer_classifier.py index 5fea71f417..0daac9327f 100644 --- a/keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +++ b/keras_hub/src/models/mix_transformer/mix_transformer_classifier.py @@ -1,5 +1,3 @@ -import keras - from keras_hub.src.api_export import keras_hub_export from keras_hub.src.models.image_classifier import ImageClassifier from keras_hub.src.models.mix_transformer.mix_transformer_backbone import ( @@ -9,111 +7,4 @@ @keras_hub_export("keras_hub.models.MiTImageClassifier") class MiTImageClassifier(ImageClassifier): - """MiTImageClassifier image classifier model. - - Args: - backbone: A `keras_hub.models.MiTBackbone` instance. - num_classes: int. The number of classes to predict. - activation: `None`, str or callable. The activation function to use on - the `Dense` layer. Set `activation=None` to return the output - logits. Defaults to `"softmax"`. - - To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` - where `x` is a tensor and `y` is a integer from `[0, num_classes)`. - All `ImageClassifier` tasks include a `from_preset()` constructor which can - be used to load a pre-trained config and weights. - - Examples: - - Call `predict()` to run inference. - ```python - # Load preset and train - images = np.ones((2, 224, 224, 3), dtype="float32") - classifier = keras_hub.models.MiTImageClassifier.from_preset( - "mit_b0_imagenet") - classifier.predict(images) - ``` - - Call `fit()` on a single batch. - ```python - # Load preset and train - images = np.ones((2, 224, 224, 3), dtype="float32") - labels = [0, 3] - classifier = keras_hub.models.MixTransformerImageClassifier.from_preset( - "mit_b0_imagenet") - classifier.fit(x=images, y=labels, batch_size=2) - ``` - - Call `fit()` with custom loss, optimizer and backbone. - ```python - classifier = keras_hub.models.MiTImageClassifier.from_preset( - "mit_b0_imagenet") - classifier.compile( - loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), - optimizer=keras.optimizers.Adam(5e-5), - ) - classifier.backbone.trainable = False - classifier.fit(x=images, y=labels, batch_size=2) - ``` - - Custom backbone. - ```python - images = np.ones((2, 224, 224, 3), dtype="float32") - labels = [0, 3] - backbone = keras_hub.models.MiTBackbone( - stackwise_num_filters=[128, 256, 512, 1024], - stackwise_depth=[3, 9, 9, 3], - block_type="basic_block", - image_shape = (224, 224, 3), - ) - classifier = keras_hub.models.MiTImageClassifier( - backbone=backbone, - num_classes=4, - ) - classifier.fit(x=images, y=labels, batch_size=2) - ``` - """ - backbone_cls = MiTBackbone - - def __init__( - self, - backbone, - num_classes, - activation="softmax", - preprocessor=None, # adding this dummy arg for saved model test - # TODO: once preprocessor flow is figured out, this needs to be updated - **kwargs, - ): - # === Layers === - self.backbone = backbone - self.output_dense = keras.layers.Dense( - num_classes, - activation=activation, - name="predictions", - ) - - # === Functional Model === - inputs = self.backbone.input - x = self.backbone(inputs) - outputs = self.output_dense(x) - super().__init__( - inputs=inputs, - outputs=outputs, - **kwargs, - ) - - # === Config === - self.num_classes = num_classes - self.activation = activation - - def get_config(self): - # Backbone serialized in `super` - config = super().get_config() - config.update( - { - "num_classes": self.num_classes, - "activation": self.activation, - } - ) - return config diff --git a/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py b/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py index ff02ec0c6a..96977bdf9f 100644 --- a/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +++ b/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py @@ -1,5 +1,3 @@ -import keras - from keras_hub.src.api_export import keras_hub_export from keras_hub.src.models.image_classifier import ImageClassifier from keras_hub.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone @@ -7,94 +5,4 @@ @keras_hub_export("keras_hub.models.MobileNetImageClassifier") class MobileNetImageClassifier(ImageClassifier): - """MobileNetV3 image classifier task model. - - To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` - where `x` is a tensor and `y` is a integer from `[0, num_classes)`. - All `ImageClassifier` tasks include a `from_preset()` constructor which can - be used to load a pre-trained config and weights. - - Args: - backbone: A `keras_hub.models.MobileNetBackbone` instance. - num_classes: int. The number of classes to predict. - activation: `None`, str or callable. The activation function to use on - the `Dense` layer. Set `activation=None` to return the output - logits. Defaults to `"softmax"`. - - Examples: - - Call `predict()` to run inference. - ```python - # Load preset and train - images = np.ones((2, 224, 224, 3), dtype="float32") - classifier = keras_hub.models.MobileNetImageClassifier.from_preset( - "mobilenet_v3_small_imagenet") - classifier.predict(images) - ``` - - Custom backbone. - ```python - images = np.ones((2, 224, 224, 3), dtype="float32") - labels = [0, 3] - model = MobileNetBackbone( - stackwise_expansion = [1, 4, 6], - stackwise_filters = [4, 8, 16], - stackwise_kernel_size = [3, 3, 5], - stackwise_stride = [2, 2, 1], - stackwise_se_ratio = [ 0.25, None, 0.25], - stackwise_activation = ["relu", "relu", "hard_swish"], - output_filter=1280, - activation="hard_swish", - inverted_res_block=True, - ) - classifier = keras_hub.models.MobileNetImageClassifier( - backbone=backbone, - num_classes=4, - ) - classifier.fit(x=images, y=labels, batch_size=2) - ``` - """ - backbone_cls = MobileNetBackbone - - def __init__( - self, - backbone, - num_classes, - activation="softmax", - preprocessor=None, # adding this dummy arg for saved model test - # TODO: once preprocessor flow is figured out, this needs to be updated - **kwargs, - ): - # === Layers === - self.backbone = backbone - self.output_dense = keras.layers.Dense( - num_classes, - activation=activation, - name="predictions", - ) - - # === Functional Model === - inputs = self.backbone.input - x = self.backbone(inputs) - outputs = self.output_dense(x) - super().__init__( - inputs=inputs, - outputs=outputs, - **kwargs, - ) - - # === Config === - self.num_classes = num_classes - self.activation = activation - - def get_config(self): - # Backbone serialized in `super` - config = super().get_config() - config.update( - { - "num_classes": self.num_classes, - "activation": self.activation, - } - ) - return config diff --git a/keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_test.py b/keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_test.py index dd24e86fce..1f53cdef04 100644 --- a/keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_test.py +++ b/keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_test.py @@ -35,7 +35,9 @@ def setUp(self): tokenizer = PaliGemmaTokenizer( os.path.join(self.get_test_data_dir(), proto) ) - image_converter = PaliGemmaImageConverter(16, 16) + image_converter = PaliGemmaImageConverter( + image_size=(16, 16), + ) self.vocabulary_size = tokenizer.vocabulary_size() self.preprocessor = PaliGemmaCausalLMPreprocessor( tokenizer, diff --git a/keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py b/keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py index b623619025..c06905eb0d 100644 --- a/keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +++ b/keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py @@ -1,12 +1,10 @@ from keras_hub.src.api_export import keras_hub_export -from keras_hub.src.layers.preprocessing.resizing_image_converter import ( - ResizingImageConverter, -) +from keras_hub.src.layers.preprocessing.image_converter import ImageConverter from keras_hub.src.models.pali_gemma.pali_gemma_backbone import ( PaliGemmaBackbone, ) @keras_hub_export("keras_hub.layers.PaliGemmaImageConverter") -class PaliGemmaImageConverter(ResizingImageConverter): +class PaliGemmaImageConverter(ImageConverter): backbone_cls = PaliGemmaBackbone diff --git a/keras_hub/src/models/resnet/resnet_backbone.py b/keras_hub/src/models/resnet/resnet_backbone.py index 088d0c7ddb..bc8def804a 100644 --- a/keras_hub/src/models/resnet/resnet_backbone.py +++ b/keras_hub/src/models/resnet/resnet_backbone.py @@ -51,16 +51,6 @@ class ResNetBackbone(FeaturePyramidBackbone): `True` for ResNetV2, `False` for ResNet. image_shape: tuple. The input shape without the batch size. Defaults to `(None, None, 3)`. - pooling: `None` or str. Pooling mode for feature extraction. Defaults - to `"avg"`. - - `None` means that the output of the model will be the 4D tensor - from the last convolutional block. - - `avg` means that global average pooling will be applied to the - output of the last convolutional block, resulting in a 2D - tensor. - - `max` means that global max pooling will be applied to the - output of the last convolutional block, resulting in a 2D - tensor. data_format: `None` or str. If specified, either `"channels_last"` or `"channels_first"`. The ordering of the dimensions in the inputs. `"channels_last"` corresponds to inputs with shape @@ -75,7 +65,7 @@ class ResNetBackbone(FeaturePyramidBackbone): Examples: ```python - input_data = np.random.uniform(0, 255, size=(2, 224, 224, 3)) + input_data = np.random.uniform(0, 1, size=(2, 224, 224, 3)) # Pretrained ResNet backbone. model = keras_hub.models.ResNetBackbone.from_preset("resnet50") diff --git a/keras_hub/src/models/resnet/resnet_image_classifier.py b/keras_hub/src/models/resnet/resnet_image_classifier.py index 50c34df37b..e278974ed8 100644 --- a/keras_hub/src/models/resnet/resnet_image_classifier.py +++ b/keras_hub/src/models/resnet/resnet_image_classifier.py @@ -1,5 +1,3 @@ -import keras - from keras_hub.src.api_export import keras_hub_export from keras_hub.src.models.image_classifier import ImageClassifier from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone @@ -10,140 +8,5 @@ @keras_hub_export("keras_hub.models.ResNetImageClassifier") class ResNetImageClassifier(ImageClassifier): - """ResNet image classifier task model. - - Args: - backbone: A `keras_hub.models.ResNetBackbone` instance. - num_classes: int. The number of classes to predict. - activation: `None`, str or callable. The activation function to use on - the `Dense` layer. Set `activation=None` to return the output - logits. Defaults to `"softmax"`. - head_dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The - dtype to use for the classification head's computations and weights. - - To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` - where `x` is a tensor and `y` is a integer from `[0, num_classes)`. - All `ImageClassifier` tasks include a `from_preset()` constructor which can - be used to load a pre-trained config and weights. - - Examples: - - Call `predict()` to run inference. - ```python - # Load preset and train - images = np.ones((2, 224, 224, 3), dtype="float32") - classifier = keras_hub.models.ResNetImageClassifier.from_preset( - "resnet_50_imagenet" - ) - classifier.predict(images) - ``` - - Call `fit()` on a single batch. - ```python - # Load preset and train - images = np.ones((2, 224, 224, 3), dtype="float32") - labels = [0, 3] - classifier = keras_hub.models.ResNetImageClassifier.from_preset( - "resnet_50_imagenet" - ) - classifier.fit(x=images, y=labels, batch_size=2) - ``` - - Call `fit()` with custom loss, optimizer and backbone. - ```python - classifier = keras_hub.models.ResNetImageClassifier.from_preset( - "resnet_50_imagenet" - ) - classifier.compile( - loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), - optimizer=keras.optimizers.Adam(5e-5), - ) - classifier.backbone.trainable = False - classifier.fit(x=images, y=labels, batch_size=2) - ``` - - Custom backbone. - ```python - images = np.ones((2, 224, 224, 3), dtype="float32") - labels = [0, 3] - backbone = keras_hub.models.ResNetBackbone( - stackwise_num_filters=[64, 64, 64], - stackwise_num_blocks=[2, 2, 2], - stackwise_num_strides=[1, 2, 2], - block_type="basic_block", - use_pre_activation=True, - pooling="avg", - ) - classifier = keras_hub.models.ResNetImageClassifier( - backbone=backbone, - num_classes=4, - ) - classifier.fit(x=images, y=labels, batch_size=2) - ``` - """ - backbone_cls = ResNetBackbone preprocessor_cls = ResNetImageClassifierPreprocessor - - def __init__( - self, - backbone, - num_classes, - preprocessor=None, - pooling="avg", - activation=None, - head_dtype=None, - **kwargs, - ): - head_dtype = head_dtype or backbone.dtype_policy - - # === Layers === - self.backbone = backbone - self.preprocessor = preprocessor - if pooling == "avg": - self.pooler = keras.layers.GlobalAveragePooling2D( - data_format=backbone.data_format, dtype=head_dtype - ) - elif pooling == "max": - self.pooler = keras.layers.GlobalAveragePooling2D( - data_format=backbone.data_format, dtype=head_dtype - ) - else: - raise ValueError( - "Unknown `pooling` type. Polling should be either `'avg'` or " - f"`'max'`. Received: pooling={pooling}." - ) - self.output_dense = keras.layers.Dense( - num_classes, - activation=activation, - dtype=head_dtype, - name="predictions", - ) - - # === Functional Model === - inputs = self.backbone.input - x = self.backbone(inputs) - x = self.pooler(x) - outputs = self.output_dense(x) - super().__init__( - inputs=inputs, - outputs=outputs, - **kwargs, - ) - - # === Config === - self.num_classes = num_classes - self.activation = activation - self.pooling = pooling - - def get_config(self): - # Backbone serialized in `super` - config = super().get_config() - config.update( - { - "num_classes": self.num_classes, - "pooling": self.pooling, - "activation": self.activation, - } - ) - return config diff --git a/keras_hub/src/models/resnet/resnet_image_converter.py b/keras_hub/src/models/resnet/resnet_image_converter.py index 34b3dd431c..5d7d777fe9 100644 --- a/keras_hub/src/models/resnet/resnet_image_converter.py +++ b/keras_hub/src/models/resnet/resnet_image_converter.py @@ -1,10 +1,8 @@ from keras_hub.src.api_export import keras_hub_export -from keras_hub.src.layers.preprocessing.resizing_image_converter import ( - ResizingImageConverter, -) +from keras_hub.src.layers.preprocessing.image_converter import ImageConverter from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone @keras_hub_export("keras_hub.layers.ResNetImageConverter") -class ResNetImageConverter(ResizingImageConverter): +class ResNetImageConverter(ImageConverter): backbone_cls = ResNetBackbone diff --git a/keras_hub/src/models/sam/sam_image_converter.py b/keras_hub/src/models/sam/sam_image_converter.py index 2ecf09878b..3ee206122a 100644 --- a/keras_hub/src/models/sam/sam_image_converter.py +++ b/keras_hub/src/models/sam/sam_image_converter.py @@ -1,10 +1,8 @@ from keras_hub.src.api_export import keras_hub_export -from keras_hub.src.layers.preprocessing.resizing_image_converter import ( - ResizingImageConverter, -) +from keras_hub.src.layers.preprocessing.image_converter import ImageConverter from keras_hub.src.models.sam.sam_backbone import SAMBackbone @keras_hub_export("keras_hub.layers.SAMImageConverter") -class SAMImageConverter(ResizingImageConverter): +class SAMImageConverter(ImageConverter): backbone_cls = SAMBackbone diff --git a/keras_hub/src/models/task.py b/keras_hub/src/models/task.py index 080c67c221..ea2df05bea 100644 --- a/keras_hub/src/models/task.py +++ b/keras_hub/src/models/task.py @@ -335,7 +335,7 @@ def add_layer(layer, info): image_converter = self.preprocessor.image_converter if image_converter: info = "Image size: " - info += highlight_shape(image_converter.image_size()) + info += highlight_shape(image_converter.image_size) add_layer(image_converter, info) audio_converter = self.preprocessor.audio_converter if audio_converter: diff --git a/keras_hub/src/models/vgg/vgg_backbone.py b/keras_hub/src/models/vgg/vgg_backbone.py index 7855e8a735..902f392962 100644 --- a/keras_hub/src/models/vgg/vgg_backbone.py +++ b/keras_hub/src/models/vgg/vgg_backbone.py @@ -57,7 +57,6 @@ def __init__( stackwise_num_repeats, stackwise_num_filters, image_shape=(224, 224, 3), - pooling="avg", **kwargs, ): @@ -76,10 +75,6 @@ def __init__( max_pool=True, name=f"block{stack_index + 1}", ) - if pooling == "avg": - x = layers.GlobalAveragePooling2D()(x) - elif pooling == "max": - x = layers.GlobalMaxPooling2D()(x) super().__init__(inputs=img_input, outputs=x, **kwargs) @@ -87,14 +82,12 @@ def __init__( self.stackwise_num_repeats = stackwise_num_repeats self.stackwise_num_filters = stackwise_num_filters self.image_shape = image_shape - self.pooling = pooling def get_config(self): return { "stackwise_num_repeats": self.stackwise_num_repeats, "stackwise_num_filters": self.stackwise_num_filters, "image_shape": self.image_shape, - "pooling": self.pooling, } diff --git a/keras_hub/src/models/vgg/vgg_backbone_test.py b/keras_hub/src/models/vgg/vgg_backbone_test.py index 7226e7a8c5..87e9ed6ef5 100644 --- a/keras_hub/src/models/vgg/vgg_backbone_test.py +++ b/keras_hub/src/models/vgg/vgg_backbone_test.py @@ -11,7 +11,6 @@ def setUp(self): "stackwise_num_repeats": [2, 3, 3], "stackwise_num_filters": [8, 64, 64], "image_shape": (16, 16, 3), - "pooling": "avg", } self.input_data = np.ones((2, 16, 16, 3), dtype="float32") @@ -20,7 +19,7 @@ def test_backbone_basics(self): cls=VGGBackbone, init_kwargs=self.init_kwargs, input_data=self.input_data, - expected_output_shape=(2, 64), + expected_output_shape=(2, 4, 4, 64), run_mixed_precision_check=False, ) diff --git a/keras_hub/src/models/vgg/vgg_image_classifier.py b/keras_hub/src/models/vgg/vgg_image_classifier.py index c535f4ac06..40adf69911 100644 --- a/keras_hub/src/models/vgg/vgg_image_classifier.py +++ b/keras_hub/src/models/vgg/vgg_image_classifier.py @@ -1,5 +1,3 @@ -import keras - from keras_hub.src.api_export import keras_hub_export from keras_hub.src.models.image_classifier import ImageClassifier from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone @@ -7,104 +5,4 @@ @keras_hub_export("keras_hub.models.VGGImageClassifier") class VGGImageClassifier(ImageClassifier): - """VGG16 image classifier task model. - - Args: - backbone: A `keras_hub.models.VGGBackbone` instance. - num_classes: int, number of classes to predict. - pooling: str, type of pooling layer. Must be one of "avg", "max". - activation: Optional `str` or callable, defaults to "softmax". The - activation function to use on the Dense layer. Set `activation=None` - to return the output logits. - - To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` - labels where `x` is a string and `y` is a integer from `[0, num_classes)`. - All `ImageClassifier` tasks include a `from_preset()` constructor which can be - used to load a pre-trained config and weights. - - Examples: - Train from preset - ```python - # Load preset and train - images = np.ones((2, 224, 224, 3), dtype="float32") - labels = [0, 3] - classifier = keras_hub.models.VGGImageClassifier.from_preset( - 'vgg_16_image_classifier') - classifier.fit(x=images, y=labels, batch_size=2) - - # Re-compile (e.g., with a new learning rate). - classifier.compile( - loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), - optimizer=keras.optimizers.Adam(5e-5), - jit_compile=True, - ) - - # Access backbone programmatically (e.g., to change `trainable`). - classifier.backbone.trainable = False - # Fit again. - classifier.fit(x=images, y=labels, batch_size=2) - ``` - Custom backbone - ```python - images = np.ones((2, 224, 224, 3), dtype="float32") - labels = [0, 3] - - backbone = keras_hub.models.VGGBackbone( - stackwise_num_repeats = [2, 2, 3, 3, 3], - stackwise_num_filters = [64, 128, 256, 512, 512], - image_shape = (224, 224, 3), - pooling = "avg", - ) - classifier = keras_hub.models.VGGImageClassifier( - backbone=backbone, - num_classes=4, - ) - classifier.fit(x=images, y=labels, batch_size=2) - ``` - """ - backbone_cls = VGGBackbone - - def __init__( - self, - backbone, - num_classes, - activation="softmax", - preprocessor=None, # adding this dummy arg for saved model test - # TODO: once preprocessor flow is figured out, this needs to be updated - **kwargs, - ): - # === Layers === - self.backbone = backbone - self.output_dense = keras.layers.Dense( - num_classes, - activation=activation, - name="predictions", - ) - - # === Functional Model === - inputs = self.backbone.input - x = self.backbone(inputs) - outputs = self.output_dense(x) - - # Instantiate using Functional API Model constructor - super().__init__( - inputs=inputs, - outputs=outputs, - **kwargs, - ) - - # === Config === - self.num_classes = num_classes - self.activation = activation - - def get_config(self): - # Backbone serialized in `super` - config = super().get_config() - config.update( - { - "num_classes": self.num_classes, - "activation": self.activation, - } - ) - return config diff --git a/keras_hub/src/models/vgg/vgg_image_classifier_test.py b/keras_hub/src/models/vgg/vgg_image_classifier_test.py index c2d18904c8..ce7b63bad4 100644 --- a/keras_hub/src/models/vgg/vgg_image_classifier_test.py +++ b/keras_hub/src/models/vgg/vgg_image_classifier_test.py @@ -15,12 +15,12 @@ def setUp(self): stackwise_num_repeats=[2, 4, 4], stackwise_num_filters=[2, 16, 16], image_shape=(4, 4, 3), - pooling="max", ) self.init_kwargs = { "backbone": self.backbone, "num_classes": 2, "activation": "softmax", + "pooling": "max", } self.train_data = ( self.images, diff --git a/keras_hub/src/utils/timm/preset_loader.py b/keras_hub/src/utils/timm/preset_loader.py index 3fed71e67c..fe596d8200 100644 --- a/keras_hub/src/utils/timm/preset_loader.py +++ b/keras_hub/src/utils/timm/preset_loader.py @@ -14,7 +14,7 @@ def __init__(self, preset, config): architecture = self.config["architecture"] if "resnet" in architecture: self.converter = convert_resnet - if "densenet" in architecture: + elif "densenet" in architecture: self.converter = convert_densenet else: raise ValueError( @@ -52,20 +52,19 @@ def load_image_converter(self, cls, **kwargs): pretrained_cfg = self.config.get("pretrained_cfg", None) if not pretrained_cfg or "input_size" not in pretrained_cfg: return None - # This assumes the same basic setup for all timm preprocessing, and that - # all our image conversion will be via a `ResizingImageConverter. We may + # This assumes the same basic setup for all timm preprocessing, We may # need to extend this as we cover more model types. input_size = pretrained_cfg["input_size"] mean = pretrained_cfg["mean"] - variance = [s**2 for s in pretrained_cfg["std"]] + std = pretrained_cfg["std"] + scale = [1.0 / 255.0 / s for s in std] + offset = [- m / s for m, s in zip(mean, std)] interpolation = pretrained_cfg["interpolation"] if interpolation not in ("bilinear", "nearest", "bicubic"): interpolation = "bilinear" # Unsupported interpolation type. return cls( - width=input_size[1], - height=input_size[2], - scale=1 / 255.0, - mean=mean, - variance=variance, + image_size=input_size[1:], + scale=scale, + offset=offset, interpolation=interpolation, )