Skip to content

Commit

Permalink
add label for data loader
Browse files Browse the repository at this point in the history
Summary:
The data loader batch fetching logic isn't regular kernel related so profiler wouldn't show the actual e2e time.

This would show a label instead of just leaving it blank.

Differential Revision: D55858614
  • Loading branch information
xunnanxu authored and facebook-github-bot committed Apr 19, 2024
1 parent d0f7729 commit d00ddb0
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions torchrec/distributed/train_pipeline/train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,8 @@ def _copy_batch_to_gpu(self, dataloader_iter: Iterator[In]) -> Optional[In]:
"""
with record_function("## copy_batch_to_gpu ##"):
with torch.cuda.stream(self._memcpy_stream):
batch = self._next_batch(dataloader_iter)
with record_function("## next_batch ##"):
batch = self._next_batch(dataloader_iter)
if batch is not None:
batch = _to_device(batch, self._device, non_blocking=True)
elif not self._execute_all_batches:
Expand Down Expand Up @@ -700,7 +701,8 @@ def _next_batch(self, dataloader_iter: Iterator[In]) -> Optional[In]:
if self._dataloader_exhausted:
batch = None
else:
batch = next(dataloader_iter, None)
with record_function("## next_batch ##"):
batch = next(dataloader_iter, None)
if batch is None:
self._dataloader_exhausted = True
return batch
Expand Down

0 comments on commit d00ddb0

Please sign in to comment.