From 9eae7f9afe3765b1020519e5f60a7477b3cf9682 Mon Sep 17 00:00:00 2001 From: Christodoulos Tsoulloftas Date: Sat, 9 Mar 2024 20:50:16 +0200 Subject: [PATCH 1/6] feat: Add config to generate wrapper fields --- tests/models/test_config.py | 2 ++ xsdata/models/config.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/tests/models/test_config.py b/tests/models/test_config.py index 734d3bb86..656ae7b76 100644 --- a/tests/models/test_config.py +++ b/tests/models/test_config.py @@ -36,6 +36,7 @@ def test_create(self): " reStructuredText\n" " false\n" ' false\n' + " false\n" " false\n" " false\n" " false\n" @@ -96,6 +97,7 @@ def test_read(self): " reStructuredText\n" " false\n" ' false\n' + " false\n" " false\n" " false\n" " false\n" diff --git a/xsdata/models/config.py b/xsdata/models/config.py index 52364a055..08e2f934f 100644 --- a/xsdata/models/config.py +++ b/xsdata/models/config.py @@ -221,6 +221,7 @@ class GeneratorOutput: docstring_style: Docstring style relative_imports: Use relative imports compound_fields: Use compound fields for repeatable elements + wrapper_fields: Generate wrapper fields for element lists max_line_length: Adjust the maximum line length subscriptable_types: Use PEP-585 generics for standard collections, python>=3.9 Only @@ -239,6 +240,7 @@ class GeneratorOutput: docstring_style: DocstringStyle = element(default=DocstringStyle.RST) relative_imports: bool = element(default=False) compound_fields: CompoundFields = element(default_factory=CompoundFields) + wrapper_fields: bool = element(default=False) max_line_length: int = attribute(default=79) subscriptable_types: bool = attribute(default=False) union_type: bool = attribute(default=False) From 7cfc3afaf9b8b4abf16a16933063a44c60651b79 Mon Sep 17 00:00:00 2001 From: Christodoulos Tsoulloftas Date: Sat, 16 Mar 2024 17:28:34 +0200 Subject: [PATCH 2/6] feat: Add handler to generate wrapper fields --- .../handlers/test_create_wrapper_fields.py | 112 ++++++++++++++++ .../handlers/test_disambiguate_choices.py | 4 +- tests/codegen/test_container.py | 1 + tests/formats/dataclass/test_filters.py | 7 + xsdata/codegen/container.py | 2 + xsdata/codegen/handlers/__init__.py | 2 + .../handlers/create_compound_fields.py | 26 ++-- .../codegen/handlers/create_wrapper_fields.py | 124 ++++++++++++++++++ .../codegen/handlers/disambiguate_choices.py | 5 +- xsdata/codegen/models.py | 1 + xsdata/codegen/utils.py | 9 +- xsdata/formats/dataclass/filters.py | 1 + xsdata/utils/testing.py | 24 +++- 13 files changed, 293 insertions(+), 25 deletions(-) create mode 100644 tests/codegen/handlers/test_create_wrapper_fields.py create mode 100644 xsdata/codegen/handlers/create_wrapper_fields.py diff --git a/tests/codegen/handlers/test_create_wrapper_fields.py b/tests/codegen/handlers/test_create_wrapper_fields.py new file mode 100644 index 000000000..81222faf2 --- /dev/null +++ b/tests/codegen/handlers/test_create_wrapper_fields.py @@ -0,0 +1,112 @@ +from xsdata.codegen.container import ClassContainer +from xsdata.codegen.handlers import CreateWrapperFields +from xsdata.models.config import GeneratorConfig +from xsdata.models.enums import DataType, Tag +from xsdata.utils.testing import ( + AttrFactory, + AttrTypeFactory, + ClassFactory, + ExtensionFactory, + FactoryTestCase, +) + + +class CreateWrapperFieldsTests(FactoryTestCase): + def setUp(self): + super().setUp() + + self.config = GeneratorConfig() + self.config.output.wrapper_fields = True + self.container = ClassContainer(config=self.config) + self.processor = CreateWrapperFields(container=self.container) + + self.target = ClassFactory.create() + self.target.attrs.append( + AttrFactory.reference("foo", name="items", tag=Tag.ELEMENT) + ) + + self.source = ClassFactory.create(qname="foo") + self.source.attrs.append(AttrFactory.native(DataType.STRING, name="item")) + self.container.extend([self.target, self.source]) + + def test_process_skip_with_config_disabled(self): + self.config.output.wrapper_fields = False + self.processor.process(self.target) + self.assertIsNone(self.target.attrs[0].wrapper) + + def test_process_with_valid_attr_wrapper(self): + self.processor.process(self.target) + self.assertEqual("items", self.target.attrs[0].wrapper) + + def test_process_with_invalid_attr(self): + self.target.attrs[0].tag = Tag.EXTENSION + self.processor.process(self.target) + + self.assertIsNone(self.target.attrs[0].wrapper) + + def test_process_with_invalid_source(self): + self.source.extensions.append(ExtensionFactory.create()) + self.processor.process(self.target) + + self.assertIsNone(self.target.attrs[0].wrapper) + + def test_wrap_field(self): + source = AttrFactory.create() + attr = AttrFactory.create() + wrapper = attr.local_name + + self.processor.wrap_field(source, attr) + self.assertEqual(source.name, attr.name) + self.assertEqual(source.local_name, attr.local_name) + self.assertEqual(wrapper, attr.wrapper) + + def test_find_source_with_forward_reference(self): + tp = self.target.attrs[0].types[0] + tp.forward = True + self.target.inner.append(self.source) + + actual = self.processor.find_source(self.target, tp) + self.assertEqual(self.source, actual) + + def test_validate_attr(self): + # Not an element + attr = AttrFactory.create(tag=Tag.EXTENSION) + self.assertFalse(self.processor.validate_attr(attr)) + + # Multiple types + attr.tag = Tag.ELEMENT + attr.types = AttrTypeFactory.list(2) + self.assertFalse(self.processor.validate_attr(attr)) + + # Native type + attr.types = [AttrTypeFactory.native(DataType.STRING)] + self.assertFalse(self.processor.validate_attr(attr)) + + # Not any of the above issues + attr.types = [AttrTypeFactory.create()] + self.assertTrue(self.processor.validate_attr(attr)) + + def test_validate_source(self): + source = ClassFactory.create() + + # Has extensions + source.extensions = ExtensionFactory.list(1) + self.assertFalse(self.processor.validate_source(source, None)) + + # Has multiple attrs + source.extensions.clear() + source.attrs = AttrFactory.list(2) + self.assertFalse(self.processor.validate_source(source, None)) + + # Has forwarded references + source.attrs.pop(0) + source.attrs[0].types[0].forward = True + self.assertFalse(self.processor.validate_source(source, None)) + + # Namespace doesn't match + source.attrs[0].types[0].forward = False + self.assertFalse(self.processor.validate_source(source, "bar")) + + # Not any of the above issues + source.attrs[0].namespace = "bar" + self.assertTrue(self.processor.validate_source(source, "bar")) diff --git a/tests/codegen/handlers/test_disambiguate_choices.py b/tests/codegen/handlers/test_disambiguate_choices.py index 1a99e0afe..e88759230 100644 --- a/tests/codegen/handlers/test_disambiguate_choices.py +++ b/tests/codegen/handlers/test_disambiguate_choices.py @@ -96,8 +96,8 @@ def test_process_with_duplicate_complex_types(self): compound = AttrFactory.any() target = ClassFactory.create() target.attrs.append(compound) - compound.choices.append(AttrFactory.reference(name="a", qname="myint")) - compound.choices.append(AttrFactory.reference(name="b", qname="myint")) + compound.choices.append(AttrFactory.reference(qname="myint")) + compound.choices.append(AttrFactory.reference(qname="myint")) self.container.add(target) self.handler.process(target) diff --git a/tests/codegen/test_container.py b/tests/codegen/test_container.py index 0fe89c264..a4804d80f 100644 --- a/tests/codegen/test_container.py +++ b/tests/codegen/test_container.py @@ -57,6 +57,7 @@ def test_initialize(self): 60: [ "DetectCircularReferences", "CreateCompoundFields", + "CreateWrapperFields", "DisambiguateChoices", "ResetAttributeSequenceNumbers", ], diff --git a/tests/formats/dataclass/test_filters.py b/tests/formats/dataclass/test_filters.py index 4b54920ac..3569a852f 100644 --- a/tests/formats/dataclass/test_filters.py +++ b/tests/formats/dataclass/test_filters.py @@ -467,6 +467,13 @@ def test_field_metadata_name(self): actual = self.filters.field_metadata(attr, None, ["cls"]) self.assertNotIn("name", actual) + def test_field_metadata_wrapper(self): + attr = AttrFactory.element(wrapper="foo") + expected = {"name": "attr_B", "wrapper": "foo", "type": "Element"} + + actual = self.filters.field_metadata(attr, None, ["cls"]) + self.assertEqual(expected, actual) + def test_field_metadata_restrictions(self): attr = AttrFactory.create(tag=Tag.RESTRICTION) attr.types.append(AttrTypeFactory.native(DataType.INT)) diff --git a/xsdata/codegen/container.py b/xsdata/codegen/container.py index afaba2e98..2a4a8bb5b 100644 --- a/xsdata/codegen/container.py +++ b/xsdata/codegen/container.py @@ -4,6 +4,7 @@ AddAttributeSubstitutions, CalculateAttributePaths, CreateCompoundFields, + CreateWrapperFields, DesignateClassPackages, DetectCircularReferences, DisambiguateChoices, @@ -99,6 +100,7 @@ def __init__(self, config: GeneratorConfig): Steps.FINALIZE: [ DetectCircularReferences(self), CreateCompoundFields(self), + CreateWrapperFields(self), DisambiguateChoices(self), ResetAttributeSequenceNumbers(self), ], diff --git a/xsdata/codegen/handlers/__init__.py b/xsdata/codegen/handlers/__init__.py index b732a6643..eae3879b7 100644 --- a/xsdata/codegen/handlers/__init__.py +++ b/xsdata/codegen/handlers/__init__.py @@ -1,6 +1,7 @@ from .add_attribute_substitutions import AddAttributeSubstitutions from .calculate_attribute_paths import CalculateAttributePaths from .create_compound_fields import CreateCompoundFields +from .create_wrapper_fields import CreateWrapperFields from .designate_class_packages import DesignateClassPackages from .detect_circular_references import DetectCircularReferences from .disambiguate_choices import DisambiguateChoices @@ -26,6 +27,7 @@ "AddAttributeSubstitutions", "CalculateAttributePaths", "CreateCompoundFields", + "CreateWrapperFields", "DesignateClassPackages", "DetectCircularReferences", "DisambiguateChoices", diff --git a/xsdata/codegen/handlers/create_compound_fields.py b/xsdata/codegen/handlers/create_compound_fields.py index 558e96a20..2f1eb2ec5 100644 --- a/xsdata/codegen/handlers/create_compound_fields.py +++ b/xsdata/codegen/handlers/create_compound_fields.py @@ -133,24 +133,20 @@ def group_fields(self, target: Class, attrs: List[Attr]): min_occurs, max_occurs = self.sum_counters(counters) name = self.choose_name(target, names, list(filter(None, substitutions))) - types = collections.unique_sequence( - t.clone() for attr in attrs for t in attr.types - ) - target.attrs.insert( - pos, - Attr( - name=name, - index=0, - types=types, - tag=Tag.CHOICE, - restrictions=Restrictions( - min_occurs=sum(min_occurs), - max_occurs=max(max_occurs) if choice > 0 else sum(max_occurs), - ), - choices=choices, + compound_attr = Attr( + name=name, + index=0, + types=[], + tag=Tag.CHOICE, + restrictions=Restrictions( + min_occurs=sum(min_occurs), + max_occurs=max(max_occurs) if choice > 0 else sum(max_occurs), ), + choices=choices, ) + ClassUtils.reset_choice_types(compound_attr) + target.attrs.insert(pos, compound_attr) def sum_counters(self, counters: Dict) -> Tuple[List[int], List[int]]: """Sum the min/max occurrences for the compound attr. diff --git a/xsdata/codegen/handlers/create_wrapper_fields.py b/xsdata/codegen/handlers/create_wrapper_fields.py new file mode 100644 index 000000000..7cce5827e --- /dev/null +++ b/xsdata/codegen/handlers/create_wrapper_fields.py @@ -0,0 +1,124 @@ +from typing import Optional + +from xsdata.codegen.mixins import RelativeHandlerInterface +from xsdata.codegen.models import Attr, AttrType, Class + + +class CreateWrapperFields(RelativeHandlerInterface): + """Create wrapper fields. + + Args: + container: The class container instance + """ + + def process(self, target: Class): + """Process the given class attrs and choices. + + Args: + target: The target class instance + """ + if not self.container.config.output.wrapper_fields: + return + + for attr in target.attrs: + if self.validate_attr(attr): + self.process_attr(target, attr) + + def process_attr(self, target: Class, attr: Attr): + """Process the given attr instance. + + Args: + target: The parent class instance + attr: The attr instance to process + """ + source = self.find_source(target, attr.types[0]) + if self.validate_source(source, attr.namespace): + self.wrap_field(source.attrs[0], attr) + + @classmethod + def wrap_field(cls, source: Attr, attr: Attr): + """Create a wrapper field. + + Clone the source attr and update its name, local name and wrapper + attributes. + + Args: + source: The source attr instance + attr: The attr instance to wrap + """ + wrapper = attr.local_name + + attr.swap(source) + attr.wrapper = wrapper + + def find_source(self, parent: Class, tp: AttrType) -> Class: + """Find the source type for the given attr type instance. + + If it's a forward reference, look up the source in + the parent class inners. + + Args: + parent: The parent class instance + tp: The attr type instance to look up + + Returns: + The source class instance that matches the attr type. + """ + if tp.forward: + return self.container.find_inner(parent, tp.qname) + + return self.container.first(tp.qname) + + @classmethod + def validate_attr(cls, attr: Attr) -> bool: + """Validate if the attr can be converted to a wrapper field. + + Rules: + 1. Must be an element + 2. Must have only one type + 3. It has to be a user type + 4. The element can't be optional + + + Args: + attr: The attr instance to validate + + Returns: + Whether the attr can be converted to a wrapper. + + """ + return ( + attr.is_element + and len(attr.types) == 1 + and not attr.types[0].native + and not attr.is_optional + ) + + @classmethod + def validate_source(cls, source: Class, namespace: Optional[str]) -> bool: + """Validate if the source class can be converted to a wrapper field. + + Rules: + 1. It must not have any extensions + 2. It must contain exactly one type + 3. It must not be a forward reference + 4. The source attr namespace must match the namespace + + + Args: + source: The source class instance to validate + namespace: The processing attr namespace + + Returns: + Whether the source class can be converted to a wrapper. + """ + + def ns_equal(a: Optional[str], b: Optional[str]): + return (a or "") == (b or "") + + return ( + not source.extensions + and len(source.attrs) == 1 + and not source.attrs[0].is_forward_ref + and ns_equal(source.attrs[0].namespace, namespace) + ) diff --git a/xsdata/codegen/handlers/disambiguate_choices.py b/xsdata/codegen/handlers/disambiguate_choices.py index 12dfe77d4..b9acaf901 100644 --- a/xsdata/codegen/handlers/disambiguate_choices.py +++ b/xsdata/codegen/handlers/disambiguate_choices.py @@ -3,6 +3,7 @@ from xsdata.codegen.mixins import ContainerInterface, RelativeHandlerInterface from xsdata.codegen.models import Attr, AttrType, Class, Extension, Restrictions +from xsdata.codegen.utils import ClassUtils from xsdata.models.enums import DataType, Tag from xsdata.utils import collections, text from xsdata.utils.constants import DEFAULT_ATTR_NAME @@ -62,9 +63,7 @@ def process_compound_field(self, target: Class, attr: Attr): for choice in self.find_ambiguous_choices(attr): self.disambiguate_choice(target, choice) - if attr.tag == Tag.CHOICE: - types = (tp for choice in attr.choices for tp in choice.types) - attr.types = collections.unique_sequence(x.clone() for x in types) + ClassUtils.reset_choice_types(attr) @classmethod def merge_wildcard_choices(cls, attr: Attr): diff --git a/xsdata/codegen/models.py b/xsdata/codegen/models.py index 90b406fd6..83035595e 100644 --- a/xsdata/codegen/models.py +++ b/xsdata/codegen/models.py @@ -287,6 +287,7 @@ class Attr(CodegenModel): tag: str name: str = field(compare=False) local_name: str = field(init=False) + wrapper: Optional[str] = field(default=None) index: int = field(compare=False, default_factory=int) default: Optional[str] = field(default=None, compare=False) fixed: bool = field(default=False, compare=False) diff --git a/xsdata/codegen/utils.py b/xsdata/codegen/utils.py index 91815ee7f..535b03afc 100644 --- a/xsdata/codegen/utils.py +++ b/xsdata/codegen/utils.py @@ -12,7 +12,7 @@ get_qname, get_slug, ) -from xsdata.models.enums import DataType +from xsdata.models.enums import DataType, Tag from xsdata.utils import collections, namespaces, text @@ -451,6 +451,13 @@ def unique_name(cls, name: str, reserved: Set[str]) -> str: return name + @classmethod + def reset_choice_types(cls, attr: Attr): + """Reset the choice types.""" + if attr.tag == Tag.CHOICE: + types = (tp for choice in attr.choices for tp in choice.types) + attr.types = collections.unique_sequence(x.clone() for x in types) + @classmethod def cleanup_class(cls, target: Class): """Go through the target class attrs and filter their types. diff --git a/xsdata/formats/dataclass/filters.py b/xsdata/formats/dataclass/filters.py index f51110c20..7b84db6cd 100644 --- a/xsdata/formats/dataclass/filters.py +++ b/xsdata/formats/dataclass/filters.py @@ -441,6 +441,7 @@ def field_metadata( restrictions = attr.restrictions.asdict(attr.native_types) metadata = { + "wrapper": attr.wrapper, "name": name, "type": attr.xml_type, "namespace": namespace, diff --git a/xsdata/utils/testing.py b/xsdata/utils/testing.py index 11e6bc882..6c272846f 100644 --- a/xsdata/utils/testing.py +++ b/xsdata/utils/testing.py @@ -267,6 +267,7 @@ def create( choices: Optional[List[Attr]] = None, tag: Optional[str] = None, namespace: Optional[str] = None, + wrapper: Optional[str] = None, default: Optional[Any] = None, fixed: bool = False, mixed: bool = False, @@ -282,6 +283,7 @@ def create( choices=choices or [], tag=tag or random.choice(cls.tags), namespace=namespace, + wrapper=wrapper, default=default, fixed=fixed, mixed=mixed, @@ -290,14 +292,28 @@ def create( ) @classmethod - def reference(cls, qname: str, tag: str = Tag.ELEMENT, **kwargs: Any) -> Attr: + def reference( + cls, + qname: str, + tag: str = Tag.ELEMENT, + name: Optional[str] = None, + **kwargs: Any, + ) -> Attr: return cls.create( - tag=tag, types=[AttrTypeFactory.create(qname=qname, **kwargs)] + name=name, tag=tag, types=[AttrTypeFactory.create(qname=qname, **kwargs)] ) @classmethod - def native(cls, datatype: DataType, tag: str = Tag.ELEMENT, **kwargs: Any) -> Attr: - return cls.create(tag=tag, types=[AttrTypeFactory.native(datatype)], **kwargs) + def native( + cls, + datatype: DataType, + tag: str = Tag.ELEMENT, + name: Optional[str] = None, + **kwargs: Any, + ) -> Attr: + return cls.create( + name=name, tag=tag, types=[AttrTypeFactory.native(datatype)], **kwargs + ) @classmethod def enumeration(cls, **kwargs: Any) -> Attr: From 68308b7731fa7d2d833110db2863896fd66d4e24 Mon Sep 17 00:00:00 2001 From: Christodoulos Tsoulloftas Date: Wed, 20 Mar 2024 15:20:59 +0200 Subject: [PATCH 3/6] feat: Support wrapper fields in json parser/serializer --- .../formats/dataclass/models/test_builders.py | 41 +------------------ tests/formats/dataclass/parsers/test_dict.py | 16 ++++++-- .../dataclass/serializers/test_dict.py | 17 +++++--- xsdata/formats/dataclass/models/builders.py | 14 ++----- xsdata/formats/dataclass/models/elements.py | 7 +++- xsdata/formats/dataclass/parsers/dict.py | 27 ++++++++---- xsdata/formats/dataclass/serializers/dict.py | 25 ++++++++--- .../formats/dataclass/serializers/mixins.py | 16 ++++---- 8 files changed, 81 insertions(+), 82 deletions(-) diff --git a/tests/formats/dataclass/models/test_builders.py b/tests/formats/dataclass/models/test_builders.py index c8129ffe1..a1547270a 100644 --- a/tests/formats/dataclass/models/test_builders.py +++ b/tests/formats/dataclass/models/test_builders.py @@ -3,7 +3,7 @@ import uuid from dataclasses import dataclass, field, fields, make_dataclass from decimal import Decimal -from typing import Dict, Iterator, List, Tuple, Union, get_type_hints +from typing import Dict, Iterator, List, Union, get_type_hints from unittest import TestCase, mock from xml.etree.ElementTree import QName @@ -23,7 +23,7 @@ from xsdata.exceptions import XmlContextError from xsdata.formats.dataclass.compat import class_types from xsdata.formats.dataclass.models.builders import XmlMetaBuilder, XmlVarBuilder -from xsdata.formats.dataclass.models.elements import XmlMeta, XmlType +from xsdata.formats.dataclass.models.elements import XmlType from xsdata.models.datatype import XmlDate from xsdata.utils import text from xsdata.utils.constants import return_input @@ -96,43 +96,6 @@ class Meta: result = self.builder.build(Thug, None) self.assertEqual("thug", result.qname) - def test_wrapper(self): - @dataclass - class PrimitiveType: - attr: str = field(metadata={"wrapper": "Items"}) - - @dataclass - class UnionType: - attr: Union[str, int] = field(metadata={"wrapper": "Items"}) - - @dataclass - class UnionCollection: - union_collection: List[Union[str, int]] = field( - metadata={"wrapper": "Items"} - ) - - @dataclass - class ListType: - attr: List[str] = field(metadata={"wrapper": "Items"}) - - @dataclass - class TupleType: - attr: Tuple[str, ...] = field(metadata={"wrapper": "Items"}) - - # @dataclass - # class SetType: - # attr: Set[str] = field(metadata={"wrapper": "Items"}) - - with self.assertRaises(XmlContextError): - self.builder.build(PrimitiveType, None) - with self.assertRaises(XmlContextError): - self.builder.build(UnionType, None) - - self.assertIsInstance(self.builder.build(ListType, None), XmlMeta) - self.assertIsInstance(self.builder.build(TupleType, None), XmlMeta) - # not supported by analyze_types - # self.assertIsInstance(self.builder.build(SetType, None), XmlMeta) - def test_build_with_no_dataclass_raises_exception(self, *args): with self.assertRaises(XmlContextError) as cm: self.builder.build(int, None) diff --git a/tests/formats/dataclass/parsers/test_dict.py b/tests/formats/dataclass/parsers/test_dict.py index 61a687237..548a86edc 100644 --- a/tests/formats/dataclass/parsers/test_dict.py +++ b/tests/formats/dataclass/parsers/test_dict.py @@ -19,6 +19,7 @@ TypeD, UnionType, ) +from tests.fixtures.wrapper import Wrapper from xsdata.exceptions import ParserError from xsdata.formats.dataclass.models.generics import AnyElement, DerivedElement from xsdata.formats.dataclass.parsers import DictDecoder @@ -104,6 +105,13 @@ def test_decode_with_fail_on_converter_warnings(self): str(cm.exception), ) + def test_decode_wrapper(self): + data = {"alphas": {"alpha": "value"}} + + actual = self.decoder.decode(data, Wrapper) + expected = Wrapper(alpha="value") + self.assertEqual(expected, actual) + def test_verify_type(self): invalid_cases = [ ( @@ -368,10 +376,10 @@ def test_find_var(self): meta = self.decoder.context.build(TypeB) xml_vars = meta.get_all_vars() - self.assertEqual(xml_vars[0], self.decoder.find_var(xml_vars, "x")) - self.assertEqual(xml_vars[0], self.decoder.find_var(xml_vars, "x", True)) + self.assertEqual(xml_vars[0], self.decoder.find_var(xml_vars, "x", 1)) + self.assertEqual(xml_vars[0], self.decoder.find_var(xml_vars, "x", [1, 2])) meta = self.decoder.context.build(ExtendedType) xml_vars = meta.get_all_vars() - self.assertIsNone(self.decoder.find_var(xml_vars, "a", True)) - self.assertEqual(xml_vars[0], self.decoder.find_var(xml_vars, "a")) + self.assertIsNone(self.decoder.find_var(xml_vars, "a", [1, 2])) + self.assertEqual(xml_vars[0], self.decoder.find_var(xml_vars, "a", {"x": 1})) diff --git a/tests/formats/dataclass/serializers/test_dict.py b/tests/formats/dataclass/serializers/test_dict.py index 5f5f95c61..d87674065 100644 --- a/tests/formats/dataclass/serializers/test_dict.py +++ b/tests/formats/dataclass/serializers/test_dict.py @@ -2,7 +2,7 @@ from tests.fixtures.books import BookForm, Books from tests.fixtures.datatypes import Telephone -from xsdata.exceptions import XmlContextError +from tests.fixtures.wrapper import Wrapper from xsdata.formats.dataclass.serializers import DictEncoder, DictFactory from xsdata.models.datatype import XmlDate from xsdata.models.xsd import Attribute @@ -61,10 +61,6 @@ def test_encode(self): actual = self.encoder.encode(self.books) self.assertEqual(self.expected, actual) - def test_encode_a_none_dataclass_object(self): - with self.assertRaises(XmlContextError): - DictEncoder().encode(1) - def test_encode_list_of_objects(self): actual = self.encoder.encode(self.books.book) self.assertEqual(self.expected["book"], actual) @@ -80,6 +76,17 @@ def test_convert_namedtuple(self): actual = self.encoder.encode(Telephone(30, 234, 56783), var) self.assertEqual("30-234-56783", actual) + def test_convert_wrapper(self): + obj = Wrapper(alpha=["value"]) + value = self.encoder.encode(obj) + expected = { + "alphas": {"alpha": ["value"]}, + "bravos": {"bravo": []}, + "charlies": {"charlie": []}, + } + + self.assertEqual(expected, value) + def test_next_value(self): book = self.books.book[0] diff --git a/xsdata/formats/dataclass/models/builders.py b/xsdata/formats/dataclass/models/builders.py index 85a7e18bf..6d4a9a58f 100644 --- a/xsdata/formats/dataclass/models/builders.py +++ b/xsdata/formats/dataclass/models/builders.py @@ -120,19 +120,19 @@ def build(self, clazz: Type, parent_namespace: Optional[str]) -> XmlMeta: attributes = {} elements: Dict[str, List[XmlVar]] = defaultdict(list) + wrappers: Dict[str, str] = {} choices = [] any_attributes = [] wildcards = [] - wrappers: Dict[str, List[XmlVar]] = defaultdict(list) text = None for var in class_vars: - if var.wrapper is not None: - wrappers[var.wrapper].append(var) if var.is_attribute: attributes[var.qname] = var elif var.is_element: elements[var.qname].append(var) + if var.wrapper: + wrappers[var.wrapper] = var.qname elif var.is_elements: choices.append(var) elif var.is_attributes: @@ -391,14 +391,6 @@ def build( f"Xml {xml_type} does not support typing `{type_hint}`" ) - if wrapper is not None and ( - not isinstance(origin, type) or not issubclass(origin, (list, set, tuple)) - ): - raise XmlContextError( - f"Error on {model.__qualname__}::{name}: " - f"A wrapper field requires a collection type" - ) - local_name = local_name or self.build_local_name(xml_type, name) if tokens and sub_origin is None: diff --git a/xsdata/formats/dataclass/models/elements.py b/xsdata/formats/dataclass/models/elements.py index c5d9fb006..43919396e 100644 --- a/xsdata/formats/dataclass/models/elements.py +++ b/xsdata/formats/dataclass/models/elements.py @@ -128,6 +128,7 @@ class XmlVar(MetaMixin): "namespace_matches", "is_clazz_union", "local_name", + "wrapper_local_name", ) def __init__( @@ -184,6 +185,10 @@ def __init__( self.is_clazz_union = self.clazz and len(types) > 1 self.local_name = local_name(qname) + self.wrapper_local_name = None + if wrapper: + self.wrapper_local_name = local_name(wrapper) + self.is_text = False self.is_element = False self.is_elements = False @@ -417,7 +422,7 @@ def __init__( wildcards: Sequence[XmlVar], attributes: Mapping[str, XmlVar], any_attributes: Sequence[XmlVar], - wrappers: Mapping[str, Sequence[XmlVar]], + wrappers: Mapping[str, str], **kwargs: Any, ): self.clazz = clazz diff --git a/xsdata/formats/dataclass/parsers/dict.py b/xsdata/formats/dataclass/parsers/dict.py index 45edcc8b6..cfe9c07ba 100644 --- a/xsdata/formats/dataclass/parsers/dict.py +++ b/xsdata/formats/dataclass/parsers/dict.py @@ -130,13 +130,15 @@ def bind_dataclass(self, data: Dict, clazz: Type[T]) -> T: params = {} for key, value in data.items(): - is_array = collections.is_array(value) - var = self.find_var(xml_vars, key, is_array) + var = self.find_var(xml_vars, key, value) if var is None and self.config.fail_on_unknown_properties: raise ParserError(f"Unknown property {clazz.__qualname__}.{key}") if var and var.init: + if var.wrapper: + value = value[var.local_name] + params[var.name] = self.bind_value(meta, var, value) try: @@ -410,23 +412,30 @@ def bind_derived_value(self, meta: XmlMeta, var: XmlVar, data: Dict) -> Any: def find_var( cls, xml_vars: List[XmlVar], - local_name: str, - is_list: bool = False, + key: str, + value: Any, ) -> Optional[XmlVar]: """Match the name to a xml variable. Args: xml_vars: A list of xml vars - local_name: A key from the loaded data - is_list: Whether the data value is an array + key: A key from the loaded data + value: The data assigned to the key Returns: One of the xml vars, if all search attributes match, None otherwise. """ for var in xml_vars: - if var.local_name == local_name: + if var.local_name == key: var_is_list = var.list_element or var.tokens - if is_list == var_is_list or var.clazz is None: + is_array = collections.is_array(value) + if is_array == var_is_list or var.clazz is None: return var - + elif var.wrapper_local_name == key: + if isinstance(value, dict) and var.local_name in value: + val = value[var.local_name] + var_is_list = var.list_element or var.tokens + is_array = collections.is_array(val) + if is_array == var_is_list or var.clazz is None: + return var return None diff --git a/xsdata/formats/dataclass/serializers/dict.py b/xsdata/formats/dataclass/serializers/dict.py index 743a71b98..b453f1f2d 100644 --- a/xsdata/formats/dataclass/serializers/dict.py +++ b/xsdata/formats/dataclass/serializers/dict.py @@ -41,7 +41,9 @@ class DictEncoder: context: XmlContext = field(default_factory=XmlContext) dict_factory: Callable = field(default=dict) - def encode(self, value: Any, var: Optional[XmlVar] = None) -> Any: + def encode( + self, value: Any, var: Optional[XmlVar] = None, wrapped: bool = False + ) -> Any: """Convert a value to a dictionary object. Args: @@ -51,20 +53,30 @@ def encode(self, value: Any, var: Optional[XmlVar] = None) -> Any: Returns: The converted json serializable value. """ - if var is None or self.context.class_type.is_model(value): + + if value is None: + return None + + if var is None: if collections.is_array(value): return list(map(self.encode, value)) return self.dict_factory(self.next_value(value)) + if var and var.wrapper and not wrapped: + return self.dict_factory(((var.local_name, self.encode(value, var, True)),)) + + if self.context.class_type.is_model(value): + return self.dict_factory(self.next_value(value)) + if collections.is_array(value): - return type(value)(self.encode(val, var) for val in value) + return type(value)(self.encode(val, var, wrapped) for val in value) if isinstance(value, (dict, int, float, str, bool)): return value if isinstance(value, Enum): - return self.encode(value.value, var) + return self.encode(value.value, var, wrapped) return converter.serialize(value, format=var.format) @@ -87,4 +99,7 @@ def next_value(self, obj: Any) -> Iterator[Tuple[str, Any]]: or not ignore_optionals or not var.is_optional(value) ): - yield var.local_name, self.encode(value, var) + if var.wrapper: + yield var.wrapper_local_name, self.encode(value, var) + else: + yield var.local_name, self.encode(value, var) diff --git a/xsdata/formats/dataclass/serializers/mixins.py b/xsdata/formats/dataclass/serializers/mixins.py index ca82f2d7f..bbdeb7c76 100644 --- a/xsdata/formats/dataclass/serializers/mixins.py +++ b/xsdata/formats/dataclass/serializers/mixins.py @@ -413,8 +413,14 @@ def convert_dataclass( yield XmlWriterEvent.ATTR, key, value for var, value in self.next_value(obj, meta): + if var.wrapper: + yield XmlWriterEvent.START, var.wrapper + yield from self.convert_value(value, var, namespace) + if var.wrapper: + yield XmlWriterEvent.END, var.wrapper + yield XmlWriterEvent.END, qname def convert_xsi_type( @@ -500,14 +506,8 @@ def convert_list( Yields: An iterator of sax events. """ - if var.wrapper is not None: - yield XmlWriterEvent.START, var.wrapper - for value in values: - yield from self.convert_value(value, var, namespace) - yield XmlWriterEvent.END, var.wrapper - else: - for value in values: - yield from self.convert_value(value, var, namespace) + for value in values: + yield from self.convert_value(value, var, namespace) def convert_tokens( self, value: Any, var: XmlVar, namespace: Optional[str] From 1587fbbf1bba168243afe79812aa700055540e76 Mon Sep 17 00:00:00 2001 From: Christodoulos Tsoulloftas Date: Wed, 20 Mar 2024 15:37:54 +0200 Subject: [PATCH 4/6] test: Add wrapper field integration tests --- tests/conftest.py | 29 +++++-- tests/fixtures/compound/sample.py | 33 ++++++++ tests/fixtures/primer/sample.py | 48 +++++++++++ tests/fixtures/wrapper/__init__.py | 15 ++++ tests/fixtures/wrapper/models.py | 101 +++++++++++++++++++++++ tests/fixtures/wrapper/sample.json | 23 ++++++ tests/fixtures/wrapper/sample.py | 21 +++++ tests/fixtures/wrapper/sample.xml | 14 ++++ tests/fixtures/wrapper/sample.xsdata.xml | 14 ++++ tests/fixtures/wrapper/schema.xsd | 41 +++++++++ tests/integration/test_wrapper.py | 34 ++++++++ tests/test_cli.py | 2 +- 12 files changed, 367 insertions(+), 8 deletions(-) create mode 100644 tests/fixtures/compound/sample.py create mode 100644 tests/fixtures/primer/sample.py create mode 100644 tests/fixtures/wrapper/__init__.py create mode 100644 tests/fixtures/wrapper/models.py create mode 100644 tests/fixtures/wrapper/sample.json create mode 100644 tests/fixtures/wrapper/sample.py create mode 100644 tests/fixtures/wrapper/sample.xml create mode 100644 tests/fixtures/wrapper/sample.xsdata.xml create mode 100644 tests/fixtures/wrapper/schema.xsd create mode 100644 tests/integration/test_wrapper.py diff --git a/tests/conftest.py b/tests/conftest.py index f9fbfa873..00fc3679e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,8 +3,13 @@ from lxml import etree +from xsdata.formats.dataclass.context import XmlContext from xsdata.formats.dataclass.parsers import JsonParser, XmlParser -from xsdata.formats.dataclass.serializers import JsonSerializer, XmlSerializer +from xsdata.formats.dataclass.serializers import ( + JsonSerializer, + PycodeSerializer, + XmlSerializer, +) from xsdata.formats.dataclass.serializers.config import SerializerConfig @@ -12,9 +17,20 @@ def validate_bindings(schema: Path, clazz: Type): __tracebackhide__ = True sample = schema.parent.joinpath("sample.xml") - obj = XmlParser().from_path(sample, clazz) + context = XmlContext() + config = SerializerConfig(indent=" ") - actual = JsonSerializer(config=config).render(obj) + xml_parser = XmlParser(context=context) + xml_serializer = XmlSerializer(context=context, config=config) + json_serializer = JsonSerializer(context=context, config=config) + pycode_serializer = PycodeSerializer(context=context) + + obj = xml_parser.from_path(sample, clazz) + + code = pycode_serializer.render(obj) + sample.with_suffix(".py").write_text(code) + + actual = json_serializer.render(obj) expected = sample.with_suffix(".json") if expected.exists(): @@ -23,10 +39,9 @@ def validate_bindings(schema: Path, clazz: Type): else: expected.write_text(actual, encoding="utf-8") - config = SerializerConfig(indent=" ") - xml = XmlSerializer(config=config).render(obj) + xml = xml_serializer.render(obj) + + expected.with_suffix(".xsdata.xml").write_text(xml, encoding="utf-8") validator = etree.XMLSchema(etree.parse(str(schema))) assert validator.validate(etree.fromstring(xml.encode())), validator.error_log - - expected.with_suffix(".xsdata.xml").write_text(xml, encoding="utf-8") diff --git a/tests/fixtures/compound/sample.py b/tests/fixtures/compound/sample.py new file mode 100644 index 000000000..43f108bab --- /dev/null +++ b/tests/fixtures/compound/sample.py @@ -0,0 +1,33 @@ +from tests.fixtures.compound.models import Alpha +from tests.fixtures.compound.models import Bravo +from tests.fixtures.compound.models import Root + + +obj = Root( + alpha_or_bravo=[ + Alpha( + + ), + Alpha( + + ), + Bravo( + + ), + Bravo( + + ), + Alpha( + + ), + Bravo( + + ), + Alpha( + + ), + Bravo( + + ), + ] +) diff --git a/tests/fixtures/primer/sample.py b/tests/fixtures/primer/sample.py new file mode 100644 index 000000000..df2643cf9 --- /dev/null +++ b/tests/fixtures/primer/sample.py @@ -0,0 +1,48 @@ +from decimal import Decimal +from tests.fixtures.primer.order import Comment +from tests.fixtures.primer.order import Items +from tests.fixtures.primer.order import PurchaseOrder +from tests.fixtures.primer.order import Usaddress +from xsdata.models.datatype import XmlDate + + +obj = PurchaseOrder( + ship_to=Usaddress( + name='Alice Smith', + street='123 Maple Street', + city='Mill Valley', + state='CA', + zip=Decimal('90952') + ), + bill_to=Usaddress( + name='Robert Smith', + street='8 Oak Avenue', + city='Old Town', + state='PA', + zip=Decimal('95819') + ), + comment=Comment( + value='Hurry, my lawn is going wild!' + ), + items=Items( + item=[ + Items.Item( + product_name='Lawnmower', + quantity=1, + usprice=Decimal('148.95'), + comment=Comment( + value='Confirm this is electric' + ), + part_num='872-AA' + ), + Items.Item( + product_name='Baby Monitor', + quantity=1, + usprice=Decimal('39.98'), + ship_date=XmlDate(1999, 5, 21), + part_num='926-AA' + ), + ] + ), + order_date=XmlDate(1999, 10, 20) +) diff --git a/tests/fixtures/wrapper/__init__.py b/tests/fixtures/wrapper/__init__.py new file mode 100644 index 000000000..460e77aa5 --- /dev/null +++ b/tests/fixtures/wrapper/__init__.py @@ -0,0 +1,15 @@ +from tests.fixtures.wrapper.models import ( + Alphas, + Bravos, + Charlie, + Charlies, + Wrapper, +) + +__all__ = [ + "Alphas", + "Bravos", + "Charlie", + "Charlies", + "Wrapper", +] diff --git a/tests/fixtures/wrapper/models.py b/tests/fixtures/wrapper/models.py new file mode 100644 index 000000000..0ec44b800 --- /dev/null +++ b/tests/fixtures/wrapper/models.py @@ -0,0 +1,101 @@ +from dataclasses import dataclass, field +from typing import List, Optional + +__NAMESPACE__ = "xsdata" + + +@dataclass +class Alphas: + class Meta: + name = "alphas" + namespace = "xsdata" + + alpha: Optional[str] = field( + default=None, + metadata={ + "type": "Element", + "required": True, + }, + ) + + +@dataclass +class Bravos: + class Meta: + name = "bravos" + namespace = "xsdata" + + bravo: List[int] = field( + default_factory=list, + metadata={ + "type": "Element", + "min_occurs": 1, + }, + ) + + +@dataclass +class Charlie: + class Meta: + name = "charlie" + namespace = "xsdata" + + value: str = field( + default="", + metadata={ + "required": True, + }, + ) + lang: Optional[object] = field( + default=None, + metadata={ + "type": "Attribute", + }, + ) + + +@dataclass +class Charlies: + class Meta: + name = "charlies" + namespace = "xsdata" + + charlie: List[Charlie] = field( + default_factory=list, + metadata={ + "type": "Element", + "min_occurs": 1, + }, + ) + + +@dataclass +class Wrapper: + class Meta: + name = "wrapper" + namespace = "xsdata" + + alpha: Optional[str] = field( + default=None, + metadata={ + "wrapper": "alphas", + "type": "Element", + "required": True, + }, + ) + bravo: List[int] = field( + default_factory=list, + metadata={ + "wrapper": "bravos", + "type": "Element", + "min_occurs": 1, + }, + ) + charlie: List[Charlie] = field( + default_factory=list, + metadata={ + "wrapper": "charlies", + "type": "Element", + "min_occurs": 1, + }, + ) diff --git a/tests/fixtures/wrapper/sample.json b/tests/fixtures/wrapper/sample.json new file mode 100644 index 000000000..43fdb7436 --- /dev/null +++ b/tests/fixtures/wrapper/sample.json @@ -0,0 +1,23 @@ +{ + "alphas": { + "alpha": "\u03b1\u03b1\u03b1" + }, + "bravos": { + "bravo": [ + 1, + 2 + ] + }, + "charlies": { + "charlie": [ + { + "value": "\u03b4\u03b4\u03b4", + "lang": "en" + }, + { + "value": "eee", + "lang": "en" + } + ] + } +} \ No newline at end of file diff --git a/tests/fixtures/wrapper/sample.py b/tests/fixtures/wrapper/sample.py new file mode 100644 index 000000000..e9ff6726a --- /dev/null +++ b/tests/fixtures/wrapper/sample.py @@ -0,0 +1,21 @@ +from tests.fixtures.wrapper.models import Charlie +from tests.fixtures.wrapper.models import Wrapper + + +obj = Wrapper( + alpha='ααα', + bravo=[ + 1, + 2, + ], + charlie=[ + Charlie( + value='δδδ', + lang='en' + ), + Charlie( + value='eee', + lang='en' + ), + ] +) diff --git a/tests/fixtures/wrapper/sample.xml b/tests/fixtures/wrapper/sample.xml new file mode 100644 index 000000000..99db8c14a --- /dev/null +++ b/tests/fixtures/wrapper/sample.xml @@ -0,0 +1,14 @@ + + + + ααα + + + 1 + 2 + + + δδδ + eee + + diff --git a/tests/fixtures/wrapper/sample.xsdata.xml b/tests/fixtures/wrapper/sample.xsdata.xml new file mode 100644 index 000000000..99db8c14a --- /dev/null +++ b/tests/fixtures/wrapper/sample.xsdata.xml @@ -0,0 +1,14 @@ + + + + ααα + + + 1 + 2 + + + δδδ + eee + + diff --git a/tests/fixtures/wrapper/schema.xsd b/tests/fixtures/wrapper/schema.xsd new file mode 100644 index 000000000..3b915e43d --- /dev/null +++ b/tests/fixtures/wrapper/schema.xsd @@ -0,0 +1,41 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/tests/integration/test_wrapper.py b/tests/integration/test_wrapper.py new file mode 100644 index 000000000..dcae13505 --- /dev/null +++ b/tests/integration/test_wrapper.py @@ -0,0 +1,34 @@ +import os + +from click.testing import CliRunner + +from tests import fixtures_dir, root +from tests.conftest import validate_bindings +from xsdata.cli import cli +from xsdata.utils.testing import load_class + +os.chdir(root) + + +def test_xml_documents(): + schema = fixtures_dir.joinpath("wrapper/schema.xsd") + package = "tests.fixtures.wrapper.models" + runner = CliRunner() + result = runner.invoke( + cli, + [ + str(schema), + "-p", + package, + "-ss", + "single-package", + "--wrapper-fields", + "--compound-fields", + ], + ) + + if result.exception: + raise result.exception + + clazz = load_class(result.output, "Wrapper") + validate_bindings(schema, clazz) diff --git a/tests/test_cli.py b/tests/test_cli.py index 450203e80..9bfb269c7 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -192,4 +192,4 @@ def test_resolve_source(self): self.assertEqual(3, len(list(resolve_source(str(def_xml_path), False)))) actual = list(resolve_source(str(fixtures_dir), True)) - self.assertEqual(39, len(actual)) + self.assertEqual(43, len(actual)) From 45f8a06c55500cd860ea33be54d700205ad78798 Mon Sep 17 00:00:00 2001 From: Christodoulos Tsoulloftas Date: Thu, 21 Mar 2024 15:06:25 +0200 Subject: [PATCH 5/6] feat: Generate wrapper fields for forward references --- .../handlers/test_create_wrapper_fields.py | 31 +++++--- .../codegen/handlers/create_wrapper_fields.py | 72 ++++++++++++------- xsdata/formats/dataclass/serializers/dict.py | 1 + 3 files changed, 68 insertions(+), 36 deletions(-) diff --git a/tests/codegen/handlers/test_create_wrapper_fields.py b/tests/codegen/handlers/test_create_wrapper_fields.py index 81222faf2..22c7a07fe 100644 --- a/tests/codegen/handlers/test_create_wrapper_fields.py +++ b/tests/codegen/handlers/test_create_wrapper_fields.py @@ -38,6 +38,16 @@ def test_process_with_valid_attr_wrapper(self): self.processor.process(self.target) self.assertEqual("items", self.target.attrs[0].wrapper) + def test_process_with_forward_reference(self): + self.container.remove(self.source) + self.target.inner.append(self.source) + self.target.attrs[0].types[0].forward = True + + self.processor.process(self.target) + self.assertEqual("items", self.target.attrs[0].wrapper) + self.assertFalse(self.target.attrs[0].types[0].forward) + self.assertEqual(0, len(self.target.inner)) + def test_process_with_invalid_attr(self): self.target.attrs[0].tag = Tag.EXTENSION self.processor.process(self.target) @@ -55,19 +65,11 @@ def test_wrap_field(self): attr = AttrFactory.create() wrapper = attr.local_name - self.processor.wrap_field(source, attr) + self.processor.wrap_field(source, attr, False) self.assertEqual(source.name, attr.name) self.assertEqual(source.local_name, attr.local_name) self.assertEqual(wrapper, attr.wrapper) - def test_find_source_with_forward_reference(self): - tp = self.target.attrs[0].types[0] - tp.forward = True - self.target.inner.append(self.source) - - actual = self.processor.find_source(self.target, tp) - self.assertEqual(self.source, actual) - def test_validate_attr(self): # Not an element attr = AttrFactory.create(tag=Tag.EXTENSION) @@ -100,6 +102,7 @@ def test_validate_source(self): # Has forwarded references source.attrs.pop(0) + source.attrs[0].tag = Tag.EXTENSION source.attrs[0].types[0].forward = True self.assertFalse(self.processor.validate_source(source, None)) @@ -107,6 +110,14 @@ def test_validate_source(self): source.attrs[0].types[0].forward = False self.assertFalse(self.processor.validate_source(source, "bar")) - # Not any of the above issues + # Optional source.attrs[0].namespace = "bar" + self.assertFalse(self.processor.validate_source(source, "bar")) + + # Not Element + source.attrs[0].restrictions.min_occurs = 1 + self.assertFalse(self.processor.validate_source(source, "bar")) + + # All rules pass + source.attrs[0].tag = Tag.ELEMENT self.assertTrue(self.processor.validate_source(source, "bar")) diff --git a/xsdata/codegen/handlers/create_wrapper_fields.py b/xsdata/codegen/handlers/create_wrapper_fields.py index 7cce5827e..b9dee7442 100644 --- a/xsdata/codegen/handlers/create_wrapper_fields.py +++ b/xsdata/codegen/handlers/create_wrapper_fields.py @@ -1,7 +1,8 @@ -from typing import Optional +from typing import Optional, Tuple from xsdata.codegen.mixins import RelativeHandlerInterface -from xsdata.codegen.models import Attr, AttrType, Class +from xsdata.codegen.models import Attr, Class +from xsdata.codegen.utils import ClassUtils class CreateWrapperFields(RelativeHandlerInterface): @@ -20,23 +21,28 @@ def process(self, target: Class): if not self.container.config.output.wrapper_fields: return + wrapped = False + wrapped_inner = False for attr in target.attrs: - if self.validate_attr(attr): - self.process_attr(target, attr) + if not self.validate_attr(attr): + continue - def process_attr(self, target: Class, attr: Attr): - """Process the given attr instance. + inner, source = self.find_source_attr(target, attr) + if not source: + continue - Args: - target: The parent class instance - attr: The attr instance to process - """ - source = self.find_source(target, attr.types[0]) - if self.validate_source(source, attr.namespace): - self.wrap_field(source.attrs[0], attr) + self.wrap_field(source, attr, inner) + wrapped = True + wrapped_inner = wrapped_inner or inner + + if wrapped: + ClassUtils.rename_duplicate_attributes(target) + + if inner: + ClassUtils.clean_inner_classes(target) @classmethod - def wrap_field(cls, source: Attr, attr: Attr): + def wrap_field(cls, source: Attr, attr: Attr, inner: bool): """Create a wrapper field. Clone the source attr and update its name, local name and wrapper @@ -45,13 +51,17 @@ def wrap_field(cls, source: Attr, attr: Attr): Args: source: The source attr instance attr: The attr instance to wrap + inner: Specify if the source is from an inner class """ wrapper = attr.local_name attr.swap(source) attr.wrapper = wrapper + attr.types[0].forward = False - def find_source(self, parent: Class, tp: AttrType) -> Class: + def find_source_attr( + self, parent: Class, attr: Attr + ) -> Tuple[bool, Optional[Attr]]: """Find the source type for the given attr type instance. If it's a forward reference, look up the source in @@ -59,15 +69,23 @@ def find_source(self, parent: Class, tp: AttrType) -> Class: Args: parent: The parent class instance - tp: The attr type instance to look up + attr: The attr instance to find a valid source attr Returns: - The source class instance that matches the attr type. + A tuple of whether the source attr is inner and the source attr. """ + tp = attr.types[0] + inner = False if tp.forward: - return self.container.find_inner(parent, tp.qname) + source = self.container.find_inner(parent, tp.qname) + inner = True + else: + source = self.container.first(tp.qname) - return self.container.first(tp.qname) + if self.validate_source(source, attr.namespace): + return inner, source.attrs[0] + + return inner, None @classmethod def validate_attr(cls, attr: Attr) -> bool: @@ -75,22 +93,21 @@ def validate_attr(cls, attr: Attr) -> bool: Rules: 1. Must be an element - 2. Must have only one type - 3. It has to be a user type + 2. Must have only one user type 4. The element can't be optional - + 5. The element can't be a list element Args: attr: The attr instance to validate Returns: Whether the attr can be converted to a wrapper. - """ return ( attr.is_element and len(attr.types) == 1 and not attr.types[0].native + and not attr.is_list and not attr.is_optional ) @@ -101,9 +118,10 @@ def validate_source(cls, source: Class, namespace: Optional[str]) -> bool: Rules: 1. It must not have any extensions 2. It must contain exactly one type - 3. It must not be a forward reference - 4. The source attr namespace must match the namespace - + 3. It must be derived from a xs:element + 4. It must not be optional + 5. It must not be a forward reference + 6. The source attr namespace must match the namespace Args: source: The source class instance to validate @@ -119,6 +137,8 @@ def ns_equal(a: Optional[str], b: Optional[str]): return ( not source.extensions and len(source.attrs) == 1 + and source.attrs[0].is_element + and not source.attrs[0].is_optional and not source.attrs[0].is_forward_ref and ns_equal(source.attrs[0].namespace, namespace) ) diff --git a/xsdata/formats/dataclass/serializers/dict.py b/xsdata/formats/dataclass/serializers/dict.py index b453f1f2d..ec611dc99 100644 --- a/xsdata/formats/dataclass/serializers/dict.py +++ b/xsdata/formats/dataclass/serializers/dict.py @@ -100,6 +100,7 @@ def next_value(self, obj: Any) -> Iterator[Tuple[str, Any]]: or not var.is_optional(value) ): if var.wrapper: + assert var.wrapper_local_name is not None yield var.wrapper_local_name, self.encode(value, var) else: yield var.local_name, self.encode(value, var) From 15db7b56d329564f2eadf670a2929361a8bdf3dd Mon Sep 17 00:00:00 2001 From: Christodoulos Tsoulloftas Date: Thu, 21 Mar 2024 15:14:23 +0200 Subject: [PATCH 6/6] docs: Update wrapper documentation, increase coverage --- docs/codegen/architecture.md | 1 + docs/codegen/config.md | 41 +++++++++++++++++++ docs/models/fields.md | 4 +- tests/fixtures/models.py | 2 +- .../formats/dataclass/models/test_builders.py | 1 + tests/formats/dataclass/parsers/test_dict.py | 32 +++++++++++++-- xsdata/formats/dataclass/parsers/dict.py | 4 +- xsdata/models/config.py | 2 +- 8 files changed, 77 insertions(+), 10 deletions(-) diff --git a/docs/codegen/architecture.md b/docs/codegen/architecture.md index 4f7471ef9..dfccb0996 100644 --- a/docs/codegen/architecture.md +++ b/docs/codegen/architecture.md @@ -141,6 +141,7 @@ pass through each step before next one starts. The order of the steps is very im - [DetectCircularReferences][xsdata.codegen.handlers.DetectCircularReferences] - [CreateCompoundFields][xsdata.codegen.handlers.CreateCompoundFields] +- [CreateWrapperFields][xsdata.codegen.handlers.CreateWrapperFields] - [DisambiguateChoices][xsdata.codegen.handlers.DisambiguateChoices] - [ResetAttributeSequenceNumbers][xsdata.codegen.handlers.ResetAttributeSequenceNumbers] diff --git a/docs/codegen/config.md b/docs/codegen/config.md index 4e9490f00..e2685eb01 100644 --- a/docs/codegen/config.md +++ b/docs/codegen/config.md @@ -184,6 +184,47 @@ hat_or_bat_cat: list[str | int | float] = field(...) product: list[Shoe | Shirt | Hat] = field(...) ``` +### WrapperFields + +Generate wrapper fields whenever possible for single or collections of simple and +complex elements. + +The wrapper and wrapped elements can't be optional. If the wrapped value is a list it +must have minimum `occurs >= 1`. + +```xml show_lines="2:17" +--8<-- "tests/fixtures/wrapper/schema.xsd" +``` + +**Default Value:** `False` + +**CLI Option:** `--wrapper-fields / --no-wrapper-fields` + +**Examples:** + +```python +alpha: str = field( + metadata={ + "wrapper": "alphas", + "type": "Element", + }, +) +bravo: List[int] = field( + default_factory=list, + metadata={ + "wrapper": "bravos", + "type": "Element", + }, +) +charlie: List[Charlie] = field( + default_factory=list, + metadata={ + "wrapper": "charlies", + "type": "Element", + }, +) +``` + ### PostponedAnnotations Use [PEP-563](https://peps.python.org/pep-0563/), postponed evaluation of annotations. diff --git a/docs/models/fields.md b/docs/models/fields.md index fd7271580..952e3f614 100644 --- a/docs/models/fields.md +++ b/docs/models/fields.md @@ -332,8 +332,8 @@ declared with `init=False` or with a default value otherwise data binding will f ### `wrapper` -The element name to wrap a collection of elements or primitives, in order to avoid -having a dedicated wrapper class. +The element name to wrap a single or a collection of elements or primitives, in order to +avoid having a dedicated wrapper class. ```python >>> from dataclasses import dataclass, field diff --git a/tests/fixtures/models.py b/tests/fixtures/models.py index 1c2e47972..6987ccad2 100644 --- a/tests/fixtures/models.py +++ b/tests/fixtures/models.py @@ -99,7 +99,7 @@ class ChoiceType: {"name": "float", "type": float}, {"name": "qname", "type": QName}, {"name": "union", "type": Type["UnionType"], "namespace": "foo"}, - {"name": "tokens", "type": List[Decimal], "tokens": True}, + {"name": "tokens", "type": List[Decimal], "tokens": True, "default_factory": list}, { "wildcard": True, "type": object, diff --git a/tests/formats/dataclass/models/test_builders.py b/tests/formats/dataclass/models/test_builders.py index a1547270a..bcdba9bc1 100644 --- a/tests/formats/dataclass/models/test_builders.py +++ b/tests/formats/dataclass/models/test_builders.py @@ -343,6 +343,7 @@ def test_build_with_choice_field(self): types=(Decimal,), tokens_factory=list, derived=True, + default=list, factory=list, namespaces=("bar",), ), diff --git a/tests/formats/dataclass/parsers/test_dict.py b/tests/formats/dataclass/parsers/test_dict.py index 548a86edc..dcb363bf4 100644 --- a/tests/formats/dataclass/parsers/test_dict.py +++ b/tests/formats/dataclass/parsers/test_dict.py @@ -19,7 +19,7 @@ TypeD, UnionType, ) -from tests.fixtures.wrapper import Wrapper +from tests.fixtures.wrapper import Charlie, Wrapper from xsdata.exceptions import ParserError from xsdata.formats.dataclass.models.generics import AnyElement, DerivedElement from xsdata.formats.dataclass.parsers import DictDecoder @@ -106,10 +106,26 @@ def test_decode_with_fail_on_converter_warnings(self): ) def test_decode_wrapper(self): - data = {"alphas": {"alpha": "value"}} + data = { + "alphas": {"alpha": "value"}, + "bravos": {"bravo": [1, 2]}, + "charlies": { + "charlie": [ + {"value": "first", "lang": "en"}, + {"value": "second", "lang": "en"}, + ] + }, + } actual = self.decoder.decode(data, Wrapper) - expected = Wrapper(alpha="value") + expected = Wrapper( + alpha="value", + bravo=[1, 2], + charlie=[ + Charlie(value="first", lang="en"), + Charlie(value="second", lang="en"), + ], + ) self.assertEqual(expected, actual) def test_verify_type(self): @@ -377,9 +393,17 @@ def test_find_var(self): xml_vars = meta.get_all_vars() self.assertEqual(xml_vars[0], self.decoder.find_var(xml_vars, "x", 1)) - self.assertEqual(xml_vars[0], self.decoder.find_var(xml_vars, "x", [1, 2])) + self.assertIsNone(self.decoder.find_var(xml_vars, "x", [1, 2])) meta = self.decoder.context.build(ExtendedType) xml_vars = meta.get_all_vars() self.assertIsNone(self.decoder.find_var(xml_vars, "a", [1, 2])) self.assertEqual(xml_vars[0], self.decoder.find_var(xml_vars, "a", {"x": 1})) + + meta = self.decoder.context.build(Wrapper) + xml_vars = meta.get_all_vars() + self.assertIsNone(self.decoder.find_var(xml_vars, "charlies", {})) + self.assertIsNone(self.decoder.find_var(xml_vars, "bravos", {"bravo": 1})) + self.assertEqual( + xml_vars[0], self.decoder.find_var(xml_vars, "alphas", {"alpha": "foo"}) + ) diff --git a/xsdata/formats/dataclass/parsers/dict.py b/xsdata/formats/dataclass/parsers/dict.py index cfe9c07ba..3ef83b1fa 100644 --- a/xsdata/formats/dataclass/parsers/dict.py +++ b/xsdata/formats/dataclass/parsers/dict.py @@ -429,13 +429,13 @@ def find_var( if var.local_name == key: var_is_list = var.list_element or var.tokens is_array = collections.is_array(value) - if is_array == var_is_list or var.clazz is None: + if is_array == var_is_list: return var elif var.wrapper_local_name == key: if isinstance(value, dict) and var.local_name in value: val = value[var.local_name] var_is_list = var.list_element or var.tokens is_array = collections.is_array(val) - if is_array == var_is_list or var.clazz is None: + if is_array == var_is_list: return var return None diff --git a/xsdata/models/config.py b/xsdata/models/config.py index 08e2f934f..cb3054d9d 100644 --- a/xsdata/models/config.py +++ b/xsdata/models/config.py @@ -221,7 +221,7 @@ class GeneratorOutput: docstring_style: Docstring style relative_imports: Use relative imports compound_fields: Use compound fields for repeatable elements - wrapper_fields: Generate wrapper fields for element lists + wrapper_fields: Generate wrapper fields max_line_length: Adjust the maximum line length subscriptable_types: Use PEP-585 generics for standard collections, python>=3.9 Only