Skip to content

Commit

Permalink
5x speedup of benchmarks with input generation (pytorch#1813)
Browse files Browse the repository at this point in the history
Summary:

Massive speedups in benchmarking time from retrieving input once instead of for every variant.

Previously, benchmark_inference took ~10 minutes to run, now ~2.5 minutes

Differential Revision: D55134385
  • Loading branch information
PaulZhang12 authored and facebook-github-bot committed Mar 20, 2024
1 parent 71cc224 commit dc882f5
Showing 1 changed file with 67 additions and 30 deletions.
97 changes: 67 additions & 30 deletions torchrec/distributed/benchmark/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,27 +266,38 @@ def get_inputs(
batch_size: int,
world_size: int,
num_inputs: int,
rank: int = -1,
train: bool,
pooling_configs: Optional[List[int]] = None,
) -> List[KeyedJaggedTensor]:
inputs: List[KeyedJaggedTensor] = []
) -> List[List[KeyedJaggedTensor]]:
inputs_batch: List[List[KeyedJaggedTensor]] = []

for _ in range(num_inputs):
model_input = ModelInput.generate(
_, model_input_by_rank = ModelInput.generate(
batch_size=batch_size,
world_size=world_size,
num_float_features=0,
tables=tables,
weighted_tables=[],
long_indices=False,
tables_pooling=pooling_configs,
)[1][0]
)

# If ProcessGroup, place input on correct device. Otherwise, place on cuda:0
device = torch.device(f"cuda:{rank}") if rank >= 0 else torch.device("cuda:0")
inputs.append(model_input.idlist_features.to(device))
if train:
sparse_features_by_rank = [
model_input.idlist_features for model_input in model_input_by_rank
]
inputs_batch.append(sparse_features_by_rank)
else:
sparse_features = model_input_by_rank[0].idlist_features
inputs_batch.append([sparse_features])

return inputs
# Transpose if train, as inputs_by_rank is currently in [B X R] format
inputs_by_rank = [
[sparse_features for sparse_features in sparse_features_rank]
for sparse_features_rank in zip(*inputs_batch)
]

return inputs_by_rank


def write_report(
Expand Down Expand Up @@ -589,9 +600,9 @@ def init_module_and_run_benchmark(
compile_mode: CompileMode,
world_size: int,
batch_size: int,
warmup_iters: int,
bench_iters: int,
prof_iters: int,
warmup_inputs: List[List[KeyedJaggedTensor]],
bench_inputs: List[List[KeyedJaggedTensor]],
prof_inputs: List[List[KeyedJaggedTensor]],
tables: Union[List[EmbeddingBagConfig], List[EmbeddingConfig]],
output_dir: str,
num_benchmarks: int,
Expand All @@ -615,14 +626,28 @@ def init_module_and_run_benchmark(
existing in the loop
"""

num_inputs_to_gen: int = warmup_iters + bench_iters + prof_iters
inputs = get_inputs(
tables, batch_size, world_size, num_inputs_to_gen, rank, pooling_configs
)

warmup_inputs = inputs[:warmup_iters]
bench_inputs = inputs[warmup_iters : (warmup_iters + bench_iters)]
prof_inputs = inputs[-prof_iters:]
if rank >= 0:
warmup_inputs_cuda = [
warmup_input[0].to(torch.device(f"cuda:{rank}"))
for warmup_input in warmup_inputs
]
bench_inputs_cuda = [
bench_input[0].to(torch.device(f"cuda:{rank}"))
for bench_input in bench_inputs
]
prof_inputs_cuda = [
prof_input[0].to(torch.device(f"cuda:{rank}")) for prof_input in prof_inputs
]
else:
warmup_inputs_cuda = [
warmup_input[0].to(torch.device("cuda:0")) for warmup_input in warmup_inputs
]
bench_inputs_cuda = [
bench_input[0].to(torch.device("cuda:0")) for bench_input in bench_inputs
]
prof_inputs_cuda = [
prof_input[0].to(torch.device("cuda:0")) for prof_input in prof_inputs
]

with (
MultiProcessContext(rank, world_size, "nccl", None)
Expand All @@ -632,7 +657,7 @@ def init_module_and_run_benchmark(
module = transform_module(
module=module,
device=device,
inputs=warmup_inputs,
inputs=warmup_inputs_cuda,
sharder=sharder,
sharding_type=sharding_type,
compile_mode=compile_mode,
Expand All @@ -647,9 +672,9 @@ def init_module_and_run_benchmark(
res = benchmark(
name,
module,
warmup_inputs,
bench_inputs,
prof_inputs,
warmup_inputs_cuda,
bench_inputs_cuda,
prof_inputs_cuda,
world_size=world_size,
output_dir=output_dir,
num_benchmarks=num_benchmarks,
Expand Down Expand Up @@ -788,6 +813,18 @@ def benchmark_module(
else:
wrapped_module = ECWrapper(module)

num_inputs_to_gen: int = warmup_iters + bench_iters + prof_iters
inputs = get_inputs(
tables, batch_size, world_size, num_inputs_to_gen, train, pooling_configs
)

warmup_inputs = [rank_inputs[:warmup_iters] for rank_inputs in inputs]
bench_inputs = [
rank_inputs[warmup_iters : (warmup_iters + bench_iters)]
for rank_inputs in inputs
]
prof_inputs = [rank_inputs[-prof_iters:] for rank_inputs in inputs]

for sharding_type in sharding_types:
for compile_mode in compile_modes:
# Test sharders should have a singular sharding_type
Expand All @@ -810,9 +847,9 @@ def benchmark_module(
compile_mode=compile_mode,
world_size=world_size,
batch_size=batch_size,
warmup_iters=warmup_iters,
bench_iters=bench_iters,
prof_iters=prof_iters,
warmup_inputs=warmup_inputs,
bench_inputs=bench_inputs,
prof_inputs=prof_inputs,
tables=tables,
num_benchmarks=num_benchmarks,
output_dir=output_dir,
Expand All @@ -830,9 +867,9 @@ def benchmark_module(
compile_mode=compile_mode,
world_size=world_size,
batch_size=batch_size,
warmup_iters=warmup_iters,
bench_iters=bench_iters,
prof_iters=prof_iters,
warmup_inputs=warmup_inputs,
bench_inputs=bench_inputs,
prof_inputs=prof_inputs,
tables=tables,
num_benchmarks=num_benchmarks,
output_dir=output_dir,
Expand Down

0 comments on commit dc882f5

Please sign in to comment.