diff --git a/src/viztracer/report_builder.py b/src/viztracer/report_builder.py index c98aeaed..72460a2c 100644 --- a/src/viztracer/report_builder.py +++ b/src/viztracer/report_builder.py @@ -47,6 +47,7 @@ def __init__( self.align = align self.minimize_memory = minimize_memory self.jsons: List[Dict] = [] + self.invalid_json_paths: List[str] = [] self.json_loaded = False self.final_messages: List[Tuple[str, Dict]] = [] if not isinstance(data, (dict, list, tuple)): @@ -67,10 +68,16 @@ def load_jsons(self) -> None: self.jsons = [get_json(self.data)] elif isinstance(self.data, (list, tuple)): self.jsons = [] + self.invalid_json_paths = [] for idx, j in enumerate(self.data): if self.verbose > 0: same_line_print(f"Loading trace data from processes {idx}/{len(self.data)}") - self.jsons.append(get_json(j)) + try: + self.jsons.append(get_json(j)) + except json.JSONDecodeError: + self.invalid_json_paths.append(j) + if len(self.invalid_json_paths) > 0: + self.final_messages.append(("invalid_json", {"paths": self.invalid_json_paths})) def combine_json(self) -> None: if self.verbose > 0: @@ -78,7 +85,10 @@ def combine_json(self) -> None: if self.combined_json: return if not self.jsons: - raise ValueError("Can't get report of nothing") + if self.invalid_json_paths: + raise ValueError("No valid json files found") + else: + raise ValueError("Can't get report of nothing") if self.align: for one in self.jsons: self.align_events(one["traceEvents"]) @@ -213,3 +223,10 @@ def print_messages(self): color_print("OKGREEN", f"vizviewer \"{report_abspath}\"") else: color_print("OKGREEN", f"vizviewer {report_abspath}") + elif msg_type == "invalid_json": + print("") + color_print("WARNING", "Found and ignored invalid json file, you may lost some process data.") + color_print("WARNING", "Invalid json file:") + for msg in msg_args["paths"]: + color_print("WARNING", f" {msg}") + print("") diff --git a/tests/test_report_builder.py b/tests/test_report_builder.py index 799cb4d4..43f3c24b 100644 --- a/tests/test_report_builder.py +++ b/tests/test_report_builder.py @@ -5,7 +5,9 @@ import io import json import os +import shutil import tempfile +from unittest.mock import patch import viztracer from viztracer.report_builder import ReportBuilder @@ -70,6 +72,29 @@ def test_invalid_json(self): with self.assertRaises(Exception): ReportBuilder([invalid_json_path], verbose=1) + @patch('sys.stdout', new_callable=io.StringIO) + def test_invalid_json_file(self, mock_stdout): + with tempfile.TemporaryDirectory() as tmpdir: + invalid_json_path = os.path.join(os.path.dirname(__file__), "data", "fib.py") + valid_json_path = os.path.join(os.path.dirname(__file__), "data", "multithread.json") + invalid_json_file = shutil.copy(invalid_json_path, os.path.join(tmpdir, "invalid.json")) + valid_json_file = shutil.copy(valid_json_path, os.path.join(tmpdir, "valid.json")) + rb = ReportBuilder([invalid_json_file, valid_json_file], verbose=1) + with io.StringIO() as s: + rb.save(s) + self.assertIn("Invalid json file", mock_stdout.getvalue()) + + @patch('sys.stdout', new_callable=io.StringIO) + def test_all_invalid_json(self, mock_stdout): + with tempfile.TemporaryDirectory() as tmpdir: + invalid_json_path = os.path.join(os.path.dirname(__file__), "data", "fib.py") + invalid_json_file = shutil.copy(invalid_json_path, os.path.join(tmpdir, "invalid.json")) + rb = ReportBuilder([invalid_json_file], verbose=1) + with self.assertRaises(Exception) as context: + with io.StringIO() as s: + rb.save(s) + self.assertEqual(str(context.exception), "No valid json files found") + def test_combine(self): with tempfile.TemporaryDirectory() as tmpdir: file_path1 = os.path.join(tmpdir, "result1.json")