Skip to content

Commit

Permalink
Added a configurable SRCNN Model.
Browse files Browse the repository at this point in the history
  • Loading branch information
GMW99 committed Jul 9, 2024
1 parent 6aab9ca commit aa88927
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/kompressor/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import srcnn
40 changes: 40 additions & 0 deletions src/kompressor/models/srcnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import flax.linen as nn


class SRCNN(nn.Module):
"""
A simple convolutional neural network for Super Resolution.
Model based on Super-Resolution Convolutional Neural Network https://arxiv.org/pdf/1501.00092v3.pdf
"""

encoding_features: int = 300
encoding_kernel_size: int = 3
padding: int = 1
padding_features: int = 100
output_kernel_size: int = 2
patches: int = 5
channels: int = 1

@nn.compact
def __call__(self, low_resolution):
features = nn.activation.relu(
nn.Conv(
features=300,
kernel_size=(self.encoding_kernel_size, self.encoding_kernel_size),
padding="VALID",
)(low_resolution)
)
for _ in range(self.padding):
features = nn.activation.relu(
nn.Conv(
features=self.padding_features, kernel_size=(1, 1), padding="VALID"
)(features)
)
features = nn.Conv(
features=self.patches * self.channels,
kernel_size=(self.output_kernel_size, self.output_kernel_size),
padding="VALID",
)(features)
batch, height, width = features.shape[0:3]
return features.reshape(batch, height, width, self.patches, self.channels)

0 comments on commit aa88927

Please sign in to comment.