Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

completed lstm #71

Merged
merged 13 commits into from
Apr 16, 2024
Merged

completed lstm #71

merged 13 commits into from
Apr 16, 2024

Conversation

mei1127
Copy link
Contributor

@mei1127 mei1127 commented Mar 1, 2024

Copy link
Contributor

@BruceDai BruceDai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @mei1127.

src/lstm.js Outdated
import {sizeOfShape, Tensor} from './lib/tensor.js';
import {sigmoid} from './sigmoid.js';
import {slice} from './slice.js';
import {squeeze} from './squeeze.js';
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now squeeze op has been removed, would you please help also remove it from this WebNN Baseline.
You can refer to the given squeeze method in Spec, thanks.

function squeeze(builder, op) {
  return builder.reshape(op, op.shape().remove(0));
}

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The definition is:

function squeeze(input, axes) {
  if (!axes) axes = [];
  if (!axes.length)
    input.shape().forEach((item, i) => { axes.push(i); });
  shape = Array.from(input.shape());
  for (let axis in axes.sort().reverse())
    if (axis < shape.length && shape[axis] == 1)
      shape.splice(axis, 1);
  return builder.reshape(input, shape);
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I will revise it next week:)

src/lstm.js Outdated
* @param {MLLstmOptions} options
* @return {Array.<Tensor>}
*/

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please delete this blank line.

src/lstm.js Outdated
export function lstm(input, weight, recurrentWeight, steps, hiddenSize,
{bias, recurrentBias, peepholeWeight, initialHiddenState,
initialCellState, returnSequence = false, direction = 'forward', layout = 'iofg',
activations = [sigmoid, tanh, tanh]}={}) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
activations = [sigmoid, tanh, tanh]}={}) {
activations = [sigmoid, tanh, tanh]} = {}) {

throw new Error(`The cellState (rank ${cellState.rank}) is not a 2-D tensor.`);
}
if (cellState.shape[0] !== batchSize || cellState.shape[1] !== hiddenSize) {
throw new Error(`The shape of cellState
Copy link

@fdwr fdwr Mar 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message will be split :( :

The shape of cellState
  [2, 3] is invalid.

src/lstm_cell.js Outdated
Comment on lines 32 to 33
const starts = layout === 'iofg' ? {i: 0, o: hiddenSize, f: 2 * hiddenSize, g: 3 *hiddenSize} :
{i: 0, f: hiddenSize, g: 2 * hiddenSize, o: 3 * hiddenSize};
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be easier for future readers to visual parse if the indentation was aligned:

Suggested change
const starts = layout === 'iofg' ? {i: 0, o: hiddenSize, f: 2 * hiddenSize, g: 3 *hiddenSize} :
{i: 0, f: hiddenSize, g: 2 * hiddenSize, o: 3 * hiddenSize};
const starts = (layout === 'iofg') ? {i: 0, o: hiddenSize, f: 2 * hiddenSize, g: 3 * hiddenSize} :
{i: 0, f: hiddenSize, g: 2 * hiddenSize, o: 3 * hiddenSize};

);

// forget gate (f)
const f = activation0(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's cool how easy it is to compose these from existing operators :).

Copy link

@fdwr fdwr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good after some small comments. This is a complex operator! Thanks for adding it. 🙂

Copy link

@fdwr fdwr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just minor comments. Functionally looks correct. TY Mei.

src/reshape.js Outdated
*/
export function squeeze(input, {axes} = {}) {
validateSqueezeParams(...arguments);
const inpAxes = axes ?? new Array(input.rank).fill(0).map((_, i) => i);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const inpAxes = axes ?? new Array(input.rank).fill(0).map((_, i) => i);
const inputAxes = axes ?? new Array(input.rank).fill(0).map((_, i) => i);

Whole word identifiers are more readable for others later (and it's consistent with input.rank rather than inp.rank).

src/reshape.js Outdated
*/
export function squeeze(input, {axes} = {}) {
validateSqueezeParams(...arguments);
const inpAxes = axes ?? new Array(input.rank).fill(0).map((_, i) => i);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const inpAxes = axes ?? new Array(input.rank).fill(0).map((_, i) => i);
const inputAxes = axes ?? new Array.from({length: input.rank}, (_, i) => i);

Hmm, kinda confusing way of initializing a sequence by chaining multiple methods together (took me a bit to figure out), rather than just saying Array.from({length: input.rank}, (_, i) => i), which is longer character-wise, but it's clearer intention-wise (and probably more performant than a fill which is overwritten by a map). This one surprising case where C++ is shorter :b std::ranges::iota(inputAxes, 0).

@huningxin huningxin merged commit d01c2fc into webmachinelearning:main Apr 16, 2024
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants