Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Complex Loss Functions Inside TrainingLoop #720

Open
xanderdunn opened this issue Dec 12, 2020 · 3 comments
Open

Complex Loss Functions Inside TrainingLoop #720

xanderdunn opened this issue Dec 12, 2020 · 3 comments

Comments

@xanderdunn
Copy link

xanderdunn commented Dec 12, 2020

Thanks to @xihui-wu's talk earlier today, I learned about the TrainingLoop struct. I had essentially replicated this functionality in a messier way in my code, so I'm looking at it to see if I could replace my train loop with this cleaner implementation. I believe the only issue I might face is with respect to the loss function.

The current loss function takes as parameters only the model's output and the target: public typealias F = @differentiable(Output, @noDerivative Target) -> Tensor<Float> from here. This covers a huge majority of supervised training situations, but there are situations where we might want more complicated loss functions. For example, how might we mask the output and the target for each sample when we calculate the loss, as done in this paper:

at each time step the model tries to predict the full, uncorrupted input vectors xt; however, only the predictions on the masked values are considered in the Mean Squared Error loss.

Another situation that comes to mind is a loss that requires some third, external set of values. Perhaps this is an RL agent whose current loss is a function of recent past losses. Another example could also be a risk-adjusted metric where "risk" depends on some external value that is not static.

Is my reading of the code correct that these types of loss functions are not currently supported? If so, could the protocol be reasonably modified to optionally support such complex loss functions?

Many thanks!

@xihui-wu
Copy link
Contributor

@xanderdunn thanks for your post! TrainingLoop is currently in iterations to develop to cover more and more use cases.

To answer you first question in supervised learning scenario - generally speaking, if you want to customize loss function by manipulating on output and label, you can write your loss function like what BERT-CoLA does here. For your mask case, you might be able to do similarly, or do sth inside the model like BERT attention does here ?

Secondly, yes, when the loss function depends on some training-state-depended value, some work are do needed to allow for it. Specially RL, it is a different world, we are considering providing different TrainingLoop variations for it.

cc @BradLarson

@xanderdunn
Copy link
Author

Thanks @xihui-wu. Writing custom loss functions is straightforward, but the challenging part is a training loop that is compatible with loss functions that take more than just logits and labels as parameters. In the linked example the labels have type Tensor<Int32>. Could this instead be a struct that contains multiple tensors? Such a struct could contain both the labels and the masks, for example.

Your BERT attention link points to the same link as the BERT-CoLA example, but I think you're referring to this attention mask that is applied here to the attention scores? This masks the inputs into the model, but it doesn't mask the outputs and the targets for calculating the loss as done in the Transformer representation learning paper I linked above.

@xihui-wu
Copy link
Contributor

Thanks for correcting the link. You are welcome to try if making it into a struct together with some changes in TrainingLoop works. Again our current TrainingLoop implementation isn't in a final form, and more flexibility may be needed for RL and other applications! At some point we might look at having composable training loop pieces so that loops could be more flexible, but that be some ways off. We're open to having the current training loop be expanded to fit other use cases.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants