diff --git a/test/data/test_molecule.py b/test/data/test_molecule.py index b1908e6a..957e7710 100644 --- a/test/data/test_molecule.py +++ b/test/data/test_molecule.py @@ -36,6 +36,36 @@ def test_feature(self): mol = data.Molecule.from_smiles(self.smiles, graph_feature="ecfp") self.assertTrue((mol.graph_feature > 0).any(), "Incorrect ECFP feature") + def test_feature_with_kwargs(self): + from torchdrug.core import Registry as R + @R.register('features.atom.my_features') + def my_features(atom, i, j): + return [i, j] + + @R.register('features.bond.my_features') + def my_features(bond, i, j): + return [i, j] + + @R.register('features.molecule.my_features') + def my_features(mol, i, j): + return [i, j] + + expected_node_features = torch.tensor([1,2]).repeat((6,1)) + expected_edge_features = torch.tensor([1, 2]).repeat((12, 1)) + expected_graph_features = torch.tensor([1, 2]) + + m = data.Molecule.from_smiles("C1=CC=CC=C1", + node_feature="my_features", + node_feature_kwargs=dict(i=1, j=2), + edge_feature="my_features", + edge_feature_kwargs=dict(i=1, j=2), + graph_feature="my_features", + graph_feature_kwargs=dict(i=1, j=2)) + + assert (m.node_feature == expected_node_features).all() + assert (m.edge_feature == expected_edge_features).all() + assert (m.graph_feature == expected_graph_features).all() + if __name__ == "__main__": unittest.main() \ No newline at end of file diff --git a/torchdrug/data/feature.py b/torchdrug/data/feature.py index dca97701..1ef356b1 100644 --- a/torchdrug/data/feature.py +++ b/torchdrug/data/feature.py @@ -47,7 +47,7 @@ def onehot(x, vocab, allow_unknown=False): # TODO: this one is too slow @R.register("features.atom.default") -def atom_default(atom): +def atom_default(atom, **kwargs): """Default atom feature. Features: @@ -83,7 +83,7 @@ def atom_default(atom): @R.register("features.atom.center_identification") -def atom_center_identification(atom): +def atom_center_identification(atom, **kwargs): """Reaction center identification atom feature. Features: @@ -107,7 +107,7 @@ def atom_center_identification(atom): @R.register("features.atom.synthon_completion") -def atom_synthon_completion(atom): +def atom_synthon_completion(atom, **kwargs): """Synthon completion atom feature. Features: @@ -133,7 +133,7 @@ def atom_synthon_completion(atom): @R.register("features.atom.symbol") -def atom_symbol(atom): +def atom_symbol(atom, **kwargs): """Symbol atom feature. Features: @@ -143,7 +143,7 @@ def atom_symbol(atom): @R.register("features.atom.explicit_property_prediction") -def atom_explicit_property_prediction(atom): +def atom_explicit_property_prediction(atom, **kwargs): """Explicit property prediction atom feature. Features: @@ -165,7 +165,7 @@ def atom_explicit_property_prediction(atom): @R.register("features.atom.property_prediction") -def atom_property_prediction(atom): +def atom_property_prediction(atom, **kwargs): """Property prediction atom feature. Features: @@ -190,7 +190,7 @@ def atom_property_prediction(atom): @R.register("features.atom.position") -def atom_position(atom): +def atom_position(atom, **kwargs): """ Atom position. Return 3D position if available, otherwise 2D position is returned. @@ -204,7 +204,7 @@ def atom_position(atom): @R.register("features.atom.pretrain") -def atom_pretrain(atom): +def atom_pretrain(atom, **kwargs): """Atom feature for pretraining. Features: @@ -217,7 +217,7 @@ def atom_pretrain(atom): @R.register("features.bond.default") -def bond_default(bond): +def bond_default(bond, **kwargs): """Default bond feature. Features: @@ -239,7 +239,7 @@ def bond_default(bond): @R.register("features.bond.length") -def bond_length(bond): +def bond_length(bond, **kwargs): """Bond length""" mol = bond.GetOwningMol() if mol.GetNumConformers() == 0: @@ -251,7 +251,7 @@ def bond_length(bond): @R.register("features.bond.property_prediction") -def bond_property_prediction(bond): +def bond_property_prediction(bond, **kwargs): """Property prediction bond feature. Features: @@ -266,7 +266,7 @@ def bond_property_prediction(bond): @R.register("features.bond.pretrain") -def bond_pretrain(bond): +def bond_pretrain(bond, **kwargs): """Bond feature for pretraining. Features: @@ -290,9 +290,9 @@ def ExtendedConnectivityFingerprint(mol, radius=2, length=1024): @R.register("features.molecule.default") -def molecule_default(mol): +def molecule_default(mol, **kwargs): """Default molecule feature.""" - return ExtendedConnectivityFingerprint(mol) + return ExtendedConnectivityFingerprint(mol, **kwargs) ECFP = ExtendedConnectivityFingerprint diff --git a/torchdrug/data/molecule.py b/torchdrug/data/molecule.py index dcf5f39d..b0d94fb1 100644 --- a/torchdrug/data/molecule.py +++ b/torchdrug/data/molecule.py @@ -11,6 +11,7 @@ from torchdrug.data import constant, Graph, PackedGraph from torchdrug.core import Registry as R from torchdrug.data.rdkit import draw +from typing import Optional plt.switch_backend("agg") @@ -98,6 +99,14 @@ def _standarize_option(cls, option): option = [option] return option + @classmethod + def _standarize_option_kwargs(cls, option_kwargs): + if option_kwargs is None: + option_kwargs = [{}] + elif isinstance(option_kwargs, dict): + option_kwargs = [option_kwargs] + return option_kwargs + def _check_no_stereo(self): if (self.bond_stereo > 0).any(): warnings.warn("Try to apply masks on molecules with stereo bonds. This may produce invalid molecules. " @@ -109,9 +118,20 @@ def _maybe_num_node(self, edge_list): else: return 0 + @classmethod + def _check_features_kwargs(cls, features, feature_kwargs): + if len(features) > 0 and len(features) != len(feature_kwargs): + raise ValueError(""" + The number of features to extract does not match the number of provided feature_kwargs. + If you provide a list of features, provide a list of (empty) kwargs dicts. + """) + @classmethod def from_smiles(cls, smiles, node_feature="default", edge_feature="default", graph_feature=None, - with_hydrogen=False, kekulize=False): + with_hydrogen=False, kekulize=False, + node_feature_kwargs: Optional[dict] = None, + edge_feature_kwargs: Optional[dict] = None, + graph_feature_kwargs: Optional[dict] = None): """ Create a molecule from a SMILES string. @@ -126,16 +146,24 @@ def from_smiles(cls, smiles, node_feature="default", edge_feature="default", gra Note this only affects the relation in ``edge_list``. For ``bond_type``, aromatic bonds are always stored explicitly. By default, aromatic bonds are stored. + node_feature_kwargs (dict or list of dic, optional): (list of) dict with kwargs for each `node_feature` extraction function + edge_feature_kwargs (dict or list of dict, optional): (list of) dict with kwargs for each `edge_feature` extraction function + graph_feature_kwargs (dict or list of dict, optional): (list of) dict with kwargs for each `graph_feature` extraction function """ mol = Chem.MolFromSmiles(smiles) if mol is None: raise ValueError("Invalid SMILES `%s`" % smiles) - return cls.from_molecule(mol, node_feature, edge_feature, graph_feature, with_hydrogen, kekulize) + return cls.from_molecule(mol, node_feature, edge_feature, graph_feature, with_hydrogen, kekulize, + node_feature_kwargs, edge_feature_kwargs, graph_feature_kwargs) @classmethod def from_molecule(cls, mol, node_feature="default", edge_feature="default", graph_feature=None, - with_hydrogen=False, kekulize=False): + with_hydrogen=False, kekulize=False, + node_feature_kwargs: Optional[dict] = None, + edge_feature_kwargs: Optional[dict] = None, + graph_feature_kwargs: Optional[dict] = None + ): """ Create a molecule from a RDKit object. @@ -150,6 +178,9 @@ def from_molecule(cls, mol, node_feature="default", edge_feature="default", grap Note this only affects the relation in ``edge_list``. For ``bond_type``, aromatic bonds are always stored explicitly. By default, aromatic bonds are stored. + node_feature_kwargs (dict or list of dic, optional): (list of) dict with kwargs for each `node_feature` extraction function + edge_feature_kwargs (dict or list of dict, optional): (list of) dict with kwargs for each `edge_feature` extraction function + graph_feature_kwargs (dict or list of dict, optional): (list of) dict with kwargs for each `graph_feature` extraction function """ if mol is None: mol = cls.empty_mol @@ -163,6 +194,14 @@ def from_molecule(cls, mol, node_feature="default", edge_feature="default", grap edge_feature = cls._standarize_option(edge_feature) graph_feature = cls._standarize_option(graph_feature) + node_feature_kwargs = cls._standarize_option_kwargs(node_feature_kwargs) + edge_feature_kwargs = cls._standarize_option_kwargs(edge_feature_kwargs) + graph_feature_kwargs = cls._standarize_option_kwargs(graph_feature_kwargs) + + for feat, feat_kwargs in zip([node_feature, edge_feature, graph_feature], + [node_feature_kwargs, edge_feature_kwargs, graph_feature_kwargs]): + cls._check_features_kwargs(feat, feat_kwargs) + atom_type = [] formal_charge = [] explicit_hs = [] @@ -179,9 +218,9 @@ def from_molecule(cls, mol, node_feature="default", edge_feature="default", grap radical_electrons.append(atom.GetNumRadicalElectrons()) atom_map.append(atom.GetAtomMapNum()) feature = [] - for name in node_feature: + for name, kwargs in zip(node_feature, node_feature_kwargs): func = R.get("features.atom.%s" % name) - feature += func(atom) + feature += func(atom, **kwargs) _node_feature.append(feature) atom_type = torch.tensor(atom_type)[:-1] atom_map = torch.tensor(atom_map)[:-1] @@ -219,9 +258,9 @@ def from_molecule(cls, mol, node_feature="default", edge_feature="default", grap bond_stereo += [stereo, stereo] stereo_atoms += [_atoms, _atoms] feature = [] - for name in edge_feature: + for name, kwargs in zip(edge_feature, edge_feature_kwargs): func = R.get("features.bond.%s" % name) - feature += func(bond) + feature += func(bond, **kwargs) _edge_feature += [feature, feature] edge_list = edge_list[:-2] bond_type = torch.tensor(bond_type)[:-2] @@ -233,9 +272,9 @@ def from_molecule(cls, mol, node_feature="default", edge_feature="default", grap _edge_feature = None _graph_feature = [] - for name in graph_feature: + for name, kwargs in zip(graph_feature, graph_feature_kwargs): func = R.get("features.molecule.%s" % name) - _graph_feature += func(mol) + _graph_feature += func(mol, **kwargs) if len(graph_feature) > 0: _graph_feature = torch.tensor(_graph_feature) else: