Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing histogramdd #2143

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dpnp/backend/extensions/statistics/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ set(_module_src
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/bincount.cpp
${CMAKE_CURRENT_SOURCE_DIR}/histogram.cpp
${CMAKE_CURRENT_SOURCE_DIR}/histogramdd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/histogram_common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/statistics_py.cpp
)
Expand Down
3 changes: 2 additions & 1 deletion dpnp/backend/extensions/statistics/bincount.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@

#pragma once

#include <dpctl4pybind11.hpp>
#include <pybind11/pybind11.h>
#include <sycl/sycl.hpp>

#include "dispatch_table.hpp"
#include "dpctl4pybind11.hpp"

namespace dpctl_td_ns = dpctl::tensor::type_dispatch;

Expand Down
77 changes: 32 additions & 45 deletions dpnp/backend/extensions/statistics/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,12 @@
#include <complex>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>

// clang-format off
// math_utils.hpp doesn't include sycl header but uses sycl types
// so sycl.hpp must be included before math_utils.hpp
#include <sycl/sycl.hpp>

#include "utils/math_utils.hpp"
// clang-format on
#include "utils/type_utils.hpp"

namespace type_utils = dpctl::tensor::type_utils;

namespace statistics
{
Expand All @@ -56,24 +55,20 @@ constexpr auto Align(N n, D d)
template <typename T, sycl::memory_order Order, sycl::memory_scope Scope>
struct AtomicOp
{
static void add(T &lhs, const T value)
static void add(T &lhs, const T &value)
{
sycl::atomic_ref<T, Order, Scope> lh(lhs);
lh += value;
}
};
if constexpr (type_utils::is_complex_v<T>) {
using vT = typename T::value_type;
vT *_lhs = reinterpret_cast<vT(&)[2]>(lhs);
const vT *_val = reinterpret_cast<const vT(&)[2]>(value);

template <typename T, sycl::memory_order Order, sycl::memory_scope Scope>
struct AtomicOp<std::complex<T>, Order, Scope>
{
static void add(std::complex<T> &lhs, const std::complex<T> value)
{
T *_lhs = reinterpret_cast<T(&)[2]>(lhs);
const T *_val = reinterpret_cast<const T(&)[2]>(value);
sycl::atomic_ref<T, Order, Scope> lh0(_lhs[0]);
lh0 += _val[0];
sycl::atomic_ref<T, Order, Scope> lh1(_lhs[1]);
lh1 += _val[1];
AtomicOp<vT, Order, Scope>::add(_lhs[0], _val[0]);
AtomicOp<vT, Order, Scope>::add(_lhs[1], _val[1]);
}
else {
sycl::atomic_ref<T, Order, Scope> lh(lhs);
lh += value;
}
}
};

Expand All @@ -82,17 +77,12 @@ struct Less
{
bool operator()(const T &lhs, const T &rhs) const
{
return std::less{}(lhs, rhs);
}
};

template <typename T>
struct Less<std::complex<T>>
{
bool operator()(const std::complex<T> &lhs,
const std::complex<T> &rhs) const
{
return dpctl::tensor::math_utils::less_complex(lhs, rhs);
if constexpr (type_utils::is_complex_v<T>) {
return dpctl::tensor::math_utils::less_complex(lhs, rhs);
}
else {
return std::less{}(lhs, rhs);
}
}
};

Expand All @@ -101,26 +91,23 @@ struct IsNan
{
static bool isnan(const T &v)
{
if constexpr (std::is_floating_point_v<T> ||
std::is_same_v<T, sycl::half>) {
if constexpr (type_utils::is_complex_v<T>) {
using vT = typename T::value_type;

const vT real1 = std::real(v);
const vT imag1 = std::imag(v);

return IsNan<vT>::isnan(real1) || IsNan<vT>::isnan(imag1);
}
else if constexpr (std::is_floating_point_v<T> ||
std::is_same_v<T, sycl::half>) {
return sycl::isnan(v);
}

return false;
antonwolfy marked this conversation as resolved.
Show resolved Hide resolved
}
};

template <typename T>
struct IsNan<std::complex<T>>
{
static bool isnan(const std::complex<T> &v)
{
T real1 = std::real(v);
T imag1 = std::imag(v);
return sycl::isnan(real1) || sycl::isnan(imag1);
}
};

size_t get_max_local_size(const sycl::device &device);
size_t get_max_local_size(const sycl::device &device,
int cpu_local_size_limit,
Expand Down
2 changes: 2 additions & 0 deletions dpnp/backend/extensions/statistics/histogram.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@

#pragma once

#include <pybind11/pybind11.h>
#include <sycl/sycl.hpp>

#include "dispatch_table.hpp"
#include "dpctl4pybind11.hpp"

namespace dpctl_td_ns = dpctl::tensor::type_dispatch;

Expand Down
55 changes: 22 additions & 33 deletions dpnp/backend/extensions/statistics/histogram_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,22 +137,16 @@ void validate(const usm_ndarray &sample,
" parameter must have at least 1 element");
}

if (histogram.get_ndim() != 1) {
throw py::value_error(get_name(&histogram) +
" parameter must be 1d. Actual " +
std::to_string(histogram.get_ndim()) + "d");
}

if (weights_ptr) {
if (weights_ptr->get_ndim() != 1) {
throw py::value_error(
get_name(weights_ptr) + " parameter must be 1d. Actual " +
std::to_string(weights_ptr->get_ndim()) + "d");
}

auto sample_size = sample.get_size();
auto sample_size = sample.get_shape(0);
auto weights_size = weights_ptr->get_size();
if (sample.get_size() != weights_ptr->get_size()) {
if (sample_size != weights_ptr->get_size()) {
throw py::value_error(
get_name(&sample) + " size (" + std::to_string(sample_size) +
") and " + get_name(weights_ptr) + " size (" +
Expand All @@ -168,42 +162,37 @@ void validate(const usm_ndarray &sample,
}

if (sample.get_ndim() == 1) {
if (bins_ptr != nullptr && bins_ptr->get_ndim() != 1) {
if (histogram.get_ndim() != 1) {
throw py::value_error(get_name(&sample) + " parameter is 1d, but " +
get_name(bins_ptr) + " is " +
std::to_string(bins_ptr->get_ndim()) + "d");
get_name(&histogram) + " is " +
std::to_string(histogram.get_ndim()) + "d");
}

if (bins_ptr && histogram.get_size() != bins_ptr->get_size() - 1) {
auto hist_size = histogram.get_size();
auto bins_size = bins_ptr->get_size();
throw py::value_error(
get_name(&histogram) + " parameter and " + get_name(bins_ptr) +
" parameters shape mismatch. " + get_name(&histogram) +
" size is " + std::to_string(hist_size) + get_name(bins_ptr) +
" must have size " + std::to_string(hist_size + 1) +
" but have " + std::to_string(bins_size));
}
}
else if (sample.get_ndim() == 2) {
auto sample_count = sample.get_shape(0);
auto expected_dims = sample.get_shape(1);

if (bins_ptr != nullptr && bins_ptr->get_ndim() != expected_dims) {
throw py::value_error(get_name(&sample) + " parameter has shape {" +
std::to_string(sample_count) + "x" +
std::to_string(expected_dims) + "}" +
", so " + get_name(bins_ptr) +
if (histogram.get_ndim() != expected_dims) {
throw py::value_error(get_name(&sample) + " parameter has shape (" +
std::to_string(sample_count) + ", " +
std::to_string(expected_dims) + ")" +
", so " + get_name(&histogram) +
" parameter expected to be " +
std::to_string(expected_dims) +
"d. "
"Actual " +
std::to_string(bins->get_ndim()) + "d");
}
}

if (bins_ptr != nullptr) {
py::ssize_t expected_hist_size = 1;
for (int i = 0; i < bins_ptr->get_ndim(); ++i) {
expected_hist_size *= (bins_ptr->get_shape(i) - 1);
}

if (histogram.get_size() != expected_hist_size) {
throw py::value_error(
get_name(&histogram) + " and " + get_name(bins_ptr) +
" shape mismatch. " + get_name(&histogram) +
" expected to have size = " +
std::to_string(expected_hist_size) + ". Actual " +
std::to_string(histogram.get_size()));
std::to_string(histogram.get_ndim()) + "d");
}
}

Expand Down
63 changes: 45 additions & 18 deletions dpnp/backend/extensions/statistics/histogram_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,15 @@ template <typename T, int Dims>
struct CachedData
{
static constexpr bool const sync_after_init = true;
using pointer_type = T *;
using Shape = sycl::range<Dims>;
using value_type = T;
using pointer_type = value_type *;
static constexpr auto dims = Dims;
antonwolfy marked this conversation as resolved.
Show resolved Hide resolved

using ncT = typename std::remove_const<T>::type;
using ncT = typename std::remove_const<value_type>::type;
using LocalData = sycl::local_accessor<ncT, Dims>;

CachedData(T *global_data, sycl::range<Dims> shape, sycl::handler &cgh)
CachedData(T *global_data, Shape shape, sycl::handler &cgh)
{
this->global_data = global_data;
local_data = LocalData(shape, cgh);
Expand All @@ -71,13 +74,13 @@ struct CachedData
template <int _Dims>
void init(const sycl::nd_item<_Dims> &item) const
{
int32_t llid = item.get_local_linear_id();
uint32_t llid = item.get_local_linear_id();
auto local_ptr = &local_data[0];
int32_t size = local_data.size();
uint32_t size = local_data.size();
auto group = item.get_group();
int32_t local_size = group.get_local_linear_range();
uint32_t local_size = group.get_local_linear_range();

for (int32_t i = llid; i < size; i += local_size) {
for (uint32_t i = llid; i < size; i += local_size) {
local_ptr[i] = global_data[i];
}
}
Expand All @@ -87,17 +90,30 @@ struct CachedData
return local_data.size();
}

T &operator[](const sycl::id<Dims> &id) const
antonwolfy marked this conversation as resolved.
Show resolved Hide resolved
{
return local_data[id];
}

template <typename = std::enable_if_t<Dims == 1>>
T &operator[](const size_t id) const
{
return local_data[id];
}

private:
LocalData local_data;
T *global_data = nullptr;
value_type *global_data = nullptr;
};

template <typename T, int Dims>
struct UncachedData
{
static constexpr bool const sync_after_init = false;
using Shape = sycl::range<Dims>;
using pointer_type = T *;
using value_type = T;
using pointer_type = value_type *;
static constexpr auto dims = Dims;

UncachedData(T *global_data, const Shape &shape, sycl::handler &)
{
Expand All @@ -120,6 +136,17 @@ struct UncachedData
return _shape.size();
}

T &operator[](const sycl::id<Dims> &id) const
{
return global_data[id];
}

template <typename = std::enable_if_t<Dims == 1>>
T &operator[](const size_t id) const
{
return global_data[id];
}

private:
T *global_data = nullptr;
Shape _shape;
Expand Down Expand Up @@ -191,15 +218,15 @@ struct HistWithLocalCopies
template <int _Dims>
void finalize(const sycl::nd_item<_Dims> &item) const
{
int32_t llid = item.get_local_linear_id();
int32_t bins_count = local_hist.get_range().get(1);
int32_t local_hist_count = local_hist.get_range().get(0);
uint32_t llid = item.get_local_linear_id();
uint32_t bins_count = local_hist.get_range().get(1);
uint32_t local_hist_count = local_hist.get_range().get(0);
auto group = item.get_group();
int32_t local_size = group.get_local_linear_range();
uint32_t local_size = group.get_local_linear_range();

for (int32_t i = llid; i < bins_count; i += local_size) {
for (uint32_t i = llid; i < bins_count; i += local_size) {
auto value = local_hist[0][i];
for (int32_t lhc = 1; lhc < local_hist_count; ++lhc) {
for (uint32_t lhc = 1; lhc < local_hist_count; ++lhc) {
value += local_hist[lhc][i];
}
if (value != T(0)) {
Expand Down Expand Up @@ -290,9 +317,9 @@ class histogram_kernel;

template <typename T, typename HistImpl, typename Edges, typename Weights>
void submit_histogram(const T *in,
size_t size,
size_t dims,
uint32_t WorkPI,
const size_t size,
const size_t dims,
const uint32_t WorkPI,
const HistImpl &hist,
const Edges &edges,
const Weights &weights,
Expand Down
Loading
Loading