-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable patchwise training and prediction #135
base: main
Are you sure you want to change the base?
Enable patchwise training and prediction #135
Conversation
Sliding window patching
Refactor `sample_sliding_window`
Co-authored-by: David Wilby <[email protected]>
…erge Replace combine_by_coords with np.where() to stich patched predictions
Tidy up patchwise prediction arguments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @davidwilby, looks much improved and nearly ready to LGTM. Two high-level comments:
- Can you fix the failing unit tests?
- For readability and maintenance, could you move the additions to
TaskLoader
andDeepSensor
model into subclasses to improve encapsulation of the patching functionality? I'm realising this PR makes the standard API more complicated, which may be offputting for users who don't require patching, and also makes the code for those classes trickier to parse. How about aPatchTaskLoader
andDeepSensorPatchwiseModel
? With no other changes to the class hierarchy I think we'd needConvNP
to subclassDeepSensorPatchwiseModel
instead of the standardDeepSensorModel
, though there may be a more elegant solution where the user decides whether to set up a model that predicts with patches through theConvNP
interface. WDYT?
deepsensor/model/model.py
Outdated
) | ||
|
||
## Cast prediction into DeepSensor.Prediction object. | ||
# TODO make this into seperate method. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pred.assign
isn't really what you want for stitching predictions together because that method just assigns data in bulk to the xarray or pandas objects.
TBH I've stared at this for a while and I'm not sure what you mean by 'copying one of the patch predictions and extending it' or where I should be looking, partly because it's difficult to follow predict_patchwise
because of all the nested functions. In particular I don't understand what's being done with prediction
below and why we can't just return stitched_prediction
--
As typing I just realised that stitch_clipped_predictions
is returning a dict not a Prediction
, and the code below has a bunch of redundant lines and is simply overwriting the entries of a copy of the first patch Prediction
with the xarray objects in stitched_predictions
. I think you can simplify and resolve this if you resolve my comment above about making stitch_clipped_predictions
return a Prediction
directly.
Co-authored-by: David Wilby <[email protected]>
Simplify stitching process
Co-authored-by: David Wilby <[email protected]>
…iction_objects Simplify stitching by retaining prediction objects
Hey @tom-andersson - at long last, the long-awaited patchwise training and prediction feature that @nilsleh and @MartinSJRogers have been working on.
This PR adds patching capabilities to DeepSensor during training and inference.
Training
Optional args
patching_strategy
,patch_size
,stride
andnum_samples_per_date
are added toTaskLoader.__call__
.There are two available patching strategies:
random_window
andsliding_window
. Therandom_window
option randomly selects points in thex1
andx2
extent as the centroid of the patch. The number of patches is defined by thenum_samples_per_date
argument. Thesliding_window
function starts in the top left of the dataset and convolves from left to right and top to bottom over the data using the user-definedpatch_size
andstride
.TaskLoader.__call__
now contains additional conditional logic depending upon the patching strategy selected. If no patching strategy is selected,task_generator()
runs exactly as before. Ifrandom_window
(sliding_window
) is selected the bounding boxes for the patches are generated using thesample_random_window()
(sample_sliding_window()
) methods. The bounding boxes are appended to the listbboxes
, and passed totask_generator()
.Within
task_generator()
after the sampling strategies are applied, the data is spatially sliced using each bbox in bboxes using theself.spatial_slice_variable()
function.When using a patching strategy,
TaskLoader
produces a list of tasks per date, rather than an individual task per date. A small change has been made toTask
'ssummarise_str
method to avoid an error whenprint
ing patchedTask
s and to output more meaningful information.Inference
To run patchwise predictions, a new method has been created in
model.py
calledpredict_patch()
. This method iterates through and applies the pre-exisitingpredict()
method to each patched task. Thepredict()
method has not been changed. Within each iteration, prior to runningpredict()
for each patch, the bounding box of each patch is unnormalized, so theX_t
of each patch can be passed to thepredict()
function. The patchwise predictions are stored in the listpreds
for subsequent stitching.It is only possible to use the sliding_window patching function during inference, and the stride and patch size are defined when the user generates the test tasks within the
task_loader()
call. Thedata_processor
must also be passed topredict_patch()
method to enable unnormalisation of the coordinates of the bboxes inmodel.py
.Once the list of patchwise predictions are generated,
stitch_clipped_predictions()
is used to form a prediction at the originalX_t
extent. Currently, functionality is provided to subset or clip each patchwise prediction so there is no overlap between adjacent patches and then merge the patches usingxr.combine_by_coords()
. The modular nature of the code means there is scope for additional stitching strategies to be added after this PR, for example applying a weighting function to overlapping predictions. To ensure the patches are clipped by the correct amount,get_patch_overlap()
calculates the overlap between adjacent patches.stitch_clipped_predictions()
also contains code to handle patches at the edge or bottom of the dataset, where the overlap may be different.The output from
predict_patch()
is the identical DeepSensor object produced inmodel.predict()
, hence DeepSensor’s plotting functionality can subsequently be used in the same way.Documentation and Testing
New notebook(s) are added illustrating the usage of both patchwise training and prediction.
New tests are added to verify the new behaviour.
Limitations
predict_patch
with more than one date raises aNotImplementedError
.predict_patch
is a new, distinct function due to all the pre-processing it needs to do, the patchwise behaviour may be better served as an option inpredict
- let me know what you think.patch_size
, e.g. for a 'square' patchpatch_size=(0.5,0.5)
the exact dimensions won't be exactly square, this is accounted for in stitching of patches, but is slightly inelegant at the moment so we may want to come back and find a more refined solution in the future.test_model.test_patchwise_prediction
I've temporarily commented-out the asserts checking for correct prediction shape, these fail with test datasets for now, but with real datasets the shapes are correct, see thepatchwise_training_and_prediction.ipynb
notebook.