diff --git a/docs/src/dev-docs/new-architecture.rst b/docs/src/dev-docs/new-architecture.rst index 60f5129b9..396f07a0f 100644 --- a/docs/src/dev-docs/new-architecture.rst +++ b/docs/src/dev-docs/new-architecture.rst @@ -28,7 +28,7 @@ to these lines model=model, devices=[], train_datasets=[], - validation_datasets=[], + val_datasets=[], checkpoint_dir="path", ) @@ -53,7 +53,7 @@ In order to follow this, a new architectures has two define two classes when a user attempts to train an architecture with unsupported target and dataset combinations. Therefore, it is the responsibility of the architecture developer to verify if the model and the trainer support the provided train_datasets and - validation_datasets passed to the Trainer, as well as the dataset_info passed to the + val_datasets passed to the Trainer, as well as the dataset_info passed to the model. The ``ModelInterface`` is the main model class and must implement a @@ -119,7 +119,7 @@ methods for ``train()``. model: ModelInterface, devices: List[torch.device], train_datasets: List[Union[Dataset, torch.utils.data.Subset]], - validation_datasets: List[Union[Dataset, torch.utils.data.Subset]], + val_datasets: List[Union[Dataset, torch.utils.data.Subset]], checkpoint_dir: str, ) -> None: ... diff --git a/examples/ase/run_ase.py b/examples/ase/run_ase.py index 66261cb76..2c9b48a05 100644 --- a/examples/ase/run_ase.py +++ b/examples/ase/run_ase.py @@ -55,8 +55,8 @@ # Next, we initialize the simulation by extracting the initial positions from the # dataset file which we initially trained the model on. -training_frames = ase.io.read("ethanol_reduced_100.xyz", ":") -atoms = training_frames[0].copy() +train_frames = ase.io.read("ethanol_reduced_100.xyz", ":") +atoms = train_frames[0].copy() # %% # @@ -168,7 +168,7 @@ # To use the RDF code from ase we first have to define a unit cell for our systems. # We choose a cubic one with a side length of 10 Å. -for atoms in training_frames: +for atoms in train_frames: atoms.cell = 10 * np.ones(3) atoms.pbc = True @@ -183,7 +183,7 @@ # method. ana_traj = Analysis(trajectory) -ana_train = Analysis(training_frames) +ana_train = Analysis(train_frames) rdf_traj = ana_traj.get_rdf(rmax=5, nbins=50, elements=["C", "H"], return_dists=True) rdf_train = ana_train.get_rdf(rmax=5, nbins=50, elements=["C", "H"], return_dists=True) diff --git a/src/metatrain/cli/train.py b/src/metatrain/cli/train.py index 60b336a93..d5438945b 100644 --- a/src/metatrain/cli/train.py +++ b/src/metatrain/cli/train.py @@ -265,13 +265,13 @@ def train_model( ########################### logger.info("Setting up validation set") - validation_options = options["validation_set"] - validation_datasets = [] - if isinstance(validation_options, float): - validation_size = validation_options - train_size -= validation_size + val_options = options["validation_set"] + val_datasets = [] + if isinstance(val_options, float): + val_size = val_options + train_size -= val_size - if validation_size <= 0 or validation_size >= 1: + if val_size <= 0 or val_size >= 1: raise ValueError( "Validation set split must be greater than 0 and lesser than 1." ) @@ -281,51 +281,43 @@ def train_model( generator.manual_seed(options["seed"]) for i_dataset, train_dataset in enumerate(train_datasets): - train_dataset_new, validation_dataset = _train_test_random_split( + train_dataset_new, val_dataset = _train_test_random_split( train_dataset=train_dataset, train_size=train_size, - test_size=validation_size, + test_size=val_size, generator=generator, ) train_datasets[i_dataset] = train_dataset_new - validation_datasets.append(validation_dataset) + val_datasets.append(val_dataset) else: - validation_options_list = expand_dataset_config(validation_options) - check_options_list(validation_options_list) + val_options_list = expand_dataset_config(val_options) + check_options_list(val_options_list) - if len(validation_options_list) != len(train_options_list): + if len(val_options_list) != len(train_options_list): raise ValueError( - f"Validation dataset with length {len(validation_options_list)} has " + f"Validation dataset with length {len(val_options_list)} has " "a different size than the train datatset with length " f"{len(train_options_list)}." ) - check_units( - actual_options=validation_options_list, desired_options=train_options_list - ) + check_units(actual_options=val_options_list, desired_options=train_options_list) - for validation_options in validation_options_list: - validation_systems = read_systems( - filename=validation_options["systems"]["read_from"], - fileformat=validation_options["systems"]["file_format"], + for val_options in val_options_list: + val_systems = read_systems( + filename=val_options["systems"]["read_from"], + fileformat=val_options["systems"]["file_format"], dtype=dtype, ) - validation_targets, _ = read_targets( - conf=validation_options["targets"], dtype=dtype - ) - validation_dataset = Dataset( - {"system": validation_systems, **validation_targets} - ) - validation_datasets.append(validation_dataset) + val_targets, _ = read_targets(conf=val_options["targets"], dtype=dtype) + val_dataset = Dataset({"system": val_systems, **val_targets}) + val_datasets.append(val_dataset) ########################### # CREATE DATASET_INFO ##### ########################### - atomic_types = get_atomic_types( - train_datasets + train_datasets + validation_datasets - ) + atomic_types = get_atomic_types(train_datasets + val_datasets) dataset_info = DatasetInfo( length_unit=train_options_list[0]["systems"]["length_unit"], @@ -346,14 +338,13 @@ def train_model( f"Training dataset{index}:\n {train_dataset.get_stats(dataset_info)}" ) - for i, validation_dataset in enumerate(validation_datasets): - if len(validation_datasets) == 1: + for i, val_dataset in enumerate(val_datasets): + if len(val_datasets) == 1: index = "" else: index = f" {i}" logger.info( - f"Validation dataset{index}:\n " - f"{validation_dataset.get_stats(dataset_info)}" + f"Validation dataset{index}:\n {val_dataset.get_stats(dataset_info)}" ) for i, test_dataset in enumerate(test_datasets): @@ -397,7 +388,7 @@ def train_model( model=model, devices=devices, train_datasets=train_datasets, - validation_datasets=validation_datasets, + val_datasets=val_datasets, checkpoint_dir=str(checkpoint_dir), ) except Exception as e: @@ -443,8 +434,8 @@ def train_model( return_predictions=False, ) - for i, validation_dataset in enumerate(validation_datasets): - if len(validation_datasets) == 1: + for i, val_dataset in enumerate(val_datasets): + if len(val_datasets) == 1: extra_log_message = "" else: extra_log_message = f" with index {i}" @@ -452,7 +443,7 @@ def train_model( logger.info(f"Evaluating validation dataset{extra_log_message}") _eval_targets( mts_atomistic_model, - validation_dataset, + val_dataset, dataset_info.targets, return_predictions=False, ) diff --git a/src/metatrain/experimental/alchemical_model/trainer.py b/src/metatrain/experimental/alchemical_model/trainer.py index 03618eff2..42860cd45 100644 --- a/src/metatrain/experimental/alchemical_model/trainer.py +++ b/src/metatrain/experimental/alchemical_model/trainer.py @@ -41,7 +41,7 @@ def train( model: AlchemicalModel, devices: List[torch.device], train_datasets: List[Union[Dataset, torch.utils.data.Subset]], - validation_datasets: List[Union[Dataset, torch.utils.data.Subset]], + val_datasets: List[Union[Dataset, torch.utils.data.Subset]], checkpoint_dir: str, ): dtype = train_datasets[0][0]["system"].positions.dtype @@ -57,12 +57,12 @@ def train( # Perform canonical checks on the datasets: logger.info("Checking datasets for consistency") - check_datasets(train_datasets, validation_datasets) + check_datasets(train_datasets, val_datasets) # Calculating the neighbor lists for the training and validation datasets: logger.info("Calculating neighbor lists for the datasets") requested_neighbor_lists = model.requested_neighbor_lists() - for dataset in train_datasets + validation_datasets: + for dataset in train_datasets + val_datasets: for i in range(len(dataset)): system = dataset[i]["system"] # The following line attaches the neighbors lists to the system, @@ -112,9 +112,9 @@ def train( model.alchemical_model.composition_weights.squeeze(0), ) ] - validation_datasets = [ + val_datasets = [ remove_composition_from_dataset( - validation_datasets[0], + val_datasets[0], model.atomic_types, model.alchemical_model.composition_weights.squeeze(0), ) @@ -136,9 +136,9 @@ def train( train_dataloader = CombinedDataLoader(train_dataloaders, shuffle=True) # Create dataloader for the validation datasets: - validation_dataloaders = [] - for dataset in validation_datasets: - validation_dataloaders.append( + val_dataloaders = [] + for dataset in val_datasets: + val_dataloaders.append( DataLoader( dataset=dataset, batch_size=self.hypers["batch_size"], @@ -146,9 +146,7 @@ def train( collate_fn=collate_fn, ) ) - validation_dataloader = CombinedDataLoader( - validation_dataloaders, shuffle=False - ) + val_dataloader = CombinedDataLoader(val_dataloaders, shuffle=False) # Extract all the possible outputs and their gradients: outputs_list = [] @@ -190,7 +188,7 @@ def train( ) # counters for early stopping: - best_validation_loss = float("inf") + best_val_loss = float("inf") epochs_without_improvement = 0 # per-atom targets: @@ -200,7 +198,7 @@ def train( logger.info("Starting training") for epoch in range(self.hypers["num_epochs"]): train_rmse_calculator = RMSEAccumulator() - validation_rmse_calculator = RMSEAccumulator() + val_rmse_calculator = RMSEAccumulator() train_loss = 0.0 for batch in train_dataloader: @@ -239,8 +237,8 @@ def train( not_per_atom=["positions_gradients"] + per_structure_targets ) - validation_loss = 0.0 - for batch in validation_dataloader: + val_loss = 0.0 + for batch in val_dataloader: systems, targets = batch assert len(systems[0].known_neighbor_lists()) > 0 systems = [system.to(device=device) for system in systems] @@ -265,32 +263,32 @@ def train( ) targets = average_by_num_atoms(targets, systems, per_structure_targets) - validation_loss_batch = loss_fn(predictions, targets) - validation_loss += validation_loss_batch.item() - validation_rmse_calculator.update(predictions, targets) - finalized_validation_info = validation_rmse_calculator.finalize( + val_loss_batch = loss_fn(predictions, targets) + val_loss += val_loss_batch.item() + val_rmse_calculator.update(predictions, targets) + finalized_val_info = val_rmse_calculator.finalize( not_per_atom=["positions_gradients"] + per_structure_targets ) - lr_scheduler.step(validation_loss) + lr_scheduler.step(val_loss) # Now we log the information: finalized_train_info = {"loss": train_loss, **finalized_train_info} - finalized_validation_info = { - "loss": validation_loss, - **finalized_validation_info, + finalized_val_info = { + "loss": val_loss, + **finalized_val_info, } if epoch == 0: metric_logger = MetricLogger( logobj=logger, dataset_info=model.dataset_info, - initial_metrics=[finalized_train_info, finalized_validation_info], - names=["train", "validation"], + initial_metrics=[finalized_train_info, finalized_val_info], + names=["training", "validation"], ) if epoch % self.hypers["log_interval"] == 0: metric_logger.log( - metrics=[finalized_train_info, finalized_validation_info], + metrics=[finalized_train_info, finalized_val_info], epoch=epoch, ) @@ -298,8 +296,8 @@ def train( model.save_checkpoint(Path(checkpoint_dir) / f"model_{epoch}.ckpt") # early stopping criterion: - if validation_loss < best_validation_loss: - best_validation_loss = validation_loss + if val_loss < best_val_loss: + best_val_loss = val_loss epochs_without_improvement = 0 else: epochs_without_improvement += 1 diff --git a/src/metatrain/experimental/gap/model.py b/src/metatrain/experimental/gap/model.py index 06b314749..df63b4e9b 100644 --- a/src/metatrain/experimental/gap/model.py +++ b/src/metatrain/experimental/gap/model.py @@ -478,8 +478,8 @@ def __init__( structurewise_aggregate: bool = False, ): super().__init__() - valid_aggregate_types = ["sum", "mean"] - if aggregate_type not in valid_aggregate_types: + val_aggregate_types = ["sum", "mean"] + if aggregate_type not in val_aggregate_types: raise ValueError( f"Given aggregate_type {aggregate_type!r} but only " f"{aggregate_type!r} are supported." @@ -604,8 +604,8 @@ def __init__( structurewise_aggregate: bool = False, ): super().__init__() - valid_aggregate_types = ["sum", "mean"] - if aggregate_type not in valid_aggregate_types: + val_aggregate_types = ["sum", "mean"] + if aggregate_type not in val_aggregate_types: raise ValueError( f"Given aggregate_type {aggregate_type} but only " f"{aggregate_type} are supported." @@ -999,7 +999,7 @@ def __init__( self._weights = None def _set_kernel(self, kernel: Union[str, AggregateKernel], **kernel_kwargs): - valid_kernels = ["linear", "polynomial", "precomputed"] + val_kernels = ["linear", "polynomial", "precomputed"] aggregate_type = kernel_kwargs.get("aggregate_type", "sum") if aggregate_type != "sum": raise ValueError( @@ -1017,7 +1017,7 @@ def _set_kernel(self, kernel: Union[str, AggregateKernel], **kernel_kwargs): else: raise ValueError( f"kernel type {kernel!r} is not supported. Please use one " - f"of the valid kernels {valid_kernels!r}" + f"of the valid kernels {val_kernels!r}" ) def fit( @@ -1222,7 +1222,7 @@ def forward(self, T: TorchTensorMap) -> TorchTensorMap: return metatensor.torch.dot(k_tm, self._weights) def _set_kernel(self, kernel: Union[str, TorchAggregateKernel], **kernel_kwargs): - valid_kernels = ["linear", "polynomial", "precomputed"] + val_kernels = ["linear", "polynomial", "precomputed"] aggregate_type = kernel_kwargs.get("aggregate_type", "sum") if aggregate_type != "sum": raise ValueError( @@ -1244,5 +1244,5 @@ def _set_kernel(self, kernel: Union[str, TorchAggregateKernel], **kernel_kwargs) else: raise ValueError( f"kernel type {kernel!r} is not supported. Please use one " - f"of the valid kernels {valid_kernels!r}" + f"of the valid kernels {val_kernels!r}" ) diff --git a/src/metatrain/experimental/gap/trainer.py b/src/metatrain/experimental/gap/trainer.py index b113ffe03..54858c553 100644 --- a/src/metatrain/experimental/gap/trainer.py +++ b/src/metatrain/experimental/gap/trainer.py @@ -26,7 +26,7 @@ def train( model: GAP, devices: List[torch.device], train_datasets: List[Union[Dataset, torch.utils.data.Subset]], - validation_datasets: List[Union[Dataset, torch.utils.data.Subset]], + val_datasets: List[Union[Dataset, torch.utils.data.Subset]], checkpoint_dir: str, ): # checks @@ -36,7 +36,7 @@ def train( target_name = next(iter(model.dataset_info.targets.keys())) if len(train_datasets) != 1: raise ValueError("GAP only supports a single training dataset") - if len(validation_datasets) != 1: + if len(val_datasets) != 1: raise ValueError("GAP only supports a single validation dataset") outputs_dict = model.dataset_info.targets if len(outputs_dict.keys()) > 1: @@ -45,7 +45,7 @@ def train( # Perform checks on the datasets: logger.info("Checking datasets for consistency") - check_datasets(train_datasets, validation_datasets) + check_datasets(train_datasets, val_datasets) logger.info(f"Training on device cpu with dtype {dtype}") diff --git a/src/metatrain/experimental/pet/trainer.py b/src/metatrain/experimental/pet/trainer.py index b8193d22f..93b99c57a 100644 --- a/src/metatrain/experimental/pet/trainer.py +++ b/src/metatrain/experimental/pet/trainer.py @@ -26,22 +26,22 @@ def train( model: WrappedPET, devices: List[torch.device], train_datasets: List[Union[Dataset, torch.utils.data.Subset]], - validation_datasets: List[Union[Dataset, torch.utils.data.Subset]], + val_datasets: List[Union[Dataset, torch.utils.data.Subset]], checkpoint_dir: str, ): if len(train_datasets) != 1: raise ValueError("PET only supports a single training dataset") - if len(validation_datasets) != 1: + if len(val_datasets) != 1: raise ValueError("PET only supports a single validation dataset") if model.checkpoint_path is not None: self.hypers["FITTING_SCHEME"]["MODEL_TO_START_WITH"] = model.checkpoint_path logger.info("Checking datasets for consistency") - check_datasets(train_datasets, validation_datasets) + check_datasets(train_datasets, val_datasets) train_dataset = train_datasets[0] - validation_dataset = validation_datasets[0] + val_dataset = val_datasets[0] # dummy dataloaders due to https://github.com/lab-cosmo/metatensor/issues/521 train_dataloader = DataLoader( @@ -50,8 +50,8 @@ def train( shuffle=False, collate_fn=collate_fn, ) - validation_dataloader = DataLoader( - validation_dataset, + val_dataloader = DataLoader( + val_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, @@ -101,8 +101,8 @@ def train( ) ase_train_dataset.append(ase_atoms) - ase_validation_dataset = [] - for (system,), targets in validation_dataloader: + ase_val_dataset = [] + for (system,), targets in val_dataloader: ase_atoms = system_to_ase(system) ase_atoms.info["energy"] = float( targets[target_name].block().values.squeeze(-1).detach().cpu().numpy() @@ -117,13 +117,13 @@ def train( .cpu() .numpy() ) - ase_validation_dataset.append(ase_atoms) + ase_val_dataset.append(ase_atoms) device = devices[0] # only one device, as we don't support multi-gpu for now fit_pet( ase_train_dataset, - ase_validation_dataset, + ase_val_dataset, self.hypers, "pet", device, diff --git a/src/metatrain/experimental/soap_bpnn/trainer.py b/src/metatrain/experimental/soap_bpnn/trainer.py index f34ab01f9..532605971 100644 --- a/src/metatrain/experimental/soap_bpnn/trainer.py +++ b/src/metatrain/experimental/soap_bpnn/trainer.py @@ -43,7 +43,7 @@ def train( model: SoapBpnn, devices: List[torch.device], train_datasets: List[Union[Dataset, torch.utils.data.Subset]], - validation_datasets: List[Union[Dataset, torch.utils.data.Subset]], + val_datasets: List[Union[Dataset, torch.utils.data.Subset]], checkpoint_dir: str, ): dtype = train_datasets[0][0]["system"].positions.dtype @@ -117,9 +117,9 @@ def train( train_dataloader = CombinedDataLoader(train_dataloaders, shuffle=True) # Create dataloader for the validation datasets: - validation_dataloaders = [] - for dataset in validation_datasets: - validation_dataloaders.append( + val_dataloaders = [] + for dataset in val_datasets: + val_dataloaders.append( DataLoader( dataset=dataset, batch_size=self.hypers["batch_size"], @@ -127,14 +127,12 @@ def train( collate_fn=collate_fn, ) ) - validation_dataloader = CombinedDataLoader( - validation_dataloaders, shuffle=False - ) + val_dataloader = CombinedDataLoader(val_dataloaders, shuffle=False) # Extract all the possible outputs and their gradients: - training_targets = get_targets_dict(train_datasets, model.dataset_info) + train_targets = get_targets_dict(train_datasets, model.dataset_info) outputs_list = [] - for target_name, target_info in training_targets.items(): + for target_name, target_info in train_targets.items(): outputs_list.append(target_name) for gradient_name in target_info.gradients: outputs_list.append(f"{target_name}_{gradient_name}_gradients") @@ -143,14 +141,14 @@ def train( for output_name in outputs_list: loss_weights_dict[output_name] = ( self.hypers["loss_weights"][ - to_external_name(output_name, training_targets) + to_external_name(output_name, train_targets) ] - if to_external_name(output_name, training_targets) + if to_external_name(output_name, train_targets) in self.hypers["loss_weights"] else 1.0 ) loss_weights_dict_external = { - to_external_name(key, training_targets): value + to_external_name(key, train_targets): value for key, value in loss_weights_dict.items() } logging.info(f"Training with loss weights: {loss_weights_dict_external}") @@ -172,7 +170,7 @@ def train( ) # counters for early stopping: - best_validation_loss = float("inf") + best_val_loss = float("inf") epochs_without_improvement = 0 # per-atom targets: @@ -182,7 +180,7 @@ def train( logger.info("Starting training") for epoch in range(self.hypers["num_epochs"]): train_rmse_calculator = RMSEAccumulator() - validation_rmse_calculator = RMSEAccumulator() + val_rmse_calculator = RMSEAccumulator() train_loss = 0.0 for batch in train_dataloader: @@ -197,7 +195,7 @@ def train( model, systems, TargetInfoDict( - **{key: training_targets[key] for key in targets.keys()} + **{key: train_targets[key] for key in targets.keys()} ), is_training=True, ) @@ -217,8 +215,8 @@ def train( not_per_atom=["positions_gradients"] + per_structure_targets ) - validation_loss = 0.0 - for batch in validation_dataloader: + val_loss = 0.0 + for batch in val_dataloader: systems, targets = batch systems = [system.to(device=device) for system in systems] targets = { @@ -228,7 +226,7 @@ def train( model, systems, TargetInfoDict( - **{key: training_targets[key] for key in targets.keys()} + **{key: train_targets[key] for key in targets.keys()} ), is_training=False, ) @@ -239,20 +237,20 @@ def train( ) targets = average_by_num_atoms(targets, systems, per_structure_targets) - validation_loss_batch = loss_fn(predictions, targets) - validation_loss += validation_loss_batch.item() - validation_rmse_calculator.update(predictions, targets) - finalized_validation_info = validation_rmse_calculator.finalize( + val_loss_batch = loss_fn(predictions, targets) + val_loss += val_loss_batch.item() + val_rmse_calculator.update(predictions, targets) + finalized_val_info = val_rmse_calculator.finalize( not_per_atom=["positions_gradients"] + per_structure_targets ) - lr_scheduler.step(validation_loss) + lr_scheduler.step(val_loss) # Now we log the information: finalized_train_info = {"loss": train_loss, **finalized_train_info} - finalized_validation_info = { - "loss": validation_loss, - **finalized_validation_info, + finalized_val_info = { + "loss": val_loss, + **finalized_val_info, } if epoch == 0: @@ -260,11 +258,11 @@ def train( logobj=logger, dataset_info=model.dataset_info, initial_metrics=[finalized_train_info, finalized_validation_info], - names=["train", "validation"], + names=["training", "validation"], ) if epoch % self.hypers["log_interval"] == 0: metric_logger.log( - metrics=[finalized_train_info, finalized_validation_info], + metrics=[finalized_train_info, finalized_val_info], epoch=epoch, ) @@ -272,8 +270,8 @@ def train( model.save_checkpoint(Path(checkpoint_dir) / f"model_{epoch}.ckpt") # early stopping criterion: - if validation_loss < best_validation_loss: - best_validation_loss = validation_loss + if val_loss < best_val_loss: + best_val_loss = val_loss epochs_without_improvement = 0 else: epochs_without_improvement += 1 diff --git a/src/metatrain/utils/data/dataset.py b/src/metatrain/utils/data/dataset.py index 0d7e2d2c0..a7c750658 100644 --- a/src/metatrain/utils/data/dataset.py +++ b/src/metatrain/utils/data/dataset.py @@ -377,16 +377,16 @@ def collate_fn(batch: List[Dict[str, Any]]) -> Tuple[List, Dict[str, TensorMap]] return systems, collated_targets -def check_datasets(train_datasets: List[Dataset], validation_datasets: List[Dataset]): +def check_datasets(train_datasets: List[Dataset], val_datasets: List[Dataset]): """Check that the training and validation sets are compatible with one another Although these checks will not fit all use cases, most models would be expected to be able to use this function. :param train_datasets: A list of training datasets to check. - :param validation_datasets: A list of validation datasets to check + :param val_datasets: A list of validation datasets to check :raises TypeError: If the ``dtype`` within the datasets are inconsistent. - :raises ValueError: If the `validation_datasets` has a target that is not present in + :raises ValueError: If the `val_datasets` has a target that is not present in the ``train_datasets``. :raises ValueError: If the training or validation set contains chemical species or targets that are not present in the training set @@ -399,31 +399,31 @@ def check_datasets(train_datasets: List[Dataset], validation_datasets: List[Data if actual_dtype != desired_dtype: raise TypeError(f"{msg}{actual_dtype} found in `train_datasets`") - for validation_dataset in validation_datasets: - actual_dtype = validation_dataset[0]["system"].positions.dtype + for val_dataset in val_datasets: + actual_dtype = val_dataset[0]["system"].positions.dtype if actual_dtype != desired_dtype: - raise TypeError(f"{msg}{actual_dtype} found in `validation_datasets`") + raise TypeError(f"{msg}{actual_dtype} found in `val_datasets`") # Get all targets in the training and validation sets: train_targets = get_all_targets(train_datasets) - validation_targets = get_all_targets(validation_datasets) + val_targets = get_all_targets(val_datasets) # Check that the validation sets do not have targets that are not in the # training sets: - for target in validation_targets: + for target in val_targets: if target not in train_targets: raise ValueError( f"The validation dataset has a target ({target}) that is not present " "in the training dataset." ) # Get all the species in the training and validation sets: - all_training_species = get_atomic_types(train_datasets) - all_validation_species = get_atomic_types(validation_datasets) + all_train_species = get_atomic_types(train_datasets) + all_val_species = get_atomic_types(val_datasets) # Check that the validation sets do not have species that are not in the # training sets: - for species in all_validation_species: - if species not in all_training_species: + for species in all_val_species: + if species not in all_train_species: raise ValueError( f"The validation dataset has a species ({species}) that is not in the " "training dataset. This could be a result of a random train/validation " diff --git a/tests/utils/data/test_dataset.py b/tests/utils/data/test_dataset.py index 1fa9aad53..e00e0fd5a 100644 --- a/tests/utils/data/test_dataset.py +++ b/tests/utils/data/test_dataset.py @@ -446,40 +446,40 @@ def test_check_datasets(): targets_ethanol, _ = read_targets(OmegaConf.create(conf_ethanol)) # everything ok - training_set = Dataset({"system": systems_qm9, **targets_qm9}) - validation_set = Dataset({"system": systems_qm9, **targets_qm9}) - check_datasets([training_set], [validation_set]) + train_set = Dataset({"system": systems_qm9, **targets_qm9}) + val_set = Dataset({"system": systems_qm9, **targets_qm9}) + check_datasets([train_set], [val_set]) # extra species in validation dataset - training_set = Dataset({"system": systems_ethanol, **targets_qm9}) - validation_set = Dataset({"system": systems_qm9, **targets_qm9}) + train_set = Dataset({"system": systems_ethanol, **targets_qm9}) + val_set = Dataset({"system": systems_qm9, **targets_qm9}) with pytest.raises(ValueError, match="The validation dataset has a species"): - check_datasets([training_set], [validation_set]) + check_datasets([train_set], [val_set]) # extra targets in validation dataset - training_set = Dataset({"system": systems_qm9, **targets_qm9}) - validation_set = Dataset({"system": systems_qm9, **targets_ethanol}) + train_set = Dataset({"system": systems_qm9, **targets_qm9}) + val_set = Dataset({"system": systems_qm9, **targets_ethanol}) with pytest.raises(ValueError, match="The validation dataset has a target"): - check_datasets([training_set], [validation_set]) + check_datasets([train_set], [val_set]) # wrong dtype systems_qm9_64_bit = read_systems( RESOURCES_PATH / "qm9_reduced_100.xyz", dtype=torch.float64 ) - training_set_64_bit = Dataset({"system": systems_qm9_64_bit, **targets_qm9}) + train_set_64_bit = Dataset({"system": systems_qm9_64_bit, **targets_qm9}) match = ( "`dtype` between datasets is inconsistent, found torch.float32 and " - "torch.float64 found in `validation_datasets`" + "torch.float64 found in `val_datasets`" ) with pytest.raises(TypeError, match=match): - check_datasets([training_set], [training_set_64_bit]) + check_datasets([train_set], [train_set_64_bit]) match = ( "`dtype` between datasets is inconsistent, found torch.float32 and " "torch.float64 found in `train_datasets`" ) with pytest.raises(TypeError, match=match): - check_datasets([training_set, training_set_64_bit], [validation_set]) + check_datasets([train_set, train_set_64_bit], [val_set]) def test_collate_fn():