forked from pytorch/torcharrow
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Allow random dataloader to have ids per feature to vary between featu…
…res (pytorch#182) Summary: Pull Request resolved: pytorch/torchrec#182 Currently the random dataloader is inflexible in ids_per_feature because every table will have the same pooling factor. Adding a bit of a refactor to change that. Also I found the generation script a bit hard to follow, so changed it - but not sure if this is more inefficient As a follow up, we can allow a distribution that we draw from to generate ids_per_feature (instead of fixed) Reviewed By: bigning Differential Revision: D35148294 fbshipit-source-id: d8a9b0f0472c4a0aa0977dedcbb9636bed61f6dc
- Loading branch information
1 parent
60f3036
commit 0410c7d
Showing
2 changed files
with
202 additions
and
57 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from testslide import TestCase | ||
from torchrec.datasets.random import RandomRecDataset | ||
|
||
|
||
class RandomDataLoader(TestCase): | ||
def test_hash_per_feature_ids_per_feature(self) -> None: | ||
dataset = RandomRecDataset( | ||
keys=["feat1", "feat2"], | ||
batch_size=16, | ||
hash_sizes=[100, 200], | ||
ids_per_features=[100, 200], | ||
num_dense=5, | ||
) | ||
|
||
example = next(iter(dataset)) | ||
dense = example.dense_features | ||
self.assertEqual(dense.shape, (16, 5)) | ||
|
||
labels = example.labels | ||
self.assertEqual(labels.shape, (16,)) | ||
|
||
sparse = example.sparse_features | ||
self.assertEqual(sparse.stride(), 16) | ||
|
||
feat1 = sparse["feat1"].to_dense() | ||
self.assertEqual(len(feat1), 16) | ||
for batch in feat1: | ||
self.assertEqual(len(batch), 100) | ||
|
||
feat2 = sparse["feat2"].to_dense() | ||
self.assertEqual(len(feat2), 16) | ||
for batch in feat2: | ||
self.assertEqual(len(batch), 200) | ||
|
||
def test_hash_ids_per_feature(self) -> None: | ||
dataset = RandomRecDataset( | ||
keys=["feat1", "feat2"], | ||
batch_size=16, | ||
hash_size=100, | ||
ids_per_features=[100, 200], | ||
num_dense=5, | ||
) | ||
|
||
example = next(iter(dataset)) | ||
dense = example.dense_features | ||
self.assertEqual(dense.shape, (16, 5)) | ||
|
||
labels = example.labels | ||
self.assertEqual(labels.shape, (16,)) | ||
|
||
sparse = example.sparse_features | ||
self.assertEqual(sparse.stride(), 16) | ||
|
||
feat1 = sparse["feat1"].to_dense() | ||
self.assertEqual(len(feat1), 16) | ||
for batch in feat1: | ||
self.assertEqual(len(batch), 100) | ||
|
||
feat2 = sparse["feat2"].to_dense() | ||
self.assertEqual(len(feat2), 16) | ||
for batch in feat2: | ||
self.assertEqual(len(batch), 200) | ||
|
||
def test_hash_ids(self) -> None: | ||
dataset = RandomRecDataset( | ||
keys=["feat1", "feat2"], | ||
batch_size=16, | ||
hash_size=100, | ||
ids_per_feature=50, | ||
num_dense=5, | ||
) | ||
|
||
example = next(iter(dataset)) | ||
dense = example.dense_features | ||
self.assertEqual(dense.shape, (16, 5)) | ||
|
||
labels = example.labels | ||
self.assertEqual(labels.shape, (16,)) | ||
|
||
sparse = example.sparse_features | ||
self.assertEqual(sparse.stride(), 16) | ||
|
||
feat1 = sparse["feat1"].to_dense() | ||
self.assertEqual(len(feat1), 16) | ||
for batch in feat1: | ||
self.assertEqual(len(batch), 50) | ||
|
||
feat2 = sparse["feat2"].to_dense() | ||
self.assertEqual(len(feat2), 16) | ||
for batch in feat2: | ||
self.assertEqual(len(batch), 50) | ||
|
||
def test_on_fly_batch_generation(self) -> None: | ||
dataset = RandomRecDataset( | ||
keys=["feat1", "feat2"], | ||
batch_size=16, | ||
hash_size=100, | ||
ids_per_feature=50, | ||
num_dense=5, | ||
num_generated_batches=-1, | ||
) | ||
|
||
it = iter(dataset) | ||
|
||
example = next(it) | ||
example = next(it) | ||
example = next(it) | ||
example = next(it) | ||
|
||
dense = example.dense_features | ||
self.assertEqual(dense.shape, (16, 5)) | ||
|
||
labels = example.labels | ||
self.assertEqual(labels.shape, (16,)) | ||
|
||
sparse = example.sparse_features | ||
self.assertEqual(sparse.stride(), 16) | ||
|
||
feat1 = sparse["feat1"].to_dense() | ||
self.assertEqual(len(feat1), 16) | ||
for batch in feat1: | ||
self.assertEqual(len(batch), 50) | ||
|
||
feat2 = sparse["feat2"].to_dense() | ||
self.assertEqual(len(feat2), 16) | ||
for batch in feat2: | ||
self.assertEqual(len(batch), 50) |