From 99d3a50698502a32aacb8371bb7bb228e69b77e8 Mon Sep 17 00:00:00 2001 From: dan Date: Tue, 28 Jan 2025 15:09:20 -0800 Subject: [PATCH] fix export --- sharktank/sharktank/layers/linear.py | 2 ++ sharktank/sharktank/ops/qlinear_impls.py | 4 +--- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sharktank/sharktank/layers/linear.py b/sharktank/sharktank/layers/linear.py index a1f1366ab..951c1701f 100644 --- a/sharktank/sharktank/layers/linear.py +++ b/sharktank/sharktank/layers/linear.py @@ -73,6 +73,8 @@ def forward(self, x): x = q_input.quantize(x) if self.fake_quant: x = x.unpack().dequant() + else: + x = x.unpack().qs elif qdq_input is not None: x = qdq_input.quantize(x).unpack().dequant() diff --git a/sharktank/sharktank/ops/qlinear_impls.py b/sharktank/sharktank/ops/qlinear_impls.py index b66d3be1d..f88684273 100644 --- a/sharktank/sharktank/ops/qlinear_impls.py +++ b/sharktank/sharktank/ops/qlinear_impls.py @@ -52,9 +52,7 @@ def qlinear_tensor_scaled( if x_layout.qs.dtype.is_floating_point or weight_layout.qs.dtype.is_floating_point: if x_layout.qs.dtype == torch.float8_e4m3fnuz: # assume quark - return matmul(x_layout.qs, weight_layout.qs, transpose_rhs=True).to( - torch.float16 - ) + return matmul(x_layout.qs, weight_layout.qs, transpose_rhs=True) else: return NotImplemented