forked from meta-llama/llama-cookbook
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprofile_utils.py
65 lines (52 loc) · 1.94 KB
/
profile_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
import contextlib
import os
import torch
@contextlib.contextmanager
def maybe_run_profiler(config, *pos_args, **kwargs):
# get user defined profiler settings
print(f"inside profiler utils...{config=}")
if config.enable_profiler:
dump_dir = "profile_traces"
save_trace_dir = config.model_name
trace_dir = os.path.join(dump_dir, save_trace_dir)
iter_frequency = 5
_global_iter_count = 0
rank = torch.distributed.get_rank()
def trace_handler(prof):
nonlocal _global_iter_count
_global_iter_count += iter_frequency
curr_trace_dir_name = "iteration_" + str(_global_iter_count)
curr_trace_dir = os.path.join(trace_dir, curr_trace_dir_name)
if not os.path.exists(curr_trace_dir):
os.makedirs(curr_trace_dir)
if rank==0:
print(f"exporting profile traces to {curr_trace_dir}")
prof.export_chrome_trace(f"{curr_trace_dir}/rank{rank}_trace.json")
if rank==0:
print(f"Profiling active. Traces will be saved at {trace_dir}")
if not os.path.exists(trace_dir):
os.makedirs(trace_dir)
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(
wait=iter_frequency - 2,
warmup=1,
active=1,
repeat=2,
),
on_trace_ready=trace_handler,
profile_memory=True,
with_stack=False,
record_shapes=True,
) as torch_profiler:
yield torch_profiler
else:
if rank==0:
print("Profiling disabled.")
torch_profiler = contextlib.nullcontext()
yield None