Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds support for knowledge distillation #380

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

RobotSail
Copy link
Member

There are many different forms of model training which exist. One popular form of training is knowledge distillation, where a student model learns the output distributions from a teacher model. This commit introduces support for knowledge distillation in the training library.

This commit also exposes the weight_decay hyperparameter which is often used to help deep learning models generalize.

Lastly, this commit changes the useage from torch.distributed to just dist, as it is a common module used throughout the codebase.

Signed-off-by: Oleg S [email protected]

There are many different forms of model training which exist. One popular form of training is knowledge distillation, where a student model learns the output distributions from a teacher model.  This commit introduces support for knowledge distillation in the training library.

This commit also exposes the `weight_decay` hyperparameter which is often used to help deep learning models generalize.

Lastly, this commit changes the useage from `torch.distributed` to just `dist`, as it is a common module used throughout the codebase.

Signed-off-by: Oleg S <[email protected]>

temperature: float = Field(1.0, gt=0.0)
alpha: float = Field(1.0, le=1.0, ge=0.0)
teacher_path: str
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if possible would love to standardize on using pathlib.Path rather than str paths.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JamesKunstle I see your point, would it make sense for it to be a path when it can also take on a HF reference? I understand that references can technically still be paths, but to a consumer reading it might sound like only local models are accepted. Would str | Path be satisfactory?

teacher_model = AutoModelForCausalLM.from_pretrained(
model_name_or_path, torch_dtype=torch.bfloat16
).to(device)
model_dev = next(teacher_model.parameters()).device
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're calling .to(device) just above, could you make a note of why you also need to confirm the device locale below?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes of course.

weight_decay: float = Field(0.0, ge=0.0)

# settings for knowledge distillation
distillation_options: Optional[DistillationConfig] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've seen that Optional[DistillationConfig] syntax is replaced by DistillationConfig | None in recent Pythonic parlance once the optional annotation was added to the language. This is a nit, not required to change.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do the proposed method to be more consistent with how Python expects optionals in the future.

loss = None
if args.distill:
# teacher_model should always be provided when `args.distill` is enabled
if TYPE_CHECKING:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is supposed to be a runtime check but TYPE_CHECKING is always False at runtime.
https://docs.python.org/3/library/typing.html#typing.TYPE_CHECKING

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should fail much earlier if distillation is set but no teacher_model is provided, like before we do any data preprocessing or fire up the GPUs.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it is, I believe I had errors with type-checking here though and it not knowing that teacher_model is properly set.

), "teacher model cannot be None when `distill` is enabled"

with torch.no_grad():
teacher_output: CausalLMOutput = teacher_model(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You turn off requires_grad on all the params in the teacher_model. You could just be doing this instead, I think this gives the same output.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So they're not fully the same. requires_grad ensures that a tensor never needs its gradient to be computed when .backward() is called at some point in the computation graph and therefore doesn't need to store any additional data for it. Whereas torch.no_grad ensures that the tensor computations within the given context do not count towards the gradient calculation during backprop.

The reason we're doing both is so that:

  1. requires_grad=False --> The teacher model doesn't need to get updated so we don't need to store any additional variables
  2. with torch.no_grad() --> If any other tensors happen to participate in the computation for whatever reason, say for example someone updates this and includes them, their gradients are also not impacted by participating in this calculation.

Having this as an explicit context also allows us to communicate to other developers in the future that this is not intended to participate in backprop, which comes to us at no extra cost really.

Probably you can get away without using torch.no_grad here, but it's just a good practice to do both.

else:
loss = output.loss

assert loss is not None, "loss cannot be equal to None!"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

asserts are typically not preferred in comparison to runtime exceptions.

Suggested change
assert loss is not None, "loss cannot be equal to None!"
if loss is None:
raise ValueError("loss was None during distillation training. Something unrecoverable went wrong.")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mostly because they can be removed with -0 when the interpreter is invoked. But we want to check non-null all the time.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure I can change this. I was using them as scaffolding when writing this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm gonna make it not be specific to distillation training though since it's more about how we branch out. I suspect as we add other loss calculations (contrastive loss, preference tuning loss, etc.), we will start out by setting it to None and having this final check to ensure it was set to something.

@@ -511,6 +609,9 @@ def main(args):
# Third Party
import yaml

if args.distill and not args.teacher_model_name_or_path:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this early check seems right.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sweet :party-cat:

Copy link
Contributor

@JamesKunstle JamesKunstle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nearly ready to go. Just a couple of API questions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants