Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Non-hashable type error #1

Open
lucfra opened this issue Feb 22, 2022 · 6 comments
Open

Non-hashable type error #1

lucfra opened this issue Feb 22, 2022 · 6 comments

Comments

@lucfra
Copy link

lucfra commented Feb 22, 2022

Dear all,

I've read the BCD-nets paper which I found very interesting.
I am trying now to recreate your results, but unfortunately, I have run into this error.

line 949, in <module>
    ) = parallel_gradient_step(
ValueError: Non-hashable static arguments are not supported. 
An error occured during a call to 'parallel_gradient_step' while trying to hash
 an object of type <class 'numpy.ndarray'>,
 [[ 4.82755829e-01  2.30017473e+00  1.29824051e+00  1.94172572e+00 .....

which refers to this line of your code.

I must say that I have no experience with jax.
For context, I did not manage to install all the required packages using your environment.yml, so I went on with a manual installation. My jax version is 0.3.1.

P.S.: the code was not compatible right away. To make it runnable I did the following:

  • Replaced jax.partial (which is no longer available) with functools.partial (I read that jax.partial was an accidental leak)
  • Copy-pasted _conv_transpose_padding in nux.util.convolution from the jax version you used. Couldn't find _conv_transpose_padding in the 0.3.1.

Any help is very appreciated,

Cheers,
Luca

@lucfra
Copy link
Author

lucfra commented Mar 9, 2022

Hi all,
any news on this?

@jithendaraa
Copy link

jithendaraa commented Mar 9, 2022

I think jax.partial was discontinued from 0.2.21. If you use the same version as that of the authors (0.2.18) you don't really need to shift to functools.

@jithendaraa
Copy link

Your error also seems to do with this change from jax.partial to functools.partial. I am trying to reproduce the experiments as well and though I am not fully done, I don't get the error you are getting because of parallel_gradient_step. I would strongly suggest downgrading to jax 0.2.18. Or if you must stay at 0.3.0 (because of which other errors might pop up), you could try converting the np.ndarray variable into a jax numpy variable (so it is jit compilable) or add this variable's index in the function definition into partial's static_argnums argument.

@jithendaraa
Copy link

Also, does you jax access GPU (you can check this using jax.devices()

@lucfra
Copy link
Author

lucfra commented Mar 14, 2022

Hi @jithendaraa,

thanks for your message. I've initially tried installing 0.2.18 but with no luck. I'll try the second way you mentioned.

Please let me know if you have luck reproducing the experiments.

Cheers,
Luca

@jithendaraa
Copy link

Hey @lucfra,
That is a bit surprising since I see no reason why it would't work. jax.partial was removed only in versions >=0.2.20. So technically it should work and it does work for me. Regarding reproducing experiments - the entire codebase runs and trains for me without issues. I did not fully run the experiments mentioned in the paper to fully reproduce the results. But I might, in the near future.

Thanks,
Jith

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants