diff --git a/README.md b/README.md index 9fba93f..71cec7e 100644 --- a/README.md +++ b/README.md @@ -185,6 +185,7 @@ The `batch` variable above is a `collections.namedtuple` that has the following | `batch.pids` | Tuple of ProteinNet/SidechainNet IDs for proteins in this batch | | `batch.seqs` | Tensor of sequences, either as integers or as one-hot vectors depending on value of `scn.load(... seq_as_onehot)` | | `batch.int_seqs` | Tensor of sequences in integer sequence format | + | `batch.str_seqs` | Tuple of sequences as strings (unpadded) | | `batch.msks` | Tensor of missing residue masks, (redundant with padding in data) | | `batch.evos` | Tensor of Position Specific Scoring Matrix + Information Content | | `batch.secs` | Tensor of secondary structure, either as integers or one-hot vectors depending on value of `scn.load(... seq_as_onehot)` | diff --git a/sidechainnet/dataloaders/ProteinDataset.py b/sidechainnet/dataloaders/ProteinDataset.py index b3a6b09..285d38d 100644 --- a/sidechainnet/dataloaders/ProteinDataset.py +++ b/sidechainnet/dataloaders/ProteinDataset.py @@ -19,6 +19,7 @@ def __init__(self, # Organize data self.seqs = [VOCAB.str2ints(s, add_sos_eos) for s in scn_data_split['seq']] + self.str_seqs = scn_data_split['seq'] self.angs = scn_data_split['ang'] self.crds = scn_data_split['crd'] self.msks = [ @@ -50,6 +51,7 @@ def _sort_by_length(self, reverse_sort): enumerate(self.angs), key=lambda x: x[1].shape[0], reverse=reverse_sort) ] self.seqs = [self.seqs[i] for i in sorted_len_indices] + self.str_seqs = [self.str_seqs[i] for i in sorted_len_indices] self.angs = [self.angs[i] for i in sorted_len_indices] self.crds = [self.crds[i] for i in sorted_len_indices] self.msks = [self.msks[i] for i in sorted_len_indices] @@ -65,7 +67,7 @@ def __len__(self): def __getitem__(self, idx): return (self.ids[idx], self.seqs[idx], self.msks[idx], self.evos[idx], self.secs[idx], self.angs[idx], self.crds[idx], self.resolutions[idx], - self.mods[idx]) + self.mods[idx], self.str_seqs[idx]) def __str__(self): """Describe this dataset to the user.""" diff --git a/sidechainnet/dataloaders/collate.py b/sidechainnet/dataloaders/collate.py index 62052f0..b3b9980 100644 --- a/sidechainnet/dataloaders/collate.py +++ b/sidechainnet/dataloaders/collate.py @@ -15,7 +15,7 @@ Batch = collections.namedtuple("Batch", "pids seqs msks evos secs angs " "crds int_seqs seq_evo_sec resolutions is_modified " - "lengths") + "lengths str_seqs") def get_collate_fn(aggregate_input, seqs_as_onehot=None): @@ -60,7 +60,7 @@ def collate_fn(insts): """ # Instead of working with a list of tuples, we extract out each category of info # so it can be padded and re-provided to the user. - pnids, sequences, masks, pssms, secs, angles, coords, resolutions, mods = list(zip(*insts)) + pnids, sequences, masks, pssms, secs, angles, coords, resolutions, mods, str_seqs = list(zip(*insts)) lengths = tuple(len(s) for s in sequences) max_batch_len = max(lengths) @@ -98,7 +98,8 @@ def collate_fn(insts): seq_evo_sec=None, resolutions=resolutions, is_modified=padded_mods, - lengths=lengths) + lengths=lengths, + str_seqs=str_seqs) # Aggregated model input elif aggregate_input: @@ -117,7 +118,8 @@ def collate_fn(insts): seq_evo_sec=seq_evo_sec, resolutions=resolutions, is_modified=padded_mods, - lengths=lengths) + lengths=lengths, + str_seqs=str_seqs) return collate_fn