Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement gather #61

Merged
merged 4 commits into from
Dec 16, 2023
Merged

implement gather #61

merged 4 commits into from
Dec 16, 2023

Conversation

BruceDai
Copy link
Contributor

@BruceDai BruceDai commented Dec 8, 2023

src/gather.js Outdated
const loc = output.locationFromIndex(i);
const indicesLoc = loc.slice(axis, axis + indices.rank);
const selectedInputLoc = loc.slice(0, axis)
.concat(indices.getValueByLocation(indicesLoc), loc.slice(axis + indices.rank));
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a FYI, the current implementation in DML will not throw if the values in indices are out of bounds.

The open spec issue for out-of-bounds handling: webmachinelearning/webnn#486

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the pointer. webnn-baseline implementation aims at "doing" right thing ((non out-of-bounds access)) and generating the expected test data.

There is a separate task that explicitly checks the out-of-bounds access tracking by issue: webmachinelearning/webnn#244 (comment).

Copy link

@fdwr fdwr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation looks right from what I remember (gather always confuses me more than other ops 😅). Let's add a scalar case, and then LGTM. TY.

src/gather.js Show resolved Hide resolved
src/gather.js Outdated
// if (dimCount <= axis) {
// continue;
// } else {
// here index shoud be set to (rankOutput + dimCount - axis - 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please set appropriate indent for this comment.

src/gather.js Outdated Show resolved Hide resolved
src/gather.js Outdated
const shapeOutput = shapeInput.slice(0, axis).concat(indices.shape, shapeInput.slice(axis + 1));
const output = new Tensor(shapeOutput);

for (let i = 0; i < sizeOfShape(shapeOutput); ++i) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using another variable name? I suppose the 'i' being used in the following comment is different one.

src/gather.js Outdated Show resolved Hide resolved
src/gather.js Outdated
const loc = output.locationFromIndex(i);
const indicesLoc = loc.slice(axis, axis + indices.rank);
const selectedInputLoc = loc.slice(0, axis)
.concat(indices.getValueByLocation(indicesLoc), loc.slice(axis + indices.rank));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the pointer. webnn-baseline implementation aims at "doing" right thing ((non out-of-bounds access)) and generating the expected test data.

There is a separate task that explicitly checks the out-of-bounds access tracking by issue: webmachinelearning/webnn#244 (comment).

const axisSize = input.shape[axis];
for (let i = 0; i < sizeOfShape(indices.shape); ++i) {
const index = indices.getValueByIndex(i);
if (!Number.isInteger(index) || index < 0 || index >= axisSize) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current WebNN DirectML backend allows negative index value. @fdwr

Copy link

@fdwr fdwr Dec 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://www.w3.org/TR/webnn/#api-mlgraphbuilder-gather

Appears the spec settled on not supporting negative indices, 8. If index is greater than or equal to axisSize, then throw, and that is tracked via webmachinelearning/webnn#484.

I will look through the models I have for any negative indices to see if this will actually matter in practice...

[update] I need some better tools, as this is way too tedious... 🤔⏳

@BruceDai
Copy link
Contributor Author

@fdwr @huningxin @shiyi9801 I've updated commits to address comments, please take another look, thanks.

Copy link
Contributor

@huningxin huningxin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, thanks!

Copy link

@fdwr fdwr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 Thanks for the new test cases.

@huningxin huningxin merged commit baaf15d into webmachinelearning:main Dec 16, 2023
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants