Skip to content

Commit

Permalink
fix: flatten_annotation behaviour for Optional (#440)
Browse files Browse the repository at this point in the history
  • Loading branch information
sam-or authored Nov 15, 2023
1 parent 1407f07 commit b479e4a
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
4 changes: 2 additions & 2 deletions polyfactory/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def flatten_annotation(annotation: Any) -> list[Any]:
if is_new_type(annotation):
flat.extend(flatten_annotation(unwrap_new_type(annotation)))
elif is_optional(annotation):
flat.append(NoneType)
flat.extend(flatten_annotation(arg) for arg in get_args(annotation) if arg not in (NoneType, None))
for a in get_args(annotation):
flat.extend(flatten_annotation(a))
elif is_annotated(annotation):
flat.extend(flatten_annotation(get_args(annotation)[0]))
elif is_union(annotation):
Expand Down
17 changes: 16 additions & 1 deletion tests/test_type_coverage_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from dataclasses import dataclass, make_dataclass
from datetime import date
from typing import Dict, FrozenSet, List, Literal, Set, Tuple, Union
from typing import Dict, FrozenSet, List, Literal, Optional, Set, Tuple, Union
from uuid import UUID

import pytest
Expand Down Expand Up @@ -212,3 +212,18 @@ class Factory(DataclassFactory[Model]):

with pytest.raises(ParameterException):
list(Factory.coverage())


def test_coverage_optional_field() -> None:
@dataclass
class OptionalInt:
i: Optional[int]

class OptionalIntFactory(DataclassFactory[OptionalInt]):
__model__ = OptionalInt

results = list(OptionalIntFactory.coverage())
assert len(results) == 2

assert isinstance(results[0].i, int)
assert results[1].i is None

0 comments on commit b479e4a

Please sign in to comment.