Skip to content

Commit

Permalink
fix bugs in model restore.
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyaofo committed Apr 11, 2021
1 parent 19f1077 commit 948fc63
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion codebase/torchutils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,8 @@ def save(self, minitor: str, metrics: dict, states: dict):
def restore(self, metrics: dict, states: dict, device="cuda:0"):
checkpoint_path = self.output_directory / "checkpoint.pt"
if checkpoint_path.exists():
checkpoint: dict = torch.load(checkpoint_path, map_location=device)
map_location= f"cuda:{device}" if isinstance(device, int) else device
checkpoint: dict = torch.load(checkpoint_path, map_location=map_location)
metrics.update(checkpoint.pop("metrics", dict()))
for name, module in states.items():
module.load_state_dict(checkpoint[name])
Expand Down

0 comments on commit 948fc63

Please sign in to comment.