Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use sigmoid postprocessing in unet2d example #676

Merged
merged 2 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ outputs:
sample_tensor:
source: test_output.png
sha256: 7bce8b53bcd0a12487a61f953aafe0f3700652848980d1083964c5bcb9555eec
postprocessing:
- id: sigmoid
- id: ensure_dtype
kwargs:
dtype: float32

weights:
pytorch_state_dict:
Expand All @@ -105,7 +110,7 @@ weights:
architecture:
callable: UNet2d
source: unet2d.py
sha256: 7cdd8332dc3e3735e71c328f81b63a9ac86c028f80522312484ca9a4027d4ce1
sha256: 589f0c9e60fa00f015213cd251541bcbf9582644f3ecffb2e6f3a30d2af1931a
kwargs: { input_channels: 1, output_channels: 1 }
dependencies:
source: environment.yaml
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ outputs:
reference_tensor: raw
scale: [1.0, 0.0, null, 1.0, 1.0]
offset: [0.0, 0.5, 0.5, 0.0, 0.0]
postprocessing:
- name: sigmoid

test_inputs: [test_input.npy]
test_outputs: [test_output_expanded.npy]
Expand All @@ -67,7 +69,7 @@ weights:
sha256: e4d3885bccbe41cbf6c1d825f3cd2b707c7021ead5593156007e407a16b27cf2
source: https://zenodo.org/records/3446812/files/unet2d_weights.torch
architecture: unet2d_expand_output_shape.py:UNet2d
architecture_sha256: 80a886acc734f848a8e018d8063cfd7e003d7e20076583b28326bfdd6136be32
architecture_sha256: 1441e8cfaf387a98a1c0bb937d59a2e9d6c311a8912cd88b39c11ecff503ccfe
kwargs: { input_channels: 1, output_channels: 1 }
dependencies: conda:environment.yaml

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ outputs:
data_range: [-.inf, .inf]
halo: [0, 0, 32, 32]
shape: [1, 1, 512, 512]
postprocessing:
- name: sigmoid

dependencies: conda:environment.yaml

Expand All @@ -63,7 +65,7 @@ weights:
sha256: e4d3885bccbe41cbf6c1d825f3cd2b707c7021ead5593156007e407a16b27cf2
source: https://zenodo.org/records/3446812/files/unet2d_weights.torch
architecture: unet2d.py:UNet2d
architecture_sha256: 7cdd8332dc3e3735e71c328f81b63a9ac86c028f80522312484ca9a4027d4ce1
architecture_sha256: 589f0c9e60fa00f015213cd251541bcbf9582644f3ecffb2e6f3a30d2af1931a
kwargs: { input_channels: 1, output_channels: 1 }
onnx:
sha256: f1f086d5e340f9d4d7001a1b62a2b835f9b87a2fb5452c4fe7d8cc821bdf539c
Expand Down
8 changes: 1 addition & 7 deletions example_descriptions/models/unet2d_nuclei_broad/unet2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def forward(self, input):


class UNet2d(nn.Module):
def __init__(self, input_channels, output_channels, training=False):
def __init__(self, input_channels, output_channels):
super().__init__()
self.input_channels = input_channels
self.output_channels = output_channels
Expand All @@ -41,7 +41,6 @@ def __init__(self, input_channels, output_channels, training=False):
)

self.output = nn.Conv2d(16, self.output_channels, 1)
self.training = training

def conv_layer(self, in_channels, out_channels):
kernel_size = 3
Expand Down Expand Up @@ -78,9 +77,4 @@ def forward(self, input):

x = self.output(x)

# apply a sigmoid directly if we are in inference mode
if not self.training:
# postprocessing
x = torch.sigmoid(x)

return x
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def forward(self, input):


class UNet2d(nn.Module):
def __init__(self, input_channels, output_channels, training=False):
def __init__(self, input_channels, output_channels):
super().__init__()
self.input_channels = input_channels
self.output_channels = output_channels
Expand All @@ -41,7 +41,6 @@ def __init__(self, input_channels, output_channels, training=False):
)

self.output = nn.Conv2d(16, self.output_channels, 1)
self.training = training

def conv_layer(self, in_channels, out_channels):
kernel_size = 3
Expand Down Expand Up @@ -78,11 +77,6 @@ def forward(self, input):

x = self.output(x)

# apply a sigmoid directly if we are in inference mode
if not self.training:
# postprocessing
x = torch.sigmoid(x)

# expand the shape across z
out_shape = tuple(x.shape)
expanded_shape = out_shape[:2] + (1,) + out_shape[2:]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ outputs:
reference_tensor: raw
scale: [1.0, 0.0, 1.0, 1.0]
offset: [0.0, 0.5, 0.0, 0.0]
postprocessing:
- name: sigmoid

dependencies: conda:environment.yaml

Expand All @@ -72,7 +74,7 @@ weights:
source: https://zenodo.org/records/3446812/files/unet2d_weights.torch
sha256: e4d3885bccbe41cbf6c1d825f3cd2b707c7021ead5593156007e407a16b27cf2
architecture: unet2d.py:UNet2d
architecture_sha256: 7cdd8332dc3e3735e71c328f81b63a9ac86c028f80522312484ca9a4027d4ce1
architecture_sha256: 589f0c9e60fa00f015213cd251541bcbf9582644f3ecffb2e6f3a30d2af1931a
kwargs: { input_channels: 1, output_channels: 1 }
onnx:
source: weights.onnx
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ outputs:
reference_tensor: raw
scale: [1.0, 0.0, 1.0, 1.0]
offset: [0.0, 0.5, 0.0, 0.0]
postprocessing:
- name: sigmoid

test_inputs: [test_input.npy]
test_outputs: [test_output.npy]
Expand All @@ -69,7 +71,7 @@ weights:
source: https://zenodo.org/records/3446812/files/unet2d_weights.torch
sha256: e4d3885bccbe41cbf6c1d825f3cd2b707c7021ead5593156007e407a16b27cf2
architecture: unet2d.py:UNet2d
architecture_sha256: 7cdd8332dc3e3735e71c328f81b63a9ac86c028f80522312484ca9a4027d4ce1
architecture_sha256: 589f0c9e60fa00f015213cd251541bcbf9582644f3ecffb2e6f3a30d2af1931a
kwargs: { input_channels: 1, output_channels: 1 }
dependencies: conda:environment.yaml
onnx:
Expand Down
Loading