Skip to content

Commit

Permalink
Fix OCPNeb trainer loading (#995)
Browse files Browse the repository at this point in the history
* add option to specify trainer for openneb; pass loss_functions to avoid breaking update_config

* fix test
  • Loading branch information
misko authored Feb 6, 2025
1 parent beba557 commit daa8acc
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/fairchem/applications/cattsunami/core/ocpneb.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
precon=None,
cpu=False,
batch_size=4,
trainer=None,
):
"""
Subclass of NEB that allows for scaled and dynamic optimizations of
Expand Down Expand Up @@ -97,11 +98,14 @@ def __init__(
if "relax_dataset" in config["task"]:
del config["task"]["relax_dataset"]

self.trainer = registry.get_trainer_class(config.get("trainer", "ocp"))(
if trainer is None:
trainer = config.get("trainer", "ocp")

self.trainer = registry.get_trainer_class(trainer)(
task=config.get("task", {}),
model=config["model"],
outputs={},
loss_functions={},
loss_functions=config.get("loss_functions", {}),
evaluation_metrics={},
dataset=[config["dataset"]],
optimizer=config["optim"],
Expand Down

0 comments on commit daa8acc

Please sign in to comment.