-
Notifications
You must be signed in to change notification settings - Fork 8
/
megatron_dataloader.py
104 lines (86 loc) · 4.4 KB
/
megatron_dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from transformers import AutoTokenizer
from dolomite_engine.arguments import DatasetArgs, DistributedArgs, RandomArgs, TrainingParameters
from dolomite_engine.data import get_megatron_gpt_dataloaders
from dolomite_engine.utils import ProcessGroupManager, set_logger
# this is needed to print logs during preparation of dataloader
set_logger()
# use the appropriate tokenizer here
# the tokenizer is required for using FIM during training
tokenizer = AutoTokenizer.from_pretrained("ibm-granite/granite-3b-code-base")
# adjust these hyperparameters
num_training_steps = 10000
# batch size per GPU
micro_batch_size = 8
sequence_length = 2048
# whether or not to evaluate during training
eval_during_training = False
# interval at which eval is done if eval_during_training is True
eval_interval = 1000
class DummyArgs:
# seed for shuffling data
random_args = RandomArgs(seed=42)
datasets = [
DatasetArgs(
# don't change these values
class_name="MegatronDataset",
data_name="Megatron",
data_sampling_ratio=1,
# you are allowed to change these values
class_args={
# number of steps to eval every time we do evaluation
# 2 means (2 * micro_batch_size * sequence_length) tokens to calculate perplexity
"eval_steps": 2,
# path of cache used by megatron dataset
"data_cache_path": "./cache",
# sampling proportion followed by their file paths
"data_path": [
0.2,
"/proj-3500/datasets/slim_pajama_gptneox_megatron/train/chunk1",
0.5,
"/proj-3500/datasets/slim_pajama_gptneox_megatron/train/chunk2",
0.3,
"/proj-3500/datasets/slim_pajama_gptneox_megatron/train/chunk3",
],
# 100% train, 0% val and 0% test
"split": "100,0,0",
# megatron dataloader returns 2049 tokens (sequence_length + 1) since the loss computation is done outside
# the model class. to train with HF's loss computation, we need to pass (sequence_length - 1) so it
# returns 2048 tokens but note that HF will itself train with 2047 tokens which is not a big problem
# Don't change this directly if using HF trainer or something adjust sequence_length above
"sequence_length": sequence_length - 1,
},
)
]
training_parameters = TrainingParameters(
num_training_steps=num_training_steps,
micro_batch_size=micro_batch_size,
eval_interval=eval_interval,
eval_during_training=eval_during_training,
)
# don't play with this value
distributed_args = DistributedArgs()
assert (
ProcessGroupManager.is_initialized()
), "torch distributed needs to be initialized before the following logic otherwise everything is incorrect"
# following logic is for FSDP/ZeRO without any model parallelism
# since the function is model parallelism aware, we need to emulate no model parallelism
# if you are using model parallelism, either train directly on dolomite or try modifying this logic
with (
ProcessGroupManager.set_dummy_data_parallel_world_size(ProcessGroupManager.get_world_size()),
ProcessGroupManager.set_dummy_data_parallel_rank(ProcessGroupManager.get_global_rank()),
ProcessGroupManager.set_dummy_tensor_parallel_world_size(1),
ProcessGroupManager.set_dummy_tensor_parallel_rank(0),
ProcessGroupManager.set_dummy_tensor_parallel_first_rank(ProcessGroupManager.get_global_rank()),
):
# running this function outside of the above context might throw error or give incorrect behaviour
# you can change consumed samples, 0 means dataloader is starting from beginning and adjusting to a different
# number can be used to resume training at a different point
# for FSDP, consumed_samples = global_step * micro_batch_size * gradient_accumulation_steps * num_GPUs
train_dataloader, val_dataloaders, test_dataloaders = get_megatron_gpt_dataloaders(
DummyArgs(), tokenizer, consumed_samples=0
)
# filter any dataloaders in val and test that are None
# this happens when there are no elements in these dataloaders
val_dataloaders = filter(lambda x: x is not None, val_dataloaders)
test_dataloaders = filter(lambda x: x is not None, test_dataloaders)
# use the dataloaders :)