From 9ca3923f39fe4ad7dba053f0b479f93afb6bf624 Mon Sep 17 00:00:00 2001 From: Henry Tsang Date: Mon, 8 Apr 2024 16:33:26 -0700 Subject: [PATCH] Add docs to planner related classes (#1853) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/1853 add docs Reviewed By: ge0405 Differential Revision: D55881753 fbshipit-source-id: 6cb4e4eb5b94f7247622952ada4ee1706458aff7 --- torchrec/distributed/planner/types.py | 78 ++++++++++++++++++++++++++- torchrec/distributed/types.py | 19 +++++++ 2 files changed, 96 insertions(+), 1 deletion(-) diff --git a/torchrec/distributed/planner/types.py b/torchrec/distributed/planner/types.py index 8bc071425..a626f6f8d 100644 --- a/torchrec/distributed/planner/types.py +++ b/torchrec/distributed/planner/types.py @@ -271,7 +271,44 @@ def __hash__(self) -> int: class ShardingOption: """ - One way of sharding an embedding table. + One way of sharding an embedding table. In the enumerator, we generate + multiple sharding options per table, but in the planner output, there + should only be one sharding option per table. + + Attributes: + name (str): name of the sharding option. + tensor (torch.Tensor): tensor of the sharding option. Usually on meta + device. + module (Tuple[str, nn.Module]): module and its fqn that contains the + table. + input_lengths (List[float]): list of pooling factors of the feature for + the table. + batch_size (int): batch size of training / eval job. + sharding_type (str): sharding type of the table. Value of enum ShardingType. + compute_kernel (str): compute kernel of the table. Value of enum + EmbeddingComputeKernel. + shards (List[Shard]): list of shards of the table. + cache_params (Optional[CacheParams]): cache parameters to be used by this table. + These are passed to FBGEMM's Split TBE kernel. + enforce_hbm (Optional[bool]): whether to place all weights/momentums in HBM when + using cache. + stochastic_rounding (Optional[bool]): whether to do stochastic rounding. This is + passed to FBGEMM's Split TBE kernel. Stochastic rounding is + non-deterministic, but important to maintain accuracy in longer + term with FP16 embedding tables. + bounds_check_mode (Optional[BoundsCheckMode]): bounds check mode to be used by + FBGEMM's Split TBE kernel. Bounds check means checking if values + (i.e. row id) is within the table size. If row id exceeds table + size, it will be set to 0. + dependency (Optional[str]): dependency of the table. Related to + Embedding tower. + is_pooled (Optional[bool]): whether the table is pooled. Pooling can be + sum pooling or mean pooling. Unpooled tables are also known as + sequence embeddings. + feature_names (Optional[List[str]]): list of feature names for this table. + output_dtype (Optional[DataType]): output dtype to be used by this table. + The default is FP32. If not None, the output dtype will also be used + by the planner to produce a more balanced plan. """ def __init__( @@ -428,6 +465,45 @@ class ParameterConstraints: If provided, `pooling_factors`, `num_poolings`, and `batch_sizes` must match in length, as per sample. + + Attributes: + sharding_types (Optional[List[str]]): sharding types allowed for the table. + Values of enum ShardingType. + compute_kernels (Optional[List[str]]): compute kernels allowed for the table. + Values of enum EmbeddingComputeKernel. + min_partition (Optional[int]): lower bound for dimension of column wise shards. + Planner will search for the column wise shard dimension in the + range of [min_partition, embedding_dim], as long as the column wise + shard dimension divides embedding_dim and is divisible by 4. Used + for column wise sharding only. + pooling_factors (Optional[List[float]]): pooling factors for each feature of the + table. This is the average number of values each sample has for + the feature. Length of pooling_factors should match the number of + features. + num_poolings (OptionalList[float]]): number of poolings for each feature of the + table. Length of num_poolings should match the number of features. + batch_sizes (Optional[List[int]]): batch sizes for each feature of the table. Length + of batch_sizes should match the number of features. + is_weighted (Optional[bool]): whether the table is weighted. + cache_params (Optional[CacheParams]): cache parameters to be used by this table. + These are passed to FBGEMM's Split TBE kernel. + enforce_hbm (Optional[bool]): whether to place all weights/momentums in HBM when + using cache. + stochastic_rounding (Optional[bool]): whether to do stochastic rounding. This is + passed to FBGEMM's Split TBE kernel. Stochastic rounding is + non-deterministic, but important to maintain accuracy in longer + term with FP16 embedding tables. + bounds_check_mode (Optional[BoundsCheckMode]): bounds check mode to be used by + FBGEMM's Split TBE kernel. Bounds check means checking if values + (i.e. row id) is within the table size. If row id exceeds table + size, it will be set to 0. + feature_names (Optional[List[str]]): list of feature names for this table. + output_dtype (Optional[DataType]): output dtype to be used by this table. + The default is FP32. If not None, the output dtype will also be used + by the planner to produce a more balanced plan. + device_group (Optional[str]): device group to be used by this table. It can be cpu + or cuda. This specifies if the table should be placed on a cpu device + or a gpu device. """ sharding_types: Optional[List[str]] = None diff --git a/torchrec/distributed/types.py b/torchrec/distributed/types.py index 00515814a..2d649bfa5 100644 --- a/torchrec/distributed/types.py +++ b/torchrec/distributed/types.py @@ -520,6 +520,25 @@ def cacheability(self) -> float: @dataclass class CacheParams: + """Caching related fused params for an embedding table. Most of these are + passed to FBGEMM's Split TBE. These are useful for when uvm caching is used. + + Attributes: + algorithm (Optional[CacheAlgorithm]): cache algorithm to use. Options + include LRU and LFU. + load_factor (Optional[float]): cache load factor per table. This decides + the size of the cache space for the table, and is crucial for + performance when using uvm caching. + reserved_memory (Optional[float]): reserved memory for the cache. + precision (Optional[DataType]): precision of the cache. Ideally this + should be the same as the data type of the weights (aka table). + prefetch_pipeline (Optional[bool]): whether to prefetch pipeline is + used. + stats (Optional[CacheStatistics]): cache statistics which has table + related metadata. Used to create a better plan and tune the load + factor. + """ + algorithm: Optional[CacheAlgorithm] = None load_factor: Optional[float] = None reserved_memory: Optional[float] = None