diff --git a/src/metatrain/experimental/gap/trainer.py b/src/metatrain/experimental/gap/trainer.py index de859c9c0..b113ffe03 100644 --- a/src/metatrain/experimental/gap/trainer.py +++ b/src/metatrain/experimental/gap/trainer.py @@ -38,17 +38,16 @@ def train( raise ValueError("GAP only supports a single training dataset") if len(validation_datasets) != 1: raise ValueError("GAP only supports a single validation dataset") + outputs_dict = model.dataset_info.targets + if len(outputs_dict.keys()) > 1: + raise NotImplementedError("More than one output is not supported yet.") + output_name = next(iter(outputs_dict.keys())) # Perform checks on the datasets: logger.info("Checking datasets for consistency") check_datasets(train_datasets, validation_datasets) - logger.info("Training on device cpu") - - outputs_dict = model.dataset_info.targets - if len(outputs_dict.keys()) > 1: - raise NotImplementedError("More than one output is not supported yet.") - output_name = next(iter(outputs_dict.keys())) + logger.info(f"Training on device cpu with dtype {dtype}") # Calculate and set the composition weights: logger.info("Calculating composition weights") @@ -58,7 +57,6 @@ def train( model.set_composition_weights(target_name, composition_weights, species) logger.info("Setting up data loaders") - if len(train_datasets[0][0][output_name].keys) > 1: raise NotImplementedError( "Found more than 1 key in targets. Assuming " @@ -72,6 +70,8 @@ def train( ) model._keys = train_y.keys train_structures = [sample["system"] for sample in train_dataset] + + logger.info("Fitting composition energies") composition_energies = torch.zeros(len(train_y.block().values), dtype=dtype) for i, structure in enumerate(train_structures): for j, s in enumerate(species): @@ -88,12 +88,12 @@ def train( ) if len(train_y[0].gradients_list()) > 0: train_block.add_gradient("positions", train_y[0].gradient("positions")) - train_y = metatensor.torch.TensorMap( train_y.keys, [train_block], ) + logger.info("Calculating SOAP features") if len(train_y[0].gradients_list()) > 0: train_tensor = model._soap_torch_calculator.compute( train_structures, gradients=["positions"] @@ -113,11 +113,12 @@ def train( train_tensor = torch_tensor_map_to_core(train_tensor) train_y = torch_tensor_map_to_core(train_y) + logger.info("Selecting sparse points") lens = len(train_tensor[0].values) if model._sampler._n_to_select > lens: raise ValueError( - f"""number of sparse points ({model._sampler._n_to_select}) - should be smaller than the number of environments ({lens})""" + f"Number of sparse points ({model._sampler._n_to_select}) " + f"should be smaller than the number of environments ({lens})" ) sparse_points = model._sampler.fit_transform(train_tensor) sparse_points = metatensor.operations.remove_gradients(sparse_points) @@ -127,9 +128,7 @@ def train( else: alpha_forces = self.hypers["regularizer_forces"] - logger.info(f"Training on device cpu with dtype {dtype}") - logger.info("Fitting GAP") - + logger.info("Fitting GAP model") model._subset_of_regressors.fit( train_tensor, sparse_points,