Skip to content

Commit

Permalink
Introduce Training benchmarks (pytorch#1705)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#1705

Differential Revision: D53481057
  • Loading branch information
PaulZhang12 authored and facebook-github-bot committed Feb 15, 2024
1 parent 7b4b1dd commit 8e975d2
Show file tree
Hide file tree
Showing 3 changed files with 531 additions and 116 deletions.
55 changes: 7 additions & 48 deletions torchrec/distributed/benchmark/benchmark_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
benchmark_module,
BenchmarkResult,
CompileMode,
DLRM_NUM_EMBEDDINGS_PER_FEATURE,
EMBEDDING_DIM,
get_tables,
init_argparse_and_args,
write_report,
)
from torchrec.distributed.embedding_types import EmbeddingComputeKernel, ShardingType
from torchrec.distributed.test_utils.infer_utils import (
Expand All @@ -35,21 +39,6 @@
logger: logging.Logger = logging.getLogger()


def init_argparse_and_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()

parser.add_argument("--warmup_iters", type=int, default=20)
parser.add_argument("--bench_iters", type=int, default=500)
parser.add_argument("--prof_iters", type=int, default=20)
parser.add_argument("--batch_size", type=int, default=2048)
parser.add_argument("--world_size", type=int, default=2)
parser.add_argument("--output_dir", type=str, default="/var/tmp/torchrec-bench")
parser.add_argument("--num_benchmarks", type=int, default=9)

args = parser.parse_args()
return args


BENCH_SHARDING_TYPES: List[ShardingType] = [
ShardingType.TABLE_WISE,
ShardingType.ROW_WISE,
Expand All @@ -61,42 +50,13 @@ def init_argparse_and_args() -> argparse.Namespace:
CompileMode.FX_SCRIPT,
]


TABLE_SIZES: List[Tuple[int, int]] = [
(40_000_000, 256),
(4_000_000, 256),
(1_000_000, 256),
(num_embeddings, EMBEDDING_DIM)
for num_embeddings in DLRM_NUM_EMBEDDINGS_PER_FEATURE
]


def write_report(
benchmark_results: List[BenchmarkResult],
report_file: str,
report_str: str,
num_requests: int,
) -> None:

for benchmark_res in benchmark_results:
avg_dur_s = benchmark_res.elapsed_time.mean().item() * 1e-3 # time in seconds
std_dur_s = benchmark_res.elapsed_time.std().item() * 1e-3 # time in seconds

qps = int(num_requests / avg_dur_s)

mem_allocated_by_rank = benchmark_res.max_mem_allocated

mem_str = ""
for i, mem_mb in enumerate(mem_allocated_by_rank):
mem_str += f"Rank {i}: {mem_mb:7} "

report_str += f"{benchmark_res.short_name:40} Avg QPS:{qps:10} Avg Duration: {int(1000*avg_dur_s):5}"
report_str += f"ms Standard Dev Duration: {(1000*std_dur_s):.2f}ms\n"
report_str += f"\tMemory Allocated Per Rank:\n\t{mem_str}\n"

with open(report_file, "w") as f:
f.write(report_str)

logger.info(f"Report written to {report_file}:\n{report_str}")


def benchmark_qec(args: argparse.Namespace, output_dir: str) -> List[BenchmarkResult]:
tables = get_tables(TABLE_SIZES, is_pooled=False)
sharder = TestQuantECSharder(
Expand Down Expand Up @@ -201,7 +161,6 @@ def main() -> None:
report += tables_info
report += "\n"

num_requests = args.bench_iters * args.batch_size * args.num_benchmarks
report += f"num_requests:{num_requests:8}\n"
report_file: str = f"{output_dir}/run.report"

Expand Down
179 changes: 179 additions & 0 deletions torchrec/distributed/benchmark/benchmark_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

#!/usr/bin/env python3

import argparse
import copy
import logging
import os
import time
from functools import partial
from typing import List, Optional, Tuple

import torch

from torchrec.distributed.benchmark.benchmark_utils import (
benchmark_module,
BenchmarkResult,
CompileMode,
DLRM_NUM_EMBEDDINGS_PER_FEATURE,
EMBEDDING_DIM,
get_tables,
init_argparse_and_args,
write_report,
)
from torchrec.distributed.embedding_types import EmbeddingComputeKernel, ShardingType
from torchrec.distributed.test_utils.test_model import TestEBCSharder
from torchrec.distributed.types import DataType
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor


logger: logging.Logger = logging.getLogger()


BENCH_SHARDING_TYPES: List[ShardingType] = [
ShardingType.TABLE_WISE,
ShardingType.ROW_WISE,
ShardingType.COLUMN_WISE,
]

BENCH_COMPILE_MODES: List[CompileMode] = [
CompileMode.EAGER,
# CompileMode.FX_SCRIPT,
]

TABLE_SIZES: List[Tuple[int, int]] = [
(num_embeddings, EMBEDDING_DIM)
for num_embeddings in DLRM_NUM_EMBEDDINGS_PER_FEATURE
]


def training_func_to_benchmark(
model: torch.nn.Module,
bench_inputs: List[KeyedJaggedTensor],
optimizer: Optional[torch.optim.Optimizer],
) -> None:
for bench_input in bench_inputs:
pooled_embeddings = model(bench_input)
vals = []
for _name, param in pooled_embeddings.to_dict().items():
vals.append(param)
torch.cat(vals, dim=1).sum().backward()
if optimizer:
optimizer.step()
optimizer.zero_grad()


def benchmark_ebc(
tables: List[Tuple[int, int]], args: argparse.Namespace, output_dir: str
) -> List[BenchmarkResult]:
table_configs = get_tables(tables, data_type=DataType.FP32)
sharder = TestEBCSharder(
sharding_type="", # sharding_type gets populated during benchmarking
kernel_type=EmbeddingComputeKernel.DENSE.value,
)

module = EmbeddingBagCollection(
# pyre-ignore [6]
tables=table_configs,
is_weighted=False,
device=torch.device("cpu"),
)

optimizer = torch.optim.SGD(module.parameters(), lr=0.02)
args_kwargs = {
argname: getattr(args, argname)
for argname in dir(args)
# Don't include output_dir since output_dir was modified
if not argname.startswith("_") and argname != "output_dir"
}

return benchmark_module(
module=module,
sharder=sharder,
sharding_types=BENCH_SHARDING_TYPES,
compile_modes=BENCH_COMPILE_MODES,
tables=table_configs,
output_dir=output_dir,
func_to_benchmark=training_func_to_benchmark,
benchmark_func_kwargs={"optimizer": optimizer},
**args_kwargs,
)


def main() -> None:
args: argparse.Namespace = init_argparse_and_args()

num_requests = args.bench_iters * args.batch_size * args.num_benchmarks
datetime_sfx: str = time.strftime("%Y%m%dT%H%M%S")

output_dir = args.output_dir
if not os.path.exists(output_dir):
# Create output directory if not exist
os.mkdir(output_dir)

benchmark_results_per_module = []
write_report_funcs_per_module = []
shrunk_table_sizes = []

for i in range(len(TABLE_SIZES)):
if TABLE_SIZES[i][0] > 1000000:
shrunk_table_sizes.append((1000000, TABLE_SIZES[i][1]))
else:
shrunk_table_sizes.append(TABLE_SIZES[i])

for module_name in ["EmbeddingBagCollection"]:
output_dir = args.output_dir + f"/run_{datetime_sfx}"
output_dir += "_ebc"
benchmark_func = benchmark_ebc

if not os.path.exists(output_dir):
# Place all outputs under the datetime folder
os.mkdir(output_dir)

tables_info = "\nTABLE SIZES:"
for i, (num, dim) in enumerate(shrunk_table_sizes):
# FP32 is 4 bytes
mb = int(float(num * dim) / 1024 / 1024) * 4
tables_info += f"\nTABLE[{i}][{num:9}, {dim:4}] {mb:6}Mb"

report: str = f"REPORT BENCHMARK {datetime_sfx} world_size:{args.world_size} batch_size:{args.batch_size}\n"
report += f"Module: {module_name}\n"
report += tables_info
report += "\n"

report += f"num_requests:{num_requests:8}\n"
report_file: str = f"{output_dir}/run.report"

# Save results to output them once benchmarking is all done
benchmark_results_per_module.append(
benchmark_func(shrunk_table_sizes, args, output_dir)
)
write_report_funcs_per_module.append(
partial(
write_report,
report_file=report_file,
report_str=report,
num_requests=num_requests,
)
)

for i, write_report_func in enumerate(write_report_funcs_per_module):
write_report_func(benchmark_results_per_module[i])


def invoke_main() -> None:
logging.basicConfig()
logging.getLogger().setLevel(logging.DEBUG)

main()


if __name__ == "__main__":
invoke_main() # pragma: no cover
Loading

0 comments on commit 8e975d2

Please sign in to comment.