From 0410c7d2382335042d72ee9c8db9c238bb7badd6 Mon Sep 17 00:00:00 2001 From: Ying Liu Date: Mon, 28 Mar 2022 18:21:23 -0700 Subject: [PATCH] Allow random dataloader to have ids per feature to vary between features (#182) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/182 Currently the random dataloader is inflexible in ids_per_feature because every table will have the same pooling factor. Adding a bit of a refactor to change that. Also I found the generation script a bit hard to follow, so changed it - but not sure if this is more inefficient As a follow up, we can allow a distribution that we draw from to generate ids_per_feature (instead of fixed) Reviewed By: bigning Differential Revision: D35148294 fbshipit-source-id: d8a9b0f0472c4a0aa0977dedcbb9636bed61f6dc --- torchrec/datasets/random.py | 125 ++++++++++++----------- torchrec/datasets/tests/test_random.py | 134 +++++++++++++++++++++++++ 2 files changed, 202 insertions(+), 57 deletions(-) create mode 100644 torchrec/datasets/tests/test_random.py diff --git a/torchrec/datasets/random.py b/torchrec/datasets/random.py index 56ce37212..b56280881 100644 --- a/torchrec/datasets/random.py +++ b/torchrec/datasets/random.py @@ -8,7 +8,6 @@ from typing import Iterator, List, Optional import torch -from pyre_extensions import none_throws from torch.utils.data.dataset import IterableDataset from torchrec.datasets.utils import Batch from torchrec.sparse.jagged_tensor import KeyedJaggedTensor @@ -21,26 +20,22 @@ def __init__( self, keys: List[str], batch_size: int, - hash_size: Optional[int], - hash_sizes: Optional[List[int]], - ids_per_feature: int, + hash_sizes: List[int], + ids_per_features: List[int], num_dense: int, manual_seed: Optional[int] = None, + num_generated_batches: int = 10, + num_batches: Optional[int] = None, ) -> None: - if (hash_size is None and hash_sizes is None) or ( - hash_size is not None and hash_sizes is not None - ): - raise ValueError( - "One - and only one - of hash_size or hash_sizes must be set." - ) self.keys = keys self.keys_length: int = len(keys) self.batch_size = batch_size - self.hash_size = hash_size self.hash_sizes = hash_sizes - self.ids_per_feature = ids_per_feature + self.ids_per_features = ids_per_features self.num_dense = num_dense + self.num_batches = num_batches + self.num_generated_batches = num_generated_batches if manual_seed is not None: self.generator = torch.Generator() @@ -49,60 +44,49 @@ def __init__( else: self.generator = None - self.iter_num = 0 - self._num_ids_in_batch: int = ( - self.ids_per_feature * self.keys_length * self.batch_size - ) - self.max_values: Optional[torch.Tensor] = None - if hash_sizes is not None: - self.max_values: torch.Tensor = torch.tensor( - [ - hash_size - for hash_size in hash_sizes - for b in range(batch_size) - for i in range(ids_per_feature) - ] - ) - self._generated_batches: List[Batch] = [self._generate_batch()] * 10 + self._generated_batches: List[Batch] = [ + self._generate_batch() + ] * num_generated_batches self.batch_index = 0 def __iter__(self) -> "_RandomRecBatch": + self.batch_index = 0 return self def __next__(self) -> Batch: - batch = self._generated_batches[self.batch_index % len(self._generated_batches)] + if self.batch_index == self.num_batches: + raise StopIteration + if self.num_generated_batches >= 0: + batch = self._generated_batches[ + self.batch_index % len(self._generated_batches) + ] + else: + batch = self._generate_batch() self.batch_index += 1 return batch def _generate_batch(self) -> Batch: - if self.hash_sizes is None: - # pyre-ignore[28] - values = torch.randint( - high=self.hash_size, - size=(self._num_ids_in_batch,), - generator=self.generator, - ) - else: - values = ( - torch.rand( - self._num_ids_in_batch, + + values = [] + lengths = [] + for key_idx, _ in enumerate(self.keys): + hash_size = self.hash_sizes[key_idx] + num_ids_in_batch = self.ids_per_features[key_idx] + + values.append( + # pyre-ignore + torch.randint( + high=hash_size, + size=(num_ids_in_batch * self.batch_size,), generator=self.generator, ) - * none_throws(self.max_values) - ).type(torch.LongTensor) - sparse_features = KeyedJaggedTensor.from_offsets_sync( + ) + lengths.extend([num_ids_in_batch] * self.batch_size) + + sparse_features = KeyedJaggedTensor.from_lengths_sync( keys=self.keys, - values=values, - offsets=torch.tensor( - list( - range( - 0, - self._num_ids_in_batch + 1, - self.ids_per_feature, - ) - ), - dtype=torch.int32, - ), + values=torch.cat(values), + lengths=torch.tensor(lengths, dtype=torch.int32), ) dense_features = torch.randn( @@ -138,10 +122,14 @@ class RandomRecDataset(IterableDataset[Batch]): hash_size (Optional[int]): Max sparse id value. All sparse IDs will be taken modulo this value. hash_sizes (Optional[List[int]]): Max sparse id value per feature in keys. Each - sparse ID will be taken modulo the corresponding value from this argument. + sparse ID will be taken modulo the corresponding value from this argument. Note, if this is used, hash_size will be ignored. ids_per_feature (int): Number of IDs per sparse feature. + ids_per_features (int): Number of IDs per sparse feature in each key. Note, if this is used, ids_per_feature will be ignored. num_dense (int): Number of dense features. manual_seed (int): Seed for deterministic behavior. + num_batches: (Optional[int]): Num batches to generate before raising StopIteration + num_generated_batches int: Num batches to cache. If num_batches > num_generated batches, then we will cycle to the first generated batch. + If this value is negative, batches will be generated on the fly. Example:: @@ -161,19 +149,42 @@ def __init__( batch_size: int, hash_size: Optional[int] = 100, hash_sizes: Optional[List[int]] = None, - ids_per_feature: int = 2, + ids_per_feature: Optional[int] = 2, + ids_per_features: Optional[List[int]] = None, num_dense: int = 50, manual_seed: Optional[int] = None, + num_batches: Optional[int] = None, + num_generated_batches: int = 10, ) -> None: super().__init__() + + if hash_sizes is None: + hash_size = hash_size or 100 + hash_sizes = [hash_size] * len(keys) + + assert hash_sizes is not None + assert len(hash_sizes) == len( + keys + ), "length of hash_sizes must be equal to the number of keys" + + if ids_per_features is None: + ids_per_feature = ids_per_feature or 2 + ids_per_features = [ids_per_feature] * len(keys) + + assert ids_per_features is not None + assert len(ids_per_features) == len( + keys + ), "length of ids_per_features must be equal to the number of keys" + self.batch_generator = _RandomRecBatch( keys=keys, batch_size=batch_size, - hash_size=hash_size, hash_sizes=hash_sizes, - ids_per_feature=ids_per_feature, + ids_per_features=ids_per_features, num_dense=num_dense, manual_seed=manual_seed, + num_batches=num_batches, + num_generated_batches=num_generated_batches, ) def __iter__(self) -> Iterator[Batch]: diff --git a/torchrec/datasets/tests/test_random.py b/torchrec/datasets/tests/test_random.py new file mode 100644 index 000000000..3be6fee88 --- /dev/null +++ b/torchrec/datasets/tests/test_random.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from testslide import TestCase +from torchrec.datasets.random import RandomRecDataset + + +class RandomDataLoader(TestCase): + def test_hash_per_feature_ids_per_feature(self) -> None: + dataset = RandomRecDataset( + keys=["feat1", "feat2"], + batch_size=16, + hash_sizes=[100, 200], + ids_per_features=[100, 200], + num_dense=5, + ) + + example = next(iter(dataset)) + dense = example.dense_features + self.assertEqual(dense.shape, (16, 5)) + + labels = example.labels + self.assertEqual(labels.shape, (16,)) + + sparse = example.sparse_features + self.assertEqual(sparse.stride(), 16) + + feat1 = sparse["feat1"].to_dense() + self.assertEqual(len(feat1), 16) + for batch in feat1: + self.assertEqual(len(batch), 100) + + feat2 = sparse["feat2"].to_dense() + self.assertEqual(len(feat2), 16) + for batch in feat2: + self.assertEqual(len(batch), 200) + + def test_hash_ids_per_feature(self) -> None: + dataset = RandomRecDataset( + keys=["feat1", "feat2"], + batch_size=16, + hash_size=100, + ids_per_features=[100, 200], + num_dense=5, + ) + + example = next(iter(dataset)) + dense = example.dense_features + self.assertEqual(dense.shape, (16, 5)) + + labels = example.labels + self.assertEqual(labels.shape, (16,)) + + sparse = example.sparse_features + self.assertEqual(sparse.stride(), 16) + + feat1 = sparse["feat1"].to_dense() + self.assertEqual(len(feat1), 16) + for batch in feat1: + self.assertEqual(len(batch), 100) + + feat2 = sparse["feat2"].to_dense() + self.assertEqual(len(feat2), 16) + for batch in feat2: + self.assertEqual(len(batch), 200) + + def test_hash_ids(self) -> None: + dataset = RandomRecDataset( + keys=["feat1", "feat2"], + batch_size=16, + hash_size=100, + ids_per_feature=50, + num_dense=5, + ) + + example = next(iter(dataset)) + dense = example.dense_features + self.assertEqual(dense.shape, (16, 5)) + + labels = example.labels + self.assertEqual(labels.shape, (16,)) + + sparse = example.sparse_features + self.assertEqual(sparse.stride(), 16) + + feat1 = sparse["feat1"].to_dense() + self.assertEqual(len(feat1), 16) + for batch in feat1: + self.assertEqual(len(batch), 50) + + feat2 = sparse["feat2"].to_dense() + self.assertEqual(len(feat2), 16) + for batch in feat2: + self.assertEqual(len(batch), 50) + + def test_on_fly_batch_generation(self) -> None: + dataset = RandomRecDataset( + keys=["feat1", "feat2"], + batch_size=16, + hash_size=100, + ids_per_feature=50, + num_dense=5, + num_generated_batches=-1, + ) + + it = iter(dataset) + + example = next(it) + example = next(it) + example = next(it) + example = next(it) + + dense = example.dense_features + self.assertEqual(dense.shape, (16, 5)) + + labels = example.labels + self.assertEqual(labels.shape, (16,)) + + sparse = example.sparse_features + self.assertEqual(sparse.stride(), 16) + + feat1 = sparse["feat1"].to_dense() + self.assertEqual(len(feat1), 16) + for batch in feat1: + self.assertEqual(len(batch), 50) + + feat2 = sparse["feat2"].to_dense() + self.assertEqual(len(feat2), 16) + for batch in feat2: + self.assertEqual(len(batch), 50)