-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
274 changed files
with
44,169 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,11 @@ | ||
# k2 | ||
# k2_training | ||
|
||
### Launch Training | ||
``` | ||
bash scripts/pretrain_65b.sh | ||
``` | ||
|
||
### Converting Megatron Checkpoints to HuggingFace Format | ||
``` | ||
python convert_ckpt_to_hf.py --load_path <megatron_ckpt_dir> --save_path <huggingface_ckpt_dir> | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,193 @@ | ||
import os | ||
import fire | ||
import json | ||
import torch | ||
import copy | ||
from transformers import AutoConfig, LlamaForCausalLM | ||
from transformers.modeling_utils import no_init_weights | ||
|
||
|
||
HF_MODEL_NAME = "huggyllama/llama-65b" | ||
TENSOR_PARALLEL = 8 | ||
PIPELINE_PARALLEL = 4 | ||
N_HEADS = 64 | ||
VOCAB_SIZE = 32032 | ||
# TEST_DATA_FILENAME = '/mount/data/shuffled_data_chunks/chunk_0.jsonl' | ||
|
||
|
||
def load_from_chunks(models, param_name, src, dim, target_param_name, sd2load): | ||
if param_name == 'embedding': | ||
cat_weights = torch.cat([models[t]['embedding']['word_embeddings']['weight'] | ||
for t in range(TENSOR_PARALLEL)], dim=dim) | ||
sd2load[target_param_name] = cat_weights | ||
elif param_name == 'output_layer': | ||
cat_weights = torch.cat([models[t]['output_layer']['weight'] | ||
for t in range(TENSOR_PARALLEL)], dim=dim) | ||
sd2load[target_param_name] = cat_weights | ||
else: | ||
if isinstance(src, list): | ||
# chunks=2 for two components in the gated MLP layer | ||
chunks = [torch.chunk(models[t]['encoder'][param_name], | ||
chunks=2, dim=dim) for t in range(TENSOR_PARALLEL)] | ||
chunks = [torch.cat(c, dim=dim) for c in zip(*chunks)] | ||
for tpn, _c in zip(target_param_name, chunks): | ||
assert sd2load[tpn].size() == _c.size() | ||
sd2load[tpn] = _c | ||
elif dim == -1: | ||
sd2load[target_param_name] = models[0]['encoder'][param_name] | ||
else: | ||
if isinstance(target_param_name, list): | ||
#handle qkv: | ||
h1, h2 = src.shape | ||
reshaped_weights = torch.cat([ | ||
models[t]['encoder'][param_name] | ||
for t in range(TENSOR_PARALLEL) | ||
], dim=0).view(N_HEADS, -1, h2) | ||
chunked_reshaped_weights = torch.chunk( | ||
reshaped_weights, chunks=3, dim=1) # 3 for qkv | ||
for tpn, crw in zip( | ||
target_param_name, chunked_reshaped_weights): | ||
crw = crw.contiguous().view(-1, h2) | ||
assert sd2load[tpn].size() == crw.size() | ||
sd2load[tpn] = crw | ||
else: | ||
#handle attn.o_proj: | ||
cat_weights = torch.cat([models[t]['encoder'][param_name] | ||
for t in range(TENSOR_PARALLEL)], dim=dim) | ||
sd2load[target_param_name] = cat_weights | ||
|
||
|
||
def main(load_path='/mount/ckpts/llama-65b-mp/iter_0096923', | ||
save_path='/mount/ckpts/65b_ckpts_hf/iter_0096923'): | ||
with no_init_weights(): | ||
model = LlamaForCausalLM( | ||
config=AutoConfig.from_pretrained(HF_MODEL_NAME, vocab_size=VOCAB_SIZE)) | ||
hf_state_dict = model.state_dict() | ||
|
||
ret = [ | ||
[ | ||
{ | ||
'model': {'language_model': {'encoder': {}}}, | ||
'checkpoint_version': 2 | ||
} for _ in range(TENSOR_PARALLEL) | ||
] for _ in range(PIPELINE_PARALLEL) | ||
] | ||
|
||
for i in range(PIPELINE_PARALLEL): | ||
for j in range(TENSOR_PARALLEL): | ||
shard_name = f'mp_rank_{j:02d}_{i:03d}/' | ||
print(f'loading {os.path.join(load_path, shard_name)}') | ||
|
||
# os.makedirs(os.path.join(load_path, shard_name), exist_ok=True) | ||
ret[i][j] = torch.load( | ||
os.path.join(load_path, shard_name, 'model_optim_rng.pt'), | ||
map_location=torch.device('cpu') | ||
)['model']['language_model'] | ||
|
||
new_state_dict = copy.deepcopy(model.state_dict()) | ||
|
||
total = (len(hf_state_dict) - 3) // 9 | ||
step = total // PIPELINE_PARALLEL | ||
# i: PP dim index | ||
# j: encoder block index | ||
# k: encoder block index per PP dim | ||
for i in range(PIPELINE_PARALLEL): | ||
end = total if i == PIPELINE_PARALLEL - 1 else (i + 1) * step | ||
for j in range(i * step, end): | ||
k = j - i * step | ||
|
||
load_from_chunks( | ||
ret[i], | ||
param_name=f'layers.{k}.input_layernorm.weight', | ||
src=hf_state_dict[f'model.layers.{j}.input_layernorm.weight'], | ||
dim=-1, | ||
target_param_name=f'model.layers.{j}.input_layernorm.weight', | ||
sd2load=new_state_dict) | ||
|
||
load_from_chunks( | ||
ret[i], | ||
param_name=f'layers.{k}.self_attention.query_key_value.weight', | ||
src=hf_state_dict[f'model.layers.{j}.self_attn.q_proj.weight'], | ||
dim=0, | ||
target_param_name=[ | ||
f'model.layers.{j}.self_attn.q_proj.weight', | ||
f'model.layers.{j}.self_attn.k_proj.weight', | ||
f'model.layers.{j}.self_attn.v_proj.weight'], | ||
sd2load=new_state_dict) | ||
|
||
load_from_chunks( | ||
ret[i], | ||
param_name=f'layers.{k}.self_attention.dense.weight', | ||
src=hf_state_dict[f'model.layers.{j}.self_attn.o_proj.weight'], | ||
dim=1, | ||
target_param_name=f'model.layers.{j}.self_attn.o_proj.weight', | ||
sd2load=new_state_dict) | ||
|
||
load_from_chunks( | ||
ret[i], | ||
param_name=f'layers.{k}.post_attention_layernorm.weight', | ||
src=hf_state_dict[f'model.layers.{j}.post_attention_layernorm.weight'], | ||
dim=-1, | ||
target_param_name=f'model.layers.{j}.post_attention_layernorm.weight', | ||
sd2load=new_state_dict) | ||
|
||
load_from_chunks( | ||
ret[i], | ||
param_name=f'layers.{k}.mlp.dense_h_to_4h.weight', | ||
src=[ | ||
hf_state_dict[f'model.layers.{j}.mlp.gate_proj.weight'], | ||
hf_state_dict[f'model.layers.{j}.mlp.up_proj.weight'], | ||
], | ||
dim=0, | ||
target_param_name=[ | ||
f'model.layers.{j}.mlp.gate_proj.weight', | ||
f'model.layers.{j}.mlp.up_proj.weight'], | ||
sd2load=new_state_dict) | ||
|
||
load_from_chunks( | ||
ret[i], | ||
param_name=f'layers.{k}.mlp.dense_4h_to_h.weight', | ||
src=hf_state_dict[f'model.layers.{j}.mlp.down_proj.weight'], | ||
dim=1, | ||
target_param_name=f'model.layers.{j}.mlp.down_proj.weight', | ||
sd2load=new_state_dict) | ||
|
||
load_from_chunks( | ||
ret[0], | ||
param_name='embedding', | ||
src=hf_state_dict['model.embed_tokens.weight'], | ||
dim=0, | ||
target_param_name='model.embed_tokens.weight', | ||
sd2load=new_state_dict) | ||
|
||
load_from_chunks( | ||
ret[-1], | ||
param_name='final_layernorm.weight', | ||
src=hf_state_dict['model.norm.weight'], | ||
dim=-1, | ||
target_param_name='model.norm.weight', | ||
sd2load=new_state_dict) | ||
|
||
load_from_chunks( | ||
ret[-1], | ||
param_name='output_layer', | ||
src=hf_state_dict['lm_head.weight'], | ||
dim=0, | ||
target_param_name='lm_head.weight', | ||
sd2load=new_state_dict) | ||
|
||
model.load_state_dict(new_state_dict) | ||
model.save_pretrained(save_path, safe_serialization=False) | ||
print("Converting to HF Done !") | ||
|
||
# token_ids = json.loads(open(TEST_DATA_FILENAME).readline())['token_ids'] | ||
# input_ids = torch.tensor([token_ids]) | ||
# labels = torch.tensor([token_ids]) | ||
|
||
# model.eval() | ||
# output_recons = model(input_ids, labels=labels, output_hidden_states=True) | ||
# print("### recons loss: {}".format(output_recons.loss)) | ||
|
||
|
||
if __name__ == '__main__': | ||
fire.Fire(main) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. | ||
|
||
"""Pretrain GPT""" | ||
import os.path | ||
|
||
import torch | ||
from functools import partial | ||
from megatron import get_args | ||
from megatron import print_rank_0 | ||
from megatron import get_timers | ||
from megatron.core import tensor_parallel | ||
from megatron.core.enums import ModelType | ||
from megatron.model import GPTModel | ||
from megatron.training import pretrain | ||
from megatron.utils import get_ltor_masks_and_position_ids | ||
from megatron.utils import average_losses_across_data_parallel_group | ||
from megatron.arguments import core_transformer_config_from_args | ||
import datasets | ||
|
||
N_CHUNKS = 360 | ||
|
||
|
||
def model_provider(pre_process=True, post_process=True): | ||
"""Build the model.""" | ||
|
||
print_rank_0('building GPT model ...') | ||
config = core_transformer_config_from_args(get_args()) | ||
model = GPTModel( | ||
config, | ||
num_tokentypes=0, | ||
parallel_output=True, | ||
pre_process=pre_process, | ||
post_process=post_process) | ||
|
||
return model | ||
|
||
|
||
def get_batch(data_iterator): | ||
"""Generate a batch""" | ||
args = get_args() | ||
|
||
# Broadcast data. | ||
if data_iterator is not None: | ||
data = next(data_iterator) | ||
else: | ||
data = None | ||
data_b = tensor_parallel.broadcast_data( | ||
keys=['token_ids'], data=data, datatype=torch.int64) | ||
|
||
tokens = data_b['token_ids'].long() | ||
labels = torch.ones_like(tokens) | ||
labels[..., :-1] = tokens[..., 1:] | ||
tokens, labels = tokens.contiguous(), labels.contiguous() | ||
|
||
assert not any([ | ||
args.reset_position_ids, args.reset_attention_mask, args.eod_mask_loss]) | ||
|
||
# Get the masks and postition ids. | ||
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( | ||
data=tokens, | ||
eod_token=None, | ||
reset_position_ids=args.reset_position_ids, | ||
reset_attention_mask=args.reset_attention_mask, | ||
eod_mask_loss=args.eod_mask_loss) | ||
|
||
loss_mask[..., -1] = 0 | ||
|
||
return tokens, labels, loss_mask, attention_mask, position_ids | ||
|
||
|
||
def loss_func(loss_mask, output_tensor): | ||
losses = output_tensor.float() | ||
loss_mask = loss_mask.view(-1).float() | ||
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() | ||
|
||
# Reduce loss for logging. | ||
averaged_loss = average_losses_across_data_parallel_group([loss]) | ||
|
||
return loss, {'lm loss': averaged_loss[0]} | ||
|
||
|
||
def forward_step(data_iterator, model): | ||
"""Forward step.""" | ||
timers = get_timers() | ||
|
||
# Get the batch. | ||
timers('batch-generator', log_level=2).start() | ||
tokens, labels, loss_mask, attention_mask, position_ids = ( | ||
get_batch(data_iterator)) | ||
timers('batch-generator').stop() | ||
|
||
output_tensor = model( | ||
tokens, position_ids, attention_mask, labels=labels) | ||
|
||
return output_tensor, partial(loss_func, loss_mask) | ||
|
||
|
||
# TODO: use path and split portion provided in the data args. | ||
def my_train_valid_test_datasets_provider(num_samples): | ||
args = get_args() | ||
|
||
print_rank_0('building datasets using huggingface datasets...') | ||
|
||
latest_ckpt_iter = int(open(f'{args.save}/latest_checkpointed_iteration.txt').read().strip()) | ||
chunk_begin_idx = latest_ckpt_iter // args.save_interval | ||
print_rank_0(f'chunk_begin_idx = {chunk_begin_idx}') | ||
|
||
# while True: | ||
# ckpt_dir = \ | ||
# f'{args.save}/iter_{(chunk_begin_idx + 1) * args.save_interval:07d}' | ||
# if not os.path.exists(ckpt_dir): | ||
# print_rank_0(f'chunk_begin_idx = {chunk_begin_idx}') | ||
# break | ||
# else: | ||
# chunk_begin_idx += 1 | ||
# print_rank_0( | ||
# f'{ckpt_dir} exists. chunk_idx chanced to {chunk_begin_idx}.') | ||
|
||
chunk_idxes = ( | ||
list(range(16)) + list(range(32, 32 + 6)) + list(range(44, 44 + 8))) | ||
chunk_idxes.extend( | ||
[i for i in range(360) if i not in chunk_idxes]) | ||
chunk_idxes = chunk_idxes[:160] + chunk_idxes[161:] + chunk_idxes[160:161] | ||
chunk_idxes = chunk_idxes[:185] + chunk_idxes[186:] + chunk_idxes[185:186] | ||
|
||
chunk_idxes = chunk_idxes[:args.n_chunks] | ||
print_rank_0(f'chunk_idxes: {chunk_idxes}') | ||
print_rank_0(f'actual chunk_idxes this run: {chunk_idxes[chunk_begin_idx:]}') | ||
|
||
data_files = [ | ||
f"{args.data_base_path}/chunk_{chunk_idx}.jsonl" | ||
for chunk_idx in chunk_idxes | ||
] | ||
train_ds = datasets.load_dataset( | ||
"json", | ||
data_files=data_files, | ||
split='train', | ||
num_proc=min(args.n_chunks, os.cpu_count()), | ||
cache_dir='/mount/data/train_cache') | ||
train_ds = train_ds.with_format("np") | ||
|
||
print_rank_0(f'huggingface dataset built, size = {len(train_ds)}') | ||
|
||
valid_ds, test_ds = None, None | ||
print_rank_0("> finished creating pretrain datasets ...") | ||
return train_ds, valid_ds, test_ds | ||
|
||
|
||
if __name__ == "__main__": | ||
pretrain( | ||
train_valid_test_dataset_provider=my_train_valid_test_datasets_provider, | ||
model_provider=model_provider, | ||
model_type=ModelType.encoder_or_decoder, | ||
forward_step_func=forward_step, | ||
args_defaults={'tokenizer_type': 'LLaMATokenizer'}) |
Oops, something went wrong.