diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index ae1d29073..ecdb12424 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -264,7 +264,8 @@ def _copy_batch_to_gpu(self, dataloader_iter: Iterator[In]) -> Optional[In]: """ with record_function("## copy_batch_to_gpu ##"): with torch.cuda.stream(self._memcpy_stream): - batch = self._next_batch(dataloader_iter) + with record_function("## next_batch ##"): + batch = self._next_batch(dataloader_iter) if batch is not None: batch = _to_device(batch, self._device, non_blocking=True) elif not self._execute_all_batches: @@ -700,7 +701,8 @@ def _next_batch(self, dataloader_iter: Iterator[In]) -> Optional[In]: if self._dataloader_exhausted: batch = None else: - batch = next(dataloader_iter, None) + with record_function("## next_batch ##"): + batch = next(dataloader_iter, None) if batch is None: self._dataloader_exhausted = True return batch