Skip to content

Commit

Permalink
fix conv1d->conv2d in STNorm (#117)
Browse files Browse the repository at this point in the history
  • Loading branch information
LMissher authored Jan 4, 2024
1 parent b611d86 commit f001944
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions baselines/STNorm/arch/stnorm_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,16 +104,16 @@ def __init__(self, num_nodes, tnorm_bool, snorm_bool, in_dim, out_dim, channels,
self.filter_convs.append(nn.Conv2d(
in_channels=num * channels, out_channels=channels, kernel_size=(1, kernel_size), dilation=new_dilation))

self.gate_convs.append(nn.Conv1d(
self.gate_convs.append(nn.Conv2d(
in_channels=num * channels, out_channels=channels, kernel_size=(1, kernel_size), dilation=new_dilation))

# 1x1 convolution for residual connection
self.residual_convs.append(
nn.Conv1d(in_channels=channels, out_channels=channels, kernel_size=(1, 1)))
nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(1, 1)))

# 1x1 convolution for skip connection
self.skip_convs.append(
nn.Conv1d(in_channels=channels, out_channels=channels, kernel_size=(1, 1)))
nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(1, 1)))
new_dilation *= 2
receptive_field += additional_scope
additional_scope *= 2
Expand Down

0 comments on commit f001944

Please sign in to comment.