Skip to content

Commit

Permalink
feat: fix lists by calculating the alignment of the changed values
Browse files Browse the repository at this point in the history
  • Loading branch information
15r10nk committed Mar 11, 2024
1 parent 6476154 commit ef1fef2
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 16 deletions.
64 changes: 64 additions & 0 deletions inline_snapshot/_align.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from itertools import groupby


def align(seq_a, seq_b) -> str:

matrix: list = [[(0, "e")] + [(0, "i")] * len(seq_b)]

for a in seq_a:
last = matrix[-1]

new_line = [(0, "d")]
for bi, b in enumerate(seq_b, 1):
la, lc, lb = new_line[-1], last[bi - 1], last[bi]
values = [(la[0], "i"), (lb[0], "d")]
if a == b:
values.append((lc[0] + 1, "m"))

new_line.append(max(values))
matrix.append(new_line)

# backtrack

ai = len(seq_a)
bi = len(seq_b)
d = ""
track = ""

while d != "e":
_, d = matrix[ai][bi]
if d == "m":
ai -= 1
bi -= 1
elif d == "i":
bi -= 1
elif d == "d":
ai -= 1
if d != "e":
track += d

return track[::-1]


def add_x(track):
"""Replaces an `id` with the same number of insertions and deletions with
x."""
groups = [(c, len(list(v))) for c, v in groupby(track)]
i = 0
result = ""
while i < len(groups):
g = groups[i]
if i == len(groups) - 1:
result += g[0] * g[1]
break

ng = groups[i + 1]
if g[0] == "d" and ng[0] == "i" and g[1] == ng[1]:
result += "x" * g[1]
i += 1
else:
result += g[0] * g[1]

i += 1

return result
32 changes: 29 additions & 3 deletions inline_snapshot/_change.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ def apply(self):


def extend_comma(atok: ASTTokens, start: Token, end: Token) -> Tuple[Token, Token]:
prev = atok.prev_token(start)
if prev.string == ",":
return prev, end
# prev = atok.prev_token(start)
# if prev.string == ",":
# return prev, end

next = atok.next_token(end)
if next.string == ",":
Expand All @@ -55,6 +55,13 @@ def apply(self):

start, end = extend_comma(atok, start, end)

change.replace((start, end), "", filename=self.filename)
elif isinstance(parent, ast.List):
tokens = list(atok.get_tokens(self.node))
start, end = tokens[0], tokens[-1]

start, end = extend_comma(atok, start, end)

change.replace((start, end), "", filename=self.filename)
else:
assert False
Expand All @@ -79,6 +86,25 @@ class ListInsert(Change):
new_code: List[str]
new_values: List[Any]

def apply(self):
change = ChangeRecorder.current.new_change()
atok = self.source.asttokens()

code = ", ".join(self.new_code)

assert self.position <= len(self.node.elts)

if self.position == len(self.node.elts):
*_, token = atok.get_tokens(self.node)
assert token.string == "]"
if self.position != 0:
code = ", " + code
else:
token, *_ = atok.get_tokens(self.node.elts[self.position])
code = code + ", "

change.insert(token, code, filename=self.filename)


@dataclass()
class DictInsert(Change):
Expand Down
41 changes: 36 additions & 5 deletions inline_snapshot/_inline_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import inspect
import sys
import tokenize
from collections import defaultdict
from pathlib import Path
from typing import Any
from typing import Dict # noqa
Expand All @@ -13,9 +14,12 @@

from executing import Source

from ._align import add_x
from ._align import align
from ._change import Change
from ._change import Delete
from ._change import DictInsert
from ._change import ListInsert
from ._change import Replace
from ._format import format_code
from ._rewrite_code import ChangeRecorder
Expand Down Expand Up @@ -220,15 +224,42 @@ def check(old_value, old_node, new_value):
isinstance(old_node, ast.List)
and isinstance(new_value, list)
and isinstance(old_value, list)
or isinstance(old_node, ast.Tuple)
and isinstance(new_value, tuple)
and isinstance(old_value, tuple)
):
if len(old_value) == len(new_value) == len(old_node.elts):
for old_value_element, old_node_element, new_value_element in zip(
old_value, old_node.elts, new_value
):
diff = add_x(align(old_value, new_value))
old = zip(old_value, old_node.elts)
new = iter(new_value)
old_position = 0
to_insert = defaultdict(list)
for c in diff:
if c in "mx":
old_value_element, old_node_element = next(old)
new_value_element = next(new)
yield from check(
old_value_element, old_node_element, new_value_element
)
return
old_position += 1
elif c == "i":
new_value_element = next(new)
new_code = self._value_to_code(new_value_element)
to_insert[old_position].append((new_code, new_value_element))
elif c == "d":
old_value_element, old_node_element = next(old)
yield Delete(
"fix", self._source, old_node_element, old_value_element
)
old_position += 1
else:
assert False

for position, code_values in to_insert.items():
yield ListInsert(
"fix", self._source, old_node, position, *zip(*code_values)
)

return

elif (
isinstance(old_node, ast.Dict)
Expand Down
20 changes: 14 additions & 6 deletions inline_snapshot/_rewrite_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,11 @@ def insert(self, node, new_content, *, filename):
self.replace(start_of(node), new_content, filename=filename)

def _replace(self, filename, range, new_contend):
self.change_recorder.get_source(filename).replacements.append(
source = self.change_recorder.get_source(filename)
source.replacements.append(
Replacement(range=range, text=new_contend, change_id=self.change_id)
)
source._check()


class SourceFile:
Expand All @@ -133,17 +135,23 @@ def rewrite(self):
with open(self.filename, "bw") as code:
code.write(new_code.encode())

def new_code(self):
"""Returns the new file contend or None if there are no replacepents to
apply."""
def _check(self):
replacements = list(self.replacements)
replacements.sort()

for r in replacements:
assert r.range.start <= r.range.end
assert r.range.start <= r.range.end, r

for lhs, rhs in pairwise(replacements):
assert lhs.range.end <= rhs.range.start
assert lhs.range.end <= rhs.range.start, (lhs, rhs)

def new_code(self):
"""Returns the new file contend or None if there are no replacepents to
apply."""
replacements = list(self.replacements)
replacements.sort()

self._check()

code = self.filename.read_text("utf-8")

Expand Down
10 changes: 10 additions & 0 deletions tests/test_align.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from inline_snapshot import snapshot
from inline_snapshot._align import add_x
from inline_snapshot._align import align


def test_align():
assert align("iabc", "abcd") == snapshot("dmmmi")

assert align("abbc", "axyc") == snapshot("mddiim")
assert add_x(align("abbc", "axyc")) == snapshot("mxxm")
20 changes: 18 additions & 2 deletions tests/test_preserve_values.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,28 @@
from inline_snapshot import snapshot


def test_fix_list(check_update):
def test_fix_list_fix(check_update):
assert check_update(
"""assert [1,2]==snapshot([0+1,3])""", reported_flags="update,fix", flags="fix"
) == snapshot("""assert [1,2]==snapshot([0+1,2])""")


def test_fix_list_insert(check_update):
assert check_update(
"""assert [1,2,3,4,5,6]==snapshot([0+1,3])""",
reported_flags="update,fix",
flags="fix",
) == snapshot("assert [1,2,3,4,5,6]==snapshot([0+1,2, 3, 4, 5, 6])")


def test_fix_list_delete(check_update):
assert check_update(
"""assert [1,5]==snapshot([0+1,2,3,4,5])""",
reported_flags="update,fix",
flags="fix",
) == snapshot("assert [1,5]==snapshot([0+1,5])")


def test_fix_dict_change(check_update):
assert check_update(
"""assert {1:1, 2:2}==snapshot({1:0+1, 2:3})""",
Expand All @@ -20,7 +36,7 @@ def test_fix_dict_remove(check_update):
"""assert {1:1}==snapshot({0:0, 1:0+1, 2:2})""",
reported_flags="update,fix",
flags="fix",
) == snapshot("assert {1:1}==snapshot({ 1:0+1})")
) == snapshot("assert {1:1}==snapshot({ 1:0+1, })")

assert check_update(
"""assert {}==snapshot({0:0})""",
Expand Down

0 comments on commit ef1fef2

Please sign in to comment.