Skip to content

Commit

Permalink
Update INT8 mixed-precision training test to be less flaky (#950)
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst authored Sep 26, 2024
1 parent 637ed13 commit ceec750
Showing 1 changed file with 19 additions and 32 deletions.
51 changes: 19 additions & 32 deletions test/prototype/test_quantized_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,44 +161,31 @@ def test_int8_weight_only_training(self, compile, device):
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_int8_mixed_precision_training(self, compile, config):
_reset()
bsize = 4
embed_dim = 32
bsize = 64
embed_dim = 64
device = "cuda"

# only use 1 matmul shape to reduce triton autotune time
model_ref = nn.Sequential(
nn.Linear(embed_dim, embed_dim, bias=False),
nn.GELU(),
nn.Linear(embed_dim, embed_dim),
).to(device)
model_int8mp = copy.deepcopy(model_ref)
quantize_(model_int8mp, int8_mixed_precision_training(config), set_inductor_config=False)
linear = nn.Linear(embed_dim, embed_dim).cuda()
linear_int8mp = copy.deepcopy(linear)
quantize_(linear_int8mp, int8_mixed_precision_training(config), set_inductor_config=False)

if compile:
model_ref.compile()
model_int8mp.compile()
linear.compile()
linear_int8mp.compile()

optim_ref = torch.optim.AdamW(model_ref.parameters())
optim_int8mp = torch.optim.AdamW(model_int8mp.parameters())
inputs = torch.randn(bsize, embed_dim, device=device)
grad_outputs = torch.randn(bsize, embed_dim, device=device)

for i in range(5):
inputs = torch.randn(bsize, embed_dim, device=device)
labels = torch.randint(embed_dim, size=(bsize,), device=device)
loss_ref = F.cross_entropy(model_ref(inputs), labels)
loss_int8mp = F.cross_entropy(model_int8mp(inputs), labels)

rel_error = abs(loss_int8mp.item() - loss_ref.item()) / abs(loss_ref.item())
assert rel_error < 3e-3, (i, rel_error)

loss_ref.backward()
optim_ref.step()
optim_ref.zero_grad()

loss_int8mp.backward()
for p in model_int8mp.parameters():
assert p.grad is not None
optim_int8mp.step()
optim_int8mp.zero_grad()
inputs_ref, outputs_ref = self._forward_and_backward(linear, inputs, grad_outputs)
inputs_int8mp, outputs_int8mp = self._forward_and_backward(linear_int8mp, inputs, grad_outputs)

def snr(ref, actual):
error = actual - ref
return 20 * torch.log10(ref.norm() / error.norm())

assert snr(outputs_ref, outputs_int8mp) > 20
assert snr(inputs_ref.grad, inputs_int8mp.grad) > 20
assert snr(linear.weight.grad, linear_int8mp.weight.grad) > 20


_FSDP_WORLD_SIZE = 2
Expand Down

0 comments on commit ceec750

Please sign in to comment.