diff --git a/sharktank/tests/export_test.py b/sharktank/tests/export_test.py index 20b7de734..44f16c113 100644 --- a/sharktank/tests/export_test.py +++ b/sharktank/tests/export_test.py @@ -4,6 +4,8 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import pytest + from sharktank.types import ( ReplicatedTensor, SplitPrimitiveTensor, @@ -70,6 +72,7 @@ def testGetFlatArgumentDeviceAffinities(self): } assert_dicts_equal(affinities, expected_affinities) + @pytest.mark.xfail(torch.__version__ >= (2,4), reason="https://github.com/nod-ai/shark-ai/issues/685") def testExportWithArgumentDeviceAffinities(self): args = (ReplicatedTensor(ts=[torch.tensor([1])]), torch.tensor([[2]]))