diff --git a/tests/pipe_test.py b/tests/pipe_test.py index ed1b34b..a78e4fb 100644 --- a/tests/pipe_test.py +++ b/tests/pipe_test.py @@ -114,6 +114,11 @@ def test_loader(self): loader = Loader() model_id = self.get_model() model_type = self.model_type() + device = torch.device('cpu') + if torch.cuda.is_available(): + device = torch.device('cuda', 0) + if 'device' not in self.device_args: + self.device_args['device'] = device classes = self.get_cls_by_type(MaskedIm2ImPipe) # load inpainting pipe cls = classes[model_type] @@ -232,7 +237,7 @@ def test_lpw_turned_off(self): seed = 49045438434843 params = dict(prompt=prompt, negative_prompt="spherical", - generator=torch.Generator(pipe.pipe.device).manual_seed(seed)) + generator=torch.Generator().manual_seed(seed)) image = pipe.gen(params) image.save("cube_test_no_lpw.png") params = dict(prompt=prompt + " , best quality, famous photo",