Skip to content

Commit

Permalink
Rename parameter
Browse files Browse the repository at this point in the history
Signed-off-by: Francesc Marti Escofet <[email protected]>
  • Loading branch information
fmartiescofet committed Dec 18, 2024
1 parent 7ecbbb2 commit 2fb9383
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 15 deletions.
8 changes: 4 additions & 4 deletions terratorch/tasks/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,15 @@ def configure_optimizers(
optimizer = "Adam"

parameters: Iterable
if self.hparams.get("reduce_lr", None) is not None and len(self.hparams["reduce_lr"]) > 0:
if self.hparams.get("lr_overrides", None) is not None and len(self.hparams["lr_overrides"]) > 0:
parameters = []
for param_name, reduce_factor in self.hparams["reduce_lr"]:
for param_name, custom_lr in self.hparams["lr_overrides"]:
p = [p for n, p in self.model.named_parameters() if param_name in n]
parameters.append({"params": p, "lr": self.hparams["lr"] / reduce_factor})
parameters.append({"params": p, "lr": custom_lr})
rest_p = [
p
for n, p in self.model.named_parameters()
if all(param_name not in n for param_name, _ in self.hparams["reduce_lr"])
if all(param_name not in n for param_name, _ in self.hparams["lr_overrides"])
]
parameters.append({"params": rest_p})
else:
Expand Down
10 changes: 6 additions & 4 deletions terratorch/tasks/classification_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from terratorch.tasks.optimizer_factory import optimizer_factory
from terratorch.tasks.base_task import TerraTorchTask

logger = logging.getLogger('terratorch')
logger = logging.getLogger("terratorch")


def to_class_prediction(y: ModelOutput) -> Tensor:
y_hat = y.output
Expand Down Expand Up @@ -62,7 +63,7 @@ def __init__(
freeze_backbone: bool = False, # noqa: FBT001, FBT002
freeze_decoder: bool = False, # noqa: FBT002, FBT001
class_names: list[str] | None = None,
reduce_lr: list[tuple[str, float]] | None = None,
lr_overrides: list[tuple[str, float]] | None = None,
) -> None:
"""Constructor
Expand Down Expand Up @@ -98,8 +99,9 @@ def __init__(
freeze_decoder (bool, optional): Whether to freeze the decoder and segmentation head. Defaults to False.
class_names (list[str] | None, optional): List of class names passed to metrics for better naming.
Defaults to numeric ordering.
reduce_lr (list[tuple[str, float]] | None, optional): List of tuples with a substring of the parameter names
to reduce the learning rate and the factor to reduce it by. Defaults to None.
lr_overrides (list[tuple[str, float]] | None, optional): List of tuples with a substring of the parameter
names (it will check the substring is contained in the parameter name) to override the learning rate and
the new lr. Defaults to None.
"""
self.aux_loss = aux_loss
self.aux_heads = aux_heads
Expand Down
10 changes: 6 additions & 4 deletions terratorch/tasks/regression_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@

BATCH_IDX_FOR_VALIDATION_PLOTTING = 10

logger = logging.getLogger('terratorch')
logger = logging.getLogger("terratorch")


class RootLossWrapper(nn.Module):
def __init__(self, loss_function: nn.Module, reduction: None | str = "mean") -> None:
Expand Down Expand Up @@ -152,7 +153,7 @@ def __init__(
freeze_decoder: bool = False, # noqa: FBT001, FBT002
plot_on_val: bool | int = 10,
tiled_inference_parameters: TiledInferenceParameters | None = None,
reduce_lr: list[tuple[str, float]] | None = None,
lr_overrides: list[tuple[str, float]] | None = None,
) -> None:
"""Constructor
Expand Down Expand Up @@ -187,8 +188,9 @@ def __init__(
If true, log every epoch. Defaults to 10. If int, will plot every plot_on_val epochs.
tiled_inference_parameters (TiledInferenceParameters | None, optional): Inference parameters
used to determine if inference is done on the whole image or through tiling.
reduce_lr (list[tuple[str, float]] | None, optional): List of tuples with a substring of the parameter names
to reduce the learning rate and the factor to reduce it by. Defaults to None.
lr_overrides (list[tuple[str, float]] | None, optional): List of tuples with a substring of the parameter
names (it will check the substring is contained in the parameter name) to override the learning rate and
the new lr. Defaults to None.
"""
self.tiled_inference_parameters = tiled_inference_parameters
self.aux_loss = aux_loss
Expand Down
7 changes: 4 additions & 3 deletions terratorch/tasks/segmentation_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(
class_names: list[str] | None = None,
tiled_inference_parameters: TiledInferenceParameters = None,
test_dataloaders_names: list[str] | None = None,
reduce_lr: list[tuple[str, float]] | None = None,
lr_overrides: list[tuple[str, float]] | None = None,
) -> None:
"""Constructor
Expand Down Expand Up @@ -107,8 +107,9 @@ def __init__(
test_dataloaders_names (list[str] | None, optional): Names used to differentiate metrics when
multiple dataloaders are returned by test_dataloader in the datamodule. Defaults to None,
which assumes only one test dataloader is used.
reduce_lr (list[tuple[str, float]] | None, optional): List of tuples with a substring of the parameter names
to reduce the learning rate and the factor to reduce it by. Defaults to None.
lr_overrides (list[tuple[str, float]] | None, optional): List of tuples with a substring of the parameter
names (it will check the substring is contained in the parameter name) to override the learning rate and
the new lr. Defaults to None.
"""
self.tiled_inference_parameters = tiled_inference_parameters
self.aux_loss = aux_loss
Expand Down

0 comments on commit 2fb9383

Please sign in to comment.