Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WebGPU EP] Batch Norm Implementation #23525

Merged
merged 7 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 137 additions & 0 deletions onnxruntime/core/providers/webgpu/nn/batch_norm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
Fixed Show fixed Hide fixed
// Licensed under the MIT License.

#include "core/common/inlined_containers.h"
#include "core/providers/webgpu/nn/batch_norm.h"
#include "core/providers/cpu/nn/batch_norm_helper.h"
#include "core/providers/cpu/tensor/utils.h"
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_supported_types.h"

namespace onnxruntime {
namespace webgpu {

#define WEBGPU_BATCH_NORM_VERSIONED_KERNEL(start, end, domain) \
ONNX_OPERATOR_VERSIONED_KERNEL_EX( \
BatchNormalization, \
domain, \
start, \
end, \
kWebGpuExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", WebGpuSupportedFloatTypes()), \
BatchNormalization);

#define WEBGPU_BATCH_NORM_KERNEL(version, domain) \
ONNX_OPERATOR_KERNEL_EX( \
BatchNormalization, \
domain, \
version, \
kWebGpuExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", WebGpuSupportedFloatTypes()), \
BatchNormalization);

WEBGPU_BATCH_NORM_VERSIONED_KERNEL(7, 8, kOnnxDomain)
WEBGPU_BATCH_NORM_VERSIONED_KERNEL(9, 13, kOnnxDomain)
WEBGPU_BATCH_NORM_VERSIONED_KERNEL(14, 14, kOnnxDomain)
WEBGPU_BATCH_NORM_KERNEL(15, kOnnxDomain)

WEBGPU_BATCH_NORM_VERSIONED_KERNEL(7, 8, kMSInternalNHWCDomain)
WEBGPU_BATCH_NORM_VERSIONED_KERNEL(9, 13, kMSInternalNHWCDomain)
WEBGPU_BATCH_NORM_VERSIONED_KERNEL(14, 14, kMSInternalNHWCDomain)
WEBGPU_BATCH_NORM_KERNEL(15, kMSInternalNHWCDomain)

Status BatchNormalizationProgram::GenerateShaderCode(ShaderHelper& shader) const {
const ShaderVariableHelper& input_tensor = shader.AddInput("input_tensor", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
const ShaderVariableHelper& scale = shader.AddInput("scale", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
const ShaderVariableHelper& B = shader.AddInput("B", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
const ShaderVariableHelper& input_mean = shader.AddInput("input_mean", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
const ShaderVariableHelper& input_var = shader.AddInput("input_var", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
const ShaderVariableHelper& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);

shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")
<< " let idx = global_idx * " << components_ << ";\n"
<< " var outputIndices = " << output.OffsetToIndices("idx") << ";\n";
if (spatial_) {
if (input_tensor.Rank() == 1) {
shader.MainFunctionBody() << " let cOffset = 0u;\n";
} else {
if (format_.compare("NHWC") == 0) {
shader.MainFunctionBody() << " let cOffset = outputIndices[" << input_tensor.Rank() - 1 << "] / " << components_ << ";\n";
} else {
shader.MainFunctionBody() << " let cOffset = outputIndices[1];\n";
}
}
} else {
if (format_.compare("NCHW") == 0) {
shader.MainFunctionBody() << " " << output.IndicesSet("outputIndices", "0", "0") << "\n"
<< " let cOffset = " << output.IndicesToOffset("outputIndices") << ";\n";
} else {
// update C channel
shader.MainFunctionBody() << " var cIndices = scale_indices_t(0);\n"
<< " cIndices[0] = outputIndices[" << input_tensor.Rank() - 1 << "];\n";
// update D1 x ... x Dn channels
for (int i = 1; i < scale.Rank(); i++) {
shader.MainFunctionBody() << " cIndices[" << i << "] = outputIndices[" << i << "];\n";
}
shader.MainFunctionBody() << " let cOffset = " << scale.IndicesToOffset("cIndices") << ";\n";
}
}

shader.MainFunctionBody() << " let scale = " << scale.GetByOffset("cOffset") << ";\n"
<< " let B = " << B.GetByOffset("cOffset") << ";\n"
<< " let input_mean = " << input_mean.GetByOffset("cOffset") << ";\n"
<< " let input_var = " << input_var.GetByOffset("cOffset") << ";\n"
<< " let x = " << input_tensor.GetByOffset("global_idx") << ";\n"
<< " let value = (x - input_mean) * inverseSqrt(input_var + " << epsilon_ << ") * scale + B;\n"
<< " " << output.SetByOffset("global_idx", "value") << "\n";

return Status::OK();
}

Status BatchNormalization::ComputeInternal(ComputeContext& context) const {
if (training_mode_) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "BatchNormalization trainingMode is not supported yet.");
}

if (context.InputCount() != 5) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "BatchNormalization requires 5 inputs.");
}

const auto* input_tensor = context.Input(0);
const TensorShape& input_shape = input_tensor->Shape();
size_t input_rank = input_shape.NumDimensions();
const int components = spatial_ ? ((input_shape[input_rank - 1] % 4 == 0) ? 4 : ((input_shape[input_rank - 1] % 2 == 0) ? 2 : 1)) : 1;

auto output_dims = input_shape.AsShapeVector();
TensorShape output_shape(output_dims);
auto* output_tensor = context.Output(0, output_shape);
int64_t output_size = output_tensor->Shape().Size() / static_cast<int64_t>(components);

if (output_size == 0) {
return Status::OK();
}

const auto* scale = context.Input<Tensor>(1);
const auto* B = context.Input<Tensor>(2);
const auto* input_mean = context.Input<Tensor>(3);
const auto* input_var = context.Input<Tensor>(4);

ORT_RETURN_IF_ERROR(BatchNormHelper::ValidateInputs(input_tensor, scale, B, input_mean, input_var, spatial_ == 1, format_.compare("NHWC") == 0));

BatchNormalizationProgram program{epsilon_, spatial_, format_, static_cast<int64_t>(components)};
program
.AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank},
{scale, ProgramTensorMetadataDependency::TypeAndRank},
{B, ProgramTensorMetadataDependency::TypeAndRank},
{input_mean, ProgramTensorMetadataDependency::TypeAndRank},
{input_var, ProgramTensorMetadataDependency::TypeAndRank}})
.AddOutputs({output_tensor})
.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
.AddUniformVariables({{static_cast<uint32_t>(output_size)}});
return context.RunProgram(program);
}

} // namespace webgpu
} // namespace onnxruntime
53 changes: 53 additions & 0 deletions onnxruntime/core/providers/webgpu/nn/batch_norm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
Fixed Show fixed Hide fixed
// Licensed under the MIT License.

#pragma once

#include "core/providers/webgpu/webgpu_kernel.h"
#include "core/providers/webgpu/program.h"

namespace onnxruntime {
namespace webgpu {

class BatchNormalizationProgram final : public Program<BatchNormalizationProgram> {
public:
BatchNormalizationProgram(float epsilon, int64_t spatial, std::string format, int64_t components) : Program{"BatchNormalization"},
epsilon_{epsilon},
spatial_{spatial},
format_{format},
components_{components} {}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32});

private:
float epsilon_;
int64_t spatial_;
std::string format_;
int64_t components_;
};

class BatchNormalization final : public WebGpuKernel {
public:
BatchNormalization(const OpKernelInfo& info) : WebGpuKernel(info) {
epsilon_ = info.GetAttrOrDefault<float>("epsilon", 1e-5f);
momentum_ = info.GetAttrOrDefault<float>("momentum", 0.9f);
spatial_ = info.GetAttrOrDefault<int64_t>("spatial", 1);
training_mode_ = info.GetAttrOrDefault<int64_t>("training_mode", 0);
// NCHW for ai.onnx domain, NHWC for com.ms.internal.nhwc domain
format_ = info.GetAttrOrDefault<std::string>("format", "NHWC");
prathikr marked this conversation as resolved.
Show resolved Hide resolved
}

Status ComputeInternal(ComputeContext& context) const override;

private:
float epsilon_;
float momentum_;
int64_t spatial_;
int64_t training_mode_;
std::string format_;

Check warning on line 49 in onnxruntime/core/providers/webgpu/nn/batch_norm.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/nn/batch_norm.h:49: Add #include <string> for string [build/include_what_you_use] [4]
};

} // namespace webgpu
} // namespace onnxruntime
16 changes: 8 additions & 8 deletions onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -696,14 +696,14 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 18, If)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, If)>,

// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 8, BatchNormalization)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 13, BatchNormalization)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, 14, BatchNormalization)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 15, BatchNormalization)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 7, 8, BatchNormalization)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 9, 13, BatchNormalization)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 14, 14, BatchNormalization)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 15, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 8, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 13, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, 14, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 15, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 7, 8, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 9, 13, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 14, 14, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 15, BatchNormalization)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 13, CumSum)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, CumSum)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 12, uint8_t, DequantizeLinear)>,
Expand Down
9 changes: 6 additions & 3 deletions onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -924,7 +924,8 @@ TEST(BatchNormTest, ForwardTrainingTestWithSavedOutputsOpset9) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
// TODO(mtavenrath) flakiness of running_mean for CUDA has been fixed, the delta of running_var is still ~0.1
{kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider,
kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider});
kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider,
kWebGpuExecutionProvider});
}

TEST(BatchNormTest, ForwardTrainingTestOpset14) {
Expand Down Expand Up @@ -953,7 +954,8 @@ TEST(BatchNormTest, ForwardTrainingTestOpset14) {
// exclude TRT and OpenVINO for same reasons as seen in TestBatchNorm()
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
{kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider,
kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider});
kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider,
kWebGpuExecutionProvider});
}

TEST(BatchNormTest, ForwardTrainingTestOpset15) {
Expand Down Expand Up @@ -982,7 +984,8 @@ TEST(BatchNormTest, ForwardTrainingTestOpset15) {
// Same exclusions as the opset 14 test
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
{kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider,
kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider});
kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider,
kWebGpuExecutionProvider});
}
#endif // BATCHNORM_INCLUDE_TRAINING_SUPPORT

Expand Down
Loading