Skip to content

Commit

Permalink
fix parse device
Browse files Browse the repository at this point in the history
  • Loading branch information
fostiropoulos committed Nov 29, 2023
1 parent 28d013c commit 5a31a00
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions tests/utils/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,20 @@ def test_set_seed():
def test_parse_device():
assert base.parse_device("cpu") == "cpu"
if torch.cuda.is_available():
assert base.parse_device(0) == "cuda:0"
assert base.parse_device(1) == "cuda:1"
assert base.parse_device("cuda") == "cuda"
assert base.parse_device(["cuda", "cpu"]) == "cuda"
device_number = min(torch.cuda.device_count() - 1, 0)
assert base.parse_device(f"cuda:{device_number}") == f"cuda:{device_number}"

gpu_number = torch.cuda.device_count()
for i in range(gpu_number):
assert base.parse_device(i) == f"cuda:{i}"
assert base.parse_device(f"cuda:{i}") == f"cuda:{i}"

with pytest.raises(
AssertionError,
match=f"gpu cuda:{gpu_number + 2} does not exist on this machine",
match=f"gpu cuda:{gpu_number + 1} does not exist on this machine",
):
base.parse_device(f"cuda:{gpu_number + 2}"),
base.parse_device(gpu_number + 1)

assert base.parse_device("cuda") == "cuda"
assert base.parse_device(["cuda", "cpu"]) == "cuda"

with pytest.raises(ValueError):
base.parse_device("invalid")
Expand Down

0 comments on commit 5a31a00

Please sign in to comment.