Skip to content

Commit

Permalink
Warp Perspective CPU Impl (#5829)
Browse files Browse the repository at this point in the history
Add CPU impl of warp perspective. Parameterize warp perspective tests with device type.

Signed-off-by: Bryce Ferenczi <[email protected]>
  • Loading branch information
5had3z authored Feb 26, 2025
1 parent 09c25d9 commit ac181b6
Show file tree
Hide file tree
Showing 2 changed files with 411 additions and 184 deletions.
323 changes: 266 additions & 57 deletions dali/operators/image/remap/cvcuda/warp_perspective.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <nvcv/Image.hpp>
#include <nvcv/ImageBatch.hpp>
#include <nvcv/Tensor.hpp>
#include <opencv2/imgproc.hpp>
#include <optional>
#include "dali/core/dev_buffer.h"
#include "dali/core/static_switch.h"
Expand All @@ -25,8 +26,8 @@
#include "dali/pipeline/operator/checkpointing/stateless_operator.h"
#include "dali/pipeline/operator/operator.h"

#include "dali/operators/nvcvop/nvcvop.h"
#include "dali/operators/image/remap/cvcuda/matrix_adjust.h"
#include "dali/operators/nvcvop/nvcvop.h"

namespace dali {

Expand All @@ -38,9 +39,8 @@ Performs a perspective transform on the images.
.NumInput(1, 2)
.InputDox(0, "input", "TensorList of uint8, uint16, int16 or float",
"Input data. Must be images in HWC or CHW layout, or a sequence of those.")
.InputDox(1, "matrix_gpu", "2D TensorList of float",
"3x3 Perspective transform matrix. Should be used to pass the GPU data. "
"For CPU data, the `matrix` argument should be used.")
.InputDox(1, "matrix", "2D TensorList of float",
"3x3 Perspective transform matrix for per sample homography, same device as input.")
.NumOutput(1)
.InputLayout(0, {"HW", "HWC", "FHWC", "CHW", "FCHW"})
.AddOptionalArg<float>("size",
Expand Down Expand Up @@ -73,7 +73,8 @@ analog to the ``WARP_INVERSE_MAP`` flag.
Determines the meaning of (0, 0) coordinates - "corner" places the origin at the top-left corner of
the top-left pixel (like in OpenGL); "center" places (0, 0) in the center of
the top-left pixel (like in OpenCV).))doc", "corner")
the top-left pixel (like in OpenCV).))doc",
"corner")
.AddOptionalArg<float>("fill_value",
"Value used to fill areas that are outside the source image when the "
"\"constant\" border_mode is chosen.",
Expand All @@ -82,7 +83,63 @@ the top-left pixel (like in OpenCV).))doc", "corner")
"If set to true (default), the matrix is interpreted as "
"destination to source coordinates mapping. "
"Otherwise it's interpreted as source to destination "
"coordinates mapping.", true);
"coordinates mapping.",
true);

bool OCVCompatArg(std::string_view arg) {
if (arg == "corner") {
return false;
} else if (arg == "center") {
return true;
} else {
DALI_FAIL(make_string("Invalid pixel_origin argument: ", arg));
}
}

template <typename T>
T GetFillValue(const std::vector<float> &fill_value_arg, int channels) {
if (fill_value_arg.size() > 1) {
if (channels > 0) {
if (channels != static_cast<int>(fill_value_arg.size())) {
DALI_FAIL(make_string(
"Number of values provided as a fill_value should match the number of channels.\n"
"Number of channels: ",
channels, ". Number of values provided: ", fill_value_arg.size(), "."));
}
assert(channels <= 4);
T fill_value{0, 0, 0, 0};
if constexpr (std::is_same<T, float4>::value) {
std::memcpy(&fill_value, fill_value_arg.data(), fill_value_arg.size() * sizeof(float));
} else {
static_assert(std::is_same<T, cv::Scalar>::value, "Unsupported fill value type.");
std::copy(fill_value_arg.begin(), fill_value_arg.end(), fill_value.val);
}
return fill_value;
} else {
DALI_FAIL("Only scalar fill_value can be provided when processing data in planar layout.");
}
} else if (fill_value_arg.size() == 1) {
auto fv = fill_value_arg[0];
T fill_value{fv, fv, fv, fv};
return fill_value;
} else {
return T{0, 0, 0, 0};
}
}

template <typename Backend>
void ValidateTypes(const Workspace &ws) {
auto inp_type = ws.Input<Backend>(0).type();
DALI_ENFORCE(inp_type == DALI_UINT8 || inp_type == DALI_INT16 || inp_type == DALI_UINT16 ||
inp_type == DALI_FLOAT,
"The operator accepts the following input types: "
"uint8, int16, uint16, float.");
if (ws.NumInput() > 1) {
auto mat_type = ws.Input<Backend>(1).type();
DALI_ENFORCE(mat_type == DALI_FLOAT,
"Transformation matrix can be provided only as float32 values.");
}
}


class WarpPerspective : public nvcvop::NVCVSequenceOperator<StatelessOperator> {
Expand All @@ -101,56 +158,8 @@ class WarpPerspective : public nvcvop::NVCVSequenceOperator<StatelessOperator> {
return true;
}

float4 GetFillValue(int channels) const {
if (fill_value_arg_.size() > 1) {
if (channels > 0) {
if (channels != static_cast<int>(fill_value_arg_.size())) {
DALI_FAIL(make_string(
"Number of values provided as a fill_value should match the number of channels.\n"
"Number of channels: ",
channels, ". Number of values provided: ", fill_value_arg_.size(), "."));
}
assert(channels <= 4);
float4 fill_value{0, 0, 0, 0};
memcpy(&fill_value, fill_value_arg_.data(), channels * sizeof(float));
return fill_value;
} else {
DALI_FAIL("Only scalar fill_value can be provided when processing data in planar layout.");
}
} else if (fill_value_arg_.size() == 1) {
auto fv = fill_value_arg_[0];
float4 fill_value{fv, fv, fv, fv};
return fill_value;
} else {
return float4{0, 0, 0, 0};
}
}

void ValidateTypes(const Workspace &ws) const {
auto inp_type = ws.Input<GPUBackend>(0).type();
DALI_ENFORCE(inp_type == DALI_UINT8 || inp_type == DALI_INT16 || inp_type == DALI_UINT16 ||
inp_type == DALI_FLOAT,
"The operator accepts the following input types: "
"uint8, int16, uint16, float.");
if (ws.NumInput() > 1) {
auto mat_type = ws.Input<GPUBackend>(1).type();
DALI_ENFORCE(mat_type == DALI_FLOAT,
"Transformation matrix can be provided only as float32 values.");
}
}

bool OCVCompatArg(const std::string &arg) {
if (arg == "corner") {
return false;
} else if (arg == "center") {
return true;
} else {
DALI_FAIL(make_string("Invalid pixel_origin argument: ", arg));
}
}

bool SetupImpl(std::vector<OutputDesc> &output_desc, const Workspace &ws) override {
ValidateTypes(ws);
ValidateTypes<GPUBackend>(ws);
const auto &input = ws.Input<GPUBackend>(0);
auto input_shape = input.shape();
auto input_layout = input.GetLayout();
Expand All @@ -160,7 +169,7 @@ class WarpPerspective : public nvcvop::NVCVSequenceOperator<StatelessOperator> {
int channels = (input_layout.find('C') != -1) ? input_shape[0][input_layout.find('C')] : -1;
if (channels > 4)
DALI_FAIL("Images with more than 4 channels are not supported.");
fill_value_ = GetFillValue(channels);
fill_value_ = GetFillValue<float4>(fill_value_arg_, channels);
if (size_arg_.HasExplicitValue()) {
size_arg_.Acquire(spec_, ws, input_shape.size(), TensorShape<1>(2));
for (int i = 0; i < input_shape.size(); i++) {
Expand Down Expand Up @@ -189,7 +198,7 @@ class WarpPerspective : public nvcvop::NVCVSequenceOperator<StatelessOperator> {
"Matrix input and `matrix` argument should not be provided at the same time.");
auto &matrix_input = ws.Input<GPUBackend>(1);
DALI_ENFORCE(matrix_input.shape() ==
uniform_list_shape(matrix_input.num_samples(), TensorShape<2>(3, 3)),
uniform_list_shape(matrix_input.num_samples(), TensorShape<2>(3, 3)),
make_string("Expected a uniform list of 3x3 matrices. "
"Instead got data with shape: ",
matrix_input.shape()));
Expand Down Expand Up @@ -236,4 +245,204 @@ class WarpPerspective : public nvcvop::NVCVSequenceOperator<StatelessOperator> {

DALI_REGISTER_OPERATOR(experimental__WarpPerspective, WarpPerspective, GPU);


class WarpPerspectiveCPU : public SequenceOperator<CPUBackend, StatelessOperator> {
public:
explicit WarpPerspectiveCPU(const OpSpec &spec)
: SequenceOperator(spec),
border_mode_(GetBorderMode(spec.GetArgument<std::string>("border_mode"))),
interp_type_(GetInterpolationType(spec.GetArgument<DALIInterpType>("interp_type"))),
fill_value_arg_(spec.GetArgument<std::vector<float>>("fill_value")),
inverse_map_(spec.GetArgument<bool>("inverse_map")),
ocv_pixel_(OCVCompatArg(spec.GetArgument<std::string>("pixel_origin"))) {}

private:
cv::BorderTypes GetBorderMode(std::string_view border_mode) {
if (border_mode == "constant") {
return cv::BorderTypes::BORDER_CONSTANT;
} else if (border_mode == "replicate") {
return cv::BorderTypes::BORDER_REPLICATE;
} else if (border_mode == "reflect") {
return cv::BorderTypes::BORDER_REFLECT;
} else if (border_mode == "reflect_101") {
return cv::BorderTypes::BORDER_REFLECT_101;
} else if (border_mode == "wrap") {
return cv::BorderTypes::BORDER_WRAP;
} else {
DALI_FAIL(make_string("Invalid border_mode argument: ", border_mode));
}
}

cv::InterpolationFlags GetInterpolationType(DALIInterpType interpolation_type) {
switch (interpolation_type) {
case DALIInterpType::DALI_INTERP_NN:
return cv::InterpolationFlags::INTER_NEAREST;
case DALIInterpType::DALI_INTERP_LINEAR:
return cv::InterpolationFlags::INTER_LINEAR;
case DALIInterpType::DALI_INTERP_CUBIC:
return cv::InterpolationFlags::INTER_CUBIC;
default:
DALI_FAIL(
make_string("Unknown interpolation type: ", static_cast<int>(interpolation_type)));
}
}


bool SetupImpl(std::vector<OutputDesc> &output_desc, const Workspace &ws) override {
ValidateTypes<CPUBackend>(ws);
const auto &input = ws.Input<CPUBackend>(0);
auto input_shape = input.shape();
auto input_layout = input.GetLayout();
output_desc.resize(1);

auto output_shape = input_shape;
const auto chIdx = input_layout.find('C');
if (chIdx == -1 && ws.GetInputDim(0) > 2) {
DALI_FAIL("Layout not specified and number of dims > 2, can't determine channel count.");
} else if (chIdx != -1 && chIdx != input_layout.size() - 1) {
DALI_FAIL("Channel dimension must be the last one.");
}

int channels = (chIdx != -1) ? input_shape[0][chIdx] : -1;
if (channels > 4)
DALI_FAIL("Images with more than 4 channels are not supported.");

fill_value_ = GetFillValue<cv::Scalar>(fill_value_arg_, channels);
if (size_arg_.HasExplicitValue()) {
size_arg_.Acquire(spec_, ws, input_shape.size(), TensorShape<1>(2));
for (int i = 0; i < input_shape.size(); i++) {
auto height = std::max<int>(std::roundf(size_arg_[i].data[0]), 1);
auto width = std::max<int>(std::roundf(size_arg_[i].data[1]), 1);
auto out_sample_shape = (channels != -1) ? TensorShape<>({height, width, channels}) :
TensorShape<>({height, width});
output_shape.set_tensor_shape(i, out_sample_shape);
}
}

channels_ = std::max(1, channels); // If channels not specified in layout (-1) then must be 1

output_desc[0] = {output_shape, input.type()};
return true;
}

bool ShouldExpandChannels(int input_idx) const override {
return true;
}

/**
* @brief Converts OpenGL perspective warp format to OpenCV format
* @param matrix 3x3 matrix to convert inplace
*/
void ConvertOpenGLtoOpenCVFormat(cv::Matx33f &matrix) {
// clang-format off
const cv::Matx33f shift = {
1, 0, 0.5,
0, 1, 0.5,
0, 0, 1,
};
const cv::Matx33f shiftBack = {
1, 0, -0.5,
0, 1, -0.5,
0, 0, 1,
};
// clang-format on
matrix = shiftBack * (matrix * shift);
}

/**
* @brief Convert DALI data type to OpenCV matrix type
*/
int matTypeFromDALI(DALIDataType dtype) {
switch (dtype) {
case DALI_UINT8:
return CV_8U;
case DALI_INT16:
return CV_16S;
case DALI_UINT16:
return CV_16U;
case DALI_FLOAT:
return CV_32F;
default:
DALI_FAIL("Unsupported input type");
}
}

/**
* @brief Construct a full OpenCV matrix type from DALI data type and number of channels
*/
int fullMatTypeFromDALI(DALIDataType dtype, int channels) {
return CV_MAKETYPE(matTypeFromDALI(dtype), channels);
}

void RunImpl(Workspace &ws) override {
const auto &input = ws.Input<CPUBackend>(0);
auto &output = ws.Output<CPUBackend>(0);
output.SetLayout(input.GetLayout());

const int num_samples = ws.GetInputBatchSize(0);
std::vector<cv::Matx33f> matrices(num_samples);
if (ws.NumInput() > 1) {
DALI_ENFORCE(!matrix_arg_.HasExplicitValue(),
"Matrix input and `matrix` argument should not be provided at the same time.");
auto &matrix_input = ws.Input<CPUBackend>(1);
DALI_ENFORCE(matrix_input.shape() == uniform_list_shape(num_samples, TensorShape<2>(3, 3)),
make_string("Expected a uniform list of 3x3 matrices. "
"Instead got data with shape: ",
matrix_input.shape()));

for (int i = 0; i < num_samples; i++) {
std::memcpy(matrices[i].val, matrix_input.raw_tensor(i), sizeof(cv::Matx33f));
}
} else {
matrix_arg_.Acquire(spec_, ws, num_samples, TensorShape<2>(3, 3));
for (int i = 0; i < num_samples; ++i) {
std::memcpy(matrices[i].val, matrix_arg_[i].data, sizeof(cv::Matx33f));
}
}
if (!ocv_pixel_) {
for (auto &matrix : matrices) {
ConvertOpenGLtoOpenCVFormat(matrix);
}
}

auto &tPool = ws.GetThreadPool();
int warpFlags = interp_type_;
if (inverse_map_) {
warpFlags |= cv::WARP_INVERSE_MAP;
}
for (int i = 0; i < num_samples; ++i) {
tPool.AddWork([&, i](int) {
const auto inImage = input[i];
auto outImage = output[i];
const int dtype = fullMatTypeFromDALI(inImage.type(), channels_);

const auto &inShape = inImage.shape();
const cv::Mat inMat(static_cast<int>(inShape[0]), static_cast<int>(inShape[1]), dtype,
const_cast<void *>(inImage.raw_data()));

const auto &outShape = outImage.shape();
cv::Mat outMat(static_cast<int>(outShape[0]), static_cast<int>(outShape[1]), dtype,
outImage.raw_mutable_data());

cv::warpPerspective(inMat, outMat, matrices[i], cv::Size(outMat.cols, outMat.rows),
warpFlags, border_mode_, fill_value_);
});
}
tPool.RunAll();
}

USE_OPERATOR_MEMBERS();
ArgValue<float, 2> matrix_arg_{"matrix", spec_};
ArgValue<float, 1> size_arg_{"size", spec_};
int channels_ = 1;
cv::BorderTypes border_mode_ = cv::BorderTypes::BORDER_CONSTANT;
cv::InterpolationFlags interp_type_ = cv::InterpolationFlags::INTER_LINEAR;
std::vector<float> fill_value_arg_{0, 0, 0, 0};
cv::Scalar fill_value_{};
bool inverse_map_ = false;
bool ocv_pixel_ = true;
};

DALI_REGISTER_OPERATOR(experimental__WarpPerspective, WarpPerspectiveCPU, CPU);

} // namespace dali
Loading

0 comments on commit ac181b6

Please sign in to comment.