diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index ae1d29073..9757ac3f1 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -283,7 +283,8 @@ def _next_batch(self, dataloader_iter: Iterator[In]) -> Optional[In]: if self._dataloader_exhausted: batch = None else: - batch = next(dataloader_iter, None) + with record_function("## next_batch ##"): + batch = next(dataloader_iter, None) if batch is None: self._dataloader_exhausted = True return batch @@ -700,7 +701,8 @@ def _next_batch(self, dataloader_iter: Iterator[In]) -> Optional[In]: if self._dataloader_exhausted: batch = None else: - batch = next(dataloader_iter, None) + with record_function("## next_batch ##"): + batch = next(dataloader_iter, None) if batch is None: self._dataloader_exhausted = True return batch