Skip to content

Commit

Permalink
tests passing
Browse files Browse the repository at this point in the history
  • Loading branch information
willdumm committed Jan 16, 2025
1 parent 8492e99 commit f03e5d3
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 41 deletions.
29 changes: 18 additions & 11 deletions netam/dxsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
apply_aa_mask_to_nt_sequence,
nt_mutation_frequency,
strip_unrecognized_tokens_from_series,
dataset_inputs_of_pcp_df,
MAX_AA_TOKEN_IDX,
RESERVED_TOKEN_REGEX,
AA_AMBIG_IDX,
Expand All @@ -46,14 +47,13 @@ def __init__(
branch_lengths: torch.Tensor,
model_embedding_dim: int,
multihit_model=None,
# TODO For debugging:
succeed=False,
):
assert succeed, "Dataset should be created through other constructor"
#This is no longer needed here, but it seems like we should be able to verify what model version an instance is built for anyway:
self.model_embedding_dim = model_embedding_dim
nt_parents = strip_unrecognized_tokens_from_series(
nt_parents, self.model_embedding_dim
)
nt_children = strip_unrecognized_tokens_from_series(
nt_children, self.model_embedding_dim
)

# We will replace reserved tokens with Ns but use the unmodified
# originals for translation and mask creation.
self.nt_parents = nt_parents.str.replace(RESERVED_TOKEN_REGEX, "N", regex=True)
Expand Down Expand Up @@ -122,6 +122,7 @@ def of_seriess(
model_embedding_dim,
branch_length_multiplier=5.0,
multihit_model=None,
succeed=False,
):
"""Alternative constructor that takes the raw data and calculates the initial
branch lengths.
Expand All @@ -143,6 +144,7 @@ def of_seriess(
initial_branch_lengths,
model_embedding_dim,
multihit_model=multihit_model,
succeed=succeed,
)

@classmethod
Expand All @@ -156,16 +158,19 @@ def of_pcp_df(
"""Alternative constructor that takes in a pcp_df and calculates the initial
branch lengths."""
assert (
"nt_rates" in pcp_df.columns
"nt_rates_l" in pcp_df.columns
), "pcp_df must have a neutral nt_rates column"
# use sequences.prepare_heavy_light_pair and the resulting
# added_indices to get the parent and child sequences and neutral model
# outputs


return cls.of_seriess(
pcp_df["parent"],
pcp_df["child"],
pcp_df["nt_rates"],
pcp_df["nt_csps"],
*dataset_inputs_of_pcp_df(pcp_df, model_embedding_dim),
model_embedding_dim,
branch_length_multiplier=branch_length_multiplier,
multihit_model=multihit_model,
succeed=True,
)

@classmethod
Expand Down Expand Up @@ -212,6 +217,7 @@ def clone(self):
self._branch_lengths.copy(),
self.model_embedding_dim,
multihit_model=self.multihit_model,
succeed=True,
)
return new_dataset

Expand All @@ -230,6 +236,7 @@ def subset_via_indices(self, indices):
self._branch_lengths[indices],
self.model_embedding_dim,
multihit_model=self.multihit_model,
succeed=True,
)
return new_dataset

Expand Down
59 changes: 32 additions & 27 deletions netam/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,14 +357,11 @@ def trimmed_shm_model_outputs_of_crepe(crepe, parents):
return trimmed_rates, trimmed_csps


def join_chains(pcp_df):
"""Join the parent and child chains in the pcp_df.
Make a parent column that is the parent_h + "^^^" + parent_l, and same for child.
def standardize_heavy_light_columns(pcp_df):
"""Ensure that heavy and light chain columns are present, and fill missing ones
with placeholder values.
If parent_h and parent_l are not present, then we assume that the parent is the
heavy chain. If only one of parent_h or parent_l is present, then we place the ^^^
padding to the right of heavy, or to the left of light.
If only `parent` and `child` column is present, we assume these are heavy chain sequences.
"""
cols = pcp_df.columns
# Look for heavy chain
Expand All @@ -380,22 +377,44 @@ def join_chains(pcp_df):
else:
pcp_df["parent_h"] = ""
pcp_df["child_h"] = ""
pcp_df["v_gene_h"] = "N/A"
pcp_df["v_gene_h"] = ""
# Look for light chain
if "parent_l" in cols:
assert "child_l" in cols, "child_l column missing!"
assert "v_gene_l" in cols, "v_gene_l column missing!"
else:
pcp_df["parent_l"] = ""
pcp_df["child_l"] = ""
pcp_df["v_gene_l"] = "N/A"
pcp_df["v_gene_l"] = ""

if (pcp_df["parent_h"].str.len() + pcp_df["parent_l"].str.len()).min() < 3:
raise ValueError("At least one PCP has fewer than three nucleotides.")

pcp_df["parent"] = pcp_df["parent_h"] + "^^^" + pcp_df["parent_l"]
pcp_df["child"] = pcp_df["child_h"] + "^^^" + pcp_df["child_l"]

pcp_df["v_family_h"] = pcp_df["v_gene_h"].str.split("-").str[0]
pcp_df["v_family_l"] = pcp_df["v_gene_l"].str.split("-").str[0]

pcp_df.drop(
columns=["parent", "child", "v_gene"],
inplace=True,
errors="ignore",
)
return pcp_df


# TODO maybe call this from the dataset constructor now?
def join_chains(pcp_df):
"""Join the parent and child chains in the pcp_df.
Make a parent column that is the parent_h + "^^^" + parent_l, and same for child.
"""
pcp_df = pcp_df.copy()

pcp_df["parent"] = pcp_df["parent_h"] + "^^^" + pcp_df["parent_l"]
pcp_df["child"] = pcp_df["child_h"] + "^^^" + pcp_df["child_l"]

pcp_df.drop(
columns=["parent_h", "parent_l", "child_h", "child_l", "v_gene"],
inplace=True,
Expand All @@ -419,12 +438,11 @@ def load_pcp_df(pcp_df_path_gz, sample_count=None, chosen_v_families=None):
.reset_index()
.rename(columns={"index": "orig_pcp_idx"})
)
pcp_df = join_chains(pcp_df)

pcp_df["v_family_h"] = pcp_df["v_gene_h"].str.split("-").str[0]
pcp_df["v_family_l"] = pcp_df["v_gene_l"].str.split("-").str[0]
pcp_df = standardize_heavy_light_columns(pcp_df)
if chosen_v_families is not None:
chosen_v_families = set(chosen_v_families)
# TODO is this the right way to handle this? Or should it be OR?
pcp_df = pcp_df[
pcp_df["v_family_h"].isin(chosen_v_families)
& pcp_df["v_family_l"].isin(chosen_v_families)
Expand All @@ -436,21 +454,8 @@ def load_pcp_df(pcp_df_path_gz, sample_count=None, chosen_v_families=None):


def add_shm_model_outputs_to_pcp_df(pcp_df, crepe):
# Split parent heavy and light chains to apply neutral model separately
split_parents = pcp_df["parent"].str.split(pat="^^^", expand=True, regex=False)
# To keep prediction aligned to joined h/l sequence, pad parent
h_parents = split_parents[0] + "NNN"
l_parents = split_parents[1]

h_rates, h_csps = trimmed_shm_model_outputs_of_crepe(crepe, h_parents)
l_rates, l_csps = trimmed_shm_model_outputs_of_crepe(crepe, l_parents)
# Join predictions
pcp_df["nt_rates"] = [
torch.cat([h_rate, l_rate], dim=0) for h_rate, l_rate in zip(h_rates, l_rates)
]
pcp_df["nt_csps"] = [
torch.cat([h_csp, l_csp], dim=0) for h_csp, l_csp in zip(h_csps, l_csps)
]
pcp_df["nt_rates_h"], pcp_df["nt_csps_h"] = trimmed_shm_model_outputs_of_crepe(crepe, pcp_df["parent_h"])
pcp_df["nt_rates_l"], pcp_df["nt_csps_l"] = trimmed_shm_model_outputs_of_crepe(crepe, pcp_df["parent_l"])
return pcp_df


Expand Down
66 changes: 66 additions & 0 deletions netam/sequences.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
RESERVED_TOKEN_REGEX = f"[{''.join(map(re.escape, list(RESERVED_TOKENS)))}]"


# TODO maybe remove now?
def token_regex_from_embedding_dim(embedding_dim: int) -> str:
"""Return a regex pattern that matches any token which cannot be handled by a model
with the provided embedding dimension."""
Expand All @@ -49,6 +50,7 @@ def token_regex_from_embedding_dim(embedding_dim: int) -> str:
return f"[{''.join(map(re.escape, list(unsupported_tokens)))}]"


# TODO maybe remove now?
def strip_unrecognized_tokens_from_series(
series: pd.Series, embedding_dim: int
) -> pd.Series:
Expand All @@ -64,6 +66,70 @@ def strip_unrecognized_tokens_from_series(
return series


def prepare_heavy_light_pair(heavy_seq, light_seq, known_token_count, is_nt=True):
"""Prepare a pair of heavy and light chain sequences for model input.
Args:
heavy_seq (str): The heavy chain sequence.
light_seq (str): The light chain sequence.
known_token_count (int): The number of tokens recognized by the model which will take the result as input.
is_nt (bool): Whether the sequences are nucleotide sequences. Otherwise, they
are assumed to be amino acid sequences.
Returns:
The prepared sequence, and a tuple of indices indicating positions where tokens were added to the prepared sequence."""
# In the future, we'll define a list of functions that will be applied in
# order, up to the maximum number of accepted tokens.
if known_token_count > AA_AMBIG_IDX + 1:
if is_nt:
heavy_light_separator = "^^^"
else:
heavy_light_separator = "^"

prepared_seq = heavy_seq + heavy_light_separator + light_seq
added_indices = tuple(range(len(heavy_seq), len(heavy_seq) + len(heavy_light_separator)))
else:
prepared_seq = heavy_seq
added_indices = tuple()

return prepared_seq, added_indices

def combine_and_pad_tensors(first, second, padding_idxs, fill=float("nan")):
res = torch.full((first.shape[0] + second.shape[0] + len(padding_idxs),) + first.shape[1:], fill)
mask = torch.full((res.shape[0],), True, dtype=torch.bool)
mask[torch.tensor(padding_idxs)] = False
res[mask] = torch.concat([first, second], dim=0)
return res


# TODO test
def dataset_inputs_of_pcp_df(pcp_df, known_token_count):
parents = []
children = []
nt_ratess = []
nt_cspss = []
for row in pcp_df.itertuples():
parent, parent_token_idxs = prepare_heavy_light_pair(
row.parent_h, row.parent_l, known_token_count, is_nt=True
)
child = prepare_heavy_light_pair(row.child_h, row.child_l, known_token_count, is_nt=True)[0]
# TODO It would be nice for these fill values to be nan, but there's
# lots of checking that would be made more difficult by that. These are
# the values that the neutral model returns when given N's.
nt_rates = combine_and_pad_tensors(row.nt_rates_h, row.nt_rates_l, parent_token_idxs, fill=1.0)
nt_csps = combine_and_pad_tensors(row.nt_csps_h, row.nt_csps_l, parent_token_idxs, fill=0.0)
parents.append(parent)
children.append(child)
nt_ratess.append(nt_rates)
nt_cspss.append(nt_csps)

return tuple(map(pd.Series, (
parents,
children,
nt_ratess,
nt_cspss,
)))


def nt_idx_array_of_str(nt_str):
"""Return the indices of the nucleotides in a string."""
try:
Expand Down
5 changes: 2 additions & 3 deletions tests/test_ambiguous.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,9 @@ def ambig_pcp_df():
"data/wyatt-10x-1p5m_pcp_2023-11-30_NI.first100.csv.gz",
)
# Apply the random N adding function to each row
df[["parent", "child"]] = df.apply(
df[["parent_h", "child_h"]] = df.apply(
lambda row: tuple(
seq + "^^^"
for seq in randomize_with_ns(row["parent"][:-3], row["child"][:-3])
randomize_with_ns(row["parent_h"][:-3], row["child_h"][:-3])
),
axis=1,
result_type="expand",
Expand Down
32 changes: 32 additions & 0 deletions tests/test_sequences.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@
TOKEN_STR_SORTED,
CODONS,
CODON_AA_INDICATOR_MATRIX,
MAX_EMBEDDING_DIM,
aa_onehot_tensor_of_str,
nt_idx_array_of_str,
nt_subs_indicator_tensor_of,
translate_sequences,
token_mask_of_aa_idxs,
aa_idx_tensor_of_str,
strip_unrecognized_tokens_from_series,
prepare_heavy_light_pair,
combine_and_pad_tensors,
)


Expand All @@ -44,6 +47,35 @@ def test_strip_unrecognized_tokens_from_series():
seq = seq.replace(token, "")
assert nseq == seq

def test_prepare_heavy_light_pair():
heavy = "AGCGTC"
light = "AGCGTC"
for heavy, light in [
("AGCGTC", "AGCGTC"),
("AGCGTC", ""),
("", "AGCGTC"),
]:
assert prepare_heavy_light_pair(heavy, light, MAX_EMBEDDING_DIM) == (heavy + "^^^" + light, tuple(range(len(heavy), len(heavy) + 3)))

heavy = "QVQ"
light = "QVQ"
for heavy, light in [
("QVQ", "QVQ"),
("QVQ", ""),
("", "QVQ"),
]:
assert prepare_heavy_light_pair(heavy, light, MAX_EMBEDDING_DIM, is_nt=False) == (heavy + "^" + light, tuple(range(len(heavy), len(heavy) + 1)))


def test_combine_and_pad_tensors():
# Test that function works with 1d tensors:
t1 = torch.tensor([1, 2, 3], dtype=torch.float)
t2 = torch.tensor([4, 5, 6], dtype=torch.float)
idxs = (0, 4, 5)
result = combine_and_pad_tensors(t1, t2, idxs)
mask = result.isnan()
assert torch.equal(result[~mask], torch.tensor([1, 2, 3, 4, 5, 6], dtype=torch.float))
assert all(mask[torch.tensor(idxs)])

def test_token_mask():
sample_aa_seq = "QYX^QC"
Expand Down

0 comments on commit f03e5d3

Please sign in to comment.