Skip to content

Commit

Permalink
Llama tutorial fixes (#730)
Browse files Browse the repository at this point in the history
Llama tutorial fixes - all

Signed-off-by: Pawel Gadzinski <[email protected]>
Co-authored-by: Pawel Gadzinski <[email protected]>
  • Loading branch information
2 people authored and ptrendx committed Apr 1, 2024
1 parent 2c14d68 commit 297459b
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 22 deletions.
46 changes: 27 additions & 19 deletions docs/examples/te_llama/te_llama.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(self, config, *args, **kwargs):
normalization="RMSNorm",
activation="swiglu",
attn_input_format="bshd",
num_gqa_groups=config.num_key_value_heads,
num_gqa_groups=config.num_key_value_heads
)
te_rope = RotaryPositionEmbedding(config.hidden_size//config.num_attention_heads)
self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda()
Expand Down Expand Up @@ -121,53 +121,61 @@ def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **k
assert not isinstance(resolved_archive_file, list)
resolved_archive_file = [resolved_archive_file]

error_msgs = []
for shard_file in resolved_archive_file:
state_dict = load_state_dict(shard_file)
replaced_layers = replace_params(state_dict, vanilla_model.state_dict())

error_msgs += _load_state_dict_into_model(vanilla_model, state_dict, start_prefix="")
# replace_params copies parameters relevant only to TransformerEngine
replace_params(state_dict, vanilla_model.state_dict(), config)
# _load_state_dict_into_model copies parameters other than those in TransformerEngine
_load_state_dict_into_model(vanilla_model, state_dict, start_prefix="")

# Force mem release. Taken from huggingface code
del state_dict
gc.collect()

return vanilla_model

def replace_params(hf_state_dict, te_state_dict):
def replace_params(hf_state_dict, te_state_dict, config):
# collect all layer prefixes to update
all_layer_prefixes = set()
for param_key in hf_state_dict.keys():
layer_prefix_pat = 'model.layers.\d+.'
m = re.match(layer_prefix_pat, param_key)
if m is not None:
all_layer_prefixes.add(m.group())



for layer_prefix in all_layer_prefixes:
# When loading weights into models with less number of layers, skip the
# copy if the corresponding layer doesn't exist in TE model
if layer_prefix + 'self_attention.layernorm_qkv.layer_norm_weight' in te_state_dict:
# copy if the corresponding layer doesn't exist in HF model
if layer_prefix + 'input_layernorm.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.layer_norm_weight'].data[:] = hf_state_dict[layer_prefix + 'input_layernorm.weight'].data[:]

if layer_prefix + 'self_attention.layernorm_qkv.query_weight' in te_state_dict:
if layer_prefix + 'self_attn.q_proj.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.query_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.q_proj.weight'].data[:]

if layer_prefix + 'self_attention.layernorm_qkv.key_weight' in te_state_dict:
if layer_prefix + 'self_attn.k_proj.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.key_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.k_proj.weight'].data[:]

if layer_prefix + 'self_attention.layernorm_qkv.value_weight' in te_state_dict:
if layer_prefix + 'self_attn.v_proj.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.value_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.v_proj.weight'].data[:]

if layer_prefix + 'self_attention.proj.weight' in te_state_dict:
if layer_prefix + 'self_attn.o_proj.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'self_attention.proj.weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.o_proj.weight'].data[:]

if layer_prefix + 'layernorm_mlp.layer_norm_weight' in te_state_dict:
if layer_prefix + 'post_attention_layernorm.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'layernorm_mlp.layer_norm_weight'].data[:] = hf_state_dict[layer_prefix + 'post_attention_layernorm.weight'].data[:]

if layer_prefix + 'layernorm_mlp.fc1_weight' in te_state_dict:
te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[:] = torch.cat((hf_state_dict[layer_prefix + 'mlp.gate_proj.weight'].data[:], hf_state_dict[layer_prefix + 'mlp.up_proj.weight'].data[:]), dim=0)

if layer_prefix + 'layernorm_mlp.fc2_weight' in te_state_dict:

# It may happen that gate_proj.weight and up_proj.weight will be in the different files, so we need to
# load them separately.
if layer_prefix + 'mlp.gate_proj.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[:config.intermediate_size] = \
hf_state_dict[layer_prefix + 'mlp.gate_proj.weight'].data

if layer_prefix + 'mlp.up_proj.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[config.intermediate_size:] = \
hf_state_dict[layer_prefix + 'mlp.up_proj.weight'].data

if layer_prefix + 'mlp.down_proj.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'layernorm_mlp.fc2_weight'].data[:] = hf_state_dict[layer_prefix + 'mlp.down_proj.weight'].data[:]

return all_layer_prefixes
9 changes: 6 additions & 3 deletions docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,8 @@
"\n",
"# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n",
"## !!! `model_name` attr must point to the location of the model weights !!!\n",
"## Weights can be downloaded from: https://llama.meta.com/llama-downloads/\n",
"## Weights can be downloaded from: https://llama.meta.com/llama-downloads/ and then coverted to the HuggingFace format.\n",
"## Instructions for conversion are available on the website https://ai.meta.com/blog/5-steps-to-getting-started-with-llama-2/ - steps 1 and 2.\n",
"hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n",
"hyperparams.mixed_precision = \"bf16\"\n",
"\n",
Expand Down Expand Up @@ -556,7 +557,8 @@
"\n",
"# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n",
"## !!! `model_name` attr must point to the location of the model weights !!!\n",
"## Weights can be downloaded from: https://llama.meta.com/llama-downloads/\n",
"## Weights can be downloaded from: https://llama.meta.com/llama-downloads/ and then coverted to the HuggingFace format.\n",
"## Instructions for conversion are available on the website https://ai.meta.com/blog/5-steps-to-getting-started-with-llama-2/ - steps 1 and 2.\n",
"hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n",
"hyperparams.mixed_precision = \"bf16\"\n",
"\n",
Expand Down Expand Up @@ -635,7 +637,8 @@
"\n",
"# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n",
"## !!! `model_name` attr must point to the location of the model weights !!!\n",
"## Weights can be downloaded from: https://llama.meta.com/llama-downloads/\n",
"## Weights can be downloaded from: https://llama.meta.com/llama-downloads/ and then coverted to the HuggingFace format.\n",
"## Instructions for conversion are available on the website https://ai.meta.com/blog/5-steps-to-getting-started-with-llama-2/ - steps 1 and 2.\n",
"hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n",
"hyperparams.mixed_precision = \"fp8\"\n",
"\n",
Expand Down
1 change: 1 addition & 0 deletions docs/examples/te_llama/utils.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def init_te_llama_model(hyperparams):
# Init the model
from te_llama import TELlamaForCausalLM
config = AutoConfig.from_pretrained(hyperparams.model_name)
config._attn_implementation = "flash_attention_2"
model = TELlamaForCausalLM.from_pretrained_local(
hyperparams.model_name,
config=config,
Expand Down

0 comments on commit 297459b

Please sign in to comment.