Skip to content

Commit

Permalink
Merge branch 'master' into hetatm_parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
a-r-j committed Aug 4, 2024
2 parents 8c713d1 + 4d8dc64 commit ee9c8ae
Show file tree
Hide file tree
Showing 9 changed files with 139 additions and 36 deletions.
4 changes: 3 additions & 1 deletion .requirements/base.in
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@ pandas<2.0.0
biopandas>=0.5.1
biopython
bioservices>=1.10.0
cpdb-protein==0.2.0
cython
deepdiff
loguru
looseversion
matplotlib>=3.4.3
multipledispatch
networkx
numpy<1.24.0
numpy
pandas
plotly
pydantic
Expand Down
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
* Fixes progress bar for `download_pdb_multiprocessing`. [#394](https://github.com/a-r-j/graphein/pull/394)
* Add support for DSSP >4. Backwards compatibility is still supported. [#355](https://github.com/a-r-j/graphein/pull/355). Fixes [#353](https://github.com/a-r-j/graphein/issues/353).
* Fixes bug where RSA features are missing from nodes with insertion codes. [#355](https://github.com/a-r-j/graphein/pull/355). Fixes [#354](https://github.com/a-r-j/graphein/issues/353).
* Fix bug where the `deprotonate` argument is not wired up to `graphein.protein.graphs.construct_graphs`. [#375](https://github.com/a-r-j/graphein/pull/375)
* Add missing modified residue `AYA` to constants [#390](https://github.com/a-r-j/graphein/pull/390)
* Fix bug where the `deprotonate` argument is not wired up to `graphein.protein.graphs.construct_graphs` [#375](https://github.com/a-r-j/graphein/pull/375)
* Fix cluster file loading bug in `pdb_data.py` [#396](https://github.com/a-r-j/graphein/pull/396)

Expand Down Expand Up @@ -80,6 +82,7 @@ https://github.com/a-r-j/graphein/pull/334

#### Other Changes

- Uses [`cpdb`](https://github.com/a-r-j/CPDB) as default PDB file parser for improved performance. [#323](https://github.com/a-r-j/graphein/pull/323).
- Adds transform composition to FoldComp Dataset [#312](https://github.com/a-r-j/graphein/pull/312)
- Adds entry point for biopandas dataframes in `graphein.protein.tensor.io.protein_to_pyg`. [#310](https://github.com/a-r-j/graphein/pull/310)
- Adds support for `.ent` files to `graphein.protein.graphs.read_pdb_to_dataframe`. [#310](https://github.com/a-r-j/graphein/pull/310)
Expand Down
32 changes: 21 additions & 11 deletions graphein/protein/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import cpdb
import networkx as nx
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -109,32 +110,41 @@ def read_pdb_to_dataframe(
or path.endswith(".pdb.gz")
or path.endswith(".ent")
):
atomic_df = PandasPdb().read_pdb(path)
atomic_df = cpdb.parse(path)
elif path.endswith(".mmtf") or path.endswith(".mmtf.gz"):
atomic_df = PandasMmtf().read_mmtf(path)
atomic_df = atomic_df.get_model(model_index)
atomic_df = pd.concat(
[atomic_df.df["ATOM"], atomic_df.df["HETATM"]]
)
elif (
path.endswith(".cif")
or path.endswith(".cif.gz")
or path.endswith(".mmcif")
or path.endswith(".mmcif.gz")
):
atomic_df = PandasMmcif().read_mmcif(path)
atomic_df = atomic_df.get_model(model_index)
atomic_df = atomic_df.convert_to_pandas_pdb()
atomic_df = pd.concat(
[atomic_df.df["ATOM"], atomic_df.df["HETATM"]]
)
else:
raise ValueError(
f"File {path} must be either .pdb(.gz), .mmtf(.gz), .(mm)cif(.gz) or .ent, not {path.split('.')[-1]}"
)
elif uniprot_id is not None:
atomic_df = PandasPdb().fetch_pdb(
uniprot_id=uniprot_id, source="alphafold2-v3"
)
atomic_df = cpdb.parse(uniprot_id=uniprot_id)
else:
atomic_df = PandasPdb().fetch_pdb(pdb_code)
atomic_df = atomic_df.get_model(model_index)
if len(atomic_df.df["ATOM"]) == 0:
atomic_df = cpdb.parse(pdb_code=pdb_code)

if "model_idx" in atomic_df.columns:
atomic_df = atomic_df.loc[atomic_df["model_idx"] == model_index]

if len(atomic_df) == 0:
raise ValueError(f"No model found for index: {model_index}")
if isinstance(atomic_df, PandasMmcif):
atomic_df = atomic_df.convert_to_pandas_pdb()
return pd.concat([atomic_df.df["ATOM"], atomic_df.df["HETATM"]])

return atomic_df


def label_node_id(
Expand Down Expand Up @@ -285,7 +295,7 @@ def remove_alt_locs(
# Unsort
if keep in ["max_occupancy", "min_occupancy"]:
df = df.sort_index()

df = df.reset_index(drop=True)
return df


Expand Down
4 changes: 4 additions & 0 deletions graphein/protein/resi_atoms.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,7 @@
"ABA",
"ACE",
"AIB",
"AYA",
"BMT",
"BOC",
"CBX",
Expand Down Expand Up @@ -535,6 +536,7 @@
"ABA",
"ACE",
"AIB",
"AYA",
"ALA",
"ARG",
"ASN",
Expand Down Expand Up @@ -639,6 +641,7 @@
"ASN": "N",
"ASP": "D",
"ASX": "B",
"AYA": "A",
"BMT": "T",
"BOC": "X",
"CBX": "X",
Expand Down Expand Up @@ -795,6 +798,7 @@
"ABA": "ALA",
"ACE": "-",
"AIB": "ALA",
"AYA": "ALA",
"BMT": "THR",
"BOC": "-",
"CBX": "-",
Expand Down
86 changes: 74 additions & 12 deletions graphein/protein/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from urllib.request import urlopen

import networkx as nx
import numpy as np
import pandas as pd
import requests
import wget
Expand All @@ -25,6 +26,30 @@

from .resi_atoms import BACKBONE_ATOMS, RESI_THREE_TO_1

pdb_df_columns = [
"record_name",
"atom_number",
"blank_1",
"atom_name",
"alt_loc",
"residue_name",
"blank_2",
"chain_id",
"residue_number",
"insertion",
"blank_3",
"x_coord",
"y_coord",
"z_coord",
"occupancy",
"b_factor",
"blank_4",
"segment_id",
"element_symbol",
"charge",
"line_idx",
]


class ProteinGraphConfigurationError(Exception):
"""
Expand Down Expand Up @@ -418,12 +443,27 @@ def save_graph_to_pdb(
:type gz: bool
"""
ppd = PandasPdb()
atom_df = filter_dataframe(
g.graph["pdb_df"], "record_name", ["ATOM"], boolean=True
)
hetatm_df = filter_dataframe(
g.graph["pdb_df"], "record_name", ["HETATM"], boolean=True
)

df = g.graph["pdb_df"].copy()
# format charge correctly
df.charge = pd.to_numeric(df.charge, errors="coerce")

# Add blank columns
blank_cols = [
"blank_1",
"blank_2",
"blank_3",
"blank_4",
"segment_id",
]
for col in blank_cols:
if col not in df.columns:
df[col] = ""
df["line_idx"] = list(range(1, len(df) + 1))
df = df[pdb_df_columns]
atom_df = filter_dataframe(df, "record_name", ["ATOM"], boolean=True)
hetatm_df = filter_dataframe(df, "record_name", ["HETATM"], boolean=True)

if atoms:
ppd.df["ATOM"] = atom_df
if hetatms:
Expand All @@ -448,9 +488,22 @@ def save_pdb_df_to_pdb(
:param gz: Whether to gzip the file. Defaults to ``False``.
:type gz: bool
"""
df = df.copy()
# format charge correctly
df.charge = pd.to_numeric(df.charge, errors="coerce")
df.alt_loc = df.alt_loc.fillna(" ")
blank_cols = ["blank_1", "blank_2", "blank_3", "blank_4", "segment_id"]
for col in blank_cols:
if col not in df.columns:
df[col] = ""
df["line_idx"] = list(range(1, len(df) + 1))
df = df[pdb_df_columns]

atom_df = filter_dataframe(df, "record_name", ["ATOM"], boolean=True)
hetatm_df = filter_dataframe(df, "record_name", ["HETATM"], boolean=True)

ppd = PandasPdb()

if atoms:
ppd.df["ATOM"] = atom_df
if hetatms:
Expand Down Expand Up @@ -481,12 +534,21 @@ def save_rgroup_df_to_pdb(
:type gz: bool
"""
ppd = PandasPdb()
atom_df = filter_dataframe(
g.graph["rgroup_df"], "record_name", ["ATOM"], boolean=True
)
hetatm_df = filter_dataframe(
g.graph["rgroup_df"], "record_name", ["HETATM"], boolean=True
)
df = g.graph["rgroup_df"].copy()

# format charge correctly
df.charge = pd.to_numeric(df.charge, errors="coerce")

blank_cols = ["blank_1", "blank_2", "blank_3", "blank_4", "segment_id"]
for col in blank_cols:
if col not in df.columns:
df[col] = [""] * len(df)
df["line_idx"] = list(range(1, len(df) + 1))
df = df[pdb_df_columns]

atom_df = filter_dataframe(df, "record_name", ["ATOM"], boolean=True)
hetatm_df = filter_dataframe(df, "record_name", ["HETATM"], boolean=True)

if atoms:
ppd.df["ATOM"] = atom_df
if hetatms:
Expand Down
2 changes: 1 addition & 1 deletion tests/protein/tensor/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@
def test_save_and_load_protein():
a = Protein().from_pdb_code("4hhb")
torch.save(a, "4hhb.pt")
b = torch.load("4hhb.pt")
b = torch.load("4hhb.pt", weights_only=False)
assert a == b
1 change: 0 additions & 1 deletion tests/protein/tensor/test_reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,3 @@ def test_dist_mat_to_coords():
assert torch.allclose(d, torch.cdist(X, X), atol=1e-4)
X_aligned = kabsch(X, coords)
assert torch.allclose(coords, X_aligned, atol=1e-4)
return coords, X, X_aligned
5 changes: 4 additions & 1 deletion tests/protein/test_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,10 @@ def test_alt_loc_exclusion():
):
config.alt_locs = opt
g = construct_graph(config=config, pdb_code="2VVI")
assert np.array_equal(g.nodes[node_id]["coords"], expected_coords)
assert np.array_equal(
g.nodes[node_id]["coords"],
np.array(expected_coords, dtype=np.float32),
)


def test_alt_loc_inclusion():
Expand Down
38 changes: 29 additions & 9 deletions tests/protein/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,24 @@ def test_save_graph_to_pdb():
# Check file exists
assert os.path.isfile("/tmp/test_graph.pdb")

# Check for equivalence between saved and existing DFs.
# We drop the line_idx columns as these will be renumbered
graph_df = (
g.graph["pdb_df"]
.drop(
[
"node_id",
"residue_id",
],
axis=1,
)
.reset_index(drop=True)
)

a.reset_index(drop=True, inplace=True)
a = a[graph_df.columns] # Reorder columns

assert_frame_equal(
a.drop(["line_idx"], axis=1),
g.graph["pdb_df"].drop(["line_idx", "node_id", "residue_id"], axis=1),
a,
graph_df,
)
h = construct_graph(path="/tmp/test_graph.pdb")

Expand All @@ -48,10 +61,17 @@ def test_save_pdb_df_to_pdb():
# Check file exists
assert os.path.isfile("/tmp/test_graph.pdb")

# We drop the line_idx columns as these will be renumbered
assert_frame_equal(
a.drop(["line_idx"], axis=1),
g.graph["pdb_df"].drop(["line_idx", "node_id", "residue_id"], axis=1),
a,
g.graph["pdb_df"]
.drop(
[
"node_id",
"residue_id",
],
axis=1,
)
.reset_index(drop=True),
)

# Now check for raw, unprocessed DF
Expand All @@ -73,10 +93,10 @@ def test_save_rgroup_df_to_pdb():

# We drop the line_idx columns as these will be renumbered
assert_frame_equal(
a.drop(["line_idx"], axis=1),
a,
filter_dataframe(
g.graph["rgroup_df"], "record_name", ["HETATM"], False
).drop(["line_idx", "node_id", "residue_id"], axis=1),
).drop(["node_id", "residue_id"], axis=1),
)


Expand Down

0 comments on commit ee9c8ae

Please sign in to comment.