Skip to content

Commit

Permalink
Add docs to planner related classes (#1853)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1853

add docs

Reviewed By: ge0405

Differential Revision: D55881753

fbshipit-source-id: 6cb4e4eb5b94f7247622952ada4ee1706458aff7
  • Loading branch information
henrylhtsang authored and facebook-github-bot committed Apr 8, 2024
1 parent 4981be5 commit 9ca3923
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 1 deletion.
78 changes: 77 additions & 1 deletion torchrec/distributed/planner/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,44 @@ def __hash__(self) -> int:

class ShardingOption:
"""
One way of sharding an embedding table.
One way of sharding an embedding table. In the enumerator, we generate
multiple sharding options per table, but in the planner output, there
should only be one sharding option per table.
Attributes:
name (str): name of the sharding option.
tensor (torch.Tensor): tensor of the sharding option. Usually on meta
device.
module (Tuple[str, nn.Module]): module and its fqn that contains the
table.
input_lengths (List[float]): list of pooling factors of the feature for
the table.
batch_size (int): batch size of training / eval job.
sharding_type (str): sharding type of the table. Value of enum ShardingType.
compute_kernel (str): compute kernel of the table. Value of enum
EmbeddingComputeKernel.
shards (List[Shard]): list of shards of the table.
cache_params (Optional[CacheParams]): cache parameters to be used by this table.
These are passed to FBGEMM's Split TBE kernel.
enforce_hbm (Optional[bool]): whether to place all weights/momentums in HBM when
using cache.
stochastic_rounding (Optional[bool]): whether to do stochastic rounding. This is
passed to FBGEMM's Split TBE kernel. Stochastic rounding is
non-deterministic, but important to maintain accuracy in longer
term with FP16 embedding tables.
bounds_check_mode (Optional[BoundsCheckMode]): bounds check mode to be used by
FBGEMM's Split TBE kernel. Bounds check means checking if values
(i.e. row id) is within the table size. If row id exceeds table
size, it will be set to 0.
dependency (Optional[str]): dependency of the table. Related to
Embedding tower.
is_pooled (Optional[bool]): whether the table is pooled. Pooling can be
sum pooling or mean pooling. Unpooled tables are also known as
sequence embeddings.
feature_names (Optional[List[str]]): list of feature names for this table.
output_dtype (Optional[DataType]): output dtype to be used by this table.
The default is FP32. If not None, the output dtype will also be used
by the planner to produce a more balanced plan.
"""

def __init__(
Expand Down Expand Up @@ -428,6 +465,45 @@ class ParameterConstraints:
If provided, `pooling_factors`, `num_poolings`, and `batch_sizes` must match in
length, as per sample.
Attributes:
sharding_types (Optional[List[str]]): sharding types allowed for the table.
Values of enum ShardingType.
compute_kernels (Optional[List[str]]): compute kernels allowed for the table.
Values of enum EmbeddingComputeKernel.
min_partition (Optional[int]): lower bound for dimension of column wise shards.
Planner will search for the column wise shard dimension in the
range of [min_partition, embedding_dim], as long as the column wise
shard dimension divides embedding_dim and is divisible by 4. Used
for column wise sharding only.
pooling_factors (Optional[List[float]]): pooling factors for each feature of the
table. This is the average number of values each sample has for
the feature. Length of pooling_factors should match the number of
features.
num_poolings (OptionalList[float]]): number of poolings for each feature of the
table. Length of num_poolings should match the number of features.
batch_sizes (Optional[List[int]]): batch sizes for each feature of the table. Length
of batch_sizes should match the number of features.
is_weighted (Optional[bool]): whether the table is weighted.
cache_params (Optional[CacheParams]): cache parameters to be used by this table.
These are passed to FBGEMM's Split TBE kernel.
enforce_hbm (Optional[bool]): whether to place all weights/momentums in HBM when
using cache.
stochastic_rounding (Optional[bool]): whether to do stochastic rounding. This is
passed to FBGEMM's Split TBE kernel. Stochastic rounding is
non-deterministic, but important to maintain accuracy in longer
term with FP16 embedding tables.
bounds_check_mode (Optional[BoundsCheckMode]): bounds check mode to be used by
FBGEMM's Split TBE kernel. Bounds check means checking if values
(i.e. row id) is within the table size. If row id exceeds table
size, it will be set to 0.
feature_names (Optional[List[str]]): list of feature names for this table.
output_dtype (Optional[DataType]): output dtype to be used by this table.
The default is FP32. If not None, the output dtype will also be used
by the planner to produce a more balanced plan.
device_group (Optional[str]): device group to be used by this table. It can be cpu
or cuda. This specifies if the table should be placed on a cpu device
or a gpu device.
"""

sharding_types: Optional[List[str]] = None
Expand Down
19 changes: 19 additions & 0 deletions torchrec/distributed/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,25 @@ def cacheability(self) -> float:

@dataclass
class CacheParams:
"""Caching related fused params for an embedding table. Most of these are
passed to FBGEMM's Split TBE. These are useful for when uvm caching is used.
Attributes:
algorithm (Optional[CacheAlgorithm]): cache algorithm to use. Options
include LRU and LFU.
load_factor (Optional[float]): cache load factor per table. This decides
the size of the cache space for the table, and is crucial for
performance when using uvm caching.
reserved_memory (Optional[float]): reserved memory for the cache.
precision (Optional[DataType]): precision of the cache. Ideally this
should be the same as the data type of the weights (aka table).
prefetch_pipeline (Optional[bool]): whether to prefetch pipeline is
used.
stats (Optional[CacheStatistics]): cache statistics which has table
related metadata. Used to create a better plan and tune the load
factor.
"""

algorithm: Optional[CacheAlgorithm] = None
load_factor: Optional[float] = None
reserved_memory: Optional[float] = None
Expand Down

0 comments on commit 9ca3923

Please sign in to comment.