From b73574aa46662d8efba182f81425a0c8a1c0f4ec Mon Sep 17 00:00:00 2001 From: BruceDai Date: Tue, 9 Apr 2024 16:48:08 +0800 Subject: [PATCH 1/2] Implement gelu --- src/gelu.js | 13 +++++++++++++ src/unary.js | 30 ++++++++++++++++-------------- test/gelu_test.js | 20 ++++++++++++++++++++ 3 files changed, 49 insertions(+), 14 deletions(-) create mode 100644 src/gelu.js create mode 100644 test/gelu_test.js diff --git a/src/gelu.js b/src/gelu.js new file mode 100644 index 0000000..f5ed343 --- /dev/null +++ b/src/gelu.js @@ -0,0 +1,13 @@ +'use strict'; + +import {erfKernal, unary} from './unary.js'; + +/** + * Compute the gaussian error linear unit function (GELU) of the input tensor. + * The calculation follows the expression 0.5 * x * (1 + erf(x / sqrt(2))). + * @param {Tensor} input + * @return {Tensor} + */ +export function gelu(input) { + return unary(input, (x) => 0.5 * x * (1 + erfKernal(x / Math.sqrt(2)))); +} diff --git a/src/unary.js b/src/unary.js index 54e4172..04e64f1 100644 --- a/src/unary.js +++ b/src/unary.js @@ -18,19 +18,7 @@ export function unary(input, unaryFunc) { return output; } -export const abs = (input) => unary(input, Math.abs); -export const ceil = (input) => unary(input, Math.ceil); -export const cos = (input) => unary(input, Math.cos); -export const exp = (input) => unary(input, Math.exp); -export const floor = (input) => unary(input, Math.floor); -export const log = (input) => unary(input, Math.log); -export const neg = (input) => unary(input, (x) => -1 * x); -export const sin = (input) => unary(input, Math.sin); -export const tan = (input) => unary(input, Math.tan); -export const copy = (input) => unary(input, (x) => x); -export const reciprocal = (input) => unary(input, (x) => 1 / x); -export const sqrt = (input) => unary(input, Math.sqrt); -export const erf = (input) => unary(input, (x) => { +export function erfKernal(x) { // reference 1: https://en.wikipedia.org/wiki/Error_function // reference 2: https://github.com/tensorflow/tfjs/blob/master/tfjs-backend-cpu/src/kernels/Erf.ts const a1 = 0.254829592; @@ -46,4 +34,18 @@ export const erf = (input) => unary(input, (x) => { (1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * Math.exp(-v * v)); -}); +} + +export const abs = (input) => unary(input, Math.abs); +export const ceil = (input) => unary(input, Math.ceil); +export const cos = (input) => unary(input, Math.cos); +export const exp = (input) => unary(input, Math.exp); +export const floor = (input) => unary(input, Math.floor); +export const log = (input) => unary(input, Math.log); +export const neg = (input) => unary(input, (x) => -1 * x); +export const sin = (input) => unary(input, Math.sin); +export const tan = (input) => unary(input, Math.tan); +export const copy = (input) => unary(input, (x) => x); +export const reciprocal = (input) => unary(input, (x) => 1 / x); +export const sqrt = (input) => unary(input, Math.sqrt); +export const erf = (input) => unary(input, erfKernal); diff --git a/test/gelu_test.js b/test/gelu_test.js new file mode 100644 index 0000000..5a755a5 --- /dev/null +++ b/test/gelu_test.js @@ -0,0 +1,20 @@ +'use strict'; + +import {gelu} from '../src/gelu.js'; +import {Tensor} from '../src/lib/tensor.js'; +import * as utils from './utils.js'; + +describe('test gelu', function() { + function testGelu(inputShape, inputValue, expected) { + const inputTensor = new Tensor(inputShape, inputValue); + const outputTensor = gelu(inputTensor); + utils.checkValue(outputTensor, expected); + } + + it('gelu', function() { + // Refer to ONNX gelu_default test: + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#gelu + testGelu([3], [-1, 0, 1], [-0.15865526383236372, 0, 0.8413447361676363]); + testGelu([1, 1, 1, 3], [-1, 0, 1], [-0.15865526383236372, 0, 0.8413447361676363]); + }); +}); From f2f928beaf6da3b6ecee2f73c176f17cbc39d2df Mon Sep 17 00:00:00 2001 From: BruceDai Date: Thu, 11 Apr 2024 09:43:02 +0800 Subject: [PATCH 2/2] Fixed typos --- src/gelu.js | 4 ++-- src/unary.js | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/gelu.js b/src/gelu.js index f5ed343..db5be87 100644 --- a/src/gelu.js +++ b/src/gelu.js @@ -1,6 +1,6 @@ 'use strict'; -import {erfKernal, unary} from './unary.js'; +import {erfKernel, unary} from './unary.js'; /** * Compute the gaussian error linear unit function (GELU) of the input tensor. @@ -9,5 +9,5 @@ import {erfKernal, unary} from './unary.js'; * @return {Tensor} */ export function gelu(input) { - return unary(input, (x) => 0.5 * x * (1 + erfKernal(x / Math.sqrt(2)))); + return unary(input, (x) => 0.5 * x * (1 + erfKernel(x / Math.sqrt(2)))); } diff --git a/src/unary.js b/src/unary.js index 04e64f1..67fa418 100644 --- a/src/unary.js +++ b/src/unary.js @@ -18,7 +18,7 @@ export function unary(input, unaryFunc) { return output; } -export function erfKernal(x) { +export function erfKernel(x) { // reference 1: https://en.wikipedia.org/wiki/Error_function // reference 2: https://github.com/tensorflow/tfjs/blob/master/tfjs-backend-cpu/src/kernels/Erf.ts const a1 = 0.254829592; @@ -48,4 +48,4 @@ export const tan = (input) => unary(input, Math.tan); export const copy = (input) => unary(input, (x) => x); export const reciprocal = (input) => unary(input, (x) => 1 / x); export const sqrt = (input) => unary(input, Math.sqrt); -export const erf = (input) => unary(input, erfKernal); +export const erf = (input) => unary(input, erfKernel);