Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow exporting models from remote locations #352

Merged
merged 8 commits into from
Nov 8, 2024
Merged

Conversation

PicoCentauri
Copy link
Contributor

@PicoCentauri PicoCentauri commented Oct 3, 2024

Fixes #343

To allow this, I added a new function load_model which either loads a model from disk or URL and returns the model. Syntax from the CLI is not changed and from Python one has to do

from metatrain.utils.io import load_model

model = load_model(
    path="https://my.url.com/fancy_model.ckpt",
    architecture_name=""experimental.soap_bpnn",
)
model.export()

It also works for already exported models even without the architecture_name

model = load_model("https://my.url.com/fancy_model.pt")

which makes models directly usable for MD for example inside the MetatensorCalculator for ASE.

We can simplify the imports but let me know if you are happy with the API.

Contributor (creator of pull-request) checklist

  • Tests updated (for new features and bugfixes)?
  • Documentation updated (for new features)?
  • Issue referenced (for PRs that solve an issue)?

📚 Documentation preview 📚: https://metatrain--352.org.readthedocs.build/en/352/

@abmazitov
Copy link
Contributor

Very cool, thank you @PicoCentauri! I actually like how this works! The are few things however, which are still not totally clear for me.

  1. What are we planning to store remotely in the end? Is it a checkpoint (.ckpt file) or an exported model (.pt file). I would personally vote for the latter, but in this case I'm not sure if the extensions can be exported properly.
  2. Does urllib.urlretrieve use caching?
  3. AFAIK, MetatensorCalculator uses both the model and the extensions directory as input arguments. Does it mean that we have to first export the model to store the extensions on disk, and then load both the model and the extensions back to the MetatensorCalculator? Maybe @Luthaf could say more on this?

@PicoCentauri
Copy link
Contributor Author

  1. What are we planning to store remotely in the end? Is it a checkpoint (.ckpt file) or an exported model (.pt file). I would personally vote for the latter, but in this case I'm not sure if the extensions can be exported properly.

This is something I discussed as well with @Luthaf and @frostedoyster. Storing already exported models (.pt) is nice and preferred for standalone models. But, if your architecture uses extensions we have to rebuild these extensions for the platform you want to run the downloaded model. That is why we should maybe store the final checkpoints for these. To have a smoother user experience models should keep a version (See also #351) to avoid confusing errors when trying to create the extensions on export.

  1. Does urllib.urlretrieve use caching?

I am not sure but I don't think so. It creates a tempory file. I know caching is everybodies darling but can be hard to implement like based on which hash to we create the cache: the URL, or something based on the model. I will look into this and will add caching in a future version.

  1. AFAIK, MetatensorCalculator uses both the model and the extensions directory as input arguments. Does it mean that we have to first export the model to store the extensions on disk, and then load both the model and the extensions back to the MetatensorCalculator? Maybe @Luthaf could say more on this?

Yes, we have to recreate the extensions. The extensions depend also on the version of the architecture. So we may have to keep every version of the architecture around. See also my first comment to your first point.

@abmazitov
Copy link
Contributor

abmazitov commented Oct 4, 2024

Okay, I'm fine with saving checkpoint either with or without TorchScripted models, if we need this. However, saving the checkpoint to the disk and then loading it back to activate the model with extensions seems a bit counterintuitive... Maybe we can make the load_model function store the extensions automatically and return an instance of the MetatensorAtomisticModel? I.e. the load_model can actually get the checkpoint from URL, export the model with the extensions to a specific place on disk (see also the comment below), and then return the loaded MetatensorAtomisticModel already with pre-loaded extensions.

I also asked ChatGPT what he thinks, and I think there was a good idea of creating a ~/.metatensor/ directory to store the cached checkpoints and models. In this case, for every version of the model, we can actually create a folder with the checkpoint, the exported model itself, and its extensions, and access it later. If we come up with a proper naming convention so every model and every version has it's own unique directory, we can solve the caching problem as well (and avoid exporting every time the load_model is called).

@PicoCentauri PicoCentauri force-pushed the download-model branch 2 times, most recently from c94be49 to 64b5e8b Compare October 7, 2024 13:31
@PicoCentauri PicoCentauri marked this pull request as ready for review October 7, 2024 13:31
@PicoCentauri
Copy link
Contributor Author

Hmm I mean what you want is currently possible via

import torch

from metatrain.cli import export_model
from metatrain.utils.io import load_model

model = torch.jit.load(export_model(load_model("experimental.pet", 'https://XXX.com/mymodel.ckpt')))

I wouldn't wrap this whole functionality into one function called load_model. Each part is used by different parts of metatrain.

But we can provide something that provides the workflow for the Python API that we are planning to write. What do you think?

@PicoCentauri
Copy link
Contributor Author

Regarding the PR in general, I would add caching and the actual example for the python API once we have all ingredients together.

@Luthaf
Copy link
Member

Luthaf commented Oct 7, 2024

AFAIK, MetatensorCalculator uses both the model and the extensions directory as input arguments. Does it mean that we have to first export the model to store the extensions on disk, and then load both the model and the extensions back to the MetatensorCalculator? Maybe @Luthaf could say more on this?

So, if you have a MetatensorAtomisticModel instance, you already have all extensions loaded, and you can create a calculator straight away. The extensions argument is only useful when trying to load the model from a path.

In the current version of the code, the extensions will be loaded when loading the architecture (architecture = import_architecture(architecture_name)).

I'm not sure what happen with architecture.__model__.load_checkpoint() though, is it required to return a MetatensorAtomisticModel? If so, this should work:

model = load_model("experimental.pet", 'https://XXX.com/mymodel.ckpt')
calculator = MetatensorCalculator(model)

If not, maybe we should clarify a bit how this feature interacts with checkpoint/exported models.

But, if your architecture uses extensions we have to rebuild these extensions for the platform you want to run the downloaded model. That is why we should maybe store the final checkpoints for these

I'm not sure I see why we would need to store checkpoints to be able to re-create the extensions? Importing the architecture should be enough to load all extensions (ignoring for now all questions of API stability & versioning).

@PicoCentauri
Copy link
Contributor Author

model = load_model("experimental.pet", 'https://XXX.com/mymodel.ckpt')
calculator = MetatensorCalculator(model)

Yes, I think this should indeed work!

I'm not sure I see why we would need to store checkpoints to be able to re-create the extensions? Importing the architecture should be enough to load all extensions (ignoring for now all questions of API stability & versioning).

For ase sure, but what if you want to run a full fledge command line experience? To have the correct extensions exported you need a checkpoint or the architecture name (of course ignoring for now all questions of API stability & versioning).

While I am writing this I see that really a model if enough plus the architecture name to construct the expansion. We maybe should change the code to always write the extensions. Currently we are just writing the model with a new name

if is_exported(model):
logger.info(f"The model is already exported. Saving it to `{path}`.")
torch.jit.save(model, path)
else:
extensions_path = "extensions/"
logger.info(
f"Exporting model to '{path}' and extensions to '{extensions_path}'"
)
mts_atomistic_model = model.export()
mts_atomistic_model.save(path, collect_extensions=extensions_path)
logger.info("Model exported successfully")

but probably we should do something like

if not is_exported(model): 
    model = model.export()

extensions_path = "extensions/" 
model.save(path, collect_extensions=extensions_path) 
logger.info( f"Model exported to '{path}' and extensions to '{extensions_path}'" ) 

Does this make sense?

@Luthaf
Copy link
Member

Luthaf commented Oct 8, 2024

but probably we should do something like [...]

Yes, this looks a lot cleaner!

@PicoCentauri
Copy link
Contributor Author

It is much cleaner but unfortunately once a model is exported and reloaded the save method comes from torch and not from metansor. You get an error like this when you try to reexport.

In [7]: model.save("foo.pt", collect_extensions="extensions/")
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[7], line 1
----> 1 model.save("foo.pt", collect_extensions="extensions/")

File ~/repos/lab-cosmo/metatrain/.venv/lib/python3.12/site-packages/torch/jit/_script.py:753, in RecursiveScriptModule.save(self, f, **kwargs)
    744 def save(self, f, **kwargs):
    745     r"""Save with a file-like object.
    746 
    747     save(f, _extra_files={})
   (...)
    751     DO NOT confuse these two functions when it comes to the 'f' parameter functionality.
    752     """
--> 753     return self._c.save(str(f), **kwargs)

TypeError: save(): incompatible function arguments. The following argument types are supported:
    1. (self: torch._C.ScriptModule, filename: str, _extra_files: dict[str, str] = {}) -> None

Invoked with: <torch.ScriptObject object at 0x12f855d20>, 'foo.pt'; kwargs: collect_extensions='extensions/'

So I think we should try to expose our save function when we export, but I don't know if this is possible.

@PicoCentauri
Copy link
Contributor Author

This PR needs metatensor/metatensor#761 to be merged and a metatensor torch release to be continued and finished.

@frostedoyster
Copy link
Collaborator

Sorry for the random comment, but it is important to keep in mind that MetatensorAtomisticModels are not torchscripted unless they're saved to a file and re-loaded. This might be important for optimal speed

@PicoCentauri
Copy link
Contributor Author

Getting closer. Now at reexport I get an error that module.forward.__annotations__ is missing for an already exported model.

module = RecursiveScriptModule(
  original_name=SoapBpnn
  (soap_calculator): RecursiveScriptModule(original_name=SoapPowerSpec...ecursiveScriptModule(
    original_name=ModuleList
    (0): RecursiveScriptModule(original_name=CompositionModel)
  )
)

    def _check_annotation(module: torch.nn.Module):
        # check annotations on forward
>       annotations = module.forward.__annotations__

@PicoCentauri
Copy link
Contributor Author

Sorry for the random comment, but it is important to keep in mind that MetatensorAtomisticModels are not torchscripted unless they're saved to a file and re-loaded. This might be important for optimal speed

Can't your script them without saving and reloading?

@frostedoyster
Copy link
Collaborator

You can, but that's not what we're doing for now (see MetatensorAtomisticModel class in metatensor)

@Luthaf
Copy link
Member

Luthaf commented Nov 6, 2024

You should be able to do

inner = ...
model = MetatensorAtomisticModel(inner, ...)

scripted = torch.jit.script(model)

And then use scripted, without loading/unloading.

You would loose the ability to save the model though, unless we refactor the code for this a bit (make it a freestanding function, and call save_atomistic_model(self) in MetatensorAtomisticModel.save)

Copy link
Collaborator

@frostedoyster frostedoyster left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a test with a checkpoint export from a remote location? Are you working on it?


.. code-block:: bash

mtt export model.ckpt -o model.pt
mtt export experimental.soap_bpnn model.ckpt -o model.pt
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Completely unrelated, but I will open an issue to have this removed (it should be easy and I don't like it too much)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't like the syntax itself or the line in the docs?

The syntax we can't remove because we need the corresponding archtecture name to load a checkpoint.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I would ideally like to go back to mtt export model.ckpt -o model.pt

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having the name on the command line should be avoidable by requiring an architecture_name field in the checkpoint (but then this rule must be enforced for all architectures and added to "how to add a new architecture")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we could do this. Good idea!

Comment on lines +62 to +64
extras = # architectures used in the package tests
soap-bpnn
pet
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is PET actually used in the package tests? I can't find it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is hidden a bit. I need the PET requirements to be installed to check for a wrong architecture name in this test.
I did not find another way to trigger the test.

def test_load_model_unknown_model():
architecture_name = "experimental.pet"
path = RESOURCES_PATH / "model-32-bit.ckpt"
match = (
f"path '{path}' is not a valid model file for the {architecture_name} "
"architecture"
)
with pytest.raises(ValueError, match=match):
load_model(path, architecture_name=architecture_name)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh fair. My concern was that PET takes a while to install (thanks to its compiled extension). I will open an issue to track this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I am also not happy, maybe one can also monkeypatch this...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't worry, we should be able to take it out soon-ish

@PicoCentauri
Copy link
Contributor Author

PicoCentauri commented Nov 8, 2024

Do we have a test with a checkpoint export from a remote location? Are you working on it?

Everything is already there in test_load_model_checkpoint and test_load_model_exported. Updated the docstring of load_model to make it clearer that we can load checkpoints and exported models.

@pytest.mark.parametrize(
"path",
[
RESOURCES_PATH / "model-32-bit.ckpt",
str(RESOURCES_PATH / "model-32-bit.ckpt"),
f"file:{str(RESOURCES_PATH / 'model-32-bit.ckpt')}",
],
)
def test_load_model_checkpoint(path):
model = load_model(path, architecture_name="experimental.soap_bpnn")
assert type(model) is SoapBpnn
@pytest.mark.parametrize(
"path",
[
RESOURCES_PATH / "model-32-bit.pt",
str(RESOURCES_PATH / "model-32-bit.pt"),
f"file:{str(RESOURCES_PATH / 'model-32-bit.pt')}",
],
)
def test_load_model_exported(path):
model = load_model(path, architecture_name="experimental.soap_bpnn")
assert type(model) is MetatensorAtomisticModel

@frostedoyster
Copy link
Collaborator

Sorry for my ignorance, what is f"file:{str(RESOURCES_PATH / 'model-32-bit.ckpt')}"?
I was looking for an https:// test, but perhaps it's the same

@PicoCentauri
Copy link
Contributor Author

PicoCentauri commented Nov 8, 2024

Sorry for my ignorance, what is f"file:{str(RESOURCES_PATH / 'model-32-bit.ckpt')}"?
I was looking for an https:// test, but perhaps it's the same

No, it is fine. Yes, it is the same because we are using urllib to do the heavy lifting of downloading files with a common API. If there is an supported URL format that is recognized by urlparse we will use urlretrieve that will "download" the file to a temporary folder and returns the path. urlparse recognized prefixed like https://, ftp:// but also file:. So using file: in the test will trigger the "url" branch of the code and it should work also with real remote locations like https://. If it doesn't this is a problem of urllib.

Copy link
Collaborator

@frostedoyster frostedoyster left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing, and also thanks for the explanation! It's all ready IMO

@PicoCentauri PicoCentauri merged commit 5eefa32 into main Nov 8, 2024
12 checks passed
@PicoCentauri PicoCentauri deleted the download-model branch November 8, 2024 14:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Downloading pre-trained models
4 participants