diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c626d57..dec68e5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -39,13 +39,13 @@ jobs: include: - label: earliest os: ubuntu-latest - python-version: 3.8 - rdkit-version: "rdkit=2021.03.1" + python-version: 3.9 + rdkit-version: "rdkit==2022.09.1" coverage: false - label: baseline os: ubuntu-latest python-version: "3.10" - rdkit-version: "rdkit~=2022.09" + rdkit-version: "rdkit~=2023.03.1" coverage: true - label: latest os: ubuntu-latest @@ -58,14 +58,10 @@ jobs: uses: actions/checkout@v4 - name: Setup Conda - uses: conda-incubator/setup-miniconda@v2 + uses: conda-incubator/setup-miniconda@v3 with: python-version: ${{ matrix.python-version }} auto-update-conda: true - channel-priority: flexible - channels: conda-forge, defaults - add-pip-as-python-dependency: true - architecture: x64 use-mamba: true miniforge-variant: Mambaforge @@ -79,13 +75,13 @@ jobs: - name: Install conda dependencies run: | - mamba install ${{ matrix.rdkit-version }} + mamba install uv ${{ matrix.rdkit-version }} mamba list - name: Install package through pip run: | - pip install .[dev] - pip list + uv pip install .[dev] + uv pip list - name: Run tests run: | diff --git a/.github/workflows/formatting.yml b/.github/workflows/formatting_linting.yml similarity index 67% rename from .github/workflows/formatting.yml rename to .github/workflows/formatting_linting.yml index 58e4e3d..af3d722 100644 --- a/.github/workflows/formatting.yml +++ b/.github/workflows/formatting_linting.yml @@ -1,4 +1,4 @@ -name: "Black" +name: "Ruff/Black" on: push: @@ -19,6 +19,12 @@ jobs: uses: actions/checkout@v4 - name: Code formatting + uses: chartboost/ruff-action@v1 + with: + args: "check --preview" + + - name: Notebook formatting uses: psf/black@stable with: + src: "docs/notebooks/" jupyter: true \ No newline at end of file diff --git a/.readthedocs.yml b/.readthedocs.yml index b6c81c8..90822b9 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -1,13 +1,15 @@ version: 2 build: - os: ubuntu-20.04 + os: ubuntu-22.04 tools: - python: mambaforge-4.10 + python: mambaforge-22.9 sphinx: configuration: docs/conf.py python: install: - - method: pip - path: . + - method: pip + path: . + extra_requirements: + - tutorials conda: environment: environment.yml diff --git a/CHANGELOG.md b/CHANGELOG.md index c750e9c..da69aa9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,31 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added + +- `Complex3D` and `fp.plot_3d` now have access to `only_interacting` and + `remove_hydrogens` parameters to control which residues and hydrogen atoms are + displayed. Non-polar hydrogen atoms that aren't involved in interactions are now + hidden by default. +- `LigNetwork.save_png` to save the displayed plot to a PNG file through JavaScript + (Issue #163). +- `Complex3D.save_png` to save the displayed plot to a PNG file through JavaScript. +- `fp.plot_3d` and `fp.plot_lignetwork` now return the underlying `LigNetwork` or + `Complex3D` object which has been enhanced with rich display functionality. From the + user's perspective, nothing changes apart from being able to do + `view = fp.plot_*(...); view.save_png()` to display a popup window for saving the + image. +- `cleanup_substructures` parameter now accessible in `mol2_supplier` to skip + sanitization based on atom types. +- `sanitize` parameter now available for the `mol2_supplier` and `sdf_supplier` classes. +- `ruff` linter and formatter. + +### Fixed + +- `display_residues` was sanitizing each residue while preparing them for display, which + could make debugging faulty molecules difficult. This is now disabled. +- Deprecation warnings + ## [2.0.3] - 2024-03-10 ### Fixed diff --git a/docs/conf.py b/docs/conf.py index bc7ad12..b62f6c2 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -10,6 +10,7 @@ # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # +# ruff: noqa: PTH100 import os import sys @@ -24,7 +25,7 @@ # -- Project information ----------------------------------------------------- project = "ProLIF" -copyright = f"2017-{datetime.now().year}, Cédric Bouysset" +copyright = f"2017-{datetime.now().year}, Cédric Bouysset" # noqa: A001 author = "Cédric Bouysset" @@ -99,7 +100,7 @@ "url": "https://github.com/chemosim-lab/ProLIF", "icon": "fa-brands fa-square-github", "type": "fontawesome", - } + }, ], } @@ -123,7 +124,7 @@ def setup(app): app.add_config_value( "recommonmark_config", { - #'url_resolver': lambda url: github_doc_root + url, + # 'url_resolver': lambda url: github_doc_root + url, "auto_toc_tree_section": "Contents", "enable_math": False, "enable_inline_math": False, diff --git a/docs/notebooks/docking.ipynb b/docs/notebooks/docking.ipynb index 0881fa1..dde8721 100644 --- a/docs/notebooks/docking.ipynb +++ b/docs/notebooks/docking.ipynb @@ -699,7 +699,8 @@ "metadata": {}, "outputs": [], "source": [ - "fp.plot_lignetwork(pose_iterable[0])" + "view = fp.plot_lignetwork(pose_iterable[0])\n", + "view" ] }, { @@ -713,7 +714,11 @@ "- hover an interaction line to display the distance.\n", "\n", ":::{note}\n", - "It is not possible to export it as an image, but you can always take a screenshot.\n", + "After arranging the residues to your liking, you can save the plot as a PNG image with:\n", + "```python\n", + "view.save_png()\n", + "```\n", + "Note that this only works in notebooks and cannot be used in regular Python scripts.\n", ":::\n", "\n", "You can generate 2 types of diagram with this function, controlled by the `kind` argument:\n", @@ -735,7 +740,8 @@ "outputs": [], "source": [ "pose_index = 0\n", - "fp.plot_lignetwork(pose_iterable[pose_index], kind=\"frame\", frame=pose_index)" + "view = fp.plot_lignetwork(pose_iterable[pose_index], kind=\"frame\", frame=pose_index)\n", + "view" ] }, { @@ -753,9 +759,10 @@ "source": [ "fp_count = plf.Fingerprint(count=True)\n", "fp_count.run_from_iterable(pose_iterable, protein_mol)\n", - "fp_count.plot_lignetwork(\n", + "view = fp_count.plot_lignetwork(\n", " pose_iterable[pose_index], kind=\"frame\", frame=pose_index, display_all=True\n", - ")" + ")\n", + "view" ] }, { @@ -785,22 +792,12 @@ "\n", "The advantage of using a count fingerprint in that case is that it will automatically select the interaction occurence with the shortest distance for a more intuitive visualization.\n", "\n", - "Once you're satisfied with the orientation, you can export the view as a PNG image with the following snippet:\n", + "Once you're satisfied with the orientation, you can export the view as a PNG image with:\n", "\n", "```python\n", - "from IPython.display import Javascript\n", - "\n", - "Javascript(\n", - " \"\"\"\n", - " var png = viewer_%s.pngURI()\n", - " var a = document.createElement('a')\n", - " a.href = png\n", - " a.download = \"prolif-3d.png\"\n", - " a.click()\n", - " a.remove()\n", - " \"\"\"\n", - " % view.uniqueid\n", - ")\n", + "view.save_png()\n", + "```\n", + "Note that this only works in notebooks and cannot be used in regular Python scripts.\n", "```" ] }, diff --git a/docs/notebooks/md-ligand-protein.ipynb b/docs/notebooks/md-ligand-protein.ipynb index b995a6c..abb75e2 100644 --- a/docs/notebooks/md-ligand-protein.ipynb +++ b/docs/notebooks/md-ligand-protein.ipynb @@ -550,7 +550,8 @@ "metadata": {}, "outputs": [], "source": [ - "fp.plot_lignetwork(ligand_mol)" + "view = fp.plot_lignetwork(ligand_mol)\n", + "view" ] }, { @@ -564,7 +565,11 @@ "- hover an interaction line to display the distance.\n", "\n", ":::{note}\n", - "It is not possible to export it as an image, but you can always take a screenshot.\n", + "After arranging the residues to your liking, you can save the plot as a PNG image with:\n", + "```python\n", + "view.save_png()\n", + "```\n", + "Note that this only works in notebooks and cannot be used in regular Python scripts.\n", ":::\n", "\n", "You can generate 2 types of diagram with this function, controlled by the `kind` argument:\n", @@ -585,7 +590,8 @@ "metadata": {}, "outputs": [], "source": [ - "fp.plot_lignetwork(ligand_mol, threshold=0.0)" + "view = fp.plot_lignetwork(ligand_mol, threshold=0.0)\n", + "view" ] }, { @@ -603,7 +609,8 @@ "source": [ "fp_count = plf.Fingerprint(count=True)\n", "fp_count.run(u.trajectory[0:1], ligand_selection, protein_selection)\n", - "fp_count.plot_lignetwork(ligand_mol, kind=\"frame\", frame=0, display_all=True)" + "view = fp_count.plot_lignetwork(ligand_mol, kind=\"frame\", frame=0, display_all=True)\n", + "view" ] }, { @@ -637,23 +644,12 @@ "\n", "The advantage of using a count fingerprint in that case is that it will automatically select the interaction occurence with the shortest distance for a more intuitive visualization.\n", "\n", - "Once you're satisfied with the orientation, you can export the view as a PNG image with the following snippet:\n", + "Once you're satisfied with the orientation, you can export the view as a PNG image with:\n", "\n", "```python\n", - "from IPython.display import Javascript\n", - "\n", - "Javascript(\n", - " \"\"\"\n", - " var png = viewer_%s.pngURI()\n", - " var a = document.createElement('a')\n", - " a.href = png\n", - " a.download = \"prolif-3d.png\"\n", - " a.click()\n", - " a.remove()\n", - " \"\"\"\n", - " % view.uniqueid\n", - ")\n", - "```" + "view.save_png()\n", + "```\n", + "Note that this only works in notebooks and cannot be used in regular Python scripts." ] }, { @@ -708,7 +704,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.13" + "version": "3.11.9" } }, "nbformat": 4, diff --git a/docs/notebooks/md-protein-protein.ipynb b/docs/notebooks/md-protein-protein.ipynb index f6817d6..1505530 100644 --- a/docs/notebooks/md-protein-protein.ipynb +++ b/docs/notebooks/md-protein-protein.ipynb @@ -578,7 +578,8 @@ "metadata": {}, "outputs": [], "source": [ - "fp.plot_lignetwork(small_protein_mol)" + "view = fp.plot_lignetwork(small_protein_mol)\n", + "view" ] }, { @@ -592,7 +593,11 @@ "- hover an interaction line to display the distance.\n", "\n", ":::{note}\n", - "It is not possible to export it as an image, but you can always take a screenshot.\n", + "After arranging the residues to your liking, you can save the plot as a PNG image with:\n", + "```python\n", + "view.save_png()\n", + "```\n", + "Note that this only works in notebooks and cannot be used in regular Python scripts.\n", ":::\n", "\n", "You can generate 2 types of diagram with this function, controlled by the `kind` argument:\n", @@ -647,23 +652,12 @@ "source": [ "As in the lignetwork plot, you can hover atoms and interactions to display more information.\n", "\n", - "Once you're satisfied with the orientation, you can export the view as a PNG image with the following snippet:\n", + "Once you're satisfied with the orientation, you can export the view as a PNG image with:\n", "\n", "```python\n", - "from IPython.display import Javascript\n", - "\n", - "Javascript(\n", - " \"\"\"\n", - " var png = viewer_%s.pngURI()\n", - " var a = document.createElement('a')\n", - " a.href = png\n", - " a.download = \"prolif-3d.png\"\n", - " a.click()\n", - " a.remove()\n", - " \"\"\"\n", - " % view.uniqueid\n", - ")\n", - "```" + "view.save_png()\n", + "```\n", + "Note that this only works in notebooks and cannot be used in regular Python scripts." ] }, { diff --git a/docs/notebooks/pdb.ipynb b/docs/notebooks/pdb.ipynb index 61424e4..54a65e9 100644 --- a/docs/notebooks/pdb.ipynb +++ b/docs/notebooks/pdb.ipynb @@ -403,7 +403,8 @@ "metadata": {}, "outputs": [], "source": [ - "fp.plot_lignetwork(ligand_mol, kind=\"frame\", frame=0)" + "view = fp.plot_lignetwork(ligand_mol, kind=\"frame\", frame=0)\n", + "view" ] }, { @@ -417,7 +418,11 @@ "- hover an interaction line to display the distance.\n", "\n", ":::{note}\n", - "It is not possible to export it as an image, but you can always take a screenshot.\n", + "After arranging the residues to your liking, you can save the plot as a PNG image with:\n", + "```python\n", + "view.save_png()\n", + "```\n", + "Note that this only works in notebooks and cannot be used in regular Python scripts.\n", ":::" ] }, @@ -436,7 +441,8 @@ "source": [ "fp_count = plf.Fingerprint(count=True)\n", "fp_count.run_from_iterable([ligand_mol], protein_mol)\n", - "fp_count.plot_lignetwork(ligand_mol, kind=\"frame\", frame=0, display_all=True)" + "view = fp_count.plot_lignetwork(ligand_mol, kind=\"frame\", frame=0, display_all=True)\n", + "view" ] }, { @@ -464,23 +470,12 @@ "\n", "The advantage of using a count fingerprint in that case is that it will automatically select the interaction occurence with the shortest distance for a more intuitive visualization.\n", "\n", - "Once you're satisfied with the orientation, you can export the view as a PNG image with the following snippet:\n", + "Once you're satisfied with the orientation, you can export the view as a PNG image with:\n", "\n", "```python\n", - "from IPython.display import Javascript\n", - "\n", - "Javascript(\n", - " \"\"\"\n", - " var png = viewer_%s.pngURI()\n", - " var a = document.createElement('a')\n", - " a.href = png\n", - " a.download = \"prolif-3d.png\"\n", - " a.click()\n", - " a.remove()\n", - " \"\"\"\n", - " % view.uniqueid\n", - ")\n", - "```" + "view.save_png()\n", + "```\n", + "Note that this only works in notebooks and cannot be used in regular Python scripts." ] } ], diff --git a/environment.yml b/environment.yml index 253f86f..45cdd4a 100644 --- a/environment.yml +++ b/environment.yml @@ -1,7 +1,6 @@ name: prolif channels: - conda-forge - - anaconda - defaults dependencies: - ipykernel @@ -24,6 +23,4 @@ dependencies: - dill - multiprocess - pip: - - pyvis - - py3Dmol - sphinx-book-theme diff --git a/prolif/__init__.py b/prolif/__init__.py index 0e69964..0f30d2e 100644 --- a/prolif/__init__.py +++ b/prolif/__init__.py @@ -1,3 +1,4 @@ +# ruff: noqa: F401 from prolif import datafiles from prolif._version import __version__ from prolif.fingerprint import Fingerprint diff --git a/prolif/datafiles.py b/prolif/datafiles.py index 81ca192..b519540 100644 --- a/prolif/datafiles.py +++ b/prolif/datafiles.py @@ -1,8 +1,11 @@ -from pathlib import Path +import atexit +from contextlib import ExitStack +from importlib import resources -from pkg_resources import resource_filename - -datapath = Path(resource_filename("prolif", "data/")) +_file_manager = ExitStack() +atexit.register(_file_manager.close) +_data_resource = resources.files("prolif") / "data/" +datapath = _file_manager.enter_context(resources.as_file(_data_resource)) TOP = str(datapath / "top.pdb") TRAJ = str(datapath / "traj.xtc") diff --git a/prolif/fingerprint.py b/prolif/fingerprint.py index 6a043e9..d66ccbf 100644 --- a/prolif/fingerprint.py +++ b/prolif/fingerprint.py @@ -29,7 +29,7 @@ import warnings from collections.abc import Sized from functools import wraps -from typing import Literal, Optional, Tuple +from typing import Literal, Optional, Tuple, Union import dill import multiprocess as mp @@ -77,7 +77,8 @@ class Fingerprint: ---------- interactions : list List of names (str) of interaction classes as found in the - :mod:`prolif.interactions` module. + :mod:`prolif.interactions` module. Defaults to Hydrophobic, HBDonor, HBAcceptor, + PiStacking, Anionic, Cationic, CationPi, PiCation, VdWContact. parameters : dict, optional New parameters for the interactions. Mapping between an interaction name and a dict of parameters as they appear in the interaction class. @@ -184,21 +185,24 @@ class for more information. def __init__( self, - interactions=[ - "Hydrophobic", - "HBDonor", - "HBAcceptor", - "PiStacking", - "Anionic", - "Cationic", - "CationPi", - "PiCation", - "VdWContact", - ], + interactions=None, parameters=None, count=False, vicinity_cutoff=6.0, ): + if interactions is None: + interactions = [ + "Hydrophobic", + "HBDonor", + "HBAcceptor", + "PiStacking", + "Anionic", + "Cationic", + "CationPi", + "PiCation", + "VdWContact", + ] + self.interactions = interactions self.count = count self._set_interactions(interactions, parameters) self.vicinity_cutoff = vicinity_cutoff @@ -227,7 +231,7 @@ def _check_valid_interactions(self, interactions_iterable, varname): unknown = unsafe.symmetric_difference(_INTERACTIONS.keys()) & unsafe if unknown: raise NameError( - f"Unknown interaction(s) in {varname!r}: {', '.join(unknown)}" + f"Unknown interaction(s) in {varname!r}: {', '.join(unknown)}", ) def __repr__(self): # pragma: no cover @@ -365,7 +369,9 @@ def generate(self, lig, prot, residues=None, metadata=False): for lresid, lres in lig.residues.items(): if residues is None: prot_residues = get_residues_near_ligand( - lres, prot, self.vicinity_cutoff + lres, + prot, + self.vicinity_cutoff, ) for prot_key in prot_residues: pres = prot[prot_key] @@ -457,7 +463,7 @@ def run( dictionary containing more complete interaction metadata instead of just atom indices. - """ + """ # noqa: E501 if n_jobs is not None and n_jobs < 1: raise ValueError("n_jobs must be > 0 or None") if converter_kwargs is not None and len(converter_kwargs) != 2: @@ -485,7 +491,10 @@ def run( lig_mol = Molecule.from_mda(lig, **converter_kwargs[0]) prot_mol = Molecule.from_mda(prot, **converter_kwargs[1]) ifp[int(ts.frame)] = self.generate( - lig_mol, prot_mol, residues=residues, metadata=True + lig_mol, + prot_mol, + residues=residues, + metadata=True, ) self.ifp = ifp return self @@ -501,7 +510,7 @@ def _run_parallel( n_jobs=None, ): """Parallel implementation of :meth:`~Fingerprint.run`""" - n_chunks = n_jobs if n_jobs else mp.cpu_count() + n_chunks = n_jobs or mp.cpu_count() try: n_frames = traj.n_frames except AttributeError: @@ -528,7 +537,13 @@ def _run_parallel( return self def run_from_iterable( - self, lig_iterable, prot_mol, *, residues=None, progress=True, n_jobs=None + self, + lig_iterable, + prot_mol, + *, + residues=None, + progress=True, + n_jobs=None, ): """Generates the fingerprint between a list of ligands and a protein @@ -613,7 +628,12 @@ def run_from_iterable( return self def _run_iter_parallel( - self, lig_iterable, prot_mol, residues=None, progress=True, n_jobs=None + self, + lig_iterable, + prot_mol, + residues=None, + progress=True, + n_jobs=None, ): """Parallel implementation of :meth:`~Fingerprint.run_from_iterable`""" total = ( @@ -637,7 +657,12 @@ def _run_iter_parallel( return self def to_dataframe( - self, *, count=None, dtype=None, drop_empty=True, index_col="Frame" + self, + *, + count=None, + dtype=None, + drop_empty=True, + index_col="Frame", ): """Converts fingerprints to a pandas DataFrame @@ -817,7 +842,7 @@ def from_pickle(cls, path_or_bytes): with warnings.catch_warnings(): warnings.filterwarnings( "ignore", - r"The .+ interaction has been superseded by a new class", # pragma: no cover + r"The .+ interaction has been superseded by a new class", # pragma: no cover # noqa: E501 ) if isinstance(path_or_bytes, bytes): return dill.loads(path_or_bytes) @@ -861,11 +886,12 @@ def plot_lignetwork( Frequency threshold, between 0 and 1. Only applicable for ``kind="aggregate"``. use_coordinates : bool - If ``True``, uses the coordinates of the molecule directly, otherwise generates - 2D coordinates from scratch. See also ``flatten_coordinates``. + If ``True``, uses the coordinates of the molecule directly, otherwise + generates 2D coordinates from scratch. See also ``flatten_coordinates``. flatten_coordinates : bool - If this is ``True`` and ``use_coordinates=True``, generates 2D coordinates that - are constrained to fit the 3D conformation of the ligand as best as possible. + If this is ``True`` and ``use_coordinates=True``, generates 2D coordinates + that are constrained to fit the 3D conformation of the ligand as best as + possible. kekulize : bool Kekulize the ligand. molsize : int @@ -978,6 +1004,8 @@ def plot_3d( frame: int, size: Tuple[int, int] = (650, 600), display_all: bool = False, + only_interacting: bool = True, + remove_hydrogens: Union[bool, Literal["ligand", "protein"]] = True, ): """Generate and display the complex in 3D with py3Dmol from a fingerprint object that has been used to run an analysis. @@ -997,16 +1025,36 @@ def plot_3d( Display all occurences for a given pair of residues and interaction, or only the shortest one. Not relevant if ``count=False`` in the ``Fingerprint`` object. + only_interacting : bool = True + Whether to show all protein residues in the vicinity of the ligand, or + only the ones participating in an interaction. + remove_hydrogens: Union[bool, Literal["ligand", "protein"]] = True + Whether to remove non-polar hydrogens (unless they are involved in an + interaction). See Also -------- :class:`prolif.plotting.complex3d.Complex3D` .. versionadded:: 2.0.0 + + .. versionchanged:: 2.1.0 + Added ``only_interacting=True`` and ``remove_hydrogens=True`` parameters. + Non-polar hydrogen atoms that aren't involved in interactions are now + hidden. + """ from prolif.plotting.complex3d import Complex3D plot3d = Complex3D.from_fingerprint( - self, frame=frame, lig_mol=ligand_mol, prot_mol=protein_mol + self, + frame=frame, + lig_mol=ligand_mol, + prot_mol=protein_mol, + ) + return plot3d.display( + size=size, + display_all=display_all, + only_interacting=only_interacting, + remove_hydrogens=remove_hydrogens, ) - return plot3d.display(size=size, display_all=display_all) diff --git a/prolif/ifp.py b/prolif/ifp.py index 4f586ee..8754d13 100644 --- a/prolif/ifp.py +++ b/prolif/ifp.py @@ -60,10 +60,10 @@ def __getitem__(self, key): residue_tuple: interactions for residue_tuple, interactions in self.data.items() if key in residue_tuple - } + }, ) raise KeyError( f"{key} does not correspond to a valid IFP key: it must be a tuple of " "either ResidueId or residue string. If you need to filter the IFP, a " - "single ResidueId or residue string can also be used." + "single ResidueId or residue string can also be used.", ) diff --git a/prolif/interactions/__init__.py b/prolif/interactions/__init__.py index 463d7b9..e3a3c81 100644 --- a/prolif/interactions/__init__.py +++ b/prolif/interactions/__init__.py @@ -1,3 +1,4 @@ +# ruff: noqa: F401 from prolif.interactions.base import ( BasePiStacking, Distance, @@ -5,4 +6,20 @@ Interaction, SingleAngle, ) -from prolif.interactions.interactions import * +from prolif.interactions.interactions import ( + Anionic, + Cationic, + CationPi, + EdgeToFace, + FaceToFace, + HBAcceptor, + HBDonor, + Hydrophobic, + MetalAcceptor, + MetalDonor, + PiCation, + PiStacking, + VdWContact, + XBAcceptor, + XBDonor, +) diff --git a/prolif/interactions/base.py b/prolif/interactions/base.py index 0e2df25..00d9736 100644 --- a/prolif/interactions/base.py +++ b/prolif/interactions/base.py @@ -36,12 +36,14 @@ def __init_subclass__(cls, is_abstract=False): register = _BASE_INTERACTIONS if is_abstract else _INTERACTIONS if not hasattr(cls, "detect"): raise TypeError( - f"Can't instantiate interaction class {name} without a `detect` method." + f"Can't instantiate interaction class {name} without a `detect`" + " method.", ) if name in register: warnings.warn( f"The {name!r} interaction has been superseded by a " - f"new class with id {id(cls):#x}" + f"new class with id {id(cls):#x}", + stacklevel=2, ) register[name] = cls @@ -66,10 +68,10 @@ def metadata(self, lig_res, prot_res, lig_indices, prot_indices, **data): }, "parent_indices": { "ligand": tuple( - [get_mapindex(lig_res, index) for index in lig_indices] + [get_mapindex(lig_res, index) for index in lig_indices], ), "protein": tuple( - [get_mapindex(prot_res, index) for index in prot_indices] + [get_mapindex(prot_res, index) for index in prot_indices], ), }, **data, @@ -101,8 +103,7 @@ def invert_role(cls, name, docstring): """ cls_docstring = cls.__doc__ or "\n" parameters_doc = cls_docstring.split("\n", maxsplit=1)[1] - __doc__ = f"{docstring}\n{parameters_doc}" - inverted = type(name, (cls,), {"__doc__": __doc__}) + inverted = type(name, (cls,), {"__doc__": f"{docstring}\n{parameters_doc}"}) def detect(self, ligand, residue): for metadata in super(inverted, self).detect(residue, ligand): @@ -140,7 +141,11 @@ def detect(self, lig_res, prot_res): dist = alig.Distance(aprot) if dist <= self.distance: yield self.metadata( - lig_res, prot_res, lig_match, prot_match, distance=dist + lig_res, + prot_res, + lig_match, + prot_match, + distance=dist, ) @@ -366,11 +371,15 @@ def detect(self, ligand, residue): n2c2c1 = res_normal.AngleTo(c2c1) ncc_angle = None if angle_between_limits( - n1c1c2, *self.normal_to_centroid_angle, ring=True + n1c1c2, + *self.normal_to_centroid_angle, + ring=True, ): ncc_angle = n1c1c2 elif angle_between_limits( - n2c2c1, *self.normal_to_centroid_angle, ring=True + n2c2c1, + *self.normal_to_centroid_angle, + ring=True, ): ncc_angle = n2c2c1 if ncc_angle is None: @@ -378,7 +387,10 @@ def detect(self, ligand, residue): if self.intersect: # look for point of intersection between both ring planes intersect = self._get_intersect_point( - lig_normal, lig_centroid, res_normal, res_centroid + lig_normal, + lig_centroid, + res_normal, + res_centroid, ) if intersect is None: continue @@ -420,7 +432,7 @@ def _get_intersect_point( intersect_direction = plane_normal.CrossProduct(tilted_normal) # setup system of linear equations to solve A = np.array( - [list(plane_normal), list(tilted_normal), list(intersect_direction)] + [list(plane_normal), list(tilted_normal), list(intersect_direction)], ) if np.linalg.det(A) == 0: return None diff --git a/prolif/interactions/interactions.py b/prolif/interactions/interactions.py index a88dabd..9404274 100644 --- a/prolif/interactions/interactions.py +++ b/prolif/interactions/interactions.py @@ -27,21 +27,21 @@ from prolif.utils import angle_between_limits, get_centroid, get_ring_normal_vector __all__ = [ - "Hydrophobic", - "HBAcceptor", - "HBDonor", - "XBAcceptor", - "XBDonor", - "Cationic", "Anionic", "CationPi", - "PiCation", - "FaceToFace", + "Cationic", "EdgeToFace", - "PiStacking", - "MetalDonor", + "FaceToFace", + "HBAcceptor", + "HBDonor", + "Hydrophobic", "MetalAcceptor", + "MetalDonor", + "PiCation", + "PiStacking", "VdWContact", + "XBAcceptor", + "XBDonor", ] VDWRADII = {symbol.capitalize(): radius for symbol, radius in vdwradii.items()} @@ -68,7 +68,9 @@ def __init__( distance=4.5, ): super().__init__( - lig_pattern=hydrophobic, prot_pattern=hydrophobic, distance=distance + lig_pattern=hydrophobic, + prot_pattern=hydrophobic, + distance=distance, ) @@ -113,7 +115,8 @@ def __init__( HBDonor = HBAcceptor.invert_role( - "HBDonor", "Hbond interaction between a ligand (donor) and a residue (acceptor)" + "HBDonor", + "Hbond interaction between a ligand (donor) and a residue (acceptor)", ) @@ -171,7 +174,8 @@ def __init__( XBDonor = XBAcceptor.invert_role( - "XBDonor", "Halogen bonding between a ligand (donor) and a residue (acceptor)" + "XBDonor", + "Halogen bonding between a ligand (donor) and a residue (acceptor)", ) @@ -193,7 +197,8 @@ def __init__( Anionic = Cationic.invert_role( - "Anionic", "Ionic interaction between a ligand (anion) and a residue (cation)" + "Anionic", + "Ionic interaction between a ligand (anion) and a residue (cation)", ) @@ -380,8 +385,9 @@ class PiStacking(Interaction): `shortest_distance` has been replaced by `angle_normal_centroid` .. versionchanged:: 1.1.0 - The implementation now directly calls :class:`EdgeToFace` and :class:`FaceToFace` - instead of overwriting the default parameters with more generic ones. + The implementation now directly calls :class:`EdgeToFace` and + :class:`FaceToFace` instead of overwriting the default parameters with more + generic ones. """ @@ -472,9 +478,13 @@ def detect(self, ligand, residue): vdw = self.vdwradii[lig] + self.vdwradii[res] + self.tolerance self._vdw_cache[elements] = vdw dist = lxyz.GetAtomPosition(la.GetIdx()).Distance( - rxyz.GetAtomPosition(ra.GetIdx()) + rxyz.GetAtomPosition(ra.GetIdx()), ) if dist <= vdw: yield self.metadata( - ligand, residue, (la.GetIdx(),), (ra.GetIdx(),), distance=dist + ligand, + residue, + (la.GetIdx(),), + (ra.GetIdx(),), + distance=dist, ) diff --git a/prolif/interactions/utils.py b/prolif/interactions/utils.py index d19581a..2ca4891 100644 --- a/prolif/interactions/utils.py +++ b/prolif/interactions/utils.py @@ -24,27 +24,27 @@ def get_mapindex(res, index): return res.GetAtomWithIdx(index).GetUnsignedProp("mapindex") -def _distance_3args_l1_p1(l1, p1, p2): +def _distance_3args_l1_p1(l1, p1, p2): # noqa: ARG001 return l1.Distance(p1) -def _distance_3args_l1_p2(l1, p1, p2): +def _distance_3args_l1_p2(l1, p1, p2): # noqa: ARG001 return l1.Distance(p2) -def _distance_4args_l1_p1(l1, l2, p1, p2): +def _distance_4args_l1_p1(l1, l2, p1, p2): # noqa: ARG001 return l1.Distance(p1) -def _distance_4args_l1_p2(l1, l2, p1, p2): +def _distance_4args_l1_p2(l1, l2, p1, p2): # noqa: ARG001 return l1.Distance(p2) -def _distance_4args_l2_p1(l1, l2, p1, p2): +def _distance_4args_l2_p1(l1, l2, p1, p2): # noqa: ARG001 return l2.Distance(p1) -def _distance_4args_l2_p2(l1, l2, p1, p2): +def _distance_4args_l2_p2(l1, l2, p1, p2): # noqa: ARG001 return l2.Distance(p2) diff --git a/prolif/molecule.py b/prolif/molecule.py index 3d28d69..1c5fe3d 100644 --- a/prolif/molecule.py +++ b/prolif/molecule.py @@ -94,8 +94,8 @@ def from_mda(cls, obj, selection=None, **kwargs): Apply a selection to `obj` to create an AtomGroup. Uses all atoms in `obj` if ``selection=None`` **kwargs : object - Other arguments passed to the :class:`~MDAnalysis.converters.RDKit.RDKitConverter` - of MDAnalysis + Other arguments passed to the + :class:`~MDAnalysis.converters.RDKit.RDKitConverter` of MDAnalysis Example ------- @@ -160,8 +160,7 @@ def from_rdkit(cls, mol, resname="UNL", resnumber=1, chain=""): return cls(mol) def __iter__(self): - for residue in self.residues.values(): - yield residue + yield from self.residues.values() def __getitem__(self, key): return self.residues[key] @@ -253,12 +252,16 @@ def pdbqt_to_mol(self, pdbqt_path): pdbqt.add_TopologyAttr("chainIDs", pdbqt.atoms.segids) pdbqt.atoms.types = pdbqt.atoms.elements # convert without infering bond orders and charges - with catch_rdkit_logs(), catch_warning( - message=r"^(Could not sanitize molecule)|" - r"(No `bonds` attribute in this AtomGroup)" + with ( + catch_rdkit_logs(), + catch_warning( + message=r"^(Could not sanitize molecule)|" + r"(No `bonds` attribute in this AtomGroup)", + ), ): pdbqt_mol = pdbqt.atoms.convert_to.rdkit( - NoImplicit=False, **self.converter_kwargs + NoImplicit=False, + **self.converter_kwargs, ) mol = self._adjust_hydrogens(self.template, pdbqt_mol) return Molecule.from_rdkit(mol, **self._kwargs) @@ -276,7 +279,8 @@ def _adjust_hydrogens(template, pdbqt_mol): atoms_with_hydrogens[ atom.GetNeighbors()[0].GetIntProp("_MDAnalysis_index") ].append(atom) - # mapping between atom that should be bearing a H in RWMol and corresponding H(s) + # mapping between atom that should be bearing a H in RWMol and + # corresponding hydrogens reverse_mapping = {} for atom in mol.GetAtoms(): if (idx := atom.GetIntProp("_MDAnalysis_index")) in atoms_with_hydrogens: @@ -309,6 +313,8 @@ class sdf_supplier(Sequence): ---------- path : str A path to the .sdf file + sanitize : bool + Whether to sanitize each molecule or not. resname : str Residue name for every ligand resnumber : int @@ -333,11 +339,14 @@ class sdf_supplier(Sequence): Molecule suppliers are now sequences that can be reused, indexed, and can return their length, instead of single-use generators. + .. versionchanged:: 2.1.0 + Added ``sanitize`` parameter (defaults to ``True``, same behavior as before). + """ - def __init__(self, path, **kwargs): + def __init__(self, path, sanitize=True, **kwargs): self.path = path - self._suppl = Chem.SDMolSupplier(path, removeHs=False) + self._suppl = Chem.SDMolSupplier(path, removeHs=False, sanitize=sanitize) self._kwargs = kwargs def __iter__(self): @@ -359,6 +368,11 @@ class mol2_supplier(Sequence): ---------- path : str A path to the .mol2 file + sanitize : bool + Whether to sanitize each molecule or not. + cleanup_substructures : bool + Toggles standardizing some substructures found in mol2 files, based on atom + types. resname : str Residue name for every ligand resnumber : int @@ -383,15 +397,21 @@ class mol2_supplier(Sequence): Molecule suppliers are now sequences that can be reused, indexed, and can return their length, instead of single-use generators. + .. versionchanged:: 2.1.0 + Added ``cleanup_substructures`` and ``sanitize`` parameters + (default to ``True``, same behavior as before). + """ - def __init__(self, path, **kwargs): + def __init__(self, path, cleanup_substructures=True, sanitize=True, **kwargs): self.path = path + self.cleanup_substructures = cleanup_substructures + self.sanitize = sanitize self._kwargs = kwargs def __iter__(self): block = [] - with open(self.path, "r") as f: + with open(self.path) as f: for line in f: if line.startswith("#"): continue @@ -402,7 +422,12 @@ def __iter__(self): yield self.block_to_mol(block) def block_to_mol(self, block): - mol = Chem.MolFromMol2Block("".join(block), removeHs=False) + mol = Chem.MolFromMol2Block( + "".join(block), + removeHs=False, + cleanupSubstructures=self.cleanup_substructures, + sanitize=self.sanitize, + ) return Molecule.from_rdkit(mol, **self._kwargs) def __getitem__(self, index): @@ -411,7 +436,7 @@ def __getitem__(self, index): mol_index = -1 molblock_started = False block = [] - with open(self.path, "r") as f: + with open(self.path) as f: for line in f: if line.startswith("@MOLECULE"): mol_index += 1 @@ -426,6 +451,5 @@ def __getitem__(self, index): raise ValueError(f"Could not parse molecule with index {index}") def __len__(self): - with open(self.path, "r") as f: - n_mols = sum(line.startswith("@MOLECULE") for line in f) - return n_mols + with open(self.path) as f: + return sum(line.startswith("@MOLECULE") for line in f) diff --git a/prolif/parallel.py b/prolif/parallel.py index c35dd51..09ba357 100644 --- a/prolif/parallel.py +++ b/prolif/parallel.py @@ -103,7 +103,7 @@ class TrajectoryPool: """ def __init__( - self, n_processes, fingerprint, residues, tqdm_kwargs, rdkitconverter_kwargs + self, n_processes, fingerprint, residues, tqdm_kwargs, rdkitconverter_kwargs, ): self.tqdm_kwargs = tqdm_kwargs self.tracker = Value(c_uint32, lock=True) @@ -142,7 +142,7 @@ def executor(cls, args): lig_mol = Molecule.from_mda(lig, **cls.converter_kwargs[0]) prot_mol = Molecule.from_mda(prot, **cls.converter_kwargs[1]) data = cls.fp.generate( - lig_mol, prot_mol, residues=cls.residues, metadata=True + lig_mol, prot_mol, residues=cls.residues, metadata=True, ) ifp[int(ts.frame)] = data with cls.tracker.get_lock(): diff --git a/prolif/plotting/barcode.py b/prolif/plotting/barcode.py index 046c61e..3b24501 100644 --- a/prolif/plotting/barcode.py +++ b/prolif/plotting/barcode.py @@ -11,6 +11,7 @@ from __future__ import annotations +from contextlib import suppress from typing import TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional, Tuple import numpy as np @@ -76,7 +77,7 @@ def _bit_to_color_value(s: pd.Series) -> pd.Series: return s.apply( lambda v: ( self.color_mapper[interaction] if v else self.color_mapper[None] - ) + ), ) self.df = df.astype(np.uint8).T.apply(_bit_to_color_value, axis=1) @@ -87,7 +88,7 @@ def from_fingerprint(cls, fp: Fingerprint) -> Barcode: if not hasattr(fp, "ifp"): raise RunRequiredError( "Please run the fingerprint analysis before attempting to display" - " results." + " results.", ) return cls(fp.to_dataframe()) @@ -182,11 +183,9 @@ def display( # legend values: List[int] = np.unique(self.df.values).tolist() - try: - values.pop(values.index(0)) # remove None color - except ValueError: + with suppress(ValueError): # 0 not in values (e.g. plotting a single frame) - pass + values.pop(values.index(0)) # remove None color legend_colors = { self.inv_color_mapper[value]: im.cmap(value) for value in values } @@ -229,7 +228,7 @@ def hover_callback(event): and event.ydata is not None ): x, y = round(event.xdata), round(event.ydata) - if self.df.values[y, x]: + if self.df.to_numpy()[y, x]: annot.xy = (x, y) frame = frames[x] interaction = interactions[y] diff --git a/prolif/plotting/complex3d.py b/prolif/plotting/complex3d.py index 26022e4..b3ee9f4 100644 --- a/prolif/plotting/complex3d.py +++ b/prolif/plotting/complex3d.py @@ -11,8 +11,9 @@ from __future__ import annotations +from contextlib import suppress from copy import deepcopy -from typing import TYPE_CHECKING, ClassVar, Dict, Optional, Set, Tuple +from typing import TYPE_CHECKING, ClassVar, Dict, Literal, Optional, Set, Tuple, Union import py3Dmol from rdkit import Chem @@ -20,13 +21,17 @@ from prolif.exceptions import RunRequiredError from prolif.plotting.utils import separated_interaction_colors -from prolif.utils import get_centroid +from prolif.utils import get_centroid, get_residues_near_ligand, requires + +with suppress(ModuleNotFoundError): + from IPython.display import Javascript, display + if TYPE_CHECKING: from prolif.fingerprint import Fingerprint from prolif.ifp import IFP from prolif.molecule import Molecule - from prolif.residue import ResidueId + from prolif.residue import Residue, ResidueId class Complex3D: @@ -75,21 +80,21 @@ class Complex3D: JavaScript callback executed when hovering an interaction line. DISABLE_HOVER_CALLBACK : str JavaScript callback executed when the hovering event is finished. - """ + """ # noqa: E501 COLORS: ClassVar[Dict[str, str]] = {**separated_interaction_colors} LIGAND_STYLE: ClassVar[Dict] = {"stick": {"colorscheme": "cyanCarbon"}} RESIDUES_STYLE: ClassVar[Dict] = {"stick": {}} PROTEIN_STYLE: ClassVar[Dict] = {"cartoon": {"style": "edged"}} PEPTIDE_STYLE: ClassVar[Dict] = { - "cartoon": {"style": "edged", "colorscheme": "cyanCarbon"} + "cartoon": {"style": "edged", "colorscheme": "cyanCarbon"}, } PEPTIDE_THRESHOLD: ClassVar[int] = 5 - LIGAND_DISPLAYED_ATOM = { + LIGAND_DISPLAYED_ATOM: ClassVar[Dict] = { "HBDonor": 1, "XBDonor": 1, } - PROTEIN_DISPLAYED_ATOM = { + PROTEIN_DISPLAYED_ATOM: ClassVar[Dict] = { "HBAcceptor": 1, "XBAcceptor": 1, } @@ -100,28 +105,22 @@ class Complex3D: } LIGAND_RING_INTERACTIONS: ClassVar[Set[str]] = {*RING_SYSTEMS, "PiCation"} PROTEIN_RING_INTERACTIONS: ClassVar[Set[str]] = {*RING_SYSTEMS, "CationPi"} - RESIDUE_HOVER_CALLBACK: ClassVar[ - str - ] = """ + RESIDUE_HOVER_CALLBACK: ClassVar[str] = """ function(atom,viewer) { if(!atom.label) { atom.label = viewer.addLabel('%s:'+atom.atom+atom.serial, {position: atom, backgroundColor: 'mintcream', fontColor:'black'}); } }""" - INTERACTION_HOVER_CALLBACK: ClassVar[ - str - ] = """ + INTERACTION_HOVER_CALLBACK: ClassVar[str] = """ function(shape,viewer) { if(!shape.label) { shape.label = viewer.addLabel(shape.interaction, {position: shape, backgroundColor: 'black', fontColor:'white'}); } }""" - DISABLE_HOVER_CALLBACK: ClassVar[ - str - ] = """ - function(obj,viewer) { + DISABLE_HOVER_CALLBACK: ClassVar[str] = """ + function(obj,viewer) { if(obj.label) { viewer.removeLabel(obj.label); delete obj.label; @@ -132,6 +131,7 @@ def __init__(self, ifp: IFP, lig_mol: Molecule, prot_mol: Molecule) -> None: self.ifp = ifp self.lig_mol = lig_mol self.prot_mol = prot_mol + self._view: Optional[py3Dmol.view] = None @classmethod def from_fingerprint( @@ -160,7 +160,7 @@ def from_fingerprint( if not hasattr(fp, "ifp"): raise RunRequiredError( "Please run the fingerprint analysis before attempting to display" - " results." + " results.", ) ifp = fp.ifp[frame] return cls(ifp, lig_mol, prot_mol) @@ -171,8 +171,12 @@ def get_ring_centroid(mol: Molecule, indices: Tuple[int, ...]) -> Point3D: return Point3D(*centroid) def display( - self, size: Tuple[int, int] = (650, 600), display_all: bool = False - ) -> py3Dmol.view: + self, + size: Tuple[int, int] = (650, 600), + display_all: bool = False, + only_interacting: bool = True, + remove_hydrogens: Union[bool, Literal["ligand", "protein"]] = True, + ) -> Complex3D: """Display as a py3Dmol widget view. Parameters @@ -183,11 +187,30 @@ def display( Display all occurences for a given pair of residues and interaction, or only the shortest one. Not relevant if ``count=False`` in the ``Fingerprint`` object. + only_interacting : bool = True + Whether to show all protein residues in the vicinity of the ligand, or + only the ones participating in an interaction. + remove_hydrogens: Union[bool, Literal["ligand", "protein"]] = True + Whether to remove non-polar hydrogens (unless they are involved in an + interaction). + + .. versionchanged:: 2.1.0 + Added ``only_interacting=True`` and ``remove_hydrogens=True`` parameters. + Non-polar hydrogen atoms that aren't involved in interactions are now + hidden. + """ v = py3Dmol.view(width=size[0], height=size[1], viewergrid=(1, 1), linked=False) v.removeAllModels() - self._populate_view(v, position=(0, 0), display_all=display_all) - return v + self._populate_view( + v, + position=(0, 0), + display_all=display_all, + only_interacting=only_interacting, + remove_hydrogens=remove_hydrogens, + ) + self._view = v + return self def compare( self, @@ -197,7 +220,9 @@ def compare( display_all: bool = False, linked: bool = True, color_unique: Optional[str] = "magentaCarbon", - ) -> py3Dmol.view: + only_interacting: bool = True, + remove_hydrogens: Union[bool, Literal["ligand", "protein"]] = True, + ) -> Complex3D: """Displays the initial complex side-by-side with a second one for easier comparison. @@ -216,9 +241,20 @@ def compare( color_unique: Optional[str] = "magentaCarbon", Which color to use for residues that have interactions that are found in one complex but not the other. Use ``None`` to disable the color override. + only_interacting : bool = True + Whether to show all protein residues in the vicinity of the ligand, or + only the ones participating in an interaction. + remove_hydrogens: Union[bool, Literal["ligand", "protein"]] = True + Whether to remove non-polar hydrogens (unless they are involved in an + interaction). .. versionadded:: 2.0.1 + .. versionchanged:: 2.1.0 + Added ``only_interacting=True`` and ``remove_hydrogens=True`` parameters. + Non-polar hydrogen atoms that aren't involved in interactions are now + hidden. + """ v = py3Dmol.view( width=size[0], @@ -229,16 +265,16 @@ def compare( v.removeAllModels() # get set of interactions for both poses - interactions1 = set( + interactions1 = { (resid[1], i) for resid, interactions in self.ifp.items() for i in interactions - ) - interactions2 = set( + } + interactions2 = { (resid[1], i) for resid, interactions in other.ifp.items() for i in interactions - ) + } # get residues with interactions specific to pose 1 highlights = ( @@ -247,7 +283,12 @@ def compare( else {} ) self._populate_view( - v, position=(0, 0), display_all=display_all, colormap=highlights + v, + position=(0, 0), + display_all=display_all, + colormap=highlights, + only_interacting=only_interacting, + remove_hydrogens=remove_hydrogens, ) # get residues with interactions specific to pose 2 @@ -257,22 +298,33 @@ def compare( else {} ) other._populate_view( - v, position=(0, 1), display_all=display_all, colormap=highlights + v, + position=(0, 1), + display_all=display_all, + colormap=highlights, + only_interacting=only_interacting, + remove_hydrogens=remove_hydrogens, ) - return v + self._view = v + return self - def _populate_view( + def _populate_view( # noqa: PLR0912 self, - v: py3Dmol.view, + v: Union[py3Dmol.view, Complex3D], position: Tuple[int, int] = (0, 0), display_all: bool = False, colormap: Optional[Dict[ResidueId, str]] = None, + only_interacting: bool = True, + remove_hydrogens: Union[bool, Literal["ligand", "protein"]] = True, ) -> None: - if colormap is None: - colormap = {} + if isinstance(v, Complex3D) and v._view: + v = v._view + self._colormap = {} if colormap is None else colormap + self._models = {} + self._mid = -1 + self._interacting_atoms = {"ligand": set(), "protein": set()} - models = {} - mid = -1 + # show all interacting residues for (lresid, presid), interactions in self.ifp.items(): lres = self.lig_mol[lresid] pres = self.prot_mol[presid] @@ -281,23 +333,8 @@ def _populate_view( (lresid, lres, self.LIGAND_STYLE), (presid, pres, self.RESIDUES_STYLE), ]: - if resid not in models: - mid += 1 - v.addModel(Chem.MolToMolBlock(res), "sdf", viewer=position) - model = v.getModel(viewer=position) - if resid in colormap: - style = deepcopy(style) - for key in style: - style[key]["colorscheme"] = colormap[resid] - model.setStyle({}, style) - # add residue label - model.setHoverable( - {}, - True, - self.RESIDUE_HOVER_CALLBACK % resid, - self.DISABLE_HOVER_CALLBACK, - ) - models[resid] = mid + if resid not in self._models: + self._add_residue_to_view(v, position, res, style) for interaction, metadata_tuple in interactions.items(): # whether to display all interactions or only the one with the shortest # distance @@ -312,6 +349,14 @@ def _populate_view( ) ) for metadata in metadata_iterator: + # record indices of atoms interacting + self._interacting_atoms["ligand"].update( + metadata["parent_indices"]["ligand"] + ) + self._interacting_atoms["protein"].update( + metadata["parent_indices"]["protein"] + ) + # get coordinates for both points of the interaction if interaction in self.LIGAND_RING_INTERACTIONS: p1 = self.get_ring_centroid(lres, metadata["indices"]["ligand"]) @@ -319,23 +364,24 @@ def _populate_view( p1 = lres.GetConformer().GetAtomPosition( metadata["indices"]["ligand"][ self.LIGAND_DISPLAYED_ATOM.get(interaction, 0) - ] + ], ) if interaction in self.PROTEIN_RING_INTERACTIONS: p2 = self.get_ring_centroid( - pres, metadata["indices"]["protein"] + pres, + metadata["indices"]["protein"], ) else: p2 = pres.GetConformer().GetAtomPosition( metadata["indices"]["protein"][ self.PROTEIN_DISPLAYED_ATOM.get(interaction, 0) - ] + ], ) # add interaction line v.addCylinder( { - "start": dict(x=p1.x, y=p1.y, z=p1.z), - "end": dict(x=p2.x, y=p2.y, z=p2.z), + "start": {"x": p1.x, "y": p1.y, "z": p1.z}, + "end": {"x": p2.x, "y": p2.y, "z": p2.z}, "color": self.COLORS.get(interaction, "grey"), "radius": 0.15, "dashed": True, @@ -344,9 +390,10 @@ def _populate_view( }, viewer=position, ) - # add label when hovering the middle of the dashed line by adding a dummy atom + # add label when hovering the middle of the dashed line by adding a + # dummy atom c = Point3D(*get_centroid([p1, p2])) - modelID = models[lresid] + modelID = self._models[lresid] model = v.getModel(modelID, viewer=position) interaction_label = f"{interaction}: {metadata['distance']:.2f}Å" model.addAtoms( @@ -357,8 +404,8 @@ def _populate_view( "y": c.y, "z": c.z, "interaction": interaction_label, - } - ] + }, + ], ) model.setStyle( {"interaction": interaction_label}, @@ -371,8 +418,42 @@ def _populate_view( self.DISABLE_HOVER_CALLBACK, ) + # show "protein" residues that are close to the "ligand" + if not only_interacting: + pocket_residues = get_residues_near_ligand(self.lig_mol, self.prot_mol) + pocket_residues = set(pocket_residues).difference(self._models) + for resid in pocket_residues: + res = self.prot_mol[resid] + self._add_residue_to_view(v, position, res, self.RESIDUES_STYLE) + + # hide non-polar hydrogens (except if they are involved in an interaction) + if remove_hydrogens: + to_remove = [] + if remove_hydrogens in {"ligand", True}: + to_remove.append(("ligand", self.lig_mol)) + if remove_hydrogens in {"protein", True}: + to_remove.append(("protein", self.prot_mol)) + + for resid in self._models: + for moltype, mol in to_remove: + try: + modelID = self._models[resid] + res = mol[resid] + except KeyError: + continue + model = v.getModel(modelID, viewer=position) + int_atoms = self._interacting_atoms[moltype] + hide = [ + a.GetIdx() + for a in res.GetAtoms() + if a.GetAtomicNum() == 1 + and a.GetUnsignedProp("mapindex") not in int_atoms + and all(n.GetAtomicNum() in {1, 6} for n in a.GetNeighbors()) + ] + model.setStyle({"index": hide}, {"stick": {"hidden": True}}) + # show protein - mol = Chem.RemoveAllHs(self.prot_mol) + mol = Chem.RemoveAllHs(self.prot_mol, sanitize=False) pdb = Chem.MolToPDBBlock(mol, flavor=0x20 | 0x10) v.addModel(pdb, "pdb", viewer=position) model = v.getModel(viewer=position) @@ -380,10 +461,65 @@ def _populate_view( # do the same for ligand if multiple residues if self.lig_mol.n_residues >= self.PEPTIDE_THRESHOLD: - mol = Chem.RemoveAllHs(self.lig_mol) + mol = Chem.RemoveAllHs(self.lig_mol, sanitize=False) pdb = Chem.MolToPDBBlock(mol, flavor=0x20 | 0x10) v.addModel(pdb, "pdb", viewer=position) model = v.getModel(viewer=position) model.setStyle({}, self.PEPTIDE_STYLE) - v.zoomTo({"model": list(models.values())}, viewer=position) + v.zoomTo({"model": list(self._models.values())}, viewer=position) + + def _add_residue_to_view( + self, + v: py3Dmol.view, + position: Tuple[int, int], + res: Residue, + style: Dict, + ) -> None: + self._mid += 1 + resid = res.resid + v.addModel(Chem.MolToMolBlock(res), "sdf", viewer=position) + model = v.getModel(viewer=position) + if resid in self._colormap: + resid_style = deepcopy(style) + for key in resid_style: + resid_style[key]["colorscheme"] = self._colormap[resid] + else: + resid_style = style + model.setStyle({}, resid_style) + # add residue label + model.setHoverable( + {}, + True, + self.RESIDUE_HOVER_CALLBACK % resid, + self.DISABLE_HOVER_CALLBACK, + ) + self._models[resid] = self._mid + + @requires("IPython.display") + def save_png(self) -> None: + """Saves the current state of the 3D viewer to a PNG. Not available outside of a + notebook. + + .. versionadded:: 2.1.0 + """ + if self._view is None: + raise ValueError( + "View not initialized, did you call `display`/`compare` first?", + ) + uid = self._view.uniqueid + display( + Javascript(f""" + var png = viewer_{uid}.pngURI() + var a = document.createElement('a') + a.href = png + a.download = "prolif-3d.png" + a.click() + a.remove() + """), + ) + + def _repr_html_(self): # noqa: PLW3201 + if self._view: + return self._view._repr_html_() + return None diff --git a/prolif/plotting/network.py b/prolif/plotting/network.py index c9e2e84..6dfc28a 100644 --- a/prolif/plotting/network.py +++ b/prolif/plotting/network.py @@ -14,12 +14,15 @@ """ import json +import operator import re import warnings from collections import defaultdict from copy import deepcopy from html import escape from pathlib import Path +from typing import ClassVar +from uuid import uuid4 import numpy as np import pandas as pd @@ -32,12 +35,13 @@ from prolif.utils import requires try: - from IPython.display import HTML + from IPython.display import Javascript, display except ModuleNotFoundError: pass else: warnings.filterwarnings( - "ignore", "Consider using IPython.display.IFrame instead" # pragma: no cover + "ignore", + "Consider using IPython.display.IFrame instead", # pragma: no cover ) @@ -97,7 +101,7 @@ class LigNetwork: VanDerWaals. """ - COLORS = { + COLORS: ClassVar = { "interactions": {**grouped_interaction_colors}, "atoms": { "C": "black", @@ -119,7 +123,7 @@ class LigNetwork: "Sulfur": "#e3ce59", }, } - RESIDUE_TYPES = { + RESIDUE_TYPES: ClassVar = { "ALA": "Aliphatic", "GLY": "Aliphatic", "ILE": "Aliphatic", @@ -149,8 +153,13 @@ class LigNetwork: "CYX": "Sulfur", "MET": "Sulfur", } - _LIG_PI_INTERACTIONS = ["EdgeToFace", "FaceToFace", "PiStacking", "PiCation"] - _DISPLAYED_ATOM = { # index 0 in indices tuple by default + _LIG_PI_INTERACTIONS: ClassVar = [ + "EdgeToFace", + "FaceToFace", + "PiStacking", + "PiCation", + ] + _DISPLAYED_ATOM: ClassVar = { # index 0 in indices tuple by default "HBDonor": 1, "XBDonor": 1, } @@ -200,7 +209,7 @@ class LigNetwork: - """ + """ # noqa: E501 def __init__( self, @@ -214,9 +223,9 @@ def __init__( carbon=0.16, ): self.df = df - self._interacting_atoms = set( - [atom for atoms in df.index.get_level_values("atoms") for atom in atoms] - ) + self._interacting_atoms = { + atom for atoms in df.index.get_level_values("atoms") for atom in atoms + } mol = deepcopy(lig_mol) if kekulize: Chem.Kekulize(mol) @@ -254,12 +263,15 @@ def __init__( self._default_atom_color = "grey" self._default_residue_color = "#dbdbdb" self._default_interaction_color = "#dbdbdb" + self._non_single_bond_spacing = 0.06 + self._dash = [10] # regroup interactions of the same color temp = defaultdict(list) interactions = set(df.index.get_level_values("interaction").unique()) for interaction in interactions: color = self.COLORS["interactions"].get( - interaction, self._default_interaction_color + interaction, + self._default_interaction_color, ) temp[color].append(interaction) self._interaction_types = { @@ -267,6 +279,9 @@ def __init__( for interaction_group in temp.values() for interaction in interaction_group } + # ID for saving to PNG with JS + self.uuid = uuid4().hex + self._iframe = None @classmethod def from_fingerprint( @@ -319,7 +334,7 @@ def from_fingerprint( if not hasattr(fp, "ifp"): raise RunRequiredError( "Please run the fingerprint analysis before attempting to display" - " results." + " results.", ) if kind == "frame": df = cls._make_frame_df_from_fp(fp, frame=frame, display_all=display_all) @@ -346,12 +361,13 @@ def _get_records(ifp, all_metadata): **entry, "atoms": metadata["parent_indices"]["ligand"], "distance": metadata.get("distance", 0), - } + }, ) else: # extract interaction with shortest distance metadata = min( - metadata_tuple, key=lambda m: m.get("distance", np.nan) + metadata_tuple, + key=lambda m: m.get("distance", np.nan), ) entry["atoms"] = metadata["parent_indices"]["ligand"] entry["distance"] = metadata.get("distance", 0) @@ -367,18 +383,19 @@ def _make_agg_df_from_fp(cls, fp, threshold=0.3): # add weight for each atoms, and average distance df["weight"] = 1 df = df.groupby(["ligand", "protein", "interaction", "atoms"]).agg( - weight=("weight", "sum"), distance=("distance", "mean") + weight=("weight", "sum"), + distance=("distance", "mean"), ) - df["weight"] = df["weight"] / len(fp.ifp) + df["weight"] /= len(fp.ifp) # merge different ligand atoms of the same residue/interaction group before # applying the threshold df = df.join( df.groupby(level=["ligand", "protein", "interaction"]).agg( - weight_total=("weight", "sum") + weight_total=("weight", "sum"), ), ) # threshold and keep most occuring ligand atom - df = ( + return ( df.loc[df["weight_total"] >= threshold] .drop(columns="weight_total") .sort_values("weight", ascending=False) @@ -386,7 +403,6 @@ def _make_agg_df_from_fp(cls, fp, threshold=0.3): .head(1) .sort_index() ) - return df @classmethod def _make_frame_df_from_fp(cls, fp, frame=0, display_all=False): @@ -394,10 +410,9 @@ def _make_frame_df_from_fp(cls, fp, frame=0, display_all=False): data = cls._get_records(ifp, all_metadata=display_all) df = pd.DataFrame(data) df["weight"] = 1 - df = df.set_index(["ligand", "protein", "interaction", "atoms"]).reindex( - columns=["weight", "distance"] + return df.set_index(["ligand", "protein", "interaction", "atoms"]).reindex( + columns=["weight", "distance"], ) - return df def _make_carbon(self): return deepcopy(self._carbon) @@ -412,7 +427,8 @@ def _make_lig_node(self, atom): charge = atom.GetFormalCharge() if charge != 0: charge = "{}{}".format( - "" if abs(charge) == 1 else str(charge), "+" if charge > 0 else "-" + "" if abs(charge) == 1 else str(charge), + "+" if charge > 0 else "-", ) label = f"{elem}{charge}" shape = "ellipse" @@ -427,7 +443,7 @@ def _make_lig_node(self, atom): "shape": shape, "color": "white", "font": { - "color": self.COLORS["atoms"].get(elem, self._default_atom_color) + "color": self.COLORS["atoms"].get(elem, self._default_atom_color), }, } node.update( @@ -438,7 +454,7 @@ def _make_lig_node(self, atom): "fixed": True, "group": "ligand", "borderWidth": 0, - } + }, ) self.nodes[idx] = node @@ -457,12 +473,12 @@ def _make_lig_edge(self, bond): "physics": False, "group": "ligand", "width": 4, - } + }, ) else: self._make_non_single_bond(idx, btype) - def _make_non_single_bond(self, ids, btype, bdist=0.06, dash=[10]): + def _make_non_single_bond(self, ids, btype): """Prepare double, triple and aromatic bonds""" xyz = self.xyz[ids] d = xyz[1, :2] - xyz[0, :2] @@ -470,8 +486,8 @@ def _make_non_single_bond(self, ids, btype, bdist=0.06, dash=[10]): u = d / length p = np.array([-u[1], u[0]]) nodes = [] - dist = bdist * self._multiplier * np.ceil(btype) - dashes = False if btype in [2, 3] else dash + dist = self._non_single_bond_spacing * self._multiplier * np.ceil(btype) + dashes = False if btype in {2, 3} else self._dash for perp in (p, -p): for point in xyz: xy = point[:2] + perp * dist @@ -507,7 +523,7 @@ def _make_non_single_bond(self, ids, btype, bdist=0.06, dash=[10]): "group": "ligand", "width": 4, }, - ] + ], ) if btype == 3: self.edges.append( @@ -518,7 +534,7 @@ def _make_non_single_bond(self, ids, btype, bdist=0.06, dash=[10]): "physics": False, "group": "ligand", "width": 4, - } + }, ) def _make_interactions(self, mass=2): @@ -566,10 +582,12 @@ def _make_interactions(self, mass=2): "to": prot_res, "title": f"{interaction}: {distance:.2f}Å", "interaction_type": self._interaction_types.get( - interaction, interaction + interaction, + interaction, ), "color": self.COLORS["interactions"].get( - interaction, self._default_interaction_color + interaction, + self._default_interaction_color, ), "smooth": {"type": "cubicBezier", "roundness": 0.2}, "dashes": [10], @@ -632,22 +650,22 @@ def _get_js(self, width="100%", height="500px", div_id="mynetwork", fontsize=20) "barnesHut": { "avoidOverlap": self._avoidOverlap, "springConstant": self._springConstant, - } + }, }, } options.update(self.options) - js = self._JS_TEMPLATE % dict( - div_id=div_id, - nodes=json.dumps(self.nodes), - edges=json.dumps(self.edges), - options=json.dumps(options), - ) + js = self._JS_TEMPLATE % { + "div_id": div_id, + "nodes": json.dumps(self.nodes), + "edges": json.dumps(self.edges), + "options": json.dumps(options), + } js += self._get_legend() return js def _get_html(self, **kwargs): """Returns the HTML code to draw the network""" - return self._HTML_TEMPLATE % dict(js=self._get_js(**kwargs)) + return self._HTML_TEMPLATE % {"js": self._get_js(**kwargs)} def _get_legend(self, height="90px"): available = {} @@ -662,12 +680,10 @@ def _get_legend(self, height="90px"): if node.get("group", "") == "protein": color = node["color"] available[color] = map_color_restype.get(color, "Unknown") - available = { - k: v for k, v in sorted(available.items(), key=lambda item: item[1]) - } + available = dict(sorted(available.items(), key=operator.itemgetter(1))) for i, (color, restype) in enumerate(available.items()): buttons.append( - {"index": i, "label": restype, "color": color, "group": "residues"} + {"index": i, "label": restype, "color": color, "group": "residues"}, ) # interactions available.clear() @@ -675,9 +691,7 @@ def _get_legend(self, height="90px"): if edge.get("group", "") == "interaction": color = edge["color"] available[color] = map_color_interactions[color] - available = { - k: v for k, v in sorted(available.items(), key=lambda item: item[1]) - } + available = dict(sorted(available.items(), key=operator.itemgetter(1))) for i, (color, interaction) in enumerate(available.items()): buttons.append( { @@ -685,128 +699,131 @@ def _get_legend(self, height="90px"): "label": interaction, "color": color, "group": "interactions", - } + }, ) # JS code if all("px" in h for h in [self.height, height]): h1 = int(re.findall(r"(\d+)\w+", self.height)[0]) h2 = int(re.findall(r"(\d+)\w+", height)[0]) - self.height = f"{h1+h2}px" - return """ - legend_buttons = %(buttons)s; - legend = document.getElementById('%(div_id)s'); - var div_residues = document.createElement('div'); - var div_interactions = document.createElement('div'); - var disabled = []; - var legend_callback = function() { - this.classList.toggle("disabled"); - var hide = this.classList.contains("disabled"); - var show = !hide; - var btn_label = this.innerHTML; - if (hide) { - disabled.push(btn_label); - } else { - disabled = disabled.filter(x => x !== btn_label); - } - var node_update = [], - edge_update = []; - // click on residue type - if (this.classList.contains("residues")) { - nodes.forEach((node) => { - // find nodes corresponding to this type - if (node.residue_type === btn_label) { - // if hiding this type and residue isn't already hidden - if (hide && !node.hidden) { - node.hidden = true; - node_update.push(node); - // if showing this type and residue isn't already visible - } else if (show && node.hidden) { - // display if there's at least one of its edge that isn't hidden - num_edges_active = edges.filter(x => x.to === node.id) - .map(x => Boolean(x.hidden)) - .filter(x => !x) - .length; - if (num_edges_active > 0) { - node.hidden = false; + self.height = f"{h1 + h2}px" + return ( + """ + legend_buttons = %(buttons)s; + legend = document.getElementById('%(div_id)s'); + var div_residues = document.createElement('div'); + var div_interactions = document.createElement('div'); + var disabled = []; + var legend_callback = function() { + this.classList.toggle("disabled"); + var hide = this.classList.contains("disabled"); + var show = !hide; + var btn_label = this.innerHTML; + if (hide) { + disabled.push(btn_label); + } else { + disabled = disabled.filter(x => x !== btn_label); + } + var node_update = [], + edge_update = []; + // click on residue type + if (this.classList.contains("residues")) { + nodes.forEach((node) => { + // find nodes corresponding to this type + if (node.residue_type === btn_label) { + // if hiding this type and residue isn't already hidden + if (hide && !node.hidden) { + node.hidden = true; node_update.push(node); + // if showing this type and residue isn't already visible + } else if (show && node.hidden) { + // display if there's at least one of its edge that isn't hidden + num_edges_active = edges.filter(x => x.to === node.id) + .map(x => Boolean(x.hidden)) + .filter(x => !x) + .length; + if (num_edges_active > 0) { + node.hidden = false; + node_update.push(node); + } } } - } - }); - ifp.body.data.nodes.update(node_update); - // click on interaction type - } else { - edges.forEach((edge) => { - // find edges corresponding to this type - if (edge.interaction_type === btn_label) { - edge.hidden = !edge.hidden; - edge_update.push(edge); - // number of active edges for the corresponding residue - var num_edges_active = edges.filter(x => x.to === edge.to) - .map(x => Boolean(x.hidden)) - .filter(x => !x) - .length; - // find corresponding residue - var ix = nodes.findIndex(x => x.id === edge.to); - // only change visibility if residue_type not being hidden - if (!(disabled.includes(nodes[ix].residue_type))) { - // hide if no edge being shown for this residue - if (hide && (num_edges_active === 0)) { - nodes[ix].hidden = true; - node_update.push(nodes[ix]); - // show if edges are being shown - } else if (show && (num_edges_active > 0)) { - nodes[ix].hidden = false; - node_update.push(nodes[ix]); + }); + ifp.body.data.nodes.update(node_update); + // click on interaction type + } else { + edges.forEach((edge) => { + // find edges corresponding to this type + if (edge.interaction_type === btn_label) { + edge.hidden = !edge.hidden; + edge_update.push(edge); + // number of active edges for the corresponding residue + var num_edges_active = edges.filter(x => x.to === edge.to) + .map(x => Boolean(x.hidden)) + .filter(x => !x) + .length; + // find corresponding residue + var ix = nodes.findIndex(x => x.id === edge.to); + // only change visibility if residue_type not being hidden + if (!(disabled.includes(nodes[ix].residue_type))) { + // hide if no edge being shown for this residue + if (hide && (num_edges_active === 0)) { + nodes[ix].hidden = true; + node_update.push(nodes[ix]); + // show if edges are being shown + } else if (show && (num_edges_active > 0)) { + nodes[ix].hidden = false; + node_update.push(nodes[ix]); + } } } - } + }); + ifp.body.data.nodes.update(node_update); + ifp.body.data.edges.update(edge_update); + } + }; + legend_buttons.forEach(function(v,i) { + if (v.group === "residues") { + var div = div_residues; + var border = "none"; + var color = v.color; + } else { + var div = div_interactions; + var border = "3px dashed " + v.color; + var color = "white"; + } + var button = div.appendChild(document.createElement('button')); + button.classList.add("legend-btn", v.group); + button.innerHTML = v.label; + Object.assign(button.style, { + "cursor": "pointer", + "background-color": color, + "border": border, + "border-radius": "5px", + "padding": "5px", + "margin": "5px", + "font": "14px 'Arial', sans-serif", }); - ifp.body.data.nodes.update(node_update); - ifp.body.data.edges.update(edge_update); - } - }; - legend_buttons.forEach(function(v,i) { - if (v.group === "residues") { - var div = div_residues; - var border = "none"; - var color = v.color; - } else { - var div = div_interactions; - var border = "3px dashed " + v.color; - var color = "white"; - } - var button = div.appendChild(document.createElement('button')); - button.classList.add("legend-btn", v.group); - button.innerHTML = v.label; - Object.assign(button.style, { - "cursor": "pointer", - "background-color": color, - "border": border, - "border-radius": "5px", - "padding": "5px", - "margin": "5px", - "font": "14px 'Arial', sans-serif", + button.onclick = legend_callback; }); - button.onclick = legend_callback; - }); - legend.appendChild(div_residues); - legend.appendChild(div_interactions); - """ % dict( - div_id="networklegend", buttons=json.dumps(buttons) + legend.appendChild(div_residues); + legend.appendChild(div_interactions); + """ # noqa: E501, UP031 + % { + "div_id": "networklegend", + "buttons": json.dumps(buttons), + } ) @requires("IPython.display") def display(self, **kwargs): """Prepare and display the network""" html = self._get_html(**kwargs) - iframe = ( - '' - ) - return HTML( - iframe.format(width=self.width, height=self.height, doc=escape(html)) + doc = escape(html) + self._iframe = ( + f'' ) + return self @requires("IPython.display") def show(self, filename, **kwargs): @@ -814,13 +831,11 @@ def show(self, filename, **kwargs): html = self._get_html(**kwargs) with open(filename, "w") as f: f.write(html) - iframe = ( - '' - ) - return HTML( - iframe.format(width=self.width, height=self.height, filename=filename) + self._iframe = ( + f'' ) + return self def save(self, fp, **kwargs): """Save the network to an HTML file @@ -836,3 +851,31 @@ def save(self, fp, **kwargs): f.write(html) elif hasattr(fp, "write") and callable(fp.write): fp.write(html) + + @requires("IPython.display") + def save_png(self): + """Saves the current state of the ligplot to a PNG. Not available outside of a + notebook. + + Notes + ----- + Requires calling ``display`` or ``show`` first. The legend won't be exported. + + .. versionadded:: 2.1.0 + """ + return display( + Javascript(f""" + var iframe = document.getElementById("{self.uuid}"); + var iframe_doc = iframe.contentWindow.document; + var canvas = iframe_doc.getElementsByTagName("canvas")[0]; + var link = document.createElement("a"); + link.href = canvas.toDataURL(); + link.download = "prolif-lignetwork.png" + link.click(); + """), + ) + + def _repr_html_(self): # noqa: PLW3201 + if self._iframe: + return self._iframe + return None diff --git a/prolif/plotting/residues.py b/prolif/plotting/residues.py index ec3545e..044b718 100644 --- a/prolif/plotting/residues.py +++ b/prolif/plotting/residues.py @@ -23,6 +23,7 @@ def display_residues( size: Tuple[int, int] = (200, 140), mols_per_row: int = 4, use_svg: bool = True, + sanitize: bool = False, ) -> Any: """Display a grid image of the residues in the molecule. The hydrogens are stripped and the 3D coordinates removed for a clearer visualisation. @@ -40,6 +41,11 @@ def display_residues( Number of residues displayed per row. use_svg: bool = True Generate an SVG or PNG image. + sanitize: bool = False + Sanitize the residues before displaying. + + .. versionchanged:: 2.1.0 + Added ``sanitize`` parameter that defaults to False for easier debugging. """ frags = [] residues_iterable = ( @@ -50,7 +56,7 @@ def display_residues( ) for residue in residues_iterable: - resmol = Chem.RemoveHs(residue) + resmol = Chem.RemoveHs(residue, sanitize=sanitize) resmol.RemoveAllConformers() resmol.SetProp("_Name", str(residue.resid)) frags.append(resmol) diff --git a/prolif/residue.py b/prolif/residue.py index d92334e..5fd6c9f 100644 --- a/prolif/residue.py +++ b/prolif/residue.py @@ -183,17 +183,18 @@ def __getitem__(self, key): # bool is a subclass of int but shouldn't be used here if isinstance(key, bool): raise KeyError( - f"Expected a ResidueId, int, or str, got {type(key).__name__!r} instead" + f"Expected a ResidueId, int, or str, got {type(key).__name__!r}" + " instead", ) if isinstance(key, int): return self._residues[key] - elif isinstance(key, str): + if isinstance(key, str): key = ResidueId.from_string(key) return self.data[key] - elif isinstance(key, ResidueId): + if isinstance(key, ResidueId): return self.data[key] raise KeyError( - f"Expected a ResidueId, int, or str, got {type(key).__name__!r} instead" + f"Expected a ResidueId, int, or str, got {type(key).__name__!r} instead", ) def select(self, mask): diff --git a/prolif/utils.py b/prolif/utils.py index 65fa1d9..c34da04 100644 --- a/prolif/utils.py +++ b/prolif/utils.py @@ -32,7 +32,7 @@ def wrapper(*args, **kwargs): return func(*args, **kwargs) raise ModuleNotFoundError( f"The module {module!r} is required to use {func.__name__!r} " - "but it is not installed!" + "but it is not installed!", ) return wrapper @@ -46,7 +46,7 @@ def catch_rdkit_logs(): rdBase.DisableLog("rdApp.*") yield log_status = {st.split(":")[0]: st.split(":")[1] for st in log_status.split("\n")} - log_status = {k: True if v == "enabled" else False for k, v in log_status.items()} + log_status = {k: v == "enabled" for k, v in log_status.items()} for k, v in log_status.items(): if v is True: rdBase.EnableLog(k) @@ -75,9 +75,8 @@ def get_ring_normal_vector(centroid, coordinates): ca = centroid.DirectionVector(a) cb = centroid.DirectionVector(b) # cross product between these two vectors - normal = ca.CrossProduct(cb) # cb.CrossProduct(ca) is the normal vector in the opposite direction - return normal + return ca.CrossProduct(cb) def angle_between_limits(angle, min_angle, max_angle, ring=False): @@ -129,7 +128,7 @@ def get_residues_near_ligand(lig, prot, cutoff=6.0): """ tree = cKDTree(prot.xyz) ix = tree.query_ball_point(lig.xyz, cutoff) - ix = set([i for lst in ix for i in lst]) + ix = {i for lst in ix for i in lst} resids = [ResidueId.from_atom(prot.GetAtomWithIdx(i)) for i in ix] return list(set(resids)) @@ -250,9 +249,7 @@ def to_dataframe( empty_arr = np.array([empty_value for _ in range(n_interactions)], dtype=dtype) # residue pairs residue_pairs = sorted( - set( - [residue_tuple for frame_ifp in ifp.values() for residue_tuple in frame_ifp] - ) + {residue_tuple for frame_ifp in ifp.values() for residue_tuple in frame_ifp}, ) # sparse to dense data = defaultdict(list) @@ -267,23 +264,25 @@ def to_dataframe( else: if count: bitvector = np.array( - [len(ifp_dict.get(i, ())) for i in interactions], dtype=dtype + [len(ifp_dict.get(i, ())) for i in interactions], + dtype=dtype, ) else: bitvector = np.array( - [i in ifp_dict for i in interactions], dtype=bool + [i in ifp_dict for i in interactions], + dtype=bool, ) data[residue_tuple].append(bitvector) index = pd.Series(index, name=index_col) # create dataframe if not data: - warnings.warn("No interaction detected") + warnings.warn("No interaction detected", stacklevel=2) return pd.DataFrame([], index=index) values = np.array( [ np.hstack([bitvector_list[frame] for bitvector_list in data.values()]) for frame in range(len(index)) - ] + ], ) columns = pd.MultiIndex.from_tuples( [ @@ -340,7 +339,7 @@ def pandas_series_to_countvector(s): size = len(s) cv = UIntSparseIntVect(size) for i in range(size): - cv[i] = int(s[i]) + cv[i] = int(s.iloc[i]) return cv diff --git a/pyproject.toml b/pyproject.toml index 955bf5a..238aef8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ authors = [ { name = "Cédric Bouysset", email = "cedric@bouysset.net" }, ] readme = "README.rst" -requires-python = ">=3.8" +requires-python = ">=3.9" dynamic = ["version"] classifiers = [ "Development Status :: 5 - Production/Stable", @@ -13,7 +13,6 @@ classifiers = [ "Intended Audience :: Science/Research", "License :: OSI Approved :: Apache Software License", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", @@ -31,7 +30,7 @@ keywords = [ "interaction-fingerprint", ] dependencies = [ - "pandas>=1.0.0", + "pandas>=1.1.0", "numpy>=1.13.3", "scipy>=1.3.0", "mdanalysis>=2.2.0", @@ -89,8 +88,69 @@ include = ["prolif*"] [tool.setuptools.dynamic] version = { attr = "prolif._version.__version__" } -[tool.black] +[tool.ruff] line-length = 88 +target-version = "py39" +exclude = [ + ".git", + ".git-rewrite", + ".ipynb_checkpoints", + ".mypy_cache", + ".pytest_cache", + ".ruff_cache", + ".vscode", + "__pypackages__", + "_build", + "build", + "dist", + "site-packages", +] + +[tool.ruff.lint] +select = [ + # pycodestyle + "E", + # Pyflakes + "F", + # Warning + "W", + # pyupgrade + "UP", + # flake8 + "B", # bugbear + "SIM", # simplify + "A", # builtins + "COM", # commas + "C4", # comprehensions + "ISC", # implicit-str-concat + "ICN", # import-conventions + "PIE", # pie + "T20", # print + "PT", # pytest-style + "Q", # quotes + "RET", # return + "ARG", # unused-arguments + "PTH", # use-pathlib + # pandas-vet + "PD", + # pylint + "PLR", # refactor + "PLW", # warning + # numpy + "NPY", + # refurb + "FURB", + # ruff + "RUF", + # isort + "I", +] +ignore = [ + "PLW1514", "PTH123", "PLR0904", "PLR0911", "PLR0913", "PLR0914", "PLR0915", + "PLR0916", "PLR0917", "PLR2004", "PLR6301", "PD901", "PT018", "FURB103", "COM812", + # typing incompatible with 3.9 + "UP035", "UP006", "UP007", +] [tool.isort] profile = "black" diff --git a/scripts/test_build.py b/scripts/test_build.py index 860808c..6e0e696 100644 --- a/scripts/test_build.py +++ b/scripts/test_build.py @@ -2,9 +2,8 @@ from pathlib import Path import prolif -from prolif.plotting.network import LigNetwork -print(prolif.__version__) +print(prolif.__version__) # noqa: T201 assert Path(prolif.datafiles.TOP).is_file() diff --git a/tests/conftest.py b/tests/conftest.py index 5769802..73ed41d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,19 +10,19 @@ from prolif.molecule import Molecule, sdf_supplier -def pytest_sessionstart(session): +def pytest_sessionstart(session): # noqa: ARG001 if not datapath.exists(): pytest.exit( - f"Example data files are not accessible: {datapath!s} does not exist" + f"Example data files are not accessible: {datapath!s} does not exist", ) vina_path = datapath / "vina" if not vina_path.exists(): pytest.exit( - f"Example Vina data files are not accessible: {vina_path!s} does not exist" + f"Example Vina data files are not accessible: {vina_path!s} does not exist", ) # ugly patch to add Mixin class as attribute to pytest so that we don't have to # worry about relative imports in the test codebase - setattr(pytest, "BaseTestMixinRDKitMol", BaseTestMixinRDKitMol) + pytest.BaseTestMixinRDKitMol = BaseTestMixinRDKitMol @pytest.fixture(scope="session") diff --git a/tests/plotting/test_barcode.py b/tests/plotting/test_barcode.py index 8fb2683..8e1b59a 100644 --- a/tests/plotting/test_barcode.py +++ b/tests/plotting/test_barcode.py @@ -22,7 +22,11 @@ def fp(self, request: pytest.FixtureRequest) -> plf.Fingerprint: @pytest.fixture(scope="class") def fp_run( - self, u: mda.Universe, ligand_ag, protein_ag, fp: plf.Fingerprint + self, + u: mda.Universe, + ligand_ag, + protein_ag, + fp: plf.Fingerprint, ) -> plf.Fingerprint: fp.run(u.trajectory[0:2], ligand_ag, protein_ag) return fp @@ -65,6 +69,7 @@ def test_from_fingerprint_raises_not_executed(self) -> None: fp = plf.Fingerprint() with pytest.raises( RunRequiredError, - match="Please run the fingerprint analysis before attempting to display results", + match="Please run the fingerprint analysis before attempting to display" + " results", ): Barcode.from_fingerprint(fp) diff --git a/tests/plotting/test_complex3d.py b/tests/plotting/test_complex3d.py index bb348b6..4b14f76 100644 --- a/tests/plotting/test_complex3d.py +++ b/tests/plotting/test_complex3d.py @@ -36,24 +36,25 @@ def plot_3d(self, fp_mols): def test_integration_display_single(self, plot_3d): view = plot_3d.display(display_all=False) - html = view._make_html() + html = view._view._make_html() assert "Hydrophobic" in html def test_integration_display_all(self, plot_3d): view = plot_3d.display(display_all=True) - html = view._make_html() + html = view._view._make_html() assert "Hydrophobic" in html def test_integration_compare(self, plot_3d): view = plot_3d.compare(plot_3d) - html = view._make_html() + html = view._view._make_html() assert "Hydrophobic" in html def test_from_fingerprint_raises_not_executed(self, ligand_mol, protein_mol): fp = plf.Fingerprint() with pytest.raises( RunRequiredError, - match="Please run the fingerprint analysis before attempting to display results", + match="Please run the fingerprint analysis before attempting to display" + " results", ): Complex3D.from_fingerprint( fp, @@ -65,5 +66,5 @@ def test_from_fingerprint_raises_not_executed(self, ligand_mol, protein_mol): def test_fp_plot_3d(self, fp_mols): fp, lig_mol, prot_mol = fp_mols view = fp.plot_3d(lig_mol, prot_mol, frame=0, display_all=fp.count) - html = view._make_html() + html = view._view._make_html() assert "Hydrophobic" in html diff --git a/tests/plotting/test_network.py b/tests/plotting/test_network.py index b29c9c2..4bcf8e7 100644 --- a/tests/plotting/test_network.py +++ b/tests/plotting/test_network.py @@ -39,7 +39,11 @@ def get_ligplot(self, fp_mol): def test_integration_frame(self, fp_mol): fp, lig_mol = fp_mol net = LigNetwork.from_fingerprint( - fp, lig_mol, kind="frame", frame=0, display_all=fp.count + fp, + lig_mol, + kind="frame", + frame=0, + display_all=fp.count, ) with StringIO() as buffer: net.save(buffer) @@ -75,8 +79,7 @@ def test_save_file(self, get_ligplot, tmp_path): net = get_ligplot() output = tmp_path / "lignetwork.html" net.save(output) - with open(output, "r") as f: - assert "PHE331.B" in f.read() + assert "PHE331.B" in output.read_text() def test_from_fingerprint_raises_kind(self, get_ligplot): with pytest.raises(ValueError, match='must be "aggregate" or "frame"'): @@ -86,11 +89,12 @@ def test_from_fingerprint_raises_not_executed(self, ligand_mol): fp = plf.Fingerprint() with pytest.raises( RunRequiredError, - match="Please run the fingerprint analysis before attempting to display results", + match="Please run the fingerprint analysis before attempting to display" + " results", ): LigNetwork.from_fingerprint(fp, ligand_mol) def test_fp_plot_lignetwork(self, fp_mol): fp, lig_mol = fp_mol - html = fp.plot_lignetwork(lig_mol, kind="frame", frame=0, display_all=fp.count) - assert " None: img = plf.display_residues(protein_mol, **kwargs) assert " 1 res = ResidueId.from_string("ALA216.A") - assert (lig_id, res) in fp_simple.ifp[0].keys() + assert (lig_id, res) in fp_simple.ifp[0] u.trajectory[0] def test_generate(self, fp_simple, ligand_mol, protein_mol): @@ -137,7 +141,11 @@ def test_generate_metadata(self, fp_simple, ligand_mol, protein_mol): def test_run(self, fp_simple, u, ligand_ag, protein_ag): fp_simple.run( - u.trajectory[0:1], ligand_ag, protein_ag, residues=None, progress=False + u.trajectory[0:1], + ligand_ag, + protein_ag, + residues=None, + progress=False, ) assert hasattr(fp_simple, "ifp") ifp = fp_simple.ifp[0] @@ -145,10 +153,8 @@ def test_run(self, fp_simple, u, ligand_ag, protein_ag): assert isinstance(interactions, dict) metadata_tuple = next(iter(interactions.values())) assert all( - [ - key in metadata_tuple[0] - for key in ["indices", "parent_indices", "distance"] - ] + key in metadata_tuple[0] + for key in ["indices", "parent_indices", "distance"] ) def test_run_from_iterable(self, fp_simple, protein_mol): @@ -161,7 +167,11 @@ def test_to_df(self, fp_simple, u, ligand_ag, protein_ag): with pytest.raises(AttributeError, match="use the `run` method"): Fingerprint().to_dataframe() fp_simple.run( - u.trajectory[:3], ligand_ag, protein_ag, residues=None, progress=False + u.trajectory[:3], + ligand_ag, + protein_ag, + residues=None, + progress=False, ) df = fp_simple.to_dataframe() assert isinstance(df, DataFrame) @@ -169,19 +179,27 @@ def test_to_df(self, fp_simple, u, ligand_ag, protein_ag): def test_to_df_kwargs(self, fp_simple, u, ligand_ag, protein_ag): fp_simple.run( - u.trajectory[:3], ligand_ag, protein_ag, residues=None, progress=False + u.trajectory[:3], + ligand_ag, + protein_ag, + residues=None, + progress=False, ) df = fp_simple.to_dataframe(dtype=np.uint8) - assert df.dtypes[0].type is np.uint8 + assert df.dtypes.iloc[0].type is np.uint8 df = fp_simple.to_dataframe(drop_empty=False) - resids = set([key for d in fp_simple.ifp.values() for key in d.keys()]) + resids = {key for d in fp_simple.ifp.values() for key in d} assert df.shape == (3, len(resids)) def test_to_bitvector(self, fp_simple, u, ligand_ag, protein_ag): with pytest.raises(AttributeError, match="use the `run` method"): Fingerprint().to_bitvectors() fp_simple.run( - u.trajectory[:3], ligand_ag, protein_ag, residues=None, progress=False + u.trajectory[:3], + ligand_ag, + protein_ag, + residues=None, + progress=False, ) bvs = fp_simple.to_bitvectors() assert isinstance(bvs[0], ExplicitBitVect) @@ -191,7 +209,11 @@ def test_to_countvectors(self, fp_count, u, ligand_ag, protein_ag): with pytest.raises(AttributeError, match="use the `run` method"): Fingerprint().to_countvectors() fp_count.run( - u.trajectory[:3], ligand_ag, protein_ag, residues=None, progress=False + u.trajectory[:3], + ligand_ag, + protein_ag, + residues=None, + progress=False, ) cvs = fp_count.to_countvectors() assert isinstance(cvs[0], UIntSparseIntVect) @@ -212,27 +234,29 @@ def test_list_avail(self): def test_unknown_interaction(self): with pytest.raises( - NameError, match=r"Unknown interaction\(s\) in 'interactions': foo" + NameError, + match=r"Unknown interaction\(s\) in 'interactions': foo", ): Fingerprint(["Cationic", "foo"]) with pytest.raises( - NameError, match=r"Unknown interaction\(s\) in 'parameters': bar" + NameError, + match=r"Unknown interaction\(s\) in 'parameters': bar", ): Fingerprint(["Cationic"], parameters={"bar": {}}) - @pytest.fixture + @pytest.fixture() def fp_to_pickle(self, fp, protein_mol): path = str(datapath / "vina" / "vina_output.sdf") lig_suppl = list(sdf_supplier(path)) fp.run_from_iterable(lig_suppl[:2], protein_mol, progress=False) return fp - @pytest.fixture + @pytest.fixture() def fp_unpkl(self, fp_to_pickle): pkl = fp_to_pickle.to_pickle() return Fingerprint.from_pickle(pkl) - @pytest.fixture + @pytest.fixture() def fp_unpkl_file(self, fp_to_pickle, tmp_path): pkl_path = tmp_path / "fp.pkl" fp_to_pickle.to_pickle(pkl_path) @@ -256,7 +280,11 @@ def test_run_multiproc_serial_same(self, fp, u, ligand_ag, protein_ag): fp.run(u.trajectory[0:100:10], ligand_ag, protein_ag, n_jobs=1, progress=False) serial = fp.to_dataframe() fp.run( - u.trajectory[0:100:10], ligand_ag, protein_ag, n_jobs=None, progress=False + u.trajectory[0:100:10], + ligand_ag, + protein_ag, + n_jobs=None, + progress=False, ) multi = fp.to_dataframe() assert serial.equals(multi) @@ -273,7 +301,8 @@ def test_run_iter_multiproc_serial_same(self, fp, protein_mol): def test_converter_kwargs_raises_error(self, fp, u, ligand_ag, protein_ag): with pytest.raises( - ValueError, match="converter_kwargs must be a list of 2 dicts" + ValueError, + match="converter_kwargs must be a list of 2 dicts", ): fp.run( u.trajectory[0:5], @@ -281,7 +310,7 @@ def test_converter_kwargs_raises_error(self, fp, u, ligand_ag, protein_ag): protein_ag, n_jobs=1, progress=False, - converter_kwargs=[dict(force=True)], + converter_kwargs=[{"force": True}], ) @pytest.mark.parametrize("n_jobs", [1, 2]) @@ -293,7 +322,7 @@ def test_converter_kwargs(self, fp, n_jobs): lig, prot, n_jobs=n_jobs, - converter_kwargs=[dict(force=True), dict(force=True)], + converter_kwargs=[{"force": True}, {"force": True}], ) assert fp.ifp diff --git a/tests/test_interactions.py b/tests/test_interactions.py index 122396a..0b0067c 100644 --- a/tests/test_interactions.py +++ b/tests/test_interactions.py @@ -55,7 +55,7 @@ def fingerprint(self): return Fingerprint() @pytest.mark.parametrize( - "func_name, any_mol, any_other_mol, expected", + ("func_name", "any_mol", "any_other_mol", "expected"), [ ("cationic", "cation", "anion", True), ("cationic", "anion", "cation", False), @@ -117,7 +117,12 @@ def fingerprint(self): indirect=["any_mol", "any_other_mol"], ) def test_interaction( - self, fingerprint, func_name, any_mol, any_other_mol, expected + self, + fingerprint, + func_name, + any_mol, + any_other_mol, + expected, ): interaction = getattr(fingerprint, func_name) assert next(interaction(any_mol[0], any_other_mol[0]), False) is expected @@ -134,12 +139,13 @@ def detect(self): assert old != new # fix dummy Hydrophobic class being reused in later unrelated tests - class Hydrophobic(prolif.interactions.Hydrophobic): + class Hydrophobic(prolif.interactions.Hydrophobic): # noqa: F811 __doc__ = prolif.interactions.Hydrophobic.__doc__ def test_error_no_detect(self): with pytest.raises( - TypeError, match="Can't instantiate interaction class _Dummy" + TypeError, + match="Can't instantiate interaction class _Dummy", ): class _Dummy(Interaction): @@ -155,7 +161,7 @@ def test_vdwcontact_tolerance_error(self): VdWContact(tolerance=-1) @pytest.mark.parametrize( - "any_mol, any_other_mol", + ("any_mol", "any_other_mol"), [("benzene", "cation")], indirect=["any_mol", "any_other_mol"], ) @@ -168,7 +174,7 @@ def test_vdwcontact_cache(self, any_mol, any_other_mol): assert vdw_dist == value @pytest.mark.parametrize( - "any_mol, any_other_mol", + ("any_mol", "any_other_mol"), [("benzene", "cation")], indirect=["any_mol", "any_other_mol"], ) @@ -178,7 +184,7 @@ def test_vdwcontact_vdwradii_update(self, any_mol, any_other_mol): assert next(metadata, None) is None @pytest.mark.parametrize( - ["interaction_qmol", "smiles", "expected"], + ("interaction_qmol", "smiles", "expected"), [ ("Hydrophobic.lig_pattern", "C", 1), ("Hydrophobic.lig_pattern", "C=[SH2]", 1), @@ -262,7 +268,7 @@ def test_smarts_matches(self, interaction_qmol, smiles, expected): assert n_matches == expected @pytest.mark.parametrize( - ["xyz", "rotation", "pi_type", "expected"], + ("xyz", "rotation", "pi_type", "expected"), [ ([0, 2.5, 4.0], [0, 0, 0], "facetoface", True), ([0, 3, 4.5], [0, 0, 0], "facetoface", False), @@ -283,12 +289,19 @@ def test_smarts_matches(self, interaction_qmol, smiles, expected): ], ) def test_pi_stacking( - self, benzene_universe, xyz, rotation, pi_type, expected, fingerprint + self, + benzene_universe, + xyz, + rotation, + pi_type, + expected, + fingerprint, ): r1, r2 = self.create_rings(benzene_universe, xyz, rotation) - evaluate = lambda pistacking_type, r1, r2: next( - getattr(fingerprint, pistacking_type)(r1, r2), False - ) + + def evaluate(pistacking_type, r1, r2): + return next(getattr(fingerprint, pistacking_type)(r1, r2), False) + assert evaluate(pi_type, r1, r2) is expected if expected is True: other = "edgetoface" if pi_type == "facetoface" else "facetoface" diff --git a/tests/test_molecule.py b/tests/test_molecule.py index 9bf8ab2..514794d 100644 --- a/tests/test_molecule.py +++ b/tests/test_molecule.py @@ -5,7 +5,7 @@ from prolif.datafiles import datapath from prolif.molecule import Molecule, mol2_supplier, pdbqt_supplier, sdf_supplier -from prolif.residue import ResidueId +from prolif.residue import Residue, ResidueId class TestMolecule(pytest.BaseTestMixinRDKitMol): @@ -22,7 +22,7 @@ def test_from_mda(self, u, ligand_rdkit): mda_mol = Molecule.from_mda(u, "resname LIG") assert rdkit_mol[0].resid == mda_mol[0].resid assert rdkit_mol.HasSubstructMatch(mda_mol) and mda_mol.HasSubstructMatch( - rdkit_mol + rdkit_mol, ) def test_from_mda_empty_ag(self, u): @@ -50,8 +50,8 @@ def test_getitem(self, mol, key): assert mol[key].resid is mol.residues[key].resid def test_iter(self, mol): - for i, r in enumerate(mol): - assert r.resid == mol[i].resid + for r in mol: + assert isinstance(r, Residue) def test_n_residues(self, mol): assert mol.n_residues == mol.residues.n_residues @@ -82,7 +82,7 @@ def test_index(self, suppl): class TestPDBQTSupplier(SupplierBase): resid = ResidueId("LIG", 1, "G") - @pytest.fixture + @pytest.fixture() def suppl(self): path = datapath / "vina" pdbqts = sorted(path.glob("*.pdbqt")) @@ -122,14 +122,20 @@ def test_pdbqt_hydrogens_stay_in_mol(self, ligand_rdkit): class TestSDFSupplier(SupplierBase): - @pytest.fixture + @pytest.fixture() def suppl(self): path = str(datapath / "vina" / "vina_output.sdf") return sdf_supplier(path) + def test_sanitize(self): + path = str(datapath / "vina" / "vina_output.sdf") + suppl = sdf_supplier(path, sanitize=False) + mol = next(iter(suppl)) + assert isinstance(mol, Molecule) + class TestMOL2Supplier(SupplierBase): - @pytest.fixture + @pytest.fixture() def suppl(self): path = str(datapath / "vina" / "vina_output.mol2") return mol2_supplier(path) @@ -139,3 +145,15 @@ def test_mol2_starting_with_comment(self): suppl = mol2_supplier(path) mol = next(iter(suppl)) assert mol is not None + + def test_sanitize(self): + path = str(datapath / "vina" / "vina_output.mol2") + suppl = mol2_supplier(path, sanitize=False) + mol = next(iter(suppl)) + assert isinstance(mol, Molecule) + + def test_cleanup_substructures(self): + path = str(datapath / "vina" / "vina_output.mol2") + suppl = mol2_supplier(path, cleanup_substructures=False) + mol = next(iter(suppl)) + assert isinstance(mol, Molecule) diff --git a/tests/test_pickling.py b/tests/test_pickling.py index 860ff42..4903f30 100644 --- a/tests/test_pickling.py +++ b/tests/test_pickling.py @@ -7,7 +7,7 @@ @pytest.fixture(autouse=True) -def reset_default_pickle_properties(): +def _reset_default_pickle_properties(): default = Chem.GetDefaultPickleProperties() yield Chem.SetDefaultPickleProperties(default) diff --git a/tests/test_residues.py b/tests/test_residues.py index 3d22149..b792a75 100644 --- a/tests/test_residues.py +++ b/tests/test_residues.py @@ -8,7 +8,7 @@ class TestResidueId: @pytest.mark.parametrize( - "name, number, chain", + ("name", "number", "chain"), [ ("ALA", None, None), ("ALA", 1, None), @@ -34,7 +34,7 @@ def test_init(self, name, number, chain): assert resid.chain == chain @pytest.mark.parametrize( - "name, number, chain", + ("name", "number", "chain"), [ ("ALA", None, None), ("ALA", 1, None), @@ -77,7 +77,7 @@ def test_from_atom_no_mi(self): assert resid.chain is None @pytest.mark.parametrize( - "resid_str, expected", + ("resid_str", "expected"), [ ("ALA", ("ALA", 0, None)), ("ALA1", ("ALA", 1, None)), @@ -112,7 +112,7 @@ def test_eq(self): assert res1 == res2 @pytest.mark.parametrize( - "res1, res2", + ("res1", "res2"), [ ("ALA1.A", "ALA1.B"), ("ALA2.A", "ALA3.A"), @@ -153,10 +153,7 @@ class TestResidueGroup: def residues(self): sequence = "ARNDCQEGHILKMFPSTWYV" protein = Chem.MolFromSequence(sequence) - residues = [ - Residue(res) for res in Chem.SplitMolByPDBResidues(protein).values() - ] - return residues + return [Residue(res) for res in Chem.SplitMolByPDBResidues(protein).values()] def test_init(self, residues): rg = ResidueGroup(residues) @@ -185,7 +182,7 @@ def test_n_residues(self, residues): assert rg.n_residues == 20 @pytest.mark.parametrize( - "ix, resid, resid_str", + ("ix", "resid", "resid_str"), [ (0, ("ALA", 1, "A"), "ALA1.A"), (4, ("CYS", 5, "A"), "CYS5.A"), diff --git a/tests/test_utils.py b/tests/test_utils.py index 8f54b5e..e465e4c 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -19,7 +19,7 @@ ) -@pytest.fixture +@pytest.fixture() def ifp_single(): return { 0: { @@ -36,7 +36,7 @@ def ifp_single(): } -@pytest.fixture +@pytest.fixture() def ifp_count(): return { 0: { @@ -44,7 +44,7 @@ def ifp_count(): "A": ( {"indices": {"ligand": (0,), "protein": (1,)}}, {"indices": {"ligand": (1,), "protein": (1,)}}, - ) + ), }, ("LIG", "GLU2"): {"B": ({"indices": {"ligand": (1,), "protein": (3,)}},)}, }, @@ -77,7 +77,7 @@ def test_centroid(): @pytest.mark.parametrize( - "angle, mina, maxa, ring, expected", + ("angle", "mina", "maxa", "ring", "expected"), [ (0, 0, 30, False, True), (30, 0, 30, False, True), @@ -161,7 +161,7 @@ def test_split_residues(): sequence = "ARNDCQEGHILKMFPSTWYV" prot = Chem.MolFromSequence(sequence) rg = ResidueGroup( - [Residue(res) for res in Chem.SplitMolByPDBResidues(prot).values()] + [Residue(res) for res in Chem.SplitMolByPDBResidues(prot).values()], ) residues = [Residue(mol) for mol in split_mol_by_residues(prot)] residues.sort(key=lambda x: x.resid) @@ -172,7 +172,7 @@ def test_split_residues(): def test_is_peptide_bond(): mol = Chem.RWMol() - for i in range(3): + for _ in range(3): a = Chem.Atom(6) mol.AddAtom(a) mol.AddBond(0, 1) @@ -196,7 +196,7 @@ def test_series_to_bv(): def test_to_df(ifp): df = to_dataframe(ifp, ["A", "B", "C"]) assert df.shape == (2, 4) - assert df.dtypes[0].type is np.bool_ + assert df.dtypes.iloc[0].type is np.bool_ assert df.index.name == "Frame" assert ("LIG", "ALA1", "A") in df.columns assert df[("LIG", "ALA1", "A")][0] is np.bool_(True) @@ -218,7 +218,7 @@ def test_to_df(ifp): ) def test_to_df_dtype(dtype, ifp): df = to_dataframe(ifp, ["A", "B", "C"], dtype=dtype) - assert df.dtypes[0].type is dtype + assert df.dtypes.iloc[0].type is dtype assert df[("LIG", "ALA1", "A")][0] == dtype(True) assert df[("LIG", "ALA1", "B")][0] == dtype(False) assert df[("LIG", "ASP3", "B")][0] == dtype(False)