-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add forward to training servicer #227
Add forward to training servicer #227
Conversation
d482661
to
32cd26d
Compare
3dc2864
to
7744f91
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #227 +/- ##
==========================================
- Coverage 64.48% 63.67% -0.81%
==========================================
Files 47 47
Lines 2689 2742 +53
==========================================
+ Hits 1734 1746 +12
- Misses 955 996 +41 ☔ View full report in Codecov by Sentry. |
7744f91
to
50b0944
Compare
If the training is running or paused, the forward, will retain the state after completion. But it requires to pause so we can release memory and do the forward pass.
Since both inference and training servicers have common the concept of id, the training session id was replaced with the model session one used for inference. This model session protobuf interfaced moved to a separate utils proto file. The PredictRequest being common, can be leveraged for abstraction.
4798dbe
to
d3d6702
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, focused PR with pretty straight forward to follow changes including tests! It's great that you unified get_model_session, ModelsessionID, PredictResponse/Request between services. Very cool that the training state is retained for forward happening during training.
You've decided to pull in pytorch already at servicer level, which sprinkles pytorch through all the layers. I would prefer to keep tensors numpy until reaching the actual backend to at least leave the option to change the framework to something else (e.g. keras) in the future. The Inference Servicer uses the bioimageio Sample abstraction - would this be an option here, as well?
assert len(predicted_tensors) == 1 | ||
predicted_tensor = predicted_tensors[0] | ||
assert predicted_tensor.dims == ("b", "c", "z", "y", "x") | ||
assert predicted_tensor.shape == (batch, out_channels_unet2d, 1, 128, 128) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice if predicted_tensor
values could also be tested somehow, just to be on a little bit of a safer side...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right. Currently with the configuration, we can't really mock the model. We need to bypass the initialization, and use the testing approach of tests such as the test_start_training_success
, where we create a mocked trainer object MockedNominalTrainer
, that could set a mocked model as attribute.
That was a general concern that I had, that currently a few tests, they use the pytorch 3d unet config to create a model, but we don't really have a testing controlled mocked model.
Bypassing the init phase, of the configuration, we lose the end to end approach of the test, but maybe we should somehow have both.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but this is in fact doing a proper forward pass?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
currently it is doing a forward pass by a model defined by the config, I am not sure what do you refer to this
Thanks a lot @k-dominik for the review! I have implemented the suggestion of being decoupled by pytorch tensors, and use the sample bioimageio abstraction :) I just have one concern regarding testing, could you please have a look on this comment #227 (comment) |
Hi Theo, conda-build changed their default package format... now the extension is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR looks good now @thodkatz (except for the previous comment about conda-build's change package format), can be merged once tests pass :)
It builds on top of the #225
I have implemented the forward method. The forward method intervenes on the training loop, and we retain its initial state. The forward method requires for the training to pause, so we are sure that we won't cause any memory issues, if we attempt to do at the same time training and inference.