diff --git a/tdc/single_pred/mpc.py b/tdc/single_pred/mpc.py index 9891537c..d52c447a 100644 --- a/tdc/single_pred/mpc.py +++ b/tdc/single_pred/mpc.py @@ -12,7 +12,7 @@ warnings.filterwarnings("ignore") from . import single_pred_dataset -from ..utils import print_sys, fuzzy_search, property_dataset_load +from ..utils import create_fold_setting_cold, create_scaffold_split from ..metadata import dataset_names @@ -21,6 +21,7 @@ class MPC(single_pred_dataset.DataLoader): def __init__(self, name, path="./data"): self.name = name self.data = None + self.is_molace = False def get_from_gh(self, link): import pandas as pd @@ -37,12 +38,13 @@ def get_from_gh(self, link): self.data = pd.read_csv(io.StringIO(data.text)) return self.data - def get_data(self, link=None, get_from_gh=True): + def get_data(self, link=None, get_from_gh=True, **kwargs): link = link if link is not None else self.name if get_from_gh: return self.get_from_gh(link) # support direct interfface with MoleculeACE API as well from MoleculeACE import Data, Descriptors + self.molace = True try: self.data = Data(self.name) self.data(Descriptors.SMILES) @@ -52,8 +54,21 @@ def get_data(self, link=None, get_from_gh=True): .format(self.name)) return self.data - def get_split(self): + def get_split(self, method="scaffold", seed=42, frac=[0.7, 0.1, 0.2]): d = self.get_data() + if not self.is_molace: + if method == "scaffold": + return create_scaffold_split(d, + seed=seed, + frac=frac, + entity="SMILES") + elif method == "cold": + return create_fold_setting_cold(d, + seed=seed, + frac=frac, + entities="SMILES") + raise Exception( + "only scaffold or cold splits supported for the MPC task") train = pd.concat([d.x_train, d.y_train], axis=1) test = pd.concat([d.x_test, d.y_test], axis=1) return { diff --git a/tdc/test/test_dataloaders.py b/tdc/test/test_dataloaders.py index 9080c5c5..4ce470e7 100644 --- a/tdc/test/test_dataloaders.py +++ b/tdc/test/test_dataloaders.py @@ -126,7 +126,10 @@ def test_resource_dataverse_dataloader_raw_splits(self): def test_mpc(self): from tdc.single_pred.mpc import MPC - Xs = MPC(name = "https://raw.githubusercontent.com/bidd-group/MPCD/main/dataset/ADMET/DeepDelta_benchmark/Caco2.csv") + Xs = MPC( + name= + "https://raw.githubusercontent.com/bidd-group/MPCD/main/dataset/ADMET/DeepDelta_benchmark/Caco2.csv" + ) Xs_split = Xs.get_split() Xs_train = Xs_split["train"] Xs_test = Xs_split["test"]