Skip to content

Commit

Permalink
VBE training benchmarks (Manual) (#1855)
Browse files Browse the repository at this point in the history
Summary:


Set TorchRec's distributed training benchmarks to include VBE.

Differential Revision: D55882022
  • Loading branch information
PaulZhang12 authored and facebook-github-bot committed Apr 8, 2024
1 parent e02c9e5 commit 08fcd6f
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 10 deletions.
26 changes: 26 additions & 0 deletions torchrec/distributed/benchmark/benchmark_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def benchmark_ebc(
args: argparse.Namespace,
output_dir: str,
pooling_configs: Optional[List[int]] = None,
variable_batch_embeddings: bool = False,
) -> List[BenchmarkResult]:
table_configs = get_tables(tables, data_type=DataType.FP32)
sharder = TestEBCSharder(
Expand Down Expand Up @@ -104,6 +105,9 @@ def benchmark_ebc(
if pooling_configs:
args_kwargs["pooling_configs"] = pooling_configs

if variable_batch_embeddings:
args_kwargs["variable_batch_embeddings"] = variable_batch_embeddings

return benchmark_module(
module=module,
sharder=sharder,
Expand Down Expand Up @@ -153,6 +157,7 @@ def main() -> None:
mb = int(float(num * dim) / 1024 / 1024) * 4
tables_info += f"\nTABLE[{i}][{num:9}, {dim:4}] {mb:6}Mb"

### Benchmark no VBE
report: str = (
f"REPORT BENCHMARK {datetime_sfx} world_size:{args.world_size} batch_size:{args.batch_size}\n"
)
Expand All @@ -176,6 +181,27 @@ def main() -> None:
)
)

### Benchmark with VBE
report: str = (
f"REPORT BENCHMARK (VBE) {datetime_sfx} world_size:{args.world_size} batch_size:{args.batch_size}\n"
)
report += f"Module: {module_name} (VBE)\n"
report += tables_info
report += "\n"
report_file = f"{output_dir}/run_vbe.report"

benchmark_results_per_module.append(
benchmark_func(shrunk_table_sizes, args, output_dir, pooling_configs, True)
)
write_report_funcs_per_module.append(
partial(
write_report,
report_file=report_file,
report_str=report,
num_requests=num_requests,
)
)

for i, write_report_func in enumerate(write_report_funcs_per_module):
write_report_func(benchmark_results_per_module[i])

Expand Down
40 changes: 30 additions & 10 deletions torchrec/distributed/benchmark/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,19 +269,32 @@ def get_inputs(
num_inputs: int,
train: bool,
pooling_configs: Optional[List[int]] = None,
variable_batch_embeddings: bool = False,
) -> List[List[KeyedJaggedTensor]]:
inputs_batch: List[List[KeyedJaggedTensor]] = []

if variable_batch_embeddings and not train:
raise RuntimeError("Variable batch size is only supported in training mode")

for _ in range(num_inputs):
_, model_input_by_rank = ModelInput.generate(
batch_size=batch_size,
world_size=world_size,
num_float_features=0,
tables=tables,
weighted_tables=[],
long_indices=False,
tables_pooling=pooling_configs,
)
if variable_batch_embeddings:
_, model_input_by_rank = ModelInput.generate_variable_batch_input(
average_batch_size=batch_size,
world_size=world_size,
num_float_features=0,
# pyre-ignore
tables=tables,
)
else:
_, model_input_by_rank = ModelInput.generate(
batch_size=batch_size,
world_size=world_size,
num_float_features=0,
tables=tables,
weighted_tables=[],
long_indices=False,
tables_pooling=pooling_configs,
)

if train:
sparse_features_by_rank = [
Expand Down Expand Up @@ -770,6 +783,7 @@ def benchmark_module(
func_to_benchmark: Callable[..., None] = default_func_to_benchmark,
benchmark_func_kwargs: Optional[Dict[str, Any]] = None,
pooling_configs: Optional[List[int]] = None,
variable_batch_embeddings: bool = False,
) -> List[BenchmarkResult]:
"""
Args:
Expand Down Expand Up @@ -820,7 +834,13 @@ def benchmark_module(

num_inputs_to_gen: int = warmup_iters + bench_iters + prof_iters
inputs = get_inputs(
tables, batch_size, world_size, num_inputs_to_gen, train, pooling_configs
tables,
batch_size,
world_size,
num_inputs_to_gen,
train,
pooling_configs,
variable_batch_embeddings,
)

warmup_inputs = [rank_inputs[:warmup_iters] for rank_inputs in inputs]
Expand Down

0 comments on commit 08fcd6f

Please sign in to comment.