Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MobileViT #8

Merged
merged 2 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions kimm/blocks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from kimm.blocks.base_block import apply_activation
from kimm.blocks.base_block import apply_conv2d_block
from kimm.blocks.base_block import apply_se_block
from kimm.blocks.inverted_residual_block import apply_inverted_residual_block
from kimm.blocks.transformer_block import apply_mlp_block
from kimm.blocks.transformer_block import apply_transformer_block
77 changes: 77 additions & 0 deletions kimm/blocks/inverted_residual_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from keras import layers

from kimm.blocks.base_block import apply_conv2d_block
from kimm.blocks.base_block import apply_se_block
from kimm.utils import make_divisible


def apply_inverted_residual_block(
inputs,
output_channels,
depthwise_kernel_size=3,
expansion_kernel_size=1,
pointwise_kernel_size=1,
strides=1,
expansion_ratio=1.0,
se_ratio=0.0,
activation="swish",
se_input_channels=None,
se_activation=None,
se_gate_activation="sigmoid",
se_make_divisible_number=None,
bn_epsilon=1e-5,
padding=None,
name="inverted_residual_block",
):
input_channels = inputs.shape[-1]
hidden_channels = make_divisible(input_channels * expansion_ratio)
has_skip = strides == 1 and input_channels == output_channels

x = inputs
# Point-wise expansion
x = apply_conv2d_block(
x,
hidden_channels,
expansion_kernel_size,
1,
activation=activation,
bn_epsilon=bn_epsilon,
padding=padding,
name=f"{name}_conv_pw",
)
# Depth-wise convolution
x = apply_conv2d_block(
x,
kernel_size=depthwise_kernel_size,
strides=strides,
activation=activation,
use_depthwise=True,
bn_epsilon=bn_epsilon,
padding=padding,
name=f"{name}_conv_dw",
)
# Squeeze-and-excitation
if se_ratio > 0:
x = apply_se_block(
x,
se_ratio,
activation=se_activation or activation,
gate_activation=se_gate_activation,
se_input_channels=se_input_channels,
make_divisible_number=se_make_divisible_number,
name=f"{name}_se",
)
# Point-wise linear projection
x = apply_conv2d_block(
x,
output_channels,
pointwise_kernel_size,
1,
activation=None,
bn_epsilon=bn_epsilon,
padding=padding,
name=f"{name}_conv_pwl",
)
if has_skip:
x = layers.Add()([x, inputs])
return x
67 changes: 67 additions & 0 deletions kimm/blocks/transformer_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from keras import layers

from kimm import layers as kimm_layers


def apply_mlp_block(
inputs,
hidden_dim,
output_dim=None,
activation="gelu",
normalization=None,
use_bias=True,
dropout_rate=0.0,
name="mlp_block",
):
input_dim = inputs.shape[-1]
output_dim = output_dim or input_dim

x = inputs
x = layers.Dense(hidden_dim, use_bias=use_bias, name=f"{name}_fc1")(x)
x = layers.Activation(activation, name=f"{name}_act")(x)
x = layers.Dropout(dropout_rate, name=f"{name}_drop1")(x)
if normalization is not None:
x = normalization(name=f"{name}_norm")(x)
x = layers.Dense(output_dim, use_bias=use_bias, name=f"{name}_fc2")(x)
x = layers.Dropout(dropout_rate, name=f"{name}_drop2")(x)
return x


def apply_transformer_block(
inputs,
dim,
num_heads,
mlp_ratio=4.0,
use_qkv_bias=False,
use_qk_norm=False,
projection_dropout_rate=0.0,
attention_dropout_rate=0.0,
activation="gelu",
name="transformer_block",
):
x = inputs
residual_1 = x

x = layers.LayerNormalization(epsilon=1e-6, name=f"{name}_norm1")(x)
x = kimm_layers.Attention(
dim,
num_heads,
use_qkv_bias,
use_qk_norm,
attention_dropout_rate,
projection_dropout_rate,
name=f"{name}_attn",
)(x)
x = layers.Add()([residual_1, x])

residual_2 = x
x = layers.LayerNormalization(epsilon=1e-6, name=f"{name}_norm2")(x)
x = apply_mlp_block(
x,
int(dim * mlp_ratio),
activation=activation,
dropout_rate=projection_dropout_rate,
name=f"{name}_mlp",
)
x = layers.Add()([residual_2, x])
return x
70 changes: 2 additions & 68 deletions kimm/models/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from keras.src.applications import imagenet_utils

from kimm.blocks import apply_conv2d_block
from kimm.blocks import apply_inverted_residual_block
from kimm.blocks import apply_se_block
from kimm.models.feature_extractor import FeatureExtractor
from kimm.utils import make_divisible
Expand Down Expand Up @@ -130,73 +131,6 @@ def apply_depthwise_separation_block(
return x


def apply_inverted_residual_block(
inputs,
output_channels,
depthwise_kernel_size=3,
expansion_kernel_size=1,
pointwise_kernel_size=1,
strides=1,
expansion_ratio=1.0,
se_ratio=0.0,
activation="swish",
bn_epsilon=1e-5,
padding=None,
name="inverted_residual_block",
):
input_channels = inputs.shape[-1]
hidden_channels = make_divisible(input_channels * expansion_ratio)
has_skip = strides == 1 and input_channels == output_channels

x = inputs
# Point-wise expansion
x = apply_conv2d_block(
x,
hidden_channels,
expansion_kernel_size,
1,
activation=activation,
bn_epsilon=bn_epsilon,
padding=padding,
name=f"{name}_conv_pw",
)
# Depth-wise convolution
x = apply_conv2d_block(
x,
kernel_size=depthwise_kernel_size,
strides=strides,
activation=activation,
use_depthwise=True,
bn_epsilon=bn_epsilon,
padding=padding,
name=f"{name}_conv_dw",
)
# Squeeze-and-excitation
if se_ratio > 0:
x = apply_se_block(
x,
se_ratio,
activation=activation,
gate_activation="sigmoid",
se_input_channels=input_channels,
name=f"{name}_se",
)
# Point-wise linear projection
x = apply_conv2d_block(
x,
output_channels,
pointwise_kernel_size,
1,
activation=None,
bn_epsilon=bn_epsilon,
padding=padding,
name=f"{name}_conv_pwl",
)
if has_skip:
x = layers.Add()([x, inputs])
return x


def apply_edge_residual_block(
inputs,
output_channels,
Expand Down Expand Up @@ -271,7 +205,7 @@ def __init__(
classes: int = 1000,
classifier_activation: str = "softmax",
weights: typing.Optional[str] = None, # TODO: imagenet
config: typing.Union[str, typing.List] = "default",
config: typing.Union[str, typing.List] = "v1",
**kwargs,
):
_available_configs = [
Expand Down
52 changes: 2 additions & 50 deletions kimm/models/mobilenet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from keras.src.applications import imagenet_utils

from kimm.blocks import apply_conv2d_block
from kimm.blocks import apply_inverted_residual_block
from kimm.models.feature_extractor import FeatureExtractor
from kimm.utils import make_divisible
from kimm.utils.model_registry import add_model_to_registry
Expand Down Expand Up @@ -58,55 +59,6 @@ def apply_depthwise_separation_block(
return x


def apply_inverted_residual_block(
inputs,
output_channels,
depthwise_kernel_size=3,
expansion_kernel_size=1,
pointwise_kernel_size=1,
strides=1,
expansion_ratio=1.0,
activation="relu6",
name="inverted_residual_block",
):
input_channels = inputs.shape[-1]
hidden_channels = make_divisible(input_channels * expansion_ratio)
has_skip = strides == 1 and input_channels == output_channels

x = inputs

# Point-wise expansion
x = apply_conv2d_block(
x,
hidden_channels,
expansion_kernel_size,
1,
activation=activation,
name=f"{name}_conv_pw",
)
# Depth-wise convolution
x = apply_conv2d_block(
x,
kernel_size=depthwise_kernel_size,
strides=strides,
activation=activation,
use_depthwise=True,
name=f"{name}_conv_dw",
)
# Point-wise linear projection
x = apply_conv2d_block(
x,
output_channels,
pointwise_kernel_size,
1,
activation=None,
name=f"{name}_conv_pwl",
)
if has_skip:
x = layers.Add()([x, inputs])
return x


class MobileNetV2(FeatureExtractor):
def __init__(
self,
Expand Down Expand Up @@ -189,7 +141,7 @@ def __init__(
)
elif block_type == "ir":
x = apply_inverted_residual_block(
x, c, k, 1, 1, s, e, name=name
x, c, k, 1, 1, s, e, activation="relu6", name=name
)
current_stride *= s
features[f"BLOCK{current_block_idx}_S{current_stride}"] = x
Expand Down
Loading