Skip to content

Commit

Permalink
Rename Layout -> TensorImpl (#1028)
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva authored Oct 8, 2024
1 parent 35ea27b commit cc8bf85
Show file tree
Hide file tree
Showing 20 changed files with 223 additions and 223 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/regression_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
gpu-arch-version: "12.1"
- name: CUDA 2.4
runs-on: linux.g5.12xlarge.nvidia.gpu
torch-spec: 'torch==2.4.0'
torch-spec: 'torch==2.4.1'
gpu-arch-type: "cuda"
gpu-arch-version: "12.1"
- name: CUDA Nightly (Oct 1)
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/benchmark_fp6.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pandas as pd
import torch.nn.functional as F
from torchao.dtypes import to_affine_quantized_fpx
from torchao.dtypes.floatx import FloatxTensorCoreAQTLayout, FloatxTensorCoreLayoutType
from torchao.dtypes.floatx import FloatxTensorCoreAQTTensorImpl, FloatxTensorCoreLayoutType
from torchao.utils import benchmark_torch_function_in_microseconds
from tqdm import tqdm

Expand Down
14 changes: 7 additions & 7 deletions test/dtypes/test_affine_quantized_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,18 +210,18 @@ def test_serialization(self, mode: str):

# Compare weights
if mode == "weight-only":
original_weight = original_layer.weight.layout_tensor.float8_data.to(
torch.float32
)
new_weight = new_layer.weight.layout_tensor.float8_data.to(
original_weight = original_layer.weight.tensor_impl.float8_data.to(
torch.float32
)
new_weight = new_layer.weight.tensor_impl.float8_data.to(torch.float32)
else:
original_weight = original_layer.weight.original_weight_tensor.layout_tensor.float8_data.to(
original_weight = original_layer.weight.original_weight_tensor.tensor_impl.float8_data.to(
torch.float32
)
new_weight = new_layer.weight.original_weight_tensor.layout_tensor.float8_data.to(
torch.float32
new_weight = (
new_layer.weight.original_weight_tensor.tensor_impl.float8_data.to(
torch.float32
)
)

assert torch.allclose(
Expand Down
14 changes: 7 additions & 7 deletions test/dtypes/test_floatx.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
run_tests,
)
from torchao.dtypes.floatx import (
FloatxTensorCoreAQTLayout,
FloatxTensorCoreAQTTensorImpl,
FloatxTensorCoreLayoutType,
to_scaled_tc_floatx,
from_scaled_tc_floatx,
Expand All @@ -28,7 +28,7 @@
_Floatx_DTYPES = [(3, 2), (2, 2)]


class TestFloatxTensorCoreAQTLayout(TestCase):
class TestFloatxTensorCoreAQTTensorImpl(TestCase):
@parametrize("device", _DEVICES)
def test_pack_tc_fp6_correctness(self, device):
x = torch.randint(256, size=(256, 64), dtype=torch.uint8, device=device)
Expand Down Expand Up @@ -82,10 +82,10 @@ def test_to_copy_device(self, ebits, mbits):
scale = choose_qparams_affine_floatx(x, ebits, mbits)
x = quantize_affine_floatx(x, scale, ebits, mbits)
layout_type = FloatxTensorCoreLayoutType(ebits, mbits)
floatx_layout_tensor = FloatxTensorCoreAQTLayout.from_plain(x, scale, None, layout_type).cuda()
assert floatx_layout_tensor.device.type == "cuda"
floatx_layout_tensor = floatx_layout_tensor.cpu()
assert floatx_layout_tensor.device.type == "cpu"
floatx_tensor_impl = FloatxTensorCoreAQTTensorImpl.from_plain(x, scale, None, layout_type).cuda()
assert floatx_tensor_impl.device.type == "cuda"
floatx_tensor_impl = floatx_tensor_impl.cpu()
assert floatx_tensor_impl.device.type == "cpu"

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="quantization only works with torch.compile for 2.5+")
Expand All @@ -106,7 +106,7 @@ def test_fpx_weight_only(self, ebits, mbits, bias):
torch.testing.assert_close(actual, expected)


instantiate_parametrized_tests(TestFloatxTensorCoreAQTLayout)
instantiate_parametrized_tests(TestFloatxTensorCoreAQTTensorImpl)


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions test/hqq/test_hqq_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from torchao.dtypes.affine_quantized_tensor import (
to_affine_quantized_intx,
ZeroPointDomain,
PlainAQTLayout,
PlainAQTTensorImpl,
PlainLayoutType,
TensorCoreTiledAQTLayout,
TensorCoreTiledAQTTensorImpl,
TensorCoreTiledLayoutType,
MappingType,
)
Expand Down
2 changes: 1 addition & 1 deletion test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,7 +1051,7 @@ def forward(self, x):
self.assertTrue(torch.equal(ref_q, test))

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(is_fbcode(), "'PlainAQTLayout' object has no attribute 'int_data'")
@unittest.skipIf(is_fbcode(), "'PlainAQTTensorImpl' object has no attribute 'int_data'")
@torch.no_grad()
def test_save_load_dqtensors(self, device, dtype):
if device == "cpu":
Expand Down
4 changes: 2 additions & 2 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
SemiSparseLayoutType,
TensorCoreTiledLayoutType,
Float8LayoutType,
Float8AQTLayout,
Float8AQTTensorImpl,
MarlinSparseLayoutType,
)

Expand All @@ -33,6 +33,6 @@
"SemiSparseLayoutType",
"TensorCoreTiledLayoutType",
"Float8LayoutType",
"Float8AQTLayout",
"Float8AQTTensorImpl",
"MarlinSparseLayoutType",
]
Loading

0 comments on commit cc8bf85

Please sign in to comment.