-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce Distillation with a Chunked, Fused Linear JS-divergence Loss #408
Conversation
Signed-off-by: Austin Liu <[email protected]> Add Testing Naive Distillation Base Signed-off-by: Austin Liu <[email protected]> Add Chunked JSD Tests and Benchmarks Signed-off-by: Austin Liu <[email protected]> Fix call Signed-off-by: Austin Liu <[email protected]> Fix Test Usage Signed-off-by: Austin Liu <[email protected]> Remove beta Signed-off-by: Austin Liu <[email protected]> Fix test params Signed-off-by: Austin Liu <[email protected]> Fix call Signed-off-by: Austin Liu <[email protected]> Fix ignore_index Signed-off-by: Austin Liu <[email protected]> Fix weights dimension Signed-off-by: Austin Liu <[email protected]> Fix assign dimension Signed-off-by: Austin Liu <[email protected]> Fix assign dimension Signed-off-by: Austin Liu <[email protected]> Fix teacher bias Signed-off-by: Austin Liu <[email protected]> Reshape input Signed-off-by: Austin Liu <[email protected]> Fix mean Signed-off-by: Austin Liu <[email protected]> Remove alpha Signed-off-by: Austin Liu <[email protected]> Fix t Signed-off-by: Austin Liu <[email protected]> Fix t Signed-off-by: Austin Liu <[email protected]> Fix t scaling Signed-off-by: Austin Liu <[email protected]> Remove teacher tests Signed-off-by: Austin Liu <[email protected]> Fix t scaling Signed-off-by: Austin Liu <[email protected]> Fix beta Signed-off-by: Austin Liu <[email protected]> Fix beta Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]> WIP Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]> Clean up Signed-off-by: Austin Liu <[email protected]> Clean up Signed-off-by: Austin Liu <[email protected]> Clean up Signed-off-by: Austin Liu <[email protected]> Clean up Signed-off-by: Austin Liu <[email protected]> Clean up Signed-off-by: Austin Liu <[email protected]> Clean up Signed-off-by: Austin Liu <[email protected]> Clean up Signed-off-by: Austin Liu <[email protected]> Clean up Signed-off-by: Austin Liu <[email protected]> Clean up Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]> Format Signed-off-by: Austin Liu <[email protected]> Fix Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]> Fix tol Signed-off-by: Austin Liu <[email protected]> Fix tol Signed-off-by: Austin Liu <[email protected]> Fix tol Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>
if valid_mask.any(): | ||
student_average_log_prob[valid_mask] = ( | ||
student_per_token_logps * loss_mask | ||
).sum(-1)[valid_mask] / loss_mask_sum[valid_mask] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not quite understand what loss_mask_sum
and valid_mask
do. Is it just a way to avoid ZeroDivisionError?
loss_fn=None, | ||
chunk_size=1, | ||
ignore_index=-100, | ||
beta=0.5, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps we need another variable name for the weight between soft and hard loss, since some loss functions have 'beta' parameter, such as generalized jsd we've implemented in #278.
Since lambda
is a reserved keyword, maybe weight_hard_loss
and weight_soft_loss
?
If sum of both weights is 1, you can just pick one of them and also consider torch.lerp() for combining 2 losses
labels.view(-1), | ||
) | ||
|
||
student_logps = self.get_batch_logps( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to calculate the probability per token for knowledge distillation? I might be wrong but don't we just pass teacher_logits and student_logits directly to divergence loss function, such as kldiv (normally with reduction="batchmean") or jsd?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Tcc0403 Thank you for the review!!
Actually, you're right—I’m aware of that. I was just trying to align the interface with the preference-based design and reuse the value of student_log_probs
calculated during ce_loss
in DistillBase. However, if it's not necessary to maintain the same interface, I prefer your suggestion.
As shown in the distillation calculation function in this PR, it essentially undoes the operations. This redundant computation could be avoided by directly passing the raw logits to the divergence function, instead of first converting them to log probabilities and then reversing them back to the original values.
label_chunk = torch.where(loss_mask, target_chunk, 0) | ||
|
||
student_average_log_prob = torch.zeros_like(loss_mask, dtype=torch.float) | ||
student_per_token_logps = student_log_probs_chunk.gather( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same question as above
If you're referring to current LigerFusedLinearJSD, there're some benchmark in #300. When comparing forward pass only, fljsd kernel is supposed to be slower since it does gradient calculations in forward pass as well, and it isn't purely written in triton so it might also suffer from kernel launching overhead. But it's true that it doesn't perform well in low BT scenario. |
student_logps (torch.Tensor): Avg log probabilities of student inputs. Shape: (batch_size, hidden_size,). | ||
teacher_logps (torch.Tensor): Avg log probabilities of teacher inputs. Shape: (batch_size, hidden_size,). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think for the general distillation loss, the student and teacher logps should be per-token instead of being averaged in the sequence length dimension. I.e., both tensors should be of shape (bathc_size, sequence_size, vocab_size) or (flattended_batch_sequence_size, vocab_size).
distillation_loss = distillation_loss_fn( | ||
student_logps, teacher_logps, temperature | ||
) | ||
distillation_loss = distillation_loss / (full_target.shape[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After we made the distillation loss per_token, we may normalize the distillation_loss
with full_target != ignore_index).sum
similar to the ce_loss.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for drafting the distillation base class, left some comments on the fused_linear_distillation.py
, mainly discussing the loss should be computed per token or averaged from the sequence level first.
Anther major question that I am having is on the chunking
dimensions. Current implementation of this PR is just chunking from the batch_size
dimension, which is similar to the implementation of fused_linear_preference.py
. However, I think it would be better if we can chunk from the flattened dim[0] of (B*T, vocab_size), which is also the way of chunking described in the paper for CE_loss.
For preference_base class, I think the chunking only happens on the batch_size dimension because the sequence dimension is reduced when calculating the average logps (link). . But for distillation, we may prefer to follow the patten of CE loss to chunk on the joint dimension of B*T, so that this kernel can work for very long sequence/ context scenario. Happy to help refine this base class @austin362667
cc @shivam15s what do you think on this?
@austin362667 nit: A side note is to split this PR into two stacked PRs: first for the distillation base class and second for the JSDloss based from it. We can prioritize to polish and merge the first PR so that other distillation losses can be based on it and it's non-blocking 😄 |
@hongpeng-guo Thanks for review~
That makes perfect sense to me; I'll proceed with this approach.
Absolutely! I'll split this into two separate PRs. |
Thanks all nice comments! @Tcc0403 and @hongpeng-guo |
## Summary Made #417 from the main repo. Thanks to the nice suggestions from @Tcc0403 and @hongpeng-guo. This PR is the s first split from #408, focusing solely on introducing the Knowledge Distillation base class. As a result, this PR does not include any tests at the moment. #### Code Changes 1. Refactor `beta` into two weights: `weight_hard_loss` and `weight_soft_loss`, as coefficients between `hard_loss` and `soft_loss`. @Tcc0403 also pointed out that we could use `torch.lerp` if applicable. 2. Pass `teacher_logits` and `student_logits` directly to the divergence loss function. This avoids redundant computations of converting logits to log probabilities and then reverting them to raw logits. However note that we are not reusing the `student_log_probs` value calculated during `ce_loss` in distillation base. 1. Remove the unnecessary `get_batch_logps` in `test/utils.py`. 3. Modify `chunking` dimensions from `B` to `B * T`. Thanks to @hongpeng-guo's great advice. 1. Fix the loss calculation to use per-token values instead of averaging across the sequence length dimension. 4. Normalize the `distillation_loss` using `(full_target != ignore_index).sum()`. #### TODO 1. [X] Although a slightly slowdown is reasonable, we need to investigate why this PR's implementation is **significantly slower** compared to the naive approach. Thanks to @Tcc0403 's clarification. The issue arises because we are not properly configuring the `chunk_size` for the `B * T` dimension, which is extremely large (a few thousand). The previous default of 1 results in an excessive number of chunks. In contrast, this problem does not occur with the preference loss, as chunking is performed on the `B` dimension. This produces fewer than 10 chunks, which is efficient and works as expected. In conclusion, I set `chunk_size` to `1024` works pretty well in new benchmark results as shown in #425 2. [ ] #417 (comment) #### Knowledge Distillation Knowledge Distillation (KD; [Hinton et al. 2015](https://arxiv.org/abs/1503.02531), [Gou et al. 2020](https://arxiv.org/abs/2006.05525)) is a straightforward way to build a smaller, cheaper model (“student model”) to speed up inference by transferring skills from a pre-trained expensive model (“teacher model”) into the student. In knowledge distillation, a student model is trained to replicate the outputs of a teacher model using a distillation loss. Neural networks typically include a softmax layer; for instance, a large language model produces a probability distribution over tokens. Let `z_t` and `z_s` represent the logits before the softmax layer for the teacher and student models, respectively. The distillation loss reduces the discrepancy between the two softmax outputs at a high temperature `T`. When ground truth labels `y` are available, this approach can be combined with a supervised learning objective, such as cross-entropy, to compare the student’s outputs with the ground truth. The combined loss function is defined as: ```math \mathcal{L}_{\text{knowledge distillation}} = \mathcal{w}_{\text{soft}} \cdot \mathcal{L}_{\text{distill}}(\mathbf{z_t}, \mathbf{z_s}, T) + \mathcal{w}_{\text{hard}} \cdot \mathcal{L}_{\text{cross entropy}}(\mathbf{y}, \mathbf{z_s}), ``` Here, we directly pass in `logits` rather than `logpbs`. @Tcc0403 #### Shared `DistillationBase` To support various distillation learning objectives, this PR aims to add a `LigerFusedLinearDistillationBase` which is basically same as propose by @hongpeng-guo within this discussion #371 (comment). Thank you @hongpeng-guo for thinking through this. ## Testing Done I'll post JSD tests and benchmarks results in next PR: #425 - Hardware Type: L40S - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Signed-off-by: Austin Liu <[email protected]> Co-authored-by: shivam15s <[email protected]>
Summary
Knowledge Distillation
Knowledge Distillation (KD; Hinton et al. 2015, Gou et al. 2020) is a straightforward way to build a smaller, cheaper model (“student model”) to speed up inference by transferring skills from a pre-trained expensive model (“teacher model”) into the student.
In knowledge distillation, a student model is trained to replicate the outputs of a teacher model using a distillation loss. Neural networks typically include a softmax layer; for instance, a large language model produces a probability distribution over tokens. Let
z_t
andz_s
represent the logits before the softmax layer for the teacher and student models, respectively. The distillation loss reduces the discrepancy between the two softmax outputs at a high temperatureT
. When ground truth labelsy
are available, this approach can be combined with a supervised learning objective, such as cross-entropy, to compare the student’s outputs with the ground truth.The combined loss function is defined as:
Here,
lambda
is a hyperparameter that balances the distillation loss and the supervised objective.Shared
DistillationBase
To support various distillation learning objectives, this PR aims to add a
LigerFusedLinearDistillationBase
which is basically same as propose by @hongpeng-guo within this discussion #371 (comment). Thank you @hongpeng-guo for thinking through this.Jensen-Shannon Divergence Loss
In addition to adding the base class, this PR implements Jensen-Shannon Divergence (JSD) loss as the soft learning objective in the distillation setting. This component can be replaced with other losses (e.g., KL divergence) as
distillation_loss_fn
.JSD is defined as the average of the KL divergences between each distribution and the mean distribution:
Here,
P
andQ
are the two probability distributions, andM
is their average.TODO
Testing Done
Yes.
make test
to ensure correctnessmake checkstyle
to ensure code stylemake test-convergence
to ensure convergence