Skip to content

Commit

Permalink
Merge branch 'main' into removed-pretrained-fallback
Browse files Browse the repository at this point in the history
  • Loading branch information
romeokienzler authored Dec 11, 2024
2 parents ef3d055 + 1cdc2d4 commit d191e29
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions terratorch/models/backbones/prithvi_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def _cfg(**kwargs):
**kwargs
}

prithvi_default_cfgs = {

prithvi_cfgs = {
"prithvi_eo_tiny": _cfg(num_frames=1, embed_dim=256, depth=4, num_heads=4,
decoder_embed_dim=128, decoder_depth=4, decoder_num_heads=4),
"prithvi_eo_v1_100": _cfg(num_frames=3, mean=PRITHVI_V1_MEAN, std=PRITHVI_V1_STD),
Expand All @@ -64,8 +65,8 @@ def _cfg(**kwargs):
coords_encoding=["time", "location"], coords_scale_learn=True),
}


pretrained_cfgs = generate_default_cfgs(
# Timm pretrained configs
default_cfgs = generate_default_cfgs(
{
"prithvi_eo_v1_100": {
"hf_hub_id": "ibm-nasa-geospatial/Prithvi-EO-1.0-100M",
Expand Down Expand Up @@ -203,19 +204,20 @@ def checkpoint_filter_wrapper_fn(state_dict, model):
return checkpoint_filter_fn_mae(state_dict, model, pretrained_bands, model_bands)

if pretrained:
assert variant in pretrained_cfgs, (f"No pre-trained model found for variant {variant} "
f"(pretrained models: {pretrained_cfgs.keys()})")
assert variant in default_cfgs, (f"No pre-trained model found for variant {variant} "
f"(pretrained models: {default_cfgs.keys()})")
# Load pre-trained config from hf
try:
model_args, _ = load_model_config_from_hf(pretrained_cfgs[variant].default.hf_hub_id)
model_args, _ = load_model_config_from_hf(default_cfgs[variant].default.hf_hub_id)
model_args.update(kwargs)
except:
logger.warning(f"No pretrained configuration was found on HuggingFace for the model {variant}.")
model_args = prithvi_default_cfgs[variant].copy()
logger.warning(f"No pretrained configuration was found on HuggingFace for the model {variant}."
f"Using random initialization.")
model_args = prithvi_cfgs[variant].copy()
model_args.update(kwargs)
else:
# Load default config
model_args = prithvi_default_cfgs[variant].copy()
model_args = prithvi_cfgs[variant].copy()
model_args.update(kwargs)

try:
Expand Down

0 comments on commit d191e29

Please sign in to comment.