From 77b9aa9c821331e0d756d709c212133903f450e5 Mon Sep 17 00:00:00 2001 From: David Lougheed Date: Sat, 30 Jul 2022 10:01:05 -0400 Subject: [PATCH] feat: more compact JSON call representation --- strkit/call/allele.py | 2 +- strkit/call/caller.py | 32 ++++++++++++++++--------------- strkit/mi/strkit.py | 8 ++++---- strkit/viz/templates/browser.html | 18 ++++++++++------- 4 files changed, 33 insertions(+), 27 deletions(-) diff --git a/strkit/call/allele.py b/strkit/call/allele.py index 4d516eb..e2eb07b 100644 --- a/strkit/call/allele.py +++ b/strkit/call/allele.py @@ -227,7 +227,7 @@ def _na_length_list(): medians_of_means_final = np.rint(medians_of_means).astype(np.int32) medians_of_weights = np.percentile(allele_weight_samples, 50, axis=1, interpolation="nearest") medians_of_stdevs = np.percentile(allele_stdev_samples, 50, axis=1, interpolation="nearest") - modal_n_peaks = statistics.mode(sample_peaks) + modal_n_peaks = statistics.mode(sample_peaks).item() return { "call": medians_of_means_final.flatten(), diff --git a/strkit/call/caller.py b/strkit/call/caller.py index 66a58f5..72fcfaf 100644 --- a/strkit/call/caller.py +++ b/strkit/call/caller.py @@ -395,8 +395,7 @@ def call_locus( if r_offset > 0: right_coord += max(0, r_offset) - read_cn_dict = {} - read_weight_dict = {} + read_dict = {} overlapping_segments = [ segment @@ -494,8 +493,6 @@ def call_locus( log_debug(f"Skipping read {segment.query_name} (scored {read_adj_score} < {min_read_score})") continue - read_cn_dict[segment.query_name] = read_cn - # When we don't have targeted sequencing, the probability of a read containing the TR region, given that it # overlaps the region, is P(read is large enough to contain) * P( # TODO: complete this.. partition_idx = np.searchsorted(sorted_read_lengths, tr_len_w_flank, side="right") @@ -506,18 +503,23 @@ def call_locus( f"Something strange happened; could not find an encompassing read where one should be guaranteed. " f"TRF row: {t}; TR length with flank: {tr_len_w_flank}; read lengths: {sorted_read_lengths}") exit(1) + mean_containing_size = read_len if targeted else np.mean(sorted_read_lengths[partition_idx:]) # TODO: re-examine weighting to possibly incorporate chance of drawing read large enough - read_weight_dict[segment.query_name] = ( - (mean_containing_size + tr_len_w_flank - 2) / (mean_containing_size - tr_len_w_flank + 1)) + read_weight = (mean_containing_size + tr_len_w_flank - 2) / (mean_containing_size - tr_len_w_flank + 1) + + read_dict[segment.query_name] = { + "cn": read_cn, + "weight": read_weight, + } n_alleles = get_n_alleles(2, sex_chroms, contig) if n_alleles is None: return None # Dicts are ordered in Python; very nice :) - read_cns = np.fromiter(read_cn_dict.values(), dtype=np.float if fractional else np.int) - read_weights = np.fromiter(read_weight_dict.values(), dtype=np.float) + read_cns = np.fromiter((r["cn"] for r in read_dict.values()), dtype=np.float if fractional else np.int) + read_weights = np.fromiter((r["weight"] for r in read_dict.values()), dtype=np.float) read_weights = read_weights / np.sum(read_weights) # Normalize to probabilities call = call_alleles( @@ -538,10 +540,10 @@ def call_locus( call_stdevs = call.get("peak_stdevs") call_modal_n = call.get("modal_n_peaks") - read_peak_labels = None # We cannot call read-level cluster labels with >2 peaks; # don't know how re-sampling has occurred. - if call_modal_n and call_modal_n <= 2: + read_peaks_called = call_modal_n and call_modal_n <= 2 + if read_peaks_called: ws = call_weights[:call_modal_n] final_model = GaussianMixture( n_components=call_modal_n, @@ -552,7 +554,8 @@ def call_locus( precisions_init=1 / (call_stdevs[:call_modal_n] ** 2), # TODO: Check, this looks wrong ) res = final_model.fit_predict(read_cns.reshape(-1, 1)) - read_peak_labels = {k: int(v) for k, v in zip(read_cn_dict.keys(), res)} + for k, v in zip(read_dict.keys(), res): + read_dict[k]["peak"] = v.item() def _round_to_base_pos(x) -> float: return round(float(x) * motif_size) / motif_size @@ -577,11 +580,10 @@ def _nested_ndarray_serialize(x: Iterable) -> List[List[Union[int, float, np.int "means": call_peaks.tolist(), # from np.ndarray "weights": call_weights.tolist(), # from np.ndarray "stdevs": call_stdevs.tolist(), # from np.ndarray - "modal_n": call_modal_n.item(), # from np.int64 + "modal_n": call_modal_n, } if call else None, - "read_cns": read_cn_dict, - "read_weights": read_weight_dict, - "read_peak_labels": read_peak_labels, + "reads": read_dict, + "read_peaks_called": read_peaks_called, } diff --git a/strkit/mi/strkit.py b/strkit/mi/strkit.py index 3319f05..d36b1ab 100644 --- a/strkit/mi/strkit.py +++ b/strkit/mi/strkit.py @@ -126,13 +126,13 @@ def _get_sample_contigs(self, include_sex_chromosomes: bool = False) -> Tuple[se def get_read_counts(res: dict, dtype=int): # TODO: This only works with diploids... - read_cns = res["read_cns"] - read_peaks = res["read_peak_labels"] + read_cns = [r["cn"] for r in res["reads"].values()] + read_peaks = [r["peak"] for r in res["reads"].values()] n = res["peaks"]["modal_n"] if n < 2 or len(set(res["call"])) == 1: - rcs = np.fromiter(read_cns.values(), dtype=dtype) + rcs = np.array(read_cns, dtype=dtype) np.random.shuffle(rcs) # TODO: seed shuffle part = rcs.shape[0] // 2 return tuple(rcs[:part].tolist()), tuple(rcs[part:].tolist()) @@ -140,7 +140,7 @@ def get_read_counts(res: dict, dtype=int): rc = [] for _ in range(n): rc.append([]) - for r, cn in read_cns.items(): + for r, cn in enumerate(read_cns): rc[read_peaks[r]].append(cn) return tuple(map(tuple, rc)) diff --git a/strkit/viz/templates/browser.html b/strkit/viz/templates/browser.html index 5bfcd4d..112880e 100644 --- a/strkit/viz/templates/browser.html +++ b/strkit/viz/templates/browser.html @@ -221,8 +221,12 @@ * ref_cn: number, * call: number[], * call_95_cis: Array., - * read_cns: Object., - * read_peak_labels: Object., + * reads: { + * cn: number, + * weight: number, + * peak: number, + * }, + * read_peaks_called: boolean, * peaks: { * means: number[], * weights: number[], @@ -235,7 +239,7 @@ console.log("call data", callData); /** @type number[] */ - const cns = Object.values(callData.read_cns); + const cns = Object.values(callData.reads).map(read => read.cn); const thresholds = Math.min(100, Math.max(...cns) - Math.min(...cns)); histogramContainer.innerHTML = ""; @@ -261,15 +265,15 @@ call95Display.innerText = callData.call_95_cis .map(ci => "(" + ci.map(c => c.toFixed(params.fractional ? 1 : 0)).join("-") + ")").join(" | "); - const getReadsForPeak = q => Object.entries(callData.read_peak_labels) - .filter(e => e[1] === q) - .map(e => callData.read_cns[e[0]]); + const getReadsForPeak = q => Object.entries(callData.reads) + .filter(e => e[1]["peak"] === q) + .map(e => e[1]["cn"]); scalesContainer.innerHTML = ""; igvContainer.innerHTML = ""; // TODO: properly deconstruct old browser const readColours = (() => { - if (!callData.read_peak_labels || !Object.entries(callData.read_peak_labels).length) return {}; + if (!callData.read_peaks_called || !Object.entries(callData.reads).length) return {}; const bounds = (() => { if (callData.peaks.means.length > 1) {