Skip to content

Commit

Permalink
[TensorV2] Add tensor gauss error function
Browse files Browse the repository at this point in the history
This commit adds an error function, which computes the gauss error function of the input tensor.

**Self-evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test:   [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Donghyeon Jeong <[email protected]>
  • Loading branch information
djeong20 authored and myungjoo committed Feb 14, 2024
1 parent 1ea1513 commit 351a8c9
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 0 deletions.
6 changes: 6 additions & 0 deletions nntrainer/tensor/float_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,12 @@ TensorV2 &FloatTensor::pow(float exponent, TensorV2 &output) const {
return output;
}

TensorV2 &FloatTensor::erf(TensorV2 &output) const {
auto f = [](float in) { return std::erf(in); };
apply(f, output);
return output;
}

void FloatTensor::print(std::ostream &out) const {
printInstance(out, this);
const float *data = (float *)getData();
Expand Down
5 changes: 5 additions & 0 deletions nntrainer/tensor/float_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,11 @@ class FloatTensor : public TensorBase {
*/
TensorV2 &pow(float exponent, TensorV2 &output) const override;

/**
* @copydoc TensorV2::erf(TensorV2 &output)
*/
TensorV2 &erf(TensorV2 &output) const override;

/**
* @copydoc TensorV2::copy(const TensorV2 &from)
*/
Expand Down
8 changes: 8 additions & 0 deletions nntrainer/tensor/half_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,14 @@ TensorV2 &HalfTensor::pow(float exponent, TensorV2 &output) const {
return output;
}

TensorV2 &HalfTensor::erf(TensorV2 &output) const {
auto f = [](_FP16 in) {
return static_cast<_FP16>(std::erf(static_cast<float>(in)));
};
apply(f, output);
return output;
}

void HalfTensor::print(std::ostream &out) const {
printInstance(out, this);
const _FP16 *data = (_FP16 *)getData();
Expand Down
5 changes: 5 additions & 0 deletions nntrainer/tensor/half_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,11 @@ class HalfTensor : public TensorBase {
*/
TensorV2 &pow(float exponent, TensorV2 &output) const override;

/**
* @copydoc TensorV2::erf(TensorV2 &output)
*/
TensorV2 &erf(TensorV2 &output) const override;

/**
* @copydoc TensorV2::copy(const TensorV2 &from)
*/
Expand Down
5 changes: 5 additions & 0 deletions nntrainer/tensor/tensor_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,11 @@ class TensorBase {
*/
virtual TensorV2 &pow(float exponent, TensorV2 &output) const = 0;

/**
* @copydoc TensorV2::erf(TensorV2 &output)
*/
virtual TensorV2 &erf(TensorV2 &output) const = 0;

/**
* @copydoc TensorV2::print(std::ostream &out)
*/
Expand Down
15 changes: 15 additions & 0 deletions nntrainer/tensor/tensor_v2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,21 @@ TensorV2 &TensorV2::pow(float exponent, TensorV2 &output) const {
return output;
}

int TensorV2::erf_i() {
erf(*this);
return ML_ERROR_NONE;
}

TensorV2 TensorV2::erf() const {
TensorV2 output("", getFormat(), getDataType());
return erf(output);
}

TensorV2 &TensorV2::erf(TensorV2 &output) const {
itensor->erf(output);
return output;
}

void TensorV2::print(std::ostream &out) const { itensor->print(out); }

void TensorV2::putData() const { itensor->putData(); }
Expand Down
19 changes: 19 additions & 0 deletions nntrainer/tensor/tensor_v2.h
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,25 @@ class TensorV2 {
*/
TensorV2 &pow(float exponent, TensorV2 &output) const;

/**
* @brief Gauss error function
* @retval #ML_ERROR_NONE Successful
*/
int erf_i();

/**
* @brief Gauss error function
* @retval Calculated Tensor
*/
TensorV2 erf() const;

/**
* @brief Gauss error function
* @param[out] output out to store the result
* @retval Calculated Tensor
*/
TensorV2 &erf(TensorV2 &output) const;

/**
* @brief Print element
* @param[in] out out stream
Expand Down

0 comments on commit 351a8c9

Please sign in to comment.