Skip to content

Commit

Permalink
feat: more compact JSON call representation
Browse files Browse the repository at this point in the history
  • Loading branch information
davidlougheed committed Jul 30, 2022
1 parent ed3728f commit 77b9aa9
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 27 deletions.
2 changes: 1 addition & 1 deletion strkit/call/allele.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def _na_length_list():
medians_of_means_final = np.rint(medians_of_means).astype(np.int32)
medians_of_weights = np.percentile(allele_weight_samples, 50, axis=1, interpolation="nearest")
medians_of_stdevs = np.percentile(allele_stdev_samples, 50, axis=1, interpolation="nearest")
modal_n_peaks = statistics.mode(sample_peaks)
modal_n_peaks = statistics.mode(sample_peaks).item()

return {
"call": medians_of_means_final.flatten(),
Expand Down
32 changes: 17 additions & 15 deletions strkit/call/caller.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,8 +395,7 @@ def call_locus(
if r_offset > 0:
right_coord += max(0, r_offset)

read_cn_dict = {}
read_weight_dict = {}
read_dict = {}

overlapping_segments = [
segment
Expand Down Expand Up @@ -494,8 +493,6 @@ def call_locus(
log_debug(f"Skipping read {segment.query_name} (scored {read_adj_score} < {min_read_score})")
continue

read_cn_dict[segment.query_name] = read_cn

# When we don't have targeted sequencing, the probability of a read containing the TR region, given that it
# overlaps the region, is P(read is large enough to contain) * P( # TODO: complete this..
partition_idx = np.searchsorted(sorted_read_lengths, tr_len_w_flank, side="right")
Expand All @@ -506,18 +503,23 @@ def call_locus(
f"Something strange happened; could not find an encompassing read where one should be guaranteed. "
f"TRF row: {t}; TR length with flank: {tr_len_w_flank}; read lengths: {sorted_read_lengths}")
exit(1)

mean_containing_size = read_len if targeted else np.mean(sorted_read_lengths[partition_idx:])
# TODO: re-examine weighting to possibly incorporate chance of drawing read large enough
read_weight_dict[segment.query_name] = (
(mean_containing_size + tr_len_w_flank - 2) / (mean_containing_size - tr_len_w_flank + 1))
read_weight = (mean_containing_size + tr_len_w_flank - 2) / (mean_containing_size - tr_len_w_flank + 1)

read_dict[segment.query_name] = {
"cn": read_cn,
"weight": read_weight,
}

n_alleles = get_n_alleles(2, sex_chroms, contig)
if n_alleles is None:
return None

# Dicts are ordered in Python; very nice :)
read_cns = np.fromiter(read_cn_dict.values(), dtype=np.float if fractional else np.int)
read_weights = np.fromiter(read_weight_dict.values(), dtype=np.float)
read_cns = np.fromiter((r["cn"] for r in read_dict.values()), dtype=np.float if fractional else np.int)
read_weights = np.fromiter((r["weight"] for r in read_dict.values()), dtype=np.float)
read_weights = read_weights / np.sum(read_weights) # Normalize to probabilities

call = call_alleles(
Expand All @@ -538,10 +540,10 @@ def call_locus(
call_stdevs = call.get("peak_stdevs")
call_modal_n = call.get("modal_n_peaks")

read_peak_labels = None
# We cannot call read-level cluster labels with >2 peaks;
# don't know how re-sampling has occurred.
if call_modal_n and call_modal_n <= 2:
read_peaks_called = call_modal_n and call_modal_n <= 2
if read_peaks_called:
ws = call_weights[:call_modal_n]
final_model = GaussianMixture(
n_components=call_modal_n,
Expand All @@ -552,7 +554,8 @@ def call_locus(
precisions_init=1 / (call_stdevs[:call_modal_n] ** 2), # TODO: Check, this looks wrong
)
res = final_model.fit_predict(read_cns.reshape(-1, 1))
read_peak_labels = {k: int(v) for k, v in zip(read_cn_dict.keys(), res)}
for k, v in zip(read_dict.keys(), res):
read_dict[k]["peak"] = v.item()

def _round_to_base_pos(x) -> float:
return round(float(x) * motif_size) / motif_size
Expand All @@ -577,11 +580,10 @@ def _nested_ndarray_serialize(x: Iterable) -> List[List[Union[int, float, np.int
"means": call_peaks.tolist(), # from np.ndarray
"weights": call_weights.tolist(), # from np.ndarray
"stdevs": call_stdevs.tolist(), # from np.ndarray
"modal_n": call_modal_n.item(), # from np.int64
"modal_n": call_modal_n,
} if call else None,
"read_cns": read_cn_dict,
"read_weights": read_weight_dict,
"read_peak_labels": read_peak_labels,
"reads": read_dict,
"read_peaks_called": read_peaks_called,
}


Expand Down
8 changes: 4 additions & 4 deletions strkit/mi/strkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,21 +126,21 @@ def _get_sample_contigs(self, include_sex_chromosomes: bool = False) -> Tuple[se
def get_read_counts(res: dict, dtype=int):
# TODO: This only works with diploids...

read_cns = res["read_cns"]
read_peaks = res["read_peak_labels"]
read_cns = [r["cn"] for r in res["reads"].values()]
read_peaks = [r["peak"] for r in res["reads"].values()]

n = res["peaks"]["modal_n"]

if n < 2 or len(set(res["call"])) == 1:
rcs = np.fromiter(read_cns.values(), dtype=dtype)
rcs = np.array(read_cns, dtype=dtype)
np.random.shuffle(rcs) # TODO: seed shuffle
part = rcs.shape[0] // 2
return tuple(rcs[:part].tolist()), tuple(rcs[part:].tolist())

rc = []
for _ in range(n):
rc.append([])
for r, cn in read_cns.items():
for r, cn in enumerate(read_cns):
rc[read_peaks[r]].append(cn)
return tuple(map(tuple, rc))

Expand Down
18 changes: 11 additions & 7 deletions strkit/viz/templates/browser.html
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,12 @@
* ref_cn: number,
* call: number[],
* call_95_cis: Array.<number[]>,
* read_cns: Object.<string, number>,
* read_peak_labels: Object.<string, number>,
* reads: {
* cn: number,
* weight: number,
* peak: number,
* },
* read_peaks_called: boolean,
* peaks: {
* means: number[],
* weights: number[],
Expand All @@ -235,7 +239,7 @@
console.log("call data", callData);

/** @type number[] */
const cns = Object.values(callData.read_cns);
const cns = Object.values(callData.reads).map(read => read.cn);
const thresholds = Math.min(100, Math.max(...cns) - Math.min(...cns));

histogramContainer.innerHTML = "";
Expand All @@ -261,15 +265,15 @@
call95Display.innerText = callData.call_95_cis
.map(ci => "(" + ci.map(c => c.toFixed(params.fractional ? 1 : 0)).join("-") + ")").join(" | ");

const getReadsForPeak = q => Object.entries(callData.read_peak_labels)
.filter(e => e[1] === q)
.map(e => callData.read_cns[e[0]]);
const getReadsForPeak = q => Object.entries(callData.reads)
.filter(e => e[1]["peak"] === q)
.map(e => e[1]["cn"]);

scalesContainer.innerHTML = "";
igvContainer.innerHTML = ""; // TODO: properly deconstruct old browser

const readColours = (() => {
if (!callData.read_peak_labels || !Object.entries(callData.read_peak_labels).length) return {};
if (!callData.read_peaks_called || !Object.entries(callData.reads).length) return {};

const bounds = (() => {
if (callData.peaks.means.length > 1) {
Expand Down

0 comments on commit 77b9aa9

Please sign in to comment.