From 5a31a00cb7961f0435cd348011a881da3d863cc0 Mon Sep 17 00:00:00 2001 From: Iordanis Fostiropoulos Date: Wed, 29 Nov 2023 00:36:54 +0000 Subject: [PATCH] fix parse device --- tests/utils/test_base.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/utils/test_base.py b/tests/utils/test_base.py index d62732fe..003e83e4 100644 --- a/tests/utils/test_base.py +++ b/tests/utils/test_base.py @@ -24,19 +24,20 @@ def test_set_seed(): def test_parse_device(): assert base.parse_device("cpu") == "cpu" if torch.cuda.is_available(): - assert base.parse_device(0) == "cuda:0" - assert base.parse_device(1) == "cuda:1" - assert base.parse_device("cuda") == "cuda" - assert base.parse_device(["cuda", "cpu"]) == "cuda" - device_number = min(torch.cuda.device_count() - 1, 0) - assert base.parse_device(f"cuda:{device_number}") == f"cuda:{device_number}" + gpu_number = torch.cuda.device_count() + for i in range(gpu_number): + assert base.parse_device(i) == f"cuda:{i}" + assert base.parse_device(f"cuda:{i}") == f"cuda:{i}" with pytest.raises( AssertionError, - match=f"gpu cuda:{gpu_number + 2} does not exist on this machine", + match=f"gpu cuda:{gpu_number + 1} does not exist on this machine", ): - base.parse_device(f"cuda:{gpu_number + 2}"), + base.parse_device(gpu_number + 1) + + assert base.parse_device("cuda") == "cuda" + assert base.parse_device(["cuda", "cpu"]) == "cuda" with pytest.raises(ValueError): base.parse_device("invalid")