Skip to content

Commit

Permalink
add cumulativeSum
Browse files Browse the repository at this point in the history
  • Loading branch information
mei1127 committed Sep 25, 2024
1 parent 9eeadad commit 353be4c
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 0 deletions.
46 changes: 46 additions & 0 deletions src/cumulativeSum.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
'use strict';

Check failure on line 1 in src/cumulativeSum.js

View workflow job for this annotation

GitHub Actions / job (macos-latest)

Expected linebreaks to be 'LF' but found 'CRLF'

Check failure on line 2 in src/cumulativeSum.js

View workflow job for this annotation

GitHub Actions / job (macos-latest)

Expected linebreaks to be 'LF' but found 'CRLF'
import {Tensor, sizeOfShape} from './lib/tensor.js';

Check failure on line 3 in src/cumulativeSum.js

View workflow job for this annotation

GitHub Actions / job (macos-latest)

Expected linebreaks to be 'LF' but found 'CRLF'
import {validateCumulativeSumParams} from './lib/validate-input.js';

Check failure on line 4 in src/cumulativeSum.js

View workflow job for this annotation

GitHub Actions / job (macos-latest)

Expected linebreaks to be 'LF' but found 'CRLF'

Check failure on line 5 in src/cumulativeSum.js

View workflow job for this annotation

GitHub Actions / job (macos-latest)

Expected linebreaks to be 'LF' but found 'CRLF'
/**

Check failure on line 6 in src/cumulativeSum.js

View workflow job for this annotation

GitHub Actions / job (macos-latest)

Expected linebreaks to be 'LF' but found 'CRLF'
* Computes the cumulative sum of the input tensor along the specified axis.

Check failure on line 7 in src/cumulativeSum.js

View workflow job for this annotation

GitHub Actions / job (macos-latest)

Expected linebreaks to be 'LF' but found 'CRLF'
* @param {Tensor} input

Check failure on line 8 in src/cumulativeSum.js

View workflow job for this annotation

GitHub Actions / job (macos-latest)

Expected linebreaks to be 'LF' but found 'CRLF'
* @param {number} axis

Check failure on line 9 in src/cumulativeSum.js

View workflow job for this annotation

GitHub Actions / job (macos-latest)

Expected linebreaks to be 'LF' but found 'CRLF'
* @param {MLCumulativeSumOptions} options

Check failure on line 10 in src/cumulativeSum.js

View workflow job for this annotation

GitHub Actions / job (macos-latest)

Expected linebreaks to be 'LF' but found 'CRLF'
* @return {Tensor}
*/
export function cumulativeSum(input, axis, {exclusive = 0, reverse = 0} = {}) {
validateCumulativeSumParams(...arguments);
const inputShape = input.shape;
const outputShape = [...inputShape];
const output = new Tensor(outputShape);
const numElementsAlongAxis = inputShape[axis];

const totalElements = sizeOfShape(outputShape);

for (let outputIndex = 0; outputIndex < totalElements; outputIndex++) {
const loc = output.locationFromIndex(outputIndex);
let cumulativeSumValue = 0;

const start = reverse ? numElementsAlongAxis - 1 : 0;
const step = reverse ? -1 : 1;
const end = reverse ? -1 : numElementsAlongAxis;

for (let i = start; reverse ? i > end : i < end; i += step) {
const inputLoc = [...loc];
inputLoc[axis] = exclusive ? (reverse ? i + 1 : i - 1) : i;

if (!exclusive || (exclusive && inputLoc[axis] >= 0 &&
inputLoc[axis] < numElementsAlongAxis)) {
cumulativeSumValue += input.getValueByLocation(inputLoc);
}

const outputLoc = [...loc];
outputLoc[axis] = i;
output.setValueByLocation(outputLoc, cumulativeSumValue);
}
}

return output;
}
12 changes: 12 additions & 0 deletions src/lib/validate-input.js
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,18 @@ export function validateGatherParams(input, indices, {axis = 0} = {}) {
}
}

export function validateCumulativeSumParams(input, axis) {
if (axis !== undefined) {
const rank = input.rank;
if (!Number.isInteger(axis) || axis < -rank || axis >= rank) {
throw new Error(`The axis ${axis} should be in the range [-rank(input), rank(input)-1].`);
}
if (axis >= rank) {
throw new Error(`The axis ${axis} should be in the interval [0, ${rank}).`);
}
}
}

export function validateTriangularParams(input, {diagonal = 0} = {}) {
const inputRank = input.rank;
if (inputRank < 2) {
Expand Down
143 changes: 143 additions & 0 deletions test/cumulativeSum_test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
'use strict';

import {cumulativeSum} from '../src/cumulativeSum.js';
import {Tensor} from '../src/lib/tensor.js';
import * as utils from './utils.js';

describe('test cumulativeSum', function() {
function testCumulativeSum(input, axis, options={}, expected) {
const tensor = new Tensor(input.shape, input.data);
const outputTensor = cumulativeSum(tensor, axis, options);
console.log('outputTensor', outputTensor);
utils.checkShape(outputTensor, expected.shape);
utils.checkValue(outputTensor, expected.data);
}

it('test cumulativeSum 1d', function() {
const input = {
shape: [5],
data: [
1, 2, 3, 4, 5,
],
};
const axis=0;
const options = {exclusive: 0, reverse: 0};
const expected = {
shape: [5],
data: [
1, 3, 6, 10, 15,
],
};
testCumulativeSum(input, axis, options, expected);
});

it('test cumulativeSum 1d exclusive', function() {
const input = {
shape: [5],
data: [
1, 2, 3, 4, 5,
],
};
const axis=0;
const options = {exclusive: 1, reverse: 0};
const expected = {
shape: [5],
data: [
0, 1, 3, 6, 10,
],
};
testCumulativeSum(input, axis, options, expected);
});

it('test cumulativeSum 1d reverse', function() {
const input = {
shape: [5],
data: [
1, 2, 3, 4, 5,
],
};
const axis=0;
const options = {exclusive: 0, reverse: 1};
const expected = {
shape: [5],
data: [
15, 14, 12, 9, 5,
],
};
testCumulativeSum(input, axis, options, expected);
});

it('test cumulativeSum 1d reverse exclusive', function() {
const input = {
shape: [5],
data: [
1, 2, 3, 4, 5,
],
};
const axis=0;
const options = {exclusive: 1, reverse: 1};
const expected = {
shape: [5],
data: [
14, 12, 9, 5, 0,
],
};
testCumulativeSum(input, axis, options, expected);
});

it('test cumulativeSum 2d', function() {
const input = {
shape: [2, 3],
data: [
1, 2, 3, 4, 5, 6,
],
};
const axis=0;
const options = {exclusive: 0, reverse: 0};
const expected = {
shape: [2, 3],
data: [
1, 2, 3, 5, 7, 9,
],
};
testCumulativeSum(input, axis, options, expected);
});

it('test cumulativeSum 2d axis=1', function() {
const input = {
shape: [2, 3],
data: [
1, 2, 3, 4, 5, 6,
],
};
const axis=1;
const options = {exclusive: 0, reverse: 0};
const expected = {
shape: [2, 3],
data: [
1, 3, 6, 4, 9, 15,
],
};
testCumulativeSum(input, axis, options, expected);
});

it('test cumulativeSum 2d negtive axis', function() {
const input = {
shape: [2, 3],
data: [
1, 2, 3, 4, 5, 6,
],
};
const axis=1;
const options = {exclusive: 0, reverse: 0};
const expected = {
shape: [2, 3],
data: [
1, 3, 6, 4, 9, 15,
],
};
testCumulativeSum(input, axis, options, expected);
});
});


0 comments on commit 353be4c

Please sign in to comment.