You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The reason this happens is because the first time we call grad(), jax traces the function with a tracer. We update, AND SAVE, the batch norm state as the tracer object. Then, the second time we trace, we find the old tracer and blow up.
This should be fixed by making sure that we never save a jax tracer object into the BatchNorm state.
The text was updated successfully, but these errors were encountered:
We have to save (temporary) jax tracer into StateVars for JAX machinery to work correctly. However it has to be "restored" later.
So this gets to another broader issue of automatically figuring out list of variables used by a function.
If we know a list of variables then we can properly handle this situation as well as some other (like auto-detect variables needed for JIT)
Thus the next step would be to figure out the way to track all Objax variables used by a function.
One possible way is to make our own Objax Tracer object to trace variables used by functions.
Next step is to explore and prototype how to make such tracer work with BatchNorm.
This code snippet also includes fix "don't save tracer into stateVars". Unfortunately this is not correct fix and it lead to multiple failures in many unit tests.
The real problem here is not the fact that Tracers are saved into variables. Temporary saving tracers is actually necessary for Objax to work. The problem is that after tracing is done, variable values should be restored back to what is was before tracing.
To do this, Grad (or any other Objax transformation) has to know what variables were used inside the function which was called.
The following code crashes.
The reason this happens is because the first time we call
grad()
, jax traces the function with a tracer. We update, AND SAVE, the batch norm state as the tracer object. Then, the second time we trace, we find the old tracer and blow up.This should be fixed by making sure that we never save a jax tracer object into the BatchNorm state.
The text was updated successfully, but these errors were encountered: