Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PET-JAX #82

Closed
wants to merge 19 commits into from
Closed

PET-JAX #82

wants to merge 19 commits into from

Conversation

frostedoyster
Copy link
Collaborator

@frostedoyster frostedoyster commented Feb 16, 2024

Sorry @PicoCentauri but I couldn't figure out how to use the template.

TODO: everything, including solving a harmless bug where all models are initialized twice and figuring out some float32 composition inaccuracies


📚 Documentation preview 📚: https://metatensor-models--82.org.readthedocs.build/en/82/

@Luthaf
Copy link
Member

Luthaf commented Feb 19, 2024

Interesting! What is your plan for torch script export? Having both a JAX and Torch version, and loading weights from one inside the other?

If so, what is the advantage of the JAX version?

@frostedoyster frostedoyster marked this pull request as ready for review February 26, 2024 07:01
@frostedoyster frostedoyster changed the base branch from main to stray-dtype February 26, 2024 07:02
Base automatically changed from stray-dtype to main February 26, 2024 08:30
@PicoCentauri PicoCentauri deleted the pet-jax branch May 29, 2024 16:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants