Skip to content

Commit

Permalink
Apply review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderKalistratov committed Dec 8, 2024
1 parent 9130a53 commit 943726f
Show file tree
Hide file tree
Showing 15 changed files with 40 additions and 98 deletions.
7 changes: 2 additions & 5 deletions dpnp/backend/extensions/statistics/bincount.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@

namespace dpctl_td_ns = dpctl::tensor::type_dispatch;

namespace statistics
{
namespace histogram
namespace statistics::histogram
{
struct Bincount
{
Expand Down Expand Up @@ -62,5 +60,4 @@ struct Bincount
};

void populate_bincount(py::module_ m);
} // namespace histogram
} // namespace statistics
} // namespace statistics::histogram
8 changes: 2 additions & 6 deletions dpnp/backend/extensions/statistics/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,8 @@

namespace dpctl_td_ns = dpctl::tensor::type_dispatch;

namespace statistics
namespace statistics::common
{
namespace common
{

size_t get_max_local_size(const sycl::device &device)
{
constexpr const int default_max_cpu_local_size = 256;
Expand Down Expand Up @@ -120,5 +117,4 @@ pybind11::dtype dtype_from_typenum(int dst_typenum)
}
}

} // namespace common
} // namespace statistics
} // namespace statistics::common
7 changes: 2 additions & 5 deletions dpnp/backend/extensions/statistics/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,7 @@
#include "utils/math_utils.hpp"
// clang-format on

namespace statistics
{
namespace common
namespace statistics::common
{

template <typename N, typename D>
Expand Down Expand Up @@ -200,5 +198,4 @@ sycl::nd_range<1>
// headers of dpctl.
pybind11::dtype dtype_from_typenum(int dst_typenum);

} // namespace common
} // namespace statistics
} // namespace statistics::common
8 changes: 2 additions & 6 deletions dpnp/backend/extensions/statistics/dispatch_table.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,8 @@
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
namespace py = pybind11;

namespace statistics
namespace statistics::common
{
namespace common
{

template <typename T, typename Rest>
struct one_of
{
Expand Down Expand Up @@ -386,5 +383,4 @@ class DispatchTable2
Table2<FnT> table;
};

} // namespace common
} // namespace statistics
} // namespace statistics::common
4 changes: 1 addition & 3 deletions dpnp/backend/extensions/statistics/histogram.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@
#include <algorithm>
#include <complex>
#include <memory>
#include <string>
#include <type_traits>
#include <unordered_map>
#include <tuple>
#include <vector>

#include <pybind11/pybind11.h>
Expand Down
9 changes: 3 additions & 6 deletions dpnp/backend/extensions/statistics/histogram.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,9 @@

#include "dispatch_table.hpp"

namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
// namespace dpctl_td_ns = dpctl::tensor::type_dispatch;

namespace statistics
{
namespace histogram
namespace statistics::histogram
{
struct Histogram
{
Expand All @@ -59,5 +57,4 @@ struct Histogram
};

void populate_histogram(py::module_ m);
} // namespace histogram
} // namespace statistics
} // namespace statistics::histogram
3 changes: 0 additions & 3 deletions dpnp/backend/extensions/statistics/histogram_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,9 @@
#include <algorithm>
#include <limits>
#include <string>
#include <unordered_map>
#include <vector>

#include "dpctl4pybind11.hpp"
#include "utils/memory_overlap.hpp"
#include "utils/output_validation.hpp"
#include "utils/type_dispatch.hpp"

#include <pybind11/pybind11.h>
Expand Down
8 changes: 0 additions & 8 deletions dpnp/backend/extensions/statistics/histogram_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,6 @@

#include "common.hpp"

namespace dpctl
{
namespace tensor
{
class usm_ndarray;
}
} // namespace dpctl

using dpctl::tensor::usm_ndarray;

namespace statistics
Expand Down
10 changes: 4 additions & 6 deletions dpnp/backend/extensions/statistics/sliding_dot_product1d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,8 @@
// THE POSSIBILITY OF SUCH DAMAGE.
//*****************************************************************************

#include <algorithm>
#include <complex>
#include <memory>
#include <string>
#include <type_traits>
#include <unordered_map>
#include <vector>

#include <pybind11/pybind11.h>
Expand All @@ -42,7 +38,7 @@
#include "sliding_dot_product1d.hpp"
#include "sliding_window1d.hpp"

#include <iostream>
// #include <iostream>

namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
using dpctl::tensor::usm_ndarray;
Expand Down Expand Up @@ -101,7 +97,9 @@ struct SlidingDotProductF
}
};

using SupportedTypes = std::tuple<uint64_t,
using SupportedTypes = std::tuple<uint32_t,
int32_t,
uint64_t,
int64_t,
float,
double,
Expand Down
10 changes: 2 additions & 8 deletions dpnp/backend/extensions/statistics/sliding_dot_product1d.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,10 @@
#pragma once

#include "dispatch_table.hpp"
#include "utils/type_dispatch.hpp"
#include <pybind11/pybind11.h>
#include <sycl/sycl.hpp>

namespace dpctl_td_ns = dpctl::tensor::type_dispatch;

namespace statistics
{
namespace sliding_window1d
namespace statistics::sliding_window1d
{
struct SlidingDotProduct1d
{
Expand Down Expand Up @@ -62,5 +57,4 @@ struct SlidingDotProduct1d
};

void populate_sliding_dot_product1d(py::module_ m);
} // namespace sliding_window1d
} // namespace statistics
} // namespace statistics::sliding_window1d
3 changes: 0 additions & 3 deletions dpnp/backend/extensions/statistics/sliding_window1d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,7 @@
// THE POSSIBILITY OF SUCH DAMAGE.
//*****************************************************************************

#include <algorithm>
#include <limits>
#include <string>
#include <unordered_map>
#include <vector>

#include "dpctl4pybind11.hpp"
Expand Down
10 changes: 0 additions & 10 deletions dpnp/backend/extensions/statistics/sliding_window1d.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,13 @@
#pragma once

#include "utils/math_utils.hpp"
#include <complex>
#include <sycl/sycl.hpp>
#include <tuple>
#include <type_traits>

#include <stdio.h>

#include "common.hpp"

namespace dpctl
{
namespace tensor
{
class usm_ndarray;
}
} // namespace dpctl

using dpctl::tensor::usm_ndarray;

namespace statistics
Expand Down
7 changes: 2 additions & 5 deletions dpnp/backend/extensions/statistics/validation_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@ sycl::queue get_queue(const std::vector<array_ptr> &inputs,
}
} // namespace

namespace statistics
{
namespace validation
namespace statistics::validation
{
std::string name_of(const array_ptr &arr, const array_names &names)
{
Expand Down Expand Up @@ -189,5 +187,4 @@ void common_checks(const std::vector<array_ptr> &inputs,
check_no_overlap(inputs, outputs, names);
}

} // namespace validation
} // namespace statistics
} // namespace statistics::validation
7 changes: 2 additions & 5 deletions dpnp/backend/extensions/statistics/validation_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@

#include "dpctl4pybind11.hpp"

namespace statistics
{
namespace validation
namespace statistics::validation
{
using array_ptr = const dpctl::tensor::usm_ndarray *;
using array_names = std::unordered_map<array_ptr, std::string>;
Expand Down Expand Up @@ -69,5 +67,4 @@ void check_size_at_least(const array_ptr &arr,
void common_checks(const std::vector<array_ptr> &inputs,
const std::vector<array_ptr> &outputs,
const array_names &names);
} // namespace validation
} // namespace statistics
} // namespace statistics::validation
37 changes: 18 additions & 19 deletions dpnp/dpnp_iface_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,8 +440,7 @@ def corrcoef(x, y=None, rowvar=True, *, dtype=None):


def _get_padding(a_size, v_size, mode):
if v_size > a_size:
a_size, v_size = v_size, a_size
assert v_size <= a_size

if mode == "valid":
l_pad, r_pad = 0, 0
Expand All @@ -463,9 +462,8 @@ def _run_native_sliding_dot_product1d(a, v, l_pad, r_pad):

usm_type = dpu.get_coerced_usm_type([a.usm_type, v.usm_type])
out_size = l_pad + r_pad + a.size - v.size + 1
out = dpnp.empty(
shape=out_size, sycl_queue=queue, dtype=a.dtype, usm_type=usm_type
)
# out type is the same as input type
out = dpnp.empty_like(a, shape=out_size, usm_type=usm_type)

a_usm = dpnp.get_usm_ndarray(a)
v_usm = dpnp.get_usm_ndarray(v)
Expand All @@ -491,11 +489,11 @@ def correlate(a, v, mode="valid"):
Cross-correlation of two 1-dimensional sequences.
This function computes the correlation as generally defined in signal
processing texts [1]:
processing texts [1]_:
.. math:: c_k = \sum_n a_{n+k} \cdot \overline{v}_n
with a and v sequences being zero-padded where necessary and
with `a` and `v` sequences being zero-padded where necessary and
:math:`\overline v` denoting complex conjugation.
For full documentation refer to :obj:`numpy.correlate`.
Expand All @@ -506,16 +504,16 @@ def correlate(a, v, mode="valid"):
First input array.
v : {dpnp.ndarray, usm_ndarray}
Second input array.
mode : {'valid', 'same', 'full'}, optional
mode : {"valid", "same", "full"}, optional
Refer to the :obj:`dpnp.convolve` docstring. Note that the default
is ``'valid'``, unlike :obj:`dpnp.convolve`, which uses ``'full'``.
is ``"valid"``, unlike :obj:`dpnp.convolve`, which uses ``"full"``.
Default: ``'valid'``.
Default: ``"valid"``.
Notes
-----
The definition of correlation above is not unique and sometimes
correlation may be defined differently. Another common definition is [1]:
correlation may be defined differently. Another common definition is [1]_:
.. math:: c'_k = \sum_n a_{n} \cdot \overline{v_{n+k}}
Expand All @@ -533,8 +531,8 @@ def correlate(a, v, mode="valid"):
See Also
--------
:obj:`dpnp.convolve` : Discrete, linear convolution of two
one-dimensional sequences.
:obj:`dpnp.convolve` : Discrete, linear convolution of two one-dimensional
sequences.
Examples
Expand All @@ -546,7 +544,7 @@ def correlate(a, v, mode="valid"):
array([3.5], dtype=float32)
>>> np.correlate(a, v, "same")
array([2. , 3.5, 3. ], dtype=float32)
>>> np.correlate([1, 2, 3], [0, 1, 0.5], "full")
>>> np.correlate([a, v, "full")
array([0.5, 2. , 3.5, 3. , 0. ], dtype=float32)
Using complex sequences:
Expand All @@ -557,10 +555,10 @@ def correlate(a, v, mode="valid"):
array([0.5-0.5j, 1. +0.j , 1.5-1.5j, 3. -1.j , 0. +0.j ], dtype=complex64)
Note that you get the time reversed, complex conjugated result
(:math:`\overline{c_{-k}}`) when the two input sequences a and v change
(:math:`\overline{c_{-k}}`) when the two input sequences `a` and `v` change
places:
>>> np.correlate([0, 1, 0.5j], [1+1j, 2, 3-1j], 'full')
>>> np.correlate(vc, ac, 'full')
array([0. +0.j , 3. +1.j , 1.5+1.5j, 1. +0.j , 0.5+0.5j], dtype=complex64)
"""
Expand All @@ -586,10 +584,11 @@ def correlate(a, v, mode="valid"):

if supported_dtype is None:
raise ValueError(
f"function '{correlate}' does not support input types "
f"({a.dtype}, {v.dtype}), "
f"function does not support input types "
f"({a.dtype.name}, {v.dtype.name}), "
"and the inputs could not be coerced to any "
f"supported types. List of supported types: {supported_types}"
f"supported types. List of supported types: "
f"{[st.name for st in supported_types]}"
)

if dpnp.issubdtype(v.dtype, dpnp.complexfloating):
Expand Down

0 comments on commit 943726f

Please sign in to comment.