Skip to content

Commit

Permalink
[Draft] Revisit categorical crossentropy
Browse files Browse the repository at this point in the history
ONE-DCO-1.0-Signed-off-by: ragmani <[email protected]>
  • Loading branch information
ragmani committed Sep 5, 2024
1 parent 67df4c5 commit 93b5d13
Show file tree
Hide file tree
Showing 13 changed files with 424 additions and 17 deletions.
151 changes: 151 additions & 0 deletions compute/cker/include/cker/eigen/xent_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef __NNFW_CKER_EIGEN_XENT_OPS_H__
#define __NNFW_CKER_EIGEN_XENT_OPS_H__

// From tensorflow/core/kernels/xent_op.cc
#define EIGEN_USE_THREADS

#include "unsupported/Eigen/CXX11/Tensor"
#include "cker/operation/Helper/Tensor.h"

// From tensorflow/core/kernels/xent_op.h
namespace nnfw
{
namespace cker
{
namespace xent_ops
{
namespace functor
{

// Functor used by XentOp to do the computations.
template <typename Device, typename T> struct XentFunctor
{
// Computes Cross Entropy loss and backprop.
//
// logits: batch_size, num_classes.
// labels: batch_size, num_classes.
// scratch: temporary tensor, dims: batch_size, 1
// loss: output tensor for the loss, dims: batch_size.
// backprop: output tensor for the backprop, dims: batch_size, num_classes.
void operator()(const Device &d, const Eigen::DSizes<Eigen::DenseIndex, 2> &shape,
const Eigen::array<Eigen::DenseIndex, 2> &logits_bcast,
const Eigen::array<Eigen::DenseIndex, 2> &labels_bcast,
typename TTypes<T>::ConstMatrix logits, typename TTypes<T>::ConstMatrix labels,
typename TTypes<T>::Matrix scratch, typename TTypes<T>::Vec loss,
typename TTypes<T>::Matrix backprop);
};

} // namespace functor
} // namespace xent_ops
} // namespace cker
} // namespace nnfw

// From tensorflow/core/kernels/xent_op.cc
namespace nnfw
{
namespace cker
{
namespace xent_ops
{

// Enable CPUDevice only for xent_ops
using CPUDevice = Eigen::ThreadPoolDevice;
using Index = Eigen::Index;

// Partial specialization for a CPUDevice, that uses the Eigen implementation
// from XentEigenImpl.
namespace functor
{
template <typename Device, typename T> struct XentFunctorBase
{
void operator()(const Device &d, const Eigen::DSizes<Eigen::DenseIndex, 2> &shape,
const Eigen::array<Eigen::DenseIndex, 2> &logits_bcast,
const Eigen::array<Eigen::DenseIndex, 2> &labels_bcast,
typename TTypes<T>::ConstMatrix logits, typename TTypes<T>::ConstMatrix labels,
typename TTypes<T>::Matrix scratch, typename TTypes<T>::Vec loss,
typename TTypes<T>::Matrix backprop)
{
T *scratch_ptr = scratch.data();
T *backprop_ptr = backprop.data();

T *loss_ptr = loss.data();

int row_size = shape[1];

if (shape[0] > 0)
{
backprop.device(d) = logits.broadcast(logits_bcast);
scratch.device(d) = labels.broadcast(labels_bcast);
auto reductionWorker = [&](int64_t begin, int64_t end) -> void {
for (int i = begin; i < end; i++)
{
T *this_backprop = backprop_ptr + (i * row_size);
T *this_logits = backprop_ptr + (i * row_size);
T *this_labels = scratch_ptr + (i * row_size);
T max_logits = this_logits[0];

// calculating max_logits
for (int j = 1; j < row_size; j++)
{
max_logits = std::max(max_logits, this_logits[j]);
}

T sum = T(0);
T loss_sum = T(0);

for (int j = 0; j < row_size; j++)
{
// Note that if input is reused than this_logits and this_backprop
// is same buffer, so after this calculation this_logits should no
// longer be trusted
this_backprop[j] = this_logits[j] - max_logits;
sum = sum + exp(this_backprop[j]);
}

// loss calculation
T log_sum = log(sum);
for (int j = 0; j < row_size; j++)
{
loss_sum += this_labels[j] * (log_sum - this_backprop[j]);
this_backprop[j] = (exp(this_backprop[j]) / sum) - this_labels[j];
}
loss_ptr[i] = loss_sum;
}
};
const int64_t compute_cycles = 50 * row_size;
const int64_t input_bytes = sizeof(T) * row_size;
const int64_t output_bytes = sizeof(T) * row_size;
const Eigen::TensorOpCost cost(input_bytes, output_bytes, compute_cycles);

d.parallelFor(shape[0], cost, reductionWorker);
}
}
};

template <typename T> struct XentFunctor<CPUDevice, T> : XentFunctorBase<CPUDevice, T>
{
};

} // namespace functor
} // namespace xent_ops
} // namespace cker
} // namespace nnfw

#endif // __NNFW_CKER_EIGEN_XENT_OPS_H__
4 changes: 4 additions & 0 deletions compute/cker/include/cker/operation/Helper/Tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,10 @@ struct Tensor
{
return typename TTypes<T>::ConstScalar(base<T>());
}

template <typename T> typename TTypes<T>::Vec vec() { return shaped<T, 1>(); }

template <typename T> typename TTypes<T>::Matrix matrix() { return shaped<T, 2>(); }
}; // Tensor

template <typename DSizes> Eigen::DSizes<Index32, DSizes::count> To32BitDims(const DSizes &in)
Expand Down
130 changes: 126 additions & 4 deletions compute/cker/include/cker/train/operation/Loss.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/*
* Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,10 +18,13 @@
#ifndef __NNFW_CKER_TRAIN_OPERATION_LOSS_H__
#define __NNFW_CKER_TRAIN_OPERATION_LOSS_H__

#include <numeric>
// #include <numeric>

#include "cker/Shape.h"
#include "cker/eigen/EigenSupport.h"
#include "cker/eigen/Utils.h"
#include "cker/eigen/xent_op.h"
#include "cker/operation/Helper/BCast.h"

namespace nnfw
{
Expand Down Expand Up @@ -74,9 +78,9 @@ inline void MSEGrad(const Shape &y_pred_shape, const T *y_pred_data, const Shape
}

template <typename T>
inline void CategoricalCrossEntropy(const Shape &y_pred_shape, const T *y_pred_data,
const Shape &y_true_shape, const T *y_true_data,
const Shape &output_shape, T *output_data)
void CategoricalCrossEntropy(const Shape &y_pred_shape, const T *y_pred_data,
const Shape &y_true_shape, const T *y_true_data,
const Shape &output_shape, T *output_data)
{
if (output_shape.DimensionsCount() != 1)
throw std::runtime_error("cker::CategoricalCrossEntropy: output dimension count should be 1");
Expand All @@ -94,6 +98,124 @@ inline void CategoricalCrossEntropy(const Shape &y_pred_shape, const T *y_pred_d
output = -(y_true.array() * y_pred.array().cwiseMax(log_threshold<T>()).log()).colwise().sum();
}

// TODO Rename
template <typename T>
void CategoricalCrossEntropyWithLogits(const Shape &logits_shape, const T *logits_data,
const Shape &y_true_shape, const T *y_true_data,
const Shape &loss_out_shape, T *loss_out_data,
const Shape &grad_shape, T *grad_data)
{
// TODO Enable sparse shapes
if (loss_out_shape.DimensionsCount() != 1)
throw std::runtime_error(
"cker::CategoricalCrossEntropyWithLogits: loss output dimension count should be 1");
if (logits_shape != y_true_shape)
throw std::runtime_error(
"cker::CategoricalCrossEntropyWithLogits: logits and y_true do not have the same shape");
if (loss_out_shape.Dims(0) != logits_shape.Dims(0))
throw std::runtime_error(
"cker::CategoricalCrossEntropyWithLogits: loss_out and logits do not have the same batch");
if (logits_shape != grad_shape)
throw std::runtime_error(
"cker::CategoricalCrossEntropyWithLogits: logits and grad do not have the same shape");

auto shape_in = logits_shape;

BCast bcast(BCast::FromShape(shape_in), BCast::FromShape(y_true_shape),
/*fewer_dims_optimization=*/false);
// if (!y_pred_shape.IsSameSize(y_true_shape)) {
// OP_REQUIRES(context, bcast.IsValid(),
// errors::InvalidArgument(
// "logits and labels must be broadcastable: logits_size=",
// logits_in.shape().DebugString(),
// " labels_size=", labels_in.shape().DebugString()));
// shape_in = BCast::ToShape(bcast.output_shape());
// }
// OP_REQUIRES(context, TensorShapeUtils::IsMatrix(shape_in),
// errors::InvalidArgument("logits and labels must be either "
// "2-dimensional, or broadcasted to be "
// "2-dimensional"));

// if (std::is_same<Device, GPUDevice>::value) {
// OP_REQUIRES(context, !OpDeterminismRequired(),
// errors::Unimplemented(
// "The GPU implementation of SoftmaxCrossEntropyWithLogits"
// " that would have been executed is not deterministic."
// " Note that the Python API uses an alternative,"
// " deterministic, GPU-accelerated path when determinism is"
// " enabled."));
// }

// loss is 1-D (one per example), and size is batch_size.

Tensor logits_in;
Tensor labels_in;
Tensor scratch;
Tensor loss_out;
Tensor back_out;

logits_in.shape.ReplaceWith(shape_in.DimensionsCount(), shape_in.DimsData());
logits_in.buffer = const_cast<T *>(logits_data);

labels_in.shape.ReplaceWith(y_true_shape.DimensionsCount(), y_true_shape.DimsData());
labels_in.buffer = const_cast<T *>(y_true_data);

scratch.shape.ReplaceWith(shape_in.DimensionsCount(), shape_in.DimsData());
std::vector<T> scratch_vec(shape_in.Dims(0) * shape_in.Dims(1), static_cast<T>(0));
scratch.buffer = scratch_vec.data();

Shape shape_loss_out{shape_in.Dims(0)};
loss_out.shape.ReplaceWith(shape_loss_out.DimensionsCount(), shape_loss_out.DimsData());
loss_out.buffer = loss_out_data;

back_out.shape.ReplaceWith(shape_in.DimensionsCount(), shape_in.DimsData());
back_out.buffer = grad_data;

//
// Tensor scratch;
// if (std::is_same<Device, CPUDevice>::value) {
// // OP_REQUIRES_OK(context,
// // context->allocate_temp(DataTypeToEnum<T>::value,
// // TensorShape({shape_in.dim_size(0),
// // shape_in.dim_size(1)}),
// // &scratch));
// scratch.shape.ReplaceWith(shape_in.DimensionsCount(), shape_in.DimsData());
// std::vector<T> scratch_vec(shape_in.Dims(0) * shape_in.Dims(1), static_cast<T>(0));
// scratch.buffer = scratch_vec.data();
// } else {
// // OP_REQUIRES_OK(context,
// // context->allocate_temp(
// // DataTypeToEnum<T>::value,
// // TensorShape({shape_in.dim_size(0), 1}), &scratch));
// Shape shape_sc{shape_in.Dims(0), 1};
// scratch.shape.ReplaceWith(shape_sc.DimensionsCount(), shape_sc.DimsData());
// std::vector<T> scratch_vec(shape_in.Dims(0), static_cast<T>(0));
// scratch.buffer = scratch_vec.data();
// }

// Tensor* loss_out = nullptr;
// OP_REQUIRES_OK(context,
// context->allocate_output(
// 0, TensorShape({shape_in.dim_size(0)}), &loss_out));
// Tensor* back_out = nullptr;
// Try to reuse the logits_in buffer for the backprop output.
// OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
// {0}, 1, shape_in, &back_out));

if (shape_in.Dims(0) > 0)
{
const xent_ops::CPUDevice &device = *eigen_support::GetThreadPoolDevice();
xent_ops::functor::XentFunctor<xent_ops::CPUDevice, T> functor;
const Eigen::DSizes<Eigen::DenseIndex, 2> shape{shape_in.Dims(0), shape_in.Dims(1)};

functor(device, shape, BCast::ToIndexArray<2>(bcast.x_bcast()),
BCast::ToIndexArray<2>(bcast.y_bcast()),
logits_in.template shaped<const T, 2>(bcast.x_reshape()),
labels_in.template shaped<const T, 2>(bcast.y_reshape()), scratch.matrix<T>(),
loss_out.vec<T>(), back_out.matrix<T>());
}
}

template <typename T>
inline void CategoricalCrossEntropyGrad(const Shape &y_pred_shape, const T *y_pred_data,
const Shape &y_true_shape, const T *y_true_data,
Expand Down
Loading

0 comments on commit 93b5d13

Please sign in to comment.