Skip to content

Commit

Permalink
* Public class members are put first
Browse files Browse the repository at this point in the history
  • Loading branch information
ElenaTyuleneva committed Aug 27, 2024
1 parent 50c586c commit 864d32e
Showing 1 changed file with 81 additions and 78 deletions.
159 changes: 81 additions & 78 deletions include/oneapi/dpl/internal/random_impl/philox_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,24 +47,97 @@ template<typename UIntType, ::std::size_t w, ::std::size_t n, ::std::size_t r,
internal::element_type_t<UIntType> ...consts>
class philox_engine
{
/* The size of the consts arrays */
static constexpr ::std::size_t array_size = n / 2;

public:
/* types */
/* Types */
using result_type = UIntType;
using scalar_type = internal::element_type_t<result_type>;

/* engine characteristics */
/* Engine characteristics */
static constexpr ::std::size_t word_size = w;
static constexpr ::std::size_t word_count = n;
static constexpr ::std::size_t round_count = r;

private:
static_assert(n == 2 || n == 4, "n must be 2 or 4");
static_assert(sizeof...(consts) == n, "the amount of consts must be equal to n");
static_assert(r > 0, "r must be more than 0");
static_assert(w > 0 && w <= ::std::numeric_limits<scalar_type>::digits, "w must satisfy 0 < w < ::std::numeric_limits<UIntType>::digits");
static_assert(::std::numeric_limits<scalar_type>::digits <= 64, "size of the scalar UIntType (in case of sycl::vec<T, N> the size of T) must be less than 64 bits");
static_assert(::std::is_unsigned_v<scalar_type>, "UIntType must be unsigned type or vector of unsigned types");

static constexpr ::std::array<scalar_type, array_size> multipliers =
internal::get_even_array_from_tuple<scalar_type>(::std::make_tuple(consts...),
::std::make_index_sequence<array_size>{});
static constexpr ::std::array<scalar_type, array_size> round_consts =
internal::get_odd_array_from_tuple<scalar_type>(::std::make_tuple(consts...),
::std::make_index_sequence<array_size>{});
static constexpr scalar_type min() { return 0; }
static constexpr scalar_type max() { return ::std::numeric_limits<scalar_type>::max() & in_mask; }
static constexpr scalar_type default_seed = 20111115u;

/* Constructors and seeding functions */
philox_engine() : philox_engine(default_seed) {}
explicit philox_engine(scalar_type value) { seed(value); }
void seed(scalar_type value = default_seed) { seed_internal({ value & in_mask }); }

/* Set the state to arbitrary position */
void set_counter(const ::std::array<scalar_type, word_count>& counter) {
auto start = counter.begin();
auto end = counter.end();
for (::std::size_t i = 0; i < word_count; i++) {
// all counters are set in everse order
state_.X[i] = (*--end) & in_mask;
}
}

/* Generating functions */
result_type operator()() {
result_type ret = generate_internal<internal::type_traits_t<result_type>::num_elems>();
return ret;
}
/* operator () overload for result portion generation */
result_type operator()(unsigned int __random_nums) {
result_type ret = generate_internal<internal::type_traits_t<result_type>::num_elems>(__random_nums);
return ret;
}

/* Shift the counter only forward relative to its current position */
void discard(unsigned long long z) {
discard_internal(z);
}

/* Equality operators */
friend bool operator==(const philox_engine& x, const philox_engine& y) {
if(!::std::equal(x.state_.X.begin(), x.state_.X.end(), y.state_.X.begin()) ||
!::std::equal(x.state_.K.begin(), x.state_.K.end(), y.state_.K.begin()) ||
!::std::equal(x.state_.Y.begin(), x.state_.Y.end(), y.state_.Y.begin()) ||
x.state_.idx != y.state_.idx) {
return false;
}
return true;
}
friend bool
operator!=(const philox_engine& __x, const philox_engine& __y)
{
return !(__x == __y);
}

/* Inserters and extractors */
template<class CharT, class Traits, typename UIntType_, ::std::size_t w_, ::std::size_t n_, ::std::size_t r_, UIntType_... consts_>
friend ::std::basic_ostream<CharT, Traits>&
operator<<(::std::basic_ostream<CharT, Traits>&, const philox_engine<UIntType_, w_, n_, r_, consts_...>&);

template<typename UIntType_, ::std::size_t w_, ::std::size_t n_, ::std::size_t r_, UIntType_... consts_>
friend const sycl::stream&
operator<<(const sycl::stream&, const philox_engine<UIntType_, w_, n_, r_, consts_...>&);

template<class CharT, class Traits, typename UIntType_, ::std::size_t w_, ::std::size_t n_, ::std::size_t r_, UIntType_... consts_>
friend ::std::basic_istream<CharT, Traits>&
operator>>(::std::basic_istream<CharT, Traits>&, philox_engine<UIntType_, w_, n_, r_, consts_...>&);

private:
/* Internal generator state */
struct state {
::std::array<scalar_type, word_count> X; // counters
Expand All @@ -75,9 +148,7 @@ class philox_engine

/* Processing mask */
static constexpr auto in_mask = internal::word_mask<scalar_type, word_size>;
/* The size of the consts arrays */
static constexpr ::std::size_t array_size = word_count / 2;


void seed_internal(::std::initializer_list<scalar_type> seed) {
auto start = seed.begin();
auto end = seed.end();
Expand Down Expand Up @@ -123,7 +194,7 @@ class philox_engine
carry = 1;
}

// select high chunk shift for addition with the next counter chunk
// select high chunk shift for addition with the next counter chunk
ctr_inc = (ctr_inc & (~in_mask)) >> (std::numeric_limits<unsigned long long>::digits - word_size);
}
}
Expand Down Expand Up @@ -181,8 +252,8 @@ class philox_engine
::std::enable_if_t<(_N == 0), result_type>
generate_internal() {
result_type loc_result;

scalar_type curr_idx = state_.idx;

if(curr_idx == word_count) { // empty buffer
philox_kernel();
increase_counter_internal();
Expand All @@ -196,45 +267,7 @@ class philox_engine
return loc_result;
}

public:
static constexpr ::std::array<scalar_type, array_size> multipliers =
internal::get_even_array_from_tuple<scalar_type>(::std::make_tuple(consts...),
::std::make_index_sequence<array_size>{});
static constexpr ::std::array<scalar_type, array_size> round_consts =
internal::get_odd_array_from_tuple<scalar_type>(::std::make_tuple(consts...),
::std::make_index_sequence<array_size>{});
static constexpr scalar_type min() { return 0; }
static constexpr scalar_type max() { return ::std::numeric_limits<scalar_type>::max() & in_mask; }
static constexpr scalar_type default_seed = 20111115u;

/* Constructors and seeding functions */
philox_engine() : philox_engine(default_seed) {}
explicit philox_engine(scalar_type value) { seed(value); }
void seed(scalar_type value = default_seed) { seed_internal({ value & in_mask }); }

/* Set the state to arbitrary position */
void set_counter(const ::std::array<scalar_type, word_count>& counter) {
auto start = counter.begin();
auto end = counter.end();
for (::std::size_t i = 0; i < word_count; i++) {
// all counters are set in everse order
state_.X[i] = (*--end) & in_mask;
}
}

/* Generating functions */
result_type operator()() {
result_type ret = generate_internal<internal::type_traits_t<result_type>::num_elems>();
return ret;
}
/* operator () overload for result portion generation */
result_type operator()(unsigned int __random_nums) {
result_type ret = generate_internal<internal::type_traits_t<result_type>::num_elems>(__random_nums);
return ret;
}

/* Shift the counter only forward relative to its current position */
void discard(unsigned long long z) {
void discard_internal(unsigned long long z) {
scalar_type curr_idx = state_.idx % word_count;
unsigned long long newridx = (curr_idx + z) % word_count;
if(newridx == 0) {
Expand All @@ -261,36 +294,6 @@ class philox_engine
state_.idx = newridx;
}

/* Equality operators */
friend bool operator==(const philox_engine& x, const philox_engine& y) {
if(!::std::equal(x.state_.X.begin(), x.state_.X.end(), y.state_.X.begin()) ||
!::std::equal(x.state_.K.begin(), x.state_.K.end(), y.state_.K.begin()) ||
!::std::equal(x.state_.Y.begin(), x.state_.Y.end(), y.state_.Y.begin()) ||
x.state_.idx != y.state_.idx) {
return false;
}
return true;
}
friend bool
operator!=(const philox_engine& __x, const philox_engine& __y)
{
return !(__x == __y);
}

/* inserters and extractors */
template<class CharT, class Traits, typename UIntType_, ::std::size_t w_, ::std::size_t n_, ::std::size_t r_, UIntType_... consts_>
friend ::std::basic_ostream<CharT, Traits>&
operator<<(::std::basic_ostream<CharT, Traits>&, const philox_engine<UIntType_, w_, n_, r_, consts_...>&);

template<typename UIntType_, ::std::size_t w_, ::std::size_t n_, ::std::size_t r_, UIntType_... consts_>
friend const sycl::stream&
operator<<(const sycl::stream&, const philox_engine<UIntType_, w_, n_, r_, consts_...>&);

template<class CharT, class Traits, typename UIntType_, ::std::size_t w_, ::std::size_t n_, ::std::size_t r_, UIntType_... consts_>
friend ::std::basic_istream<CharT, Traits>&
operator>>(::std::basic_istream<CharT, Traits>&, philox_engine<UIntType_, w_, n_, r_, consts_...>&);

private:
/* Internal generation Philox kernel */
void philox_kernel() {
if constexpr (word_count == 2) {
Expand Down

0 comments on commit 864d32e

Please sign in to comment.