Add conversion for StableHLO scatter op to TTIR and TTNN dialect #1279
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
As part of the effort to run stableHLO models on TT silicon, lowering scatter op through TTIR and TTNN dialects.
This op currently has limited functionality, but enough to pass the current models, primarily due to the current limitations in tt-metal:
The op only supports one index, mimicking Torch's
select_scatter
op https://pytorch.org/docs/stable/generated/torch.select_scatter.htmlFurthermore, since we do not currently support any indexing on the tt-metal side (Limitations in TT-Scatter Op tt-metal#4294), it will assume that the index is 0, and run tt-metal accordingly. Since we cannot check for in-memory values, it will currently silently fail (produce wrong results).
The StableHLO scatter op also supports adding a custom function for merging the two tensor operands of the scatter. Since this is not supported in tt-metal, I have added the check to see if the function is just mapping from one tensor to another (the default
select_scatter
behaviour) and assert on that. I have opened a related tt-mlir issue (Support of custom functions for the scatter op in TTNN dialect. #1278), but not a tt-metal one.Relevant TTIR dialect issue: Add ttir.scatter op to the TTIR dialect #1325