Skip to content

Commit

Permalink
fix: improve robustness of and add a test #141
Browse files Browse the repository at this point in the history
  • Loading branch information
a-r-j committed Aug 1, 2024
1 parent d9a2878 commit 5473c6b
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 69 deletions.
163 changes: 100 additions & 63 deletions biopandas/pdb/pandas_pdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

import gzip
import sys
import warnings
import textwrap
import warnings
from copy import deepcopy
from io import StringIO
from typing import List, Optional
Expand All @@ -22,9 +22,10 @@
import pandas as pd
from looseversion import LooseVersion

from .engines import amino3to1dict, pdb_df_columns, pdb_records
from biopandas.constants import ATOMIC_MASSES

from .engines import amino3to1dict, pdb_df_columns, pdb_records

pd_version = LooseVersion(pd.__version__)


Expand Down Expand Up @@ -115,45 +116,56 @@ def read_pdb_from_list(self, pdb_lines):
self.header, self.code = self._parse_header_code()
return self

def fetch_pdb(self, pdb_code: Optional[str] = None, uniprot_id: Optional[str] = None, source: str = "pdb"):
def fetch_pdb(
self,
pdb_code: Optional[str] = None,
uniprot_id: Optional[str] = None,
source: str = "pdb",
):
"""Fetches PDB file contents from the Protein Databank at rcsb.org or AlphaFold database
at https://alphafold.ebi.ac.uk/.
.
at https://alphafold.ebi.ac.uk/.
.
Parameters
----------
pdb_code : str, optional
A 4-letter PDB code, e.g., `"3eiy"` to retrieve structures from the PDB.
Defaults to `None`.
Parameters
----------
pdb_code : str, optional
A 4-letter PDB code, e.g., `"3eiy"` to retrieve structures from the PDB.
Defaults to `None`.
uniprot_id : str, optional
A UniProt Identifier, e.g., `"Q5VSL9"` to retrieve structures from the AF2 database.
Defaults to `None`.
uniprot_id : str, optional
A UniProt Identifier, e.g., `"Q5VSL9"` to retrieve structures from the AF2 database.
Defaults to `None`.
source : str
The source to retrieve the structure from
(`"pdb"`, `"alphafold2-v3"`, `"alphafold2-v4"`(latest)).
Defaults to `"pdb"`.
source : str
The source to retrieve the structure from
(`"pdb"`, `"alphafold2-v3"`, `"alphafold2-v4"`(latest)).
Defaults to `"pdb"`.
Returns
---------
self
Returns
---------
self
"""
# Sanitize input
invalid_input_identifier_1 = pdb_code is None and uniprot_id is None
invalid_input_identifier_2 = pdb_code is not None and uniprot_id is not None
invalid_input_combination_1 = uniprot_id is not None and source == "pdb"
invalid_input_combination_2 = pdb_code is not None and source in {
"alphafold2-v3", "alphafold2-v4"}
"alphafold2-v3",
"alphafold2-v4",
}

if invalid_input_identifier_1 or invalid_input_identifier_2:
raise ValueError("Please provide either a PDB code or a UniProt ID.")

if invalid_input_combination_1:
raise ValueError("Please use a 'pdb_code' instead of 'uniprot_id' for source='pdb'.")
raise ValueError(
"Please use a 'pdb_code' instead of 'uniprot_id' for source='pdb'."
)
elif invalid_input_combination_2:
raise ValueError(f"Please use a 'uniprot_id' instead of 'pdb_code' for source={source}.")
raise ValueError(
f"Please use a 'uniprot_id' instead of 'pdb_code' for source={source}."
)

if source == "alphafold2-v3":
af2_version = 3
Expand All @@ -164,8 +176,10 @@ def fetch_pdb(self, pdb_code: Optional[str] = None, uniprot_id: Optional[str] =
elif source == "pdb":
self.pdb_path, self.pdb_text = self._fetch_pdb(pdb_code)
else:
raise ValueError(f"Invalid source: {source}."
" Please use one of 'pdb' or 'alphafold2-v3' or 'alphafold2-v4'.")
raise ValueError(
f"Invalid source: {source}."
" Please use one of 'pdb' or 'alphafold2-v3' or 'alphafold2-v4'."
)

self._df = self._construct_df(pdb_lines=self.pdb_text.splitlines(True))
return self
Expand Down Expand Up @@ -248,7 +262,7 @@ def impute_element(self, records=("ATOM", "HETATM"), inplace=False):
)
return t

def add_remark(self, code, text='', indent=0):
def add_remark(self, code, text="", indent=0):
"""Add custom REMARK entry.
The remark will be inserted to preserve the ordering of REMARK codes, i.e. if the code is
Expand All @@ -275,57 +289,65 @@ def add_remark(self, code, text='', indent=0):
"""
# Prepare info from self
if 'OTHERS' in self.df:
df_others = self.df['OTHERS']
if "OTHERS" in self.df:
df_others = self.df["OTHERS"]
else:
df_others = pd.DataFrame(columns=['record_name', 'entry', 'line_idx'])
record_types = list(filter(lambda x: x in self.df, ['ATOM', 'HETATM', 'ANISOU']))
remarks = df_others[df_others['record_name'] == 'REMARK']['entry']
df_others = pd.DataFrame(columns=["record_name", "entry", "line_idx"])
record_types = list(
filter(lambda x: x in self.df, ["ATOM", "HETATM", "ANISOU"])
)
remarks = df_others[df_others["record_name"] == "REMARK"]["entry"]

# Find index and line_idx where to insert the remark to preserve remark code order
if len(remarks):
remark_codes = remarks.apply(lambda x: x.split(maxsplit=1)[0]).astype(int)
insertion_pos = remark_codes.searchsorted(code, side='right')
insertion_pos = remark_codes.searchsorted(code, side="right")
if insertion_pos < len(remark_codes): # Remark in the middle
insertion_idx = remark_codes.index[insertion_pos]
insertion_line_idx = df_others.loc[insertion_idx]['line_idx']
insertion_line_idx = df_others.loc[insertion_idx]["line_idx"]
else: # Last remark
insertion_idx = len(remark_codes)
insertion_line_idx = df_others['line_idx'].iloc[-1] + 1
insertion_line_idx = df_others["line_idx"].iloc[-1] + 1
else: # First remark
insertion_idx = 0
insertion_line_idx = min([self.df[r]['line_idx'].min() for r in record_types])
insertion_line_idx = min(
[self.df[r]["line_idx"].min() for r in record_types]
)

# Wrap remark to fit into 80 characters per line and add indentation
wrapper = textwrap.TextWrapper(width=80 - (11 + indent))
lines = sum([wrapper.wrap(line.strip()) or [' '] for line in text.split('\n')], [])
lines = list(map(lambda x: f'{code:4} ' + indent*' ' + x, lines))
lines = sum(
[wrapper.wrap(line.strip()) or [" "] for line in text.split("\n")], []
)
lines = list(map(lambda x: f"{code:4} " + indent * " " + x, lines))

# Shift data frame indices and row indices to create space for the remark
# Create space in OTHERS
line_idx = df_others['line_idx'].copy()
line_idx = df_others["line_idx"].copy()
line_idx[line_idx >= insertion_line_idx] += len(lines)
df_others['line_idx'] = line_idx
df_others["line_idx"] = line_idx
index = pd.Series(df_others.index.copy())
index[index >= insertion_idx] += len(lines)
df_others.index = index
# Shift all other record types that follow inserted remark
for records in record_types:
df_records = self.df[records]
if not insertion_line_idx > df_records['line_idx'].max():
df_records['line_idx'] += len(lines)
if not insertion_line_idx > df_records["line_idx"].max():
df_records["line_idx"] += len(lines)

# Put remark into 'OTHERS' data frame
df_remark = {
idx: ['REMARK', line, line_idx]
idx: ["REMARK", line, line_idx]
for idx, line, line_idx in zip(
range(insertion_idx, insertion_idx + len(lines)),
lines,
range(insertion_line_idx, insertion_line_idx + len(lines)),
)
}
df_remark = pd.DataFrame.from_dict(df_remark, orient='index', columns=df_others.columns)
self.df['OTHERS'] = pd.concat([df_others, df_remark]).sort_index()
df_remark = pd.DataFrame.from_dict(
df_remark, orient="index", columns=df_others.columns
)
self.df["OTHERS"] = pd.concat([df_others, df_remark]).sort_index()

@staticmethod
def rmsd(df1, df2, s=None, invert=False, decimals=4):
Expand Down Expand Up @@ -435,11 +457,13 @@ def _fetch_af2(uniprot_id: str, af2_version: int = 3):
try:
response = urlopen(url)
txt = response.read()
txt = txt.decode('utf-8') if sys.version_info[0] >= 3 else txt.encode('ascii')
txt = (
txt.decode("utf-8") if sys.version_info[0] >= 3 else txt.encode("ascii")
)
except HTTPError as e:
print(f'HTTP Error {e.code}')
print(f"HTTP Error {e.code}")
except URLError as e:
print(f'URL Error {e.args}')
print(f"URL Error {e.args}")
return url, txt

def _parse_header_code(self):
Expand Down Expand Up @@ -518,7 +542,7 @@ def _construct_df(pdb_lines):
record = line[:6].rstrip()
line_ele = ["" for _ in range(len(pdb_records[record]) + 1)]
for idx, ele in enumerate(pdb_records[record]):
line_ele[idx] = line[ele["line"][0]: ele["line"][1]].strip()
line_ele[idx] = line[ele["line"][0] : ele["line"][1]].strip()
line_ele[-1] = line_num
line_lists[record].append(line_ele)
else:
Expand Down Expand Up @@ -847,7 +871,9 @@ def get_model(self, model_index: int) -> PandasPdb:
biopandas_structure.label_models()

if "ATOM" in biopandas_structure.df.keys():
biopandas_structure.df["ATOM"] = biopandas_structure.df["ATOM"].loc[biopandas_structure.df["ATOM"]["model_id"] == model_index]
biopandas_structure.df["ATOM"] = biopandas_structure.df["ATOM"].loc[
biopandas_structure.df["ATOM"]["model_id"] == model_index
]
if "HETATM" in biopandas_structure.df.keys():
biopandas_structure.df["HETATM"] = biopandas_structure.df["HETATM"].loc[
biopandas_structure.df["HETATM"]["model_id"] == model_index
Expand Down Expand Up @@ -877,15 +903,24 @@ def get_models(self, model_indices: List[int]) -> PandasPdb:

if "ATOM" in biopandas_structure.df.keys():
biopandas_structure.df["ATOM"] = biopandas_structure.df["ATOM"].loc[
[x in model_indices for x in biopandas_structure.df["ATOM"]["model_id"].tolist()]
[
x in model_indices
for x in biopandas_structure.df["ATOM"]["model_id"].tolist()
]
]
if "HETATM" in biopandas_structure.df.keys():
biopandas_structure.df["HETATM"] = biopandas_structure.df["HETATM"].loc[
[x in model_indices for x in biopandas_structure.df["HETATM"]["model_id"].tolist()]
[
x in model_indices
for x in biopandas_structure.df["HETATM"]["model_id"].tolist()
]
]
if "ANISOU" in biopandas_structure.df.keys():
biopandas_structure.df["ANISOU"] = biopandas_structure.df["ANISOU"].loc[
[x in model_indices for x in biopandas_structure.df["ANISOU"]["model_id"].tolist()]
[
x in model_indices
for x in biopandas_structure.df["ANISOU"]["model_id"].tolist()
]
]
return biopandas_structure

Expand All @@ -906,7 +941,7 @@ def to_pdb_stream(self, records: tuple[str] = ("ATOM", "HETATM")) -> StringIO:
df = pd.concat([df[a] for a in records])
if "model_id" in df.columns:
df = df.drop(columns=["model_id"])
df.residue_number = df.residue_number.astype(int)
df["residue_number"] = pd.to_numeric(df.residue_number, errors="coerce")
records = [r.strip() for r in list(set(df.record_name))]
dfs = {r: df.loc[df.record_name == r] for r in records}

Expand All @@ -921,8 +956,7 @@ def to_pdb_stream(self, records: tuple[str] = ("ATOM", "HETATM")) -> StringIO:
if c in {"x_coord", "y_coord", "z_coord"}:
for idx in range(dfs[r][c].values.shape[0]):
if len(dfs[r][c].values[idx]) > 8:
dfs[r][c].values[idx] = str(
dfs[r][c].values[idx]).strip()
dfs[r][c].values[idx] = str(dfs[r][c].values[idx]).strip()

if c not in {"line_idx", "OUT"}:
dfs[r]["OUT"] = dfs[r]["OUT"] + dfs[r][c]
Expand All @@ -941,7 +975,7 @@ def to_pdb_stream(self, records: tuple[str] = ("ATOM", "HETATM")) -> StringIO:
output.seek(0)
return output

def gyradius(self, records: tuple[str] = ("ATOM",), decimals: int = 4) -> float:
def gyradius(self, records: tuple[str] = ("ATOM",), decimals: int = 4) -> float:
"""Compute the Radius of Gyration of a molecule
Parameters
Expand All @@ -958,7 +992,7 @@ def gyradius(self, records: tuple[str] = ("ATOM",), decimals: int = 4) -> float
rg : float
Radius of Gyration of df in Angstrom
"""
"""
if isinstance(records, str):
warnings.warn(
"Using a string as `records` argument is "
Expand All @@ -970,16 +1004,19 @@ def gyradius(self, records: tuple[str] = ("ATOM",), decimals: int = 4) -> float
records = (records,)

if len(records) > 1:
df = pd.concat(objs=[self.df[record][["x_coord",
"y_coord",
"z_coord",
"element_symbol"]]
for record in records])
df = pd.concat(
objs=[
self.df[record][["x_coord", "y_coord", "z_coord", "element_symbol"]]
for record in records
]
)
else:
df = self.df[records[0]]

coords = df[["x_coord", "y_coord", "z_coord"]].to_numpy()
masses = df["element_symbol"].map(lambda atom: ATOMIC_MASSES.get(atom, 0)).to_numpy()
masses = (
df["element_symbol"].map(lambda atom: ATOMIC_MASSES.get(atom, 0)).to_numpy()
)
total_mass = masses.sum()
center_of_mass = (masses[:, None] * coords).sum(axis=0) / total_mass
distances = np.linalg.norm(coords - center_of_mass, axis=1)
Expand Down
25 changes: 19 additions & 6 deletions tests/pdb/test_write_pdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import importlib.resources as pkg_resources
else:
import importlib_resources as pkg_resources

import os
import warnings

Expand Down Expand Up @@ -51,9 +52,7 @@ def test_defaults():
def test_nonexpected_column():
ppdb = PandasPdb()
ppdb.read_pdb(TESTDATA_FILENAME)
ppdb.df["HETATM"]["test"] = pd.Series(
"test", index=ppdb.df["HETATM"].index
)
ppdb.df["HETATM"]["test"] = pd.Series("test", index=ppdb.df["HETATM"].index)
with warnings.catch_warnings(record=True) as w:
ppdb.to_pdb(path=OUTFILE, records=["HETATM"])
with open(OUTFILE, "r") as f:
Expand Down Expand Up @@ -159,6 +158,20 @@ def test_b_factor_shift():
assert tmp_df[
tmp_df["element_symbol"].isnull() | (tmp_df["element_symbol"] == "")
].empty
assert not tmp_df[
tmp_df["blank_4"].isnull() | (tmp_df["blank_4"] == "")
].empty
assert not tmp_df[tmp_df["blank_4"].isnull() | (tmp_df["blank_4"] == "")].empty


def test_to_pdb_stream():
"""Test public write_pdb_stream"""
ppdb = PandasPdb()
ppdb.read_pdb(TESTDATA_FILENAME)
stream = ppdb.to_pdb_stream()

lines_to_check = open(TESTDATA_FILENAME).read().split("\n")
lines_to_check = [
line for line in lines_to_check if line.startswith(("ATOM", "HETATM"))
]
lines_to_check.append("")

source_pdb = "\n".join(lines_to_check)
assert stream.read() == source_pdb

0 comments on commit 5473c6b

Please sign in to comment.