Skip to content

Commit

Permalink
more cuda support for bitvec, vecvec (supports span now)
Browse files Browse the repository at this point in the history
  • Loading branch information
felixguendling committed Jan 16, 2025
1 parent b4b8bc7 commit 6b460ae
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 8 deletions.
37 changes: 37 additions & 0 deletions include/cista/const_iterator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#pragma once

#include <type_traits>

namespace cista {

template <class T, class R = void>
struct enable_if_type {
using type = R;
};

template <typename T, typename = void>
struct has_const_iterator : std::false_type {};

template <typename T>
struct has_const_iterator<
T, typename enable_if_type<typename T::const_iterator>::type>
: std::true_type {};

template <typename T>
inline constexpr bool has_const_iterator_v = has_const_iterator<T>::value;

template <typename Container, typename Enable = void>
struct const_iterator {
using type = typename Container::iterator;
};

template <typename Container>
struct const_iterator<Container,
std::enable_if_t<has_const_iterator_v<Container>>> {
using type = typename Container::const_iterator;
};

template <typename T>
using const_iterator_t = typename const_iterator<T>::type;

} // namespace cista
5 changes: 4 additions & 1 deletion include/cista/containers/bitvec.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "cista/bit_counting.h"
#include "cista/containers/vector.h"
#include "cista/cuda_check.h"
#include "cista/strong.h"

namespace cista {
Expand Down Expand Up @@ -125,7 +126,9 @@ struct basic_bitvec {

void reset() noexcept { blocks_ = {}; }

bool operator[](Key const i) const noexcept { return test(i); }
CISTA_CUDA_COMPAT bool operator[](Key const i) const noexcept {
return test(i);
}

std::size_t count() const noexcept {
if (empty()) {
Expand Down
22 changes: 15 additions & 7 deletions include/cista/containers/vecvec.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
#include <type_traits>
#include <vector>

#include "cista/const_iterator.h"
#include "cista/containers/vector.h"
#include "cista/cuda_check.h"
#include "cista/verify.h"

namespace cista {
Expand Down Expand Up @@ -177,8 +179,8 @@ struct basic_vecvec {

struct const_bucket final {
using value_type = data_value_type;
using iterator = typename DataVec::const_iterator;
using const_iterator = typename DataVec::const_iterator;
using iterator = const_iterator_t<DataVec>;
using const_iterator = iterator;

using iterator_category = std::random_access_iterator_tag;
using difference_type = std::ptrdiff_t;
Expand Down Expand Up @@ -224,14 +226,18 @@ struct basic_vecvec {
index_value_type size() const {
return bucket_end_idx() - bucket_begin_idx();
}
const_iterator begin() const {
CISTA_CUDA_COMPAT const_iterator begin() const {
return map_->data_.begin() + bucket_begin_idx();
}
const_iterator end() const {
CISTA_CUDA_COMPAT const_iterator end() const {
return map_->data_.begin() + bucket_end_idx();
}
friend const_iterator begin(const_bucket const& b) { return b.begin(); }
friend const_iterator end(const_bucket const& b) { return b.end(); }
friend CISTA_CUDA_COMPAT const_iterator begin(const_bucket const& b) {
return b.begin();
}
friend CISTA_CUDA_COMPAT const_iterator end(const_bucket const& b) {
return b.end();
}

std::reverse_iterator<const_iterator> rbegin() const {
return std::reverse_iterator{begin() + size()};
Expand Down Expand Up @@ -301,7 +307,9 @@ struct basic_vecvec {
using const_iterator = const_bucket;

bucket operator[](Key const i) { return {this, to_idx(i)}; }
const_bucket operator[](Key const i) const { return {this, to_idx(i)}; }
CISTA_CUDA_COMPAT const_bucket operator[](Key const i) const {
return {this, to_idx(i)};
}

const_bucket at(Key const i) const {
verify(to_idx(i) < bucket_starts_.size(),
Expand Down

0 comments on commit 6b460ae

Please sign in to comment.