Skip to content

Commit

Permalink
Improve Random Batch generation speed (#1805)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1805

Remove for loop from batch generation; very slow at larger test batch sizes (4K+) - prior to changes, would take 2mins+.

Reviewed By: henrylhtsang

Differential Revision:
D55052074

Privacy Context Container: 1203980333745195

fbshipit-source-id: 858246959832360d08063e02b95370703dd12dd7
  • Loading branch information
dstaay-fb authored and facebook-github-bot committed Mar 20, 2024
1 parent 347d760 commit d9fbac9
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 31 deletions.
57 changes: 26 additions & 31 deletions torchrec/datasets/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,30 +81,25 @@ def _generate_batch(self) -> Batch:
lengths = []
for key_idx, _ in enumerate(self.keys):
hash_size = self.hash_sizes[key_idx]
min_num_ids_in_batch = self.min_ids_per_features[key_idx]
max_num_ids_in_batch = self.ids_per_features[key_idx]
for _ in range(self.batch_size):
num_ids_in_batch = int(
torch.randint(
low=min_num_ids_in_batch,
high=max_num_ids_in_batch + 1,
size=(),
generator=self.generator,
).item()
)
values.append(
torch.randint(
high=hash_size,
size=(num_ids_in_batch,),
generator=self.generator,
)
)
lengths.extend([num_ids_in_batch])
min_num_ids = self.min_ids_per_features[key_idx]
max_num_ids = self.ids_per_features[key_idx]
length = torch.randint(
min_num_ids,
max_num_ids + 1,
(self.batch_size,),
dtype=torch.int32,
generator=self.generator,
)
value = torch.randint(
0, hash_size, (cast(int, length.sum()),), generator=self.generator
)
lengths.append(length)
values.append(value)

sparse_features = KeyedJaggedTensor.from_lengths_sync(
keys=self.keys,
values=torch.cat(values),
lengths=torch.tensor(lengths, dtype=torch.int32),
lengths=torch.cat(lengths),
)

dense_features = torch.randn(
Expand Down Expand Up @@ -140,14 +135,15 @@ class RandomRecDataset(IterableDataset[Batch]):
modulo this value.
hash_sizes (Optional[List[int]]): Max sparse id value per feature in keys. Each
sparse ID will be taken modulo the corresponding value from this argument. Note, if this is used, hash_size will be ignored.
ids_per_feature (int): Number of IDs per sparse feature.
ids_per_features (int): Number of IDs per sparse feature in each key. Note, if this is used, ids_per_feature will be ignored.
ids_per_feature (Optional[int]): Number of IDs per sparse feature per sample.
ids_per_features (Optional[List[int]]): Number of IDs per sparse feature per sample in each key. Note, if this is used, ids_per_feature will be ignored.
num_dense (int): Number of dense features.
manual_seed (int): Seed for deterministic behavior.
num_batches: (Optional[int]): Num batches to generate before raising StopIteration
num_generated_batches int: Num batches to cache. If num_batches > num_generated batches, then we will cycle to the first generated batch.
If this value is negative, batches will be generated on the fly.
min_ids_per_feature (int): Minimum number of IDs per features.
min_ids_per_feature (Optional[int]): Minimum number of IDs per features.
min_ids_per_features (Optional[List[int]]): Minimum number of IDs per sparse feature per sample in each key. Note, if this is used, min_ids_per_feature will be ignored.
Example::
Expand All @@ -165,9 +161,9 @@ def __init__(
self,
keys: List[str],
batch_size: int,
hash_size: Optional[int] = 100,
hash_size: Optional[int] = None,
hash_sizes: Optional[List[int]] = None,
ids_per_feature: Optional[int] = 2,
ids_per_feature: Optional[int] = None,
ids_per_features: Optional[List[int]] = None,
num_dense: int = 50,
manual_seed: Optional[int] = None,
Expand All @@ -179,17 +175,15 @@ def __init__(
super().__init__()

if hash_sizes is None:
hash_size = hash_size or 100
hash_sizes = [hash_size] * len(keys)
hash_sizes = [hash_size if hash_size else 100] * len(keys)

assert hash_sizes is not None
assert len(hash_sizes) == len(
keys
), "length of hash_sizes must be equal to the number of keys"

if ids_per_features is None:
ids_per_feature = ids_per_feature or 2
ids_per_features = [ids_per_feature] * len(keys)
ids_per_features = [ids_per_feature if ids_per_feature else 2] * len(keys)

assert ids_per_features is not None

Expand All @@ -199,8 +193,9 @@ def __init__(
if min_ids_per_feature is not None
else ids_per_feature
)
assert min_ids_per_feature is not None
min_ids_per_features = [min_ids_per_feature] * len(keys)
min_ids_per_features = [
min_ids_per_feature if min_ids_per_feature else 0
] * len(keys)

assert len(ids_per_features) == len(
keys
Expand Down
19 changes: 19 additions & 0 deletions torchrec/datasets/tests/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import itertools
import unittest

from hypothesis import given, settings, strategies as st

from torchrec.datasets.random import RandomRecDataset


Expand Down Expand Up @@ -74,6 +76,23 @@ def test_hash_ids_per_feature(self) -> None:
for batch in feat2:
self.assertEqual(len(batch), 200)

# pyre-ignore
@given(
batch_size=st.sampled_from([2048, 4096, 8192]),
)
@settings(max_examples=3, deadline=5000) # expected runtime <=500ms
def test_large_batch_size_deadline(self, batch_size: int) -> None:
dataset = RandomRecDataset(
keys=["feat1", "feat2"],
batch_size=batch_size,
ids_per_features=[10, 20],
hash_size=100,
num_dense=5,
)
iterator = iter(dataset)
for _ in range(5):
next(iterator)

def test_hash_ids(self) -> None:
dataset = RandomRecDataset(
keys=["feat1", "feat2"],
Expand Down

0 comments on commit d9fbac9

Please sign in to comment.