Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement tile #100

Merged
merged 2 commits into from
Sep 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/lib/validate-input.js
Original file line number Diff line number Diff line change
Expand Up @@ -620,3 +620,16 @@ export function validateTriangularParams(input, {diagonal = 0} = {}) {
throw new Error(`The diagonal should be an integer.`);
}
}

export function validateTileParams(input, repetitions) {
const inputRank = input.rank;
const repetitionsLength = repetitions.length;
if (repetitionsLength != inputRank ) {
throw new Error(
`The repetitions length ${repetitionsLength} is not equal to rank ${inputRank}.`);
}
if (!repetitions.every((v) => Number.isInteger(v) && v > 0)) {
throw new Error(
`Invalid repetitions ${repetitions} - it should be an Array of positive integers.`);
Copy link

@fdwr fdwr Sep 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This reminds me - I want to try adding 0-dimension support into DirectML.dll so these cases can cleanly work without checks in the Chromium DML backend or the ORT DML EP. (resolve me)

}
}
28 changes: 28 additions & 0 deletions src/tile.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
'use strict';

import {Tensor, sizeOfShape} from './lib/tensor.js';
import {validateTileParams} from './lib/validate-input.js';

/**
* Represents the tile operation that repeats a tensor the given number of times along each axis.
* @param {Tensor} input
* @param {Array} repetitions
* @return {Tensor}
*/
export function tile(input, repetitions) {
validateTileParams(...arguments);
const outputShape = input.shape.map((size, index) => {
return size * repetitions[index];
});
const output = new Tensor(outputShape);
for (let outputIndex = 0; outputIndex < sizeOfShape(outputShape); ++outputIndex) {
const loc = output.locationFromIndex(outputIndex);
const selectedInputLoc = loc.slice();
for (let i = 0; i < loc.length; ++i) {
selectedInputLoc[i] = loc[i] % input.shape[i];
}
const inputValue = input.getValueByLocation(selectedInputLoc);
output.setValueByIndex(outputIndex, inputValue);
}
return output;
}
52 changes: 52 additions & 0 deletions test/tile_test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
'use strict';

import {tile} from '../src/tile.js';
import {Tensor} from '../src/lib/tensor.js';
import * as utils from './utils.js';

describe('test tile', function() {
function testTile(input, repetitions, expected) {
const tensor = new Tensor(input.shape, input.data);
const outputTensor = tile(tensor, repetitions);
utils.checkShape(outputTensor, expected.shape);
utils.checkValue(outputTensor, expected.data);
}

it('tile 1D', function() {
const input = {
shape: [4],
data: [
1, 2, 3, 4,
],
};
const repetitions = [2];
const expected = {
shape: [8],
data: [
1, 2, 3, 4, 1, 2, 3, 4,
],
};
testTile(input, repetitions, expected);
});

it('tile 2D', function() {
const input = {
shape: [2, 2],
data: [
1, 2,
3, 4,
],
};
const repetitions = [2, 3];
const expected = {
shape: [4, 6],
data: [
1, 2, 1, 2, 1, 2,
3, 4, 3, 4, 3, 4,
1, 2, 1, 2, 1, 2,
3, 4, 3, 4, 3, 4,
],
};
testTile(input, repetitions, expected);
});
});
Loading