Skip to content

Commit

Permalink
change weight loading schema (#145)
Browse files Browse the repository at this point in the history
  • Loading branch information
arogozhnikov authored Nov 5, 2024
1 parent d6ccd41 commit 7ade6f9
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions chai_lab/chai1.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,13 @@ def forward(self, crop_size: int, **kw):
def load_exported(comp_key: str, device: torch.device) -> ModuleWrapper:
torch.jit.set_fusion_strategy([("STATIC", 0), ("DYNAMIC", 0)])
local_path = chai1_component(comp_key)
# specifying map_location=... doesn't load weights properly
return ModuleWrapper(torch.jit.load(local_path).to(device))
assert isinstance(device, torch.device)
if device != torch.device("cuda:0"):
# load on cpu first, then move to device
return ModuleWrapper(torch.jit.load(local_path, map_location="cpu").to(device))
else:
# skip loading on CPU.
return ModuleWrapper(torch.jit.load(local_path).to(device))


# %%
Expand Down

0 comments on commit 7ade6f9

Please sign in to comment.