Skip to content

Commit

Permalink
Update mm_group_quant.py template for MLIR upstream change (#23)
Browse files Browse the repository at this point in the history
Must feed in output shape to the tensor.expand_shape op after an
upstream change. This allows us to patch IR for quantized llama2.
  • Loading branch information
monorimet authored May 29, 2024
1 parent c4db712 commit 4b451f8
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 11 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ jobs:
# from non default locations first. Installing the PyTorch CPU
# wheels saves multiple minutes and a lot of bandwidth on runner setup.
pip install --no-compile -r pytorch-cpu-requirements.txt
pip install --no-cache-dir -r iree-requirements.txt -f https://iree.dev/pip-release-links.html
pip install -r requirements.txt -e .
- name: Run unit tests
Expand Down
11 changes: 2 additions & 9 deletions .github/workflows/test_shark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,10 @@ jobs:
path: SHARK
ref: "iree-turbine-switch"

# TODO: Replace with a sh script from shark repo
- name: "Install SHARK"
run: |
cd $GITHUB_WORKSPACE/SHARK
python${{ matrix.version }} -m venv shark.venv
source shark.venv/bin/activate
sed -i 's/iree-turbine#/iree-turbine.git@${{github.sha}}#/g' requirements.txt
pip install -r requirements.txt --no-cache-dir
pip install -e .
pip uninstall -y torch
pip install torch==2.1.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip uninstall -y mpmath
pip install mpmath==1.3.0
./setup_venv.sh
source shark.venv/bin/activate
python apps/shark_studio/tests/api_test.py
8 changes: 6 additions & 2 deletions shark_turbine/transforms/quantization/mm_group_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,15 @@ def match(self, op: Operation):
%weight_raw = util.global.load @{param_name} : tensor<{k}x{n_div}xi8>
%m = tensor.dim %a, %c0 : tensor<{m}x{n}x{element_type}>
%k = tensor.dim %weight_raw, %c0 : tensor<{k}x{n_div}xi8>
%m_1 = tensor.dim %a, %c0 : tensor<{m}x{n}x{element_type}>
%k_1 = tensor.dim %weight_raw, %c0 : tensor<{k}x{n_div}xi8>
%group0_dim = arith.constant {group0} : index
%group1_dim = arith.constant {group1} : index
%scale = util.global.load @{param_name}.quant.scale : tensor<{k}x{group0}x{element_type}>
%zp = util.global.load @{param_name}.quant.zero_point : tensor<{k}x{group0}x{element_type}>
%weight = flow.tensor.bitcast %weight_raw : tensor<{k}x{n_div}xi8> -> tensor<{k}x{n}x{lowp_type}>
%a_exp = tensor.expand_shape %a [[0], [1, 2]] : tensor<{m}x{n}x{element_type}> into tensor<{m}x{group0}x{group1}x{element_type}>
%weight_exp = tensor.expand_shape %weight [[0], [1, 2]] : tensor<{k}x{n}x{lowp_type}> into tensor<{k}x{group0}x{group1}x{lowp_type}>
%a_exp = tensor.expand_shape %a [[0], [1, 2]] output_shape [%m_1, %group0_dim, %group1_dim]: tensor<{m}x{n}x{element_type}> into tensor<{m}x{group0}x{group1}x{element_type}>
%weight_exp = tensor.expand_shape %weight [[0], [1, 2]] output_shape [%k_1, %group0_dim, %group1_dim]: tensor<{k}x{n}x{lowp_type}> into tensor<{k}x{group0}x{group1}x{lowp_type}>
%empty_0 = tensor.empty() : tensor<{k}x{group0}x{group1}x{element_type}>
%weight_cast = linalg.generic {{
indexing_maps = [
Expand Down

0 comments on commit 4b451f8

Please sign in to comment.