diff --git a/terratorch/models/backbones/prithvi_vit.py b/terratorch/models/backbones/prithvi_vit.py index 98266025..136c6513 100644 --- a/terratorch/models/backbones/prithvi_vit.py +++ b/terratorch/models/backbones/prithvi_vit.py @@ -220,7 +220,6 @@ def checkpoint_filter_wrapper_fn(state_dict, model): model_args = prithvi_cfgs[variant].copy() model_args.update(kwargs) - # When the pretrained configuration is not available in HF, we shift to pretrained=False try: model = build_model_with_cfg( prithvi_model_class, @@ -230,16 +229,13 @@ def checkpoint_filter_wrapper_fn(state_dict, model): pretrained_strict=True, **model_args, ) - except RuntimeError: - logger.warning(f"No pretrained configuration was found for the model {variant}. Using random initialization.") - model = build_model_with_cfg( - prithvi_model_class, - variant, - False, - pretrained_filter_fn=checkpoint_filter_wrapper_fn, - pretrained_strict=True, - **model_args, - ) + except RuntimeError as e: + if pretrained: + logger.error(f"Failed to initialize the pre-trained model {variant} via timm, " + f"consider running the code with pretrained=False.") + else: + logger.error(f"Failed to initialize the model {variant} via timm.") + raise e if encoder_only: default_out_indices = list(range(len(model.blocks)))