diff --git a/graphein/protein/tensor/io.py b/graphein/protein/tensor/io.py index 58089158..4bd139c5 100644 --- a/graphein/protein/tensor/io.py +++ b/graphein/protein/tensor/io.py @@ -49,7 +49,7 @@ conda_channel="pyg", pip_install=True, ) - log.debug(message) + log.warning(message) try: import torch @@ -60,7 +60,7 @@ conda_channel="pytorch", pip_install=True, ) - log.debug(message) + log.warning(message) def get_protein_length(df: pd.DataFrame, insertions: bool = True) -> int: @@ -108,8 +108,8 @@ def protein_to_pyg( atom_types: List[str] = PROTEIN_ATOMS, remove_nonstandard: bool = True, store_het: bool = False, - store_bfactor: bool = False, fill_value_coords: float = 1e-5, + return_coord_mask: bool = False, ) -> Data: """ Parses a protein (from either: a PDB code, PDB file or a UniProt ID @@ -160,12 +160,9 @@ def protein_to_pyg( :param store_het: Whether or not to store heteroatoms in the ``Data`` object. Default is ``False``. :type store_het: bool - :param store_bfactor: Whether or not to store bfactors in the ``Data`` - object. Default is ``False. - :type store_bfactor: bool - :param fill_value_coords: Fill value to use for positions in atom37 - representation that are not filled. Defaults to 1e-5 - :type fill_value_coords: float + :param return_coord_mask: Whether to include the coordinate mask as a feature. Default is + ``False``. + :type keep_insertions: bool :returns: ``Data`` object with attributes: ``x`` (AtomTensor), ``residues`` (list of 3-letter residue codes), id (ID of protein), residue_id (E.g. ``"A:SER:1"``), residue_type (torch.Tensor), ``chains`` (torch.Tensor). @@ -245,9 +242,6 @@ def protein_to_pyg( df["residue_id"] = df.residue_id + ":" + df.insertion out = Data( - coords=protein_df_to_tensor( - df, atoms_to_keep=atom_types, fill_value=fill_value_coords - ), residues=get_sequence( df, chains=chain_selection, @@ -259,14 +253,17 @@ def protein_to_pyg( residue_type=residue_type_tensor(df), chains=protein_df_to_chain_tensor(df), ) + if return_coord_mask: + coords,coord_mask=protein_df_to_tensor( + df, atoms_to_keep=atom_types, fill_value=fill_value_coords, return_coord_mask=return_coord_mask + ) + else: + coords=protein_df_to_tensor( + df, atoms_to_keep=atom_types, fill_value=fill_value_coords + ) + if store_het: out.hetatms = [het_coords] - - if store_bfactor: - # group by residue_id and average b_factor per residue - residue_bfactors = df.groupby("residue_id")["b_factor"].mean() - out.bfactor = torch.from_numpy(residue_bfactors.values) - return out @@ -330,6 +327,7 @@ def protein_df_to_tensor( atoms_to_keep: List[str] = PROTEIN_ATOMS, insertions: bool = True, fill_value: float = 1e-5, + return_coord_mask: bool = False ) -> AtomTensor: """ Transforms a DataFrame of a protein structure into a @@ -344,14 +342,14 @@ def protein_df_to_tensor( :type insertions: bool :param fill_value: Value to fill missing entries with. Defaults to ``1e-5``. :type fill_value: float - :returns: ``Length x Num_Atoms (default 37) x 3`` tensor. + :param return_coord_mask: Whether to return the coord mask created. Defaults to ``False``. + :type insertions: bool + :returns: ``Length x Num_Atoms (default 37) x 3`` tensor and, if return_coord_mask==True, also the coord_mask tensor. :rtype: graphein.protein.tensor.types.AtomTensor """ num_residues = get_protein_length(df, insertions=insertions) df = df.loc[df["atom_name"].isin(atoms_to_keep)] - residue_indices = pd.factorize( - pd.Series(get_residue_id(df, unique=False)) - )[0] + residue_indices = pd.factorize(get_residue_id(df, unique=False))[0] atom_indices = df["atom_name"].map(lambda x: atoms_to_keep.index(x)).values positions: AtomTensor = ( @@ -360,7 +358,13 @@ def protein_df_to_tensor( positions[residue_indices, atom_indices] = torch.tensor( df[["x_coord", "y_coord", "z_coord"]].values ).float() - return positions + + if return_coord_mask: + coord_mask = torch.zeros((num_residues, len(atoms_to_keep))) + coord_mask[residue_indices, atom_indices] = 1 + return coords, coord_mask + else: + return positions def to_dataframe(