From f00194411da8d355dfc635c59fe73390cf94a3e0 Mon Sep 17 00:00:00 2001 From: LMissher <37818979+LMissher@users.noreply.github.com> Date: Thu, 4 Jan 2024 18:09:45 +0800 Subject: [PATCH] fix conv1d->conv2d in STNorm (#117) --- baselines/STNorm/arch/stnorm_arch.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/baselines/STNorm/arch/stnorm_arch.py b/baselines/STNorm/arch/stnorm_arch.py index 81465ee4..3afbed2d 100644 --- a/baselines/STNorm/arch/stnorm_arch.py +++ b/baselines/STNorm/arch/stnorm_arch.py @@ -104,16 +104,16 @@ def __init__(self, num_nodes, tnorm_bool, snorm_bool, in_dim, out_dim, channels, self.filter_convs.append(nn.Conv2d( in_channels=num * channels, out_channels=channels, kernel_size=(1, kernel_size), dilation=new_dilation)) - self.gate_convs.append(nn.Conv1d( + self.gate_convs.append(nn.Conv2d( in_channels=num * channels, out_channels=channels, kernel_size=(1, kernel_size), dilation=new_dilation)) # 1x1 convolution for residual connection self.residual_convs.append( - nn.Conv1d(in_channels=channels, out_channels=channels, kernel_size=(1, 1))) + nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(1, 1))) # 1x1 convolution for skip connection self.skip_convs.append( - nn.Conv1d(in_channels=channels, out_channels=channels, kernel_size=(1, 1))) + nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(1, 1))) new_dilation *= 2 receptive_field += additional_scope additional_scope *= 2