From 664080c6f26eedabcb7afc33d4c6656d4b5a02f5 Mon Sep 17 00:00:00 2001 From: chrisiacovella Date: Thu, 27 Feb 2025 00:27:01 -0800 Subject: [PATCH] Updating tests and scripts to handle db bag end. Small revamps to speed up curation --- modelforge-curate/modelforge/curate/curate.py | 237 ++++++++++++------ .../curate/datasets/ani2x_curation.py | 4 +- .../curate/datasets/curation_baseclass.py | 26 ++ .../curate/datasets/phalkethoh_curation.py | 59 +++-- .../curate/datasets/qm9_curation.py | 4 +- .../datasets/scripts/curate_PhAlkEthOH.py | 139 ++++------ .../curate/datasets/scripts/curate_ani2x.py | 107 ++------ .../curate/datasets/scripts/curate_qm9.py | 95 ++----- .../curate/datasets/scripts/curate_spice1.py | 25 +- .../curate/datasets/scripts/curate_spice2.py | 25 +- .../curate/datasets/scripts/curate_tmqm.py | 1 + .../datasets/scripts/curate_tmqm_xtb.py | 4 +- .../curate/datasets/spice_1_curation.py | 2 +- .../curate/datasets/spice_2_curation.py | 2 +- .../curate/datasets/tmqm_curation.py | 12 +- .../curate/datasets/tmqm_xtb_curation.py | 2 +- .../modelforge/curate/tests/test_curate.py | 2 +- .../curate/tests/test_curation_baseclass.py | 22 +- 18 files changed, 379 insertions(+), 389 deletions(-) diff --git a/modelforge-curate/modelforge/curate/curate.py b/modelforge-curate/modelforge/curate/curate.py index 200afa31..c2f22e2f 100644 --- a/modelforge-curate/modelforge/curate/curate.py +++ b/modelforge-curate/modelforge/curate/curate.py @@ -585,6 +585,36 @@ def add_record(self, record: Record): # self.records[record.name] = copy.deepcopy(record) + def add_records(self, records: List[Record]): + """ + Add a list of records to the dataset. + + Parameters + ---------- + records: List[Record] + List of records to add to the dataset. + + Returns + ------- + + """ + with SqliteDict( + f"{self.local_db_dir}/{self.local_db_name}", + autocommit=True, + ) as db: + for i in range(len(records)): + record_name = records[i].name + + if record_name in self.records.keys(): + log.warning( + f"Record with name {record_name} already exists in the dataset." + ) + raise ValueError( + f"Record with name {record_name} already exists in the dataset." + ) + + db[record_name] = records[i] + def update_record(self, record: Record): """ Update a record in the dataset by overwriting the existing record with the input. @@ -660,9 +690,24 @@ def add_properties( ------- """ + assert isinstance(record_name, str) + # check if the record exists; if it does not add it + if record_name not in self.records.keys(): + log.info( + f"Record with name {record_name} does not exist in the dataset. Creating it now." + ) + self.create_record(record_name) - for property in properties: - self.add_property(record_name, property) + with SqliteDict( + f"{self.local_db_dir}/{self.local_db_name}", + autocommit=True, + ) as db: + record = db[record_name] + record.add_properties(properties) + db[record_name] = record + + # for property in properties: + # self.add_property(record_name, property) def add_property(self, record_name: str, property: Type[PropertyBaseModel]): """ @@ -754,6 +799,7 @@ def subset_dataset( max_configurations_per_record: Optional[int] = None, atomic_numbers_to_limit: Optional[np.ndarray] = None, max_force: Optional[unit.Quantity] = None, + final_configuration_only: Optional[bool] = False, local_db_dir: Optional[str] = None, local_db_name: Optional[str] = None, ) -> Self: @@ -775,6 +821,8 @@ def subset_dataset( Any molecules that contain elements outside of this list will be igonored max_force: Optional[unit.Quantity], default=None If set, configurations with forces greater than this value will be removed. + final_configuration_only: Optional[bool], default=False + If True, only the final configuration of each record will be included in the subset. local_db_dir: str, optional, default=None Directory to store the local database for the new dataset. If not defined, will use the same directory as the current dataset. local_db_name: str, optional, default=None @@ -793,6 +841,14 @@ def subset_dataset( "Cannot set both total_records and total_conformers. Please choose one." ) + if ( + final_configuration_only == True + and max_configurations_per_record is not None + ): + raise ValueError( + "Cannot set final_configuration_only=True and total_conformers. Please choose one." + ) + if total_configurations is not None: if total_configurations > self.total_configs(): log.warning( @@ -864,6 +920,7 @@ def subset_dataset( # now check to see if the n_configs_to_add is more than what we still need to hit max conformers if n_configs_to_add > total_configurations_to_add: n_configs_to_add = total_configurations_to_add + # even if we don't limit the number of configurations, we may still need to limit the number we # include to satisfy the total number of configurations else: @@ -871,9 +928,17 @@ def subset_dataset( if n_configs_to_add > total_configurations_to_add: n_configs_to_add = total_configurations_to_add - record = record.slice_record(0, n_configs_to_add) - new_dataset.add_record(record) - total_configurations_to_add -= n_configs_to_add + if final_configuration_only: + record = record.slice_record( + record.n_configs - 1, record.n_configs + ) + new_dataset.add_record(record) + total_configurations_to_add -= 1 + + else: + record = record.slice_record(0, n_configs_to_add) + new_dataset.add_record(record) + total_configurations_to_add -= n_configs_to_add return new_dataset elif total_records is not None: @@ -904,6 +969,11 @@ def subset_dataset( record = record.slice_record(0, n_to_add) + if final_configuration_only: + record = record.slice_record( + record.n_configs - 1, record.n_configs + ) + new_dataset.add_record(record) total_records_to_add -= 1 return new_dataset @@ -927,6 +997,10 @@ def subset_dataset( if max_configurations_per_record is not None: n_to_add = min(max_configurations_per_record, record.n_configs) record = record.slice_record(0, n_to_add) + if final_configuration_only: + record = record.slice_record( + record.n_configs - 1, record.n_configs + ) new_dataset.add_record(record) return new_dataset @@ -1209,84 +1283,42 @@ def to_hdf5(self, file_path: str, file_name: str): with OpenWithLock(f"{full_file_path}.lockfile", "w") as lockfile: with h5py.File(full_file_path, "w") as f: - for record_name in tqdm(self.records.keys()): - record_group = f.create_group(record_name) - - record = self.get_record(record_name) - - record_group.create_dataset( - "atomic_numbers", - data=record.atomic_numbers.value, - shape=record.atomic_numbers.value.shape, - ) - record_group["atomic_numbers"].attrs["format"] = str( - record.atomic_numbers.classification - ) - record_group["atomic_numbers"].attrs["property_type"] = str( - record.atomic_numbers.property_type - ) - - record_group.create_dataset("n_configs", data=record.n_configs) - record_group["n_configs"].attrs["format"] = "meta_data" + with SqliteDict( + f"{self.local_db_dir}/{self.local_db_name}", + ) as db: + for record_name in tqdm(self.records.keys()): + record_group = f.create_group(record_name) - for key, property in record.per_atom.items(): + record = db[record_name] - target_units = GlobalUnitSystem.get_units( - property.property_type - ) record_group.create_dataset( - key, - data=unit.Quantity(property.value, property.units) - .to(target_units, "chem") - .magnitude, - shape=property.value.shape, + "atomic_numbers", + data=record.atomic_numbers.value, + shape=record.atomic_numbers.value.shape, ) - record_group[key].attrs["u"] = str(target_units) - record_group[key].attrs["format"] = str(property.classification) - record_group[key].attrs["property_type"] = str( - property.property_type + record_group["atomic_numbers"].attrs["format"] = str( + record.atomic_numbers.classification ) - - for key, property in record.per_system.items(): - target_units = GlobalUnitSystem.get_units( - property.property_type - ) - record_group.create_dataset( - key, - data=unit.Quantity(property.value, property.units) - .to(target_units, "chem") - .magnitude, - shape=property.value.shape, + record_group["atomic_numbers"].attrs["property_type"] = str( + record.atomic_numbers.property_type ) - record_group[key].attrs["u"] = str(target_units) - record_group[key].attrs["format"] = str(property.classification) - record_group[key].attrs["property_type"] = str( - property.property_type - ) + record_group.create_dataset("n_configs", data=record.n_configs) + record_group["n_configs"].attrs["format"] = "meta_data" - for key, property in record.meta_data.items(): - if isinstance(property.value, str): - record_group.create_dataset( - key, - data=property.value, - dtype=dt, - ) - record_group[key].attrs["u"] = str(property.units) - record_group[key].attrs["format"] = str( - property.classification - ) - record_group[key].attrs["property_type"] = str( + for key, property in record.per_atom.items(): + + target_units = GlobalUnitSystem.get_units( property.property_type ) - - elif isinstance(property.value, (float, int)): - record_group.create_dataset( key, - data=property.value, + data=unit.Quantity(property.value, property.units) + .to(target_units, "chem") + .magnitude, + shape=property.value.shape, ) - record_group[key].attrs["u"] = str(property.units) + record_group[key].attrs["u"] = str(target_units) record_group[key].attrs["format"] = str( property.classification ) @@ -1294,22 +1326,71 @@ def to_hdf5(self, file_path: str, file_name: str): property.property_type ) - elif isinstance(property.value, np.ndarray): - + for key, property in record.per_system.items(): + target_units = GlobalUnitSystem.get_units( + property.property_type + ) record_group.create_dataset( - key, data=property.value, shape=property.value.shape + key, + data=unit.Quantity(property.value, property.units) + .to(target_units, "chem") + .magnitude, + shape=property.value.shape, ) - record_group[key].attrs["u"] = str(property.units) + + record_group[key].attrs["u"] = str(target_units) record_group[key].attrs["format"] = str( property.classification ) record_group[key].attrs["property_type"] = str( property.property_type ) - else: - raise ValueError( - f"Unsupported type ({type(property.value)}) for metadata {key}" - ) + + for key, property in record.meta_data.items(): + if isinstance(property.value, str): + record_group.create_dataset( + key, + data=property.value, + dtype=dt, + ) + record_group[key].attrs["u"] = str(property.units) + record_group[key].attrs["format"] = str( + property.classification + ) + record_group[key].attrs["property_type"] = str( + property.property_type + ) + + elif isinstance(property.value, (float, int)): + + record_group.create_dataset( + key, + data=property.value, + ) + record_group[key].attrs["u"] = str(property.units) + record_group[key].attrs["format"] = str( + property.classification + ) + record_group[key].attrs["property_type"] = str( + property.property_type + ) + + elif isinstance(property.value, np.ndarray): + + record_group.create_dataset( + key, data=property.value, shape=property.value.shape + ) + record_group[key].attrs["u"] = str(property.units) + record_group[key].attrs["format"] = str( + property.classification + ) + record_group[key].attrs["property_type"] = str( + property.property_type + ) + else: + raise ValueError( + f"Unsupported type ({type(property.value)}) for metadata {key}" + ) from modelforge.utils.remote import calculate_md5_checksum hdf5_checksum = calculate_md5_checksum(file_path=file_path, file_name=file_name) diff --git a/modelforge-curate/modelforge/curate/datasets/ani2x_curation.py b/modelforge-curate/modelforge/curate/datasets/ani2x_curation.py index c7ff092c..e6d721e6 100644 --- a/modelforge-curate/modelforge/curate/datasets/ani2x_curation.py +++ b/modelforge-curate/modelforge/curate/datasets/ani2x_curation.py @@ -99,7 +99,9 @@ def _process_downloaded( conformers_counter = 0 - dataset = SourceDataset(dataset_name="ani2x", local_db_dir=self.local_cache_dir) + dataset = SourceDataset( + dataset_name=self.dataset_name, local_db_dir=self.local_cache_dir + ) with h5py.File(input_file_name, "r") as hf: # The ani2x hdf5 file groups molecules by number of atoms # we need to break up each of these groups into individual molecules diff --git a/modelforge-curate/modelforge/curate/datasets/curation_baseclass.py b/modelforge-curate/modelforge/curate/datasets/curation_baseclass.py index 38e533a9..73c98d6a 100644 --- a/modelforge-curate/modelforge/curate/datasets/curation_baseclass.py +++ b/modelforge-curate/modelforge/curate/datasets/curation_baseclass.py @@ -22,6 +22,7 @@ class DatasetCuration(ABC): def __init__( self, + dataset_name: str, local_cache_dir: Optional[str] = "./datasets_cache", version_select: str = "latest", ): @@ -30,6 +31,8 @@ def __init__( Parameters ---------- + dataset_name: str, required + Name of the dataset to curate. local_cache_dir: str, optional, default='./qm9_datafiles' Location to save downloaded dataset. version_select: str, optional, default='latest' @@ -41,6 +44,8 @@ def __init__( # make sure we can handle a path with a ~ in it self.local_cache_dir = os.path.expanduser(local_cache_dir) self.version_select = version_select + self.dataset_name = dataset_name + os.makedirs(self.local_cache_dir, exist_ok=True) # initialize parameter information @@ -240,6 +245,7 @@ def to_hdf5( total_configurations: Optional[int] = None, atomic_species_to_limit: Optional[List[Union[str, int]]] = None, max_force: Optional[unit.Quantity] = None, + final_configuration_only: Optional[bool] = False, ) -> Tuple[int, int]: """ Writes the dataset to an hdf5 file. @@ -265,6 +271,8 @@ def to_hdf5( These can be passed as a list of strings, e.g., ['C', 'H', 'O'] or as a list of atomic numbers, e.g., [6, 1, 8]. max_force: unit.Quantity, optional, default=None Maximum force to include in the dataset. Any configuration with forces greater than this value will be excluded. + final_configuration_only: bool, optional, default=False + If True, only the final configuration of each record will be included in the dataset. Returns ------- @@ -323,6 +331,7 @@ def to_hdf5( or atomic_species_to_limit is not None or max_force is not None or total_records is not None + or final_configuration_only ): import time import random @@ -342,6 +351,7 @@ def to_hdf5( max_configurations_per_record=max_configurations_per_record, atomic_numbers_to_limit=atomic_numbers_to_limit, max_force=max_force, + final_configuration_only=final_configuration_only, ) if dataset_trimmed.total_records() == 0: raise ValueError("No records found in the dataset after filtering.") @@ -366,3 +376,19 @@ def to_hdf5( file_path=output_file_dir, ) return (self.dataset.total_records(), self.dataset.total_configs()) + + def load_from_db(self, local_db_dir: str, local_db_name: str): + """ + Load the dataset from a local database. + + Parameters + ---------- + local_db_name: str, required + Name of the local database to load the dataset from. + """ + self.dataset = SourceDataset( + dataset_name=self.dataset_name, + local_db_dir=local_db_dir, + local_db_name=local_db_name, + read_from_db=True, + ) diff --git a/modelforge-curate/modelforge/curate/datasets/phalkethoh_curation.py b/modelforge-curate/modelforge/curate/datasets/phalkethoh_curation.py index e031df6b..89a182ba 100644 --- a/modelforge-curate/modelforge/curate/datasets/phalkethoh_curation.py +++ b/modelforge-curate/modelforge/curate/datasets/phalkethoh_curation.py @@ -163,6 +163,9 @@ def _fetch_singlepoint_from_qcarchive( if pbar is not None: pbar.update(1) + from functools import lru_cache + + @lru_cache(maxsize=None) def _calculate_total_charge( self, smiles: str ) -> Tuple[unit.Quantity, unit.Quantity]: @@ -220,7 +223,7 @@ def _process_downloaded( from numpy import newaxis dataset = SourceDataset( - dataset_name="PhAlkEthOH_openff", + dataset_name=self.dataset_name, append_property=True, local_db_dir=self.local_cache_dir, ) @@ -258,6 +261,7 @@ def _process_downloaded( # these properties only need to be added once # so we need to check if if not name in dataset.records.keys(): + dataset.create_record(name) source = MetaData( name="source", value=input_file_name.replace(".sqlite", "") ) @@ -299,19 +303,18 @@ def _process_downloaded( # name = key.split("-")[0] trajectory = spice_db[key][1] + name = f'{key[: key.rfind("-")]}_{trajectory[0][0]["molecule_"]["name"]}' + record = dataset.get_record(name) + for state in trajectory: - add_record = True properties, config = state name = ( f'{key[: key.rfind("-")]}_{properties["molecule_"]["name"]}' ) - smiles = ( - dataset.records[name] - .meta_data[ - "canonical_isomeric_explicit_hydrogen_mapped_smiles" - ] - .value - ) + # record = dataset.get_record(name) + smiles = record.meta_data[ + "canonical_isomeric_explicit_hydrogen_mapped_smiles" + ].value total_charge_temp = self._calculate_total_charge(smiles) @@ -319,12 +322,10 @@ def _process_downloaded( value=np.array(total_charge_temp.m).reshape(1, 1), units=total_charge_temp.u, ) - dataset.add_property(name, total_charge) positions = Positions( value=config.reshape(1, -1, 3), units=unit.bohr ) - dataset.add_property(name, positions) # Note need to typecast here because of a bug in the # qcarchive entry: see issue: https://github.com/MolSSI/QCFractal/issues/766 @@ -339,7 +340,6 @@ def _process_downloaded( ).reshape(1, 1), units=unit.hartree, ) - dataset.add_property(name, dispersion_correction_energy) dft_total_energy = Energies( name="dft_total_energy", @@ -349,7 +349,6 @@ def _process_downloaded( + dispersion_correction_energy.value, units=unit.hartree, ) - dataset.add_property(name, dft_total_energy) dispersion_correction_gradient = Forces( name="dispersion_correction_gradient", @@ -360,14 +359,12 @@ def _process_downloaded( ).reshape(1, -1, 3), units=unit.hartree / unit.bohr, ) - dataset.add_property(name, dispersion_correction_gradient) dispersion_correction_force = Forces( name="dispersion_correction_force", value=-dispersion_correction_gradient.value, units=unit.hartree / unit.bohr, ) - dataset.add_property(name, dispersion_correction_force) dft_total_gradient = Forces( name="dft_total_gradient", @@ -377,14 +374,12 @@ def _process_downloaded( + dispersion_correction_gradient.value, units=unit.hartree / unit.bohr, ) - dataset.add_property(name, dft_total_gradient) dft_total_force = Forces( name="dft_total_force", value=-dft_total_gradient.value, units=unit.hartree / unit.bohr, ) - dataset.add_property(name, dft_total_force) scf_dipole = DipoleMomentPerSystem( name="scf_dipole", @@ -393,8 +388,34 @@ def _process_downloaded( ).reshape(1, 3), units=unit.elementary_charge * unit.bohr, ) - dataset.add_property(name, scf_dipole) - + record.add_properties( + [ + total_charge, + positions, + dispersion_correction_energy, + dft_total_energy, + dispersion_correction_gradient, + dispersion_correction_force, + dft_total_gradient, + dft_total_force, + scf_dipole, + ], + ) + # dataset.add_properties( + # name, + # [ + # total_charge, + # positions, + # dispersion_correction_energy, + # dft_total_energy, + # dispersion_correction_gradient, + # dispersion_correction_force, + # dft_total_gradient, + # dft_total_force, + # scf_dipole, + # ], + # ) + dataset.update_record(record) return dataset def process( diff --git a/modelforge-curate/modelforge/curate/datasets/qm9_curation.py b/modelforge-curate/modelforge/curate/datasets/qm9_curation.py index 1d26ae3e..487cb486 100644 --- a/modelforge-curate/modelforge/curate/datasets/qm9_curation.py +++ b/modelforge-curate/modelforge/curate/datasets/qm9_curation.py @@ -404,7 +404,9 @@ def _process_downloaded( # we do not need to do anything check unit_testing_max_conformers_per_record because qm9 only has a single conformer per record - dataset = SourceDataset(dataset_name="qm9", local_db_dir=self.local_cache_dir) + dataset = SourceDataset( + dataset_name=self.dataset_name, local_db_dir=self.local_cache_dir + ) for i, file in enumerate(tqdm(files, desc="processing", total=len(files))): record_temp = self._parse_xyzfile(f"{local_path_dir}/{file}") dataset.add_record(record_temp) diff --git a/modelforge-curate/modelforge/curate/datasets/scripts/curate_PhAlkEthOH.py b/modelforge-curate/modelforge/curate/datasets/scripts/curate_PhAlkEthOH.py index 0be9948b..d0b306da 100644 --- a/modelforge-curate/modelforge/curate/datasets/scripts/curate_PhAlkEthOH.py +++ b/modelforge-curate/modelforge/curate/datasets/scripts/curate_PhAlkEthOH.py @@ -10,70 +10,6 @@ """ -def PhAlkEthOH_openff_wrapper( - hdf5_file_name: str, - output_file_dir: str, - local_cache_dir: str, - force_download: bool = False, - version_select: str = "latest", - max_records=None, - max_conformers_per_record=None, - total_conformers=None, - max_force=None, - final_conformer_only=False, -): - """ - This curates and processes the SPICE 114 dataset at the OpenFF level of theory into an hdf5 file. - - - Parameters - ---------- - hdf5_file_name: str, required - Name of the hdf5 file that will be generated. - output_file_dir: str, required - Directory where the hdf5 file will be saved. - local_cache_dir: str, required - Directory where the intermediate data will be saved; in this case it will be tarred file downloaded - from figshare and the expanded archive that contains xyz files for each molecule in the dataset. - force_download: bool, optional, default=False - If False, we will use the tarred file that exists in the local_cache_dir (if it exists); - If True, the tarred file will be downloaded, even if it exists locally. - version_select: str, optional, default="latest" - The version of the dataset to use as defined in the associated yaml file. - If "latest", the most recent version will be used. - max_records: int, optional, default=None - The maximum number of records to process. - max_conformers_per_record: int, optional, default=None - The maximum number of conformers to process for each record. - total_conformers: int, optional, default=None - The total number of conformers to process. - max_force: float, optional, default=None - The maximum force to allow in the dataset. Any conformers with forces greater than this value will be ignored. - final_conformer_only: bool, optional, default=False - If True, only the final conformer for each molecule will be processed. If False, all conformers will be processed. - - """ - from modelforge.curate.datasets.phalkethoh_curation import PhAlkEthOHCuration - - PhAlkEthOH_dataset = PhAlkEthOHCuration( - hdf5_file_name=hdf5_file_name, - output_file_dir=output_file_dir, - local_cache_dir=local_cache_dir, - version_select=version_select, - ) - PhAlkEthOH_dataset.process( - force_download=force_download, - max_records=max_records, - max_conformers_per_record=max_conformers_per_record, - total_conformers=total_conformers, - n_threads=1, - max_force=max_force, - final_conformer_only=final_conformer_only, - ) - print(f"Total records: {PhAlkEthOH_dataset.total_records()}") - print(f"Total conformers: {PhAlkEthOH_dataset.total_configs()}") - - def main(): from openff.units import unit @@ -91,62 +27,71 @@ def main(): # version of the dataset to curate version_select = f"v_0" + from modelforge.curate.datasets.phalkethoh_curation import PhAlkEthOHCuration + + PhAlkEthOH_openff = PhAlkEthOHCuration( + dataset_name="PhAlkEthOH_openff", + local_cache_dir=local_cache_dir, + version_select=version_select, + ) + + PhAlkEthOH_openff.process(force_download=False) + # curate dataset with 1000 total conformers, max of 10 conformers per record hdf5_file_name = f"PhAlkEthOH_openff_dataset_v{version}_ntc_1000.hdf5" - PhAlkEthOH_openff_wrapper( - hdf5_file_name, - output_file_dir, - local_cache_dir, - force_download=False, - version_select=version_select, - max_records=1000, + n_total_records, n_total_configs = PhAlkEthOH_openff.to_hdf5( + hdf5_file_name=hdf5_file_name, + output_file_dir=output_file_dir, total_conformers=1000, max_conformers_per_record=10, max_force=1.0 * unit.hartree / unit.bohr, ) + print("1000 conformer subset") + print(f"Total records: {n_total_records}") + print(f"Total configs: {n_total_configs}") # curate the full dataset hdf5_file_name = f"PhAlkEthOH_openff_dataset_v{version}.hdf5" print("total dataset") - PhAlkEthOH_openff_wrapper( - hdf5_file_name, - output_file_dir, - local_cache_dir, - force_download=False, - version_select=version_select, + + n_total_records, n_total_configs = PhAlkEthOH_openff.to_hdf5( + hdf5_file_name=hdf5_file_name, + output_file_dir=output_file_dir, max_force=1.0 * unit.hartree / unit.bohr, ) + print(f"Total records: {n_total_records}") + print(f"Total configs: {n_total_configs}") - # curate dataset with 1000 total conformers, max of 10 conformers per record + # curate dataset with 1000 total conformers, last only hdf5_file_name = f"PhAlkEthOH_openff_dataset_v{version}_ntc_1000_minimal.hdf5" - PhAlkEthOH_openff_wrapper( - hdf5_file_name, - output_file_dir, - local_cache_dir, - force_download=False, - version_select=version_select, - max_records=1000, - total_conformers=1000, - max_conformers_per_record=10, + n_total_records, n_total_configs = PhAlkEthOH_openff.to_hdf5( + hdf5_file_name=hdf5_file_name, + output_file_dir=output_file_dir, + total_configurations=1000, max_force=1.0 * unit.hartree / unit.bohr, - final_conformer_only=True, + final_configuration_only=True, ) + print("1000 conformer subset last configurations only") + print(f"Total records: {n_total_records}") + print(f"Total configs: {n_total_configs}") + + # curate the full dataset last config only - # curate the full dataset hdf5_file_name = f"PhAlkEthOH_openff_dataset_v{version}_minimal.hdf5" - print("total dataset") - PhAlkEthOH_openff_wrapper( - hdf5_file_name, - output_file_dir, - local_cache_dir, - force_download=False, - version_select=version_select, + + n_total_records, n_total_configs = PhAlkEthOH_openff.to_hdf5( + hdf5_file_name=hdf5_file_name, + output_file_dir=output_file_dir, max_force=1.0 * unit.hartree / unit.bohr, - final_conformer_only=True, + final_configuration_only=True, ) + print("full dataset last configurations only") + print(f"Total records: {n_total_records}") + print(f"Total configs: {n_total_configs}") + if __name__ == "__main__": main() diff --git a/modelforge-curate/modelforge/curate/datasets/scripts/curate_ani2x.py b/modelforge-curate/modelforge/curate/datasets/scripts/curate_ani2x.py index a9e95ba7..50187b72 100644 --- a/modelforge-curate/modelforge/curate/datasets/scripts/curate_ani2x.py +++ b/modelforge-curate/modelforge/curate/datasets/scripts/curate_ani2x.py @@ -19,76 +19,10 @@ """ -def ani2x_wrapper( - hdf5_file_name: str, - output_file_dir: str, - local_cache_dir: str, - force_download: bool = False, - version_select: str = "latest", - max_records=None, - max_conformers_per_record=None, - total_conformers=None, -): - """ - This fetches and processes the ANI2x dataset into a curated hdf5 file. - - The ANI-2x data set includes properties for small organic molecules that contain - H, C, N, O, S, F, and Cl. This dataset contains 9651712 conformers for 200,000 - This will fetch data generated with the wB97X/631Gd level of theory - used in the original ANI-2x paper, calculated using Gaussian 09 - - Citation: Devereux, C, Zubatyuk, R., Smith, J. et al. - "Extending the applicability of the ANI deep learning molecular potential to sulfur and halogens." - Journal of Chemical Theory and Computation 16.7 (2020): 4192-4202. - https://doi.org/10.1021/acs.jctc.0c00121 - - DOI for dataset: 10.5281/zenodo.10108941 - - Parameters - ---------- - hdf5_file_name: str, required - Name of the hdf5 file that will be generated. - output_file_dir: str, required - Directory where the hdf5 file will be saved. - local_cache_dir: str, required - Directory where the intermediate data will be saved; in this case it will be tarred file downloaded - from figshare and the expanded archive that contains xyz files for each molecule in the dataset. - force_download: bool, optional, default=False - If False, we will use the tarred file that exists in the local_cache_dir (if it exists); - If True, the tarred file will be downloaded, even if it exists locally. - version_select: str, optional, default="latest" - The version of the dataset to use as defined in the associated yaml file. - If "latest", the most recent version will be used. - max_records: int, optional, default=None - The maximum number of records to process. - max_conformers_per_record: int, optional, default=None - The maximum number of conformers to process for each record. - total_conformers: int, optional, default=None - The total number of conformers to process. - - """ - from modelforge.curate.datasets.ani2x_curation import ANI2xCuration - - ani2x = ANI2xCuration( - hdf5_file_name=hdf5_file_name, - output_file_dir=output_file_dir, - local_cache_dir=local_cache_dir, - version_select=version_select, - ) - - ani2x.process( - force_download=force_download, - max_records=max_records, - max_conformers_per_record=max_conformers_per_record, - total_conformers=total_conformers, - ) - print(f"Total records: {ani2x.total_records()}") - print(f"Total configs: {ani2x.total_configs()}") - - def main(): # define the location where to store and output the files import os + from modelforge.curate.datasets.ani2x_curation import ANI2xCuration local_prefix = os.path.expanduser("~/mf_datasets") output_file_dir = f"{local_prefix}/hdf5_files" @@ -101,30 +35,41 @@ def main(): # version of the dataset to curate version_select = f"v_0" + ani2x_dataset = ANI2xCuration( + dataset_name="ani2x", + local_cache_dir=local_cache_dir, + version_select=version_select, + ) + + ani2x_dataset.process(force_download=False) + # curate ANI2x test dataset with 1000 total conformers, max of 10 conformers per record hdf5_file_name = f"ani2x_dataset_v{version}_ntc_1000.hdf5" - ani2x_wrapper( - hdf5_file_name, - output_file_dir, - local_cache_dir, - force_download=False, - max_conformers_per_record=10, - version_select=version_select, - total_conformers=1000, + n_total_records, n_total_configs = ani2x_dataset.to_hdf5( + hdf5_file_name=hdf5_file_name, + output_file_dir=output_file_dir, + total_configurations=1000, + max_configurations_per_record=10, ) + print("1000 configuration dataset") + print(f"Total records: {n_total_records}") + print(f"Total configs: {n_total_configs}") # curate the full ANI-2x dataset hdf5_file_name = f"ani2x_dataset_v{version}.hdf5" - ani2x_wrapper( - hdf5_file_name, - output_file_dir, - local_cache_dir, - force_download=False, - version_select=version_select, + # full datset + + n_total_records, n_total_configs = ani2x_dataset.to_hdf5( + hdf5_file_name=hdf5_file_name, + output_file_dir=output_file_dir, ) + print("full dataset dataset") + print(f"Total records: {n_total_records}") + print(f"Total configs: {n_total_configs}") + if __name__ == "__main__": main() diff --git a/modelforge-curate/modelforge/curate/datasets/scripts/curate_qm9.py b/modelforge-curate/modelforge/curate/datasets/scripts/curate_qm9.py index 748cec8f..5799026a 100644 --- a/modelforge-curate/modelforge/curate/datasets/scripts/curate_qm9.py +++ b/modelforge-curate/modelforge/curate/datasets/scripts/curate_qm9.py @@ -17,68 +17,12 @@ """ -def qm9_wrapper( - hdf5_file_name: str, - output_file_dir: str, - local_cache_dir: str, - force_download: bool = False, - version_select: str = "latest", - max_records=None, - max_conformers_per_record=None, - total_conformers=None, -): - """ - This instantiates and calls the QM9Curation class to generate the hdf5 file for the QM9 dataset. - - Parameters - ---------- - hdf5_file_name: str, required - Name of the hdf5 file that will be generated. - output_file_dir: str, required - Directory where the hdf5 file will be saved. - local_cache_dir: str, required - Directory where the intermediate data will be saved; in this case it will be tarred file downloaded - from figshare and the expanded archive that contains xyz files for each molecule in the dataset. - force_download: bool, optional, default=False - If False, we will use the tarred file that exists in the local_cache_dir (if it exists); - If True, the tarred file will be downloaded, even if it exists locally. - version_select: str, optional, default="latest" - The version of the dataset to use as defined in the associated yaml file. - If "latest", the most recent version will be used. - max_records: int, optional, default=None - The maximum number of records to process. - max_conformers_per_record: int, optional, default=None - The maximum number of conformers to process for each record. - total_conformers: int, optional, default=None - The total number of conformers to process. - - - """ - from modelforge.curate.datasets.qm9_curation import QM9Curation - - qm9 = QM9Curation( - hdf5_file_name=hdf5_file_name, - output_file_dir=output_file_dir, - local_cache_dir=local_cache_dir, - version_select=version_select, - ) - - qm9.process( - force_download=force_download, - max_records=max_records, - max_conformers_per_record=max_conformers_per_record, - total_conformers=total_conformers, - ) - print(f"Total records: {qm9.total_records()}") - print(f"Total configurations: {qm9.total_configs()}") - - def main(): # define the location where to store and output the files import os local_prefix = os.path.expanduser("~/mf_datasets") - output_file_dir = f"{local_prefix}/hdf5_files" + output_file_dir = f"{local_prefix}/hdf5_files/qm9_dataset" local_cache_dir = f"{local_prefix}/qm9_dataset" # We'll want to provide some simple means of versioning @@ -87,28 +31,37 @@ def main(): # version of the dataset to curate version_select = f"v_0" # Curate the test dataset with 1000 total conformers - hdf5_file_name = f"qm9_dataset_v{version}_ntc_1000.hdf5" + from modelforge.curate.datasets.qm9_curation import QM9Curation - qm9_wrapper( - hdf5_file_name, - output_file_dir, - local_cache_dir, - force_download=False, + qm9_dataset = QM9Curation( + dataset_name="qm9", + local_cache_dir=local_cache_dir, version_select=version_select, - max_conformers_per_record=1, # there is only one conformer per molecule in the QM9 dataset - total_conformers=1000, ) + qm9_dataset.process(force_download=False) # Curates the full dataset hdf5_file_name = f"qm9_dataset_v{version}.hdf5" - qm9_wrapper( - hdf5_file_name, - output_file_dir, - local_cache_dir, - force_download=False, - version_select=version_select, + n_total_records, n_total_configs = qm9_dataset.to_hdf5( + hdf5_file_name=hdf5_file_name, output_file_dir=output_file_dir + ) + print("full dataset") + print(f"Total records: {n_total_records}") + print(f"Total configs: {n_total_configs}") + + # Curates the test dataset with 1000 total conformers + # only a single config per record + + hdf5_file_name = f"qm9_dataset_v{version}_ntc_1000.hdf5" + n_total_records, n_total_configs = qm9_dataset.to_hdf5( + hdf5_file_name=hdf5_file_name, + output_file_dir=output_file_dir, + total_configurations=1000, ) + print(" 1000 configuration subset") + print(f"Total records: {n_total_records}") + print(f"Total configs: {n_total_configs}") if __name__ == "__main__": diff --git a/modelforge-curate/modelforge/curate/datasets/scripts/curate_spice1.py b/modelforge-curate/modelforge/curate/datasets/scripts/curate_spice1.py index ae05d343..020ae1c5 100644 --- a/modelforge-curate/modelforge/curate/datasets/scripts/curate_spice1.py +++ b/modelforge-curate/modelforge/curate/datasets/scripts/curate_spice1.py @@ -42,6 +42,7 @@ def main(): from modelforge.curate.datasets.spice_1_curation import SPICE1Curation spice1_dataset = SPICE1Curation( + dataset_name="spice1", local_cache_dir=local_cache_dir, version_select=version_select, ) @@ -54,7 +55,7 @@ def main(): # limited to the elements that will work with ANI2x hdf5_file_name = f"spice_1_dataset_v{version}_ntc_1000_HCNOFClS.hdf5" - spice1_dataset.to_hdf5( + total_records, total_configs = spice1_dataset.to_hdf5( hdf5_file_name=hdf5_file_name, output_file_dir=output_file_dir, total_configurations=1000, @@ -63,26 +64,26 @@ def main(): ) print("SPICE1: 1000 configuration subset limited to ANI2x elements") - print(f"Total records: {spice1_dataset.total_records()}") - print(f"Total configs: {spice1_dataset.total_configs()}") + print(f"Total records: {total_records}") + print(f"Total configs: {total_configs}") # curate the full SPICE 1.1.4 dataset, limited to the elements that will work with ANI2x hdf5_file_name = f"spice_1_dataset_v{version}_HCNOFClS.hdf5" - spice1_dataset.to_hdf5( + total_records, total_configs = spice1_dataset.to_hdf5( hdf5_file_name=hdf5_file_name, output_file_dir=output_file_dir, atomic_species_to_limit=ani2x_elements, ) print("SPICE1: full dataset limited to ANI2x elements") - print(f"Total records: {spice1_dataset.total_records()}") - print(f"Total configs: {spice1_dataset.total_configs()}") + print(f"Total records: {total_records}") + print(f"Total configs: {total_configs}") # curate the test SPICE 1.1.4 dataset with 1000 total configurations, max of 10 configurations per record hdf5_file_name = f"spice_1_dataset_v{version}_ntc_1000.hdf5" - spice1_dataset.to_hdf5( + total_records, total_configs = spice1_dataset.to_hdf5( hdf5_file_name=hdf5_file_name, output_file_dir=output_file_dir, total_configurations=1000, @@ -90,18 +91,18 @@ def main(): ) print("SPICE1: 1000 configuration subset") - print(f"Total records: {spice1_dataset.total_records()}") - print(f"Total configs: {spice1_dataset.total_configs()}") + print(f"Total records: {total_records}") + print(f"Total configs: {total_configs}") # curate the full SPICE 1.1.4 dataset hdf5_file_name = f"spice_1_dataset_v{version}.hdf5" - spice1_dataset.to_hdf5( + total_records, total_configs = spice1_dataset.to_hdf5( hdf5_file_name=hdf5_file_name, output_file_dir=output_file_dir ) print("SPICE1: full dataset") - print(f"Total records: {spice1_dataset.total_records()}") - print(f"Total configs: {spice1_dataset.total_configs()}") + print(f"Total records: {total_records}") + print(f"Total configs: {total_configs}") if __name__ == "__main__": diff --git a/modelforge-curate/modelforge/curate/datasets/scripts/curate_spice2.py b/modelforge-curate/modelforge/curate/datasets/scripts/curate_spice2.py index 9649c104..7342c80a 100644 --- a/modelforge-curate/modelforge/curate/datasets/scripts/curate_spice2.py +++ b/modelforge-curate/modelforge/curate/datasets/scripts/curate_spice2.py @@ -42,6 +42,7 @@ def main(): from modelforge.curate.datasets.spice_2_curation import SPICE2Curation spice2_dataset = SPICE2Curation( + dataset_name="spice2", local_cache_dir=local_cache_dir, version_select=version_select, ) @@ -54,7 +55,7 @@ def main(): # limited to the elements that will work with ANI2x hdf5_file_name = f"spice_2_dataset_v{version}_ntc_1000_HCNOFClS.hdf5" - spice2_dataset.to_hdf5( + total_records, total_configs = spice2_dataset.to_hdf5( hdf5_file_name=hdf5_file_name, output_file_dir=output_file_dir, total_configurations=1000, @@ -63,26 +64,26 @@ def main(): ) print("SPICE2: 1000 configuration subset limited to ANI2x elements") - print(f"Total records: {spice2_dataset.total_records()}") - print(f"Total configs: {spice2_dataset.total_configs()}") + print(f"Total records: {total_records}") + print(f"Total configs: {total_configs}") # curate the full SPICE 2.0.1 dataset, limited to the elements that will work with ANI2x hdf5_file_name = f"spice_2_dataset_v{version}_HCNOFClS.hdf5" - spice2_dataset.to_hdf5( + total_records, total_configs = spice2_dataset.to_hdf5( hdf5_file_name=hdf5_file_name, output_file_dir=output_file_dir, atomic_species_to_limit=ani2x_elements, ) print("SPICE2: full dataset limited to ANI2x elements") - print(f"Total records: {spice2_dataset.total_records()}") - print(f"Total configs: {spice2_dataset.total_configs()}") + print(f"Total records: {total_records}") + print(f"Total configs: {total_configs}") # curate the test SPICE 2.0.1 dataset with 1000 total configurations, max of 10 configurations per record hdf5_file_name = f"spice_2_dataset_v{version}_ntc_1000.hdf5" - spice2_dataset.to_hdf5( + total_records, total_configs = spice2_dataset.to_hdf5( hdf5_file_name=hdf5_file_name, output_file_dir=output_file_dir, total_configurations=1000, @@ -90,18 +91,18 @@ def main(): ) print("SPICE2: 1000 configuration subset") - print(f"Total records: {spice2_dataset.total_records()}") - print(f"Total configs: {spice2_dataset.total_configs()}") + print(f"Total records: {total_records}") + print(f"Total configs: {total_configs}") # curate the full SPICE 2.0.1 dataset hdf5_file_name = f"spice_2_dataset_v{version}.hdf5" - spice2_dataset.to_hdf5( + total_records, total_configs = spice2_dataset.to_hdf5( hdf5_file_name=hdf5_file_name, output_file_dir=output_file_dir ) print("SPICE2: full dataset") - print(f"Total records: {spice2_dataset.total_records()}") - print(f"Total configs: {spice2_dataset.total_configs()}") + print(f"Total records: {total_records}") + print(f"Total configs: {total_configs}") if __name__ == "__main__": diff --git a/modelforge-curate/modelforge/curate/datasets/scripts/curate_tmqm.py b/modelforge-curate/modelforge/curate/datasets/scripts/curate_tmqm.py index 241b5132..bbf994b2 100644 --- a/modelforge-curate/modelforge/curate/datasets/scripts/curate_tmqm.py +++ b/modelforge-curate/modelforge/curate/datasets/scripts/curate_tmqm.py @@ -40,6 +40,7 @@ def main(): from modelforge.curate.datasets.tmqm_curation import tmQMCuration tmqm = tmQMCuration( + dataset_name="tmqm", local_cache_dir=local_cache_dir, version_select=version_select, ) diff --git a/modelforge-curate/modelforge/curate/datasets/scripts/curate_tmqm_xtb.py b/modelforge-curate/modelforge/curate/datasets/scripts/curate_tmqm_xtb.py index 84fbcac4..cb1d170b 100644 --- a/modelforge-curate/modelforge/curate/datasets/scripts/curate_tmqm_xtb.py +++ b/modelforge-curate/modelforge/curate/datasets/scripts/curate_tmqm_xtb.py @@ -66,7 +66,9 @@ def main(): version_select = f"v_{version}" tmqm_xtb = tmQMXTBCuration( - local_cache_dir=local_cache_dir, version_select=version_select + dataset_name="tmqm_xtb", + local_cache_dir=local_cache_dir, + version_select=version_select, ) tmqm_xtb.process(force_download=False) diff --git a/modelforge-curate/modelforge/curate/datasets/spice_1_curation.py b/modelforge-curate/modelforge/curate/datasets/spice_1_curation.py index d501784f..04be9b5d 100644 --- a/modelforge-curate/modelforge/curate/datasets/spice_1_curation.py +++ b/modelforge-curate/modelforge/curate/datasets/spice_1_curation.py @@ -130,7 +130,7 @@ def _process_downloaded( input_file_name = f"{local_path_dir}/{name}" dataset = SourceDataset( - dataset_name="spice1", local_db_dir=self.local_cache_dir + dataset_name=self.dataset_name, local_db_dir=self.local_cache_dir ) with OpenWithLock(f"{input_file_name}.lockfile", "w") as lockfile: diff --git a/modelforge-curate/modelforge/curate/datasets/spice_2_curation.py b/modelforge-curate/modelforge/curate/datasets/spice_2_curation.py index ad91c998..0b38fe9e 100644 --- a/modelforge-curate/modelforge/curate/datasets/spice_2_curation.py +++ b/modelforge-curate/modelforge/curate/datasets/spice_2_curation.py @@ -141,7 +141,7 @@ def _process_downloaded( input_file_name = f"{local_path_dir}/{name}" dataset = SourceDataset( - dataset_name="spice2", local_db_dir=self.local_cache_dir + dataset_name=self.dataset_name, local_db_dir=self.local_cache_dir ) with OpenWithLock(f"{input_file_name}.lockfile", "w") as lockfile: diff --git a/modelforge-curate/modelforge/curate/datasets/tmqm_curation.py b/modelforge-curate/modelforge/curate/datasets/tmqm_curation.py index 83b9b30a..e2fd8c5d 100644 --- a/modelforge-curate/modelforge/curate/datasets/tmqm_curation.py +++ b/modelforge-curate/modelforge/curate/datasets/tmqm_curation.py @@ -174,7 +174,9 @@ def _process_downloaded( from modelforge.dataset.utils import _ATOMIC_ELEMENT_TO_NUMBER from modelforge.utils.misc import str_to_float - dataset = SourceDataset(dataset_name="tmqm", local_db_dir=self.local_cache_dir) + dataset = SourceDataset( + dataset_name=self.dataset_name, local_db_dir=self.local_cache_dir + ) # aggregate the snapshot contents into a list snapshots = [] @@ -274,9 +276,7 @@ def _process_downloaded( partial_charges = PartialCharges( value=np.array(charges).reshape(1, -1, 1), units=unit.elementary_charge ) - dataset.add_properties( - record_name=record_name, properties=[partial_charges] - ) + dataset.add_property(record_name=record_name, property=partial_charges) columns = [] csv_temp_dict = {} @@ -285,7 +285,9 @@ def _process_downloaded( with open(csv_input_file) as csv_file: csv_reader = csv.reader(csv_file, delimiter=";") line_count = 0 - for row in tqdm(csv_reader): + for row in tqdm( + csv_reader, desc="Processing csv file", total=len(snapshot_charges) + ): if line_count == 0: columns = row line_count += 1 diff --git a/modelforge-curate/modelforge/curate/datasets/tmqm_xtb_curation.py b/modelforge-curate/modelforge/curate/datasets/tmqm_xtb_curation.py index 9d6a7cd1..98a33cf7 100644 --- a/modelforge-curate/modelforge/curate/datasets/tmqm_xtb_curation.py +++ b/modelforge-curate/modelforge/curate/datasets/tmqm_xtb_curation.py @@ -144,7 +144,7 @@ def _process_downloaded( import h5py dataset = SourceDataset( - dataset_name="tmqm_xtb", local_db_dir=self.local_cache_dir + dataset_name=self.dataset_name, local_db_dir=self.local_cache_dir ) with OpenWithLock(f"{local_path_dir}/{hdf5_file_name}.lockfile", "w") as f: with h5py.File(f"{local_path_dir}/{hdf5_file_name}", "r") as f: diff --git a/modelforge-curate/modelforge/curate/tests/test_curate.py b/modelforge-curate/modelforge/curate/tests/test_curate.py index a9fc29a8..183cc272 100644 --- a/modelforge-curate/modelforge/curate/tests/test_curate.py +++ b/modelforge-curate/modelforge/curate/tests/test_curate.py @@ -27,7 +27,7 @@ def test_dataset_create_record(prep_temp_dir): # test creating a record that already exists # this will fail new_dataset = SourceDataset( - "test_dataset2", + dataset_name="test_dataset2", local_db_dir=str(prep_temp_dir), local_db_name="test_dataset2.sqlite", ) diff --git a/modelforge-curate/modelforge/curate/tests/test_curation_baseclass.py b/modelforge-curate/modelforge/curate/tests/test_curation_baseclass.py index 84155d40..99012340 100644 --- a/modelforge-curate/modelforge/curate/tests/test_curation_baseclass.py +++ b/modelforge-curate/modelforge/curate/tests/test_curation_baseclass.py @@ -10,10 +10,10 @@ from modelforge.curate.datasets.curation_baseclass import DatasetCuration -def setup_test_dataset(local_cache_dir): +def setup_test_dataset(dataset_name, local_cache_dir): class TestCuration(DatasetCuration): def _init_dataset_parameters(self): - self.dataset = SourceDataset("test_dataset") + self.dataset = SourceDataset(dataset_name=self.dataset_name) for i in range(5): atomic_numbers = AtomicNumbers(value=[[6 + i], [1]]) positions = Positions( @@ -39,7 +39,7 @@ def _init_dataset_parameters(self): record.add_properties([atomic_numbers, positions, energy, forces]) self.dataset.add_record(record) - return TestCuration(local_cache_dir=local_cache_dir) + return TestCuration(dataset_name=dataset_name, local_cache_dir=local_cache_dir) def test_dipolemoment_calculation(): @@ -281,7 +281,7 @@ def _init_dataset_parameters(self): value=np.linalg.norm(scf_dipole_moment.value).reshape(1, 1), units=unit.elementary_charge * unit.nanometer, ) - dataset_curation = TestCuration() + dataset_curation = TestCuration(dataset_name="test_dataset") dipole_moment_comp = dataset_curation.compute_dipole_moment( atomic_numbers=atomic_numbers, positions=positions, partial_charges=charges @@ -311,7 +311,7 @@ def _init_dataset_parameters(self): def test_base_convert_element_string_to_atomic_number(prep_temp_dir): - curated_dataset = setup_test_dataset(str(prep_temp_dir)) + curated_dataset = setup_test_dataset("test_dataset_1a", str(prep_temp_dir)) output = curated_dataset._convert_element_list_to_atomic_numbers(["C", "H"]) assert np.all(output == np.array([6, 1])) @@ -319,7 +319,7 @@ def test_base_convert_element_string_to_atomic_number(prep_temp_dir): def test_base_operations(prep_temp_dir): output_dir = f"{prep_temp_dir}/test_base_operations" - curated_dataset = setup_test_dataset(output_dir) + curated_dataset = setup_test_dataset("test_dataset_1b", output_dir) # test writing the dataset n_record, n_configs = curated_dataset.to_hdf5( @@ -437,6 +437,14 @@ def test_base_operations(prep_temp_dir): assert n_configs == 5 assert os.path.exists(f"{output_dir}/test_species5.hdf5") + n_record, n_configs = curated_dataset.to_hdf5( + hdf5_file_name="test_species5.hdf5", + output_file_dir=output_dir, + final_configuration_only=True, + ) + assert n_record == 5 + assert n_configs == 5 + # test to see if we can remove high energy configurations # anything greater than 2.5 should exlcude the last record n_record, n_configs = curated_dataset.to_hdf5( @@ -490,7 +498,7 @@ def test_base_operations(prep_temp_dir): # make the original dataset empty with pytest.raises(ValueError): - empty_dataset = SourceDataset("empty_dataset") + empty_dataset = SourceDataset(dataset_name="empty_dataset") curated_dataset.dataset = empty_dataset n_record, n_configs = curated_dataset.to_hdf5( hdf5_file_name="test_energy.hdf5",