Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VBE training benchmarks (Manual) #1855

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions torchrec/distributed/benchmark/benchmark_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def benchmark_ebc(
args: argparse.Namespace,
output_dir: str,
pooling_configs: Optional[List[int]] = None,
variable_batch_embeddings: bool = False,
) -> List[BenchmarkResult]:
table_configs = get_tables(tables, data_type=DataType.FP32)
sharder = TestEBCSharder(
Expand Down Expand Up @@ -104,6 +105,9 @@ def benchmark_ebc(
if pooling_configs:
args_kwargs["pooling_configs"] = pooling_configs

if variable_batch_embeddings:
args_kwargs["variable_batch_embeddings"] = variable_batch_embeddings

return benchmark_module(
module=module,
sharder=sharder,
Expand Down Expand Up @@ -153,6 +157,7 @@ def main() -> None:
mb = int(float(num * dim) / 1024 / 1024) * 4
tables_info += f"\nTABLE[{i}][{num:9}, {dim:4}] {mb:6}Mb"

### Benchmark no VBE
report: str = (
f"REPORT BENCHMARK {datetime_sfx} world_size:{args.world_size} batch_size:{args.batch_size}\n"
)
Expand All @@ -176,6 +181,27 @@ def main() -> None:
)
)

### Benchmark with VBE
report: str = (
f"REPORT BENCHMARK (VBE) {datetime_sfx} world_size:{args.world_size} batch_size:{args.batch_size}\n"
)
report += f"Module: {module_name} (VBE)\n"
report += tables_info
report += "\n"
report_file = f"{output_dir}/run_vbe.report"

benchmark_results_per_module.append(
benchmark_func(shrunk_table_sizes, args, output_dir, pooling_configs, True)
)
write_report_funcs_per_module.append(
partial(
write_report,
report_file=report_file,
report_str=report,
num_requests=num_requests,
)
)

for i, write_report_func in enumerate(write_report_funcs_per_module):
write_report_func(benchmark_results_per_module[i])

Expand Down
40 changes: 30 additions & 10 deletions torchrec/distributed/benchmark/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,19 +269,32 @@ def get_inputs(
num_inputs: int,
train: bool,
pooling_configs: Optional[List[int]] = None,
variable_batch_embeddings: bool = False,
) -> List[List[KeyedJaggedTensor]]:
inputs_batch: List[List[KeyedJaggedTensor]] = []

if variable_batch_embeddings and not train:
raise RuntimeError("Variable batch size is only supported in training mode")

for _ in range(num_inputs):
_, model_input_by_rank = ModelInput.generate(
batch_size=batch_size,
world_size=world_size,
num_float_features=0,
tables=tables,
weighted_tables=[],
long_indices=False,
tables_pooling=pooling_configs,
)
if variable_batch_embeddings:
_, model_input_by_rank = ModelInput.generate_variable_batch_input(
average_batch_size=batch_size,
world_size=world_size,
num_float_features=0,
# pyre-ignore
tables=tables,
)
else:
_, model_input_by_rank = ModelInput.generate(
batch_size=batch_size,
world_size=world_size,
num_float_features=0,
tables=tables,
weighted_tables=[],
long_indices=False,
tables_pooling=pooling_configs,
)

if train:
sparse_features_by_rank = [
Expand Down Expand Up @@ -770,6 +783,7 @@ def benchmark_module(
func_to_benchmark: Callable[..., None] = default_func_to_benchmark,
benchmark_func_kwargs: Optional[Dict[str, Any]] = None,
pooling_configs: Optional[List[int]] = None,
variable_batch_embeddings: bool = False,
) -> List[BenchmarkResult]:
"""
Args:
Expand Down Expand Up @@ -820,7 +834,13 @@ def benchmark_module(

num_inputs_to_gen: int = warmup_iters + bench_iters + prof_iters
inputs = get_inputs(
tables, batch_size, world_size, num_inputs_to_gen, train, pooling_configs
tables,
batch_size,
world_size,
num_inputs_to_gen,
train,
pooling_configs,
variable_batch_embeddings,
)

warmup_inputs = [rank_inputs[:warmup_iters] for rank_inputs in inputs]
Expand Down
10 changes: 5 additions & 5 deletions torchrec/distributed/test_utils/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ def generate_variable_batch_input(
strides_per_rank_per_feature = {}
inverse_indices_per_rank_per_feature = {}
label_per_rank = []

for rank in range(world_size):
# keys, values, lengths, strides
lengths_per_rank_per_feature[rank] = {}
Expand Down Expand Up @@ -375,12 +376,11 @@ def generate_variable_batch_input(
accum_batch_size = 0
inverse_indices = []
for rank in range(world_size):
inverse_indices += [
index + accum_batch_size
for index in inverse_indices_per_rank_per_feature[rank][key]
]
inverse_indices.append(
inverse_indices_per_rank_per_feature[rank][key] + accum_batch_size
)
accum_batch_size += strides_per_rank_per_feature[rank][key]
inverse_indices_list.append(torch.IntTensor(inverse_indices))
inverse_indices_list.append(torch.cat(inverse_indices))
global_inverse_indices = (list(keys.keys()), torch.stack(inverse_indices_list))
if global_constant_batch:
global_offsets = []
Expand Down
Loading