diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index 3926767f8b..8b474192dd 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -132,15 +132,28 @@ def add(self: "BitSet", row: int, bit: int) -> None: def get_items(self: "BitSet", row: int) -> Generator[int, None, None]: """Get the items stored in the row of a bitset + Uses a de Bruijn sequence lookup table to determine the lowest bit set. + See the wikipedia article for more info: https://w.wiki/BYiF :param row: Row from the array to list from. :returns: A generator of integers stored in the array. """ + lookup = [0, 1, 28, 2, 29, 14, 24, 3, 30, 22, 20, 15, 25, 17, 4, 8, 31, 27, + 13, 23, 21, 19, 16, 7, 26, 12, 18, 6, 11, 5, 10, 9] # fmt: skip + m = np.uint32(125613361) offset = row * self.row_len for i in range(self.row_len): - for item in range(self.CHUNK_SIZE): - if self.data[i + offset] & (self.DTYPE(1) << item): - yield item + (i * self.CHUNK_SIZE) + v = self.data[i + offset] + if v == 0: + continue + else: + # v & -v operations rely on integer overflow + with np.errstate(over="ignore"): + lsb = v & -v # isolate the least significant bit + while lsb: # while there are bits remaining + yield lookup[(lsb * m) >> 27] + (i * self.CHUNK_SIZE) + v ^= lsb # unset the lsb + lsb = v & -v def contains(self: "BitSet", row: int, bit: int) -> bool: """Test if a bit is contained within a bit array row @@ -1561,7 +1574,6 @@ def advance(self, index): def compute_branch_stat_update( c, - child_samples, A_state, B_state, state_dim, @@ -1572,19 +1584,11 @@ def compute_branch_stat_update( params, ): """Compute an update to the two-locus statistic for a single subset of the - tree being modified, relative to all subsets of the fixed tree. We perform - this operation for all samples edge being modified. For subsequent parent - nodes, we update the statistic by removing the existing contribution after - adding in the update contribution. - - i.e. if we're adding two samples ({3, 4}) to a node, if the parent node - contains {1, 2}, we first add the statistic for {1, 2, 3, 4}, then - subtract the stat for {1, 2}. + tree being modified, relative to all subsets of the fixed tree. :param c: Child node of the edge we're modifying - :param child_samples: Samples under the edge being added/removed :param A_state: State for the tree contributing to the A samples (fixed) - :param A_state: State for the tree contributing to the B samples (modified) + :param B_state: State for the tree contributing to the B samples (modified) :param state_dim: Number of sample sets. :param sign: The sign of the update :param stat_func: Function used to compute the two-locus statistic @@ -1597,7 +1601,6 @@ def compute_branch_stat_update( return result AB_samples = BitSet(num_samples, 1) - node_samples_tmp = BitSet(num_samples, 1) weights = np.zeros((3, state_dim), dtype=np.int64) result_tmp = np.zeros(state_dim, np.float64) @@ -1621,32 +1624,16 @@ def compute_branch_stat_update( for k in range(state_dim): result[k] += result_tmp[k] * a_len * b_len - # If we've begun our walk up the parents of the current edge removal, we - # must adjust the statistic for samples that were already present before - # addition or that remain after removal. - if child_samples is not None: - for k in range(state_dim): - row = (state_dim * n) + k - c_row = (state_dim * c) + k - node_samples_tmp.union(0, B_state.node_samples, c_row) - node_samples_tmp.difference(0, child_samples, k) - AB_samples.data[:] = 0 # Zero out the bitset so that we can reuse it - A_state.node_samples.intersect(row, node_samples_tmp, 0, AB_samples) - - w_AB = AB_samples.count(0) - w_A = A_state.node_samples.count(row) - w_B = node_samples_tmp.count(0) - - weights[0, k] = w_AB - weights[1, k] = w_A - w_AB # w_Ab - weights[2, k] = w_B - w_AB # w_aB - - stat_func(state_dim, weights, result_tmp, params) - for k in range(state_dim): - result[k] -= result_tmp[k] * a_len * b_len - -def compute_branch_stat(ts, stat_func, stat, params, state_dim, l_state, r_state): +def compute_branch_stat( + ts: tskit.TreeSequence, + stat_func, + stat, + params, + state_dim, + l_state: TreeState, + r_state: TreeState, +): """Step between trees in a tree sequence, updating our two-locus statistic as we add or remove edges. Since we're computing statistics for two loci, we have a focal tree that remains constant, and a tree that is updated to @@ -1673,89 +1660,63 @@ def compute_branch_stat(ts, stat_func, stat, params, state_dim, l_state, r_state :returns: A tuple containing the statistic between the two trees after branch updates and the righthand tree state. """ + num_samples = ts.num_samples time = ts.tables.nodes.time + updates = BitSet(ts.num_nodes, 1) - child_samples = BitSet(ts.num_samples, state_dim) - for e in r_state.edges_out: + # Identify modified nodes + for e in r_state.edges_out + r_state.edges_in: p = ts.edges_parent[e] c = ts.edges_child[e] - child_samples.data[:] = 0 - for k in range(state_dim): - c_row = (state_dim * c) + k - child_samples.union(k, r_state.node_samples, c_row) - - # Remove the LD contributed by the samples under removed edges. When - # we walk up the tree to propagate these changes to parents of the - # removed edge, we need to add back in the LD contributed by samples - # that aren't removed. We remove samples from the parents of the removed - # branch as we propagate changes upward - in_parent = None + # identify affected nodes above child while p != tskit.NULL: - compute_branch_stat_update( - c, - in_parent, - l_state, - r_state, - state_dim, - -1, - stat_func, - ts.num_samples, - stat, - params, - ) - if in_parent is not None: - # remove samples from the parents of the branch being removed - # we remove the child node after the first iteration - for k in range(state_dim): - c_row = (state_dim * c) + k - r_state.node_samples.difference(c_row, child_samples, k) - in_parent = child_samples + updates.add(0, c) c = p p = r_state.parent[p] - for k in range(state_dim): - c_row = (state_dim * c) + k - r_state.node_samples.difference(c_row, child_samples, k) - # reset to the child of the edge being removed. - c = ts.edges_child[e] - r_state.branch_len[c] = 0 - r_state.parent[c] = tskit.NULL + # Subtract the whole contribution from child node + for c in updates.get_items(0): + compute_branch_stat_update( + c, l_state, r_state, state_dim, -1, stat_func, num_samples, stat, params + ) + + # Sample Removal + for e in r_state.edges_out: + p = ts.edges_parent[e] + ec = ts.edges_child[e] + # update samples under nodes, propagate upwards + while p != tskit.NULL: + for k in range(state_dim): + r_state.node_samples.difference( + state_dim * p + k, r_state.node_samples, state_dim * ec + k + ) + p = r_state.parent[p] + # set the parent to prevent upwards iteration + r_state.branch_len[ec] = 0 + r_state.parent[ec] = tskit.NULL + # Sample Addition for e in r_state.edges_in: p = ts.edges_parent[e] - c = ts.edges_child[e] - child_samples.data[:] = 0 - for k in range(state_dim): - c_row = (state_dim * c) + k - child_samples.union(k, r_state.node_samples, c_row) + ec = c = ts.edges_child[e] r_state.branch_len[c] = time[p] - time[c] r_state.parent[c] = p - - # Add the LD contributed by the samples under added edges. When we walk - # up the tree to propagate these changes to parents of the removed edge, - # we need to remove the LD contributed by samples that were already - # there - in_parent = None + # update samples under nodes, store modified node, propagate upwards while p != tskit.NULL: + updates.add(0, c) for k in range(state_dim): - p_row = (state_dim * p) + k - r_state.node_samples.union(p_row, child_samples, k) - compute_branch_stat_update( - c, - in_parent, - l_state, - r_state, - state_dim, - +1, - stat_func, - ts.num_samples, - stat, - params, - ) - in_parent = child_samples + r_state.node_samples.union( + state_dim * p + k, r_state.node_samples, state_dim * ec + k + ) c = p p = r_state.parent[p] + # Update all affected child nodes (fully subtracted, deferred from addition) + for c in updates.get_items(0): + compute_branch_stat_update( + c, l_state, r_state, state_dim, +1, stat_func, num_samples, stat, params + ) + return stat, r_state