Skip to content

Commit

Permalink
Support custom embedding config
Browse files Browse the repository at this point in the history
Differential Revision: D55055970
  • Loading branch information
xw285cornell authored and facebook-github-bot committed Mar 19, 2024
1 parent 5a2d11b commit 2af5313
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 27 deletions.
42 changes: 25 additions & 17 deletions torchrec/distributed/benchmark/benchmark_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,9 @@
benchmark_module,
BenchmarkResult,
CompileMode,
DLRM_NUM_EMBEDDINGS_PER_FEATURE,
EMBEDDING_DIM,
get_tables,
init_argparse_and_args,
set_embedding_config,
write_report,
)
from torchrec.distributed.embedding_types import EmbeddingComputeKernel, ShardingType
Expand All @@ -50,11 +49,6 @@
# CompileMode.FX_SCRIPT,
]

TABLE_SIZES: List[Tuple[int, int]] = [
(num_embeddings, EMBEDDING_DIM)
for num_embeddings in DLRM_NUM_EMBEDDINGS_PER_FEATURE
]


def training_func_to_benchmark(
model: torch.nn.Module,
Expand All @@ -73,29 +67,43 @@ def training_func_to_benchmark(


def benchmark_ebc(
tables: List[Tuple[int, int]], args: argparse.Namespace, output_dir: str
tables: List[Tuple[int, int]],
args: argparse.Namespace,
output_dir: str,
pooling_configs: Optional[List[int]] = None,
) -> List[BenchmarkResult]:
table_configs = get_tables(tables, data_type=DataType.FP32)
sharder = TestEBCSharder(
sharding_type="", # sharding_type gets populated during benchmarking
kernel_type=EmbeddingComputeKernel.DENSE.value,
)

# we initialize the embedding tables using CUDA, because when the table is large,
# CPU initialization will be prohibitively long. We then copy the module back
# to CPU because this module will be sent over subprocesses via multiprocessing,
# and we don't want to create an extra CUDA context on GPU0 for each subprocess.
# we also need to release the memory in the parent process (empty_cache)
module = EmbeddingBagCollection(
# pyre-ignore [6]
tables=table_configs,
is_weighted=False,
device=torch.device("cpu"),
)
device=torch.device("cuda"),
).cpu()

torch.cuda.empty_cache()

IGNORE_ARGNAME = ["output_dir", "embedding_config_json", "max_num_embeddings"]
optimizer = torch.optim.SGD(module.parameters(), lr=0.02)
args_kwargs = {
argname: getattr(args, argname)
for argname in dir(args)
# Don't include output_dir since output_dir was modified
if not argname.startswith("_") and argname != "output_dir"
if not argname.startswith("_") and argname not in IGNORE_ARGNAME
}

if pooling_configs:
args_kwargs["pooling_configs"] = pooling_configs

return benchmark_module(
module=module,
sharder=sharder,
Expand Down Expand Up @@ -124,11 +132,11 @@ def main() -> None:
write_report_funcs_per_module = []
shrunk_table_sizes = []

for i in range(len(TABLE_SIZES)):
if TABLE_SIZES[i][0] > 1000000:
shrunk_table_sizes.append((1000000, TABLE_SIZES[i][1]))
else:
shrunk_table_sizes.append(TABLE_SIZES[i])
embedding_configs, pooling_configs = set_embedding_config(
args.embedding_config_json
)
for config in embedding_configs:
shrunk_table_sizes.append((min(args.max_num_embeddings, config[0]), config[1]))

for module_name in ["EmbeddingBagCollection"]:
output_dir = args.output_dir + f"/run_{datetime_sfx}"
Expand Down Expand Up @@ -157,7 +165,7 @@ def main() -> None:

# Save results to output them once benchmarking is all done
benchmark_results_per_module.append(
benchmark_func(shrunk_table_sizes, args, output_dir)
benchmark_func(shrunk_table_sizes, args, output_dir, pooling_configs)
)
write_report_funcs_per_module.append(
partial(
Expand Down
56 changes: 55 additions & 1 deletion torchrec/distributed/benchmark/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import contextlib
import copy
import gc
import json
import logging
import os
from dataclasses import dataclass
Expand Down Expand Up @@ -266,6 +267,7 @@ def get_inputs(
world_size: int,
num_inputs: int,
rank: int = -1,
pooling_configs: Optional[List[int]] = None,
) -> List[KeyedJaggedTensor]:
inputs: List[KeyedJaggedTensor] = []

Expand All @@ -277,6 +279,7 @@ def get_inputs(
tables=tables,
weighted_tables=[],
long_indices=False,
tables_pooling=pooling_configs,
)[1][0]

# If ProcessGroup, place input on correct device. Otherwise, place on cuda:0
Expand Down Expand Up @@ -314,6 +317,48 @@ def write_report(
logger.info(f"Report written to {report_file}:\n{report_str}")


def set_embedding_config(
embedding_config_json: str,
) -> Tuple[List[Tuple[int, int]], List[int]]:
"""
the config file should follow this pattern: {feature: {num_embeddings: int, embedding_dim: int}}
"""
embedding_configs = []
pooling_configs = []
has_pooling_config = False
try:
if os.path.exists(embedding_config_json):
with open(embedding_config_json, "r") as f:
embedding_config_json = json.load(f)

for _, config in embedding_config_json.items():
embedding_configs.append(
(config["num_embeddings"], config["embedding_dim"])
)
if "pooling_factor" in config:
pooling_configs.append(config["pooling_factor"])
has_pooling_config = True
else:
if has_pooling_config:
raise RuntimeError(
"We cannot handle some features have pooling factor and others don't."
)
else:
raise RuntimeError(
f"Could not find embedding config json at path {embedding_config_json}"
)
except BaseException as e:
logger.warning(
f"Failed to load embedding config because {e}, fallback to DLRM config"
)
embedding_configs = [
(num_embeddings, EMBEDDING_DIM)
for num_embeddings in DLRM_NUM_EMBEDDINGS_PER_FEATURE
]

return embedding_configs, pooling_configs


def init_argparse_and_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()

Expand All @@ -322,10 +367,13 @@ def init_argparse_and_args() -> argparse.Namespace:
parser.add_argument("--prof_iters", type=int, default=20)
parser.add_argument("--batch_size", type=int, default=2048)
parser.add_argument("--world_size", type=int, default=2)
parser.add_argument("--max_num_embeddings", type=int, default=1000000)
parser.add_argument("--output_dir", type=str, default="/var/tmp/torchrec-bench")
parser.add_argument("--num_benchmarks", type=int, default=5)
parser.add_argument("--embedding_config_json", type=str, default="")

args = parser.parse_args()

return args


Expand Down Expand Up @@ -551,6 +599,7 @@ def init_module_and_run_benchmark(
benchmark_func_kwargs: Optional[Dict[str, Any]],
rank: int = -1,
queue: Optional[mp.Queue] = None,
pooling_configs: Optional[List[int]] = None,
) -> BenchmarkResult:
"""
There are a couple of caveats here as to why the module has to be initialized
Expand All @@ -566,7 +615,9 @@ def init_module_and_run_benchmark(
"""

num_inputs_to_gen: int = warmup_iters + bench_iters + prof_iters
inputs = get_inputs(tables, batch_size, world_size, num_inputs_to_gen, rank)
inputs = get_inputs(
tables, batch_size, world_size, num_inputs_to_gen, rank, pooling_configs
)

warmup_inputs = inputs[:warmup_iters]
bench_inputs = inputs[warmup_iters : (warmup_iters + bench_iters)]
Expand Down Expand Up @@ -687,6 +738,7 @@ def benchmark_module(
output_dir: str = "",
func_to_benchmark: Callable[..., None] = default_func_to_benchmark,
benchmark_func_kwargs: Optional[Dict[str, Any]] = None,
pooling_configs: Optional[List[int]] = None,
) -> List[BenchmarkResult]:
"""
Args:
Expand Down Expand Up @@ -761,6 +813,7 @@ def benchmark_module(
output_dir=output_dir,
func_to_benchmark=func_to_benchmark,
benchmark_func_kwargs=benchmark_func_kwargs,
pooling_configs=pooling_configs,
)
else:
res = init_module_and_run_benchmark(
Expand All @@ -780,6 +833,7 @@ def benchmark_module(
output_dir=output_dir,
func_to_benchmark=func_to_benchmark,
benchmark_func_kwargs=benchmark_func_kwargs,
pooling_configs=pooling_configs,
)

gc.collect()
Expand Down
60 changes: 51 additions & 9 deletions torchrec/distributed/test_utils/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def generate(
] = None,
variable_batch_size: bool = False,
long_indices: bool = True,
tables_pooling: Optional[List[int]] = None,
weighted_tables_pooling: Optional[List[int]] = None,
) -> Tuple["ModelInput", List["ModelInput"]]:
"""
Returns a global (single-rank training) batch
Expand All @@ -78,10 +80,30 @@ def generate(
for r in range(world_size)
]

def _validate_pooling_factor(
tables: Union[
List[EmbeddingTableConfig],
List[EmbeddingBagConfig],
List[EmbeddingConfig],
],
pooling_factor: Optional[List[int]],
) -> None:
if pooling_factor and len(pooling_factor) != len(tables):
raise ValueError(
"tables_pooling and tables must have the same length. "
f"Got {len(pooling_factor)} and {len(tables)}."
)

_validate_pooling_factor(tables, tables_pooling)
_validate_pooling_factor(weighted_tables, weighted_tables_pooling)

idlist_features_to_num_embeddings = {}
for table in tables:
for feature in table.feature_names:
idlist_features_to_num_embeddings[feature] = table.num_embeddings
idlist_features_to_pooling_factor = {}
for idx in range(len(tables)):
for feature in tables[idx].feature_names:
idlist_features_to_num_embeddings[feature] = tables[idx].num_embeddings
if tables_pooling is not None:
idlist_features_to_pooling_factor[feature] = tables_pooling[idx]

idlist_features = list(idlist_features_to_num_embeddings.keys())
idscore_features = [
Expand All @@ -91,17 +113,31 @@ def generate(
idlist_ind_ranges = list(idlist_features_to_num_embeddings.values())
idscore_ind_ranges = [table.num_embeddings for table in weighted_tables]

idlist_pooling_factor = list(idlist_features_to_pooling_factor.values())
idscore_pooling_factor = weighted_tables_pooling

# Generate global batch.
global_idlist_lengths = []
global_idlist_indices = []
global_idscore_lengths = []
global_idscore_indices = []
global_idscore_weights = []

for ind_range in idlist_ind_ranges:
lengths_ = torch.abs(
torch.randn(batch_size * world_size) + pooling_avg
).int()
for idx in range(len(idlist_ind_ranges)):
ind_range = idlist_ind_ranges[idx]
if idlist_pooling_factor:
lengths_ = torch.max(
torch.normal(
idlist_pooling_factor[idx],
idlist_pooling_factor[idx] / 10,
[batch_size * world_size],
),
torch.tensor(1.0),
).int()
else:
lengths_ = torch.abs(
torch.randn(batch_size * world_size) + pooling_avg
).int()
if variable_batch_size:
lengths = torch.zeros(batch_size * world_size).int()
for r in range(world_size):
Expand All @@ -127,9 +163,15 @@ def generate(
lengths=torch.cat(global_idlist_lengths),
)

for ind_range in idscore_ind_ranges:
for idx in range(len(idscore_ind_ranges)):
ind_range = idscore_ind_ranges[idx]
lengths_ = torch.abs(
torch.randn(batch_size * world_size) + pooling_avg
torch.randn(batch_size * world_size)
+ (
idscore_pooling_factor[idx]
if idscore_pooling_factor
else pooling_avg
)
).int()
if variable_batch_size:
lengths = torch.zeros(batch_size * world_size).int()
Expand Down

0 comments on commit 2af5313

Please sign in to comment.