Skip to content

Commit

Permalink
Merge pull request #67 from ROCm/shreya/polyval_tolarance
Browse files Browse the repository at this point in the history
Polyval test Assertion Error Fix
  • Loading branch information
pnunna93 authored Nov 7, 2024
2 parents 123e5e4 + 9c6dff7 commit 25b5b33
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 22 deletions.
23 changes: 2 additions & 21 deletions cupy/testing/_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from cupy.testing import _parameterized
import cupyx
import cupyx.scipy.sparse
from cupy_backends.cuda.api import runtime

from cupy.testing._pytest_impl import is_available

Expand Down Expand Up @@ -511,26 +510,8 @@ def numpy_cupy_allclose(rtol=1e-7, atol=0, err_msg='', verbose=True,

def check_func(c, n):
rtol1, atol1 = _resolve_tolerance(type_check, c, rtol, atol)
try:
_array.assert_allclose(c, n, rtol1, atol1, err_msg, verbose)
except AssertionError as e:
import numbers
if runtime.is_hip and type(n) is numpy.ndarray and \
(issubclass(n.dtype.type, numbers.Real) or
issubclass(n.dtype.type, numbers.Complex)):
npc = cupy.asnumpy(c)
diff = numpy.linalg.norm(npc - n, numpy.inf)
norm = numpy.linalg.norm(npc, numpy.inf)
if numpy.any(npc) and numpy.any(n):
min_positive = numpy.finfo(n.dtype).tiny
if (abs(npc) < min_positive).any() or \
(abs(n) < min_positive).any():
# Denormal case handling
assert diff <= (atol1 + rtol1 * norm)
else:
raise e
else:
raise e
_array.assert_allclose(c, n, rtol1, atol1, err_msg, verbose)

return _make_decorator(check_func, name, type_check, contiguous_check,
accept_error, sp_name, scipy_name,
_check_sparse_format)
Expand Down
9 changes: 8 additions & 1 deletion tests/cupy_tests/lib_tests/test_polynomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from cupy.cuda import runtime
import cupyx
from cupy import testing
from cupy_backends.cuda.api import runtime as _runtime


@testing.parameterize(
Expand Down Expand Up @@ -694,7 +695,13 @@ def test_polyfit_weighted_diff_types(self, xp, dtype1, dtype2, dtype3):
class TestPolyval(Poly1dTestBase):

@testing.for_all_dtypes()
@testing.numpy_cupy_allclose(rtol={numpy.float16: 1e-2, 'default': 1e-3})
@testing.numpy_cupy_allclose(
rtol=(
{numpy.float16: 1e-2, 'default': 2e-3}
if _runtime.is_hip
else {numpy.float16: 1e-2, 'default': 1e-3}
)
)
def test_polyval(self, xp, dtype):
a1 = self._get_input(xp, self.type_l, dtype, size=5)
a2 = self._get_input(xp, self.type_r, dtype, size=5)
Expand Down

0 comments on commit 25b5b33

Please sign in to comment.