Skip to content

Commit

Permalink
tweak
Browse files Browse the repository at this point in the history
  • Loading branch information
alexwitt23 committed Jun 22, 2020
1 parent 7130953 commit a94eb1a
Showing 1 changed file with 53 additions and 46 deletions.
99 changes: 53 additions & 46 deletions src/bifpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import torch
import numpy as np

from src import efficientnet


@dataclasses.dataclass
class node_param:
Expand All @@ -27,6 +25,17 @@ class node_param:
]


class Swish(torch.nn.Module):
""" Swish activation function presented here:
https://arxiv.org/pdf/1710.05941.pdf. """

def __init__(self) -> None:
super().__init__()

def __call__(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.sigmoid_(x)


def depthwise(in_channels: int, out_channels: int):
""" A depthwise separable linear layer. """
return [
Expand Down Expand Up @@ -59,17 +68,20 @@ def __init__(
out_channels: int,
num_bifpns: int,
bifpn_height: int = 5,
use_dw: bool = False,
levels: List[int] = [3, 4, 5],
) -> None:
"""
Args:
in_channels: A list of the incomming number of filters for each pyramid
level.
out_channels: The number of features used within the BiFPN.
num_bifpns: The number of BiFPN layers in the model.
out_channels: The number of features outputted from the latteral
convolutions.
num_bifpns: The number of BiFPN layers in the model. start_level: Which
pyramid level to start at.
num_levels_in: The number of feature maps incoming.
bifpn_height: The number of feature maps to send in to the
bifpn. NOTE might not be equal to num_levels_in.
levels: The idxs of the levels coming in to the BiFPN.
bifpn. NOTE might not be equal to num_levels_in.
"""
super().__init__()
self.levels_in = levels
Expand All @@ -80,32 +92,28 @@ def __init__(
# level to form lower resolution levels.
if self.bifpn_height != len(self.levels_in):

# Before the first BiFPN level, we need to dowsample the incoming most
# 'low-resolution' level to create the necessary number of levels. Only
# before the first downsample is a pointwise conv applied to fix the
# channel depth to the number of channels in the BiFPN.
# This first level we dowsample will also be pointwise constrained to the
# specified channel depth associated with this bifpn.
self.downsample_convs = [
torch.nn.Sequential(
torch.nn.Conv2d(in_channels[-1], out_channels, kernel_size=1),
torch.nn.MaxPool2d(kernel_size=3, padding=1, stride=2),
torch.nn.Conv2d(in_channels[-1], out_channels, kernel_size=1),
)
] + [
torch.nn.MaxPool2d(kernel_size=3, padding=1, stride=2)
for _ in range(self.bifpn_height - len(levels) - 1)
]
self.downsample_convs = torch.nn.Sequential(*self.downsample_convs)

# Specify the channels for the first bifpn layer. It will be the top most
# 'low-res' feature maps plus however many extra levels are made above (typically
# two).
# Specify the channels for the first bifpn layer
level_channels = in_channels[-len(levels) :] + [out_channels] * (
self.bifpn_height - len(levels)
)

# Construct the BiFPN layers. We need to keep track of the channels of the inputs
# to each level. The first level will have channel depths _not equal_ to the
# internal feature depth. The original input feature maps will be squashed to the
# proper channel depth _twice_.
# Construct the BiFPN layers. If we are to take fewer feature pyramids than the
# list given, we must interpolate the others. This occurs when the supplied
# feature list might not align with anchor grid generated since the anchor grid
# assumes that each level is 1 / 2 the W, H of
# the previous level.
channel_dict = {level: level_channels[idx] for idx, level in enumerate(levels)}
self.bifp_layers = torch.nn.Sequential()
for idx in range(num_bifpns):
Expand All @@ -114,9 +122,9 @@ def __init__(
BiFPNBlock(
channels=out_channels,
num_levels=bifpn_height,
channels_in={idx: out_channels for idx in levels}
if idx
else channel_dict,
levels_in=channel_dict
if idx == 0
else {level: out_channels for level in levels},
),
)

Expand All @@ -127,11 +135,12 @@ def __call__(self, feature_maps: collections.OrderedDict) -> List[torch.Tensor]:
Args:
feature_maps: Feature maps in sorted order of layer.
"""
# Make sure fpn gets the anticipated number of levels.
assert len(feature_maps) == len(self.levels_in), len(feature_maps)

# Apply the downsampling to form the top layers.
for layer in self.downsample_convs:

# Get the top most layer which is the last in the dict.
# Get the top most layer which happens to be the last in the dict.
top_level_idx, top_level_map = next(reversed(feature_maps.items()))
feature_maps[top_level_idx + 1] = layer(top_level_map)

Expand All @@ -142,33 +151,28 @@ class BiFPNBlock(torch.nn.Module):
""" Modular implementation of a single BiFPN layer. """

def __init__(
self, channels: int, num_levels: int, channels_in: Dict[int, int]
self, channels: int, num_levels: int, levels_in: Dict[int, int]
) -> None:
"""
Args:
channels: The number of channels in and out.
num_levels: The number incoming feature pyramid levels.
channels_in: The channels for each input level to the block. This is really
important for the first level which must adapt the channel depth of the
original feature levels twice.
"""
super().__init__()
self.num_levels = num_levels
self.combines = torch.nn.Sequential()
self.post_combines = torch.nn.Sequential()
self.index_offset = num_levels - len(channels_in) + 1
self.index_offset = num_levels - len(levels_in) + 1

# Create node combination and depthwise separable convolutions that will process
# the input feature maps.
for idx, node in enumerate(_NODE_PARAMS):

# Combine the nodes first.
self.combines.add_module(
f"combine_{node.offsets}",
CombineLevels(node, self.index_offset, channels, channels_in),
CombineLevels(node, self.index_offset, channels, levels_in),
)

# Apply output convolution.
self.post_combines.add_module(
f"post_combine_{node.offsets}",
torch.nn.Sequential(
Expand All @@ -187,7 +191,7 @@ def __init__(
bias=True,
),
torch.nn.BatchNorm2d(channels, momentum=0.01, eps=1e-3),
efficientnet.Swish(),
Swish(),
),
)

Expand All @@ -199,20 +203,21 @@ def __call__(self, input_maps: collections.OrderedDict) -> collections.OrderedDi
feature_maps: A list of the feature maps from each of the
pyramid levels. Highest to lowest.
"""
level_id_offset = next(iter(input_maps.keys()))
for combine, post_combine_conv in zip(self.combines, self.post_combines):
assert self.num_levels == len(input_maps)
for idx, (combine, post_combine_conv) in enumerate(
zip(self.combines, self.post_combines)
):
level_idx = next(reversed(input_maps.keys()))
input_maps[level_idx + 1] = post_combine_conv(combine(input_maps))

# Only return the last self.num_levels levels.
retval = collections.OrderedDict()
for idx, level in enumerate(input_maps.values()):
if idx < len(input_maps) - self.num_levels:
continue

retval[len(retval) + level_id_offset] = level

return retval
# Only return the last n levels
return collections.OrderedDict(
[
(idx - self.num_levels, level)
for idx, level in enumerate(input_maps.values())
if idx >= len(input_maps) - self.num_levels
]
)


class CombineLevels(torch.nn.Module):
Expand All @@ -239,7 +244,9 @@ def __init__(
for offset in self.offsets:
if offset in levels_in and levels_in[offset] != channels:
self.lateral_node = offset
self.lateral_conv = torch.nn.Conv2d(levels_in[offset], channels, kernel_size=1)
self.lateral_conv = torch.nn.Conv2d(
levels_in[offset], channels, kernel_size=1
)

# Construct the resample module.
if param.upsample:
Expand Down

0 comments on commit a94eb1a

Please sign in to comment.