Skip to content

Commit

Permalink
Remove set_default_dtype and get_default_dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri committed Mar 20, 2024
1 parent c46dd5a commit a7f21d7
Show file tree
Hide file tree
Showing 29 changed files with 124 additions and 116 deletions.
2 changes: 1 addition & 1 deletion src/metatensor/models/cli/conf/base.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
device: "cpu"
base_precision: 64
base_precision: 32
seed: null
7 changes: 6 additions & 1 deletion src/metatensor/models/cli/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,10 @@ def eval_model(
)
logger.info("Setting up evaluation set.")

# TODO: once https://github.com/lab-cosmo/metatensor/pull/551 is merged and released
# use capabilities instead of this workaround
dtype = next(model.parameters()).dtype

if isinstance(output, str):
output = Path(output)

Expand All @@ -194,11 +198,12 @@ def eval_model(
eval_systems = read_systems(
filename=options["systems"]["read_from"],
fileformat=options["systems"]["file_format"],
dtype=dtype,
)

# Predict targets
if hasattr(options, "targets"):
eval_targets = read_targets(options["targets"])
eval_targets = read_targets(options["targets"], dtype=dtype)
eval_dataset = Dataset(system=eval_systems, energy=eval_targets["energy"])
_eval_targets(model, eval_dataset)
else:
Expand Down
22 changes: 9 additions & 13 deletions src/metatensor/models/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,11 @@ def _train_model_hydra(options: DictConfig) -> None:
necessary options for dataset preparation, model hyperparameters, and training.
"""
if options["base_precision"] == 64:
torch.set_default_dtype(torch.float64)
dtype = torch.float64
elif options["base_precision"] == 32:
torch.set_default_dtype(torch.float32)
dtype = torch.float32
elif options["base_precision"] == 16:
torch.set_default_dtype(torch.float16)
dtype = torch.float16
else:
raise ValueError("Only 64, 32 or 16 are possible values for `base_precision`.")

Expand Down Expand Up @@ -239,11 +239,9 @@ def _train_model_hydra(options: DictConfig) -> None:
train_systems = read_systems(
filename=train_options["systems"]["read_from"],
fileformat=train_options["systems"]["file_format"],
dtype=torch.get_default_dtype(),
)
train_targets = read_targets(
conf=train_options["targets"], dtype=torch.get_default_dtype()
dtype=dtype,
)
train_targets = read_targets(conf=train_options["targets"], dtype=dtype)
train_datasets.append(Dataset(system=train_systems, **train_targets))

train_size = 1.0
Expand Down Expand Up @@ -293,11 +291,9 @@ def _train_model_hydra(options: DictConfig) -> None:
test_systems = read_systems(
filename=test_options["systems"]["read_from"],
fileformat=test_options["systems"]["file_format"],
dtype=torch.get_default_dtype(),
)
test_targets = read_targets(
conf=test_options["targets"], dtype=torch.get_default_dtype()
dtype=dtype,
)
test_targets = read_targets(conf=test_options["targets"], dtype=dtype)
test_dataset = Dataset(system=test_systems, **test_targets)
test_datasets.append(test_dataset)

Expand Down Expand Up @@ -346,10 +342,10 @@ def _train_model_hydra(options: DictConfig) -> None:
validation_systems = read_systems(
filename=validation_options["systems"]["read_from"],
fileformat=validation_options["systems"]["file_format"],
dtype=torch.get_default_dtype(),
dtype=dtype,
)
validation_targets = read_targets(
conf=validation_options["targets"], dtype=torch.get_default_dtype()
conf=validation_options["targets"], dtype=dtype
)
validation_dataset = Dataset(
system=validation_systems, **validation_targets
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import ase
import torch
from metatensor.torch.atomistic import (
MetatensorAtomisticModel,
ModelCapabilities,
Expand Down Expand Up @@ -31,7 +30,7 @@ def test_prediction_subset():

alchemical_model = Model(capabilities, DEFAULT_HYPERS["model"])
system = ase.Atoms("O2", positions=[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]])
system = systems_to_torch(system, dtype=torch.get_default_dtype())
system = systems_to_torch(system)
system = get_system_with_neighbors_lists(
system, alchemical_model.requested_neighbors_lists()
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ def test_rotational_invariance():
system = ase.io.read(DATASET_PATH)
original_system = copy.deepcopy(system)
system.rotate(48, "y")
original_system = systems_to_torch(original_system, dtype=torch.get_default_dtype())
original_system = systems_to_torch(original_system)
original_system = get_system_with_neighbors_lists(
original_system, alchemical_model.requested_neighbors_lists()
)
system = systems_to_torch(system, dtype=torch.get_default_dtype())
system = systems_to_torch(system)
system = get_system_with_neighbors_lists(
system, alchemical_model.requested_neighbors_lists()
)
Expand All @@ -63,7 +63,7 @@ def test_rotational_invariance():
check_consistency=True,
)

assert torch.allclose(
torch.testing.assert_close(
original_output["energy"].block().values,
rotated_output["energy"].block().values,
)
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@ def test_regression_init():

# Predict on the first five systems
systems = ase.io.read(DATASET_PATH, ":5")
systems = [
systems_to_torch(system, dtype=torch.get_default_dtype()) for system in systems
]
systems = [systems_to_torch(system) for system in systems]
systems = [
get_system_with_neighbors_lists(
system, alchemical_model.requested_neighbors_lists()
Expand All @@ -71,7 +69,9 @@ def test_regression_init():

expected_output = torch.tensor([[-1.9819], [0.1507], [1.6116], [3.4118], [0.8383]])

assert torch.allclose(output["U0"].block().values, expected_output, atol=1e-4)
torch.testing.assert_close(
output["U0"].block().values, expected_output, rtol=1e-05, atol=1e-4
)


def test_regression_train():
Expand All @@ -83,7 +83,7 @@ def test_regression_train():
np.random.seed(0)
torch.manual_seed(0)

systems = read_systems(DATASET_PATH, dtype=torch.get_default_dtype())
systems = read_systems(DATASET_PATH)
conf = {
"U0": {
"quantity": "energy",
Expand All @@ -95,7 +95,7 @@ def test_regression_train():
"virial": False,
}
}
targets = read_targets(OmegaConf.create(conf), dtype=torch.get_default_dtype())
targets = read_targets(OmegaConf.create(conf))
dataset = Dataset(system=systems, U0=targets["U0"])

hypers = DEFAULT_HYPERS.copy()
Expand Down Expand Up @@ -137,4 +137,6 @@ def test_regression_train():
[[-118.6454], [-106.1644], [-137.0310], [-164.7832], [-139.8678]]
)

assert torch.allclose(output["U0"].block().values, expected_output, atol=1e-4)
torch.testing.assert_close(
output["U0"].block().values, expected_output, rtol=1e-05, atol=1e-4
)
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
np.random.seed(0)
torch.manual_seed(0)

systems = read_systems(DATASET_PATH, dtype=torch.get_default_dtype())
systems = read_systems(DATASET_PATH)
nl_options = NeighborsListOptions(
cutoff=5.0,
full_list=True,
Expand All @@ -50,18 +50,18 @@

def test_systems_to_torch_alchemical_batch():
batch_dict = systems_to_torch_alchemical_batch(systems, nl_options)
assert torch.allclose(batch_dict["positions"], batch.pos)
assert torch.allclose(batch_dict["cells"], batch.cell)
assert torch.allclose(batch_dict["numbers"], batch.numbers)
torch.testing.assert_close(batch_dict["positions"], batch.pos)
torch.testing.assert_close(batch_dict["cells"], batch.cell)
torch.testing.assert_close(batch_dict["numbers"], batch.numbers)
index_1, counts_1 = torch.unique(batch_dict["batch"], return_counts=True)
index_2, counts_2 = torch.unique(batch.batch, return_counts=True)
assert torch.allclose(index_1, index_2)
assert torch.allclose(counts_1, counts_2)
torch.testing.assert_close(index_1, index_2)
torch.testing.assert_close(counts_1, counts_2)
offset_1, counts_1 = torch.unique(batch_dict["edge_offsets"], return_counts=True)
offset_2, counts_2 = torch.unique(batch.edge_offsets, return_counts=True)
assert torch.allclose(offset_1, offset_2)
assert torch.allclose(counts_1, counts_2)
assert torch.allclose(batch_dict["batch"], batch.batch)
torch.testing.assert_close(offset_1, offset_2)
torch.testing.assert_close(counts_1, counts_2)
torch.testing.assert_close(batch_dict["batch"], batch.batch)


def test_alchemical_model_inference():
Expand Down Expand Up @@ -114,4 +114,4 @@ def test_alchemical_model_inference():
edge_offsets=batch.edge_offsets,
batch=batch.batch,
)
assert torch.allclose(output["energy"].block().values, original_output)
torch.testing.assert_close(output["energy"].block().values, original_output)
6 changes: 4 additions & 2 deletions src/metatensor/models/experimental/alchemical_model/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,10 @@ def train(
model.set_basis_normalization_factor(average_number_of_neighbors)

device = devices[0] # only one device, as we don't support multi-gpu for now
logger.info(f"Training on device {device}")
model.to(device)
dtype = train_datasets[0][0].system.positions.dtype

logger.info(f"training on device {device} with dtype {dtype}")
model.to(device=device, dtype=dtype)

# Calculate and set the composition weights for all targets:
for target_name in novel_capabilities.outputs.keys():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@ def get_average_number_of_atoms(
"""
average_number_of_atoms = []
for dataset in datasets:
dtype = dataset[0].system.positions.dtype
num_atoms = []
for i in range(len(dataset)):
system = dataset[i].system
num_atoms.append(len(system))
average_number_of_atoms.append(
torch.mean(torch.tensor(num_atoms).to(torch.get_default_dtype()))
)
average_number_of_atoms.append(torch.mean(torch.tensor(num_atoms, dtype=dtype)))
return torch.tensor(average_number_of_atoms)


Expand All @@ -38,6 +37,7 @@ def get_average_number_of_neighbors(
average_number_of_neighbors = []
for dataset in datasets:
num_neighbors = []
dtype = dataset[0].system.positions.dtype
for i in range(len(dataset)):
system = dataset[i].system
known_neighbors_lists = system.known_neighbors_lists()
Expand All @@ -51,7 +51,7 @@ def get_average_number_of_neighbors(
num_neighbors.append(
torch.mean(
torch.unique(nl.samples["first_atom"], return_counts=True)[1].to(
torch.get_default_dtype()
dtype
)
)
)
Expand Down
2 changes: 1 addition & 1 deletion src/metatensor/models/experimental/pet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def train(
continue_from: Optional[str] = None,
output_dir: str = ".",
):
if torch.get_default_dtype() != torch.float32:
if train_datasets[0][0].system.positions.dtype != torch.float32:
raise ValueError("PET only supports float32")
if len(dataset_info.targets) != 1:
raise ValueError("PET only supports a single target")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_continue(monkeypatch, tmp_path):
monkeypatch.chdir(tmp_path)
shutil.copy(DATASET_PATH, "qm9_reduced_100.xyz")

systems = read_systems(DATASET_PATH, dtype=torch.get_default_dtype())
systems = read_systems(DATASET_PATH)

capabilities = ModelCapabilities(
length_unit="Angstrom",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_prediction_subset_elements():

system = ase.Atoms("O2", positions=[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]])
soap_bpnn(
[systems_to_torch(system, dtype=torch.get_default_dtype())],
[systems_to_torch(system)],
{"energy": soap_bpnn.capabilities.outputs["energy"]},
)

Expand Down Expand Up @@ -55,7 +55,7 @@ def test_prediction_subset_atoms():
)

energy_monomer = soap_bpnn(
[systems_to_torch(system_monomer).to(torch.get_default_dtype())],
[systems_to_torch(system_monomer)],
{"energy": soap_bpnn.capabilities.outputs["energy"]},
)

Expand All @@ -77,12 +77,12 @@ def test_prediction_subset_atoms():
)

energy_dimer = soap_bpnn(
[systems_to_torch(system_far_away_dimer).to(torch.get_default_dtype())],
[systems_to_torch(system_far_away_dimer)],
{"energy": soap_bpnn.capabilities.outputs["energy"]},
)

energy_monomer_in_dimer = soap_bpnn(
[systems_to_torch(system_far_away_dimer).to(torch.get_default_dtype())],
[systems_to_torch(system_far_away_dimer)],
{"energy": soap_bpnn.capabilities.outputs["energy"]},
selected_atoms=selection_labels,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_rotational_invariance():
{"energy": soap_bpnn.capabilities.outputs["energy"]},
)

assert torch.allclose(
torch.testing.assert_close(
original_output["energy"].block().values,
rotated_output["energy"].block().values,
)
Original file line number Diff line number Diff line change
Expand Up @@ -39,22 +39,21 @@ def test_regression_init():
systems = ase.io.read(DATASET_PATH, ":5")

output = soap_bpnn(
[
systems_to_torch(system, dtype=torch.get_default_dtype())
for system in systems
],
[systems_to_torch(system) for system in systems],
{"U0": soap_bpnn.capabilities.outputs["U0"]},
)
expected_output = torch.tensor([[0.0739], [0.0758], [0.1782], [-0.3517], [-0.3251]])

assert torch.allclose(output["U0"].block().values, expected_output, rtol=1e-3)
torch.testing.assert_close(
output["U0"].block().values, expected_output, rtol=1e-3, atol=1e-08
)


def test_regression_train():
"""Perform a regression test on the model when
trained for 2 epoch on a small dataset"""

systems = read_systems(DATASET_PATH, dtype=torch.get_default_dtype())
systems = read_systems(DATASET_PATH)

conf = {
"U0": {
Expand All @@ -67,7 +66,7 @@ def test_regression_train():
"virial": False,
}
}
targets = read_targets(OmegaConf.create(conf), dtype=torch.get_default_dtype())
targets = read_targets(OmegaConf.create(conf))
dataset = Dataset(system=systems, U0=targets["U0"])

hypers = DEFAULT_HYPERS.copy()
Expand All @@ -91,4 +90,6 @@ def test_regression_train():
[[-40.3951], [-56.4275], [-76.4008], [-77.3751], [-93.4227]]
)

assert torch.allclose(output["U0"].block().values, expected_output, rtol=1e-3)
torch.testing.assert_close(
output["U0"].block().values, expected_output, rtol=1e-3, atol=1e-08
)
9 changes: 5 additions & 4 deletions src/metatensor/models/experimental/soap_bpnn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,10 @@ def train(
)

device = devices[0] # only one device, as we don't support multi-gpu for now
logger.info(f"Training on device {device}")
model.to(device)
dtype = train_datasets[0][0].system.positions.dtype

logger.info(f"training on device {device} with dtype {dtype}")
model.to(device=device, dtype=dtype)

hypers_training = hypers["training"]

Expand All @@ -119,11 +121,10 @@ def train(
f"For {target_name}, model will proceed with "
"user-supplied composition weights"
)

cur_weight_dict = hypers_training["fixed_composition_weights"][target_name]
species = []
num_species = len(cur_weight_dict)
fixed_weights = torch.zeros(num_species, device=device)
fixed_weights = torch.zeros(num_species, dtype=dtype, device=device)

for ii, (key, weight) in enumerate(cur_weight_dict.items()):
species.append(key)
Expand Down
1 change: 1 addition & 0 deletions src/metatensor/models/utils/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def check_datasets(
(if ``false``) upon detection of a chemical species or target in the
validation set that is not present in the training set.
"""
# TODO: Check that `dtypes` are consistent within datasets

# Get all targets in the training and validation sets:
train_targets = get_all_targets(train_datasets)
Expand Down
Loading

0 comments on commit a7f21d7

Please sign in to comment.