-
Notifications
You must be signed in to change notification settings - Fork 76
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve error when recursively calling Parallel #157
Comments
Indeed, there are multiple errors interacting here, not sure on how to catch what. For your particular example:
import objax
import jax.numpy as jn
import numpy as np
# 1. Create modules
mod = objax.nn.Conv2D(2, 4, 3)
def ell(x):
return mod(x) # before it was using p, I assume you meant using mod.
m = objax.Grad(ell, objax.VarCollection(), (0,))
p = objax.Parallel(m, mod.vars(), reduce=lambda x: x)
# 2. Replicate vars before using modules.
with mod.vars().replicate():
print(p(np.ones((8*8,2,10,10)))) |
No that code is exactly what I meant. It's insane code, but it minified from a bug I actually had. I meant calling |
Okay, so can suggest a few things we could catch in your example. |
Yeah. I'm not sure yet is the issue. This code is obviously wrong and stupid. But I don't know the "right" way to say that something has gone wrong with it. Maybe the recursive call into parallel is where things go bad? Probably that should never happen. But it seems unfortunate to have to make the codebase uglier if we're going to explicitly check for loops. |
Currently if you have a parallel function recursively call itself, you can get some incomprehensible error messages.
This is very low priority.
The error for this is
And figuring out what this means is more or less impossible.
The text was updated successfully, but these errors were encountered: