From 5473c6b6c9061d280898e8326e776211ac196770 Mon Sep 17 00:00:00 2001 From: "Arian R. Jamasb" Date: Thu, 1 Aug 2024 18:16:44 +0200 Subject: [PATCH] fix: improve robustness of and add a test #141 --- biopandas/pdb/pandas_pdb.py | 163 ++++++++++++++++++++++-------------- tests/pdb/test_write_pdb.py | 25 ++++-- 2 files changed, 119 insertions(+), 69 deletions(-) diff --git a/biopandas/pdb/pandas_pdb.py b/biopandas/pdb/pandas_pdb.py index 6eb53bf..94c2ca0 100644 --- a/biopandas/pdb/pandas_pdb.py +++ b/biopandas/pdb/pandas_pdb.py @@ -9,8 +9,8 @@ import gzip import sys -import warnings import textwrap +import warnings from copy import deepcopy from io import StringIO from typing import List, Optional @@ -22,9 +22,10 @@ import pandas as pd from looseversion import LooseVersion -from .engines import amino3to1dict, pdb_df_columns, pdb_records from biopandas.constants import ATOMIC_MASSES +from .engines import amino3to1dict, pdb_df_columns, pdb_records + pd_version = LooseVersion(pd.__version__) @@ -115,29 +116,34 @@ def read_pdb_from_list(self, pdb_lines): self.header, self.code = self._parse_header_code() return self - def fetch_pdb(self, pdb_code: Optional[str] = None, uniprot_id: Optional[str] = None, source: str = "pdb"): + def fetch_pdb( + self, + pdb_code: Optional[str] = None, + uniprot_id: Optional[str] = None, + source: str = "pdb", + ): """Fetches PDB file contents from the Protein Databank at rcsb.org or AlphaFold database - at https://alphafold.ebi.ac.uk/. -. + at https://alphafold.ebi.ac.uk/. + . - Parameters - ---------- - pdb_code : str, optional - A 4-letter PDB code, e.g., `"3eiy"` to retrieve structures from the PDB. - Defaults to `None`. + Parameters + ---------- + pdb_code : str, optional + A 4-letter PDB code, e.g., `"3eiy"` to retrieve structures from the PDB. + Defaults to `None`. - uniprot_id : str, optional - A UniProt Identifier, e.g., `"Q5VSL9"` to retrieve structures from the AF2 database. - Defaults to `None`. + uniprot_id : str, optional + A UniProt Identifier, e.g., `"Q5VSL9"` to retrieve structures from the AF2 database. + Defaults to `None`. - source : str - The source to retrieve the structure from - (`"pdb"`, `"alphafold2-v3"`, `"alphafold2-v4"`(latest)). - Defaults to `"pdb"`. + source : str + The source to retrieve the structure from + (`"pdb"`, `"alphafold2-v3"`, `"alphafold2-v4"`(latest)). + Defaults to `"pdb"`. - Returns - --------- - self + Returns + --------- + self """ # Sanitize input @@ -145,15 +151,21 @@ def fetch_pdb(self, pdb_code: Optional[str] = None, uniprot_id: Optional[str] = invalid_input_identifier_2 = pdb_code is not None and uniprot_id is not None invalid_input_combination_1 = uniprot_id is not None and source == "pdb" invalid_input_combination_2 = pdb_code is not None and source in { - "alphafold2-v3", "alphafold2-v4"} + "alphafold2-v3", + "alphafold2-v4", + } if invalid_input_identifier_1 or invalid_input_identifier_2: raise ValueError("Please provide either a PDB code or a UniProt ID.") if invalid_input_combination_1: - raise ValueError("Please use a 'pdb_code' instead of 'uniprot_id' for source='pdb'.") + raise ValueError( + "Please use a 'pdb_code' instead of 'uniprot_id' for source='pdb'." + ) elif invalid_input_combination_2: - raise ValueError(f"Please use a 'uniprot_id' instead of 'pdb_code' for source={source}.") + raise ValueError( + f"Please use a 'uniprot_id' instead of 'pdb_code' for source={source}." + ) if source == "alphafold2-v3": af2_version = 3 @@ -164,8 +176,10 @@ def fetch_pdb(self, pdb_code: Optional[str] = None, uniprot_id: Optional[str] = elif source == "pdb": self.pdb_path, self.pdb_text = self._fetch_pdb(pdb_code) else: - raise ValueError(f"Invalid source: {source}." - " Please use one of 'pdb' or 'alphafold2-v3' or 'alphafold2-v4'.") + raise ValueError( + f"Invalid source: {source}." + " Please use one of 'pdb' or 'alphafold2-v3' or 'alphafold2-v4'." + ) self._df = self._construct_df(pdb_lines=self.pdb_text.splitlines(True)) return self @@ -248,7 +262,7 @@ def impute_element(self, records=("ATOM", "HETATM"), inplace=False): ) return t - def add_remark(self, code, text='', indent=0): + def add_remark(self, code, text="", indent=0): """Add custom REMARK entry. The remark will be inserted to preserve the ordering of REMARK codes, i.e. if the code is @@ -275,57 +289,65 @@ def add_remark(self, code, text='', indent=0): """ # Prepare info from self - if 'OTHERS' in self.df: - df_others = self.df['OTHERS'] + if "OTHERS" in self.df: + df_others = self.df["OTHERS"] else: - df_others = pd.DataFrame(columns=['record_name', 'entry', 'line_idx']) - record_types = list(filter(lambda x: x in self.df, ['ATOM', 'HETATM', 'ANISOU'])) - remarks = df_others[df_others['record_name'] == 'REMARK']['entry'] + df_others = pd.DataFrame(columns=["record_name", "entry", "line_idx"]) + record_types = list( + filter(lambda x: x in self.df, ["ATOM", "HETATM", "ANISOU"]) + ) + remarks = df_others[df_others["record_name"] == "REMARK"]["entry"] # Find index and line_idx where to insert the remark to preserve remark code order if len(remarks): remark_codes = remarks.apply(lambda x: x.split(maxsplit=1)[0]).astype(int) - insertion_pos = remark_codes.searchsorted(code, side='right') + insertion_pos = remark_codes.searchsorted(code, side="right") if insertion_pos < len(remark_codes): # Remark in the middle insertion_idx = remark_codes.index[insertion_pos] - insertion_line_idx = df_others.loc[insertion_idx]['line_idx'] + insertion_line_idx = df_others.loc[insertion_idx]["line_idx"] else: # Last remark insertion_idx = len(remark_codes) - insertion_line_idx = df_others['line_idx'].iloc[-1] + 1 + insertion_line_idx = df_others["line_idx"].iloc[-1] + 1 else: # First remark insertion_idx = 0 - insertion_line_idx = min([self.df[r]['line_idx'].min() for r in record_types]) + insertion_line_idx = min( + [self.df[r]["line_idx"].min() for r in record_types] + ) # Wrap remark to fit into 80 characters per line and add indentation wrapper = textwrap.TextWrapper(width=80 - (11 + indent)) - lines = sum([wrapper.wrap(line.strip()) or [' '] for line in text.split('\n')], []) - lines = list(map(lambda x: f'{code:4} ' + indent*' ' + x, lines)) + lines = sum( + [wrapper.wrap(line.strip()) or [" "] for line in text.split("\n")], [] + ) + lines = list(map(lambda x: f"{code:4} " + indent * " " + x, lines)) # Shift data frame indices and row indices to create space for the remark # Create space in OTHERS - line_idx = df_others['line_idx'].copy() + line_idx = df_others["line_idx"].copy() line_idx[line_idx >= insertion_line_idx] += len(lines) - df_others['line_idx'] = line_idx + df_others["line_idx"] = line_idx index = pd.Series(df_others.index.copy()) index[index >= insertion_idx] += len(lines) df_others.index = index # Shift all other record types that follow inserted remark for records in record_types: df_records = self.df[records] - if not insertion_line_idx > df_records['line_idx'].max(): - df_records['line_idx'] += len(lines) + if not insertion_line_idx > df_records["line_idx"].max(): + df_records["line_idx"] += len(lines) # Put remark into 'OTHERS' data frame df_remark = { - idx: ['REMARK', line, line_idx] + idx: ["REMARK", line, line_idx] for idx, line, line_idx in zip( range(insertion_idx, insertion_idx + len(lines)), lines, range(insertion_line_idx, insertion_line_idx + len(lines)), ) } - df_remark = pd.DataFrame.from_dict(df_remark, orient='index', columns=df_others.columns) - self.df['OTHERS'] = pd.concat([df_others, df_remark]).sort_index() + df_remark = pd.DataFrame.from_dict( + df_remark, orient="index", columns=df_others.columns + ) + self.df["OTHERS"] = pd.concat([df_others, df_remark]).sort_index() @staticmethod def rmsd(df1, df2, s=None, invert=False, decimals=4): @@ -435,11 +457,13 @@ def _fetch_af2(uniprot_id: str, af2_version: int = 3): try: response = urlopen(url) txt = response.read() - txt = txt.decode('utf-8') if sys.version_info[0] >= 3 else txt.encode('ascii') + txt = ( + txt.decode("utf-8") if sys.version_info[0] >= 3 else txt.encode("ascii") + ) except HTTPError as e: - print(f'HTTP Error {e.code}') + print(f"HTTP Error {e.code}") except URLError as e: - print(f'URL Error {e.args}') + print(f"URL Error {e.args}") return url, txt def _parse_header_code(self): @@ -518,7 +542,7 @@ def _construct_df(pdb_lines): record = line[:6].rstrip() line_ele = ["" for _ in range(len(pdb_records[record]) + 1)] for idx, ele in enumerate(pdb_records[record]): - line_ele[idx] = line[ele["line"][0]: ele["line"][1]].strip() + line_ele[idx] = line[ele["line"][0] : ele["line"][1]].strip() line_ele[-1] = line_num line_lists[record].append(line_ele) else: @@ -847,7 +871,9 @@ def get_model(self, model_index: int) -> PandasPdb: biopandas_structure.label_models() if "ATOM" in biopandas_structure.df.keys(): - biopandas_structure.df["ATOM"] = biopandas_structure.df["ATOM"].loc[biopandas_structure.df["ATOM"]["model_id"] == model_index] + biopandas_structure.df["ATOM"] = biopandas_structure.df["ATOM"].loc[ + biopandas_structure.df["ATOM"]["model_id"] == model_index + ] if "HETATM" in biopandas_structure.df.keys(): biopandas_structure.df["HETATM"] = biopandas_structure.df["HETATM"].loc[ biopandas_structure.df["HETATM"]["model_id"] == model_index @@ -877,15 +903,24 @@ def get_models(self, model_indices: List[int]) -> PandasPdb: if "ATOM" in biopandas_structure.df.keys(): biopandas_structure.df["ATOM"] = biopandas_structure.df["ATOM"].loc[ - [x in model_indices for x in biopandas_structure.df["ATOM"]["model_id"].tolist()] + [ + x in model_indices + for x in biopandas_structure.df["ATOM"]["model_id"].tolist() + ] ] if "HETATM" in biopandas_structure.df.keys(): biopandas_structure.df["HETATM"] = biopandas_structure.df["HETATM"].loc[ - [x in model_indices for x in biopandas_structure.df["HETATM"]["model_id"].tolist()] + [ + x in model_indices + for x in biopandas_structure.df["HETATM"]["model_id"].tolist() + ] ] if "ANISOU" in biopandas_structure.df.keys(): biopandas_structure.df["ANISOU"] = biopandas_structure.df["ANISOU"].loc[ - [x in model_indices for x in biopandas_structure.df["ANISOU"]["model_id"].tolist()] + [ + x in model_indices + for x in biopandas_structure.df["ANISOU"]["model_id"].tolist() + ] ] return biopandas_structure @@ -906,7 +941,7 @@ def to_pdb_stream(self, records: tuple[str] = ("ATOM", "HETATM")) -> StringIO: df = pd.concat([df[a] for a in records]) if "model_id" in df.columns: df = df.drop(columns=["model_id"]) - df.residue_number = df.residue_number.astype(int) + df["residue_number"] = pd.to_numeric(df.residue_number, errors="coerce") records = [r.strip() for r in list(set(df.record_name))] dfs = {r: df.loc[df.record_name == r] for r in records} @@ -921,8 +956,7 @@ def to_pdb_stream(self, records: tuple[str] = ("ATOM", "HETATM")) -> StringIO: if c in {"x_coord", "y_coord", "z_coord"}: for idx in range(dfs[r][c].values.shape[0]): if len(dfs[r][c].values[idx]) > 8: - dfs[r][c].values[idx] = str( - dfs[r][c].values[idx]).strip() + dfs[r][c].values[idx] = str(dfs[r][c].values[idx]).strip() if c not in {"line_idx", "OUT"}: dfs[r]["OUT"] = dfs[r]["OUT"] + dfs[r][c] @@ -941,7 +975,7 @@ def to_pdb_stream(self, records: tuple[str] = ("ATOM", "HETATM")) -> StringIO: output.seek(0) return output - def gyradius(self, records: tuple[str] = ("ATOM",), decimals: int = 4) -> float: + def gyradius(self, records: tuple[str] = ("ATOM",), decimals: int = 4) -> float: """Compute the Radius of Gyration of a molecule Parameters @@ -958,7 +992,7 @@ def gyradius(self, records: tuple[str] = ("ATOM",), decimals: int = 4) -> float rg : float Radius of Gyration of df in Angstrom - """ + """ if isinstance(records, str): warnings.warn( "Using a string as `records` argument is " @@ -970,16 +1004,19 @@ def gyradius(self, records: tuple[str] = ("ATOM",), decimals: int = 4) -> float records = (records,) if len(records) > 1: - df = pd.concat(objs=[self.df[record][["x_coord", - "y_coord", - "z_coord", - "element_symbol"]] - for record in records]) + df = pd.concat( + objs=[ + self.df[record][["x_coord", "y_coord", "z_coord", "element_symbol"]] + for record in records + ] + ) else: df = self.df[records[0]] coords = df[["x_coord", "y_coord", "z_coord"]].to_numpy() - masses = df["element_symbol"].map(lambda atom: ATOMIC_MASSES.get(atom, 0)).to_numpy() + masses = ( + df["element_symbol"].map(lambda atom: ATOMIC_MASSES.get(atom, 0)).to_numpy() + ) total_mass = masses.sum() center_of_mass = (masses[:, None] * coords).sum(axis=0) / total_mass distances = np.linalg.norm(coords - center_of_mass, axis=1) diff --git a/tests/pdb/test_write_pdb.py b/tests/pdb/test_write_pdb.py index f16e5ad..7fcc39b 100644 --- a/tests/pdb/test_write_pdb.py +++ b/tests/pdb/test_write_pdb.py @@ -10,6 +10,7 @@ import importlib.resources as pkg_resources else: import importlib_resources as pkg_resources + import os import warnings @@ -51,9 +52,7 @@ def test_defaults(): def test_nonexpected_column(): ppdb = PandasPdb() ppdb.read_pdb(TESTDATA_FILENAME) - ppdb.df["HETATM"]["test"] = pd.Series( - "test", index=ppdb.df["HETATM"].index - ) + ppdb.df["HETATM"]["test"] = pd.Series("test", index=ppdb.df["HETATM"].index) with warnings.catch_warnings(record=True) as w: ppdb.to_pdb(path=OUTFILE, records=["HETATM"]) with open(OUTFILE, "r") as f: @@ -159,6 +158,20 @@ def test_b_factor_shift(): assert tmp_df[ tmp_df["element_symbol"].isnull() | (tmp_df["element_symbol"] == "") ].empty - assert not tmp_df[ - tmp_df["blank_4"].isnull() | (tmp_df["blank_4"] == "") - ].empty + assert not tmp_df[tmp_df["blank_4"].isnull() | (tmp_df["blank_4"] == "")].empty + + +def test_to_pdb_stream(): + """Test public write_pdb_stream""" + ppdb = PandasPdb() + ppdb.read_pdb(TESTDATA_FILENAME) + stream = ppdb.to_pdb_stream() + + lines_to_check = open(TESTDATA_FILENAME).read().split("\n") + lines_to_check = [ + line for line in lines_to_check if line.startswith(("ATOM", "HETATM")) + ] + lines_to_check.append("") + + source_pdb = "\n".join(lines_to_check) + assert stream.read() == source_pdb