From 3827301e06c1902beb536268400fdc676e6b8f4e Mon Sep 17 00:00:00 2001 From: Li Xiaohong <35672632+LydiaXiaohongLi@users.noreply.github.com> Date: Sat, 1 Jul 2023 00:45:43 +0800 Subject: [PATCH] interface update for dot(), trans() for tl 2.01 1. added tl.trans() 2. removed trans_a, trans_b and allow_tf32 for tl.dot(), since removed by triton 2.01 --- src/kernl/debugger/tl_lang.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/kernl/debugger/tl_lang.py b/src/kernl/debugger/tl_lang.py index 3787b5fd..d7774d5d 100644 --- a/src/kernl/debugger/tl_lang.py +++ b/src/kernl/debugger/tl_lang.py @@ -403,13 +403,13 @@ def reshape(self, input, shape): raise NotImplementedError() @_tensor_operation - def dot(self, input, other, trans_a=False, trans_b=False, allow_tf32=True): + def dot(self, input, other): assert input.dtype == other.dtype - if trans_a: - input = input.T - if trans_b: - other = other.T return torch.matmul(input=input, other=other) + + @_tensor_operation + def trans(self, input): + return input.T @_tensor_operation def atomic_cas(self, pointer, cmp, val):