Skip to content

Commit

Permalink
[WebNN EP] Allow 0D input/output for Reshape and Expand (microsoft#22344
Browse files Browse the repository at this point in the history
)

- Allows Expand input be a scalar
- Allows Reshape input be a scalar
- Allows Reshape to a scalar

Fixed microsoft#22215

---------

Co-authored-by: Dwayne Robinson <[email protected]>
  • Loading branch information
Honry and fdwr authored Oct 24, 2024
1 parent 70be2eb commit cd60af0
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,6 @@ bool ExpandOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers
return false;
}

if (input_shape.empty()) {
LOGS(logger, VERBOSE) << "Expand does not support empty input's shape.";
return false;
}

std::vector<int64_t> output_shape;
if (!GetBidirectionalBroadcastShape(input_shape, new_shape, output_shape)) {
LOGS(logger, VERBOSE) << "The input cannot expand to shape " << GetShapeString(new_shape);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,21 +44,25 @@ Status ReshapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
const auto& input_defs = node.InputDefs();
const auto& initializers(model_builder.GetInitializerTensors());
const auto& target_shape_tensor = *initializers.at(input_defs[1]->Name());
const int64_t* raw_target_shape = target_shape_tensor.int64_data().empty()
? reinterpret_cast<const int64_t*>(target_shape_tensor.raw_data().data())
: target_shape_tensor.int64_data().data();
const auto& target_shape_tensor_dims = target_shape_tensor.dims();
std::vector<uint32_t> new_shape;
// Do nothing if target shape is an empty shape, which means converting to a scalar.
if (!target_shape_tensor_dims.empty()) {
const int64_t* raw_target_shape = target_shape_tensor.int64_data().empty()
? reinterpret_cast<const int64_t*>(target_shape_tensor.raw_data().data())
: target_shape_tensor.int64_data().data();

const auto size = target_shape_tensor_dims[0];
TensorShapeVector target_shape{raw_target_shape, raw_target_shape + size};
std::vector<int64_t> input_shape;
ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape");
ReshapeHelper helper(TensorShape(input_shape), target_shape);
std::transform(target_shape.cbegin(), target_shape.cend(),
std::back_inserter(new_shape),
[](int64_t dim) -> uint32_t { return SafeInt<uint32_t>(dim); });
}

const auto size = target_shape_tensor.dims()[0];
TensorShapeVector target_shape{raw_target_shape, raw_target_shape + size};
std::vector<int64_t> input_shape;
ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape");
ReshapeHelper helper(TensorShape(input_shape), target_shape);
emscripten::val input = model_builder.GetOperand(input_defs[0]->Name());
std::vector<int32_t> new_shape;
std::transform(target_shape.cbegin(), target_shape.cend(),
std::back_inserter(new_shape),
[](int64_t dim) -> uint32_t { return SafeInt<int32_t>(dim); });

emscripten::val options = emscripten::val::object();
options.set("label", node.Name());
emscripten::val output = model_builder.GetBuilder().call<emscripten::val>("reshape",
Expand All @@ -76,6 +80,11 @@ bool ReshapeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializer
const WebnnDeviceType /* device_type */,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();

std::vector<int64_t> input_shape;
if (!GetShape(*input_defs[0], input_shape, logger))
return false;

const auto& perm_name = input_defs[1]->Name();
if (!Contains(initializers, perm_name)) {
LOGS(logger, VERBOSE) << "New shape of reshape must be a constant initializer";
Expand All @@ -92,24 +101,11 @@ bool ReshapeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializer

const int64_t* raw_new_shape = reinterpret_cast<const int64_t*>(unpacked_tensor.data());
const auto& perm_dims = perm_tensor.dims();
if (perm_dims.empty() || perm_dims[0] == 0) {
LOGS(logger, VERBOSE) << "New shape of reshape cannot be empty";
return false;
}

std::vector<int64_t> input_shape;
if (!GetShape(*input_defs[0], input_shape, logger))
return false;

if (input_shape.empty()) {
LOGS(logger, VERBOSE) << "Reshape does not support empty input shape";
return false;
}

// WebNN reshape does not support 0 as dimension.
NodeAttrHelper helper(node);
const bool allow_zero = helper.Get("allowzero ", 0) == 1;
if (allow_zero) {
const bool allow_zero = helper.Get("allowzero", 0) == 1;
if (allow_zero && !perm_dims.empty()) {
for (int64_t i = 0; i < perm_dims[0]; i++) {
if (raw_new_shape[i] == 0) {
LOGS_DEFAULT(VERBOSE) << "Reshape doesn't support 0 reshape dimension when allowzero is enabled";
Expand Down

0 comments on commit cd60af0

Please sign in to comment.