Skip to content

Commit

Permalink
fix: handle dicts with mulitple insertions and deletions
Browse files Browse the repository at this point in the history
  • Loading branch information
15r10nk committed Mar 22, 2024
1 parent de7377d commit f263f62
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 125 deletions.
211 changes: 105 additions & 106 deletions inline_snapshot/_change.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from typing import Tuple
from typing import Union

from asttokens import ASTTokens
from asttokens.util import Token
from executing.executing import EnhancedAST
from executing.executing import Source
Expand All @@ -32,40 +31,11 @@ def apply(self):
raise NotImplementedError()


def extend_comma(atok: ASTTokens, start: Token, end: Token) -> Tuple[Token, Token]:
# prev = atok.prev_token(start)
# if prev.string == ",":
# return prev, end

next = atok.next_token(end)
if next.string == ",":
return start, next

return start, end


@dataclass()
class Delete(Change):
node: ast.AST
old_value: Any

def apply(self):
change = ChangeRecorder.current.new_change()
parent = self.node.parent
atok = self.source.asttokens()
if isinstance(parent, ast.Dict):
index = parent.values.index(self.node)
key = parent.keys[index]

start, *_ = atok.get_tokens(key)
*_, end = atok.get_tokens(self.node)

start, end = extend_comma(atok, start, end)

change.replace((start, end), "", filename=self.filename)
else:
assert False


@dataclass()
class AddArgument(Change):
Expand Down Expand Up @@ -95,21 +65,6 @@ class DictInsert(Change):
new_code: List[Tuple[str, str]]
new_values: List[Tuple[Any, Any]]

def apply(self):
change = ChangeRecorder.current.new_change()
atok = self.source.asttokens()
code = ",".join(f"{k}:{v}" for k, v in self.new_code)

if self.position == len(self.node.keys):
*_, token = atok.get_tokens(self.node.values[-1])
token = atok.next_token(token)
code = ", " + code
else:
token, *_ = atok.get_tokens(self.node.keys[self.position])
code = code + ", "

change.insert(token, code, filename=self.filename)


@dataclass()
class Replace(Change):
Expand All @@ -125,11 +80,80 @@ def apply(self):
change.replace(range, self.new_code, filename=self.filename)


TokenRange = Tuple[Token, Token]


def generic_sequence_update(
source: Source,
parent: Union[ast.List, ast.Tuple, ast.Dict],
parent_elements: List[Union[TokenRange, None]],
to_insert: Dict[int, List[str]],
):
rec = ChangeRecorder.current.new_change()

new_code = []
deleted = False
last_token, *_, end_token = source.asttokens().get_tokens(parent)
is_start = True
elements = 0

for index, entry in enumerate(parent_elements):
if index in to_insert:
new_code += to_insert[index]
print("insert", entry, new_code)
if entry is None:
deleted = True
print("delete1", entry)
else:
first_token, new_last_token = entry
elements += len(new_code) + 1

if deleted or new_code:
print("change", deleted, new_code)

code = ""
if new_code:
code = ", ".join(new_code) + ", "
if not is_start:
code = ", " + code
print("code", code)

rec.replace(
(end_of(last_token), start_of(first_token)),
code,
filename=source.filename,
)
print("keep", entry)
new_code = []
deleted = False
last_token = new_last_token
is_start = False

if len(parent_elements) in to_insert:
new_code += to_insert[len(parent_elements)]
elements += len(new_code)

if new_code or deleted or elements == 1 or len(parent_elements) <= 1:
code = ", ".join(new_code)
if not is_start and code:
code = ", " + code

if elements == 1 and isinstance(parent, ast.Tuple):
# trailing comma for tuples (1,)i
code += ","

rec.replace(
(end_of(last_token), start_of(end_token)),
code,
filename=source.filename,
)


def apply_all(all_changes: List[Change]):
by_parent: Dict[EnhancedAST, List[Union[Delete, DictInsert, ListInsert]]] = (
defaultdict(list)
)
sources = {}
sources: Dict[EnhancedAST, Source] = {}

for change in all_changes:
if isinstance(change, Delete):
Expand All @@ -146,78 +170,53 @@ def apply_all(all_changes: List[Change]):

for parent, changes in by_parent.items():
source = sources[parent]
print(parent, changes)

rec = ChangeRecorder.current.new_change()
if isinstance(parent, (ast.List, ast.Tuple)):
to_delete = {
change.node for change in changes if isinstance(change, Delete)
}
to_insert = {
change.position: change
change.position: change.new_code
for change in changes
if isinstance(change, ListInsert)
}

new_code = []
deleted = False
last_token, *_, end_token = source.asttokens().get_tokens(parent)
is_start = True
elements = 0

for index, entry in enumerate(parent.elts):
if index in to_insert:
new_code += to_insert[index].new_code
print("insert", entry, new_code)
if entry in to_delete:
deleted = True
print("delete1", entry)
else:
entry_tokens = list(source.asttokens().get_tokens(entry))
first_token = entry_tokens[0]
new_last_token = entry_tokens[-1]
elements += len(new_code) + 1

if deleted or new_code:
print("change", deleted, new_code)

code = ""
if new_code:
code = ", ".join(new_code) + ", "
if not is_start:
code = ", " + code
print("code", code)

rec.replace(
(end_of(last_token), start_of(first_token)),
code,
filename=source.filename,
)
print("keep", entry)
new_code = []
deleted = False
last_token = new_last_token
is_start = False

if len(parent.elts) in to_insert:
new_code += to_insert[len(parent.elts)].new_code
elements += len(new_code)

if new_code or deleted or elements == 1 or len(parent.elts) <= 1:
code = ", ".join(new_code)
if not is_start and code:
code = ", " + code
def list_token_range(entry):
r = list(source.asttokens().get_tokens(entry))
return r[0], r[-1]

if elements == 1 and isinstance(parent, ast.Tuple):
# trailing comma for tuples (1,)i
code += ","
generic_sequence_update(
source,
parent,
[None if e in to_delete else list_token_range(e) for e in parent.elts],
to_insert,
)

rec.replace(
(end_of(last_token), start_of(end_token)),
code,
filename=source.filename,
elif isinstance(parent, (ast.Dict)):
to_delete = {
change.node for change in changes if isinstance(change, Delete)
}
to_insert = {
change.position: [f"{key}: {value}" for key, value in change.new_code]
for change in changes
if isinstance(change, DictInsert)
}

def dict_token_range(key, value):
return (
list(source.asttokens().get_tokens(key))[0],
list(source.asttokens().get_tokens(value))[-1],
)

generic_sequence_update(
source,
parent,
[
None if value in to_delete else dict_token_range(key, value)
for key, value in zip(parent.keys, parent.values)
],
to_insert,
)

else:
for change in changes:
change.apply()
assert False
9 changes: 6 additions & 3 deletions inline_snapshot/_inline_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,14 +284,17 @@ def check(old_value, old_node, new_value):

for key, node in zip(old_value.keys(), old_node.values):
if key in new_value:
# check values with same keys
yield from check(old_value[key], node, new_value[key])
else:
# delete entries
yield Delete("fix", self._source, node, old_value[key])

to_insert = []
insert_pos = 0
for key, new_value_element in new_value.items():
if key not in old_value:
# add new values
to_insert.append((key, new_value_element))
else:
if to_insert:
Expand All @@ -302,7 +305,7 @@ def check(old_value, old_node, new_value):
yield DictInsert(
"fix",
self._source,
node.parent,
old_node,
insert_pos,
new_code,
to_insert,
Expand All @@ -318,8 +321,8 @@ def check(old_value, old_node, new_value):
yield DictInsert(
"fix",
self._source,
node.parent,
insert_pos,
old_node,
len(old_node.values),
new_code,
to_insert,
)
Expand Down
38 changes: 22 additions & 16 deletions tests/test_preserve_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_fix_dict_remove(check_update):
"""assert {1:1}==snapshot({0:0, 1:0+1, 2:2})""",
reported_flags="update,fix",
flags="fix",
) == snapshot("assert {1:1}==snapshot({ 1:0+1, })")
) == snapshot("assert {1:1}==snapshot({1:0+1})")

assert check_update(
"""assert {}==snapshot({0:0})""",
Expand All @@ -61,7 +61,7 @@ def test_fix_dict_insert(check_update):
reported_flags="update,fix",
flags="fix",
) == snapshot(
"""assert {0:"before",1:1,2:"after"}==snapshot({0:"before", 1:0+1, 2:"after"})"""
'assert {0:"before",1:1,2:"after"}==snapshot({0: "before", 1:0+1, 2: "after"})'
)


Expand Down Expand Up @@ -166,23 +166,27 @@ def test_preserve_case_from_original_mr(check_update):


def test_generic(source, subtests):
codes = []

for braces in ("[]", "()"):
for s in itertools.product(stuff, repeat=3):
flags = set().union(*[e[3] for e in s])
name = ",".join(e[2] for e in s)
print(flags)
for braces in ("[]", "()", "{}"):
for value_specs in itertools.product(stuff, repeat=3):
flags = set().union(*[e[3] for e in value_specs])
all_flags = {
frozenset(x) - {""}
for x in itertools.combinations_with_replacement(
flags | {""}, len(flags)
)
}
print(all_flags)

def build(l):
values = [x for e in l for x in e]
def build(value_lists):
value_lists = list(value_lists)

if braces == "{}":
values = [
f"{i}: {value_list[0]}"
for i, value_list in enumerate(value_lists)
if value_list
]
else:
values = [x for value_list in value_lists for x in value_list]

code = ", ".join(values)

Expand All @@ -192,13 +196,13 @@ def build(l):

return f"{braces[0]}{code}{comma}{braces[1]}"

c1 = build(e[0] for e in s)
c2 = build(e[1] for e in s)
c1 = build(spec[0] for spec in value_specs)
c2 = build(spec[1] for spec in value_specs)
code = f"assert {c2}==snapshot({c1})"

named_flags = ", ".join(flags)
with subtests.test(f"{c1} -> {c2} <{named_flags}>"):

with subtests.test(f"{c1} -> {c2} <{named_flags}>"):
s1 = source(code)
print("source:", code)

Expand All @@ -207,7 +211,9 @@ def build(l):
assert ("fix" in flags) == s1.error

for f in all_flags:
c3 = build([(e[1] if e[3] & f else e[0]) for e in s])
c3 = build(
[(spec[1] if spec[3] & f else spec[0]) for spec in value_specs]
)
new_code = f"assert {c2}==snapshot({c3})"

print(f"{set(f)}:")
Expand Down

0 comments on commit f263f62

Please sign in to comment.