Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Calling objax.Grad twice on a model with BatchNorm with respect to input params crashes #158

Open
carlini opened this issue Nov 21, 2020 · 2 comments
Assignees
Labels
bug Something isn't working urgent

Comments

@carlini
Copy link
Collaborator

carlini commented Nov 21, 2020

The following code crashes.

    import objax
    import jax.numpy as jn
    import numpy as np

    mod = objax.nn.Sequential([objax.nn.Conv2D(2, 4, 3),
                               objax.nn.BatchNorm2D(4)])

    def ell(x):
        return jn.sum(mod(x, training=True))

    m1 = objax.Grad(ell, {}, (0,))
    print(m1(np.ones((8*8*8,2,10,10))))

    m2 = objax.Grad(ell, {}, (0,))
    print(m2(np.ones((8*8*8,2,10,10)))) # boom!

The reason this happens is because the first time we call grad(), jax traces the function with a tracer. We update, AND SAVE, the batch norm state as the tracer object. Then, the second time we trace, we find the old tracer and blow up.

This should be fixed by making sure that we never save a jax tracer object into the BatchNorm state.

@david-berthelot david-berthelot added bug Something isn't working urgent labels Nov 21, 2020
@AlexeyKurakin AlexeyKurakin self-assigned this Dec 1, 2020
@AlexeyKurakin
Copy link
Member

We have to save (temporary) jax tracer into StateVars for JAX machinery to work correctly. However it has to be "restored" later.

So this gets to another broader issue of automatically figuring out list of variables used by a function.
If we know a list of variables then we can properly handle this situation as well as some other (like auto-detect variables needed for JIT)

Thus the next step would be to figure out the way to track all Objax variables used by a function.
One possible way is to make our own Objax Tracer object to trace variables used by functions.

Next step is to explore and prototype how to make such tracer work with BatchNorm.

@AlexeyKurakin
Copy link
Member

Here is a code snipped which encapsulates this example into unit test: https://github.com/google/objax/pull/189/files

This code snippet also includes fix "don't save tracer into stateVars". Unfortunately this is not correct fix and it lead to multiple failures in many unit tests.

The real problem here is not the fact that Tracers are saved into variables. Temporary saving tracers is actually necessary for Objax to work. The problem is that after tracing is done, variable values should be restored back to what is was before tracing.
To do this, Grad (or any other Objax transformation) has to know what variables were used inside the function which was called.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working urgent
Projects
None yet
Development

No branches or pull requests

3 participants