diff --git a/CMakeLists.txt b/CMakeLists.txt index eb09469b48..5f9aea0b79 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,5 +1,5 @@ cmake_minimum_required(VERSION 3.18) -project(mlc_llm C CXX) +project(mlc_llm C CXX CUDA) include(CheckCXXCompilerFlag) if(NOT MSVC) diff --git a/cpp/conv_templates.cc b/cpp/conv_templates.cc index f1e74173bd..a88e1bd39b 100644 --- a/cpp/conv_templates.cc +++ b/cpp/conv_templates.cc @@ -674,6 +674,7 @@ Conversation Conversation::FromTemplate(const std::string& name) { {"wizardlm_7b", WizardLM7B}, {"wizard_coder_or_math", WizardCoderOrMATH}, {"glm", GLM}, + {"mixtral_default", MistralDefault} }; auto it = factory.find(name); if (it == factory.end()) { diff --git a/mlc_llm/core.py b/mlc_llm/core.py index 1bd0d0266a..f2b8073192 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -854,11 +854,13 @@ def build_model_from_args(args: argparse.Namespace): "rwkv": rwkv, "rwkv_world": rwkv, "chatglm": chatglm, + "mixtral": llama, } if args.use_vllm_attention: model_generators["llama"] = llama_batched_vllm model_generators["mistral"] = llama_batched_vllm + model_generators["mixtral"] = llama_batched_vllm assert args.model_category in model_generators, f"Model {args.model} not supported" @@ -884,7 +886,7 @@ def build_model_from_args(args: argparse.Namespace): # Run pre-quantization if provided. args.model_path = param_manager.run_pre_quantize(args.model_path) param_manager.init_torch_pname_to_bin_name(args.use_safetensors) - parameter_transforms.append(param_manager.create_parameter_transformation()) + parameter_transforms.append(param_manager.create_parameter_transformation(optimize_parameter_order=False)) # disable to prevent errors # Run pre-sharding if required if args.num_shards > 1 and args.use_presharded_weights: diff --git a/mlc_llm/quantization/ft_quantization.py b/mlc_llm/quantization/ft_quantization.py index 39a35d8547..c05c296178 100644 --- a/mlc_llm/quantization/ft_quantization.py +++ b/mlc_llm/quantization/ft_quantization.py @@ -8,11 +8,12 @@ from tvm.relax.expr_functor import visitor from . import tir_utils -from .quantization import QuantizationSpec, QuantSpecUpdater +from .quantization import QuantizationSpec, QuantSpecUpdater, NoQuantizationSpec from .quantization import FQuantize, convert_TE_func from .group_quantization import GroupQuantizationSpec + @dataclass class FTQuantizationSpec(QuantizationSpec): """The quantization specification for the FasterTransformer kernel.""" @@ -203,7 +204,10 @@ def visit_call_(self, call: relax.Call): param = self.param_map[rhs.args[0]] - if call.struct_info.dtype == "float32" or rhs.struct_info.shape[-1] % 8 != 0: + if isinstance(rhs.struct_info.shape[-1], tvm.tir.IntImm) and int(rhs.struct_info.shape[-1]) <= 8: + # gate in MoE + param.quant_spec = NoQuantizationSpec("float16") + elif call.struct_info.dtype == "float32" or rhs.struct_info.shape[-1] % 8 != 0: # FT requires N to be a multiple of 8 # FT does not support fp32 output dtype # TODO(masahi): If `matmul(..., out_dtype="float32")` is immediately followed diff --git a/mlc_llm/relax_model/commons.py b/mlc_llm/relax_model/commons.py index 3eb67b5b1c..26441c411b 100644 --- a/mlc_llm/relax_model/commons.py +++ b/mlc_llm/relax_model/commons.py @@ -81,6 +81,34 @@ def shard_gate_up_weight_scale(weight: relax.TensorStructInfo): func = te.create_prim_func([a, w]) return func + def moe_shard_k_weight_scale(weight: relax.TensorStructInfo): + (num_experts, red, spatial), dtype = weight.shape, weight.dtype + spatial, red = int(spatial), int(red) + if param_shape_is_already_sharded: + red *= num_shards + a = te.placeholder((num_experts, red, spatial), dtype=dtype) + w = topi.reshape(a, (num_experts, num_shards, red // num_shards, spatial)) + w = topi.transpose(w, (1, 0, 2, 3)) + func = te.create_prim_func([a, w]) + return func + + def moe_shard_gate_up_weight_scale(weight: relax.TensorStructInfo): + (num_experts, red, spatial), dtype = weight.shape, weight.dtype + spatial, red = int(spatial), int(red) + if param_shape_is_already_sharded: + spatial *= num_shards + a = te.placeholder((num_experts, red, spatial), dtype=dtype) + g = te.compute((num_experts, red, spatial // 2), lambda e, i, j: a[e, i, j]) + u = te.compute((num_experts, red, spatial // 2), lambda e, i, j: a[e, i, spatial // 2 + j]) + g = topi.reshape(g, (num_experts, red, num_shards, spatial // 2 // num_shards)) + u = topi.reshape(u, (num_experts, red, num_shards, spatial // 2 // num_shards)) + w = topi.concatenate((g, u), axis=3) + w = topi.reshape(w, (num_experts, red, num_shards, spatial // num_shards)) + w = topi.transpose(w, (2, 0, 1, 3)) + func = te.create_prim_func([a, w]) + return func + + # pylint: enable=invalid-name return { @@ -88,6 +116,8 @@ def shard_gate_up_weight_scale(weight: relax.TensorStructInfo): "shard_mlp_k": shard_k_weight_scale, "shard_o_proj_k": shard_k_weight_scale, "shard_gate_up": shard_gate_up_weight_scale, + "moe_shard_mlp_k": moe_shard_k_weight_scale, + "moe_shard_gate_up": moe_shard_gate_up_weight_scale, } diff --git a/mlc_llm/relax_model/llama.py b/mlc_llm/relax_model/llama.py index 88fde9509a..57058edab5 100644 --- a/mlc_llm/relax_model/llama.py +++ b/mlc_llm/relax_model/llama.py @@ -7,7 +7,7 @@ from tvm import relax, te, tir from tvm.relax.op import ccl from tvm.relax.testing import nn -from tvm.script import relax as R +from tvm.script import relax as R, tir as T from ..quantization import ParamQuantKind, QuantizationScheme from .commons import create_metadata_func @@ -73,6 +73,27 @@ def get_num_key_value_heads(self): return self.num_key_value_heads +class MixtralConfig(LlamaConfig): + num_experts_per_tok: int + num_local_experts: int + sliding_window: int + # router_aux_loss_coef: float # not sure if needed + quantization_scheme: QuantizationScheme + + def __init__( + self, + **kwargs, + ): + super().__init__(**kwargs) + self.num_experts_per_tok = kwargs["num_experts_per_tok"] + self.num_local_experts = kwargs["num_local_experts"] + self.sliding_window = kwargs["sliding_window"] + # self.router_aux_loss_coef = kwargs["router_aux_loss_coef"] + + # FIXME: remove this + self.quantization_scheme = kwargs["quantization_scheme"] + + class Linear(nn.Module): def __init__(self, in_features, out_features, dtype: str, bias=True): self.in_features = in_features @@ -556,12 +577,245 @@ def attention_fwd( return attn_output, past_key_values +class MoELinear(nn.Module): + def __init__(self, config: MixtralConfig, num_experts, in_features, out_features, bias=False): + assert not bias, "bias not supported" + self.num_experts = num_experts + self.in_features = in_features + self.out_features = out_features + self.quantization_scheme = config.quantization_scheme + + if config.quantization_scheme.name == "q0f16": + # weight is row major + self.weight = nn.Parameter( + (num_experts, in_features, out_features), + dtype="float16", + ) + elif config.quantization_scheme.name == "q4f16_ft": + assert out_features % 8 == 0 + self.weight = nn.Parameter( + (num_experts, in_features, out_features // 2), + dtype="int8", + ) + self.scales = nn.Parameter( + (num_experts, out_features), + dtype="float16", + ) + else: + assert False, "unsupported quantization scheme" + + def forward(self, x, rows_before): + assert len(x.struct_info.shape) == 2 + total_rows = x.struct_info.shape[0] + if self.quantization_scheme.name == "q0f16": + return nn.emit( + relax.call_dps_packed( + "cutlass.moe_gemm_f16f16", + [ + x, + self.weight, + rows_before, + total_rows, + self.out_features, # gemm_n + self.in_features, # gemm_k + self.num_experts, + ], + out_sinfo=relax.TensorStructInfo( + (total_rows, self.out_features), + x.struct_info.dtype, + ), + ) + ) + else: + return nn.emit( + relax.call_dps_packed( + "cutlass.moe_gemm_s4f16", + [ + x, + self.weight, + self.scales, + rows_before, + total_rows, + self.out_features, # gemm_n + self.in_features, # gemm_k + self.num_experts, + ], + out_sinfo=relax.TensorStructInfo( + (total_rows, self.out_features), + x.struct_info.dtype, + ), + ) + ) + + +class MoEMLP(nn.Module): + def __init__(self, config: MixtralConfig): + self.num_experts_per_tok = config.num_experts_per_tok + self.num_experts = config.num_local_experts + self.combine_matmul = config.combine_matmul + + self.num_shards = config.num_shards + hidden_size = config.hidden_size + intermediate_size = config.intermediate_size // self.num_shards + + self.down_proj = MoELinear( + config, self.num_experts, intermediate_size, hidden_size, bias=False + ) + if config.combine_matmul: + self.gate_up_combined_proj = MoELinear( + config, + self.num_experts, + hidden_size, + 2 * intermediate_size, + bias=False, + ) + # FIXME: rename to 'gate_up_proj' that's consistent with llama. using this name for now to avoid conflicting pname str replacing rules + # TODO: check sharding is correct, note that the weight is row major + self.gate_up_combined_proj.weight.shard_dim = 2 + self.gate_up_combined_proj.weight.shard_strategy = "moe_shard_gate_up" + self.down_proj.weight.shard_dim = 1 + self.down_proj.weight.shard_strategy = "moe_shard_mlp_k" + else: + self.gate_proj = MoELinear( + config, self.num_experts, config.hidden_size, config.intermediate_size, bias=False + ) + self.up_proj = MoELinear( + config, self.num_experts, config.hidden_size, config.intermediate_size, bias=False + ) + + def forward(self, hidden_states: relax.Expr, rows_before: relax.Expr): + # TODO: disco + if self.combine_matmul: + gate_up_results = nn.emit( + relax.op.split( + self.gate_up_combined_proj(hidden_states, rows_before), + indices_or_sections=2, + axis=-1, + ) + ) + gate_result = relax.TupleGetItem(gate_up_results, 0) + up_result = relax.TupleGetItem(gate_up_results, 1) + else: + gate_result = self.gate_proj(hidden_states, rows_before) + up_result = self.up_proj(hidden_states, rows_before) + result = self.down_proj(nn.emit(relax.op.nn.silu(gate_result) * up_result), rows_before) + return result + + +class MoE(nn.Module): + def __init__(self, config: MixtralConfig): + self.experts = MoEMLP(config) + self.num_shards = config.num_shards + self.gate = Linear( + in_features=config.hidden_size, + out_features=config.num_local_experts, + bias=False, + dtype=config.dtype, + ) + self.num_experts_per_tok = config.num_experts_per_tok + self.num_experts = config.num_local_experts + + def topk(self, x, is_ascend, index_dtype, k=-1): + # topk along axis -1 + result = nn.emit( + relax.call_dps_packed( + "tvm.contrib.thrust.sort_dps", + [x, is_ascend], + out_sinfo=[ + x.struct_info, + relax.TensorStructInfo(x.struct_info.shape, index_dtype), + ], + ) + ) + sorted_x = relax.TupleGetItem(result, 0) + indices = relax.TupleGetItem(result, 1) + if k != -1: + ndim = len(x.struct_info.shape) + beg = [0] * ndim + end = [x.struct_info.shape[i] for i in range(ndim - 1)] + [k] + axes = list(range(ndim)) + sorted_x = nn.emit( + relax.op.strided_slice(sorted_x, axes, beg, end, assume_inbound=True) + ) + indices = nn.emit(relax.op.strided_slice(indices, axes, beg, end, assume_inbound=True)) + return sorted_x, indices + + def compute_rows_before(self, sorted_expert_ids): + return nn.emit( + relax.call_dps_packed( + "moe_compute_rows_before", + [sorted_expert_ids], + out_sinfo=relax.TensorStructInfo([self.num_experts], "int64"), + ) + ) + + def scatter(self, linear_out, indices): + return nn.emit( + relax.call_dps_packed( + "scatter", + [linear_out, indices], + out_sinfo=linear_out.struct_info, + ) + ) + + def get_token_indices(self, indices): + def te_compute(x): + return tvm.te.compute( + x.shape, + lambda *idx: tvm.tir.indexdiv( + x(*idx), tvm.runtime.const(self.num_experts_per_tok, dtype="int32") + ).astype("int32"), + ) + + return nn.emit_te(te_compute, indices) + + def forward(self, hidden_states): + hidden_states_shape = hidden_states.struct_info.shape + hidden_size = hidden_states_shape[-1] + # reshape to 2D + hidden_states = nn.emit(relax.op.reshape(hidden_states, (-1, hidden_size))) + + gate = self.gate(hidden_states) + scores = nn.emit(relax.op.nn.softmax(gate, axis=-1)) + + expert_weights, expert_indices = self.topk( + scores, is_ascend=False, k=self.num_experts_per_tok, index_dtype="int32" + ) # (num_tokens, top_k), (num_tokens, top_k) + expert_weights = nn.emit(expert_weights / R.sum(expert_weights, axis=-1, keepdims=True)) + flattened_indices = nn.emit(relax.op.flatten(expert_indices)) + sorted_expert_ids, indices = self.topk( + flattened_indices, is_ascend=True, index_dtype="int32" + ) + + rows_before = self.compute_rows_before(sorted_expert_ids) + token_indices = self.get_token_indices(indices) + gathered_x = nn.emit(relax.op.take(hidden_states, token_indices, axis=0)) + linear_out = self.experts(gathered_x, rows_before) + unpermuted = self.scatter(linear_out, indices) + unflattened = nn.emit( + relax.op.reshape(unpermuted, (-1, self.num_experts_per_tok, hidden_size)) + ) + expert_weights = nn.emit( + relax.op.reshape(expert_weights, (-1, self.num_experts_per_tok, 1)) + ) + weighted_sum = nn.emit(relax.op.sum(unflattened * expert_weights, axis=1)) + + # reshape back to 3D + weighted_sum = nn.emit(relax.op.reshape(weighted_sum, hidden_states_shape)) + return weighted_sum + + class LlamaDecoderLayer(nn.Module): def __init__(self, config: LlamaConfig, enable_batching: bool): attn_class = LlamaPagedAttention if enable_batching else LlamaAttention self.hidden_size = config.hidden_size self.self_attn = attn_class(config) - self.mlp = LlamaMLP(config) + if isinstance(config, MixtralConfig): + self.use_moe = True + self.block_sparse_moe = MoE(config) + else: + self.use_moe = False + self.mlp = LlamaMLP(config) self.input_layernorm = LlamaRMSNorm( config.hidden_size, dtype=config.dtype, eps=config.rms_norm_eps ) @@ -578,16 +832,18 @@ def post_self_attn(self, hidden_states, residual): if self.self_attn.num_shards > 1: hidden_states = nn.emit(ccl.allreduce(hidden_states, "sum")) - # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - if self.mlp.num_shards > 1: + + model = self.block_sparse_moe if self.use_moe else self.mlp + + hidden_states = model(hidden_states) + if model.num_shards > 1: residual = nn.emit( - residual / R.const(self.mlp.num_shards, dtype=residual.struct_info.dtype) + residual / R.const(model.num_shards, dtype=residual.struct_info.dtype) ) hidden_states = nn.emit(residual + hidden_states) - if self.mlp.num_shards > 1: + if model.num_shards > 1: hidden_states = nn.emit(ccl.allreduce(hidden_states, "sum")) return hidden_states @@ -1208,12 +1464,52 @@ def kv_cache_transpose_append( def setup_params(mod, param_manager, dtype, config, args): + mappings = [ + ("gate_proj", "w1"), + ("down_proj", "w2"), + ("up_proj", "w3"), + ] + + def f_convert_pname_fwd(pname: str) -> List[str]: + qkv_str = "query_key_value_proj" + gate_up_str = "gate_up_proj" + + if isinstance(config, MixtralConfig): + for k, v in mappings: + pname = pname.replace(k, v) + # pname = pname.replace("model.", "") + if config.quantization_scheme.name == "q4f16_ft": + if pname.endswith("scales"): + # TODO: remove after quantization integarted + pname = pname.replace("scales", "weight") + + if config.combine_matmul: + if qkv_str in pname: + return [ + pname.replace(qkv_str, "q_proj"), + pname.replace(qkv_str, "k_proj"), + pname.replace(qkv_str, "v_proj"), + ] + if "experts.gate_up_combined_proj" in pname: + return [ + pname.replace("experts.gate_up_combined_proj", f"experts.{i}.w1") + for i in range(config.num_local_experts) + ] + [ + pname.replace("experts.gate_up_combined_proj", f"experts.{i}.w3") + for i in range(config.num_local_experts) + ] + + if "experts" in pname: + # not needed if using combine_matmul + return [ + pname.replace("experts", f"experts.{i}") + for i in range(config.num_local_experts) + ] + if not config.combine_matmul: return [pname] - qkv_str = "query_key_value_proj" - gate_up_str = "gate_up_proj" if qkv_str in pname: return [ pname.replace(qkv_str, "q_proj"), @@ -1229,18 +1525,85 @@ def f_convert_pname_fwd(pname: str) -> List[str]: return [pname] def f_convert_param_bkwd(torch_pname: str, torch_param): + if isinstance(config, MixtralConfig): + if "experts" in torch_pname: + return None + for v, k in mappings: + torch_pname = torch_pname.replace(k, v) if not config.combine_matmul: return [(torch_pname, torch_param.astype(dtype))] combined_layers = ["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj"] + # print('bkwd pname ', torch_pname) if any([name in torch_pname for name in combined_layers]): return None return [(torch_pname, torch_param.astype(dtype))] + def quantize(experts, relax_pname): + print("quantizing experts", relax_pname) + func = tvm.get_global_func("cutlass.symmetric_quantize") + nd_experts = tvm.nd.array(experts) + qweight, qscale = func(nd_experts, True) + if relax_pname.endswith("weight"): + return qweight + else: + assert relax_pname.endswith("scales") + return qscale + def f_compute_relax_param(relax_pname: str, torch_params: List[Any]): # Expected to enter this function only for the combined linear matmul weights. # Other weights are supposed to be loaded in `f_convert_param_bkwd` since # each other relax param has a unique corresponding torch param. + if isinstance(config, MixtralConfig): + if "gate_up_combined_proj" in relax_pname: + # combine along out_features dimension and then experts dimension + experts = [] + assert len(torch_params) == 2 * config.num_local_experts + + use_pytorch = True + if use_pytorch and dtype == "float16": + import torch + + torch_params = [torch.from_numpy(param).cuda() for param in torch_params] + for i in range(config.num_local_experts): + gate, up = ( + torch_params[i], + torch_params[i + config.num_local_experts], + ) # torch weight in col major + gate_up = torch.concatenate([gate, up], axis=0).type(torch.float16) + experts.append(gate_up.transpose(1, 0)) + result = torch.stack(experts) + result = result.cpu().numpy() + else: + for i in range(config.num_local_experts): + gate, up = ( + torch_params[i], + torch_params[i + config.num_local_experts], + ) # torch weight in col major + gate_up = np.concatenate([gate, up], axis=0).astype(dtype) + experts.append(gate_up.transpose()) + result = np.stack(experts) + # print(config.quantization_scheme.name) + if config.quantization_scheme.name == "q4f16_ft" and "experts" in relax_pname: + result = quantize(result, relax_pname) + return result + if "experts" in relax_pname: + use_pytorch = True + if use_pytorch and dtype == "float16": + import torch + torch_params = [torch.from_numpy(param).cuda() for param in torch_params] + experts = torch.stack([expert.type(torch.float16).transpose(1, 0) for expert in torch_params]) + result = experts.cpu().numpy() + else: + experts = [expert.astype(dtype).transpose() for expert in torch_params] + result = np.stack(experts) + # torch_params = [torch.from_numpy(param).cuda() for param in torch_params] + # experts = [expert.type(dtype).transpose(1, 0) for expert in torch_params] + # result = torch.stack(experts).detach().numpy() + if config.quantization_scheme.name == "q4f16_ft" and "experts" in relax_pname: + result = quantize(result, relax_pname) + return result + if not config.combine_matmul: # When matmul combination is not turned on, each relax param has a unique # corresponding torch param, and this function is not expected to be entered. @@ -1292,6 +1655,28 @@ def f_compute_relax_param(relax_pname: str, torch_params: List[Any]): return mod, param_manager, param_list, config +def get_scatter_func(dtype): + @T.prim_func + def scatter_func( + x_handle: T.handle, + indices_handle: T.handle, + out_handle: T.handle, + ) -> None: + total_rows = T.int64() + hidden_size = T.int64() + x = T.match_buffer(x_handle, (total_rows, hidden_size), dtype) + indices = T.match_buffer(indices_handle, (total_rows,), "int32") + out = T.match_buffer(out_handle, (total_rows, hidden_size), dtype) + T.func_attr({"global_symbol": "scatter", "tir.noalias": True}) + for i in range(total_rows): + for j in range(hidden_size): + with T.block("scatter"): + vi, vj = T.axis.remap("SS", [i, j]) + out[indices[vi], vj] = x[vi, vj] + + return scatter_func + + def get_model(args, hf_config): model_name = args.model dtype = args.quantization.model_dtype @@ -1310,7 +1695,19 @@ def get_model(args, hf_config): # while Llama-1 variants use `max_sequence_length`. # Thus, use `max_sequence_length` if defined. Otherwise, use `max_position_embeddings`. # If none of them is defined, throw an error. - if "max_sequence_length" in hf_config: + if "mixtral" in args.model.lower(): + # FIXME + config = MixtralConfig( + **hf_config, + dtype=dtype, + max_sequence_length=hf_config["max_position_embeddings"], + position_embedding_base=position_embedding_base, + combine_matmul=True, + num_shards=args.num_shards, + build_model_only=args.build_model_only, + quantization_scheme=args.quantization, + ) + elif "max_sequence_length" in hf_config: config = LlamaConfig( **hf_config, dtype=dtype, @@ -1341,6 +1738,8 @@ def get_model(args, hf_config): param_manager = ParamManager() bb = relax.BlockBuilder() + bb.add_func(get_scatter_func(dtype), "scatter") + if sep_embed: create_embed_func(bb, param_manager, config, args.quantization) diff --git a/mlc_llm/relax_model/llama_batched_vllm.py b/mlc_llm/relax_model/llama_batched_vllm.py index 2309bdd92e..b9303c2978 100644 --- a/mlc_llm/relax_model/llama_batched_vllm.py +++ b/mlc_llm/relax_model/llama_batched_vllm.py @@ -14,7 +14,9 @@ from .modules import ModuleList from .param_manager import ParamManager from .llama import ( + get_scatter_func, LlamaConfig, + MixtralConfig, Linear, Embedding, LlamaRMSNorm, @@ -613,7 +615,19 @@ def get_model(args, hf_config): # while Llama-1 variants use `max_sequence_length`. # Thus, use `max_sequence_length` if defined. Otherwise, use `max_position_embeddings`. # If none of them is defined, throw an error. - if "max_sequence_length" in hf_config: + if "mixtral" in args.model.lower(): + # FIXME + config = MixtralConfig( + **hf_config, + dtype=dtype, + max_sequence_length=hf_config["max_position_embeddings"], + position_embedding_base=position_embedding_base, + combine_matmul=True, + num_shards=args.num_shards, + build_model_only=args.build_model_only, + quantization_scheme=args.quantization, + ) + elif "max_sequence_length" in hf_config: config = LlamaConfig( **hf_config, dtype=dtype, @@ -647,6 +661,8 @@ def get_model(args, hf_config): # The CPU device to copy the result of relax.op.max(seq_lens) to CPU. cpu_dev = VDevice("llvm", 0, "global") + bb.add_func(get_scatter_func(dtype), "scatter") + create_evaluate_func(bb, param_manager, config, cpu_dev, args.quantization, sep_embed) create_encoding_func(bb, param_manager, config, cpu_dev, args.quantization, sep_embed) create_decoding_func(bb, param_manager, config, cpu_dev, args.quantization) diff --git a/mlc_llm/relax_model/param_manager.py b/mlc_llm/relax_model/param_manager.py index 69a25ccb73..3f824ccae4 100644 --- a/mlc_llm/relax_model/param_manager.py +++ b/mlc_llm/relax_model/param_manager.py @@ -606,8 +606,8 @@ def get_item(i): relax_pname, [cached_torch_params[torch_pname] for torch_pname in torch_pnames], ) - for torch_pname in torch_pnames: - del cached_torch_params[torch_pname] + # for torch_pname in torch_pnames: + # del cached_torch_params[torch_pname] assert i in cached_relax_params assert i not in loaded_idx_set @@ -904,6 +904,13 @@ def load_torch_pname2binname_map( for relax_pname in relax_pnames for torch_pname in f_convert_pname_fwd(relax_pname) } + elif "mixtral" in model_path: + ckpt_path = os.path.join(model_path, "consolidated.00.pth") + torch_pname2binname = { + torch_pname: ckpt_path + for relax_pname in relax_pnames + for torch_pname in f_convert_pname_fwd(relax_pname) + } else: suffix = ".safetensors" if use_safetensors else ".bin" shard_names = [] diff --git a/mlc_llm/utils.py b/mlc_llm/utils.py index 7d34505da9..9217d0b19e 100644 --- a/mlc_llm/utils.py +++ b/mlc_llm/utils.py @@ -25,6 +25,7 @@ "gptj", "chatglm", "mistral", + "mixtral", "stablelm_epoch", ] )