Skip to content

Commit

Permalink
Finish up resnet
Browse files Browse the repository at this point in the history
- Add presets.
- Add converter script.
- Add preprocessing with auto resizing.
  • Loading branch information
mattdangerw committed Sep 13, 2024
1 parent a5e5d8f commit b066976
Show file tree
Hide file tree
Showing 14 changed files with 411 additions and 33 deletions.
3 changes: 3 additions & 0 deletions keras_nlp/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@
from keras_nlp.src.models.pali_gemma.pali_gemma_image_converter import (
PaliGemmaImageConverter,
)
from keras_nlp.src.models.resnet.resnet_image_converter import (
ResNetImageConverter,
)
from keras_nlp.src.models.whisper.whisper_audio_converter import (
WhisperAudioConverter,
)
6 changes: 6 additions & 0 deletions keras_nlp/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,9 @@
)
from keras_nlp.src.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer
from keras_nlp.src.models.image_classifier import ImageClassifier
from keras_nlp.src.models.image_classifier_preprocessor import (
ImageClassifierPreprocessor,
)
from keras_nlp.src.models.llama3.llama3_backbone import Llama3Backbone
from keras_nlp.src.models.llama3.llama3_causal_lm import Llama3CausalLM
from keras_nlp.src.models.llama3.llama3_causal_lm_preprocessor import (
Expand Down Expand Up @@ -231,6 +234,9 @@
from keras_nlp.src.models.resnet.resnet_image_classifier import (
ResNetImageClassifier,
)
from keras_nlp.src.models.resnet.resnet_image_classifier_preprocessor import (
ResNetImageClassifierPreprocessor,
)
from keras_nlp.src.models.roberta.roberta_backbone import RobertaBackbone
from keras_nlp.src.models.roberta.roberta_masked_lm import RobertaMaskedLM
from keras_nlp.src.models.roberta.roberta_masked_lm_preprocessor import (
Expand Down
5 changes: 0 additions & 5 deletions keras_nlp/src/models/image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,6 @@ class ImageClassifier(Task):
used to load a pre-trained config and weights.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Default compilation.
self.compile()

def compile(
self,
optimizer="auto",
Expand Down
83 changes: 83 additions & 0 deletions keras_nlp/src/models/image_classifier_preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import keras

from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.models.preprocessor import Preprocessor
from keras_nlp.src.utils.tensor_utils import preprocessing_function


@keras_nlp_export("keras_nlp.models.ImageClassifierPreprocessor")
class ImageClassifierPreprocessor(Preprocessor):
"""Base class for image classification preprocessing layers.
`ImageClassifierPreprocessor` tasks wraps a
`keras_nlp.layers.ImageConverter` to create a preprocessing layer for
image classification tasks. It is intended to be paired with a
`keras_nlp.models.ImageClassifier` task.
All `ImageClassifierPreprocessor` take inputs three inputs, `x`, `y`, and
`sample_weight`. `x`, the first input, should always be included. It can
be a image or batch of images. See examples below. `y` and `sample_weight`
are optional inputs that will be passed through unaltered. Usually, `y` will
be the classification label, and `sample_weight` will not be provided.
The layer will output either `x`, an `(x, y)` tuple if labels were provided,
or an `(x, y, sample_weight)` tuple if labels and sample weight were
provided. `x` will be the input images after all model preprocessing has
been applied.
All `ImageClassifierPreprocessor` tasks include a `from_preset()`
constructor which can be used to load a pre-trained config and vocabularies.
You can call the `from_preset()` constructor directly on this base class, in
which case the correct class for your model will be automatically
instantiated.
Examples.
```python
preprocessor = keras_nlp.models.ImageClassifierPreprocessor.from_preset(
"resnet_50",
)
# Resize a single image for resnet 50.
x = np.ones((512, 512, 3))
x = preprocessor(x)
# Resize a labeled image.
x, y = np.ones((512, 512, 3)), 1
x, y = preprocessor(x, y)
# Resize a batch of labeled images.
x, y = [np.ones((512, 512, 3)), np.zeros((512, 512, 3))], [1, 0]
x, y = preprocessor(x, y)
# Use a `tf.data.Dataset`.
ds = tf.data.Dataset.from_tensor_slices((x, y)).batch(2)
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
```
"""

def __init__(
self,
image_converter=None,
**kwargs,
):
super().__init__(**kwargs)
self.image_converter = image_converter

@preprocessing_function
def call(self, x, y=None, sample_weight=None):
if self.image_converter:
x = self.image_converter(x)
return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
6 changes: 6 additions & 0 deletions keras_nlp/src/models/resnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_nlp.src.models.resnet.resnet_backbone import ResNetBackbone
from keras_nlp.src.models.resnet.resnet_presets import backbone_presets
from keras_nlp.src.utils.preset_utils import register_presets

register_presets(backbone_presets, ResNetBackbone)
10 changes: 7 additions & 3 deletions keras_nlp/src/models/resnet/resnet_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.models.image_classifier import ImageClassifier
from keras_nlp.src.models.resnet.resnet_backbone import ResNetBackbone
from keras_nlp.src.models.resnet.resnet_image_classifier_preprocessor import (
ResNetImageClassifierPreprocessor,
)


@keras_nlp_export("keras_nlp.models.ResNetImageClassifier")
Expand Down Expand Up @@ -88,21 +91,22 @@ class ResNetImageClassifier(ImageClassifier):
"""

backbone_cls = ResNetBackbone
preprocessor_cls = ResNetImageClassifierPreprocessor

def __init__(
self,
backbone,
num_classes,
activation="softmax",
preprocessor=None,
activation=None,
head_dtype=None,
preprocessor=None, # adding this dummy arg for saved model test
# TODO: once preprocessor flow is figured out, this needs to be updated
**kwargs,
):
head_dtype = head_dtype or backbone.dtype_policy

# === Layers ===
self.backbone = backbone
self.preprocessor = preprocessor
self.output_dense = keras.layers.Dense(
num_classes,
activation=activation,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.models.image_classifier_preprocessor import (
ImageClassifierPreprocessor,
)
from keras_nlp.src.models.resnet.resnet_backbone import ResNetBackbone
from keras_nlp.src.models.resnet.resnet_image_converter import (
ResNetImageConverter,
)


@keras_nlp_export("keras_nlp.models.ResNetImageClassifierPreprocessor")
class ResNetImageClassifierPreprocessor(ImageClassifierPreprocessor):
backbone_cls = ResNetBackbone
image_converter_cls = ResNetImageConverter
23 changes: 23 additions & 0 deletions keras_nlp/src/models/resnet/resnet_image_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.layers.preprocessing.resizing_image_converter import (
ResizingImageConverter,
)
from keras_nlp.src.models.resnet.resnet_backbone import ResNetBackbone


@keras_nlp_export("keras_nlp.layers.ResNetImageConverter")
class ResNetImageConverter(ResizingImageConverter):
backbone_cls = ResNetBackbone
95 changes: 95 additions & 0 deletions keras_nlp/src/models/resnet/resnet_presets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ResNet preset configurations."""

backbone_presets = {
"resnet_18_imagenet": {
"metadata": {
"description": (
"18-layer ResNet model pre-trained on the ImageNet 1k dataset "
"at a 224x224 resolution."
),
"params": 11186112,
"official_name": "ResNet",
"path": "resnet",
"model_card": "https://arxiv.org/abs/2110.00476",
},
"kaggle_handle": "kaggle://kerashub/resnetv1/keras/resnet_18_imagenet/1",
},
"resnet_50_imagenet": {
"metadata": {
"description": (
"50-layer ResNet model pre-trained on the ImageNet 1k dataset "
"at a 224x224 resolution."
),
"params": 23561152,
"official_name": "ResNet",
"path": "resnet",
"model_card": "https://arxiv.org/abs/2110.00476",
},
"kaggle_handle": "kaggle://kerashub/resnetv1/keras/resnet_50_imagenet/1",
},
"resnet_101_imagenet": {
"metadata": {
"description": (
"101-layer ResNet model pre-trained on the ImageNet 1k dataset "
"at a 224x224 resolution."
),
"params": 42605504,
"official_name": "ResNet",
"path": "resnet",
"model_card": "https://arxiv.org/abs/2110.00476",
},
"kaggle_handle": "kaggle://kerashub/resnetv1/keras/resnet_101_imagenet/1",
},
"resnet_152_imagenet": {
"metadata": {
"description": (
"152-layer ResNet model pre-trained on the ImageNet 1k dataset "
"at a 224x224 resolution."
),
"params": 58295232,
"official_name": "ResNet",
"path": "resnet",
"model_card": "https://arxiv.org/abs/2110.00476",
},
"kaggle_handle": "kaggle://kerashub/resnetv1/keras/resnet_152_imagenet/1",
},
"resnet_v2_50_imagenet": {
"metadata": {
"description": (
"50-layer ResNetV2 model pre-trained on the ImageNet 1k "
"dataset at a 224x224 resolution."
),
"params": 23561152,
"official_name": "ResNet",
"path": "resnet",
"model_card": "https://arxiv.org/abs/2110.00476",
},
"kaggle_handle": "kaggle://kerashub/resnetv2/keras/resnet_v2_50_imagenet/1",
},
"resnet_v2_101_imagenet": {
"metadata": {
"description": (
"101-layer ResNetV2 model pre-trained on the ImageNet 1k "
"dataset at a 224x224 resolution."
),
"params": 42605504,
"official_name": "ResNet",
"path": "resnet",
"model_card": "https://arxiv.org/abs/2110.00476",
},
"kaggle_handle": "kaggle://kerashub/resnetv2/keras/resnet_v2_101_imagenet/1",
},
}
3 changes: 0 additions & 3 deletions keras_nlp/src/models/text_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,6 @@ class TextClassifier(Task):
```
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def compile(
self,
optimizer="auto",
Expand Down
Loading

0 comments on commit b066976

Please sign in to comment.