From 5b75ea1e4b62d50d7f8d9d6455879bae831a5c51 Mon Sep 17 00:00:00 2001 From: Nithin Meganathan <18070964+nithinsubbiah@users.noreply.github.com> Date: Mon, 2 Dec 2024 17:45:07 -0800 Subject: [PATCH] [CI] Add punet export test (#623) Adds back punet export test to the CI to run on pre-commit. Takes about ~15 minutes for the tests to run --- .github/workflows/ci-sharktank.yml | 40 +++++++++++++++++++ .../models/punet/integration_test.py | 2 + 2 files changed, 42 insertions(+) diff --git a/.github/workflows/ci-sharktank.yml b/.github/workflows/ci-sharktank.yml index 8aabc8082..38de6eb22 100644 --- a/.github/workflows/ci-sharktank.yml +++ b/.github/workflows/ci-sharktank.yml @@ -130,3 +130,43 @@ jobs: --with-t5-data \ sharktank/tests/models/t5/t5_test.py \ --durations=0 + + + test_integration: + name: "Model Integration Tests" + runs-on: ubuntu-24.04 + env: + PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache" + steps: + - name: "Setting up Python" + id: setup_python + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + with: + python-version: 3.11 + + - name: "Checkout Code" + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Cache Pip Packages + uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 + id: cache-pip + with: + path: ${{ env.PIP_CACHE_DIR }} + key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements*.txt','sharktank/requirements*.txt') }} + + - name: Install pip deps + run: | + python -m pip install --no-compile --upgrade pip + # Note: We install in three steps in order to satisfy requirements + # from non default locations first. Installing the PyTorch CPU + # wheels saves multiple minutes and a lot of bandwidth on runner setup. + pip install --no-compile -r pytorch-cpu-requirements.txt + pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/ + # Update to the latest iree packages. + pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \ + iree-base-compiler iree-base-runtime --src deps \ + -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" + - name: Run punet tests + run: | + pytest -v sharktank/ -m punet_quick \ + --durations=0 diff --git a/sharktank/integration/models/punet/integration_test.py b/sharktank/integration/models/punet/integration_test.py index 45af24004..754a54311 100644 --- a/sharktank/integration/models/punet/integration_test.py +++ b/sharktank/integration/models/punet/integration_test.py @@ -143,6 +143,7 @@ def sdxl_fp16_export_mlir(sdxl_fp16_dataset, temp_dir): return output_path +@pytest.mark.punet_quick @pytest.mark.model_punet @pytest.mark.export def test_sdxl_export_fp16_mlir(sdxl_fp16_export_mlir): @@ -166,6 +167,7 @@ def sdxl_int8_export_mlir(sdxl_int8_dataset, temp_dir): return output_path +@pytest.mark.punet_quick @pytest.mark.model_punet @pytest.mark.export def test_sdxl_export_int8_mlir(sdxl_int8_export_mlir):