Skip to content

Commit

Permalink
Revamped SourceDataset to store records in an sqlite database rather …
Browse files Browse the repository at this point in the history
…than python memory.
  • Loading branch information
chrisiacovella committed Feb 27, 2025
1 parent e333971 commit 218adc1
Show file tree
Hide file tree
Showing 14 changed files with 732 additions and 450 deletions.
543 changes: 339 additions & 204 deletions modelforge-curate/modelforge/curate/curate.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _process_downloaded(

conformers_counter = 0

dataset = SourceDataset("ani2x")
dataset = SourceDataset(dataset_name="ani2x", local_db_dir=self.local_cache_dir)
with h5py.File(input_file_name, "r") as hf:
# The ani2x hdf5 file groups molecules by number of atoms
# we need to break up each of these groups into individual molecules
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,19 @@ def to_hdf5(
or max_force is not None
or total_records is not None
):
import time
import random

# generate a 5 digit random number to append to the dataset name
# using current time as the seed
# this database will be removed after the dataset is written
random.seed(time.time())
number = random.randint(10000, 99999)

new_dataset_name = f"{self.dataset.dataset_name}_temp_{number}"

dataset_trimmed = self.dataset.subset_dataset(
new_dataset_name=new_dataset_name,
total_configurations=total_configurations,
total_records=total_records,
max_configurations_per_record=max_configurations_per_record,
Expand All @@ -339,7 +351,13 @@ def to_hdf5(
file_name=hdf5_file_name,
file_path=output_file_dir,
)
return (dataset_trimmed.total_records(), dataset_trimmed.total_configs())
n_total_records = dataset_trimmed.total_records()
n_total_configs = dataset_trimmed.total_configs()

# remove the database associated with the temporarily created dataset
dataset_trimmed._remove_local_db()

return (n_total_records, n_total_configs)
else:

self._write_hdf5_and_json_files(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,11 @@ def _process_downloaded(

from numpy import newaxis

dataset = SourceDataset("PhAlkEthOH_openff", append_property=True)
dataset = SourceDataset(
dataset_name="PhAlkEthOH_openff",
append_property=True,
local_db_dir=self.local_cache_dir,
)

for filename, dataset_name in zip(filenames, dataset_names):
input_file_name = f"{local_path_dir}/{filename}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ def _process_downloaded(

# we do not need to do anything check unit_testing_max_conformers_per_record because qm9 only has a single conformer per record

dataset = SourceDataset("qm9")
dataset = SourceDataset(dataset_name="qm9", local_db_dir=self.local_cache_dir)
for i, file in enumerate(tqdm(files, desc="processing", total=len(files))):
record_temp = self._parse_xyzfile(f"{local_path_dir}/{file}")
dataset.add_record(record_temp)
Expand Down
143 changes: 47 additions & 96 deletions modelforge-curate/modelforge/curate/datasets/scripts/curate_spice1.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,73 +23,12 @@
"""


def spice1_wrapper(
hdf5_file_name: str,
output_file_dir: str,
local_cache_dir: str,
force_download: bool = False,
version_select: str = "latest",
max_records=None,
max_conformers_per_record=None,
total_conformers=None,
limit_atomic_species=None,
):
"""
This curates and processes the SPICE2 dataset into an hdf5 file.
Parameters
----------
hdf5_file_name: str, required
Name of the hdf5 file that will be generated.
output_file_dir: str, required
Directory where the hdf5 file will be saved.
local_cache_dir: str, required
Directory where the intermediate data will be saved; in this case it will be tarred file downloaded
from figshare and the expanded archive that contains xyz files for each molecule in the dataset.
force_download: bool, optional, default=False
If False, we will use the tarred file that exists in the local_cache_dir (if it exists);
If True, the tarred file will be downloaded, even if it exists locally.
version_select: str, optional, default="latest"
The version of the dataset to use as defined in the associated yaml file.
If "latest", the most recent version will be used.
max_records: int, optional, default=None
The maximum number of records to process.
max_conformers_per_record: int, optional, default=None
The maximum number of conformers to process for each record.
total_conformers: int, optional, default=None
The total number of conformers to process.
limit_atomic_species: list, optional, default=None
A list of atomic species to limit the dataset to. Any molecules that contain elements outside of this list
will be ignored. If not defined, no filtering by atomic species will be performed.
"""
from modelforge.curate.datasets.spice_1_curation import SPICE1Curation

spice_1_data = SPICE1Curation(
hdf5_file_name=hdf5_file_name,
output_file_dir=output_file_dir,
local_cache_dir=local_cache_dir,
version_select=version_select,
)

spice_1_data.process(
force_download=force_download,
max_records=max_records,
max_conformers_per_record=max_conformers_per_record,
total_conformers=total_conformers,
limit_atomic_species=limit_atomic_species,
)
print(f"Total records: {spice_1_data.total_records()}")
print(f"Total configs: {spice_1_data.total_configs()}")


def main():
# define the location where to store and output the files
import os

local_prefix = os.path.expanduser("~/mf_datasets")
output_file_dir = f"{local_prefix}/hdf5_files"
output_file_dir = f"{local_prefix}/hdf5_files/spice1"
local_cache_dir = f"{local_prefix}/spice1_dataset"

# We'll want to provide some simple means of versioning
Expand All @@ -99,58 +38,70 @@ def main():
version_select = f"v_0"

# version v_0 corresponds to SPICE 1.1.4
# start with processing the full dataset
from modelforge.curate.datasets.spice_1_curation import SPICE1Curation

spice1_dataset = SPICE1Curation(
local_cache_dir=local_cache_dir,
version_select=version_select,
)

spice1_dataset.process(force_download=False)

ani2x_elements = ["H", "C", "N", "O", "F", "Cl", "S"]

# curate SPICE 2.0.1 dataset with 1000 total conformers, max of 10 conformers per record
# curate SPICE 1.1.4 dataset with 1000 total configurations, max of 10 conformers per record
# limited to the elements that will work with ANI2x
hdf5_file_name = f"spice_1_dataset_v{version}_ntc_1000_HCNOFClS.hdf5"

spice1_wrapper(
hdf5_file_name,
output_file_dir,
local_cache_dir,
force_download=False,
version_select=version_select,
max_conformers_per_record=10,
total_conformers=1000,
limit_atomic_species=ani2x_elements,
spice1_dataset.to_hdf5(
hdf5_file_name=hdf5_file_name,
output_file_dir=output_file_dir,
total_configurations=1000,
max_configurations_per_record=10,
atomic_species_to_limit=ani2x_elements,
)
# curate the full SPICE 2.0.1 dataset, limited to the elements that will work with ANI2x

print("SPICE1: 1000 configuration subset limited to ANI2x elements")
print(f"Total records: {spice1_dataset.total_records()}")
print(f"Total configs: {spice1_dataset.total_configs()}")

# curate the full SPICE 1.1.4 dataset, limited to the elements that will work with ANI2x
hdf5_file_name = f"spice_1_dataset_v{version}_HCNOFClS.hdf5"

spice1_wrapper(
hdf5_file_name,
output_file_dir,
local_cache_dir,
force_download=False,
version_select=version_select,
limit_atomic_species=ani2x_elements,
spice1_dataset.to_hdf5(
hdf5_file_name=hdf5_file_name,
output_file_dir=output_file_dir,
atomic_species_to_limit=ani2x_elements,
)

# curate the test SPICE 2.0.1 dataset with 1000 total conformers, max of 10 conformers per record
print("SPICE1: full dataset limited to ANI2x elements")
print(f"Total records: {spice1_dataset.total_records()}")
print(f"Total configs: {spice1_dataset.total_configs()}")

# curate the test SPICE 1.1.4 dataset with 1000 total configurations, max of 10 configurations per record
hdf5_file_name = f"spice_1_dataset_v{version}_ntc_1000.hdf5"

spice1_wrapper(
hdf5_file_name,
output_file_dir,
local_cache_dir,
force_download=False,
version_select=version_select,
max_conformers_per_record=10,
total_conformers=1000,
spice1_dataset.to_hdf5(
hdf5_file_name=hdf5_file_name,
output_file_dir=output_file_dir,
total_configurations=1000,
max_configurations_per_record=10,
)

# curate the full SPICE 2.0.1 dataset
print("SPICE1: 1000 configuration subset")
print(f"Total records: {spice1_dataset.total_records()}")
print(f"Total configs: {spice1_dataset.total_configs()}")

# curate the full SPICE 1.1.4 dataset
hdf5_file_name = f"spice_1_dataset_v{version}.hdf5"

spice1_wrapper(
hdf5_file_name,
output_file_dir,
local_cache_dir,
force_download=False,
version_select=version_select,
spice1_dataset.to_hdf5(
hdf5_file_name=hdf5_file_name, output_file_dir=output_file_dir
)
print("SPICE1: full dataset")
print(f"Total records: {spice1_dataset.total_records()}")
print(f"Total configs: {spice1_dataset.total_configs()}")


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 218adc1

Please sign in to comment.