Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

flax counterpart for torch.nn.Conv1d #4188

Open
Liyang90 opened this issue Sep 10, 2024 · 0 comments
Open

flax counterpart for torch.nn.Conv1d #4188

Liyang90 opened this issue Sep 10, 2024 · 0 comments

Comments

@Liyang90
Copy link

Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): GCP Cloud TPU VM
  • Flax 0.8.5, jax 0.4.31, jaxlib 0.4.31
  • Python version: 3.10
  • GPU/TPU model and memory: v5p
  • CUDA version (if applicable):

Problem you have encountered:

I'm trying to figure out the flax counterpart for torch.nn.Conv1d. But I find the implementation below have same output but different grads after backward.

conv_torch.py:

from torch import nn


class BlockTorch(nn.Module):

    def __init__(self):
        super().__init__()
        self.conv1d = nn.Conv1d(
            in_channels=5120,
            out_channels=5120,
            bias=True,
            kernel_size=4,
            groups=5120,
            padding=3,
        )

    def forward(
        self,
        x,
    ):
        batch_size, seq_len, _ = x.shape
        x = self.conv1d(x.transpose(1, 2))[..., :seq_len]
        return x.transpose(1, 2)

conv_jax.py

from flax import linen as nn
import jax
import jax.numpy as jnp


class BlockJAX_0(nn.Module):
  kernel: jax.Array
  bias: jax.Array

  def setup(self):

    def kernel_init(key, shape, dtype):
      assert self.kernel.shape == shape
      return self.kernel.astype(dtype)
    
    def bias_init(key, shape, dtype):
      assert self.bias.shape == shape
      return self.bias.astype(dtype)

    self.conv1d = nn.Conv(features=5120,
                          kernel_size=[4],
                          feature_group_count=5120,
                          padding='CAUSAL',
                          use_bias=True,
                          kernel_init=kernel_init,
                          bias_init=bias_init,
                          )
    
  def __call__(self, x):
    x = self.conv1d(x)
    return x
  

class BlockJAX_1(nn.Module):
  kernel: jax.Array
  bias: jax.Array

  def setup(self):

    def kernel_init(key, shape, dtype):
      assert self.kernel.shape == shape
      return self.kernel.astype(dtype)
    
    def bias_init(key, shape, dtype):
      assert self.bias.shape == shape
      return self.bias.astype(dtype)

    self.conv1d = nn.Conv(features=5120,
                          kernel_size=[4],
                          feature_group_count=5120,
                          padding=3,
                          use_bias=True,
                          kernel_init=kernel_init,
                          bias_init=bias_init,
                          )
    
  def __call__(self, x):
    (b, l, d) = x.shape
    x = self.conv1d(x)[:, :l, :]
    return x

My main test script test.py

import numpy as np
from numpy.random import MT19937
from numpy.random import RandomState, SeedSequence
rs = RandomState(MT19937(SeedSequence(123456789)))

import jax
import jax.numpy as jnp

import torch

from conv_jax import BlockJAX_0, BlockJAX_1
from conv_torch import BlockTorch

# prepare common weights and inputs
kernel = rs.normal(size=(4, 1, 5120))
bias = rs.normal(size=(5120,))
input = rs.normal(size=(4, 4096, 5120))

# torch module forward and backward
torch.set_printoptions(precision=7)
conv_torch = BlockTorch()
state_dict = conv_torch.state_dict()
state_dict["conv1d.weight"] = torch.from_numpy(kernel).to(torch.float32).transpose(0, 2)
state_dict["conv1d.bias"] = torch.from_numpy(bias).to(torch.float32)
conv_torch.load_state_dict(state_dict)
conv_torch.zero_grad()
output_torch = conv_torch(torch.from_numpy(input).to(torch.float32))
loss_torch = output_torch.mean()
loss_torch.backward()

# flax module forward and backward
def jax_forward_backward(model, params, input):
  def forward(params, input):
    output = model.apply(params, input)
    loss = jnp.mean(output)
    return loss, output

  forward_backward_fn = jax.value_and_grad(forward, has_aux=True)
  (loss, output), grad = forward_backward_fn(params, input)
  return loss, output, grad

input_jax = jnp.array(input)
kernel_jax = jnp.array(kernel)
bias_jax = jnp.array(bias)

conv_jax_0 = BlockJAX_0(kernel_jax, bias_jax)
rng = jax.random.key(0)
params_jax_0 = conv_jax_0.init(rng, input_jax)

loss_jax_0, output_jax_0, grad_jax_0 = jax_forward_backward(conv_jax_0, params_jax_0, input_jax)

conv_jax_1 = BlockJAX_1(kernel_jax, bias_jax)
params_jax_1 = conv_jax_1.init(rng, input_jax)

loss_jax_1, output_jax_1, grad_jax_1 = jax_forward_backward(conv_jax_1, params_jax_1, input_jax)

print("================================================")
print(f"conv_torch.conv1d.weight.grad.T shape: {conv_torch.conv1d.weight.grad.T.shape}")
print(conv_torch.conv1d.weight.grad.T)
print("================================================")
print(f'grad_jax_0["params"]["conv1d"]["kernel"] {grad_jax_0["params"]["conv1d"]["kernel"].shape}')
print(grad_jax_0["params"]["conv1d"]["kernel"])
print("================================================")

def wmape(a, b):
  return np.sum(np.abs(a - b)) / np.sum(np.abs(a))

print(f"losses: {(loss_torch, loss_jax_0, loss_jax_1)}")

print("Outputs WMAPE:")
print(wmape(output_torch.detach().numpy(), np.array(output_jax_0)))
print(wmape(output_torch.detach().numpy(), np.array(output_jax_1)))

print("Grads WMAPE:")
print(wmape(conv_torch.conv1d.weight.grad.T.detach().numpy(), np.array(grad_jax_0["params"]["conv1d"]["kernel"])))
print(wmape(conv_torch.conv1d.weight.grad.T.detach().numpy(), np.array(grad_jax_1["params"]["conv1d"]["kernel"])))

All the modules in the test script loads same conv weight and bias. The output:

conv_torch.conv1d.weight.grad.T shape: torch.Size([4, 1, 5120])
tensor([[[-1.0676654e-06, -6.8133699e-07,  6.4320474e-08,  ...,
          -6.2070239e-07, -1.3589821e-06, -4.0437203e-06]],

        [[-1.0596159e-06, -6.6314635e-07,  6.1584302e-08,  ...,
          -6.3736798e-07, -1.4141298e-06, -4.0575105e-06]],

        [[-1.0563341e-06, -6.9564146e-07,  1.6096948e-08,  ...,
          -6.5698919e-07, -1.4593004e-06, -4.0748873e-06]],

        [[-1.0387610e-06, -6.8761653e-07,  7.2919590e-09,  ...,
          -6.7382234e-07, -1.4510101e-06, -4.0458808e-06]]])
================================================
grad_jax_0["params"]["conv1d"]["kernel"] (4, 1, 5120)
[[[-1.0676664e-06 -6.8133664e-07  6.4321057e-08 ...  5.7032690e-07
    3.1477473e-06  1.3152874e-06]]

 [[-1.0596164e-06 -6.6314561e-07  6.1584501e-08 ...  5.5595456e-07
    3.1550721e-06  1.3070267e-06]]

 [[-1.0563347e-06 -6.9564129e-07  1.6097601e-08 ...  6.2066499e-07
    3.1343752e-06  1.2927190e-06]]

 [[-1.0387618e-06 -6.8761614e-07  7.2924422e-09 ...  5.8497437e-07
    3.1675968e-06  1.2691942e-06]]]
================================================

losses: (tensor(0.0189537, grad_fn=<MeanBackward0>), Array(0.01895385, dtype=float32), Array(0.01895385, dtype=float32))

Outputs WMAPE:
0.0013432346
0.0013432346

Grads WMAPE:
0.7058088
0.7058088

It shows small error between conv layer outputs, but big difference between grads.

What you expected to happen:

The WMAPE of both outputs and grads should be small.

Steps to reproduce:

Whenever possible, please provide a minimal example. Please consider submitting it as a Colab link.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant