diff --git a/kimm/blocks/transformer_block.py b/kimm/blocks/transformer_block.py index 04485a3..f878d0a 100644 --- a/kimm/blocks/transformer_block.py +++ b/kimm/blocks/transformer_block.py @@ -11,18 +11,29 @@ def apply_mlp_block( normalization=None, use_bias=True, dropout_rate=0.0, + use_conv_mlp=False, name="mlp_block", ): input_dim = inputs.shape[-1] output_dim = output_dim or input_dim x = inputs - x = layers.Dense(hidden_dim, use_bias=use_bias, name=f"{name}_fc1")(x) + if use_conv_mlp: + x = layers.Conv2D( + hidden_dim, 1, use_bias=use_bias, name=f"{name}_fc1_conv2d" + )(x) + else: + x = layers.Dense(hidden_dim, use_bias=use_bias, name=f"{name}_fc1")(x) x = layers.Activation(activation, name=f"{name}_act")(x) x = layers.Dropout(dropout_rate, name=f"{name}_drop1")(x) if normalization is not None: x = normalization(name=f"{name}_norm")(x) - x = layers.Dense(output_dim, use_bias=use_bias, name=f"{name}_fc2")(x) + if use_conv_mlp: + x = layers.Conv2D( + output_dim, 1, use_bias=use_bias, name=f"{name}_fc2_conv2d" + )(x) + else: + x = layers.Dense(output_dim, use_bias=use_bias, name=f"{name}_fc2")(x) x = layers.Dropout(dropout_rate, name=f"{name}_drop2")(x) return x diff --git a/kimm/layers/attention.py b/kimm/layers/attention.py index 7eb90af..134d5fd 100644 --- a/kimm/layers/attention.py +++ b/kimm/layers/attention.py @@ -1,7 +1,9 @@ +import keras from keras import layers from keras import ops +@keras.saving.register_keras_serializable(package="kimm") class Attention(layers.Layer): def __init__( self, diff --git a/kimm/layers/layer_scale.py b/kimm/layers/layer_scale.py index 0afce2d..e39789b 100644 --- a/kimm/layers/layer_scale.py +++ b/kimm/layers/layer_scale.py @@ -1,8 +1,10 @@ +import keras from keras import initializers from keras import layers from keras import ops +@keras.saving.register_keras_serializable(package="kimm") class LayerScale(layers.Layer): def __init__( self, diff --git a/kimm/layers/position_embedding.py b/kimm/layers/position_embedding.py index 82670f3..9738aaa 100644 --- a/kimm/layers/position_embedding.py +++ b/kimm/layers/position_embedding.py @@ -1,7 +1,9 @@ +import keras from keras import layers from keras import ops +@keras.saving.register_keras_serializable(package="kimm") class PositionEmbedding(layers.Layer): def __init__(self, **kwargs): super().__init__(**kwargs) diff --git a/kimm/models/__init__.py b/kimm/models/__init__.py index bbd825a..f1540de 100644 --- a/kimm/models/__init__.py +++ b/kimm/models/__init__.py @@ -1,5 +1,6 @@ from kimm.models.base_model import BaseModel from kimm.models.convmixer import * # noqa:F403 +from kimm.models.convnext import * # noqa:F403 from kimm.models.densenet import * # noqa:F403 from kimm.models.efficientnet import * # noqa:F403 from kimm.models.ghostnet import * # noqa:F403 diff --git a/kimm/models/base_model.py b/kimm/models/base_model.py index 5a2e388..4f1d3f4 100644 --- a/kimm/models/base_model.py +++ b/kimm/models/base_model.py @@ -1,10 +1,13 @@ import abc +import pathlib import typing +import urllib.parse from keras import KerasTensor from keras import backend from keras import layers from keras import models +from keras import utils from keras.src.applications import imagenet_utils @@ -14,53 +17,79 @@ def __init__( inputs, outputs, features: typing.Optional[typing.Dict[str, KerasTensor]] = None, - feature_keys: typing.Optional[typing.List[str]] = None, **kwargs, ): - self.feature_extractor = kwargs.pop("feature_extractor", False) - self.feature_keys = feature_keys - if self.feature_extractor: - if features is None: - raise ValueError( - "`features` must be set when " - f"`feature_extractor=True`. Received features={features}" - ) - if self.feature_keys is None: - self.feature_keys = list(features.keys()) - filtered_features = {} - for k in self.feature_keys: - if k not in features: - raise KeyError( - f"'{k}' is not a key of `features`. Available keys " - f"are: {list(features.keys())}" - ) - filtered_features[k] = features[k] - # add outputs - if backend.is_keras_tensor(outputs): - filtered_features["TOP"] = outputs - super().__init__(inputs=inputs, outputs=filtered_features, **kwargs) - else: + if not hasattr(self, "_feature_extractor"): del features super().__init__(inputs=inputs, outputs=outputs, **kwargs) + else: + if not hasattr(self, "_feature_keys"): + raise AttributeError( + "`self._feature_keys` must be set when initializing " + "BaseModel" + ) + if self._feature_extractor: + if features is None: + raise ValueError( + "`features` must be set when `feature_extractor=True`. " + f"Received features={features}" + ) + if self._feature_keys is None: + self._feature_keys = list(features.keys()) + filtered_features = {} + for k in self._feature_keys: + if k not in features: + raise KeyError( + f"'{k}' is not a key of `features`. Available keys " + f"are: {list(features.keys())}" + ) + filtered_features[k] = features[k] + # Add outputs + if backend.is_keras_tensor(outputs): + filtered_features["TOP"] = outputs + super().__init__( + inputs=inputs, outputs=filtered_features, **kwargs + ) + else: + del features + super().__init__(inputs=inputs, outputs=outputs, **kwargs) + + if hasattr(self, "_weights_url"): + self.load_pretrained_weights(self._weights_url) - def parse_kwargs( + def set_properties( self, kwargs: typing.Dict[str, typing.Any], default_size: int = 224 ): - result = { - "input_tensor": kwargs.pop("input_tensor", None), - "input_shape": kwargs.pop("input_shape", None), - "include_preprocessing": kwargs.pop("include_preprocessing", True), - "include_top": kwargs.pop("include_top", True), - "pooling": kwargs.pop("pooling", None), - "dropout_rate": kwargs.pop("dropout_rate", 0.0), - "classes": kwargs.pop("classes", 1000), - "classifier_activation": kwargs.pop( - "classifier_activation", "softmax" - ), - "weights": kwargs.pop("weights", "imagenet"), - "default_size": kwargs.pop("default_size", default_size), - } - return result + """Must be called in the initilization of the class. + + This method will add following common properties to the model object: + - input_shape + - include_preprocessing + - include_top + - pooling + - dropout_rate + - classes + - classifier_activation + - _weights + - weights_url + - default_size + """ + self._input_shape = kwargs.pop("input_shape", None) + self._include_preprocessing = kwargs.pop("include_preprocessing", True) + self._include_top = kwargs.pop("include_top", True) + self._pooling = kwargs.pop("pooling", None) + self._dropout_rate = kwargs.pop("dropout_rate", 0.0) + self._classes = kwargs.pop("classes", 1000) + self._classifier_activation = kwargs.pop( + "classifier_activation", "softmax" + ) + self._weights = kwargs.pop("weights", None) + self._weights_url = kwargs.pop("weights_url", None) + self._default_size = kwargs.pop("default_size", default_size) + # feature extractor + self._feature_extractor = kwargs.pop("feature_extractor", False) + self._feature_keys = kwargs.pop("feature_keys", None) + print("self._feature_keys", self._feature_keys) def determine_input_tensor( self, @@ -87,10 +116,12 @@ def determine_input_tensor( if not backend.is_keras_tensor(input_tensor): x = layers.Input(tensor=input_tensor, shape=input_shape) else: - x = input_tensor + x = utils.get_source_inputs(input_tensor) return x def build_preprocessing(self, inputs, mode="imagenet"): + if self._include_preprocessing is False: + return inputs if mode == "imagenet": # [0, 255] to [0, 1] and apply ImageNet mean and variance x = layers.Rescaling(scale=1.0 / 255.0)(inputs) @@ -118,15 +149,30 @@ def build_top(self, inputs, classes, classifier_activation, dropout_rate): )(x) return x - def add_references(self, parsed_kwargs: typing.Dict[str, typing.Any]): - self.include_preprocessing = parsed_kwargs["include_preprocessing"] - self.include_top = parsed_kwargs["include_top"] - self.pooling = parsed_kwargs["pooling"] - self.dropout_rate = parsed_kwargs["dropout_rate"] - self.classes = parsed_kwargs["classes"] - self.classifier_activation = parsed_kwargs["classifier_activation"] - # `self.weights` is been used internally - self._weights = parsed_kwargs["weights"] + def build_head(self, inputs): + x = inputs + if self._include_top: + x = self.build_top( + x, + self._classes, + self._classifier_activation, + self._dropout_rate, + ) + else: + if self._pooling == "avg": + x = layers.GlobalAveragePooling2D(name="avg_pool")(x) + elif self._pooling == "max": + x = layers.GlobalMaxPooling2D(name="max_pool")(x) + return x + + def load_pretrained_weights(self, weights_url: typing.Optional[str] = None): + if weights_url is not None: + result = urllib.parse.urlparse(weights_url) + file_name = pathlib.Path(result.path).name + weights_path = utils.get_file( + file_name, weights_url, cache_subdir="kimm_models" + ) + self.load_weights(weights_path) @staticmethod @abc.abstractmethod @@ -141,20 +187,25 @@ def get_config(self): # models.Model "name": self.name, "trainable": self.trainable, - # feature extractor - "feature_extractor": self.feature_extractor, - "feature_keys": self.feature_keys, - # common "input_shape": self.input_shape[1:], - "include_preprocessing": self.include_preprocessing, - "include_top": self.include_top, - "pooling": self.pooling, - "dropout_rate": self.dropout_rate, - "classes": self.classes, - "classifier_activation": self.classifier_activation, + # common + "include_preprocessing": self._include_preprocessing, + "include_top": self._include_top, + "pooling": self._pooling, + "dropout_rate": self._dropout_rate, + "classes": self._classes, + "classifier_activation": self._classifier_activation, "weights": self._weights, + "weights_url": self._weights_url, + # feature extractor + "feature_extractor": self._feature_extractor, + "feature_keys": self._feature_keys, } return config def fix_config(self, config: typing.Dict): return config + + @property + def default_origin(self): + return "https://github.com/james77777778/keras-aug/releases/download/v0.5.0" diff --git a/kimm/models/base_model_test.py b/kimm/models/base_model_test.py index b3dc3e0..bfba31c 100644 --- a/kimm/models/base_model_test.py +++ b/kimm/models/base_model_test.py @@ -8,6 +8,7 @@ class SampleModel(BaseModel): def __init__(self, **kwargs): + self.set_properties(kwargs) inputs = layers.Input(shape=[224, 224, 3]) features = {} diff --git a/kimm/models/convmixer.py b/kimm/models/convmixer.py index 76ee965..c6bca47 100644 --- a/kimm/models/convmixer.py +++ b/kimm/models/convmixer.py @@ -2,7 +2,6 @@ import keras from keras import layers -from keras import utils from kimm.models.base_model import BaseModel from kimm.utils import add_model_to_registry @@ -42,6 +41,7 @@ def apply_convmixer_block( return x +@keras.saving.register_keras_serializable(package="kimm") class ConvMixer(BaseModel): def __init__( self, @@ -52,16 +52,16 @@ def __init__( activation: str = "relu", **kwargs, ): - parsed_kwargs = self.parse_kwargs(kwargs) - img_input = self.determine_input_tensor( - parsed_kwargs["input_tensor"], - parsed_kwargs["input_shape"], - parsed_kwargs["default_size"], + input_tensor = kwargs.pop("input_tensor", None) + self.set_properties(kwargs) + inputs = self.determine_input_tensor( + input_tensor, + self._input_shape, + self._default_size, ) - x = img_input + x = inputs - if parsed_kwargs["include_preprocessing"]: - x = self.build_preprocessing(x, "imagenet") + x = self.build_preprocessing(x, "imagenet") # Prepare feature extraction features = {} @@ -89,30 +89,11 @@ def __init__( features[f"BLOCK{i}"] = x # Head - if parsed_kwargs["include_top"]: - x = self.build_top( - x, - parsed_kwargs["classes"], - parsed_kwargs["classifier_activation"], - parsed_kwargs["dropout_rate"], - ) - else: - if parsed_kwargs["pooling"] == "avg": - x = layers.GlobalAveragePooling2D(name="avg_pool")(x) - elif parsed_kwargs["pooling"] == "max": - x = layers.GlobalMaxPooling2D(name="max_pool")(x) - - # Ensure that the model takes into account - # any potential predecessors of `input_tensor`. - if parsed_kwargs["input_tensor"] is not None: - inputs = utils.get_source_inputs(parsed_kwargs["input_tensor"]) - else: - inputs = img_input + x = self.build_head(x) super().__init__(inputs=inputs, outputs=x, features=features, **kwargs) # All references to `self` below this line - self.add_references(parsed_kwargs) self.depth = depth self.hidden_channels = hidden_channels self.patch_size = patch_size diff --git a/kimm/models/convnext.py b/kimm/models/convnext.py new file mode 100644 index 0000000..b1bd8fe --- /dev/null +++ b/kimm/models/convnext.py @@ -0,0 +1,580 @@ +import typing + +import keras +from keras import initializers +from keras import layers + +from kimm import layers as kimm_layers +from kimm.blocks import apply_mlp_block +from kimm.models.base_model import BaseModel +from kimm.utils import add_model_to_registry + + +def apply_convnext_block( + inputs, + output_channels, + kernel_size, + strides, + mlp_ratio, + activation="gelu", + use_conv_mlp=False, + use_grn=False, + name="convnext_block", +): + input_channels = inputs.shape[-1] + hidden_channels = int(mlp_ratio * output_channels) + x = inputs + shortcut = inputs + + # Padding + padding = "same" + if strides > 1: + padding = "valid" + x = layers.ZeroPadding2D( + (kernel_size[0] // 2, kernel_size[1] // 2), name=f"{name}_pad" + )(x) + + # Depthwise + x = layers.DepthwiseConv2D( + kernel_size, + strides, + padding=padding, + use_bias=True, + name=f"{name}_conv_dw_dwconv2d", + )(x) + x = layers.LayerNormalization(epsilon=1e-6, name=f"{name}_norm")(x) + + # MLP + x = apply_mlp_block( + x, + hidden_channels, + output_channels, + activation, + use_bias=True, + use_conv_mlp=use_conv_mlp, + name=f"{name}_mlp", + ) + + # LayerScale + x = kimm_layers.LayerScale( + output_channels, initializers.Constant(1e-6), name=f"{name}_layerscale" + )(x) + + # Downsample + if input_channels != output_channels or strides != 1: + shortcut = layers.AveragePooling2D( + 2, strides=strides, name=f"{name}_pool" + )(shortcut) + if input_channels != output_channels: + shortcut = layers.Conv2D( + output_channels, 1, 1, use_bias=False, name=f"{name}_conv" + )(shortcut) + + x = layers.Add()([x, shortcut]) + return x + + +def apply_convnext_stage( + inputs, + depth: int, + output_channels: int, + kernel_size: int, + strides: int, + activation="gelu", + use_conv_mlp=False, + use_grn=False, + name="convnext_stage", +): + input_channels = inputs.shape[-1] + x = inputs + + # Downsample + if input_channels != output_channels or strides > 1: + ds_ks = 2 if strides > 1 else 1 + x = layers.LayerNormalization( + epsilon=1e-6, name=f"{name}_downsample_0" + )(x) + x = layers.Conv2D( + output_channels, + ds_ks, + strides, + padding="valid", + use_bias=True, + name=f"{name}_downsample_1_conv2d", + )(x) + + for i in range(depth): + x = apply_convnext_block( + x, + output_channels, + kernel_size, + 1, + mlp_ratio=4, + activation=activation, + use_conv_mlp=use_conv_mlp, + use_grn=use_grn, + name=f"{name}_blocks_{i}", + ) + return x + + +@keras.saving.register_keras_serializable(package="kimm") +class ConvNeXt(BaseModel): + def __init__( + self, + depths: typing.Sequence[int] = [3, 3, 9, 3], + hidden_channels: typing.Sequence[int] = [96, 192, 384, 768], + patch_size: int = 4, + kernel_size: int = 7, + activation: str = "gelu", + use_conv_mlp: bool = False, + **kwargs, + ): + input_tensor = kwargs.pop("input_tensor", None) + self.set_properties(kwargs) + inputs = self.determine_input_tensor( + input_tensor, + self._input_shape, + self._default_size, + ) + x = inputs + + x = self.build_preprocessing(x, "imagenet") + + # Prepare feature extraction + features = {} + + # Stem + x = layers.Conv2D( + hidden_channels[0], + patch_size, + patch_size, + use_bias=True, + name="stem_0_conv2d", + )(x) + x = layers.LayerNormalization(epsilon=1e-6, name="stem_1")(x) + features["STEM_S4"] = x + + # Blocks (4 stages) + current_stride = patch_size + for i in range(4): + strides = 2 if i > 0 else 1 + x = apply_convnext_stage( + x, + depths[i], + hidden_channels[i], + kernel_size, + strides, + activation, + use_conv_mlp, + use_grn=False, + name=f"stages_{i}", + ) + current_stride *= strides + # Add feature + features[f"BLOCK{i}_S{current_stride}"] = x + + # Head + x = self.build_head(x) + + super().__init__(inputs=inputs, outputs=x, features=features, **kwargs) + + # All references to `self` below this line + self.depths = depths + self.hidden_channels = hidden_channels + self.patch_size = patch_size + self.kernel_size = kernel_size + self.activation = activation + self.use_conv_mlp = use_conv_mlp + + def build_top(self, inputs, classes, classifier_activation, dropout_rate): + x = inputs + x = layers.GlobalAveragePooling2D(name="avg_pool")(inputs) + x = layers.LayerNormalization(epsilon=1e-6, name="head_norm")(x) + x = layers.Dropout(rate=dropout_rate, name="head_dropout")(x) + x = layers.Dense( + classes, activation=classifier_activation, name="classifier" + )(x) + return x + + @staticmethod + def available_feature_keys(): + feature_keys = ["STEM_S4"] + feature_keys.extend([f"BLOCK{i}_S{2**(i+2)}" for i in range(4)]) + return feature_keys + + def get_config(self): + config = super().get_config() + config.update( + { + "depths": self.depths, + "hidden_channels": self.hidden_channels, + "patch_size": self.patch_size, + "kernel_size": self.kernel_size, + "activation": self.activation, + "use_conv_mlp": self.use_conv_mlp, + } + ) + return config + + def fix_config(self, config): + unused_kwargs = [ + "depths", + "hidden_channels", + "patch_size", + "kernel_size", + "activation", + "use_conv_mlp", + ] + for k in unused_kwargs: + config.pop(k, None) + return config + + +""" +Model Definition +""" + + +class ConvNeXtAtto(ConvNeXt): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, + name: str = "ConvNeXtAtto", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + (2, 2, 6, 2), + (40, 80, 160, 320), + 4, + 7, + "gelu", + True, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class ConvNeXtFemto(ConvNeXt): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, + name: str = "ConvNeXtFemto", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + (2, 2, 6, 2), + (48, 96, 192, 384), + 4, + 7, + "gelu", + True, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class ConvNeXtPico(ConvNeXt): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, + name: str = "ConvNeXtPico", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + (2, 2, 6, 2), + (64, 128, 256, 512), + 4, + 7, + "gelu", + True, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class ConvNeXtNano(ConvNeXt): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, + name: str = "ConvNeXtNano", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + (2, 2, 8, 2), + (80, 160, 320, 640), + 4, + 7, + "gelu", + True, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class ConvNeXtTiny(ConvNeXt): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, + name: str = "ConvNeXtTiny", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + (3, 3, 9, 3), + (96, 192, 384, 768), + 4, + 7, + "gelu", + False, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class ConvNeXtSmall(ConvNeXt): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, + name: str = "ConvNeXtSmall", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + (3, 3, 27, 3), + (96, 192, 384, 768), + 4, + 7, + "gelu", + False, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class ConvNeXtBase(ConvNeXt): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, + name: str = "ConvNeXtBase", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + (3, 3, 27, 3), + (128, 256, 512, 1024), + 4, + 7, + "gelu", + False, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class ConvNeXtLarge(ConvNeXt): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, + name: str = "ConvNeXtLarge", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + (3, 3, 27, 3), + (192, 384, 768, 1536), + 4, + 7, + "gelu", + False, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class ConvNeXtXLarge(ConvNeXt): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, + name: str = "ConvNeXtXLarge", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + (3, 3, 27, 3), + (256, 512, 1024, 2048), + 4, + 7, + "gelu", + False, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +add_model_to_registry(ConvNeXtAtto, "imagenet") +add_model_to_registry(ConvNeXtFemto, "imagenet") +add_model_to_registry(ConvNeXtPico, "imagenet") +add_model_to_registry(ConvNeXtNano, "imagenet") +add_model_to_registry(ConvNeXtTiny, "imagenet") +add_model_to_registry(ConvNeXtSmall, "imagenet") +add_model_to_registry(ConvNeXtBase, "imagenet") +add_model_to_registry(ConvNeXtLarge, "imagenet") +add_model_to_registry(ConvNeXtXLarge, "imagenet") diff --git a/kimm/models/densenet.py b/kimm/models/densenet.py index 42cfee2..abb51fc 100644 --- a/kimm/models/densenet.py +++ b/kimm/models/densenet.py @@ -2,7 +2,6 @@ import keras from keras import layers -from keras import utils from kimm.blocks import apply_conv2d_block from kimm.models import BaseModel @@ -65,6 +64,7 @@ def apply_dense_transition_block( return x +@keras.saving.register_keras_serializable(package="kimm") class DenseNet(BaseModel): def __init__( self, @@ -72,16 +72,16 @@ def __init__( num_blocks: typing.Sequence[int] = [6, 12, 24, 16], **kwargs, ): - parsed_kwargs = self.parse_kwargs(kwargs) - img_input = self.determine_input_tensor( - parsed_kwargs["input_tensor"], - parsed_kwargs["input_shape"], - parsed_kwargs["default_size"], + input_tensor = kwargs.pop("input_tensor", None) + self.set_properties(kwargs) + inputs = self.determine_input_tensor( + input_tensor, + self._input_shape, + self._default_size, ) - x = img_input + x = inputs - if parsed_kwargs["include_preprocessing"]: - x = self.build_preprocessing(x, "imagenet") + x = self.build_preprocessing(x, "imagenet") # Prepare feature extraction features = {} @@ -125,30 +125,11 @@ def __init__( x = layers.ReLU()(x) # Head - if parsed_kwargs["include_top"]: - x = self.build_top( - x, - parsed_kwargs["classes"], - parsed_kwargs["classifier_activation"], - parsed_kwargs["dropout_rate"], - ) - else: - if parsed_kwargs["pooling"] == "avg": - x = layers.GlobalAveragePooling2D(name="avg_pool")(x) - elif parsed_kwargs["pooling"] == "max": - x = layers.GlobalMaxPooling2D(name="max_pool")(x) - - # Ensure that the model takes into account - # any potential predecessors of `input_tensor`. - if parsed_kwargs["input_tensor"] is not None: - inputs = utils.get_source_inputs(parsed_kwargs["input_tensor"]) - else: - inputs = img_input + x = self.build_head(x) super().__init__(inputs=inputs, outputs=x, features=features, **kwargs) # All references to `self` below this line - self.add_references(parsed_kwargs) self.growth_rate = growth_rate self.num_blocks = num_blocks diff --git a/kimm/models/efficientnet.py b/kimm/models/efficientnet.py index 8297a29..04b69e0 100644 --- a/kimm/models/efficientnet.py +++ b/kimm/models/efficientnet.py @@ -3,7 +3,6 @@ import keras from keras import layers -from keras import utils from kimm.blocks import apply_conv2d_block from kimm.blocks import apply_depthwise_separation_block @@ -124,6 +123,7 @@ def apply_edge_residual_block( return x +@keras.saving.register_keras_serializable(package="kimm") class EfficientNet(BaseModel): def __init__( self, @@ -173,16 +173,16 @@ def __init__( # TinyNet config round_fn = kwargs.pop("round_fn", math.ceil) - parsed_kwargs = self.parse_kwargs(kwargs) - img_input = self.determine_input_tensor( - parsed_kwargs["input_tensor"], - parsed_kwargs["input_shape"], - parsed_kwargs["default_size"], + input_tensor = kwargs.pop("input_tensor", None) + self.set_properties(kwargs) + inputs = self.determine_input_tensor( + input_tensor, + self._input_shape, + self._default_size, ) - x = img_input + x = inputs - if parsed_kwargs["include_preprocessing"]: - x = self.build_preprocessing(x, "imagenet") + x = self.build_preprocessing(x, "imagenet") # Prepare feature extraction features = {} @@ -256,30 +256,11 @@ def __init__( ) # Head - if parsed_kwargs["include_top"]: - x = self.build_top( - x, - parsed_kwargs["classes"], - parsed_kwargs["classifier_activation"], - parsed_kwargs["dropout_rate"], - ) - else: - if parsed_kwargs["pooling"] == "avg": - x = layers.GlobalAveragePooling2D(name="avg_pool")(x) - elif parsed_kwargs["pooling"] == "max": - x = layers.GlobalMaxPooling2D(name="max_pool")(x) - - # Ensure that the model takes into account - # any potential predecessors of `input_tensor`. - if parsed_kwargs["input_tensor"] is not None: - inputs = utils.get_source_inputs(parsed_kwargs["input_tensor"]) - else: - inputs = img_input + x = self.build_head(x) super().__init__(inputs=inputs, outputs=x, features=features, **kwargs) # All references to `self` below this line - self.add_references(parsed_kwargs) self.width = width self.depth = depth self.stem_channels = stem_channels @@ -349,12 +330,15 @@ def __init__( dropout_rate: float = 0.0, classes: int = 1000, classifier_activation: str = "softmax", - weights: typing.Optional[str] = None, # TODO: imagenet + weights: typing.Optional[str] = "imagenet", # TODO: imagenet config: typing.Union[str, typing.List] = "v1", name: str = "EfficientNetB0", **kwargs, ): kwargs = self.fix_config(kwargs) + if weights == "imagenet": + file_name = "efficientnetb0_tf_efficientnet_b0.ns_jft_in1k.keras" + kwargs["weights_url"] = f"{self.default_origin}/{file_name}" # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 1.0, diff --git a/kimm/models/ghostnet.py b/kimm/models/ghostnet.py index 14816f4..343fea8 100644 --- a/kimm/models/ghostnet.py +++ b/kimm/models/ghostnet.py @@ -3,7 +3,6 @@ import keras from keras import layers from keras import ops -from keras import utils from kimm.blocks import apply_conv2d_block from kimm.blocks import apply_se_block @@ -229,6 +228,7 @@ def apply_ghost_bottleneck( return out +@keras.saving.register_keras_serializable(package="kimm") class GhostNet(BaseModel): def __init__( self, @@ -251,18 +251,18 @@ def __init__( f"Received version={version}" ) - parsed_kwargs = self.parse_kwargs(kwargs) - img_input = self.determine_input_tensor( - parsed_kwargs["input_tensor"], - parsed_kwargs["input_shape"], - parsed_kwargs["default_size"], - require_flatten=parsed_kwargs["include_top"], + input_tensor = kwargs.pop("input_tensor", None) + self.set_properties(kwargs) + inputs = self.determine_input_tensor( + input_tensor, + self._input_shape, + self._default_size, + require_flatten=self._include_top, static_shape=True if version == "v2" else False, ) - x = img_input + x = inputs - if parsed_kwargs["include_preprocessing"]: - x = self.build_preprocessing(x, "imagenet") + x = self.build_preprocessing(x, "imagenet") # Prepare feature extraction features = {} @@ -309,30 +309,11 @@ def __init__( ) # Head - if parsed_kwargs["include_top"]: - x = self.build_top( - x, - parsed_kwargs["classes"], - parsed_kwargs["classifier_activation"], - parsed_kwargs["dropout_rate"], - ) - else: - if parsed_kwargs["pooling"] == "avg": - x = layers.GlobalAveragePooling2D(name="avg_pool")(x) - elif parsed_kwargs["pooling"] == "max": - x = layers.GlobalMaxPooling2D(name="max_pool")(x) - - # Ensure that the model takes into account - # any potential predecessors of `input_tensor`. - if parsed_kwargs["input_tensor"] is not None: - inputs = utils.get_source_inputs(parsed_kwargs["input_tensor"]) - else: - inputs = img_input + x = self.build_head(x) super().__init__(inputs=inputs, outputs=x, features=features, **kwargs) # All references to `self` below this line - self.add_references(parsed_kwargs) self.width = width self.config = config self.version = version diff --git a/kimm/models/inception_v3.py b/kimm/models/inception_v3.py index 06b5a7b..451a23b 100644 --- a/kimm/models/inception_v3.py +++ b/kimm/models/inception_v3.py @@ -3,7 +3,6 @@ import keras from keras import layers -from keras import utils from kimm.blocks import apply_conv2d_block from kimm.models import BaseModel @@ -202,19 +201,20 @@ def apply_inception_aux_block(inputs, classes, name="inception_aux_block"): return x +@keras.saving.register_keras_serializable(package="kimm") class InceptionV3Base(BaseModel): def __init__(self, has_aux_logits=False, **kwargs): - parsed_kwargs = self.parse_kwargs(kwargs, default_size=299) - img_input = self.determine_input_tensor( - parsed_kwargs["input_tensor"], - parsed_kwargs["input_shape"], - parsed_kwargs["default_size"], - require_flatten=parsed_kwargs["include_top"], + input_tensor = kwargs.pop("input_tensor", None) + self.set_properties(kwargs, 299) + inputs = self.determine_input_tensor( + input_tensor, + self._input_shape, + self._default_size, + require_flatten=self._include_top, ) - x = img_input + x = inputs - if parsed_kwargs["include_preprocessing"]: - x = self.build_preprocessing(x, "imagenet") + x = self.build_preprocessing(x, "imagenet") # Prepare feature extraction features = {} @@ -246,7 +246,7 @@ def __init__(self, has_aux_logits=False, **kwargs): if has_aux_logits: aux_logits = apply_inception_aux_block( - x, parsed_kwargs["classes"], "AuxLogits" + x, self._classes, "AuxLogits" ) x = apply_inception_d_block(x, "Mixed_7a") @@ -255,32 +255,13 @@ def __init__(self, has_aux_logits=False, **kwargs): features["BLOCK3_S32"] = x # Head - if parsed_kwargs["include_top"]: - x = self.build_top( - x, - parsed_kwargs["classes"], - parsed_kwargs["classifier_activation"], - parsed_kwargs["dropout_rate"], - ) - else: - if parsed_kwargs["pooling"] == "avg": - x = layers.GlobalAveragePooling2D(name="avg_pool")(x) - elif parsed_kwargs["pooling"] == "max": - x = layers.GlobalMaxPooling2D(name="max_pool")(x) - - # Ensure that the model takes into account - # any potential predecessors of `input_tensor`. - if parsed_kwargs["input_tensor"] is not None: - inputs = utils.get_source_inputs(parsed_kwargs["input_tensor"]) - else: - inputs = img_input + x = self.build_head(x) if has_aux_logits: x = [x, aux_logits] super().__init__(inputs=inputs, outputs=x, features=features, **kwargs) # All references to `self` below this line - self.add_references(parsed_kwargs) self.has_aux_logits = has_aux_logits @staticmethod diff --git a/kimm/models/mobilenet_v2.py b/kimm/models/mobilenet_v2.py index 965e511..d6fe3d9 100644 --- a/kimm/models/mobilenet_v2.py +++ b/kimm/models/mobilenet_v2.py @@ -2,8 +2,6 @@ import typing import keras -from keras import layers -from keras import utils from kimm.blocks import apply_conv2d_block from kimm.blocks import apply_depthwise_separation_block @@ -24,6 +22,7 @@ ] +@keras.saving.register_keras_serializable(package="kimm") class MobileNetV2(BaseModel): def __init__( self, @@ -42,16 +41,16 @@ def __init__( f"Received: config={config}" ) - parsed_kwargs = self.parse_kwargs(kwargs) - img_input = self.determine_input_tensor( - parsed_kwargs["input_tensor"], - parsed_kwargs["input_shape"], - parsed_kwargs["default_size"], + input_tensor = kwargs.pop("input_tensor", None) + self.set_properties(kwargs) + inputs = self.determine_input_tensor( + input_tensor, + self._input_shape, + self._default_size, ) - x = img_input + x = inputs - if parsed_kwargs["include_preprocessing"]: - x = self.build_preprocessing(x, "imagenet") + x = self.build_preprocessing(x, "imagenet") # Prepare feature extraction features = {} @@ -102,30 +101,11 @@ def __init__( ) # Head - if parsed_kwargs["include_top"]: - x = self.build_top( - x, - parsed_kwargs["classes"], - parsed_kwargs["classifier_activation"], - parsed_kwargs["dropout_rate"], - ) - else: - if parsed_kwargs["pooling"] == "avg": - x = layers.GlobalAveragePooling2D(name="avg_pool")(x) - elif parsed_kwargs["pooling"] == "max": - x = layers.GlobalMaxPooling2D(name="max_pool")(x) - - # Ensure that the model takes into account - # any potential predecessors of `input_tensor`. - if parsed_kwargs["input_tensor"] is not None: - inputs = utils.get_source_inputs(parsed_kwargs["input_tensor"]) - else: - inputs = img_input + x = self.build_head(x) super().__init__(inputs=inputs, outputs=x, features=features, **kwargs) # All references to `self` below this line - self.add_references(parsed_kwargs) self.width = width self.depth = depth self.fix_stem_and_head_channels = fix_stem_and_head_channels diff --git a/kimm/models/mobilenet_v3.py b/kimm/models/mobilenet_v3.py index 7a1ac50..5da1a5e 100644 --- a/kimm/models/mobilenet_v3.py +++ b/kimm/models/mobilenet_v3.py @@ -3,7 +3,6 @@ import keras from keras import layers -from keras import utils from kimm.blocks import apply_conv2d_block from kimm.blocks import apply_depthwise_separation_block @@ -80,6 +79,7 @@ ] +@keras.saving.register_keras_serializable(package="kimm") class MobileNetV3(BaseModel): def __init__( self, @@ -117,16 +117,16 @@ def __init__( bn_epsilon = kwargs.pop("bn_epsilon", 1e-5) padding = kwargs.pop("padding", None) - parsed_kwargs = self.parse_kwargs(kwargs) - img_input = self.determine_input_tensor( - parsed_kwargs["input_tensor"], - parsed_kwargs["input_shape"], - parsed_kwargs["default_size"], + input_tensor = kwargs.pop("input_tensor", None) + self.set_properties(kwargs) + inputs = self.determine_input_tensor( + input_tensor, + self._input_shape, + self._default_size, ) - x = img_input + x = inputs - if parsed_kwargs["include_preprocessing"]: - x = self.build_preprocessing(x, "imagenet") + x = self.build_preprocessing(x, "imagenet") # Prepare feature extraction features = {} @@ -215,7 +215,7 @@ def __init__( features[f"BLOCK{current_stage_idx}_S{current_stride}"] = x # Head - if parsed_kwargs["include_top"]: + if self._include_top: if fix_stem_and_head_channels: conv_head_channels = conv_head_channels else: @@ -226,29 +226,21 @@ def __init__( head_activation = force_activation or "hard_swish" x = self.build_top( x, - parsed_kwargs["classes"], - parsed_kwargs["classifier_activation"], - parsed_kwargs["dropout_rate"], + self._classes, + self._classifier_activation, + self._dropout_rate, conv_head_channels=conv_head_channels, head_activation=head_activation, ) else: - if parsed_kwargs["pooling"] == "avg": + if self._pooling == "avg": x = layers.GlobalAveragePooling2D(name="avg_pool")(x) - elif parsed_kwargs["pooling"] == "max": + elif self._pooling == "max": x = layers.GlobalMaxPooling2D(name="max_pool")(x) - # Ensure that the model takes into account - # any potential predecessors of `input_tensor`. - if parsed_kwargs["input_tensor"] is not None: - inputs = utils.get_source_inputs(parsed_kwargs["input_tensor"]) - else: - inputs = img_input - super().__init__(inputs=inputs, outputs=x, features=features, **kwargs) # All references to `self` below this line - self.add_references(parsed_kwargs) self.width = width self.depth = depth self.fix_stem_and_head_channels = fix_stem_and_head_channels @@ -647,12 +639,16 @@ def __init__( dropout_rate: float = 0.0, classes: int = 1000, classifier_activation: str = "softmax", - weights: typing.Optional[str] = None, # TODO: imagenet + weights: typing.Optional[str] = "imagenet", # TODO: imagenet config: typing.Union[str, typing.List] = "lcnet", name: str = "LCNet050", **kwargs, ): kwargs = self.fix_config(kwargs) + if weights == "imagenet": + origin = "https://github.com/james77777778/keras-aug/releases/download/v0.5.0" + file_name = "lcnet050_lcnet_050.ra2_in1k.keras" + kwargs["weights_url"] = f"{origin}/{file_name}" # default to TF configuration (bn_epsilon=1e-3 and padding="same") super().__init__( 0.5, diff --git a/kimm/models/mobilevit.py b/kimm/models/mobilevit.py index 4da9299..d96622a 100644 --- a/kimm/models/mobilevit.py +++ b/kimm/models/mobilevit.py @@ -4,7 +4,6 @@ import keras from keras import layers from keras import ops -from keras import utils from kimm.blocks import apply_conv2d_block from kimm.blocks import apply_inverted_residual_block @@ -161,6 +160,7 @@ def apply_mobilevit_block( return x +@keras.saving.register_keras_serializable(package="kimm") class MobileViT(BaseModel): def __init__( self, @@ -183,17 +183,17 @@ def __init__( f"Received: config={config}" ) - parsed_kwargs = self.parse_kwargs(kwargs, 256) - img_input = self.determine_input_tensor( - parsed_kwargs["input_tensor"], - parsed_kwargs["input_shape"], - parsed_kwargs["default_size"], + input_tensor = kwargs.pop("input_tensor", None) + self.set_properties(kwargs, 256) + inputs = self.determine_input_tensor( + input_tensor, + self._input_shape, + self._default_size, static_shape=True, ) - x = img_input + x = inputs - if parsed_kwargs["include_preprocessing"]: - x = self.build_preprocessing(x, "imagenet") + x = self.build_preprocessing(x, "0_1") # Prepare feature extraction features = {} @@ -248,30 +248,11 @@ def __init__( ) # Head - if parsed_kwargs["include_top"]: - x = self.build_top( - x, - parsed_kwargs["classes"], - parsed_kwargs["classifier_activation"], - parsed_kwargs["dropout_rate"], - ) - else: - if parsed_kwargs["pooling"] == "avg": - x = layers.GlobalAveragePooling2D(name="avg_pool")(x) - elif parsed_kwargs["pooling"] == "max": - x = layers.GlobalMaxPooling2D(name="max_pool")(x) - - # Ensure that the model takes into account - # any potential predecessors of `input_tensor`. - if parsed_kwargs["input_tensor"] is not None: - inputs = utils.get_source_inputs(parsed_kwargs["input_tensor"]) - else: - inputs = img_input + x = self.build_head(x) super().__init__(inputs=inputs, outputs=x, features=features, **kwargs) # All references to `self` below this line - self.add_references(parsed_kwargs) self.stem_channels = stem_channels self.head_channels = head_channels self.activation = activation @@ -392,12 +373,16 @@ def __init__( dropout_rate: float = 0.1, classes: int = 1000, classifier_activation: str = "softmax", - weights: typing.Optional[str] = None, # TODO: imagenet + weights: typing.Optional[str] = "imagenet", # TODO: imagenet config: str = "v1_xxs", name="MobileViTXXS", **kwargs, ): kwargs = self.fix_config(kwargs) + if weights == "imagenet": + origin = "https://github.com/james77777778/keras-aug/releases/download/v0.5.0" + file_name = "mobilevitxxs_mobilevit_xxs.cvnets_in1k.keras" + kwargs["weights_url"] = f"{origin}/{file_name}" super().__init__( 16, 320, diff --git a/kimm/models/models_test.py b/kimm/models/models_test.py index d89f4d7..bc10777 100644 --- a/kimm/models/models_test.py +++ b/kimm/models/models_test.py @@ -19,6 +19,19 @@ *((f"BLOCK{i}", [1, 32, 32, 768]) for i in range(32)), ], ), + # convnext + ( + kimm_models.ConvNeXtAtto.__name__, + kimm_models.ConvNeXtAtto, + 288, + [ + ("STEM_S4", [1, 72, 72, 40]), + ("BLOCK0_S4", [1, 72, 72, 40]), + ("BLOCK1_S8", [1, 36, 36, 80]), + ("BLOCK2_S16", [1, 18, 18, 160]), + ("BLOCK3_S32", [1, 9, 9, 320]), + ], + ), # densenet ( kimm_models.DenseNet121.__name__, @@ -323,7 +336,7 @@ ] -class ConvMixerTest(testing.TestCase, parameterized.TestCase): +class ModelTest(testing.TestCase, parameterized.TestCase): @parameterized.named_parameters(MODEL_CONFIGS) def test_model_base(self, model_class, image_size, features): # TODO: test the correctness of the real image diff --git a/kimm/models/regnet.py b/kimm/models/regnet.py index 5a0a33d..4e61ccc 100644 --- a/kimm/models/regnet.py +++ b/kimm/models/regnet.py @@ -3,7 +3,6 @@ import keras import numpy as np from keras import layers -from keras import utils from kimm.blocks import apply_conv2d_block from kimm.blocks import apply_se_block @@ -142,6 +141,7 @@ def apply_bottleneck_block( return x +@keras.saving.register_keras_serializable(package="kimm") class RegNet(BaseModel): def __init__( self, @@ -155,16 +155,16 @@ def __init__( ): per_stage_config = _generate_regnet(w0, wa, wm, group_size, depth) - parsed_kwargs = self.parse_kwargs(kwargs) - img_input = self.determine_input_tensor( - parsed_kwargs["input_tensor"], - parsed_kwargs["input_shape"], - parsed_kwargs["default_size"], + input_tensor = kwargs.pop("input_tensor", None) + self.set_properties(kwargs) + inputs = self.determine_input_tensor( + input_tensor, + self._input_shape, + self._default_size, ) - x = img_input + x = inputs - if parsed_kwargs["include_preprocessing"]: - x = self.build_preprocessing(x, "imagenet") + x = self.build_preprocessing(x, "imagenet") # Prepare feature extraction features = {} @@ -190,30 +190,11 @@ def __init__( features[f"BLOCK{current_stage_idx}_S{current_stride}"] = x # Head - if parsed_kwargs["include_top"]: - x = self.build_top( - x, - parsed_kwargs["classes"], - parsed_kwargs["classifier_activation"], - parsed_kwargs["dropout_rate"], - ) - else: - if parsed_kwargs["pooling"] == "avg": - x = layers.GlobalAveragePooling2D(name="avg_pool")(x) - elif parsed_kwargs["pooling"] == "max": - x = layers.GlobalMaxPooling2D(name="max_pool")(x) - - # Ensure that the model takes into account - # any potential predecessors of `input_tensor`. - if parsed_kwargs["input_tensor"] is not None: - inputs = utils.get_source_inputs(parsed_kwargs["input_tensor"]) - else: - inputs = img_input + x = self.build_head(x) super().__init__(inputs=inputs, outputs=x, features=features, **kwargs) # All references to `self` below this line - self.add_references(parsed_kwargs) self.w0 = w0 self.wa = wa self.wm = wm diff --git a/kimm/models/resnet.py b/kimm/models/resnet.py index 05d7a0d..65d35d7 100644 --- a/kimm/models/resnet.py +++ b/kimm/models/resnet.py @@ -2,7 +2,6 @@ import keras from keras import layers -from keras import utils from kimm.blocks import apply_conv2d_block from kimm.models.base_model import BaseModel @@ -104,6 +103,7 @@ def apply_bottleneck_block( return x +@keras.saving.register_keras_serializable(package="kimm") class ResNet(BaseModel): def __init__( self, block_fn: str, num_blocks: typing.Sequence[int], **kwargs @@ -114,16 +114,16 @@ def __init__( f"Received: block_fn={block_fn}" ) - parsed_kwargs = self.parse_kwargs(kwargs) - img_input = self.determine_input_tensor( - parsed_kwargs["input_tensor"], - parsed_kwargs["input_shape"], - parsed_kwargs["default_size"], + input_tensor = kwargs.pop("input_tensor", None) + self.set_properties(kwargs) + inputs = self.determine_input_tensor( + input_tensor, + self._input_shape, + self._default_size, ) - x = img_input + x = inputs - if parsed_kwargs["include_preprocessing"]: - x = self.build_preprocessing(x, "imagenet") + x = self.build_preprocessing(x, "imagenet") # Prepare feature extraction features = {} @@ -161,30 +161,11 @@ def __init__( features[f"BLOCK{current_stage_idx}_S{current_stride}"] = x # Head - if parsed_kwargs["include_top"]: - x = self.build_top( - x, - parsed_kwargs["classes"], - parsed_kwargs["classifier_activation"], - parsed_kwargs["dropout_rate"], - ) - else: - if parsed_kwargs["pooling"] == "avg": - x = layers.GlobalAveragePooling2D(name="avg_pool")(x) - elif parsed_kwargs["pooling"] == "max": - x = layers.GlobalMaxPooling2D(name="max_pool")(x) - - # Ensure that the model takes into account - # any potential predecessors of `input_tensor`. - if parsed_kwargs["input_tensor"] is not None: - inputs = utils.get_source_inputs(parsed_kwargs["input_tensor"]) - else: - inputs = img_input + x = self.build_head(x) super().__init__(inputs=inputs, outputs=x, features=features, **kwargs) # All references to `self` below this line - self.add_references(parsed_kwargs) self.block_fn = block_fn self.num_blocks = num_blocks diff --git a/kimm/models/vgg.py b/kimm/models/vgg.py index e2f1895..23b241f 100644 --- a/kimm/models/vgg.py +++ b/kimm/models/vgg.py @@ -2,7 +2,6 @@ import keras from keras import layers -from keras import utils from kimm.models import BaseModel from kimm.utils import add_model_to_registry @@ -107,6 +106,7 @@ def apply_conv_mlp_layer( return x +@keras.saving.register_keras_serializable(package="kimm") class VGG(BaseModel): def __init__(self, config: typing.Union[str, typing.List], **kwargs): _available_configs = ["vgg11", "vgg13", "vgg16", "vgg19"] @@ -124,16 +124,16 @@ def __init__(self, config: typing.Union[str, typing.List], **kwargs): f"Received: config={config}" ) - parsed_kwargs = self.parse_kwargs(kwargs) - img_input = self.determine_input_tensor( - parsed_kwargs["input_tensor"], - parsed_kwargs["input_shape"], - parsed_kwargs["default_size"], + input_tensor = kwargs.pop("input_tensor", None) + self.set_properties(kwargs) + inputs = self.determine_input_tensor( + input_tensor, + self._input_shape, + self._default_size, ) - x = img_input + x = inputs - if parsed_kwargs["include_preprocessing"]: - x = self.build_preprocessing(x, "imagenet") + x = self.build_preprocessing(x, "imagenet") # Prepare feature extraction features = {} @@ -171,30 +171,11 @@ def __init__(self, config: typing.Union[str, typing.List], **kwargs): x = apply_conv_mlp_layer(x, 4096, 7, 1.0, 0.0, name="pre_logits") # Head - if parsed_kwargs["include_top"]: - x = self.build_top( - x, - parsed_kwargs["classes"], - parsed_kwargs["classifier_activation"], - parsed_kwargs["dropout_rate"], - ) - else: - if parsed_kwargs["pooling"] == "avg": - x = layers.GlobalAveragePooling2D(name="avg_pool")(x) - elif parsed_kwargs["pooling"] == "max": - x = layers.GlobalMaxPooling2D(name="max_pool")(x) - - # Ensure that the model takes into account - # any potential predecessors of `input_tensor`. - if parsed_kwargs["input_tensor"] is not None: - inputs = utils.get_source_inputs(parsed_kwargs["input_tensor"]) - else: - inputs = img_input + x = self.build_head(x) super().__init__(inputs=inputs, outputs=x, features=features, **kwargs) # All references to `self` below this line - self.add_references(parsed_kwargs) self.config = config @staticmethod @@ -248,7 +229,6 @@ def __init__( classifier_activation=classifier_activation, weights=weights, name=name, - default_size=224, **kwargs, ) @@ -281,7 +261,6 @@ def __init__( classifier_activation=classifier_activation, weights=weights, name=name, - default_size=224, **kwargs, ) @@ -314,7 +293,6 @@ def __init__( classifier_activation=classifier_activation, weights=weights, name=name, - default_size=224, **kwargs, ) @@ -347,7 +325,6 @@ def __init__( classifier_activation=classifier_activation, weights=weights, name=name, - default_size=224, **kwargs, ) diff --git a/kimm/models/vision_transformer.py b/kimm/models/vision_transformer.py index 2f2b60c..a5d69fb 100644 --- a/kimm/models/vision_transformer.py +++ b/kimm/models/vision_transformer.py @@ -2,7 +2,6 @@ import keras from keras import layers -from keras import utils from kimm import layers as kimm_layers from kimm.blocks import apply_transformer_block @@ -10,6 +9,7 @@ from kimm.utils import add_model_to_registry +@keras.saving.register_keras_serializable(package="kimm") class VisionTransformer(BaseModel): def __init__( self, @@ -23,22 +23,22 @@ def __init__( pos_dropout_rate: float = 0.0, **kwargs, ): - parsed_kwargs = self.parse_kwargs(kwargs, 384) - if parsed_kwargs["pooling"] is not None: + input_tensor = kwargs.pop("input_tensor", None) + self.set_properties(kwargs, 384) + if self._pooling is not None: raise ValueError( "`VisionTransformer` doesn't support `pooling`. " - f"Received: pooling={parsed_kwargs['pooling']}" + f"Received: pooling={self._pooling}" ) - img_input = self.determine_input_tensor( - parsed_kwargs["input_tensor"], - parsed_kwargs["input_shape"], - parsed_kwargs["default_size"], + inputs = self.determine_input_tensor( + input_tensor, + self._input_shape, + self._default_size, static_shape=True, ) - x = img_input + x = inputs - if parsed_kwargs["include_preprocessing"]: - x = self.build_preprocessing(x, "-1_1") + x = self.build_preprocessing(x, "-1_1") # Prepare feature extraction features = {} @@ -72,25 +72,17 @@ def __init__( x = layers.LayerNormalization(epsilon=1e-6, name="norm")(x) # Head - if parsed_kwargs["include_top"]: + if self._include_top: x = self.build_top( x, - parsed_kwargs["classes"], - parsed_kwargs["classifier_activation"], - parsed_kwargs["dropout_rate"], + self._classes, + self._classifier_activation, + self._dropout_rate, ) - # Ensure that the model takes into account - # any potential predecessors of `input_tensor`. - if parsed_kwargs["input_tensor"] is not None: - inputs = utils.get_source_inputs(parsed_kwargs["input_tensor"]) - else: - inputs = img_input - super().__init__(inputs=inputs, outputs=x, features=features, **kwargs) # All references to `self` below this line - self.add_references(parsed_kwargs) self.patch_size = patch_size self.embed_dim = embed_dim self.depth = depth diff --git a/kimm/models/xception.py b/kimm/models/xception.py index 03e9cff..b67a549 100644 --- a/kimm/models/xception.py +++ b/kimm/models/xception.py @@ -2,7 +2,6 @@ import keras from keras import layers -from keras import utils from kimm.models import BaseModel from kimm.utils import add_model_to_registry @@ -65,19 +64,20 @@ def apply_xception_block( return x +@keras.saving.register_keras_serializable(package="kimm") class XceptionBase(BaseModel): def __init__(self, **kwargs): - parsed_kwargs = self.parse_kwargs(kwargs) - img_input = self.determine_input_tensor( - parsed_kwargs["input_tensor"], - parsed_kwargs["input_shape"], - parsed_kwargs["default_size"], + input_tensor = kwargs.pop("input_tensor", None) + self.set_properties(kwargs) + inputs = self.determine_input_tensor( + input_tensor, + self._input_shape, + self._default_size, min_size=71, ) - x = img_input + x = inputs - if parsed_kwargs["include_preprocessing"]: - x = self.build_preprocessing(x, "-1_1") + x = self.build_preprocessing(x, "-1_1") # Prepare feature extraction features = {} @@ -127,30 +127,11 @@ def __init__(self, **kwargs): features["BLOCK3_S32"] = x # Head - if parsed_kwargs["include_top"]: - x = self.build_top( - x, - parsed_kwargs["classes"], - parsed_kwargs["classifier_activation"], - parsed_kwargs["dropout_rate"], - ) - else: - if parsed_kwargs["pooling"] == "avg": - x = layers.GlobalAveragePooling2D(name="avg_pool")(x) - elif parsed_kwargs["pooling"] == "max": - x = layers.GlobalMaxPooling2D(name="max_pool")(x) - - # Ensure that the model takes into account - # any potential predecessors of `input_tensor`. - if parsed_kwargs["input_tensor"] is not None: - inputs = utils.get_source_inputs(parsed_kwargs["input_tensor"]) - else: - inputs = img_input + x = self.build_head(x) super().__init__(inputs=inputs, outputs=x, features=features, **kwargs) # All references to `self` below this line - self.add_references(parsed_kwargs) @staticmethod def available_feature_keys(): diff --git a/kimm/utils/timm_utils.py b/kimm/utils/timm_utils.py index d1e7a17..b26db04 100644 --- a/kimm/utils/timm_utils.py +++ b/kimm/utils/timm_utils.py @@ -64,9 +64,11 @@ def assign_weights( # conventional conv2d layer keras_weight.assign(np.transpose(torch_weight, [2, 3, 1, 0])) else: - print(keras_weight.shape) - print(torch_weight.shape) - raise ValueError(f"Failed to assign {keras_name}") + raise ValueError( + f"Failed to assign {keras_name}. " + f"keras weight shape={keras_weight.shape}, " + f"torch weight shape={torch_weight.shape}" + ) elif len(keras_weight.shape) == 2: # dense layer keras_weight.assign(np.transpose(torch_weight)) diff --git a/requirements.txt b/requirements.txt index 1e788cf..f54fabf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,14 +1,15 @@ # Working GPU setup -# CUDA 11.8, CUDNN 8.8 - -# tensorflow>=2.14.1 -# --index-url https://download.pytorch.org/whl/cu118 -# torch==2.1.0 torchvision==0.16.0 +# CUDA 12.2, CUDNN 8.9 +# +# tensorflow==2.15.0.post1 +# +# --index-url https://download.pytorch.org/whl/cu121 +# torch torchvision +# # -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html -# "jax[cuda11_local]" +# "jax[cuda12_local]" -# Following is for github runner and borrowed from -# https://github.com/keras-team/keras/blob/master/requirements.txt +# Following is for github runner tf-nightly-cpu==2.16.0.dev20240108 --extra-index-url https://download.pytorch.org/whl/cpu diff --git a/shell/export.sh b/shell/export.sh index ef849b2..e788389 100755 --- a/shell/export.sh +++ b/shell/export.sh @@ -5,6 +5,7 @@ export CUDA_VISIBLE_DEVICES= export TF_CPP_MIN_LOG_LEVEL=3 export KERAS_BACKEND=tensorflow python3 -m tools.convert_convmixer_from_timm +python3 -m tools.convert_convnext_from_timm python3 -m tools.convert_densenet_from_timm python3 -m tools.convert_efficientnet_from_timm python3 -m tools.convert_ghostnet_from_timm diff --git a/tools/convert_convnext_from_timm.py b/tools/convert_convnext_from_timm.py new file mode 100644 index 0000000..c819155 --- /dev/null +++ b/tools/convert_convnext_from_timm.py @@ -0,0 +1,148 @@ +""" +pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu +pip install timm +""" +import os + +import keras +import numpy as np +import timm +import torch + +from kimm.models import convnext +from kimm.utils.timm_utils import assign_weights +from kimm.utils.timm_utils import is_same_weights +from kimm.utils.timm_utils import separate_keras_weights +from kimm.utils.timm_utils import separate_torch_state_dict + +timm_model_names = [ + "convnext_atto.d2_in1k", + "convnext_femto.d1_in1k", + "convnext_pico.d1_in1k", + "convnext_nano.in12k_ft_in1k", + "convnext_tiny.in12k_ft_in1k", + "convnext_small.in12k_ft_in1k", + "convnext_base.fb_in22k_ft_in1k", + "convnext_large.fb_in22k_ft_in1k", + "convnext_xlarge.fb_in22k_ft_in1k", +] +keras_model_classes = [ + convnext.ConvNeXtAtto, + convnext.ConvNeXtFemto, + convnext.ConvNeXtPico, + convnext.ConvNeXtNano, + convnext.ConvNeXtTiny, + convnext.ConvNeXtSmall, + convnext.ConvNeXtBase, + convnext.ConvNeXtLarge, + convnext.ConvNeXtXLarge, +] + +for timm_model_name, keras_model_class in zip( + timm_model_names, keras_model_classes +): + """ + Prepare timm model and keras model + """ + input_shape = [224, 224, 3] + torch_model = timm.create_model(timm_model_name, pretrained=True) + torch_model = torch_model.eval() + trainable_state_dict, non_trainable_state_dict = separate_torch_state_dict( + torch_model.state_dict() + ) + keras_model = keras_model_class( + input_shape=input_shape, + include_preprocessing=False, + classifier_activation="linear", + ) + trainable_weights, non_trainable_weights = separate_keras_weights( + keras_model + ) + + # for torch_name, (_, keras_name) in zip( + # trainable_state_dict.keys(), trainable_weights + # ): + # print(f"{torch_name} {keras_name}") + + # print(len(trainable_state_dict.keys())) + # print(len(trainable_weights)) + + # exit() + + """ + Assign weights + """ + for keras_weight, keras_name in trainable_weights + non_trainable_weights: + # prevent gamma to be replaced + is_layerscale = False + keras_name: str + torch_name = keras_name + torch_name = torch_name.replace("_", ".") + + # stem + torch_name = torch_name.replace("stem.0.conv2d.kernel", "stem.0.weight") + torch_name = torch_name.replace("stem.0.conv2d.bias", "stem.0.bias") + + # blocks + torch_name = torch_name.replace("dwconv2d.", "") + torch_name = torch_name.replace("conv2d.", "") + torch_name = torch_name.replace("conv.dw", "conv_dw") + if "layerscale" in torch_name: + is_layerscale = True + torch_name = torch_name.replace("layerscale.", "") + # head + torch_name = torch_name.replace("classifier", "head.fc") + + # weights naming mapping + torch_name = torch_name.replace("kernel", "weight") # conv2d + if not is_layerscale: + torch_name = torch_name.replace("gamma", "weight") # bn + torch_name = torch_name.replace("beta", "bias") # bn + torch_name = torch_name.replace("moving.mean", "running_mean") # bn + torch_name = torch_name.replace("moving.variance", "running_var") # bn + + # assign weights + if torch_name in trainable_state_dict: + torch_weights = trainable_state_dict[torch_name].numpy() + elif torch_name in non_trainable_state_dict: + torch_weights = non_trainable_state_dict[torch_name].numpy() + else: + raise ValueError( + "Can't find the corresponding torch weights. " + f"Got keras_name={keras_name}, torch_name={torch_name}" + ) + if is_layerscale: + assign_weights(keras_name, keras_weight, torch_weights) + elif is_same_weights( + keras_name, keras_weight, torch_name, torch_weights + ): + assign_weights(keras_name, keras_weight, torch_weights) + else: + raise ValueError( + "Can't find the corresponding torch weights. The shape is " + f"mismatched. Got keras_name={keras_name}, " + f"keras_weight shape={keras_weight.shape}, " + f"torch_name={torch_name}, " + f"torch_weights shape={torch_weights.shape}" + ) + + """ + Verify model outputs + """ + np.random.seed(2023) + keras_data = np.random.uniform(size=[1] + input_shape).astype("float32") + torch_data = torch.from_numpy(np.transpose(keras_data, [0, 3, 1, 2])) + torch_y = torch_model(torch_data) + keras_y = keras_model(keras_data, training=False) + torch_y = torch_y.detach().cpu().numpy() + keras_y = keras.ops.convert_to_numpy(keras_y) + np.testing.assert_allclose(torch_y, keras_y, atol=1e-4) + print(f"{keras_model_class.__name__}: output matched!") + + """ + Save converted model + """ + os.makedirs("exported", exist_ok=True) + export_path = f"exported/{keras_model.name.lower()}_{timm_model_name}.keras" + keras_model.save(export_path) + print(f"Export to {export_path}")