diff --git a/CMakeLists.txt b/CMakeLists.txt index ba37925..c722255 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -23,9 +23,10 @@ if(SCRAN_AGGREGATE_FETCH_EXTERN) add_subdirectory(extern) else() find_package(tatami_tatami 3.0.0 CONFIG REQUIRED) + find_package(tatami_tatami_stats 1.1.0 CONFIG REQUIRED) endif() -target_link_libraries(scran_aggregate INTERFACE tatami::tatami) +target_link_libraries(scran_aggregate INTERFACE tatami::tatami tatami::tatami_stats) # Tests if(CMAKE_PROJECT_NAME STREQUAL PROJECT_NAME) diff --git a/cmake/Config.cmake.in b/cmake/Config.cmake.in index 8397be9..a236203 100644 --- a/cmake/Config.cmake.in +++ b/cmake/Config.cmake.in @@ -2,5 +2,6 @@ include(CMakeFindDependencyMacro) find_dependency(tatami_tatami 3.0.0 CONFIG REQUIRED) +find_dependency(tatami_tatami_stats 1.1.0 CONFIG REQUIRED) include("${CMAKE_CURRENT_LIST_DIR}/libscran_scran_aggregateTargets.cmake") diff --git a/extern/CMakeLists.txt b/extern/CMakeLists.txt index c7828fe..a34208d 100644 --- a/extern/CMakeLists.txt +++ b/extern/CMakeLists.txt @@ -6,4 +6,11 @@ FetchContent_Declare( GIT_TAG master # ^3.0.0 ) +FetchContent_Declare( + tatami_stats + GIT_REPOSITORY https://github.com/tatami-inc/tatami_stats + GIT_TAG master # ^1.1.0 +) + FetchContent_MakeAvailable(tatami) +FetchContent_MakeAvailable(tatami_stats) diff --git a/include/scran_aggregate/aggregate_across_cells.hpp b/include/scran_aggregate/aggregate_across_cells.hpp index 749a9c8..99ac4d6 100644 --- a/include/scran_aggregate/aggregate_across_cells.hpp +++ b/include/scran_aggregate/aggregate_across_cells.hpp @@ -3,7 +3,9 @@ #include #include + #include "tatami/tatami.hpp" +#include "tatami_stats/tatami_stats.hpp" /** * @file aggregate_across_cells.hpp @@ -173,49 +175,68 @@ void compute_aggregate_by_column( tatami::Options opt; opt.sparse_ordered_index = false; - tatami::parallelize([&](size_t, Index_ s, Index_ l) { + tatami::parallelize([&](size_t t, Index_ s, Index_ l) { auto NC = p.ncol(); auto ext = tatami::consecutive_extractor(&p, false, static_cast(0), NC, s, l, opt); std::vector vbuffer(l); typename std::conditional, Index_>::type ibuffer(l); + size_t num_sums = buffers.sums.size(); + std::vector > local_sums; + local_sums.reserve(num_sums); + for (auto ptr : buffers.sums) { + local_sums.emplace_back(t, s, l, ptr); + } + size_t num_detected = buffers.detected.size(); + std::vector > local_detected; + local_detected.reserve(num_detected); + for (auto ptr : buffers.detected) { + local_detected.emplace_back(t, s, l, ptr); + } + for (Index_ x = 0; x < NC; ++x) { auto current = factor[x]; if constexpr(sparse_) { auto col = ext->fetch(vbuffer.data(), ibuffer.data()); - if (buffers.sums.size()) { - auto& cursum = buffers.sums[current]; + if (num_sums) { + auto cursum = local_sums[current].data(); for (Index_ i = 0; i < col.number; ++i) { - cursum[col.index[i]] += col.value[i]; + cursum[col.index[i] - s] += col.value[i]; } } - if (buffers.detected.size()) { - auto& curdetected = buffers.detected[current]; + if (num_detected) { + auto curdetected = local_detected[current].data(); for (Index_ i = 0; i < col.number; ++i) { - curdetected[col.index[i]] += (col.value[i] > 0); + curdetected[col.index[i] - s] += (col.value[i] > 0); } } } else { auto col = ext->fetch(vbuffer.data()); - - if (buffers.sums.size()) { - auto cursum = buffers.sums[current] + s; + if (num_sums) { + auto cursum = local_sums[current].data(); for (Index_ i = 0; i < l; ++i) { cursum[i] += col[i]; } } - if (buffers.detected.size()) { - auto curdetected = buffers.detected[current] + s; + if (num_detected) { + auto curdetected = local_detected[current].data(); for (Index_ i = 0; i < l; ++i) { curdetected[i] += (col[i] > 0); } } } } + + for (auto& lsums : local_sums) { + lsums.transfer(); + } + for (auto& ldetected : local_detected) { + ldetected.transfer(); + } }, p.nrow(), options.num_threads); } diff --git a/tests/src/aggregate_across_cells.cpp b/tests/src/aggregate_across_cells.cpp index f5a2d88..9296eed 100644 --- a/tests/src/aggregate_across_cells.cpp +++ b/tests/src/aggregate_across_cells.cpp @@ -2,10 +2,11 @@ #include "scran_aggregate/aggregate_across_cells.hpp" #include "scran_tests/scran_tests.hpp" + #include #include -std::vector create_groupings(size_t n, int ngroups) { +static std::vector create_groupings(size_t n, int ngroups) { std::vector groupings(n); for (size_t g = 0; g < groupings.size(); ++g) { groupings[g] = g % ngroups; @@ -115,3 +116,51 @@ TEST(AggregateAcrossCells, Skipping) { EXPECT_EQ(skipped.sums.size(), 0); EXPECT_EQ(skipped.detected.size(), 0); } + +TEST(AggregateAcrossCells, DirtyBuffers) { + int nr = 88, nc = 126; + auto vec = scran_tests::simulate_vector(nr * nc, []{ + scran_tests::SimulationParameters sparams; + sparams.density = 0.1; + sparams.seed = 69; + return sparams; + }()); + + tatami::DenseRowMatrix dense_row(nr, nc, std::move(vec)); + auto sparse_column = tatami::convert_to_compressed_sparse(&dense_row, false); + size_t ngroups = 3; + auto grouping = create_groupings(dense_row.ncol(), ngroups); + + // Setting up some dirty buffers. + scran_aggregate::AggregateAcrossCellsResults store; + scran_aggregate::AggregateAcrossCellsBuffers buffers; + store.sums.resize(ngroups, std::vector(nr, -1)); + store.detected.resize(ngroups, std::vector(nr, -1)); + buffers.sums.resize(ngroups); + buffers.detected.resize(ngroups); + for (size_t l = 0; l < ngroups; ++l) { + buffers.sums[l] = store.sums[l].data(); + buffers.detected[l] = store.detected[l].data(); + } + + scran_aggregate::AggregateAcrossCellsOptions opt; + auto ref = scran_aggregate::aggregate_across_cells(dense_row, grouping.data(), opt); + scran_aggregate::aggregate_across_cells(dense_row, grouping.data(), buffers, opt); + EXPECT_EQ(ref.sums.size(), store.sums.size()); + EXPECT_EQ(ref.detected.size(), store.detected.size()); + for (size_t l = 0; l < ngroups; ++l) { + EXPECT_EQ(ref.sums[l], store.sums[l]); + EXPECT_EQ(ref.detected[l], store.detected[l]); + } + + // Same for column-major iteration. + for (size_t l = 0; l < ngroups; ++l) { + std::fill_n(buffers.sums[l], nr, -1); + std::fill_n(buffers.detected[l], nr, -1); + } + scran_aggregate::aggregate_across_cells(*sparse_column, grouping.data(), buffers, opt); + for (size_t l = 0; l < ngroups; ++l) { + EXPECT_EQ(ref.sums[l], store.sums[l]); + EXPECT_EQ(ref.detected[l], store.detected[l]); + } +}