Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ConvNeXt and refactor BaseModel #16

Merged
merged 3 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions kimm/blocks/transformer_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,29 @@ def apply_mlp_block(
normalization=None,
use_bias=True,
dropout_rate=0.0,
use_conv_mlp=False,
name="mlp_block",
):
input_dim = inputs.shape[-1]
output_dim = output_dim or input_dim

x = inputs
x = layers.Dense(hidden_dim, use_bias=use_bias, name=f"{name}_fc1")(x)
if use_conv_mlp:
x = layers.Conv2D(
hidden_dim, 1, use_bias=use_bias, name=f"{name}_fc1_conv2d"
)(x)
else:
x = layers.Dense(hidden_dim, use_bias=use_bias, name=f"{name}_fc1")(x)
x = layers.Activation(activation, name=f"{name}_act")(x)
x = layers.Dropout(dropout_rate, name=f"{name}_drop1")(x)
if normalization is not None:
x = normalization(name=f"{name}_norm")(x)
x = layers.Dense(output_dim, use_bias=use_bias, name=f"{name}_fc2")(x)
if use_conv_mlp:
x = layers.Conv2D(
output_dim, 1, use_bias=use_bias, name=f"{name}_fc2_conv2d"
)(x)
else:
x = layers.Dense(output_dim, use_bias=use_bias, name=f"{name}_fc2")(x)
x = layers.Dropout(dropout_rate, name=f"{name}_drop2")(x)
return x

Expand Down
2 changes: 2 additions & 0 deletions kimm/layers/attention.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import keras
from keras import layers
from keras import ops


@keras.saving.register_keras_serializable(package="kimm")
class Attention(layers.Layer):
def __init__(
self,
Expand Down
2 changes: 2 additions & 0 deletions kimm/layers/layer_scale.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import keras
from keras import initializers
from keras import layers
from keras import ops


@keras.saving.register_keras_serializable(package="kimm")
class LayerScale(layers.Layer):
def __init__(
self,
Expand Down
2 changes: 2 additions & 0 deletions kimm/layers/position_embedding.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import keras
from keras import layers
from keras import ops


@keras.saving.register_keras_serializable(package="kimm")
class PositionEmbedding(layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
Expand Down
1 change: 1 addition & 0 deletions kimm/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from kimm.models.base_model import BaseModel
from kimm.models.convmixer import * # noqa:F403
from kimm.models.convnext import * # noqa:F403
from kimm.models.densenet import * # noqa:F403
from kimm.models.efficientnet import * # noqa:F403
from kimm.models.ghostnet import * # noqa:F403
Expand Down
171 changes: 111 additions & 60 deletions kimm/models/base_model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import abc
import pathlib
import typing
import urllib.parse

from keras import KerasTensor
from keras import backend
from keras import layers
from keras import models
from keras import utils
from keras.src.applications import imagenet_utils


Expand All @@ -14,53 +17,79 @@ def __init__(
inputs,
outputs,
features: typing.Optional[typing.Dict[str, KerasTensor]] = None,
feature_keys: typing.Optional[typing.List[str]] = None,
**kwargs,
):
self.feature_extractor = kwargs.pop("feature_extractor", False)
self.feature_keys = feature_keys
if self.feature_extractor:
if features is None:
raise ValueError(
"`features` must be set when "
f"`feature_extractor=True`. Received features={features}"
)
if self.feature_keys is None:
self.feature_keys = list(features.keys())
filtered_features = {}
for k in self.feature_keys:
if k not in features:
raise KeyError(
f"'{k}' is not a key of `features`. Available keys "
f"are: {list(features.keys())}"
)
filtered_features[k] = features[k]
# add outputs
if backend.is_keras_tensor(outputs):
filtered_features["TOP"] = outputs
super().__init__(inputs=inputs, outputs=filtered_features, **kwargs)
else:
if not hasattr(self, "_feature_extractor"):
del features
super().__init__(inputs=inputs, outputs=outputs, **kwargs)
else:
if not hasattr(self, "_feature_keys"):
raise AttributeError(
"`self._feature_keys` must be set when initializing "
"BaseModel"
)
if self._feature_extractor:
if features is None:
raise ValueError(
"`features` must be set when `feature_extractor=True`. "
f"Received features={features}"
)
if self._feature_keys is None:
self._feature_keys = list(features.keys())
filtered_features = {}
for k in self._feature_keys:
if k not in features:
raise KeyError(
f"'{k}' is not a key of `features`. Available keys "
f"are: {list(features.keys())}"
)
filtered_features[k] = features[k]
# Add outputs
if backend.is_keras_tensor(outputs):
filtered_features["TOP"] = outputs
super().__init__(
inputs=inputs, outputs=filtered_features, **kwargs
)
else:
del features
super().__init__(inputs=inputs, outputs=outputs, **kwargs)

if hasattr(self, "_weights_url"):
self.load_pretrained_weights(self._weights_url)

def parse_kwargs(
def set_properties(
self, kwargs: typing.Dict[str, typing.Any], default_size: int = 224
):
result = {
"input_tensor": kwargs.pop("input_tensor", None),
"input_shape": kwargs.pop("input_shape", None),
"include_preprocessing": kwargs.pop("include_preprocessing", True),
"include_top": kwargs.pop("include_top", True),
"pooling": kwargs.pop("pooling", None),
"dropout_rate": kwargs.pop("dropout_rate", 0.0),
"classes": kwargs.pop("classes", 1000),
"classifier_activation": kwargs.pop(
"classifier_activation", "softmax"
),
"weights": kwargs.pop("weights", "imagenet"),
"default_size": kwargs.pop("default_size", default_size),
}
return result
"""Must be called in the initilization of the class.

This method will add following common properties to the model object:
- input_shape
- include_preprocessing
- include_top
- pooling
- dropout_rate
- classes
- classifier_activation
- _weights
- weights_url
- default_size
"""
self._input_shape = kwargs.pop("input_shape", None)
self._include_preprocessing = kwargs.pop("include_preprocessing", True)
self._include_top = kwargs.pop("include_top", True)
self._pooling = kwargs.pop("pooling", None)
self._dropout_rate = kwargs.pop("dropout_rate", 0.0)
self._classes = kwargs.pop("classes", 1000)
self._classifier_activation = kwargs.pop(
"classifier_activation", "softmax"
)
self._weights = kwargs.pop("weights", None)
self._weights_url = kwargs.pop("weights_url", None)
self._default_size = kwargs.pop("default_size", default_size)
# feature extractor
self._feature_extractor = kwargs.pop("feature_extractor", False)
self._feature_keys = kwargs.pop("feature_keys", None)
print("self._feature_keys", self._feature_keys)

def determine_input_tensor(
self,
Expand All @@ -87,10 +116,12 @@ def determine_input_tensor(
if not backend.is_keras_tensor(input_tensor):
x = layers.Input(tensor=input_tensor, shape=input_shape)
else:
x = input_tensor
x = utils.get_source_inputs(input_tensor)
return x

def build_preprocessing(self, inputs, mode="imagenet"):
if self._include_preprocessing is False:
return inputs
if mode == "imagenet":
# [0, 255] to [0, 1] and apply ImageNet mean and variance
x = layers.Rescaling(scale=1.0 / 255.0)(inputs)
Expand Down Expand Up @@ -118,15 +149,30 @@ def build_top(self, inputs, classes, classifier_activation, dropout_rate):
)(x)
return x

def add_references(self, parsed_kwargs: typing.Dict[str, typing.Any]):
self.include_preprocessing = parsed_kwargs["include_preprocessing"]
self.include_top = parsed_kwargs["include_top"]
self.pooling = parsed_kwargs["pooling"]
self.dropout_rate = parsed_kwargs["dropout_rate"]
self.classes = parsed_kwargs["classes"]
self.classifier_activation = parsed_kwargs["classifier_activation"]
# `self.weights` is been used internally
self._weights = parsed_kwargs["weights"]
def build_head(self, inputs):
x = inputs
if self._include_top:
x = self.build_top(
x,
self._classes,
self._classifier_activation,
self._dropout_rate,
)
else:
if self._pooling == "avg":
x = layers.GlobalAveragePooling2D(name="avg_pool")(x)
elif self._pooling == "max":
x = layers.GlobalMaxPooling2D(name="max_pool")(x)
return x

def load_pretrained_weights(self, weights_url: typing.Optional[str] = None):
if weights_url is not None:
result = urllib.parse.urlparse(weights_url)
file_name = pathlib.Path(result.path).name
weights_path = utils.get_file(
file_name, weights_url, cache_subdir="kimm_models"
)
self.load_weights(weights_path)

@staticmethod
@abc.abstractmethod
Expand All @@ -141,20 +187,25 @@ def get_config(self):
# models.Model
"name": self.name,
"trainable": self.trainable,
# feature extractor
"feature_extractor": self.feature_extractor,
"feature_keys": self.feature_keys,
# common
"input_shape": self.input_shape[1:],
"include_preprocessing": self.include_preprocessing,
"include_top": self.include_top,
"pooling": self.pooling,
"dropout_rate": self.dropout_rate,
"classes": self.classes,
"classifier_activation": self.classifier_activation,
# common
"include_preprocessing": self._include_preprocessing,
"include_top": self._include_top,
"pooling": self._pooling,
"dropout_rate": self._dropout_rate,
"classes": self._classes,
"classifier_activation": self._classifier_activation,
"weights": self._weights,
"weights_url": self._weights_url,
# feature extractor
"feature_extractor": self._feature_extractor,
"feature_keys": self._feature_keys,
}
return config

def fix_config(self, config: typing.Dict):
return config

@property
def default_origin(self):
return "https://github.com/james77777778/keras-aug/releases/download/v0.5.0"
1 change: 1 addition & 0 deletions kimm/models/base_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

class SampleModel(BaseModel):
def __init__(self, **kwargs):
self.set_properties(kwargs)
inputs = layers.Input(shape=[224, 224, 3])

features = {}
Expand Down
39 changes: 10 additions & 29 deletions kimm/models/convmixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import keras
from keras import layers
from keras import utils

from kimm.models.base_model import BaseModel
from kimm.utils import add_model_to_registry
Expand Down Expand Up @@ -42,6 +41,7 @@ def apply_convmixer_block(
return x


@keras.saving.register_keras_serializable(package="kimm")
class ConvMixer(BaseModel):
def __init__(
self,
Expand All @@ -52,16 +52,16 @@ def __init__(
activation: str = "relu",
**kwargs,
):
parsed_kwargs = self.parse_kwargs(kwargs)
img_input = self.determine_input_tensor(
parsed_kwargs["input_tensor"],
parsed_kwargs["input_shape"],
parsed_kwargs["default_size"],
input_tensor = kwargs.pop("input_tensor", None)
self.set_properties(kwargs)
inputs = self.determine_input_tensor(
input_tensor,
self._input_shape,
self._default_size,
)
x = img_input
x = inputs

if parsed_kwargs["include_preprocessing"]:
x = self.build_preprocessing(x, "imagenet")
x = self.build_preprocessing(x, "imagenet")

# Prepare feature extraction
features = {}
Expand Down Expand Up @@ -89,30 +89,11 @@ def __init__(
features[f"BLOCK{i}"] = x

# Head
if parsed_kwargs["include_top"]:
x = self.build_top(
x,
parsed_kwargs["classes"],
parsed_kwargs["classifier_activation"],
parsed_kwargs["dropout_rate"],
)
else:
if parsed_kwargs["pooling"] == "avg":
x = layers.GlobalAveragePooling2D(name="avg_pool")(x)
elif parsed_kwargs["pooling"] == "max":
x = layers.GlobalMaxPooling2D(name="max_pool")(x)

# Ensure that the model takes into account
# any potential predecessors of `input_tensor`.
if parsed_kwargs["input_tensor"] is not None:
inputs = utils.get_source_inputs(parsed_kwargs["input_tensor"])
else:
inputs = img_input
x = self.build_head(x)

super().__init__(inputs=inputs, outputs=x, features=features, **kwargs)

# All references to `self` below this line
self.add_references(parsed_kwargs)
self.depth = depth
self.hidden_channels = hidden_channels
self.patch_size = patch_size
Expand Down
Loading