Skip to content

Commit

Permalink
PYON: fix nested encoding
Browse files Browse the repository at this point in the history
Forces recursive calls of pyon.encode(), which is compatible with the plugin architecture.

Previously, values similar to {'entry': CustomType()} would not encode

because _Encoder didn't recognze CustomType
  • Loading branch information
drewrisinger committed Sep 15, 2021
1 parent 0d53280 commit 4e41e61
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 15 deletions.
47 changes: 32 additions & 15 deletions sipyco/pyon.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@


class _Encoder:
def __init__(self, pretty):
def __init__(self, pretty, indent_level=0):
self.pretty = pretty
self.indent_level = 0
self.indent_level = indent_level

def indent(self):
return " " * self.indent_level
Expand All @@ -100,22 +100,22 @@ def encode_bytes(self, x):

def encode_tuple(self, x):
if len(x) == 1:
return "(" + self.encode(x[0]) + ", )"
return "(" + encode(x[0], self.pretty, self.indent_level) + ", )"
else:
r = "("
r += ", ".join([self.encode(item) for item in x])
r += ", ".join([encode(item, self.pretty, self.indent_level) for item in x])
r += ")"
return r

def encode_list(self, x):
r = "["
r += ", ".join([self.encode(item) for item in x])
r += ", ".join([encode(item, self.pretty, self.indent_level) for item in x])
r += "]"
return r

def encode_set(self, x):
r = "{"
r += ", ".join([self.encode(item) for item in x])
r += ", ".join([encode(item, self.pretty, self.indent_level) for item in x])
r += "}"
return r

Expand All @@ -127,8 +127,14 @@ def encode_dict(self, x):

r = "{"
if not self.pretty or len(x) < 2:
r += ", ".join([self.encode(k) + ": " + self.encode(v)
for k, v in items()])
r += ", ".join(
[
encode(k, self.pretty, self.indent_level)
+ ": "
+ encode(v, self.pretty, self.indent_level)
for k, v in items()
]
)
else:
self.indent_level += 1
r += "\n"
Expand All @@ -137,7 +143,12 @@ def encode_dict(self, x):
if not first:
r += ",\n"
first = False
r += self.indent() + self.encode(k) + ": " + self.encode(v)
r += (
self.indent()
+ encode(k, self.pretty, self.indent_level)
+ ": "
+ encode(v, self.pretty, self.indent_level)
)
r += "\n" # no ','
self.indent_level -= 1
r += self.indent()
Expand All @@ -148,24 +159,30 @@ def encode_slice(self, x):
return repr(x)

def encode_fraction(self, x):
return "Fraction({}, {})".format(self.encode(x.numerator),
self.encode(x.denominator))
return "Fraction({}, {})".format(
encode(x.numerator, self.pretty, self.indent_level),
encode(x.denominator, self.pretty, self.indent_level),
)

def encode_ordereddict(self, x):
return "OrderedDict(" + self.encode(list(x.items())) + ")"
return (
"OrderedDict("
+ encode(list(x.items()), self.pretty, self.indent_level)
+ ")"
)

def encode_nparray(self, x):
x = numpy.ascontiguousarray(x)
r = "nparray("
r += self.encode(x.shape) + ", "
r += self.encode(x.dtype.str) + ", b\""
r += encode(x.shape, self.pretty, self.indent_level) + ", "
r += encode(x.dtype.str, self.pretty, self.indent_level) + ", b\""
r += base64.b64encode(x.data).decode()
r += "\")"
return r

def encode_npscalar(self, x):
r = "npscalar("
r += self.encode(x.dtype.str) + ", b\""
r += encode(x.dtype.str, self.pretty, self.indent_level) + ", b\""
r += base64.b64encode(x.data).decode()
r += "\")"
return r
Expand Down
7 changes: 7 additions & 0 deletions sipyco/test/test_pyon_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,10 @@ def test_pyon_plugin_encode_decode(monkeypatch):
monkeypatch.setattr(plugin, "get_plugin_manager", pyon_extra_plugin)
test_value = Point(2.5, 3.4)
assert pyon.decode(pyon.encode(test_value)) == test_value


def test_pyon_nested_encode(monkeypatch):
"""Tests that nested items will be properly encoded."""
monkeypatch.setattr(plugin, "get_plugin_manager", pyon_extra_plugin)
test_value = {"first": Point(2.5, {"nothing": 0})}
assert pyon.decode(pyon.encode(test_value)) == test_value

0 comments on commit 4e41e61

Please sign in to comment.