diff --git a/chai_lab/chai1.py b/chai_lab/chai1.py index f3305ce..04d9df9 100644 --- a/chai_lab/chai1.py +++ b/chai_lab/chai1.py @@ -104,8 +104,13 @@ def forward(self, crop_size: int, **kw): def load_exported(comp_key: str, device: torch.device) -> ModuleWrapper: torch.jit.set_fusion_strategy([("STATIC", 0), ("DYNAMIC", 0)]) local_path = chai1_component(comp_key) - # specifying map_location=... doesn't load weights properly - return ModuleWrapper(torch.jit.load(local_path).to(device)) + assert isinstance(device, torch.device) + if device != torch.device("cuda:0"): + # load on cpu first, then move to device + return ModuleWrapper(torch.jit.load(local_path, map_location="cpu").to(device)) + else: + # skip loading on CPU. + return ModuleWrapper(torch.jit.load(local_path).to(device)) # %%