Skip to content

Commit

Permalink
add option to finetune backbone from tnt/dcp
Browse files Browse the repository at this point in the history
  • Loading branch information
misko committed Mar 2, 2025
1 parent ab9f004 commit 0fb5575
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions src/fairchem/core/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import TYPE_CHECKING

import torch
import torch.distributed.checkpoint as dcp
from torch import nn
from torch_geometric.nn import radius_graph

Expand Down Expand Up @@ -239,6 +240,7 @@ def __init__(
otf_graph: bool = True,
pass_through_head_outputs: bool = False,
freeze_backbone: bool = False,
tnt_finetune_config: dict | None = None,
):
super().__init__()
self.device = None
Expand Down Expand Up @@ -292,6 +294,18 @@ def __init__(
for param in self.backbone.parameters():
param.requires_grad = False

# load after the backbone is loaded but before heads
if tnt_finetune_config is not None:
dcp.load(
state_dict={
f"unit_state.model.{k}": v for k, v in self.state_dict().items()
},
checkpoint_id=tnt_finetune_config["starting_checkpoint"],
)
if "override" in tnt_finetune_config:
for key, value in tnt_finetune_config["override"].items():
setattr(self.backbone, key, value)

if heads is not None:
heads = copy.deepcopy(heads)
# Iterate through outputs_cfg and create heads
Expand Down

0 comments on commit 0fb5575

Please sign in to comment.