diff --git a/terratorch/models/backbones/prithvi_vit.py b/terratorch/models/backbones/prithvi_vit.py index 93a3b504..98266025 100644 --- a/terratorch/models/backbones/prithvi_vit.py +++ b/terratorch/models/backbones/prithvi_vit.py @@ -51,7 +51,8 @@ def _cfg(**kwargs): **kwargs } -prithvi_default_cfgs = { + +prithvi_cfgs = { "prithvi_eo_tiny": _cfg(num_frames=1, embed_dim=256, depth=4, num_heads=4, decoder_embed_dim=128, decoder_depth=4, decoder_num_heads=4), "prithvi_eo_v1_100": _cfg(num_frames=3, mean=PRITHVI_V1_MEAN, std=PRITHVI_V1_STD), @@ -64,8 +65,8 @@ def _cfg(**kwargs): coords_encoding=["time", "location"], coords_scale_learn=True), } - -pretrained_cfgs = generate_default_cfgs( +# Timm pretrained configs +default_cfgs = generate_default_cfgs( { "prithvi_eo_v1_100": { "hf_hub_id": "ibm-nasa-geospatial/Prithvi-EO-1.0-100M", @@ -203,19 +204,20 @@ def checkpoint_filter_wrapper_fn(state_dict, model): return checkpoint_filter_fn_mae(state_dict, model, pretrained_bands, model_bands) if pretrained: - assert variant in pretrained_cfgs, (f"No pre-trained model found for variant {variant} " - f"(pretrained models: {pretrained_cfgs.keys()})") + assert variant in default_cfgs, (f"No pre-trained model found for variant {variant} " + f"(pretrained models: {default_cfgs.keys()})") # Load pre-trained config from hf try: - model_args, _ = load_model_config_from_hf(pretrained_cfgs[variant].default.hf_hub_id) + model_args, _ = load_model_config_from_hf(default_cfgs[variant].default.hf_hub_id) model_args.update(kwargs) except: - logger.warning(f"No pretrained configuration was found on HuggingFace for the model {variant}.") - model_args = prithvi_default_cfgs[variant].copy() + logger.warning(f"No pretrained configuration was found on HuggingFace for the model {variant}." + f"Using random initialization.") + model_args = prithvi_cfgs[variant].copy() model_args.update(kwargs) else: # Load default config - model_args = prithvi_default_cfgs[variant].copy() + model_args = prithvi_cfgs[variant].copy() model_args.update(kwargs) # When the pretrained configuration is not available in HF, we shift to pretrained=False @@ -229,7 +231,7 @@ def checkpoint_filter_wrapper_fn(state_dict, model): **model_args, ) except RuntimeError: - logger.warning(f"No pretrained configuration was found for the model {variant}.") + logger.warning(f"No pretrained configuration was found for the model {variant}. Using random initialization.") model = build_model_with_cfg( prithvi_model_class, variant,