Skip to content

Commit

Permalink
Merge pull request #73 from BruceDai/add_gelu
Browse files Browse the repository at this point in the history
Implement gelu
  • Loading branch information
huningxin committed Apr 16, 2024
2 parents d01c2fc + f2f928b commit c419a51
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 14 deletions.
13 changes: 13 additions & 0 deletions src/gelu.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
'use strict';

import {erfKernel, unary} from './unary.js';

/**
* Compute the gaussian error linear unit function (GELU) of the input tensor.
* The calculation follows the expression 0.5 * x * (1 + erf(x / sqrt(2))).
* @param {Tensor} input
* @return {Tensor}
*/
export function gelu(input) {
return unary(input, (x) => 0.5 * x * (1 + erfKernel(x / Math.sqrt(2))));
}
30 changes: 16 additions & 14 deletions src/unary.js
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,7 @@ export function unary(input, unaryFunc) {
return output;
}

export const abs = (input) => unary(input, Math.abs);
export const ceil = (input) => unary(input, Math.ceil);
export const cos = (input) => unary(input, Math.cos);
export const exp = (input) => unary(input, Math.exp);
export const floor = (input) => unary(input, Math.floor);
export const log = (input) => unary(input, Math.log);
export const neg = (input) => unary(input, (x) => -1 * x);
export const sin = (input) => unary(input, Math.sin);
export const tan = (input) => unary(input, Math.tan);
export const copy = (input) => unary(input, (x) => x);
export const reciprocal = (input) => unary(input, (x) => 1 / x);
export const sqrt = (input) => unary(input, Math.sqrt);
export const erf = (input) => unary(input, (x) => {
export function erfKernel(x) {
// reference 1: https://en.wikipedia.org/wiki/Error_function
// reference 2: https://github.com/tensorflow/tfjs/blob/master/tfjs-backend-cpu/src/kernels/Erf.ts
const a1 = 0.254829592;
Expand All @@ -46,4 +34,18 @@ export const erf = (input) => unary(input, (x) => {
(1.0 -
(((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t *
Math.exp(-v * v));
});
}

export const abs = (input) => unary(input, Math.abs);
export const ceil = (input) => unary(input, Math.ceil);
export const cos = (input) => unary(input, Math.cos);
export const exp = (input) => unary(input, Math.exp);
export const floor = (input) => unary(input, Math.floor);
export const log = (input) => unary(input, Math.log);
export const neg = (input) => unary(input, (x) => -1 * x);
export const sin = (input) => unary(input, Math.sin);
export const tan = (input) => unary(input, Math.tan);
export const copy = (input) => unary(input, (x) => x);
export const reciprocal = (input) => unary(input, (x) => 1 / x);
export const sqrt = (input) => unary(input, Math.sqrt);
export const erf = (input) => unary(input, erfKernel);
20 changes: 20 additions & 0 deletions test/gelu_test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
'use strict';

import {gelu} from '../src/gelu.js';
import {Tensor} from '../src/lib/tensor.js';
import * as utils from './utils.js';

describe('test gelu', function() {
function testGelu(inputShape, inputValue, expected) {
const inputTensor = new Tensor(inputShape, inputValue);
const outputTensor = gelu(inputTensor);
utils.checkValue(outputTensor, expected);
}

it('gelu', function() {
// Refer to ONNX gelu_default test:
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#gelu
testGelu([3], [-1, 0, 1], [-0.15865526383236372, 0, 0.8413447361676363]);
testGelu([1, 1, 1, 3], [-1, 0, 1], [-0.15865526383236372, 0, 0.8413447361676363]);
});
});

0 comments on commit c419a51

Please sign in to comment.