Skip to content

Commit

Permalink
add sharding_type argument to pipeline benchmark
Browse files Browse the repository at this point in the history
Summary:
# context
* add sharding_type argument to the pipeline benchmark
* better control of different sharding types

Differential Revision: D64676132
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Oct 21, 2024
1 parent dbca437 commit f767886
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@ def _gen_pipelines(
default=8192,
help="Batch size.",
)
@click.option(
"--sharding_type",
type=ShardingType,
default=ShardingType.TABLE_WISE,
help="ShardingType.",
)
@click.option(
"--pooling_factor",
type=int,
Expand Down Expand Up @@ -129,6 +135,7 @@ def main(
dim_emb: int,
n_batches: int,
batch_size: int,
sharding_type: ShardingType,
pooling_factor: int,
input_type: str,
pipeline: str,
Expand Down Expand Up @@ -178,7 +185,7 @@ def main(
callable=runner,
tables=tables,
weighted_tables=weighted_tables,
sharding_type=ShardingType.TABLE_WISE.value,
sharding_type=sharding_type.value,
kernel_type=EmbeddingComputeKernel.FUSED.value,
batches=batches,
fused_params={},
Expand All @@ -190,7 +197,7 @@ def main(
single_runner(
tables=tables,
weighted_tables=weighted_tables,
sharding_type=ShardingType.TABLE_WISE.value,
sharding_type=sharding_type.value,
kernel_type=EmbeddingComputeKernel.FUSED.value,
batches=batches,
fused_params={},
Expand Down

0 comments on commit f767886

Please sign in to comment.