Skip to content

Commit

Permalink
Merge branch 'main' into units
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Jun 13, 2024
2 parents 315f159 + c6b594b commit 47597a7
Show file tree
Hide file tree
Showing 10 changed files with 136 additions and 149 deletions.
6 changes: 3 additions & 3 deletions docs/src/dev-docs/new-architecture.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ to these lines
model=model,
devices=[],
train_datasets=[],
validation_datasets=[],
val_datasets=[],
checkpoint_dir="path",
)
Expand All @@ -53,7 +53,7 @@ In order to follow this, a new architectures has two define two classes
when a user attempts to train an architecture with unsupported target and dataset
combinations. Therefore, it is the responsibility of the architecture developer to
verify if the model and the trainer support the provided train_datasets and
validation_datasets passed to the Trainer, as well as the dataset_info passed to the
val_datasets passed to the Trainer, as well as the dataset_info passed to the
model.

The ``ModelInterface`` is the main model class and must implement a
Expand Down Expand Up @@ -119,7 +119,7 @@ methods for ``train()``.
model: ModelInterface,
devices: List[torch.device],
train_datasets: List[Union[Dataset, torch.utils.data.Subset]],
validation_datasets: List[Union[Dataset, torch.utils.data.Subset]],
val_datasets: List[Union[Dataset, torch.utils.data.Subset]],
checkpoint_dir: str,
) -> None: ...
Expand Down
8 changes: 4 additions & 4 deletions examples/ase/run_ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@
# Next, we initialize the simulation by extracting the initial positions from the
# dataset file which we initially trained the model on.

training_frames = ase.io.read("ethanol_reduced_100.xyz", ":")
atoms = training_frames[0].copy()
train_frames = ase.io.read("ethanol_reduced_100.xyz", ":")
atoms = train_frames[0].copy()

# %%
#
Expand Down Expand Up @@ -168,7 +168,7 @@
# To use the RDF code from ase we first have to define a unit cell for our systems.
# We choose a cubic one with a side length of 10 Å.

for atoms in training_frames:
for atoms in train_frames:
atoms.cell = 10 * np.ones(3)
atoms.pbc = True

Expand All @@ -183,7 +183,7 @@
# method.

ana_traj = Analysis(trajectory)
ana_train = Analysis(training_frames)
ana_train = Analysis(train_frames)

rdf_traj = ana_traj.get_rdf(rmax=5, nbins=50, elements=["C", "H"], return_dists=True)
rdf_train = ana_train.get_rdf(rmax=5, nbins=50, elements=["C", "H"], return_dists=True)
Expand Down
67 changes: 29 additions & 38 deletions src/metatrain/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,13 +265,13 @@ def train_model(
###########################

logger.info("Setting up validation set")
validation_options = options["validation_set"]
validation_datasets = []
if isinstance(validation_options, float):
validation_size = validation_options
train_size -= validation_size
val_options = options["validation_set"]
val_datasets = []
if isinstance(val_options, float):
val_size = val_options
train_size -= val_size

if validation_size <= 0 or validation_size >= 1:
if val_size <= 0 or val_size >= 1:
raise ValueError(
"Validation set split must be greater than 0 and lesser than 1."
)
Expand All @@ -281,51 +281,43 @@ def train_model(
generator.manual_seed(options["seed"])

for i_dataset, train_dataset in enumerate(train_datasets):
train_dataset_new, validation_dataset = _train_test_random_split(
train_dataset_new, val_dataset = _train_test_random_split(
train_dataset=train_dataset,
train_size=train_size,
test_size=validation_size,
test_size=val_size,
generator=generator,
)

train_datasets[i_dataset] = train_dataset_new
validation_datasets.append(validation_dataset)
val_datasets.append(val_dataset)
else:
validation_options_list = expand_dataset_config(validation_options)
check_options_list(validation_options_list)
val_options_list = expand_dataset_config(val_options)
check_options_list(val_options_list)

if len(validation_options_list) != len(train_options_list):
if len(val_options_list) != len(train_options_list):
raise ValueError(
f"Validation dataset with length {len(validation_options_list)} has "
f"Validation dataset with length {len(val_options_list)} has "
"a different size than the train datatset with length "
f"{len(train_options_list)}."
)

check_units(
actual_options=validation_options_list, desired_options=train_options_list
)
check_units(actual_options=val_options_list, desired_options=train_options_list)

for validation_options in validation_options_list:
validation_systems = read_systems(
filename=validation_options["systems"]["read_from"],
fileformat=validation_options["systems"]["file_format"],
for val_options in val_options_list:
val_systems = read_systems(
filename=val_options["systems"]["read_from"],
fileformat=val_options["systems"]["file_format"],
dtype=dtype,
)
validation_targets, _ = read_targets(
conf=validation_options["targets"], dtype=dtype
)
validation_dataset = Dataset(
{"system": validation_systems, **validation_targets}
)
validation_datasets.append(validation_dataset)
val_targets, _ = read_targets(conf=val_options["targets"], dtype=dtype)
val_dataset = Dataset({"system": val_systems, **val_targets})
val_datasets.append(val_dataset)

###########################
# CREATE DATASET_INFO #####
###########################

atomic_types = get_atomic_types(
train_datasets + train_datasets + validation_datasets
)
atomic_types = get_atomic_types(train_datasets + val_datasets)

dataset_info = DatasetInfo(
length_unit=train_options_list[0]["systems"]["length_unit"],
Expand All @@ -346,14 +338,13 @@ def train_model(
f"Training dataset{index}:\n {train_dataset.get_stats(dataset_info)}"
)

for i, validation_dataset in enumerate(validation_datasets):
if len(validation_datasets) == 1:
for i, val_dataset in enumerate(val_datasets):
if len(val_datasets) == 1:
index = ""
else:
index = f" {i}"
logger.info(
f"Validation dataset{index}:\n "
f"{validation_dataset.get_stats(dataset_info)}"
f"Validation dataset{index}:\n {val_dataset.get_stats(dataset_info)}"
)

for i, test_dataset in enumerate(test_datasets):
Expand Down Expand Up @@ -397,7 +388,7 @@ def train_model(
model=model,
devices=devices,
train_datasets=train_datasets,
validation_datasets=validation_datasets,
val_datasets=val_datasets,
checkpoint_dir=str(checkpoint_dir),
)
except Exception as e:
Expand Down Expand Up @@ -443,16 +434,16 @@ def train_model(
return_predictions=False,
)

for i, validation_dataset in enumerate(validation_datasets):
if len(validation_datasets) == 1:
for i, val_dataset in enumerate(val_datasets):
if len(val_datasets) == 1:
extra_log_message = ""
else:
extra_log_message = f" with index {i}"

logger.info(f"Evaluating validation dataset{extra_log_message}")
_eval_targets(
mts_atomistic_model,
validation_dataset,
val_dataset,
dataset_info.targets,
return_predictions=False,
)
Expand Down
54 changes: 26 additions & 28 deletions src/metatrain/experimental/alchemical_model/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def train(
model: AlchemicalModel,
devices: List[torch.device],
train_datasets: List[Union[Dataset, torch.utils.data.Subset]],
validation_datasets: List[Union[Dataset, torch.utils.data.Subset]],
val_datasets: List[Union[Dataset, torch.utils.data.Subset]],
checkpoint_dir: str,
):
dtype = train_datasets[0][0]["system"].positions.dtype
Expand All @@ -57,12 +57,12 @@ def train(

# Perform canonical checks on the datasets:
logger.info("Checking datasets for consistency")
check_datasets(train_datasets, validation_datasets)
check_datasets(train_datasets, val_datasets)

# Calculating the neighbor lists for the training and validation datasets:
logger.info("Calculating neighbor lists for the datasets")
requested_neighbor_lists = model.requested_neighbor_lists()
for dataset in train_datasets + validation_datasets:
for dataset in train_datasets + val_datasets:
for i in range(len(dataset)):
system = dataset[i]["system"]
# The following line attaches the neighbors lists to the system,
Expand Down Expand Up @@ -112,9 +112,9 @@ def train(
model.alchemical_model.composition_weights.squeeze(0),
)
]
validation_datasets = [
val_datasets = [
remove_composition_from_dataset(
validation_datasets[0],
val_datasets[0],
model.atomic_types,
model.alchemical_model.composition_weights.squeeze(0),
)
Expand All @@ -136,19 +136,17 @@ def train(
train_dataloader = CombinedDataLoader(train_dataloaders, shuffle=True)

# Create dataloader for the validation datasets:
validation_dataloaders = []
for dataset in validation_datasets:
validation_dataloaders.append(
val_dataloaders = []
for dataset in val_datasets:
val_dataloaders.append(
DataLoader(
dataset=dataset,
batch_size=self.hypers["batch_size"],
shuffle=False,
collate_fn=collate_fn,
)
)
validation_dataloader = CombinedDataLoader(
validation_dataloaders, shuffle=False
)
val_dataloader = CombinedDataLoader(val_dataloaders, shuffle=False)

# Extract all the possible outputs and their gradients:
outputs_list = []
Expand Down Expand Up @@ -190,7 +188,7 @@ def train(
)

# counters for early stopping:
best_validation_loss = float("inf")
best_val_loss = float("inf")
epochs_without_improvement = 0

# per-atom targets:
Expand All @@ -200,7 +198,7 @@ def train(
logger.info("Starting training")
for epoch in range(self.hypers["num_epochs"]):
train_rmse_calculator = RMSEAccumulator()
validation_rmse_calculator = RMSEAccumulator()
val_rmse_calculator = RMSEAccumulator()

train_loss = 0.0
for batch in train_dataloader:
Expand Down Expand Up @@ -239,8 +237,8 @@ def train(
not_per_atom=["positions_gradients"] + per_structure_targets
)

validation_loss = 0.0
for batch in validation_dataloader:
val_loss = 0.0
for batch in val_dataloader:
systems, targets = batch
assert len(systems[0].known_neighbor_lists()) > 0
systems = [system.to(device=device) for system in systems]
Expand All @@ -265,41 +263,41 @@ def train(
)
targets = average_by_num_atoms(targets, systems, per_structure_targets)

validation_loss_batch = loss_fn(predictions, targets)
validation_loss += validation_loss_batch.item()
validation_rmse_calculator.update(predictions, targets)
finalized_validation_info = validation_rmse_calculator.finalize(
val_loss_batch = loss_fn(predictions, targets)
val_loss += val_loss_batch.item()
val_rmse_calculator.update(predictions, targets)
finalized_val_info = val_rmse_calculator.finalize(
not_per_atom=["positions_gradients"] + per_structure_targets
)

lr_scheduler.step(validation_loss)
lr_scheduler.step(val_loss)

# Now we log the information:
finalized_train_info = {"loss": train_loss, **finalized_train_info}
finalized_validation_info = {
"loss": validation_loss,
**finalized_validation_info,
finalized_val_info = {
"loss": val_loss,
**finalized_val_info,
}

if epoch == 0:
metric_logger = MetricLogger(
logobj=logger,
dataset_info=model.dataset_info,
initial_metrics=[finalized_train_info, finalized_validation_info],
names=["train", "validation"],
initial_metrics=[finalized_train_info, finalized_val_info],
names=["training", "validation"],
)
if epoch % self.hypers["log_interval"] == 0:
metric_logger.log(
metrics=[finalized_train_info, finalized_validation_info],
metrics=[finalized_train_info, finalized_val_info],
epoch=epoch,
)

if epoch % self.hypers["checkpoint_interval"] == 0:
model.save_checkpoint(Path(checkpoint_dir) / f"model_{epoch}.ckpt")

# early stopping criterion:
if validation_loss < best_validation_loss:
best_validation_loss = validation_loss
if val_loss < best_val_loss:
best_val_loss = val_loss
epochs_without_improvement = 0
else:
epochs_without_improvement += 1
Expand Down
16 changes: 8 additions & 8 deletions src/metatrain/experimental/gap/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,8 +478,8 @@ def __init__(
structurewise_aggregate: bool = False,
):
super().__init__()
valid_aggregate_types = ["sum", "mean"]
if aggregate_type not in valid_aggregate_types:
val_aggregate_types = ["sum", "mean"]
if aggregate_type not in val_aggregate_types:
raise ValueError(
f"Given aggregate_type {aggregate_type!r} but only "
f"{aggregate_type!r} are supported."
Expand Down Expand Up @@ -604,8 +604,8 @@ def __init__(
structurewise_aggregate: bool = False,
):
super().__init__()
valid_aggregate_types = ["sum", "mean"]
if aggregate_type not in valid_aggregate_types:
val_aggregate_types = ["sum", "mean"]
if aggregate_type not in val_aggregate_types:
raise ValueError(
f"Given aggregate_type {aggregate_type} but only "
f"{aggregate_type} are supported."
Expand Down Expand Up @@ -999,7 +999,7 @@ def __init__(
self._weights = None

def _set_kernel(self, kernel: Union[str, AggregateKernel], **kernel_kwargs):
valid_kernels = ["linear", "polynomial", "precomputed"]
val_kernels = ["linear", "polynomial", "precomputed"]
aggregate_type = kernel_kwargs.get("aggregate_type", "sum")
if aggregate_type != "sum":
raise ValueError(
Expand All @@ -1017,7 +1017,7 @@ def _set_kernel(self, kernel: Union[str, AggregateKernel], **kernel_kwargs):
else:
raise ValueError(
f"kernel type {kernel!r} is not supported. Please use one "
f"of the valid kernels {valid_kernels!r}"
f"of the valid kernels {val_kernels!r}"
)

def fit(
Expand Down Expand Up @@ -1222,7 +1222,7 @@ def forward(self, T: TorchTensorMap) -> TorchTensorMap:
return metatensor.torch.dot(k_tm, self._weights)

def _set_kernel(self, kernel: Union[str, TorchAggregateKernel], **kernel_kwargs):
valid_kernels = ["linear", "polynomial", "precomputed"]
val_kernels = ["linear", "polynomial", "precomputed"]
aggregate_type = kernel_kwargs.get("aggregate_type", "sum")
if aggregate_type != "sum":
raise ValueError(
Expand All @@ -1244,5 +1244,5 @@ def _set_kernel(self, kernel: Union[str, TorchAggregateKernel], **kernel_kwargs)
else:
raise ValueError(
f"kernel type {kernel!r} is not supported. Please use one "
f"of the valid kernels {valid_kernels!r}"
f"of the valid kernels {val_kernels!r}"
)
Loading

0 comments on commit 47597a7

Please sign in to comment.