From 5789469d857457d5223af9b528be8ffc467cb4bd Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 11 Dec 2023 00:06:49 +0000 Subject: [PATCH 01/10] moe arch --- mlc_llm/core.py | 13 ++ mlc_llm/relax_model/llama.py | 196 +++++++++++++++++++++++++-- mlc_llm/relax_model/param_manager.py | 7 + mlc_llm/utils.py | 21 +-- 4 files changed, 218 insertions(+), 19 deletions(-) diff --git a/mlc_llm/core.py b/mlc_llm/core.py index 1bd0d0266a..208b8ffd5b 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -476,6 +476,14 @@ def _parse_args(parsed) -> argparse.Namespace: def _setup_model_path(args: argparse.Namespace): # pylint: disable=too-many-branches + if 'mixtral' in args.model: + if os.path.isdir(args.model): + args.model = os.path.normpath(args.model) # Remove potential trailing `/` + args.model_path = args.model + args.model = os.path.basename(args.model) + else: + args.model_path = os.path.join(args.artifact_path, "models", args.model) + return args if args.hf_path: if args.model != "auto": assert args.model == os.path.basename(args.hf_path), ( @@ -838,6 +846,9 @@ def build_model_from_args(args: argparse.Namespace): if args.model_category == "minigpt": # Special case for minigpt, which neither provides nor requires a configuration. config = {} + elif "mixtral" in args.model: + with open(os.path.join(args.model_path, "params.json"), encoding="utf-8") as i_f: + config = json.load(i_f) else: with open(os.path.join(args.model_path, "config.json"), encoding="utf-8") as i_f: config = json.load(i_f) @@ -854,11 +865,13 @@ def build_model_from_args(args: argparse.Namespace): "rwkv": rwkv, "rwkv_world": rwkv, "chatglm": chatglm, + "mixtral": llama, } if args.use_vllm_attention: model_generators["llama"] = llama_batched_vllm model_generators["mistral"] = llama_batched_vllm + model_generators["mixtral"] = llama_batched_vllm assert args.model_category in model_generators, f"Model {args.model} not supported" diff --git a/mlc_llm/relax_model/llama.py b/mlc_llm/relax_model/llama.py index 88fde9509a..225e889080 100644 --- a/mlc_llm/relax_model/llama.py +++ b/mlc_llm/relax_model/llama.py @@ -73,6 +73,26 @@ def get_num_key_value_heads(self): return self.num_key_value_heads +class MixtralConfig(LlamaConfig): + num_experts_per_tok: int + num_experts: int + def __init__( + self, + **kwargs, + ): + kwargs["num_attention_heads"] = kwargs["n_heads"] + kwargs["num_key_value_heads"] = kwargs["n_kv_heads"] + kwargs["rms_norm_eps"] = kwargs["norm_eps"] + kwargs["num_hidden_layers"] = kwargs["n_layers"] + kwargs["intermediate_size"] = kwargs["hidden_dim"] + kwargs["hidden_size"] = kwargs["dim"] # n heads * head_size + + super().__init__(**kwargs) + moe_config = kwargs["moe"] + self.num_experts_per_tok = moe_config["num_experts_per_tok"] + self.num_experts = moe_config["num_experts"] + + class Linear(nn.Module): def __init__(self, in_features, out_features, dtype: str, bias=True): self.in_features = in_features @@ -556,12 +576,148 @@ def attention_fwd( return attn_output, past_key_values +class MoELinear(nn.Module): + def __init__(self, num_experts, in_features, out_features, bias=False): + assert not bias, "bias not supported" + self.num_experts = num_experts + self.in_features = in_features + self.out_features = out_features + + # weight is row major + self.weight = nn.Parameter( + (num_experts, in_features, out_features), + dtype="float16", + name="expert_weight", + ) + + def forward(self, x, rows_before): + assert len(x.struct_info.shape) == 2 + total_rows = x.struct_info.shape[0] + return nn.emit( + relax.call_dps_packed( + "cutlass.moe_gemm_f16f16", + [ + x, + self.weight, + rows_before, + total_rows, + self.out_features, # gemm_n + self.in_features, # gemm_k + self.num_experts, + ], + out_sinfo=x.struct_info, + ) + ) + +class MoEMLP(nn.Module): + def __init__(self, config: MixtralConfig): + self.num_experts_per_tok = config.num_experts_per_tok + self.num_experts = config.num_experts + self.linear1 = MoELinear(self.num_experts, config.hidden_size, config.intermediate_size, bias=False) + self.linear2 = MoELinear(self.num_experts, config.intermediate_size, config.hidden_size, bias=False) + self.linear3 = MoELinear(self.num_experts, config.hidden_size, config.intermediate_size, bias=False) + + def forward(self, hidden_states: relax.Expr, rows_before: relax.Expr): + # TODO: combine matmul + # TODO: disco + gate_result = self.linear1(hidden_states, rows_before) + up_result = self.linear3(hidden_states, rows_before) + result = self.linear2(nn.emit(relax.op.nn.silu(gate_result) * up_result), rows_before) + return result + +class MoE(nn.Module): + def __init__(self, config: MixtralConfig): + self.experts = MoEMLP(config) + self.gate = Linear(in_features=config.hidden_size, out_features=config.num_experts, bias=False, dtype=config.dtype) + self.num_experts_per_tok = config.num_experts_per_tok + self.num_experts = config.num_experts + + def topk(self, x, is_ascend, index_dtype, k = -1): + # topk along axis -1 + result = nn.emit( + relax.call_dps_packed( + "tvm.contrib.thrust.sort_dps", + [x, is_ascend], + out_sinfo= + [ + x.struct_info, + relax.TensorStructInfo(x.struct_info.shape, index_dtype), + ] + ) + ) + sorted_x = relax.TupleGetItem(result, 0) + indices = relax.TupleGetItem(result, 1) + if k != -1: + ndim = len(x.struct_info.shape) + beg = [0] * ndim + end = [x.struct_info.shape[i] for i in range(ndim - 1)] + [k] + axes = list(range(ndim)) + sorted_x = nn.emit(relax.op.strided_slice(sorted_x, axes, beg, end)) + indices = nn.emit(relax.op.strided_slice(indices, axes, beg, end)) + return sorted_x, indices + + def compute_rows_before(self, sorted_expert_ids): + return nn.emit( + relax.call_dps_packed( + "moe_compute_rows_before", + [ sorted_expert_ids], + out_sinfo=relax.TensorStructInfo([self.num_experts], "int64") + ) + ) + + def scatter(self, linear_out, indices): + return nn.emit( + relax.call_dps_packed( + "scatter", + [linear_out, indices], + out_sinfo=linear_out.struct_info, + ) + ) + + def get_token_indices(self, indices): + def te_compute(x): + return tvm.te.compute(x.shape, lambda *idx: tvm.tir.indexdiv(x(*idx), tvm.runtime.const(self.num_experts_per_tok, dtype="int32")).astype("int32")) + return nn.emit_te(te_compute, indices) + + def forward(self, hidden_states): + hidden_states_shape = hidden_states.struct_info.shape + hidden_size = hidden_states_shape[-1] + # reshape to 2D + hidden_states = nn.emit(relax.op.reshape(hidden_states, (-1, hidden_size))) + + # TODO: switch topk softmax + gate = self.gate(hidden_states) + scores = nn.emit(relax.op.nn.softmax(gate, axis=-1)) + + expert_weights, expert_indices = self.topk(scores, is_ascend=False, k=self.num_experts_per_tok, index_dtype="int32") # (num_tokens, top_k), (num_tokens, top_k) + flattened_indices = nn.emit(relax.op.flatten(expert_indices)) + sorted_expert_ids, indices = self.topk(flattened_indices, is_ascend=True, index_dtype="int32") + + rows_before = self.compute_rows_before(sorted_expert_ids) + token_indices = self.get_token_indices(indices) + gathered_x = nn.emit(relax.op.take(hidden_states, token_indices, axis=0)) + linear_out = self.experts(gathered_x, rows_before) + unpermuted = self.scatter(linear_out, indices) + unflattened = nn.emit(relax.op.reshape(unpermuted, (-1, self.num_experts_per_tok, hidden_size))) + expert_weights = nn.emit(relax.op.reshape(expert_weights, (-1, self.num_experts_per_tok, 1))) + weighted_sum = nn.emit(relax.op.sum(unflattened * expert_weights, axis=1)) + + # reshape back to 3D + weighted_sum = nn.emit(relax.op.reshape(weighted_sum, hidden_states_shape)) + return weighted_sum + + class LlamaDecoderLayer(nn.Module): def __init__(self, config: LlamaConfig, enable_batching: bool): attn_class = LlamaPagedAttention if enable_batching else LlamaAttention self.hidden_size = config.hidden_size self.self_attn = attn_class(config) - self.mlp = LlamaMLP(config) + if isinstance(config, MixtralConfig): + self.use_moe = True + self.feed_forward = MoE(config) + else: + self.use_moe = False + self.mlp = LlamaMLP(config) self.input_layernorm = LlamaRMSNorm( config.hidden_size, dtype=config.dtype, eps=config.rms_norm_eps ) @@ -578,17 +734,23 @@ def post_self_attn(self, hidden_states, residual): if self.self_attn.num_shards > 1: hidden_states = nn.emit(ccl.allreduce(hidden_states, "sum")) - # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - if self.mlp.num_shards > 1: - residual = nn.emit( - residual / R.const(self.mlp.num_shards, dtype=residual.struct_info.dtype) - ) - hidden_states = nn.emit(residual + hidden_states) - if self.mlp.num_shards > 1: - hidden_states = nn.emit(ccl.allreduce(hidden_states, "sum")) + + if not self.use_moe: + # Fully Connected + hidden_states = self.mlp(hidden_states) + if self.mlp.num_shards > 1: + residual = nn.emit( + residual / R.const(self.mlp.num_shards, dtype=residual.struct_info.dtype) + ) + hidden_states = nn.emit(residual + hidden_states) + if self.mlp.num_shards > 1: + hidden_states = nn.emit(ccl.allreduce(hidden_states, "sum")) + else: + # TODO: disco integration + hidden_states = self.feed_forward(hidden_states) + hidden_states = nn.emit(residual + hidden_states) return hidden_states @@ -1310,7 +1472,19 @@ def get_model(args, hf_config): # while Llama-1 variants use `max_sequence_length`. # Thus, use `max_sequence_length` if defined. Otherwise, use `max_position_embeddings`. # If none of them is defined, throw an error. - if "max_sequence_length" in hf_config: + if 'mixtral' in args.model: + # FIXME + hf_config['max_sequence_length'] = 4096 + # hf_config['num_attention_heads'] = + config = MixtralConfig( + **hf_config, + dtype=dtype, + position_embedding_base=position_embedding_base, + combine_matmul=True, + num_shards=args.num_shards, + build_model_only=args.build_model_only, + ) + elif "max_sequence_length" in hf_config: config = LlamaConfig( **hf_config, dtype=dtype, diff --git a/mlc_llm/relax_model/param_manager.py b/mlc_llm/relax_model/param_manager.py index 69a25ccb73..ae43f8d644 100644 --- a/mlc_llm/relax_model/param_manager.py +++ b/mlc_llm/relax_model/param_manager.py @@ -904,6 +904,13 @@ def load_torch_pname2binname_map( for relax_pname in relax_pnames for torch_pname in f_convert_pname_fwd(relax_pname) } + elif "mixtral" in model_path: + ckpt_path = os.path.join(model_path, "consolidated.00.pth") + torch_pname2binname = { + torch_pname: ckpt_path + for relax_pname in relax_pnames + for torch_pname in f_convert_pname_fwd(relax_pname) + } else: suffix = ".safetensors" if use_safetensors else ".bin" shard_names = [] diff --git a/mlc_llm/utils.py b/mlc_llm/utils.py index 7d34505da9..57ffa69143 100644 --- a/mlc_llm/utils.py +++ b/mlc_llm/utils.py @@ -69,15 +69,20 @@ def argparse_postproc_common(args: argparse.Namespace) -> None: "rwkv_world": "rwkv_world", "minigpt": "minigpt", } - try: - with open(os.path.join(args.model_path, "config.json"), encoding="utf-8") as i_f: - config = json.load(i_f) - args.model_category = config["model_type"] + if "mixtral" in args.model_path: + args.model_category = "mixtral" + config = open(os.path.join(args.model_path, "params.json"), encoding="utf-8") model_path_lower = args.model_path.lower() - if "rwkv" in model_path_lower and "world" in model_path_lower: - args.model_category = "rwkv_world" - except Exception: - args.model_category = "" + else: + try: + with open(os.path.join(args.model_path, "config.json"), encoding="utf-8") as i_f: + config = json.load(i_f) + args.model_category = config["model_type"] + model_path_lower = args.model_path.lower() + if "rwkv" in model_path_lower and "world" in model_path_lower: + args.model_category = "rwkv_world" + except Exception: + args.model_category = "" model = args.model.lower() if "rwkv" in model and "world" in model: model = "rwkv_world" From 2643016a4558fae6efebfa5080be02d587ab3eb4 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 11 Dec 2023 00:42:27 +0000 Subject: [PATCH 02/10] param mapping --- mlc_llm/relax_model/commons.py | 30 ++++ mlc_llm/relax_model/llama.py | 302 ++++++++++++++++++++++++++------- 2 files changed, 271 insertions(+), 61 deletions(-) diff --git a/mlc_llm/relax_model/commons.py b/mlc_llm/relax_model/commons.py index 3eb67b5b1c..26441c411b 100644 --- a/mlc_llm/relax_model/commons.py +++ b/mlc_llm/relax_model/commons.py @@ -81,6 +81,34 @@ def shard_gate_up_weight_scale(weight: relax.TensorStructInfo): func = te.create_prim_func([a, w]) return func + def moe_shard_k_weight_scale(weight: relax.TensorStructInfo): + (num_experts, red, spatial), dtype = weight.shape, weight.dtype + spatial, red = int(spatial), int(red) + if param_shape_is_already_sharded: + red *= num_shards + a = te.placeholder((num_experts, red, spatial), dtype=dtype) + w = topi.reshape(a, (num_experts, num_shards, red // num_shards, spatial)) + w = topi.transpose(w, (1, 0, 2, 3)) + func = te.create_prim_func([a, w]) + return func + + def moe_shard_gate_up_weight_scale(weight: relax.TensorStructInfo): + (num_experts, red, spatial), dtype = weight.shape, weight.dtype + spatial, red = int(spatial), int(red) + if param_shape_is_already_sharded: + spatial *= num_shards + a = te.placeholder((num_experts, red, spatial), dtype=dtype) + g = te.compute((num_experts, red, spatial // 2), lambda e, i, j: a[e, i, j]) + u = te.compute((num_experts, red, spatial // 2), lambda e, i, j: a[e, i, spatial // 2 + j]) + g = topi.reshape(g, (num_experts, red, num_shards, spatial // 2 // num_shards)) + u = topi.reshape(u, (num_experts, red, num_shards, spatial // 2 // num_shards)) + w = topi.concatenate((g, u), axis=3) + w = topi.reshape(w, (num_experts, red, num_shards, spatial // num_shards)) + w = topi.transpose(w, (2, 0, 1, 3)) + func = te.create_prim_func([a, w]) + return func + + # pylint: enable=invalid-name return { @@ -88,6 +116,8 @@ def shard_gate_up_weight_scale(weight: relax.TensorStructInfo): "shard_mlp_k": shard_k_weight_scale, "shard_o_proj_k": shard_k_weight_scale, "shard_gate_up": shard_gate_up_weight_scale, + "moe_shard_mlp_k": moe_shard_k_weight_scale, + "moe_shard_gate_up": moe_shard_gate_up_weight_scale, } diff --git a/mlc_llm/relax_model/llama.py b/mlc_llm/relax_model/llama.py index 225e889080..3a2502c27a 100644 --- a/mlc_llm/relax_model/llama.py +++ b/mlc_llm/relax_model/llama.py @@ -76,6 +76,8 @@ def get_num_key_value_heads(self): class MixtralConfig(LlamaConfig): num_experts_per_tok: int num_experts: int + quantization_scheme: QuantizationScheme + def __init__( self, **kwargs, @@ -92,6 +94,9 @@ def __init__( self.num_experts_per_tok = moe_config["num_experts_per_tok"] self.num_experts = moe_config["num_experts"] + # FIXME: remove this + self.quantization_scheme = kwargs["quantization_scheme"] + class Linear(nn.Module): def __init__(self, in_features, out_features, dtype: str, bias=True): @@ -577,72 +582,135 @@ def attention_fwd( class MoELinear(nn.Module): - def __init__(self, num_experts, in_features, out_features, bias=False): + def __init__(self, config: MixtralConfig, num_experts, in_features, out_features, bias=False): assert not bias, "bias not supported" self.num_experts = num_experts self.in_features = in_features self.out_features = out_features - - # weight is row major - self.weight = nn.Parameter( - (num_experts, in_features, out_features), - dtype="float16", - name="expert_weight", - ) + self.quantization_scheme = config.quantization_scheme + + if config.quantization_scheme.name == "q0f16": + # weight is row major + self.weight = nn.Parameter( + (num_experts, in_features, out_features), + dtype="float16", + name="expert_weight", + ) + elif config.quantization_scheme.name == "q4f16_ft": + self.weight = nn.Parameter( + (num_experts, in_features, out_features), + dtype="int8", + name="expert_weight", + ) + self.scales = nn.Parameter( + (num_experts, out_features), + dtype="float16", + name="expert_scales", + ) + else: + assert False, "unsupported quantization scheme" def forward(self, x, rows_before): assert len(x.struct_info.shape) == 2 total_rows = x.struct_info.shape[0] - return nn.emit( - relax.call_dps_packed( - "cutlass.moe_gemm_f16f16", - [ - x, - self.weight, - rows_before, - total_rows, - self.out_features, # gemm_n - self.in_features, # gemm_k - self.num_experts, - ], - out_sinfo=x.struct_info, + if self.quantization_scheme.name == "q0f16": + return nn.emit( + relax.call_dps_packed( + "cutlass.moe_gemm_f16f16", + [ + x, + self.weight, + rows_before, + total_rows, + self.out_features, # gemm_n + self.in_features, # gemm_k + self.num_experts, + ], + out_sinfo=relax.TensorStructInfo( + (total_rows, self.out_features), + x.struct_info.dtype, + ), + ) ) - ) + class MoEMLP(nn.Module): def __init__(self, config: MixtralConfig): self.num_experts_per_tok = config.num_experts_per_tok self.num_experts = config.num_experts - self.linear1 = MoELinear(self.num_experts, config.hidden_size, config.intermediate_size, bias=False) - self.linear2 = MoELinear(self.num_experts, config.intermediate_size, config.hidden_size, bias=False) - self.linear3 = MoELinear(self.num_experts, config.hidden_size, config.intermediate_size, bias=False) + self.combine_matmul = config.combine_matmul + + self.num_shards = config.num_shards + hidden_size = config.hidden_size + intermediate_size = config.intermediate_size // self.num_shards + + self.down_proj = MoELinear( + config, self.num_experts, intermediate_size, hidden_size, bias=False + ) + if config.combine_matmul: + self.gate_up_combined_proj = MoELinear( + config, + self.num_experts, + hidden_size, + 2 * intermediate_size, + bias=False, + ) + # FIXME: rename to 'gate_up_proj' that's consistent with llama. using this name for now to avoid conflicting pname str replacing rules + # TODO: check sharding is correct, note that the weight is row major + self.gate_up_combined_proj.weight.shard_dim = 2 + self.gate_up_combined_proj.weight.shard_strategy = "moe_shard_gate_up" + self.down_proj.weight.shard_dim = 1 + self.down_proj.weight.shard_strategy = "moe_shard_mlp_k" + else: + self.gate_proj = MoELinear( + config, self.num_experts, config.hidden_size, config.intermediate_size, bias=False + ) + self.up_proj = MoELinear( + config, self.num_experts, config.hidden_size, config.intermediate_size, bias=False + ) def forward(self, hidden_states: relax.Expr, rows_before: relax.Expr): - # TODO: combine matmul # TODO: disco - gate_result = self.linear1(hidden_states, rows_before) - up_result = self.linear3(hidden_states, rows_before) - result = self.linear2(nn.emit(relax.op.nn.silu(gate_result) * up_result), rows_before) + if self.combine_matmul: + gate_up_results = nn.emit( + relax.op.split( + self.gate_up_combined_proj(hidden_states, rows_before), + indices_or_sections=2, + axis=-1, + ) + ) + gate_result = relax.TupleGetItem(gate_up_results, 0) + up_result = relax.TupleGetItem(gate_up_results, 1) + else: + gate_result = self.gate_proj(hidden_states, rows_before) + up_result = self.up_proj(hidden_states, rows_before) + result = self.down_proj(nn.emit(relax.op.nn.silu(gate_result) * up_result), rows_before) return result + class MoE(nn.Module): def __init__(self, config: MixtralConfig): self.experts = MoEMLP(config) - self.gate = Linear(in_features=config.hidden_size, out_features=config.num_experts, bias=False, dtype=config.dtype) + self.num_shards = config.num_shards + self.gate = Linear( + in_features=config.hidden_size, + out_features=config.num_experts, + bias=False, + dtype=config.dtype, + ) self.num_experts_per_tok = config.num_experts_per_tok self.num_experts = config.num_experts - def topk(self, x, is_ascend, index_dtype, k = -1): + def topk(self, x, is_ascend, index_dtype, k=-1): # topk along axis -1 result = nn.emit( relax.call_dps_packed( "tvm.contrib.thrust.sort_dps", [x, is_ascend], - out_sinfo= - [ - x.struct_info, - relax.TensorStructInfo(x.struct_info.shape, index_dtype), - ] + out_sinfo=[ + x.struct_info, + relax.TensorStructInfo(x.struct_info.shape, index_dtype), + ], ) ) sorted_x = relax.TupleGetItem(result, 0) @@ -660,8 +728,8 @@ def compute_rows_before(self, sorted_expert_ids): return nn.emit( relax.call_dps_packed( "moe_compute_rows_before", - [ sorted_expert_ids], - out_sinfo=relax.TensorStructInfo([self.num_experts], "int64") + [sorted_expert_ids], + out_sinfo=relax.TensorStructInfo([self.num_experts], "int64"), ) ) @@ -676,7 +744,13 @@ def scatter(self, linear_out, indices): def get_token_indices(self, indices): def te_compute(x): - return tvm.te.compute(x.shape, lambda *idx: tvm.tir.indexdiv(x(*idx), tvm.runtime.const(self.num_experts_per_tok, dtype="int32")).astype("int32")) + return tvm.te.compute( + x.shape, + lambda *idx: tvm.tir.indexdiv( + x(*idx), tvm.runtime.const(self.num_experts_per_tok, dtype="int32") + ).astype("int32"), + ) + return nn.emit_te(te_compute, indices) def forward(self, hidden_states): @@ -689,17 +763,25 @@ def forward(self, hidden_states): gate = self.gate(hidden_states) scores = nn.emit(relax.op.nn.softmax(gate, axis=-1)) - expert_weights, expert_indices = self.topk(scores, is_ascend=False, k=self.num_experts_per_tok, index_dtype="int32") # (num_tokens, top_k), (num_tokens, top_k) + expert_weights, expert_indices = self.topk( + scores, is_ascend=False, k=self.num_experts_per_tok, index_dtype="int32" + ) # (num_tokens, top_k), (num_tokens, top_k) flattened_indices = nn.emit(relax.op.flatten(expert_indices)) - sorted_expert_ids, indices = self.topk(flattened_indices, is_ascend=True, index_dtype="int32") + sorted_expert_ids, indices = self.topk( + flattened_indices, is_ascend=True, index_dtype="int32" + ) rows_before = self.compute_rows_before(sorted_expert_ids) token_indices = self.get_token_indices(indices) gathered_x = nn.emit(relax.op.take(hidden_states, token_indices, axis=0)) linear_out = self.experts(gathered_x, rows_before) unpermuted = self.scatter(linear_out, indices) - unflattened = nn.emit(relax.op.reshape(unpermuted, (-1, self.num_experts_per_tok, hidden_size))) - expert_weights = nn.emit(relax.op.reshape(expert_weights, (-1, self.num_experts_per_tok, 1))) + unflattened = nn.emit( + relax.op.reshape(unpermuted, (-1, self.num_experts_per_tok, hidden_size)) + ) + expert_weights = nn.emit( + relax.op.reshape(expert_weights, (-1, self.num_experts_per_tok, 1)) + ) weighted_sum = nn.emit(relax.op.sum(unflattened * expert_weights, axis=1)) # reshape back to 3D @@ -737,20 +819,16 @@ def post_self_attn(self, hidden_states, residual): residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) - if not self.use_moe: - # Fully Connected - hidden_states = self.mlp(hidden_states) - if self.mlp.num_shards > 1: - residual = nn.emit( - residual / R.const(self.mlp.num_shards, dtype=residual.struct_info.dtype) - ) - hidden_states = nn.emit(residual + hidden_states) - if self.mlp.num_shards > 1: - hidden_states = nn.emit(ccl.allreduce(hidden_states, "sum")) - else: - # TODO: disco integration - hidden_states = self.feed_forward(hidden_states) - hidden_states = nn.emit(residual + hidden_states) + model = self.feed_forward if self.use_moe else self.mlp + + hidden_states = model(hidden_states) + if model.num_shards > 1: + residual = nn.emit( + residual / R.const(model.num_shards, dtype=residual.struct_info.dtype) + ) + hidden_states = nn.emit(residual + hidden_states) + if model.num_shards > 1: + hidden_states = nn.emit(ccl.allreduce(hidden_states, "sum")) return hidden_states @@ -1370,12 +1448,56 @@ def kv_cache_transpose_append( def setup_params(mod, param_manager, dtype, config, args): + mappings = [ + ("embed_tokens", "tok_embeddings"), + ("lm_head", "output"), + ("input_layernorm", "attention_norm"), + ("self_attn", "attention"), + ("post_attention_layernorm", "ffn_norm"), + ("o_proj", "wo"), + ("q_proj", "wq"), + ("k_proj", "wk"), + ("v_proj", "wv"), + ("gate_proj", "w1"), + ("down_proj", "w2"), + ("up_proj", "w3"), + ] + + assert isinstance(config, MixtralConfig) + def f_convert_pname_fwd(pname: str) -> List[str]: + qkv_str = "query_key_value_proj" + gate_up_str = "gate_up_proj" + + assert isinstance(config, MixtralConfig) + if isinstance(config, MixtralConfig): + for k, v in mappings: + pname = pname.replace(k, v) + pname = pname.replace("model.", "") + + if config.combine_matmul: + if qkv_str in pname: + return [ + pname.replace(qkv_str, "wq"), + pname.replace(qkv_str, "wk"), + pname.replace(qkv_str, "wv"), + ] + if "experts.gate_up_combined_proj" in pname: + return [ + pname.replace("experts.gate_up_combined_proj", f"experts.{i}.w1") + for i in range(config.num_experts) + ] + [ + pname.replace("experts.gate_up_combined_proj", f"experts.{i}.w3") + for i in range(config.num_experts) + ] + + if "experts" in pname: + # not needed if using combine_matmul + return [pname.replace("experts", f"experts.{i}") for i in range(config.num_experts)] + if not config.combine_matmul: return [pname] - qkv_str = "query_key_value_proj" - gate_up_str = "gate_up_proj" if qkv_str in pname: return [ pname.replace(qkv_str, "q_proj"), @@ -1391,10 +1513,18 @@ def f_convert_pname_fwd(pname: str) -> List[str]: return [pname] def f_convert_param_bkwd(torch_pname: str, torch_param): + if isinstance(config, MixtralConfig): + if "experts" in torch_pname: + return None + for v, k in mappings: + torch_pname = torch_pname.replace(k, v) + if "lm_head" not in torch_pname: + torch_pname = "model." + torch_pname if not config.combine_matmul: return [(torch_pname, torch_param.astype(dtype))] combined_layers = ["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj"] + # print('bkwd pname ', torch_pname) if any([name in torch_pname for name in combined_layers]): return None return [(torch_pname, torch_param.astype(dtype))] @@ -1403,6 +1533,29 @@ def f_compute_relax_param(relax_pname: str, torch_params: List[Any]): # Expected to enter this function only for the combined linear matmul weights. # Other weights are supposed to be loaded in `f_convert_param_bkwd` since # each other relax param has a unique corresponding torch param. + if isinstance(config, MixtralConfig): + if "gate_up_combined_proj" in relax_pname: + # combine along out_features dimension and then experts dimension + experts = [] + assert len(torch_params) == 2 * config.num_experts + + for i in range(config.num_experts): + gate, up = ( + torch_params[i], + torch_params[i + config.num_experts], + ) # torch weight in col major + gate_up = np.concatenate([gate, up], axis=0).astype(dtype) + experts.append(gate_up.transpose()) + result = np.stack(experts) + return result + if "experts" in relax_pname: + experts = [expert.astype(dtype).transpose() for expert in torch_params] + result = np.stack(experts) + # torch_params = [torch.from_numpy(param).cuda() for param in torch_params] + # experts = [expert.type(dtype).transpose(1, 0) for expert in torch_params] + # result = torch.stack(experts).detach().numpy() + return result + if not config.combine_matmul: # When matmul combination is not turned on, each relax param has a unique # corresponding torch param, and this function is not expected to be entered. @@ -1454,6 +1607,30 @@ def f_compute_relax_param(relax_pname: str, torch_params: List[Any]): return mod, param_manager, param_list, config +def get_scatter_func(dtype): + import tvm.script.tir as T + + @T.prim_func + def scatter_func( + x_handle: T.handle, + indices_handle: T.handle, + out_handle: T.handle, + ) -> None: + total_rows = T.int64() + hidden_size = T.int64() + x = T.match_buffer(x_handle, (total_rows, hidden_size), dtype) + indices = T.match_buffer(indices_handle, (total_rows,), "int32") + out = T.match_buffer(out_handle, (total_rows, hidden_size), dtype) + T.func_attr({"global_symbol": "scatter", "tir.noalias": True}) + for i in range(total_rows): + for j in range(hidden_size): + with T.block("scatter"): + vi, vj = T.axis.remap("SS", [i, j]) + out[indices[vi], vj] = x[vi, vj] + + return scatter_func + + def get_model(args, hf_config): model_name = args.model dtype = args.quantization.model_dtype @@ -1472,9 +1649,9 @@ def get_model(args, hf_config): # while Llama-1 variants use `max_sequence_length`. # Thus, use `max_sequence_length` if defined. Otherwise, use `max_position_embeddings`. # If none of them is defined, throw an error. - if 'mixtral' in args.model: + if "mixtral" in args.model: # FIXME - hf_config['max_sequence_length'] = 4096 + hf_config["max_sequence_length"] = 4096 # hf_config['num_attention_heads'] = config = MixtralConfig( **hf_config, @@ -1483,6 +1660,7 @@ def get_model(args, hf_config): combine_matmul=True, num_shards=args.num_shards, build_model_only=args.build_model_only, + quantization_scheme=args.quantization, ) elif "max_sequence_length" in hf_config: config = LlamaConfig( @@ -1515,6 +1693,8 @@ def get_model(args, hf_config): param_manager = ParamManager() bb = relax.BlockBuilder() + bb.add_func(get_scatter_func(dtype), "scatter") + if sep_embed: create_embed_func(bb, param_manager, config, args.quantization) From c0722775c626eb8efd50da2f1a487fc431372863 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 11 Dec 2023 16:39:18 +0000 Subject: [PATCH 03/10] fix --- CMakeLists.txt | 2 +- cpp/conv_templates.cc | 1 + mlc_llm/quantization/ft_quantization.py | 8 ++- mlc_llm/relax_model/llama.py | 86 +++++++++++++++++++------ mlc_llm/relax_model/param_manager.py | 4 +- 5 files changed, 78 insertions(+), 23 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index eb09469b48..5f9aea0b79 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,5 +1,5 @@ cmake_minimum_required(VERSION 3.18) -project(mlc_llm C CXX) +project(mlc_llm C CXX CUDA) include(CheckCXXCompilerFlag) if(NOT MSVC) diff --git a/cpp/conv_templates.cc b/cpp/conv_templates.cc index f1e74173bd..1ed5297ecc 100644 --- a/cpp/conv_templates.cc +++ b/cpp/conv_templates.cc @@ -674,6 +674,7 @@ Conversation Conversation::FromTemplate(const std::string& name) { {"wizardlm_7b", WizardLM7B}, {"wizard_coder_or_math", WizardCoderOrMATH}, {"glm", GLM}, + {"mixtral_default", Llama2} }; auto it = factory.find(name); if (it == factory.end()) { diff --git a/mlc_llm/quantization/ft_quantization.py b/mlc_llm/quantization/ft_quantization.py index 39a35d8547..c05c296178 100644 --- a/mlc_llm/quantization/ft_quantization.py +++ b/mlc_llm/quantization/ft_quantization.py @@ -8,11 +8,12 @@ from tvm.relax.expr_functor import visitor from . import tir_utils -from .quantization import QuantizationSpec, QuantSpecUpdater +from .quantization import QuantizationSpec, QuantSpecUpdater, NoQuantizationSpec from .quantization import FQuantize, convert_TE_func from .group_quantization import GroupQuantizationSpec + @dataclass class FTQuantizationSpec(QuantizationSpec): """The quantization specification for the FasterTransformer kernel.""" @@ -203,7 +204,10 @@ def visit_call_(self, call: relax.Call): param = self.param_map[rhs.args[0]] - if call.struct_info.dtype == "float32" or rhs.struct_info.shape[-1] % 8 != 0: + if isinstance(rhs.struct_info.shape[-1], tvm.tir.IntImm) and int(rhs.struct_info.shape[-1]) <= 8: + # gate in MoE + param.quant_spec = NoQuantizationSpec("float16") + elif call.struct_info.dtype == "float32" or rhs.struct_info.shape[-1] % 8 != 0: # FT requires N to be a multiple of 8 # FT does not support fp32 output dtype # TODO(masahi): If `matmul(..., out_dtype="float32")` is immediately followed diff --git a/mlc_llm/relax_model/llama.py b/mlc_llm/relax_model/llama.py index 3a2502c27a..a5858e65e3 100644 --- a/mlc_llm/relax_model/llama.py +++ b/mlc_llm/relax_model/llama.py @@ -7,7 +7,7 @@ from tvm import relax, te, tir from tvm.relax.op import ccl from tvm.relax.testing import nn -from tvm.script import relax as R +from tvm.script import relax as R, tir as T from ..quantization import ParamQuantKind, QuantizationScheme from .commons import create_metadata_func @@ -594,18 +594,16 @@ def __init__(self, config: MixtralConfig, num_experts, in_features, out_features self.weight = nn.Parameter( (num_experts, in_features, out_features), dtype="float16", - name="expert_weight", ) elif config.quantization_scheme.name == "q4f16_ft": + assert out_features % 8 == 0 self.weight = nn.Parameter( - (num_experts, in_features, out_features), + (num_experts, in_features, out_features // 2), dtype="int8", - name="expert_weight", ) self.scales = nn.Parameter( (num_experts, out_features), dtype="float16", - name="expert_scales", ) else: assert False, "unsupported quantization scheme" @@ -632,6 +630,26 @@ def forward(self, x, rows_before): ), ) ) + else: + return nn.emit( + relax.call_dps_packed( + "cutlass.moe_gemm_s4f16", + [ + x, + self.weight, + self.scales, + rows_before, + total_rows, + self.out_features, # gemm_n + self.in_features, # gemm_k + self.num_experts, + ], + out_sinfo=relax.TensorStructInfo( + (total_rows, self.out_features), + x.struct_info.dtype, + ), + ) + ) class MoEMLP(nn.Module): @@ -720,8 +738,8 @@ def topk(self, x, is_ascend, index_dtype, k=-1): beg = [0] * ndim end = [x.struct_info.shape[i] for i in range(ndim - 1)] + [k] axes = list(range(ndim)) - sorted_x = nn.emit(relax.op.strided_slice(sorted_x, axes, beg, end)) - indices = nn.emit(relax.op.strided_slice(indices, axes, beg, end)) + sorted_x = nn.emit(relax.op.strided_slice(sorted_x, axes, beg, end, assume_inbound=True)) + indices = nn.emit(relax.op.strided_slice(indices, axes, beg, end, assume_inbound=True)) return sorted_x, indices def compute_rows_before(self, sorted_expert_ids): @@ -1142,8 +1160,8 @@ def create_prefill_func_for_single_seq( func_name = "prefill_with_embed" if sep_embed else "prefill" bsz = 1 - seq_len = tvm.tir.Var("n", "int64") - all_seq_len = tvm.tir.Var("m", "int64") + seq_len = tvm.tir.SizeVar("n", "int64") + all_seq_len = tvm.tir.SizeVar("m", "int64") hidden_size = config.hidden_size with bb.function(func_name): model = LlamaForCausalLM( @@ -1474,6 +1492,9 @@ def f_convert_pname_fwd(pname: str) -> List[str]: for k, v in mappings: pname = pname.replace(k, v) pname = pname.replace("model.", "") + if config.quantization_scheme.name == "q4f16_ft": + if pname.endswith("scales"): + pname = pname.replace("scales", "weight") if config.combine_matmul: if qkv_str in pname: @@ -1529,6 +1550,17 @@ def f_convert_param_bkwd(torch_pname: str, torch_param): return None return [(torch_pname, torch_param.astype(dtype))] + def quantize(experts, relax_pname): + print("quantizing experts", relax_pname) + func = tvm.get_global_func("cutlass.symmetric_quantize") + nd_experts = tvm.nd.array(experts) + qweight, qscale = func(nd_experts, True) + if relax_pname.endswith("weight"): + return qweight + else: + assert relax_pname.endswith("scales") + return qscale + def f_compute_relax_param(relax_pname: str, torch_params: List[Any]): # Expected to enter this function only for the combined linear matmul weights. # Other weights are supposed to be loaded in `f_convert_param_bkwd` since @@ -1539,14 +1571,31 @@ def f_compute_relax_param(relax_pname: str, torch_params: List[Any]): experts = [] assert len(torch_params) == 2 * config.num_experts - for i in range(config.num_experts): - gate, up = ( - torch_params[i], - torch_params[i + config.num_experts], - ) # torch weight in col major - gate_up = np.concatenate([gate, up], axis=0).astype(dtype) - experts.append(gate_up.transpose()) - result = np.stack(experts) + use_pytorch = True + if use_pytorch and dtype=='float16': + import torch + torch_params = [torch.from_numpy(param).cuda() for param in torch_params] + for i in range(config.num_experts): + gate, up = ( + torch_params[i], + torch_params[i + config.num_experts], + ) # torch weight in col major + gate_up = torch.concatenate([gate, up], axis=0).type(torch.float16) + experts.append(gate_up.transpose(1, 0)) + result = torch.stack(experts) + result = result.cpu().numpy() + else: + for i in range(config.num_experts): + gate, up = ( + torch_params[i], + torch_params[i + config.num_experts], + ) # torch weight in col major + gate_up = np.concatenate([gate, up], axis=0).astype(dtype) + experts.append(gate_up.transpose()) + result = np.stack(experts) + # print(config.quantization_scheme.name) + if config.quantization_scheme.name == "q4f16_ft" and 'experts' in relax_pname: + result = quantize(result, relax_pname) return result if "experts" in relax_pname: experts = [expert.astype(dtype).transpose() for expert in torch_params] @@ -1554,6 +1603,8 @@ def f_compute_relax_param(relax_pname: str, torch_params: List[Any]): # torch_params = [torch.from_numpy(param).cuda() for param in torch_params] # experts = [expert.type(dtype).transpose(1, 0) for expert in torch_params] # result = torch.stack(experts).detach().numpy() + if config.quantization_scheme.name == "q4f16_ft" and 'experts' in relax_pname: + result = quantize(result, relax_pname) return result if not config.combine_matmul: @@ -1608,7 +1659,6 @@ def f_compute_relax_param(relax_pname: str, torch_params: List[Any]): def get_scatter_func(dtype): - import tvm.script.tir as T @T.prim_func def scatter_func( diff --git a/mlc_llm/relax_model/param_manager.py b/mlc_llm/relax_model/param_manager.py index ae43f8d644..3f824ccae4 100644 --- a/mlc_llm/relax_model/param_manager.py +++ b/mlc_llm/relax_model/param_manager.py @@ -606,8 +606,8 @@ def get_item(i): relax_pname, [cached_torch_params[torch_pname] for torch_pname in torch_pnames], ) - for torch_pname in torch_pnames: - del cached_torch_params[torch_pname] + # for torch_pname in torch_pnames: + # del cached_torch_params[torch_pname] assert i in cached_relax_params assert i not in loaded_idx_set From 8e25c562a301dcbaccbb5a677f106cdf758d6fd6 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 11 Dec 2023 18:32:32 +0000 Subject: [PATCH 04/10] adopt hf weights --- mlc_llm/core.py | 22 +++---- mlc_llm/relax_model/llama.py | 119 ++++++++++++++++++++--------------- mlc_llm/utils.py | 1 + 3 files changed, 82 insertions(+), 60 deletions(-) diff --git a/mlc_llm/core.py b/mlc_llm/core.py index 208b8ffd5b..e930ebd58a 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -476,14 +476,14 @@ def _parse_args(parsed) -> argparse.Namespace: def _setup_model_path(args: argparse.Namespace): # pylint: disable=too-many-branches - if 'mixtral' in args.model: - if os.path.isdir(args.model): - args.model = os.path.normpath(args.model) # Remove potential trailing `/` - args.model_path = args.model - args.model = os.path.basename(args.model) - else: - args.model_path = os.path.join(args.artifact_path, "models", args.model) - return args + # if 'mixtral' in args.model: + # if os.path.isdir(args.model): + # args.model = os.path.normpath(args.model) # Remove potential trailing `/` + # args.model_path = args.model + # args.model = os.path.basename(args.model) + # else: + # args.model_path = os.path.join(args.artifact_path, "models", args.model) + # return args if args.hf_path: if args.model != "auto": assert args.model == os.path.basename(args.hf_path), ( @@ -846,9 +846,9 @@ def build_model_from_args(args: argparse.Namespace): if args.model_category == "minigpt": # Special case for minigpt, which neither provides nor requires a configuration. config = {} - elif "mixtral" in args.model: - with open(os.path.join(args.model_path, "params.json"), encoding="utf-8") as i_f: - config = json.load(i_f) + # elif "mixtral" in args.model: + # with open(os.path.join(args.model_path, "params.json"), encoding="utf-8") as i_f: + # config = json.load(i_f) else: with open(os.path.join(args.model_path, "config.json"), encoding="utf-8") as i_f: config = json.load(i_f) diff --git a/mlc_llm/relax_model/llama.py b/mlc_llm/relax_model/llama.py index a5858e65e3..20253c27e5 100644 --- a/mlc_llm/relax_model/llama.py +++ b/mlc_llm/relax_model/llama.py @@ -75,24 +75,27 @@ def get_num_key_value_heads(self): class MixtralConfig(LlamaConfig): num_experts_per_tok: int - num_experts: int + num_local_experts: int + sliding_window: int + router_aux_loss_coef: float # not sure if needed quantization_scheme: QuantizationScheme def __init__( self, **kwargs, ): - kwargs["num_attention_heads"] = kwargs["n_heads"] - kwargs["num_key_value_heads"] = kwargs["n_kv_heads"] - kwargs["rms_norm_eps"] = kwargs["norm_eps"] - kwargs["num_hidden_layers"] = kwargs["n_layers"] - kwargs["intermediate_size"] = kwargs["hidden_dim"] - kwargs["hidden_size"] = kwargs["dim"] # n heads * head_size + # kwargs["num_attention_heads"] = kwargs["n_heads"] + # kwargs["num_key_value_heads"] = kwargs["n_kv_heads"] + # kwargs["rms_norm_eps"] = kwargs["norm_eps"] + # kwargs["num_hidden_layers"] = kwargs["n_layers"] + # kwargs["intermediate_size"] = kwargs["hidden_dim"] + # kwargs["hidden_size"] = kwargs["dim"] # n heads * head_size super().__init__(**kwargs) - moe_config = kwargs["moe"] - self.num_experts_per_tok = moe_config["num_experts_per_tok"] - self.num_experts = moe_config["num_experts"] + self.num_experts_per_tok = kwargs["num_experts_per_tok"] + self.num_local_experts = kwargs["num_local_experts"] + self.sliding_window = kwargs["sliding_window"] + self.router_aux_loss_coef = kwargs["router_aux_loss_coef"] # FIXME: remove this self.quantization_scheme = kwargs["quantization_scheme"] @@ -655,7 +658,7 @@ def forward(self, x, rows_before): class MoEMLP(nn.Module): def __init__(self, config: MixtralConfig): self.num_experts_per_tok = config.num_experts_per_tok - self.num_experts = config.num_experts + self.num_experts = config.num_local_experts self.combine_matmul = config.combine_matmul self.num_shards = config.num_shards @@ -712,12 +715,12 @@ def __init__(self, config: MixtralConfig): self.num_shards = config.num_shards self.gate = Linear( in_features=config.hidden_size, - out_features=config.num_experts, + out_features=config.num_local_experts, bias=False, dtype=config.dtype, ) self.num_experts_per_tok = config.num_experts_per_tok - self.num_experts = config.num_experts + self.num_experts = config.num_local_experts def topk(self, x, is_ascend, index_dtype, k=-1): # topk along axis -1 @@ -738,7 +741,9 @@ def topk(self, x, is_ascend, index_dtype, k=-1): beg = [0] * ndim end = [x.struct_info.shape[i] for i in range(ndim - 1)] + [k] axes = list(range(ndim)) - sorted_x = nn.emit(relax.op.strided_slice(sorted_x, axes, beg, end, assume_inbound=True)) + sorted_x = nn.emit( + relax.op.strided_slice(sorted_x, axes, beg, end, assume_inbound=True) + ) indices = nn.emit(relax.op.strided_slice(indices, axes, beg, end, assume_inbound=True)) return sorted_x, indices @@ -784,6 +789,7 @@ def forward(self, hidden_states): expert_weights, expert_indices = self.topk( scores, is_ascend=False, k=self.num_experts_per_tok, index_dtype="int32" ) # (num_tokens, top_k), (num_tokens, top_k) + expert_weights = nn.emit(expert_weights / R.sum(expert_weights, axis=-1, keepdims=True)) flattened_indices = nn.emit(relax.op.flatten(expert_indices)) sorted_expert_ids, indices = self.topk( flattened_indices, is_ascend=True, index_dtype="int32" @@ -814,7 +820,7 @@ def __init__(self, config: LlamaConfig, enable_batching: bool): self.self_attn = attn_class(config) if isinstance(config, MixtralConfig): self.use_moe = True - self.feed_forward = MoE(config) + self.block_sparse_moe = MoE(config) else: self.use_moe = False self.mlp = LlamaMLP(config) @@ -837,7 +843,7 @@ def post_self_attn(self, hidden_states, residual): residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) - model = self.feed_forward if self.use_moe else self.mlp + model = self.block_sparse_moe if self.use_moe else self.mlp hidden_states = model(hidden_states) if model.num_shards > 1: @@ -1467,20 +1473,21 @@ def kv_cache_transpose_append( def setup_params(mod, param_manager, dtype, config, args): mappings = [ - ("embed_tokens", "tok_embeddings"), - ("lm_head", "output"), - ("input_layernorm", "attention_norm"), - ("self_attn", "attention"), - ("post_attention_layernorm", "ffn_norm"), - ("o_proj", "wo"), - ("q_proj", "wq"), - ("k_proj", "wk"), - ("v_proj", "wv"), - ("gate_proj", "w1"), - ("down_proj", "w2"), - ("up_proj", "w3"), + # ("embed_tokens", "tok_embeddings"), + # ("lm_head", "output"), + # ("input_layernorm", "attention_norm"), + # ("self_attn", "attention"), + # ("post_attention_layernorm", "ffn_norm"), + # ("o_proj", "wo"), + # ("q_proj", "wq"), + # ("k_proj", "wk"), + # ("v_proj", "wv"), + ("gate_proj", "w1"), + ("down_proj", "w2"), + ("up_proj", "w3"), ] + print(config) assert isinstance(config, MixtralConfig) def f_convert_pname_fwd(pname: str) -> List[str]: @@ -1491,30 +1498,44 @@ def f_convert_pname_fwd(pname: str) -> List[str]: if isinstance(config, MixtralConfig): for k, v in mappings: pname = pname.replace(k, v) - pname = pname.replace("model.", "") + # pname = pname.replace("model.", "") if config.quantization_scheme.name == "q4f16_ft": if pname.endswith("scales"): + # TODO: remove after quantization integarted pname = pname.replace("scales", "weight") if config.combine_matmul: if qkv_str in pname: return [ - pname.replace(qkv_str, "wq"), - pname.replace(qkv_str, "wk"), - pname.replace(qkv_str, "wv"), + # pname.replace(qkv_str, "wq"), + # pname.replace(qkv_str, "wk"), + # pname.replace(qkv_str, "wv"), + pname.replace(qkv_str, "q_proj"), + pname.replace(qkv_str, "k_proj"), + pname.replace(qkv_str, "v_proj"), ] if "experts.gate_up_combined_proj" in pname: return [ pname.replace("experts.gate_up_combined_proj", f"experts.{i}.w1") - for i in range(config.num_experts) + for i in range(config.num_local_experts) ] + [ pname.replace("experts.gate_up_combined_proj", f"experts.{i}.w3") - for i in range(config.num_experts) + for i in range(config.num_local_experts) ] + # return [ + # pname.replace("experts.gate_up_combined_proj", f"experts.{i}.gate_proj") + # for i in range(config.num_local_experts) + # ] + [ + # pname.replace("experts.gate_up_combined_proj", f"experts.{i}.up_proj") + # for i in range(config.num_local_experts) + # ] if "experts" in pname: # not needed if using combine_matmul - return [pname.replace("experts", f"experts.{i}") for i in range(config.num_experts)] + return [ + pname.replace("experts", f"experts.{i}") + for i in range(config.num_local_experts) + ] if not config.combine_matmul: return [pname] @@ -1539,8 +1560,8 @@ def f_convert_param_bkwd(torch_pname: str, torch_param): return None for v, k in mappings: torch_pname = torch_pname.replace(k, v) - if "lm_head" not in torch_pname: - torch_pname = "model." + torch_pname + # if "lm_head" not in torch_pname: + # torch_pname = "model." + torch_pname if not config.combine_matmul: return [(torch_pname, torch_param.astype(dtype))] @@ -1569,32 +1590,33 @@ def f_compute_relax_param(relax_pname: str, torch_params: List[Any]): if "gate_up_combined_proj" in relax_pname: # combine along out_features dimension and then experts dimension experts = [] - assert len(torch_params) == 2 * config.num_experts + assert len(torch_params) == 2 * config.num_local_experts use_pytorch = True - if use_pytorch and dtype=='float16': + if use_pytorch and dtype == "float16": import torch + torch_params = [torch.from_numpy(param).cuda() for param in torch_params] - for i in range(config.num_experts): + for i in range(config.num_local_experts): gate, up = ( torch_params[i], - torch_params[i + config.num_experts], + torch_params[i + config.num_local_experts], ) # torch weight in col major gate_up = torch.concatenate([gate, up], axis=0).type(torch.float16) experts.append(gate_up.transpose(1, 0)) result = torch.stack(experts) result = result.cpu().numpy() else: - for i in range(config.num_experts): + for i in range(config.num_local_experts): gate, up = ( torch_params[i], - torch_params[i + config.num_experts], + torch_params[i + config.num_local_experts], ) # torch weight in col major gate_up = np.concatenate([gate, up], axis=0).astype(dtype) experts.append(gate_up.transpose()) result = np.stack(experts) # print(config.quantization_scheme.name) - if config.quantization_scheme.name == "q4f16_ft" and 'experts' in relax_pname: + if config.quantization_scheme.name == "q4f16_ft" and "experts" in relax_pname: result = quantize(result, relax_pname) return result if "experts" in relax_pname: @@ -1603,7 +1625,7 @@ def f_compute_relax_param(relax_pname: str, torch_params: List[Any]): # torch_params = [torch.from_numpy(param).cuda() for param in torch_params] # experts = [expert.type(dtype).transpose(1, 0) for expert in torch_params] # result = torch.stack(experts).detach().numpy() - if config.quantization_scheme.name == "q4f16_ft" and 'experts' in relax_pname: + if config.quantization_scheme.name == "q4f16_ft" and "experts" in relax_pname: result = quantize(result, relax_pname) return result @@ -1659,7 +1681,6 @@ def f_compute_relax_param(relax_pname: str, torch_params: List[Any]): def get_scatter_func(dtype): - @T.prim_func def scatter_func( x_handle: T.handle, @@ -1699,13 +1720,13 @@ def get_model(args, hf_config): # while Llama-1 variants use `max_sequence_length`. # Thus, use `max_sequence_length` if defined. Otherwise, use `max_position_embeddings`. # If none of them is defined, throw an error. - if "mixtral" in args.model: + print(args.model) + if "mixtral" in args.model.lower(): # FIXME - hf_config["max_sequence_length"] = 4096 - # hf_config['num_attention_heads'] = config = MixtralConfig( **hf_config, dtype=dtype, + max_sequence_length=hf_config["max_position_embeddings"], position_embedding_base=position_embedding_base, combine_matmul=True, num_shards=args.num_shards, diff --git a/mlc_llm/utils.py b/mlc_llm/utils.py index 57ffa69143..b3acdb3f1c 100644 --- a/mlc_llm/utils.py +++ b/mlc_llm/utils.py @@ -25,6 +25,7 @@ "gptj", "chatglm", "mistral", + "mixtral", "stablelm_epoch", ] ) From f279349c6224ce5e27a716278c9eb614021ff117 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 11 Dec 2023 19:36:44 +0000 Subject: [PATCH 05/10] fix --- mlc_llm/utils.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/mlc_llm/utils.py b/mlc_llm/utils.py index b3acdb3f1c..4df059b370 100644 --- a/mlc_llm/utils.py +++ b/mlc_llm/utils.py @@ -70,20 +70,20 @@ def argparse_postproc_common(args: argparse.Namespace) -> None: "rwkv_world": "rwkv_world", "minigpt": "minigpt", } - if "mixtral" in args.model_path: - args.model_category = "mixtral" - config = open(os.path.join(args.model_path, "params.json"), encoding="utf-8") + # if "mixtral" in args.model_path: + # args.model_category = "mixtral" + # config = open(os.path.join(args.model_path, "params.json"), encoding="utf-8") + # model_path_lower = args.model_path.lower() + # else: + try: + with open(os.path.join(args.model_path, "config.json"), encoding="utf-8") as i_f: + config = json.load(i_f) + args.model_category = config["model_type"] model_path_lower = args.model_path.lower() - else: - try: - with open(os.path.join(args.model_path, "config.json"), encoding="utf-8") as i_f: - config = json.load(i_f) - args.model_category = config["model_type"] - model_path_lower = args.model_path.lower() - if "rwkv" in model_path_lower and "world" in model_path_lower: - args.model_category = "rwkv_world" - except Exception: - args.model_category = "" + if "rwkv" in model_path_lower and "world" in model_path_lower: + args.model_category = "rwkv_world" + except Exception: + args.model_category = "" model = args.model.lower() if "rwkv" in model and "world" in model: model = "rwkv_world" From a39ca0160ef3dc75e693b1265edc52cf3a5d2311 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 11 Dec 2023 19:57:57 +0000 Subject: [PATCH 06/10] speedup weight preprocess --- mlc_llm/relax_model/llama.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/mlc_llm/relax_model/llama.py b/mlc_llm/relax_model/llama.py index 20253c27e5..8efe5c77d1 100644 --- a/mlc_llm/relax_model/llama.py +++ b/mlc_llm/relax_model/llama.py @@ -1620,8 +1620,15 @@ def f_compute_relax_param(relax_pname: str, torch_params: List[Any]): result = quantize(result, relax_pname) return result if "experts" in relax_pname: - experts = [expert.astype(dtype).transpose() for expert in torch_params] - result = np.stack(experts) + use_pytorch = True + if use_pytorch and dtype == "float16": + import torch + torch_params = [torch.from_numpy(param).cuda() for param in torch_params] + experts = torch.stack([expert.type(torch.float16).transpose(1, 0) for expert in torch_params]) + result = experts.cpu().numpy() + else: + experts = [expert.astype(dtype).transpose() for expert in torch_params] + result = np.stack(experts) # torch_params = [torch.from_numpy(param).cuda() for param in torch_params] # experts = [expert.type(dtype).transpose(1, 0) for expert in torch_params] # result = torch.stack(experts).detach().numpy() From 1f267eaefc5624d406ccbfbb9049dc3b76d932d8 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 11 Dec 2023 20:20:30 +0000 Subject: [PATCH 07/10] fix --- cpp/conv_templates.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/conv_templates.cc b/cpp/conv_templates.cc index 1ed5297ecc..a88e1bd39b 100644 --- a/cpp/conv_templates.cc +++ b/cpp/conv_templates.cc @@ -674,7 +674,7 @@ Conversation Conversation::FromTemplate(const std::string& name) { {"wizardlm_7b", WizardLM7B}, {"wizard_coder_or_math", WizardCoderOrMATH}, {"glm", GLM}, - {"mixtral_default", Llama2} + {"mixtral_default", MistralDefault} }; auto it = factory.find(name); if (it == factory.end()) { From fca1b397e017206645c0c36cff7a6203d4030f9b Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 11 Dec 2023 23:17:51 +0000 Subject: [PATCH 08/10] batch --- mlc_llm/core.py | 2 +- mlc_llm/relax_model/llama_batched_vllm.py | 18 +++++++++++++++++- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/mlc_llm/core.py b/mlc_llm/core.py index e930ebd58a..4f53d1a7ee 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -897,7 +897,7 @@ def build_model_from_args(args: argparse.Namespace): # Run pre-quantization if provided. args.model_path = param_manager.run_pre_quantize(args.model_path) param_manager.init_torch_pname_to_bin_name(args.use_safetensors) - parameter_transforms.append(param_manager.create_parameter_transformation()) + parameter_transforms.append(param_manager.create_parameter_transformation(optimize_parameter_order=False)) # disable to prevent errors # Run pre-sharding if required if args.num_shards > 1 and args.use_presharded_weights: diff --git a/mlc_llm/relax_model/llama_batched_vllm.py b/mlc_llm/relax_model/llama_batched_vllm.py index 2309bdd92e..b9303c2978 100644 --- a/mlc_llm/relax_model/llama_batched_vllm.py +++ b/mlc_llm/relax_model/llama_batched_vllm.py @@ -14,7 +14,9 @@ from .modules import ModuleList from .param_manager import ParamManager from .llama import ( + get_scatter_func, LlamaConfig, + MixtralConfig, Linear, Embedding, LlamaRMSNorm, @@ -613,7 +615,19 @@ def get_model(args, hf_config): # while Llama-1 variants use `max_sequence_length`. # Thus, use `max_sequence_length` if defined. Otherwise, use `max_position_embeddings`. # If none of them is defined, throw an error. - if "max_sequence_length" in hf_config: + if "mixtral" in args.model.lower(): + # FIXME + config = MixtralConfig( + **hf_config, + dtype=dtype, + max_sequence_length=hf_config["max_position_embeddings"], + position_embedding_base=position_embedding_base, + combine_matmul=True, + num_shards=args.num_shards, + build_model_only=args.build_model_only, + quantization_scheme=args.quantization, + ) + elif "max_sequence_length" in hf_config: config = LlamaConfig( **hf_config, dtype=dtype, @@ -647,6 +661,8 @@ def get_model(args, hf_config): # The CPU device to copy the result of relax.op.max(seq_lens) to CPU. cpu_dev = VDevice("llvm", 0, "global") + bb.add_func(get_scatter_func(dtype), "scatter") + create_evaluate_func(bb, param_manager, config, cpu_dev, args.quantization, sep_embed) create_encoding_func(bb, param_manager, config, cpu_dev, args.quantization, sep_embed) create_decoding_func(bb, param_manager, config, cpu_dev, args.quantization) From 4618de1600999d3f0ebdac0dd6fc6cdbe26d0ca7 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 12 Dec 2023 01:27:22 +0000 Subject: [PATCH 09/10] remove deadcode --- mlc_llm/relax_model/llama.py | 33 ++------------------------------- 1 file changed, 2 insertions(+), 31 deletions(-) diff --git a/mlc_llm/relax_model/llama.py b/mlc_llm/relax_model/llama.py index 8efe5c77d1..e6c438501c 100644 --- a/mlc_llm/relax_model/llama.py +++ b/mlc_llm/relax_model/llama.py @@ -77,25 +77,18 @@ class MixtralConfig(LlamaConfig): num_experts_per_tok: int num_local_experts: int sliding_window: int - router_aux_loss_coef: float # not sure if needed + # router_aux_loss_coef: float # not sure if needed quantization_scheme: QuantizationScheme def __init__( self, **kwargs, ): - # kwargs["num_attention_heads"] = kwargs["n_heads"] - # kwargs["num_key_value_heads"] = kwargs["n_kv_heads"] - # kwargs["rms_norm_eps"] = kwargs["norm_eps"] - # kwargs["num_hidden_layers"] = kwargs["n_layers"] - # kwargs["intermediate_size"] = kwargs["hidden_dim"] - # kwargs["hidden_size"] = kwargs["dim"] # n heads * head_size - super().__init__(**kwargs) self.num_experts_per_tok = kwargs["num_experts_per_tok"] self.num_local_experts = kwargs["num_local_experts"] self.sliding_window = kwargs["sliding_window"] - self.router_aux_loss_coef = kwargs["router_aux_loss_coef"] + # self.router_aux_loss_coef = kwargs["router_aux_loss_coef"] # FIXME: remove this self.quantization_scheme = kwargs["quantization_scheme"] @@ -782,7 +775,6 @@ def forward(self, hidden_states): # reshape to 2D hidden_states = nn.emit(relax.op.reshape(hidden_states, (-1, hidden_size))) - # TODO: switch topk softmax gate = self.gate(hidden_states) scores = nn.emit(relax.op.nn.softmax(gate, axis=-1)) @@ -1473,15 +1465,6 @@ def kv_cache_transpose_append( def setup_params(mod, param_manager, dtype, config, args): mappings = [ - # ("embed_tokens", "tok_embeddings"), - # ("lm_head", "output"), - # ("input_layernorm", "attention_norm"), - # ("self_attn", "attention"), - # ("post_attention_layernorm", "ffn_norm"), - # ("o_proj", "wo"), - # ("q_proj", "wq"), - # ("k_proj", "wk"), - # ("v_proj", "wv"), ("gate_proj", "w1"), ("down_proj", "w2"), ("up_proj", "w3"), @@ -1507,9 +1490,6 @@ def f_convert_pname_fwd(pname: str) -> List[str]: if config.combine_matmul: if qkv_str in pname: return [ - # pname.replace(qkv_str, "wq"), - # pname.replace(qkv_str, "wk"), - # pname.replace(qkv_str, "wv"), pname.replace(qkv_str, "q_proj"), pname.replace(qkv_str, "k_proj"), pname.replace(qkv_str, "v_proj"), @@ -1522,13 +1502,6 @@ def f_convert_pname_fwd(pname: str) -> List[str]: pname.replace("experts.gate_up_combined_proj", f"experts.{i}.w3") for i in range(config.num_local_experts) ] - # return [ - # pname.replace("experts.gate_up_combined_proj", f"experts.{i}.gate_proj") - # for i in range(config.num_local_experts) - # ] + [ - # pname.replace("experts.gate_up_combined_proj", f"experts.{i}.up_proj") - # for i in range(config.num_local_experts) - # ] if "experts" in pname: # not needed if using combine_matmul @@ -1560,8 +1533,6 @@ def f_convert_param_bkwd(torch_pname: str, torch_param): return None for v, k in mappings: torch_pname = torch_pname.replace(k, v) - # if "lm_head" not in torch_pname: - # torch_pname = "model." + torch_pname if not config.combine_matmul: return [(torch_pname, torch_param.astype(dtype))] From 113bd1873cb563151ed5675730be0e53560c7ab2 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 12 Dec 2023 03:08:33 +0000 Subject: [PATCH 10/10] cleanup --- mlc_llm/core.py | 11 ----------- mlc_llm/relax_model/llama.py | 8 ++------ mlc_llm/utils.py | 5 ----- 3 files changed, 2 insertions(+), 22 deletions(-) diff --git a/mlc_llm/core.py b/mlc_llm/core.py index 4f53d1a7ee..f2b8073192 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -476,14 +476,6 @@ def _parse_args(parsed) -> argparse.Namespace: def _setup_model_path(args: argparse.Namespace): # pylint: disable=too-many-branches - # if 'mixtral' in args.model: - # if os.path.isdir(args.model): - # args.model = os.path.normpath(args.model) # Remove potential trailing `/` - # args.model_path = args.model - # args.model = os.path.basename(args.model) - # else: - # args.model_path = os.path.join(args.artifact_path, "models", args.model) - # return args if args.hf_path: if args.model != "auto": assert args.model == os.path.basename(args.hf_path), ( @@ -846,9 +838,6 @@ def build_model_from_args(args: argparse.Namespace): if args.model_category == "minigpt": # Special case for minigpt, which neither provides nor requires a configuration. config = {} - # elif "mixtral" in args.model: - # with open(os.path.join(args.model_path, "params.json"), encoding="utf-8") as i_f: - # config = json.load(i_f) else: with open(os.path.join(args.model_path, "config.json"), encoding="utf-8") as i_f: config = json.load(i_f) diff --git a/mlc_llm/relax_model/llama.py b/mlc_llm/relax_model/llama.py index e6c438501c..57058edab5 100644 --- a/mlc_llm/relax_model/llama.py +++ b/mlc_llm/relax_model/llama.py @@ -1158,8 +1158,8 @@ def create_prefill_func_for_single_seq( func_name = "prefill_with_embed" if sep_embed else "prefill" bsz = 1 - seq_len = tvm.tir.SizeVar("n", "int64") - all_seq_len = tvm.tir.SizeVar("m", "int64") + seq_len = tvm.tir.Var("n", "int64") + all_seq_len = tvm.tir.Var("m", "int64") hidden_size = config.hidden_size with bb.function(func_name): model = LlamaForCausalLM( @@ -1470,14 +1470,11 @@ def setup_params(mod, param_manager, dtype, config, args): ("up_proj", "w3"), ] - print(config) - assert isinstance(config, MixtralConfig) def f_convert_pname_fwd(pname: str) -> List[str]: qkv_str = "query_key_value_proj" gate_up_str = "gate_up_proj" - assert isinstance(config, MixtralConfig) if isinstance(config, MixtralConfig): for k, v in mappings: pname = pname.replace(k, v) @@ -1698,7 +1695,6 @@ def get_model(args, hf_config): # while Llama-1 variants use `max_sequence_length`. # Thus, use `max_sequence_length` if defined. Otherwise, use `max_position_embeddings`. # If none of them is defined, throw an error. - print(args.model) if "mixtral" in args.model.lower(): # FIXME config = MixtralConfig( diff --git a/mlc_llm/utils.py b/mlc_llm/utils.py index 4df059b370..9217d0b19e 100644 --- a/mlc_llm/utils.py +++ b/mlc_llm/utils.py @@ -70,11 +70,6 @@ def argparse_postproc_common(args: argparse.Namespace) -> None: "rwkv_world": "rwkv_world", "minigpt": "minigpt", } - # if "mixtral" in args.model_path: - # args.model_category = "mixtral" - # config = open(os.path.join(args.model_path, "params.json"), encoding="utf-8") - # model_path_lower = args.model_path.lower() - # else: try: with open(os.path.join(args.model_path, "config.json"), encoding="utf-8") as i_f: config = json.load(i_f)