Skip to content

Commit

Permalink
[acti_func] implement quick gelu
Browse files Browse the repository at this point in the history
 - Implemented quick gelu function.
   Please note that quickGeluPrime which is calculate derivate of quickGelu function is not yet implemented.

Signed-off-by: hyeonseok <[email protected]>
  • Loading branch information
lhs8928 committed Jun 7, 2024
1 parent ea26fdf commit b8dc545
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 2 deletions.
32 changes: 32 additions & 0 deletions nntrainer/layers/acti_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ class ActiFunc {
in_place = false;
this->setActivation<Tensor>(gelu<T>, geluPrime<T>);
break;
case ActivationType::ACT_QUICK_GELU:
in_place = false;
this->setActivation<Tensor>(quickGelu<T>, quickGeluPrime<T>);
break;
case ActivationType::ACT_ELU:
this->setActivation<T>(elu<T>, eluPrime<T>);
break;
Expand Down Expand Up @@ -457,6 +461,34 @@ class ActiFunc {
return outgoing_derivative;
}

/**
* @brief quick gelu activation function (gelu approximation)
* @param[in] t_in input tensor
* @param[in] t_out output tensor
*/
template <typename T = float>
static Tensor &quickGelu(Tensor const &t_in, Tensor &t_out) {
t_in.apply<T>(
[&](T x) { return static_cast<T>(x * (sigmoid<T>(static_cast<T>(1.702 * x)))); }, t_out);
return t_out;
}

/**
* @brief derivative quick gelu function
* @param[in] t_in input tensor
* @param[in] t_out output tensor
* @param[in] outgoing_derivative outgoing derivative
* @param[in] incoming_derivative incoming derivative
*/
template <typename T = float>
static Tensor &quickGeluPrime(Tensor const &t_in, Tensor const &t_out,
Tensor &outgoing_derivative,
Tensor const &incoming_derivative = Tensor()) {
// NYI
ml_logw("quickGeluPrime which is calculate derivate of quickGelu function is not yet implemented");
return outgoing_derivative;
}

/**
* @brief elu function
* @note alpha parameter is needed for elu, but supporting property on
Expand Down
7 changes: 5 additions & 2 deletions nntrainer/layers/common_properties.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ enum class ActivationType {
ACT_RELU, /**< ReLU */
ACT_SWISH, /**< Swish */
ACT_GELU, /**< GELU */
ACT_QUICK_GELU, /**< Quick GELU */
ACT_SOFTMAX, /**< softmax */
ACT_SOFTPLUS, /**< softplus */
ACT_LEAKY_RELU, /**< Leaky ReLU */
Expand Down Expand Up @@ -865,11 +866,13 @@ struct ActivationTypeInfo {
static constexpr std::initializer_list<Enum> EnumList = {
Enum::ACT_TANH, Enum::ACT_SIGMOID, Enum::ACT_RELU,
Enum::ACT_SOFTMAX, Enum::ACT_LEAKY_RELU, Enum::ACT_SWISH,
Enum::ACT_GELU, Enum::ACT_NONE, Enum::ACT_UNKNOWN};
Enum::ACT_GELU, Enum::ACT_QUICK_GELU, Enum::ACT_NONE,
Enum::ACT_UNKNOWN};

static constexpr const char *EnumStr[] = {"tanh", "sigmoid", "relu",
"softmax", "leaky_relu", "swish",
"gelu", "none", "unknown"};
"gelu", "quick_gelu", "none",
"unknown"};
};

/**
Expand Down

0 comments on commit b8dc545

Please sign in to comment.