Skip to content

Commit

Permalink
ESPCN Model
Browse files Browse the repository at this point in the history
This is an implementation of the ESPCN architecture
  • Loading branch information
GMW99 committed Jul 12, 2024
1 parent 78cf10d commit 8bf6d14
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/kompressor/models/cnn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .srcnn import SRCNN
from .vdsr import VDSR
from .esdr import EDSR
from .espcn import ESPCN
50 changes: 50 additions & 0 deletions src/kompressor/models/cnn/espcn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import flax.linen as nn


class ESPCN(nn.Module):
"""
Efficient Sub-Pixel Convolutional Neural Network https://arxiv.org/pdf/1609.05158v2.pdf
"""

name: str = "ESPCN"
features: int = 64
input_conv_kernel_size: int = 5
encoding_conv_kernel_size: int = 3
output_kernel_size: int = 2
patches: int = 5
channels: int = 1

@nn.compact
def __call__(self, low_resolution):
features = nn.activation.relu(
nn.Conv(
features=self.features,
kernel_size=(self.input_conv_kernel_size, self.input_conv_kernel_size),
)(low_resolution)
)
features = nn.activation.relu(
nn.Conv(
features=self.features,
kernel_size=(
self.encoding_conv_kernel_size,
self.encoding_conv_kernel_size,
),
)(features)
)
features = nn.activation.relu(
nn.Conv(
features=self.features // 2,
kernel_size=(
self.encoding_conv_kernel_size,
self.encoding_conv_kernel_size,
),
padding="VALID",
)(features)
)
features = nn.Conv(
features=self.patches * self.channels,
kernel_size=(self.output_kernel_size, self.output_kernel_size),
padding="VALID",
)(features)
batch, height, width = features.shape[0:3]
return features.reshape(batch, height, width, self.patches, self.channels)

0 comments on commit 8bf6d14

Please sign in to comment.