diff --git a/neuron/APUWareUtilsApi.h b/neuron/APUWareUtilsApi.h index aa65709..904c408 100644 --- a/neuron/APUWareUtilsApi.h +++ b/neuron/APUWareUtilsApi.h @@ -1,30 +1,29 @@ /* -* Copyright (C) 2021 MediaTek Inc., this file is modified on 02/26/2021 -* by MediaTek Inc. based on MIT License . -* Permission is hereby granted, free of charge, to any person obtaining a copy -* of this software and associated documentation files (the ""Software""), to -* deal in the Software without restriction, including without limitation the -* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or -* sell copies of the Software, and to permit persons to whom the Software is -* furnished to do so, subject to the following conditions: -* -* The above copyright notice and this permission notice shall be included in all -* copies or substantial portions of the Software. -* -* THE SOFTWARE IS PROVIDED ""AS IS"", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -* SOFTWARE. -*/ + * Copyright (C) 2021 MediaTek Inc., this file is modified on 02/26/2021 + * by MediaTek Inc. based on MIT License . + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the ""Software""), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED ""AS IS"", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ #pragma once #include #include -#include #include #include @@ -40,8 +39,6 @@ typedef enum { PERFORMANCE_MODE_MAX, } PERFORMANCE_MODE_E; -#define ABORT_ON_DLOPEN_ERROR - //------------------------------------- ------------------------------------- #define APUWARE_LOG_D(format, ...) \ __android_log_print(ANDROID_LOG_DEBUG, "APUWARELIB", format "\n", \ @@ -86,7 +83,12 @@ inline void* loadApuWareUtilsLibrary(const char* name) { inline void* getApuWareUtilsLibraryHandle() { if (sAPUWareUtilsLibHandle == nullptr) { - sAPUWareUtilsLibHandle = loadApuWareUtilsLibrary("libapuwareutils.mtk.so"); + sAPUWareUtilsLibHandle = + loadApuWareUtilsLibrary("libapuwareutils_v2.mtk.so"); + if (sAPUWareUtilsLibHandle == nullptr) { + sAPUWareUtilsLibHandle = + loadApuWareUtilsLibrary("libapuwareutils.mtk.so"); + } } return sAPUWareUtilsLibHandle; } diff --git a/neuron/neuron_delegate.cc b/neuron/neuron_delegate.cc index ade6504..a2afc3f 100644 --- a/neuron/neuron_delegate.cc +++ b/neuron/neuron_delegate.cc @@ -27,9 +27,11 @@ #include "neuron/neuron_delegate_kernel.h" #include "neuron/neuron_delegate_validation.h" #include "neuron/neuron_implementation.h" +#include "tensorflow/lite/builtin_ops.h" #include "tensorflow/lite/context_util.h" #include "tensorflow/lite/delegates/utils/simple_delegate.h" #include "tensorflow/lite/minimal_logging.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace neuron { @@ -50,8 +52,11 @@ class NeuronDelegate : public SimpleDelegateInterface { std::vector failure; bool supported = Validate(registration, node, context, &failure); if (!supported) { - TFLITE_LOG_PROD(tflite::TFLITE_LOG_ERROR, "OP %d is not supported(%s)", - registration->builtin_code, failure[0].message.c_str()); + TFLITE_LOG_PROD( + tflite::TFLITE_LOG_ERROR, "OP %s (v%d) is not supported (%s)", + tflite::EnumNameBuiltinOperator( + static_cast(registration->builtin_code)), + registration->version, failure[0].message.c_str()); } return supported; } diff --git a/neuron/neuron_delegate_builder.h b/neuron/neuron_delegate_builder.h index cea4ca0..9a8ef71 100644 --- a/neuron/neuron_delegate_builder.h +++ b/neuron/neuron_delegate_builder.h @@ -43,6 +43,7 @@ enum { NN_TENSOR_FLAG_SCALAR_AS_TENSOR = 1U << 0, NN_TENSOR_FLAG_INT8_CONVERSION = 1U << 1, NN_TENSOR_FLAG_USE_INT8_ASYMM_SIGNED = 1U << 2, + NN_TENSOR_FLAG_FORCE_PER_CHANNEL = 1U << 3, }; class DequantizeMapping { @@ -406,6 +407,65 @@ class NeuronOpBuilder { return kTfLiteOk; } + // Add a RESHAPE op which reshapes an NNAPI intermediate output to the + // dimensions of the TFLite output tensor. + TfLiteStatus AppendReshape(int nn_input_index, int lite_out_tensor_index) { + augmented_inputs_.push_back(nn_input_index); + auto& output_tensor = context_->tensors[lite_out_tensor_index]; + TF_LITE_ENSURE_STATUS( + AddVectorInt32Operand(output_tensor.dims->data, + static_cast(output_tensor.dims->size))); + TF_LITE_ENSURE_OK(context_, + AddTensorOutput(lite_out_tensor_index, + NN_TENSOR_FLAG_USE_INT8_ASYMM_SIGNED)); + TF_LITE_ENSURE_STATUS( + FinalizeAddOperation(NEURON_RESHAPE)); + return kTfLiteOk; + } + + // Lower PACK into CONCAT + RESHAPE when possible + TfLiteStatus TransformPackIntoSupportedOps(TfLiteNode* node, + TfLiteRegistration* reg) { + // Add input tensors for CONCAT, and calculate the dimensions for the + // output. + int concat_output_ann_index = -1; + TfLitePackParams* builtin = + reinterpret_cast(node->builtin_data); + auto& input_tensor = context_->tensors[node->inputs->data[0]]; + int axis = builtin->axis < 0 ? input_tensor.dims->size + builtin->axis + 1 + : builtin->axis; + TF_LITE_ENSURE(context_, axis < input_tensor.dims->size); + uint32_t concat_dim_size = 0; + for (int input_pos = 0; input_pos < node->inputs->size; ++input_pos) { + const auto input_index = node->inputs->data[input_pos]; + concat_dim_size += + context_->tensors[node->inputs->data[input_pos]].dims->data[axis]; + TF_LITE_ENSURE_STATUS( + AddTensorInput(input_index, /*hybrid_op=*/false, + NN_TENSOR_FLAG_USE_INT8_ASYMM_SIGNED)); + } + TF_LITE_ENSURE_STATUS(AddScalarInt32Operand(axis)); + std::vector concat_output_shape(input_tensor.dims->size, 0); + for (int i = 0; i < concat_output_shape.size(); i++) { + if (i == axis) { + concat_output_shape[i] = concat_dim_size; + } else { + concat_output_shape[i] = input_tensor.dims->data[i]; + } + } + TF_LITE_ENSURE_STATUS(AddIntermediateOutputTensor( + input_tensor.type, concat_output_shape.size(), + concat_output_shape.data(), input_tensor.params.scale, + input_tensor.params.zero_point, &concat_output_ann_index)); + TF_LITE_ENSURE_STATUS( + FinalizeAddOperation(NEURON_CONCATENATION)); + + // Reshape the output tensor + TF_LITE_ENSURE_STATUS(AppendReshape( + concat_output_ann_index, node->outputs->data[0])); + return kTfLiteOk; + } + // Finish emitting the op (of type `type`) into the Neuron. TfLiteStatus FinalizeAddOperation(NeuronOperationType type) { // Actually add a Neuron operation @@ -517,6 +577,31 @@ class NeuronOpBuilder { return result; } + TfLiteStatus AddIntermediateOutputTensor(TfLiteType tfl_type, + uint32_t dimension_count, + const uint32_t* dimension_data, + float scale, int32_t zero_point, + int* ann_index_out) { + int32_t nn_type; + switch (tfl_type) { + case kTfLiteFloat32: + nn_type = NEURON_TENSOR_FLOAT32; + break; + case kTfLiteInt8: + nn_type = NEURON_TENSOR_QUANT8_ASYMM_SIGNED; + break; + case kTfLiteUInt8: + nn_type = NEURON_TENSOR_QUANT8_ASYMM; + break; + default: + return kTfLiteError; + } + TF_LITE_ENSURE_STATUS( + AddAdditionalOutputTensor(dimension_count, dimension_data, nn_type, + scale, zero_point, ann_index_out)); + return kTfLiteOk; + } + private: // Returns a TF Lite type which has the same memory representation as a // provided Neuron type. @@ -626,6 +711,8 @@ class NeuronOpBuilder { tensor_flags & NN_TENSOR_FLAG_INT8_CONVERSION; const bool use_int8_asymm_signed = tensor_flags & NN_TENSOR_FLAG_USE_INT8_ASYMM_SIGNED; + const bool force_per_channel = + tensor_flags & NN_TENSOR_FLAG_FORCE_PER_CHANNEL; int neuron_tensor_index = operand_mapping_->lite_index_to_neuron(tensor_index); if (neuron_tensor_index != -1) { @@ -684,7 +771,7 @@ class NeuronOpBuilder { TfLiteAffineQuantization* quantization_params = static_cast( tensor->quantization.params); - if (quantization_params->scale->size > 1) { + if (quantization_params->scale->size > 1 || force_per_channel) { // Set up per-channel quantization. ann_perchannel_params = { .channelDim = static_cast( diff --git a/neuron/neuron_delegate_kernel.cc b/neuron/neuron_delegate_kernel.cc index 77ae6dc..577da27 100644 --- a/neuron/neuron_delegate_kernel.cc +++ b/neuron/neuron_delegate_kernel.cc @@ -1,24 +1,24 @@ /* -* Copyright (C) 2021 MediaTek Inc., this file is modified on 02/26/2021 -* by MediaTek Inc. based on MIT License . -* Permission is hereby granted, free of charge, to any person obtaining a copy -* of this software and associated documentation files (the ""Software""), to -* deal in the Software without restriction, including without limitation the -* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or -* sell copies of the Software, and to permit persons to whom the Software is -* furnished to do so, subject to the following conditions: -* -* The above copyright notice and this permission notice shall be included in all -* copies or substantial portions of the Software. -* -* THE SOFTWARE IS PROVIDED ""AS IS"", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -* SOFTWARE. -*/ + * Copyright (C) 2021 MediaTek Inc., this file is modified on 02/26/2021 + * by MediaTek Inc. based on MIT License . + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the ""Software""), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED ""AS IS"", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ #include "neuron/neuron_delegate_kernel.h" @@ -55,6 +55,8 @@ bool IsScalarInputSupported(int builtin_code) { case kTfLiteBuiltinPow: case kTfLiteBuiltinMaximum: case kTfLiteBuiltinMinimum: + case kTfLiteBuiltinPrelu: + case kTfLiteBuiltinLeakyRelu: case kTfLiteBuiltinReduceMax: case kTfLiteBuiltinSum: return true; @@ -83,6 +85,16 @@ bool NeedInt8Conversion(const TfLiteContext* context, int builtin_code, } return false; } + case kTfLiteBuiltinTransposeConv: { + // Transpose convolution has a different order of inputs: + // 0: output_shape, 1: filter, 2: input, 3: bias. + const int input_id = 2; + const TfLiteType input_type = context->tensors[input_id].type; + if (input_type == kTfLiteInt8) { + return true; + } + return false; + } case kTfLiteBuiltinSelect: { const auto value_type = context->tensors[node->inputs->data[1]].type; return value_type == kTfLiteInt8; @@ -100,6 +112,7 @@ bool NeedInt8Conversion(const TfLiteContext* context, int builtin_code, case kTfLiteBuiltinGreaterEqual: case kTfLiteBuiltinHardSwish: case kTfLiteBuiltinL2Normalization: + case kTfLiteBuiltinLeakyRelu: case kTfLiteBuiltinLess: case kTfLiteBuiltinLessEqual: case kTfLiteBuiltinLogistic: @@ -111,6 +124,7 @@ bool NeedInt8Conversion(const TfLiteContext* context, int builtin_code, case kTfLiteBuiltinNotEqual: case kTfLiteBuiltinPad: case kTfLiteBuiltinPadv2: + case kTfLiteBuiltinPrelu: case kTfLiteBuiltinReduceMax: case kTfLiteBuiltinReduceMin: case kTfLiteBuiltinRelu: @@ -231,6 +245,50 @@ TfLiteStatus NeuronDelegateKernel::Map(TfLiteContext* context, int builtin_code, int version, int neuron_sdk_version, const NeuronOpMappingArgs& mapping_args, NeuronOperationType* nn_op_type) { + auto add_zero_bias = [mapping_args](int input_id, int filter_id, + int num_elements) -> void { + // Neuron requires a bias tensor, so we allocate a new tensor to fill + // it with zeroes. It is deleted with other tensors in the context + // during subgraph destructor call. + int bias_index = -1; + mapping_args.context->AddTensors(mapping_args.context, 1, &bias_index); + TfLiteTensor* bias_tensor = &mapping_args.context->tensors[bias_index]; + const auto input_type = mapping_args.context->tensors[input_id].type; + if (input_type == kTfLiteFloat32) { + bias_tensor->type = kTfLiteFloat32; + } else { + bias_tensor->type = kTfLiteInt32; + } + // Create an array with a required bias shape and resize the bias + // tensor. + TfLiteIntArray* bias_shape = TfLiteIntArrayCreate(1); + bias_shape->data[0] = num_elements; + bias_tensor->allocation_type = kTfLiteDynamic; + mapping_args.context->ResizeTensor(mapping_args.context, bias_tensor, + bias_shape); + // Set tensor's values to zeroes and add it using AddVector*, so + // that the values are copied to Neuron. We don't use the AddTensor + // function because it doesn't copy values and the tensor we just + // created is not in the node->inputs. + if (input_type == kTfLiteFloat32) { + memset(bias_tensor->data.f, 0, num_elements * sizeof(float)); + mapping_args.builder->AddVectorFloat32Operand(bias_tensor->data.f, + num_elements); + } else { + memset(bias_tensor->data.i32, 0, num_elements * sizeof(int)); + const TfLiteTensor& input_tensor = + mapping_args.context->tensors[input_id]; + const TfLiteTensor& filter_tensor = + mapping_args.context->tensors[filter_id]; + // Neuron requires bias scale to be a product of an input scale and + // a filter scale. + bias_tensor->params.scale = + input_tensor.params.scale * filter_tensor.params.scale; + mapping_args.builder->AddVectorInt32Operand( + bias_tensor->data.i32, num_elements, bias_tensor->params.scale, + /*zero_point=*/0); + } + }; switch (builtin_code) { case kTfLiteBuiltinAdd: { auto builtin = @@ -309,6 +367,18 @@ TfLiteStatus NeuronDelegateKernel::Map(TfLiteContext* context, int builtin_code, *nn_op_type = NEURON_SOFTMAX; } break; case kTfLiteBuiltinReshape: { + if (mapping_args.node->inputs->size == 1) { + // if no new_shape tensor, construct the new shape from params. + auto* params = reinterpret_cast( + mapping_args.node->builtin_data); + int num_dimensions = params->num_dimensions; + std::vector output_shape(num_dimensions); + for (int i = 0; i < num_dimensions; ++i) { + output_shape[i] = params->shape[i]; + } + mapping_args.builder->AddVectorInt32Operand( + output_shape.data(), static_cast(num_dimensions)); + } *nn_op_type = NEURON_RESHAPE; } break; case kTfLiteBuiltinResizeBilinear: { @@ -460,67 +530,42 @@ TfLiteStatus NeuronDelegateKernel::Map(TfLiteContext* context, int builtin_code, *nn_op_type = NEURON_SLICE; } break; case kTfLiteBuiltinTransposeConv: { - const bool hybrid_op = IsHybridOperator( - mapping_args.context, kTfLiteBuiltinTransposeConv, mapping_args.node); + int input_tensor_flags = 0; + const int input_tensor_id = + mapping_args.node->inputs->data[/*kDataInputTensor*/ 2]; + const int weight_tensor_id = + mapping_args.node->inputs->data[/*kWeightsTensor*/ 1]; + + // Transpose convolution doesn't have hybrid variation. + const bool hybrid_op = false; + mapping_args.builder->AddTensorInput( - mapping_args.node->inputs->data[/*kDataInputTensor*/ 2], hybrid_op); + input_tensor_id, hybrid_op, + input_tensor_flags | NN_TENSOR_FLAG_USE_INT8_ASYMM_SIGNED); + + // Transpose convlution uses per-channel quantization with int8 inputs + // even if the number of channels in quantization parameters is equal to 1 + // (as opposed to conv2d, which uses per-tensor quantization in this + // case). mapping_args.builder->AddTensorInput( - mapping_args.node->inputs->data[/*kWeightsTensor*/ 1], hybrid_op); - - // Neuron requires a bias tensor, so we allocate a new tensor to fill - // it with zeroes. It is deleted with other tensors in the context - // during subgraph destructor call. - int bias_index = -1; - mapping_args.context->AddTensors(mapping_args.context, 1, &bias_index); - TfLiteTensor* bias_tensor = &mapping_args.context->tensors[bias_index]; - const auto input_type = - mapping_args.context - ->tensors[mapping_args.node->inputs->data[/*kDataInputTensor*/ 2]] - .type; - if (input_type == kTfLiteFloat32) { - bias_tensor->type = kTfLiteFloat32; - } else { - bias_tensor->type = kTfLiteInt32; - } + weight_tensor_id, hybrid_op, + input_tensor_flags | NN_TENSOR_FLAG_FORCE_PER_CHANNEL); - // Create an array with a required bias shape and resize the bias - // tensor. - TfLiteIntArray* bias_shape = TfLiteIntArrayCreate(1); - const TfLiteTensor& output_shape = - mapping_args.context->tensors[mapping_args.node->inputs - ->data[/*kOutputShapeTensor*/ 0]]; - const int output_depth = output_shape.data.i32[3]; - bias_shape->data[0] = output_depth; - bias_tensor->allocation_type = kTfLiteDynamic; - mapping_args.context->ResizeTensor(mapping_args.context, bias_tensor, - bias_shape); - - // Set tensor's values to zeroes and add it using AddVector*, so - // that the values are copied to Neuron. We don't use the AddTensor - // function because it doesn't copy values and the tensor we just - // created is not in the node->inputs. - if (input_type == kTfLiteFloat32) { - memset(bias_tensor->data.f, 0, output_depth * sizeof(float)); - mapping_args.builder->AddVectorFloat32Operand(bias_tensor->data.f, - output_depth); + const bool is_bias_present = + mapping_args.node->inputs->size == 4 && + mapping_args.node->inputs->data[/*kBiasTensor*/ 3] != + kTfLiteOptionalTensor; + + if (is_bias_present) { + mapping_args.builder->AddTensorInput( + mapping_args.node->inputs->data[/*kBiasTensor*/ 3], hybrid_op); } else { - memset(bias_tensor->data.i32, 0, output_depth * sizeof(int)); - const TfLiteTensor& input_tensor = - mapping_args.context->tensors[mapping_args.node->inputs - ->data[/*kDataInputTensor*/ 2]]; - const TfLiteTensor& filter_tensor = + const TfLiteTensor& output_shape = mapping_args.context->tensors[mapping_args.node->inputs - ->data[/*kWeightsTensor*/ 1]]; - // Neuron requires bias scale to be a product of an input scale and - // a filter scale. - bias_tensor->params.scale = - input_tensor.params.scale * filter_tensor.params.scale; - mapping_args.builder->AddVectorInt32Operand( - bias_tensor->data.i32, output_depth, - input_tensor.params.scale * filter_tensor.params.scale, - /*zero_point=*/0); + ->data[/*kOutputShapeTensor*/ 0]]; + const int output_depth = output_shape.data.i32[3]; + add_zero_bias(input_tensor_id, weight_tensor_id, output_depth); } - mapping_args.builder->AddTensorInput( mapping_args.node->inputs->data[/*kOutputShapeTensor*/ 0], hybrid_op); @@ -568,24 +613,40 @@ TfLiteStatus NeuronDelegateKernel::Map(TfLiteContext* context, int builtin_code, *nn_op_type = NEURON_CAST; } break; case kTfLiteBuiltinLeakyRelu: { + const auto input_type = + mapping_args.context->tensors[mapping_args.node->inputs->data[0]] + .type; auto builtin = reinterpret_cast( mapping_args.node->builtin_data); - TfLiteTensor t0; - t0.type = kTfLiteFloat32; - t0.allocation_type = kTfLiteDynamic; - t0.dims = TfLiteIntArrayCreate(1); - t0.dims->data[0] = 1; - t0.params.scale = 0; - t0.params.zero_point = 0; + TfLiteTensor alpha_tensor; + alpha_tensor.type = input_type; + alpha_tensor.allocation_type = kTfLiteDynamic; + alpha_tensor.dims = TfLiteIntArrayCreate(1); + alpha_tensor.dims->data[0] = 1; + alpha_tensor.params.zero_point = 0; - std::vector dims_float(1); - std::fill(dims_float.begin(), dims_float.end(), builtin->alpha); int new_tensor_index = -1; - mapping_args.builder->AddNewInputConstantTensor( - NEURON_TENSOR_FLOAT32, kTfLiteFloat32, t0.dims, dims_float, t0.params, - &new_tensor_index); - TfLiteIntArrayFree(t0.dims); + if (input_type == kTfLiteFloat32) { + alpha_tensor.params.scale = 0; + std::vector alpha_value = {builtin->alpha}; + mapping_args.builder->AddNewInputConstantTensor( + NEURON_TENSOR_FLOAT32, kTfLiteFloat32, alpha_tensor.dims, + alpha_value, alpha_tensor.params, &new_tensor_index); + } else if (input_type == kTfLiteInt8) { + alpha_tensor.params.scale = builtin->alpha; + std::vector alpha_value = {1}; + mapping_args.builder->AddNewInputConstantTensor( + NEURON_TENSOR_QUANT8_ASYMM_SIGNED, kTfLiteInt8, alpha_tensor.dims, + alpha_value, alpha_tensor.params, &new_tensor_index); + } else { + alpha_tensor.params.scale = builtin->alpha; + std::vector alpha_value = {1}; + mapping_args.builder->AddNewInputConstantTensor( + NEURON_TENSOR_QUANT8_ASYMM, kTfLiteUInt8, alpha_tensor.dims, + alpha_value, alpha_tensor.params, &new_tensor_index); + } + *nn_op_type = NEURON_PRELU; } break; case kTfLiteBuiltinPrelu: { @@ -1036,6 +1097,13 @@ TfLiteStatus NeuronDelegateKernel::AddOpsAndTensors(TfLiteContext* context) { TF_LITE_ENSURE_STATUS( context->GetNodeAndRegistration(context, node_index, &node, ®)); + // Delegate PACK by lowering it into CONCAT + RESHAPE. + if (reg->builtin_code == kTfLiteBuiltinPack) { + TF_LITE_ENSURE_STATUS( + builder.TransformPackIntoSupportedOps(node, reg)); + continue; + } + const bool hybrid_op = IsHybridOperator(context, reg->builtin_code, node); const bool scalar_as_tensor = IsScalarInputSupported(reg->builtin_code); const bool need_int8_conversion = @@ -1059,6 +1127,11 @@ TfLiteStatus NeuronDelegateKernel::AddOpsAndTensors(TfLiteContext* context) { // Map inputs to Neuron tensor indices. for (int input_pos = 0; input_pos < node->inputs->size; ++input_pos) { + if (reg->builtin_code == kTfLiteBuiltinTransposeConv) { + // Everything is added during Map since input tensors + // have different order. + continue; + } const auto input_index = node->inputs->data[input_pos]; if (need_int8_conversion && (input_pos == 0 || @@ -1071,8 +1144,10 @@ TfLiteStatus NeuronDelegateKernel::AddOpsAndTensors(TfLiteContext* context) { reg->builtin_code == kTfLiteBuiltinConcatenation || reg->builtin_code == kTfLiteBuiltinMaximum || reg->builtin_code == kTfLiteBuiltinMinimum || + reg->builtin_code == kTfLiteBuiltinLeakyRelu || reg->builtin_code == kTfLiteBuiltinLess || reg->builtin_code == kTfLiteBuiltinLessEqual || + reg->builtin_code == kTfLiteBuiltinPrelu || reg->builtin_code == kTfLiteBuiltinGreater || reg->builtin_code == kTfLiteBuiltinGreaterEqual || reg->builtin_code == kTfLiteBuiltinEqual || @@ -1115,11 +1190,6 @@ TfLiteStatus NeuronDelegateKernel::AddOpsAndTensors(TfLiteContext* context) { // by the Map() mapping. continue; } - if (reg->builtin_code == kTfLiteBuiltinTransposeConv) { - // Everything is added during Map since input tensors - // have different order. - continue; - } // Pad and Padv2 have an optional parameter for a pad value which has // to be converted to a scalar type in Neuron. @@ -1258,23 +1328,45 @@ TfLiteStatus NeuronDelegateKernel::AddOpsAndTensors(TfLiteContext* context) { if (use_int8_asymm_signed) { output_tensor_flags |= NN_TENSOR_FLAG_USE_INT8_ASYMM_SIGNED; } + // fc_nn_intermediate_output_index is used to indicate whether additional + // RESHAPE op is needed. + int fc_nn_intermediate_output_index = -1; for (int output_pos = 0; output_pos < node->outputs->size; ++output_pos) { - const auto output_index = node->outputs->data[output_pos]; + auto output_index = node->outputs->data[output_pos]; // Outputs for basic LSTM cell are set in the Map function since if (reg->builtin_code == kTfLiteBuiltinLstm && isLstmBasicKernel(node)) { continue; } - - TF_LITE_ENSURE_STATUS( - builder.AddTensorOutput(output_index, output_tensor_flags)); + // Handle FC with keep_num_dims==true. + if (reg->builtin_code == kTfLiteBuiltinFullyConnected && + reinterpret_cast(node->builtin_data) + ->keep_num_dims) { + auto& output_tensor = context->tensors[output_index]; + + int num_units = output_tensor.dims->data[output_tensor.dims->size - 1]; + std::vector output_dims(2); + output_dims[0] = NumElements(output_tensor.dims) / num_units; + output_dims[1] = num_units; + TF_LITE_ENSURE_STATUS(builder.AddIntermediateOutputTensor( + output_tensor.type, output_dims.size(), output_dims.data(), + output_tensor.params.scale, output_tensor.params.zero_point, + &fc_nn_intermediate_output_index)); + } else { + TF_LITE_ENSURE_STATUS( + builder.AddTensorOutput(output_index, output_tensor_flags)); + } } // Dequantize operators may have to be added in case inputs are to be // floating-point. AddDequantizeOperatorsWhereNeeded(context, reg->builtin_code, node, &builder); - builder.FinalizeAddOperation(nn_op_type); + TF_LITE_ENSURE_STATUS(builder.FinalizeAddOperation(nn_op_type)); + if (fc_nn_intermediate_output_index > -1) { + TF_LITE_ENSURE_STATUS(builder.AppendReshape( + fc_nn_intermediate_output_index, node->outputs->data[0])); + } } TFLITE_LOG_PROD(tflite::TFLITE_LOG_INFO, "AddOpsAndTensors done"); return kTfLiteOk; diff --git a/neuron/neuron_delegate_validation.cc b/neuron/neuron_delegate_validation.cc index 3848cd5..0c0c720 100644 --- a/neuron/neuron_delegate_validation.cc +++ b/neuron/neuron_delegate_validation.cc @@ -232,11 +232,6 @@ bool Validate(const TfLiteRegistration* registration, const TfLiteNode* node, } break; case kTfLiteBuiltinFullyConnected: { ExpectMaxOpVersion(version, 5, &val_ctx); - // TODO(Code): Add support for FullyConnected with no bias. - Expect(node->inputs->size == 3 && - node->inputs->data[2] != kTfLiteOptionalTensor, - NeuronValidationFailureType::kMissingRequiredOperand, - "FullyConnected with no bias not supported", &val_ctx); const TfLiteType input_type = context->tensors[node->inputs->data[(0)]].type; if (input_type == kTfLiteUInt8) { @@ -248,8 +243,14 @@ bool Validate(const TfLiteRegistration* registration, const TfLiteNode* node, } break; case kTfLiteBuiltinSoftmax: { ExpectOpVersion(version, 2, &val_ctx); - const auto& input = context->tensors[node->outputs->data[0]]; ExpectIsFloatOrQuant8Operator(context, node, &val_ctx); + const auto& output = context->tensors[node->outputs->data[0]]; + ExpectTypeIn(output.type, {kTfLiteFloat32, kTfLiteUInt8, kTfLiteInt8}, + NeuronValidationFailureType::kUnsupportedOutputType, + "Output type should be one of kTfLiteFloat32, kTfLiteUInt8, " + "kTfLiteInt8.", + &val_ctx); + const auto& input = context->tensors[node->inputs->data[0]]; const int input_rank = input.dims->size; Expect(input_rank <= 4, NeuronValidationFailureType::kUnsupportedOperandRank, @@ -258,15 +259,26 @@ bool Validate(const TfLiteRegistration* registration, const TfLiteNode* node, case kTfLiteBuiltinReshape: { ExpectOpVersion(version, 1, &val_ctx); ExpectIsFloatOrQuant8Operator(context, node, &val_ctx); - Expect(node->inputs->size >= 2, - NeuronValidationFailureType::kMissingRequiredOperand, - "Expected at least 2 inputs", &val_ctx); if (node->inputs->size >= 2) { Expect(context->tensors[node->inputs->data[1]].allocation_type == kTfLiteMmapRo, NeuronValidationFailureType::kInputTensorShouldHaveConstantShape, "The shape input tensor must be constant.", &val_ctx); } + if (node->inputs->size == 1) { + // reject scalar reshaping + auto* params = + reinterpret_cast(node->builtin_data); + int num_dimensions = params->num_dimensions; + if (num_dimensions == 1 && params->shape[0] == 0) { + // Legacy tflite models use a shape parameter of [0] to indicate + // scalars. + num_dimensions = 0; + } + Expect(num_dimensions > 0, + NeuronValidationFailureType::kUnsupportedOperandRank, + "New shape rank should be > 0", &val_ctx); + } } break; case kTfLiteBuiltinResizeBilinear: { ExpectMaxOpVersion(version, 3, &val_ctx); @@ -306,6 +318,10 @@ bool Validate(const TfLiteRegistration* registration, const TfLiteNode* node, Expect(context->tensors[node->inputs->data[0]].dims->size <= 4, NeuronValidationFailureType::kUnsupportedOperandRank, "Input rank should be less than 4", &val_ctx); + + const auto& input_type = context->tensors[node->inputs->data[0]].type; + EXPECT_INPUT_TYPE_IN(input_type, kTfLiteFloat16, kTfLiteFloat32, + kTfLiteUInt8, kTfLiteInt8); } break; case kTfLiteBuiltinDequantize: { ExpectOpVersion(version, 2, &val_ctx); @@ -403,7 +419,7 @@ bool Validate(const TfLiteRegistration* registration, const TfLiteNode* node, ExpectOpVersion(version, 1, &val_ctx); } break; case kTfLiteBuiltinTransposeConv: { - ExpectOpVersion(version, 1, &val_ctx); + ExpectMaxOpVersion(version, 3, &val_ctx); Expect((node->inputs->size > 1) && (context->tensors[node->inputs->data[0]].allocation_type == kTfLiteMmapRo) && @@ -537,6 +553,9 @@ bool Validate(const TfLiteRegistration* registration, const TfLiteNode* node, EXPECT_INPUT_TYPE_IN(input_type, kTfLiteFloat32, kTfLiteFloat16, kTfLiteInt32, kTfLiteUInt8, kTfLiteInt8); + Expect(positions.allocation_type == kTfLiteMmapRo, + NeuronValidationFailureType::kUnsupportedInputType, + "Neuron only supports constant int32 positions tensor", &val_ctx); Expect(positions.type == kTfLiteInt32, NeuronValidationFailureType::kUnsupportedInputType, "Positions type should be one of kTfLiteInt32", &val_ctx); @@ -640,6 +659,18 @@ bool Validate(const TfLiteRegistration* registration, const TfLiteNode* node, &val_ctx); } } break; + case kTfLiteBuiltinPack: { + ExpectOpVersion(version, 2, &val_ctx); + const auto input_type = context->tensors[node->inputs->data[0]].type; + EXPECT_INPUT_TYPE_IN(input_type, kTfLiteInt32, kTfLiteFloat32, + kTfLiteInt8); + auto builtin = reinterpret_cast(node->builtin_data); + Expect(builtin->axis != -1 && + builtin->axis != + context->tensors[node->inputs->data[0]].dims->size, + NeuronValidationFailureType::kUnsupportedOperandValue, + "Neuron does not support axis being the last dimension", &val_ctx); + } break; default: // All other operators are not mapped. TFLITE_LOG_PROD(tflite::TFLITE_LOG_WARNING,