Skip to content

Commit

Permalink
[Draft] Enable training with optional bias
Browse files Browse the repository at this point in the history
ONE-DCO-1.0-Signed-off-by: ragmani <[email protected]>
  • Loading branch information
ragmani committed Aug 29, 2024
1 parent f963350 commit 8e69e6a
Show file tree
Hide file tree
Showing 12 changed files with 185 additions and 13 deletions.
12 changes: 5 additions & 7 deletions runtime/onert/backend/train/ops/FullyConnectedLayer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,13 +186,11 @@ void FullyConnectedLayer::backwardFloat32()
getShape(_grad_weights), getBuffer<float>(_grad_weights));

// Compute gradient for bias
if (_bias)
{
assert(_grad_bias);
nnfw::cker::train::FullyConnectedBiasGrad(getShape(backprop_act),
getBuffer<float>(backprop_act), getShape(_grad_bias),
getBuffer<float>(_grad_bias));
}
// Bias tensor must exist on training even if bias is optinal input
assert(_bias);
assert(_grad_bias);
nnfw::cker::train::FullyConnectedBiasGrad(getShape(backprop_act), getBuffer<float>(backprop_act),
getShape(_grad_bias), getBuffer<float>(_grad_bias));
}

} // namespace ops
Expand Down
1 change: 1 addition & 0 deletions runtime/onert/core/include/ir/IOperation.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ struct IOperation

virtual void replaceInputs(const OperandIndex &from, const OperandIndex &to) = 0;
virtual void replaceOutputs(const OperandIndex &from, const OperandIndex &to) = 0;
virtual void replaceInput(size_t pos, const OperandIndex &index) = 0;
virtual const OperandIndexSequence &getInputs() const = 0;
virtual const OperandIndexSequence &getOutputs() const = 0;
};
Expand Down
1 change: 1 addition & 0 deletions runtime/onert/core/include/ir/OperandIndexSequence.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class OperandIndexSequence
const OperandIndex &at(uint32_t index) const { return _vec.at(index); }
bool contains(const OperandIndex &index) const;
void replace(const OperandIndex &from, const OperandIndex &to);
void replace(size_t pos, const OperandIndex &index);
OperandIndexSequence operator|(ir::Remove filter) const
{
switch (filter)
Expand Down
1 change: 1 addition & 0 deletions runtime/onert/core/include/ir/Operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class Operation : virtual public IOperation
public:
void replaceInputs(const OperandIndex &from, const OperandIndex &to) override;
void replaceOutputs(const OperandIndex &from, const OperandIndex &to) override;
void replaceInput(size_t pos, const OperandIndex &index) override;
OperandIndexSequence &getInputs() { return _inputs; }
const OperandIndexSequence &getInputs() const override { return _inputs; }
const OperandIndexSequence &getOutputs() const override { return _outputs; }
Expand Down
1 change: 1 addition & 0 deletions runtime/onert/core/include/ir/train/TrainableGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ class TrainableGraph : public IGraph
const Operands &operands() const override { return _graph.operands(); }
Operands &operands() { return _graph.operands(); } // TODO Remove this non-const accessor
const Operations &operations() const override { return _graph.operations(); }
Operations &operations() { return _graph.operations(); }
const Operands &backward_operands() const { return _backward_operands; }
OperandIndex getLossIndex(const IOIndex &pred_io_ind) const;
const Graph &graph() const { return _graph; }
Expand Down
2 changes: 2 additions & 0 deletions runtime/onert/core/src/compiler/train/TrainingCompiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include "StaticBackwardShapeInferer.h"
#include "TrainableOperationConverter.h"
#include "pass/BiasInsertionPass.h"
#include "pass/LossInsertionPass.h"
#include "../CompilerHelpers.h"
#include "../ExecutorFactory.h"
Expand Down Expand Up @@ -86,6 +87,7 @@ std::shared_ptr<CompilerArtifact> TrainingCompiler::compile(void)
compiler::pass::PassRunner{}
.append(std::make_unique<compiler::pass::ConstantOutputPass>(subg))
.append(std::make_unique<compiler::pass::OddOutputPass>(subg))
.append(std::make_unique<train::pass::BiasInsertionPass>(subg))
.run();

// Optimizations
Expand Down
83 changes: 83 additions & 0 deletions runtime/onert/core/src/compiler/train/pass/BiasInsertionPass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "BiasInsertionPass.h"

#include "ir/Graph.h"
#include "util/logging.h"

namespace onert
{
namespace compiler
{
namespace train
{
namespace pass
{

void BiasInsertionPass::run()
{
_graph.operations().iterate([&](const ir::OperationIndex &op_index, const ir::IOperation &node) {
_current_op_index = op_index;
node.accept(*this);
});
}

void BiasInsertionPass::visit(const ir::operation::Conv2D &) {}

void BiasInsertionPass::visit(const ir::operation::DepthwiseConv2D &) {}

void BiasInsertionPass::visit(const ir::operation::FullyConnected &node)
{
const auto &bias_index = node.getInputs().at(ir::operation::Conv2D::Input::BIAS);

// Insert bias if it is optional
if (!bias_index.valid())
{
const auto &output_index = node.getOutputs().at(0);
const auto &output = _graph.operands().at(output_index);
const auto &output_shape = output.shape();
const auto bias_shape = ir::Shape{output_shape.dim(output_shape.rank() - 1)};

auto bias_typeinfo = output.typeInfo();
if (bias_typeinfo.type() != ir::DataType::FLOAT32)
throw std::runtime_error("BiasInsertionPass: Only FLOAT32 is supported for now");

const auto new_bias_index = _graph.addOperand(bias_shape, output.typeInfo());

// TODO Replace data with sparse data to reduce memory usage
const auto bias_size = bias_shape.num_elements() * ir::sizeOfDataType(bias_typeinfo.type());
std::vector<uint8_t> data_vec(bias_size, 0);
auto data_obj = std::make_shared<ir::CachedData>(data_vec.data(), bias_size);
_graph.setOperandValue(new_bias_index, std::move(data_obj));

auto &bias = _graph.operands().at(new_bias_index);
bias.insertUse(_current_op_index);
bias.isConstant();

_graph.operations()
.at(_current_op_index)
.replaceInput(ir::operation::Conv2D::Input::BIAS, new_bias_index);

VERBOSE(BiasInsertionPass) << "Optional bias is inserted for training, bias index : "
<< bias_index << std::endl;
}
}

} // namespace pass
} // namespace train
} // namespace compiler
} // namespace onert
55 changes: 55 additions & 0 deletions runtime/onert/core/src/compiler/train/pass/BiasInsertionPass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef __ONERT_COMPILER_TRAIN_PASS_BIAS_INSERTION_PASS_H__
#define __ONERT_COMPILER_TRAIN_PASS_BIAS_INSERTION_PASS_H__

#include "../../pass/Pass.h"
#include "ir/OperationVisitor.h"

namespace onert
{
namespace compiler
{
namespace train
{
namespace pass
{

class BiasInsertionPass final : public compiler::pass::Pass, public ir::OperationVisitor
{
public:
BiasInsertionPass(ir::Graph &graph) : compiler::pass::Pass{graph} {}

public:
std::string id() final { return "BiasInsertionPass"; }
void run() final;

public:
void visit(const ir::operation::Conv2D &node) override;
void visit(const ir::operation::DepthwiseConv2D &node) override;
void visit(const ir::operation::FullyConnected &node) override;

private:
ir::OperationIndex _current_op_index;
};

} // namespace pass
} // namespace train
} // namespace compiler
} // namespace onert

#endif // __ONERT_COMPILER_TRAIN_PASS_BIAS_INSERTION_PASS_H__
7 changes: 7 additions & 0 deletions runtime/onert/core/src/ir/OperandIndexSequence.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "ir/OperandIndexSequence.h"

#include <algorithm>
#include <cassert>
#include <sstream>

namespace onert
Expand Down Expand Up @@ -55,6 +56,12 @@ void OperandIndexSequence::replace(const OperandIndex &from, const OperandIndex
std::replace(_vec.begin(), _vec.end(), from, to);
}

void OperandIndexSequence::replace(size_t pos, const OperandIndex &index)
{
assert(pos < _vec.size() && "OperandIndexSequence: Out of range");
_vec.at(pos) = index;
}

bool OperandIndexSequence::operator==(const OperandIndexSequence &other) const
{
return _vec == other._vec;
Expand Down
2 changes: 2 additions & 0 deletions runtime/onert/core/src/ir/Operation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,7 @@ void Operation::replaceOutputs(const OperandIndex &from, const OperandIndex &to)
_outputs.replace(from, to);
}

void Operation::replaceInput(size_t pos, const OperandIndex &index) { _inputs.replace(pos, index); }

} // namespace ir
} // namespace onert
31 changes: 26 additions & 5 deletions runtime/onert/core/src/ir/train/UseDefGenerator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,23 @@ void UseDefGenerator::visit(const train::operation::FullyConnected &node)
const auto weights_forwarding_index = TrainingOperandIndex{weights_index, true};
insertUse(weights_forwarding_index, backwarding_op_index);

// // Set def and use for bias if bias is optional since bias is required on training
// const auto forwarding_op_index = TrainingOperationIndex{op_index, true};
// const auto &bias_index = node.getInputs().at(ir::operation::Conv2D::Input::BIAS);
// const auto bias_forwarding_index = TrainingOperandIndex{bias_index, true};
// if (!bias_index.valid())
// {
// // If bias is not optional, the use for bias is already inserted before
// insertUse(bias_forwarding_index, forwarding_op_index);

// // The forwarding bias is not used in backwarding
// }
// else
// {
// [[maybe_unused]] const auto &usedef_chain = _training_usedefs.at(bias_forwarding_index);
// assert(usedef_chain.getTrainingUses().count(forwarding_op_index) == 0);
// }

// Insert uses of forwarding output
if (node.param().activation != ir::Activation::NONE)
{
Expand All @@ -220,18 +237,19 @@ void UseDefGenerator::visit(const train::operation::FullyConnected &node)
}

// Set def of backwarding inputs
// The uses of backwarding inputs has already been inserted before
const auto outgoing_index = TrainingOperandIndex{in_index, false};
insertBackPropDef(outgoing_index, backwarding_op_index);

const auto weights_gradient_index = TrainingOperandIndex{weights_index, false};
insertDef(weights_gradient_index, backwarding_op_index);

// Set def for gradient bias even if it is optional bias since bias is required
// on training.
const auto &bias_index = node.getInputs().at(ir::operation::Conv2D::Input::BIAS);
if (bias_index.valid())
{
const auto bias_gradient_index = TrainingOperandIndex{bias_index, false};
insertDef(bias_gradient_index, backwarding_op_index);
}
assert(bias_index.valid());
const auto bias_gradient_index = TrainingOperandIndex{bias_index, false};
insertDef(bias_gradient_index, backwarding_op_index);
}

void UseDefGenerator::visit(const train::operation::Loss &node)
Expand Down Expand Up @@ -427,7 +445,10 @@ void UseDefGenerator::initForForwardingNodes()
assert(_training_usedefs.at(forwarding_operand_index).getTrainingUses().size() == 0);
const auto uses = operand.getUses();
for (const auto &use : uses)
{
// if (use.valid())
insertUse(forwarding_operand_index, TrainingOperationIndex{use, is_forward});
}
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ TEST_F(GenModelTrain, OneOp_FullyConnected_OptionalBias)
_context->addTrainCase(
uniformTCD<float>({{{1, 3, 2, 1}}}, // inputs
{{{2, 1, 5, 5, 2, 1, 5, 5, 2, 1, 5, 5, 2, 1, 5, 6}}}, // expected
{{14.4375f}, {13.9950f}, {13.5668f}, {13.1523f}, {12.7512f}} // loss
{{14.4375f}, {13.9242f}, {13.4298f}, {12.9537f}, {12.4952f}} // loss
));

_context->setBackends({"train"});
Expand Down

0 comments on commit 8e69e6a

Please sign in to comment.