Skip to content

Commit

Permalink
Skip more hipblas unsupported tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pnunna93 committed Feb 14, 2024
1 parent a3a30b1 commit 3b35d94
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 0 deletions.
2 changes: 2 additions & 0 deletions tests/cupy_tests/linalg_tests/test_decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ def check_singular(self, shape, xp, dtype):
return result

@_condition.repeat(3, 10)
@pytest.mark.skipif(cupy.cuda.runtime.is_hip, reason="hipblasSgemmEx not implemented")
def test_svd_rank2(self):
self.check_usv((3, 7))
self.check_usv((2, 2))
Expand All @@ -267,6 +268,7 @@ def test_svd_rank2_no_uv(self):
self.check_singular((7, 3))

@testing.with_requires('numpy>=1.16')
@pytest.mark.skipif(cupy.cuda.runtime.is_hip, reason="hipblasSgemmEx not implemented")
def test_svd_rank2_empty_array(self):
self.check_usv((0, 3))
self.check_usv((3, 0))
Expand Down
2 changes: 2 additions & 0 deletions tests/cupy_tests/linalg_tests/test_eigenvalue.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class TestEigenvalue:

@testing.for_all_dtypes()
@testing.numpy_cupy_allclose(rtol=1e-3, atol=1e-4, contiguous_check=False)
@pytest.mark.skipif(cupy.cuda.runtime.is_hip, reason="hipblasSgemmEx not implemented")
def test_eigh(self, xp, dtype):
if xp == numpy and dtype == numpy.float16:
# NumPy's eigh does not support float16
Expand Down Expand Up @@ -55,6 +56,7 @@ def test_eigh(self, xp, dtype):

@testing.for_all_dtypes(no_bool=True, no_float16=True, no_complex=True)
@testing.numpy_cupy_allclose(rtol=1e-3, atol=1e-4, contiguous_check=False)
@pytest.mark.skipif(cupy.cuda.runtime.is_hip, reason="hipblasSgemmEx not implemented")
def test_eigh_batched(self, xp, dtype):
a = xp.array([[[1, 0, 3], [0, 5, 0], [7, 0, 9]],
[[3, 0, 3], [0, 7, 0], [7, 0, 11]]], dtype)
Expand Down
2 changes: 2 additions & 0 deletions tests/cupyx_tests/test_lapack.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,13 @@ def setUp(self):
self.x_ref = x.reshape(b_shape)
self.tol = self._tol[self.dtype.char.lower()]

@pytest.mark.skipif(cupy.cuda.runtime.is_hip, reason="hipblasSgemmEx not implemented")
def test_gesv(self):
lapack.gesv(self.a, self.b)
cupy.testing.assert_allclose(self.b, self.x_ref,
rtol=self.tol, atol=self.tol)

@pytest.mark.skipif(cupy.cuda.runtime.is_hip, reason="hipblasSgemmEx not implemented")
def test_invalid_cases(self):
if self.nrhs is None or self.nrhs == 1:
raise unittest.SkipTest()
Expand Down

0 comments on commit 3b35d94

Please sign in to comment.