diff --git a/sipyco/pyon.py b/sipyco/pyon.py index 3f312eb..e691ece 100644 --- a/sipyco/pyon.py +++ b/sipyco/pyon.py @@ -72,9 +72,9 @@ class _Encoder: - def __init__(self, pretty): + def __init__(self, pretty, indent_level=0): self.pretty = pretty - self.indent_level = 0 + self.indent_level = indent_level def indent(self): return " " * self.indent_level @@ -100,22 +100,22 @@ def encode_bytes(self, x): def encode_tuple(self, x): if len(x) == 1: - return "(" + self.encode(x[0]) + ", )" + return "(" + encode(x[0], self.pretty, self.indent_level) + ", )" else: r = "(" - r += ", ".join([self.encode(item) for item in x]) + r += ", ".join([encode(item, self.pretty, self.indent_level) for item in x]) r += ")" return r def encode_list(self, x): r = "[" - r += ", ".join([self.encode(item) for item in x]) + r += ", ".join([encode(item, self.pretty, self.indent_level) for item in x]) r += "]" return r def encode_set(self, x): r = "{" - r += ", ".join([self.encode(item) for item in x]) + r += ", ".join([encode(item, self.pretty, self.indent_level) for item in x]) r += "}" return r @@ -127,8 +127,14 @@ def encode_dict(self, x): r = "{" if not self.pretty or len(x) < 2: - r += ", ".join([self.encode(k) + ": " + self.encode(v) - for k, v in items()]) + r += ", ".join( + [ + encode(k, self.pretty, self.indent_level) + + ": " + + encode(v, self.pretty, self.indent_level) + for k, v in items() + ] + ) else: self.indent_level += 1 r += "\n" @@ -137,7 +143,12 @@ def encode_dict(self, x): if not first: r += ",\n" first = False - r += self.indent() + self.encode(k) + ": " + self.encode(v) + r += ( + self.indent() + + encode(k, self.pretty, self.indent_level) + + ": " + + encode(v, self.pretty, self.indent_level) + ) r += "\n" # no ',' self.indent_level -= 1 r += self.indent() @@ -148,24 +159,30 @@ def encode_slice(self, x): return repr(x) def encode_fraction(self, x): - return "Fraction({}, {})".format(self.encode(x.numerator), - self.encode(x.denominator)) + return "Fraction({}, {})".format( + encode(x.numerator, self.pretty, self.indent_level), + encode(x.denominator, self.pretty, self.indent_level), + ) def encode_ordereddict(self, x): - return "OrderedDict(" + self.encode(list(x.items())) + ")" + return ( + "OrderedDict(" + + encode(list(x.items()), self.pretty, self.indent_level) + + ")" + ) def encode_nparray(self, x): x = numpy.ascontiguousarray(x) r = "nparray(" - r += self.encode(x.shape) + ", " - r += self.encode(x.dtype.str) + ", b\"" + r += encode(x.shape, self.pretty, self.indent_level) + ", " + r += encode(x.dtype.str, self.pretty, self.indent_level) + ", b\"" r += base64.b64encode(x.data).decode() r += "\")" return r def encode_npscalar(self, x): r = "npscalar(" - r += self.encode(x.dtype.str) + ", b\"" + r += encode(x.dtype.str, self.pretty, self.indent_level) + ", b\"" r += base64.b64encode(x.data).decode() r += "\")" return r diff --git a/sipyco/test/test_pyon_plugin.py b/sipyco/test/test_pyon_plugin.py index 720f52c..e90a10e 100644 --- a/sipyco/test/test_pyon_plugin.py +++ b/sipyco/test/test_pyon_plugin.py @@ -50,3 +50,10 @@ def test_pyon_plugin_encode_decode(monkeypatch): monkeypatch.setattr(plugin, "get_plugin_manager", pyon_extra_plugin) test_value = Point(2.5, 3.4) assert pyon.decode(pyon.encode(test_value)) == test_value + + +def test_pyon_nested_encode(monkeypatch): + """Tests that nested items will be properly encoded.""" + monkeypatch.setattr(plugin, "get_plugin_manager", pyon_extra_plugin) + test_value = {"first": Point(2.5, {"nothing": 0})} + assert pyon.decode(pyon.encode(test_value)) == test_value