diff --git a/runtime/onert/backend/train/ops/FullyConnectedLayer.cc b/runtime/onert/backend/train/ops/FullyConnectedLayer.cc index 9d35655b26f..1c79e796d55 100644 --- a/runtime/onert/backend/train/ops/FullyConnectedLayer.cc +++ b/runtime/onert/backend/train/ops/FullyConnectedLayer.cc @@ -186,13 +186,11 @@ void FullyConnectedLayer::backwardFloat32() getShape(_grad_weights), getBuffer(_grad_weights)); // Compute gradient for bias - if (_bias) - { - assert(_grad_bias); - nnfw::cker::train::FullyConnectedBiasGrad(getShape(backprop_act), - getBuffer(backprop_act), getShape(_grad_bias), - getBuffer(_grad_bias)); - } + // Bias tensor must exist on training even if bias is optinal input + assert(_bias); + assert(_grad_bias); + nnfw::cker::train::FullyConnectedBiasGrad(getShape(backprop_act), getBuffer(backprop_act), + getShape(_grad_bias), getBuffer(_grad_bias)); } } // namespace ops diff --git a/runtime/onert/core/include/ir/IOperation.h b/runtime/onert/core/include/ir/IOperation.h index be0dd939da6..c340d62e7cb 100644 --- a/runtime/onert/core/include/ir/IOperation.h +++ b/runtime/onert/core/include/ir/IOperation.h @@ -40,6 +40,7 @@ struct IOperation virtual void replaceInputs(const OperandIndex &from, const OperandIndex &to) = 0; virtual void replaceOutputs(const OperandIndex &from, const OperandIndex &to) = 0; + virtual void replaceInput(size_t pos, const OperandIndex &index) = 0; virtual const OperandIndexSequence &getInputs() const = 0; virtual const OperandIndexSequence &getOutputs() const = 0; }; diff --git a/runtime/onert/core/include/ir/OperandIndexSequence.h b/runtime/onert/core/include/ir/OperandIndexSequence.h index 66d00761ba9..3f0e448282b 100644 --- a/runtime/onert/core/include/ir/OperandIndexSequence.h +++ b/runtime/onert/core/include/ir/OperandIndexSequence.h @@ -51,6 +51,7 @@ class OperandIndexSequence const OperandIndex &at(uint32_t index) const { return _vec.at(index); } bool contains(const OperandIndex &index) const; void replace(const OperandIndex &from, const OperandIndex &to); + void replace(size_t pos, const OperandIndex &index); OperandIndexSequence operator|(ir::Remove filter) const { switch (filter) diff --git a/runtime/onert/core/include/ir/Operation.h b/runtime/onert/core/include/ir/Operation.h index 06ab29ecb19..be0d2243b5e 100644 --- a/runtime/onert/core/include/ir/Operation.h +++ b/runtime/onert/core/include/ir/Operation.h @@ -50,6 +50,7 @@ class Operation : virtual public IOperation public: void replaceInputs(const OperandIndex &from, const OperandIndex &to) override; void replaceOutputs(const OperandIndex &from, const OperandIndex &to) override; + void replaceInput(size_t pos, const OperandIndex &index) override; OperandIndexSequence &getInputs() { return _inputs; } const OperandIndexSequence &getInputs() const override { return _inputs; } const OperandIndexSequence &getOutputs() const override { return _outputs; } diff --git a/runtime/onert/core/include/ir/train/TrainableGraph.h b/runtime/onert/core/include/ir/train/TrainableGraph.h index 4952a6b513f..61be9d63384 100644 --- a/runtime/onert/core/include/ir/train/TrainableGraph.h +++ b/runtime/onert/core/include/ir/train/TrainableGraph.h @@ -123,6 +123,7 @@ class TrainableGraph : public IGraph const Operands &operands() const override { return _graph.operands(); } Operands &operands() { return _graph.operands(); } // TODO Remove this non-const accessor const Operations &operations() const override { return _graph.operations(); } + Operations &operations() { return _graph.operations(); } const Operands &backward_operands() const { return _backward_operands; } OperandIndex getLossIndex(const IOIndex &pred_io_ind) const; const Graph &graph() const { return _graph; } diff --git a/runtime/onert/core/src/compiler/train/TrainingCompiler.cc b/runtime/onert/core/src/compiler/train/TrainingCompiler.cc index aee23a6d992..d6996620447 100644 --- a/runtime/onert/core/src/compiler/train/TrainingCompiler.cc +++ b/runtime/onert/core/src/compiler/train/TrainingCompiler.cc @@ -18,6 +18,7 @@ #include "StaticBackwardShapeInferer.h" #include "TrainableOperationConverter.h" +#include "pass/BiasInsertionPass.h" #include "pass/LossInsertionPass.h" #include "../CompilerHelpers.h" #include "../ExecutorFactory.h" @@ -86,6 +87,7 @@ std::shared_ptr TrainingCompiler::compile(void) compiler::pass::PassRunner{} .append(std::make_unique(subg)) .append(std::make_unique(subg)) + .append(std::make_unique(subg)) .run(); // Optimizations diff --git a/runtime/onert/core/src/compiler/train/pass/BiasInsertionPass.cc b/runtime/onert/core/src/compiler/train/pass/BiasInsertionPass.cc new file mode 100644 index 00000000000..0c0bdb8a0c6 --- /dev/null +++ b/runtime/onert/core/src/compiler/train/pass/BiasInsertionPass.cc @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "BiasInsertionPass.h" + +#include "ir/Graph.h" +#include "util/logging.h" + +namespace onert +{ +namespace compiler +{ +namespace train +{ +namespace pass +{ + +void BiasInsertionPass::run() +{ + _graph.operations().iterate([&](const ir::OperationIndex &op_index, const ir::IOperation &node) { + _current_op_index = op_index; + node.accept(*this); + }); +} + +void BiasInsertionPass::visit(const ir::operation::Conv2D &) {} + +void BiasInsertionPass::visit(const ir::operation::DepthwiseConv2D &) {} + +void BiasInsertionPass::visit(const ir::operation::FullyConnected &node) +{ + const auto &bias_index = node.getInputs().at(ir::operation::Conv2D::Input::BIAS); + + // Insert bias if it is optional + if (!bias_index.valid()) + { + const auto &output_index = node.getOutputs().at(0); + const auto &output = _graph.operands().at(output_index); + const auto &output_shape = output.shape(); + const auto bias_shape = ir::Shape{output_shape.dim(output_shape.rank() - 1)}; + + auto bias_typeinfo = output.typeInfo(); + if (bias_typeinfo.type() != ir::DataType::FLOAT32) + throw std::runtime_error("BiasInsertionPass: Only FLOAT32 is supported for now"); + + const auto new_bias_index = _graph.addOperand(bias_shape, output.typeInfo()); + + // TODO Replace data with sparse data to reduce memory usage + const auto bias_size = bias_shape.num_elements() * ir::sizeOfDataType(bias_typeinfo.type()); + std::vector data_vec(bias_size, 0); + auto data_obj = std::make_shared(data_vec.data(), bias_size); + _graph.setOperandValue(new_bias_index, std::move(data_obj)); + + auto &bias = _graph.operands().at(new_bias_index); + bias.insertUse(_current_op_index); + bias.isConstant(); + + _graph.operations() + .at(_current_op_index) + .replaceInput(ir::operation::Conv2D::Input::BIAS, new_bias_index); + + VERBOSE(BiasInsertionPass) << "Optional bias is inserted for training, bias index : " + << bias_index << std::endl; + } +} + +} // namespace pass +} // namespace train +} // namespace compiler +} // namespace onert diff --git a/runtime/onert/core/src/compiler/train/pass/BiasInsertionPass.h b/runtime/onert/core/src/compiler/train/pass/BiasInsertionPass.h new file mode 100644 index 00000000000..96801edb592 --- /dev/null +++ b/runtime/onert/core/src/compiler/train/pass/BiasInsertionPass.h @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_COMPILER_TRAIN_PASS_BIAS_INSERTION_PASS_H__ +#define __ONERT_COMPILER_TRAIN_PASS_BIAS_INSERTION_PASS_H__ + +#include "../../pass/Pass.h" +#include "ir/OperationVisitor.h" + +namespace onert +{ +namespace compiler +{ +namespace train +{ +namespace pass +{ + +class BiasInsertionPass final : public compiler::pass::Pass, public ir::OperationVisitor +{ +public: + BiasInsertionPass(ir::Graph &graph) : compiler::pass::Pass{graph} {} + +public: + std::string id() final { return "BiasInsertionPass"; } + void run() final; + +public: + void visit(const ir::operation::Conv2D &node) override; + void visit(const ir::operation::DepthwiseConv2D &node) override; + void visit(const ir::operation::FullyConnected &node) override; + +private: + ir::OperationIndex _current_op_index; +}; + +} // namespace pass +} // namespace train +} // namespace compiler +} // namespace onert + +#endif // __ONERT_COMPILER_TRAIN_PASS_BIAS_INSERTION_PASS_H__ diff --git a/runtime/onert/core/src/ir/OperandIndexSequence.cc b/runtime/onert/core/src/ir/OperandIndexSequence.cc index a15b6d0d69f..2d00725e5af 100644 --- a/runtime/onert/core/src/ir/OperandIndexSequence.cc +++ b/runtime/onert/core/src/ir/OperandIndexSequence.cc @@ -17,6 +17,7 @@ #include "ir/OperandIndexSequence.h" #include +#include #include namespace onert @@ -55,6 +56,12 @@ void OperandIndexSequence::replace(const OperandIndex &from, const OperandIndex std::replace(_vec.begin(), _vec.end(), from, to); } +void OperandIndexSequence::replace(size_t pos, const OperandIndex &index) +{ + assert(pos < _vec.size() && "OperandIndexSequence: Out of range"); + _vec.at(pos) = index; +} + bool OperandIndexSequence::operator==(const OperandIndexSequence &other) const { return _vec == other._vec; diff --git a/runtime/onert/core/src/ir/Operation.cc b/runtime/onert/core/src/ir/Operation.cc index 64792525dd7..10fb0296bb5 100644 --- a/runtime/onert/core/src/ir/Operation.cc +++ b/runtime/onert/core/src/ir/Operation.cc @@ -62,5 +62,7 @@ void Operation::replaceOutputs(const OperandIndex &from, const OperandIndex &to) _outputs.replace(from, to); } +void Operation::replaceInput(size_t pos, const OperandIndex &index) { _inputs.replace(pos, index); } + } // namespace ir } // namespace onert diff --git a/runtime/onert/core/src/ir/train/UseDefGenerator.cc b/runtime/onert/core/src/ir/train/UseDefGenerator.cc index af6f75003ff..066ff375e8b 100644 --- a/runtime/onert/core/src/ir/train/UseDefGenerator.cc +++ b/runtime/onert/core/src/ir/train/UseDefGenerator.cc @@ -211,6 +211,23 @@ void UseDefGenerator::visit(const train::operation::FullyConnected &node) const auto weights_forwarding_index = TrainingOperandIndex{weights_index, true}; insertUse(weights_forwarding_index, backwarding_op_index); + // // Set def and use for bias if bias is optional since bias is required on training + // const auto forwarding_op_index = TrainingOperationIndex{op_index, true}; + // const auto &bias_index = node.getInputs().at(ir::operation::Conv2D::Input::BIAS); + // const auto bias_forwarding_index = TrainingOperandIndex{bias_index, true}; + // if (!bias_index.valid()) + // { + // // If bias is not optional, the use for bias is already inserted before + // insertUse(bias_forwarding_index, forwarding_op_index); + + // // The forwarding bias is not used in backwarding + // } + // else + // { + // [[maybe_unused]] const auto &usedef_chain = _training_usedefs.at(bias_forwarding_index); + // assert(usedef_chain.getTrainingUses().count(forwarding_op_index) == 0); + // } + // Insert uses of forwarding output if (node.param().activation != ir::Activation::NONE) { @@ -220,18 +237,19 @@ void UseDefGenerator::visit(const train::operation::FullyConnected &node) } // Set def of backwarding inputs + // The uses of backwarding inputs has already been inserted before const auto outgoing_index = TrainingOperandIndex{in_index, false}; insertBackPropDef(outgoing_index, backwarding_op_index); const auto weights_gradient_index = TrainingOperandIndex{weights_index, false}; insertDef(weights_gradient_index, backwarding_op_index); + // Set def for gradient bias even if it is optional bias since bias is required + // on training. const auto &bias_index = node.getInputs().at(ir::operation::Conv2D::Input::BIAS); - if (bias_index.valid()) - { - const auto bias_gradient_index = TrainingOperandIndex{bias_index, false}; - insertDef(bias_gradient_index, backwarding_op_index); - } + assert(bias_index.valid()); + const auto bias_gradient_index = TrainingOperandIndex{bias_index, false}; + insertDef(bias_gradient_index, backwarding_op_index); } void UseDefGenerator::visit(const train::operation::Loss &node) @@ -427,7 +445,10 @@ void UseDefGenerator::initForForwardingNodes() assert(_training_usedefs.at(forwarding_operand_index).getTrainingUses().size() == 0); const auto uses = operand.getUses(); for (const auto &use : uses) + { + // if (use.valid()) insertUse(forwarding_operand_index, TrainingOperationIndex{use, is_forward}); + } }); } diff --git a/runtime/onert/core/src/ir/train/UseDefGenerator.test.cc b/runtime/onert/core/src/ir/train/UseDefGenerator.test.cc index a38ce0ac397..852d519616c 100644 --- a/runtime/onert/core/src/ir/train/UseDefGenerator.test.cc +++ b/runtime/onert/core/src/ir/train/UseDefGenerator.test.cc @@ -106,6 +106,7 @@ TEST(UseDefGenerator, one_op) Shape shape{2, 2}; TypeInfo type{DataType::FLOAT32}; std::vector data(4, 0.f); + std::vector bias_data(2, 0.f); /* (input) ⎼[FC]⎼> (ba_input1) ⎼[BA]⎼> (y_pred) @@ -117,6 +118,7 @@ TEST(UseDefGenerator, one_op) const auto input = tgraph.addOperand(shape, type); const auto weights = tgraph.addOperand(shape, type); + const auto bias = tgraph.addOperand(Shape{2}, type); const auto ba_input1 = tgraph.addOperand(shape, type); const auto ba_input2 = tgraph.addOperand(shape, type); const auto y_pred = tgraph.addOperand(shape, type); @@ -125,6 +127,8 @@ TEST(UseDefGenerator, one_op) tgraph.operands().at(weights).data(std::make_unique( reinterpret_cast(data.data()), data.size() * sizeof(float))); + tgraph.operands().at(bias).data(std::make_unique( + reinterpret_cast(bias_data.data()), bias_data.size() * sizeof(float))); tgraph.operands().at(ba_input2).data(std::make_unique( reinterpret_cast(data.data()), data.size() * sizeof(float))); @@ -132,8 +136,7 @@ TEST(UseDefGenerator, one_op) tgraph.addInput({y_true}); tgraph.addOutput({output}); - const auto fc_index = - addFullyConnectedOperation(tgraph, {input, weights, OperandIndex{}}, {ba_input1}); + const auto fc_index = addFullyConnectedOperation(tgraph, {input, weights, bias}, {ba_input1}); operation::BinaryArithmetic::Param param; param.arithmetic_type = operation::BinaryArithmetic::ArithmeticType::ADD; @@ -575,6 +578,7 @@ TEST(UseDefGenerator, one_op) Shape shape{2, 2}; TypeInfo type{DataType::FLOAT32}; std::vector data(4, 0.f); + std::vector bias_data(2, 0.f); /* (input) ⎼[FC]⎼> (ea_input) ⎼[EA]⎼> (y_pred) @@ -586,6 +590,7 @@ TEST(UseDefGenerator, one_op) const auto input = tgraph.addOperand(shape, type); const auto weights = tgraph.addOperand(shape, type); + const auto bias = tgraph.addOperand(Shape{2}, type); const auto ea_input = tgraph.addOperand(shape, type); const auto y_pred = tgraph.addOperand(shape, type); const auto y_true = tgraph.addOperand(shape, type); @@ -593,13 +598,14 @@ TEST(UseDefGenerator, one_op) tgraph.operands().at(weights).data(std::make_unique( reinterpret_cast(data.data()), data.size() * sizeof(float))); + tgraph.operands().at(bias).data(std::make_unique( + reinterpret_cast(bias_data.data()), bias_data.size() * sizeof(float))); tgraph.addInput({input}); tgraph.addInput({y_true}); tgraph.addOutput({output}); - const auto fc_index = - addFullyConnectedOperation(tgraph, {input, weights, OperandIndex{}}, {ea_input}); + const auto fc_index = addFullyConnectedOperation(tgraph, {input, weights, bias}, {ea_input}); operation::ElementwiseActivation::Param param; param.op_type = operation::ElementwiseActivation::Type::RELU; @@ -745,6 +751,7 @@ TEST(UseDefGenerator, one_op) Shape shape{2, 2}; TypeInfo type{DataType::FLOAT32}; std::vector data(4, 0.f); + std::vector bias_data(2, 0.f); /* (input) ⎼[FC]⎼> (y_pred) @@ -756,19 +763,21 @@ TEST(UseDefGenerator, one_op) const auto input = tgraph.addOperand(shape, type); const auto weights = tgraph.addOperand(shape, type); + const auto bias = tgraph.addOperand(Shape{2}, type); const auto y_pred = tgraph.addOperand(shape, type); const auto y_true = tgraph.addOperand(shape, type); const auto output = tgraph.addOperand(shape, type); tgraph.operands().at(weights).data(std::make_unique( reinterpret_cast(data.data()), data.size() * sizeof(float))); + tgraph.operands().at(bias).data(std::make_unique( + reinterpret_cast(bias_data.data()), bias_data.size() * sizeof(float))); tgraph.addInput({input}); tgraph.addInput({y_true}); tgraph.addOutput({output}); - const auto fc_index = - addFullyConnectedOperation(tgraph, {input, weights, OperandIndex{}}, {y_pred}); + const auto fc_index = addFullyConnectedOperation(tgraph, {input, weights, bias}, {y_pred}); const auto loss_index = addLossOperation(tgraph, {y_pred, y_true}, {output}); enableAllBackwarding(tgraph); @@ -883,6 +892,7 @@ TEST(UseDefGenerator, one_op) Shape shape{2, 2}; TypeInfo type{DataType::FLOAT32}; std::vector weights_data(4, 0.f); + std::vector bias_data(2, 0.f); std::vector padding_data(4, 0); /* @@ -895,6 +905,7 @@ TEST(UseDefGenerator, one_op) const auto input = tgraph.addOperand(shape, type); const auto weights = tgraph.addOperand(shape, type); + const auto bias = tgraph.addOperand(Shape{2}, type); const auto pad_input = tgraph.addOperand(shape, type); const auto padding = tgraph.addOperand(shape, TypeInfo{DataType::INT32}); const auto y_pred = tgraph.addOperand(shape, type); @@ -903,6 +914,8 @@ TEST(UseDefGenerator, one_op) tgraph.operands().at(weights).data(std::make_unique( reinterpret_cast(weights_data.data()), weights_data.size() * sizeof(float))); + tgraph.operands().at(bias).data(std::make_unique( + reinterpret_cast(bias_data.data()), bias_data.size() * sizeof(float))); tgraph.operands().at(padding).data(std::make_unique( reinterpret_cast(padding_data.data()), padding_data.size() * sizeof(int32_t))); @@ -910,8 +923,7 @@ TEST(UseDefGenerator, one_op) tgraph.addInput({y_true}); tgraph.addOutput({output}); - const auto fc_index = - addFullyConnectedOperation(tgraph, {input, weights, OperandIndex{}}, {pad_input}); + const auto fc_index = addFullyConnectedOperation(tgraph, {input, weights, bias}, {pad_input}); const auto pad_op = operation::Pad({pad_input, padding}, {y_pred}); const auto pad_index = tgraph.addOperation(std::make_unique(pad_op)); @@ -1069,6 +1081,7 @@ TEST(UseDefGenerator, one_op) Shape shape{2, 2}; TypeInfo type{DataType::FLOAT32}; std::vector data(4, 0.f); + std::vector bias_data(2, 0.f); /* (input) ⎼[FC]⎼> (pool_input) ⎼[MaxPool2D]⎼> (y_pred) @@ -1080,6 +1093,7 @@ TEST(UseDefGenerator, one_op) const auto input = tgraph.addOperand(shape, type); const auto weights = tgraph.addOperand(shape, type); + const auto bias = tgraph.addOperand(Shape{2}, type); const auto pool_input = tgraph.addOperand(shape, type); const auto y_pred = tgraph.addOperand(shape, type); const auto y_true = tgraph.addOperand(shape, type); @@ -1087,13 +1101,14 @@ TEST(UseDefGenerator, one_op) tgraph.operands().at(weights).data(std::make_unique( reinterpret_cast(data.data()), data.size() * sizeof(float))); + tgraph.operands().at(bias).data(std::make_unique( + reinterpret_cast(bias_data.data()), bias_data.size() * sizeof(float))); tgraph.addInput({input}); tgraph.addInput({y_true}); tgraph.addOutput({output}); - const auto fc_index = - addFullyConnectedOperation(tgraph, {input, weights, OperandIndex{}}, {pool_input}); + const auto fc_index = addFullyConnectedOperation(tgraph, {input, weights, bias}, {pool_input}); operation::Pool2D::Param param; param.op_type = operation::Pool2D::PoolType::MAX; @@ -1243,6 +1258,7 @@ TEST(UseDefGenerator, one_op) Shape shape{2, 1}; TypeInfo type{DataType::FLOAT32}; std::vector weights_data(4, 0.f); + std::vector bias_data(2, 0.f); std::vector axis_data{-1}; /* @@ -1255,6 +1271,7 @@ TEST(UseDefGenerator, one_op) const auto input = tgraph.addOperand(shape, type); const auto weights = tgraph.addOperand(shape, type); + const auto bias = tgraph.addOperand(Shape{2}, type); const auto mean_input = tgraph.addOperand(shape, type); const auto axis = tgraph.addOperand(Shape{1}, TypeInfo{DataType::INT32}); const auto y_pred = tgraph.addOperand(shape, type); @@ -1263,6 +1280,8 @@ TEST(UseDefGenerator, one_op) tgraph.operands().at(weights).data(std::make_unique( reinterpret_cast(weights_data.data()), weights_data.size() * sizeof(float))); + tgraph.operands().at(bias).data(std::make_unique( + reinterpret_cast(bias_data.data()), bias_data.size() * sizeof(float))); tgraph.operands().at(axis).data(std::make_unique( reinterpret_cast(axis_data.data()), axis_data.size() * sizeof(int32_t))); @@ -1270,8 +1289,7 @@ TEST(UseDefGenerator, one_op) tgraph.addInput({y_true}); tgraph.addOutput({output}); - const auto fc_index = - addFullyConnectedOperation(tgraph, {input, weights, OperandIndex{}}, {mean_input}); + const auto fc_index = addFullyConnectedOperation(tgraph, {input, weights, bias}, {mean_input}); operation::Reduce::Param param; param.reduce_type = operation::Reduce::ReduceType::MEAN; @@ -1431,6 +1449,7 @@ TEST(UseDefGenerator, one_op) Shape s{2, 2}; TypeInfo type{DataType::FLOAT32}; std::vector weights_data(4, 0.f); + std::vector bias_data(2, 0.f); std::vector shape_data{2, 2}; /* @@ -1443,6 +1462,7 @@ TEST(UseDefGenerator, one_op) const auto input = tgraph.addOperand(s, type); const auto weights = tgraph.addOperand(s, type); + const auto bias = tgraph.addOperand(Shape{2}, type); const auto reshape_input = tgraph.addOperand(s, type); const auto shape = tgraph.addOperand(Shape{2}, TypeInfo{DataType::INT32}); const auto y_pred = tgraph.addOperand(s, type); @@ -1451,6 +1471,8 @@ TEST(UseDefGenerator, one_op) tgraph.operands().at(weights).data(std::make_unique( reinterpret_cast(weights_data.data()), weights_data.size() * sizeof(float))); + tgraph.operands().at(bias).data(std::make_unique( + reinterpret_cast(bias_data.data()), bias_data.size() * sizeof(float))); tgraph.operands().at(shape).data(std::make_unique( reinterpret_cast(shape_data.data()), shape_data.size() * sizeof(int32_t))); @@ -1459,7 +1481,7 @@ TEST(UseDefGenerator, one_op) tgraph.addOutput({output}); const auto fc_index = - addFullyConnectedOperation(tgraph, {input, weights, OperandIndex{}}, {reshape_input}); + addFullyConnectedOperation(tgraph, {input, weights, bias}, {reshape_input}); operation::Reshape::Param param; param.new_shape = shape_data; @@ -1619,6 +1641,7 @@ TEST(UseDefGenerator, one_op) Shape shape{2, 2}; TypeInfo type{DataType::FLOAT32}; std::vector data(4, 0.f); + std::vector bias_data(2, 0.f); /* (input) ⎼[FC]⎼> (softmax_input) ⎼[Softmax]⎼> (y_pred) @@ -1630,6 +1653,7 @@ TEST(UseDefGenerator, one_op) const auto input = tgraph.addOperand(shape, type); const auto weights = tgraph.addOperand(shape, type); + const auto bias = tgraph.addOperand(Shape{2}, type); const auto softmax_input = tgraph.addOperand(shape, type); const auto y_pred = tgraph.addOperand(shape, type); const auto y_true = tgraph.addOperand(shape, type); @@ -1637,13 +1661,15 @@ TEST(UseDefGenerator, one_op) tgraph.operands().at(weights).data(std::make_unique( reinterpret_cast(data.data()), data.size() * sizeof(float))); + tgraph.operands().at(bias).data(std::make_unique( + reinterpret_cast(bias_data.data()), bias_data.size() * sizeof(float))); tgraph.addInput({input}); tgraph.addInput({y_true}); tgraph.addOutput({output}); const auto fc_index = - addFullyConnectedOperation(tgraph, {input, weights, OperandIndex{}}, {softmax_input}); + addFullyConnectedOperation(tgraph, {input, weights, bias}, {softmax_input}); operation::Softmax::Param param; param.beta = 1.0f; @@ -1788,6 +1814,7 @@ TEST(UseDefGenerator, one_op) Shape shape{2, 2}; TypeInfo type{DataType::FLOAT32}; std::vector data(4, 0.f); + std::vector bias_data(2, 0.f); /* (input) ⎼[FC]⎼> (fc_out) ⎼⎼⎼⎼⎼⎼⎼⎼⎼⎼⎼⎼⎼⎼⎼⎼[BA]⎼> (y_pred) @@ -1799,6 +1826,7 @@ TEST(UseDefGenerator, one_op) const auto input = tgraph.addOperand(shape, type); const auto weights = tgraph.addOperand(shape, type); + const auto bias = tgraph.addOperand(Shape{2}, type); const auto fc_out = tgraph.addOperand(shape, type); const auto ea_out = tgraph.addOperand(shape, type); const auto y_pred = tgraph.addOperand(shape, type); @@ -1807,13 +1835,14 @@ TEST(UseDefGenerator, one_op) tgraph.operands().at(weights).data(std::make_unique( reinterpret_cast(data.data()), data.size() * sizeof(float))); + tgraph.operands().at(bias).data(std::make_unique( + reinterpret_cast(bias_data.data()), bias_data.size() * sizeof(float))); tgraph.addInput({input}); tgraph.addInput({y_true}); tgraph.addOutput({output}); - const auto fc_index = - addFullyConnectedOperation(tgraph, {input, weights, OperandIndex{}}, {fc_out}); + const auto fc_index = addFullyConnectedOperation(tgraph, {input, weights, bias}, {fc_out}); operation::ElementwiseActivation::Param ea_param; ea_param.op_type = operation::ElementwiseActivation::Type::RELU; diff --git a/tests/nnfw_api/src/GenModelTests/one_op_trains/FullyConnected.test.cc b/tests/nnfw_api/src/GenModelTests/one_op_trains/FullyConnected.test.cc index 9f5ba6632ad..bf572a97742 100644 --- a/tests/nnfw_api/src/GenModelTests/one_op_trains/FullyConnected.test.cc +++ b/tests/nnfw_api/src/GenModelTests/one_op_trains/FullyConnected.test.cc @@ -74,7 +74,7 @@ TEST_F(GenModelTrain, OneOp_FullyConnected_OptionalBias) _context->addTrainCase( uniformTCD({{{1, 3, 2, 1}}}, // inputs {{{2, 1, 5, 5, 2, 1, 5, 5, 2, 1, 5, 5, 2, 1, 5, 6}}}, // expected - {{14.4375f}, {13.9950f}, {13.5668f}, {13.1523f}, {12.7512f}} // loss + {{14.4375f}, {13.9242f}, {13.4299f}, {12.9538f}, {12.4952f}} // loss )); _context->setBackends({"train"});