diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index bffff16fc..b07ade0b5 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -161,44 +161,31 @@ def test_int8_weight_only_training(self, compile, device): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_int8_mixed_precision_training(self, compile, config): _reset() - bsize = 4 - embed_dim = 32 + bsize = 64 + embed_dim = 64 device = "cuda" - # only use 1 matmul shape to reduce triton autotune time - model_ref = nn.Sequential( - nn.Linear(embed_dim, embed_dim, bias=False), - nn.GELU(), - nn.Linear(embed_dim, embed_dim), - ).to(device) - model_int8mp = copy.deepcopy(model_ref) - quantize_(model_int8mp, int8_mixed_precision_training(config), set_inductor_config=False) + linear = nn.Linear(embed_dim, embed_dim).cuda() + linear_int8mp = copy.deepcopy(linear) + quantize_(linear_int8mp, int8_mixed_precision_training(config), set_inductor_config=False) if compile: - model_ref.compile() - model_int8mp.compile() + linear.compile() + linear_int8mp.compile() - optim_ref = torch.optim.AdamW(model_ref.parameters()) - optim_int8mp = torch.optim.AdamW(model_int8mp.parameters()) + inputs = torch.randn(bsize, embed_dim, device=device) + grad_outputs = torch.randn(bsize, embed_dim, device=device) - for i in range(5): - inputs = torch.randn(bsize, embed_dim, device=device) - labels = torch.randint(embed_dim, size=(bsize,), device=device) - loss_ref = F.cross_entropy(model_ref(inputs), labels) - loss_int8mp = F.cross_entropy(model_int8mp(inputs), labels) - - rel_error = abs(loss_int8mp.item() - loss_ref.item()) / abs(loss_ref.item()) - assert rel_error < 3e-3, (i, rel_error) - - loss_ref.backward() - optim_ref.step() - optim_ref.zero_grad() - - loss_int8mp.backward() - for p in model_int8mp.parameters(): - assert p.grad is not None - optim_int8mp.step() - optim_int8mp.zero_grad() + inputs_ref, outputs_ref = self._forward_and_backward(linear, inputs, grad_outputs) + inputs_int8mp, outputs_int8mp = self._forward_and_backward(linear_int8mp, inputs, grad_outputs) + + def snr(ref, actual): + error = actual - ref + return 20 * torch.log10(ref.norm() / error.norm()) + + assert snr(outputs_ref, outputs_int8mp) > 20 + assert snr(inputs_ref.grad, inputs_int8mp.grad) > 20 + assert snr(linear.weight.grad, linear_int8mp.weight.grad) > 20 _FSDP_WORLD_SIZE = 2