Skip to content

Commit

Permalink
[onert] Apply ReLUGrad to ElementwiseActivationLayer (#11251)
Browse files Browse the repository at this point in the history
This commit applies ReLUGrad to ElementwiseActivationLayer.

ONE-DCO-1.0-Signed-off-by: ragmani <[email protected]>
  • Loading branch information
ragmani authored Aug 17, 2023
1 parent 2aeb743 commit 015e3e9
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 6 deletions.
2 changes: 1 addition & 1 deletion runtime/onert/backend/cpu/ops/ElementwiseActivationLayer.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class ElementwiseActivationLayer : public ::onert::exec::IFunction

void EvalUsingLookupTable(const IPortableTensor *input, IPortableTensor *output);

private:
protected:
const IPortableTensor *_input;
IPortableTensor *_output;
uint8_t _table[256];
Expand Down
34 changes: 29 additions & 5 deletions runtime/onert/backend/train/ops/ElementwiseActivationLayer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

#include "ElementwiseActivationLayer.h"

#include <ops/OperationUtils.h>
#include "OperationUtils.h"

#include <cker/train/operation/ReLU.h>

namespace onert
{
Expand Down Expand Up @@ -50,19 +52,41 @@ void ElementwiseActivationLayer::configure(const IPortableTensor *input, IPortab
switch (op_type)
{
case ElementwiseActivationType::kReLU:
cpu::ops::ElementwiseActivationLayer::configure(input, output, alpha, beta,
cpu::ops::ElementwiseActivationType::kReLU);
if (_input->data_type() == OperandType::FLOAT32)
{
if (alpha == std::numeric_limits<float>::infinity() && beta == 0.f)
{
cpu::ops::ElementwiseActivationLayer::configure(
input, output, alpha, beta, cpu::ops::ElementwiseActivationType::kReLU);

_backward_kernel = [](const IPortableTensor *output, const IPortableTensor *incoming,
IPortableTensor *outgoing) {
nnfw::cker::train::ReLUGrad(getShape(output), getBuffer<float>(output),
getShape(incoming), getBuffer<float>(incoming),
getShape(outgoing), getBuffer<float>(outgoing));
};
}
else
{
throw std::runtime_error("train ElementwiseActivationLayer : This layer does not "
"suppport other ReLU except for ReLU(0-inf)");
}
}
else
{
throw std::runtime_error("train ElementwiseActivationLayer: Unsupported datatype");
}
break;
default:
throw std::runtime_error("ElementwiseActivationLayer: unsupported op type");
throw std::runtime_error("train ElementwiseActivationLayer: Unsupported activation type yet");
}
}

void ElementwiseActivationLayer::forward(bool) { cpu::ops::ElementwiseActivationLayer::run(); }

void ElementwiseActivationLayer::backward(uint32_t)
{
// TODO Implement details
_backward_kernel(_output, _deriv_output, _deriv_input);
}

} // namespace ops
Expand Down
3 changes: 3 additions & 0 deletions runtime/onert/backend/train/ops/ElementwiseActivationLayer.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ class ElementwiseActivationLayer : public ::onert::exec::train::ITrainableFuncti
const IPortableTensor *_deriv_output;

ElementwiseActivationType _op_type;
std::function<void(const IPortableTensor *output, const IPortableTensor *incoming,
IPortableTensor *outgoing)>
_backward_kernel;
};

} // namespace ops
Expand Down

0 comments on commit 015e3e9

Please sign in to comment.