Skip to content

Commit

Permalink
Prototype
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed May 30, 2024
1 parent 7334795 commit e7b293f
Show file tree
Hide file tree
Showing 7 changed files with 8 additions and 22 deletions.
2 changes: 1 addition & 1 deletion docs/src/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


# When importing metatensor-torch, this will change the definition of the classes
# to include in the documentation
# to include the documentation
os.environ["METATENSOR_IMPORT_FOR_SPHINX"] = "1"
os.environ["RASCALINE_IMPORT_FOR_SPHINX"] = "1"

Expand Down
4 changes: 2 additions & 2 deletions docs/src/dev-docs/new-architecture.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ to these lines
checkpoint_dir="path",
)
model.save_checkpoint("final.ckpt")
model.save_checkpoint("model.ckpt")
mts_atomistic_model = model.export()
mts_atomistic_model.export("path", collect_extensions="extensions-dir/")
mts_atomistic_model.export("model.pt", collect_extensions="extensions/")
In order to follow this, a new architectures has two define two classes
Expand Down
15 changes: 0 additions & 15 deletions src/metatensor/models/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,6 @@
from .utils.logging import setup_logging


# This import is necessary to avoid errors when loading an
# exported alchemical model, which depends on sphericart-torch.
# TODO: Remove this when https://github.com/lab-cosmo/metatensor/issues/512
# is ready
try:
import sphericart.torch # noqa: F401
except ImportError:
pass

try:
import rascaline.torch # noqa: F401
except ImportError:
pass


logger = logging.getLogger(__name__)


Expand Down
2 changes: 1 addition & 1 deletion src/metatensor/models/cli/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _add_eval_model_parser(subparser: argparse._SubParsersAction) -> None:
)
parser.add_argument(
"-e",
"--extdir",
"--extensions-dir",
type=str,
required=False,
dest="extensions_directory",
Expand Down
2 changes: 1 addition & 1 deletion src/metatensor/models/cli/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,4 @@ def export_model(model: Any, output: Union[Path, str] = "exported-model.pt") ->
torch.jit.save(model, path)
else:
mts_atomistic_model = model.export()
mts_atomistic_model.export(path)
mts_atomistic_model.export(path, collect_extensions="extensions")
2 changes: 1 addition & 1 deletion src/metatensor/models/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def train_model(
raise ArchitectureError(e)

mts_atomistic_model = model.export()
mts_atomistic_model.export(str(output_checked))
mts_atomistic_model.export(str(output_checked), collect_extensions="extensions")

###########################
# EVALUATE FINAL MODEL ####
Expand Down
3 changes: 2 additions & 1 deletion src/metatensor/models/utils/evaluate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
"ignore",
category=UserWarning,
message="neighbor",
) # TODO: this is not filtering out the warning for some reason
) # TODO: this is not filtering out the warning for some reason, therefore:
warnings.filterwarnings("ignore") # ignore all warnings if not in debug mode


def evaluate_model(
Expand Down

0 comments on commit e7b293f

Please sign in to comment.