Skip to content

Commit

Permalink
updated.
Browse files Browse the repository at this point in the history
  • Loading branch information
tanyuqian committed Apr 30, 2024
1 parent d0979e9 commit 351e8c3
Show file tree
Hide file tree
Showing 274 changed files with 44,169 additions and 1 deletion.
12 changes: 11 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,11 @@
# k2
# k2_training

### Launch Training
```
bash scripts/pretrain_65b.sh
```

### Converting Megatron Checkpoints to HuggingFace Format
```
python convert_ckpt_to_hf.py --load_path <megatron_ckpt_dir> --save_path <huggingface_ckpt_dir>
```
193 changes: 193 additions & 0 deletions convert_ckpt_to_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import os
import fire
import json
import torch
import copy
from transformers import AutoConfig, LlamaForCausalLM
from transformers.modeling_utils import no_init_weights


HF_MODEL_NAME = "huggyllama/llama-65b"
TENSOR_PARALLEL = 8
PIPELINE_PARALLEL = 4
N_HEADS = 64
VOCAB_SIZE = 32032
# TEST_DATA_FILENAME = '/mount/data/shuffled_data_chunks/chunk_0.jsonl'


def load_from_chunks(models, param_name, src, dim, target_param_name, sd2load):
if param_name == 'embedding':
cat_weights = torch.cat([models[t]['embedding']['word_embeddings']['weight']
for t in range(TENSOR_PARALLEL)], dim=dim)
sd2load[target_param_name] = cat_weights
elif param_name == 'output_layer':
cat_weights = torch.cat([models[t]['output_layer']['weight']
for t in range(TENSOR_PARALLEL)], dim=dim)
sd2load[target_param_name] = cat_weights
else:
if isinstance(src, list):
# chunks=2 for two components in the gated MLP layer
chunks = [torch.chunk(models[t]['encoder'][param_name],
chunks=2, dim=dim) for t in range(TENSOR_PARALLEL)]
chunks = [torch.cat(c, dim=dim) for c in zip(*chunks)]
for tpn, _c in zip(target_param_name, chunks):
assert sd2load[tpn].size() == _c.size()
sd2load[tpn] = _c
elif dim == -1:
sd2load[target_param_name] = models[0]['encoder'][param_name]
else:
if isinstance(target_param_name, list):
#handle qkv:
h1, h2 = src.shape
reshaped_weights = torch.cat([
models[t]['encoder'][param_name]
for t in range(TENSOR_PARALLEL)
], dim=0).view(N_HEADS, -1, h2)
chunked_reshaped_weights = torch.chunk(
reshaped_weights, chunks=3, dim=1) # 3 for qkv
for tpn, crw in zip(
target_param_name, chunked_reshaped_weights):
crw = crw.contiguous().view(-1, h2)
assert sd2load[tpn].size() == crw.size()
sd2load[tpn] = crw
else:
#handle attn.o_proj:
cat_weights = torch.cat([models[t]['encoder'][param_name]
for t in range(TENSOR_PARALLEL)], dim=dim)
sd2load[target_param_name] = cat_weights


def main(load_path='/mount/ckpts/llama-65b-mp/iter_0096923',
save_path='/mount/ckpts/65b_ckpts_hf/iter_0096923'):
with no_init_weights():
model = LlamaForCausalLM(
config=AutoConfig.from_pretrained(HF_MODEL_NAME, vocab_size=VOCAB_SIZE))
hf_state_dict = model.state_dict()

ret = [
[
{
'model': {'language_model': {'encoder': {}}},
'checkpoint_version': 2
} for _ in range(TENSOR_PARALLEL)
] for _ in range(PIPELINE_PARALLEL)
]

for i in range(PIPELINE_PARALLEL):
for j in range(TENSOR_PARALLEL):
shard_name = f'mp_rank_{j:02d}_{i:03d}/'
print(f'loading {os.path.join(load_path, shard_name)}')

# os.makedirs(os.path.join(load_path, shard_name), exist_ok=True)
ret[i][j] = torch.load(
os.path.join(load_path, shard_name, 'model_optim_rng.pt'),
map_location=torch.device('cpu')
)['model']['language_model']

new_state_dict = copy.deepcopy(model.state_dict())

total = (len(hf_state_dict) - 3) // 9
step = total // PIPELINE_PARALLEL
# i: PP dim index
# j: encoder block index
# k: encoder block index per PP dim
for i in range(PIPELINE_PARALLEL):
end = total if i == PIPELINE_PARALLEL - 1 else (i + 1) * step
for j in range(i * step, end):
k = j - i * step

load_from_chunks(
ret[i],
param_name=f'layers.{k}.input_layernorm.weight',
src=hf_state_dict[f'model.layers.{j}.input_layernorm.weight'],
dim=-1,
target_param_name=f'model.layers.{j}.input_layernorm.weight',
sd2load=new_state_dict)

load_from_chunks(
ret[i],
param_name=f'layers.{k}.self_attention.query_key_value.weight',
src=hf_state_dict[f'model.layers.{j}.self_attn.q_proj.weight'],
dim=0,
target_param_name=[
f'model.layers.{j}.self_attn.q_proj.weight',
f'model.layers.{j}.self_attn.k_proj.weight',
f'model.layers.{j}.self_attn.v_proj.weight'],
sd2load=new_state_dict)

load_from_chunks(
ret[i],
param_name=f'layers.{k}.self_attention.dense.weight',
src=hf_state_dict[f'model.layers.{j}.self_attn.o_proj.weight'],
dim=1,
target_param_name=f'model.layers.{j}.self_attn.o_proj.weight',
sd2load=new_state_dict)

load_from_chunks(
ret[i],
param_name=f'layers.{k}.post_attention_layernorm.weight',
src=hf_state_dict[f'model.layers.{j}.post_attention_layernorm.weight'],
dim=-1,
target_param_name=f'model.layers.{j}.post_attention_layernorm.weight',
sd2load=new_state_dict)

load_from_chunks(
ret[i],
param_name=f'layers.{k}.mlp.dense_h_to_4h.weight',
src=[
hf_state_dict[f'model.layers.{j}.mlp.gate_proj.weight'],
hf_state_dict[f'model.layers.{j}.mlp.up_proj.weight'],
],
dim=0,
target_param_name=[
f'model.layers.{j}.mlp.gate_proj.weight',
f'model.layers.{j}.mlp.up_proj.weight'],
sd2load=new_state_dict)

load_from_chunks(
ret[i],
param_name=f'layers.{k}.mlp.dense_4h_to_h.weight',
src=hf_state_dict[f'model.layers.{j}.mlp.down_proj.weight'],
dim=1,
target_param_name=f'model.layers.{j}.mlp.down_proj.weight',
sd2load=new_state_dict)

load_from_chunks(
ret[0],
param_name='embedding',
src=hf_state_dict['model.embed_tokens.weight'],
dim=0,
target_param_name='model.embed_tokens.weight',
sd2load=new_state_dict)

load_from_chunks(
ret[-1],
param_name='final_layernorm.weight',
src=hf_state_dict['model.norm.weight'],
dim=-1,
target_param_name='model.norm.weight',
sd2load=new_state_dict)

load_from_chunks(
ret[-1],
param_name='output_layer',
src=hf_state_dict['lm_head.weight'],
dim=0,
target_param_name='lm_head.weight',
sd2load=new_state_dict)

model.load_state_dict(new_state_dict)
model.save_pretrained(save_path, safe_serialization=False)
print("Converting to HF Done !")

# token_ids = json.loads(open(TEST_DATA_FILENAME).readline())['token_ids']
# input_ids = torch.tensor([token_ids])
# labels = torch.tensor([token_ids])

# model.eval()
# output_recons = model(input_ids, labels=labels, output_hidden_states=True)
# print("### recons loss: {}".format(output_recons.loss))


if __name__ == '__main__':
fire.Fire(main)
155 changes: 155 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

"""Pretrain GPT"""
import os.path

import torch
from functools import partial
from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
from megatron.core import tensor_parallel
from megatron.core.enums import ModelType
from megatron.model import GPTModel
from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import average_losses_across_data_parallel_group
from megatron.arguments import core_transformer_config_from_args
import datasets

N_CHUNKS = 360


def model_provider(pre_process=True, post_process=True):
"""Build the model."""

print_rank_0('building GPT model ...')
config = core_transformer_config_from_args(get_args())
model = GPTModel(
config,
num_tokentypes=0,
parallel_output=True,
pre_process=pre_process,
post_process=post_process)

return model


def get_batch(data_iterator):
"""Generate a batch"""
args = get_args()

# Broadcast data.
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
data_b = tensor_parallel.broadcast_data(
keys=['token_ids'], data=data, datatype=torch.int64)

tokens = data_b['token_ids'].long()
labels = torch.ones_like(tokens)
labels[..., :-1] = tokens[..., 1:]
tokens, labels = tokens.contiguous(), labels.contiguous()

assert not any([
args.reset_position_ids, args.reset_attention_mask, args.eod_mask_loss])

# Get the masks and postition ids.
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
data=tokens,
eod_token=None,
reset_position_ids=args.reset_position_ids,
reset_attention_mask=args.reset_attention_mask,
eod_mask_loss=args.eod_mask_loss)

loss_mask[..., -1] = 0

return tokens, labels, loss_mask, attention_mask, position_ids


def loss_func(loss_mask, output_tensor):
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()

# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])

return loss, {'lm loss': averaged_loss[0]}


def forward_step(data_iterator, model):
"""Forward step."""
timers = get_timers()

# Get the batch.
timers('batch-generator', log_level=2).start()
tokens, labels, loss_mask, attention_mask, position_ids = (
get_batch(data_iterator))
timers('batch-generator').stop()

output_tensor = model(
tokens, position_ids, attention_mask, labels=labels)

return output_tensor, partial(loss_func, loss_mask)


# TODO: use path and split portion provided in the data args.
def my_train_valid_test_datasets_provider(num_samples):
args = get_args()

print_rank_0('building datasets using huggingface datasets...')

latest_ckpt_iter = int(open(f'{args.save}/latest_checkpointed_iteration.txt').read().strip())
chunk_begin_idx = latest_ckpt_iter // args.save_interval
print_rank_0(f'chunk_begin_idx = {chunk_begin_idx}')

# while True:
# ckpt_dir = \
# f'{args.save}/iter_{(chunk_begin_idx + 1) * args.save_interval:07d}'
# if not os.path.exists(ckpt_dir):
# print_rank_0(f'chunk_begin_idx = {chunk_begin_idx}')
# break
# else:
# chunk_begin_idx += 1
# print_rank_0(
# f'{ckpt_dir} exists. chunk_idx chanced to {chunk_begin_idx}.')

chunk_idxes = (
list(range(16)) + list(range(32, 32 + 6)) + list(range(44, 44 + 8)))
chunk_idxes.extend(
[i for i in range(360) if i not in chunk_idxes])
chunk_idxes = chunk_idxes[:160] + chunk_idxes[161:] + chunk_idxes[160:161]
chunk_idxes = chunk_idxes[:185] + chunk_idxes[186:] + chunk_idxes[185:186]

chunk_idxes = chunk_idxes[:args.n_chunks]
print_rank_0(f'chunk_idxes: {chunk_idxes}')
print_rank_0(f'actual chunk_idxes this run: {chunk_idxes[chunk_begin_idx:]}')

data_files = [
f"{args.data_base_path}/chunk_{chunk_idx}.jsonl"
for chunk_idx in chunk_idxes
]
train_ds = datasets.load_dataset(
"json",
data_files=data_files,
split='train',
num_proc=min(args.n_chunks, os.cpu_count()),
cache_dir='/mount/data/train_cache')
train_ds = train_ds.with_format("np")

print_rank_0(f'huggingface dataset built, size = {len(train_ds)}')

valid_ds, test_ds = None, None
print_rank_0("> finished creating pretrain datasets ...")
return train_ds, valid_ds, test_ds


if __name__ == "__main__":
pretrain(
train_valid_test_dataset_provider=my_train_valid_test_datasets_provider,
model_provider=model_provider,
model_type=ModelType.encoder_or_decoder,
forward_step_func=forward_step,
args_defaults={'tokenizer_type': 'LLaMATokenizer'})
Loading

0 comments on commit 351e8c3

Please sign in to comment.