Skip to content

Commit

Permalink
fix dataset split
Browse files Browse the repository at this point in the history
  • Loading branch information
willdumm committed Jan 22, 2025
1 parent b598353 commit 9111b8b
Show file tree
Hide file tree
Showing 7 changed files with 7 additions and 20 deletions.
8 changes: 3 additions & 5 deletions netam/dasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,11 +218,9 @@ def build_selection_matrix_from_parent_aa(
per_aa_selection_factors, aa_parent_idxs.unsqueeze(0), fill=1.0
).squeeze(0)

# TODO I'm not sure if this is still used anywhere. It would be best to
# just have it take aa strings, but that's not what it did before, so I'm
# keeping the original behavior for now. (although, the old docstring
# claimed incorrectly that it took an aa sequence)
def build_selection_matrix_from_parent(self, parent: Tuple[str, str]):
# This is not used anywhere, except for in a few tests. Keeping it around
# for that reason.
def _build_selection_matrix_from_parent(self, parent: Tuple[str, str]):
"""Build a selection matrix from a parent nucleotide sequence, a heavy-chain,
light-chain pair.
Expand Down
2 changes: 1 addition & 1 deletion netam/dnsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def build_selection_matrix_from_parent_aa(
selection_factors, aa_parent_idxs
)

def build_selection_matrix_from_parent(self, parent: Tuple[str, str]):
def _build_selection_matrix_from_parent(self, parent: Tuple[str, str]):
"""Build a selection matrix from a nucleotide sequence.
Values at ambiguous sites are meaningless.
Expand Down
6 changes: 0 additions & 6 deletions netam/dxsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,12 +296,9 @@ def selection_factors_of_aa_idxs(self, aa_idxs, aa_mask):
aa_idxs and aa_mask are expected to be as prepared in the Dataset constructor.
"""

# We need the model to see special tokens here. For every other purpose
# they are masked out.
# TODO for testing
keep_token_mask = aa_mask | token_mask_of_aa_idxs(aa_idxs)
# keep_token_mask = aa_mask
return self.model(aa_idxs, keep_token_mask)

def _find_optimal_branch_length(
Expand All @@ -322,9 +319,6 @@ def _find_optimal_branch_length(
# Masks may be padded at end to account for sequences of different
# lengths. The first part of the mask up to parent length should be
# all the valid bits for the sequence.

# TODO Why does aa_mask length not match nt_rates length? Shouldn't
# they be padded by the same amount?
trimmed_aa_mask = aa_mask[: len(parent)]
log_pcp_probability = molevol.mutsel_log_pcp_probability_of(
sel_matrix[aa_mask],
Expand Down
1 change: 0 additions & 1 deletion netam/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,6 @@ def load_pcp_df(pcp_df_path_gz, sample_count=None, chosen_v_families=None):
pcp_df = standardize_heavy_light_columns(pcp_df)
if chosen_v_families is not None:
chosen_v_families = set(chosen_v_families)
# TODO is this the right way to handle this? Or should it be OR?
pcp_df = pcp_df[
pcp_df["v_family_h"].isin(chosen_v_families)
| pcp_df["v_family_l"].isin(chosen_v_families)
Expand Down
4 changes: 1 addition & 3 deletions netam/sequences.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,7 @@ def dataset_inputs_of_pcp_df(pcp_df, known_token_count):
child = prepare_heavy_light_pair(
row.child_h, row.child_l, known_token_count, is_nt=True
)[0]
# TODO It would be nice for these fill values to be nan, but there's
# lots of checking that would be made more difficult by that. These are
# the values that the neutral model returns when given N's.
# These are the fill values that the neutral model returns when given N's:
nt_rates = combine_and_pad_tensors(
row.nt_rates_h, row.nt_rates_l, parent_token_idxs, fill=1.0
)[: len(parent)]
Expand Down
4 changes: 1 addition & 3 deletions tests/test_dasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
torch.set_printoptions(precision=10)


# TODO verify that this loops through both pcp_dfs, even though one is named
# the same as the argument. If not, remember to fix in test_dnsm.py too.
@pytest.fixture(scope="module", params=["pcp_df", "pcp_df_paired"])
def dasm_burrito(pcp_df):
force_spawn()
Expand Down Expand Up @@ -146,7 +144,7 @@ def test_build_selection_matrix_from_parent(dasm_burrito):
parent_aa_idxs, aa_mask
)

indirect_val = dasm_burrito.build_selection_matrix_from_parent(
indirect_val = dasm_burrito._build_selection_matrix_from_parent(
(light_chain_seq, heavy_chain_seq)
)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_dnsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def test_build_selection_matrix_from_parent(dnsm_burrito):
parent_aa_idxs, aa_mask
)

indirect_val = dnsm_burrito.build_selection_matrix_from_parent(
indirect_val = dnsm_burrito._build_selection_matrix_from_parent(
(light_chain_seq, heavy_chain_seq)
)

Expand Down

0 comments on commit 9111b8b

Please sign in to comment.