-
Notifications
You must be signed in to change notification settings - Fork 158
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ONE-DCO-1.0-Signed-off-by: ragmani <[email protected]>
- Loading branch information
Showing
30 changed files
with
772 additions
and
201 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
/* | ||
* Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved | ||
* Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
// From tensorflow/core/kernels/training_ops.h and | ||
|
||
#ifndef TENSORFLOW_CORE_KERNELS_TRAINING_OPS_H_ | ||
#define TENSORFLOW_CORE_KERNELS_TRAINING_OPS_H_ | ||
|
||
// TODO Enable EIGEN_USE_THREADS only for training_ops | ||
#define EIGEN_USE_THREADS | ||
|
||
#include "unsupported/Eigen/CXX11/Tensor" | ||
#include "cker/operation/Helper/Tensor.h" | ||
|
||
namespace nnfw | ||
{ | ||
namespace cker | ||
{ | ||
namespace training_ops | ||
{ | ||
namespace functor | ||
{ | ||
|
||
// From tensorflow/core/kernels/training_ops.h | ||
|
||
// Each training algorithm has a ApplyXYZ functor struct declared in | ||
// this header file. They are specialized for different devices | ||
// (CPUDevice in training_ops.cc or GPUDevice in training_ops_gpu.cc). | ||
|
||
template <typename Device, typename T> struct ApplyGradientDescent | ||
{ | ||
void operator()(const Device &d, typename TTypes<T>::Flat var, | ||
typename TTypes<T>::ConstScalar alpha, typename TTypes<T>::ConstFlat delta); | ||
}; | ||
|
||
template <typename Device, typename T> struct ApplyAdam | ||
{ | ||
void operator()(const Device &d, typename TTypes<T>::Flat var, typename TTypes<T>::Flat m, | ||
typename TTypes<T>::Flat v, typename TTypes<T>::ConstScalar beta1_power, | ||
typename TTypes<T>::ConstScalar beta2_power, typename TTypes<T>::ConstScalar lr, | ||
typename TTypes<T>::ConstScalar beta1, typename TTypes<T>::ConstScalar beta2, | ||
typename TTypes<T>::ConstScalar epsilon, typename TTypes<T>::ConstFlat grad, | ||
bool use_nesterov); | ||
}; | ||
|
||
} // namespace functor | ||
} // namespace training_ops | ||
} // namespace cker | ||
} // namespace nnfw | ||
|
||
// From tensorflow/core/kernels/training_ops.cc | ||
|
||
namespace nnfw | ||
{ | ||
namespace cker | ||
{ | ||
namespace training_ops | ||
{ | ||
|
||
// Enable CPUDevice only for training_ops | ||
using CPUDevice = Eigen::ThreadPoolDevice; | ||
using Index = Eigen::Index; | ||
|
||
namespace functor | ||
{ | ||
|
||
template <typename T> struct ApplyGradientDescent<CPUDevice, T> | ||
{ | ||
void operator()(const CPUDevice &d, typename TTypes<T>::Flat var, | ||
typename TTypes<T>::ConstScalar lr, typename TTypes<T>::ConstFlat grad) | ||
{ | ||
var.device(d) -= grad * lr(); | ||
} | ||
}; | ||
|
||
template <typename Device, typename T> struct ApplyAdamNonCuda | ||
{ | ||
void operator()(const Device &d, typename TTypes<T>::Flat var, typename TTypes<T>::Flat m, | ||
typename TTypes<T>::Flat v, typename TTypes<T>::ConstScalar beta1_power, | ||
typename TTypes<T>::ConstScalar beta2_power, typename TTypes<T>::ConstScalar lr, | ||
typename TTypes<T>::ConstScalar beta1, typename TTypes<T>::ConstScalar beta2, | ||
typename TTypes<T>::ConstScalar epsilon, typename TTypes<T>::ConstFlat grad, | ||
bool use_nesterov) | ||
{ | ||
// Get params length and check if they can be vectorized by packet size. | ||
Index length = var.size(); | ||
Index packet_size = Eigen::internal::packet_traits<T>::size; | ||
if (length % packet_size == 0) | ||
{ | ||
length = length / packet_size; | ||
} | ||
else | ||
{ | ||
packet_size = 1; | ||
} | ||
|
||
T *var_ptr = var.data(); | ||
T *m_ptr = m.data(); | ||
T *v_ptr = v.data(); | ||
const T *g_ptr = grad.data(); | ||
const T alpha = lr() * Eigen::numext::sqrt(T(1) - beta2_power()) / (T(1) - beta1_power()); | ||
// beta1 == μ | ||
// beta2 == ν | ||
// v == n | ||
// var == θ | ||
|
||
auto shard = [var_ptr, m_ptr, v_ptr, g_ptr, alpha, beta1, beta2, epsilon, use_nesterov, | ||
packet_size](int begin, int end) { | ||
int t_size = (end - begin) * packet_size; | ||
begin = begin * packet_size; | ||
auto var = typename TTypes<T>::UnalignedTensor(var_ptr + begin, t_size); | ||
auto m = typename TTypes<T>::UnalignedTensor(m_ptr + begin, t_size); | ||
auto v = typename TTypes<T>::UnalignedTensor(v_ptr + begin, t_size); | ||
auto g = typename TTypes<T>::UnalignedConstTensor(g_ptr + begin, t_size); | ||
|
||
if (use_nesterov) | ||
{ | ||
m += (g - m) * (T(1) - beta1()); | ||
v += (g.square() - v) * (T(1) - beta2()); | ||
var -= ((g * (T(1) - beta1()) + beta1() * m) * alpha) / (v.sqrt() + epsilon()); | ||
} | ||
else | ||
{ | ||
m += (g - m) * (T(1) - beta1()); | ||
v += (g.square() - v) * (T(1) - beta2()); | ||
var -= (m * alpha) / (v.sqrt() + epsilon()); | ||
} | ||
}; | ||
|
||
// Input data: var, v, m, grad. | ||
// Output data: var, v, m. | ||
const int input_bytes = length * packet_size * sizeof(T) * 4; | ||
const int output_bytes = length * packet_size * sizeof(T) * 3; | ||
const int compute_cycles = | ||
// Consider Sub as Add | ||
(Eigen::TensorOpCost::AddCost<int>() * 5 + Eigen::TensorOpCost::MulCost<int>() * 2 + | ||
Eigen::TensorOpCost::AddCost<T>() * 10 + Eigen::TensorOpCost::MulCost<T>() * 6 + | ||
Eigen::TensorOpCost::DivCost<T>()) * | ||
length; | ||
const Eigen::TensorOpCost cost(input_bytes, output_bytes, compute_cycles); | ||
|
||
// Eigen device must update 3 variables with 3 different expressions, | ||
// which is bad for cache locality on CPU. Here use ParallelFor instead of | ||
// "regular" tensor expressions to get better performance. | ||
d.parallelFor(length, cost, shard); | ||
} | ||
}; | ||
|
||
template <typename T> struct ApplyAdam<CPUDevice, T> : ApplyAdamNonCuda<CPUDevice, T> | ||
{ | ||
}; | ||
|
||
} // namespace functor | ||
} // namespace training_ops | ||
} // namespace cker | ||
} // namespace nnfw | ||
|
||
#endif // TENSORFLOW_CORE_KERNELS_TRAINING_OPS_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
/* | ||
* Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved | ||
* Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#ifndef __NNFW_CKER_TRAIN_OPTIMIZER_ADAM_H__ | ||
|
||
// #include "OptimizerHelpers.h" | ||
#include "cker/eigen/training_ops.h" | ||
#include "cker/eigen/EigenSupport.h" | ||
|
||
#include <vector> | ||
|
||
namespace nnfw | ||
{ | ||
namespace cker | ||
{ | ||
namespace train | ||
{ | ||
|
||
inline void Adam(const Shape &output_shape, float *output_data, const Shape &grad_shape, | ||
const float *grad_data, const Shape &m_shape, float *m_data, | ||
const Shape &v_shape, float *v_data, float beta1_power, float beta2_power, | ||
float learning_rate, float beta1, float beta2, float epsilon, bool use_nesterov) | ||
{ | ||
// const bool sparse = false; | ||
// const bool use_exclusive_lock = false; | ||
// auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>( | ||
// ctx, use_exclusive_lock, sparse, {0, 1, 2}); | ||
|
||
Tensor output_tensor; | ||
Tensor grad_tensor; | ||
Tensor m_tensor; | ||
Tensor v_tensor; | ||
Tensor beta1_power_tensor; | ||
Tensor beta2_power_tensor; | ||
Tensor lr_tensor; | ||
Tensor beta1_tensor; | ||
Tensor beta2_tensor; | ||
Tensor epsilon_tensor; | ||
|
||
output_tensor.shape.ReplaceWith(output_shape.DimensionsCount(), output_shape.DimsData()); | ||
output_tensor.buffer = output_data; | ||
|
||
grad_tensor.shape.ReplaceWith(grad_shape.DimensionsCount(), grad_shape.DimsData()); | ||
grad_tensor.buffer = const_cast<float *>(grad_data); | ||
|
||
m_tensor.shape.ReplaceWith(m_shape.DimensionsCount(), m_shape.DimsData()); | ||
m_tensor.buffer = const_cast<float *>(m_data); | ||
|
||
v_tensor.shape.ReplaceWith(v_shape.DimensionsCount(), v_shape.DimsData()); | ||
v_tensor.buffer = const_cast<float *>(v_data); | ||
|
||
std::vector<float> beta1_power_vec{beta1_power}; | ||
beta1_power_tensor.buffer = beta1_power_vec.data(); | ||
|
||
std::vector<float> beta2_power_vec{beta2_power}; | ||
beta2_power_tensor.buffer = beta2_power_vec.data(); | ||
|
||
std::vector<float> lr_vec{learning_rate}; | ||
lr_tensor.buffer = lr_vec.data(); | ||
|
||
std::vector<float> beta1_vec{beta1}; | ||
beta1_tensor.buffer = beta1_vec.data(); | ||
|
||
std::vector<float> beta2_vec{beta2}; | ||
beta2_tensor.buffer = beta2_vec.data(); | ||
|
||
std::vector<float> epsilon_vec{epsilon}; | ||
epsilon_tensor.buffer = epsilon_vec.data(); | ||
|
||
if (output_shape != m_shape) | ||
throw std::runtime_error( | ||
"cker::Adam: output and m do not have the same shape"); | ||
|
||
if (output_shape != v_shape) | ||
throw std::runtime_error( | ||
"cker::Adam: output and v do not have the same shape"); | ||
|
||
if (output_shape != grad_shape) | ||
throw std::runtime_error( | ||
"cker::Adam: output and gradient do not have the same shape"); | ||
|
||
const training_ops::CPUDevice &device = *eigen_support::GetThreadPoolDevice(); | ||
training_ops::functor::ApplyAdam<training_ops::CPUDevice, float>()( | ||
device, output_tensor.flat<float>(), m_tensor.flat<float>(), v_tensor.flat<float>(), | ||
beta1_power_tensor.scalar<float>(), beta2_power_tensor.scalar<float>(), lr_tensor.scalar<float>(), | ||
beta1_tensor.scalar<float>(), beta2_tensor.scalar<float>(), epsilon_tensor.scalar<float>(), | ||
static_cast<const Tensor &>(grad_tensor).flat<float>(), use_nesterov); | ||
|
||
// MaybeForwardRefInputToRefOutput(ctx, 0, 0); | ||
} | ||
|
||
} // namespace train | ||
} // namespace cker | ||
} // namespace nnfw | ||
|
||
#endif // __NNFW_CKER_TRAIN_OPTIMIZER_ADAM_H__ |
Oops, something went wrong.