Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
Signed-off-by: Mayank Mishra <[email protected]>
  • Loading branch information
mayank31398 committed Jan 8, 2025
1 parent 58bd173 commit 829c79b
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 119 deletions.
94 changes: 44 additions & 50 deletions dolomite_engine/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,67 +63,61 @@ def train(
backward_context = loss_parallel if args.distributed_args.tensor_parallel_word_embeddings else nullcontext

torch_profiler = get_torch_profiler(args.logging_args.torch_profiler_trace_path)

if torch_profiler is not None:
torch_profiler.__enter__()

metrics_tracker = MetricsTrackingDict({})

global_step = starting_iteration
while global_step < num_training_steps:
global_step += 1

loss_step_dict = train_step(
model_container=model_container,
pipeline_schedule=pipeline_schedule,
optimizer_container=optimizer_container,
lr_scheduler_container=lr_scheduler_container,
train_dataloader=train_dataloader_infinite,
gradient_accumulation_steps=gradient_accumulation_steps,
gradient_clipping=gradient_clipping,
forward_context=forward_context,
backward_context=backward_context,
sync_every_gradient_accumulation_step=args.distributed_args.sync_every_gradient_accumulation_step,
is_pipeline_parallel_enabled=args.distributed_args.num_pipeline_stages > 1,
sequence_length=None,
)

metrics_tracker = metrics_tracker + loss_step_dict
with torch_profiler:
while global_step < num_training_steps:
global_step += 1

if torch_profiler is not None:
torch_profiler.step()
loss_step_dict = train_step(
model_container=model_container,
pipeline_schedule=pipeline_schedule,
optimizer_container=optimizer_container,
lr_scheduler_container=lr_scheduler_container,
train_dataloader=train_dataloader_infinite,
gradient_accumulation_steps=gradient_accumulation_steps,
gradient_clipping=gradient_clipping,
forward_context=forward_context,
backward_context=backward_context,
sync_every_gradient_accumulation_step=args.distributed_args.sync_every_gradient_accumulation_step,
is_pipeline_parallel_enabled=args.distributed_args.num_pipeline_stages > 1,
sequence_length=None,
)

if global_step % log_interval == 0:
metrics_tracker = metrics_tracker / log_interval
metrics_tracker["learning_rate"] = lr_scheduler_container[0].get_lr()[0]
metrics_tracker = metrics_tracker + loss_step_dict

track_metrics(
global_step=global_step,
experiments_tracker=experiments_tracker,
metrics_tracker=metrics_tracker,
context="train",
)
if torch_profiler is not None:
torch_profiler.step()

metrics_tracker = MetricsTrackingDict({})
if global_step % log_interval == 0:
metrics_tracker = metrics_tracker / log_interval
metrics_tracker["learning_rate"] = lr_scheduler_container[0].get_lr()[0]

if eval_during_training and (global_step % eval_interval == 0 or global_step == num_training_steps):
evaluate(val_dataloader, model_container, global_step, experiments_tracker)
track_metrics(
global_step=global_step,
experiments_tracker=experiments_tracker,
metrics_tracker=metrics_tracker,
context="train",
)

if global_step % save_interval == 0 or global_step == num_training_steps:
save_checkpoint(
args=args,
model_container=model_container,
optimizer_container=optimizer_container,
lr_scheduler_container=lr_scheduler_container,
train_dataloader=train_dataloader,
experiments_tracker=experiments_tracker,
iteration=global_step,
)
metrics_tracker = MetricsTrackingDict({})

ensure_last_checkpoint_is_saved()
if eval_during_training and (global_step % eval_interval == 0 or global_step == num_training_steps):
evaluate(val_dataloader, model_container, global_step, experiments_tracker)

if torch_profiler is not None:
torch_profiler.__exit__()
if global_step % save_interval == 0 or global_step == num_training_steps:
save_checkpoint(
args=args,
model_container=model_container,
optimizer_container=optimizer_container,
lr_scheduler_container=lr_scheduler_container,
train_dataloader=train_dataloader,
experiments_tracker=experiments_tracker,
iteration=global_step,
)

ensure_last_checkpoint_is_saved()


@torch.no_grad()
Expand Down
129 changes: 61 additions & 68 deletions dolomite_engine/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,17 +112,18 @@ def train(
eval_steps = args.datasets[0].class_args.get("eval_steps")
evaluate(val_dataloaders, model_container, starting_iteration, experiments_tracker, eval_steps, group_names)

dp_world_size = ProcessGroupManager.get_data_parallel_world_size()
ProcessGroupManager.get_data_parallel_world_size()

micro_batch_size = args.training_parameters.micro_batch_size
sequence_length = args.datasets[0].class_args.get("sequence_length")
tokens_per_batch = StepTracker.get_global_batch_size() * sequence_length
global_batch_size = StepTracker.get_global_batch_size()
tokens_per_batch = global_batch_size * sequence_length

# model flops per GPU
model_flops = (
get_model_tflops(
config=model_container[0].config,
batch_size=StepTracker.get_global_batch_size(),
batch_size=global_batch_size,
sequence_length=sequence_length,
gradient_checkpointing_method=args.distributed_args.gradient_checkpointing_method,
gradient_checkpointing_args=args.distributed_args.gradient_checkpointing_args,
Expand All @@ -134,93 +135,85 @@ def train(
backward_context = loss_parallel if args.distributed_args.tensor_parallel_word_embeddings else nullcontext

torch_profiler = get_torch_profiler(args.logging_args.torch_profiler_trace_path)

if torch_profiler is not None:
torch_profiler.__enter__()

start_time = time.perf_counter()
steps_since_start_time = 0
metrics_tracker = MetricsTrackingDict({})

global_step = starting_iteration
while global_step < num_training_steps:
global_step += 1
steps_since_start_time += 1

loss_step_dict = train_step(
model_container=model_container,
pipeline_schedule=pipeline_schedule,
optimizer_container=optimizer_container,
lr_scheduler_container=lr_scheduler_container,
train_dataloader=train_dataloader,
gradient_accumulation_steps=gradient_accumulation_steps,
gradient_clipping=gradient_clipping,
forward_context=forward_context,
backward_context=backward_context,
sync_every_gradient_accumulation_step=args.distributed_args.sync_every_gradient_accumulation_step,
is_pipeline_parallel_enabled=args.distributed_args.num_pipeline_stages > 1,
micro_batch_size=micro_batch_size,
sequence_length=sequence_length,
)

metrics_tracker = metrics_tracker + loss_step_dict
with torch_profiler:
while global_step < num_training_steps:
global_step += 1
steps_since_start_time += 1

loss_step_dict = train_step(
model_container=model_container,
pipeline_schedule=pipeline_schedule,
optimizer_container=optimizer_container,
lr_scheduler_container=lr_scheduler_container,
train_dataloader=train_dataloader,
gradient_accumulation_steps=gradient_accumulation_steps,
gradient_clipping=gradient_clipping,
forward_context=forward_context,
backward_context=backward_context,
sync_every_gradient_accumulation_step=args.distributed_args.sync_every_gradient_accumulation_step,
is_pipeline_parallel_enabled=args.distributed_args.num_pipeline_stages > 1,
micro_batch_size=micro_batch_size,
sequence_length=sequence_length,
)

metrics_tracker = metrics_tracker + loss_step_dict

if torch_profiler is not None:
torch_profiler.step()
if torch_profiler is not None:
torch_profiler.step()

if global_step % log_interval == 0:
metrics_tracker = metrics_tracker / log_interval
if global_step % log_interval == 0:
metrics_tracker = metrics_tracker / log_interval

time_elapsed = time.perf_counter() - start_time
step_time = time_elapsed / steps_since_start_time
time_elapsed = time.perf_counter() - start_time
step_time = time_elapsed / steps_since_start_time

metrics_tracker["learning_rate"] = lr_scheduler_container[0].get_lr()[0]
metrics_tracker["learning_rate"] = lr_scheduler_container[0].get_lr()[0]

if model_flops is not None:
metrics_tracker["FLOPs"] = model_flops * steps_since_start_time / time_elapsed
if model_flops is not None:
metrics_tracker["FLOPs"] = model_flops * steps_since_start_time / time_elapsed

metrics_tracker["billion_tokens_per_day"] = tokens_per_batch * 86400 / step_time / 1e9
metrics_tracker["step_time (sec)"] = step_time
metrics_tracker["billion_tokens_per_day"] = tokens_per_batch * 86400 / step_time / 1e9
metrics_tracker["step_time (sec)"] = step_time

track_metrics(
global_step=global_step,
experiments_tracker=experiments_tracker,
metrics_tracker=metrics_tracker,
context="train",
)
track_metrics(
global_step=global_step,
experiments_tracker=experiments_tracker,
metrics_tracker=metrics_tracker,
context="train",
)

start_time = time.perf_counter()
steps_since_start_time = 0
metrics_tracker = MetricsTrackingDict({})
start_time = time.perf_counter()
steps_since_start_time = 0
metrics_tracker = MetricsTrackingDict({})

if eval_during_training and (global_step % eval_interval == 0 or global_step == num_training_steps):
evaluate(val_dataloaders, model_container, global_step, experiments_tracker, eval_steps, group_names)
if eval_during_training and (global_step % eval_interval == 0 or global_step == num_training_steps):
evaluate(val_dataloaders, model_container, global_step, experiments_tracker, eval_steps, group_names)

if global_step % save_interval == 0 or global_step == num_training_steps:
save_checkpoint(
args=args,
model_container=model_container,
optimizer_container=optimizer_container,
lr_scheduler_container=lr_scheduler_container,
train_dataloader=None,
experiments_tracker=experiments_tracker,
iteration=global_step,
metadata={
"consumed_samples": global_step * micro_batch_size * gradient_accumulation_steps * dp_world_size
},
)
if global_step % save_interval == 0 or global_step == num_training_steps:
save_checkpoint(
args=args,
model_container=model_container,
optimizer_container=optimizer_container,
lr_scheduler_container=lr_scheduler_container,
train_dataloader=None,
experiments_tracker=experiments_tracker,
iteration=global_step,
metadata={"consumed_samples": global_step * global_batch_size},
)

start_time = time.perf_counter()
steps_since_start_time = 0
start_time = time.perf_counter()
steps_since_start_time = 0

if eval_during_training:
evaluate(test_dataloaders, model_container, global_step, experiments_tracker, eval_steps, group_names)

ensure_last_checkpoint_is_saved()

if torch_profiler is not None:
torch_profiler.__exit__()


@torch.no_grad()
def evaluate(
Expand Down
3 changes: 2 additions & 1 deletion dolomite_engine/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,8 @@ def track_metrics(


def get_torch_profiler(torch_profiler_trace_path: str) -> torch.profiler.profile:
torch_profiler = None
torch_profiler = nullcontext()

if torch_profiler_trace_path is not None:
torch_profiler = torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
Expand Down

0 comments on commit 829c79b

Please sign in to comment.