Skip to content

Commit

Permalink
Reshape output tensor for average pooling 2d
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 598397636
  • Loading branch information
alankelly authored and xnnpack-bot committed Jan 14, 2024
1 parent edd71bc commit 53695fe
Show file tree
Hide file tree
Showing 4 changed files with 210 additions and 11 deletions.
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1523,6 +1523,11 @@ IF(XNNPACK_BUILD_TESTS)
TARGET_LINK_LIBRARIES(average-pooling-2d-test PRIVATE XNNPACK fp16 GTest::gtest GTest::gtest_main subgraph)
ADD_TEST(NAME average-pooling-2d-test COMMAND average-pooling-2d-test)

ADD_EXECUTABLE(average-pooling-2d-reshape-test test/average-pooling-2d-reshape.cc)
TARGET_INCLUDE_DIRECTORIES(average-pooling-2d-reshape-test PRIVATE src test)
TARGET_LINK_LIBRARIES(average-pooling-2d-reshape-test PRIVATE XNNPACK fp16 GTest::gtest GTest::gtest_main subgraph)
ADD_TEST(NAME average-pooling-2d-reshape-test COMMAND average-pooling-2d-reshape-test)

ADD_EXECUTABLE(bankers-rounding-test test/bankers-rounding.cc)
TARGET_INCLUDE_DIRECTORIES(bankers-rounding-test PRIVATE src test)
TARGET_LINK_LIBRARIES(bankers-rounding-test PRIVATE XNNPACK fp16 GTest::gtest GTest::gtest_main subgraph)
Expand Down
48 changes: 37 additions & 11 deletions src/subgraph/average-pooling-2d.c
Original file line number Diff line number Diff line change
Expand Up @@ -73,40 +73,66 @@ static enum xnn_status reshape_average_pooling_operator(
{
const uint32_t input_id = opdata->inputs[0];
assert(input_id < num_values);
const size_t batch_size = values[input_id].shape.dim[0];
const size_t input_height = values[input_id].shape.dim[1];
const size_t input_width = values[input_id].shape.dim[2];
const size_t channel_dim = values[input_id].shape.dim[3];
assert(channel_dim == values[opdata->outputs[0]].shape.dim[3]);

const uint32_t output_id = opdata->outputs[0];
assert(output_id < num_values);

const struct xnn_value* input_value = values + input_id;
struct xnn_value* output_value = values + output_id;

const size_t batch_size = input_value->shape.dim[0];
const size_t input_height = input_value->shape.dim[1];
const size_t input_width = input_value->shape.dim[2];
const size_t channel_dim = input_value->shape.dim[3];

enum xnn_status status = xnn_status_invalid_state;
const size_t old_workspace_size = opdata->workspace_size;
size_t output_height, output_width;
switch (opdata->operator_objects[0]->type) {
case xnn_operator_type_average_pooling_nhwc_f16:
return xnn_reshape_average_pooling2d_nhwc_f16(
status = xnn_reshape_average_pooling2d_nhwc_f16(
opdata->operator_objects[0],
batch_size,
input_height,
input_width,
/*channels=*/channel_dim, /*input_pixel_stride=*/channel_dim, /*output_pixel_stride=*/channel_dim,
&opdata->workspace_size,
&opdata->workspace_alignment,
/*output_height_out=*/NULL,
/*output_width_out=*/NULL,
&output_height,
&output_width,
threadpool);
break;
case xnn_operator_type_average_pooling_nhwc_f32:
return xnn_reshape_average_pooling2d_nhwc_f32(
status = xnn_reshape_average_pooling2d_nhwc_f32(
opdata->operator_objects[0],
batch_size,
input_height,
input_width,
/*channels=*/channel_dim, /*input_pixel_stride=*/channel_dim, /*output_pixel_stride=*/channel_dim,
&opdata->workspace_size,
&opdata->workspace_alignment,
/*output_height_out=*/NULL,
/*output_width_out=*/NULL,
&output_height,
&output_width,
threadpool);
break;
default:
XNN_UNREACHABLE;
}
if (status != xnn_status_success) {
return status;
}

output_value->shape.dim[0] = batch_size;
output_value->shape.dim[1] = output_height;
output_value->shape.dim[2] = output_width;
output_value->shape.dim[3] = channel_dim;

const size_t new_size = xnn_tensor_get_size(output_value);
if (new_size > output_value->size || old_workspace_size > opdata->workspace_size) {
output_value->size = new_size;
return xnn_status_reallocation_required;
}
return xnn_status_success;
}

static enum xnn_status setup_average_pooling_operator(
Expand Down
16 changes: 16 additions & 0 deletions test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -3364,6 +3364,22 @@ xnnpack_unit_test(
],
)

xnnpack_unit_test(
name = "average_pooling_2d_reshape_test",
srcs = [
"average-pooling-2d-reshape.cc",
],
deps = [
"//:XNNPACK_test_mode",
"//:aligned_allocator",
"//:common",
"//:node_type",
"//:operator_utils",
"//:operators_test_mode",
"//:subgraph_test_mode",
],
)

xnnpack_unit_test(
name = "bankers_rounding_test",
srcs = [
Expand Down
152 changes: 152 additions & 0 deletions test/average-pooling-2d-reshape.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
// Copyright 2023 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include <algorithm>
#include <array>
#include <cstddef>
#include <cstdint>
#include <limits>
#include <memory>
#include <random>
#include <vector>

#include <gtest/gtest.h>

#include <xnnpack.h>
#include <xnnpack/aligned-allocator.h>
#include <xnnpack/common.h>
#include <xnnpack/node-type.h>
#include <xnnpack/operator-utils.h>
#include <xnnpack/operator.h>
#include <xnnpack/subgraph.h>

TEST(AveragePooling2DTestF32, Reshape)
{
ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));

xnn_subgraph_t subgraph = nullptr;
ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/2, /*flags=*/0, &subgraph));
std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);

std::vector<size_t> dims{2, 3, 4, 5};
uint32_t input_id = XNN_INVALID_NODE_ID;
ASSERT_EQ(
xnn_status_success, xnn_define_tensor_value(
subgraph, xnn_datatype_fp32, dims.size(), dims.data(), nullptr, 0,
/*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id));
ASSERT_NE(input_id, XNN_INVALID_NODE_ID);

uint32_t output_id = XNN_INVALID_NODE_ID;
ASSERT_EQ(
xnn_status_success, xnn_define_tensor_value(
subgraph, xnn_datatype_fp32, dims.size(), dims.data(), nullptr, 1,
/*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
ASSERT_NE(output_id, XNN_INVALID_NODE_ID);

const size_t pooling_height = 2;
const size_t pooling_width = 2;
const size_t stride_height = 2;
const size_t stride_width = 2;
const float output_min = -std::numeric_limits<float>::infinity();
const float output_max = std::numeric_limits<float>::infinity();
ASSERT_EQ(xnn_status_success, xnn_define_average_pooling_2d(
subgraph, /*input_padding_top=*/0, /*input_padding_right=*/0, /*input_padding_bottom=*/0, /*input_padding_left=*/0, pooling_height,
pooling_width, stride_height, stride_width, output_min, output_max, input_id, output_id,
/*flags=*/0));

ASSERT_EQ(subgraph->num_nodes, 1);
struct xnn_node* node = &subgraph->nodes[0];
ASSERT_EQ(node->type, xnn_node_type_average_pooling_2d);
ASSERT_EQ(node->compute_type, xnn_compute_type_fp32);
ASSERT_EQ(node->num_inputs, 1);
ASSERT_EQ(node->inputs[0], input_id);
ASSERT_EQ(node->num_outputs, 1);
ASSERT_EQ(node->outputs[0], output_id);
ASSERT_EQ(node->flags, 0);

xnn_runtime_t runtime = nullptr;
ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
ASSERT_NE(nullptr, runtime);
std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);

ASSERT_EQ(node->reshape(&runtime->opdata[0], subgraph->values, subgraph->num_values, /*threadpool=*/nullptr), xnn_status_success);

dims[0] = 7;
dims[3] = 9;
ASSERT_EQ(xnn_status_success, xnn_reshape_external_value(runtime, 0, dims.size(), dims.data()));

ASSERT_EQ(node->reshape(&runtime->opdata[0], runtime->values, runtime->num_values, /*threadpool=*/nullptr), xnn_status_reallocation_required);
const xnn_shape* output_shape = &runtime->values[node->outputs[0]].shape;
ASSERT_EQ(output_shape->dim[0], dims[0]);
ASSERT_EQ(output_shape->dim[1], dims[1] - 2);
ASSERT_EQ(output_shape->dim[2], dims[2] - 2);
ASSERT_EQ(output_shape->dim[3], dims[3]);
}

TEST(AveragePooling2DTestF32, ReshapeWithPadding)
{
ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));

xnn_subgraph_t subgraph = nullptr;
ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/2, /*flags=*/0, &subgraph));
std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);

std::vector<size_t> dims{2, 3, 4, 5};
std::vector<size_t> output_dims{2, 3, 5, 5};
uint32_t input_id = XNN_INVALID_NODE_ID;
ASSERT_EQ(
xnn_status_success, xnn_define_tensor_value(
subgraph, xnn_datatype_fp32, dims.size(), dims.data(), nullptr, 0,
/*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id));
ASSERT_NE(input_id, XNN_INVALID_NODE_ID);

uint32_t output_id = XNN_INVALID_NODE_ID;
ASSERT_EQ(
xnn_status_success, xnn_define_tensor_value(
subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, 1,
/*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
ASSERT_NE(output_id, XNN_INVALID_NODE_ID);

const size_t pooling_height = 2;
const size_t pooling_width = 2;
const size_t stride_height = 2;
const size_t stride_width = 2;
const float output_min = -std::numeric_limits<float>::infinity();
const float output_max = std::numeric_limits<float>::infinity();
ASSERT_EQ(xnn_status_success, xnn_define_average_pooling_2d(
subgraph, /*input_padding_top=*/3, /*input_padding_right=*/2, /*input_padding_bottom=*/1, /*input_padding_left=*/4, pooling_height,
pooling_width, stride_height, stride_width, output_min, output_max, input_id, output_id,
/*flags=*/0));

ASSERT_EQ(subgraph->num_nodes, 1);
struct xnn_node* node = &subgraph->nodes[0];
ASSERT_EQ(node->type, xnn_node_type_average_pooling_2d);
ASSERT_EQ(node->compute_type, xnn_compute_type_fp32);
ASSERT_EQ(node->num_inputs, 1);
ASSERT_EQ(node->inputs[0], input_id);
ASSERT_EQ(node->num_outputs, 1);
ASSERT_EQ(node->outputs[0], output_id);
ASSERT_EQ(node->flags, 0);

xnn_runtime_t runtime = nullptr;
ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
ASSERT_NE(nullptr, runtime);
std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);

ASSERT_EQ(node->reshape(&runtime->opdata[0], subgraph->values, subgraph->num_values, /*threadpool=*/nullptr), xnn_status_success);

dims[0] = 2;
dims[1] = 2;
dims[2] = 8;
dims[3] = 17;
ASSERT_EQ(xnn_status_success, xnn_reshape_external_value(runtime, 0, dims.size(), dims.data()));

ASSERT_EQ(node->reshape(&runtime->opdata[0], runtime->values, runtime->num_values, /*threadpool=*/nullptr), xnn_status_reallocation_required);
const xnn_shape* output_shape = &runtime->values[node->outputs[0]].shape;
ASSERT_EQ(output_shape->dim[0], dims[0]);
ASSERT_EQ(output_shape->dim[1], 3);
ASSERT_EQ(output_shape->dim[2], 7);
ASSERT_EQ(output_shape->dim[3], dims[3]);
}

0 comments on commit 53695fe

Please sign in to comment.