Skip to content

Commit

Permalink
Image classifier changes
Browse files Browse the repository at this point in the history
- Move image classifier implementation to the base class.
- Allow passing arbitrary layers as preprocessing.
  • Loading branch information
mattdangerw committed Oct 2, 2024
1 parent 6d80bcf commit 80146c7
Show file tree
Hide file tree
Showing 26 changed files with 339 additions and 972 deletions.
3 changes: 0 additions & 3 deletions keras_hub/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@
)
from keras_hub.src.layers.preprocessing.random_deletion import RandomDeletion
from keras_hub.src.layers.preprocessing.random_swap import RandomSwap
from keras_hub.src.layers.preprocessing.resizing_image_converter import (
ResizingImageConverter,
)
from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker
from keras_hub.src.models.densenet.densenet_image_converter import (
DenseNetImageConverter,
Expand Down
155 changes: 128 additions & 27 deletions keras_hub/src/layers/preprocessing/image_converter.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,151 @@
import math

import numpy as np
import keras
from keras import ops

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.layers.preprocessing.preprocessing_layer import (
PreprocessingLayer,
)
from keras_hub.src.utils.keras_utils import standardize_data_format
from keras_hub.src.utils.preset_utils import builtin_presets
from keras_hub.src.utils.preset_utils import find_subclass
from keras_hub.src.utils.preset_utils import get_preset_loader
from keras_hub.src.utils.preset_utils import get_preset_saver
from keras_hub.src.utils.python_utils import classproperty
from keras_hub.src.utils.tensor_utils import preprocessing_function


@keras_hub_export("keras_hub.layers.ImageConverter")
class ImageConverter(PreprocessingLayer):
"""Convert raw image for models that support image input.
"""Preprocess raw images into model ready inputs.
This class converts from raw images to model ready inputs. This conversion
proceeds in the following steps:
This class converts from raw images of any size, to preprocessed
images for pretrained model inputs. It is meant to be a convenient way to
write custom preprocessing code that is not model specific. This layer
should be instantiated via the `from_preset()` constructor, which will
create the correct subclass of this layer for the model preset.
1. Resize the image using to `image_size`. If `image_size` is `None`, this
step will be skipped.
2. Rescale the image by multiplying by `scale`, which can be either global
or per channel. If `scale` is `None`, this step will be skipped.
3. Offset the image by adding `offset`, which can be either global
or per channel. If `offset` is `None`, this step will be skipped.
The layer will take as input a raw image tensor in the channels last or
channels first format, and output a preprocessed image input for modeling.
The exact structure of the output will vary per model, though in most cases
this layer will simply resize the image to the size needed by the model
input.
This tensor can be batched (rank 4), or unbatched (rank 3).
This layer can be used with the `from_preset()` constructor to load a layer
that will rescale and resize an image for a specific pretrained model.
Using the layer this way allows writing preprocessing code that does not
need updating when switching between model checkpoints.
Examples:
```python
# Resize images for `"pali_gemma_3b_224"`.
converter = keras_hub.layers.ImageConverter.from_preset("pali_gemma_3b_224")
converter(np.ones(2, 512, 512, 3)) # Output shape: (2, 224, 224, 3)
# Resize images for `"pali_gemma_3b_448"`.
converter = keras_hub.layers.ImageConverter.from_preset("pali_gemma_3b_448")
converter(np.ones(2, 512, 512, 3)) # Output shape: (2, 448, 448, 3)
# Resize raw images and scale them to [0, 1].
converter = keras_hub.layers.ImageConverter(
image_size=(128, 128),
scale=1 / 255.
)
converter(np.random.randint(0, 256, size=(2, 512, 512, 3)))
# Resize images to the specific size needed for a PaliGemma preset.
converter = keras_hub.layers.ImageConverter.from_preset(
"pali_gemma_3b_224"
)
converter(np.random.randint(0, 256, size=(2, 512, 512, 3)))
```
"""

backbone_cls = None

def __init__(
self,
image_size=None,
scale=None,
offset=None,
crop_to_aspect_ratio=True,
interpolation="bilinear",
data_format=None,
**kwargs,
):
# TODO: old arg names. Delete this block after resaving Kaggle assets.
if "height" in kwargs and "width" in kwargs:
image_size = (kwargs.pop("height"), kwargs.pop("width"))
if "variance" in kwargs and "mean" in kwargs:
std = [math.sqrt(v) for v in kwargs.pop("variance")]
scale = [scale / s for s in std]
offset = [-m / s for m, s in zip(kwargs.pop("mean"), std)]

super().__init__(**kwargs)

# Create the `Resizing` layer here even if it's not being used. That
# allows us to make `image_size` a settable property.
self.resizing = keras.layers.Resizing(
height=image_size[0] if image_size else None,
width=image_size[1] if image_size else None,
crop_to_aspect_ratio=crop_to_aspect_ratio,
interpolation=interpolation,
data_format=data_format,
dtype=self.dtype_policy,
name="resizing",
)
self.scale = scale
self.offset = offset
self.crop_to_aspect_ratio = crop_to_aspect_ratio
self.interpolation = interpolation
self.data_format = standardize_data_format(data_format)

@property
def image_size(self):
"""Returns the default size of a single image."""
return (None, None)
"""Settable tuple of `(height, width)` ints. The output image shape."""
if self.resizing.height is None:
return None
return (self.resizing.height, self.resizing.width)

@image_size.setter
def image_size(self, value):
if value is None:
value = (None, None)
self.resizing.height = value[0]
self.resizing.width = value[1]

@preprocessing_function
def call(self, inputs):
if self.image_size is not None:
x = self.resizing(inputs)
if self.scale is not None:
x = x * self._expand_non_channel_dims(self.scale, x)
if self.offset is not None:
x = x + self._expand_non_channel_dims(self.offset, x)
return x

def _expand_non_channel_dims(self, value, inputs):
unbatched = len(ops.shape(inputs)) == 3
channels_first = self.data_format == "channels_first"
if unbatched:
broadcast_dims = (1, 2) if channels_first else (0, 1)
else:
broadcast_dims = (0, 2, 3) if channels_first else (0, 1, 2)
# If inputs are not a tensor type, return a numpy array.
# This might happen when running under tf.data.
if ops.is_tensor(inputs):
return ops.expand_dims(value, broadcast_dims)
else:
return np.expand_dims(value, broadcast_dims)

def get_config(self):
config = super().get_config()
config.update(
{
"image_size": self.image_size,
"scale": self.scale,
"offset": self.offset,
"interpolation": self.interpolation,
"crop_to_aspect_ratio": self.crop_to_aspect_ratio,
}
)
return config

@classproperty
def presets(cls):
Expand Down Expand Up @@ -68,13 +173,6 @@ def from_preset(
You can run `cls.presets.keys()` to list all built-in presets available
on the class.
This constructor can be called in one of two ways. Either from the base
class like `keras_hub.models.ImageConverter.from_preset()`, or from a
model class like
`keras_hub.models.PaliGemmaImageConverter.from_preset()`. If calling
from the base class, the subclass of the returning object will be
inferred from the config in the preset directory.
Args:
preset: string. A built-in preset identifier, a Kaggle Models
handle, a Hugging Face handle, or a path to a local directory.
Expand All @@ -84,17 +182,20 @@ class like `keras_hub.models.ImageConverter.from_preset()`, or from a
Examples:
```python
batch = np.random.randint(0, 256, size=(2, 512, 512, 3))
# Resize images for `"pali_gemma_3b_224"`.
converter = keras_hub.layers.ImageConverter.from_preset(
"pali_gemma_3b_224"
)
converter(np.ones(2, 512, 512, 3)) # Output shape: (2, 224, 224, 3)
# Override arguments on the base class.
converter(batch) # Output shape: (2, 224, 224, 3)
# Resize images for `"pali_gemma_3b_448"` without cropping.
converter = keras_hub.layers.ImageConverter.from_preset(
"pali_gemma_3b_448",
crop_to_aspect_ratio=False,
)
converter(np.ones(2, 512, 512, 3)) # (2, 448, 448, 3)
converter(batch) # Output shape: (2, 448, 448, 3)
```
"""
loader = get_preset_loader(preset)
Expand Down
45 changes: 45 additions & 0 deletions keras_hub/src/layers/preprocessing/image_converter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import pytest
from keras import ops

from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
from keras_hub.src.models.pali_gemma.pali_gemma_backbone import (
Expand All @@ -15,6 +16,50 @@


class ImageConverterTest(TestCase):
def test_resize_simple(self):
converter = ImageConverter(height=4, width=4, scale=1 / 255.0)
inputs = np.ones((10, 10, 3)) * 255.0
outputs = converter(inputs)
self.assertAllClose(outputs, ops.ones((4, 4, 3)))

def test_unbatched(self):
converter = ImageConverter(
image_size=(4, 4),
scale=(1.0 / 255.0, 0.8 / 255.0, 1.2 / 255.0),
offset=(0.2, -0.1, 0.25),
)
inputs = np.ones((10, 10, 3)) * 128
outputs = converter(inputs)
self.assertEqual(ops.shape(outputs), (4, 4, 3))
self.assertAllClose(outputs[:, :, 0], np.ones((4, 4)) * 0.701961)
self.assertAllClose(outputs[:, :, 1], np.ones((4, 4)) * 0.301569)
self.assertAllClose(outputs[:, :, 2], np.ones((4, 4)) * 0.852353)

def test_resize_batch(self):
converter = ImageConverter(
image_size=(4, 4),
scale=(1.0 / 255.0, 0.8 / 255.0, 1.2 / 255.0),
offset=(0.2, -0.1, 0.25),
)
inputs = np.ones((2, 10, 10, 3)) * 128
outputs = converter(inputs)
self.assertEqual(ops.shape(outputs), (2, 4, 4, 3))
self.assertAllClose(outputs[:, :, :, 0], np.ones((2, 4, 4)) * 0.701961)
self.assertAllClose(outputs[:, :, :, 1], np.ones((2, 4, 4)) * 0.301569)
self.assertAllClose(outputs[:, :, :, 2], np.ones((2, 4, 4)) * 0.852353)

def test_config(self):
converter = ImageConverter(
image_size=(12, 20),
scale=(0.25 / 255.0, 0.1 / 255.0, 0.5 / 255.0),
offset=(0.2, -0.1, 0.25),
crop_to_aspect_ratio=False,
interpolation="nearest",
)
clone = ImageConverter.from_config(converter.get_config())
test_batch = np.random.rand(4, 10, 20, 3) * 255
self.assertAllClose(converter(test_batch), clone(test_batch))

def test_preset_accessors(self):
pali_gemma_presets = set(PaliGemmaImageConverter.presets.keys())
all_presets = set(ImageConverter.presets.keys())
Expand Down
Loading

0 comments on commit 80146c7

Please sign in to comment.