Skip to content

Commit

Permalink
feat: Prevent classes with ambiguous choices
Browse files Browse the repository at this point in the history
  • Loading branch information
tefra committed Feb 27, 2024
1 parent b8ddec2 commit f3f1ca3
Show file tree
Hide file tree
Showing 14 changed files with 129 additions and 130 deletions.
21 changes: 13 additions & 8 deletions tests/fixtures/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,17 +95,10 @@ class ChoiceType:
{"name": "a", "type": TypeA},
{"name": "b", "type": TypeB},
{"name": "int", "type": int},
{"name": "int2", "type": int, "nillable": True},
{"name": "float", "type": float},
{"name": "qname", "type": QName},
{
"name": "tokens",
"type": List[int],
"tokens": True,
"default_factory": return_true
},
{"name": "union", "type": Type["UnionType"], "namespace": "foo"},
{"name": "p", "type": float, "fixed": True, "default": 1.1},
{"name": "tokens", "type": List[str], "tokens": True},
{
"wildcard": True,
"type": object,
Expand All @@ -128,6 +121,18 @@ class OptionalChoiceType:
)


@dataclass
class AmbiguousChoiceType:
choice: int = field(
metadata={
"type": "Elements",
"choices": (
{"name": "a", "type": int},
{"name": "b", "type": int},
),
}
)


@dataclass
class UnionType:
Expand Down
104 changes: 61 additions & 43 deletions tests/formats/dataclass/models/test_builders.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import sys
import uuid
from dataclasses import dataclass, field, fields, make_dataclass
Expand All @@ -7,7 +8,15 @@

from tests.fixtures.artists import Artist
from tests.fixtures.books import BookForm
from tests.fixtures.models import ChoiceType, Parent, TypeA, TypeB, TypeNS1, UnionType
from tests.fixtures.models import (
AmbiguousChoiceType,
ChoiceType,
Parent,
TypeA,
TypeB,
TypeNS1,
UnionType,
)
from tests.fixtures.series import Country
from tests.fixtures.submodels import ChoiceTypeChild
from xsdata.exceptions import XmlContextError
Expand All @@ -16,7 +25,7 @@
from xsdata.formats.dataclass.models.elements import XmlMeta, XmlType
from xsdata.models.datatype import XmlDate
from xsdata.utils import text
from xsdata.utils.constants import return_input, return_true
from xsdata.utils.constants import return_input
from xsdata.utils.namespaces import build_qname
from xsdata.utils.testing import FactoryTestCase, XmlMetaFactory, XmlVarFactory

Expand Down Expand Up @@ -132,7 +141,7 @@ def test_build_with_no_dataclass_raises_exception(self, *args):
def test_build_locates_globalns_per_field(self):
actual = self.builder.build(ChoiceTypeChild, None)
self.assertEqual(1, len(actual.choices))
self.assertEqual(9, len(actual.choices[0].elements))
self.assertEqual(7, len(actual.choices[0].elements))

with self.assertRaises(XmlContextError):
self.builder.find_declared_class(object, "foo")
Expand Down Expand Up @@ -276,6 +285,8 @@ def test_default_xml_type(self):


class XmlVarBuilderTests(TestCase):
maxDiff = None

def setUp(self) -> None:
self.builder = XmlVarBuilder(
class_type=class_types.get_type("dataclasses"),
Expand All @@ -285,15 +296,14 @@ def setUp(self) -> None:
)

super().setUp()
self.maxDiff = None

def test_build_with_choice_field(self):
globalns = sys.modules[ChoiceType.__module__].__dict__
type_hints = get_type_hints(ChoiceType)
class_field = fields(ChoiceType)[0]

self.maxDiff = None
actual = self.builder.build(
ChoiceType,
"choice",
type_hints["choice"],
class_field.metadata,
Expand Down Expand Up @@ -337,66 +347,45 @@ def test_build_with_choice_field(self):
factory=list,
namespaces=("bar",),
),
"{bar}int2": XmlVarFactory.create(
index=5,
name="choice",
qname="{bar}int2",
types=(int,),
derived=True,
nillable=True,
factory=list,
namespaces=("bar",),
),
"{bar}float": XmlVarFactory.create(
index=6,
index=5,
name="choice",
qname="{bar}float",
types=(float,),
factory=list,
namespaces=("bar",),
),
"{bar}qname": XmlVarFactory.create(
index=7,
index=6,
name="choice",
qname="{bar}qname",
types=(QName,),
factory=list,
namespaces=("bar",),
),
"{bar}tokens": XmlVarFactory.create(
index=8,
name="choice",
qname="{bar}tokens",
types=(int,),
tokens_factory=list,
derived=True,
factory=list,
default=return_true,
namespaces=("bar",),
),
"{foo}union": XmlVarFactory.create(
index=9,
index=7,
name="choice",
qname="{foo}union",
types=(UnionType,),
clazz=UnionType,
factory=list,
namespaces=("foo",),
),
"{bar}p": XmlVarFactory.create(
index=10,
"{bar}tokens": XmlVarFactory.create(
index=8,
name="choice",
qname="{bar}p",
types=(float,),
qname="{bar}tokens",
types=(str,),
tokens_factory=list,
derived=True,
factory=list,
default=1.1,
namespaces=("bar",),
),
},
wildcards=[
XmlVarFactory.create(
index=11,
index=9,
name="choice",
xml_type=XmlType.WILDCARD,
qname="{http://www.w3.org/1999/xhtml}any",
Expand All @@ -408,17 +397,44 @@ def test_build_with_choice_field(self):
],
)

self.maxDiff = None
self.assertEqual(expected, actual)

def test_build_with_ambiguous_choices(self):
type_hints = get_type_hints(AmbiguousChoiceType)
class_field = fields(AmbiguousChoiceType)[0]

with self.assertRaises(XmlContextError) as cm:
self.builder.build(
AmbiguousChoiceType,
"choice",
type_hints["choice"],
class_field.metadata,
True,
None,
None,
{},
)

self.assertEqual(
"Error on AmbiguousChoiceType::choice: Compound field contains ambiguous types",
str(cm.exception),
)

def test_build_validates_result(self):
with self.assertRaises(XmlContextError) as cm:
self.builder.build(
"foo", List[int], {"type": "Attributes"}, True, None, None, None
BookForm,
"foo",
List[int],
{"type": "Attributes"},
True,
None,
None,
None,
)

self.assertEqual(
"Xml type 'Attributes' does not support typing: typing.List[int]",
"Error on BookForm::foo: Xml Attributes does not support typing `typing.List[int]`",
str(cm.exception),
)

Expand Down Expand Up @@ -465,20 +481,22 @@ def test_resolve_namespaces(self):
self.assertEqual(("foo", "p"), tuple(sorted(actual)))

def test_analyze_types(self):
actual = self.builder.analyze_types(List[List[Union[str, int]]], None)
func = functools.partial(self.builder.analyze_types, BookForm, "foo")

actual = func(List[List[Union[str, int]]], None)
self.assertEqual((list, list, (int, str)), actual)

actual = self.builder.analyze_types(Union[str, int], None)
actual = func(Union[str, int], None)
self.assertEqual((None, None, (int, str)), actual)

actual = self.builder.analyze_types(Dict[str, int], None)
actual = func(Dict[str, int], None)
self.assertEqual((dict, None, (int, str)), actual)

with self.assertRaises(XmlContextError) as cm:
self.builder.analyze_types(List[List[List[int]]], None)
func(List[List[List[int]]], None)

self.assertEqual(
"Unsupported typing: typing.List[typing.List[typing.List[int]]]",
"Error on BookForm::foo: Unsupported field typing `typing.List[typing.List[typing.List[int]]]`",
str(cm.exception),
)

Expand Down
2 changes: 0 additions & 2 deletions tests/formats/dataclass/parsers/nodes/test_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,6 @@ def test_build_node_with_dataclass_var(self, mock_ctx_fetch, mock_xsi_type):
name="a",
qname="a",
types=(TypeC,),
derived=True,
)
xsi_type = "foo"
namespace = self.meta.namespace
Expand All @@ -384,7 +383,6 @@ def test_build_node_with_dataclass_var(self, mock_ctx_fetch, mock_xsi_type):

self.assertIsInstance(actual, ElementNode)
self.assertEqual(10, actual.position)
self.assertEqual(DerivedElement, actual.derived_factory)
self.assertIs(mock_ctx_fetch.return_value, actual.meta)

mock_xsi_type.assert_called_once_with(attrs, ns_map)
Expand Down
36 changes: 10 additions & 26 deletions tests/formats/dataclass/parsers/nodes/test_primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from xsdata.exceptions import XmlContextError
from xsdata.formats.dataclass.models.elements import XmlType
from xsdata.formats.dataclass.models.generics import DerivedElement
from xsdata.formats.dataclass.parsers.nodes import PrimitiveNode
from xsdata.formats.dataclass.parsers.utils import ParserUtils
from xsdata.utils.testing import XmlVarFactory
Expand All @@ -16,7 +15,7 @@ def test_bind(self, mock_parse_value):
xml_type=XmlType.TEXT, name="foo", qname="foo", types=(int,), format="Nope"
)
ns_map = {"foo": "bar"}
node = PrimitiveNode(var, ns_map, False, DerivedElement)
node = PrimitiveNode(var, ns_map, False)
objects = []

self.assertTrue(node.bind("foo", "13", "Impossible", objects))
Expand All @@ -31,23 +30,12 @@ def test_bind(self, mock_parse_value):
format=var.format,
)

def test_bind_derived_mode(self):
var = XmlVarFactory.create(
xml_type=XmlType.TEXT, name="foo", qname="foo", types=(int,), derived=True
)
ns_map = {"foo": "bar"}
node = PrimitiveNode(var, ns_map, False, DerivedElement)
objects = []

self.assertTrue(node.bind("foo", "13", "Impossible", objects))
self.assertEqual(DerivedElement("foo", 13), objects[-1][1])

def test_bind_nillable_content(self):
var = XmlVarFactory.create(
xml_type=XmlType.TEXT, name="foo", qname="foo", types=(str,), nillable=False
)
ns_map = {"foo": "bar"}
node = PrimitiveNode(var, ns_map, False, DerivedElement)
node = PrimitiveNode(var, ns_map, False)
objects = []

self.assertTrue(node.bind("foo", None, None, objects))
Expand All @@ -66,7 +54,7 @@ def test_bind_nillable_bytes_content(self):
nillable=False,
)
ns_map = {"foo": "bar"}
node = PrimitiveNode(var, ns_map, False, DerivedElement)
node = PrimitiveNode(var, ns_map, False)
objects = []

self.assertTrue(node.bind("foo", None, None, objects))
Expand All @@ -77,29 +65,25 @@ def test_bind_nillable_bytes_content(self):
self.assertIsNone(objects[-1][1])

def test_bind_mixed_with_tail_content(self):
var = XmlVarFactory.create(
xml_type=XmlType.TEXT, name="foo", types=(int,), derived=True
)
node = PrimitiveNode(var, {}, True, DerivedElement)
var = XmlVarFactory.create(xml_type=XmlType.TEXT, name="foo", types=(int,))
node = PrimitiveNode(var, {}, True)
objects = []

self.assertTrue(node.bind("foo", "13", "tail", objects))
self.assertEqual((None, "tail"), objects[-1])
self.assertEqual(DerivedElement("foo", 13), objects[-2][1])
self.assertEqual(13, objects[-2][1])

def test_bind_mixed_without_tail_content(self):
var = XmlVarFactory.create(
xml_type=XmlType.TEXT, name="foo", types=(int,), derived=True
)
node = PrimitiveNode(var, {}, True, DerivedElement)
var = XmlVarFactory.create(xml_type=XmlType.TEXT, name="foo", types=(int,))
node = PrimitiveNode(var, {}, True)
objects = []

self.assertTrue(node.bind("foo", "13", "", objects))
self.assertEqual(DerivedElement("foo", 13), objects[-1][1])
self.assertEqual(13, objects[-1][1])

def test_child(self):
var = XmlVarFactory.create(xml_type=XmlType.TEXT, name="foo", qname="foo")
node = PrimitiveNode(var, {}, False, DerivedElement)
node = PrimitiveNode(var, {}, False)

with self.assertRaises(XmlContextError):
node.child("foo", {}, {}, 0)
16 changes: 4 additions & 12 deletions tests/formats/dataclass/parsers/test_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,17 +232,16 @@ def test_bind_simple_type_with_wildcard_var(self):
self.assertEqual(2, actual.wildcard)

def test_bind_simple_type_with_elements_var(self):
data = {"choice": ["1.0", 1, ["1"], "a", "{a}b"]}
data = {"choice": ["1.0", 1, "a", "{a}b"]}

actual = self.decoder.bind_dataclass(data, ChoiceType)

self.assertEqual(1.0, actual.choice[0])
self.assertEqual(1, actual.choice[1])
self.assertEqual([1], actual.choice[2])
self.assertEqual(QName("a"), actual.choice[3])
self.assertEqual(QName("a"), actual.choice[2])
self.assertIsInstance(actual.choice[2], QName)
self.assertEqual(QName("{a}b"), actual.choice[3])
self.assertIsInstance(actual.choice[3], QName)
self.assertEqual(QName("{a}b"), actual.choice[4])
self.assertIsInstance(actual.choice[4], QName)

data = {"choice": ["!NotAQname"]}
with self.assertRaises(ParserError) as cm:
Expand Down Expand Up @@ -278,13 +277,6 @@ def test_bind_choice_dataclass(self):
expected = ChoiceType(choice=[TypeA(x=1), TypeB(x=1, y="a")])
self.assertEqual(expected, self.decoder.bind_dataclass(data, ChoiceType))

def test_bind_derived_value_with_simple_type(self):
data = {"choice": [{"qname": "int2", "value": 1, "type": None}]}

actual = self.decoder.bind_dataclass(data, ChoiceType)
expected = ChoiceType(choice=[DerivedElement(qname="int2", value=1)])
self.assertEqual(expected, actual)

def test_bind_derived_value_with_choice_var(self):
data = {
"choice": [
Expand Down
Loading

0 comments on commit f3f1ca3

Please sign in to comment.