Skip to content

Commit

Permalink
Fix the non-exist json module when orjson is used (#419)
Browse files Browse the repository at this point in the history
  • Loading branch information
gaogaotiantian authored Mar 27, 2024
1 parent 82b9638 commit bbedd36
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 12 deletions.
20 changes: 8 additions & 12 deletions src/viztracer/report_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
# For details: https://github.com/gaogaotiantian/viztracer/blob/master/NOTICE.txt

try:
import orjson # type: ignore
import orjson as json # type: ignore
except ImportError:
import json

import gzip
import os
import re
import sys
from string import Template
from typing import Any, Dict, List, Optional, Sequence, TextIO, Tuple, Union

Expand All @@ -27,10 +26,7 @@ def get_json(data: Union[Dict, str]) -> Dict[str, Any]:
with open(data, encoding="utf-8") as f:
json_str = f.read()

if "orjson" in sys.modules:
return orjson.loads(json_str)
else:
return json.loads(json_str)
return json.loads(json_str)


class ReportBuilder:
Expand Down Expand Up @@ -161,18 +157,18 @@ def generate_report(
tmpl = f.read()
with open(os.path.join(os.path.dirname(__file__), "html/trace_viewer_full.html"), encoding="utf-8") as f:
sub["trace_viewer_full"] = f.read()
if "orjson" in sys.modules:
sub["json_data"] = orjson.dumps(self.combined_json) \
.decode("utf-8") \
.replace("</script>", "<\\/script>")
if json.__name__ == "orjson":
sub["json_data"] = json.dumps(self.combined_json) \
.decode("utf-8") \
.replace("</script>", "<\\/script>")
else:
sub["json_data"] = json.dumps(self.combined_json) \
.replace("</script>", "<\\/script>")
output_file.write(Template(tmpl).substitute(sub))
elif output_format == "json":
self.prepare_json(file_info=file_info)
if "orjson" in sys.modules:
output_file.write(orjson.dumps(self.combined_json).decode("utf-8"))
if json.__name__ == "orjson":
output_file.write(json.dumps(self.combined_json).decode("utf-8"))
else:
if self.minimize_memory:
json.dump(self.combined_json, output_file) # type: ignore
Expand Down
31 changes: 31 additions & 0 deletions tests/test_report_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@
import os
import shutil
import tempfile
import textwrap
from unittest.mock import patch

import viztracer
from viztracer.report_builder import ReportBuilder

from .base_tmpl import BaseTmpl
from .cmdline_tmpl import CmdlineTmpl
from .package_env import package_matrix


class TestReportBuilder(BaseTmpl):
Expand Down Expand Up @@ -114,3 +117,31 @@ def test_combine(self):
rb.save(output_file=s)
data = json.loads(s.getvalue())
self.assertTrue(data["viztracer_metadata"]["overflow"])


class TestReportBuilderCmdline(CmdlineTmpl):
@package_matrix(["~orjson", "orjson"])
def test_package_matrix(self):
"""
The module will be imported only once so flipping the package matrix will only
work when we start a new script
"""

with tempfile.TemporaryDirectory() as tmpdir:
invalid_json_path = os.path.join(os.path.dirname(__file__), "data", "fib.py")
invalid_json_file = shutil.copy(invalid_json_path, os.path.join(tmpdir, "invalid.json"))

script = textwrap.dedent(f"""
import io
from viztracer.report_builder import ReportBuilder
rb = ReportBuilder([{repr(invalid_json_file)}], verbose=1)
try:
with io.StringIO() as s:
rb.save(s)
except Exception as e:
assert str(e) == "No valid json files found"
else:
assert False
""")

self.template(["python", "cmdline_test.py"], script=script, expected_output_file=None)

0 comments on commit bbedd36

Please sign in to comment.