From 5b75ea1e4b62d50d7f8d9d6455879bae831a5c51 Mon Sep 17 00:00:00 2001
From: Nithin Meganathan <18070964+nithinsubbiah@users.noreply.github.com>
Date: Mon, 2 Dec 2024 17:45:07 -0800
Subject: [PATCH] [CI] Add punet export test (#623)

Adds back punet export test to the CI to run on pre-commit. Takes about
~15 minutes for the tests to run
---
 .github/workflows/ci-sharktank.yml            | 40 +++++++++++++++++++
 .../models/punet/integration_test.py          |  2 +
 2 files changed, 42 insertions(+)

diff --git a/.github/workflows/ci-sharktank.yml b/.github/workflows/ci-sharktank.yml
index 8aabc8082..38de6eb22 100644
--- a/.github/workflows/ci-sharktank.yml
+++ b/.github/workflows/ci-sharktank.yml
@@ -130,3 +130,43 @@ jobs:
             --with-t5-data \
             sharktank/tests/models/t5/t5_test.py \
             --durations=0
+
+
+  test_integration:
+    name: "Model Integration Tests"
+    runs-on: ubuntu-24.04
+    env:
+      PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache"
+    steps:
+      - name: "Setting up Python"
+        id: setup_python
+        uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
+        with:
+          python-version: 3.11
+
+      - name: "Checkout Code"
+        uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
+
+      - name: Cache Pip Packages
+        uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
+        id: cache-pip
+        with:
+          path: ${{ env.PIP_CACHE_DIR }}
+          key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements*.txt','sharktank/requirements*.txt') }}
+
+      - name: Install pip deps
+        run: |
+          python -m pip install --no-compile --upgrade pip
+          # Note: We install in three steps in order to satisfy requirements
+          # from non default locations first. Installing the PyTorch CPU
+          # wheels saves multiple minutes and a lot of bandwidth on runner setup.
+          pip install --no-compile -r pytorch-cpu-requirements.txt
+          pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/
+          # Update to the latest iree packages.
+          pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \
+            iree-base-compiler iree-base-runtime --src deps \
+            -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine"
+      - name: Run punet tests
+        run: |
+          pytest -v sharktank/ -m punet_quick \
+            --durations=0
diff --git a/sharktank/integration/models/punet/integration_test.py b/sharktank/integration/models/punet/integration_test.py
index 45af24004..754a54311 100644
--- a/sharktank/integration/models/punet/integration_test.py
+++ b/sharktank/integration/models/punet/integration_test.py
@@ -143,6 +143,7 @@ def sdxl_fp16_export_mlir(sdxl_fp16_dataset, temp_dir):
     return output_path
 
 
+@pytest.mark.punet_quick
 @pytest.mark.model_punet
 @pytest.mark.export
 def test_sdxl_export_fp16_mlir(sdxl_fp16_export_mlir):
@@ -166,6 +167,7 @@ def sdxl_int8_export_mlir(sdxl_int8_dataset, temp_dir):
     return output_path
 
 
+@pytest.mark.punet_quick
 @pytest.mark.model_punet
 @pytest.mark.export
 def test_sdxl_export_int8_mlir(sdxl_int8_export_mlir):