Skip to content

Commit

Permalink
perf(call): minor call optimizations
Browse files Browse the repository at this point in the history
  • Loading branch information
davidlougheed committed Jun 6, 2024
1 parent 3d4817e commit 429481b
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 16 deletions.
13 changes: 7 additions & 6 deletions strkit/call/allele.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
expansion_ratio = 5
N_GM_INIT = 3

WEIGHT_1_0 = np.array([[1.0]])
FLOAT_32_EPSILON = np.finfo(np.float32).eps

CI_PERCENTILE_RANGES = {
Expand Down Expand Up @@ -101,7 +102,7 @@ def fit_gmm(
# I've confirmed this gives an ~identical result to fitting a GMM with one parameter.
fake_g: object = type("", (), {})()
fake_g.means_ = np.array([[np.mean(sample_rs)]])
fake_g.weights_ = np.array([[1.0]])
fake_g.weights_ = WEIGHT_1_0
fake_g.covariances_ = np.array([[np.var(sample_rs)]])
return fake_g

Expand Down Expand Up @@ -198,18 +199,18 @@ def call_alleles(
logger_.debug(f"{debug_str} - skipping bootstrap / GMM fitting for allele(s) (single value)")
cn = combined_reads[0]

call = _array_as_int(np.array([cn] * n_alleles))
call_cis = _array_as_int(np.array([[cn, cn] for _ in range(n_alleles)]))
call = _array_as_int(np.full(n_alleles, cn))
call_cis = _array_as_int(np.full((n_alleles, 2), cn))

peaks: NDArray[np.float_] = np.array([cn] * n_alleles, dtype=np.float_)
peaks: NDArray[np.float_] = call.astype(np.float_)

return {
"call": call,
"call_95_cis": call_cis,
"call_99_cis": call_cis,
"peaks": peaks,
"peak_weights": np.array([1.0] * n_alleles) / n_alleles,
"peak_stdevs": np.array([0.0] * n_alleles),
"peak_weights": np.full(n_alleles, 1.0 / n_alleles),
"peak_stdevs": np.full(n_alleles, 0.0),
"modal_n_peaks": 1, # 1 peak, since we have 1 value
}

Expand Down
20 changes: 10 additions & 10 deletions strkit/call/call_locus.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,13 @@ def debug_log_flanking_seq(logger_: logging.Logger, locus_log_str: str, rn: str,
f"{' (post-realignment)' if realigned else ''}")


def _ndarray_serialize(x: Iterable) -> list[Union[int, np.int_]]:
return list(map(round, x))

def _nested_ndarray_serialize(x: Iterable) -> list[list[Union[int, np.int_]]]:
return list(map(_ndarray_serialize, x))


def call_locus(
t_idx: int,
t: tuple,
Expand Down Expand Up @@ -1028,8 +1035,7 @@ def call_locus(

if count_kmers != "none":
read_kmers.clear()
for i in range(0, tr_len - motif_size + 1):
read_kmers.update((tr_read_seq_wc[i:i+motif_size],))
read_kmers.update(tr_read_seq_wc[i:i+motif_size] for i in range(0, tr_len - motif_size + 1))

read_cn, read_cn_score = get_repeat_count(
start_count=round(tr_len / motif_size), # Set initial integer copy number based on aligned TR size
Expand Down Expand Up @@ -1172,8 +1178,8 @@ def call_locus(
have_rare_realigns: bool = False
for rn, read in read_dict_items:
read_cn = read["cn"]
n_same_cn_no_realign = sum(1 for _, r2 in read_dict_items if not r2.get("realn") and r2["cn"] == read_cn)
if read.get("realn") and n_same_cn_no_realign == 0:
if (read.get("realn") and
sum(1 for _, r2 in read_dict_items if not r2.get("realn") and r2["cn"] == read_cn) == 0):
have_rare_realigns = True
break

Expand Down Expand Up @@ -1390,12 +1396,6 @@ def call_locus(

# Compile the call into a dictionary with all information to return ------------------------------------------------

def _ndarray_serialize(x: Iterable) -> list[Union[int, np.int_]]:
return list(map(round, x))

def _nested_ndarray_serialize(x: Iterable) -> list[list[Union[int, np.int_]]]:
return list(map(_ndarray_serialize, x))

call_val = apply_or_none(_ndarray_serialize, call)
call_95_cis_val = apply_or_none(_nested_ndarray_serialize, call_95_cis)

Expand Down

0 comments on commit 429481b

Please sign in to comment.