Skip to content

Commit

Permalink
add parsing hf model config for using gated linear in mlp
Browse files Browse the repository at this point in the history
  • Loading branch information
cli99 committed Jan 2, 2024
1 parent 5d2d2fc commit a8d968f
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 16 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ The pre-defined mappings are populated at the runtime from the model, GPU, and d

llm-analysis also supports retrieving `ModelConfig` from a model config json file path or Hugging Face with the model name .
- From a local model config json file, e.g., `python -m llm_analysis.analysis train --model_name=local_example_model.json`. Check the model configurations under the [model_configs](llm_analysis/model_configs) folder.
- From Hugging Face, e.g., use [`EleutherAI/gpt-neox-20b`](https://huggingface.co/EleutherAI/gpt-neox-20b) as `model_name` when calling the `train` or `infer` entry functions. `python -m llm_analysis.analysis train --model_name=EleutherAI/gpt-neox-20b --total_num_gpus 32 --ds_zero 3`
- From Hugging Face, e.g., use [`EleutherAI/gpt-neox-20b`](https://huggingface.co/EleutherAI/gpt-neox-20b) as `model_name` when calling the `train` or `infer` entry functions. `python -m llm_analysis.analysis train --model_name=EleutherAI/gpt-neox-20b --total_num_gpus 32 --ds_zero 3`. With this method, llm-analysis relies on `transformers` to find the corresponding model configuration on [huggingface.co/models](https://huggingface.co/models), meaning information of newer models only exist after certain version of the transformers library. To access latest models through their names, update the installed `transformers` package.

A list of handy commands is provided to query against the pre-defined mappings as well as Hugging Face, or to dump configurations. Run ```python -m llm_analysis.config --help``` for details.

Expand Down
2 changes: 1 addition & 1 deletion examples/llama2/run_infer.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,5 @@ python -m llm_analysis.analysis infer --model_name=${model_name} --gpu_name=${gp
--seq_len=${seq_len} --num_tokens_to_generate=${num_tokens_to_generate} --batch_size_per_gpu=${batch_size_per_gpu} \
--tp_size=${tp_size} \
--cost_per_gpu_hour=${cost_per_gpu_hour} \
--flops_efficiency=${flops_efficiency} --hbm_memory_efficiency=${hbm_memory_efficiency}
--flops_efficiency=${flops_efficiency} --hbm_memory_efficiency=${hbm_memory_efficiency} --log_level DEBUG
# --achieved_tflops=${achieved_tflops} --achieved_memory_bandwidth_GBs=${achieved_memory_bandwidth_GBs}
27 changes: 17 additions & 10 deletions llm_analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,9 @@ def get_num_params_per_layer_mlp(self) -> int:
Returns:
int: the number of parameters in the two MLP linear layers
"""
return (3 if self.model_config.mlp_gated_linear_units else 2) * self.model_config.hidden_dim * self.model_config.ffn_embed_dim * self.model_config.moe_num_experts
return (
3 if self.model_config.mlp_gated_linear_units else 2
) * self.model_config.hidden_dim * self.model_config.ffn_embed_dim * self.model_config.moe_num_experts

def get_num_params_per_layer_router(self) -> int:
if self.model_config.moe_num_experts > 1:
Expand Down Expand Up @@ -1662,8 +1664,8 @@ def inference(
)

if use_kv_cache:
if (batch_size_per_gpu * (seq_len + num_tokens_to_generate)
< self.get_pivot()):
if (batch_size_per_gpu *
(seq_len + num_tokens_to_generate) < self.get_pivot()):
logger.warning(
"kv_cache is only useful when batch_size *"
" (seq+num_tokens_to_generate)"
Expand Down Expand Up @@ -1771,6 +1773,11 @@ def inference(
"ep_size": self.parallelism_config.ep_size,
"pp_size": self.parallelism_config.pp_size,
"num_tokens_to_generate": num_tokens_to_generate,
"num_params_total": self.total_num_params,
"num_params_total_mlp": self.total_num_params_mlp,
"num_params_total_embedding": self.total_num_params_embedding,
"num_params_total_others": self.total_num_params_others,
"num_active_params_total": self.total_num_active_params,
"flops_efficiency": self.flops_efficiency,
"hbm_memory_efficiency": self.hbm_memory_efficiency,
"layernorm_dtype_bytes": layernorm_dtype_bytes,
Expand Down Expand Up @@ -1876,16 +1883,16 @@ def config_batch_size_and_gradient_accumulation_steps(
gradient_accumulation_steps = global_batch_size // (
batch_size_per_gpu * dp_size)
assert (global_batch_size % (batch_size_per_gpu * dp_size) == 0
and gradient_accumulation_steps
> 0), "no valid gradient_accumulation_steps, {assert_msg}"
and gradient_accumulation_steps > 0
), "no valid gradient_accumulation_steps, {assert_msg}"
elif global_batch_size and gradient_accumulation_steps:
# batch_size_per_gpu is None, the other two are not None
batch_size_per_gpu = global_batch_size // (
gradient_accumulation_steps * dp_size)
assert (global_batch_size %
(gradient_accumulation_steps * dp_size) == 0
and batch_size_per_gpu
> 0), "no valid batch_size_per_gpu, {assert_msg}"
and batch_size_per_gpu > 0
), "no valid batch_size_per_gpu, {assert_msg}"
elif batch_size_per_gpu and gradient_accumulation_steps or batch_size_per_gpu:
# batch_size_per_gpu is not None
if batch_size_per_gpu > max_batch_size_per_gpu:
Expand Down Expand Up @@ -1920,9 +1927,9 @@ def config_batch_size_and_gradient_accumulation_steps(
else:
# (global_batch_size and batch_size_per_gpu are None) or (all are None)
batch_size_per_gpu = max_batch_size_per_gpu
gradient_accumulation_steps = (1 if gradient_accumulation_steps
is None else
gradient_accumulation_steps)
gradient_accumulation_steps = (1 if
gradient_accumulation_steps is None
else gradient_accumulation_steps)
global_batch_size = (batch_size_per_gpu *
gradient_accumulation_steps *
self.parallelism_config.dp_size)
Expand Down
29 changes: 27 additions & 2 deletions llm_analysis/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def __post_init__(self):
elif self.ffn_embed_dim is None:
self.ffn_embed_dim = self.hidden_dim * self.expansion_ratio
elif self.expansion_ratio is None:
assert self.ffn_embed_dim % self.hidden_dim == 0, f"ffn_embed_dim ({self.ffn_embed_dim}) must be divisible by hidden_dim ({self.hidden_dim})"
self.expansion_ratio = self.ffn_embed_dim / self.hidden_dim

if self.num_key_value_heads is None:
Expand Down Expand Up @@ -199,19 +198,45 @@ def get_model_config_from_hf(name: str, ) -> ModelConfig:
"hf config does not have hidden_size or d_model, check the config.json file"
)

if hasattr(hf_config, "moe_num_experts"):
moe_num_experts = hf_config.moe_num_experts
elif hasattr(hf_config, "num_local_experts"):
moe_num_experts = hf_config.num_local_experts
else:
moe_num_experts = 1
logger.info(
"hf config does not have moe_num_experts or num_local_experts, setting moe_num_experts = 1 (not MoE model)"
)

if hasattr(hf_config, "ffn_embed_dim"):
ffn_embed_dim = hf_config.ffn_embed_dim
elif hasattr(hf_config, "intermediate_size"):
ffn_embed_dim = hf_config.intermediate_size
else:
ffn_embed_dim = None

if ffn_embed_dim:
expansion_ratio = ffn_embed_dim / hidden_dim
if expansion_ratio == 3.5:
mlp_gated_linear_units = True
else:
mlp_gated_linear_units = False

config = ModelConfig(
name=canonical_model_name(name),
max_seq_len=hf_config.max_position_embeddings if hasattr(
hf_config, "max_position_embeddings") else None,
num_layers=num_layers,
n_head=n_head,
hidden_dim=hidden_dim,
ffn_embed_dim=ffn_embed_dim,
vocab_size=hf_config.vocab_size,
model_type=hf_config.model_type
if hasattr(hf_config, "model_type") else None,
num_key_value_heads=hf_config.num_key_value_heads if hasattr(
hf_config, "num_key_value_heads") else None,
)
moe_num_experts=moe_num_experts,
mlp_gated_linear_units=mlp_gated_linear_units)
return config


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"vocab_size": 32000,
"max_seq_len": 4096,
"num_key_value_heads": 8,
"ffn_embed_dim": 32768,
"model_type": "llama"
"ffn_embed_dim": 28672,
"model_type": "llama",
"mlp_gated_linear_units": true
}

0 comments on commit a8d968f

Please sign in to comment.