Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Colab and PAX error #226

Open
Mhdaw opened this issue Jan 22, 2025 · 7 comments
Open

Colab and PAX error #226

Mhdaw opened this issue Jan 22, 2025 · 7 comments

Comments

@Mhdaw
Copy link

Mhdaw commented Jan 22, 2025

Hello all, I have tried to use the TimesFM PAX version in colab but it is not working, And I get this error:

---------------------------------------------------------------------------
FileNotFoundError                         Traceback (most recent call last)
<ipython-input-7-f2ffa192cd9e> in <cell line: 0>()
      1 # For PAX
----> 2 tfm = timesfm.TimesFm(
      3       hparams=timesfm.TimesFmHparams(
      4           backend="cpu",
      5           per_core_batch_size=32,

4 frames
/usr/local/lib/python3.11/dist-packages/timesfm/timesfm_base.py in __init__(self, hparams, checkpoint)
    237     self._horizon_start = self.context_len - self.input_patch_len
    238     self.__post_init__()
--> 239     self.load_from_checkpoint(checkpoint)
    240 
    241   def load_from_checkpoint(self, checkpoint: TimesFmCheckpoint) -> None:

/usr/local/lib/python3.11/dist-packages/timesfm/timesfm_torch.py in load_from_checkpoint(self, checkpoint)
     62           "torch_model.ckpt")
     63     self._model = ppd.PatchedTimeSeriesDecoder(self._model_config)
---> 64     loaded_checkpoint = torch.load(checkpoint_path, weights_only=True)
     65     logging.info("Loading checkpoint from %s", checkpoint_path)
     66     self._model.load_state_dict(loaded_checkpoint)

/usr/local/lib/python3.11/dist-packages/torch/serialization.py in load(f, map_location, pickle_module, weights_only, mmap, **pickle_load_args)
   1317         pickle_load_args["encoding"] = "utf-8"
   1318 
-> 1319     with _open_file_like(f, "rb") as opened_file:
   1320         if _is_zipfile(opened_file):
   1321             # The zipfile reader is going to advance the current file position.

/usr/local/lib/python3.11/dist-packages/torch/serialization.py in _open_file_like(name_or_buffer, mode)
    657 def _open_file_like(name_or_buffer, mode):
    658     if _is_path(name_or_buffer):
--> 659         return _open_file(name_or_buffer, mode)
    660     else:
    661         if "w" in mode:

/usr/local/lib/python3.11/dist-packages/torch/serialization.py in __init__(self, name, mode)
    638 class _open_file(_opener):
    639     def __init__(self, name, mode):
--> 640         super().__init__(open(name, mode))
    641 
    642     def __exit__(self, *args):

FileNotFoundError: [Errno 2] No such file or directory: '/root/.cache/huggingface/hub/models--google--timesfm-1.0-200m/snapshots/8775f7531211ac864b739fe776b0b255c277e2be/torch_model.ckpt'

Also When I want to try the Fine-tuning notebook it is not working as well and it is not installing praxis and I think lingvo

@rajatsen91
Copy link
Collaborator

Hi, it seems that you have installed the pytorch version. If the python version in your colab is 3.11 the pip install timesfm[torch] will install the torch version. Then you have to use the torch checkpoint.

If the python version is 3.10.x then pip install timesfm[pax] will install the jax version and you can use the jax checkpoint.

@Mhdaw
Copy link
Author

Mhdaw commented Jan 22, 2025

Can you share a version for fine-tuning the pytorch model? I'm trying to write it but not sure(with dataset(dataloader) and loss computation(mae or quantile loss?)? or a guide?
Thanks

@rajatsen91
Copy link
Collaborator

@Mhdaw
Copy link
Author

Mhdaw commented Jan 22, 2025

Thanks for your fast response.

@khoat007
Copy link

I ran into the same issue. I also changed the colab python to python3.10 version and using
pip install timesfm[pax] and load the checkpoint but it didn't work. It was fine last week. Hopefully your PR will fix the issue. There are so many moving part of the great GIT REPO and timesfm. Recently google colab change to python version 3.11 and it broke the fine tune.
@rajatsen91, is it possible to provide a GCP instance that is stable and we can use it ?
Thanks,

@rajatsen91
Copy link
Collaborator

We will try to release a docker for the pax and torch versions. But when the PR is released we will have finetuning in torch so that should work on python 3.11.x.

@khoat007
Copy link

Hi @rajatsen91 ,
Could you help to include the notebooks that use torch for fine tuning and also loading the fine tune model for testing. We are developing some solutions with your model and currently we can't load our previous fine tuned anymore! Below is the code but we didn't even get pass the

"from timesfm import patched_decoder" <-----

import timesfm #timesfm foundational model
import gc #Garbage Collector interface
import numpy as np
import pandas as pd
from timesfm import patched_decoder # decoder in patches (N=L/p)
from timesfm import data_loader


ModuleNotFoundError Traceback (most recent call last)
in <cell line: 0>()
3 import numpy as np
4 import pandas as pd
----> 5 from timesfm import patched_decoder # decoder in patches (N=L/p)
6 from timesfm import data_loader

/usr/local/lib/python3.11/dist-packages/timesfm/patched_decoder.py in
23 from jax import lax
24 import jax.numpy as jnp
---> 25 from praxis import base_layer
26 from praxis import base_model
27 from praxis import layers

ModuleNotFoundError: No module named 'praxis'


NOTE: If your import is failing due to a missing package, you can
manually install dependencies using either !pip or !apt.

To view examples of installing some common dependencies, click the
"Open Examples" button below.

For Torch

#Directory for fine tuning
CHECKPOINT_DIR='./finetune/health/2025-01-15_00.48.30'

tfm_200M = timesfm.TimesFm(
hparams=timesfm.TimesFmHparams(
context_len=context_len,
horizon_len=horizon_len,
input_patch_len=32,
output_patch_len=128,
num_layers=20,
model_dims=1280,
per_core_batch_size=32,
backend="gpu", # This is using 'pytorch'
),
checkpoint=timesfm.TimesFmCheckpoint(
path=CHECKPOINT_DIR),
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants