Skip to content

Commit

Permalink
Add cpdb support to speed up parsing (#323)
Browse files Browse the repository at this point in the history
* add cpdb to speed up parsing

* reorder requirements

* reorder requirements

* Add blank columns to write PDBs

* Update changelog

* update range indexing

* add blank segment_id column if necessary

* pin cpdb version

* update test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* updates to save_pdb function

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix broken pdb writer

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unpin numpy dependency

* add missing numpy import

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* resolve test dtype

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix test dtype

* fix type error in charge writing

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix test syntax error

* format charge correctly

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* modify tests to use CPDB

* fix syntax error

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix column drops

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix remaining tests after adding CPDB parser backend

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: Arian Jamasb <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Aug 4, 2024
1 parent 90be006 commit 4d8dc64
Show file tree
Hide file tree
Showing 8 changed files with 132 additions and 35 deletions.
4 changes: 3 additions & 1 deletion .requirements/base.in
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@ pandas<2.0.0
biopandas>=0.5.1
biopython
bioservices>=1.10.0
cpdb-protein==0.2.0
cython
deepdiff
loguru
looseversion
matplotlib>=3.4.3
multipledispatch
networkx
numpy<1.24.0
numpy
pandas
plotly
pydantic
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ https://github.com/a-r-j/graphein/pull/334

#### Other Changes

- Uses [`cpdb`](https://github.com/a-r-j/CPDB) as default PDB file parser for improved performance. [#323](https://github.com/a-r-j/graphein/pull/323).
- Adds transform composition to FoldComp Dataset [#312](https://github.com/a-r-j/graphein/pull/312)
- Adds entry point for biopandas dataframes in `graphein.protein.tensor.io.protein_to_pyg`. [#310](https://github.com/a-r-j/graphein/pull/310)
- Adds support for `.ent` files to `graphein.protein.graphs.read_pdb_to_dataframe`. [#310](https://github.com/a-r-j/graphein/pull/310)
Expand Down
30 changes: 20 additions & 10 deletions graphein/protein/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import cpdb
import networkx as nx
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -109,32 +110,41 @@ def read_pdb_to_dataframe(
or path.endswith(".pdb.gz")
or path.endswith(".ent")
):
atomic_df = PandasPdb().read_pdb(path)
atomic_df = cpdb.parse(path)
elif path.endswith(".mmtf") or path.endswith(".mmtf.gz"):
atomic_df = PandasMmtf().read_mmtf(path)
atomic_df = atomic_df.get_model(model_index)
atomic_df = pd.concat(
[atomic_df.df["ATOM"], atomic_df.df["HETATM"]]
)
elif (
path.endswith(".cif")
or path.endswith(".cif.gz")
or path.endswith(".mmcif")
or path.endswith(".mmcif.gz")
):
atomic_df = PandasMmcif().read_mmcif(path)
atomic_df = atomic_df.get_model(model_index)
atomic_df = atomic_df.convert_to_pandas_pdb()
atomic_df = pd.concat(
[atomic_df.df["ATOM"], atomic_df.df["HETATM"]]
)
else:
raise ValueError(
f"File {path} must be either .pdb(.gz), .mmtf(.gz), .(mm)cif(.gz) or .ent, not {path.split('.')[-1]}"
)
elif uniprot_id is not None:
atomic_df = PandasPdb().fetch_pdb(
uniprot_id=uniprot_id, source="alphafold2-v3"
)
atomic_df = cpdb.parse(uniprot_id=uniprot_id)
else:
atomic_df = PandasPdb().fetch_pdb(pdb_code)
atomic_df = atomic_df.get_model(model_index)
if len(atomic_df.df["ATOM"]) == 0:
atomic_df = cpdb.parse(pdb_code=pdb_code)

if "model_idx" in atomic_df.columns:
atomic_df = atomic_df.loc[atomic_df["model_idx"] == model_index]

if len(atomic_df) == 0:
raise ValueError(f"No model found for index: {model_index}")
if isinstance(atomic_df, PandasMmcif):
atomic_df = atomic_df.convert_to_pandas_pdb()
return pd.concat([atomic_df.df["ATOM"], atomic_df.df["HETATM"]])

return atomic_df


def label_node_id(
Expand Down
86 changes: 74 additions & 12 deletions graphein/protein/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from urllib.request import urlopen

import networkx as nx
import numpy as np
import pandas as pd
import requests
import wget
Expand All @@ -25,6 +26,30 @@

from .resi_atoms import BACKBONE_ATOMS, RESI_THREE_TO_1

pdb_df_columns = [
"record_name",
"atom_number",
"blank_1",
"atom_name",
"alt_loc",
"residue_name",
"blank_2",
"chain_id",
"residue_number",
"insertion",
"blank_3",
"x_coord",
"y_coord",
"z_coord",
"occupancy",
"b_factor",
"blank_4",
"segment_id",
"element_symbol",
"charge",
"line_idx",
]


class ProteinGraphConfigurationError(Exception):
"""
Expand Down Expand Up @@ -418,12 +443,27 @@ def save_graph_to_pdb(
:type gz: bool
"""
ppd = PandasPdb()
atom_df = filter_dataframe(
g.graph["pdb_df"], "record_name", ["ATOM"], boolean=True
)
hetatm_df = filter_dataframe(
g.graph["pdb_df"], "record_name", ["HETATM"], boolean=True
)

df = g.graph["pdb_df"].copy()
# format charge correctly
df.charge = pd.to_numeric(df.charge, errors="coerce")

# Add blank columns
blank_cols = [
"blank_1",
"blank_2",
"blank_3",
"blank_4",
"segment_id",
]
for col in blank_cols:
if col not in df.columns:
df[col] = ""
df["line_idx"] = list(range(1, len(df) + 1))
df = df[pdb_df_columns]
atom_df = filter_dataframe(df, "record_name", ["ATOM"], boolean=True)
hetatm_df = filter_dataframe(df, "record_name", ["HETATM"], boolean=True)

if atoms:
ppd.df["ATOM"] = atom_df
if hetatms:
Expand All @@ -448,9 +488,22 @@ def save_pdb_df_to_pdb(
:param gz: Whether to gzip the file. Defaults to ``False``.
:type gz: bool
"""
df = df.copy()
# format charge correctly
df.charge = pd.to_numeric(df.charge, errors="coerce")
df.alt_loc = df.alt_loc.fillna(" ")
blank_cols = ["blank_1", "blank_2", "blank_3", "blank_4", "segment_id"]
for col in blank_cols:
if col not in df.columns:
df[col] = ""
df["line_idx"] = list(range(1, len(df) + 1))
df = df[pdb_df_columns]

atom_df = filter_dataframe(df, "record_name", ["ATOM"], boolean=True)
hetatm_df = filter_dataframe(df, "record_name", ["HETATM"], boolean=True)

ppd = PandasPdb()

if atoms:
ppd.df["ATOM"] = atom_df
if hetatms:
Expand Down Expand Up @@ -481,12 +534,21 @@ def save_rgroup_df_to_pdb(
:type gz: bool
"""
ppd = PandasPdb()
atom_df = filter_dataframe(
g.graph["rgroup_df"], "record_name", ["ATOM"], boolean=True
)
hetatm_df = filter_dataframe(
g.graph["rgroup_df"], "record_name", ["HETATM"], boolean=True
)
df = g.graph["rgroup_df"].copy()

# format charge correctly
df.charge = pd.to_numeric(df.charge, errors="coerce")

blank_cols = ["blank_1", "blank_2", "blank_3", "blank_4", "segment_id"]
for col in blank_cols:
if col not in df.columns:
df[col] = [""] * len(df)
df["line_idx"] = list(range(1, len(df) + 1))
df = df[pdb_df_columns]

atom_df = filter_dataframe(df, "record_name", ["ATOM"], boolean=True)
hetatm_df = filter_dataframe(df, "record_name", ["HETATM"], boolean=True)

if atoms:
ppd.df["ATOM"] = atom_df
if hetatms:
Expand Down
2 changes: 1 addition & 1 deletion tests/protein/tensor/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@
def test_save_and_load_protein():
a = Protein().from_pdb_code("4hhb")
torch.save(a, "4hhb.pt")
b = torch.load("4hhb.pt")
b = torch.load("4hhb.pt", weights_only=False)
assert a == b
1 change: 0 additions & 1 deletion tests/protein/tensor/test_reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,3 @@ def test_dist_mat_to_coords():
assert torch.allclose(d, torch.cdist(X, X), atol=1e-4)
X_aligned = kabsch(X, coords)
assert torch.allclose(coords, X_aligned, atol=1e-4)
return coords, X, X_aligned
5 changes: 4 additions & 1 deletion tests/protein/test_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,10 @@ def test_alt_loc_exclusion():
):
config.alt_locs = opt
g = construct_graph(config=config, pdb_code="2VVI")
assert np.array_equal(g.nodes[node_id]["coords"], expected_coords)
assert np.array_equal(
g.nodes[node_id]["coords"],
np.array(expected_coords, dtype=np.float32),
)


def test_alt_loc_inclusion():
Expand Down
38 changes: 29 additions & 9 deletions tests/protein/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,24 @@ def test_save_graph_to_pdb():
# Check file exists
assert os.path.isfile("/tmp/test_graph.pdb")

# Check for equivalence between saved and existing DFs.
# We drop the line_idx columns as these will be renumbered
graph_df = (
g.graph["pdb_df"]
.drop(
[
"node_id",
"residue_id",
],
axis=1,
)
.reset_index(drop=True)
)

a.reset_index(drop=True, inplace=True)
a = a[graph_df.columns] # Reorder columns

assert_frame_equal(
a.drop(["line_idx"], axis=1),
g.graph["pdb_df"].drop(["line_idx", "node_id", "residue_id"], axis=1),
a,
graph_df,
)
h = construct_graph(path="/tmp/test_graph.pdb")

Expand All @@ -48,10 +61,17 @@ def test_save_pdb_df_to_pdb():
# Check file exists
assert os.path.isfile("/tmp/test_graph.pdb")

# We drop the line_idx columns as these will be renumbered
assert_frame_equal(
a.drop(["line_idx"], axis=1),
g.graph["pdb_df"].drop(["line_idx", "node_id", "residue_id"], axis=1),
a,
g.graph["pdb_df"]
.drop(
[
"node_id",
"residue_id",
],
axis=1,
)
.reset_index(drop=True),
)

# Now check for raw, unprocessed DF
Expand All @@ -73,10 +93,10 @@ def test_save_rgroup_df_to_pdb():

# We drop the line_idx columns as these will be renumbered
assert_frame_equal(
a.drop(["line_idx"], axis=1),
a,
filter_dataframe(
g.graph["rgroup_df"], "record_name", ["HETATM"], False
).drop(["line_idx", "node_id", "residue_id"], axis=1),
).drop(["node_id", "residue_id"], axis=1),
)


Expand Down

0 comments on commit 4d8dc64

Please sign in to comment.