diff --git a/src/kompressor/models/__init__.py b/src/kompressor/models/__init__.py new file mode 100644 index 0000000..90a52ad --- /dev/null +++ b/src/kompressor/models/__init__.py @@ -0,0 +1 @@ +from . import srcnn \ No newline at end of file diff --git a/src/kompressor/models/srcnn.py b/src/kompressor/models/srcnn.py new file mode 100644 index 0000000..27fd664 --- /dev/null +++ b/src/kompressor/models/srcnn.py @@ -0,0 +1,40 @@ +import flax.linen as nn + + +class SRCNN(nn.Module): + """ + A simple convolutional neural network for Super Resolution. + + Model based on Super-Resolution Convolutional Neural Network https://arxiv.org/pdf/1501.00092v3.pdf + """ + + encoding_features: int = 300 + encoding_kernel_size: int = 3 + padding: int = 1 + padding_features: int = 100 + output_kernel_size: int = 2 + patches: int = 5 + channels: int = 1 + + @nn.compact + def __call__(self, low_resolution): + features = nn.activation.relu( + nn.Conv( + features=300, + kernel_size=(self.encoding_kernel_size, self.encoding_kernel_size), + padding="VALID", + )(low_resolution) + ) + for _ in range(self.padding): + features = nn.activation.relu( + nn.Conv( + features=self.padding_features, kernel_size=(1, 1), padding="VALID" + )(features) + ) + features = nn.Conv( + features=self.patches * self.channels, + kernel_size=(self.output_kernel_size, self.output_kernel_size), + padding="VALID", + )(features) + batch, height, width = features.shape[0:3] + return features.reshape(batch, height, width, self.patches, self.channels)