diff --git a/grain/_src/python/dataset/stats.py b/grain/_src/python/dataset/stats.py index c7fec32a..9c6510e7 100644 --- a/grain/_src/python/dataset/stats.py +++ b/grain/_src/python/dataset/stats.py @@ -411,6 +411,11 @@ def _build_visualization_str(self, has_multiple_parents: bool = False): ) +def _running_in_colab() -> bool: + """Returns whether the current process is running in Colab.""" + return "google.colab" in sys.modules + + class _NoopStats(Stats): """Default implementation for statistics collection that does nothing.""" @@ -454,7 +459,10 @@ def record_output_spec(self, element: T) -> T: return element def report(self): - logging.info("Grain Dataset graph:\n\n%s", self._visualize_dataset_graph()) + msg = f"Grain Dataset graph:\n\n{self._visualize_dataset_graph()}" + logging.info(msg) + if _running_in_colab(): + print(msg) class _ExecutionStats(_VisualizationStats): @@ -507,15 +515,17 @@ def _logging_execution_summary_loop(self): if self._last_update_time > self._last_report_time: self._last_report_time = time.time() summary = self._get_execution_summary() - logging.info( + msg = ( "Grain Dataset Execution Summary:\n\nNOTE: Before analyzing the" " `MapDataset` nodes, ensure that the `total_processing_time` of" " the `PrefetchDatasetIterator` node indicates it is a bottleneck." " The `MapDataset` nodes are executed in multiple threads and thus," " should not be compared to the `total_processing_time` of" - " `DatasetIterator` nodes.\n\n%s", - _pretty_format_summary(summary), + f" `DatasetIterator` nodes.\n\n{_pretty_format_summary(summary)}" ) + logging.info(msg) + if _running_in_colab(): + print(msg) def _build_execution_summary( self,