Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Validate min < max occurs #979

Merged
merged 1 commit into from
Mar 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions tests/codegen/parsers/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,15 +416,11 @@ def test_end_schema(
schema.elements.append(Element())
schema.elements.append(Element())

for el in schema.elements:
self.assertEqual(1, el.min_occurs)
self.assertEqual(1, el.max_occurs)

self.parser.end_schema(schema)

for el in schema.elements:
self.assertIsNone(el.min_occurs)
self.assertIsNone(el.max_occurs)
self.assertEqual(1, el.min_occurs)
self.assertEqual(1, el.max_occurs)

self.parser.end_schema(ComplexType())

Expand Down
10 changes: 10 additions & 0 deletions tests/models/xsd/test_all.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
import sys
from unittest import TestCase

from xsdata.models.xsd import All


class AllTests(TestCase):
def test_normalize_max_occurs(self):
obj = All(min_occurs=3, max_occurs=2)
self.assertEqual(3, obj.max_occurs)
self.assertEqual(3, obj.min_occurs)

obj = All(min_occurs=3, max_occurs="unbounded")
self.assertEqual(sys.maxsize, obj.max_occurs)
self.assertEqual(3, obj.min_occurs)

def test_get_restrictions(self):
obj = All(min_occurs=1, max_occurs=2)
self.assertEqual({"path": [("a", id(obj), 1, 2)]}, obj.get_restrictions())
10 changes: 10 additions & 0 deletions tests/models/xsd/test_any.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
import sys
from unittest import TestCase

from xsdata.models.enums import Namespace, NamespaceType
from xsdata.models.xsd import Any


class AnyTests(TestCase):
def test_normalize_max_occurs(self):
obj = Any(min_occurs=3, max_occurs=2)
self.assertEqual(3, obj.max_occurs)
self.assertEqual(3, obj.min_occurs)

obj = Any(min_occurs=3, max_occurs="unbounded")
self.assertEqual(sys.maxsize, obj.max_occurs)
self.assertEqual(3, obj.min_occurs)

def test_property_is_property(self):
self.assertTrue(Any().is_property)

Expand Down
14 changes: 9 additions & 5 deletions tests/models/xsd/test_choice.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@


class ChoiceTests(TestCase):
def test_normalize_max_occurs(self):
obj = Choice(min_occurs=3, max_occurs=2)
self.assertEqual(3, obj.max_occurs)
self.assertEqual(3, obj.min_occurs)

obj = Choice(min_occurs=3, max_occurs="unbounded")
self.assertEqual(sys.maxsize, obj.max_occurs)
self.assertEqual(3, obj.min_occurs)

def test_get_restrictions(self):
obj = Choice(min_occurs=1, max_occurs=2)
self.assertEqual({"path": [("c", id(obj), 1, 2)]}, obj.get_restrictions())

obj = Choice(max_occurs="unbounded")
self.assertEqual(
{"path": [("c", id(obj), 1, sys.maxsize)]}, obj.get_restrictions()
)
10 changes: 10 additions & 0 deletions tests/models/xsd/test_element.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
from unittest import TestCase

from xsdata.codegen.exceptions import CodegenError
Expand All @@ -13,6 +14,15 @@


class ElementTests(TestCase):
def test_normalize_max_occurs(self):
obj = Element(min_occurs=3, max_occurs=2)
self.assertEqual(3, obj.max_occurs)
self.assertEqual(3, obj.min_occurs)

obj = Element(min_occurs=3, max_occurs="unbounded")
self.assertEqual(sys.maxsize, obj.max_occurs)
self.assertEqual(3, obj.min_occurs)

def test_property_is_property(self):
obj = Element()
self.assertTrue(obj)
Expand Down
10 changes: 10 additions & 0 deletions tests/models/xsd/test_group.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
import sys
from unittest import TestCase

from xsdata.models.xsd import Group


class GroupTests(TestCase):
def test_normalize_max_occurs(self):
obj = Group(min_occurs=3, max_occurs=2)
self.assertEqual(3, obj.max_occurs)
self.assertEqual(3, obj.min_occurs)

obj = Group(min_occurs=3, max_occurs="unbounded")
self.assertEqual(sys.maxsize, obj.max_occurs)
self.assertEqual(3, obj.min_occurs)

def test_property_is_property(self):
obj = Group()
self.assertTrue(obj.is_property)
Expand Down
14 changes: 9 additions & 5 deletions tests/models/xsd/test_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@


class SequenceTests(TestCase):
def test_normalize_max_occurs(self):
obj = Sequence(min_occurs=3, max_occurs=2)
self.assertEqual(3, obj.max_occurs)
self.assertEqual(3, obj.min_occurs)

obj = Sequence(min_occurs=3, max_occurs="unbounded")
self.assertEqual(sys.maxsize, obj.max_occurs)
self.assertEqual(3, obj.min_occurs)

def test_get_restrictions(self):
obj = Sequence(min_occurs=1, max_occurs=2)
self.assertEqual({"path": [("s", id(obj), 1, 2)]}, obj.get_restrictions())

obj = Sequence(min_occurs=1, max_occurs="unbounded")
self.assertEqual(
{"path": [("s", id(obj), 1, sys.maxsize)]}, obj.get_restrictions()
)
6 changes: 5 additions & 1 deletion xsdata/codegen/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,11 @@ def asdict(self, types: Optional[List[Type]] = None) -> Dict:
result["min_occurs"] = self.min_occurs
if self.max_occurs is not None and self.max_occurs < sys.maxsize:
result["max_occurs"] = self.max_occurs
elif self.min_occurs == self.max_occurs == 1 and not self.nillable:
elif (
self.min_occurs == self.max_occurs == 1
and not self.nillable
and not self.tokens
):
result["required"] = True

for key, value in asdict(self).items():
Expand Down
14 changes: 0 additions & 14 deletions xsdata/codegen/parsers/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ def end_schema(self, obj: T):
self.set_schema_namespaces(obj)
self.add_default_imports(obj)
self.resolve_schemas_locations(obj)
self.reset_element_occurs(obj)

def end_attribute(self, obj: T):
"""End attribute element entrypoint.
Expand Down Expand Up @@ -411,16 +410,3 @@ def add_default_imports(cls, obj: xsd.Schema):
xsi_ns = Namespace.XSI.uri
if xsi_ns in obj.ns_map.values() and xsi_ns not in imp_namespaces:
obj.imports.insert(0, xsd.Import(namespace=xsi_ns))

@classmethod
def reset_element_occurs(cls, obj: xsd.Schema):
"""Reset the root elements occurs restrictions.

The root elements don't get those.

Args:
obj: The xsd schema instance
"""
for element in obj.elements:
element.min_occurs = None
element.max_occurs = None
69 changes: 45 additions & 24 deletions xsdata/models/xsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,16 @@
)


def validate_max_occurs(min_occurs: int, max_occurs: UnionType[str, int]) -> int:
"""Validate max occurs."""
if max_occurs == "unbounded":
max_occurs = sys.maxsize

assert isinstance(max_occurs, int)

return max(max_occurs, min_occurs)


@dataclass(frozen=True)
class Docstring:
"""Docstring model representation.
Expand Down Expand Up @@ -121,7 +131,7 @@ class AnyAttribute(AnnotationBase):
)

def __post_init__(self):
"""Clean the namespace value."""
"""Post initialization validations."""
self.namespace = " ".join(unique_sequence(self.namespace.split()))

@property
Expand Down Expand Up @@ -351,14 +361,15 @@ class Any(AnnotationBase):

namespace: str = attribute(default="##any")
min_occurs: int = attribute(default=1, name="minOccurs")
max_occurs: UnionType[int, str] = attribute(default=1, name="maxOccurs")
max_occurs: UnionType[str, int] = attribute(default=1, name="maxOccurs")
process_contents: ProcessType = attribute(
default=ProcessType.STRICT, name="processContents"
)

def __post_init__(self):
"""Clean the namespace value."""
"""Post initialization validations."""
self.namespace = " ".join(unique_sequence(self.namespace.split()))
self.max_occurs = validate_max_occurs(self.min_occurs, self.max_occurs)

@property
def is_property(self) -> bool:
Expand Down Expand Up @@ -397,17 +408,19 @@ class All(AnnotationBase):
"""XSD All model representation."""

min_occurs: int = attribute(default=1, name="minOccurs")
max_occurs: UnionType[int, str] = attribute(default=1, name="maxOccurs")
max_occurs: UnionType[str, int] = attribute(default=1, name="maxOccurs")
any: Array[Any] = array_element(name="any")
elements: Array["Element"] = array_element(name="element")
groups: Array["Group"] = array_element(name="group")

def __post_init__(self):
"""Post initialization validations."""
self.max_occurs = validate_max_occurs(self.min_occurs, self.max_occurs)

def get_restrictions(self) -> Dict[str, Anything]:
"""Return the restrictions dictionary of this element."""
max_occurs = sys.maxsize if self.max_occurs == "unbounded" else self.max_occurs

return {
"path": [("a", id(self), self.min_occurs, max_occurs)],
"path": [("a", id(self), self.min_occurs, self.max_occurs)],
}


Expand All @@ -416,19 +429,21 @@ class Sequence(AnnotationBase):
"""XSD Sequence model representation."""

min_occurs: int = attribute(default=1, name="minOccurs")
max_occurs: UnionType[int, str] = attribute(default=1, name="maxOccurs")
max_occurs: UnionType[str, int] = attribute(default=1, name="maxOccurs")
elements: Array["Element"] = array_element(name="element")
groups: Array["Group"] = array_element(name="group")
choices: Array["Choice"] = array_element(name="choice")
sequences: Array["Sequence"] = array_element(name="sequence")
any: Array["Any"] = array_element()

def __post_init__(self):
"""Post initialization validations."""
self.max_occurs = validate_max_occurs(self.min_occurs, self.max_occurs)

def get_restrictions(self) -> Dict[str, Anything]:
"""Return the restrictions dictionary of this element."""
max_occurs = sys.maxsize if self.max_occurs == "unbounded" else self.max_occurs

return {
"path": [("s", id(self), self.min_occurs, max_occurs)],
"path": [("s", id(self), self.min_occurs, self.max_occurs)],
}


Expand All @@ -437,19 +452,21 @@ class Choice(AnnotationBase):
"""XSD Choice model representation."""

min_occurs: int = attribute(default=1, name="minOccurs")
max_occurs: UnionType[int, str] = attribute(default=1, name="maxOccurs")
max_occurs: UnionType[str, int] = attribute(default=1, name="maxOccurs")
elements: Array["Element"] = array_element(name="element")
groups: Array["Group"] = array_element(name="group")
choices: Array["Choice"] = array_element(name="choice")
sequences: Array[Sequence] = array_element(name="sequence")
any: Array["Any"] = array_element()

def __post_init__(self):
"""Post initialization validations."""
self.max_occurs = validate_max_occurs(self.min_occurs, self.max_occurs)

def get_restrictions(self) -> Dict[str, Anything]:
"""Return the restrictions dictionary of this element."""
max_occurs = sys.maxsize if self.max_occurs == "unbounded" else self.max_occurs

return {
"path": [("c", id(self), self.min_occurs, max_occurs)],
"path": [("c", id(self), self.min_occurs, self.max_occurs)],
}


Expand All @@ -460,11 +477,15 @@ class Group(AnnotationBase):
name: Optional[str] = attribute()
ref: str = attribute(default="")
min_occurs: int = attribute(default=1, name="minOccurs")
max_occurs: UnionType[int, str] = attribute(default=1, name="maxOccurs")
max_occurs: UnionType[str, int] = attribute(default=1, name="maxOccurs")
all: Optional[All] = element()
choice: Optional[Choice] = element()
sequence: Optional[Sequence] = element()

def __post_init__(self):
"""Post initialization validations."""
self.max_occurs = validate_max_occurs(self.min_occurs, self.max_occurs)

@property
def is_property(self) -> bool:
"""Specify it is qualified to be a class property."""
Expand All @@ -478,10 +499,8 @@ def attr_types(self) -> Iterator[str]:

def get_restrictions(self) -> Dict[str, Anything]:
"""Return the restrictions dictionary of this element."""
max_occurs = sys.maxsize if self.max_occurs == "unbounded" else self.max_occurs

return {
"path": [("g", id(self), self.min_occurs, max_occurs)],
"path": [("g", id(self), self.min_occurs, self.max_occurs)],
}


Expand Down Expand Up @@ -862,11 +881,15 @@ class Element(AnnotationBase):
uniques: Array[Unique] = array_element(name="unique")
keys: Array[Key] = array_element(name="key")
keyrefs: Array[Keyref] = array_element(name="keyref")
min_occurs: Optional[int] = attribute(default=1, name="minOccurs")
max_occurs: UnionType[None, int, str] = attribute(default=1, name="maxOccurs")
min_occurs: int = attribute(default=1, name="minOccurs")
max_occurs: UnionType[str, int] = attribute(default=1, name="maxOccurs")
nillable: bool = attribute(default=False)
abstract: bool = attribute(default=False)

def __post_init__(self):
"""Post initialization validations."""
self.max_occurs = validate_max_occurs(self.min_occurs, self.max_occurs)

@property
def bases(self) -> Iterator[str]:
"""Return an iterator of all the base types."""
Expand Down Expand Up @@ -910,11 +933,9 @@ def substitutions(self) -> Array[str]:

def get_restrictions(self) -> Dict[str, Anything]:
"""Return the restrictions dictionary of this element."""
max_occurs = sys.maxsize if self.max_occurs == "unbounded" else self.max_occurs

restrictions = {
"min_occurs": self.min_occurs,
"max_occurs": max_occurs,
"max_occurs": self.max_occurs,
}

if self.simple_type:
Expand Down
Loading