Skip to content

Commit

Permalink
Mixtral batching support (#108)
Browse files Browse the repository at this point in the history
* moe arch

* param mapping

* fix

* adopt hf weights

* fix

* speedup weight preprocess

* fix

* batch

* remove deadcode

* cleanup
  • Loading branch information
vinx13 authored Dec 12, 2023
1 parent 73926cb commit 87eef11
Show file tree
Hide file tree
Showing 9 changed files with 477 additions and 17 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
cmake_minimum_required(VERSION 3.18)
project(mlc_llm C CXX)
project(mlc_llm C CXX CUDA)

include(CheckCXXCompilerFlag)
if(NOT MSVC)
Expand Down
1 change: 1 addition & 0 deletions cpp/conv_templates.cc
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,7 @@ Conversation Conversation::FromTemplate(const std::string& name) {
{"wizardlm_7b", WizardLM7B},
{"wizard_coder_or_math", WizardCoderOrMATH},
{"glm", GLM},
{"mixtral_default", MistralDefault}
};
auto it = factory.find(name);
if (it == factory.end()) {
Expand Down
4 changes: 3 additions & 1 deletion mlc_llm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,11 +854,13 @@ def build_model_from_args(args: argparse.Namespace):
"rwkv": rwkv,
"rwkv_world": rwkv,
"chatglm": chatglm,
"mixtral": llama,
}

if args.use_vllm_attention:
model_generators["llama"] = llama_batched_vllm
model_generators["mistral"] = llama_batched_vllm
model_generators["mixtral"] = llama_batched_vllm

assert args.model_category in model_generators, f"Model {args.model} not supported"

Expand All @@ -884,7 +886,7 @@ def build_model_from_args(args: argparse.Namespace):
# Run pre-quantization if provided.
args.model_path = param_manager.run_pre_quantize(args.model_path)
param_manager.init_torch_pname_to_bin_name(args.use_safetensors)
parameter_transforms.append(param_manager.create_parameter_transformation())
parameter_transforms.append(param_manager.create_parameter_transformation(optimize_parameter_order=False)) # disable to prevent errors

# Run pre-sharding if required
if args.num_shards > 1 and args.use_presharded_weights:
Expand Down
8 changes: 6 additions & 2 deletions mlc_llm/quantization/ft_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
from tvm.relax.expr_functor import visitor

from . import tir_utils
from .quantization import QuantizationSpec, QuantSpecUpdater
from .quantization import QuantizationSpec, QuantSpecUpdater, NoQuantizationSpec
from .quantization import FQuantize, convert_TE_func
from .group_quantization import GroupQuantizationSpec



@dataclass
class FTQuantizationSpec(QuantizationSpec):
"""The quantization specification for the FasterTransformer kernel."""
Expand Down Expand Up @@ -203,7 +204,10 @@ def visit_call_(self, call: relax.Call):

param = self.param_map[rhs.args[0]]

if call.struct_info.dtype == "float32" or rhs.struct_info.shape[-1] % 8 != 0:
if isinstance(rhs.struct_info.shape[-1], tvm.tir.IntImm) and int(rhs.struct_info.shape[-1]) <= 8:
# gate in MoE
param.quant_spec = NoQuantizationSpec("float16")
elif call.struct_info.dtype == "float32" or rhs.struct_info.shape[-1] % 8 != 0:
# FT requires N to be a multiple of 8
# FT does not support fp32 output dtype
# TODO(masahi): If `matmul(..., out_dtype="float32")` is immediately followed
Expand Down
30 changes: 30 additions & 0 deletions mlc_llm/relax_model/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,43 @@ def shard_gate_up_weight_scale(weight: relax.TensorStructInfo):
func = te.create_prim_func([a, w])
return func

def moe_shard_k_weight_scale(weight: relax.TensorStructInfo):
(num_experts, red, spatial), dtype = weight.shape, weight.dtype
spatial, red = int(spatial), int(red)
if param_shape_is_already_sharded:
red *= num_shards
a = te.placeholder((num_experts, red, spatial), dtype=dtype)
w = topi.reshape(a, (num_experts, num_shards, red // num_shards, spatial))
w = topi.transpose(w, (1, 0, 2, 3))
func = te.create_prim_func([a, w])
return func

def moe_shard_gate_up_weight_scale(weight: relax.TensorStructInfo):
(num_experts, red, spatial), dtype = weight.shape, weight.dtype
spatial, red = int(spatial), int(red)
if param_shape_is_already_sharded:
spatial *= num_shards
a = te.placeholder((num_experts, red, spatial), dtype=dtype)
g = te.compute((num_experts, red, spatial // 2), lambda e, i, j: a[e, i, j])
u = te.compute((num_experts, red, spatial // 2), lambda e, i, j: a[e, i, spatial // 2 + j])
g = topi.reshape(g, (num_experts, red, num_shards, spatial // 2 // num_shards))
u = topi.reshape(u, (num_experts, red, num_shards, spatial // 2 // num_shards))
w = topi.concatenate((g, u), axis=3)
w = topi.reshape(w, (num_experts, red, num_shards, spatial // num_shards))
w = topi.transpose(w, (2, 0, 1, 3))
func = te.create_prim_func([a, w])
return func


# pylint: enable=invalid-name

return {
"shard_qkv": shard_qkv_weight_scale,
"shard_mlp_k": shard_k_weight_scale,
"shard_o_proj_k": shard_k_weight_scale,
"shard_gate_up": shard_gate_up_weight_scale,
"moe_shard_mlp_k": moe_shard_k_weight_scale,
"moe_shard_gate_up": moe_shard_gate_up_weight_scale,
}


Expand Down
Loading

0 comments on commit 87eef11

Please sign in to comment.