-
Notifications
You must be signed in to change notification settings - Fork 33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Testing all the backbones #31
Conversation
Signed-off-by: João Lucas de Sousa Almeida <[email protected]> Testing to save and load checkpoints Signed-off-by: João Lucas de Sousa Almeida <[email protected]> Testing finetuning for Swin Signed-off-by: João Lucas de Sousa Almeida <[email protected]> More config files used for executing the manufactured tests Signed-off-by: João Lucas de Sousa Almeida <[email protected]> More input/target files to perform the manufactured tests Signed-off-by: João Lucas de Sousa Almeida <[email protected]> Automatically testing fine-tuning Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the work! I left a few comments, but I think having these tests will really help us going forward :)
tests/test_backbones.py
Outdated
def test_all_backbone_instantiation(model_name): | ||
|
||
if "vit" in model_name : | ||
module_str = "terratorch.models.backbones.prithvi_vit" | ||
elif "swin" in model_name: | ||
module_str = "terratorch.models.backbones.prithvi_swin" | ||
|
||
module_instance = importlib.import_module(module_str) | ||
|
||
model_template = getattr(module_instance, model_name) | ||
|
||
model_instance = model_template() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this will be hard to extend in order to really cover all backbones, because of the way we have to specify all the import paths.
Instead I wonder if we could adopt something like doctest where in the docstring of each class we can write code that gets executed in the tests. There we can instantiate that specific class. It also acts as extra documentation for users.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the current structure we have on terratorch it seems to be enough , but we could refactor it to use doctest or to create a list in full, specifying all the models we have and performing a loop to test the instantiation all of them. This list should be updated each time we introduce a new backbone.
tests/test_backbones.py
Outdated
torch.save(model_instance.state_dict(), os.path.join("tests/", str(id(model_instance)) + ".pth")) | ||
|
||
model_restored = torch.load(os.path.join("tests/", str(id(model_instance)) + ".pth")) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would this part every fail? I think we this part just tests the torch.save
and torch.load
methods themselves
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My original intention was to test save/restoring and execution in the same test, but as it was already done in another part it can be removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was removed.
tests/test_finetune.py
Outdated
if "vit" in model_name : | ||
module_str = "terratorch.models.backbones.prithvi_vit" | ||
ckpt_filter = checkpoint_filter_fn_vit | ||
elif "swin" in model_name: | ||
module_str = "terratorch.models.backbones.prithvi_swin" | ||
ckpt_filter = checkpoint_filter_fn_swin | ||
|
||
module_instance = importlib.import_module(module_str) | ||
|
||
model_template = getattr(module_instance, model_name) | ||
|
||
model_instance = model_template() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of this, can we leverage timm.create_model
to directly create the backbone from the string?
This wont work for all models, as not all of them support being created from timm, but at least for this I think its a better solution. I worry about the if statement of import module paths growing and becoming unwieldy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Considering the way the code is currently structured and the models available on terratorch, it seems to be enough for now, but we could do it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since its a relatively quick code change and results in cleaner code I think we should
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's done now.
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
This commit is a pool of the sub-commits:
Testing to save and load checkpoints
Testing finetuning for Swin
More config files used for executing the manufactured tests
More input/target files to perform the manufactured tests
Automatically testing fine-tuning
These modifications enable some automatic CLI tests for fine-tuning tasks using terratorch for the main backbones.