Skip to content

Commit

Permalink
Added trained notebookw
Browse files Browse the repository at this point in the history
  • Loading branch information
shubham-sri committed Feb 17, 2023
1 parent 7156be3 commit 746d529
Show file tree
Hide file tree
Showing 5 changed files with 291 additions and 93 deletions.
29 changes: 29 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1 +1,30 @@
# MRIGan


## Dataset

Dataset is placed in the `data` folder. `T1` and `T2` dataset placed under respective folders. All path are defined under utils module in `paths.py` file.

## Model

Model related code is placed under `model` folder. `mri_gan.py` contains the intialization, training and ploting code. `generator.py` and `discriminator.py` contains the generator and discriminator model respectively. `conv_block.py` contains the convolution block used in both generator and discriminator.

## Architecture
Downsample block used skip connection from input to output. Upsample block used skip connection from input to output and from output to input. The skip connection is used to preserve the spatial information. The architecture is shown below.

Generator Architecture is U-Net like architecture. And discriminator is PatchGAN architecture.

## Training

Training is done using `mri_gan.py` file. The training is done in two steps. First step is to train the generator and discriminator separately. Second step is to train the generator and discriminator together. The training is done for 260 epochs. The training is done using Adam optimizer with learning rate 0.0002 and beta1 0.5. The training is done on A6000 GPU.

### Results
GIF of the results is shown below.
![GIF](https://storage.googleapis.com/deeplearning-archive/MRIGan/cyclegan.gif)

Final epoch results are shown below.

![Final Epoch](./output/final_epoch.png)

Model archive is below.
[Model Archive](https://storage.googleapis.com/deeplearning-archive/MRIGan/models.zip)
Binary file added output/final_epoch.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
60 changes: 60 additions & 0 deletions src/model/conv_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,47 @@ def __init__(
use_bias=False,
)

self.conv2 = tf.keras.layers.Conv2D(
filters=filters,
kernel_size=kernel_size,
strides=1,
padding='same',
kernel_initializer=initializer,
use_bias=False,
)

self.conv3 = tf.keras.layers.Conv2D(
filters=filters,
kernel_size=1,
strides=1,
padding='same',
kernel_initializer=initializer,
use_bias=False,
)

self.instance_norm = InstanceNormalization()

self.leaky_relu = tf.keras.layers.LeakyReLU()

self.concat = tf.keras.layers.Concatenate()


def call(self, inputs, training=None):
x = self.conv1(inputs, training=training)
if self.apply_norm:
x = self.instance_norm(x, training=training)
x1 = self.leaky_relu(x, training=training)

x = self.conv2(x1, training=training)
x = self.leaky_relu(x, training=training)

x = self.conv3(x, training=training)
if self.apply_norm:
x = self.instance_norm(x, training=training)
x = self.leaky_relu(x, training=training)

x = self.concat([x, x1])

return x


Expand Down Expand Up @@ -70,12 +101,41 @@ def __init__(
self.activation = tf.keras.layers.Activation('tanh')
else:
self.activation = tf.keras.layers.ReLU()

self.conv2 = tf.keras.layers.Conv2D(
filters=filters,
kernel_size=3,
strides=1,
padding='same',
kernel_initializer=initializer,
use_bias=False,
)

self.conv3 = tf.keras.layers.Conv2D(
filters=filters,
kernel_size=1,
strides=1,
padding='same',
kernel_initializer=initializer,
use_bias=False,
)


def call(self, inputs, training=None):
x = self.conv1(inputs, training=training)
if self.apply_dropout:
x = self.dropout(x, training=training)
x1 = self.activation(x, training=training)

x = self.conv2(x1, training=training)
x = self.activation(x, training=training)

x = self.conv3(x, training=training)
x = self.activation(x, training=training)

x = tf.keras.layers.Concatenate()([x, x1])

if self.apply_dropout:
x = self.dropout(x, training=training)
return x

18 changes: 16 additions & 2 deletions src/model/mri_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,29 @@ def plot_images(self, clear_output=True, save_path="./images"):

# plot example t1 only in first row and predict t2 in second row, no need to plot t2
prediction_t1_to_t2 = self.generator_t1_to_t2(self.example_t1, training=False)
prediction_t2_to_t1 = self.generator_t2_to_t1(self.example_t2, training=False)

plt.figure(figsize=(10,10))
for i in range(4):
plt.subplot(4, 2, 2*i+1)
# plt t1 to t2 and t2 to t1
plt.subplot(4, 4, i*4+1)
plt.imshow(self.example_t1[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
plt.title("T1 Original")
plt.axis('off')

plt.subplot(4, 2, 2*i+2)
plt.subplot(4, 4, i*4+2)
plt.imshow(prediction_t1_to_t2[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
plt.title("T1 to T2")
plt.axis('off')

plt.subplot(4, 4, i*4+3)
plt.imshow(self.example_t2[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
plt.title("T2 Original")
plt.axis('off')

plt.subplot(4, 4, i*4+4)
plt.imshow(prediction_t2_to_t1[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
plt.title("T2 to T1")
plt.axis('off')

plt.savefig(
Expand Down
277 changes: 186 additions & 91 deletions train.ipynb

Large diffs are not rendered by default.

0 comments on commit 746d529

Please sign in to comment.