Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Set z_loss_1d=None when return_z_loss=False in cross_entropy_loss to …
…avoid tl.store fail when triton_interpret=1(for tl.device_print etc.) (#508) For [issue-507](#507) ## Summary In cross_entropy_loss kernel, `tl.store(loss_ptr, loss)` doesn't work when `return_z_loss=False` and `triton_interpret=1`, because loss_1d is assigned to tensor z_loss_1d, So I set `z_loss_1d = None` in this situation and it works well. ## Testing Done I test it on my code and [this most simplified example](#507 (comment)), both work well. Hardware Type: T4 GPU - [ ] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence
- Loading branch information