Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Misleading print make believe Jax is loaded #219

Open
dandarm opened this issue Jan 16, 2025 · 1 comment
Open

Misleading print make believe Jax is loaded #219

dandarm opened this issue Jan 16, 2025 · 1 comment

Comments

@dandarm
Copy link

dandarm commented Jan 16, 2025

In the init file:

try:
print("Loaded Jax TimesFM.")
from timesfm.timesfm_jax import TimesFmJax as TimesFm
from timesfm import data_loader
except Exception as _:
print("Loaded PyTorch TimesFM.")
from timesfm.timesfm_torch import TimesFmTorch as TimesFm

can you change print rows so not to make users believe that Jax is correctly loaded, when that's not the case?
One understands that both Jax and Torch versions are loaded, while clearly is not, so now I understand hours of debugging not being able to load jax checkpoint as pointed in the finetuning notebook :)

Thanks, cheers

@rajatsen91
Copy link
Collaborator

Good point, sorry about that. We can change the order of the print statement, or if you want you can also submit a PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants