diff --git a/grain/_src/python/dataset/stats.py b/grain/_src/python/dataset/stats.py index 9c6510e7..b0ae0372 100644 --- a/grain/_src/python/dataset/stats.py +++ b/grain/_src/python/dataset/stats.py @@ -114,6 +114,10 @@ def _pretty_format_summary( # the visualization graph. col_names.remove("output_spec") col_names.remove("is_output") + # TODO: Add a column for `is_prefetch` in the logged execution + # summary. + col_names.remove("wait_time_ratio") + col_names.remove("is_prefetch") # Insert the average processing time column after the max processing time # column. index = col_names.index("max_processing_time_ns") @@ -273,6 +277,8 @@ class StatsConfig: # Whether this transformation mutates the element spec. This is used to # determine element spec of the current transformation. transform_mutates_spec: bool = True + # Whether this transformation is a prefetch transformation. + is_prefetch: bool = False # Whether to log the execution summary. log_summary: bool = False @@ -539,6 +545,7 @@ def _build_execution_summary( self._summary.name = self._config.name self._summary.output_spec = str(self.output_spec) self._summary.is_output = self._is_output + self._summary.is_prefetch = self._config.is_prefetch execution_summary.nodes.get_or_create(node_id) execution_summary.nodes[node_id].CopyFrom(self._summary) current_node_id = node_id diff --git a/grain/_src/python/dataset/transformations/prefetch.py b/grain/_src/python/dataset/transformations/prefetch.py index 7762e0a1..b66bebb5 100644 --- a/grain/_src/python/dataset/transformations/prefetch.py +++ b/grain/_src/python/dataset/transformations/prefetch.py @@ -119,6 +119,7 @@ def _stats(self): dataset_stats.StatsConfig( name=str(self), transform_mutates_spec=self._MUTATES_ELEMENT_SPEC, + is_prefetch=True, ), (parent_stats,), execution_tracking_mode, @@ -452,6 +453,21 @@ def __init__( _LAST_WORKER_INDEX: -1, } + @functools.cached_property + def _stats(self): + config = dataset_stats.StatsConfig( + name=str(self), + transform_mutates_spec=self._MUTATES_ELEMENT_SPEC, + is_prefetch=True, + ) + return dataset_stats.make_stats( + config, + [p._stats for p in self._parents], # pylint: disable=protected-access + execution_tracking_mode=( + self._ctx.dataset_options.execution_tracking_mode + ), + ) + def __iter__(self) -> dataset.DatasetIterator[T]: return self