Skip to content

Commit

Permalink
Remove ScratchpadAllocator and ScratchpadEstimator (#5810)
Browse files Browse the repository at this point in the history
Removes deprecated ScratchpadAllocator

Signed-off-by: Joaquin Anton Guirao <[email protected]>
Signed-off-by: Michał Zientkiewicz <[email protected]>
Co-authored-by: Michał Zientkiewicz <[email protected]>
  • Loading branch information
jantonguirao and mzient authored Feb 26, 2025
1 parent c06fdc2 commit 09c25d9
Show file tree
Hide file tree
Showing 116 changed files with 409 additions and 1,485 deletions.
1 change: 0 additions & 1 deletion dali/benchmark/slice_kernel_bench.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#include "dali/kernels/slice/slice_cpu.h"
#include "dali/test/tensor_test_utils.h"
#include "dali/test/test_tensors.h"
#include "dali/kernels/scratch.h"

namespace dali {

Expand Down
8 changes: 3 additions & 5 deletions dali/benchmark/slice_kernel_bench.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#include "dali/kernels/slice/slice_gpu.cuh"
#include "dali/test/tensor_test_utils.h"
#include "dali/test/test_tensors.h"
#include "dali/kernels/scratch.h"
#include "dali/kernels/dynamic_scratchpad.h"

namespace dali {

Expand Down Expand Up @@ -76,10 +76,8 @@ class SliceBenchGPU : public DALIBenchmark {

auto req = kernel.Setup(ctx, in_tv, args_vec);

kernels::ScratchpadAllocator scratch_alloc;
scratch_alloc.Reserve(req.scratch_sizes);
auto scratchpad = scratch_alloc.GetScratchpad();
ctx.scratchpad = &scratchpad;
kernels::DynamicScratchpad dyn_scratchpad(ctx.gpu.stream);
ctx.scratchpad = &dyn_scratchpad;

kernel.Run(ctx, out_tv, in_tv, args_vec);
CUDA_CALL(cudaStreamSynchronize(ctx.gpu.stream));
Expand Down
1 change: 0 additions & 1 deletion dali/kernels/audio/mel_scale/mel_filter_bank_cpu_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#include <complex>
#include <cmath>
#include "dali/kernels/audio/mel_scale/mel_filter_bank_test.h"
#include "dali/kernels/scratch.h"
#include "dali/kernels/audio/mel_scale/mel_scale.h"
#include "dali/kernels/audio/mel_scale/mel_filter_bank_cpu.h"
#include "dali/kernels/common/utils.h"
Expand Down
23 changes: 6 additions & 17 deletions dali/kernels/audio/mel_scale/mel_filter_bank_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -304,24 +304,17 @@ class MelFilterBankGpu<T>::Impl : public MelFilterImplBase<T> {
}
}

void Setup(ScratchpadEstimator &se, const TensorListShape<> &in_shape) {
void Setup(const TensorListShape<> &in_shape) {
inner_fft_ = true;
for (int s = 0; s < in_shape.size(); s++) {
inner_fft_ &= volume(in_shape.tensor_shape_span(s).begin() + args_.axis + 1,
in_shape.tensor_shape_span(s).end()) == 1;
}
nfft_ = in_shape.tensor_shape_span(0)[args_.axis];
if (inner_fft_) {
SetupBlockDescsInnerFft(se, in_shape);
se.add<mm::memory_kind::device, T>(weights_down_norm_.size());
se.add<mm::memory_kind::device, T>(weights_up_norm_.size());
se.add<mm::memory_kind::device, int>(interval_ends_.size());
SetupBlockDescsInnerFft(in_shape);
} else {
SetupBlockDescsOuterFft(se, in_shape);
se.add<mm::memory_kind::device, int>(interval_ends_.size());
se.add<mm::memory_kind::device, T>(weights_down_.size());
if (args_.normalize)
se.add<mm::memory_kind::device, T>(norm_factors_.size());
SetupBlockDescsOuterFft(in_shape);
}
}

Expand Down Expand Up @@ -392,7 +385,7 @@ class MelFilterBankGpu<T>::Impl : public MelFilterImplBase<T> {
using MelFilterImplBase<T>::Args;

private:
void SetupBlockDescsOuterFft(ScratchpadEstimator &se, const TensorListShape<> &in_shape) {
void SetupBlockDescsOuterFft(const TensorListShape<> &in_shape) {
nframes_.clear();
nwindows_.clear();
block_descs_outer_.clear();
Expand All @@ -408,10 +401,9 @@ class MelFilterBankGpu<T>::Impl : public MelFilterImplBase<T> {
}
}
}
se.add<mm::memory_kind::device, BlockDescOuter<T>>(block_descs_outer_.size());
}

void SetupBlockDescsInnerFft(ScratchpadEstimator &se, const TensorListShape<> &in_shape) {
void SetupBlockDescsInnerFft(const TensorListShape<> &in_shape) {
nframes_.clear();
nwindows_.clear();
block_descs_inner_.clear();
Expand Down Expand Up @@ -486,7 +478,6 @@ class MelFilterBankGpu<T>::Impl : public MelFilterImplBase<T> {
block_descs_inner_.emplace_back(start, start + count);
}
}
se.add<mm::memory_kind::device, BlockDescInner<T>>(block_descs_inner_.size());
}

void FillBlockDescsOuterFft(T *const *out_list, const T *const *in_list) {
Expand Down Expand Up @@ -547,7 +538,6 @@ KernelRequirements MelFilterBankGpu<T>::Setup(KernelContext &context,
DALI_ENFORCE(in.shape.tensor_shape_span(s)[args.axis] == nfft,
"All samples should have the same FFT dimension");
}
ScratchpadEstimator se;
args.nfft = args.nfft > 0 ? args.nfft : 2 * (in.shape[0][args.axis] - 1);
args.freq_high = args.freq_high > 0 ? args.freq_high : args.sample_rate / 2;
if (!impl_ || impl_->Args() != args) {
Expand All @@ -562,8 +552,7 @@ KernelRequirements MelFilterBankGpu<T>::Setup(KernelContext &context,
break;
}
}
impl_->Setup(se, in.shape);
req.scratch_sizes = se.sizes;
impl_->Setup(in.shape);
return req;
}

Expand Down
5 changes: 1 addition & 4 deletions dali/kernels/common/join/tensor_join_gpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,9 @@ class TensorJoinGPU : public tensor_join::TensorJoinImplGPU<type_of_size<sizeof(
const std::function<const TensorListShape<> *(int)> &get_input_shape,
int num_inputs,
int axis) {
ScratchpadEstimator se;
KernelRequirements req;
req.output_shapes.resize(1);
se.add<mm::memory_kind::pinned, const InListU *>(num_inputs);
Base::Setup(req.output_shapes[0], se, get_input_shape, num_inputs, axis);
req.scratch_sizes = se.sizes;
Base::Setup(req.output_shapes[0], get_input_shape, num_inputs, axis);
return req;
}

Expand Down
5 changes: 0 additions & 5 deletions dali/kernels/common/join/tensor_join_gpu_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,11 @@ namespace tensor_join {
template <typename T, bool new_axis>
void TensorJoinImplGPU<T, new_axis>::Setup(
TensorListShape<> &output_shape,
ScratchpadEstimator &se,
const std::function<const TensorListShape<> *(int)> &get_input_shape,
int num_inputs,
int axis) {
JoinedShape(output_shape, get_input_shape, num_inputs, axis, new_axis);
int N = output_shape.num_samples();
se.add<mm::memory_kind::device, OutputDesc<T>>(N);
se.add<mm::memory_kind::device, InputDesc<T>>(num_inputs * N);
se.add<mm::memory_kind::pinned, OutputDesc<T>>(N);
se.add<mm::memory_kind::pinned, InputDesc<T>>(num_inputs * N);
axis_ = axis;
}

Expand Down
2 changes: 0 additions & 2 deletions dali/kernels/common/join/tensor_join_gpu_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ class DLL_PUBLIC TensorJoinImplGPU {

/**
* @param output_shape shape of the result of joining the inputs
* @param se scratchpad requirements are added here
* @param get_input_shape a function called with an input index; returns a reference to a shape
* of the input at given index
* @param num_inputs number of joined tensors
Expand All @@ -52,7 +51,6 @@ class DLL_PUBLIC TensorJoinImplGPU {
* differ at index `axis` (if new_axis == `false`).
*/
void Setup(TensorListShape<> &output_shape,
ScratchpadEstimator &se,
const std::function<const TensorListShape<> *(int)> &get_input_shape,
int num_inputs,
int axis);
Expand Down
2 changes: 1 addition & 1 deletion dali/kernels/common/scatter_gather.cu
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ void ScatterGatherGPU::Run(cudaStream_t stream, bool reset, ScatterGatherGPU::Me
size_t num_blocks, size_per_block;
std::tie(num_blocks, size_per_block) = BlockCountAndSize(ranges_);
if (num_blocks > kMaxRangesByVal) {
kernels::DynamicScratchpad scratchpad({}, stream);
kernels::DynamicScratchpad scratchpad(stream);
auto *blocks_pinned = scratchpad.Allocate<mm::memory_kind::pinned, CopyRange>(num_blocks);
auto blocks = make_span(blocks_pinned, num_blocks);
ScatterGatherBase::MakeBlocks(blocks, ranges_, size_per_block);
Expand Down
15 changes: 8 additions & 7 deletions dali/kernels/dynamic_scratchpad.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,19 +137,20 @@ class DynamicScratchpad
/**
* @brief Constructs a dynamically allocated scratchpad
*
* @param initial_sizes Sizes, in bytes, of the initial buffers. Note that these buffers
* are allocated lazily, so nothing is allocated if there's no request
* for memory of any given kind.
* @param device_order Allocation and deallocation order for device memory.
* @param device_order Allocation and deallocation order for device memory.
* @param pinned_dealloc_order Deallocation order for pinned memory. Allocation is always
* host-ordered. If not set, device_order is used.
* @param managed_dealloc_order Deallocation order for managed memory. Allocation is always
* host-ordered. If not set, device_order is used.
* @param initial_sizes Sizes, in bytes, of the initial buffers. Note that these buffers
* are allocated lazily, so nothing is allocated if there's no request
* for memory of any given kind.
*/
explicit DynamicScratchpad(scratch_sizes_t initial_sizes = {},
AccessOrder device_order = cudaStream_t(0),
using scratch_sizes_t = std::array<size_t, static_cast<size_t>(mm::memory_kind_id::count)>;
explicit DynamicScratchpad(AccessOrder device_order = cudaStream_t(0),
AccessOrder pinned_dealloc_order = {},
AccessOrder managed_dealloc_order = {}) {
AccessOrder managed_dealloc_order = {},
scratch_sizes_t initial_sizes = {}) {
initial_sizes_ = initial_sizes;
for (auto &s : initial_sizes_) {
if (s == 0)
Expand Down
2 changes: 1 addition & 1 deletion dali/kernels/dynamic_scratchpad_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ TEST(DynamicScratchpad, BasicTest) {
for (int attempt = 0; attempt < max_attempts; attempt++) {
char *pinned;
{
DynamicScratchpad scratch({}, AccessOrder(stream.get()));
DynamicScratchpad scratch(AccessOrder(stream.get()));
pinned = scratch.Allocate<mm::memory_kind::pinned, char>(N);
memcpy(pinned, in.data(), N);
CUDA_CALL(cudaMemcpyAsync(dev.get(), pinned, N, cudaMemcpyHostToDevice, stream));
Expand Down
1 change: 0 additions & 1 deletion dali/kernels/erase/erase_cpu_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
#include <vector>
#include <complex>
#include <cmath>
#include "dali/kernels/scratch.h"
#include "dali/kernels/erase/erase_cpu.h"
#include "dali/kernels/common/utils.h"
#include "dali/test/test_tensors.h"
Expand Down
3 changes: 1 addition & 2 deletions dali/kernels/erase/erase_gpu_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

#include "dali/kernels/common/utils.h"
#include "dali/kernels/erase/erase_gpu.h"
#include "dali/kernels/scratch.h"
#include "dali/kernels/dynamic_scratchpad.h"
#include "dali/pipeline/data/tensor_list.h"
#include "dali/test/tensor_test_utils.h"
Expand Down Expand Up @@ -203,7 +202,7 @@ struct EraseGpuKernelTest :
EraseGpu<T, ndim, channel_dim> kernel;
KernelContext ctx;
ctx.gpu.stream = 0;
DynamicScratchpad dyn_scratchpad({}, AccessOrder(ctx.gpu.stream));
DynamicScratchpad dyn_scratchpad(AccessOrder(ctx.gpu.stream));
ctx.scratchpad = &dyn_scratchpad;

CreateRegions();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
#include "dali/kernels/dynamic_scratchpad.h"
#include "dali/kernels/imgproc/color_manipulation/debayer/debayer.h"
#include "dali/kernels/imgproc/color_manipulation/debayer/debayer_npp.h"
#include "dali/kernels/scratch.h"
#include "dali/pipeline/data/tensor_list.h"
#include "dali/pipeline/data/views.h"
#include "dali/test/tensor_test_utils.h"
Expand Down Expand Up @@ -134,7 +133,7 @@ class DebayerGpuTest : public ::testing::Test {
Kernel kernel{0};
KernelContext ctx;
ctx.gpu.stream = cuda_stream;
DynamicScratchpad dyn_scratchpad({}, AccessOrder(ctx.gpu.stream));
DynamicScratchpad dyn_scratchpad(AccessOrder(ctx.gpu.stream));
ctx.scratchpad = &dyn_scratchpad;
auto in_view = in_.gpu(cuda_stream);
auto out_view = out_.gpu(cuda_stream);
Expand Down
4 changes: 2 additions & 2 deletions dali/kernels/imgproc/color_manipulation/equalize/hist_test.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -39,7 +39,7 @@ class EqualizeHistGpuTest : public ::testing::Test {
HistogramKernelGpu kernel;
KernelContext ctx;
ctx.gpu.stream = cuda_stream;
DynamicScratchpad dyn_scratchpad({}, AccessOrder(ctx.gpu.stream));
DynamicScratchpad dyn_scratchpad(AccessOrder(ctx.gpu.stream));
ctx.scratchpad = &dyn_scratchpad;
auto out_view = out_.gpu(cuda_stream);
auto in_view = in_.gpu(cuda_stream);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class EqualizeLookupGpuTest : public ::testing::Test {
LookupKernelGpu kernel;
KernelContext ctx;
ctx.gpu.stream = cuda_stream;
DynamicScratchpad dyn_scratchpad({}, AccessOrder(ctx.gpu.stream));
DynamicScratchpad dyn_scratchpad(AccessOrder(ctx.gpu.stream));
ctx.scratchpad = &dyn_scratchpad;
auto out_view = out_.gpu(cuda_stream);
auto in_view = in_.gpu(cuda_stream);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class EqualizeLutGpuTest : public ::testing::Test {
LutKernelGpu kernel;
KernelContext ctx;
ctx.gpu.stream = cuda_stream;
DynamicScratchpad dyn_scratchpad({}, AccessOrder(ctx.gpu.stream));
DynamicScratchpad dyn_scratchpad(AccessOrder(ctx.gpu.stream));
ctx.scratchpad = &dyn_scratchpad;
auto out_view = out_.gpu(cuda_stream);
auto in_view = in_.gpu(cuda_stream);
Expand Down
3 changes: 0 additions & 3 deletions dali/kernels/imgproc/convolution/convolution_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -379,11 +379,8 @@ struct ConvolutionCpu {

KernelRequirements Setup(KernelContext& ctx, const TensorShape<ndim>& in_shape, int window_size) {
KernelRequirements req;
ScratchpadEstimator se;
DALI_ENFORCE(window_size % 2 == 1,
make_string("Kernel window should have odd length, got: ", window_size, "."));
se.add<mm::memory_kind::host, In>(GetInputWindowBufSize(in_shape, window_size));
req.scratch_sizes = se.sizes;
req.output_shapes.push_back(uniform_list_shape<ndim>(1, in_shape));
return req;
}
Expand Down
9 changes: 3 additions & 6 deletions dali/kernels/imgproc/convolution/convolution_cpu_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
#include "dali/core/convert.h"
#include "dali/kernels/common/utils.h"
#include "dali/kernels/imgproc/convolution/convolution_cpu.h"
#include "dali/kernels/scratch.h"
#include "dali/test/tensor_test_utils.h"
#include "dali/test/test_tensors.h"
#include "dali/kernels/imgproc/convolution/baseline_convolution.h"
#include "dali/kernels/dynamic_scratchpad.h"

namespace dali {
namespace kernels {
Expand Down Expand Up @@ -302,11 +302,8 @@ struct ConvolutionCpuKernelTest : public ::testing::Test {
Kernel kernel;

auto req = kernel.Setup(ctx, in_.shape, k_win_.num_elements());
// this is painful
ScratchpadAllocator scratch_alloc;
scratch_alloc.Reserve(req.scratch_sizes);
auto scratchpad = scratch_alloc.GetScratchpad();
ctx.scratchpad = &scratchpad;
DynamicScratchpad dyn_scratchpad(AccessOrder::host());
ctx.scratchpad = &dyn_scratchpad;

testing::BaselineConvolve(baseline_out_, baseline_in_, k_win_, T::axis, T::window_size / 2);
TransformCase tranform(out_, baseline_out_, T::in_place);
Expand Down
5 changes: 0 additions & 5 deletions dali/kernels/imgproc/convolution/convolution_gpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ struct ConvolutionGpu {
KernelRequirements Setup(KernelContext& ctx, const TensorListShape<ndim>& in_shape,
const TensorListShape<1>& window_size) {
KernelRequirements req;
ScratchpadEstimator se;
DALI_ENFORCE(
in_shape.size() == window_size.size(),
make_string(
Expand All @@ -103,10 +102,6 @@ struct ConvolutionGpu {
make_string("Window is too big for sample ", i, ", got: ", window_size[i][0],
", expected at most: ", kMaxWindowSize / num_channels, "."));
}
se.add<mm::memory_kind::host, W>(num_samples * kWindowCopyBufferSize);
se.add<mm::memory_kind::device, W>(num_samples * kWindowCopyBufferSize);
se.add<mm::memory_kind::device, typename CutlassConv::SampleParams>(num_samples);
req.scratch_sizes = se.sizes;
req.output_shapes.push_back(in_shape);
return req;
}
Expand Down
19 changes: 9 additions & 10 deletions dali/kernels/imgproc/convolution/convolution_gpu_test.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2020-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -24,9 +24,9 @@
#include "dali/kernels/imgproc/convolution/baseline_convolution.h"
#include "dali/kernels/imgproc/convolution/convolution_cpu.h"
#include "dali/kernels/imgproc/convolution/convolution_gpu.h"
#include "dali/kernels/scratch.h"
#include "dali/test/tensor_test_utils.h"
#include "dali/test/test_tensors.h"
#include "dali/kernels/dynamic_scratchpad.h"

namespace dali {
namespace kernels {
Expand Down Expand Up @@ -159,6 +159,7 @@ struct ConvolutionGpuKernelTest : public ::testing::Test {
void RunTest() {
KernelContext ctx_cpu, ctx_gpu;
ctx_gpu.gpu.stream = 0;

KernelCpu kernel_cpu;
KernelGpu kernel_gpu;

Expand All @@ -175,22 +176,20 @@ struct ConvolutionGpuKernelTest : public ::testing::Test {
int window_size = shape_window[sample][0];
auto req = kernel_cpu.Setup(ctx_cpu, data_shape[sample], window_size);

ScratchpadAllocator scratch_alloc;
scratch_alloc.Reserve(req.scratch_sizes);
auto scratchpad = scratch_alloc.GetScratchpad();
ctx_cpu.scratchpad = &scratchpad;
DynamicScratchpad dyn_scratchpad_cpu(AccessOrder::host());
ctx_cpu.scratchpad = &dyn_scratchpad_cpu;

kernel_cpu.Run(ctx_cpu, baseline_out_[sample], baseline_in_[sample], k_win_[sample],
transform.GetCpuTransform(sample));
}

auto req = kernel_gpu.Setup(ctx_gpu, in_.shape, shape_window);

ScratchpadAllocator scratch_alloc;
scratch_alloc.Reserve(req.scratch_sizes);
auto scratchpad = scratch_alloc.GetScratchpad();
ctx_gpu.scratchpad = &scratchpad;
auto gpu_epilogue = transform.GetGpuEpilogue();

DynamicScratchpad dyn_scratchpad_gpu(AccessOrder(ctx_gpu.gpu.stream));
ctx_gpu.scratchpad = &dyn_scratchpad_gpu;

kernel_gpu.Run(ctx_gpu, out_, in_, k_win_, span<const int>{}, gpu_epilogue);

output_.invalidate_cpu();
Expand Down
Loading

0 comments on commit 09c25d9

Please sign in to comment.