diff --git a/torchrec/distributed/benchmark/benchmark_train.py b/torchrec/distributed/benchmark/benchmark_train.py index f923949fc..19cd983bd 100644 --- a/torchrec/distributed/benchmark/benchmark_train.py +++ b/torchrec/distributed/benchmark/benchmark_train.py @@ -71,6 +71,7 @@ def benchmark_ebc( args: argparse.Namespace, output_dir: str, pooling_configs: Optional[List[int]] = None, + variable_batch_embeddings: bool = False, ) -> List[BenchmarkResult]: table_configs = get_tables(tables, data_type=DataType.FP32) sharder = TestEBCSharder( @@ -104,6 +105,9 @@ def benchmark_ebc( if pooling_configs: args_kwargs["pooling_configs"] = pooling_configs + if variable_batch_embeddings: + args_kwargs["variable_batch_embeddings"] = variable_batch_embeddings + return benchmark_module( module=module, sharder=sharder, @@ -153,6 +157,7 @@ def main() -> None: mb = int(float(num * dim) / 1024 / 1024) * 4 tables_info += f"\nTABLE[{i}][{num:9}, {dim:4}] {mb:6}Mb" + ### Benchmark no VBE report: str = ( f"REPORT BENCHMARK {datetime_sfx} world_size:{args.world_size} batch_size:{args.batch_size}\n" ) @@ -176,6 +181,27 @@ def main() -> None: ) ) + ### Benchmark with VBE + report: str = ( + f"REPORT BENCHMARK (VBE) {datetime_sfx} world_size:{args.world_size} batch_size:{args.batch_size}\n" + ) + report += f"Module: {module_name} (VBE)\n" + report += tables_info + report += "\n" + report_file = f"{output_dir}/run_vbe.report" + + benchmark_results_per_module.append( + benchmark_func(shrunk_table_sizes, args, output_dir, pooling_configs, True) + ) + write_report_funcs_per_module.append( + partial( + write_report, + report_file=report_file, + report_str=report, + num_requests=num_requests, + ) + ) + for i, write_report_func in enumerate(write_report_funcs_per_module): write_report_func(benchmark_results_per_module[i]) diff --git a/torchrec/distributed/benchmark/benchmark_utils.py b/torchrec/distributed/benchmark/benchmark_utils.py index 2d221ba82..996b34cb7 100644 --- a/torchrec/distributed/benchmark/benchmark_utils.py +++ b/torchrec/distributed/benchmark/benchmark_utils.py @@ -269,19 +269,32 @@ def get_inputs( num_inputs: int, train: bool, pooling_configs: Optional[List[int]] = None, + variable_batch_embeddings: bool = False, ) -> List[List[KeyedJaggedTensor]]: inputs_batch: List[List[KeyedJaggedTensor]] = [] + if variable_batch_embeddings and not train: + raise RuntimeError("Variable batch size is only supported in training mode") + for _ in range(num_inputs): - _, model_input_by_rank = ModelInput.generate( - batch_size=batch_size, - world_size=world_size, - num_float_features=0, - tables=tables, - weighted_tables=[], - long_indices=False, - tables_pooling=pooling_configs, - ) + if variable_batch_embeddings: + _, model_input_by_rank = ModelInput.generate_variable_batch_input( + average_batch_size=batch_size, + world_size=world_size, + num_float_features=0, + # pyre-ignore + tables=tables, + ) + else: + _, model_input_by_rank = ModelInput.generate( + batch_size=batch_size, + world_size=world_size, + num_float_features=0, + tables=tables, + weighted_tables=[], + long_indices=False, + tables_pooling=pooling_configs, + ) if train: sparse_features_by_rank = [ @@ -770,6 +783,7 @@ def benchmark_module( func_to_benchmark: Callable[..., None] = default_func_to_benchmark, benchmark_func_kwargs: Optional[Dict[str, Any]] = None, pooling_configs: Optional[List[int]] = None, + variable_batch_embeddings: bool = False, ) -> List[BenchmarkResult]: """ Args: @@ -820,7 +834,13 @@ def benchmark_module( num_inputs_to_gen: int = warmup_iters + bench_iters + prof_iters inputs = get_inputs( - tables, batch_size, world_size, num_inputs_to_gen, train, pooling_configs + tables, + batch_size, + world_size, + num_inputs_to_gen, + train, + pooling_configs, + variable_batch_embeddings, ) warmup_inputs = [rank_inputs[:warmup_iters] for rank_inputs in inputs]