Skip to content

Commit

Permalink
Merge pull request #55 from BruceDai/add_logical
Browse files Browse the repository at this point in the history
Implement element-wise logical operations
  • Loading branch information
huningxin authored Nov 17, 2023
2 parents 4ca1381 + b5c0e9f commit 00e73d1
Show file tree
Hide file tree
Showing 4 changed files with 425 additions and 11 deletions.
2 changes: 1 addition & 1 deletion src/binary.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import {Tensor, sizeOfShape} from './lib/tensor.js';
* @param {Function} binaryFunc
* @return {Tensor}
*/
function binary(inputA, inputB, binaryFunc) {
export function binary(inputA, inputB, binaryFunc) {
const outputShape = getBroadcastShape(inputA.shape, inputB.shape);
const inputABroadcast = broadcast(inputA, outputShape);
const inputBBroadcast = broadcast(inputB, outputShape);
Expand Down
30 changes: 20 additions & 10 deletions src/lib/validate-input.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
'use strict';

import {sizeOfShape} from './tensor.js';

/**
* Check the tensor whether it is a 1-D tensor and its length is equal to `expectedSize`.
* @param {Tensor} a
Expand All @@ -22,7 +24,7 @@ function check1DTensorWithSize(a, expectedSize, name) {
export function validateBatchNormalizationParams(input, mean, variance,
{axis=1, scale, bias} = {}) {
if (!Number.isInteger(axis)) {
throw new Error(`Invalid axis ${axis}, axis should be an integer.`);
throw new Error(`Invalid axis ${axis} - axis should be an integer.`);
}
const dim = input.shape[axis];
check1DTensorWithSize(mean, dim, 'mean');
Expand Down Expand Up @@ -57,10 +59,10 @@ export function validateInstanceNormalizationParams(
export function validateConcatParams(inputs, axis) {
const rank = inputs[0].rank;
if (!Number.isInteger(axis)) {
throw new Error(`Invalid axis ${axis}, axis should be an integer.`);
throw new Error(`Invalid axis ${axis} - axis should be an integer.`);
} else {
if (axis < 0 || axis >= rank) {
throw new Error(`Invalid axis ${axis}, axis should be in the interval [0, ${rank}).`);
throw new Error(`Invalid axis ${axis} - axis should be in the interval [0, ${rank}).`);
}
}
const inputShape = inputs[0].shape;
Expand Down Expand Up @@ -299,18 +301,18 @@ export function validateSliceParams(input, starts, sizes) {
const size = input.shape[i];
const start = starts[i];
if (!Number.isInteger(start) || start < 0 ) {
throw new Error(`Invalid starts value ${start}, it should be an unsigned integer.`);
throw new Error(`Invalid starts value ${start} - it should be an unsigned integer.`);
}
if (start >= size) {
throw new Error(`Invalid starts value ${start}, it shoule be in the interval ` +
throw new Error(`Invalid starts value ${start} - it shoule be in the interval ` +
`[0, ${size}).`);
} else {
const sliceSize = sizes[i];
if (!Number.isInteger(sliceSize) || sliceSize <= 0) {
throw new Error(`Invalid sizes value ${sliceSize}, it should be an unsigned integer.`);
throw new Error(`Invalid sizes value ${sliceSize} - it should be an unsigned integer.`);
}
if (start + sliceSize > size) {
throw new Error(`Invalid sizes value ${sliceSize}, the sum of the start ${start} ` +
throw new Error(`Invalid sizes value ${sliceSize} - the sum of the start ${start} ` +
`plus the size ${sliceSize} is greater than the dimensional size ${size}`);
}
}
Expand All @@ -335,19 +337,19 @@ export function validateSplitParams(input, splits, {axis = 0} = {}) {
}
if (typeof splits === 'number') {
if (!Number.isInteger(splits) || splits <= 0) {
throw new Error(`Invalid splits ${splits}, it should be a positive integer.`);
throw new Error(`Invalid splits ${splits} - it should be a positive integer.`);
}
if (input.shape[axis] % splits !== 0) {
throw new Error(`The splits ${splits} must evenly divide the dimension size ` +
`${input.shape[axis]} of input along options.axis ${axis}.`);
}
} else if (splits instanceof Array) {
if (!splits.every((v) => Number.isInteger(v) && v > 0)) {
throw new Error(`Invalid splits ${splits}, it should be an Array of positive integers.`);
throw new Error(`Invalid splits ${splits} - it should be an Array of positive integers.`);
}
const sum = splits.reduce((a, b) => a + b);
if (sum !== input.shape[axis]) {
throw new Error(`Invalid [${splits}], the sum of sizes ${sum} must equal ` +
throw new Error(`Invalid [${splits}] - the sum of sizes ${sum} must equal ` +
`to the dimension size ${input.shape[axis]} of input` +
` along options.axis ${axis}`);
}
Expand Down Expand Up @@ -380,3 +382,11 @@ export function validateTranposeParams(input, {permutation}) {
}
}

export function validateNotParams(input) {
for (let i = 0; i < sizeOfShape(input.shape); ++i) {
const a = input.getValueByIndex(i);
if (!Number.isInteger(a) || a < 0 || a > 255) {
throw new Error('Invalid input value - it should be an integer in the interval [0, 255]');
}
}
}
31 changes: 31 additions & 0 deletions src/logical.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
'use strict';
import {Tensor, sizeOfShape} from './lib/tensor.js';
import {validateNotParams} from './lib/validate-input.js';
import {binary} from './binary.js';

/**
* Compute the element-wise logical not operation of input tensors.
* @param {Tensor} input
* @return {Tensor}
*/
function logicalNot(input) {
validateNotParams(input);
const outputShape = input.shape;
const outputSize = sizeOfShape(outputShape);
const output = new Tensor(outputShape);
for (let i = 0; i < outputSize; ++i) {
const a = input.getValueByIndex(i);
const b = !a ? 1 : 0;
output.setValueByIndex(i, b);
}
return output;
}

export const equal = (inputA, inputB) => binary(inputA, inputB, (a, b) => (a == b ? 1 : 0));
export const greater = (inputA, inputB) => binary(inputA, inputB, (a, b) => (a > b ? 1 : 0));
export const greaterOrEqual =
(inputA, inputB) => binary(inputA, inputB, (a, b) => (a >= b ? 1 : 0));
export const lesser = (inputA, inputB) => binary(inputA, inputB, (a, b) => (a < b ? 1 : 0));
export const lesserOrEqual =
(inputA, inputB) => binary(inputA, inputB, (a, b) => (a <= b ? 1 : 0));
export const not = (input) => logicalNot(input);
Loading

0 comments on commit 00e73d1

Please sign in to comment.