diff --git a/modelforge-curate/modelforge/curate/curate.py b/modelforge-curate/modelforge/curate/curate.py index 5a32719c..200afa31 100644 --- a/modelforge-curate/modelforge/curate/curate.py +++ b/modelforge-curate/modelforge/curate/curate.py @@ -12,12 +12,13 @@ import numpy as np import copy - +import os from typing import Union, List, Type, Optional from typing_extensions import Self from loguru import logger as log +from sqlitedict import SqliteDict class Record: @@ -438,6 +439,9 @@ def __init__( self, dataset_name: str, append_property: bool = False, + local_db_dir: Optional[str] = "./", + local_db_name: Optional[str] = None, + read_from_local_db=False, ): """ Class to hold a dataset of properties for a given dataset name @@ -449,12 +453,49 @@ def __init__( append_property: bool, optional, default=False If True, append an array to existing array if a property with the same name is added multiple times to a record. If False, an error will be raised if trying to add a property with the same name already exists in a record - Use True if data for configurations are stored in separate files/database entries and you want to combine them. + local_db_dir: str, optional, default="./" + Directory to store the local database + local_db_name: str, optional, default=None + Name of the cache database. If None, the dataset name will be used. + read_from_local_db: bool, optional, default=False + If True, use an existing database. + If False, removes the existing database and creates a new one. """ self.dataset_name = dataset_name self.records = {} self.append_property = append_property + self.local_db_dir = local_db_dir + + if local_db_name is None: + self.local_db_name = dataset_name.replace(" ", "_") + ".sqlite" + else: + self.local_db_name = local_db_name + + self.read_from_local_db = read_from_local_db + if self.read_from_local_db == False: + if os.path.exists(f"{self.local_db_dir}/{self.local_db_name}"): + log.warning( + f"Database file {self.local_db_name} already exists in {self.local_db_dir}. Removing it." + ) + self._remove_local_db() + else: + if not os.path.exists(f"{self.local_db_dir}/{self.local_db_name}"): + log.warning( + f"Database file {self.local_db_name} does not exist in {self.local_db_dir}" + ) + raise ValueError( + f"Database file {self.local_db_name} does not exist in {self.local_db_dir}." + ) + else: + # populate the records dict with the keys from the database + with SqliteDict( + f"{self.local_db_dir}/{self.local_db_name}", + autocommit=True, + ) as db: + keys = list(db.keys()) + for key in keys: + self.records[key] = key def total_records(self): """ @@ -471,7 +512,8 @@ def total_configs(self): Get the total number of configurations in the dataset. """ total_config = 0 - for record in self.records.values(): + for record in self.records.keys(): + record = self.get_record(record) total_config += record.n_configs return total_config @@ -503,7 +545,12 @@ def create_record( f"Record with name {record_name} already exists in the dataset" ) - self.records[record_name] = Record(record_name, self.append_property) + self.records[record_name] = record_name + with SqliteDict( + f"{self.local_db_dir}/{self.local_db_name}", + autocommit=True, + ) as db: + db[record_name] = Record(record_name, self.append_property) if properties is not None: self.add_properties_to_record(record_name, properties) @@ -528,8 +575,15 @@ def add_record(self, record: Record): raise ValueError( f"Record with name {record.name} already exists in the dataset." ) + with SqliteDict( + f"{self.local_db_dir}/{self.local_db_name}", + autocommit=True, + ) as db: + db[record.name] = record + + self.records[record.name] = record.name - self.records[record.name] = copy.deepcopy(record) + # self.records[record.name] = copy.deepcopy(record) def update_record(self, record: Record): """ @@ -555,8 +609,13 @@ def update_record(self, record: Record): raise ValueError( f"Record with name {record.name} does not exist in the dataset." ) + with SqliteDict( + f"{self.local_db_dir}/{self.local_db_name}", + autocommit=True, + ) as db: + db[record.name] = record - self.records[record.name] = copy.deepcopy(record) + # self.records[record.name] = copy.deepcopy(record) def remove_record(self, record_name: str): """ @@ -574,6 +633,11 @@ def remove_record(self, record_name: str): assert isinstance(record_name, str) if record_name in self.records.keys(): self.records.pop(record_name) + with SqliteDict( + f"{self.local_db_dir}/{self.local_db_name}", + autocommit=True, + ) as db: + db.pop(record_name) else: log.warning( f"Record with name {record_name} does not exist in the dataset." @@ -623,7 +687,14 @@ def add_property(self, record_name: str, property: Type[PropertyBaseModel]): ) self.create_record(record_name) - self.records[record_name].add_property(property) + with SqliteDict( + f"{self.local_db_dir}/{self.local_db_name}", + autocommit=True, + ) as db: + record = db[record_name] + record.add_property(property) + db[record_name] = record + # self.records[record_name].add_property(property) def get_record(self, record_name: str): """ @@ -641,7 +712,11 @@ def get_record(self, record_name: str): assert isinstance(record_name, str) from copy import deepcopy - return deepcopy(self.records[record_name]) + with SqliteDict( + f"{self.local_db_dir}/{self.local_db_name}", + autocommit=True, + ) as db: + return db[record_name] def slice_record(self, record_name: str, min: int = 0, max: int = -1) -> Record: """ @@ -663,22 +738,32 @@ def slice_record(self, record_name: str, min: int = 0, max: int = -1) -> Record: Record: A copy of the sliced record. """ assert isinstance(record_name, str) + with SqliteDict( + f"{self.local_db_dir}/{self.local_db_name}", + autocommit=True, + ) as db: + return db[record_name].slice_record(min=min, max=max) - return self.records[record_name].slice_record(min=min, max=max) + # return self.records[record_name].slice_record(min=min, max=max) def subset_dataset( self, + new_dataset_name: str, total_records: Optional[int] = None, total_configurations: Optional[int] = None, max_configurations_per_record: Optional[int] = None, atomic_numbers_to_limit: Optional[np.ndarray] = None, max_force: Optional[unit.Quantity] = None, + local_db_dir: Optional[str] = None, + local_db_name: Optional[str] = None, ) -> Self: """ Subset the dataset to only include a certain species Parameters ---------- + new_dataset_name: str + Name of the new dataset that will be returned. Cannot be the same as the current dataset name. total_records: Optional[int], default=None Maximum number of records to include in the subset. total_configurations: Optional[int], default=None @@ -690,11 +775,19 @@ def subset_dataset( Any molecules that contain elements outside of this list will be igonored max_force: Optional[unit.Quantity], default=None If set, configurations with forces greater than this value will be removed. + local_db_dir: str, optional, default=None + Directory to store the local database for the new dataset. If not defined, will use the same directory as the current dataset. + local_db_name: str, optional, default=None + Name of the cache database for the new dataset. If None, the dataset name will be used. + Returns ------- - SourceDataset: A copy of the subset of the dataset. + SourceDataset: A new dataset that corresponds to the desired subset. """ - + if new_dataset_name == self.dataset_name: + raise ValueError( + "New dataset name cannot be the same as the current dataset name." + ) if total_records is not None and total_configurations is not None: raise ValueError( "Cannot set both total_records and total_conformers. Please choose one." @@ -718,9 +811,17 @@ def subset_dataset( log.warning("Using all records in the dataset instead.") total_records = len(self.records) + if local_db_dir is None: + local_db_dir = self.local_db_dir + if local_db_name is None: + local_db_name = new_dataset_name.replace(" ", "_") + ".sqlite" + # create a new empty dataset, we will add records that meet the criteria to this dataset new_dataset = SourceDataset( - self.dataset_name, append_property=self.append_property + dataset_name=new_dataset_name, + append_property=self.append_property, + local_db_dir=local_db_dir, + local_db_name=local_db_name, ) # if we are limiting the total conformers @@ -728,52 +829,92 @@ def subset_dataset( total_configurations_to_add = total_configurations - for record_name in self.records.keys(): - - if total_configurations_to_add > 0: - - record = self.get_record(record_name) - - # if we have a max force, we will remove configurations with forces greater than the max_force - # we will just overwrite the record with the new record - if max_force is not None: - record = record.remove_high_force_configs(max_force) + with SqliteDict( + f"{self.local_db_dir}/{self.local_db_name}", + autocommit=True, + ) as db: + for record_name in self.records.keys(): + + if total_configurations_to_add > 0: + + record = db[record_name] + + # if we have a max force, we will remove configurations with forces greater than the max_force + # we will just overwrite the record with the new record + if max_force is not None: + record = record.remove_high_force_configs(max_force) + + # we need to set the total configs in the record AFTER we have done any force filtering + n_configs = record.n_configs + + # if the record does not contain the appropriate atomic species, we will skip it + # and move on to the next iteration + if atomic_numbers_to_limit is not None: + if not record.contains_atomic_numbers( + atomic_numbers_to_limit + ): + continue + + # if we set a max number of configurations we want per record, consider that here + if max_configurations_per_record is not None: + n_configs_to_add = max_configurations_per_record + # if we have fewer than the max, just set to n_configs + if n_configs < max_configurations_per_record: + n_configs_to_add = n_configs + # now check to see if the n_configs_to_add is more than what we still need to hit max conformers + if n_configs_to_add > total_configurations_to_add: + n_configs_to_add = total_configurations_to_add + # even if we don't limit the number of configurations, we may still need to limit the number we + # include to satisfy the total number of configurations + else: + n_configs_to_add = n_configs + if n_configs_to_add > total_configurations_to_add: + n_configs_to_add = total_configurations_to_add - # we need to set the total configs in the record AFTER we have done any force filtering - n_configs = record.n_configs + record = record.slice_record(0, n_configs_to_add) + new_dataset.add_record(record) + total_configurations_to_add -= n_configs_to_add + return new_dataset - # if the record does not contain the appropriate atomic species, we will skip it - # and move on to the next iteration - if atomic_numbers_to_limit is not None: - if not record.contains_atomic_numbers(atomic_numbers_to_limit): - continue + elif total_records is not None: + with SqliteDict( + f"{self.local_db_dir}/{self.local_db_name}", + autocommit=True, + ) as db: + total_records_to_add = total_records + for record_name in self.records.keys(): + if total_records_to_add > 0: + record = db[record_name] + # if we have a max force, we will remove configurations with forces greater than the max_force + # we will just overwrite the record with the new record + if max_force is not None: + record = record.remove_high_force_configs(max_force) + # if the record does not contain the appropriate atomic species, we will skip it + # and move on to the next iteration + if atomic_numbers_to_limit is not None: + if not record.contains_atomic_numbers( + atomic_numbers_to_limit + ): + continue + # if we have a max number of configurations we want per record, consider that here + if max_configurations_per_record is not None: + n_to_add = min( + max_configurations_per_record, record.n_configs + ) - # if we set a max number of configurations we want per record, consider that here - if max_configurations_per_record is not None: - n_configs_to_add = max_configurations_per_record - # if we have fewer than the max, just set to n_configs - if n_configs < max_configurations_per_record: - n_configs_to_add = n_configs - # now check to see if the n_configs_to_add is more than what we still need to hit max conformers - if n_configs_to_add > total_configurations_to_add: - n_configs_to_add = total_configurations_to_add - # even if we don't limit the number of configurations, we may still need to limit the number we - # include to satisfy the total number of configurations - else: - n_configs_to_add = n_configs - if n_configs_to_add > total_configurations_to_add: - n_configs_to_add = total_configurations_to_add - - record = record.slice_record(0, n_configs_to_add) - new_dataset.add_record(record) - total_configurations_to_add -= n_configs_to_add - return new_dataset + record = record.slice_record(0, n_to_add) - elif total_records is not None: - total_records_to_add = total_records - for record_name in self.records.keys(): - if total_records_to_add > 0: - record = self.get_record(record_name) + new_dataset.add_record(record) + total_records_to_add -= 1 + return new_dataset + # if we are not going to be limiting the total number of configurations or records + else: + with SqliteDict( + f"{self.local_db_dir}/{self.local_db_name}", + autocommit=True, + ) as db: + for record_name in self.records.keys(): + record = db[record_name] # if we have a max force, we will remove configurations with forces greater than the max_force # we will just overwrite the record with the new record if max_force is not None: @@ -783,34 +924,12 @@ def subset_dataset( if atomic_numbers_to_limit is not None: if not record.contains_atomic_numbers(atomic_numbers_to_limit): continue - # if we have a max number of configurations we want per record, consider that here if max_configurations_per_record is not None: n_to_add = min(max_configurations_per_record, record.n_configs) - record = record.slice_record(0, n_to_add) new_dataset.add_record(record) - total_records_to_add -= 1 - return new_dataset - # if we are not going to be limiting the total number of configurations or records - else: - for record_name in self.records.keys(): - record = self.get_record(record_name) - # if we have a max force, we will remove configurations with forces greater than the max_force - # we will just overwrite the record with the new record - if max_force is not None: - record = record.remove_high_force_configs(max_force) - # if the record does not contain the appropriate atomic species, we will skip it - # and move on to the next iteration - if atomic_numbers_to_limit is not None: - if not record.contains_atomic_numbers(atomic_numbers_to_limit): - continue - if max_configurations_per_record is not None: - n_to_add = min(max_configurations_per_record, record.n_configs) - record = record.slice_record(0, n_to_add) - - new_dataset.add_record(record) - return new_dataset + return new_dataset def validate_record(self, record_name: str): """ @@ -831,107 +950,112 @@ def validate_record(self, record_name: str): validation_status = True # every record should have atomic numbers, positions, and energies # make sure atomic_numbers have been set - if self.records[record_name].atomic_numbers is None: - validation_status = False - log.error( - f"No atomic numbers set for record {record_name}. These are required." - ) - # raise ValueError( - # f"No atomic numbers set for record {record_name}. These are required." - # ) - - # ensure we added positions and energies as these are the minimum requirements for a dataset along with - # atomic_numbers - positions_in_properties = False - for property in self.records[record_name].per_atom.keys(): - if isinstance(self.records[record_name].per_atom[property], Positions): - positions_in_properties = True - break - if positions_in_properties == False: - validation_status = False - log.error( - f"No positions found in properties for record {record_name}. These are required." - ) - # raise ValueError( - # f"No positions found in properties for record {record_name}. These are required." - # ) - - # we need to ensure we have some type of energy defined - energy_in_properties = False - for property in self.records[record_name].per_system.keys(): - if isinstance(self.records[record_name].per_system[property], Energies): - energy_in_properties = True - break - - if energy_in_properties == False: - validation_status = False - log.error( - f"No energies found in properties for record {record_name}. These are required." - ) - # raise ValueError( - # f"No energies found in properties for record {record_name}. These are required." - # ) - - # run record validation for number of atoms - # this will check that all per_atom properties have the same number of atoms as the atomic numbers - if self.records[record_name]._validate_n_atoms() == False: - validation_status = False - log.error( - f"Number of atoms for properties in record {record_name} are not consistent." - ) - # raise ValueError( - # f"Number of atoms for properties in record {record_name} are not consistent." - # ) - # run record validation for number of configurations - # this will check that all properties have the same number of configurations - if self.records[record_name]._validate_n_configs() == False: - validation_status = False - log.error( - f"Number of configurations for properties in record {record_name} are not consistent." - ) - # raise ValueError( - # f"Number of configurations for properties in record {record_name} are not consistent." - # ) - - # check that the units provided are compatible with the expected units for the property type - # e.g., ensure things that should be length have units of length. - for property in self.records[record_name].per_atom.keys(): - - property_record = self.records[record_name].per_atom[property] - - property_type = property_record.property_type - # first double check that this did indeed get pushed to the correct sub dictionary - assert property_record.classification == PropertyClassification.per_atom + with SqliteDict( + f"{self.local_db_dir}/{self.local_db_name}", + autocommit=True, + ) as db: + record = db[record_name] + if record.atomic_numbers is None: + validation_status = False + log.error( + f"No atomic numbers set for record {record_name}. These are required." + ) + # raise ValueError( + # f"No atomic numbers set for record {record_name}. These are required." + # ) - property_units = property_record.units - expected_units = GlobalUnitSystem.get_units(property_type) - # check to make sure units are compatible with the expected units for the property type - if not expected_units.is_compatible_with(property_units, "chem"): + # ensure we added positions and energies as these are the minimum requirements for a dataset along with + # atomic_numbers + positions_in_properties = False + for property in record.per_atom.keys(): + if isinstance(record.per_atom[property], Positions): + positions_in_properties = True + break + if positions_in_properties == False: validation_status = False log.error( - f"Unit of {property_record.name} is not compatible with the expected unit {expected_units} for record {record_name}." + f"No positions found in properties for record {record_name}. These are required." ) # raise ValueError( - # f"Unit of {property_record.name} is not compatible with the expected unit {expected_units} for record {record_name}." + # f"No positions found in properties for record {record_name}. These are required." # ) - for property in self.records[record_name].per_system.keys(): - property_record = self.records[record_name].per_system[property] + # we need to ensure we have some type of energy defined + energy_in_properties = False + for property in record.per_system.keys(): + if isinstance(record.per_system[property], Energies): + energy_in_properties = True + break - # check that the number of atoms in the property matches the number of atoms in the atomic numbers - property_type = property_record.property_type + if energy_in_properties == False: + validation_status = False + log.error( + f"No energies found in properties for record {record_name}. These are required." + ) + # raise ValueError( + # f"No energies found in properties for record {record_name}. These are required." + # ) - assert property_record.classification == PropertyClassification.per_system - expected_units = GlobalUnitSystem.get_units(property_type) - property_units = property_record.units - if not expected_units.is_compatible_with(property_units, "chem"): + # run record validation for number of atoms + # this will check that all per_atom properties have the same number of atoms as the atomic numbers + if record._validate_n_atoms() == False: validation_status = False log.error( - f"Unit of {property_record.name} is not compatible with the expected unit {expected_units} for record {record_name}." + f"Number of atoms for properties in record {record_name} are not consistent." ) # raise ValueError( - # f"Unit of {property_record.name} is not compatible with the expected unit {expected_units} for record {record_name}." + # f"Number of atoms for properties in record {record_name} are not consistent." # ) + # run record validation for number of configurations + # this will check that all properties have the same number of configurations + if record._validate_n_configs() == False: + validation_status = False + log.error( + f"Number of configurations for properties in record {record_name} are not consistent." + ) + # raise ValueError( + # f"Number of configurations for properties in record {record_name} are not consistent." + # ) + + # check that the units provided are compatible with the expected units for the property type + # e.g., ensure things that should be length have units of length. + for property in record.per_atom.keys(): + + property_record = record.per_atom[property] + + property_type = property_record.property_type + # first double check that this did indeed get pushed to the correct sub dictionary + assert property_record.classification == PropertyClassification.per_atom + + property_units = property_record.units + expected_units = GlobalUnitSystem.get_units(property_type) + # check to make sure units are compatible with the expected units for the property type + if not expected_units.is_compatible_with(property_units, "chem"): + validation_status = False + log.error( + f"Unit of {property_record.name} is not compatible with the expected unit {expected_units} for record {record_name}." + ) + # raise ValueError( + # f"Unit of {property_record.name} is not compatible with the expected unit {expected_units} for record {record_name}." + # ) + + for property in record.per_system.keys(): + property_record = record.per_system[property] + + # check that the number of atoms in the property matches the number of atoms in the atomic numbers + property_type = property_record.property_type + + assert ( + property_record.classification == PropertyClassification.per_system + ) + expected_units = GlobalUnitSystem.get_units(property_type) + property_units = property_record.units + if not expected_units.is_compatible_with(property_units, "chem"): + validation_status = False + log.error( + f"Unit of {property_record.name} is not compatible with the expected unit {expected_units} for record {record_name}." + ) + return validation_status def validate_records(self): @@ -981,38 +1105,39 @@ def _generate_dataset_summary(self, checksum: str, file_name: str): output_dict["total_configurations"] = self.total_configs() key = list(self.records.keys())[0] - - temp_props = {} - temp_props["atomic_numbers"] = { - "classification": str(self.records[key].atomic_numbers.classification), - } - - for prop in self.records[key].per_atom.keys(): - temp_props[prop] = { - "classification": str(self.records[key].per_atom[prop].classification), - "units": str( - GlobalUnitSystem.get_units( - self.records[key].per_atom[prop].property_type - ) - ), + with SqliteDict( + f"{self.local_db_dir}/{self.local_db_name}", + autocommit=True, + ) as db: + record = db[key] + + temp_props = {} + temp_props["atomic_numbers"] = { + "classification": str(record.atomic_numbers.classification), } - for prop in self.records[key].per_system.keys(): - temp_props[prop] = { - "classification": str( - self.records[key].per_system[prop].classification - ), - "units": str( - GlobalUnitSystem.get_units( - self.records[key].per_system[prop].property_type - ) - ), - } + for prop in record.per_atom.keys(): + temp_props[prop] = { + "classification": str(record.per_atom[prop].classification), + "units": str( + GlobalUnitSystem.get_units(record.per_atom[prop].property_type) + ), + } + + for prop in record.per_system.keys(): + temp_props[prop] = { + "classification": str(record.per_system[prop].classification), + "units": str( + GlobalUnitSystem.get_units( + record.per_system[prop].property_type + ) + ), + } - for prop in self.records[key].meta_data.keys(): - temp_props[prop] = "meta_data" + for prop in record.meta_data.keys(): + temp_props[prop] = "meta_data" - output_dict["properties"] = temp_props + output_dict["properties"] = temp_props return output_dict @@ -1036,6 +1161,16 @@ def summary_to_json( with open(f"{file_path}/{file_name}", "w") as f: json.dump(dataset_summary, f, indent=4) + def _remove_local_db(self): + """ + Remove the local database file. + + Returns + ------- + + """ + os.remove(f"{self.local_db_dir}/{self.local_db_name}") + def to_hdf5(self, file_path: str, file_name: str): """ Write the dataset to an HDF5 file. @@ -1074,27 +1209,27 @@ def to_hdf5(self, file_path: str, file_name: str): with OpenWithLock(f"{full_file_path}.lockfile", "w") as lockfile: with h5py.File(full_file_path, "w") as f: - for record in tqdm(self.records.keys()): - record_group = f.create_group(record) + for record_name in tqdm(self.records.keys()): + record_group = f.create_group(record_name) + + record = self.get_record(record_name) record_group.create_dataset( "atomic_numbers", - data=self.records[record].atomic_numbers.value, - shape=self.records[record].atomic_numbers.value.shape, + data=record.atomic_numbers.value, + shape=record.atomic_numbers.value.shape, ) record_group["atomic_numbers"].attrs["format"] = str( - self.records[record].atomic_numbers.classification + record.atomic_numbers.classification ) record_group["atomic_numbers"].attrs["property_type"] = str( - self.records[record].atomic_numbers.property_type + record.atomic_numbers.property_type ) - record_group.create_dataset( - "n_configs", data=self.records[record].n_configs - ) + record_group.create_dataset("n_configs", data=record.n_configs) record_group["n_configs"].attrs["format"] = "meta_data" - for key, property in self.records[record].per_atom.items(): + for key, property in record.per_atom.items(): target_units = GlobalUnitSystem.get_units( property.property_type @@ -1112,7 +1247,7 @@ def to_hdf5(self, file_path: str, file_name: str): property.property_type ) - for key, property in self.records[record].per_system.items(): + for key, property in record.per_system.items(): target_units = GlobalUnitSystem.get_units( property.property_type ) @@ -1130,7 +1265,7 @@ def to_hdf5(self, file_path: str, file_name: str): property.property_type ) - for key, property in self.records[record].meta_data.items(): + for key, property in record.meta_data.items(): if isinstance(property.value, str): record_group.create_dataset( key, diff --git a/modelforge-curate/modelforge/curate/datasets/ani2x_curation.py b/modelforge-curate/modelforge/curate/datasets/ani2x_curation.py index 09daa72f..c7ff092c 100644 --- a/modelforge-curate/modelforge/curate/datasets/ani2x_curation.py +++ b/modelforge-curate/modelforge/curate/datasets/ani2x_curation.py @@ -99,7 +99,7 @@ def _process_downloaded( conformers_counter = 0 - dataset = SourceDataset("ani2x") + dataset = SourceDataset(dataset_name="ani2x", local_db_dir=self.local_cache_dir) with h5py.File(input_file_name, "r") as hf: # The ani2x hdf5 file groups molecules by number of atoms # we need to break up each of these groups into individual molecules diff --git a/modelforge-curate/modelforge/curate/datasets/curation_baseclass.py b/modelforge-curate/modelforge/curate/datasets/curation_baseclass.py index 4fd22375..38e533a9 100644 --- a/modelforge-curate/modelforge/curate/datasets/curation_baseclass.py +++ b/modelforge-curate/modelforge/curate/datasets/curation_baseclass.py @@ -324,7 +324,19 @@ def to_hdf5( or max_force is not None or total_records is not None ): + import time + import random + + # generate a 5 digit random number to append to the dataset name + # using current time as the seed + # this database will be removed after the dataset is written + random.seed(time.time()) + number = random.randint(10000, 99999) + + new_dataset_name = f"{self.dataset.dataset_name}_temp_{number}" + dataset_trimmed = self.dataset.subset_dataset( + new_dataset_name=new_dataset_name, total_configurations=total_configurations, total_records=total_records, max_configurations_per_record=max_configurations_per_record, @@ -339,7 +351,13 @@ def to_hdf5( file_name=hdf5_file_name, file_path=output_file_dir, ) - return (dataset_trimmed.total_records(), dataset_trimmed.total_configs()) + n_total_records = dataset_trimmed.total_records() + n_total_configs = dataset_trimmed.total_configs() + + # remove the database associated with the temporarily created dataset + dataset_trimmed._remove_local_db() + + return (n_total_records, n_total_configs) else: self._write_hdf5_and_json_files( diff --git a/modelforge-curate/modelforge/curate/datasets/phalkethoh_curation.py b/modelforge-curate/modelforge/curate/datasets/phalkethoh_curation.py index 0c1b6827..e031df6b 100644 --- a/modelforge-curate/modelforge/curate/datasets/phalkethoh_curation.py +++ b/modelforge-curate/modelforge/curate/datasets/phalkethoh_curation.py @@ -219,7 +219,11 @@ def _process_downloaded( from numpy import newaxis - dataset = SourceDataset("PhAlkEthOH_openff", append_property=True) + dataset = SourceDataset( + dataset_name="PhAlkEthOH_openff", + append_property=True, + local_db_dir=self.local_cache_dir, + ) for filename, dataset_name in zip(filenames, dataset_names): input_file_name = f"{local_path_dir}/{filename}" diff --git a/modelforge-curate/modelforge/curate/datasets/qm9_curation.py b/modelforge-curate/modelforge/curate/datasets/qm9_curation.py index a955cd45..1d26ae3e 100644 --- a/modelforge-curate/modelforge/curate/datasets/qm9_curation.py +++ b/modelforge-curate/modelforge/curate/datasets/qm9_curation.py @@ -404,7 +404,7 @@ def _process_downloaded( # we do not need to do anything check unit_testing_max_conformers_per_record because qm9 only has a single conformer per record - dataset = SourceDataset("qm9") + dataset = SourceDataset(dataset_name="qm9", local_db_dir=self.local_cache_dir) for i, file in enumerate(tqdm(files, desc="processing", total=len(files))): record_temp = self._parse_xyzfile(f"{local_path_dir}/{file}") dataset.add_record(record_temp) diff --git a/modelforge-curate/modelforge/curate/datasets/scripts/curate_spice1.py b/modelforge-curate/modelforge/curate/datasets/scripts/curate_spice1.py index b491dcd4..ae05d343 100644 --- a/modelforge-curate/modelforge/curate/datasets/scripts/curate_spice1.py +++ b/modelforge-curate/modelforge/curate/datasets/scripts/curate_spice1.py @@ -23,73 +23,12 @@ """ -def spice1_wrapper( - hdf5_file_name: str, - output_file_dir: str, - local_cache_dir: str, - force_download: bool = False, - version_select: str = "latest", - max_records=None, - max_conformers_per_record=None, - total_conformers=None, - limit_atomic_species=None, -): - """ - This curates and processes the SPICE2 dataset into an hdf5 file. - - - Parameters - ---------- - hdf5_file_name: str, required - Name of the hdf5 file that will be generated. - output_file_dir: str, required - Directory where the hdf5 file will be saved. - local_cache_dir: str, required - Directory where the intermediate data will be saved; in this case it will be tarred file downloaded - from figshare and the expanded archive that contains xyz files for each molecule in the dataset. - force_download: bool, optional, default=False - If False, we will use the tarred file that exists in the local_cache_dir (if it exists); - If True, the tarred file will be downloaded, even if it exists locally. - version_select: str, optional, default="latest" - The version of the dataset to use as defined in the associated yaml file. - If "latest", the most recent version will be used. - max_records: int, optional, default=None - The maximum number of records to process. - max_conformers_per_record: int, optional, default=None - The maximum number of conformers to process for each record. - total_conformers: int, optional, default=None - The total number of conformers to process. - limit_atomic_species: list, optional, default=None - A list of atomic species to limit the dataset to. Any molecules that contain elements outside of this list - will be ignored. If not defined, no filtering by atomic species will be performed. - - """ - from modelforge.curate.datasets.spice_1_curation import SPICE1Curation - - spice_1_data = SPICE1Curation( - hdf5_file_name=hdf5_file_name, - output_file_dir=output_file_dir, - local_cache_dir=local_cache_dir, - version_select=version_select, - ) - - spice_1_data.process( - force_download=force_download, - max_records=max_records, - max_conformers_per_record=max_conformers_per_record, - total_conformers=total_conformers, - limit_atomic_species=limit_atomic_species, - ) - print(f"Total records: {spice_1_data.total_records()}") - print(f"Total configs: {spice_1_data.total_configs()}") - - def main(): # define the location where to store and output the files import os local_prefix = os.path.expanduser("~/mf_datasets") - output_file_dir = f"{local_prefix}/hdf5_files" + output_file_dir = f"{local_prefix}/hdf5_files/spice1" local_cache_dir = f"{local_prefix}/spice1_dataset" # We'll want to provide some simple means of versioning @@ -99,58 +38,70 @@ def main(): version_select = f"v_0" # version v_0 corresponds to SPICE 1.1.4 + # start with processing the full dataset + from modelforge.curate.datasets.spice_1_curation import SPICE1Curation + + spice1_dataset = SPICE1Curation( + local_cache_dir=local_cache_dir, + version_select=version_select, + ) + + spice1_dataset.process(force_download=False) ani2x_elements = ["H", "C", "N", "O", "F", "Cl", "S"] - # curate SPICE 2.0.1 dataset with 1000 total conformers, max of 10 conformers per record + # curate SPICE 1.1.4 dataset with 1000 total configurations, max of 10 conformers per record # limited to the elements that will work with ANI2x hdf5_file_name = f"spice_1_dataset_v{version}_ntc_1000_HCNOFClS.hdf5" - spice1_wrapper( - hdf5_file_name, - output_file_dir, - local_cache_dir, - force_download=False, - version_select=version_select, - max_conformers_per_record=10, - total_conformers=1000, - limit_atomic_species=ani2x_elements, + spice1_dataset.to_hdf5( + hdf5_file_name=hdf5_file_name, + output_file_dir=output_file_dir, + total_configurations=1000, + max_configurations_per_record=10, + atomic_species_to_limit=ani2x_elements, ) - # curate the full SPICE 2.0.1 dataset, limited to the elements that will work with ANI2x + + print("SPICE1: 1000 configuration subset limited to ANI2x elements") + print(f"Total records: {spice1_dataset.total_records()}") + print(f"Total configs: {spice1_dataset.total_configs()}") + + # curate the full SPICE 1.1.4 dataset, limited to the elements that will work with ANI2x hdf5_file_name = f"spice_1_dataset_v{version}_HCNOFClS.hdf5" - spice1_wrapper( - hdf5_file_name, - output_file_dir, - local_cache_dir, - force_download=False, - version_select=version_select, - limit_atomic_species=ani2x_elements, + spice1_dataset.to_hdf5( + hdf5_file_name=hdf5_file_name, + output_file_dir=output_file_dir, + atomic_species_to_limit=ani2x_elements, ) - # curate the test SPICE 2.0.1 dataset with 1000 total conformers, max of 10 conformers per record + print("SPICE1: full dataset limited to ANI2x elements") + print(f"Total records: {spice1_dataset.total_records()}") + print(f"Total configs: {spice1_dataset.total_configs()}") + + # curate the test SPICE 1.1.4 dataset with 1000 total configurations, max of 10 configurations per record hdf5_file_name = f"spice_1_dataset_v{version}_ntc_1000.hdf5" - spice1_wrapper( - hdf5_file_name, - output_file_dir, - local_cache_dir, - force_download=False, - version_select=version_select, - max_conformers_per_record=10, - total_conformers=1000, + spice1_dataset.to_hdf5( + hdf5_file_name=hdf5_file_name, + output_file_dir=output_file_dir, + total_configurations=1000, + max_configurations_per_record=10, ) - # curate the full SPICE 2.0.1 dataset + print("SPICE1: 1000 configuration subset") + print(f"Total records: {spice1_dataset.total_records()}") + print(f"Total configs: {spice1_dataset.total_configs()}") + + # curate the full SPICE 1.1.4 dataset hdf5_file_name = f"spice_1_dataset_v{version}.hdf5" - spice1_wrapper( - hdf5_file_name, - output_file_dir, - local_cache_dir, - force_download=False, - version_select=version_select, + spice1_dataset.to_hdf5( + hdf5_file_name=hdf5_file_name, output_file_dir=output_file_dir ) + print("SPICE1: full dataset") + print(f"Total records: {spice1_dataset.total_records()}") + print(f"Total configs: {spice1_dataset.total_configs()}") if __name__ == "__main__": diff --git a/modelforge-curate/modelforge/curate/datasets/scripts/curate_spice2.py b/modelforge-curate/modelforge/curate/datasets/scripts/curate_spice2.py index 326842bb..9649c104 100644 --- a/modelforge-curate/modelforge/curate/datasets/scripts/curate_spice2.py +++ b/modelforge-curate/modelforge/curate/datasets/scripts/curate_spice2.py @@ -23,73 +23,12 @@ """ -def spice2_wrapper( - hdf5_file_name: str, - output_file_dir: str, - local_cache_dir: str, - force_download: bool = False, - version_select: str = "latest", - max_records=None, - max_conformers_per_record=None, - total_conformers=None, - limit_atomic_species=None, -): - """ - This curates and processes the SPICE2 dataset into an hdf5 file. - - - Parameters - ---------- - hdf5_file_name: str, required - Name of the hdf5 file that will be generated. - output_file_dir: str, required - Directory where the hdf5 file will be saved. - local_cache_dir: str, required - Directory where the intermediate data will be saved; in this case it will be tarred file downloaded - from figshare and the expanded archive that contains xyz files for each molecule in the dataset. - force_download: bool, optional, default=False - If False, we will use the tarred file that exists in the local_cache_dir (if it exists); - If True, the tarred file will be downloaded, even if it exists locally. - version_select: str, optional, default="latest" - The version of the dataset to use as defined in the associated yaml file. - If "latest", the most recent version will be used. - max_records: int, optional, default=None - The maximum number of records to process. - max_conformers_per_record: int, optional, default=None - The maximum number of conformers to process for each record. - total_conformers: int, optional, default=None - The total number of conformers to process. - limit_atomic_species: list, optional, default=None - A list of atomic species to limit the dataset to. Any molecules that contain elements outside of this list - will be ignored. If not defined, no filtering by atomic species will be performed. - - """ - from modelforge.curate.datasets.spice_2_curation import SPICE2Curation - - spice_2_data = SPICE2Curation( - hdf5_file_name=hdf5_file_name, - output_file_dir=output_file_dir, - local_cache_dir=local_cache_dir, - version_select=version_select, - ) - - spice_2_data.process( - force_download=force_download, - max_records=max_records, - max_conformers_per_record=max_conformers_per_record, - total_conformers=total_conformers, - limit_atomic_species=limit_atomic_species, - ) - print(f"Total records: {spice_2_data.total_records()}") - print(f"Total configs: {spice_2_data.total_configs()}") - - def main(): # define the location where to store and output the files import os local_prefix = os.path.expanduser("~/mf_datasets") - output_file_dir = f"{local_prefix}/hdf5_files" + output_file_dir = f"{local_prefix}/hdf5_files/spice2" local_cache_dir = f"{local_prefix}/spice2_dataset" # We'll want to provide some simple means of versioning @@ -99,58 +38,70 @@ def main(): version_select = f"v_0" # version v_0 corresponds to SPICE 2.0.1 + # start with processing the full dataset + from modelforge.curate.datasets.spice_2_curation import SPICE2Curation + + spice2_dataset = SPICE2Curation( + local_cache_dir=local_cache_dir, + version_select=version_select, + ) + + spice2_dataset.process(force_download=False) ani2x_elements = ["H", "C", "N", "O", "F", "Cl", "S"] - # curate SPICE 2.0.1 dataset with 1000 total conformers, max of 10 conformers per record + # curate SPICE 2.0.1 dataset with 1000 total configurations, max of 10 conformers per record # limited to the elements that will work with ANI2x hdf5_file_name = f"spice_2_dataset_v{version}_ntc_1000_HCNOFClS.hdf5" - spice2_wrapper( - hdf5_file_name, - output_file_dir, - local_cache_dir, - force_download=False, - version_select=version_select, - max_conformers_per_record=10, - total_conformers=1000, - limit_atomic_species=ani2x_elements, + spice2_dataset.to_hdf5( + hdf5_file_name=hdf5_file_name, + output_file_dir=output_file_dir, + total_configurations=1000, + max_configurations_per_record=10, + atomic_species_to_limit=ani2x_elements, ) + + print("SPICE2: 1000 configuration subset limited to ANI2x elements") + print(f"Total records: {spice2_dataset.total_records()}") + print(f"Total configs: {spice2_dataset.total_configs()}") + # curate the full SPICE 2.0.1 dataset, limited to the elements that will work with ANI2x hdf5_file_name = f"spice_2_dataset_v{version}_HCNOFClS.hdf5" - spice2_wrapper( - hdf5_file_name, - output_file_dir, - local_cache_dir, - force_download=False, - version_select=version_select, - limit_atomic_species=ani2x_elements, + spice2_dataset.to_hdf5( + hdf5_file_name=hdf5_file_name, + output_file_dir=output_file_dir, + atomic_species_to_limit=ani2x_elements, ) - # curate the test SPICE 2.0.1 dataset with 1000 total conformers, max of 10 conformers per record + print("SPICE2: full dataset limited to ANI2x elements") + print(f"Total records: {spice2_dataset.total_records()}") + print(f"Total configs: {spice2_dataset.total_configs()}") + + # curate the test SPICE 2.0.1 dataset with 1000 total configurations, max of 10 configurations per record hdf5_file_name = f"spice_2_dataset_v{version}_ntc_1000.hdf5" - spice2_wrapper( - hdf5_file_name, - output_file_dir, - local_cache_dir, - force_download=False, - version_select=version_select, - max_conformers_per_record=10, - total_conformers=1000, + spice2_dataset.to_hdf5( + hdf5_file_name=hdf5_file_name, + output_file_dir=output_file_dir, + total_configurations=1000, + max_configurations_per_record=10, ) + print("SPICE2: 1000 configuration subset") + print(f"Total records: {spice2_dataset.total_records()}") + print(f"Total configs: {spice2_dataset.total_configs()}") + # curate the full SPICE 2.0.1 dataset hdf5_file_name = f"spice_2_dataset_v{version}.hdf5" - spice2_wrapper( - hdf5_file_name, - output_file_dir, - local_cache_dir, - force_download=False, - version_select=version_select, + spice2_dataset.to_hdf5( + hdf5_file_name=hdf5_file_name, output_file_dir=output_file_dir ) + print("SPICE2: full dataset") + print(f"Total records: {spice2_dataset.total_records()}") + print(f"Total configs: {spice2_dataset.total_configs()}") if __name__ == "__main__": diff --git a/modelforge-curate/modelforge/curate/datasets/scripts/curate_tmqm.py b/modelforge-curate/modelforge/curate/datasets/scripts/curate_tmqm.py index a10c76ca..241b5132 100644 --- a/modelforge-curate/modelforge/curate/datasets/scripts/curate_tmqm.py +++ b/modelforge-curate/modelforge/curate/datasets/scripts/curate_tmqm.py @@ -26,7 +26,7 @@ def main(): import os local_prefix = os.path.expanduser("~/mf_datasets") - output_file_dir = f"{local_prefix}/hdf5_files" + output_file_dir = f"{local_prefix}/hdf5_files/tmqm" local_cache_dir = f"{local_prefix}/tmqm_dataset" # We'll want to provide some simple means of versioning @@ -59,17 +59,17 @@ def main(): print(f"Total records: {total_records}") print(f"Total configs: {total_configs}") - # Curate the test dataset with 1000 total conformers + # Curate the test dataset with 1000 total configurations # only a single config per record hdf5_file_name = f"tmqm_dataset_v{version_out}_ntc_1000.hdf5" total_records, total_configs = tmqm.to_hdf5( hdf5_file_name=hdf5_file_name, output_file_dir=output_file_dir, - total_conformers=1000, + total_configurations=1000, ) - print(" 1000 conformer subset") + print(" 1000 configuration subset") print(f"Total records: {total_records}") print(f"Total configs: {total_configs}") @@ -80,7 +80,7 @@ def main(): total_records, total_configs = tmqm.to_hdf5( hdf5_file_name=f"tmqm_dataset_PdZnFeCu_CHPSONFClBr_v{version_out}.hdf5", output_file_dir=output_file_dir, - limit_atomic_species=[ + atomic_species_to_limit=[ "Pd", "Zn", "Fe", @@ -97,6 +97,101 @@ def main(): ], ) + print("Primary transition metals subset") + print(f"Total records: {total_records}") + print(f"Total configs: {total_configs}") + + # same dataset but with only 1000 total_configurations + + total_records, total_configs = tmqm.to_hdf5( + hdf5_file_name=f"tmqm_dataset_PdZnFeCu_CHPSONFClBr_v{version_out}_ntc_1000.hdf5", + output_file_dir=output_file_dir, + atomic_species_to_limit=[ + "Pd", + "Zn", + "Fe", + "Cu", + "C", + "H", + "P", + "S", + "O", + "N", + "F", + "Cl", + "Br", + ], + total_configurations=1000, + ) + + print("Primary transition metals 1000 configuration subset") + print(f"Total records: {total_records}") + print(f"Total configs: {total_configs}") + + # create a dataset with a second subset of transition metals + # Pd, Zn, Fe, Cu, Ni, Pt, Ir, Rh, Cr, Ag and the same organic elements as above + + total_records, total_configs = tmqm.to_hdf5( + hdf5_file_name=f"tmqm_dataset_PdZnFeCuNiPtIrRhCrAg_CHPSONFClBr_v{version_out}.hdf5", + output_file_dir=output_file_dir, + atomic_species_to_limit=[ + "Pd", + "Zn", + "Fe", + "Cu", + "Ni", + "Pt", + "Ir", + "Rh", + "Cr", + "Ag", + "C", + "H", + "P", + "S", + "O", + "N", + "F", + "Cl", + "Br", + ], + ) + print("Primary + second transition metals subset") + print(f"Total records: {total_records}") + print(f"Total configs: {total_configs}") + + # same dataset but with only 1000 total_configurations total + total_records, total_configs = tmqm.to_hdf5( + hdf5_file_name=f"tmqm_dataset_PdZnFeCuNiPtIrRhCrAg_CHPSONFClBr_v{version_out}.hdf5", + output_file_dir=output_file_dir, + atomic_species_to_limit=[ + "Pd", + "Zn", + "Fe", + "Cu", + "Ni", + "Pt", + "Ir", + "Rh", + "Cr", + "Ag", + "C", + "H", + "P", + "S", + "O", + "N", + "F", + "Cl", + "Br", + ], + total_configurations=1000, + ) + + print("Primary + second transition metals 1000 configuration subset") + print(f"Total records: {total_records}") + print(f"Total configs: {total_configs}") + if __name__ == "__main__": main() diff --git a/modelforge-curate/modelforge/curate/datasets/spice_1_curation.py b/modelforge-curate/modelforge/curate/datasets/spice_1_curation.py index 17572f7d..d501784f 100644 --- a/modelforge-curate/modelforge/curate/datasets/spice_1_curation.py +++ b/modelforge-curate/modelforge/curate/datasets/spice_1_curation.py @@ -129,7 +129,9 @@ def _process_downloaded( input_file_name = f"{local_path_dir}/{name}" - dataset = SourceDataset("spice1") + dataset = SourceDataset( + dataset_name="spice1", local_db_dir=self.local_cache_dir + ) with OpenWithLock(f"{input_file_name}.lockfile", "w") as lockfile: with h5py.File(input_file_name, "r") as hf: diff --git a/modelforge-curate/modelforge/curate/datasets/spice_2_curation.py b/modelforge-curate/modelforge/curate/datasets/spice_2_curation.py index 646cac35..ad91c998 100644 --- a/modelforge-curate/modelforge/curate/datasets/spice_2_curation.py +++ b/modelforge-curate/modelforge/curate/datasets/spice_2_curation.py @@ -140,7 +140,9 @@ def _process_downloaded( input_file_name = f"{local_path_dir}/{name}" - dataset = SourceDataset("spice2") + dataset = SourceDataset( + dataset_name="spice2", local_db_dir=self.local_cache_dir + ) with OpenWithLock(f"{input_file_name}.lockfile", "w") as lockfile: with h5py.File(input_file_name, "r") as hf: diff --git a/modelforge-curate/modelforge/curate/datasets/tmqm_curation.py b/modelforge-curate/modelforge/curate/datasets/tmqm_curation.py index b81bf4dc..83b9b30a 100644 --- a/modelforge-curate/modelforge/curate/datasets/tmqm_curation.py +++ b/modelforge-curate/modelforge/curate/datasets/tmqm_curation.py @@ -174,7 +174,7 @@ def _process_downloaded( from modelforge.dataset.utils import _ATOMIC_ELEMENT_TO_NUMBER from modelforge.utils.misc import str_to_float - dataset = SourceDataset("tmqm") + dataset = SourceDataset(dataset_name="tmqm", local_db_dir=self.local_cache_dir) # aggregate the snapshot contents into a list snapshots = [] diff --git a/modelforge-curate/modelforge/curate/datasets/tmqm_xtb_curation.py b/modelforge-curate/modelforge/curate/datasets/tmqm_xtb_curation.py index 1b75200f..9d6a7cd1 100644 --- a/modelforge-curate/modelforge/curate/datasets/tmqm_xtb_curation.py +++ b/modelforge-curate/modelforge/curate/datasets/tmqm_xtb_curation.py @@ -143,7 +143,9 @@ def _process_downloaded( from modelforge.utils.misc import OpenWithLock import h5py - dataset = SourceDataset("tmqm_xtb") + dataset = SourceDataset( + dataset_name="tmqm_xtb", local_db_dir=self.local_cache_dir + ) with OpenWithLock(f"{local_path_dir}/{hdf5_file_name}.lockfile", "w") as f: with h5py.File(f"{local_path_dir}/{hdf5_file_name}", "r") as f: for key in tqdm(f.keys()): diff --git a/modelforge-curate/modelforge/curate/examples/basic_usage.ipynb b/modelforge-curate/modelforge/curate/examples/basic_usage.ipynb index c9b67252..8918cbbe 100644 --- a/modelforge-curate/modelforge/curate/examples/basic_usage.ipynb +++ b/modelforge-curate/modelforge/curate/examples/basic_usage.ipynb @@ -1457,7 +1457,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.8" + "version": "3.11.8" } }, "nbformat": 4, diff --git a/modelforge-curate/modelforge/curate/tests/test_curate.py b/modelforge-curate/modelforge/curate/tests/test_curate.py index 1761b94b..a9fc29a8 100644 --- a/modelforge-curate/modelforge/curate/tests/test_curate.py +++ b/modelforge-curate/modelforge/curate/tests/test_curate.py @@ -7,9 +7,13 @@ from modelforge.curate.properties import * -def test_source_dataset_init(): - new_dataset = SourceDataset("test_dataset") - assert new_dataset.dataset_name == "test_dataset" +def test_source_dataset_init(prep_temp_dir): + new_dataset = SourceDataset( + dataset_name="test_dataset1", + local_db_dir=str(prep_temp_dir), + local_db_name="test_dataset1.sqlite", + ) + assert new_dataset.dataset_name == "test_dataset1" new_dataset.create_record("mol1") assert "mol1" in new_dataset.records @@ -19,10 +23,14 @@ def test_source_dataset_init(): assert len(new_dataset.records) == 2 -def test_dataset_create_record(): +def test_dataset_create_record(prep_temp_dir): # test creating a record that already exists # this will fail - new_dataset = SourceDataset("test_dataset") + new_dataset = SourceDataset( + "test_dataset2", + local_db_dir=str(prep_temp_dir), + local_db_name="test_dataset2.sqlite", + ) new_dataset.create_record("mol1") assert "mol1" in new_dataset.records with pytest.raises(ValueError): @@ -57,7 +65,7 @@ def test_dataset_create_record(): assert "mol4" in new_dataset.records -def test_add_properties_to_records_directly(): +def test_add_properties_to_records_directly(prep_temp_dir): record = Record(name="mol1") positions = Positions(value=[[[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]], units="nanometer") @@ -93,7 +101,11 @@ def test_add_properties_to_records_directly(): assert record.n_configs == 2 - new_dataset = SourceDataset("test_dataset") + new_dataset = SourceDataset( + dataset_name="test_dataset3", + local_db_dir=str(prep_temp_dir), + local_db_name="test_dataset3.sqlite", + ) new_dataset.add_record(record) assert "mol1" in new_dataset.records.keys() @@ -113,7 +125,6 @@ def test_record_repr(capsys): assert "n_configs: cannot be determined" in out record.add_properties([positions, energies, atomic_numbers, smiles]) - print(record) out, err = capsys.readouterr() assert "name: mol1" in out @@ -141,7 +152,6 @@ def test_record_repr(capsys): " name='smiles' value='[CH]' units= classification='meta_data' property_type='meta_data' n_configs=None n_atoms=None" in out ) - print(record) def test_record_to_dict(): @@ -257,8 +267,12 @@ def test_add_properties_failures(): record.add_property(energies) -def test_add_properties(): - new_dataset = SourceDataset("test_dataset") +def test_add_properties(prep_temp_dir): + new_dataset = SourceDataset( + "test_dataset4", + local_db_dir=str(prep_temp_dir), + local_db_name="test_dataset4.sqlite", + ) new_dataset.create_record("mol1") positions = Positions(value=[[[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]], units="nanometer") energies = Energies(value=np.array([[0.1]]), units=unit.hartree) @@ -289,7 +303,7 @@ def test_add_properties(): new_dataset.add_properties("mol1", [atomic_numbers]) -def test_slicing_properties(): +def test_slicing_properties(prep_temp_dir): record = Record(name="mol1") positions = Positions( @@ -316,7 +330,11 @@ def test_slicing_properties(): assert sliced1.per_system["energies"].value == [[0.1]] # check dataset level slicing, that just calls the record level slicing - new_dataset = SourceDataset("test_dataset") + new_dataset = SourceDataset( + "test_dataset5", + local_db_dir=str(prep_temp_dir), + local_db_name="test_dataset5.sqlite", + ) new_dataset.add_record(record) sliced2 = new_dataset.slice_record("mol1", 0, 1) @@ -329,8 +347,12 @@ def test_slicing_properties(): new_dataset.slice_record(record, 0, 1) -def test_counting_records(): - new_dataset = SourceDataset("test_dataset") +def test_counting_records(prep_temp_dir): + new_dataset = SourceDataset( + "test_dataset6", + local_db_dir=str(prep_temp_dir), + local_db_name="test_dataset6.sqlite", + ) new_dataset.create_record("mol1") new_dataset.create_record("mol2") @@ -360,8 +382,13 @@ def test_counting_records(): new_dataset.validate_records() -def test_append_properties(): - new_dataset = SourceDataset("test_dataset", append_property=True) +def test_append_properties(prep_temp_dir): + new_dataset = SourceDataset( + "test_dataset7", + append_property=True, + local_db_dir=str(prep_temp_dir), + local_db_name="test_dataset7.sqlite", + ) new_dataset.create_record("mol1") positions = Positions(value=[[[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]], units="nanometer") @@ -439,7 +466,11 @@ def test_append_properties(): def test_write_hdf5(prep_temp_dir): - new_dataset = SourceDataset("test_dataset") + new_dataset = SourceDataset( + "test_dataset8", + local_db_dir=str(prep_temp_dir), + local_db_name="test_dataset8.sqlite", + ) new_dataset.create_record("mol1") positions = Positions(value=[[[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]], units="nanometer") energies = Energies(value=np.array([[0.1]]), units=unit.hartree) @@ -536,15 +567,23 @@ def test_write_hdf5(prep_temp_dir): with open(str(prep_temp_dir / "test_dataset.json"), "r") as f: data = json.load(f) - assert data["dataset_name"] == "test_dataset" + assert data["dataset_name"] == "test_dataset8" assert data["total_records"] == new_dataset.total_records() assert data["total_configurations"] == new_dataset.total_configs() assert data["md5_checksum"] == checksum assert data["filename"] == "test_dataset.hdf5" -def test_dataset_validation(): - new_dataset = SourceDataset("test_dataset") +def test_dataset_validation(prep_temp_dir): + new_dataset = SourceDataset( + dataset_name="test_dataset9", + local_db_dir=str(prep_temp_dir), + local_db_name="test_dataset9.sqlite", + ) + + assert new_dataset.local_db_dir == str(prep_temp_dir) + assert new_dataset.local_db_name == "test_dataset9.sqlite" + new_dataset.create_record("mol1") positions = Positions(value=[[[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]], units="nanometer") energies = Energies(value=np.array([[0.1]]), units=unit.hartree) @@ -569,8 +608,12 @@ def test_dataset_validation(): assert new_dataset.validate_records() == True -def test_dataset_subsetting(): - ds = SourceDataset("test_dataset") +def test_dataset_subsetting(prep_temp_dir): + ds = SourceDataset(dataset_name="test dataset10", local_db_dir=str(prep_temp_dir)) + + assert ds.local_db_dir == str(prep_temp_dir) + assert ds.local_db_name == "test_dataset10.sqlite" + positions = Positions( value=[ [[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]], @@ -594,50 +637,68 @@ def test_dataset_subsetting(): assert ds.total_records() == 10 # check total records - ds_subset = ds.subset_dataset(total_records=5) + ds_subset = ds.subset_dataset(new_dataset_name="test_dataset_sub1", total_records=5) assert ds_subset.total_configs() == 25 assert ds_subset.total_records() == 5 - ds_subset = ds.subset_dataset(total_records=3) + assert ds_subset.dataset_name == "test_dataset_sub1" + assert ds_subset.local_db_name == "test_dataset_sub1.sqlite" + assert ds_subset.local_db_dir == ds.local_db_dir + + ds_subset = ds.subset_dataset(new_dataset_name="test_dataset_sub2", total_records=3) assert ds_subset.total_configs() == 15 assert ds_subset.total_records() == 3 # check total_records and max_configurations_per_record - ds_subset = ds.subset_dataset(total_records=3, max_configurations_per_record=2) + ds_subset = ds.subset_dataset( + new_dataset_name="test_dataset_sub3", + total_records=3, + max_configurations_per_record=2, + ) assert ds_subset.total_configs() == 6 assert ds_subset.total_records() == 3 # check total_conformers - ds_subset = ds.subset_dataset(total_configurations=20) + ds_subset = ds.subset_dataset( + new_dataset_name="test_dataset_sub4", total_configurations=20 + ) assert ds_subset.total_configs() == 20 assert ds_subset.total_records() == 4 ds_subset = ds.subset_dataset( - total_configurations=20, max_configurations_per_record=2 + new_dataset_name="test_dataset_sub5", + total_configurations=20, + max_configurations_per_record=2, ) assert ds_subset.total_configs() == 20 assert ds_subset.total_records() == 10 ds_subset = ds.subset_dataset( - total_configurations=20, max_configurations_per_record=6 + new_dataset_name="test_dataset_sub6", + total_configurations=20, + max_configurations_per_record=6, ) assert ds_subset.total_configs() == 20 assert ds_subset.total_records() == 4 ds_subset = ds.subset_dataset( - total_configurations=11, max_configurations_per_record=4 + new_dataset_name="test_dataset_sub7", + total_configurations=11, + max_configurations_per_record=4, ) assert ds_subset.total_configs() == 11 assert ds_subset.total_records() == 3 ds_subset = ds.subset_dataset( - total_configurations=11, max_configurations_per_record=5 + new_dataset_name="test_dataset_sub8", + total_configurations=11, + max_configurations_per_record=5, ) assert ds_subset.total_configs() == 11 assert ds_subset.total_records() == 3 -def test_limit_atomic_numbers(): +def test_limit_atomic_numbers(prep_temp_dir): atomic_numbers = AtomicNumbers( value=np.array( @@ -673,14 +734,17 @@ def test_limit_atomic_numbers(): record = Record("mol1") record.add_properties([atomic_numbers, energies, positions]) - dataset = SourceDataset("test") + dataset = SourceDataset( + dataset_name="test_dataset11", local_db_dir=str(prep_temp_dir) + ) dataset.add_record(record) atomic_numbers_to_limit = np.array([8, 6, 1]) assert record.contains_atomic_numbers(atomic_numbers_to_limit) == True new_dataset = dataset.subset_dataset( - atomic_numbers_to_limit=atomic_numbers_to_limit + new_dataset_name="test_dataset11_sub", + atomic_numbers_to_limit=atomic_numbers_to_limit, ) assert new_dataset.total_records() == 1 @@ -689,12 +753,20 @@ def test_limit_atomic_numbers(): assert record.contains_atomic_numbers(atomic_numbers_to_limit) == False new_dataset = dataset.subset_dataset( - atomic_numbers_to_limit=atomic_numbers_to_limit + new_dataset_name="test_dataset11_sub2", + atomic_numbers_to_limit=atomic_numbers_to_limit, ) assert new_dataset.total_records() == 0 + # test that we fail if we give the same name for a subset + with pytest.raises(ValueError): + dataset.subset_dataset( + new_dataset_name="test_dataset11", + atomic_numbers_to_limit=atomic_numbers_to_limit, + ) + -def test_remove_high_force_configs(): +def test_remove_high_force_configs(prep_temp_dir): atomic_numbers = AtomicNumbers( value=np.array( @@ -791,11 +863,14 @@ def test_remove_high_force_configs(): # test filtering via the dataset - dataset = SourceDataset("test") + dataset = SourceDataset( + dataset_name="test_dataset12", local_db_dir=str(prep_temp_dir) + ) dataset.add_record(record) # effectively the same tests as above, but done on the dataset level trimmed_dataset = dataset.subset_dataset( + new_dataset_name="test_dataset12_sub1", total_configurations=5, max_force=30 * unit.kilojoule_per_mole / unit.nanometer, ) @@ -807,7 +882,9 @@ def test_remove_high_force_configs(): ) trimmed_dataset = dataset.subset_dataset( - total_configurations=5, max_force=31 * unit.kilojoule_per_mole / unit.nanometer + new_dataset_name="test_dataset12_sub2", + total_configurations=5, + max_force=31 * unit.kilojoule_per_mole / unit.nanometer, ) assert trimmed_dataset.total_records() == 1 assert trimmed_dataset.get_record("mol1").n_configs == 4 @@ -819,7 +896,9 @@ def test_remove_high_force_configs(): # consider now including other restrictions on the number of configurations, records, etc. # limit the number of configurations in total trimmed_dataset = dataset.subset_dataset( - total_configurations=3, max_force=31 * unit.kilojoule_per_mole / unit.nanometer + new_dataset_name="test_dataset12_sub3", + total_configurations=3, + max_force=31 * unit.kilojoule_per_mole / unit.nanometer, ) assert trimmed_dataset.total_records() == 1 assert trimmed_dataset.get_record("mol1").n_configs == 3 @@ -830,6 +909,7 @@ def test_remove_high_force_configs(): # total_configurations and max_configurations_per_record trimmed_dataset = dataset.subset_dataset( + new_dataset_name="test_dataset12_sub4", total_configurations=5, max_configurations_per_record=3, max_force=31 * unit.kilojoule_per_mole / unit.nanometer, @@ -852,7 +932,9 @@ def test_remove_high_force_configs(): # limit total_configurations trimmed_dataset = dataset.subset_dataset( - total_configurations=5, max_force=30 * unit.kilojoule_per_mole / unit.nanometer + new_dataset_name="test_dataset12_sub5", + total_configurations=5, + max_force=30 * unit.kilojoule_per_mole / unit.nanometer, ) assert trimmed_dataset.total_records() == 2 assert trimmed_dataset.get_record("mol1").n_configs == 3 @@ -861,6 +943,7 @@ def test_remove_high_force_configs(): # limit total_configurations and max_configurations_per_record trimmed_dataset = dataset.subset_dataset( + new_dataset_name="test_dataset12_sub6", total_configurations=6, max_configurations_per_record=3, max_force=31 * unit.kilojoule_per_mole / unit.nanometer, @@ -872,6 +955,7 @@ def test_remove_high_force_configs(): # Add in limiting of the atomic numbers trimmed_dataset = dataset.subset_dataset( + new_dataset_name="test_dataset12_sub7", total_configurations=6, max_configurations_per_record=3, max_force=30 * unit.kilojoule_per_mole / unit.nanometer, @@ -884,6 +968,7 @@ def test_remove_high_force_configs(): # same test but only restrictions on atomic_numbers and max_force trimmed_dataset = dataset.subset_dataset( + new_dataset_name="test_dataset12_sub8", max_force=30 * unit.kilojoule_per_mole / unit.nanometer, atomic_numbers_to_limit=[6, 1], ) @@ -894,7 +979,9 @@ def test_remove_high_force_configs(): # check toggling of total records trimmed_dataset = dataset.subset_dataset( - total_records=1, max_force=30 * unit.kilojoule_per_mole / unit.nanometer + new_dataset_name="test_dataset12_sub9", + total_records=1, + max_force=30 * unit.kilojoule_per_mole / unit.nanometer, ) assert trimmed_dataset.total_records() == 1 assert trimmed_dataset.get_record("mol1").n_configs == 3 @@ -902,7 +989,9 @@ def test_remove_high_force_configs(): # check toggling of total records trimmed_dataset = dataset.subset_dataset( - total_records=2, max_force=30 * unit.kilojoule_per_mole / unit.nanometer + new_dataset_name="test_dataset12_sub10", + total_records=2, + max_force=30 * unit.kilojoule_per_mole / unit.nanometer, ) assert trimmed_dataset.total_records() == 2 assert trimmed_dataset.get_record("mol1").n_configs == 3 @@ -911,6 +1000,7 @@ def test_remove_high_force_configs(): # make sure we can also exclude atomic numbers trimmed_dataset = dataset.subset_dataset( + new_dataset_name="test_dataset12_sub11", total_records=2, max_force=30 * unit.kilojoule_per_mole / unit.nanometer, atomic_numbers_to_limit=[6, 1], @@ -921,6 +1011,7 @@ def test_remove_high_force_configs(): # case where our atomic number filtering captures everything trimmed_dataset = dataset.subset_dataset( + new_dataset_name="test_dataset12_sub12", total_records=2, max_force=30 * unit.kilojoule_per_mole / unit.nanometer, atomic_numbers_to_limit=[6, 1, 8], @@ -929,3 +1020,34 @@ def test_remove_high_force_configs(): assert trimmed_dataset.get_record("mol1").n_configs == 3 assert trimmed_dataset.get_record("mol2").n_configs == 3 assert trimmed_dataset.total_configs() == 6 + + +def test_reading_from_db_file(prep_temp_dir): + record = Record(name="mol1") + + positions = Positions(value=[[[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]], units="nanometer") + energies = Energies(value=np.array([[0.1]]), units=unit.hartree) + atomic_numbers = AtomicNumbers(value=np.array([[1], [6]])) + smiles = MetaData(name="smiles", value="[CH]") + + record.add_properties([positions, energies, atomic_numbers, smiles]) + + ds = SourceDataset( + dataset_name="test_dataset14", + local_db_dir=str(prep_temp_dir), + local_db_name="test_dataset14.sqlite", + ) + ds.add_record(record) + + assert ds.total_records() == 1 + assert "mol1" in ds.records.keys() + + # now let us read from the file + ds2 = SourceDataset( + dataset_name="test_dataset15", + local_db_dir=str(prep_temp_dir), + local_db_name="test_dataset14.sqlite", + read_from_local_db=True, + ) + assert ds2.total_records() == 1 + assert "mol1" in ds2.records.keys()