Skip to content

Commit

Permalink
Finish up resnet
Browse files Browse the repository at this point in the history
- Add presets.
- Add converter script.
- Add preprocessing with auto resizing.
  • Loading branch information
mattdangerw committed Sep 14, 2024
1 parent a5e5d8f commit a28b611
Show file tree
Hide file tree
Showing 20 changed files with 529 additions and 99 deletions.
3 changes: 3 additions & 0 deletions keras_nlp/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@
from keras_nlp.src.models.pali_gemma.pali_gemma_image_converter import (
PaliGemmaImageConverter,
)
from keras_nlp.src.models.resnet.resnet_image_converter import (
ResNetImageConverter,
)
from keras_nlp.src.models.whisper.whisper_audio_converter import (
WhisperAudioConverter,
)
6 changes: 6 additions & 0 deletions keras_nlp/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,9 @@
)
from keras_nlp.src.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer
from keras_nlp.src.models.image_classifier import ImageClassifier
from keras_nlp.src.models.image_classifier_preprocessor import (
ImageClassifierPreprocessor,
)
from keras_nlp.src.models.llama3.llama3_backbone import Llama3Backbone
from keras_nlp.src.models.llama3.llama3_causal_lm import Llama3CausalLM
from keras_nlp.src.models.llama3.llama3_causal_lm_preprocessor import (
Expand Down Expand Up @@ -231,6 +234,9 @@
from keras_nlp.src.models.resnet.resnet_image_classifier import (
ResNetImageClassifier,
)
from keras_nlp.src.models.resnet.resnet_image_classifier_preprocessor import (
ResNetImageClassifierPreprocessor,
)
from keras_nlp.src.models.roberta.roberta_backbone import RobertaBackbone
from keras_nlp.src.models.roberta.roberta_masked_lm import RobertaMaskedLM
from keras_nlp.src.models.roberta.roberta_masked_lm_preprocessor import (
Expand Down
5 changes: 0 additions & 5 deletions keras_nlp/src/models/image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,6 @@ class ImageClassifier(Task):
used to load a pre-trained config and weights.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Default compilation.
self.compile()

def compile(
self,
optimizer="auto",
Expand Down
83 changes: 83 additions & 0 deletions keras_nlp/src/models/image_classifier_preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import keras

from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.models.preprocessor import Preprocessor
from keras_nlp.src.utils.tensor_utils import preprocessing_function


@keras_nlp_export("keras_nlp.models.ImageClassifierPreprocessor")
class ImageClassifierPreprocessor(Preprocessor):
"""Base class for image classification preprocessing layers.
`ImageClassifierPreprocessor` tasks wraps a
`keras_nlp.layers.ImageConverter` to create a preprocessing layer for
image classification tasks. It is intended to be paired with a
`keras_nlp.models.ImageClassifier` task.
All `ImageClassifierPreprocessor` take inputs three inputs, `x`, `y`, and
`sample_weight`. `x`, the first input, should always be included. It can
be a image or batch of images. See examples below. `y` and `sample_weight`
are optional inputs that will be passed through unaltered. Usually, `y` will
be the classification label, and `sample_weight` will not be provided.
The layer will output either `x`, an `(x, y)` tuple if labels were provided,
or an `(x, y, sample_weight)` tuple if labels and sample weight were
provided. `x` will be the input images after all model preprocessing has
been applied.
All `ImageClassifierPreprocessor` tasks include a `from_preset()`
constructor which can be used to load a pre-trained config and vocabularies.
You can call the `from_preset()` constructor directly on this base class, in
which case the correct class for your model will be automatically
instantiated.
Examples.
```python
preprocessor = keras_nlp.models.ImageClassifierPreprocessor.from_preset(
"resnet_50",
)
# Resize a single image for resnet 50.
x = np.ones((512, 512, 3))
x = preprocessor(x)
# Resize a labeled image.
x, y = np.ones((512, 512, 3)), 1
x, y = preprocessor(x, y)
# Resize a batch of labeled images.
x, y = [np.ones((512, 512, 3)), np.zeros((512, 512, 3))], [1, 0]
x, y = preprocessor(x, y)
# Use a `tf.data.Dataset`.
ds = tf.data.Dataset.from_tensor_slices((x, y)).batch(2)
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
```
"""

def __init__(
self,
image_converter=None,
**kwargs,
):
super().__init__(**kwargs)
self.image_converter = image_converter

@preprocessing_function
def call(self, x, y=None, sample_weight=None):
if self.image_converter:
x = self.image_converter(x)
return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
6 changes: 6 additions & 0 deletions keras_nlp/src/models/resnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_nlp.src.models.resnet.resnet_backbone import ResNetBackbone
from keras_nlp.src.models.resnet.resnet_presets import backbone_presets
from keras_nlp.src.utils.preset_utils import register_presets

register_presets(backbone_presets, ResNetBackbone)
19 changes: 3 additions & 16 deletions keras_nlp/src/models/resnet/resnet_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class ResNetBackbone(FeaturePyramidBackbone):
stackwise_num_blocks: list of ints. The number of blocks for each stack.
stackwise_num_strides: list of ints. The number of strides for each
stack.
block_type: str. The block type to stack. One of `"basic_block"` or
block_type: str. The block type to stack. One of `"basic_block"`,
`"bottleneck_block"`, `"basic_block_vd"` or
`"bottleneck_block_vd"`. Use `"basic_block"` for ResNet18 and
ResNet34. Use `"bottleneck_block"` for ResNet50, ResNet101 and
Expand Down Expand Up @@ -126,7 +126,6 @@ def __init__(
use_pre_activation=False,
include_rescaling=True,
image_shape=(None, None, 3),
pooling="avg",
data_format=None,
dtype=None,
**kwargs,
Expand Down Expand Up @@ -285,20 +284,9 @@ def __init__(
)(x)
x = layers.Activation("relu", dtype=dtype, name="post_relu")(x)

if pooling == "avg":
feature_map_output = layers.GlobalAveragePooling2D(
data_format=data_format, dtype=dtype
)(x)
elif pooling == "max":
feature_map_output = layers.GlobalMaxPooling2D(
data_format=data_format, dtype=dtype
)(x)
else:
feature_map_output = x

super().__init__(
inputs=image_input,
outputs=feature_map_output,
outputs=x,
dtype=dtype,
**kwargs,
)
Expand All @@ -313,8 +301,8 @@ def __init__(
self.use_pre_activation = use_pre_activation
self.include_rescaling = include_rescaling
self.image_shape = image_shape
self.pooling = pooling
self.pyramid_outputs = pyramid_outputs
self.data_format = data_format

def get_config(self):
config = super().get_config()
Expand All @@ -329,7 +317,6 @@ def get_config(self):
"use_pre_activation": self.use_pre_activation,
"include_rescaling": self.include_rescaling,
"image_shape": self.image_shape,
"pooling": self.pooling,
}
)
return config
Expand Down
117 changes: 68 additions & 49 deletions keras_nlp/src/models/resnet/resnet_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,74 +29,93 @@ def setUp(self):
"stackwise_num_blocks": [2, 2, 2],
"stackwise_num_strides": [1, 2, 2],
"image_shape": (None, None, 3),
"pooling": "avg",
"block_type": "bottleneck_block",
"use_pre_activation": False,
}
self.input_size = 64
self.input_data = ops.ones((2, self.input_size, self.input_size, 3))

@parameterized.named_parameters(
("v1_basic", False, "basic_block"),
("v1_bottleneck", False, "bottleneck_block"),
("v2_basic", True, "basic_block"),
("v2_bottleneck", True, "bottleneck_block"),
("vd_basic", False, "basic_block_vd"),
("vd_bottleneck", False, "bottleneck_block_vd"),
("basic", "basic_block"),
("bottleneck", "bottleneck_block"),
)
def test_backbone_basics(self, use_pre_activation, block_type):
init_kwargs = self.init_kwargs.copy()
init_kwargs.update(
{
"block_type": block_type,
"use_pre_activation": use_pre_activation,
}
def test_backbone_basics(self, block_type):
feature_size = 64 if block_type == "basic_block" else 256
self.run_vision_backbone_test(
cls=ResNetBackbone,
init_kwargs={**self.init_kwargs, "block_type": block_type},
input_data=self.input_data,
expected_output_shape=(2, 4, 4, feature_size),
expected_pyramid_output_keys=["P2", "P3", "P4"],
expected_pyramid_image_sizes=[(16, 16), (8, 8), (4, 4)],
)
if block_type in ("basic_block_vd", "bottleneck_block_vd"):
init_kwargs.update(
{
"input_conv_filters": [32, 32, 64],
"input_conv_kernel_sizes": [3, 3, 3],
}
)

@parameterized.named_parameters(
("basic", "basic_block"),
("bottleneck", "bottleneck_block"),
)
def test_backbone_v2(self, block_type):
feature_size = 64 if block_type == "basic_block" else 256
self.run_vision_backbone_test(
cls=ResNetBackbone,
init_kwargs=init_kwargs,
init_kwargs={
**self.init_kwargs,
"block_type": block_type,
"use_pre_activation": True,
},
input_data=self.input_data,
expected_output_shape=(
(2, 64)
if block_type in ("basic_block", "basic_block_vd")
else (2, 256)
),
expected_output_shape=(2, 4, 4, feature_size),
expected_pyramid_output_keys=["P2", "P3", "P4"],
expected_pyramid_image_sizes=[(16, 16), (8, 8), (4, 4)],
)

@parameterized.named_parameters(
("v1_basic", False, "basic_block"),
("v1_bottleneck", False, "bottleneck_block"),
("v2_basic", True, "basic_block"),
("v2_bottleneck", True, "bottleneck_block"),
("vd_basic", False, "basic_block_vd"),
("vd_bottleneck", False, "bottleneck_block_vd"),
("basic", "basic_block_vd"),
("bottleneck", "bottleneck_block_vd"),
)
@pytest.mark.large
def test_saved_model(self, use_pre_activation, block_type):
init_kwargs = self.init_kwargs.copy()
init_kwargs.update(
{
def test_backbone_vd(self, block_type):
feature_size = 64 if block_type == "basic_block_vd" else 256
self.run_vision_backbone_test(
cls=ResNetBackbone,
init_kwargs={
**self.init_kwargs,
"block_type": block_type,
"use_pre_activation": use_pre_activation,
"image_shape": (None, None, 3),
}
"input_conv_filters": [32, 32, 64],
"input_conv_kernel_sizes": [3, 3, 3],
},
input_data=self.input_data,
expected_output_shape=(2, 4, 4, feature_size),
expected_pyramid_output_keys=["P2", "P3", "P4"],
expected_pyramid_image_sizes=[(16, 16), (8, 8), (4, 4)],
)
if block_type in ("basic_block_vd", "bottleneck_block_vd"):
init_kwargs.update(
{
"input_conv_filters": [32, 32, 64],
"input_conv_kernel_sizes": [3, 3, 3],
}
)

@pytest.mark.large
def test_saved_model(self):
self.run_model_saving_test(
cls=ResNetBackbone,
init_kwargs=init_kwargs,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
)

@pytest.mark.large
def test_smallest_preset(self):
image_batch = self.load_test_image((224, 224))[None, ...]
self.run_preset_test(
cls=ResNetBackbone,
preset="resnet_18_imagenet",
input_data=image_batch,
expected_output_shape=(1, 7, 7, 512),
# The forward pass from a preset should be stable!
expected_partial_output=ops.array(
[0.008969, 0.015136, 0.028074, 0.594599, 0.002846]
),
)

@pytest.mark.extra_large
def test_all_presets(self):
for preset in ResNetBackbone.presets:
self.run_preset_test(
cls=ResNetBackbone,
preset=preset,
input_data=self.input_data,
)
Loading

0 comments on commit a28b611

Please sign in to comment.