From d9fbac92da7b4e642e555d3eea5c317e2fea5b01 Mon Sep 17 00:00:00 2001 From: Dennis van der Staay Date: Wed, 20 Mar 2024 12:34:31 -0700 Subject: [PATCH] Improve Random Batch generation speed (#1805) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/1805 Remove for loop from batch generation; very slow at larger test batch sizes (4K+) - prior to changes, would take 2mins+. Reviewed By: henrylhtsang Differential Revision: D55052074 Privacy Context Container: 1203980333745195 fbshipit-source-id: 858246959832360d08063e02b95370703dd12dd7 --- torchrec/datasets/random.py | 57 ++++++++++++-------------- torchrec/datasets/tests/test_random.py | 19 +++++++++ 2 files changed, 45 insertions(+), 31 deletions(-) diff --git a/torchrec/datasets/random.py b/torchrec/datasets/random.py index db4804a2b..9008622e5 100644 --- a/torchrec/datasets/random.py +++ b/torchrec/datasets/random.py @@ -81,30 +81,25 @@ def _generate_batch(self) -> Batch: lengths = [] for key_idx, _ in enumerate(self.keys): hash_size = self.hash_sizes[key_idx] - min_num_ids_in_batch = self.min_ids_per_features[key_idx] - max_num_ids_in_batch = self.ids_per_features[key_idx] - for _ in range(self.batch_size): - num_ids_in_batch = int( - torch.randint( - low=min_num_ids_in_batch, - high=max_num_ids_in_batch + 1, - size=(), - generator=self.generator, - ).item() - ) - values.append( - torch.randint( - high=hash_size, - size=(num_ids_in_batch,), - generator=self.generator, - ) - ) - lengths.extend([num_ids_in_batch]) + min_num_ids = self.min_ids_per_features[key_idx] + max_num_ids = self.ids_per_features[key_idx] + length = torch.randint( + min_num_ids, + max_num_ids + 1, + (self.batch_size,), + dtype=torch.int32, + generator=self.generator, + ) + value = torch.randint( + 0, hash_size, (cast(int, length.sum()),), generator=self.generator + ) + lengths.append(length) + values.append(value) sparse_features = KeyedJaggedTensor.from_lengths_sync( keys=self.keys, values=torch.cat(values), - lengths=torch.tensor(lengths, dtype=torch.int32), + lengths=torch.cat(lengths), ) dense_features = torch.randn( @@ -140,14 +135,15 @@ class RandomRecDataset(IterableDataset[Batch]): modulo this value. hash_sizes (Optional[List[int]]): Max sparse id value per feature in keys. Each sparse ID will be taken modulo the corresponding value from this argument. Note, if this is used, hash_size will be ignored. - ids_per_feature (int): Number of IDs per sparse feature. - ids_per_features (int): Number of IDs per sparse feature in each key. Note, if this is used, ids_per_feature will be ignored. + ids_per_feature (Optional[int]): Number of IDs per sparse feature per sample. + ids_per_features (Optional[List[int]]): Number of IDs per sparse feature per sample in each key. Note, if this is used, ids_per_feature will be ignored. num_dense (int): Number of dense features. manual_seed (int): Seed for deterministic behavior. num_batches: (Optional[int]): Num batches to generate before raising StopIteration num_generated_batches int: Num batches to cache. If num_batches > num_generated batches, then we will cycle to the first generated batch. If this value is negative, batches will be generated on the fly. - min_ids_per_feature (int): Minimum number of IDs per features. + min_ids_per_feature (Optional[int]): Minimum number of IDs per features. + min_ids_per_features (Optional[List[int]]): Minimum number of IDs per sparse feature per sample in each key. Note, if this is used, min_ids_per_feature will be ignored. Example:: @@ -165,9 +161,9 @@ def __init__( self, keys: List[str], batch_size: int, - hash_size: Optional[int] = 100, + hash_size: Optional[int] = None, hash_sizes: Optional[List[int]] = None, - ids_per_feature: Optional[int] = 2, + ids_per_feature: Optional[int] = None, ids_per_features: Optional[List[int]] = None, num_dense: int = 50, manual_seed: Optional[int] = None, @@ -179,8 +175,7 @@ def __init__( super().__init__() if hash_sizes is None: - hash_size = hash_size or 100 - hash_sizes = [hash_size] * len(keys) + hash_sizes = [hash_size if hash_size else 100] * len(keys) assert hash_sizes is not None assert len(hash_sizes) == len( @@ -188,8 +183,7 @@ def __init__( ), "length of hash_sizes must be equal to the number of keys" if ids_per_features is None: - ids_per_feature = ids_per_feature or 2 - ids_per_features = [ids_per_feature] * len(keys) + ids_per_features = [ids_per_feature if ids_per_feature else 2] * len(keys) assert ids_per_features is not None @@ -199,8 +193,9 @@ def __init__( if min_ids_per_feature is not None else ids_per_feature ) - assert min_ids_per_feature is not None - min_ids_per_features = [min_ids_per_feature] * len(keys) + min_ids_per_features = [ + min_ids_per_feature if min_ids_per_feature else 0 + ] * len(keys) assert len(ids_per_features) == len( keys diff --git a/torchrec/datasets/tests/test_random.py b/torchrec/datasets/tests/test_random.py index bafe78d75..f30d8e0ad 100644 --- a/torchrec/datasets/tests/test_random.py +++ b/torchrec/datasets/tests/test_random.py @@ -10,6 +10,8 @@ import itertools import unittest +from hypothesis import given, settings, strategies as st + from torchrec.datasets.random import RandomRecDataset @@ -74,6 +76,23 @@ def test_hash_ids_per_feature(self) -> None: for batch in feat2: self.assertEqual(len(batch), 200) + # pyre-ignore + @given( + batch_size=st.sampled_from([2048, 4096, 8192]), + ) + @settings(max_examples=3, deadline=5000) # expected runtime <=500ms + def test_large_batch_size_deadline(self, batch_size: int) -> None: + dataset = RandomRecDataset( + keys=["feat1", "feat2"], + batch_size=batch_size, + ids_per_features=[10, 20], + hash_size=100, + num_dense=5, + ) + iterator = iter(dataset) + for _ in range(5): + next(iterator) + def test_hash_ids(self) -> None: dataset = RandomRecDataset( keys=["feat1", "feat2"],