diff --git a/graphein/protein/tensor/io.py b/graphein/protein/tensor/io.py index 4bd139c5..a9e96473 100644 --- a/graphein/protein/tensor/io.py +++ b/graphein/protein/tensor/io.py @@ -108,6 +108,7 @@ def protein_to_pyg( atom_types: List[str] = PROTEIN_ATOMS, remove_nonstandard: bool = True, store_het: bool = False, + store_bfactor: bool = False, fill_value_coords: float = 1e-5, return_coord_mask: bool = False, ) -> Data: @@ -163,6 +164,12 @@ def protein_to_pyg( :param return_coord_mask: Whether to include the coordinate mask as a feature. Default is ``False``. :type keep_insertions: bool + :param store_bfactor: Whether or not to store bfactors in the ``Data`` + object. Default is ``False. + :type store_bfactor: bool + :param return_coord_mask: Whether to include the coordinate mask as a feature. Default is + ``False``. + :type keep_insertions: bool :returns: ``Data`` object with attributes: ``x`` (AtomTensor), ``residues`` (list of 3-letter residue codes), id (ID of protein), residue_id (E.g. ``"A:SER:1"``), residue_type (torch.Tensor), ``chains`` (torch.Tensor). @@ -253,17 +260,24 @@ def protein_to_pyg( residue_type=residue_type_tensor(df), chains=protein_df_to_chain_tensor(df), ) + if return_coord_mask: - coords,coord_mask=protein_df_to_tensor( + out.coords,out.coord_mask=protein_df_to_tensor( df, atoms_to_keep=atom_types, fill_value=fill_value_coords, return_coord_mask=return_coord_mask ) else: - coords=protein_df_to_tensor( + out.coords=protein_df_to_tensor( df, atoms_to_keep=atom_types, fill_value=fill_value_coords ) if store_het: out.hetatms = [het_coords] + + if store_bfactor: + # group by residue_id and average b_factor per residue + residue_bfactors = df.groupby("residue_id")["b_factor"].mean() + out.bfactor = torch.from_numpy(residue_bfactors.values) + return out