Skip to content

Commit

Permalink
Add skeleton for exporter
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri committed Dec 13, 2023
1 parent c63123d commit a049918
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 6 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dependencies = [
"ase",
"torch",
"hydra-core",
"rascaline-torch @ git+https://github.com/luthaf/rascaline#subdirectory=python/rascaline-torch",
#"rascaline-torch @ git+https://github.com/luthaf/rascaline#subdirectory=python/rascaline-torch",
"metatensor-core",
"metatensor-operations",
"metatensor-torch",
Expand Down
1 change: 0 additions & 1 deletion src/metatensor/models/cli/eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def eval_model(
:param structure_path: Path to a structure file which should be considered for the
evaluation.
:param output_path: Path to save the predicted values
"""

model = load_model(model_path)
Expand Down
35 changes: 31 additions & 4 deletions src/metatensor/models/cli/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,41 @@


def _add_export_model_parser(subparser: argparse._SubParsersAction) -> None:
if export_model.__doc__ is not None:
description = export_model.__doc__.split(r":param")[0]
else:
description = None

parser = subparser.add_parser(
"export",
description=export_model.__doc__,
description=description,
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.set_defaults(callable="export_model")

parser.add_argument(
"-m",
"--model",
dest="model_path",
type=str,
required=True,
help="Path to a saved model",
)
parser.add_argument(
"-o",
"--output",
dest="output_path",
type=str,
required=False,
default="exported.pt",
help="Export path for the model.",
)


def export_model(model_path: str, output_path: str) -> None:
"""Export a pretrained model to run MD simulations
def export_model():
"""export a model"""
print("Run exort...")
:param model_path: Path to a saved model
:param output_path: Path to save the exported model
"""
raise NotImplementedError("model exporting is not implemented yet.")

0 comments on commit a049918

Please sign in to comment.