Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SamplerWrapper node #1357

Merged
merged 48 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
f8b59b5
half-working commit
andrewkho Nov 2, 2024
c1dabe2
half-working commit
andrewkho Nov 2, 2024
a04521a
basic fast-forward in iterablewrapper
andrewkho Nov 4, 2024
0853a3a
partial: update adapters
andrewkho Nov 4, 2024
5d701bf
partial: update batcher
andrewkho Nov 4, 2024
d3723c6
simplify, prefetcher
andrewkho Nov 5, 2024
a4cf475
simplify, prefetcher
andrewkho Nov 5, 2024
d781f6d
finish map tests
andrewkho Nov 5, 2024
376660b
fix pin_memory
andrewkho Nov 5, 2024
8e1f7dc
fix mypy
andrewkho Nov 5, 2024
336859e
remove lock and rely on deque consistency
andrewkho Nov 5, 2024
1ed42eb
test for hang
andrewkho Nov 5, 2024
8d2cd9a
test for hang
andrewkho Nov 5, 2024
d871e7f
test for hang
andrewkho Nov 5, 2024
f5f65ae
test for hang
andrewkho Nov 5, 2024
ae4062d
test for hang
andrewkho Nov 5, 2024
95d7e59
test for hang
andrewkho Nov 5, 2024
b7f0e68
test for hang
andrewkho Nov 5, 2024
83fe9a1
test for hang
andrewkho Nov 5, 2024
6511669
test for hang
andrewkho Nov 5, 2024
13a8266
test for hang
andrewkho Nov 5, 2024
005bbbf
test for hang
andrewkho Nov 6, 2024
8280a11
test for hang
andrewkho Nov 6, 2024
6583dec
test for hang
andrewkho Nov 6, 2024
0096298
test for hang
andrewkho Nov 6, 2024
2bd099b
test for hang
andrewkho Nov 6, 2024
45a2796
test for hang
andrewkho Nov 6, 2024
c080d27
test for hang
andrewkho Nov 6, 2024
6692e5a
test for hang
andrewkho Nov 6, 2024
bff18db
test for hang
andrewkho Nov 6, 2024
149e56a
test for hang
andrewkho Nov 6, 2024
1e44014
test for hang
andrewkho Nov 6, 2024
270f99e
test for hang
andrewkho Nov 6, 2024
2a9af36
test for hang
andrewkho Nov 6, 2024
e89aef1
revert ci changes
andrewkho Nov 6, 2024
b9fb830
format
andrewkho Nov 6, 2024
2dfc975
add logging and disable daemon
andrewkho Nov 6, 2024
5b3f09a
Revert "add logging and disable daemon"
andrewkho Nov 6, 2024
6850b60
switch to pytorch TestCase and add test for snapshot_store
andrewkho Nov 6, 2024
7a1278f
update state test to handle end of epoch
andrewkho Nov 7, 2024
b87147a
add sampler adapter with epoch wrapping
andrewkho Nov 7, 2024
7337555
fix adapter test
andrewkho Nov 7, 2024
7c64b46
Merge branch 'andrewkh/add-state-management' into andrewkh/add-sample…
andrewkho Nov 7, 2024
b5eec07
fix mypy
andrewkho Nov 7, 2024
0c14215
Merge branch 'andrewkh/add-state-management' into andrewkh/add-sample…
andrewkho Nov 7, 2024
6b1775e
fix mypy
andrewkho Nov 7, 2024
f4c5a1e
add initial_epoch arg
andrewkho Nov 7, 2024
7e6e948
Merge branch 'main' into andrewkh/add-sampler-wrapper
andrewkho Nov 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 58 additions & 2 deletions test/nodes/test_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from parameterized import parameterized
from torch.testing._internal.common_utils import TestCase

from torch.utils.data import RandomSampler
from torchdata.nodes.adapters import IterableWrapper, MapStyleWrapper
from torch.utils.data import DistributedSampler, RandomSampler
from torchdata.nodes.adapters import IterableWrapper, MapStyleWrapper, SamplerWrapper

from torchdata.nodes.types import Stateful

Expand Down Expand Up @@ -134,3 +134,59 @@ def test_save_load_state_stateful(self, midpoint: int):
n = 20
node = MapStyleWrapper(DummyMapDataset(n), sampler=_StatefulRange(n))
run_test_save_load_state(self, node, midpoint)


class TestSamplerWrapper(TestCase):
def test_sampler_wrapper(self):
n = 20
ds = DummyMapDataset(n)

node = SamplerWrapper(sampler=RandomSampler(ds))

results = []
for epoch in range(2):
result = list(node)
results.append(result)
self.assertEqual(node._epoch, epoch)
self.assertEqual(len(result), n)
self.assertEqual(set(result), set(range(n)))

self.assertNotEqual(results[0], results[1])

def test_distributed_sampler(self):
# Distributed sampler has set_epoch method
n = 40
ds = DummyMapDataset(n)

sampler = DistributedSampler(ds, rank=1, num_replicas=2)
exp = []
for epoch in range(4):
sampler.set_epoch(epoch)
exp.append(list(sampler))

node = SamplerWrapper(sampler=sampler)

for epoch in range(4):
result = list(node)
self.assertEqual(result, exp[epoch])

@parameterized.expand([0, 7])
def test_save_load_state(self, midpoint: int):
n = 20
ds = DummyMapDataset(n)
sampler = DistributedSampler(ds, rank=1, num_replicas=2)
node = SamplerWrapper(sampler=sampler)
run_test_save_load_state(self, node, midpoint)

@parameterized.expand([0, 7])
def test_save_load_state_with_updater(self, midpoint: int):
n = 20
ds = DummyMapDataset(n)
initial_epoch = 2

def epoch_updater(epoch):
return epoch + 5

sampler = DistributedSampler(ds, rank=1, num_replicas=2)
node = SamplerWrapper(sampler=sampler, initial_epoch=initial_epoch, epoch_updater=epoch_updater)
run_test_save_load_state(self, node, midpoint)
83 changes: 78 additions & 5 deletions torchdata/nodes/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
# LICENSE file in the root directory of this source tree.


from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, TypeVar
from typing import Any, Callable, Dict, Iterable, Iterator, Mapping, Optional, TypeVar

from torch.utils.data import Sampler

from torchdata.nodes.base_node import BaseNode, T

Expand Down Expand Up @@ -60,13 +62,84 @@ def get_state(self) -> Dict[str, Any]:
return state_dict


def MapStyleWrapper(map_dataset: Mapping[K, T], sampler: Iterable[K]) -> BaseNode[T]:
def MapStyleWrapper(map_dataset: Mapping[K, T], sampler: Sampler[K]) -> BaseNode[T]:
"""Thin Wrapper that converts any MapDataset in to a torchdata.node
If you want parallelism, copy this and replace Mapper with ParallelMapper.

:param map_dataset: Mapping to wrap.
:param sampler: Optional[Iterable].
:param map_dataset: Mapping[K, T] - Apply map_dataset.__getitem__ to the outputs of sampler.
:param sampler: Sampler[K]
"""
sampler_node = IterableWrapper(sampler)
sampler_node: SamplerWrapper[K] = SamplerWrapper(sampler)
mapper_node = Mapper(sampler_node, map_dataset.__getitem__)
return mapper_node


class SamplerWrapper(BaseNode[T]):
"""
Convert a sampler into a BaseNode. This is nearly identical to
IterableWrapper except it includes a hook to call set_epoch on the sampler,
if it supports it.

:param sampler: Sampler - to wrap.
:param initial_epoch: int - initial epoch to set on the sampler
:param epoch_updater: Optional[Callable[[int], int]] = None - callback to update epoch at start of new iteration. It's called at the beginning of each iterator request, except the first one.
"""

NUM_YIELDED_KEY = "_num_yielded"
SAMPLER_KEY = "sampler"
EPOCH_KEY = "_epoch"
STARTED_KEY = "_started"

@classmethod
def _default_epoch_updater(cls, epoch: int) -> int:
return epoch + 1

def __init__(
self,
sampler: Sampler[T],
initial_epoch: int = 0,
epoch_updater: Optional[Callable[[int], int]] = None,
):
self.sampler = sampler
self.epoch_updater = epoch_updater or self._default_epoch_updater
self._num_yielded = 0
self._epoch = initial_epoch
self._started = False

def iterator(self, initial_state: Optional[Dict[str, Any]]) -> Iterator[T]:
it: Iterator[T]
if initial_state is not None:
self._num_yielded = initial_state[self.NUM_YIELDED_KEY]
self._epoch = initial_state[self.EPOCH_KEY]
self._started = initial_state[self.STARTED_KEY]

if isinstance(self.sampler, Stateful):
self.sampler.load_state_dict(initial_state[self.SAMPLER_KEY])
it = iter(self.sampler)
else:
if hasattr(self.sampler, "set_epoch"):
self.sampler.set_epoch(self._epoch)
it = iter(self.sampler)
for _ in range(self._num_yielded):
next(it)
else:
if self._started: # don't call first time
self._epoch = self.epoch_updater(self._epoch)
if hasattr(self.sampler, "set_epoch"):
self.sampler.set_epoch(self._epoch)
it = iter(self.sampler)

self._started = True
for item in it:
self._num_yielded += 1
yield item

def get_state(self) -> Dict[str, Any]:
state_dict: Dict[str, Any] = {
self.NUM_YIELDED_KEY: self._num_yielded,
self.EPOCH_KEY: self._epoch,
self.STARTED_KEY: self._started,
}
if isinstance(self.sampler, Stateful):
state_dict[self.SAMPLER_KEY] = self.sampler.state_dict()
return state_dict
Loading