Skip to content

Commit

Permalink
Refactor Temperature Scaling in Distillation Loss (#444)
Browse files Browse the repository at this point in the history
## Summary

Addressed the part of issue raised in
#441

Moving the scale temperature outside the `distillation_loss_fn` is fine
as well. Keep the `loss_fn` simpler, and the rest can be handled in the
`forward` function beforehand. Thanks to the advice by @Tcc0403

<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->

## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->

<!-- 
Replace BLANK with your device type. For example, A100-80G-PCIe

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- Hardware Type: <BLANK>
- [ ] run `make test` to ensure correctness
- [X] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence

---------

Signed-off-by: Austin Liu <[email protected]>
  • Loading branch information
austin362667 authored Jan 1, 2025
1 parent 42ff02a commit c46d951
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions src/liger_kernel/chunked_loss/fused_linear_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@

class LigerFusedLinearDistillationBase(torch.autograd.Function):
@abstractmethod
def distillation_loss_fn(student_logits, teacher_logits, temperature):
def distillation_loss_fn(
student_logits,
teacher_logits,
):
"""
Compute distillation loss.
Args:
student_logits (torch.Tensor): Raw logits of student tokens. Shape: (batch_size * seq_len, vocab_size).
teacher_logits (torch.Tensor): Raw logits of teacher tokens. Shape: (batch_size * seq_len, vocab_size).
student_logits (torch.Tensor): Raw (temperature-scaled) logits of student tokens. Shape: (batch_size * seq_len, vocab_size).
teacher_logits (torch.Tensor): Raw (temperature-scaled) logits of teacher tokens. Shape: (batch_size * seq_len, vocab_size).
"""
raise NotImplementedError("Distillation loss function must be implemented.")

Expand Down Expand Up @@ -65,7 +68,6 @@ def _compute_loss(
distillation_loss_fn=None,
full_target=None,
ignore_index=-100,
temperature=1.0,
weight_hard_loss=0.5,
weight_soft_loss=0.5,
compute_ce_loss=True,
Expand Down Expand Up @@ -107,7 +109,7 @@ def _compute_loss(

hard_loss /= full_target.shape[0]

soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, temperature)
soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk)
soft_loss /= full_target.shape[0]

loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss
Expand Down Expand Up @@ -147,10 +149,11 @@ def forward(
teacher_bias (torch.Tensor, optional): Teacher bias tensor. Shape: (vocab_size,).
loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
chunk_size (int): Size of a chunk.
compute_ce_loss (bool): Whether to compute CE loss.
ignore_index (int): Index to ignore for loss computation.
weight_hard_loss (float): Weight for hard/task loss.
weight_soft_loss (float): Weight for soft/distillation loss.
compute_ce_loss (bool): Whether to compute CE loss.
temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale)
compiled (bool): Whether to use torch compile for chunk accumulation.
loss_kwargs (dict): Other possible arguments that a loss function might need
"""
Expand All @@ -168,7 +171,6 @@ def forward(
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
compute_ce_loss=compute_ce_loss,
temperature=temperature,
**loss_kwargs,
)

Expand Down Expand Up @@ -223,6 +225,9 @@ def accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk):
if compiled:
accumulate_chunk = torch.compile(accumulate_chunk)

student_input /= temperature
teacher_input /= temperature

num_chunks = max(1, student_input.shape[0] // CHUNK_SIZE)
_student_input_chunks = torch.chunk(student_input, chunks=num_chunks, dim=0)
_teacher_input_chunks = torch.chunk(teacher_input, chunks=num_chunks, dim=0)
Expand Down

0 comments on commit c46d951

Please sign in to comment.