diff --git a/nntrainer/tensor/float_tensor.cpp b/nntrainer/tensor/float_tensor.cpp index 89f7520ccc..69c382f691 100644 --- a/nntrainer/tensor/float_tensor.cpp +++ b/nntrainer/tensor/float_tensor.cpp @@ -513,6 +513,12 @@ TensorV2 &FloatTensor::pow(float exponent, TensorV2 &output) const { return output; } +TensorV2 &FloatTensor::erf(TensorV2 &output) const { + auto f = [](float in) { return std::erf(in); }; + apply(f, output); + return output; +} + void FloatTensor::print(std::ostream &out) const { printInstance(out, this); const float *data = (float *)getData(); diff --git a/nntrainer/tensor/float_tensor.h b/nntrainer/tensor/float_tensor.h index 4d99fbdaab..b6821525f6 100644 --- a/nntrainer/tensor/float_tensor.h +++ b/nntrainer/tensor/float_tensor.h @@ -256,6 +256,11 @@ class FloatTensor : public TensorBase { */ TensorV2 &pow(float exponent, TensorV2 &output) const override; + /** + * @copydoc TensorV2::erf(TensorV2 &output) + */ + TensorV2 &erf(TensorV2 &output) const override; + /** * @copydoc TensorV2::copy(const TensorV2 &from) */ diff --git a/nntrainer/tensor/half_tensor.cpp b/nntrainer/tensor/half_tensor.cpp index 748bb0f189..d0026daf1d 100644 --- a/nntrainer/tensor/half_tensor.cpp +++ b/nntrainer/tensor/half_tensor.cpp @@ -480,6 +480,14 @@ TensorV2 &HalfTensor::pow(float exponent, TensorV2 &output) const { return output; } +TensorV2 &HalfTensor::erf(TensorV2 &output) const { + auto f = [](_FP16 in) { + return static_cast<_FP16>(std::erf(static_cast(in))); + }; + apply(f, output); + return output; +} + void HalfTensor::print(std::ostream &out) const { printInstance(out, this); const _FP16 *data = (_FP16 *)getData(); diff --git a/nntrainer/tensor/half_tensor.h b/nntrainer/tensor/half_tensor.h index 7e96db11ca..6afbe64e02 100644 --- a/nntrainer/tensor/half_tensor.h +++ b/nntrainer/tensor/half_tensor.h @@ -255,6 +255,11 @@ class HalfTensor : public TensorBase { */ TensorV2 &pow(float exponent, TensorV2 &output) const override; + /** + * @copydoc TensorV2::erf(TensorV2 &output) + */ + TensorV2 &erf(TensorV2 &output) const override; + /** * @copydoc TensorV2::copy(const TensorV2 &from) */ diff --git a/nntrainer/tensor/tensor_base.h b/nntrainer/tensor/tensor_base.h index f905083f10..7974ea47a0 100644 --- a/nntrainer/tensor/tensor_base.h +++ b/nntrainer/tensor/tensor_base.h @@ -244,6 +244,11 @@ class TensorBase { */ virtual TensorV2 &pow(float exponent, TensorV2 &output) const = 0; + /** + * @copydoc TensorV2::erf(TensorV2 &output) + */ + virtual TensorV2 &erf(TensorV2 &output) const = 0; + /** * @copydoc TensorV2::print(std::ostream &out) */ diff --git a/nntrainer/tensor/tensor_v2.cpp b/nntrainer/tensor/tensor_v2.cpp index 3359de03de..1732dc9aca 100644 --- a/nntrainer/tensor/tensor_v2.cpp +++ b/nntrainer/tensor/tensor_v2.cpp @@ -361,6 +361,21 @@ TensorV2 &TensorV2::pow(float exponent, TensorV2 &output) const { return output; } +int TensorV2::erf_i() { + erf(*this); + return ML_ERROR_NONE; +} + +TensorV2 TensorV2::erf() const { + TensorV2 output("", getFormat(), getDataType()); + return erf(output); +} + +TensorV2 &TensorV2::erf(TensorV2 &output) const { + itensor->erf(output); + return output; +} + void TensorV2::print(std::ostream &out) const { itensor->print(out); } void TensorV2::putData() const { itensor->putData(); } diff --git a/nntrainer/tensor/tensor_v2.h b/nntrainer/tensor/tensor_v2.h index 0876632269..663a122ff6 100644 --- a/nntrainer/tensor/tensor_v2.h +++ b/nntrainer/tensor/tensor_v2.h @@ -732,6 +732,25 @@ class TensorV2 { */ TensorV2 &pow(float exponent, TensorV2 &output) const; + /** + * @brief Gauss error function + * @retval #ML_ERROR_NONE Successful + */ + int erf_i(); + + /** + * @brief Gauss error function + * @retval Calculated Tensor + */ + TensorV2 erf() const; + + /** + * @brief Gauss error function + * @param[out] output out to store the result + * @retval Calculated Tensor + */ + TensorV2 &erf(TensorV2 &output) const; + /** * @brief Print element * @param[in] out out stream