Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change path for scattermoe #21

Merged
merged 5 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ model_args:
attention_softmax_in_fp32: true
add_bias: true
position_embedding_type: learned_absolute
rope_theta: 10000
attention_implementation: flash_attention_2
use_padding_free_transformer: true

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ model_args:
attention_softmax_in_fp32: true
add_bias: true
position_embedding_type: learned_absolute
rope_theta: 10000
attention_implementation: flash_attention_2
use_padding_free_transformer: true

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ model_args:
attention_softmax_in_fp32: true
add_bias: true
position_embedding_type: learned_absolute
rope_theta: 10000
attention_implementation: flash_attention_2
use_padding_free_transformer: true

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ model_args:
attention_softmax_in_fp32: true
add_bias: true
position_embedding_type: learned_absolute
rope_theta: 10000
attention_implementation: flash_attention_2
use_padding_free_transformer: true

Expand Down
92 changes: 92 additions & 0 deletions configs/pretraining-examples/moe/moe.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
datasets:
# class_name, data_name & data_sampling_ratio are not used but need to be passed to avoid errors
- class_name: MegatronDataset
data_name: Megatron
data_sampling_ratio: 1
class_args:
eval_steps: 2
data_cache_path: /proj/checkpoints/mayank/cache
# Option 1: data loading using --data-path with single file
data_path:
- data/lang=Matlab
split: 100,0,0
sequence_length: 4096

tokenizer_args:
tokenizer_name: bigcode/starcoder

model_args:
model_class: AutoModelForCausalLM
pretrained_config:
activation_function: swiglu
add_bias: false
attn_pdrop: 0
embd_pdrop: 0
resid_pdrop: 0
initializer_range: 0.1
layer_norm_epsilon: 1e-05
model_type: moe_dolomite
n_embd: 1024
n_head: 16
n_inner: 512
n_layer: 24
n_positions: 4096
num_experts: 32
num_experts_per_tok: 8
num_key_value_heads: 8
normalization_function: rmsnorm
position_embedding_type: rope
rope_theta: 10000
attention_head_type: gqa
scale_attn_weights: true
vocab_size: 49152
tie_word_embeddings: true
upcast_logits_for_loss: true
bos_token_id: 0
eos_token_id: 0
pad_token_id: 0
router_aux_loss_coef: 0.01
moe_implementation: scattermoe
attention_implementation: sdpa

tuning_args:
tuning_method: pretraining

save_args:
save_path: /proj/checkpoints/mayank/test/sdpa-stage-0-1b-moe-compile
save_interval: 5000

logging_args:
log_interval: 10

training_parameters:
num_training_steps: 25000
eval_interval: 10000000
micro_batch_size: 2
gradient_accumulation_steps: 8
eval_during_training: false

optimizer_args:
class_name: TorchAdamW
class_args:
lr: 3e-4
weight_decay: 0.1
betas:
- 0.9
- 0.95
eps: 1e-10

lr_scheduler_args:
lr_decay_style: cosine
num_warmup_steps: 2500
num_constant_steps: 0
num_decay_steps: 22500

mixed_precision_args:
dtype: bf16

distributed_args:
distributed_backend: torch
communication_dtype: fp32
torch_compile: true
stage: 0
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


if is_kernel_hyperdrive_available():
from khd.scattermoe.triton_implementation import padded_block_indices, scattered_experts
from khd.kernels.scattermoe.triton_implementation import padded_block_indices, scattered_experts


class ParameterizedScatteredExperts(ParameterizedExperts):
Expand Down
2 changes: 1 addition & 1 deletion dolomite_engine/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def train_step(
def all_reduce_metrics_tracker(metrics_tracker: MetricsTrackingDict) -> MetricsTrackingDict:
tensor = [metrics_tracker[key] for key in metrics_tracker]
tensor = torch.stack(tensor)
torch.distributed.all_reduce(tensor, op=ReduceOp.AVG, group=ProcessGroupManager.get_data_parallel_group())
torch.distributed.all_reduce(tensor.cpu(), group=ProcessGroupManager.get_data_parallel_group())
tensor = tensor.tolist()

for i, key in enumerate(metrics_tracker):
Expand Down
Loading