Skip to content

Commit

Permalink
add label for data loader (#1907)
Browse files Browse the repository at this point in the history
Summary:

The data loader batch fetching logic isn't regular kernel related so profiler wouldn't show the actual e2e time.

This would show a label instead of just leaving it blank.

Differential Revision: D55858614
  • Loading branch information
xunnanxu authored and facebook-github-bot committed Apr 20, 2024
1 parent 303e852 commit 01746e5
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions torchrec/distributed/train_pipeline/train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,8 @@ def _next_batch(self, dataloader_iter: Iterator[In]) -> Optional[In]:
if self._dataloader_exhausted:
batch = None
else:
batch = next(dataloader_iter, None)
with record_function("## next_batch ##"):
batch = next(dataloader_iter, None)
if batch is None:
self._dataloader_exhausted = True
return batch
Expand Down Expand Up @@ -700,7 +701,8 @@ def _next_batch(self, dataloader_iter: Iterator[In]) -> Optional[In]:
if self._dataloader_exhausted:
batch = None
else:
batch = next(dataloader_iter, None)
with record_function("## next_batch ##"):
batch = next(dataloader_iter, None)
if batch is None:
self._dataloader_exhausted = True
return batch
Expand Down

0 comments on commit 01746e5

Please sign in to comment.