Skip to content

Commit

Permalink
add NJT/TD support in test data generator (pytorch#2528)
Browse files Browse the repository at this point in the history
Summary:

# context
* add NJT/TD support in test data generator
* add NJT/TD input option in pipeline benchmark
* resolve pyre/typing errors in multiple places
* should be safe to land, no production impact

NOTE: This diff is splitted from the next one (D65103519) to resolve pyre/typing errors

Differential Revision: D65120889
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Oct 30, 2024
1 parent 37631df commit ffe02f8
Show file tree
Hide file tree
Showing 8 changed files with 120 additions and 44 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/unittest_ci_cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ jobs:
conda run -n build_binary \
python -c "import fbgemm_gpu"
echo "fbgemm_gpu succeeded"
conda run -n build_binary \
pip install tensordict
conda run -n build_binary \
pip install -r requirements.txt
conda run -n build_binary \
Expand Down
5 changes: 4 additions & 1 deletion torchrec/distributed/benchmark/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,11 +318,14 @@ def get_inputs(

if train:
sparse_features_by_rank = [
model_input.idlist_features for model_input in model_input_by_rank
model_input.idlist_features
for model_input in model_input_by_rank
if isinstance(model_input.idlist_features, KeyedJaggedTensor)
]
inputs_batch.append(sparse_features_by_rank)
else:
sparse_features = model_input_by_rank[0].idlist_features
assert isinstance(sparse_features, KeyedJaggedTensor)
inputs_batch.append([sparse_features])

# Transpose if train, as inputs_by_rank is currently in [B X R] format
Expand Down
4 changes: 3 additions & 1 deletion torchrec/distributed/test_utils/infer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ def model_input_to_forward_args_kjt(
Optional[torch.Tensor],
]:
kjt = mi.idlist_features
assert isinstance(kjt, KeyedJaggedTensor)
return (
kjt._keys,
kjt._values,
Expand Down Expand Up @@ -289,7 +290,8 @@ def model_input_to_forward_args(
]:
idlist_kjt = mi.idlist_features
idscore_kjt = mi.idscore_features
assert idscore_kjt is not None
assert isinstance(idlist_kjt, KeyedJaggedTensor)
assert isinstance(idscore_kjt, KeyedJaggedTensor)
return (
mi.float_features,
idlist_kjt._keys,
Expand Down
131 changes: 92 additions & 39 deletions torchrec/distributed/test_utils/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import torch
import torch.nn as nn
from tensordict import TensorDict
from torchrec.distributed.embedding_tower_sharding import (
EmbeddingTowerCollectionSharder,
EmbeddingTowerSharder,
Expand Down Expand Up @@ -46,8 +47,8 @@
@dataclass
class ModelInput(Pipelineable):
float_features: torch.Tensor
idlist_features: KeyedJaggedTensor
idscore_features: Optional[KeyedJaggedTensor]
idlist_features: Union[KeyedJaggedTensor, TensorDict]
idscore_features: Optional[Union[KeyedJaggedTensor, TensorDict]]
label: torch.Tensor

@staticmethod
Expand Down Expand Up @@ -76,11 +77,13 @@ def generate(
randomize_indices: bool = True,
device: Optional[torch.device] = None,
max_feature_lengths: Optional[List[int]] = None,
input_type: str = "kjt",
) -> Tuple["ModelInput", List["ModelInput"]]:
"""
Returns a global (single-rank training) batch
and a list of local (multi-rank training) batches of world_size.
"""

batch_size_by_rank = [batch_size] * world_size
if variable_batch_size:
batch_size_by_rank = [
Expand Down Expand Up @@ -199,11 +202,26 @@ def _validate_pooling_factor(
)
global_idlist_lengths.append(lengths)
global_idlist_indices.append(indices)
global_idlist_kjt = KeyedJaggedTensor(
keys=idlist_features,
values=torch.cat(global_idlist_indices),
lengths=torch.cat(global_idlist_lengths),
)

if input_type == "kjt":
global_idlist_input = KeyedJaggedTensor(
keys=idlist_features,
values=torch.cat(global_idlist_indices),
lengths=torch.cat(global_idlist_lengths),
)
elif input_type == "td":
dict_of_nt = {
k: torch.nested.nested_tensor_from_jagged(
values=values,
lengths=lengths,
)
for k, values, lengths in zip(
idlist_features, global_idlist_indices, global_idlist_lengths
)
}
global_idlist_input = TensorDict(source=dict_of_nt)
else:
raise ValueError(f"For IdList features, unknown input type {input_type}")

for idx in range(len(idscore_ind_ranges)):
ind_range = idscore_ind_ranges[idx]
Expand Down Expand Up @@ -245,16 +263,25 @@ def _validate_pooling_factor(
global_idscore_lengths.append(lengths)
global_idscore_indices.append(indices)
global_idscore_weights.append(weights)
global_idscore_kjt = (
KeyedJaggedTensor(
keys=idscore_features,
values=torch.cat(global_idscore_indices),
lengths=torch.cat(global_idscore_lengths),
weights=torch.cat(global_idscore_weights),

if input_type == "kjt":
global_idscore_input = (
KeyedJaggedTensor(
keys=idscore_features,
values=torch.cat(global_idscore_indices),
lengths=torch.cat(global_idscore_lengths),
weights=torch.cat(global_idscore_weights),
)
if global_idscore_indices
else None
)
if global_idscore_indices
else None
)
elif input_type == "td":
assert (
len(idscore_features) == 0
), "TensorDict does not support weighted features"
global_idscore_input = None
else:
raise ValueError(f"For weighted features, unknown input type {input_type}")

if randomize_indices:
global_float = torch.rand(
Expand Down Expand Up @@ -303,36 +330,57 @@ def _validate_pooling_factor(
weights[lengths_cumsum[r] : lengths_cumsum[r + 1]]
)

local_idlist_kjt = KeyedJaggedTensor(
keys=idlist_features,
values=torch.cat(local_idlist_indices),
lengths=torch.cat(local_idlist_lengths),
)
if input_type == "kjt":
local_idlist_input = KeyedJaggedTensor(
keys=idlist_features,
values=torch.cat(local_idlist_indices),
lengths=torch.cat(local_idlist_lengths),
)

local_idscore_kjt = (
KeyedJaggedTensor(
keys=idscore_features,
values=torch.cat(local_idscore_indices),
lengths=torch.cat(local_idscore_lengths),
weights=torch.cat(local_idscore_weights),
local_idscore_input = (
KeyedJaggedTensor(
keys=idscore_features,
values=torch.cat(local_idscore_indices),
lengths=torch.cat(local_idscore_lengths),
weights=torch.cat(local_idscore_weights),
)
if local_idscore_indices
else None
)
elif input_type == "td":
dict_of_nt = {
k: torch.nested.nested_tensor_from_jagged(
values=values,
lengths=lengths,
)
for k, values, lengths in zip(
idlist_features, local_idlist_indices, local_idlist_lengths
)
}
local_idlist_input = TensorDict(source=dict_of_nt)
assert (
len(idscore_features) == 0
), "TensorDict does not support weighted features"
local_idscore_input = None

else:
raise ValueError(
f"For weighted features, unknown input type {input_type}"
)
if local_idscore_indices
else None
)

local_input = ModelInput(
float_features=global_float[r * batch_size : (r + 1) * batch_size],
idlist_features=local_idlist_kjt,
idscore_features=local_idscore_kjt,
idlist_features=local_idlist_input,
idscore_features=local_idscore_input,
label=global_label[r * batch_size : (r + 1) * batch_size],
)
local_inputs.append(local_input)

return (
ModelInput(
float_features=global_float,
idlist_features=global_idlist_kjt,
idscore_features=global_idscore_kjt,
idlist_features=global_idlist_input,
idscore_features=global_idscore_input,
label=global_label,
),
local_inputs,
Expand Down Expand Up @@ -623,8 +671,9 @@ def to(self, device: torch.device, non_blocking: bool = False) -> "ModelInput":

def record_stream(self, stream: torch.Stream) -> None:
self.float_features.record_stream(stream)
self.idlist_features.record_stream(stream)
if self.idscore_features is not None:
if isinstance(self.idlist_features, KeyedJaggedTensor):
self.idlist_features.record_stream(stream)
if isinstance(self.idscore_features, KeyedJaggedTensor):
self.idscore_features.record_stream(stream)
self.label.record_stream(stream)

Expand Down Expand Up @@ -1753,10 +1802,12 @@ def forward(
if self._preproc_module is not None:
modified_input = self._preproc_module(modified_input)
elif self._run_preproc_inline:
idlist_features = modified_input.idlist_features
assert isinstance(idlist_features, KeyedJaggedTensor)
modified_input.idlist_features = KeyedJaggedTensor.from_lengths_sync(
modified_input.idlist_features.keys(),
modified_input.idlist_features.values(),
modified_input.idlist_features.lengths(),
idlist_features.keys(),
idlist_features.values(),
idlist_features.lengths(),
)

modified_idlist_features = self.preproc_nonweighted(
Expand Down Expand Up @@ -1817,6 +1868,8 @@ def forward(self, input: ModelInput) -> ModelInput:
)

# stride will be same but features will be joined
assert isinstance(modified_input.idlist_features, KeyedJaggedTensor)
assert isinstance(self._extra_input.idlist_features, KeyedJaggedTensor)
modified_input.idlist_features = KeyedJaggedTensor.concat(
[modified_input.idlist_features, self._extra_input.idlist_features]
)
Expand Down
3 changes: 3 additions & 0 deletions torchrec/distributed/tests/test_infer_shardings.py
Original file line number Diff line number Diff line change
Expand Up @@ -1969,6 +1969,7 @@ def test_sharded_quant_fp_ebc_tw(
inputs = []
for model_input in model_inputs:
kjt = model_input.idlist_features
assert isinstance(kjt, KeyedJaggedTensor)
kjt = kjt.to(local_device)
weights = torch.rand(
kjt._values.size(0), dtype=torch.float, device=local_device
Expand Down Expand Up @@ -2149,6 +2150,7 @@ def test_sharded_quant_mc_ec_rw(
inputs = []
for model_input in model_inputs:
kjt = model_input.idlist_features
assert isinstance(kjt, KeyedJaggedTensor)
kjt = kjt.to(local_device)
weights = None
inputs.append(
Expand Down Expand Up @@ -2285,6 +2287,7 @@ def test_sharded_quant_fp_ebc_tw_meta(self, compute_device: str) -> None:
)
inputs = []
kjt = model_inputs[0].idlist_features
assert isinstance(kjt, KeyedJaggedTensor)
kjt = kjt.to(local_device)
weights = torch.rand(
kjt._values.size(0), dtype=torch.float, device=local_device
Expand Down
12 changes: 10 additions & 2 deletions torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ def _gen_pipelines(
default=100,
help="Total number of sparse embeddings to be used.",
)
@click.option(
"--ratio_features_weighted",
default=0.4,
help="percentage of features weighted vs unweighted",
)
@click.option(
"--dim_emb",
type=int,
Expand Down Expand Up @@ -132,6 +137,7 @@ def _gen_pipelines(
def main(
world_size: int,
n_features: int,
ratio_features_weighted: float,
dim_emb: int,
n_batches: int,
batch_size: int,
Expand All @@ -149,8 +155,9 @@ def main(
os.environ["MASTER_ADDR"] = str("localhost")
os.environ["MASTER_PORT"] = str(get_free_port())

num_features = n_features // 2
num_weighted_features = n_features // 2
num_weighted_features = int(n_features * ratio_features_weighted)
num_features = n_features - num_weighted_features

tables = [
EmbeddingBagConfig(
num_embeddings=(i + 1) * 1000,
Expand Down Expand Up @@ -257,6 +264,7 @@ def _generate_data(
world_size=world_size,
num_float_features=num_float_features,
pooling_avg=pooling_factor,
input_type=input_type,
)[1]
for i in range(num_batches)
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,11 @@ def test_equal_to_non_pipelined_with_input_transformer(self) -> None:
optimizer_cpu = optim.SGD(model_cpu.model.parameters(), lr=0.01)
optimizer_gpu = optim.SGD(model_gpu.model.parameters(), lr=0.01)

data = [i.idlist_features for i in local_model_inputs]
data = [
i.idlist_features
for i in local_model_inputs
if isinstance(i.idlist_features, KeyedJaggedTensor)
]
dataloader = iter(data)
pipeline = TrainPipelinePT2(
model_gpu, optimizer_gpu, self.device, input_transformer=kjt_for_pt2_tracing
Expand Down
1 change: 1 addition & 0 deletions torchrec/sparse/tests/keyed_jagged_tensor_benchmark_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def generate_kjt(
randomize_indices=True,
device=device,
)[0]
assert isinstance(global_input.idlist_features, KeyedJaggedTensor)
return global_input.idlist_features


Expand Down

0 comments on commit ffe02f8

Please sign in to comment.