diff --git a/src/viztracer/report_builder.py b/src/viztracer/report_builder.py index 72460a2c..62fd5338 100644 --- a/src/viztracer/report_builder.py +++ b/src/viztracer/report_builder.py @@ -2,14 +2,13 @@ # For details: https://github.com/gaogaotiantian/viztracer/blob/master/NOTICE.txt try: - import orjson # type: ignore + import orjson as json # type: ignore except ImportError: import json import gzip import os import re -import sys from string import Template from typing import Any, Dict, List, Optional, Sequence, TextIO, Tuple, Union @@ -27,10 +26,7 @@ def get_json(data: Union[Dict, str]) -> Dict[str, Any]: with open(data, encoding="utf-8") as f: json_str = f.read() - if "orjson" in sys.modules: - return orjson.loads(json_str) - else: - return json.loads(json_str) + return json.loads(json_str) class ReportBuilder: @@ -161,18 +157,18 @@ def generate_report( tmpl = f.read() with open(os.path.join(os.path.dirname(__file__), "html/trace_viewer_full.html"), encoding="utf-8") as f: sub["trace_viewer_full"] = f.read() - if "orjson" in sys.modules: - sub["json_data"] = orjson.dumps(self.combined_json) \ - .decode("utf-8") \ - .replace("", "<\\/script>") + if json.__name__ == "orjson": + sub["json_data"] = json.dumps(self.combined_json) \ + .decode("utf-8") \ + .replace("", "<\\/script>") else: sub["json_data"] = json.dumps(self.combined_json) \ .replace("", "<\\/script>") output_file.write(Template(tmpl).substitute(sub)) elif output_format == "json": self.prepare_json(file_info=file_info) - if "orjson" in sys.modules: - output_file.write(orjson.dumps(self.combined_json).decode("utf-8")) + if json.__name__ == "orjson": + output_file.write(json.dumps(self.combined_json).decode("utf-8")) else: if self.minimize_memory: json.dump(self.combined_json, output_file) # type: ignore diff --git a/tests/test_report_builder.py b/tests/test_report_builder.py index 43f3c24b..4afed792 100644 --- a/tests/test_report_builder.py +++ b/tests/test_report_builder.py @@ -7,12 +7,15 @@ import os import shutil import tempfile +import textwrap from unittest.mock import patch import viztracer from viztracer.report_builder import ReportBuilder from .base_tmpl import BaseTmpl +from .cmdline_tmpl import CmdlineTmpl +from .package_env import package_matrix class TestReportBuilder(BaseTmpl): @@ -114,3 +117,31 @@ def test_combine(self): rb.save(output_file=s) data = json.loads(s.getvalue()) self.assertTrue(data["viztracer_metadata"]["overflow"]) + + +class TestReportBuilderCmdline(CmdlineTmpl): + @package_matrix(["~orjson", "orjson"]) + def test_package_matrix(self): + """ + The module will be imported only once so flipping the package matrix will only + work when we start a new script + """ + + with tempfile.TemporaryDirectory() as tmpdir: + invalid_json_path = os.path.join(os.path.dirname(__file__), "data", "fib.py") + invalid_json_file = shutil.copy(invalid_json_path, os.path.join(tmpdir, "invalid.json")) + + script = textwrap.dedent(f""" + import io + from viztracer.report_builder import ReportBuilder + rb = ReportBuilder([{repr(invalid_json_file)}], verbose=1) + try: + with io.StringIO() as s: + rb.save(s) + except Exception as e: + assert str(e) == "No valid json files found" + else: + assert False + """) + + self.template(["python", "cmdline_test.py"], script=script, expected_output_file=None)