Skip to content

Commit

Permalink
Refactor temperatue calculation
Browse files Browse the repository at this point in the history
Signed-off-by: Austin Liu <[email protected]>
  • Loading branch information
austin362667 committed Dec 9, 2024
1 parent 515b491 commit d66f4d1
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions src/liger_kernel/chunked_loss/fused_linear_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@
class LigerFusedLinearDistillationBase(torch.autograd.Function):

@abstractmethod
def distillation_loss_fn(student_logits, teacher_logits, temperature):
def distillation_loss_fn(
student_logits,
teacher_logits,
):
"""
Compute distillation loss.
Args:
student_logits (torch.Tensor): Raw logits of student tokens. Shape: (batch_size * seq_len, vocab_size).
teacher_logits (torch.Tensor): Raw logits of teacher tokens. Shape: (batch_size * seq_len, vocab_size).
student_logits (torch.Tensor): Raw (temperature-scaled) logits of student tokens. Shape: (batch_size * seq_len, vocab_size).
teacher_logits (torch.Tensor): Raw (temperature-scaled) logits of teacher tokens. Shape: (batch_size * seq_len, vocab_size).
"""
raise NotImplementedError("Distillation loss function must be implemented.")

Expand Down Expand Up @@ -65,7 +68,6 @@ def _compute_loss(
distillation_loss_fn=None,
full_target=None,
ignore_index=-100,
temperature=1.0,
weight_hard_loss=0.5,
weight_soft_loss=0.5,
compute_ce_loss=True,
Expand Down Expand Up @@ -105,7 +107,7 @@ def _compute_loss(

hard_loss /= full_target.shape[0]

soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, temperature)
soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk)
soft_loss /= full_target.shape[0]

loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss
Expand Down Expand Up @@ -166,7 +168,6 @@ def forward(
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
compute_ce_loss=compute_ce_loss,
temperature=temperature,
**loss_kwargs,
)

Expand Down Expand Up @@ -219,6 +220,10 @@ def accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk):
if compiled:
accumulate_chunk = torch.compile(accumulate_chunk)

# Scale logits by temperature, scale by 1 as default (no scale).
student_input /= temperature
teacher_input /= temperature

num_chunks = max(1, student_input.shape[0] // CHUNK_SIZE)
_student_input_chunks = torch.chunk(student_input, chunks=num_chunks, dim=0)
_teacher_input_chunks = torch.chunk(teacher_input, chunks=num_chunks, dim=0)
Expand Down

0 comments on commit d66f4d1

Please sign in to comment.