Skip to content

Commit

Permalink
Switch to LocalOutputBuffers for column-based aggregate_across_cells.
Browse files Browse the repository at this point in the history
This avoids false sharing between threads and also zeros the
user-supplied buffers before performing any additions. The latter fixes
a bug where values were incorrectly added to uninitialized buffers. It
comes at a cost of adding the tatami_stats dependency, though.
  • Loading branch information
LTLA committed Jan 8, 2025
1 parent 800b776 commit 6d4253d
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 14 deletions.
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ if(SCRAN_AGGREGATE_FETCH_EXTERN)
add_subdirectory(extern)
else()
find_package(tatami_tatami 3.0.0 CONFIG REQUIRED)
find_package(tatami_tatami_stats 1.1.0 CONFIG REQUIRED)
endif()

target_link_libraries(scran_aggregate INTERFACE tatami::tatami)
target_link_libraries(scran_aggregate INTERFACE tatami::tatami tatami::tatami_stats)

# Tests
if(CMAKE_PROJECT_NAME STREQUAL PROJECT_NAME)
Expand Down
1 change: 1 addition & 0 deletions cmake/Config.cmake.in
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@

include(CMakeFindDependencyMacro)
find_dependency(tatami_tatami 3.0.0 CONFIG REQUIRED)
find_dependency(tatami_tatami_stats 1.1.0 CONFIG REQUIRED)

include("${CMAKE_CURRENT_LIST_DIR}/libscran_scran_aggregateTargets.cmake")
7 changes: 7 additions & 0 deletions extern/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,11 @@ FetchContent_Declare(
GIT_TAG master # ^3.0.0
)

FetchContent_Declare(
tatami_stats
GIT_REPOSITORY https://github.com/tatami-inc/tatami_stats
GIT_TAG master # ^1.1.0
)

FetchContent_MakeAvailable(tatami)
FetchContent_MakeAvailable(tatami_stats)
45 changes: 33 additions & 12 deletions include/scran_aggregate/aggregate_across_cells.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

#include <algorithm>
#include <vector>

#include "tatami/tatami.hpp"
#include "tatami_stats/tatami_stats.hpp"

/**
* @file aggregate_across_cells.hpp
Expand Down Expand Up @@ -173,49 +175,68 @@ void compute_aggregate_by_column(
tatami::Options opt;
opt.sparse_ordered_index = false;

tatami::parallelize([&](size_t, Index_ s, Index_ l) {
tatami::parallelize([&](size_t t, Index_ s, Index_ l) {
auto NC = p.ncol();
auto ext = tatami::consecutive_extractor<sparse_>(&p, false, static_cast<Index_>(0), NC, s, l, opt);
std::vector<Data_> vbuffer(l);
typename std::conditional<sparse_, std::vector<Index_>, Index_>::type ibuffer(l);

size_t num_sums = buffers.sums.size();
std::vector<tatami_stats::LocalOutputBuffer<Sum_> > local_sums;
local_sums.reserve(num_sums);
for (auto ptr : buffers.sums) {
local_sums.emplace_back(t, s, l, ptr);
}
size_t num_detected = buffers.detected.size();
std::vector<tatami_stats::LocalOutputBuffer<Detected_> > local_detected;
local_detected.reserve(num_detected);
for (auto ptr : buffers.detected) {
local_detected.emplace_back(t, s, l, ptr);
}

for (Index_ x = 0; x < NC; ++x) {
auto current = factor[x];

if constexpr(sparse_) {
auto col = ext->fetch(vbuffer.data(), ibuffer.data());
if (buffers.sums.size()) {
auto& cursum = buffers.sums[current];
if (num_sums) {
auto cursum = local_sums[current].data();
for (Index_ i = 0; i < col.number; ++i) {
cursum[col.index[i]] += col.value[i];
cursum[col.index[i] - s] += col.value[i];
}
}

if (buffers.detected.size()) {
auto& curdetected = buffers.detected[current];
if (num_detected) {
auto curdetected = local_detected[current].data();
for (Index_ i = 0; i < col.number; ++i) {
curdetected[col.index[i]] += (col.value[i] > 0);
curdetected[col.index[i] - s] += (col.value[i] > 0);
}
}

} else {
auto col = ext->fetch(vbuffer.data());

if (buffers.sums.size()) {
auto cursum = buffers.sums[current] + s;
if (num_sums) {
auto cursum = local_sums[current].data();
for (Index_ i = 0; i < l; ++i) {
cursum[i] += col[i];
}
}

if (buffers.detected.size()) {
auto curdetected = buffers.detected[current] + s;
if (num_detected) {
auto curdetected = local_detected[current].data();
for (Index_ i = 0; i < l; ++i) {
curdetected[i] += (col[i] > 0);
}
}
}
}

for (auto& lsums : local_sums) {
lsums.transfer();
}
for (auto& ldetected : local_detected) {
ldetected.transfer();
}
}, p.nrow(), options.num_threads);
}

Expand Down
51 changes: 50 additions & 1 deletion tests/src/aggregate_across_cells.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

#include "scran_aggregate/aggregate_across_cells.hpp"
#include "scran_tests/scran_tests.hpp"

#include <map>
#include <random>

std::vector<int> create_groupings(size_t n, int ngroups) {
static std::vector<int> create_groupings(size_t n, int ngroups) {
std::vector<int> groupings(n);
for (size_t g = 0; g < groupings.size(); ++g) {
groupings[g] = g % ngroups;
Expand Down Expand Up @@ -115,3 +116,51 @@ TEST(AggregateAcrossCells, Skipping) {
EXPECT_EQ(skipped.sums.size(), 0);
EXPECT_EQ(skipped.detected.size(), 0);
}

TEST(AggregateAcrossCells, DirtyBuffers) {
int nr = 88, nc = 126;
auto vec = scran_tests::simulate_vector(nr * nc, []{
scran_tests::SimulationParameters sparams;
sparams.density = 0.1;
sparams.seed = 69;
return sparams;
}());

tatami::DenseRowMatrix<double, int> dense_row(nr, nc, std::move(vec));
auto sparse_column = tatami::convert_to_compressed_sparse(&dense_row, false);
size_t ngroups = 3;
auto grouping = create_groupings(dense_row.ncol(), ngroups);

// Setting up some dirty buffers.
scran_aggregate::AggregateAcrossCellsResults<double, int> store;
scran_aggregate::AggregateAcrossCellsBuffers<double, int> buffers;
store.sums.resize(ngroups, std::vector<double>(nr, -1));
store.detected.resize(ngroups, std::vector<int>(nr, -1));
buffers.sums.resize(ngroups);
buffers.detected.resize(ngroups);
for (size_t l = 0; l < ngroups; ++l) {
buffers.sums[l] = store.sums[l].data();
buffers.detected[l] = store.detected[l].data();
}

scran_aggregate::AggregateAcrossCellsOptions opt;
auto ref = scran_aggregate::aggregate_across_cells(dense_row, grouping.data(), opt);
scran_aggregate::aggregate_across_cells(dense_row, grouping.data(), buffers, opt);
EXPECT_EQ(ref.sums.size(), store.sums.size());
EXPECT_EQ(ref.detected.size(), store.detected.size());
for (size_t l = 0; l < ngroups; ++l) {
EXPECT_EQ(ref.sums[l], store.sums[l]);
EXPECT_EQ(ref.detected[l], store.detected[l]);
}

// Same for column-major iteration.
for (size_t l = 0; l < ngroups; ++l) {
std::fill_n(buffers.sums[l], nr, -1);
std::fill_n(buffers.detected[l], nr, -1);
}
scran_aggregate::aggregate_across_cells(*sparse_column, grouping.data(), buffers, opt);
for (size_t l = 0; l < ngroups; ++l) {
EXPECT_EQ(ref.sums[l], store.sums[l]);
EXPECT_EQ(ref.detected[l], store.detected[l]);
}
}

0 comments on commit 6d4253d

Please sign in to comment.