From d66f4d1d1f8ca049991cb4075e3d654853d9411a Mon Sep 17 00:00:00 2001 From: Austin Liu Date: Mon, 9 Dec 2024 19:17:05 +0800 Subject: [PATCH] Refactor temperatue calculation Signed-off-by: Austin Liu --- .../chunked_loss/fused_linear_distillation.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/liger_kernel/chunked_loss/fused_linear_distillation.py b/src/liger_kernel/chunked_loss/fused_linear_distillation.py index 11ae767f6..23adf94ba 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_distillation.py +++ b/src/liger_kernel/chunked_loss/fused_linear_distillation.py @@ -8,12 +8,15 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function): @abstractmethod - def distillation_loss_fn(student_logits, teacher_logits, temperature): + def distillation_loss_fn( + student_logits, + teacher_logits, + ): """ Compute distillation loss. Args: - student_logits (torch.Tensor): Raw logits of student tokens. Shape: (batch_size * seq_len, vocab_size). - teacher_logits (torch.Tensor): Raw logits of teacher tokens. Shape: (batch_size * seq_len, vocab_size). + student_logits (torch.Tensor): Raw (temperature-scaled) logits of student tokens. Shape: (batch_size * seq_len, vocab_size). + teacher_logits (torch.Tensor): Raw (temperature-scaled) logits of teacher tokens. Shape: (batch_size * seq_len, vocab_size). """ raise NotImplementedError("Distillation loss function must be implemented.") @@ -65,7 +68,6 @@ def _compute_loss( distillation_loss_fn=None, full_target=None, ignore_index=-100, - temperature=1.0, weight_hard_loss=0.5, weight_soft_loss=0.5, compute_ce_loss=True, @@ -105,7 +107,7 @@ def _compute_loss( hard_loss /= full_target.shape[0] - soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, temperature) + soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk) soft_loss /= full_target.shape[0] loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss @@ -166,7 +168,6 @@ def forward( weight_hard_loss=weight_hard_loss, weight_soft_loss=weight_soft_loss, compute_ce_loss=compute_ce_loss, - temperature=temperature, **loss_kwargs, ) @@ -219,6 +220,10 @@ def accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk): if compiled: accumulate_chunk = torch.compile(accumulate_chunk) + # Scale logits by temperature, scale by 1 as default (no scale). + student_input /= temperature + teacher_input /= temperature + num_chunks = max(1, student_input.shape[0] // CHUNK_SIZE) _student_input_chunks = torch.chunk(student_input, chunks=num_chunks, dim=0) _teacher_input_chunks = torch.chunk(teacher_input, chunks=num_chunks, dim=0)