From 7ade6f98711acc33df673d02a843e266d07fa1c0 Mon Sep 17 00:00:00 2001 From: Alex Rogozhnikov Date: Tue, 5 Nov 2024 10:07:57 -0800 Subject: [PATCH] change weight loading schema (#145) --- chai_lab/chai1.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/chai_lab/chai1.py b/chai_lab/chai1.py index f3305ce..04d9df9 100644 --- a/chai_lab/chai1.py +++ b/chai_lab/chai1.py @@ -104,8 +104,13 @@ def forward(self, crop_size: int, **kw): def load_exported(comp_key: str, device: torch.device) -> ModuleWrapper: torch.jit.set_fusion_strategy([("STATIC", 0), ("DYNAMIC", 0)]) local_path = chai1_component(comp_key) - # specifying map_location=... doesn't load weights properly - return ModuleWrapper(torch.jit.load(local_path).to(device)) + assert isinstance(device, torch.device) + if device != torch.device("cuda:0"): + # load on cpu first, then move to device + return ModuleWrapper(torch.jit.load(local_path, map_location="cpu").to(device)) + else: + # skip loading on CPU. + return ModuleWrapper(torch.jit.load(local_path).to(device)) # %%