From 66c9b36bfb0a4a3061f39a9850e9c85304360c53 Mon Sep 17 00:00:00 2001 From: ragmani Date: Wed, 28 Aug 2024 21:04:21 +0900 Subject: [PATCH] [Draft] Enable training with optional bias ONE-DCO-1.0-Signed-off-by: ragmani --- .../backend/train/ops/FullyConnectedLayer.cc | 12 ++- runtime/onert/core/include/ir/IOperation.h | 1 + .../core/include/ir/OperandIndexSequence.h | 1 + runtime/onert/core/include/ir/Operation.h | 1 + .../core/include/ir/train/TrainableGraph.h | 1 + .../src/compiler/train/TrainingCompiler.cc | 2 + .../compiler/train/pass/BiasInsertionPass.cc | 83 +++++++++++++++++++ .../compiler/train/pass/BiasInsertionPass.h | 55 ++++++++++++ .../onert/core/src/ir/OperandIndexSequence.cc | 7 ++ runtime/onert/core/src/ir/Operation.cc | 2 + .../core/src/ir/train/UseDefGenerator.cc | 31 +++++-- .../one_op_trains/FullyConnected.test.cc | 2 +- 12 files changed, 185 insertions(+), 13 deletions(-) create mode 100644 runtime/onert/core/src/compiler/train/pass/BiasInsertionPass.cc create mode 100644 runtime/onert/core/src/compiler/train/pass/BiasInsertionPass.h diff --git a/runtime/onert/backend/train/ops/FullyConnectedLayer.cc b/runtime/onert/backend/train/ops/FullyConnectedLayer.cc index 9d35655b26f..1c79e796d55 100644 --- a/runtime/onert/backend/train/ops/FullyConnectedLayer.cc +++ b/runtime/onert/backend/train/ops/FullyConnectedLayer.cc @@ -186,13 +186,11 @@ void FullyConnectedLayer::backwardFloat32() getShape(_grad_weights), getBuffer(_grad_weights)); // Compute gradient for bias - if (_bias) - { - assert(_grad_bias); - nnfw::cker::train::FullyConnectedBiasGrad(getShape(backprop_act), - getBuffer(backprop_act), getShape(_grad_bias), - getBuffer(_grad_bias)); - } + // Bias tensor must exist on training even if bias is optinal input + assert(_bias); + assert(_grad_bias); + nnfw::cker::train::FullyConnectedBiasGrad(getShape(backprop_act), getBuffer(backprop_act), + getShape(_grad_bias), getBuffer(_grad_bias)); } } // namespace ops diff --git a/runtime/onert/core/include/ir/IOperation.h b/runtime/onert/core/include/ir/IOperation.h index be0dd939da6..c340d62e7cb 100644 --- a/runtime/onert/core/include/ir/IOperation.h +++ b/runtime/onert/core/include/ir/IOperation.h @@ -40,6 +40,7 @@ struct IOperation virtual void replaceInputs(const OperandIndex &from, const OperandIndex &to) = 0; virtual void replaceOutputs(const OperandIndex &from, const OperandIndex &to) = 0; + virtual void replaceInput(size_t pos, const OperandIndex &index) = 0; virtual const OperandIndexSequence &getInputs() const = 0; virtual const OperandIndexSequence &getOutputs() const = 0; }; diff --git a/runtime/onert/core/include/ir/OperandIndexSequence.h b/runtime/onert/core/include/ir/OperandIndexSequence.h index 66d00761ba9..3f0e448282b 100644 --- a/runtime/onert/core/include/ir/OperandIndexSequence.h +++ b/runtime/onert/core/include/ir/OperandIndexSequence.h @@ -51,6 +51,7 @@ class OperandIndexSequence const OperandIndex &at(uint32_t index) const { return _vec.at(index); } bool contains(const OperandIndex &index) const; void replace(const OperandIndex &from, const OperandIndex &to); + void replace(size_t pos, const OperandIndex &index); OperandIndexSequence operator|(ir::Remove filter) const { switch (filter) diff --git a/runtime/onert/core/include/ir/Operation.h b/runtime/onert/core/include/ir/Operation.h index 06ab29ecb19..be0d2243b5e 100644 --- a/runtime/onert/core/include/ir/Operation.h +++ b/runtime/onert/core/include/ir/Operation.h @@ -50,6 +50,7 @@ class Operation : virtual public IOperation public: void replaceInputs(const OperandIndex &from, const OperandIndex &to) override; void replaceOutputs(const OperandIndex &from, const OperandIndex &to) override; + void replaceInput(size_t pos, const OperandIndex &index) override; OperandIndexSequence &getInputs() { return _inputs; } const OperandIndexSequence &getInputs() const override { return _inputs; } const OperandIndexSequence &getOutputs() const override { return _outputs; } diff --git a/runtime/onert/core/include/ir/train/TrainableGraph.h b/runtime/onert/core/include/ir/train/TrainableGraph.h index 4952a6b513f..61be9d63384 100644 --- a/runtime/onert/core/include/ir/train/TrainableGraph.h +++ b/runtime/onert/core/include/ir/train/TrainableGraph.h @@ -123,6 +123,7 @@ class TrainableGraph : public IGraph const Operands &operands() const override { return _graph.operands(); } Operands &operands() { return _graph.operands(); } // TODO Remove this non-const accessor const Operations &operations() const override { return _graph.operations(); } + Operations &operations() { return _graph.operations(); } const Operands &backward_operands() const { return _backward_operands; } OperandIndex getLossIndex(const IOIndex &pred_io_ind) const; const Graph &graph() const { return _graph; } diff --git a/runtime/onert/core/src/compiler/train/TrainingCompiler.cc b/runtime/onert/core/src/compiler/train/TrainingCompiler.cc index aee23a6d992..d6996620447 100644 --- a/runtime/onert/core/src/compiler/train/TrainingCompiler.cc +++ b/runtime/onert/core/src/compiler/train/TrainingCompiler.cc @@ -18,6 +18,7 @@ #include "StaticBackwardShapeInferer.h" #include "TrainableOperationConverter.h" +#include "pass/BiasInsertionPass.h" #include "pass/LossInsertionPass.h" #include "../CompilerHelpers.h" #include "../ExecutorFactory.h" @@ -86,6 +87,7 @@ std::shared_ptr TrainingCompiler::compile(void) compiler::pass::PassRunner{} .append(std::make_unique(subg)) .append(std::make_unique(subg)) + .append(std::make_unique(subg)) .run(); // Optimizations diff --git a/runtime/onert/core/src/compiler/train/pass/BiasInsertionPass.cc b/runtime/onert/core/src/compiler/train/pass/BiasInsertionPass.cc new file mode 100644 index 00000000000..0c0bdb8a0c6 --- /dev/null +++ b/runtime/onert/core/src/compiler/train/pass/BiasInsertionPass.cc @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "BiasInsertionPass.h" + +#include "ir/Graph.h" +#include "util/logging.h" + +namespace onert +{ +namespace compiler +{ +namespace train +{ +namespace pass +{ + +void BiasInsertionPass::run() +{ + _graph.operations().iterate([&](const ir::OperationIndex &op_index, const ir::IOperation &node) { + _current_op_index = op_index; + node.accept(*this); + }); +} + +void BiasInsertionPass::visit(const ir::operation::Conv2D &) {} + +void BiasInsertionPass::visit(const ir::operation::DepthwiseConv2D &) {} + +void BiasInsertionPass::visit(const ir::operation::FullyConnected &node) +{ + const auto &bias_index = node.getInputs().at(ir::operation::Conv2D::Input::BIAS); + + // Insert bias if it is optional + if (!bias_index.valid()) + { + const auto &output_index = node.getOutputs().at(0); + const auto &output = _graph.operands().at(output_index); + const auto &output_shape = output.shape(); + const auto bias_shape = ir::Shape{output_shape.dim(output_shape.rank() - 1)}; + + auto bias_typeinfo = output.typeInfo(); + if (bias_typeinfo.type() != ir::DataType::FLOAT32) + throw std::runtime_error("BiasInsertionPass: Only FLOAT32 is supported for now"); + + const auto new_bias_index = _graph.addOperand(bias_shape, output.typeInfo()); + + // TODO Replace data with sparse data to reduce memory usage + const auto bias_size = bias_shape.num_elements() * ir::sizeOfDataType(bias_typeinfo.type()); + std::vector data_vec(bias_size, 0); + auto data_obj = std::make_shared(data_vec.data(), bias_size); + _graph.setOperandValue(new_bias_index, std::move(data_obj)); + + auto &bias = _graph.operands().at(new_bias_index); + bias.insertUse(_current_op_index); + bias.isConstant(); + + _graph.operations() + .at(_current_op_index) + .replaceInput(ir::operation::Conv2D::Input::BIAS, new_bias_index); + + VERBOSE(BiasInsertionPass) << "Optional bias is inserted for training, bias index : " + << bias_index << std::endl; + } +} + +} // namespace pass +} // namespace train +} // namespace compiler +} // namespace onert diff --git a/runtime/onert/core/src/compiler/train/pass/BiasInsertionPass.h b/runtime/onert/core/src/compiler/train/pass/BiasInsertionPass.h new file mode 100644 index 00000000000..96801edb592 --- /dev/null +++ b/runtime/onert/core/src/compiler/train/pass/BiasInsertionPass.h @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_COMPILER_TRAIN_PASS_BIAS_INSERTION_PASS_H__ +#define __ONERT_COMPILER_TRAIN_PASS_BIAS_INSERTION_PASS_H__ + +#include "../../pass/Pass.h" +#include "ir/OperationVisitor.h" + +namespace onert +{ +namespace compiler +{ +namespace train +{ +namespace pass +{ + +class BiasInsertionPass final : public compiler::pass::Pass, public ir::OperationVisitor +{ +public: + BiasInsertionPass(ir::Graph &graph) : compiler::pass::Pass{graph} {} + +public: + std::string id() final { return "BiasInsertionPass"; } + void run() final; + +public: + void visit(const ir::operation::Conv2D &node) override; + void visit(const ir::operation::DepthwiseConv2D &node) override; + void visit(const ir::operation::FullyConnected &node) override; + +private: + ir::OperationIndex _current_op_index; +}; + +} // namespace pass +} // namespace train +} // namespace compiler +} // namespace onert + +#endif // __ONERT_COMPILER_TRAIN_PASS_BIAS_INSERTION_PASS_H__ diff --git a/runtime/onert/core/src/ir/OperandIndexSequence.cc b/runtime/onert/core/src/ir/OperandIndexSequence.cc index a15b6d0d69f..2d00725e5af 100644 --- a/runtime/onert/core/src/ir/OperandIndexSequence.cc +++ b/runtime/onert/core/src/ir/OperandIndexSequence.cc @@ -17,6 +17,7 @@ #include "ir/OperandIndexSequence.h" #include +#include #include namespace onert @@ -55,6 +56,12 @@ void OperandIndexSequence::replace(const OperandIndex &from, const OperandIndex std::replace(_vec.begin(), _vec.end(), from, to); } +void OperandIndexSequence::replace(size_t pos, const OperandIndex &index) +{ + assert(pos < _vec.size() && "OperandIndexSequence: Out of range"); + _vec.at(pos) = index; +} + bool OperandIndexSequence::operator==(const OperandIndexSequence &other) const { return _vec == other._vec; diff --git a/runtime/onert/core/src/ir/Operation.cc b/runtime/onert/core/src/ir/Operation.cc index 64792525dd7..10fb0296bb5 100644 --- a/runtime/onert/core/src/ir/Operation.cc +++ b/runtime/onert/core/src/ir/Operation.cc @@ -62,5 +62,7 @@ void Operation::replaceOutputs(const OperandIndex &from, const OperandIndex &to) _outputs.replace(from, to); } +void Operation::replaceInput(size_t pos, const OperandIndex &index) { _inputs.replace(pos, index); } + } // namespace ir } // namespace onert diff --git a/runtime/onert/core/src/ir/train/UseDefGenerator.cc b/runtime/onert/core/src/ir/train/UseDefGenerator.cc index af6f75003ff..066ff375e8b 100644 --- a/runtime/onert/core/src/ir/train/UseDefGenerator.cc +++ b/runtime/onert/core/src/ir/train/UseDefGenerator.cc @@ -211,6 +211,23 @@ void UseDefGenerator::visit(const train::operation::FullyConnected &node) const auto weights_forwarding_index = TrainingOperandIndex{weights_index, true}; insertUse(weights_forwarding_index, backwarding_op_index); + // // Set def and use for bias if bias is optional since bias is required on training + // const auto forwarding_op_index = TrainingOperationIndex{op_index, true}; + // const auto &bias_index = node.getInputs().at(ir::operation::Conv2D::Input::BIAS); + // const auto bias_forwarding_index = TrainingOperandIndex{bias_index, true}; + // if (!bias_index.valid()) + // { + // // If bias is not optional, the use for bias is already inserted before + // insertUse(bias_forwarding_index, forwarding_op_index); + + // // The forwarding bias is not used in backwarding + // } + // else + // { + // [[maybe_unused]] const auto &usedef_chain = _training_usedefs.at(bias_forwarding_index); + // assert(usedef_chain.getTrainingUses().count(forwarding_op_index) == 0); + // } + // Insert uses of forwarding output if (node.param().activation != ir::Activation::NONE) { @@ -220,18 +237,19 @@ void UseDefGenerator::visit(const train::operation::FullyConnected &node) } // Set def of backwarding inputs + // The uses of backwarding inputs has already been inserted before const auto outgoing_index = TrainingOperandIndex{in_index, false}; insertBackPropDef(outgoing_index, backwarding_op_index); const auto weights_gradient_index = TrainingOperandIndex{weights_index, false}; insertDef(weights_gradient_index, backwarding_op_index); + // Set def for gradient bias even if it is optional bias since bias is required + // on training. const auto &bias_index = node.getInputs().at(ir::operation::Conv2D::Input::BIAS); - if (bias_index.valid()) - { - const auto bias_gradient_index = TrainingOperandIndex{bias_index, false}; - insertDef(bias_gradient_index, backwarding_op_index); - } + assert(bias_index.valid()); + const auto bias_gradient_index = TrainingOperandIndex{bias_index, false}; + insertDef(bias_gradient_index, backwarding_op_index); } void UseDefGenerator::visit(const train::operation::Loss &node) @@ -427,7 +445,10 @@ void UseDefGenerator::initForForwardingNodes() assert(_training_usedefs.at(forwarding_operand_index).getTrainingUses().size() == 0); const auto uses = operand.getUses(); for (const auto &use : uses) + { + // if (use.valid()) insertUse(forwarding_operand_index, TrainingOperationIndex{use, is_forward}); + } }); } diff --git a/tests/nnfw_api/src/GenModelTests/one_op_trains/FullyConnected.test.cc b/tests/nnfw_api/src/GenModelTests/one_op_trains/FullyConnected.test.cc index 9f5ba6632ad..bf572a97742 100644 --- a/tests/nnfw_api/src/GenModelTests/one_op_trains/FullyConnected.test.cc +++ b/tests/nnfw_api/src/GenModelTests/one_op_trains/FullyConnected.test.cc @@ -74,7 +74,7 @@ TEST_F(GenModelTrain, OneOp_FullyConnected_OptionalBias) _context->addTrainCase( uniformTCD({{{1, 3, 2, 1}}}, // inputs {{{2, 1, 5, 5, 2, 1, 5, 5, 2, 1, 5, 5, 2, 1, 5, 6}}}, // expected - {{14.4375f}, {13.9950f}, {13.5668f}, {13.1523f}, {12.7512f}} // loss + {{14.4375f}, {13.9242f}, {13.4299f}, {12.9538f}, {12.4952f}} // loss )); _context->setBackends({"train"});