From b3d56841574c50eb65a4fbfde8dbd39a408e947b Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Wed, 17 Jan 2024 17:48:09 +0800 Subject: [PATCH 1/3] Add `RegNet` --- kimm/models/__init__.py | 1 + kimm/models/regnet.py | 1157 +++++++++++++++++++++++++++++ kimm/models/regnet_test.py | 61 ++ tools/convert_regnet_from_timm.py | 171 +++++ 4 files changed, 1390 insertions(+) create mode 100644 kimm/models/regnet.py create mode 100644 kimm/models/regnet_test.py create mode 100644 tools/convert_regnet_from_timm.py diff --git a/kimm/models/__init__.py b/kimm/models/__init__.py index fff901c..f2e256a 100644 --- a/kimm/models/__init__.py +++ b/kimm/models/__init__.py @@ -6,5 +6,6 @@ from kimm.models.mobilenet_v2 import * # noqa:F403 from kimm.models.mobilenet_v3 import * # noqa:F403 from kimm.models.mobilevit import * # noqa:F403 +from kimm.models.regnet import * # noqa:F403 from kimm.models.resnet import * # noqa:F403 from kimm.models.vision_transformer import * # noqa:F403 diff --git a/kimm/models/regnet.py b/kimm/models/regnet.py new file mode 100644 index 0000000..5a0a33d --- /dev/null +++ b/kimm/models/regnet.py @@ -0,0 +1,1157 @@ +import typing + +import keras +import numpy as np +from keras import layers +from keras import utils + +from kimm.blocks import apply_conv2d_block +from kimm.blocks import apply_se_block +from kimm.models.base_model import BaseModel +from kimm.utils import add_model_to_registry + + +def _adjust_widths_and_groups(widths, groups, expansion_ratio): + def _quantize_float(f, q): + return int(round(f / q) * q) + + bottleneck_widths = [int(w * b) for w, b in zip(widths, expansion_ratio)] + groups = [min(g, w_bot) for g, w_bot in zip(groups, bottleneck_widths)] + bottleneck_widths = [ + _quantize_float(w_bot, g) for w_bot, g in zip(bottleneck_widths, groups) + ] + widths = [ + int(w_bot / b) for w_bot, b in zip(bottleneck_widths, expansion_ratio) + ] + return widths, groups + + +def _generate_regnet( + width_init, + width_slope, + width_mult, + group_size, + depth, + quant=8, + expansion_ratio=1.0, +): + widths_cont = np.arange(depth) * width_slope + width_init + width_exps = np.round(np.log(widths_cont / width_init) / np.log(width_mult)) + widths = ( + np.round( + np.divide(width_init * np.power(width_mult, width_exps), quant) + ) + * quant + ) + num_stages = len(np.unique(widths)) + groups = np.array([group_size for _ in range(num_stages)]) + + widths = np.array(widths).astype(int).tolist() + stage_gs = groups.astype(int).tolist() + + # Convert to per-stage format + stage_widths, stage_depths = np.unique(widths, return_counts=True) + stage_e = [expansion_ratio for _ in range(num_stages)] + stage_strides = [] + for _ in range(num_stages): + stride = 2 + stage_strides.append(stride) + + # Adjust the compatibility of ws and gws + stage_widths, stage_gs = _adjust_widths_and_groups( + stage_widths, stage_gs, stage_e + ) + per_stage_args = [ + params + for params in zip( + stage_widths, + stage_strides, + stage_depths, + stage_e, + stage_gs, + ) + ] + return per_stage_args + + +def apply_bottleneck_block( + inputs, + output_channels: int, + strides: int = 1, + expansion_ratio: float = 1.0, + group_size: int = 1, + se_ratio: float = 0.25, + activation="relu", + linear_out: bool = False, + name="bottleneck_block", +): + input_channels = inputs.shape[-1] + expansion_channels = int(round(output_channels * expansion_ratio)) + groups = expansion_channels // group_size + + shortcut = inputs + x = inputs + x = apply_conv2d_block( + x, + expansion_channels, + 1, + 1, + activation=activation, + name=f"{name}_conv1", + ) + x = apply_conv2d_block( + x, + expansion_channels, + 3, + strides, + groups=groups, + activation=activation, + name=f"{name}_conv2", + ) + if se_ratio > 0.0: + x = apply_se_block( + x, + se_ratio, + activation, + se_input_channels=input_channels, + name=f"{name}_se", + ) + x = apply_conv2d_block( + x, + output_channels, + 1, + 1, + activation=None, + name=f"{name}_conv3", + ) + + # downsampling + if strides != 1 or input_channels != output_channels: + shortcut = apply_conv2d_block( + shortcut, + output_channels, + 1, + strides, + activation=None, + name=f"{name}_downsample", + ) + + x = layers.Add(name=f"{name}_add")([x, shortcut]) + if not linear_out: + x = layers.Activation(activation=activation, name=f"{name}")(x) + return x + + +class RegNet(BaseModel): + def __init__( + self, + w0: int = 80, + wa: float = 42.64, + wm: float = 2.66, + group_size: int = 24, + depth: int = 21, + se_ratio: float = 0.0, + **kwargs, + ): + per_stage_config = _generate_regnet(w0, wa, wm, group_size, depth) + + parsed_kwargs = self.parse_kwargs(kwargs) + img_input = self.determine_input_tensor( + parsed_kwargs["input_tensor"], + parsed_kwargs["input_shape"], + parsed_kwargs["default_size"], + ) + x = img_input + + if parsed_kwargs["include_preprocessing"]: + x = self.build_preprocessing(x, "imagenet") + + # Prepare feature extraction + features = {} + + # stem + stem_channels = 32 + x = apply_conv2d_block( + x, stem_channels, 3, 2, activation="relu", name="stem" + ) + features["STEM_S2"] = x + + # stages + current_stride = 2 + for current_stage_idx, params in enumerate(per_stage_config): + c, s, d, e, g = params + current_stride *= s + # blocks + for current_block_idx in range(d): + s = s if current_block_idx == 0 else 1 + name = f"s{current_stage_idx + 1}_b{current_block_idx + 1}" + x = apply_bottleneck_block(x, c, s, e, g, se_ratio, name=name) + # add feature + features[f"BLOCK{current_stage_idx}_S{current_stride}"] = x + + # Head + if parsed_kwargs["include_top"]: + x = self.build_top( + x, + parsed_kwargs["classes"], + parsed_kwargs["classifier_activation"], + parsed_kwargs["dropout_rate"], + ) + else: + if parsed_kwargs["pooling"] == "avg": + x = layers.GlobalAveragePooling2D(name="avg_pool")(x) + elif parsed_kwargs["pooling"] == "max": + x = layers.GlobalMaxPooling2D(name="max_pool")(x) + + # Ensure that the model takes into account + # any potential predecessors of `input_tensor`. + if parsed_kwargs["input_tensor"] is not None: + inputs = utils.get_source_inputs(parsed_kwargs["input_tensor"]) + else: + inputs = img_input + + super().__init__(inputs=inputs, outputs=x, features=features, **kwargs) + + # All references to `self` below this line + self.add_references(parsed_kwargs) + self.w0 = w0 + self.wa = wa + self.wm = wm + self.group_size = group_size + self.depth = depth + self.se_ratio = se_ratio + + @staticmethod + def available_feature_keys(): + feature_keys = ["STEM_S2"] + feature_keys.extend( + [f"BLOCK{i}_S{j}" for i, j in zip(range(4), [4, 8, 16, 32])] + ) + return feature_keys + + def get_config(self): + config = super().get_config() + config.update( + { + "w0": self.w0, + "wa": self.wa, + "wm": self.wm, + "group_size": self.group_size, + "depth": self.depth, + "se_ratio": self.se_ratio, + } + ) + return config + + def fix_config(self, config): + unused_kwargs = ["w0", "wa", "wm", "group_size", "depth", "se_ratio"] + for k in unused_kwargs: + config.pop(k, None) + return config + + +""" +Model Definition +""" + + +class RegNetX002(RegNet): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, + name: str = "RegNetX002", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + 24, + 36.44, + 2.49, + 8, + 13, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class RegNetY002(RegNet): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, + name: str = "RegNetY002", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + 24, + 36.44, + 2.49, + 8, + 13, + 0.25, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class RegNetX004(RegNet): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, + name: str = "RegNetX004", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + 24, + 24.48, + 2.54, + 16, + 22, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class RegNetY004(RegNet): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, + name: str = "RegNetY004", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + 48, + 27.89, + 2.09, + 8, + 16, + 0.25, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class RegNetX006(RegNet): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, + name: str = "RegNetX006", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + 48, + 36.97, + 2.24, + 24, + 16, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class RegNetY006(RegNet): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, + name: str = "RegNetY006", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + 48, + 32.54, + 2.32, + 16, + 15, + 0.25, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class RegNetX008(RegNet): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, + name: str = "RegNetX008", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + 56, + 35.73, + 2.28, + 16, + 16, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class RegNetY008(RegNet): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, + name: str = "RegNetY008", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + 56, + 38.84, + 2.4, + 16, + 14, + 0.25, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class RegNetX016(RegNet): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, + name: str = "RegNetX016", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + 80, + 34.01, + 2.25, + 24, + 18, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class RegNetY016(RegNet): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, + name: str = "RegNetY016", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + 48, + 20.71, + 2.65, + 24, + 27, + 0.25, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class RegNetX032(RegNet): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, + name: str = "RegNetX032", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + 88, + 26.31, + 2.25, + 48, + 25, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class RegNetY032(RegNet): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, + name: str = "RegNetY032", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + 80, + 42.63, + 2.66, + 24, + 21, + 0.25, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class RegNetX040(RegNet): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, + name: str = "RegNetX040", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + 96, + 38.65, + 2.43, + 40, + 23, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class RegNetY040(RegNet): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, + name: str = "RegNetY040", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + 96, + 31.41, + 2.24, + 64, + 22, + 0.25, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class RegNetX064(RegNet): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, + name: str = "RegNetX064", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + 184, + 60.83, + 2.07, + 56, + 17, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class RegNetY064(RegNet): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, + name: str = "RegNetY064", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + 112, + 33.22, + 2.27, + 72, + 25, + 0.25, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class RegNetX080(RegNet): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, + name: str = "RegNetX080", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + 80, + 49.56, + 2.88, + 120, + 23, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class RegNetY080(RegNet): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, + name: str = "RegNetY080", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + 192, + 76.82, + 2.19, + 56, + 17, + 0.25, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class RegNetX120(RegNet): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, + name: str = "RegNetX120", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + 168, + 73.36, + 2.37, + 112, + 19, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class RegNetY120(RegNet): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, + name: str = "RegNetY120", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + 168, + 73.36, + 2.37, + 112, + 19, + 0.25, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class RegNetX160(RegNet): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, + name: str = "RegNetX160", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + 216, + 55.59, + 2.1, + 128, + 22, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class RegNetY160(RegNet): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, + name: str = "RegNetY160", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + 200, + 106.23, + 2.48, + 112, + 18, + 0.25, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class RegNetX320(RegNet): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, + name: str = "RegNetX320", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + 320, + 69.86, + 2.0, + 168, + 23, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class RegNetY320(RegNet): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, + name: str = "RegNetY320", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + 232, + 115.89, + 2.53, + 232, + 20, + 0.25, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +add_model_to_registry(RegNetX002, "imagenet") +add_model_to_registry(RegNetY002, "imagenet") +add_model_to_registry(RegNetX004, "imagenet") +add_model_to_registry(RegNetY004, "imagenet") +add_model_to_registry(RegNetX006, "imagenet") +add_model_to_registry(RegNetY006, "imagenet") +add_model_to_registry(RegNetX008, "imagenet") +add_model_to_registry(RegNetY008, "imagenet") +add_model_to_registry(RegNetX016, "imagenet") +add_model_to_registry(RegNetY016, "imagenet") +add_model_to_registry(RegNetX032, "imagenet") +add_model_to_registry(RegNetY032, "imagenet") +add_model_to_registry(RegNetX040, "imagenet") +add_model_to_registry(RegNetY040, "imagenet") +add_model_to_registry(RegNetX064, "imagenet") +add_model_to_registry(RegNetY064, "imagenet") +add_model_to_registry(RegNetX080, "imagenet") +add_model_to_registry(RegNetY080, "imagenet") +add_model_to_registry(RegNetX120, "imagenet") +add_model_to_registry(RegNetY120, "imagenet") +add_model_to_registry(RegNetX160, "imagenet") +add_model_to_registry(RegNetY160, "imagenet") +add_model_to_registry(RegNetX320, "imagenet") +add_model_to_registry(RegNetY320, "imagenet") diff --git a/kimm/models/regnet_test.py b/kimm/models/regnet_test.py new file mode 100644 index 0000000..2c75012 --- /dev/null +++ b/kimm/models/regnet_test.py @@ -0,0 +1,61 @@ +import pytest +from absl.testing import parameterized +from keras import models +from keras import random +from keras.src import testing + +from kimm.models.regnet import RegNetX002 +from kimm.models.regnet import RegNetY002 + + +class ResNetTest(testing.TestCase, parameterized.TestCase): + @parameterized.named_parameters( + [(RegNetX002.__name__, RegNetX002), (RegNetY002.__name__, RegNetY002)] + ) + def test_resnet_base(self, model_class): + # TODO: test the correctness of the real image + x = random.uniform([1, 224, 224, 3]) * 255.0 + model = model_class() + + y = model(x, training=False) + + self.assertEqual(y.shape, (1, 1000)) + + @parameterized.named_parameters( + [(RegNetX002.__name__, RegNetX002), (RegNetY002.__name__, RegNetY002)] + ) + def test_resnet_feature_extractor(self, model_class): + x = random.uniform([1, 224, 224, 3]) * 255.0 + model = model_class(feature_extractor=True) + + y = model(x, training=False) + + self.assertIsInstance(y, dict) + self.assertContainsSubset( + model_class.available_feature_keys(), + list(y.keys()), + ) + self.assertEqual(list(y["STEM_S2"].shape), [1, 112, 112, 32]) + self.assertEqual(list(y["BLOCK0_S4"].shape), [1, 56, 56, 24]) + self.assertEqual(list(y["BLOCK1_S8"].shape), [1, 28, 28, 56]) + self.assertEqual(list(y["BLOCK2_S16"].shape), [1, 14, 14, 152]) + self.assertEqual(list(y["BLOCK3_S32"].shape), [1, 7, 7, 368]) + + @pytest.mark.serialization + @parameterized.named_parameters( + [ + (RegNetX002.__name__, RegNetX002, 224), + (RegNetY002.__name__, RegNetY002, 224), + ] + ) + def test_resnet_serialization(self, model_class, image_size): + x = random.uniform([1, image_size, image_size, 3]) * 255.0 + temp_dir = self.get_temp_dir() + model1 = model_class() + y1 = model1(x, training=False) + model1.save(temp_dir + "/model.keras") + + model2 = models.load_model(temp_dir + "/model.keras") + y2 = model2(x, training=False) + + self.assertAllClose(y1, y2) diff --git a/tools/convert_regnet_from_timm.py b/tools/convert_regnet_from_timm.py new file mode 100644 index 0000000..33372bf --- /dev/null +++ b/tools/convert_regnet_from_timm.py @@ -0,0 +1,171 @@ +""" +pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu +pip install timm +""" +import os + +import keras +import numpy as np +import timm +import torch + +from kimm.models import regnet +from kimm.utils.timm_utils import assign_weights +from kimm.utils.timm_utils import is_same_weights +from kimm.utils.timm_utils import separate_keras_weights +from kimm.utils.timm_utils import separate_torch_state_dict + +timm_model_names = [ + "regnetx_002.pycls_in1k", + "regnety_002.pycls_in1k", + "regnetx_004.pycls_in1k", + "regnety_004.tv2_in1k", + "regnetx_006.pycls_in1k", + "regnety_006.pycls_in1k", + "regnetx_008.tv2_in1k", + "regnety_008.pycls_in1k", + "regnetx_016.tv2_in1k", + "regnety_016.tv2_in1k", + "regnetx_032.tv2_in1k", + "regnety_032.ra_in1k", + "regnetx_040.pycls_in1k", + "regnety_040.ra3_in1k", + "regnetx_064.pycls_in1k", + "regnety_064.ra3_in1k", + "regnetx_080.tv2_in1k", + "regnety_080.ra3_in1k", + "regnetx_120.pycls_in1k", + "regnety_120.sw_in12k_ft_in1k", + "regnetx_160.tv2_in1k", + "regnety_160.swag_ft_in1k", + "regnetx_320.tv2_in1k", + "regnety_320.swag_ft_in1k", +] +keras_model_classes = [ + regnet.RegNetX002, + regnet.RegNetY002, + regnet.RegNetX004, + regnet.RegNetY004, + regnet.RegNetX006, + regnet.RegNetY006, + regnet.RegNetX008, + regnet.RegNetY008, + regnet.RegNetX016, + regnet.RegNetY016, + regnet.RegNetX032, + regnet.RegNetY032, + regnet.RegNetX040, + regnet.RegNetY040, + regnet.RegNetX064, + regnet.RegNetY064, + regnet.RegNetX080, + regnet.RegNetY080, + regnet.RegNetX120, + regnet.RegNetY120, + regnet.RegNetX160, + regnet.RegNetY160, + regnet.RegNetX320, + regnet.RegNetY320, +] + +for timm_model_name, keras_model_class in zip( + timm_model_names, keras_model_classes +): + """ + Prepare timm model and keras model + """ + input_shape = [224, 224, 3] + torch_model = timm.create_model(timm_model_name, pretrained=True) + torch_model = torch_model.eval() + trainable_state_dict, non_trainable_state_dict = separate_torch_state_dict( + torch_model.state_dict() + ) + keras_model = keras_model_class( + input_shape=input_shape, + include_preprocessing=False, + classifier_activation="linear", + ) + trainable_weights, non_trainable_weights = separate_keras_weights( + keras_model + ) + + # for torch_name, (_, keras_name) in zip( + # trainable_state_dict.keys(), trainable_weights + # ): + # print(f"{torch_name} {keras_name}") + + # print(len(trainable_state_dict.keys())) + # print(len(trainable_weights)) + # print(timm_model_name, keras_model_class.__name__) + + # exit() + + """ + Assign weights + """ + for keras_weight, keras_name in trainable_weights + non_trainable_weights: + keras_name: str + torch_name = keras_name + torch_name = torch_name.replace("_", ".") + # stem + torch_name = torch_name.replace("stem_conv2d", "stem.conv") + # blocks + torch_name = torch_name.replace("conv2d", "conv") + # se + torch_name = torch_name.replace("se.conv.reduce", "se.fc1") + torch_name = torch_name.replace("se.conv.expand", "se.fc2") + # head + torch_name = torch_name.replace("classifier", "head.fc") + + # weights naming mapping + torch_name = torch_name.replace("kernel", "weight") # conv2d + torch_name = torch_name.replace("gamma", "weight") # bn + torch_name = torch_name.replace("beta", "bias") # bn + torch_name = torch_name.replace("moving.mean", "running_mean") # bn + torch_name = torch_name.replace("moving.variance", "running_var") # bn + + # assign weights + if torch_name in trainable_state_dict: + torch_weights = trainable_state_dict[torch_name].numpy() + elif torch_name in non_trainable_state_dict: + torch_weights = non_trainable_state_dict[torch_name].numpy() + else: + raise ValueError( + "Can't find the corresponding torch weights. " + f"Got keras_name={keras_name}, torch_name={torch_name}" + ) + if is_same_weights(keras_name, keras_weight, torch_name, torch_weights): + assign_weights(keras_name, keras_weight, torch_weights) + else: + raise ValueError( + "Can't find the corresponding torch weights. The shape is " + f"mismatched. Got keras_name={keras_name}, " + f"keras_weight shape={keras_weight.shape}, " + f"torch_name={torch_name}, " + f"torch_weights shape={torch_weights.shape}" + ) + + """ + Verify model outputs + """ + np.random.seed(2023) + keras_data = np.random.uniform(size=[1] + input_shape).astype("float32") + torch_data = torch.from_numpy(np.transpose(keras_data, [0, 3, 1, 2])) + torch_y = torch_model(torch_data) + keras_y = keras_model(keras_data, training=False) + torch_y = torch_y.detach().cpu().numpy() + keras_y = keras.ops.convert_to_numpy(keras_y) + try: + np.testing.assert_allclose(torch_y, keras_y, atol=1e-5) + except AssertionError as e: + print(timm_model_name, keras_model_class.__name__) + raise e + print(f"{keras_model_class.__name__}: output matched!") + + """ + Save converted model + """ + os.makedirs("exported", exist_ok=True) + export_path = f"exported/{keras_model.name.lower()}_{timm_model_name}.keras" + keras_model.save(export_path) + print(f"Export to {export_path}") From 2a2200e4133c36d530488c717d55ea189e743532 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Wed, 17 Jan 2024 18:11:33 +0800 Subject: [PATCH 2/3] Update `export.sh` --- shell/export.sh | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/shell/export.sh b/shell/export.sh index 27c209a..7597cdf 100755 --- a/shell/export.sh +++ b/shell/export.sh @@ -1,11 +1,18 @@ #!/bin/bash +set -Euxo pipefail + export CUDA_VISIBLE_DEVICES= export TF_CPP_MIN_LOG_LEVEL=3 -python3 -m tools.convert_densenet_from_timm && -python3 -m tools.convert_efficientnet_from_timm && -python3 -m tools.convert_ghostnet_from_timm && -python3 -m tools.convert_inception_v3_from_timm && -python3 -m tools.convert_mobilenet_v2_from_timm && -python3 -m tools.convert_mobilenet_v3_from_timm && -python3 -m tools.convert_mobilevit_from_timm && +export KERAS_BACKEND=tensorflow +python3 -m tools.convert_densenet_from_timm +python3 -m tools.convert_efficientnet_from_timm +python3 -m tools.convert_ghostnet_from_timm +python3 -m tools.convert_inception_v3_from_timm +python3 -m tools.convert_mobilenet_v2_from_timm +python3 -m tools.convert_mobilenet_v3_from_timm +python3 -m tools.convert_mobilevit_from_timm +python3 -m tools.convert_regnet_from_timm +python3 -m tools.convert_resnet_from_timm +python3 -m tools.convert_vit_from_timm + echo "Export finished successfully!" From d57dda37c9b3582c8aec628a22ab55ca9c784bae Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Wed, 17 Jan 2024 18:11:58 +0800 Subject: [PATCH 3/3] Nit --- kimm/models/regnet_test.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/kimm/models/regnet_test.py b/kimm/models/regnet_test.py index 2c75012..906586b 100644 --- a/kimm/models/regnet_test.py +++ b/kimm/models/regnet_test.py @@ -8,11 +8,11 @@ from kimm.models.regnet import RegNetY002 -class ResNetTest(testing.TestCase, parameterized.TestCase): +class RegNetTest(testing.TestCase, parameterized.TestCase): @parameterized.named_parameters( [(RegNetX002.__name__, RegNetX002), (RegNetY002.__name__, RegNetY002)] ) - def test_resnet_base(self, model_class): + def test_regnet_base(self, model_class): # TODO: test the correctness of the real image x = random.uniform([1, 224, 224, 3]) * 255.0 model = model_class() @@ -24,7 +24,7 @@ def test_resnet_base(self, model_class): @parameterized.named_parameters( [(RegNetX002.__name__, RegNetX002), (RegNetY002.__name__, RegNetY002)] ) - def test_resnet_feature_extractor(self, model_class): + def test_regnet_feature_extractor(self, model_class): x = random.uniform([1, 224, 224, 3]) * 255.0 model = model_class(feature_extractor=True) @@ -48,7 +48,7 @@ def test_resnet_feature_extractor(self, model_class): (RegNetY002.__name__, RegNetY002, 224), ] ) - def test_resnet_serialization(self, model_class, image_size): + def test_regnet_serialization(self, model_class, image_size): x = random.uniform([1, image_size, image_size, 3]) * 255.0 temp_dir = self.get_temp_dir() model1 = model_class()