Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

completed lstm_cell #70

Merged
merged 2 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions src/lib/validate-input.js
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,73 @@ export function validateGemmParams(a, b) {
}
}

export function validateLstmCellParams(input, weight, recurrentWeight,
hiddenState, cellState, hiddenSize,
{bias, recurrentBias, peepholeWeight, layout = 'iofg'} = {}) {
if (!Number.isInteger(hiddenSize) || hiddenSize <= 0) {
throw new Error(`The hiddenSize ${hiddenSize} is invalid.`);
}
if (input.rank !== 2) {
throw new Error(`The input (rank ${input.rank}) is not a 2-D tensor.`);
}
const batchSize = input.shape[0];
const inputSize = input.shape[1];
if (weight.rank !== 2) {
throw new Error(`The weight (rank ${weight.rank}) is not a 2-D tensor.`);
}
if (weight.shape[0] !== 4 * hiddenSize || weight.shape[1] !== inputSize) {
throw new Error(`The shape of weight [${weight.shape[0]}, ${weight.shape[1]}] is invalid.`);
}
if (recurrentWeight.rank !== 2) {
throw new Error(`The recurrentWeight (rank ${recurrentWeight.rank}) is not a 2-D tensor.`);
}
if (recurrentWeight.shape[0] !== 4 * hiddenSize || recurrentWeight.shape[1] !== hiddenSize) {
throw new Error(`The shape of recurrentWeight ` +
`[${recurrentWeight.shape[0]}, ${recurrentWeight.shape[1]}] is invalid.`);
}
if (hiddenState.rank !== 2) {
throw new Error(`The hiddenState (rank ${hiddenState.rank}) is not a 2-D tensor.`);
}
if (hiddenState.shape[0] !== batchSize || hiddenState.shape[1] !== hiddenSize) {
throw new Error(`The shape of hiddenState
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message is split.

[${hiddenState.shape[0]}, ${hiddenState.shape[1]}] is invalid.`);
}
if (cellState.rank !== 2) {
throw new Error(`The cellState (rank ${cellState.rank}) is not a 2-D tensor.`);
}
if (cellState.shape[0] !== batchSize || cellState.shape[1] !== hiddenSize) {
throw new Error(`The shape of cellState
[${cellState.shape[0]}, ${cellState.shape[1]}] is invalid.`);
}
if (bias) {
if (bias.rank !== 1) {
throw new Error(`The bias (rank ${bias.rank}) is not a 1-D tensor.`);
}
if (bias.shape[0] !== 4 * hiddenSize) {
throw new Error(`The shape of bias [${bias.shape[0]}] is invalid.`);
}
}
if (recurrentBias) {
if (recurrentBias.rank !== 1) {
throw new Error(`The recurrentBias (rank ${bias.rank}) is not a 1-D tensor.`);
}
if (recurrentBias.shape[0] !== 4 * hiddenSize) {
throw new Error(`The shape of recurrentBias [${recurrentBias.shape[0]}] is invalid.`);
}
}
if (peepholeWeight) {
if (peepholeWeight.rank !== 1) {
throw new Error(`The peepholeWeight (rank ${bias.rank}) is not a 1-D tensor.`);
}
if (peepholeWeight.shape[0] !== 3 * hiddenSize) {
throw new Error(`The shape of peepholeWeight [${peepholeWeight.shape[0]}] is invalid.`);
}
}
if (layout !== 'iofg' && layout !== 'ifgo') {
throw new Error(`The layout ${layout} is invalid.`);
}
}

export function validateGruCellParams(input, weight, recurrentWeight, hiddenState, hiddenSize,
{bias, recurrentBias, layout = 'zrn'} = {}) {
if (!Number.isInteger(hiddenSize) || hiddenSize <= 0) {
Expand Down
143 changes: 143 additions & 0 deletions src/lstm_cell.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
'use strict';

import {add, mul} from './binary.js';
import {matmul} from './matmul.js';
import {Scalar} from './lib/tensor.js';
import {sigmoid} from './sigmoid.js';
import {slice} from './slice.js';
import {tanh} from './tanh.js';
import {transpose} from './transpose.js';
import {validateLstmCellParams} from './lib/validate-input.js';

/**
*
* @param {Tensor} input
* @param {Tensor} weight
* @param {Tensor} recurrentWeight
* @param {Tensor} hiddenState
* @param {Tensor} cellState
* @param {Number} hiddenSize
* @param {MLLstmCellOptions} options
* return {Tensor}
*/

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove this blank line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

export function lstmCell(input, weight, recurrentWeight, hiddenState, cellState, hiddenSize,
{bias, recurrentBias, peepholeWeight,
layout = 'iofg', activations = [sigmoid, tanh, tanh]}={}) {
validateLstmCellParams(...arguments);
const zero = new Scalar(0);
const inputSize = input.shape[1];
const starts = layout === 'iofg' ? {i: 0, o: hiddenSize, f: 2* hiddenSize, g: 3*hiddenSize} :
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const starts = layout === 'iofg' ? {i: 0, o: hiddenSize, f: 2* hiddenSize, g: 3*hiddenSize} :
const starts = layout === 'iofg' ? {i: 0, o: hiddenSize, f: 2 * hiddenSize, g: 3 * hiddenSize} :
{i: 0, f: hiddenSize, g: 2 * hiddenSize, o: 3 * hiddenSize};

{i: 0, f: hiddenSize, g: 2* hiddenSize, o: 3*hiddenSize};
const activation0 = activations[0];
const activation1 = activations[1];
const activation2 = activations[2];

// input gate (i)
const i = activation0(
add(
mul(
cellState,
(peepholeWeight ? slice(peepholeWeight, [starts.i], [hiddenSize]) : zero),
),
add(
add(
(bias ? slice(bias, [starts.i], [hiddenSize]) : zero),
(recurrentBias ? slice(recurrentBias, [starts.i], [hiddenSize]) : zero),
),
add(
matmul(
input,
transpose(slice(weight, [starts.i, 0], [hiddenSize, inputSize])),
),
matmul(
hiddenState,
transpose(slice(recurrentWeight, [starts.i, 0], [hiddenSize, hiddenSize])),
),
),
),
),
);

// forget gate (f)
const f = activation0(
add(
mul(
cellState,
(peepholeWeight ? slice(peepholeWeight, [starts.f], [hiddenSize]) : zero),
),
add(
add(
(bias ? slice(bias, [starts.f], [hiddenSize]) : zero),
(recurrentBias ? slice(recurrentBias, [starts.f], [hiddenSize]) : zero),
),
add(
matmul(
input,
transpose(slice(weight, [starts.f, 0], [hiddenSize, inputSize])),
),
matmul(
hiddenState,
transpose(
slice(recurrentWeight, [starts.f, 0], [hiddenSize, hiddenSize]),
),
),
),
),
),
);

// cell gate (g)
const g = activation1(
add(
add(
(bias ? slice(bias, [starts.g], [hiddenSize]) : zero),
(recurrentBias ? slice(recurrentBias, [starts.g], [hiddenSize]) : zero),
),
add(
matmul(
input,
transpose(slice(weight, [starts.g, 0], [hiddenSize, inputSize])),
),
matmul(
hiddenState,
transpose(slice(recurrentWeight, [starts.g, 0], [hiddenSize, hiddenSize])),
),
),
),
);

// output gate (o)
const o = activation0(
add(
mul(
cellState,
(peepholeWeight ? slice(peepholeWeight, [starts.o], [hiddenSize]) : zero),
),
add(
add(
(bias ? slice(bias, [starts.o], [hiddenSize]) : zero),
(recurrentBias ? slice(recurrentBias, [starts.o], [hiddenSize]) : zero),
),
add(
matmul(
input,
transpose(slice(weight, [starts.o, 0], [hiddenSize, inputSize])),
),
matmul(
hiddenState,
transpose(slice(recurrentWeight, [starts.o, 0], [hiddenSize, hiddenSize])),
),
),
),
),
);

// output cell state (ct)
const ct = add(mul(f, cellState), mul(i, g));

// output hidden state (ht)
const ht = mul(o, activation2(ct));

return [ht, ct];
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Full names are nice, like return [hiddenState, cellState]. Plus you already spell these names hiddenState and cellStates elsewhere below, which would be more consistent.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe hiddenStateNew could be better for understand and differentiate from the previous state?

}
57 changes: 57 additions & 0 deletions test/lstm_cell_test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
'use strict';

import {lstmCell} from '../src/lstm_cell.js';
import {relu} from '../src/relu.js';
import {Tensor} from '../src/lib/tensor.js';
import * as utils from './utils.js';

describe('test lstmCell', function() {
it.only('lstmCell lstmCell activations=[relu, relu, relu]', function() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use it() method.

const batchSize = 2;
const inputSize = 2;
const hiddenSize = 2;
const input = new Tensor([batchSize, inputSize], [1, 2, 2, 1]);
const weight = new Tensor([4 * hiddenSize, inputSize],
new Float32Array([
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer keep 2 space indentations.

Copy link

@fdwr fdwr Mar 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer keep 2 space indentations.

👍 Yeah, although I much prefer 4-space indents over 2-space indents to more clearly/accurately/quickly see the overall hierarchy of blocks, even more important is consistency within a codebase, and since the rest of the files are 2-space, then I'd do so here too (a mix of both 4-space and 2-space makes it harder to follow the levels too :'-( ).

There are lots of possible ways to do it, but picking one and sticking with it is better than switching between 3 different kinds of indentation. e.g.

Linear flow (uses more vertical space, but it's highly scannable for the eye to quickly locate code, efficient to glean overall block structure, consistently symmetric in constructs, and easily diffable)
    const hiddenState = new Tensor(                          // ->
      [batchSize, hiddenSize],                               //  0
      new Float32Array(batchSize * hiddenSize).fill(0),      //  1
    );                                                       // <-
    const cellState = new Tensor(                            // ->
      [batchSize, hiddenSize],                               //  0
      new Float32Array(batchSize * hiddenSize).fill(0),      //  1
    );                                                       // <-
    const activations = [                                    // ->
      relu,                                                  //  0
      relu,                                                  //  1
      relu,                                                  //  2
    ];                                                       // <-
    const outputs = lstmCell(
      input,
      weight,
      recurrentWeight,
      hiddenState,
      cellState,
      hiddenSize,
      {bias, recurrentBias, peepholeWeight, activations}
    );
Compact zig-zag ragged wrap with variable indent for nested lines (slower to scan and spot identifiers, and terrible for diffing, but shaves vertical space):
    const hiddenState = new Tensor([batchSize, hiddenSize],
                                   new Float32Array(batchSize * hiddenSize).fill(0));
    const cellState = new Tensor([batchSize, hiddenSize],
                                 new Float32Array(batchSize * hiddenSize).fill(0));
    const activations = [relu, relu, relu];
    const outputs = lstmCell(input, weight, recurrentWeight, hiddenState, cellState, hiddenSize,
                             {bias, recurrentBias, peepholeWeight, activations});
Compact zig-zag ragged wrap with uniform indentation for nested lines (slower to scan, bad for diffing albeit it a little less than #‌1, and shaves less vertical space):
    const hiddenState = new Tensor(
      [batchSize, hiddenSize],
      new Float32Array(batchSize * hiddenSize).fill(0));
    const cellState = new Tensor(
      [batchSize, hiddenSize],
      new Float32Array(batchSize * hiddenSize).fill(0));
    const activations = [
      relu, relu, relu];
    const outputs = lstmCell(
      input, weight, recurrentWeight, hiddenState, cellState, hiddenSize,
      {bias, recurrentBias, peepholeWeight, activations});

1, -1, 2, -2, 1, -1, 2, -2,
1, -1, 2, -2, 1, -1, 2, -2,
]));
const recurrentWeight = new Tensor([4 * hiddenSize, hiddenSize],
new Float32Array(4 * hiddenSize * hiddenSize).fill(0.1));
const hiddenState = new Tensor([batchSize, hiddenSize],
new Float32Array(batchSize * hiddenSize).fill(0));
const cellState = new Tensor([batchSize, hiddenSize],
new Float32Array(batchSize * hiddenSize).fill(0));
const bias = new Tensor([4* hiddenSize],
new Float32Array([
1, 2, 1, 2, 1, 2, 1, 2,
]));
const recurrentBias = new Tensor([4* hiddenSize],
new Float32Array([
1, 2, 1, 2, 1, 2, 1, 2,
]));
const peepholeWeight = new Tensor([3* hiddenSize],
new Float32Array(3 * hiddenSize).fill(0));
const activations = [
relu,
relu,
relu,
];
const outputs = lstmCell(
input, weight, recurrentWeight, hiddenState, cellState, hiddenSize,
{bias, recurrentBias, peepholeWeight, activations});
utils.checkShape(outputs[0], [batchSize, hiddenSize]);
utils.checkShape(outputs[1], [batchSize, hiddenSize]);
const expected = [
[
1, 8, 27, 216,
],
[
1, 4, 9, 36,
],
];
for (let i = 0; i < expected.length; ++i) {
utils.checkValue(outputs[i], expected[i]);
}
});
});
Loading