From abfceb9acbc17794b729ae7ef99c8745abfa8bc0 Mon Sep 17 00:00:00 2001 From: mei1127 Date: Wed, 28 Feb 2024 15:49:55 +0800 Subject: [PATCH 1/2] complete lstm_cell --- src/lib/validate-input.js | 67 ++++++++++++++++++ src/lstm_cell.js | 143 ++++++++++++++++++++++++++++++++++++++ test/lstm_cell_test.js | 57 +++++++++++++++ 3 files changed, 267 insertions(+) create mode 100644 src/lstm_cell.js create mode 100644 test/lstm_cell_test.js diff --git a/src/lib/validate-input.js b/src/lib/validate-input.js index ca1e92c..448f525 100644 --- a/src/lib/validate-input.js +++ b/src/lib/validate-input.js @@ -133,6 +133,73 @@ export function validateGemmParams(a, b) { } } +export function validateLstmCellParams(input, weight, recurrentWeight, + hiddenState, cellState, hiddenSize, + {bias, recurrentBias, peepholeWeight, layout = 'iofg'} = {}) { + if (!Number.isInteger(hiddenSize) || hiddenSize <= 0) { + throw new Error(`The hiddenSize ${hiddenSize} is invalid.`); + } + if (input.rank !== 2) { + throw new Error(`The input (rank ${input.rank}) is not a 2-D tensor.`); + } + const batchSize = input.shape[0]; + const inputSize = input.shape[1]; + if (weight.rank !== 2) { + throw new Error(`The weight (rank ${weight.rank}) is not a 2-D tensor.`); + } + if (weight.shape[0] !== 4 * hiddenSize || weight.shape[1] !== inputSize) { + throw new Error(`The shape of weight [${weight.shape[0]}, ${weight.shape[1]}] is invalid.`); + } + if (recurrentWeight.rank !== 2) { + throw new Error(`The recurrentWeight (rank ${recurrentWeight.rank}) is not a 2-D tensor.`); + } + if (recurrentWeight.shape[0] !== 4 * hiddenSize || recurrentWeight.shape[1] !== hiddenSize) { + throw new Error(`The shape of recurrentWeight ` + + `[${recurrentWeight.shape[0]}, ${recurrentWeight.shape[1]}] is invalid.`); + } + if (hiddenState.rank !== 2) { + throw new Error(`The hiddenState (rank ${hiddenState.rank}) is not a 2-D tensor.`); + } + if (hiddenState.shape[0] !== batchSize || hiddenState.shape[1] !== hiddenSize) { + throw new Error(`The shape of hiddenState + [${hiddenState.shape[0]}, ${hiddenState.shape[1]}] is invalid.`); + } + if (cellState.rank !== 2) { + throw new Error(`The cellState (rank ${cellState.rank}) is not a 2-D tensor.`); + } + if (cellState.shape[0] !== batchSize || cellState.shape[1] !== hiddenSize) { + throw new Error(`The shape of cellState + [${cellState.shape[0]}, ${cellState.shape[1]}] is invalid.`); + } + if (bias) { + if (bias.rank !== 1) { + throw new Error(`The bias (rank ${bias.rank}) is not a 1-D tensor.`); + } + if (bias.shape[0] !== 4 * hiddenSize) { + throw new Error(`The shape of bias [${bias.shape[0]}] is invalid.`); + } + } + if (recurrentBias) { + if (recurrentBias.rank !== 1) { + throw new Error(`The recurrentBias (rank ${bias.rank}) is not a 1-D tensor.`); + } + if (recurrentBias.shape[0] !== 4 * hiddenSize) { + throw new Error(`The shape of recurrentBias [${recurrentBias.shape[0]}] is invalid.`); + } + } + if (peepholeWeight) { + if (peepholeWeight.rank !== 1) { + throw new Error(`The peepholeWeight (rank ${bias.rank}) is not a 1-D tensor.`); + } + if (peepholeWeight.shape[0] !== 3 * hiddenSize) { + throw new Error(`The shape of peepholeWeight [${peepholeWeight.shape[0]}] is invalid.`); + } + } + if (layout !== 'iofg' && layout !== 'ifgo') { + throw new Error(`The layout ${layout} is invalid.`); + } +} + export function validateGruCellParams(input, weight, recurrentWeight, hiddenState, hiddenSize, {bias, recurrentBias, layout = 'zrn'} = {}) { if (!Number.isInteger(hiddenSize) || hiddenSize <= 0) { diff --git a/src/lstm_cell.js b/src/lstm_cell.js new file mode 100644 index 0000000..bfafd46 --- /dev/null +++ b/src/lstm_cell.js @@ -0,0 +1,143 @@ +'use strict'; + +import {add, mul} from './binary.js'; +import {matmul} from './matmul.js'; +import {Scalar} from './lib/tensor.js'; +import {sigmoid} from './sigmoid.js'; +import {slice} from './slice.js'; +import {tanh} from './tanh.js'; +import {transpose} from './transpose.js'; +import {validateLstmCellParams} from './lib/validate-input.js'; + +/** + * + * @param {Tensor} input + * @param {Tensor} weight + * @param {Tensor} recurrentWeight + * @param {Tensor} hiddenState + * @param {Tensor} cellState + * @param {Number} hiddenSize + * @param {MLLstmCellOptions} options + * return {Tensor} + */ + +export function lstmCell(input, weight, recurrentWeight, hiddenState, cellState, hiddenSize, + {bias, recurrentBias, peepholeWeight, + layout = 'iofg', activations = [sigmoid, tanh, tanh]}={}) { + validateLstmCellParams(...arguments); + const zero = new Scalar(0); + const inputSize = input.shape[1]; + const starts = layout === 'iofg' ? {i: 0, o: hiddenSize, f: 2* hiddenSize, g: 3*hiddenSize} : + {i: 0, f: hiddenSize, g: 2* hiddenSize, o: 3*hiddenSize}; + const activation0 = activations[0]; + const activation1 = activations[1]; + const activation2 = activations[2]; + + // input gate (i) + const i = activation0( + add( + mul( + cellState, + (peepholeWeight ? slice(peepholeWeight, [starts.i], [hiddenSize]) : zero), + ), + add( + add( + (bias ? slice(bias, [starts.i], [hiddenSize]) : zero), + (recurrentBias ? slice(recurrentBias, [starts.i], [hiddenSize]) : zero), + ), + add( + matmul( + input, + transpose(slice(weight, [starts.i, 0], [hiddenSize, inputSize])), + ), + matmul( + hiddenState, + transpose(slice(recurrentWeight, [starts.i, 0], [hiddenSize, hiddenSize])), + ), + ), + ), + ), + ); + + // forget gate (f) + const f = activation0( + add( + mul( + cellState, + (peepholeWeight ? slice(peepholeWeight, [starts.f], [hiddenSize]) : zero), + ), + add( + add( + (bias ? slice(bias, [starts.f], [hiddenSize]) : zero), + (recurrentBias ? slice(recurrentBias, [starts.f], [hiddenSize]) : zero), + ), + add( + matmul( + input, + transpose(slice(weight, [starts.f, 0], [hiddenSize, inputSize])), + ), + matmul( + hiddenState, + transpose( + slice(recurrentWeight, [starts.f, 0], [hiddenSize, hiddenSize]), + ), + ), + ), + ), + ), + ); + + // cell gate (g) + const g = activation1( + add( + add( + (bias ? slice(bias, [starts.g], [hiddenSize]) : zero), + (recurrentBias ? slice(recurrentBias, [starts.g], [hiddenSize]) : zero), + ), + add( + matmul( + input, + transpose(slice(weight, [starts.g, 0], [hiddenSize, inputSize])), + ), + matmul( + hiddenState, + transpose(slice(recurrentWeight, [starts.g, 0], [hiddenSize, hiddenSize])), + ), + ), + ), + ); + + // output gate (o) + const o = activation0( + add( + mul( + cellState, + (peepholeWeight ? slice(peepholeWeight, [starts.o], [hiddenSize]) : zero), + ), + add( + add( + (bias ? slice(bias, [starts.o], [hiddenSize]) : zero), + (recurrentBias ? slice(recurrentBias, [starts.o], [hiddenSize]) : zero), + ), + add( + matmul( + input, + transpose(slice(weight, [starts.o, 0], [hiddenSize, inputSize])), + ), + matmul( + hiddenState, + transpose(slice(recurrentWeight, [starts.o, 0], [hiddenSize, hiddenSize])), + ), + ), + ), + ), + ); + + // output cell state (ct) + const ct = add(mul(f, cellState), mul(i, g)); + + // output hidden state (ht) + const ht = mul(o, activation2(ct)); + + return [ht, ct]; +} diff --git a/test/lstm_cell_test.js b/test/lstm_cell_test.js new file mode 100644 index 0000000..f84cde0 --- /dev/null +++ b/test/lstm_cell_test.js @@ -0,0 +1,57 @@ +'use strict'; + +import {lstmCell} from '../src/lstm_cell.js'; +import {relu} from '../src/relu.js'; +import {Tensor} from '../src/lib/tensor.js'; +import * as utils from './utils.js'; + +describe('test lstmCell', function() { + it.only('lstmCell lstmCell activations=[relu, relu, relu]', function() { + const batchSize = 2; + const inputSize = 2; + const hiddenSize = 2; + const input = new Tensor([batchSize, inputSize], [1, 2, 2, 1]); + const weight = new Tensor([4 * hiddenSize, inputSize], + new Float32Array([ + 1, -1, 2, -2, 1, -1, 2, -2, + 1, -1, 2, -2, 1, -1, 2, -2, + ])); + const recurrentWeight = new Tensor([4 * hiddenSize, hiddenSize], + new Float32Array(4 * hiddenSize * hiddenSize).fill(0.1)); + const hiddenState = new Tensor([batchSize, hiddenSize], + new Float32Array(batchSize * hiddenSize).fill(0)); + const cellState = new Tensor([batchSize, hiddenSize], + new Float32Array(batchSize * hiddenSize).fill(0)); + const bias = new Tensor([4* hiddenSize], + new Float32Array([ + 1, 2, 1, 2, 1, 2, 1, 2, + ])); + const recurrentBias = new Tensor([4* hiddenSize], + new Float32Array([ + 1, 2, 1, 2, 1, 2, 1, 2, + ])); + const peepholeWeight = new Tensor([3* hiddenSize], + new Float32Array(3 * hiddenSize).fill(0)); + const activations = [ + relu, + relu, + relu, + ]; + const outputs = lstmCell( + input, weight, recurrentWeight, hiddenState, cellState, hiddenSize, + {bias, recurrentBias, peepholeWeight, activations}); + utils.checkShape(outputs[0], [batchSize, hiddenSize]); + utils.checkShape(outputs[1], [batchSize, hiddenSize]); + const expected = [ + [ + 1, 8, 27, 216, + ], + [ + 1, 4, 9, 36, + ], + ]; + for (let i = 0; i < expected.length; ++i) { + utils.checkValue(outputs[i], expected[i]); + } + }); +}); From 9a7e0a486fe07165dbb91d9f4702734817a89d28 Mon Sep 17 00:00:00 2001 From: mei1127 Date: Wed, 28 Feb 2024 16:17:05 +0800 Subject: [PATCH 2/2] revised lstm_cell.js and lstm_cell_test.js --- src/lstm_cell.js | 12 +++++++----- test/lstm_cell_test.js | 2 +- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/lstm_cell.js b/src/lstm_cell.js index bfafd46..3d07d11 100644 --- a/src/lstm_cell.js +++ b/src/lstm_cell.js @@ -10,7 +10,10 @@ import {transpose} from './transpose.js'; import {validateLstmCellParams} from './lib/validate-input.js'; /** - * + *A single time step of the Long Short-Term Memory [LSTM] recurrent network + *using a cell state, an input, output, and forget gate to compute the cell + *state and the hidden state of the next time step that rolls into the output + *across the temporal sequence of the network. * @param {Tensor} input * @param {Tensor} weight * @param {Tensor} recurrentWeight @@ -18,17 +21,16 @@ import {validateLstmCellParams} from './lib/validate-input.js'; * @param {Tensor} cellState * @param {Number} hiddenSize * @param {MLLstmCellOptions} options - * return {Tensor} + * @return {Tensor} */ - export function lstmCell(input, weight, recurrentWeight, hiddenState, cellState, hiddenSize, {bias, recurrentBias, peepholeWeight, layout = 'iofg', activations = [sigmoid, tanh, tanh]}={}) { validateLstmCellParams(...arguments); const zero = new Scalar(0); const inputSize = input.shape[1]; - const starts = layout === 'iofg' ? {i: 0, o: hiddenSize, f: 2* hiddenSize, g: 3*hiddenSize} : - {i: 0, f: hiddenSize, g: 2* hiddenSize, o: 3*hiddenSize}; + const starts = layout === 'iofg' ? {i: 0, o: hiddenSize, f: 2 * hiddenSize, g: 3 *hiddenSize} : + {i: 0, f: hiddenSize, g: 2 * hiddenSize, o: 3 * hiddenSize}; const activation0 = activations[0]; const activation1 = activations[1]; const activation2 = activations[2]; diff --git a/test/lstm_cell_test.js b/test/lstm_cell_test.js index f84cde0..b57ffa8 100644 --- a/test/lstm_cell_test.js +++ b/test/lstm_cell_test.js @@ -6,7 +6,7 @@ import {Tensor} from '../src/lib/tensor.js'; import * as utils from './utils.js'; describe('test lstmCell', function() { - it.only('lstmCell lstmCell activations=[relu, relu, relu]', function() { + it('lstmCell lstmCell activations=[relu, relu, relu]', function() { const batchSize = 2; const inputSize = 2; const hiddenSize = 2;