Skip to content

Commit

Permalink
fix: snapshots with pydantic models can now be compared multiple times
Browse files Browse the repository at this point in the history
  • Loading branch information
15r10nk committed Jan 12, 2025
1 parent 22865dc commit 793ad2f
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 6 deletions.
13 changes: 13 additions & 0 deletions changelog.d/20250112_121000_15r10nk-git_pydantic_ai_fixes.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
### Fixed

- snapshots with pydantic models can now be compared multiple times

``` python
class A(BaseModel):
a: int


def test_something():
for _ in [1, 2]:
assert A(a=1) == snapshot(A(a=1))
```
22 changes: 16 additions & 6 deletions src/inline_snapshot/_adapter/generic_call_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,24 @@ def map(cls, value, map_function):
)

def items(self, value, node):
assert isinstance(node, ast.Call)
assert not node.args
assert all(kw.arg for kw in node.keywords)
new_args, new_kwargs = self.arguments(value)

if node is not None:
assert isinstance(node, ast.Call)
assert not node.args
assert all(kw.arg for kw in node.keywords)
kw_arg_node = {kw.arg: kw.value for kw in node.keywords if kw.arg}.get
pos_arg_node = lambda pos: node.args[pos]
else:
kw_arg_node = lambda _: None
pos_arg_node = lambda _: None

return [
Item(value=self.argument(value, kw.arg), node=kw.value)
for kw in node.keywords
if kw.arg
Item(value=arg.value, node=pos_arg_node(i))
for i, arg in enumerate(new_args)
] + [
Item(value=kw.value, node=kw_arg_node(name))
for name, kw in new_kwargs.items()
]

def assign(self, old_value, old_node, new_value):
Expand Down
18 changes: 18 additions & 0 deletions tests/test_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,21 @@ def test_something():
}
),
)


def test_pydantic_evaluate_twice():
Example(
"""\
from inline_snapshot import snapshot
from pydantic import BaseModel
class A(BaseModel):
a:int
def test_something():
for _ in [1,2]:
assert A(a=1) == snapshot(A(a=1))
"""
).run_pytest(
changed_files=snapshot({}),
)

0 comments on commit 793ad2f

Please sign in to comment.