Skip to content

Commit

Permalink
[Loss] Add finalize to KLD Loss
Browse files Browse the repository at this point in the history
Implement Finalize on KLD Loss

**Self evaluation:**
1. Build test:	 [X]Passed [ ]Failed [ ]Skipped
2. Run test:	 [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Donghak PARK <[email protected]>
  • Loading branch information
DonghakPark authored and jijoongmoon committed Nov 7, 2024
1 parent bf9b8e8 commit fec6b7d
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
21 changes: 21 additions & 0 deletions nntrainer/layers/loss/kld_loss_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,27 @@ KLDLossLayer::KLDLossLayer() {}

KLDLossLayer::~KLDLossLayer() {}

void KLDLossLayer::finalize(nntrainer::InitLayerContext &context) {
if (context.getNumInputs() != 2) {
throw std::invalid_argument("kld loss requires two input");
}
const auto &input_dims = context.getInputDimensions();

if (input_dims.front() != input_dims.back()) {
throw std::invalid_argument("dimension of mu and log_var is different");
}

auto &input_dim = input_dims.front();

temp_idx = context.requestTensor(input_dim, "temp");
before_sum_idx = context.requestTensor(
input_dim, "before_sum", nntrainer::Initializer::NONE, false,
nntrainer::TensorLifespan::FORWARD_FUNC_LIFESPAN);

/// output is a scaler-like tensor
context.setOutputDimensions({{input_dim.batch(), 1, 1, 1}});
}

void KLDLossLayer::setProperty(const std::vector<std::string> &values) {
if (values.size()) {
throw std::invalid_argument(
Expand Down
5 changes: 5 additions & 0 deletions nntrainer/layers/loss/kld_loss_layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ class KLDLossLayer final : public LossLayer {
*/
~KLDLossLayer();

/**
* @copydoc Layer::finalize(InitLayerContext &context)
*/
void finalize(nntrainer::InitLayerContext &context) override;

/**
* @copydoc Layer::setProperty(const std::vector<std::string> &values)
*/
Expand Down

0 comments on commit fec6b7d

Please sign in to comment.