Skip to content

Commit

Permalink
feat: add option to return coord mask
Browse files Browse the repository at this point in the history
  • Loading branch information
kdidiNVIDIA committed Jun 12, 2024
1 parent 27463a5 commit b02f1d7
Showing 1 changed file with 27 additions and 23 deletions.
50 changes: 27 additions & 23 deletions graphein/protein/tensor/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
conda_channel="pyg",
pip_install=True,
)
log.debug(message)
log.warning(message)

try:
import torch
Expand All @@ -60,7 +60,7 @@
conda_channel="pytorch",
pip_install=True,
)
log.debug(message)
log.warning(message)


def get_protein_length(df: pd.DataFrame, insertions: bool = True) -> int:
Expand Down Expand Up @@ -108,8 +108,8 @@ def protein_to_pyg(
atom_types: List[str] = PROTEIN_ATOMS,
remove_nonstandard: bool = True,
store_het: bool = False,
store_bfactor: bool = False,
fill_value_coords: float = 1e-5,
return_coord_mask: bool = False,
) -> Data:
"""
Parses a protein (from either: a PDB code, PDB file or a UniProt ID
Expand Down Expand Up @@ -160,12 +160,9 @@ def protein_to_pyg(
:param store_het: Whether or not to store heteroatoms in the ``Data``
object. Default is ``False``.
:type store_het: bool
:param store_bfactor: Whether or not to store bfactors in the ``Data``
object. Default is ``False.
:type store_bfactor: bool
:param fill_value_coords: Fill value to use for positions in atom37
representation that are not filled. Defaults to 1e-5
:type fill_value_coords: float
:param return_coord_mask: Whether to include the coordinate mask as a feature. Default is
``False``.
:type keep_insertions: bool
:returns: ``Data`` object with attributes: ``x`` (AtomTensor), ``residues``
(list of 3-letter residue codes), id (ID of protein), residue_id (E.g.
``"A:SER:1"``), residue_type (torch.Tensor), ``chains`` (torch.Tensor).
Expand Down Expand Up @@ -245,9 +242,6 @@ def protein_to_pyg(
df["residue_id"] = df.residue_id + ":" + df.insertion

out = Data(
coords=protein_df_to_tensor(
df, atoms_to_keep=atom_types, fill_value=fill_value_coords
),
residues=get_sequence(
df,
chains=chain_selection,
Expand All @@ -259,14 +253,17 @@ def protein_to_pyg(
residue_type=residue_type_tensor(df),
chains=protein_df_to_chain_tensor(df),
)
if return_coord_mask:
coords,coord_mask=protein_df_to_tensor(
df, atoms_to_keep=atom_types, fill_value=fill_value_coords, return_coord_mask=return_coord_mask
)
else:
coords=protein_df_to_tensor(
df, atoms_to_keep=atom_types, fill_value=fill_value_coords
)

if store_het:
out.hetatms = [het_coords]

if store_bfactor:
# group by residue_id and average b_factor per residue
residue_bfactors = df.groupby("residue_id")["b_factor"].mean()
out.bfactor = torch.from_numpy(residue_bfactors.values)

return out


Expand Down Expand Up @@ -330,6 +327,7 @@ def protein_df_to_tensor(
atoms_to_keep: List[str] = PROTEIN_ATOMS,
insertions: bool = True,
fill_value: float = 1e-5,
return_coord_mask: bool = False
) -> AtomTensor:
"""
Transforms a DataFrame of a protein structure into a
Expand All @@ -344,14 +342,14 @@ def protein_df_to_tensor(
:type insertions: bool
:param fill_value: Value to fill missing entries with. Defaults to ``1e-5``.
:type fill_value: float
:returns: ``Length x Num_Atoms (default 37) x 3`` tensor.
:param return_coord_mask: Whether to return the coord mask created. Defaults to ``False``.
:type insertions: bool
:returns: ``Length x Num_Atoms (default 37) x 3`` tensor and, if return_coord_mask==True, also the coord_mask tensor.
:rtype: graphein.protein.tensor.types.AtomTensor
"""
num_residues = get_protein_length(df, insertions=insertions)
df = df.loc[df["atom_name"].isin(atoms_to_keep)]
residue_indices = pd.factorize(
pd.Series(get_residue_id(df, unique=False))
)[0]
residue_indices = pd.factorize(get_residue_id(df, unique=False))[0]
atom_indices = df["atom_name"].map(lambda x: atoms_to_keep.index(x)).values

positions: AtomTensor = (
Expand All @@ -360,7 +358,13 @@ def protein_df_to_tensor(
positions[residue_indices, atom_indices] = torch.tensor(
df[["x_coord", "y_coord", "z_coord"]].values
).float()
return positions

if return_coord_mask:
coord_mask = torch.zeros((num_residues, len(atoms_to_keep)))
coord_mask[residue_indices, atom_indices] = 1
return coords, coord_mask
else:
return positions


def to_dataframe(
Expand Down

0 comments on commit b02f1d7

Please sign in to comment.