Skip to content

Commit

Permalink
Add mean.dtype_out op for internal model
Browse files Browse the repository at this point in the history
Differential Revision: D67453766

Pull Request resolved: #7404
  • Loading branch information
davidlin54 authored Jan 10, 2025
1 parent 73f50df commit 07d6f24
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 1 deletion.
2 changes: 2 additions & 0 deletions kernels/aten/functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,8 @@

- op: mean.out

- op: mean.dtype_out

- op: min.dim_min

- op: min.unary_out
Expand Down
8 changes: 8 additions & 0 deletions kernels/portable/cpu/op_mean.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@ Tensor& mean_dim_out(
return out;
}

Tensor& mean_dtype_out(
KernelRuntimeContext& ctx,
const Tensor& in,
optional<ScalarType> dtype,
Tensor& out) {
return mean_dim_out(ctx, in, ArrayRef<int64_t>(), false, dtype, out);
}

} // namespace native
} // namespace executor
} // namespace torch
1 change: 1 addition & 0 deletions kernels/portable/cpu/util/reduce_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ bool check_mean_dim_args(
check_reduction_args(in, dim_list, keepdim, dtype, out));

if (dtype) {
ET_LOG(Info, "dtype is %hhd", static_cast<int8_t>(dtype.value()));
ET_LOG_AND_RETURN_IF_FALSE(torch::executor::isFloatingType(dtype.value()));
ET_LOG_AND_RETURN_IF_FALSE(out.scalar_type() == dtype.value());
} else {
Expand Down
5 changes: 5 additions & 0 deletions kernels/portable/functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,11 @@
- arg_meta: null
kernel_name: torch::executor::mean_dim_out

- op: mean.dtype_out
kernels:
- arg_meta: null
kernel_name: torch::executor::mean_dtype_out

- op: min.dim_min
kernels:
- arg_meta: null
Expand Down
75 changes: 74 additions & 1 deletion kernels/test/op_mean_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
#include <executorch/kernels/test/TestUtil.h>
#include <executorch/kernels/test/supported_features.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/error.h>
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
Expand All @@ -22,6 +22,7 @@ using exec_aten::ArrayRef;
using exec_aten::optional;
using exec_aten::ScalarType;
using exec_aten::Tensor;
using executorch::runtime::Error;
using torch::executor::testing::TensorFactory;

class OpMeanOutTest : public OperatorTest {
Expand All @@ -36,6 +37,13 @@ class OpMeanOutTest : public OperatorTest {
context_, self, dim, keepdim, dtype, out);
}

Tensor& op_mean_dtype_out(
const Tensor& self,
optional<ScalarType> dtype,
Tensor& out) {
return torch::executor::aten::mean_outf(context_, self, dtype, out);
}

template <ScalarType IN_DTYPE, ScalarType OUT_DTYPE>
void test_mean_dim_out_invalid_dimensions() {
TensorFactory<IN_DTYPE> tf_in;
Expand Down Expand Up @@ -466,3 +474,68 @@ TEST_F(OpMeanOutTest, DynamicShapeUnbound) {
op_mean_out(x, ArrayRef<int64_t>{1}, false, ScalarType::Float, out);
EXPECT_TENSOR_CLOSE(out, expected_result);
}

TEST_F(OpMeanOutTest, DTypeOutFloatValid) {
TensorFactory<ScalarType::Float> tf;

Tensor x = tf.make(
{10, 10},
{1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0});
Tensor expected_result = tf.make({}, {1.0});

Tensor out = tf.zeros({});
Tensor ret = op_mean_dtype_out(x, ScalarType::Float, out);
EXPECT_TENSOR_CLOSE(out, expected_result);
}

TEST_F(OpMeanOutTest, DTypeOutFloatToBoolInvalid) {
TensorFactory<ScalarType::Float> tf;

Tensor x = tf.make(
{10, 10},
{1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0});
Tensor expected_result = tf.make({}, {1.0});

Tensor out = tf.zeros({});

ET_EXPECT_KERNEL_FAILURE(
context_, op_mean_dtype_out(x, ScalarType::Bool, out));
}

TEST_F(OpMeanOutTest, DTypeOutFloatInfinity) {
TensorFactory<ScalarType::Float> tf;

Tensor x = tf.make({2, 1}, {INFINITY, INFINITY});
Tensor expected_result = tf.make({}, {INFINITY});

Tensor out = tf.zeros({});

Tensor ret = op_mean_dtype_out(x, ScalarType::Float, out);
EXPECT_TENSOR_CLOSE(out, expected_result);
}

TEST_F(OpMeanOutTest, DTypeOutFloatNAN) {
TensorFactory<ScalarType::Float> tf;

Tensor x = tf.make({2, 1}, {NAN, INFINITY});
Tensor expected_result = tf.make({}, {NAN});

Tensor out = tf.zeros({});

Tensor ret = op_mean_dtype_out(x, ScalarType::Float, out);
EXPECT_TENSOR_CLOSE(out, expected_result);
}

0 comments on commit 07d6f24

Please sign in to comment.