Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mark atoms inside rings #12

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 93 additions & 10 deletions src/header_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import tempfile

from rdkit import Chem
from rdkit import Geometry

HEADER_FILE = 'template_smiles.h'
TEMPLATE_FILE = 'templates.smi'
Expand All @@ -30,19 +31,101 @@
const std::vector<std::string> TEMPLATE_SMILES = {
"""


def clean_smiles(template_smiles):
"""
Translate all atoms into dummy atoms so that templates are not atom-specific.
"""
template = Chem.MolFromSmiles(template_smiles)
for atom in template.GetAtoms():
atom.SetAtomicNum(0)
# TO_DO: replace bonds with query bonds
return Chem.MolToCXSmiles(template)
ZontaNicola marked this conversation as resolved.
Show resolved Hide resolved


# TO_DO: replace bonds with query bonds

return Chem.MolToCXSmiles(template)

def crossing(v1pt1, v1pt2, v2pt1, v2pt2):
# Convert vector 1 to a line (line 1) of infinite length.
# We want the line in linear equation standard form: A*x + B*y + C = 0
# See: http://en.wikipedia.org/wiki/Linear_equation
a1 = v1pt2.y - v1pt1.y
b1 = v1pt1.x - v1pt2.x
c1 = (v1pt2.x * v1pt1.y) - (v1pt1.x * v1pt2.y)

# Insert (x1,y1) and (x2,y2) of vector 2 into the equation above.
d1 = (a1 * v2pt1.x) + (b1 * v2pt1.y) + c1
d2 = (a1 * v2pt2.x) + (b1 * v2pt2.y) + c1

# If d1 and d2 both have the same sign, no intersection is possible.
if (d1 > 0 and d2 > 0) or (d1 < 0 and d2 < 0):
return False

# Calculate the infinite line 2 in linear equation standard form.
a2 = v2pt2.y - v2pt1.y
b2 = v2pt1.x - v2pt2.x
c2 = (v2pt2.x * v2pt1.y) - (v2pt1.x * v2pt2.y)

# Calculate d1 and d2 again, this time using points of vector 1.
d1 = (a2 * v1pt1.x) + (b2 * v1pt1.y) + c2
d2 = (a2 * v1pt2.x) + (b2 * v1pt2.y) + c2

# If both have the same sign, no intersection is possible.
if (d1 > 0 and d2 > 0) or (d1 < 0 and d2 < 0):
return False

# If they are not collinear, they must intersect in exactly one point.
return True

def point_inside_ring(pt, coords, ring):
# Check if a point is inside a ring
min_pt = Geometry.Point2D(1e8, 1e8)
max_pt = Geometry.Point2D(-1e8, -1e8)
for atom in ring:
loc = Geometry.Point2D(coords[atom][0], coords[atom][1])
min_pt.x = min(min_pt.x, loc.x)
min_pt.y = min(min_pt.y, loc.y)
max_pt.x = max(max_pt.x, loc.x)
max_pt.y = max(max_pt.y, loc.y)
if pt.x < min_pt.x or pt.x > max_pt.x or pt.y < min_pt.y or pt.y > max_pt.y:
return False

# Ray casting: check how many times a ray to the point crosses the ring. If it crosses an odd number of times, the point is inside the ring
outside_x = min_pt.x - 0.1
intersections = 0
for i in range(len(ring)):
p1 = Geometry.Point2D (coords[ring[i]][0], coords[ring[i]][1])
p2 = Geometry.Point2D (coords[ring[(i + 1) % len(ring)]][0], coords[ring[(i + 1) % len(ring)]][1])
if crossing(p1, p2, pt, Geometry.Point2D(outside_x, pt.y)) and pt.y != p2.y:
intersections += 1
return (intersections % 2) == 1

def mark_inner_atoms(smiles):
mol = Chem.MolFromSmiles(smiles)
mol = Chem.RWMol(mol)
inner_atoms = set()
for atom in mol.GetAtoms():
point = Geometry.Point3D(0, 0, 0)
# average the position of substituents
if atom.GetDegree() < 1:
continue
coordinates = mol.GetConformer().GetAtomPosition(atom.GetIdx())
for nbr in atom.GetNeighbors():
point += mol.GetConformer().GetAtomPosition(nbr.GetIdx())
point /= atom.GetDegree()
point = coordinates + (point - coordinates) * (-0.3)
rachelnwalker marked this conversation as resolved.
Show resolved Hide resolved

for ring in mol.GetRingInfo().AtomRings():
# if the point is inside a ring, avoid attaching substituents to atom
if point_inside_ring(point, mol.GetConformer().GetPositions(), ring):
inner_atoms.add(atom.GetIdx())
break

DUMMY_ATOMIC_NUM = 200
for aidx in range(mol.GetNumAtoms()):
atom = mol.GetAtomWithIdx(aidx)
query = f"[!#{DUMMY_ATOMIC_NUM}]" # query for any atom except the dummy atom
if aidx in inner_atoms:
# this atom cannot have substituents, so we also add a fixed degree to the query
query = f"[!#{DUMMY_ATOMIC_NUM}&D{atom.GetDegree()}]"
query_atom = Chem.AtomFromSmarts(query)
mol.ReplaceAtom(aidx, query_atom)

return Chem.MolToCXSmarts(mol)

def generate_header(generated_header_path):
with open(generated_header_path, 'w') as f_out:
Expand All @@ -52,8 +135,8 @@ def generate_header(generated_header_path):
if not (cxsmiles := line.strip()):
continue

# TO_DO: Clean smiles to make them atom-type and bond-type agnostic
# cxsmiles = clean_smiles(cxsmiles)
# cxsmiles = clean_smiles(cxsmiles)
cxsmiles = mark_inner_atoms(cxsmiles)

f_out.write(f' "{cxsmiles}",\n')
f_out.write('};\n// clang-format on\n')
Expand Down