Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add docs to planner related classes #1853

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 77 additions & 1 deletion torchrec/distributed/planner/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,44 @@ def __hash__(self) -> int:

class ShardingOption:
"""
One way of sharding an embedding table.
One way of sharding an embedding table. In the enumerator, we generate
multiple sharding options per table, but in the planner output, there
should only be one sharding option per table.

Attributes:
name (str): name of the sharding option.
tensor (torch.Tensor): tensor of the sharding option. Usually on meta
device.
module (Tuple[str, nn.Module]): module and its fqn that contains the
table.
input_lengths (List[float]): list of pooling factors of the feature for
the table.
batch_size (int): batch size of training / eval job.
sharding_type (str): sharding type of the table. Value of enum ShardingType.
compute_kernel (str): compute kernel of the table. Value of enum
EmbeddingComputeKernel.
shards (List[Shard]): list of shards of the table.
cache_params (Optional[CacheParams]): cache parameters to be used by this table.
These are passed to FBGEMM's Split TBE kernel.
enforce_hbm (Optional[bool]): whether to place all weights/momentums in HBM when
using cache.
stochastic_rounding (Optional[bool]): whether to do stochastic rounding. This is
passed to FBGEMM's Split TBE kernel. Stochastic rounding is
non-deterministic, but important to maintain accuracy in longer
term with FP16 embedding tables.
bounds_check_mode (Optional[BoundsCheckMode]): bounds check mode to be used by
FBGEMM's Split TBE kernel. Bounds check means checking if values
(i.e. row id) is within the table size. If row id exceeds table
size, it will be set to 0.
dependency (Optional[str]): dependency of the table. Related to
Embedding tower.
is_pooled (Optional[bool]): whether the table is pooled. Pooling can be
sum pooling or mean pooling. Unpooled tables are also known as
sequence embeddings.
feature_names (Optional[List[str]]): list of feature names for this table.
output_dtype (Optional[DataType]): output dtype to be used by this table.
The default is FP32. If not None, the output dtype will also be used
by the planner to produce a more balanced plan.
"""

def __init__(
Expand Down Expand Up @@ -428,6 +465,45 @@ class ParameterConstraints:

If provided, `pooling_factors`, `num_poolings`, and `batch_sizes` must match in
length, as per sample.

Attributes:
sharding_types (Optional[List[str]]): sharding types allowed for the table.
Values of enum ShardingType.
compute_kernels (Optional[List[str]]): compute kernels allowed for the table.
Values of enum EmbeddingComputeKernel.
min_partition (Optional[int]): lower bound for dimension of column wise shards.
Planner will search for the column wise shard dimension in the
range of [min_partition, embedding_dim], as long as the column wise
shard dimension divides embedding_dim and is divisible by 4. Used
for column wise sharding only.
pooling_factors (Optional[List[float]]): pooling factors for each feature of the
table. This is the average number of values each sample has for
the feature. Length of pooling_factors should match the number of
features.
num_poolings (OptionalList[float]]): number of poolings for each feature of the
table. Length of num_poolings should match the number of features.
batch_sizes (Optional[List[int]]): batch sizes for each feature of the table. Length
of batch_sizes should match the number of features.
is_weighted (Optional[bool]): whether the table is weighted.
cache_params (Optional[CacheParams]): cache parameters to be used by this table.
These are passed to FBGEMM's Split TBE kernel.
enforce_hbm (Optional[bool]): whether to place all weights/momentums in HBM when
using cache.
stochastic_rounding (Optional[bool]): whether to do stochastic rounding. This is
passed to FBGEMM's Split TBE kernel. Stochastic rounding is
non-deterministic, but important to maintain accuracy in longer
term with FP16 embedding tables.
bounds_check_mode (Optional[BoundsCheckMode]): bounds check mode to be used by
FBGEMM's Split TBE kernel. Bounds check means checking if values
(i.e. row id) is within the table size. If row id exceeds table
size, it will be set to 0.
feature_names (Optional[List[str]]): list of feature names for this table.
output_dtype (Optional[DataType]): output dtype to be used by this table.
The default is FP32. If not None, the output dtype will also be used
by the planner to produce a more balanced plan.
device_group (Optional[str]): device group to be used by this table. It can be cpu
or cuda. This specifies if the table should be placed on a cpu device
or a gpu device.
"""

sharding_types: Optional[List[str]] = None
Expand Down
19 changes: 19 additions & 0 deletions torchrec/distributed/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,25 @@ def cacheability(self) -> float:

@dataclass
class CacheParams:
"""Caching related fused params for an embedding table. Most of these are
passed to FBGEMM's Split TBE. These are useful for when uvm caching is used.

Attributes:
algorithm (Optional[CacheAlgorithm]): cache algorithm to use. Options
include LRU and LFU.
load_factor (Optional[float]): cache load factor per table. This decides
the size of the cache space for the table, and is crucial for
performance when using uvm caching.
reserved_memory (Optional[float]): reserved memory for the cache.
precision (Optional[DataType]): precision of the cache. Ideally this
should be the same as the data type of the weights (aka table).
prefetch_pipeline (Optional[bool]): whether to prefetch pipeline is
used.
stats (Optional[CacheStatistics]): cache statistics which has table
related metadata. Used to create a better plan and tune the load
factor.
"""

algorithm: Optional[CacheAlgorithm] = None
load_factor: Optional[float] = None
reserved_memory: Optional[float] = None
Expand Down
Loading