diff --git a/da-clip/src/open_clip/transformer.py b/da-clip/src/open_clip/transformer.py index c0fd2c8..da8caeb 100644 --- a/da-clip/src/open_clip/transformer.py +++ b/da-clip/src/open_clip/transformer.py @@ -294,7 +294,7 @@ def __init__(self, transformer): self.zero_modules = nn.ModuleList([ self.zero_module(nn.Linear(self.width, self.width, 1)) - for _ in range(self.layers)]).cuda() + for _ in range(self.layers)]) self.grad_checkpointing = transformer.grad_checkpointing def zero_module(self, module):