diff --git a/.github/workflows/rocm-ci.yml b/.github/workflows/rocm-ci.yml index 7f9cd760aee5..e910e8723b9b 100644 --- a/.github/workflows/rocm-ci.yml +++ b/.github/workflows/rocm-ci.yml @@ -59,5 +59,5 @@ jobs: - name: Run tests run: | cd $WORKSPACE_DIR - python3 build/rocm/ci_build test $TEST_IMAGE + python3 build/rocm/ci_build test $TEST_IMAGE --test-cmd "pytest tests/core_test.py" diff --git a/build/rocm/ci_build b/build/rocm/ci_build index 2556633482ff..86faa5e42fdd 100755 --- a/build/rocm/ci_build +++ b/build/rocm/ci_build @@ -202,7 +202,7 @@ def dist_docker( subprocess.check_call(cmd) -def test(image_name): +def test(image_name, test_cmd): """Run unit tests like CI would inside a JAX image.""" gpu_args = [ @@ -236,7 +236,7 @@ def test(image_name): cmd.extend(mounts) cmd.extend(gpu_args) - container_cmd = "cd /jax && ./build/rocm/build_rocm.sh && ./build/rocm/run_single_gpu.py -c && ./build/rocm/run_multi_gpu.sh" + container_cmd = "cd /jax && " + test_cmd cmd.append(image_name) cmd.extend( [ @@ -299,6 +299,7 @@ def parse_args(): testp = subp.add_parser("test") testp.add_argument("image_name") + testp.add_argument("--test-cmd", default="./build/rocm/build_rocm.sh && ./build/rocm/run_single_gpu.py -c && ./build/rocm/run_multi_gpu.sh") ddp = subp.add_parser("dist_docker") ddp.add_argument("--dockerfile", default="build/rocm/Dockerfile.ms") @@ -322,7 +323,7 @@ def main(): ) elif args.action == "test": - test(args.image_name) + test(args.image_name, args.test_cmd) elif args.action == "dist_docker": dist_wheels(