You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.
System information
OS Platform and Distribution (e.g., Linux Ubuntu 16.04): GCP Cloud TPU VM
Flax 0.8.5, jax 0.4.31, jaxlib 0.4.31
Python version: 3.10
GPU/TPU model and memory: v5p
CUDA version (if applicable):
Problem you have encountered:
I'm trying to figure out the flax counterpart for torch.nn.Conv1d. But I find the implementation below have same output but different grads after backward.
conv_torch.py:
from torch import nn
class BlockTorch(nn.Module):
def __init__(self):
super().__init__()
self.conv1d = nn.Conv1d(
in_channels=5120,
out_channels=5120,
bias=True,
kernel_size=4,
groups=5120,
padding=3,
)
def forward(
self,
x,
):
batch_size, seq_len, _ = x.shape
x = self.conv1d(x.transpose(1, 2))[..., :seq_len]
return x.transpose(1, 2)
Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.
System information
Problem you have encountered:
I'm trying to figure out the flax counterpart for
torch.nn.Conv1d
. But I find the implementation below have same output but different grads after backward.conv_torch.py
:conv_jax.py
My main test script
test.py
All the modules in the test script loads same
conv
weight and bias. The output:It shows small error between conv layer outputs, but big difference between grads.
What you expected to happen:
The WMAPE of both outputs and grads should be small.
Steps to reproduce:
Whenever possible, please provide a minimal example. Please consider submitting it as a Colab link.
The text was updated successfully, but these errors were encountered: