Skip to content

Commit

Permalink
Updating tests and scripts to handle db bag end. Small revamps to spe…
Browse files Browse the repository at this point in the history
…ed up curation
  • Loading branch information
chrisiacovella committed Feb 27, 2025
1 parent 218adc1 commit 664080c
Show file tree
Hide file tree
Showing 18 changed files with 379 additions and 389 deletions.
237 changes: 159 additions & 78 deletions modelforge-curate/modelforge/curate/curate.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ def _process_downloaded(

conformers_counter = 0

dataset = SourceDataset(dataset_name="ani2x", local_db_dir=self.local_cache_dir)
dataset = SourceDataset(
dataset_name=self.dataset_name, local_db_dir=self.local_cache_dir
)
with h5py.File(input_file_name, "r") as hf:
# The ani2x hdf5 file groups molecules by number of atoms
# we need to break up each of these groups into individual molecules
Expand Down
26 changes: 26 additions & 0 deletions modelforge-curate/modelforge/curate/datasets/curation_baseclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class DatasetCuration(ABC):

def __init__(
self,
dataset_name: str,
local_cache_dir: Optional[str] = "./datasets_cache",
version_select: str = "latest",
):
Expand All @@ -30,6 +31,8 @@ def __init__(
Parameters
----------
dataset_name: str, required
Name of the dataset to curate.
local_cache_dir: str, optional, default='./qm9_datafiles'
Location to save downloaded dataset.
version_select: str, optional, default='latest'
Expand All @@ -41,6 +44,8 @@ def __init__(
# make sure we can handle a path with a ~ in it
self.local_cache_dir = os.path.expanduser(local_cache_dir)
self.version_select = version_select
self.dataset_name = dataset_name

os.makedirs(self.local_cache_dir, exist_ok=True)

# initialize parameter information
Expand Down Expand Up @@ -240,6 +245,7 @@ def to_hdf5(
total_configurations: Optional[int] = None,
atomic_species_to_limit: Optional[List[Union[str, int]]] = None,
max_force: Optional[unit.Quantity] = None,
final_configuration_only: Optional[bool] = False,
) -> Tuple[int, int]:
"""
Writes the dataset to an hdf5 file.
Expand All @@ -265,6 +271,8 @@ def to_hdf5(
These can be passed as a list of strings, e.g., ['C', 'H', 'O'] or as a list of atomic numbers, e.g., [6, 1, 8].
max_force: unit.Quantity, optional, default=None
Maximum force to include in the dataset. Any configuration with forces greater than this value will be excluded.
final_configuration_only: bool, optional, default=False
If True, only the final configuration of each record will be included in the dataset.
Returns
-------
Expand Down Expand Up @@ -323,6 +331,7 @@ def to_hdf5(
or atomic_species_to_limit is not None
or max_force is not None
or total_records is not None
or final_configuration_only
):
import time
import random
Expand All @@ -342,6 +351,7 @@ def to_hdf5(
max_configurations_per_record=max_configurations_per_record,
atomic_numbers_to_limit=atomic_numbers_to_limit,
max_force=max_force,
final_configuration_only=final_configuration_only,
)
if dataset_trimmed.total_records() == 0:
raise ValueError("No records found in the dataset after filtering.")
Expand All @@ -366,3 +376,19 @@ def to_hdf5(
file_path=output_file_dir,
)
return (self.dataset.total_records(), self.dataset.total_configs())

def load_from_db(self, local_db_dir: str, local_db_name: str):
"""
Load the dataset from a local database.
Parameters
----------
local_db_name: str, required
Name of the local database to load the dataset from.
"""
self.dataset = SourceDataset(
dataset_name=self.dataset_name,
local_db_dir=local_db_dir,
local_db_name=local_db_name,
read_from_db=True,
)
59 changes: 40 additions & 19 deletions modelforge-curate/modelforge/curate/datasets/phalkethoh_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ def _fetch_singlepoint_from_qcarchive(
if pbar is not None:
pbar.update(1)

from functools import lru_cache

@lru_cache(maxsize=None)
def _calculate_total_charge(
self, smiles: str
) -> Tuple[unit.Quantity, unit.Quantity]:
Expand Down Expand Up @@ -220,7 +223,7 @@ def _process_downloaded(
from numpy import newaxis

dataset = SourceDataset(
dataset_name="PhAlkEthOH_openff",
dataset_name=self.dataset_name,
append_property=True,
local_db_dir=self.local_cache_dir,
)
Expand Down Expand Up @@ -258,6 +261,7 @@ def _process_downloaded(
# these properties only need to be added once
# so we need to check if
if not name in dataset.records.keys():
dataset.create_record(name)
source = MetaData(
name="source", value=input_file_name.replace(".sqlite", "")
)
Expand Down Expand Up @@ -299,32 +303,29 @@ def _process_downloaded(
# name = key.split("-")[0]
trajectory = spice_db[key][1]

name = f'{key[: key.rfind("-")]}_{trajectory[0][0]["molecule_"]["name"]}'
record = dataset.get_record(name)

for state in trajectory:
add_record = True
properties, config = state
name = (
f'{key[: key.rfind("-")]}_{properties["molecule_"]["name"]}'
)
smiles = (
dataset.records[name]
.meta_data[
"canonical_isomeric_explicit_hydrogen_mapped_smiles"
]
.value
)
# record = dataset.get_record(name)
smiles = record.meta_data[
"canonical_isomeric_explicit_hydrogen_mapped_smiles"
].value

total_charge_temp = self._calculate_total_charge(smiles)

total_charge = TotalCharge(
value=np.array(total_charge_temp.m).reshape(1, 1),
units=total_charge_temp.u,
)
dataset.add_property(name, total_charge)

positions = Positions(
value=config.reshape(1, -1, 3), units=unit.bohr
)
dataset.add_property(name, positions)

# Note need to typecast here because of a bug in the
# qcarchive entry: see issue: https://github.com/MolSSI/QCFractal/issues/766
Expand All @@ -339,7 +340,6 @@ def _process_downloaded(
).reshape(1, 1),
units=unit.hartree,
)
dataset.add_property(name, dispersion_correction_energy)

dft_total_energy = Energies(
name="dft_total_energy",
Expand All @@ -349,7 +349,6 @@ def _process_downloaded(
+ dispersion_correction_energy.value,
units=unit.hartree,
)
dataset.add_property(name, dft_total_energy)

dispersion_correction_gradient = Forces(
name="dispersion_correction_gradient",
Expand All @@ -360,14 +359,12 @@ def _process_downloaded(
).reshape(1, -1, 3),
units=unit.hartree / unit.bohr,
)
dataset.add_property(name, dispersion_correction_gradient)

dispersion_correction_force = Forces(
name="dispersion_correction_force",
value=-dispersion_correction_gradient.value,
units=unit.hartree / unit.bohr,
)
dataset.add_property(name, dispersion_correction_force)

dft_total_gradient = Forces(
name="dft_total_gradient",
Expand All @@ -377,14 +374,12 @@ def _process_downloaded(
+ dispersion_correction_gradient.value,
units=unit.hartree / unit.bohr,
)
dataset.add_property(name, dft_total_gradient)

dft_total_force = Forces(
name="dft_total_force",
value=-dft_total_gradient.value,
units=unit.hartree / unit.bohr,
)
dataset.add_property(name, dft_total_force)

scf_dipole = DipoleMomentPerSystem(
name="scf_dipole",
Expand All @@ -393,8 +388,34 @@ def _process_downloaded(
).reshape(1, 3),
units=unit.elementary_charge * unit.bohr,
)
dataset.add_property(name, scf_dipole)

record.add_properties(
[
total_charge,
positions,
dispersion_correction_energy,
dft_total_energy,
dispersion_correction_gradient,
dispersion_correction_force,
dft_total_gradient,
dft_total_force,
scf_dipole,
],
)
# dataset.add_properties(
# name,
# [
# total_charge,
# positions,
# dispersion_correction_energy,
# dft_total_energy,
# dispersion_correction_gradient,
# dispersion_correction_force,
# dft_total_gradient,
# dft_total_force,
# scf_dipole,
# ],
# )
dataset.update_record(record)
return dataset

def process(
Expand Down
4 changes: 3 additions & 1 deletion modelforge-curate/modelforge/curate/datasets/qm9_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,9 @@ def _process_downloaded(

# we do not need to do anything check unit_testing_max_conformers_per_record because qm9 only has a single conformer per record

dataset = SourceDataset(dataset_name="qm9", local_db_dir=self.local_cache_dir)
dataset = SourceDataset(
dataset_name=self.dataset_name, local_db_dir=self.local_cache_dir
)
for i, file in enumerate(tqdm(files, desc="processing", total=len(files))):
record_temp = self._parse_xyzfile(f"{local_path_dir}/{file}")
dataset.add_record(record_temp)
Expand Down
Loading

0 comments on commit 664080c

Please sign in to comment.