Skip to content

Commit

Permalink
Allow random dataloader to have ids per feature to vary between featu…
Browse files Browse the repository at this point in the history
…res (pytorch#182)

Summary:
Pull Request resolved: pytorch/torchrec#182

Currently the random dataloader is inflexible in ids_per_feature because every table will have the same pooling factor.
Adding a bit of a refactor to change that.

Also I found the generation script a bit hard to follow, so changed it - but not sure if this is more inefficient
As a follow up, we can allow a distribution that we draw from to generate ids_per_feature (instead of fixed)

Reviewed By: bigning

Differential Revision: D35148294

fbshipit-source-id: d8a9b0f0472c4a0aa0977dedcbb9636bed61f6dc
  • Loading branch information
YLGH authored and facebook-github-bot committed Mar 29, 2022
1 parent 60f3036 commit 0410c7d
Show file tree
Hide file tree
Showing 2 changed files with 202 additions and 57 deletions.
125 changes: 68 additions & 57 deletions torchrec/datasets/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import Iterator, List, Optional

import torch
from pyre_extensions import none_throws
from torch.utils.data.dataset import IterableDataset
from torchrec.datasets.utils import Batch
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
Expand All @@ -21,26 +20,22 @@ def __init__(
self,
keys: List[str],
batch_size: int,
hash_size: Optional[int],
hash_sizes: Optional[List[int]],
ids_per_feature: int,
hash_sizes: List[int],
ids_per_features: List[int],
num_dense: int,
manual_seed: Optional[int] = None,
num_generated_batches: int = 10,
num_batches: Optional[int] = None,
) -> None:
if (hash_size is None and hash_sizes is None) or (
hash_size is not None and hash_sizes is not None
):
raise ValueError(
"One - and only one - of hash_size or hash_sizes must be set."
)

self.keys = keys
self.keys_length: int = len(keys)
self.batch_size = batch_size
self.hash_size = hash_size
self.hash_sizes = hash_sizes
self.ids_per_feature = ids_per_feature
self.ids_per_features = ids_per_features
self.num_dense = num_dense
self.num_batches = num_batches
self.num_generated_batches = num_generated_batches

if manual_seed is not None:
self.generator = torch.Generator()
Expand All @@ -49,60 +44,49 @@ def __init__(
else:
self.generator = None

self.iter_num = 0
self._num_ids_in_batch: int = (
self.ids_per_feature * self.keys_length * self.batch_size
)
self.max_values: Optional[torch.Tensor] = None
if hash_sizes is not None:
self.max_values: torch.Tensor = torch.tensor(
[
hash_size
for hash_size in hash_sizes
for b in range(batch_size)
for i in range(ids_per_feature)
]
)
self._generated_batches: List[Batch] = [self._generate_batch()] * 10
self._generated_batches: List[Batch] = [
self._generate_batch()
] * num_generated_batches
self.batch_index = 0

def __iter__(self) -> "_RandomRecBatch":
self.batch_index = 0
return self

def __next__(self) -> Batch:
batch = self._generated_batches[self.batch_index % len(self._generated_batches)]
if self.batch_index == self.num_batches:
raise StopIteration
if self.num_generated_batches >= 0:
batch = self._generated_batches[
self.batch_index % len(self._generated_batches)
]
else:
batch = self._generate_batch()
self.batch_index += 1
return batch

def _generate_batch(self) -> Batch:
if self.hash_sizes is None:
# pyre-ignore[28]
values = torch.randint(
high=self.hash_size,
size=(self._num_ids_in_batch,),
generator=self.generator,
)
else:
values = (
torch.rand(
self._num_ids_in_batch,

values = []
lengths = []
for key_idx, _ in enumerate(self.keys):
hash_size = self.hash_sizes[key_idx]
num_ids_in_batch = self.ids_per_features[key_idx]

values.append(
# pyre-ignore
torch.randint(
high=hash_size,
size=(num_ids_in_batch * self.batch_size,),
generator=self.generator,
)
* none_throws(self.max_values)
).type(torch.LongTensor)
sparse_features = KeyedJaggedTensor.from_offsets_sync(
)
lengths.extend([num_ids_in_batch] * self.batch_size)

sparse_features = KeyedJaggedTensor.from_lengths_sync(
keys=self.keys,
values=values,
offsets=torch.tensor(
list(
range(
0,
self._num_ids_in_batch + 1,
self.ids_per_feature,
)
),
dtype=torch.int32,
),
values=torch.cat(values),
lengths=torch.tensor(lengths, dtype=torch.int32),
)

dense_features = torch.randn(
Expand Down Expand Up @@ -138,10 +122,14 @@ class RandomRecDataset(IterableDataset[Batch]):
hash_size (Optional[int]): Max sparse id value. All sparse IDs will be taken
modulo this value.
hash_sizes (Optional[List[int]]): Max sparse id value per feature in keys. Each
sparse ID will be taken modulo the corresponding value from this argument.
sparse ID will be taken modulo the corresponding value from this argument. Note, if this is used, hash_size will be ignored.
ids_per_feature (int): Number of IDs per sparse feature.
ids_per_features (int): Number of IDs per sparse feature in each key. Note, if this is used, ids_per_feature will be ignored.
num_dense (int): Number of dense features.
manual_seed (int): Seed for deterministic behavior.
num_batches: (Optional[int]): Num batches to generate before raising StopIteration
num_generated_batches int: Num batches to cache. If num_batches > num_generated batches, then we will cycle to the first generated batch.
If this value is negative, batches will be generated on the fly.
Example::
Expand All @@ -161,19 +149,42 @@ def __init__(
batch_size: int,
hash_size: Optional[int] = 100,
hash_sizes: Optional[List[int]] = None,
ids_per_feature: int = 2,
ids_per_feature: Optional[int] = 2,
ids_per_features: Optional[List[int]] = None,
num_dense: int = 50,
manual_seed: Optional[int] = None,
num_batches: Optional[int] = None,
num_generated_batches: int = 10,
) -> None:
super().__init__()

if hash_sizes is None:
hash_size = hash_size or 100
hash_sizes = [hash_size] * len(keys)

assert hash_sizes is not None
assert len(hash_sizes) == len(
keys
), "length of hash_sizes must be equal to the number of keys"

if ids_per_features is None:
ids_per_feature = ids_per_feature or 2
ids_per_features = [ids_per_feature] * len(keys)

assert ids_per_features is not None
assert len(ids_per_features) == len(
keys
), "length of ids_per_features must be equal to the number of keys"

self.batch_generator = _RandomRecBatch(
keys=keys,
batch_size=batch_size,
hash_size=hash_size,
hash_sizes=hash_sizes,
ids_per_feature=ids_per_feature,
ids_per_features=ids_per_features,
num_dense=num_dense,
manual_seed=manual_seed,
num_batches=num_batches,
num_generated_batches=num_generated_batches,
)

def __iter__(self) -> Iterator[Batch]:
Expand Down
134 changes: 134 additions & 0 deletions torchrec/datasets/tests/test_random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from testslide import TestCase
from torchrec.datasets.random import RandomRecDataset


class RandomDataLoader(TestCase):
def test_hash_per_feature_ids_per_feature(self) -> None:
dataset = RandomRecDataset(
keys=["feat1", "feat2"],
batch_size=16,
hash_sizes=[100, 200],
ids_per_features=[100, 200],
num_dense=5,
)

example = next(iter(dataset))
dense = example.dense_features
self.assertEqual(dense.shape, (16, 5))

labels = example.labels
self.assertEqual(labels.shape, (16,))

sparse = example.sparse_features
self.assertEqual(sparse.stride(), 16)

feat1 = sparse["feat1"].to_dense()
self.assertEqual(len(feat1), 16)
for batch in feat1:
self.assertEqual(len(batch), 100)

feat2 = sparse["feat2"].to_dense()
self.assertEqual(len(feat2), 16)
for batch in feat2:
self.assertEqual(len(batch), 200)

def test_hash_ids_per_feature(self) -> None:
dataset = RandomRecDataset(
keys=["feat1", "feat2"],
batch_size=16,
hash_size=100,
ids_per_features=[100, 200],
num_dense=5,
)

example = next(iter(dataset))
dense = example.dense_features
self.assertEqual(dense.shape, (16, 5))

labels = example.labels
self.assertEqual(labels.shape, (16,))

sparse = example.sparse_features
self.assertEqual(sparse.stride(), 16)

feat1 = sparse["feat1"].to_dense()
self.assertEqual(len(feat1), 16)
for batch in feat1:
self.assertEqual(len(batch), 100)

feat2 = sparse["feat2"].to_dense()
self.assertEqual(len(feat2), 16)
for batch in feat2:
self.assertEqual(len(batch), 200)

def test_hash_ids(self) -> None:
dataset = RandomRecDataset(
keys=["feat1", "feat2"],
batch_size=16,
hash_size=100,
ids_per_feature=50,
num_dense=5,
)

example = next(iter(dataset))
dense = example.dense_features
self.assertEqual(dense.shape, (16, 5))

labels = example.labels
self.assertEqual(labels.shape, (16,))

sparse = example.sparse_features
self.assertEqual(sparse.stride(), 16)

feat1 = sparse["feat1"].to_dense()
self.assertEqual(len(feat1), 16)
for batch in feat1:
self.assertEqual(len(batch), 50)

feat2 = sparse["feat2"].to_dense()
self.assertEqual(len(feat2), 16)
for batch in feat2:
self.assertEqual(len(batch), 50)

def test_on_fly_batch_generation(self) -> None:
dataset = RandomRecDataset(
keys=["feat1", "feat2"],
batch_size=16,
hash_size=100,
ids_per_feature=50,
num_dense=5,
num_generated_batches=-1,
)

it = iter(dataset)

example = next(it)
example = next(it)
example = next(it)
example = next(it)

dense = example.dense_features
self.assertEqual(dense.shape, (16, 5))

labels = example.labels
self.assertEqual(labels.shape, (16,))

sparse = example.sparse_features
self.assertEqual(sparse.stride(), 16)

feat1 = sparse["feat1"].to_dense()
self.assertEqual(len(feat1), 16)
for batch in feat1:
self.assertEqual(len(batch), 50)

feat2 = sparse["feat2"].to_dense()
self.assertEqual(len(feat2), 16)
for batch in feat2:
self.assertEqual(len(batch), 50)

0 comments on commit 0410c7d

Please sign in to comment.