From f03e5d3f139fa5fc42ec517dfa4240b13e613d53 Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Thu, 16 Jan 2025 15:38:13 -0800 Subject: [PATCH] tests passing --- netam/dxsm.py | 29 +++++++++++------- netam/framework.py | 59 +++++++++++++++++++----------------- netam/sequences.py | 66 +++++++++++++++++++++++++++++++++++++++++ tests/test_ambiguous.py | 5 ++-- tests/test_sequences.py | 32 ++++++++++++++++++++ 5 files changed, 150 insertions(+), 41 deletions(-) diff --git a/netam/dxsm.py b/netam/dxsm.py index 74117690..0d50aaf1 100644 --- a/netam/dxsm.py +++ b/netam/dxsm.py @@ -28,6 +28,7 @@ apply_aa_mask_to_nt_sequence, nt_mutation_frequency, strip_unrecognized_tokens_from_series, + dataset_inputs_of_pcp_df, MAX_AA_TOKEN_IDX, RESERVED_TOKEN_REGEX, AA_AMBIG_IDX, @@ -46,14 +47,13 @@ def __init__( branch_lengths: torch.Tensor, model_embedding_dim: int, multihit_model=None, + # TODO For debugging: + succeed=False, ): + assert succeed, "Dataset should be created through other constructor" + #This is no longer needed here, but it seems like we should be able to verify what model version an instance is built for anyway: self.model_embedding_dim = model_embedding_dim - nt_parents = strip_unrecognized_tokens_from_series( - nt_parents, self.model_embedding_dim - ) - nt_children = strip_unrecognized_tokens_from_series( - nt_children, self.model_embedding_dim - ) + # We will replace reserved tokens with Ns but use the unmodified # originals for translation and mask creation. self.nt_parents = nt_parents.str.replace(RESERVED_TOKEN_REGEX, "N", regex=True) @@ -122,6 +122,7 @@ def of_seriess( model_embedding_dim, branch_length_multiplier=5.0, multihit_model=None, + succeed=False, ): """Alternative constructor that takes the raw data and calculates the initial branch lengths. @@ -143,6 +144,7 @@ def of_seriess( initial_branch_lengths, model_embedding_dim, multihit_model=multihit_model, + succeed=succeed, ) @classmethod @@ -156,16 +158,19 @@ def of_pcp_df( """Alternative constructor that takes in a pcp_df and calculates the initial branch lengths.""" assert ( - "nt_rates" in pcp_df.columns + "nt_rates_l" in pcp_df.columns ), "pcp_df must have a neutral nt_rates column" + # use sequences.prepare_heavy_light_pair and the resulting + # added_indices to get the parent and child sequences and neutral model + # outputs + + return cls.of_seriess( - pcp_df["parent"], - pcp_df["child"], - pcp_df["nt_rates"], - pcp_df["nt_csps"], + *dataset_inputs_of_pcp_df(pcp_df, model_embedding_dim), model_embedding_dim, branch_length_multiplier=branch_length_multiplier, multihit_model=multihit_model, + succeed=True, ) @classmethod @@ -212,6 +217,7 @@ def clone(self): self._branch_lengths.copy(), self.model_embedding_dim, multihit_model=self.multihit_model, + succeed=True, ) return new_dataset @@ -230,6 +236,7 @@ def subset_via_indices(self, indices): self._branch_lengths[indices], self.model_embedding_dim, multihit_model=self.multihit_model, + succeed=True, ) return new_dataset diff --git a/netam/framework.py b/netam/framework.py index b4db2d7e..7add749b 100644 --- a/netam/framework.py +++ b/netam/framework.py @@ -357,14 +357,11 @@ def trimmed_shm_model_outputs_of_crepe(crepe, parents): return trimmed_rates, trimmed_csps -def join_chains(pcp_df): - """Join the parent and child chains in the pcp_df. - - Make a parent column that is the parent_h + "^^^" + parent_l, and same for child. +def standardize_heavy_light_columns(pcp_df): + """Ensure that heavy and light chain columns are present, and fill missing ones + with placeholder values. - If parent_h and parent_l are not present, then we assume that the parent is the - heavy chain. If only one of parent_h or parent_l is present, then we place the ^^^ - padding to the right of heavy, or to the left of light. + If only `parent` and `child` column is present, we assume these are heavy chain sequences. """ cols = pcp_df.columns # Look for heavy chain @@ -380,7 +377,7 @@ def join_chains(pcp_df): else: pcp_df["parent_h"] = "" pcp_df["child_h"] = "" - pcp_df["v_gene_h"] = "N/A" + pcp_df["v_gene_h"] = "" # Look for light chain if "parent_l" in cols: assert "child_l" in cols, "child_l column missing!" @@ -388,7 +385,7 @@ def join_chains(pcp_df): else: pcp_df["parent_l"] = "" pcp_df["child_l"] = "" - pcp_df["v_gene_l"] = "N/A" + pcp_df["v_gene_l"] = "" if (pcp_df["parent_h"].str.len() + pcp_df["parent_l"].str.len()).min() < 3: raise ValueError("At least one PCP has fewer than three nucleotides.") @@ -396,6 +393,28 @@ def join_chains(pcp_df): pcp_df["parent"] = pcp_df["parent_h"] + "^^^" + pcp_df["parent_l"] pcp_df["child"] = pcp_df["child_h"] + "^^^" + pcp_df["child_l"] + pcp_df["v_family_h"] = pcp_df["v_gene_h"].str.split("-").str[0] + pcp_df["v_family_l"] = pcp_df["v_gene_l"].str.split("-").str[0] + + pcp_df.drop( + columns=["parent", "child", "v_gene"], + inplace=True, + errors="ignore", + ) + return pcp_df + + +# TODO maybe call this from the dataset constructor now? +def join_chains(pcp_df): + """Join the parent and child chains in the pcp_df. + + Make a parent column that is the parent_h + "^^^" + parent_l, and same for child. + """ + pcp_df = pcp_df.copy() + + pcp_df["parent"] = pcp_df["parent_h"] + "^^^" + pcp_df["parent_l"] + pcp_df["child"] = pcp_df["child_h"] + "^^^" + pcp_df["child_l"] + pcp_df.drop( columns=["parent_h", "parent_l", "child_h", "child_l", "v_gene"], inplace=True, @@ -419,12 +438,11 @@ def load_pcp_df(pcp_df_path_gz, sample_count=None, chosen_v_families=None): .reset_index() .rename(columns={"index": "orig_pcp_idx"}) ) - pcp_df = join_chains(pcp_df) - pcp_df["v_family_h"] = pcp_df["v_gene_h"].str.split("-").str[0] - pcp_df["v_family_l"] = pcp_df["v_gene_l"].str.split("-").str[0] + pcp_df = standardize_heavy_light_columns(pcp_df) if chosen_v_families is not None: chosen_v_families = set(chosen_v_families) + # TODO is this the right way to handle this? Or should it be OR? pcp_df = pcp_df[ pcp_df["v_family_h"].isin(chosen_v_families) & pcp_df["v_family_l"].isin(chosen_v_families) @@ -436,21 +454,8 @@ def load_pcp_df(pcp_df_path_gz, sample_count=None, chosen_v_families=None): def add_shm_model_outputs_to_pcp_df(pcp_df, crepe): - # Split parent heavy and light chains to apply neutral model separately - split_parents = pcp_df["parent"].str.split(pat="^^^", expand=True, regex=False) - # To keep prediction aligned to joined h/l sequence, pad parent - h_parents = split_parents[0] + "NNN" - l_parents = split_parents[1] - - h_rates, h_csps = trimmed_shm_model_outputs_of_crepe(crepe, h_parents) - l_rates, l_csps = trimmed_shm_model_outputs_of_crepe(crepe, l_parents) - # Join predictions - pcp_df["nt_rates"] = [ - torch.cat([h_rate, l_rate], dim=0) for h_rate, l_rate in zip(h_rates, l_rates) - ] - pcp_df["nt_csps"] = [ - torch.cat([h_csp, l_csp], dim=0) for h_csp, l_csp in zip(h_csps, l_csps) - ] + pcp_df["nt_rates_h"], pcp_df["nt_csps_h"] = trimmed_shm_model_outputs_of_crepe(crepe, pcp_df["parent_h"]) + pcp_df["nt_rates_l"], pcp_df["nt_csps_l"] = trimmed_shm_model_outputs_of_crepe(crepe, pcp_df["parent_l"]) return pcp_df diff --git a/netam/sequences.py b/netam/sequences.py index 39dc71ad..6b1883ba 100644 --- a/netam/sequences.py +++ b/netam/sequences.py @@ -38,6 +38,7 @@ RESERVED_TOKEN_REGEX = f"[{''.join(map(re.escape, list(RESERVED_TOKENS)))}]" +# TODO maybe remove now? def token_regex_from_embedding_dim(embedding_dim: int) -> str: """Return a regex pattern that matches any token which cannot be handled by a model with the provided embedding dimension.""" @@ -49,6 +50,7 @@ def token_regex_from_embedding_dim(embedding_dim: int) -> str: return f"[{''.join(map(re.escape, list(unsupported_tokens)))}]" +# TODO maybe remove now? def strip_unrecognized_tokens_from_series( series: pd.Series, embedding_dim: int ) -> pd.Series: @@ -64,6 +66,70 @@ def strip_unrecognized_tokens_from_series( return series +def prepare_heavy_light_pair(heavy_seq, light_seq, known_token_count, is_nt=True): + """Prepare a pair of heavy and light chain sequences for model input. + + Args: + heavy_seq (str): The heavy chain sequence. + light_seq (str): The light chain sequence. + known_token_count (int): The number of tokens recognized by the model which will take the result as input. + is_nt (bool): Whether the sequences are nucleotide sequences. Otherwise, they + are assumed to be amino acid sequences. + Returns: + The prepared sequence, and a tuple of indices indicating positions where tokens were added to the prepared sequence.""" + # In the future, we'll define a list of functions that will be applied in + # order, up to the maximum number of accepted tokens. + if known_token_count > AA_AMBIG_IDX + 1: + if is_nt: + heavy_light_separator = "^^^" + else: + heavy_light_separator = "^" + + prepared_seq = heavy_seq + heavy_light_separator + light_seq + added_indices = tuple(range(len(heavy_seq), len(heavy_seq) + len(heavy_light_separator))) + else: + prepared_seq = heavy_seq + added_indices = tuple() + + return prepared_seq, added_indices + +def combine_and_pad_tensors(first, second, padding_idxs, fill=float("nan")): + res = torch.full((first.shape[0] + second.shape[0] + len(padding_idxs),) + first.shape[1:], fill) + mask = torch.full((res.shape[0],), True, dtype=torch.bool) + mask[torch.tensor(padding_idxs)] = False + res[mask] = torch.concat([first, second], dim=0) + return res + + +# TODO test +def dataset_inputs_of_pcp_df(pcp_df, known_token_count): + parents = [] + children = [] + nt_ratess = [] + nt_cspss = [] + for row in pcp_df.itertuples(): + parent, parent_token_idxs = prepare_heavy_light_pair( + row.parent_h, row.parent_l, known_token_count, is_nt=True + ) + child = prepare_heavy_light_pair(row.child_h, row.child_l, known_token_count, is_nt=True)[0] + # TODO It would be nice for these fill values to be nan, but there's + # lots of checking that would be made more difficult by that. These are + # the values that the neutral model returns when given N's. + nt_rates = combine_and_pad_tensors(row.nt_rates_h, row.nt_rates_l, parent_token_idxs, fill=1.0) + nt_csps = combine_and_pad_tensors(row.nt_csps_h, row.nt_csps_l, parent_token_idxs, fill=0.0) + parents.append(parent) + children.append(child) + nt_ratess.append(nt_rates) + nt_cspss.append(nt_csps) + + return tuple(map(pd.Series, ( + parents, + children, + nt_ratess, + nt_cspss, + ))) + + def nt_idx_array_of_str(nt_str): """Return the indices of the nucleotides in a string.""" try: diff --git a/tests/test_ambiguous.py b/tests/test_ambiguous.py index 1de0eb01..411d388a 100644 --- a/tests/test_ambiguous.py +++ b/tests/test_ambiguous.py @@ -122,10 +122,9 @@ def ambig_pcp_df(): "data/wyatt-10x-1p5m_pcp_2023-11-30_NI.first100.csv.gz", ) # Apply the random N adding function to each row - df[["parent", "child"]] = df.apply( + df[["parent_h", "child_h"]] = df.apply( lambda row: tuple( - seq + "^^^" - for seq in randomize_with_ns(row["parent"][:-3], row["child"][:-3]) + randomize_with_ns(row["parent_h"][:-3], row["child_h"][:-3]) ), axis=1, result_type="expand", diff --git a/tests/test_sequences.py b/tests/test_sequences.py index 6251b067..61534f32 100644 --- a/tests/test_sequences.py +++ b/tests/test_sequences.py @@ -11,6 +11,7 @@ TOKEN_STR_SORTED, CODONS, CODON_AA_INDICATOR_MATRIX, + MAX_EMBEDDING_DIM, aa_onehot_tensor_of_str, nt_idx_array_of_str, nt_subs_indicator_tensor_of, @@ -18,6 +19,8 @@ token_mask_of_aa_idxs, aa_idx_tensor_of_str, strip_unrecognized_tokens_from_series, + prepare_heavy_light_pair, + combine_and_pad_tensors, ) @@ -44,6 +47,35 @@ def test_strip_unrecognized_tokens_from_series(): seq = seq.replace(token, "") assert nseq == seq +def test_prepare_heavy_light_pair(): + heavy = "AGCGTC" + light = "AGCGTC" + for heavy, light in [ + ("AGCGTC", "AGCGTC"), + ("AGCGTC", ""), + ("", "AGCGTC"), + ]: + assert prepare_heavy_light_pair(heavy, light, MAX_EMBEDDING_DIM) == (heavy + "^^^" + light, tuple(range(len(heavy), len(heavy) + 3))) + + heavy = "QVQ" + light = "QVQ" + for heavy, light in [ + ("QVQ", "QVQ"), + ("QVQ", ""), + ("", "QVQ"), + ]: + assert prepare_heavy_light_pair(heavy, light, MAX_EMBEDDING_DIM, is_nt=False) == (heavy + "^" + light, tuple(range(len(heavy), len(heavy) + 1))) + + +def test_combine_and_pad_tensors(): + # Test that function works with 1d tensors: + t1 = torch.tensor([1, 2, 3], dtype=torch.float) + t2 = torch.tensor([4, 5, 6], dtype=torch.float) + idxs = (0, 4, 5) + result = combine_and_pad_tensors(t1, t2, idxs) + mask = result.isnan() + assert torch.equal(result[~mask], torch.tensor([1, 2, 3, 4, 5, 6], dtype=torch.float)) + assert all(mask[torch.tensor(idxs)]) def test_token_mask(): sample_aa_seq = "QYX^QC"