Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mixtral batching support #108

Merged
merged 10 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
cmake_minimum_required(VERSION 3.18)
project(mlc_llm C CXX)
project(mlc_llm C CXX CUDA)

include(CheckCXXCompilerFlag)
if(NOT MSVC)
Expand Down
1 change: 1 addition & 0 deletions cpp/conv_templates.cc
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,7 @@ Conversation Conversation::FromTemplate(const std::string& name) {
{"wizardlm_7b", WizardLM7B},
{"wizard_coder_or_math", WizardCoderOrMATH},
{"glm", GLM},
{"mixtral_default", MistralDefault}
};
auto it = factory.find(name);
if (it == factory.end()) {
Expand Down
4 changes: 3 additions & 1 deletion mlc_llm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,11 +854,13 @@ def build_model_from_args(args: argparse.Namespace):
"rwkv": rwkv,
"rwkv_world": rwkv,
"chatglm": chatglm,
"mixtral": llama,
}

if args.use_vllm_attention:
model_generators["llama"] = llama_batched_vllm
model_generators["mistral"] = llama_batched_vllm
model_generators["mixtral"] = llama_batched_vllm

assert args.model_category in model_generators, f"Model {args.model} not supported"

Expand All @@ -884,7 +886,7 @@ def build_model_from_args(args: argparse.Namespace):
# Run pre-quantization if provided.
args.model_path = param_manager.run_pre_quantize(args.model_path)
param_manager.init_torch_pname_to_bin_name(args.use_safetensors)
parameter_transforms.append(param_manager.create_parameter_transformation())
parameter_transforms.append(param_manager.create_parameter_transformation(optimize_parameter_order=False)) # disable to prevent errors

# Run pre-sharding if required
if args.num_shards > 1 and args.use_presharded_weights:
Expand Down
8 changes: 6 additions & 2 deletions mlc_llm/quantization/ft_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
from tvm.relax.expr_functor import visitor

from . import tir_utils
from .quantization import QuantizationSpec, QuantSpecUpdater
from .quantization import QuantizationSpec, QuantSpecUpdater, NoQuantizationSpec
from .quantization import FQuantize, convert_TE_func
from .group_quantization import GroupQuantizationSpec



@dataclass
class FTQuantizationSpec(QuantizationSpec):
"""The quantization specification for the FasterTransformer kernel."""
Expand Down Expand Up @@ -203,7 +204,10 @@ def visit_call_(self, call: relax.Call):

param = self.param_map[rhs.args[0]]

if call.struct_info.dtype == "float32" or rhs.struct_info.shape[-1] % 8 != 0:
if isinstance(rhs.struct_info.shape[-1], tvm.tir.IntImm) and int(rhs.struct_info.shape[-1]) <= 8:
# gate in MoE
param.quant_spec = NoQuantizationSpec("float16")
elif call.struct_info.dtype == "float32" or rhs.struct_info.shape[-1] % 8 != 0:
# FT requires N to be a multiple of 8
# FT does not support fp32 output dtype
# TODO(masahi): If `matmul(..., out_dtype="float32")` is immediately followed
Expand Down
30 changes: 30 additions & 0 deletions mlc_llm/relax_model/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,43 @@ def shard_gate_up_weight_scale(weight: relax.TensorStructInfo):
func = te.create_prim_func([a, w])
return func

def moe_shard_k_weight_scale(weight: relax.TensorStructInfo):
(num_experts, red, spatial), dtype = weight.shape, weight.dtype
spatial, red = int(spatial), int(red)
if param_shape_is_already_sharded:
red *= num_shards
a = te.placeholder((num_experts, red, spatial), dtype=dtype)
w = topi.reshape(a, (num_experts, num_shards, red // num_shards, spatial))
w = topi.transpose(w, (1, 0, 2, 3))
func = te.create_prim_func([a, w])
return func

def moe_shard_gate_up_weight_scale(weight: relax.TensorStructInfo):
(num_experts, red, spatial), dtype = weight.shape, weight.dtype
spatial, red = int(spatial), int(red)
if param_shape_is_already_sharded:
spatial *= num_shards
a = te.placeholder((num_experts, red, spatial), dtype=dtype)
g = te.compute((num_experts, red, spatial // 2), lambda e, i, j: a[e, i, j])
u = te.compute((num_experts, red, spatial // 2), lambda e, i, j: a[e, i, spatial // 2 + j])
g = topi.reshape(g, (num_experts, red, num_shards, spatial // 2 // num_shards))
u = topi.reshape(u, (num_experts, red, num_shards, spatial // 2 // num_shards))
w = topi.concatenate((g, u), axis=3)
w = topi.reshape(w, (num_experts, red, num_shards, spatial // num_shards))
w = topi.transpose(w, (2, 0, 1, 3))
func = te.create_prim_func([a, w])
return func
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these functions are not used for FT + multi gpu. Can you confirm?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will double check. doesn't it depend on disco sharding?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For FT quantization + disco, we need to use https://github.com/vinx13/mlc-llm/blob/113bd1873cb563151ed5675730be0e53560c7ab2/mlc_llm/relax_model/commons.py#L124

It hasn't been upstreamed (we should)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, q4 multigpu is probably broken. This is needed for q0f16 (we are using ft kernel for moe) though



# pylint: enable=invalid-name

return {
"shard_qkv": shard_qkv_weight_scale,
"shard_mlp_k": shard_k_weight_scale,
"shard_o_proj_k": shard_k_weight_scale,
"shard_gate_up": shard_gate_up_weight_scale,
"moe_shard_mlp_k": moe_shard_k_weight_scale,
"moe_shard_gate_up": moe_shard_gate_up_weight_scale,
}


Expand Down
Loading