diff --git a/docs/conf.py b/docs/conf.py index 71b8c022b..ef5f086d3 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -30,6 +30,7 @@ # ones. extensions = [ "xsdatadocs", + "sphinx.ext.napoleon", "sphinx.ext.doctest", "sphinx.ext.autodoc", "sphinx.ext.intersphinx", @@ -86,3 +87,4 @@ autosummary_generate = True set_type_checking_flag = True always_document_param_types = False +napoleon_google_docstring = True diff --git a/pyproject.toml b/pyproject.toml index a656544d0..a048efef2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -117,13 +117,17 @@ ignore = [ "B028", "B904", "D100", + "D104", "D107" + ] isort = { known-first-party = ['xsdata', 'tests'] } [tool.ruff.lint.per-file-ignores] "**/{tests}/*" = ["ANN001", "ANN002", "ANN003", "E501", "B018", "D"] +"**/utils/testing.py" = ["D"] +"docs/*" = ["D"] [tool.ruff.lint.pydocstyle] convention = "google" diff --git a/tests/codegen/handlers/test_process_attributes_types.py b/tests/codegen/handlers/test_process_attributes_types.py index b2f0d915d..6027150c0 100644 --- a/tests/codegen/handlers/test_process_attributes_types.py +++ b/tests/codegen/handlers/test_process_attributes_types.py @@ -295,8 +295,8 @@ def test_copy_attribute_properties(self, mock_copy_inner_class): ) mock_copy_inner_class.assert_has_calls( [ - mock.call(source, target, attr, source.attrs[0].types[0]), - mock.call(source, target, attr, source.attrs[0].types[1]), + mock.call(source, target, source.attrs[0].types[0]), + mock.call(source, target, source.attrs[0].types[1]), ] ) diff --git a/tests/codegen/handlers/test_sanitize_attributes_default_value.py b/tests/codegen/handlers/test_sanitize_attributes_default_value.py index afb67c2a5..ad99172dc 100644 --- a/tests/codegen/handlers/test_sanitize_attributes_default_value.py +++ b/tests/codegen/handlers/test_sanitize_attributes_default_value.py @@ -287,18 +287,18 @@ def test_is_valid_enum_type(self): self.assertFalse(self.processor.is_valid_enum_type(enumeration, attr)) self.assertEqual("3", attr.default) - def test_find_type(self): + def test_find_inner_type(self): target = ClassFactory.create() attr_type = AttrTypeFactory.create("foo") foo = ClassFactory.create(qname="foo") self.processor.container.add(foo) - self.assertIs(foo, self.processor.find_type(target, attr_type)) + self.assertIs(foo, self.processor.find_inner_type(target, attr_type)) attr_type = AttrTypeFactory.create("bar", forward=True) bar = ClassFactory.create(qname="bar") target.inner.append(bar) - self.assertIs(bar, self.processor.find_type(target, attr_type)) + self.assertIs(bar, self.processor.find_inner_type(target, attr_type)) def test_reset_attribute_types(self): attr = AttrFactory.create( diff --git a/tests/codegen/mappers/test_definitions.py b/tests/codegen/mappers/test_definitions.py index 458eb288b..123059be2 100644 --- a/tests/codegen/mappers/test_definitions.py +++ b/tests/codegen/mappers/test_definitions.py @@ -1,7 +1,7 @@ from typing import Generator from unittest import mock -from xsdata.codegen.mappers.definitions import DefinitionsMapper +from xsdata.codegen.mappers import DefinitionsMapper from xsdata.codegen.models import Class, Status from xsdata.formats.dataclass.models.generics import AnyElement from xsdata.models.enums import DataType, Namespace, Tag diff --git a/tests/codegen/mappers/test_dict.py b/tests/codegen/mappers/test_dict.py index 7ec652106..f0d5fda4a 100644 --- a/tests/codegen/mappers/test_dict.py +++ b/tests/codegen/mappers/test_dict.py @@ -1,7 +1,7 @@ import sys from unittest import mock -from xsdata.codegen.mappers.dict import DictMapper +from xsdata.codegen.mappers import DictMapper from xsdata.codegen.models import Restrictions from xsdata.codegen.utils import ClassUtils from xsdata.models.enums import DataType, Tag diff --git a/tests/codegen/mappers/test_dtd.py b/tests/codegen/mappers/test_dtd.py index 6c13fe22d..e437d9287 100644 --- a/tests/codegen/mappers/test_dtd.py +++ b/tests/codegen/mappers/test_dtd.py @@ -2,7 +2,7 @@ from typing import Iterator from unittest import mock -from xsdata.codegen.mappers.dtd import DtdMapper +from xsdata.codegen.mappers import DtdMapper from xsdata.codegen.models import Class, Restrictions from xsdata.models.dtd import ( DtdAttributeDefault, diff --git a/tests/codegen/mappers/test_element.py b/tests/codegen/mappers/test_element.py index ed441f267..ee6fb6ac9 100644 --- a/tests/codegen/mappers/test_element.py +++ b/tests/codegen/mappers/test_element.py @@ -1,7 +1,7 @@ import sys from unittest import mock -from xsdata.codegen.mappers.element import ElementMapper +from xsdata.codegen.mappers import ElementMapper from xsdata.codegen.models import Restrictions from xsdata.codegen.utils import ClassUtils from xsdata.formats.dataclass.models.generics import AnyElement @@ -267,28 +267,28 @@ def test_build_class_ignore_invalid(self): actual = ElementMapper.build_class(element, None) self.assertEqual(0, len(actual.attrs)) - def test_build_attribute_type(self): - actual = ElementMapper.build_attribute_type(QNames.XSI_TYPE, "") + def test_build_attr_type(self): + actual = ElementMapper.build_attr_type(QNames.XSI_TYPE, "") self.assertEqual(str(DataType.QNAME), actual.qname) self.assertTrue(actual.native) - actual = ElementMapper.build_attribute_type("name", "foo") + actual = ElementMapper.build_attr_type("name", "foo") self.assertEqual(str(DataType.STRING), actual.qname) self.assertTrue(actual.native) - actual = ElementMapper.build_attribute_type("name", "") + actual = ElementMapper.build_attr_type("name", "") self.assertEqual(str(DataType.ANY_SIMPLE_TYPE), actual.qname) self.assertTrue(actual.native) - actual = ElementMapper.build_attribute_type("name", None) + actual = ElementMapper.build_attr_type("name", None) self.assertEqual(str(DataType.ANY_SIMPLE_TYPE), actual.qname) self.assertTrue(actual.native) - actual = ElementMapper.build_attribute_type("name", 1) + actual = ElementMapper.build_attr_type("name", 1) self.assertEqual(str(DataType.SHORT), actual.qname) self.assertTrue(actual.native) - actual = ElementMapper.build_attribute_type("name", "1.9") + actual = ElementMapper.build_attr_type("name", "1.9") self.assertEqual(str(DataType.FLOAT), actual.qname) self.assertTrue(actual.native) diff --git a/tests/codegen/mappers/test_schema.py b/tests/codegen/mappers/test_schema.py index 3bbbeacb9..fc514993b 100644 --- a/tests/codegen/mappers/test_schema.py +++ b/tests/codegen/mappers/test_schema.py @@ -2,7 +2,7 @@ from typing import Iterator from unittest import mock -from xsdata.codegen.mappers.schema import SchemaMapper +from xsdata.codegen.mappers import SchemaMapper from xsdata.codegen.models import Class, Restrictions from xsdata.models.enums import DataType, FormType, Tag from xsdata.models.xsd import ( @@ -268,7 +268,7 @@ def test_children_extensions(self): self.assertIsInstance(children, GeneratorType) self.assertEqual(expected, list(children)) - @mock.patch.object(SchemaMapper, "build_class_attribute_types") + @mock.patch.object(SchemaMapper, "build_attr_types") @mock.patch.object(SchemaMapper, "element_namespace") @mock.patch.object(Attribute, "get_restrictions") @mock.patch.object(Attribute, "is_fixed", new_callable=mock.PropertyMock) @@ -285,13 +285,11 @@ def test_build_class_attribute( mock_is_fixed, mock_get_restrictions, mock_element_namespace, - mock_build_class_attribute_types, + mock_build_attr_types, ): item = ClassFactory.create(ns_map={"bar": "foo"}) - mock_build_class_attribute_types.return_value = AttrTypeFactory.list( - 1, qname="int" - ) + mock_build_attr_types.return_value = AttrTypeFactory.list(1, qname="int") mock_real_name.return_value = item.name mock_display_help.return_value = "sos" mock_prefix.return_value = "com" @@ -307,7 +305,7 @@ def test_build_class_attribute( SchemaMapper.build_class_attribute(item, attribute, Restrictions()) expected = AttrFactory.create( name=mock_real_name.return_value, - types=mock_build_class_attribute_types.return_value, + types=mock_build_attr_types.return_value, tag=Tag.ATTRIBUTE, namespace=mock_element_namespace.return_value, help=mock_display_help.return_value, @@ -318,20 +316,18 @@ def test_build_class_attribute( ) self.assertEqual(expected, item.attrs[0]) self.assertEqual({"bar": "foo", "foo": "bar"}, item.ns_map) - mock_build_class_attribute_types.assert_called_once_with(item, attribute) + mock_build_attr_types.assert_called_once_with(item, attribute) mock_element_namespace.assert_called_once_with(attribute, item.target_namespace) @mock.patch.object(Attribute, "attr_types", new_callable=mock.PropertyMock) @mock.patch.object(SchemaMapper, "build_inner_classes") - def test_build_class_attribute_types( - self, mock_build_inner_classes, mock_attr_types - ): + def test_build_attr_types(self, mock_build_inner_classes, mock_attr_types): mock_attr_types.return_value = ["xs:integer", "xs:string"] mock_build_inner_classes.return_value = [] item = ClassFactory.create() attribute = Attribute(default="false") - actual = SchemaMapper.build_class_attribute_types(item, attribute) + actual = SchemaMapper.build_attr_types(item, attribute) expected = [ AttrTypeFactory.native(DataType.INTEGER), @@ -342,7 +338,7 @@ def test_build_class_attribute_types( @mock.patch.object(Attribute, "attr_types", new_callable=mock.PropertyMock) @mock.patch.object(SchemaMapper, "build_inner_classes") - def test_build_class_attribute_types_when_obj_has_inner_class( + def test_build_attr_types_when_obj_has_inner_class( self, mock_build_inner_classes, mock_attr_types ): inner_class = ClassFactory.create(qname="foo") @@ -351,7 +347,7 @@ def test_build_class_attribute_types_when_obj_has_inner_class( item = ClassFactory.create() attribute = Attribute(default="false") - actual = SchemaMapper.build_class_attribute_types(item, attribute) + actual = SchemaMapper.build_attr_types(item, attribute) expected = [ AttrTypeFactory.native(DataType.INTEGER), @@ -365,7 +361,7 @@ def test_build_class_attribute_types_when_obj_has_inner_class( @mock.patch.object(Attribute, "default_type", new_callable=mock.PropertyMock) @mock.patch.object(Attribute, "attr_types", new_callable=mock.PropertyMock) @mock.patch.object(SchemaMapper, "build_inner_classes") - def test_build_class_attribute_types_when_obj_has_no_types( + def test_build_attr_types_when_obj_has_no_types( self, mock_build_inner_classes, mock_attr_types, mock_default_type ): mock_attr_types.return_value = "" @@ -374,7 +370,7 @@ def test_build_class_attribute_types_when_obj_has_no_types( item = ClassFactory.create() attribute = Attribute(default="false", name="attr") - actual = SchemaMapper.build_class_attribute_types(item, attribute) + actual = SchemaMapper.build_attr_types(item, attribute) self.assertEqual(1, len(actual)) self.assertEqual(AttrTypeFactory.native(DataType.STRING), actual[0]) diff --git a/tests/codegen/parsers/test_dtd.py b/tests/codegen/parsers/test_dtd.py index 1d8870d3e..a26b991aa 100644 --- a/tests/codegen/parsers/test_dtd.py +++ b/tests/codegen/parsers/test_dtd.py @@ -2,7 +2,7 @@ from unittest import TestCase, mock from tests import fixtures_dir -from xsdata.codegen.parsers.dtd import DtdParser +from xsdata.codegen.parsers import DtdParser from xsdata.exceptions import ParserError from xsdata.models.dtd import ( DtdAttributeDefault, diff --git a/tests/codegen/test_resolver.py b/tests/codegen/test_resolver.py index 0a7539a40..57d1ec3d3 100644 --- a/tests/codegen/test_resolver.py +++ b/tests/codegen/test_resolver.py @@ -131,12 +131,12 @@ def test_apply_aliases(self): @mock.patch.object(DependenciesResolver, "set_aliases") @mock.patch.object(DependenciesResolver, "resolve_conflicts") - @mock.patch.object(DependenciesResolver, "find_package") + @mock.patch.object(DependenciesResolver, "get_class_module") @mock.patch.object(DependenciesResolver, "import_classes") def test_resolve_imports( self, mock_import_classes, - mock_find_package, + mock_get_class_module, mock_resolve_conflicts, mock_set_aliases, ): @@ -150,7 +150,13 @@ def test_resolve_imports( ] self.resolver.class_map = {class_life.qname: class_life} mock_import_classes.return_value = import_names - mock_find_package.side_effect = ["first", "second", "third", "forth", "fifth"] + mock_get_class_module.side_effect = [ + "first", + "second", + "third", + "forth", + "fifth", + ] self.resolver.resolve_imports() mock_resolve_conflicts.assert_called_once_with( @@ -183,13 +189,13 @@ def test_set_aliases(self): self.resolver.set_aliases() self.assertEqual({"{a}a": "aa", "{b}a": "ba"}, self.resolver.aliases) - def test_find_package(self): + def test_get_class_module(self): class_a = ClassFactory.create() - self.resolver.packages[class_a.qname] = "foo.bar" + self.resolver.registry[class_a.qname] = "foo.bar" - self.assertEqual("foo.bar", self.resolver.find_package(class_a.qname)) + self.assertEqual("foo.bar", self.resolver.get_class_module(class_a.qname)) with self.assertRaises(ResolverValueError): - self.resolver.find_package("nope") + self.resolver.get_class_module("nope") def test_import_classes(self): self.resolver.class_list = list("abcdefg") diff --git a/tests/codegen/test_transformer.py b/tests/codegen/test_transformer.py index 00e19c275..9bff013b6 100644 --- a/tests/codegen/test_transformer.py +++ b/tests/codegen/test_transformer.py @@ -5,14 +5,15 @@ from xsdata.codegen.analyzer import ClassAnalyzer from xsdata.codegen.container import ClassContainer -from xsdata.codegen.mappers.definitions import DefinitionsMapper -from xsdata.codegen.mappers.dict import DictMapper -from xsdata.codegen.mappers.dtd import DtdMapper -from xsdata.codegen.mappers.element import ElementMapper -from xsdata.codegen.mappers.schema import SchemaMapper -from xsdata.codegen.parsers import DefinitionsParser -from xsdata.codegen.parsers.dtd import DtdParser -from xsdata.codegen.transformer import SchemaTransformer +from xsdata.codegen.mappers import ( + DefinitionsMapper, + DictMapper, + DtdMapper, + ElementMapper, + SchemaMapper, +) +from xsdata.codegen.parsers import DefinitionsParser, DtdParser +from xsdata.codegen.transformer import ResourceTransformer from xsdata.codegen.utils import ClassUtils from xsdata.codegen.writer import CodeWriter from xsdata.exceptions import CodeGenerationError @@ -24,18 +25,18 @@ from xsdata.utils.testing import ClassFactory, DtdFactory, FactoryTestCase -class SchemaTransformerTests(FactoryTestCase): +class ResourceTransformerTests(FactoryTestCase): def setUp(self): config = GeneratorConfig() - self.transformer = SchemaTransformer(print=True, config=config) + self.transformer = ResourceTransformer(print=True, config=config) super().setUp() - @mock.patch.object(SchemaTransformer, "process_classes") - @mock.patch.object(SchemaTransformer, "process_dtds") - @mock.patch.object(SchemaTransformer, "process_json_documents") - @mock.patch.object(SchemaTransformer, "process_xml_documents") - @mock.patch.object(SchemaTransformer, "process_schemas") - @mock.patch.object(SchemaTransformer, "process_definitions") + @mock.patch.object(ResourceTransformer, "process_classes") + @mock.patch.object(ResourceTransformer, "process_dtds") + @mock.patch.object(ResourceTransformer, "process_json_documents") + @mock.patch.object(ResourceTransformer, "process_xml_documents") + @mock.patch.object(ResourceTransformer, "process_schemas") + @mock.patch.object(ResourceTransformer, "process_definitions") def test_process( self, mock_process_definitions, @@ -65,9 +66,9 @@ def test_process( mock_process_dtds.assert_called_once_with(uris[8:]) mock_process_classes.assert_called_once_with() - @mock.patch.object(SchemaTransformer, "process_classes") - @mock.patch.object(SchemaTransformer, "process_sources") - @mock.patch.object(SchemaTransformer, "get_cache_file") + @mock.patch.object(ResourceTransformer, "process_classes") + @mock.patch.object(ResourceTransformer, "process_sources") + @mock.patch.object(ResourceTransformer, "get_cache_file") def test_process_from_cache( self, mock_get_cache_file, mock_process_sources, mock_process_classes ): @@ -85,9 +86,9 @@ def test_process_from_cache( self.assertEqual(0, mock_process_sources.call_count) mock_process_classes.assert_called_once_with() - @mock.patch.object(SchemaTransformer, "process_classes") - @mock.patch.object(SchemaTransformer, "process_sources") - @mock.patch.object(SchemaTransformer, "get_cache_file") + @mock.patch.object(ResourceTransformer, "process_classes") + @mock.patch.object(ResourceTransformer, "process_sources") + @mock.patch.object(ResourceTransformer, "get_cache_file") def test_process_with_cache( self, mock_get_cache_file, mock_process_sources, mock_process_classes ): @@ -104,9 +105,9 @@ def test_process_with_cache( self.assertEqual(classes, pickle.loads(cache.read_bytes())) mock_process_classes.assert_called_once_with() - @mock.patch.object(SchemaTransformer, "convert_schema") - @mock.patch.object(SchemaTransformer, "convert_definitions") - @mock.patch.object(SchemaTransformer, "parse_definitions") + @mock.patch.object(ResourceTransformer, "convert_schema") + @mock.patch.object(ResourceTransformer, "convert_definitions") + @mock.patch.object(ResourceTransformer, "parse_definitions") def test_process_definitions( self, mock_parse_definitions, @@ -129,7 +130,7 @@ def test_process_definitions( ) mock_convert_definitions.assert_called_once_with(fist_def) - @mock.patch.object(SchemaTransformer, "process_schema") + @mock.patch.object(ResourceTransformer, "process_schema") def test_process_schemas(self, mock_process_schema): uris = ["http://xsdata/foo.xsd", "http://xsdata/bar.xsd"] @@ -140,7 +141,7 @@ def test_process_schemas(self, mock_process_schema): @mock.patch.object(ClassUtils, "reduce_classes") @mock.patch.object(ElementMapper, "map") @mock.patch.object(TreeParser, "from_bytes") - @mock.patch.object(SchemaTransformer, "load_resource") + @mock.patch.object(ResourceTransformer, "load_resource") def test_process_xml_documents( self, mock_load_resource, mock_from_bytes, mock_map, mock_reduce_classes ): @@ -169,7 +170,7 @@ def test_process_xml_documents( @mock.patch("xsdata.codegen.transformer.logger.warning") @mock.patch.object(ClassUtils, "reduce_classes") @mock.patch.object(DictMapper, "map") - @mock.patch.object(SchemaTransformer, "load_resource") + @mock.patch.object(ResourceTransformer, "load_resource") def test_process_json_documents( self, mock_load_resource, mock_map, mock_reduce_classes, mock_warning ): @@ -201,7 +202,7 @@ def test_process_json_documents( @mock.patch.object(DtdMapper, "map") @mock.patch.object(DtdParser, "parse") - @mock.patch.object(SchemaTransformer, "load_resource") + @mock.patch.object(ResourceTransformer, "load_resource") def test_process_dtds(self, mock_load_resource, mock_parse, mock_map): uris = ["foo/a.dtd", "foo/b.dtd", "foo/c.dtd"] resources = [b"a", None, b"c"] @@ -233,7 +234,7 @@ def test_process_dtds(self, mock_load_resource, mock_parse, mock_map): @mock.patch("xsdata.codegen.transformer.logger.info") @mock.patch.object(CodeWriter, "print") - @mock.patch.object(SchemaTransformer, "analyze_classes") + @mock.patch.object(ResourceTransformer, "analyze_classes") def test_process_classes_with_print_true( self, mock_analyze_classes, @@ -258,7 +259,7 @@ def test_process_classes_with_print_true( @mock.patch("xsdata.codegen.transformer.logger.info") @mock.patch.object(CodeWriter, "write") - @mock.patch.object(SchemaTransformer, "analyze_classes") + @mock.patch.object(ResourceTransformer, "analyze_classes") def test_process_classes_with_print_false( self, mock_analyze_classes, @@ -288,8 +289,8 @@ def test_process_classes_with_zero_classes_after_analyze(self): self.assertEqual("Nothing to generate.", str(cm.exception)) - @mock.patch.object(SchemaTransformer, "convert_schema") - @mock.patch.object(SchemaTransformer, "parse_schema") + @mock.patch.object(ResourceTransformer, "convert_schema") + @mock.patch.object(ResourceTransformer, "parse_schema") def test_process_schema( self, mock_parse_schema, @@ -304,8 +305,8 @@ def test_process_schema( mock_convert_schema.assert_called_once_with(schema) - @mock.patch.object(SchemaTransformer, "convert_schema") - @mock.patch.object(SchemaTransformer, "parse_schema") + @mock.patch.object(ResourceTransformer, "convert_schema") + @mock.patch.object(ResourceTransformer, "parse_schema") def test_process_schema_ignores_empty_schema( self, mock_parse_schema, @@ -318,8 +319,8 @@ def test_process_schema_ignores_empty_schema( self.transformer.process_schema(uri, namespace) self.assertEqual(0, mock_convert_schema.call_count) - @mock.patch.object(SchemaTransformer, "generate_classes") - @mock.patch.object(SchemaTransformer, "process_schema") + @mock.patch.object(ResourceTransformer, "generate_classes") + @mock.patch.object(ResourceTransformer, "process_schema") def test_convert_schema(self, mock_process_schema, mock_generate_classes): schema = Schema(target_namespace="thug", location="main") schema.includes.append(Include(location="foo")) @@ -350,7 +351,7 @@ def test_convert_definitions(self, mock_definitions_map): self.assertEqual(classes, self.transformer.classes) @mock.patch("xsdata.codegen.transformer.logger.info") - @mock.patch.object(SchemaTransformer, "count_classes") + @mock.patch.object(ResourceTransformer, "count_classes") @mock.patch.object(SchemaMapper, "map") def test_generate_classes( self, mock_mapper_map, mock_count_classes, mock_logger_info @@ -377,10 +378,10 @@ def test_parse_schema(self): self.assertEqual(2, len(schema.complex_types)) self.assertIsNone(self.transformer.parse_schema(uri, None)) # Once - @mock.patch.object(SchemaTransformer, "process_schema") + @mock.patch.object(ResourceTransformer, "process_schema") @mock.patch.object(Definitions, "merge") @mock.patch.object(DefinitionsParser, "from_bytes") - @mock.patch.object(SchemaTransformer, "load_resource") + @mock.patch.object(ResourceTransformer, "load_resource") def test_parse_definitions( self, mock_load_resource, diff --git a/tests/codegen/test_utils.py b/tests/codegen/test_utils.py index 197c599f1..c656209d2 100644 --- a/tests/codegen/test_utils.py +++ b/tests/codegen/test_utils.py @@ -173,9 +173,9 @@ def test_copy_inner_classes(self, mock_copy_inner_class): mock_copy_inner_class.assert_has_calls( [ - mock.call(source, target, attr, attr.types[0]), - mock.call(source, target, attr, attr.types[1]), - mock.call(source, target, attr, attr.types[2]), + mock.call(source, target, attr.types[0]), + mock.call(source, target, attr.types[1]), + mock.call(source, target, attr.types[2]), ] ) @@ -185,11 +185,10 @@ def test_copy_inner_class(self): qname="a", module="b", package="c", status=Status.FLATTENED ) target = ClassFactory.create() - attr = AttrFactory.create() attr_type = AttrTypeFactory.create(forward=True, qname=inner.qname) source.inner.append(inner) - ClassUtils.copy_inner_class(source, target, attr, attr_type) + ClassUtils.copy_inner_class(source, target, attr_type) self.assertEqual(1, len(target.inner)) self.assertIsNot(inner, target.inner[0]) @@ -202,22 +201,20 @@ def test_copy_inner_class(self): def test_copy_inner_class_check_circular_reference(self): source = ClassFactory.create() target = ClassFactory.create() - attr = AttrFactory.create() attr_type = AttrTypeFactory.create(forward=True, qname=target.qname) source.inner.append(target) - ClassUtils.copy_inner_class(source, target, attr, attr_type) + ClassUtils.copy_inner_class(source, target, attr_type) self.assertTrue(attr_type.circular) self.assertEqual(0, len(target.inner)) def test_copy_inner_class_with_missing_inner(self): source = ClassFactory.create() target = ClassFactory.create() - attr = AttrFactory.create() attr_type = AttrTypeFactory.create(forward=True, qname=target.qname) with self.assertRaises(CodeGenerationError): - ClassUtils.copy_inner_class(source, target, attr, attr_type) + ClassUtils.copy_inner_class(source, target, attr_type) def test_find_inner(self): obj = ClassFactory.create(qname="{a}parent") diff --git a/tests/formats/dataclass/parsers/nodes/test_element.py b/tests/formats/dataclass/parsers/nodes/test_element.py index 2d5d16e3e..21f3f7840 100644 --- a/tests/formats/dataclass/parsers/nodes/test_element.py +++ b/tests/formats/dataclass/parsers/nodes/test_element.py @@ -275,27 +275,19 @@ def test_bind_wild_list_var(self): self.assertEqual(expected, params) def test_prepare_generic_value(self): - var = XmlVarFactory.create( - index=2, - xml_type=XmlType.WILDCARD, - qname="a", - types=(object,), - elements={"known": XmlVarFactory.create()}, - ) - - actual = self.node.prepare_generic_value(None, 1, var) + actual = self.node.prepare_generic_value(None, 1) self.assertEqual(1, actual) - actual = self.node.prepare_generic_value("a", 1, var) + actual = self.node.prepare_generic_value("a", 1) expected = AnyElement(qname="a", text="1") self.assertEqual(expected, actual) - actual = self.node.prepare_generic_value("a", "foo", var) + actual = self.node.prepare_generic_value("a", "foo") expected = AnyElement(qname="a", text="foo") self.assertEqual(expected, actual) fixture = make_dataclass("Fixture", [("content", str)]) - actual = self.node.prepare_generic_value("a", fixture("foo"), var) + actual = self.node.prepare_generic_value("a", fixture("foo")) self.assertEqual(fixture("foo"), actual) def test_child(self): diff --git a/tests/formats/dataclass/parsers/test_xml.py b/tests/formats/dataclass/parsers/test_xml.py index abf1f0a3c..7bc3f6778 100644 --- a/tests/formats/dataclass/parsers/test_xml.py +++ b/tests/formats/dataclass/parsers/test_xml.py @@ -57,4 +57,4 @@ def test_emit_event(self): self.parser.emit_event("foo", "{tns}BarEl", a=1, b=2) mock_func.assert_called_once_with(a=1, b=2) - self.assertEqual({("foo", "{tns}BarEl"): mock_func}, self.parser.emit_cache) + self.assertEqual({("foo", "{tns}BarEl"): mock_func}, self.parser.hooks_cache) diff --git a/tests/formats/dataclass/serializers/test_code.py b/tests/formats/dataclass/serializers/test_code.py index 96f14be2b..66fe27b73 100644 --- a/tests/formats/dataclass/serializers/test_code.py +++ b/tests/formats/dataclass/serializers/test_code.py @@ -83,24 +83,24 @@ def test_write_string_with_unicode_characters(self): self.assertEqual(expected, result) def test_write_object_with_empty_array(self): - iterator = self.serializer.write_object([], 0, set()) + iterator = self.serializer.repr_object([], 0, set()) self.assertEqual("[]", "".join(iterator)) - iterator = self.serializer.write_object((), 0, set()) + iterator = self.serializer.repr_object((), 0, set()) self.assertEqual("()", "".join(iterator)) - iterator = self.serializer.write_object(set(), 0, set()) + iterator = self.serializer.repr_object(set(), 0, set()) self.assertEqual("set()", "".join(iterator)) def test_write_object_with_mapping(self): - iterator = self.serializer.write_object({}, 0, set()) + iterator = self.serializer.repr_object({}, 0, set()) self.assertEqual("{}", "".join(iterator)) - iterator = self.serializer.write_object({"foo": "bar"}, 0, set()) + iterator = self.serializer.repr_object({"foo": "bar"}, 0, set()) self.assertEqual("{\n 'foo': 'bar',\n}", "".join(iterator)) def test_write_object_with_enum(self): - iterator = self.serializer.write_object(Namespace.SOAP11, 0, set()) + iterator = self.serializer.repr_object(Namespace.SOAP11, 0, set()) self.assertEqual("Namespace.SOAP11", "".join(iterator)) def test_build_imports_with_nested_types(self): diff --git a/tests/formats/dataclass/serializers/test_json.py b/tests/formats/dataclass/serializers/test_json.py index 51757944d..34e143477 100644 --- a/tests/formats/dataclass/serializers/test_json.py +++ b/tests/formats/dataclass/serializers/test_json.py @@ -1,7 +1,5 @@ import json -import warnings from unittest.case import TestCase -from unittest.mock import ANY, Mock, call from tests.fixtures.books import BookForm, Books from tests.fixtures.datatypes import Telephone @@ -87,32 +85,6 @@ def test_convert_namedtuple(self): actual = serializer.convert(Telephone(30, 234, 56783), var) self.assertEqual("30-234-56783", actual) - def test_indent_deprecation(self): - dump_factory = Mock(json.dump) - - with warnings.catch_warnings(record=True) as w: - serializer = JsonSerializer(dump_factory=dump_factory) - serializer.render(self.books) - - serializer.config.pretty_print = True - serializer.render(self.books) - - serializer.indent = 4 - serializer.render(self.books) - - dump_factory.assert_has_calls( - [ - call(ANY, ANY, indent=None), - call(ANY, ANY, indent=2), - call(ANY, ANY, indent=4), - ] - ) - - self.assertEqual( - "JsonSerializer indent property is deprecated, use SerializerConfig", - str(w[-1].message), - ) - def test_next_value(self): book = self.books.book[0] serializer = JsonSerializer() diff --git a/tests/formats/dataclass/serializers/test_mixins.py b/tests/formats/dataclass/serializers/test_mixins.py index 43f47936d..c4e007307 100644 --- a/tests/formats/dataclass/serializers/test_mixins.py +++ b/tests/formats/dataclass/serializers/test_mixins.py @@ -1,6 +1,6 @@ from io import StringIO -from typing import Dict, TextIO from unittest import TestCase +from xml.sax import ContentHandler from xml.sax.saxutils import XMLGenerator from xsdata.exceptions import XmlWriterError @@ -10,12 +10,8 @@ class XmlWriterImpl(XmlWriter): - __slots__ = () - - def __init__(self, config: SerializerConfig, output: TextIO, ns_map: Dict): - super().__init__(config, output, ns_map) - - self.handler = XMLGenerator( + def build_handler(self) -> ContentHandler: + return XMLGenerator( self.output, encoding="UTF-8", short_empty_elements=True, diff --git a/tests/models/test_config.py b/tests/models/test_config.py index 5bde12fe5..c8247b112 100644 --- a/tests/models/test_config.py +++ b/tests/models/test_config.py @@ -8,12 +8,9 @@ from xsdata.exceptions import GeneratorConfigError, ParserError from xsdata.models.config import ( ExtensionType, - GeneratorAlias, - GeneratorAliases, GeneratorConfig, GeneratorExtension, GeneratorOutput, - ObjectType, OutputFormat, ) @@ -79,7 +76,6 @@ def test_read(self): " \n" ' \n' " \n" - " \n" " \n" "\n" ) @@ -113,7 +109,6 @@ def test_read(self): ' \n' ' \n' " \n" - " \n" " \n" " \n" "\n" @@ -204,32 +199,6 @@ def test_format_kw_only_requires_310(self): else: self.assertIsNotNone(OutputFormat(kw_only=True)) - def test_init_config_with_aliases(self): - config = GeneratorConfig( - aliases=GeneratorAliases( - class_name=[GeneratorAlias(source="a", target="b")], - field_name=[GeneratorAlias(source="c", target="d")], - package_name=[GeneratorAlias(source="e", target="f")], - module_name=[GeneratorAlias(source="g", target="h")], - ) - ) - - self.assertEqual(4, len(config.substitutions.substitution)) - self.assertEqual(ObjectType.CLASS, config.substitutions.substitution[0].type) - self.assertEqual(ObjectType.FIELD, config.substitutions.substitution[1].type) - self.assertEqual(ObjectType.PACKAGE, config.substitutions.substitution[2].type) - self.assertEqual(ObjectType.MODULE, config.substitutions.substitution[3].type) - - output = tempfile.mktemp() - output_path = Path(output) - config.substitutions = None - with output_path.open("w") as fp: - config.write(fp, config) - - config = GeneratorConfig.read(output_path) - self.assertIsNone(config.aliases) - self.assertEqual(4, len(config.substitutions.substitution)) - def test_extension_with_invalid_import_string(self): cases = [None, "", "bar"] for case in cases: diff --git a/tests/models/test_type_mapping.py b/tests/models/test_type_mapping.py index 5ccf8a580..7607237ee 100644 --- a/tests/models/test_type_mapping.py +++ b/tests/models/test_type_mapping.py @@ -24,7 +24,7 @@ def test_type_mapping(self): json_serializer = JsonSerializer(config=serializer_config) xml_serializer = XmlSerializer(config=serializer_config) - pycode_serializer = PycodeSerializer(config=serializer_config) + pycode_serializer = PycodeSerializer() for model in (city1, street1, house1): json_serializer.render(model) diff --git a/tests/models/xsd/test_annotation_base.py b/tests/models/xsd/test_annotation_base.py index d01dffe3a..8c0e7e5e1 100644 --- a/tests/models/xsd/test_annotation_base.py +++ b/tests/models/xsd/test_annotation_base.py @@ -16,7 +16,7 @@ def test_property_display_help(self): Annotation( documentations=[ Documentation( - elements=[ + content=[ " I am a ", AnyElement( qname="{http://www.w3.org/1999/xhtml}p", diff --git a/tests/test_cli.py b/tests/test_cli.py index 6fd02d5d6..604372c0c 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -10,7 +10,7 @@ from tests import fixtures_dir from xsdata import __version__ from xsdata.cli import cli, resolve_source -from xsdata.codegen.transformer import SchemaTransformer +from xsdata.codegen.transformer import ResourceTransformer from xsdata.codegen.writer import CodeWriter from xsdata.formats.dataclass.generator import DataclassGenerator from xsdata.logger import logger @@ -33,8 +33,8 @@ def setUpClass(cls): def tearDownClass(cls): CodeWriter.unregister_generator("testing") - @mock.patch.object(SchemaTransformer, "process") - @mock.patch.object(SchemaTransformer, "__init__", return_value=None) + @mock.patch.object(ResourceTransformer, "process") + @mock.patch.object(ResourceTransformer, "__init__", return_value=None) def test_generate(self, mock_init, mock_process): source = fixtures_dir.joinpath("defxmlschema/chapter03.xsd") result = self.runner.invoke(cli, [str(source), "--package", "foo"]) @@ -48,8 +48,8 @@ def test_generate(self, mock_init, mock_process): self.assertEqual(StructureStyle.FILENAMES, config.output.structure_style) self.assertEqual([source.as_uri()], mock_process.call_args[0][0]) - @mock.patch.object(SchemaTransformer, "process") - @mock.patch.object(SchemaTransformer, "__init__", return_value=None) + @mock.patch.object(ResourceTransformer, "process") + @mock.patch.object(ResourceTransformer, "__init__", return_value=None) def test_generate_with_configuration_file(self, mock_init, mock_process): file_path = Path(tempfile.mktemp()) config = GeneratorConfig() @@ -75,8 +75,8 @@ def test_generate_with_configuration_file(self, mock_init, mock_process): self.assertEqual([source.as_uri()], mock_process.call_args[0][0]) file_path.unlink() - @mock.patch.object(SchemaTransformer, "process") - @mock.patch.object(SchemaTransformer, "__init__", return_value=None) + @mock.patch.object(ResourceTransformer, "process") + @mock.patch.object(ResourceTransformer, "__init__", return_value=None) def test_generate_with_print_mode(self, mock_init, mock_process): source = fixtures_dir.joinpath("defxmlschema/chapter03.xsd") result = self.runner.invoke(cli, [str(source), "--package", "foo", "--print"]) @@ -85,8 +85,8 @@ def test_generate_with_print_mode(self, mock_init, mock_process): self.assertEqual([source.as_uri()], mock_process.call_args[0][0]) self.assertTrue(mock_init.call_args[1]["print"]) - @mock.patch.object(SchemaTransformer, "process") - @mock.patch.object(SchemaTransformer, "__init__", return_value=None) + @mock.patch.object(ResourceTransformer, "process") + @mock.patch.object(ResourceTransformer, "__init__", return_value=None) def test_generate_with_debug_mode(self, *args): self.runner.invoke(cli, ["foo.xsd", "--package", "foo", "--debug"]) self.assertEqual(logging.DEBUG, logger.level) diff --git a/tests/utils/test_text.py b/tests/utils/test_text.py index d10530f48..420869586 100644 --- a/tests/utils/test_text.py +++ b/tests/utils/test_text.py @@ -2,7 +2,7 @@ from unittest import TestCase from xsdata.utils.text import ( - StringType, + CharType, alnum, camel_case, capitalize, @@ -155,13 +155,13 @@ def test_variable(self): def test_classify(self): for ltr in string.ascii_uppercase: - self.assertEqual(StringType.UPPER, classify(ltr)) + self.assertEqual(CharType.UPPER, classify(ltr)) for ltr in string.ascii_lowercase: - self.assertEqual(StringType.LOWER, classify(ltr)) + self.assertEqual(CharType.LOWER, classify(ltr)) for ltr in string.digits: - self.assertEqual(StringType.NUMERIC, classify(ltr)) + self.assertEqual(CharType.NUMERIC, classify(ltr)) for ltr in "~!@#$%^&*()_+β": - self.assertEqual(StringType.OTHER, classify(ltr)) + self.assertEqual(CharType.OTHER, classify(ltr)) diff --git a/xsdata/__main__.py b/xsdata/__main__.py index a4fb12120..650140786 100644 --- a/xsdata/__main__.py +++ b/xsdata/__main__.py @@ -2,6 +2,7 @@ def main(): + """Cli entry point.""" try: from xsdata.cli import cli diff --git a/xsdata/cli.py b/xsdata/cli.py index b2f9f8b5f..96945e623 100644 --- a/xsdata/cli.py +++ b/xsdata/cli.py @@ -9,7 +9,7 @@ from click_default_group import DefaultGroup from xsdata import __version__ -from xsdata.codegen.transformer import SchemaTransformer +from xsdata.codegen.transformer import ResourceTransformer from xsdata.logger import logger from xsdata.models.config import GeneratorConfig, GeneratorOutput from xsdata.utils.click import LogFormatter, LogHandler, model_options @@ -116,14 +116,11 @@ def download(source: str, output: str): @click.option("--debug", is_flag=True, default=False, help="Show debug messages") @model_options(GeneratorOutput) def generate(**kwargs: Any): - """ - Generate code from xml schemas, webservice definitions and any xml or json - document. + """Generate code from xsd, dtd, wsdl, xml and json files. The input source can be either a filepath, uri or a directory containing xml, json, xsd and wsdl files. """ - debug = kwargs.pop("debug") if debug: logger.setLevel(logging.DEBUG) @@ -138,7 +135,7 @@ def generate(**kwargs: Any): config = GeneratorConfig.read(config_file) config.output.update(**params) - transformer = SchemaTransformer(config=config, print=stdout) + transformer = ResourceTransformer(config=config, print=stdout) uris = sorted(resolve_source(source, recursive=recursive)) transformer.process(uris, cache=cache) @@ -146,6 +143,7 @@ def generate(**kwargs: Any): def resolve_source(source: str, recursive: bool) -> Iterator[str]: + """Yields all supported resource URIs.""" if source.find("://") > -1 and not source.startswith("file://"): yield source else: diff --git a/xsdata/codegen/analyzer.py b/xsdata/codegen/analyzer.py index 97a654b95..f76ffa2fc 100644 --- a/xsdata/codegen/analyzer.py +++ b/xsdata/codegen/analyzer.py @@ -7,12 +7,20 @@ class ClassAnalyzer: - """Validate, analyze, sanitize and select the final class list to be - generated.""" + """Validate, analyze, sanitize and filter the generated classes.""" @classmethod def process(cls, container: ClassContainer) -> List[Class]: - """Run all the processes.""" + """Main entrypoint for the class container instance. + + Orchestrate the class validations and processors. + + Args: + container: The class container instance + + Returns: + The list of classes to be generated. + """ # Run validation checks for duplicate, invalid and redefined types. ClassValidator(container).process() @@ -25,8 +33,17 @@ def process(cls, container: ClassContainer) -> List[Class]: return classes @classmethod - def class_references(cls, target: Class) -> List: - """Produce a list of instance references for the given class.""" + def class_references(cls, target: Class) -> List[int]: + """Produce a list of instance references for the given class. + + Collect the ids of the class, attr, extension and inner instances. + + Args: + target: The target class instance + + List: + The list of id references. + """ result = [id(target)] for attr in target.attrs: result.append(id(attr)) @@ -43,7 +60,17 @@ def class_references(cls, target: Class) -> List: @classmethod def validate_references(cls, classes: List[Class]): - """Validate all code gen objects are not cross referenced.""" + """Validate all codegen objects are not cross-referenced. + + This validation ensures we never share any attr, or extension + between classes. + + Args: + classes: The list of classes to be generated. + + Raises: + AnalyzerValueError: If an object is shared between the classes. + """ references = [ref for obj in classes for ref in cls.class_references(obj)] if len(references) != len(set(references)): raise AnalyzerValueError("Cross references detected!") diff --git a/xsdata/codegen/container.py b/xsdata/codegen/container.py index 31a754fa3..72a269ac4 100644 --- a/xsdata/codegen/container.py +++ b/xsdata/codegen/container.py @@ -31,6 +31,8 @@ class Steps: + """Process steps.""" + UNGROUP = 10 FLATTEN = 20 SANITIZE = 30 @@ -39,11 +41,27 @@ class Steps: class ClassContainer(ContainerInterface): + """A class list wrapper with an easy access api. + + Args: + config: The generator configuration instance + + Attributes: + processors: A step-processors mapping + step: The current process step + """ + __slots__ = ("data", "processors", "step") def __init__(self, config: GeneratorConfig): - """Initialize a class container instance with its processors based on - the provided configuration.""" + """Initialize the container and all the class processors. + + The order of the steps and the processors is the secret + recipe of the xsdata code generator. + + Args: + config: The generator configuration instance + """ super().__init__(config) self.step: int = 0 @@ -75,19 +93,29 @@ def __init__(self, config: GeneratorConfig): Steps.FINALIZE: [ VacuumInnerClasses(), CreateCompoundFields(self), - # Prettify things!!! ResetAttributeSequenceNumbers(self), ], } def __iter__(self) -> Iterator[Class]: - """Create an iterator for the class map values.""" + """Yield an iterator for the class map values.""" for items in list(self.data.values()): yield from items def find(self, qname: str, condition: Callable = return_true) -> Optional[Class]: - """Search by qualified name for a specific class with an optional - condition callable.""" + """Find class that matches the given qualified name and condition callable. + + Classes are allowed to have the same qualified name, e.g. xsd:Element + extending xsd:ComplexType with the same name, you can provide and additional + callback to filter the classes like the tag. + + Args: + qname: The qualified name of the class + condition: A user callable to filter further + + Returns: + A class instance or None if no match found. + """ for row in self.data.get(qname, []): if condition(row): if row.status < self.step: @@ -98,6 +126,18 @@ def find(self, qname: str, condition: Callable = return_true) -> Optional[Class] return None def find_inner(self, source: Class, qname: str) -> Class: + """Search by qualified name for a specific inner class or fail. + + Args: + source: The source class to search for the inner class + qname: The qualified name of the inner class to look up + + Returns: + The inner class instance + + Raises: + CodeGenerationError: If the inner class is not found. + """ inner = ClassUtils.find_inner(source, qname) if inner.status < self.step: self.process_class(inner, self.step) @@ -105,6 +145,17 @@ def find_inner(self, source: Class, qname: str) -> Class: return inner def first(self, qname: str) -> Class: + """Return the first class that matches the qualified name. + + Args: + qname: The qualified name of the class + + Returns: + The first matching class + + Raises: + KeyError: If no class matches the qualified name + """ classes = self.data.get(qname) if not classes: raise KeyError(f"Class {qname} not found") @@ -112,7 +163,17 @@ def first(self, qname: str) -> Class: return classes[0] def process(self): - """The hidden naive recipe of processing xsd models.""" + """Run the processor and filter steps. + + Steps: + 1. Ungroup xs:groups and xs:attributeGroups + 2. Remove the group classes from the container + 3. Flatten extensions, attrs and attr types + 4. Remove the classes that won't be generated + 5. Resolve attrs overrides + 5. Create compound fields, cleanup classes and atts + 7. Designate final class names, packages and modules + """ self.process_classes(Steps.UNGROUP) self.remove_groups() self.process_classes(Steps.FLATTEN) @@ -122,13 +183,26 @@ def process(self): self.process_classes(Steps.FINALIZE) self.designate_classes() - def process_classes(self, step: int) -> None: + def process_classes(self, step: int): + """Run the given step processors for all classes. + + Args: + step: The step reference number + """ self.step = step for obj in self: if obj.status < step: self.process_class(obj, step) def process_class(self, target: Class, step: int): + """Run the step processors for the given class. + + Process recursively any inner classes as well. + + Args: + target: The target class to process + step: The step reference number + """ target.status = Status(step) for processor in self.processors.get(step, []): processor.process(target) @@ -140,6 +214,7 @@ def process_class(self, target: Class, step: int): target.status = Status(step + 1) def designate_classes(self): + """Designate the final class names, packages and modules.""" designators = [ RenameDuplicateClasses(self), DesignateClassPackages(self), @@ -153,20 +228,40 @@ def filter_classes(self): FilterClasses(self).run() def remove_groups(self): + """Remove xs:groups and xs:attributeGroups from the container.""" self.set([x for x in iter(self) if not x.is_group]) def add(self, item: Class): - """Add class item to the container.""" + """Add class item to the container. + + Args: + item: The class instance to add + """ self.data.setdefault(item.qname, []).append(item) def reset(self, item: Class, qname: str): + """Update the given class qualified name. + + Args: + item: The target class instance to update + qname: The new qualified name of the class + """ self.data[qname].remove(item) self.add(item) def set(self, items: List[Class]): + """Set the list of classes to the container. + + Args: + items: The list of classes + """ self.data.clear() self.extend(items) def extend(self, items: List[Class]): - """Add a list of classes the container.""" + """Add a list of classes to the container. + + Args: + items: The list of class instances to add + """ collections.apply(items, self.add) diff --git a/xsdata/codegen/handlers/add_attribute_substitutions.py b/xsdata/codegen/handlers/add_attribute_substitutions.py index 20840b69d..e380459a7 100644 --- a/xsdata/codegen/handlers/add_attribute_substitutions.py +++ b/xsdata/codegen/handlers/add_attribute_substitutions.py @@ -9,7 +9,14 @@ class AddAttributeSubstitutions(RelativeHandlerInterface): - """Apply substitution attributes to the given class recursively.""" + """Apply substitution attributes to the given class recursively. + + Args: + container: The class container instance + + Attributes: + substitutions: Mapping of type names to attr values + """ __slots__ = "substitutions" @@ -18,11 +25,15 @@ def __init__(self, container: ContainerInterface): self.substitutions: Optional[Dict[str, List[Attr]]] = None def process(self, target: Class): - """ - Search and process attributes not derived from xs:enumeration or - xs:any. + """Process the given class attrs for substitution groups. - Build the substitutions map if it's not initialized yet. + This method will ignore attrs in the class derived from + a xs:enumeration, xs:anyType and xs:any. If this is the + first time we call the method, build the substitution + map. + + Args: + target: The target class instance """ if self.substitutions is None: self.create_substitutions() @@ -32,14 +43,18 @@ def process(self, target: Class): self.process_attribute(target, attr) def process_attribute(self, target: Class, attr: Attr): - """ - Check if the given attribute matches any substitution class in order to - clone its attributes to the target class. + """Add substitution attrs that refer to the attr type. - The cloned attributes are placed below the attribute they are - supposed to substitute. + If the given attr is referenced in substitution groups + clone all substitution attrs and place them bellow + the original attr. Convert all the attrs of the group + to repeatable choice elements. Guard against multiple substitutions in case of xs:groups. + + Args: + target: The target class instance + attr: The source attr instance to check and process """ index = target.attrs.index(attr) assert self.substitutions is not None @@ -65,9 +80,11 @@ def process_attribute(self, target: Class, attr: Attr): self.process_attribute(target, clone) def create_substitutions(self): - """Create reference attributes for all the classes substitutions and - group them by their fully qualified name.""" + """Build the substitutions mapping of type names to attr values. + The values are simple reference attrs that we can easily + clone later on demand. + """ self.substitutions = defaultdict(list) for obj in self.container: for qname in obj.substitutions: @@ -76,6 +93,11 @@ def create_substitutions(self): @classmethod def prepare_substituted(cls, attr: Attr): + """Prepare the original attr for substitutions. + + Effectively place the attr inside a xs:choice container + with min occurs zero. + """ attr.restrictions.min_occurs = 0 if not attr.restrictions.choice: choice = id(attr) @@ -84,9 +106,14 @@ def prepare_substituted(cls, attr: Attr): @classmethod def create_substitution(cls, source: Class) -> Attr: - """Create an attribute with type that refers to the given source class - and namespaced qualified name.""" + """Create a reference attr to the source class qname. + Args: + source: The source class to reference + + Returns: + The reference to the source class attr. + """ return Attr( name=source.name, types=[AttrType(qname=source.qname)], diff --git a/xsdata/codegen/handlers/calculate_attribute_paths.py b/xsdata/codegen/handlers/calculate_attribute_paths.py index 47939b154..39d2c5473 100644 --- a/xsdata/codegen/handlers/calculate_attribute_paths.py +++ b/xsdata/codegen/handlers/calculate_attribute_paths.py @@ -8,13 +8,22 @@ class CalculateAttributePaths(HandlerInterface): - """Calculate min/max occurs and sequence/choice/group from the schema - path.""" + """Calculate min/max occurs and sequence/choice/group from the schema path.""" __slots__ = () @classmethod def process(cls, target: Class): + """Calculating the class attrs restrictions by their schema path. + + For each attr calculate the min/max occurs and set the + sequence/choice/group reference id. Ignore attrs derived + from xs:attribute and xs:enumeration as these are not affected + by the parent element. + + Args: + target: The target class instance + """ for attr in target.attrs: if ( attr.restrictions.path @@ -25,6 +34,22 @@ def process(cls, target: Class): @classmethod def process_attr_path(cls, attr: Attr): + """Entrypoint for processing a class attr. + + Example path: + ("s", 1, 1, 1), ("s", 2, 1, 2), ("c", 3, 0, 10) -> + sequence:1 with min=1 and max_occurs=1 + sequence:2 with min=1 and max_occurs=2 + choice:3 with min=0 and max_occurs=10 + + Steps: + - Every attr starts with a min/max occurs equal to one. + - For every parent container multiply the min/max occurs + - Set the sequence/choice/group reference ids as you go along + + Args: + attr: The attr of the class to check and process + """ min_occurs = 1 max_occurs = 1 for path in attr.restrictions.path: diff --git a/xsdata/codegen/handlers/create_compound_fields.py b/xsdata/codegen/handlers/create_compound_fields.py index df7271b07..25f9b0ff9 100644 --- a/xsdata/codegen/handlers/create_compound_fields.py +++ b/xsdata/codegen/handlers/create_compound_fields.py @@ -16,17 +16,31 @@ class CreateCompoundFields(RelativeHandlerInterface): - """Group attributes that belong in the same choice and replace them by - compound fields.""" + """Process attrs that belong in the same choice. + + Args: + container: The class container instance + + Attributes: + config: The compound fields configuration + """ __slots__ = "config" def __init__(self, container: ContainerInterface): super().__init__(container) - self.config = container.config.output.compound_fields def process(self, target: Class): + """Process the attrs of class that belong in the same choice. + + If the compound fields configuration is enabled replace the + attrs with a compound attr, otherwise recalculate the min + occurs restriction for each of them. + + Args: + target: The target class instance + """ groups = group_by(target.attrs, get_restriction_choice) for choice, attrs in groups.items(): if choice and len(attrs) > 1: @@ -37,6 +51,16 @@ def process(self, target: Class): @classmethod def calculate_choice_min_occurs(cls, attrs: List[Attr]): + """Calculate the min occurs restriction of the attrs. + + If that attr has a path that includes a xs:choice + with min occurs less than 1, update the attr min + occurs restriction to zero, effectively marking + it as optional. + + Args: + attrs: A list of attrs that belong in the same xs:choice. + """ for attr in attrs: for path in attr.restrictions.path: name, index, mi, ma = path @@ -45,6 +69,25 @@ def calculate_choice_min_occurs(cls, attrs: List[Attr]): @classmethod def update_counters(cls, attr: Attr, counters: Dict): + """Update the counters dictionary with the attr min/max restrictions. + + This method builds a nested counters mapping per path. + + Example: + { + "min": [0, 1, 2] + "max": [3, 4, 5], + ("c", 1, 0, 1): { + "min": [6, 7, 8], + "max": [9, 10, 11], + ("g", 2, 0, 1): { ... } + } + } + + Args: + attr: The source attr instance + counters: The nested counters instance to update + """ started = False choice = attr.restrictions.choice for path in attr.restrictions.path: @@ -64,7 +107,12 @@ def update_counters(cls, attr: Attr, counters: Dict): counters["max"].append(attr.restrictions.max_occurs) def group_fields(self, target: Class, attrs: List[Attr]): - """Group attributes into a new compound field.""" + """Group attributes into a new compound field. + + Args: + target: The target class instance + attrs: A list of attrs that belong to the same choice + """ pos = target.attrs.index(attrs[0]) choice = attrs[0].restrictions.choice @@ -103,6 +151,11 @@ def group_fields(self, target: Class, attrs: List[Attr]): ) def sum_counters(self, counters: Dict) -> Tuple[List[int], List[int]]: + """Sum the min/max occurrences for the compound attr. + + Args: + counters: The counters map of all the choice attrs + """ min_occurs = counters.pop("min", []) max_occurs = counters.pop("max", []) @@ -119,8 +172,34 @@ def sum_counters(self, counters: Dict) -> Tuple[List[int], List[int]]: return min_occurs, max_occurs def choose_name( - self, target: Class, names: List[str], substitutions: List[str] + self, + target: Class, + names: List[str], + substitutions: List[str], ) -> str: + """Choose a name for the compound attr. + + If the attrs were placed in the same choice because of a single + substitution group and the configuration `use_substitution_groups` + is enabled, the group name will be used for the compound attr. + + Otherwise, the name will be the concatenation of the names of the + attrs, if the length of the attrs is less than the `max_name_parts` + config and the `force_default_name` is false, + e.g. `hat_Or_bat_Or_bar` + + Otherwise, the name will the default from the config, + e.g. `choice` + + If there are any name collisions with other class attrs, the system + will add an integer suffix till the name is unique in the class + e.g. `choice_1`, `choice_2`, `choice_3` + + Args: + target: The target class + names: The list of the attr names + substitutions: The list of the substitution group names of the attrs + """ if self.config.use_substitution_groups and len(names) == len(substitutions): names = substitutions @@ -134,6 +213,14 @@ def choose_name( return ClassUtils.unique_name(name, reserved) def build_reserved_names(self, target: Class, names: List[str]) -> Set[str]: + """Build a set of reserved attr names. + + The method will also check parent attrs. + + Args: + target: The target class instance + names: The potential names for the compound attr + """ names_counter = Counter(names) all_attrs = self.base_attrs(target) all_attrs.extend(target.attrs) @@ -147,11 +234,19 @@ def build_reserved_names(self, target: Class, names: List[str]) -> Set[str]: @classmethod def build_attr_choice(cls, attr: Attr) -> Attr: - """ - Converts the given attr to a choice. + """Build the choice attr from a normal attr. + + Steps: + - Clone the original attr restrictions + - Reset the min/max occurs + - Remove the sequence reference + - Build the new attr and maintain the basic attributes + + Args: + attr: The source attr instance to use - The most important part is the reset of certain restrictions - that don't make sense as choice metadata like occurrences. + Returns: + The new choice attr for the compound attr. """ restrictions = attr.restrictions.clone() restrictions.min_occurs = None diff --git a/xsdata/codegen/handlers/designate_class_packages.py b/xsdata/codegen/handlers/designate_class_packages.py index c5cd65039..160fa8eff 100644 --- a/xsdata/codegen/handlers/designate_class_packages.py +++ b/xsdata/codegen/handlers/designate_class_packages.py @@ -19,12 +19,20 @@ class DesignateClassPackages(ContainerHandlerInterface): - """Designate classes to packages and modules based on the output structure - style.""" + """Designate classes to packages and modules based on the output structure style.""" __slots__ = () def run(self): + """Group classes to packages and modules based on the output structure style. + + Structure Styles: + - Namespaces: classes with the same namespace are grouped together + - Single Package: all classes are grouped together + - Clusters: strong connected classes are grouped together + - Namespace clusters: A combination of the namespaces and clusters + - Filenames: classes are grouped together by the schema file location + """ structure_style = self.container.config.output.structure_style if structure_style == StructureStyle.NAMESPACES: self.group_by_namespace() @@ -38,8 +46,14 @@ def run(self): self.group_by_filenames() def group_by_filenames(self): - """Group uris by common path and auto assign package names to all - classes.""" + """Group classes by their schema file location. + + The classes are organized by the common paths of the + file locations. + + Example: + http://xsdata/foo/bar/schema.xsd -> foo.bar.schema.py + """ package = self.container.config.output.package class_map = collections.group_by(self.container, key=get_location) groups = self.group_common_paths(class_map.keys()) @@ -58,7 +72,11 @@ def group_by_filenames(self): self.assign(items, package_name, module_name(key)) def group_by_namespace(self): - """Group classes by their target namespace.""" + """Group classes by their target namespace. + + Example: + {myNS.tempuri.org}Root -> org.tempuri.myNS.py + """ groups = collections.group_by(self.container, key=get_target_namespace) for namespace, classes in groups.items(): parts = self.combine_ns_package(namespace) @@ -75,16 +93,21 @@ def group_all_together(self): self.assign(self.container, package, module) def group_by_strong_components(self): - """Find circular imports and cluster their classes together.""" + """Find circular imports and cluster their classes together. + + This grouping ideally creates a class per file, if there + are circular imports, the classes will be grouped together. + """ package = self.container.config.output.package for group in self.strongly_connected_classes(): - classes = self.sorted_classes(group) + classes = self.sort_classes(group) module = classes[0].name self.assign(classes, package, module) def group_by_namespace_clusters(self): + """Group strongly connected classes together by namespaces.""" for group in self.strongly_connected_classes(): - classes = self.sorted_classes(group) + classes = self.sort_classes(group) if len(set(map(get_target_namespace, classes))) > 1: raise CodeGenerationError( "Found strongly connected classes from different " @@ -95,7 +118,16 @@ def group_by_namespace_clusters(self): module = classes[0].name self.assign(classes, ".".join(parts), module) - def sorted_classes(self, qnames: Set[str]) -> List[Class]: + def sort_classes(self, qnames: Set[str]) -> List[Class]: + """Sort classes by their dependencies graph. + + Args: + qnames: A set of qualified class names + + Returns: + A class list in a safe to generate order. + + """ edges = { qname: set(self.container.first(qname).dependencies()).intersection(qnames) for qname in qnames @@ -103,11 +135,21 @@ def sorted_classes(self, qnames: Set[str]) -> List[Class]: return [self.container.first(qname) for qname in toposort_flatten(edges)] def strongly_connected_classes(self) -> Iterator[Set[str]]: + """Compute strongly connected classes of a directed graph. + + Returns: + A list of sets of qualified class names. + """ edges = {obj.qname: list(set(obj.dependencies(True))) for obj in self.container} return strongly_connected_components(edges) @classmethod def assign(cls, classes: Iterable[Class], package: str, module: str): + """Assign package and model to classes. + + It's important to assign the same for any inner/nested + classes as well. + """ for obj in classes: obj.package = package obj.module = module @@ -115,6 +157,14 @@ def assign(cls, classes: Iterable[Class], package: str, module: str): @classmethod def group_common_paths(cls, paths: Iterable[str]) -> List[List[str]]: + """Group a list of file paths by their common paths. + + Args: + paths: A list of file paths + + Returns: + A list of file lists that belong to the same common path. + """ prev = "" index = 0 groups = defaultdict(list) @@ -135,6 +185,20 @@ def group_common_paths(cls, paths: Iterable[str]) -> List[List[str]]: return list(groups.values()) def combine_ns_package(self, namespace: Optional[str]) -> List[str]: + """Combine the output package with a namespace. + + You can add aliases to namespace uri with the + substitutions configuration. + + Without Alias: + urn:foo-bar:add -> ["generated", "bar", "foo", "add"] + + With Package Alias: urn:foo-bar:add -> add.again + urn:foo-bar:add -> ["generated", "add", "again"] + + Returns: + The package path as a list of strings. + """ result = self.container.config.output.package.split(".") if namespace: diff --git a/xsdata/codegen/handlers/filter_classes.py b/xsdata/codegen/handlers/filter_classes.py index 47bb4df56..4219cd1e3 100644 --- a/xsdata/codegen/handlers/filter_classes.py +++ b/xsdata/codegen/handlers/filter_classes.py @@ -7,12 +7,22 @@ class FilterClasses(ContainerHandlerInterface): - """Filter classes for code generation based on the configuration output - filter strategy.""" + """Filter classes for code generation based on the configuration strategy.""" __slots__ = () def run(self): + """Main entrypoint to filter the class container. + + In order for a class to be considered global it has + to be a non-abstract element, a complex type without + simple content or a wsdl binding element. + + Strategies: + - Filter all global classes and the referenced simple types. + - Filter global classes with references to other global classes. + - Filter all classes + """ classes = [] filter_strategy = self.container.config.output.filter_strategy if filter_strategy == ClassFilterStrategy.ALL_GLOBALS: @@ -30,7 +40,14 @@ def run(self): ) def filter_all_globals(self) -> List[Class]: - """Filter all globals and any referenced types.""" + """Filter all globals and any referenced types. + + This filter is trying to remove unused simple + types. + + Returns: + The list of classes for generation. + """ occurs = set() for obj in self.container: if obj.is_global_type: @@ -40,7 +57,14 @@ def filter_all_globals(self) -> List[Class]: return [obj for obj in self.container if obj.ref in occurs] def filter_referred_globals(self) -> List[Class]: - """Filter globals with any references.""" + """Filter globals with any references. + + This filter is trying to remove unused global + types. + + Returns: + The list of classes for generation. + """ occurs = set() for obj in self.container: if obj.is_global_type: diff --git a/xsdata/codegen/handlers/flatten_attribute_groups.py b/xsdata/codegen/handlers/flatten_attribute_groups.py index 291c0027e..af7e9953d 100644 --- a/xsdata/codegen/handlers/flatten_attribute_groups.py +++ b/xsdata/codegen/handlers/flatten_attribute_groups.py @@ -10,11 +10,13 @@ class FlattenAttributeGroups(RelativeHandlerInterface): __slots__ = () def process(self, target: Class): - """ - Iterate over all group attributes and apply handler logic. + """Iterate over all group attributes and apply handler logic. Group attributes can refer to attributes or other group attributes, repeat until there is no group attribute left. + + Args: + target: The target class instance to inspect and process """ repeat = False for attr in list(target.attrs): @@ -26,11 +28,19 @@ def process(self, target: Class): self.process(target) def process_attribute(self, target: Class, attr: Attr): - """ - Find the source class the attribute refers to and copy its attributes - to the target class. + """Process a group/attributeGroup attr. + + Steps: + 1. Find the source class by the attr type and tag + 2. If the attr is circular reference, remove the attr + 3. Otherwise, copy all source attrs to the target class + + Args: + target: The target class instance + attr: The group attr to flatten - :raises AnalyzerValueError: if source class is not found. + Raises: + AnalyzerValueError: if source class is not found. """ qname = attr.types[0].qname # group attributes have one type only. source = self.container.find(qname, condition=lambda x: x.tag == attr.tag) diff --git a/xsdata/codegen/handlers/flatten_class_extensions.py b/xsdata/codegen/handlers/flatten_class_extensions.py index 68f9b710d..329f79a17 100644 --- a/xsdata/codegen/handlers/flatten_class_extensions.py +++ b/xsdata/codegen/handlers/flatten_class_extensions.py @@ -14,14 +14,24 @@ class FlattenClassExtensions(RelativeHandlerInterface): __slots__ = () def process(self, target: Class): - """Iterate and process the target class's extensions in reverser - order.""" + """Process a class' extensions. + + Args: + target: The target class instance + """ for extension in list(target.extensions): self.process_extension(target, extension) def process_extension(self, target: Class, extension: Extension): - """Slit the process of extension into schema data types and user - defined types.""" + """Process a class extension. + + Slit the process to native xsd extensions and user defined + types. + + Args: + target: The target class instance + extension: The class extension instance + """ if extension.type.native: self.process_native_extension(target, extension) else: @@ -29,12 +39,15 @@ def process_extension(self, target: Class, extension: Extension): @classmethod def process_native_extension(cls, target: Class, extension: Extension): - """ - Native type flatten handler. + """Native type flatten handler. In case of enumerations copy the native data type to all enum - members, otherwise create a default text value with the + members, otherwise add a default text attr with the extension attributes. + + Args: + target: The target class instance + extension: The class extension instance """ if target.is_enumeration: cls.replace_attributes_type(target, extension) @@ -42,7 +55,18 @@ def process_native_extension(cls, target: Class, extension: Extension): cls.add_default_attribute(target, extension) def process_dependency_extension(self, target: Class, extension: Extension): - """User defined type flatten handler.""" + """Process user defined extension types. + + Case: + - Extension source is missing + - Target class is an enumeration + - Extension source is a simple type or an enumeration + - Extension source is a complex type + + Args: + target: The target class instance + extension: The class extension instance + """ source = self.find_dependency(extension.type) if not source: logger.warning("Missing extension type: %s", extension.type.name) @@ -55,10 +79,12 @@ def process_dependency_extension(self, target: Class, extension: Extension): self.process_complex_extension(source, target, extension) def process_enum_extension( - self, source: Class, target: Class, ext: Optional[Extension] + self, + source: Class, + target: Class, + extension: Optional[Extension], ): - """ - Process enumeration class extension. + """Process an enumeration class extension. Cases: 1. Source is an enumeration: merge them @@ -66,6 +92,11 @@ def process_enum_extension( 3. Source is a complex type 3.1 Target has a single member: Restrict default value 3.2 Target has multiple members: unsupported reset enumeration + + Args: + source: The source class instance + target: The target class instance + extension: The class extension instance """ if source.is_enumeration: self.merge_enumerations(source, target) @@ -78,11 +109,17 @@ def process_enum_extension( # the target enumeration, mypy doesn't play nicely. target.attrs.clear() - if ext and target.is_enumeration: - target.extensions.remove(ext) + if extension and target.is_enumeration: + target.extensions.remove(extension) @classmethod def merge_enumerations(cls, source: Class, target: Class): + """Merge enumeration members from source to target class. + + Args: + source: The source class instance + target: The target class instance + """ source_attrs = {attr.name: attr for attr in source.attrs} target.attrs = [ source_attrs[attr.name].clone() if attr.name in source_attrs else attr @@ -90,6 +127,12 @@ def merge_enumerations(cls, source: Class, target: Class): ] def merge_enumeration_types(self, source: Class, target: Class): + """Merge the enumeration attr types and restrictions. + + Args: + source: The source class instance + target: The target class instance + """ source_attr = source.attrs[0] for tp in source_attr.types: if tp.native: @@ -106,8 +149,16 @@ def merge_enumeration_types(self, source: Class, target: Class): @classmethod def set_default_value(cls, source: Class, target: Class): - """Restrict the extension source class with the target single - enumeration value.""" + """Set the default value from the source single enumeration. + + When a simple type is a restriction of an enumeration with + only one member, we can safely set its default value + to that member value as fixed. + + Args: + source: The source class instance + target: The target class instance + """ new_attr = ClassUtils.find_value_attr(source).clone() new_attr.types = target.attrs[0].types new_attr.default = target.attrs[0].default @@ -116,16 +167,18 @@ def set_default_value(cls, source: Class, target: Class): @classmethod def process_simple_extension(cls, source: Class, target: Class, ext: Extension): - """ - Simple flatten extension handler for common classes eg SimpleType, - Restriction. + """Process simple type extensions. - Steps: + Cases: 1. If target is source: drop the extension. 2. If source is enumeration and target isn't create default value attribute. 3. If both source and target are enumerations copy all attributes. - 4. If both source and target are not enumerations copy all attributes. - 5. If target is enumeration: drop the extension. + 4. If target is enumeration: drop the extension. + + Args: + source: The source class instance + target: The target class instance + ext: The extension class instance """ if source is target: target.extensions.remove(ext) @@ -138,13 +191,16 @@ def process_simple_extension(cls, source: Class, target: Class, ext: Extension): @classmethod def process_complex_extension(cls, source: Class, target: Class, ext: Extension): - """ - Complex flatten extension handler for primary classes eg ComplexType, - Element. + """Process complex type extensions. Compare source and target classes and either remove the extension completely, copy all source attributes to the target class or leave the extension alone. + + Args: + source: The source class instance + target: The target class instance + ext: The extension class instance """ if cls.should_remove_extension(source, target, ext): target.extensions.remove(ext) @@ -154,10 +210,15 @@ class or leave the extension alone. ext.type.reference = id(source) def find_dependency(self, attr_type: AttrType) -> Optional[Class]: - """ - Find dependency for the given extension type with priority. + """Find dependency for the given extension type with priority. Search priority: xs:SimpleType > xs:ComplexType + + Args: + attr_type: The attr type instance + + Returns: + The class instance or None if it's undefined. """ conditions = ( lambda x: x.tag == Tag.SIMPLE_TYPE, @@ -173,23 +234,28 @@ def find_dependency(self, attr_type: AttrType) -> Optional[Class]: @classmethod def should_remove_extension( - cls, source: Class, target: Class, ext: Extension + cls, + source: Class, + target: Class, + extension: Extension, ) -> bool: - """ - Return whether the extension should be removed because of some - violation. + """Return whether the extension should be removed. Violations: - Circular Reference - Forward Reference - Unordered sequences - MRO Violation A(B), C(B) and extensions includes A, B, C + + Args: + source: The source class instance + target: The target class instance + extension: The extension class instance """ - # Circular or Forward reference if ( source is target or target in source.inner - or cls.have_unordered_sequences(source, target, ext) + or cls.have_unordered_sequences(source, target, extension) ): return True @@ -199,16 +265,17 @@ def should_remove_extension( @classmethod def should_flatten_extension(cls, source: Class, target: Class) -> bool: - """ - Return whether the extension should be flattened because of rules. + """Return whether the extension should be flattened. Rules: 1. Source doesn't have a parent class 2. Source class is a simple type 3. Source class has a suffix attr and target has its own attrs 4. Target class has a suffix attr - 5. Target restrictions parent attrs in different sequence order - 6. Target restricts parent attr with a not matching type. + + Args: + source: The source class instance + target: The target class instance """ if not source.extensions and ( source.is_simple_type @@ -221,10 +288,12 @@ def should_flatten_extension(cls, source: Class, target: Class) -> bool: @classmethod def have_unordered_sequences( - cls, source: Class, target: Class, ext: Extension + cls, + source: Class, + target: Class, + extension: Extension, ) -> bool: - """ - Validate sequence attributes are in the same order in the parent class. + """Validate overriding sequence attrs are in order. Dataclasses fields ordering follows the python mro pattern, the parent fields are always first, and they are updated if the @@ -233,9 +302,13 @@ def have_unordered_sequences( @todo This needs a complete rewrite and most likely it needs to @todo move way down in the process chain. - """ - if ext.tag == Tag.EXTENSION or source.extensions: + Args: + source: The source class instance + target: The target class instance + extension: The extension class instance + """ + if extension.tag == Tag.EXTENSION or source.extensions: return False sequence = [ @@ -252,18 +325,30 @@ def have_unordered_sequences( @classmethod def replace_attributes_type(cls, target: Class, extension: Extension): - """Replace all target attributes types with the extension's type and - remove it from the target class extensions.""" + """Replace all attrs types with the extension's type. + + The extension is a native xsd datatype. + Args: + target: The target class instance + extension: The extension class instance + """ + target.extensions.remove(extension) for attr in target.attrs: attr.types.clear() attr.types.append(extension.type.clone()) - target.extensions.remove(extension) @classmethod def add_default_attribute(cls, target: Class, extension: Extension): - """Add a default value field to the given class based on the extension - type.""" + """Convert extension to a value text attr. + + If the extension type is xs:anyType convert the + attr into a wildcard attr to match everything. + + Args: + target: The target class instance + extension: The extension class instance + """ if extension.type.datatype != DataType.ANY_TYPE: tag = Tag.EXTENSION name = DEFAULT_ATTR_NAME @@ -281,9 +366,16 @@ def add_default_attribute(cls, target: Class, extension: Extension): @classmethod def get_or_create_attribute(cls, target: Class, name: str, tag: str) -> Attr: - """Find or create for the given parameters an attribute in the target - class.""" + """Find or create an attr with the given name and tag. + If the attr doesn't exist, create a new required + attr and prepend it in the attrs list. + + Args: + target: The target class instance + name: The attr name + tag: The attr tag name + """ attr = ClassUtils.find_attr(target, name) if attr is None: attr = Attr(name=name, tag=tag) diff --git a/xsdata/codegen/handlers/merge_attributes.py b/xsdata/codegen/handlers/merge_attributes.py index 565c01ae3..a8d531dbe 100644 --- a/xsdata/codegen/handlers/merge_attributes.py +++ b/xsdata/codegen/handlers/merge_attributes.py @@ -7,17 +7,20 @@ class MergeAttributes(HandlerInterface): - """Merge same type attributes and their restrictions.""" + """Merge same type attr and their restrictions.""" __slots__ = () @classmethod def process(cls, target: Class): - """ - Detect and process duplicate attributes. + """Detect and process duplicate attributes. + + Cases: + - Enumeration: remove duplicates + - Otherwise: merge attrs and - - Remove duplicates for enumerations. - - Merge duplicates with restrictions and types. + Args: + target: The target class instance """ if target.is_enumeration: cls.filter_duplicate_attrs(target) @@ -26,11 +29,24 @@ def process(cls, target: Class): @classmethod def filter_duplicate_attrs(cls, target: Class): + """Removes duplicate attrs. + + Args: + target: The target class instance + """ attrs = collections.unique_sequence(target.attrs, key="default") target.attrs = attrs @classmethod - def merge_duplicate_attrs(self, target: Class): + def merge_duplicate_attrs(cls, target: Class): + """Find duplicate attrs and merge them. + + In order for two attrs to be considered duplicates, + they must have the same name, namespace and tag. + + Args: + target: The target class instance + """ result: List[Attr] = [] for attr in target.attrs: pos = collections.find(result, attr) diff --git a/xsdata/codegen/handlers/process_attributes_types.py b/xsdata/codegen/handlers/process_attributes_types.py index 38c2b21a8..2b145b6bd 100644 --- a/xsdata/codegen/handlers/process_attributes_types.py +++ b/xsdata/codegen/handlers/process_attributes_types.py @@ -9,8 +9,14 @@ class ProcessAttributeTypes(RelativeHandlerInterface): - """Minimize class attributes complexity by filtering and flattening - types.""" + """Minimize class attrs complexity by filtering and flattening types. + + Args: + container: The container instance + + Attributes: + dependencies: Class qname dependencies mapping + """ __slots__ = "dependencies" @@ -19,13 +25,24 @@ def __init__(self, container: ContainerInterface): self.dependencies: Dict = {} def process(self, target: Class): - """Process the given class attributes and their types.""" + """Process the given class attrs and their types. + + Cascades class restrictions to class attrs. + + Args: + target: The target class instance + """ for attr in list(target.attrs): self.process_types(target, attr) self.cascade_properties(target, attr) def process_types(self, target: Class, attr: Attr): - """Process every attr type and filter out duplicates.""" + """Process every attr type and filter out duplicates. + + Args: + target: The target class instance + attr: The attr instance + """ if self.container.config.output.ignore_patterns: attr.restrictions.pattern = None @@ -36,8 +53,17 @@ def process_types(self, target: Class, attr: Attr): @classmethod def cascade_properties(cls, target: Class, attr: Attr): - """Cascade target class default/fixed/nillable properties to the given - attr if it's a text node.""" + """Cascade class properties to the attr if it's a text node. + + Properties: + - Default value + - Fixed flag + - Nillable flag + + Args: + target: The target class instance + attr: The attr instance + """ if attr.xml_type is None: if target.default is not None and attr.default is None: attr.default = target.default @@ -47,8 +73,18 @@ def cascade_properties(cls, target: Class, attr: Attr): attr.restrictions.nillable = True def process_type(self, target: Class, attr: Attr, attr_type: AttrType): - """Process attribute type, split process for xml schema and user - defined types.""" + """Process attr type. + + Cases: + - Attr type is a native xsd type + - Attr type is a forward reference (inner class) + - Attr type is a complex user defined type + + Args: + target: The target class instance + attr: The attr instance + attr_type: The attr type instance + """ if attr_type.native: self.process_native_type(attr, attr_type) elif attr_type.forward: @@ -58,11 +94,15 @@ def process_type(self, target: Class, attr: Attr, attr_type: AttrType): @classmethod def process_native_type(cls, attr: Attr, attr_type: AttrType): - """ - Process native attribute types. + """Process native xsd types. + + Cascade the datatype restrictions to the attr and also + resets the type to a simple xsd:string if there is a pattern + restriction. - - Update restrictions from the datatype - - Reset attribute type if there is a pattern restriction + Args: + attr: The attr instance + attr_type: The attr type instance """ datatype = attr_type.datatype @@ -74,14 +114,20 @@ def process_native_type(cls, attr: Attr, attr_type: AttrType): cls.reset_attribute_type(attr_type) def find_dependency(self, attr_type: AttrType, tag: str) -> Optional[Class]: - """ - Find dependency for the given attribute and tag. + """Find the source type from the attr type and tag. Avoid conflicts by selecting any matching type by qname and preferably: 1. Match the candidate object tag 2. Match element again complexType 3. Match non element and complexType 4. Anything + + Args: + attr_type: The attr type instance + tag: The xml tag name, e.g. Element, Attribute, ComplexType + + Returns: + The source class or None if no match is found """ conditions = ( lambda obj: obj.tag == tag, @@ -98,10 +144,14 @@ def find_dependency(self, attr_type: AttrType, tag: str) -> Optional[Class]: return None def process_inner_type(self, target: Class, attr: Attr, attr_type: AttrType): - """ - Process an attributes type that depends on an inner type. + """Process an attr type that depends on a simple inner type. + + Skip If the source class is not simple type, or it's a circular reference. - Ignore inner circular references. + Args: + target: The target class instance + attr: The attr instance + attr_type: The attr type instance """ if attr_type.circular: return @@ -112,14 +162,18 @@ def process_inner_type(self, target: Class, attr: Attr, attr_type: AttrType): target.inner.remove(inner) def process_dependency_type(self, target: Class, attr: Attr, attr_type: AttrType): - """ - Process an attributes type that depends on any global type. + """Process an attr type that depends on any global type. Strategies: 1. Reset absent types with a warning 2. Copy attribute properties from a simple type 3. Copy format restriction from an enumeration 4. Set circular flag for the rest + + Args: + target: The target class instance + attr: The attr instance + attr_type: The attr type instance """ source = self.find_dependency(attr_type, attr.tag) if not source: @@ -147,17 +201,25 @@ def process_dependency_type(self, target: Class, attr: Attr, attr_type: AttrType @classmethod def copy_attribute_properties( - cls, source: Class, target: Class, attr: Attr, attr_type: AttrType + cls, + source: Class, + target: Class, + attr: Attr, + attr_type: AttrType, ): - """ - Replace the given attribute type with the types of the single field - source class. + """Replace the attr type with the types of the first attr in the source class. Ignore enumerations and gracefully handle dump types with no - attributes. + attrs. - :raises: AnalyzerValueError if the source class has more than - one attributes + Args: + source: The source class instance + target: The target class instance + attr: The attr instance + attr_type: The attr type instance + + Raises: + AnalyzerValueError: if the source class has more than one attributes """ source_attr = source.attrs[0] index = attr.types.index(attr_type) @@ -168,7 +230,7 @@ def copy_attribute_properties( attr.types.insert(index, clone_type) index += 1 - ClassUtils.copy_inner_class(source, target, attr, clone_type) + ClassUtils.copy_inner_class(source, target, clone_type) restrictions = source_attr.restrictions.clone() restrictions.merge(attr.restrictions) @@ -186,7 +248,13 @@ def copy_attribute_properties( attr.default = attr.default or source_attr.default def set_circular_flag(self, source: Class, target: Class, attr_type: AttrType): - """Update circular reference flag.""" + """Detect circular references and set the type flag. + + Args: + source: The source class instance + target: The target class instance + attr_type: The attr type instance + """ attr_type.reference = id(source) attr_type.circular = self.is_circular_dependency(source, target, set()) @@ -194,9 +262,16 @@ def set_circular_flag(self, source: Class, target: Class, attr_type: AttrType): logger.debug("Possible circular reference %s, %s", target.name, source.name) def is_circular_dependency(self, source: Class, target: Class, seen: Set) -> bool: - """Check if any source dependencies recursively match the target - class.""" + """Check if any source dependencies recursively match the target class. + + Args: + source: The source class instance + target: The target class instance + seen: A set of qualified names, to guard against recursive runs + Returns: + The bool result + """ if source is target or source.status == Status.FLATTENING: return True @@ -210,8 +285,14 @@ def is_circular_dependency(self, source: Class, target: Class, seen: Set) -> boo return False def cached_dependencies(self, source: Class) -> Tuple[str]: - """Returns from cache the source class dependencies as a collection of - qualified names.""" + """Returns from cache the source class dependencies. + + Args: + source: The source class instance + + Returns: + A tuple containing the qualified names of the class dependencies. + """ cache_key = id(source) if cache_key not in self.dependencies: self.dependencies[cache_key] = tuple(source.dependencies()) @@ -220,7 +301,15 @@ def cached_dependencies(self, source: Class) -> Tuple[str]: @classmethod def reset_attribute_type(cls, attr_type: AttrType, use_str: bool = True): - """Reset the attribute type to string or any simple type.""" + """Reset the attribute type to string or any simple type. + + The method will also unset the circular/forward flags, as native + types only depend on python builtin types. + + Args: + attr_type: The attr type instance to reset + use_str: Whether to use xs:string or xs:anySimpleType + """ attr_type.qname = str(DataType.STRING if use_str else DataType.ANY_SIMPLE_TYPE) attr_type.native = True attr_type.circular = False @@ -228,6 +317,16 @@ def reset_attribute_type(cls, attr_type: AttrType, use_str: bool = True): @classmethod def update_restrictions(cls, attr: Attr, datatype: DataType): + """Helper method to copy the native datatype restriction to the attr. + + Sets: + - The format restriction, e.g. hexBinary, base64Binary + - The tokens flag for xs:NMTOKENS and xs:IDREFS + + Args: + attr: The attr to update + datatype: The datatype to extract the restrictions. + """ attr.restrictions.format = datatype.format if datatype in (DataType.NMTOKENS, DataType.IDREFS): @@ -235,13 +334,17 @@ def update_restrictions(cls, attr: Attr, datatype: DataType): @classmethod def detect_lazy_namespace(cls, source: Class, target: Class, attr: Attr): - """ - Override attr namespace with the source namespace when during the - initial mapping the namespace detection wasn't possible. + """Set the attr namespace if the current is marked as lazy. + + Cases: + WSDL message part type can be an element, complex or + simple type, we can't do the detection during the initial + mapping to class objects. - Case 1: WSDL message part type can be an element, complex or - simple type, we can't do the detection during the initial - mapping to class objects. + Args: + source: The source class instance + target: The target class instance + attr: The target class attr instance """ if attr.namespace == "##lazy": logger.warning( diff --git a/xsdata/codegen/handlers/process_mixed_content_class.py b/xsdata/codegen/handlers/process_mixed_content_class.py index 6e760e4d5..b907364ce 100644 --- a/xsdata/codegen/handlers/process_mixed_content_class.py +++ b/xsdata/codegen/handlers/process_mixed_content_class.py @@ -6,18 +6,18 @@ class ProcessMixedContentClass(HandlerInterface): - """ - Mixed content handler. - - If the target class supports mixed content, a new wildcard attr will - replace the originals except any attributes. All the previous attrs - derived from xs:element will be moved as choices for the new - content attr. - """ + """Mixed content handler.""" __slots__ = () def process(self, target: Class): + """Add a wildcard attr if the class supports mixed content. + + All other elements will be moved as the wildcard attr choices. + + Args: + target: Tha target class instance + """ if not target.is_mixed: return diff --git a/xsdata/codegen/handlers/rename_duplicate_attributes.py b/xsdata/codegen/handlers/rename_duplicate_attributes.py index 6060a826c..a5b906ada 100644 --- a/xsdata/codegen/handlers/rename_duplicate_attributes.py +++ b/xsdata/codegen/handlers/rename_duplicate_attributes.py @@ -5,22 +5,25 @@ from xsdata.utils.constants import DEFAULT_ATTR_NAME -def attr_group_name(x: Attr) -> str: - return x.slug or DEFAULT_ATTR_NAME - - class RenameDuplicateAttributes(HandlerInterface): - """Resolve attribute name conflicts defined in the class.""" + """Resolve attr name conflicts defined in the class.""" __slots__ = () def process(self, target: Class): - """Sanitize duplicate attribute names that might exist by applying to - rename strategies.""" - grouped = group_by(target.attrs, key=attr_group_name) + """Detect and resolve naming conflicts. + + Args: + target: The target class instance + """ + grouped = group_by(target.attrs, key=self._attr_unique_slug) for items in grouped.values(): total = len(items) if total == 2 and not items[0].is_enumeration: ClassUtils.rename_attribute_by_preference(*items) elif total > 1: ClassUtils.rename_attributes_by_index(target.attrs, items) + + @staticmethod + def _attr_unique_slug(attr: Attr) -> str: + return attr.slug or DEFAULT_ATTR_NAME diff --git a/xsdata/codegen/handlers/rename_duplicate_classes.py b/xsdata/codegen/handlers/rename_duplicate_classes.py index 4b3887820..fbf7c794c 100644 --- a/xsdata/codegen/handlers/rename_duplicate_classes.py +++ b/xsdata/codegen/handlers/rename_duplicate_classes.py @@ -14,10 +14,7 @@ class RenameDuplicateClasses(ContainerHandlerInterface): __slots__ = () def run(self): - """Search for conflicts either by qualified name or local name - depending on the configuration and start renaming classes and - dependencies.""" - + """Detect and resolve class name conflicts.""" use_name = self.should_use_names() getter = get_name if use_name else get_qname groups = collections.group_by(self.container, lambda x: text.alnum(getter(x))) @@ -27,14 +24,12 @@ def run(self): self.rename_classes(classes, use_name) def should_use_names(self) -> bool: - """ - Determine if we should be using names or qualified names to detect - collisions. + """Determine if names or qualified names should be used for detection. Strict unique names: - - Single package - - Clustered packages - - All classes have the same source location. + - Single package + - Clustered packages + - All classes have the same source location. """ return ( self.container.config.output.structure_style in REQUIRE_UNIQUE_NAMES @@ -42,11 +37,15 @@ def should_use_names(self) -> bool: ) def rename_classes(self, classes: List[Class], use_name: bool): - """ - Rename all the classes in the list. + """Rename all the classes in the list. Protect classes derived from xs:element if there is only one in the list. + + Args: + classes: A list of classes with duplicate names + use_name: Whether simple or qualified names should be used + during renaming """ total_elements = sum(x.is_element for x in classes) for target in sorted(classes, key=get_name): @@ -54,10 +53,17 @@ def rename_classes(self, classes: List[Class], use_name: bool): self.rename_class(target, use_name) def rename_class(self, target: Class, use_name: bool): - """Find the next available class identifier, save the original name in - the class metadata and update the class qualified name and all classes - that depend on the target class.""" + """Find the next available class name. + Save the original name in the class metadata and update + the class qualified name and all classes that depend on + the target class. + + Args: + target: The target class instance to rename + use_name: Whether simple or qualified names should be + used during renaming + """ qname = target.qname namespace, name = namespaces.split_qname(target.qname) target.qname = self.next_qname(namespace, name, use_name) @@ -68,8 +74,14 @@ def rename_class(self, target: Class, use_name: bool): self.rename_class_dependencies(item, id(target), target.qname) def next_qname(self, namespace: str, name: str, use_name: bool) -> str: - """Append the next available index number for the given namespace and - local name.""" + """Use int suffixes to get the next available qualified name. + + Args: + namespace: The class namespace + name: The class name + use_name: Whether simple or qualified names should be + used during renaming + """ index = 0 if use_name: @@ -87,9 +99,13 @@ def next_qname(self, namespace: str, name: str, use_name: bool) -> str: return qname def rename_class_dependencies(self, target: Class, reference: int, replace: str): - """Search and replace the old qualified attribute type name with the - new one if it exists in the target class attributes, extensions and - inner classes.""" + """Search and replace the old qualified class name in all classes. + + Args: + target: The target class instance to inspect + reference: The reference id of the renamed class + replace: The new qualified name of the renamed class + """ for attr in target.attrs: self.rename_attr_dependencies(attr, reference, replace) @@ -101,8 +117,15 @@ def rename_class_dependencies(self, target: Class, reference: int, replace: str) self.rename_class_dependencies(inner, reference, replace) def rename_attr_dependencies(self, attr: Attr, reference: int, replace: str): - """Search and replace the old qualified attribute type name with the - new one in the attr types, choices and default value.""" + """Search and replace the old qualified class name in the attr types. + + This also covers any choices and references to enum values. + + Args: + attr: The target attr instance to inspect + reference: The reference id of the renamed class + replace: The new qualified name of the renamed class + """ for attr_type in attr.types: if attr_type.reference == reference: attr_type.qname = replace diff --git a/xsdata/codegen/handlers/reset_attribute_sequence_numbers.py b/xsdata/codegen/handlers/reset_attribute_sequence_numbers.py index 9a2214acc..bc21727b6 100644 --- a/xsdata/codegen/handlers/reset_attribute_sequence_numbers.py +++ b/xsdata/codegen/handlers/reset_attribute_sequence_numbers.py @@ -5,15 +5,24 @@ class ResetAttributeSequenceNumbers(RelativeHandlerInterface): - """ - Reset attributes sequence numbers. + """Reset attrs sequence numbers. + + The sequence numbers are the ids of xs:sequence elements, because + up until now it was important to determine which child/parent + attrs belong to different sequence numbers. - Until now all sequence numbers point to the id of sequence class!!! + Before we generate the classes let's reset them to simple auto + increment numbers per class. """ __slots__ = () def process(self, target: Class): + """Process entrypoint for classes. + + Args: + target: The target class instance + """ groups = defaultdict(list) for attr in target.attrs: if attr.restrictions.sequence: @@ -28,10 +37,15 @@ def process(self, target: Class): next_sequence_number += 1 def find_next_sequence_number(self, target: Class) -> int: - return ( - max( - (attr.restrictions.sequence or 0 for attr in self.base_attrs(target)), - default=0, - ) - + 1 - ) + """Calculate the next sequence number from the base classes. + + Args: + target: The target class instance + + Returns: + The next sequence number + """ + base_attrs = self.base_attrs(target) + sequences = (attr.restrictions.sequence or 0 for attr in base_attrs) + max_sequence = max(sequences, default=0) + return max_sequence + 1 diff --git a/xsdata/codegen/handlers/reset_attribute_sequences.py b/xsdata/codegen/handlers/reset_attribute_sequences.py index 8b7655636..e7c5e92d8 100644 --- a/xsdata/codegen/handlers/reset_attribute_sequences.py +++ b/xsdata/codegen/handlers/reset_attribute_sequences.py @@ -4,12 +4,20 @@ class ResetAttributeSequences(HandlerInterface): - """Validate if fields are part of a repeatable sequence otherwise reset the - sequence flag.""" + """Inspect a class for non-repeatable choices and unset the sequence number.""" __slots__ = () def process(self, target: Class): + """Process entrypoint for classes. + + Reset Cases: + - A sequence only contains one attr + - The sequence includes attrs with max_occurs==1 + + Args: + target: The target class instance + """ groups = collections.group_by(target.attrs, get_restriction_sequence) for sequence, attrs in groups.items(): if not sequence: @@ -24,6 +32,16 @@ def process(self, target: Class): @classmethod def is_repeatable_sequence(cls, attr: Attr) -> bool: + """Determine if the given attr is repeatable. + + Repeatable means max_occurs > 1 + + Args: + attr: The attr instance + + Returns: + The bool result + """ seq = attr.restrictions.sequence if seq: for path in attr.restrictions.path: diff --git a/xsdata/codegen/handlers/sanitize_attributes_default_value.py b/xsdata/codegen/handlers/sanitize_attributes_default_value.py index aaddb980c..68624a3b4 100644 --- a/xsdata/codegen/handlers/sanitize_attributes_default_value.py +++ b/xsdata/codegen/handlers/sanitize_attributes_default_value.py @@ -6,8 +6,7 @@ class SanitizeAttributesDefaultValue(RelativeHandlerInterface): - """ - Sanitize attributes default values. + """Sanitize attributes default values. Cases: 1. Ignore enumerations. @@ -20,6 +19,13 @@ class SanitizeAttributesDefaultValue(RelativeHandlerInterface): __slots__ = () def process(self, target: Class): + """Process entrypoint for classes. + + Inspect all attrs and attr choices. + + Args: + target: The target class instance. + """ for attr in target.attrs: self.process_attribute(target, attr) @@ -27,6 +33,18 @@ def process(self, target: Class): self.process_attribute(target, choice) def process_attribute(self, target: Class, attr: Attr): + """Process entrypoint for attrs. + + Cases: + - Reset min_occurs + - Reset default value + - Validate default value against types + - Set empty string as default value for string text nodes. + + Args: + target: The target class instance + attr: The attr instance + """ if self.should_reset_required(attr): attr.restrictions.min_occurs = 0 @@ -41,6 +59,12 @@ def process_attribute(self, target: Class, attr: Attr): attr.default = "" def process_types(self, target: Class, attr: Attr): + """Reset attr types if default value doesn't pass validation. + + Args: + target: The target class instance + attr: The attr instance + """ if self.is_valid_external_value(target, attr): return @@ -58,12 +82,17 @@ def process_types(self, target: Class, attr: Attr): self.reset_attribute_types(attr) def is_valid_external_value(self, target: Class, attr: Attr) -> bool: - """Return whether the default value of the given attr can be mapped to - user defined type like an enumeration or an inner complex content - class.""" + """Validate user defined types. + Only enumerations and complex content inner types are supported. + + Args: + target: The target class instance + attr: The attr instance + . + """ for tp in attr.user_types: - source = self.find_type(target, tp) + source = self.find_inner_type(target, tp) if self.is_valid_inner_type(source, attr, tp): return True @@ -72,7 +101,16 @@ def is_valid_external_value(self, target: Class, attr: Attr) -> bool: return False - def find_type(self, target: Class, attr_type: AttrType) -> Class: + def find_inner_type(self, target: Class, attr_type: AttrType) -> Class: + """Find the inner class for the given attr type. + + Args: + target: The target class instance + attr_type: The attr type instance + + Returns: + The inner class instance. + """ if attr_type.forward: return self.container.find_inner(target, attr_type.qname) @@ -80,10 +118,23 @@ def find_type(self, target: Class, attr_type: AttrType) -> Class: @classmethod def is_valid_inner_type( - cls, source: Class, attr: Attr, attr_type: AttrType + cls, + source: Class, + attr: Attr, + attr_type: AttrType, ) -> bool: - """Return whether the inner class can inherit the attr default value - and swap them as well.""" + """Return whether the inner class can inherit the attr default value. + + If it does, then swap the default/fixed values. + + Args: + source: The inner source class instance + attr: The attr instance + attr_type: The attr type instance + + Returns: + The bool result. + """ if attr_type.forward: for src_attr in source.attrs: if src_attr.xml_type is None: @@ -96,15 +147,24 @@ def is_valid_inner_type( @classmethod def is_valid_enum_type(cls, source: Class, attr: Attr) -> bool: - """ - Convert string literal default values to enumeration members - placeholders and return result. + """Return whether the default value matches an enum member. + + If it does convert the string literal default value to + enumeration members placeholders. The placeholder will be + converted to proper references from the generator filters, + because we don't know yet the enumeration class name. + - The placeholders will be converted to proper references from the - generator filters. + Placeholder examples: + Single -> @enum@qname::member_name + Multiple -> @enum@qname::first_member@second_member - Placeholder examples: Single -> @enum@qname::member_name - Multiple -> @enum@qname::first_member@second_member + Args: + source: The inner source class instance + attr: The attr instance + + Returns: + The bool result. """ assert attr.default is not None @@ -127,13 +187,18 @@ def is_valid_enum_type(cls, source: Class, attr: Attr) -> bool: @classmethod def is_valid_native_value(cls, target: Class, attr: Attr) -> bool: - """ - Return whether the default value of the given attribute can be - converted successfully to and from xml. + """Return whether the default value can be converted successfully. The test process for enumerations and fixed value fields are strict, meaning the textual representation also needs to match the original. + + Args: + target: The target class instance + attr: The attr instance + + Returns: + The bool result. """ assert attr.default is not None @@ -162,10 +227,19 @@ def is_valid_native_value(cls, target: Class, attr: Attr) -> bool: @classmethod def should_reset_required(cls, attr: Attr) -> bool: - """ - Return whether the min occurrences for the attr needs to be reset. + """Return whether the min_occurs needs to be reset. + + Condition: + - Attr not derived from xs:attribute + - It has no default value + - It's derived from xs:anyType + - It has max_occurs==1 - @Todo figure out if wildcards are supposed to be optional! + Args: + attr: The attr instance + + Returns: + The bool result. """ return ( not attr.is_attribute @@ -176,12 +250,18 @@ def should_reset_required(cls, attr: Attr) -> bool: @classmethod def should_reset_default(cls, attr: Attr) -> bool: - """ - Return whether we should unset the default value of the attribute. + """Return whether the default value needs to be reset. - - Default value is not set - - Attribute is xsi:type (ignorable) - - Attribute is part of a choice + Cases: + - Attr is xsi:type (ignorable) + - Attr has max_occurs > 1 + - Attr is not derived from xs:attribute and has min_occurs=0 + + Args: + attr: The attr instance + + Returns: + The bool result. """ return attr.default is not None and ( attr.is_xsi_type @@ -191,6 +271,11 @@ def should_reset_default(cls, attr: Attr) -> bool: @classmethod def reset_attribute_types(cls, attr: Attr): + """Reset the attribute type to string. + + Args: + attr: The attr instance + """ attr.types.clear() attr.types.append(AttrType(qname=str(DataType.STRING), native=True)) attr.restrictions.format = None diff --git a/xsdata/codegen/handlers/sanitize_enumeration_class.py b/xsdata/codegen/handlers/sanitize_enumeration_class.py index 0342a524d..edc392c25 100644 --- a/xsdata/codegen/handlers/sanitize_enumeration_class.py +++ b/xsdata/codegen/handlers/sanitize_enumeration_class.py @@ -11,30 +11,37 @@ class SanitizeEnumerationClass(RelativeHandlerInterface): __slots__ = () def process(self, target: Class): - """ - Process class receiver. + """Process entrypoint for classes. Steps: 1. Filter attrs not derived from xs:enumeration 2. Flatten attrs derived from xs:union of enumerations + + Args: + target: The target class instance """ self.filter(target) self.flatten(target) @classmethod def filter(cls, target: Class): - """Filter attrs not derived from xs:enumeration if there are any - xs:enumeration attrs.""" + """Remove attrs not derived from xs:enumeration. + + Args: + target: The target class instance + """ enumerations = [attr for attr in target.attrs if attr.is_enumeration] if enumerations: target.attrs = enumerations def flatten(self, target: Class): - """ - Flatten attrs derived from xs:union of enumeration classes. + """Flatten attrs derived from xs:union of enumeration classes. Find the enumeration classes and merge all of their members in the target class. + + Args: + target: The target class instance """ if len(target.attrs) != 1 or target.attrs[0].tag != Tag.UNION: return diff --git a/xsdata/codegen/handlers/unnest_inner_classes.py b/xsdata/codegen/handlers/unnest_inner_classes.py index d9eb30c47..6ec7652ea 100644 --- a/xsdata/codegen/handlers/unnest_inner_classes.py +++ b/xsdata/codegen/handlers/unnest_inner_classes.py @@ -6,23 +6,38 @@ class UnnestInnerClasses(RelativeHandlerInterface): - """Unnest class processor.""" + """Promote inner classes to root classes.""" __slots__ = () def process(self, target: Class): - """ - Promote enumeration classes to root classes. + """Process entrypoint for classes. + + Process the target class inner classes recursively. - Candidates - - Enumerations - - All if config is enabled + All enumerations are promoted by default, otherwise + only if the configuration is disabled the classes + are ignored. + + Args: + target: The target class instance to inspect """ for inner in list(target.inner): if inner.is_enumeration or self.container.config.output.unnest_classes: self.promote(target, inner) def promote(self, target: Class, inner: Class): + """Promote the inner class to root classes. + + Steps: + - Replace forward references to the inner class + - Remove inner class from target class + - Copy the class to the global class container. + + Args: + target: The target class + inner: An inner class + """ target.inner.remove(inner) attr = self.find_forward_attr(target, inner.qname) if attr: @@ -32,6 +47,18 @@ def promote(self, target: Class, inner: Class): @classmethod def clone_class(cls, inner: Class, name: str) -> Class: + """Clone and prepare inner class for promotion. + + Clone the inner class, mark it as promoted and pref + the qualified name with the parent class name. + + Args: + inner: The inner class to clone and prepare + name: The parent class name to use a prefix + + Returns: + The new class instance + """ clone = inner.clone() clone.local_type = True clone.qname = build_qname(inner.target_namespace, f"{name}_{inner.name}") @@ -39,6 +66,13 @@ def clone_class(cls, inner: Class, name: str) -> Class: @classmethod def update_types(cls, attr: Attr, search: str, replace: str): + """Update the references from an inner to a global class. + + Args: + attr: The target attr to inspect and update + search: The current inner class qname + replace: The new global class qname + """ for attr_type in attr.types: if attr_type.qname == search and attr_type.forward: attr_type.qname = replace @@ -46,6 +80,16 @@ def update_types(cls, attr: Attr, search: str, replace: str): @classmethod def find_forward_attr(cls, target: Class, qname: str) -> Optional[Attr]: + """Find the first attr that references the given inner class qname. + + Args: + target: The target class instance + qname: An inner class qualified name + + Returns: + Attr: The first attr that references the given qname + None: If no such attr exists, it can happen! + """ for attr in target.attrs: for attr_type in attr.types: if attr_type.forward and attr_type.qname == qname: diff --git a/xsdata/codegen/handlers/update_attributes_effective_choice.py b/xsdata/codegen/handlers/update_attributes_effective_choice.py index ee40ce89c..e84ed8b35 100644 --- a/xsdata/codegen/handlers/update_attributes_effective_choice.py +++ b/xsdata/codegen/handlers/update_attributes_effective_choice.py @@ -7,16 +7,22 @@ class UpdateAttributesEffectiveChoice(HandlerInterface): - """ - Look for fields that are repeated and mark them effectively as choices if - they are not part of symmetrical sequences. + """Detect implied repeated choices and update them. - valid eg: symmetrical sequence: + valid eg: + symmetrical sequence: """ __slots__ = () def process(self, target: Class): + """Process entrypoint for classes. + + Ignore enumerations, for performance reasons. + + Args: + target: The target class instance + """ if target.is_enumeration: return @@ -29,6 +35,11 @@ def process(self, target: Class): @classmethod def reset_symmetrical_choices(cls, target: Class): + """Mark symmetrical choices as sequences. + + Args: + target: The target class instance + """ groups = collections.group_by(target.attrs, get_restriction_choice) for choice, attrs in groups.items(): if choice is None or choice > 0: @@ -52,7 +63,6 @@ def reset_symmetrical_choices(cls, target: Class): attr.restrictions.choice = None cls.reset_effective_choice( attr.restrictions.path, - "s", attr.restrictions.sequence, attr.restrictions.max_occurs, ) @@ -61,18 +71,33 @@ def reset_symmetrical_choices(cls, target: Class): def reset_effective_choice( cls, paths: List[Tuple[str, int, int, int]], - name: str, index: int, max_occur: int, ): + """Update an attr path to resemble a repeatable sequence. + + Args: + paths: The paths of an attr + index: The sequence index + max_occur: The new max occurrences + """ for i, path in enumerate(paths): - if path[0] == name and path[1] == index and path[3] == 1: + if path[0] == "s" and path[1] == index and path[3] == 1: new_path = (*path[:-1], max_occur) paths[i] = new_path break @classmethod def merge_attrs(cls, target: Class, groups: List[List[int]]) -> List[Attr]: + """Merge same name/tag/namespace attrs. + + Args: + target: The target class + groups: The list of connected attr indexes + + Returns: + The final list of target class attrs + """ attrs = [] for index, attr in enumerate(target.attrs): @@ -98,6 +123,21 @@ def merge_attrs(cls, target: Class, groups: List[List[int]]) -> List[Attr]: @classmethod def group_repeating_attrs(cls, target: Class) -> List[List[int]]: + """Create a list of indexes of the same attrs. + + Example: [ + [0, 1 ,2], + [3, 4, 6], + [5,] + ] + + Args: + target: The target class instance + + Returns: + The list of indexes + + """ counters = defaultdict(list) for index, attr in enumerate(target.attrs): if not attr.is_attribute: diff --git a/xsdata/codegen/handlers/vacuum_inner_classes.py b/xsdata/codegen/handlers/vacuum_inner_classes.py index 5b87edb56..09edc83f9 100644 --- a/xsdata/codegen/handlers/vacuum_inner_classes.py +++ b/xsdata/codegen/handlers/vacuum_inner_classes.py @@ -8,15 +8,14 @@ class VacuumInnerClasses(HandlerInterface): - """ - Cleanup nested classes. + """Cleanup nested classes. Search and vacuum inner classes with no attributes or a single extension or rename inner classes that have the same name as the outer/parent class. Cases: 1. Filter duplicate inner classes - 2. Removing identical overriding fields can some times leave a class + 2. Removing identical overriding fields can sometimes leave a class bare with just an extension. For inner classes we can safely replace the forward reference with the inner extension reference. 3. Empty nested complexContent with no restrictions or extensions, @@ -26,6 +25,11 @@ class VacuumInnerClasses(HandlerInterface): __slots__ = () def process(self, target: Class): + """Process entrypoint for classes. + + Args: + target: The target class instance + """ target.inner = collections.unique_sequence(target.inner, key="qname") for inner in list(target.inner): if not inner.attrs and len(inner.extensions) < 2: @@ -35,6 +39,12 @@ def process(self, target: Class): @classmethod def remove_inner(cls, target: Class, inner: Class): + """Remove inner class and update the target class attrs. + + Args: + target: The target class instance + inner: The nested class instance + """ target.inner.remove(inner) for attr_type in cls.find_attr_types(target, inner.qname): @@ -53,6 +63,14 @@ def remove_inner(cls, target: Class, inner: Class): @classmethod def rename_inner(cls, target: Class, inner: Class): + """Rename the inner class and update the target class attrs. + + The inner class will get the `Inner suffix`. + + Args: + target: The target class instance + inner: The nested class instance + """ namespace = inner.target_namespace old_qname = inner.qname inner.qname = build_qname(namespace, f"{inner.name}_Inner") @@ -62,6 +80,12 @@ def rename_inner(cls, target: Class, inner: Class): @classmethod def find_attr_types(cls, target: Class, qname: str) -> Iterator[AttrType]: + """Find attr and choice types by the qualified name. + + Args: + target: The target class instance + qname: The qualified name + """ for attr in target.attrs: for attr_type in attr.types: if attr_type.forward and attr_type.qname == qname: diff --git a/xsdata/codegen/handlers/validate_attributes_overrides.py b/xsdata/codegen/handlers/validate_attributes_overrides.py index 3df1447fb..20f0ca739 100644 --- a/xsdata/codegen/handlers/validate_attributes_overrides.py +++ b/xsdata/codegen/handlers/validate_attributes_overrides.py @@ -14,32 +14,46 @@ class ValidateAttributesOverrides(RelativeHandlerInterface): __slots__ = () def process(self, target: Class): + """Process entrypoint for classes. + + - Validate override attrs + - Add restricted attrs + + Args: + target: The target class instance + """ base_attrs_map = self.base_attrs_map(target) # We need the original class attrs before validation, in order to # prohibit the rest of the parent attrs later... - restricted_attrs = { + explicit_attrs = { attr.slug for attr in target.attrs if attr.can_be_restricted() } self.validate_attrs(target, base_attrs_map) if target.is_restricted: - self.prohibit_parent_attrs(target, restricted_attrs, base_attrs_map) + self.prohibit_parent_attrs(target, explicit_attrs, base_attrs_map) @classmethod def prohibit_parent_attrs( cls, target: Class, - restricted_attrs: Set[str], + explicit_attrs: Set[str], base_attrs_map: Dict[str, List[Attr]], ): - """ - Prepend prohibited parent attrs to the target class. + """Prepend prohibited parent attrs to the target class. - Reset the types and default value in order to avoid conflicts + Prepend the parent prohibited attrs and reset their + types and default values in order to avoid conflicts later. + + Args: + target: The target class instance + explicit_attrs: The list of explicit attrs in the class + base_attrs_map: A mapping of qualified names to lists of parent attrs + """ for slug, attrs in reversed(base_attrs_map.items()): attr = attrs[0] - if attr.can_be_restricted() and slug not in restricted_attrs: + if attr.can_be_restricted() and slug not in explicit_attrs: attr_restricted = attr.clone() attr_restricted.restrictions.max_occurs = 0 attr_restricted.default = None @@ -48,6 +62,18 @@ def prohibit_parent_attrs( @classmethod def validate_attrs(cls, target: Class, base_attrs_map: Dict[str, List[Attr]]): + """Validate overriding attrs. + + Cases: + - Overriding attr, either remove it or update parent attr + - Duplicate names, resolve conflicts + - Remove prohibited attrs. + + + Args: + target: The target class instance + base_attrs_map: A mapping of qualified names to lists of parent attrs + """ for attr in list(target.attrs): base_attrs = base_attrs_map.get(attr.slug) @@ -62,48 +88,98 @@ def validate_attrs(cls, target: Class, base_attrs_map: Dict[str, List[Attr]]): @classmethod def overrides(cls, a: Attr, b: Attr) -> bool: + """Override attrs must belong to the same xml type and namespace. + + Args: + a: The first attr + b: The second attr + + Returns: + The bool result. + """ return a.xml_type == b.xml_type and a.namespace == b.namespace def base_attrs_map(self, target: Class) -> Dict[str, List[Attr]]: + """Create a mapping of qualified names to lists of parent attrs. + + Args: + target: The target class instance + + Returns: + A mapping of qualified names to lists of parent attrs. + """ base_attrs = self.base_attrs(target) return collections.group_by(base_attrs, key=get_slug) @classmethod - def validate_override(cls, target: Class, attr: Attr, source_attr: Attr): - if source_attr.is_any_type and not attr.is_any_type: + def validate_override(cls, target: Class, child_attr: Attr, parent_attr: Attr): + """Validate the override will not break mypy type checking. + + - Ignore wildcard attrs. + - If child is a list and parent isn't convert parent to list + - If restrictions are the same we can safely remove override attr + + Args: + target: The target class instance + child_attr: The child attr + parent_attr: The parent attr + """ + if parent_attr.is_any_type and not child_attr.is_any_type: return - if attr.is_list and not source_attr.is_list: + if child_attr.is_list and not parent_attr.is_list: # Hack much??? idk but Optional[str] can't override List[str] - source_attr.restrictions.max_occurs = sys.maxsize - assert source_attr.parent is not None + parent_attr.restrictions.max_occurs = sys.maxsize + assert parent_attr.parent is not None logger.warning( "Converting parent field `%s::%s` to a list to match child class `%s`", - source_attr.parent.name, - source_attr.name, + parent_attr.parent.name, + parent_attr.name, target.name, ) if ( - attr.default == source_attr.default - and bool_eq(attr.fixed, source_attr.fixed) - and bool_eq(attr.mixed, source_attr.mixed) - and bool_eq(attr.restrictions.tokens, source_attr.restrictions.tokens) - and bool_eq(attr.restrictions.nillable, source_attr.restrictions.nillable) - and bool_eq(attr.is_prohibited, source_attr.is_prohibited) - and bool_eq(attr.is_optional, source_attr.is_optional) + child_attr.default == parent_attr.default + and _bool_eq(child_attr.fixed, parent_attr.fixed) + and _bool_eq(child_attr.mixed, parent_attr.mixed) + and _bool_eq( + child_attr.restrictions.tokens, parent_attr.restrictions.tokens + ) + and _bool_eq( + child_attr.restrictions.nillable, parent_attr.restrictions.nillable + ) + and _bool_eq(child_attr.is_prohibited, parent_attr.is_prohibited) + and _bool_eq(child_attr.is_optional, parent_attr.is_optional) ): - cls.remove_attribute(target, attr) + cls.remove_attribute(target, child_attr) @classmethod def remove_attribute(cls, target: Class, attr: Attr): + """Safely remove attr. + + The search is done with the reference id for safety, + of removing attrs with same name. If the attr has + a forward reference, the inner class will also be removed + if it's unused! + + Args: + target: The target class instance + attr: The attr to remove + + """ ClassUtils.remove_attribute(target, attr) ClassUtils.clean_inner_classes(target) @classmethod - def resolve_conflict(cls, attr: Attr, source_attr: Attr): - ClassUtils.rename_attribute_by_preference(attr, source_attr) + def resolve_conflict(cls, child_attr: Attr, parent_attr: Attr): + """Rename the child or parent attr. + + Args: + child_attr: The child attr instance + parent_attr: The parent attr instance + """ + ClassUtils.rename_attribute_by_preference(child_attr, parent_attr) -def bool_eq(a: Optional[bool], b: Optional[bool]) -> bool: +def _bool_eq(a: Optional[bool], b: Optional[bool]) -> bool: return bool(a) is bool(b) diff --git a/xsdata/codegen/mappers/__init__.py b/xsdata/codegen/mappers/__init__.py index e69de29bb..436c40fc7 100644 --- a/xsdata/codegen/mappers/__init__.py +++ b/xsdata/codegen/mappers/__init__.py @@ -0,0 +1,13 @@ +from xsdata.codegen.mappers.definitions import DefinitionsMapper +from xsdata.codegen.mappers.dict import DictMapper +from xsdata.codegen.mappers.dtd import DtdMapper +from xsdata.codegen.mappers.element import ElementMapper +from xsdata.codegen.mappers.schema import SchemaMapper + +__all__ = [ + "DefinitionsMapper", + "DictMapper", + "DtdMapper", + "ElementMapper", + "SchemaMapper", +] diff --git a/xsdata/codegen/mappers/definitions.py b/xsdata/codegen/mappers/definitions.py index 8d6a26a2b..be323eeda 100644 --- a/xsdata/codegen/mappers/definitions.py +++ b/xsdata/codegen/mappers/definitions.py @@ -20,8 +20,7 @@ class DefinitionsMapper: - """ - Map a definitions instance to message and service classes. + """Map a definitions instance to message and service classes. Currently, only SOAP 1.1 bindings with rpc/document style is supported. @@ -29,7 +28,17 @@ class DefinitionsMapper: @classmethod def map(cls, definitions: Definitions) -> List[Class]: - """Step 1: Main mapper entry point.""" + """Main entrypoint for this mapper. + + Iterates over their services and their ports and build + the binding and service classes. + + Args: + definitions: The definitions instance to map. + + Returns: + The generated class instances + """ return [ obj for service in definitions.services @@ -39,9 +48,18 @@ def map(cls, definitions: Definitions) -> List[Class]: @classmethod def map_port(cls, definitions: Definitions, port: ServicePort) -> Iterator[Class]: - """Step 2: Match a ServicePort to a Binding and PortType object and - delegate the process to the next entry point.""" + """Map a service port into binding and service classes. + + Match a ServicePort to a Binding and PortType object and + delegate the process to the next entry point. + Args: + definitions: The definitions instance + port: The service port instance + + Yields: + An iterator of class instances. + """ binding = definitions.find_binding(text.suffix(port.binding)) port_type = definitions.find_port_type(text.suffix(binding.type)) @@ -58,8 +76,20 @@ def map_binding( port_type: PortType, config: Dict, ) -> Iterator[Class]: - """Step 3: Match every BindingOperation to a PortTypeOperation and - delegate the process for each operation to the next entry point.""" + """Map binding operations into binding and service classes. + + Match every BindingOperation to a PortTypeOperation and + delegate the process for each operation to the next entry point. + + Args: + definitions: The definitions instance + binding: The binding instance + port_type: The port type instance + config: Configuration dictionary + + Yields: + An iterator of class instances. + """ for operation in binding.unique_operations(): cfg = config.copy() cfg.update(cls.attributes(operation.extended_elements)) @@ -78,9 +108,21 @@ def map_binding_operation( config: Dict, name: str, ) -> Iterator[Class]: - """Step 4: Convert a BindingOperation to a service class and delegate - the process of all the message classes to the next entry point.""" + """Map a binding operation to a service and binding classes. + Convert a BindingOperation to a service class and delegate + the process of all the message classes to the next entry point. + + Args: + definitions: The definitions instance + binding_operation: The binding operation instance + port_type_operation: The port type operation instance + config: Configuration dictionary + name: The operation name + + Yields: + An iterator of class instances. + """ attrs = [ cls.build_attr(key, str(DataType.STRING), native=True, default=config[key]) for key in sorted(config.keys(), key=len) @@ -91,7 +133,12 @@ def map_binding_operation( name = f"{name}_{binding_operation.name}" namespace = cls.operation_namespace(config) operation_messages = cls.map_binding_operation_messages( - definitions, binding_operation, port_type_operation, name, style, namespace + definitions, + binding_operation, + port_type_operation, + name, + style, + namespace, ) for message_class in operation_messages: yield message_class @@ -115,24 +162,40 @@ def map_binding_operation( def map_binding_operation_messages( cls, definitions: Definitions, - operation: BindingOperation, + binding_operation: BindingOperation, port_type_operation: PortTypeOperation, name: str, style: str, namespace: Optional[str], ) -> Iterator[Class]: - """Step 5: Map the BindingOperation messages to classes.""" - + """Map the binding operation messages to binding classes. + + Args: + definitions: The definitions instance + binding_operation: The binding operation instance + port_type_operation: The port type operation instance + name: The operation name + style: The operation style + namespace: The operation namespace + + Yields: + An iterator of class instances. + """ messages: List[Tuple[str, BindingMessage, PortTypeMessage, Optional[str]]] = [] - if operation.input: + if binding_operation.input: messages.append( - ("input", operation.input, port_type_operation.input, operation.name) + ( + "input", + binding_operation.input, + port_type_operation.input, + binding_operation.name, + ) ) - if operation.output: + if binding_operation.output: messages.append( - ("output", operation.output, port_type_operation.output, None) + ("output", binding_operation.output, port_type_operation.output, None) ) for suffix, binding_message, port_type_message, operation_name in messages: @@ -161,7 +224,13 @@ def build_envelope_fault( port_type_operation: PortTypeOperation, target: Class, ): - """Build inner fault class with default fields.""" + """Add an inner message fault class with default fields. + + Args: + definitions: The definitions instance + port_type_operation: The port type operation instance + target: The target class instance + """ ns_map: Dict = {} body = next(inner for inner in target.inner if inner.name == "Body") fault_class = cls.build_inner_class(body, "Fault", target.namespace) @@ -197,9 +266,21 @@ def build_envelope_class( namespace: Optional[str], operation: Optional[str], ) -> Class: - """Step 6.1: Build Envelope class for the given binding message with - attributes from the port type message.""" - + """Map the binding message to an envelope class. + + Args: + definitions: The definitions instance + binding_message: The port type message instance + port_type_message: The port type message instance + name: The class name + style: The operation style e.g. rpc + namespace: The operation namespace + operation: The custom operation name, if it's empty + the message name will be used instead + + Returns: + The class instance. + """ assert binding_message.location is not None target = Class( @@ -233,10 +314,19 @@ def build_envelope_class( @classmethod def build_message_class( - cls, definitions: Definitions, port_type_message: PortTypeMessage + cls, + definitions: Definitions, + port_type_message: PortTypeMessage, ) -> Class: - """Step 6.2: Build the input/output message class of an rpc style - operation.""" + """Map the input/output message of a rpc style operation. + + Args: + definitions: The definitions instance + port_type_message: The port type message instance + + Returns: + The class instance. + """ prefix, name = text.split(port_type_message.message) definition_message = definitions.find_message(name) @@ -259,12 +349,18 @@ def build_message_class( def build_inner_class( cls, target: Class, name: str, namespace: Optional[str] = None ) -> Class: - """ - Build or retrieve an inner class for the given target class by the - given name. + """Build or retrieve an inner class. This helper will also create a forward reference attribute for the parent class. + + Args: + target: The parent class instance + name: The inner class name + namespace: The inner class namespace + + Returns: + The inner class instance. """ inner = collections.first(inner for inner in target.inner if inner.name == name) if not inner: @@ -288,7 +384,17 @@ def map_port_type_message( message: PortTypeMessage, namespace: Optional[str], ) -> Iterator[Attr]: - """Build an attribute for the given port type message.""" + """Build an attribute for the given port type message. + + Args: + operation: The operation name, use the message name + if it's empty + message: The port type message instance + namespace: The operation namespace + + Yields: + An iterator of class attrs. + """ prefix, name = text.split(message.message) source_namespace = message.ns_map.get(prefix) @@ -305,8 +411,17 @@ def map_port_type_message( def map_binding_message_parts( cls, definitions: Definitions, message: str, extended: AnyElement, ns_map: Dict ) -> Iterator[Attr]: - """Find a Message instance and map its parts to attributes according to - the extensible element..""" + """Find a Message instance and map its parts to attrs. + + Args: + definitions: The definitions instance + message: The message qualified name + extended: The related extended element + ns_map: The namespace prefix-URI map + + Yields: + An iterator of class attrs. + """ parts = [] if "part" in extended.attributes: parts.append(extended.attributes["part"]) @@ -328,11 +443,14 @@ def map_binding_message_parts( @classmethod def build_parts_attributes(cls, parts: List[Part], ns_map: Dict) -> Iterator[Attr]: - """ - Build attributes for the given list of parts. + """Build attributes for the given list of parts. - :param parts: List of parts - :param ns_map: Namespace prefix-URI map + Args: + parts: A list of part instances + ns_map: The namespace prefix-URI map + + Yields: + An iterator of class attrs. """ for part in parts: if part.element: @@ -357,6 +475,14 @@ def build_parts_attributes(cls, parts: List[Part], ns_map: Dict) -> Iterator[Att @classmethod def operation_namespace(cls, config: Dict) -> Optional[str]: + """Return the operation namespace by the operation transport. + + Args: + config: The operation configuration + + Returns: + The operation namespace string or None if transport is not soap. + """ transport = config.get("transport") namespace = None if transport == "http://schemas.xmlsoap.org/soap/http": @@ -366,7 +492,14 @@ def operation_namespace(cls, config: Dict) -> Optional[str]: @classmethod def attributes(cls, elements: Iterator[AnyElement]) -> Dict: - """Return all attributes from all extended elements as a dictionary.""" + """Return all attributes from all extended elements as a dictionary. + + Args: + elements: An iterator of generic elements + + Returns: + A key-value mapping of the xml attributes. + """ return { namespaces.local_name(qname): value for element in elements @@ -384,7 +517,19 @@ def build_attr( namespace: Optional[str] = None, default: Optional[str] = None, ) -> Attr: - """Builder method for attributes.""" + """Helper method to build an attr instance. + + Args: + name: The attr name + qname: The attr qualified name + native: Whether the type is native + forward: Whether the type is a forward reference + namespace: The attr namespace + default: The attr default value + + Returns: + The new attr instance. + """ occurs = 1 if default is not None else None if native: namespace = "" diff --git a/xsdata/codegen/mappers/dict.py b/xsdata/codegen/mappers/dict.py index 927a91be6..3f5b5cf51 100644 --- a/xsdata/codegen/mappers/dict.py +++ b/xsdata/codegen/mappers/dict.py @@ -1,23 +1,44 @@ import sys from typing import Any, Dict, List -from xsdata.codegen.mappers.element import ElementMapper +from xsdata.codegen.mappers.mixins import RawDocumentMapper from xsdata.codegen.models import AttrType, Class from xsdata.codegen.utils import ClassUtils from xsdata.models.enums import Tag -class DictMapper: - """Map a dictionary to classes, extensions and attributes.""" +class DictMapper(RawDocumentMapper): + """Map a dictionary to classes. + + This mapper is used to build classes from raw json documents. + """ @classmethod def map(cls, data: Dict, name: str, location: str) -> List[Class]: - """Convert a dictionary to a list of codegen classes.""" + """Map a dictionary to classes. + + Args: + data: The json resource data + name: The main resource name + location: The resource location + + Returns: + The list of classes. + """ target = cls.build_class(data, name) return list(ClassUtils.flatten(target, f"{location}/{name}")) @classmethod def build_class(cls, data: Dict, name: str) -> Class: + """Build a class from a data dictionary. + + Args: + data: The json resource data + name: The main resource name + + Returns: + The list of classes. + """ target = Class(qname=name, tag=Tag.ELEMENT, location="") for key, value in data.items(): @@ -27,6 +48,13 @@ def build_class(cls, data: Dict, name: str) -> Class: @classmethod def build_class_attribute(cls, target: Class, name: str, value: Any): + """Build a class attr. + + Args: + target: The target class instance + name: The attr name + value: The data value to extract types and restrictions. + """ if isinstance(value, list): if not value: cls.build_class_attribute(target, name, None) @@ -41,6 +69,6 @@ def build_class_attribute(cls, target: Class, name: str, value: Any): attr_type = AttrType(qname=inner.qname, forward=True) target.inner.append(inner) else: - attr_type = ElementMapper.build_attribute_type(name, value) + attr_type = cls.build_attr_type(name, value) - ElementMapper.build_attribute(target, name, attr_type) + cls.build_attr(target, name, attr_type) diff --git a/xsdata/codegen/mappers/dtd.py b/xsdata/codegen/mappers/dtd.py index 66ba7f0e7..4acc9f361 100644 --- a/xsdata/codegen/mappers/dtd.py +++ b/xsdata/codegen/mappers/dtd.py @@ -18,13 +18,32 @@ class DtdMapper: + """Maps a Dtd instance to a list of class instances.""" + @classmethod def map(cls, dtd: Dtd) -> Iterator[Class]: + """Map the given Dtd instance to a list of classes. + + Args: + dtd: The Dtd instance to be mapped + + Yields: + An iterator of mapped classes. + """ for element in dtd.elements: yield cls.build_class(element, dtd.location) @classmethod def build_class(cls, element: DtdElement, location: str) -> Class: + """Build a clas for the given element. + + Args: + element: The dtd element to be mapped + location: The location of dtd resource + + Returns: + The mapped class instance. + """ target = Class( qname=element.qname, ns_map=element.ns_map, @@ -39,11 +58,23 @@ def build_class(cls, element: DtdElement, location: str) -> Class: @classmethod def build_attributes(cls, target: Class, element: DtdElement): + """Build attributes from the dtd element attributes. + + Args: + target: The target class instance + element: The dtd element containing attributes + """ for attribute in element.attributes: cls.build_attribute(target, attribute) @classmethod def build_attribute(cls, target: Class, attribute: DtdAttribute): + """Build an attr from a dtd attribute. + + Args: + target: The target class instance + attribute: The dtd attribute to be mapped + """ attr_type = cls.build_attribute_type(target, attribute) attr = Attr( name=attribute.name, @@ -61,8 +92,18 @@ def build_attribute(cls, target: Class, attribute: DtdAttribute): @classmethod def build_attribute_restrictions( - cls, attr: Attr, default: DtdAttributeDefault, default_value: Optional[str] + cls, + attr: Attr, + default: DtdAttributeDefault, + default_value: Optional[str], ): + """Build attribute restrictions based on DtdAttributeDefault. + + Args: + attr: The target attr instance + default: The default attribute type + default_value: The default value + """ attr.restrictions.max_occurs = 1 if default == DtdAttributeDefault.REQUIRED: attr.restrictions.min_occurs = 1 @@ -80,6 +121,15 @@ def build_attribute_restrictions( @classmethod def build_attribute_type(cls, target: Class, attribute: DtdAttribute) -> AttrType: + """Build attribute type based on the dtd attribute. + + Args: + target: The target class instance + attribute: The dtd attribute to be mapped + + Returns: + The mapped attr type instance. + """ if attribute.type == DtdAttributeType.ENUMERATION: cls.build_enumeration(target, attribute.name, attribute.values) return AttrType(qname=attribute.name, forward=True) @@ -88,7 +138,12 @@ def build_attribute_type(cls, target: Class, attribute: DtdAttribute) -> AttrTyp @classmethod def build_elements(cls, target: Class, element: DtdElement): - # "undefined", "empty", "any", "mixed", or "element"; + """Build attrs from the dtd element. + + Args: + target: The target class instance + element: The dtd containing elements. + """ if element.type == DtdElementType.ELEMENT and element.content: cls.build_content(target, element.content) elif element.type == DtdElementType.MIXED and element.content: @@ -98,6 +153,12 @@ def build_elements(cls, target: Class, element: DtdElement): @classmethod def build_mixed_content(cls, target: Class, content: DtdContent): + """Mark class to support mixed content. + + Args: + target: The target class instance + content: The dtd content instance + """ if content.left and content.left.type == DtdContentType.PCDATA: target.mixed = True content.left = None @@ -110,6 +171,12 @@ def build_mixed_content(cls, target: Class, content: DtdContent): @classmethod def build_extension(cls, target: Class, data_type: DataType): + """Add a xsd native type as class extension. + + Args: + target: The target class instance + data_type: The data type instance + """ ext_type = AttrType(qname=str(data_type), native=True) extension = Extension( tag=Tag.EXTENSION, type=ext_type, restrictions=Restrictions() @@ -118,6 +185,13 @@ def build_extension(cls, target: Class, data_type: DataType): @classmethod def build_content(cls, target: Class, content: DtdContent, **kwargs: Any): + """Build class content. + + Args: + target: The target class instance + content: The dtd content instance. + **kwargs: Additional restriction arguments. + """ content_type = content.type if content_type == DtdContentType.ELEMENT: restrictions = cls.build_restrictions(content.occur, **kwargs) @@ -140,6 +214,13 @@ def build_content(cls, target: Class, content: DtdContent, **kwargs: Any): @classmethod def build_content_tree(cls, target: Class, content: DtdContent, **kwargs: Any): + """Build the class content tree. + + Args: + target: The target class instance + content: The dtd content instance + **kwargs: Additional restriction arguments + """ if content.left: cls.build_content(target, content.left, **kwargs) @@ -148,6 +229,14 @@ def build_content_tree(cls, target: Class, content: DtdContent, **kwargs: Any): @classmethod def build_occurs(cls, occur: DtdContentOccur) -> Dict: + """Calculate min/max occurs from the dtd content occur instance. + + Args: + occur: The dtd content occur instance. + + Returns: + The min/max occurs restrictions dictionary + """ if occur == DtdContentOccur.ONCE: min_occurs = 1 max_occurs = 1 @@ -168,6 +257,15 @@ def build_occurs(cls, occur: DtdContentOccur) -> Dict: @classmethod def build_restrictions(cls, occur: DtdContentOccur, **kwargs: Any) -> Restrictions: + """Map the dtd content occur instance to a restriction instance. + + Args: + occur: The DtdContentOccur to be mapped + **kwargs: Additional restriction arguments + + Returns: + The mapped restrictions instance. + """ params = cls.build_occurs(occur) params.update(kwargs) @@ -175,6 +273,13 @@ def build_restrictions(cls, occur: DtdContentOccur, **kwargs: Any) -> Restrictio @classmethod def build_element(cls, target: Class, name: str, restrictions: Restrictions): + """Build an element attr for the target class instance. + + Args: + target: The target class instance + name: The attr name + restrictions: The attr restrictions + """ types = AttrType(qname=name, native=False) attr = Attr( name=name, tag=Tag.ELEMENT, types=[types], restrictions=restrictions.clone() @@ -184,6 +289,12 @@ def build_element(cls, target: Class, name: str, restrictions: Restrictions): @classmethod def build_value(cls, target: Class, restrictions: Restrictions): + """Build a value attr for the target class instance. + + Args: + target: The target class instance + restrictions: The attr restrictions + """ types = AttrType(qname=str(DataType.STRING), native=True) attr = Attr( name=DEFAULT_ATTR_NAME, @@ -196,6 +307,13 @@ def build_value(cls, target: Class, restrictions: Restrictions): @classmethod def build_enumeration(cls, target: Class, name: str, values: List[str]): + """Build a nested enumeration class from the given values list. + + Args: + target: The target class instance + name: The attr/enum class name + values: The enumeration values + """ inner = Class(qname=name, tag=Tag.SIMPLE_TYPE, location=target.location) attr_type = AttrType(qname=str(DataType.STRING), native=True) diff --git a/xsdata/codegen/mappers/element.py b/xsdata/codegen/mappers/element.py index c939edb5a..78ad64f03 100644 --- a/xsdata/codegen/mappers/element.py +++ b/xsdata/codegen/mappers/element.py @@ -1,22 +1,32 @@ -import sys from collections import defaultdict -from typing import Any, List, Optional +from typing import List, Optional -from xsdata.codegen.models import Attr, AttrType, Class +from xsdata.codegen.mappers.mixins import RawDocumentMapper +from xsdata.codegen.models import AttrType, Class from xsdata.codegen.utils import ClassUtils -from xsdata.formats.converter import converter from xsdata.formats.dataclass.models.generics import AnyElement -from xsdata.models.enums import DataType, QNames, Tag +from xsdata.models.enums import QNames, Tag from xsdata.utils import collections from xsdata.utils.namespaces import build_qname, split_qname -class ElementMapper: - """Map a schema instance to classes, extensions and attributes.""" +class ElementMapper(RawDocumentMapper): + """Map a generic element to classes. + + This mapper is used to build classes from raw xml documents. + """ @classmethod def map(cls, element: AnyElement, location: str) -> List[Class]: - """Map schema children elements to classes.""" + """Map schema children elements to classes. + + Args: + element: The root element to be mapped + location: The location of the xml document + + Returns: + The list of mapped class instances. + """ assert element.qname is not None uri, name = split_qname(element.qname) @@ -26,6 +36,15 @@ def map(cls, element: AnyElement, location: str) -> List[Class]: @classmethod def build_class(cls, element: AnyElement, parent_namespace: Optional[str]) -> Class: + """Build a Class instance for the given generic element. + + Args: + element: The generic element to be mapped + parent_namespace: The parent element namespace + + Returns: + The mapped class instance. + """ assert element.qname is not None namespace, name = split_qname(element.qname) @@ -45,19 +64,40 @@ def build_class(cls, element: AnyElement, parent_namespace: Optional[str]) -> Cl @classmethod def build_attributes( - cls, target: Class, element: AnyElement, namespace: Optional[str] + cls, + target: Class, + element: AnyElement, + namespace: Optional[str], ): + """Build attributes for the given Class instance based on AnyElement attributes. + + Args: + target: The target class instance + element: The AnyElement containing attributes. + namespace: The namespace. + + """ for key, value in element.attributes.items(): if key == QNames.XSI_NIL: target.nillable = value.strip() in ("true", "1") else: - attr_type = cls.build_attribute_type(key, value) - cls.build_attribute(target, key, attr_type, namespace, Tag.ATTRIBUTE) + attr_type = cls.build_attr_type(key, value) + cls.build_attr(target, key, attr_type, namespace, Tag.ATTRIBUTE) @classmethod def build_elements( - cls, target: Class, element: AnyElement, namespace: Optional[str] + cls, + target: Class, + element: AnyElement, + namespace: Optional[str], ): + """Build elements for the given Class instance based on AnyElement children. + + Args: + target: The target class instance + element: The AnyElement containing children. + namespace: The namespace. + """ sequences = cls.sequential_groups(element) for index, child in enumerate(element.children): if isinstance(child, AnyElement) and child.qname: @@ -69,10 +109,10 @@ def build_elements( attr_type = AttrType(qname=inner.qname, forward=True) target.inner.append(inner) else: - attr_type = cls.build_attribute_type(child.qname, child.text) + attr_type = cls.build_attr_type(child.qname, child.text) sequence = collections.find_connected_component(sequences, index) - cls.build_attribute( + cls.build_attr( target, child.qname, attr_type, @@ -83,92 +123,42 @@ def build_elements( @classmethod def build_text(cls, target: Class, element: AnyElement): + """Build a text attr from the generic element text value. + + Args: + target: The target class instance + element: The AnyElement containing text content. + """ if element.text: - attr_type = cls.build_attribute_type("value", element.text) - cls.build_attribute(target, "value", attr_type, None, Tag.SIMPLE_TYPE) + attr_type = cls.build_attr_type("value", element.text) + cls.build_attr(target, "value", attr_type, None, Tag.SIMPLE_TYPE) if any(attr.tag == Tag.ELEMENT for attr in target.attrs): target.mixed = True @classmethod - def build_attribute_type(cls, qname: str, value: Any) -> AttrType: - def match_type(val: Any) -> DataType: - if not isinstance(val, str): - return DataType.from_value(val) - - for tp in converter.explicit_types(): - if converter.test(val, [tp], strict=True): - return DataType.from_type(tp) - - return DataType.STRING - - if qname == QNames.XSI_TYPE: - data_type = DataType.QNAME - elif value is None or value == "": - data_type = DataType.ANY_SIMPLE_TYPE - else: - data_type = match_type(value) - - return AttrType(qname=str(data_type), native=True) - - @classmethod - def build_attribute( - cls, - target: Class, - qname: str, - attr_type: AttrType, - parent_namespace: Optional[str] = None, - tag: str = Tag.ELEMENT, - sequence: int = 0, - ): - namespace, name = split_qname(qname) - namespace = cls.select_namespace(namespace, parent_namespace, tag) - index = len(target.attrs) - - attr = Attr(index=index, name=name, tag=tag, namespace=namespace) - attr.types.append(attr_type) - - if sequence: - attr.restrictions.path.append(("s", sequence, 1, sys.maxsize)) - - attr.restrictions.min_occurs = 1 - attr.restrictions.max_occurs = 1 - cls.add_attribute(target, attr) - - @classmethod - def add_attribute(cls, target: Class, attr: Attr): - pos = collections.find(target.attrs, attr) - - if pos > -1: - existing = target.attrs[pos] - existing.restrictions.max_occurs = sys.maxsize - existing.types.extend(attr.types) - existing.types = collections.unique_sequence(existing.types, key="qname") - else: - target.attrs.append(attr) - - @classmethod - def select_namespace( - cls, - namespace: Optional[str], - parent_namespace: Optional[str], - tag: str = Tag.ELEMENT, - ) -> Optional[str]: - if tag == Tag.ATTRIBUTE: - return namespace - - if namespace is None and parent_namespace is not None: - return "" + def sequential_groups(cls, element: AnyElement) -> List[List[int]]: + """Identify sequential groups of repeating attributes. - return namespace + Args: + element: The generic element instance - @classmethod - def sequential_groups(cls, element: AnyElement) -> List[List[int]]: + Returns: + A list of lists of strongly connected children indexes. + """ groups = cls.group_repeating_attrs(element) return list(collections.connected_components(groups)) @classmethod def group_repeating_attrs(cls, element: AnyElement) -> List[List[int]]: + """Group repeating children in the given generic element. + + Args: + element: The generic element instance + + Returns: + A list of lists of children indexes. + """ counters = defaultdict(list) for index, child in enumerate(element.children): if isinstance(child, AnyElement) and child.qname: diff --git a/xsdata/codegen/mappers/mixins.py b/xsdata/codegen/mappers/mixins.py new file mode 100644 index 000000000..bc345bc5e --- /dev/null +++ b/xsdata/codegen/mappers/mixins.py @@ -0,0 +1,120 @@ +import sys +from typing import Any, Optional + +from xsdata.codegen.models import Attr, AttrType, Class +from xsdata.formats.converter import converter +from xsdata.models.enums import DataType, QNames, Tag +from xsdata.utils import collections +from xsdata.utils.namespaces import split_qname + + +class RawDocumentMapper: + """Mixin class for raw json/xml documents.""" + + @classmethod + def build_attr( + cls, + target: Class, + qname: str, + attr_type: AttrType, + parent_namespace: Optional[str] = None, + tag: str = Tag.ELEMENT, + sequence: int = 0, + ): + """Build an attr for the given class instance. + + Args: + target: The target class instance + qname: The attr qualified name + attr_type: The attr type instance + parent_namespace: The parent namespace + tag: The attr tag + sequence: The attr sequence number + """ + namespace, name = split_qname(qname) + namespace = cls.select_namespace(namespace, parent_namespace, tag) + index = len(target.attrs) + + attr = Attr(index=index, name=name, tag=tag, namespace=namespace) + attr.types.append(attr_type) + + if sequence: + attr.restrictions.path.append(("s", sequence, 1, sys.maxsize)) + + attr.restrictions.min_occurs = 1 + attr.restrictions.max_occurs = 1 + cls.add_attribute(target, attr) + + @classmethod + def build_attr_type(cls, qname: str, value: Any) -> AttrType: + """Build an attribute type for the given attribute name and value. + + Args: + qname: The attr qualified name + value: The attr value + + Returns: + The new attr type instance. + """ + + def match_type(val: Any) -> DataType: + if not isinstance(val, str): + return DataType.from_value(val) + + for tp in converter.explicit_types(): + if converter.test(val, [tp], strict=True): + return DataType.from_type(tp) + + return DataType.STRING + + if qname == QNames.XSI_TYPE: + data_type = DataType.QNAME + elif value is None or value == "": + data_type = DataType.ANY_SIMPLE_TYPE + else: + data_type = match_type(value) + + return AttrType(qname=str(data_type), native=True) + + @classmethod + def select_namespace( + cls, + namespace: Optional[str], + parent_namespace: Optional[str], + tag: str = Tag.ELEMENT, + ) -> Optional[str]: + """Select the namespace based on the tag and namespace. + + Args: + namespace: The current namespace + parent_namespace: The parent namespace + tag: The tag name + + Returns: + Optional[str]: The selected namespace. + """ + if tag == Tag.ATTRIBUTE: + return namespace + + if namespace is None and parent_namespace is not None: + return "" + + return namespace + + @classmethod + def add_attribute(cls, target: Class, attr: Attr): + """Add an attr to the target class instance. + + Args: + target: The target class instance + attr (Attr): The attribute to be added. + """ + pos = collections.find(target.attrs, attr) + + if pos > -1: + existing = target.attrs[pos] + existing.restrictions.max_occurs = sys.maxsize + existing.types.extend(attr.types) + existing.types = collections.unique_sequence(existing.types, key="qname") + else: + target.attrs.append(attr) diff --git a/xsdata/codegen/mappers/schema.py b/xsdata/codegen/mappers/schema.py index 269978b02..d5c145e30 100644 --- a/xsdata/codegen/mappers/schema.py +++ b/xsdata/codegen/mappers/schema.py @@ -15,13 +15,25 @@ from xsdata.utils import collections, text from xsdata.utils.namespaces import build_qname, is_default, prefix_exists +ROOT_CLASSES = (SimpleType, ComplexType, Group, AttributeGroup, Element, Attribute) + class SchemaMapper: - """Map a schema instance to classes, extensions and attributes.""" + """Map a schema instance to classes. + + This mapper is used to build classes from xsd documents. + """ @classmethod def map(cls, schema: Schema) -> List[Class]: - """Map schema children elements to classes.""" + """Map schema children elements to classes. + + Args: + schema: The schema instance + + Returns: + A list of classes. + """ assert schema.location is not None location = schema.location @@ -33,19 +45,36 @@ def map(cls, schema: Schema) -> List[Class]: ] @classmethod - def root_elements(cls, schema: Schema): - """Return all valid schema elements that can be converted to - classes.""" + def root_elements(cls, schema: Schema) -> Iterator[Tuple[str, ElementBase]]: + """Return the schema root elements. + + Qualified Elements: + - SimpleType + - ComplexType + - Group + - AttributeGroup + - Element + - Attribute + + Args: + schema: The schema instance + + Yields: + An iterator of element base instances. + """ + + def condition(item: ElementBase) -> bool: + return isinstance(item, ROOT_CLASSES) for override in schema.overrides: - for child in override.children(condition=cls.is_class): + for child in override.children(condition=condition): yield Tag.OVERRIDE, child for redefine in schema.redefines: - for child in redefine.children(condition=cls.is_class): + for child in redefine.children(condition=condition): yield Tag.REDEFINE, child - for child in schema.children(condition=cls.is_class): + for child in schema.children(condition=condition): yield Tag.SCHEMA, child @classmethod @@ -56,7 +85,17 @@ def build_class( location: str, target_namespace: Optional[str], ) -> Class: - """Build and return a class instance.""" + """Build and return a class instance. + + Args: + obj: The element base instance + container: The container name + location: The schema location + target_namespace: The schema target namespace + + Returns: + The new class instance. + """ instance = Class( qname=build_qname(target_namespace, obj.real_name), abstract=obj.is_abstract, @@ -79,8 +118,19 @@ def build_class( @classmethod def build_substitutions( - cls, obj: ElementBase, target_namespace: Optional[str] + cls, + obj: ElementBase, + target_namespace: Optional[str], ) -> List[str]: + """Builds a list of qualified substitution group names. + + Args: + obj: The element base instance + target_namespace: The schema target namespace + + Returns: + A list of qualified substitution group names. + """ return [ build_qname(obj.ns_map.get(prefix, target_namespace), suffix) for prefix, suffix in map(text.split, obj.substitutions) @@ -88,9 +138,12 @@ def build_substitutions( @classmethod def build_class_attributes(cls, obj: ElementBase, target: Class): - """Build the target class attributes from the given ElementBase - children.""" + """Build the target class attrs from the element children. + Args: + obj: The element base instance + target: The target class instance + """ base_restrictions = Restrictions.from_element(obj) for child, restrictions in cls.element_children(obj, base_restrictions): cls.build_class_attribute(target, child, restrictions) @@ -99,9 +152,12 @@ def build_class_attributes(cls, obj: ElementBase, target: Class): @classmethod def build_class_extensions(cls, obj: ElementBase, target: Class): - """Build the item class extensions from the given ElementBase - children.""" + """Build the target class extensions from the element children. + Args: + obj: The element base instance + target: The target class instance + """ restrictions = obj.get_restrictions() extensions = [ cls.build_class_extension(obj.class_name, target, base, restrictions) @@ -111,10 +167,22 @@ def build_class_extensions(cls, obj: ElementBase, target: Class): target.extensions = collections.unique_sequence(extensions) @classmethod - def build_data_type( - cls, target: Class, name: str, forward: bool = False + def build_attr_type( + cls, + target: Class, + name: str, + forward: bool = False, ) -> AttrType: - """Create an attribute type for the target class.""" + """Create a reference attr type for the target class. + + Args: + target: The target class instance + name: the qualified name of the attr + forward: Whether the reference is for an inner class + + Returns: + The new attr type instance. + """ prefix, suffix = text.split(name) namespace = target.ns_map.get(prefix, target.target_namespace) qname = build_qname(namespace, suffix) @@ -128,11 +196,19 @@ def build_data_type( @classmethod def element_children( - cls, obj: ElementBase, parent_restrictions: Restrictions + cls, + obj: ElementBase, + parent_restrictions: Restrictions, ) -> Iterator[Tuple[ElementBase, Restrictions]]: - """Recursively find and return all child elements that are qualified to - be class attributes, with all their restrictions.""" + """Recursively find and return all child elements. + Args: + obj: The element base instance. + parent_restrictions: The parent element restrictions instance + + Yields: + An iterator of elements and their parent restrictions. + """ for child in obj.children(): if child.is_property: yield child, parent_restrictions @@ -143,19 +219,26 @@ def element_children( @classmethod def element_namespace( - cls, obj: ElementBase, target_namespace: Optional[str] + cls, + obj: ElementBase, + target_namespace: Optional[str], ) -> Optional[str]: - """ - Return the target namespace for the given schema element. + """Return the target namespace for the given schema element. - In order: + Rules: - elements/attributes with specific target namespace - prefixed elements returns the namespace from schema ns_map - qualified elements returns the schema target namespace - unqualified elements return an empty string - unqualified attributes return None - """ + Args: + obj: The element base instance + target_namespace: The schema target namespace + + Returns: + The element real namespace or None if no namespace + """ raw_namespace = obj.raw_namespace if raw_namespace: return raw_namespace @@ -176,13 +259,18 @@ def element_namespace( @classmethod def children_extensions( - cls, obj: ElementBase, target: Class + cls, + obj: ElementBase, + target: Class, ) -> Iterator[Extension]: - """ - Recursively find and return all target's Extension classes. + """Recursively find and return all target's extension instances. + + Args: + obj: The element base instance + target: The target class instance - If the initial given obj has a type attribute include it in - result. + Yields: + An iterator of extension instances. """ for child in obj.children(): if child.is_property: @@ -197,11 +285,25 @@ def children_extensions( @classmethod def build_class_extension( - cls, tag: str, target: Class, name: str, restrictions: Dict + cls, + tag: str, + target: Class, + name: str, + restrictions: Dict, ) -> Extension: - """Create an extension for the target class.""" + """Create a reference extension for the target class. + + Args: + tag: The tag name + target: The target class instance + name: The qualified name of the extension + restrictions: A key-value restrictions mapping + + Returns: + The new extension instance. + """ return Extension( - type=cls.build_data_type(target, name), + type=cls.build_attr_type(target, name), tag=tag, restrictions=Restrictions(**restrictions), ) @@ -213,10 +315,15 @@ def build_class_attribute( obj: ElementBase, parent_restrictions: Restrictions, ): - """Generate and append an attribute field to the target class.""" + """Build and append a new attr to the target class. + Args: + target: The target class instance + obj: The element base instance to map to an attr + parent_restrictions: The parent element restrictions + """ target.ns_map.update(obj.ns_map) - types = cls.build_class_attribute_types(target, obj) + types = cls.build_attr_types(target, obj) restrictions = Restrictions.from_element(obj) if obj.class_name in (Tag.ELEMENT, Tag.ANY, Tag.GROUP): @@ -238,13 +345,17 @@ def build_class_attribute( ) @classmethod - def build_class_attribute_types( - cls, target: Class, obj: ElementBase - ) -> List[AttrType]: - """Convert real type and anonymous inner types to an attribute type - list.""" + def build_attr_types(cls, target: Class, obj: ElementBase) -> List[AttrType]: + """Convert the element types and inner types to an attr types. - types = [cls.build_data_type(target, tp) for tp in obj.attr_types] + Args: + target: The target class instance + obj: The element base instance to extract the types from + + Returns: + A list of attr type instances. + """ + types = [cls.build_attr_type(target, tp) for tp in obj.attr_types] location = target.location namespace = target.target_namespace @@ -253,15 +364,27 @@ def build_class_attribute_types( types.append(AttrType(qname=inner.qname, forward=True)) if len(types) == 0: - types.append(cls.build_data_type(target, name=obj.default_type)) + types.append(cls.build_attr_type(target, name=obj.default_type)) return collections.unique_sequence(types) @classmethod def build_inner_classes( - cls, obj: ElementBase, location: str, namespace: Optional[str] + cls, + obj: ElementBase, + location: str, + namespace: Optional[str], ) -> Iterator[Class]: - """Find and convert anonymous types to a class instances.""" + """Find and convert anonymous types to a class instances. + + Args: + obj: The element base instance + location: The schema location + namespace: The parent element namespace + + Yields: + An iterator of class instances. + """ if isinstance(obj, SimpleType) and obj.is_enumeration: yield cls.build_class(obj, obj.class_name, location, namespace) else: @@ -273,9 +396,3 @@ def build_inner_classes( yield cls.build_class(child, obj.class_name, location, namespace) else: yield from cls.build_inner_classes(child, location, namespace) - - @classmethod - def is_class(cls, item: ElementBase) -> bool: - return isinstance( - item, (SimpleType, ComplexType, Group, AttributeGroup, Element, Attribute) - ) diff --git a/xsdata/codegen/mixins.py b/xsdata/codegen/mixins.py index b3eaaada4..bfb76588f 100644 --- a/xsdata/codegen/mixins.py +++ b/xsdata/codegen/mixins.py @@ -8,8 +8,11 @@ class ContainerInterface(abc.ABC): - """Wrap a list of classes and expose a simple api for easy access and - process.""" + """A class list wrapper with an easy access api. + + Args: + config: The generator configuration instance + """ __slots__ = ("config",) @@ -18,37 +21,85 @@ def __init__(self, config: GeneratorConfig): @abc.abstractmethod def __iter__(self) -> Iterator[Class]: - """Create an iterator for the class map values.""" + """Yield an iterator for the class map values.""" @abc.abstractmethod def find(self, qname: str, condition: Callable = return_true) -> Optional[Class]: - """Search by qualified name for a specific class with an optional - condition callable.""" + """Find class that matches the given qualified name and condition callable. + + Classes are allowed to have the same qualified name, e.g. xsd:Element + extending xsd:ComplexType with the same name, you can provide and additional + callback to filter the classes like the tag. + + Args: + qname: The qualified name of the class + condition: A user callable to filter further + + Returns: + A class instance or None if no match found. + """ @abc.abstractmethod def find_inner(self, source: Class, qname: str) -> Class: - """Search by qualified name for a specific inner class or fail.""" + """Search by qualified name for a specific inner class or fail. + + Args: + source: The source class to search for the inner class + qname: The qualified name of the inner class to look up + + Returns: + The inner class instance + + Raises: + CodeGenerationError: If the inner class is not found. + """ @abc.abstractmethod def first(self, qname: str) -> Class: - """Search by qualified name for a specific class and return the first - available.""" + """Return the first class that matches the qualified name. + + Args: + qname: The qualified name of the class + + Returns: + The first matching class + + Raises: + KeyError: If no class matches the qualified name + """ @abc.abstractmethod def add(self, item: Class): - """Add class item to the container.""" + """Add class item to the container. + + Args: + item: The class instance to add + """ @abc.abstractmethod def extend(self, items: List[Class]): - """Add a list of classes to the container.""" + """Add a list of classes to the container. + + Args: + items: The list of class instances to add + """ @abc.abstractmethod def reset(self, item: Class, qname: str): - """Update the given class qualified name.""" + """Update the given class qualified name. + + Args: + item: The target class instance to update + qname: The new qualified name of the class + """ @abc.abstractmethod def set(self, items: List[Class]): - """Set the list of classes to the container.""" + """Set the list of classes to the container. + + Args: + items: The list of classes + """ class HandlerInterface(abc.ABC): @@ -58,12 +109,19 @@ class HandlerInterface(abc.ABC): @abc.abstractmethod def process(self, target: Class): - """Process the given target class.""" + """Process the given target class. + + Args: + target: The target class instance + """ class RelativeHandlerInterface(HandlerInterface, metaclass=ABCMeta): - """Class handler interface with access to the complete classes' - container.""" + """An interface for codegen handlers with class container access. + + Args: + container: The container instance + """ __slots__ = "container" @@ -71,6 +129,15 @@ def __init__(self, container: ContainerInterface): self.container = container def base_attrs(self, target: Class) -> List[Attr]: + """Return a list of all parent attrs recursively. + + Args: + target: The target class + + Returns: + A list of attr instances. + + """ attrs: List[Attr] = [] for extension in target.extensions: base = self.container.find(extension.type.qname) @@ -87,11 +154,19 @@ def base_attrs(self, target: Class) -> List[Attr]: @abc.abstractmethod def process(self, target: Class): - """Process class.""" + """Process entrypoint for a class. + + Args: + target: The target class instance + """ class ContainerHandlerInterface(abc.ABC): - """Class container.""" + """A codegen interface for processing the whole class container. + + Args: + container: The class container instance + """ __slots__ = "container" diff --git a/xsdata/codegen/models.py b/xsdata/codegen/models.py index 0a7026740..fa1908358 100644 --- a/xsdata/codegen/models.py +++ b/xsdata/codegen/models.py @@ -25,8 +25,32 @@ @dataclass class Restrictions: - """Model representation of a dataclass field validation and type - metadata.""" + """Class field validation restrictions. + + Args: + min_occurs: The minimum number of occurrences + max_occurs: The maximum number of occurrences + min_exclusive: The lower exclusive bound for numeric values + min_inclusive: The lower inclusive bound for numeric values + min_length: The minimum length of characters or list items allowed + max_exclusive: The upper exclusive bound for numeric values + max_inclusive: The upper inclusive bound for numeric values + max_length: The max length of characters or list items allowed + total_digits: The exact number of digits allowed for numeric values + fraction_digits: The maximum number of decimal places allowed + length: The exact number of characters or list items allowed + white_space: Specifies how white space is handled + pattern: Defines the exact sequence of characters that are acceptable + explicit_timezone: Require or prohibit the time zone offset in date/time + nillable: Specifies whether nil content is allowed + sequence: The sequence reference number of the attr + tokens: Specifies whether the value needs tokenization + format: The output format used for byte and datetime types + choice: The choice reference number of the attr + group: The group reference number of the attr + process_contents: Specifies the content processed mode: strict, lax, skip + path: The coded attr path in the source document + """ min_occurs: Optional[int] = field(default=None) max_occurs: Optional[int] = field(default=None) @@ -53,36 +77,25 @@ class Restrictions: @property def is_list(self) -> bool: - """Return true if max occurs property is larger than one.""" + """Return whether the max occurs larger than one.""" return self.max_occurs is not None and self.max_occurs > 1 @property def is_optional(self) -> bool: - """Return true if min occurs property equals zero.""" + """Return whether the min occurs is zero.""" return self.min_occurs == 0 @property def is_prohibited(self) -> bool: + """Return whether the max occurs is zero.""" return self.max_occurs == 0 def merge(self, source: "Restrictions"): - """Update properties from another instance.""" - self.update(source) - - self.path = source.path + self.path - self.sequence = self.sequence or source.sequence - self.choice = self.choice or source.choice - self.tokens = self.tokens or source.tokens - self.format = self.format or source.format - self.group = self.group or source.group + """Update properties from another instance. - if self.min_occurs is None and source.min_occurs is not None: - self.min_occurs = source.min_occurs - - if self.max_occurs is None and source.max_occurs is not None: - self.max_occurs = source.max_occurs - - def update(self, source: "Restrictions"): + Args: + source: The source instance to merge properties from + """ keys = ( "min_exclusive", "min_inclusive", @@ -104,12 +117,30 @@ def update(self, source: "Restrictions"): if value is not None: setattr(self, key, value) + self.path = source.path + self.path + self.sequence = self.sequence or source.sequence + self.choice = self.choice or source.choice + self.tokens = self.tokens or source.tokens + self.format = self.format or source.format + self.group = self.group or source.group + + if self.min_occurs is None and source.min_occurs is not None: + self.min_occurs = source.min_occurs + + if self.max_occurs is None and source.max_occurs is not None: + self.max_occurs = source.max_occurs + def asdict(self, types: Optional[List[Type]] = None) -> Dict: - """ - Return the initialized only properties as a dictionary. + """Return the initialized only properties as a dictionary. - Skip None or implied values, and optionally use the parent + Skip None or implied values, and optionally use the attribute types to convert relevant options. + + Args: + types: An optional list of attr python types + + Returns: + A key-value of map of the attr restrictions for generation. """ result = {} sorted_types = converter.sort_types(types) if types else [] @@ -148,20 +179,30 @@ def clone(self) -> "Restrictions": @classmethod def from_element(cls, element: ElementBase) -> "Restrictions": - """Static constructor from a xsd model.""" - return cls(**element.get_restrictions()) + """Static constructor from a xsd model. + Args: + element: A element base instance. -class AttrCategory(IntEnum): - NATIVE = 0 - FORWARD = 1 - EXTERNAL = 2 + Returns: + The new restrictions instance + """ + return cls(**element.get_restrictions()) @dataclass(unsafe_hash=True) class AttrType: - """Model representation for the typing information for fields and - extensions.""" + """Class field typing information. + + Args: + qname: The namespace qualified name + alias: The type alias + reference: The type reference number + native: Specifies if it's python native type + forward: Specifies if it's a forward reference + circular: Specifies if it's a circular reference + substituted: Specifies if it has been processed for substitution groups + """ qname: str alias: Optional[str] = field(default=None, compare=False) @@ -173,6 +214,7 @@ class AttrType: @property def datatype(self) -> Optional[DataType]: + """Return the datatype instance if native, none otherwise.""" return DataType.from_qname(self.qname) if self.native else None @property @@ -181,9 +223,17 @@ def name(self) -> str: return namespaces.local_name(self.qname) def is_dependency(self, allow_circular: bool) -> bool: - """Return true if attribute is not a forward/circular references, and - it's not a native python time.""" + """Return whether this type is a dependency. + + The type must a reference to a user type, not a forward + reference and not a circular unless if it's allowed. + Args: + allow_circular: Allow circular references as dependencies + + Returns: + The bool result/ + """ return not ( self.forward or self.native or (not allow_circular and self.circular) ) @@ -195,7 +245,24 @@ def clone(self) -> "AttrType": @dataclass class Attr: - """Model representation for a dataclass field.""" + """Class field model representation. + + Args: + tag: The xml tag that produced this attr + name: The final attr name + local_name: The original attr name + index: The index position of this attr in the class + default: The default value + fixed: Specifies if the default value is fixed + mixed: Specifies if the attr supports mixed content + types: The attr types list + choices: The attr choice list + namespace: The attr namespace + help: The attr help text + restrictions: The attr restrictions instance + parent: The class reference of the attr + substitution: The substitution group this attr belongs to + """ tag: str name: str = field(compare=False) @@ -213,26 +280,36 @@ class Attr: substitution: Optional[str] = field(default=None, compare=False) def __post_init__(self): + """Set the original attr name on init.""" self.local_name = self.name @property def key(self) -> str: + """Generate a key for this attr. + + Concatenate the tag/namespace/local_name. + This key is used to find duplicates, it's not + supposed to be unique. + + Returns: + The unique key for this attr. + + """ return f"{self.tag}.{self.namespace}.{self.local_name}" @property def is_attribute(self) -> bool: - """Return whether this attribute is derived from a xs:attribute or - xs:anyAttribute.""" + """Return whether this attr represents a xml attribute node.""" return self.tag in (Tag.ATTRIBUTE, Tag.ANY_ATTRIBUTE) @property def is_enumeration(self) -> bool: - """Return whether this attribute is derived from a xs:enumeration.""" + """Return whether this attr an enumeration member.""" return self.tag == Tag.ENUMERATION @property def is_dict(self) -> bool: - """Return whether this attribute is a mapping of values.""" + """Return whether this attr is derived from xs:anyAttribute.""" return self.tag == Tag.ANY_ATTRIBUTE @property @@ -242,85 +319,84 @@ def is_factory(self) -> bool: @property def is_forward_ref(self) -> bool: + """Return whether any attr types is a forward or circular reference.""" return any(tp.circular or tp.forward for tp in self.types) @property def is_group(self) -> bool: - """Return whether this attribute is derived from a xs:group or - xs:attributeGroup.""" + """Return whether this attr is a reference to a group class.""" return self.tag in (Tag.ATTRIBUTE_GROUP, Tag.GROUP) @property def is_list(self) -> bool: - """Return whether this attribute is a list of values.""" + """Return whether this attr requires a list of values.""" return self.restrictions.is_list @property def is_prohibited(self) -> bool: - """Return whether this attribute is prohibited.""" + """Return whether this attr is prohibited.""" return self.restrictions.is_prohibited @property def is_nameless(self) -> bool: - """Return whether this attribute has a local name that will be used - during parsing/serialization.""" + """Return whether this attr is a real xml node.""" return self.tag not in (Tag.ATTRIBUTE, Tag.ELEMENT) @property def is_nillable(self) -> bool: + """Return whether this attr supports nil values.""" return self.restrictions.nillable is True @property def is_optional(self) -> bool: - """Return whether this attribute is not required.""" + """Return whether this attr is not required.""" return self.restrictions.is_optional @property def is_suffix(self) -> bool: - """Return whether this attribute is not derived from a xs element with - mode suffix.""" + """Return whether this attr is supposed to be generated last.""" return self.index == sys.maxsize @property def is_xsi_type(self) -> bool: - """Return whether this attribute qualified name is equal to - xsi:type.""" + """Return whether this attr represents a xsi:type attribute.""" return self.namespace == Namespace.XSI.uri and self.name == "type" @property def is_tokens(self) -> bool: - """Return whether this attribute is a list of values.""" + """Return whether this attr supports token values.""" return self.restrictions.tokens is True @property def is_wildcard(self) -> bool: - """Return whether this attribute is derived from xs:anyAttribute or - xs:any.""" + """Return whether this attr supports any content.""" return self.tag in (Tag.ANY_ATTRIBUTE, Tag.ANY) @property def is_any_type(self) -> bool: + """Return whether this attr types support any content.""" return any(tp is object for tp in self.get_native_types()) @property def native_types(self) -> List[Type]: - """Return a list of all builtin data types.""" + """Return a list of all the builtin data types.""" return list(set(self.get_native_types())) @property def user_types(self) -> Iterator[AttrType]: - """Return an iterator of all the user defined types.""" + """Yield an iterator of all the user defined types.""" for tp in self.types: if not tp.native: yield tp @property def slug(self) -> str: + """Return the slugified name of the attr.""" return text.alnum(self.name) @property def xml_type(self) -> Optional[str]: - """Return the xml node type this attribute is mapped to.""" + """Return the xml type this attribute is mapped to.""" return xml_type_map.get(self.tag) def clone(self) -> "Attr": @@ -332,19 +408,26 @@ def clone(self) -> "Attr": ) def get_native_types(self) -> Iterator[Type]: + """Yield an iterator of all the native attr types.""" for tp in self.types: datatype = tp.datatype if datatype: yield datatype.type def can_be_restricted(self) -> bool: - """Return whether this attribute can be restricted.""" + """Return whether this attr can be restricted.""" return self.xml_type not in (Tag.ATTRIBUTE, None) @dataclass(unsafe_hash=True) class Extension: - """Model representation of a dataclass base class.""" + """Base class model representation. + + Args: + tag: The xml tag that produced this extension + type: The extension type + restrictions: The extension restrictions instance + """ tag: str type: AttrType @@ -360,6 +443,8 @@ def clone(self) -> "Extension": class Status(IntEnum): + """Class process status enumeration.""" + RAW = 0 UNGROUPING = 10 UNGROUPED = 11 @@ -375,8 +460,31 @@ class Status(IntEnum): @dataclass class Class: - """Model representation of a dataclass with fields, base/inner classes and - additional metadata settings.""" + """Class model representation. + + Args: + qname: The namespace qualified name + tag: The xml tag that produced this class + location: The schema/document location uri + mixed: Specifies whether this class supports mixed content + abstract: Specifies whether this is an abstract class + nillable: Specifies whether this class supports nil content + local_type: Specifies if this class was an inner type at some point + status: The processing status of the class + container: The xml container of the class, schema, override, redefine + package: The designated package of the class + module: The designated module of the class + namespace: The class namespace + help: The help text + meta_name: The xml element name of the class + default: The default value + fixed: Specifies whether the default value is fixed + substitutions: The list of all the substitution groups this class belongs to + extensions: The list of all the extension instances + attrs: The list of all the attr instances + inner: The list of all the inner class instances + ns_map: The namespace prefix-URI map + """ qname: str tag: str @@ -402,60 +510,62 @@ class Class: @property def name(self) -> str: - """Shortcut for qname local name.""" + """Shortcut for the class local name.""" return namespaces.local_name(self.qname) @property def slug(self) -> str: + """Return a slugified version of the class name.""" return text.alnum(self.name) @property def ref(self) -> int: + """Return this id reference of this instance.""" return id(self) @property def target_namespace(self) -> Optional[str]: + """Return the class target namespace.""" return namespaces.target_uri(self.qname) @property def has_suffix_attr(self) -> bool: - """Return whether it includes a suffix attribute.""" + """Return whether it includes a suffix attr.""" return any(attr.is_suffix for attr in self.attrs) @property def has_help_attr(self) -> bool: - """Return whether it includes at least one attr with help content.""" + """Return whether at least one of attrs has help content.""" return any(attr.help and attr.help.strip() for attr in self.attrs) @property def is_complex(self) -> bool: - """Return whether this instance is derived from a xs:element or - xs:complexType.""" + """Return whether class represents a xs:element/complex type.""" return self.tag in (Tag.ELEMENT, Tag.COMPLEX_TYPE) @property def is_element(self) -> bool: - """Return whether this instance is derived from a non abstract - xs:element.""" + """Return whether this class represents a xml element.""" return self.tag == Tag.ELEMENT @property def is_enumeration(self) -> bool: - """Return whether all attributes are derived from xs:enumeration.""" + """Return whether all attrs are enumeration members.""" return len(self.attrs) > 0 and all(attr.is_enumeration for attr in self.attrs) @property def is_global_type(self) -> bool: - """Return whether this instance is a non-abstract element, wsdl binding - class or a complex type without simple content.""" + """Return whether this class represents a root/global class. + + Global classes are the only classes that get generated by default. + """ return (not self.abstract and self.tag in GLOBAL_TYPES) or ( self.tag == Tag.COMPLEX_TYPE and not self.is_simple_type ) @property def is_group(self) -> bool: - """Return whether this attribute is derived from a xs:group or - xs:attributeGroup.""" + """Return whether this class is derived from a xs:group/attributeGroup.""" return self.tag in (Tag.ATTRIBUTE_GROUP, Tag.GROUP) @property @@ -470,18 +580,25 @@ def is_mixed(self) -> bool: @property def is_restricted(self) -> bool: + """Return whether this class includes any restriction extensions.""" return any( True for extension in self.extensions if extension.tag == Tag.RESTRICTION ) @property def is_service(self) -> bool: - """Return whether this instance is derived from wsdl:operation.""" + """Return whether this instance is derived from a wsdl:operation.""" return self.tag == Tag.BINDING_OPERATION @property def is_simple_type(self) -> bool: - """Return whether the class represents a simple text type.""" + """Return whether the class represents a simple type. + + Simple Types: + - xs:simpleType/extension/list/union + - have only one attr + - have no extensions. + """ return ( len(self.attrs) == 1 and self.attrs[0].tag in SIMPLE_TYPES @@ -490,6 +607,8 @@ def is_simple_type(self) -> bool: @property def references(self) -> Iterator[int]: + """Yield all class object reference numbers.""" + def all_refs(): for ext in self.extensions: yield ext.type.reference @@ -511,7 +630,12 @@ def all_refs(): @property def target_module(self) -> str: - """Return the target module this class is assigned to.""" + """Return the designated full module path. + + Raises: + CodeGenerationError: if the target was not designated + a package and module. + """ if self.package and self.module: return f"{self.package}.{self.module}" @@ -530,8 +654,9 @@ def clone(self) -> "Class": return replace(self, inner=inners, extensions=extensions, attrs=attrs) def dependencies(self, allow_circular: bool = False) -> Iterator[str]: - """ - Return a set of dependencies for the given class. + """Yields all class dependencies. + + Omit circular and forward references by default. Collect: * base classes @@ -540,6 +665,9 @@ def dependencies(self, allow_circular: bool = False) -> Iterator[str]: * recursively go through the inner classes * Ignore inner class references * Ignore native types. + + Args: + allow_circular: Allow circular references """ types = {ext.type for ext in self.extensions} @@ -568,12 +696,12 @@ def has_forward_ref(self) -> bool: @dataclass class Import: - """ - Model representation of a python import statement. + """Python import statement model representation. - :param qname: - :param source: - :param alias: + Args: + qname: The qualified name of the imported class + source: The absolute module path + alias: Specifies an alias to avoid naming conflicts """ qname: str @@ -582,11 +710,12 @@ class Import: @property def name(self) -> str: - """Shortcut for qname local name.""" + """Return the name of the imported class.""" return namespaces.local_name(self.qname) @property def slug(self) -> str: + """Return a slugified version of the imported class name.""" return text.alnum(self.name) diff --git a/xsdata/codegen/parsers/__init__.py b/xsdata/codegen/parsers/__init__.py index 33c524561..5b0d1afbd 100644 --- a/xsdata/codegen/parsers/__init__.py +++ b/xsdata/codegen/parsers/__init__.py @@ -1,4 +1,9 @@ from xsdata.codegen.parsers.definitions import DefinitionsParser +from xsdata.codegen.parsers.dtd import DtdParser from xsdata.codegen.parsers.schema import SchemaParser -__all__ = ["SchemaParser", "DefinitionsParser"] +__all__ = [ + "DefinitionsParser", + "DtdParser", + "SchemaParser", +] diff --git a/xsdata/codegen/parsers/definitions.py b/xsdata/codegen/parsers/definitions.py index 866c6beba..783945524 100644 --- a/xsdata/codegen/parsers/definitions.py +++ b/xsdata/codegen/parsers/definitions.py @@ -10,8 +10,7 @@ @dataclass class DefinitionsParser(SchemaParser): - """A simple parser to convert a wsdl to an easy to handle data structure - based on dataclasses.""" + """Parse a wsdl document into data models.""" def end( self, @@ -21,7 +20,21 @@ def end( text: Optional[str], tail: Optional[str], ) -> Any: - """Override parent method to set element location.""" + """Parse the last xml node and bind any intermediate objects. + + Override parent method to set source location in every + wsdl element. + + Args: + queue: The XmlNode queue list + objects: The list of all intermediate parsed objects + qname: The element qualified name + text: The element text content + tail: The element tail content + + Returns: + Whether the binding process was successful. + """ obj = super().end(queue, objects, qname, text, tail) if isinstance(obj, wsdl.WsdlElement): obj.location = self.location @@ -29,5 +42,12 @@ def end( return obj def end_import(self, obj: T): + """End import element entrypoint. + + Resolve the location path of import elements. + + Args: + obj: The wsdl import element. + """ if isinstance(obj, wsdl.Import) and self.location: obj.location = self.resolve_path(obj.location) diff --git a/xsdata/codegen/parsers/dtd.py b/xsdata/codegen/parsers/dtd.py index ee685048f..cfcbbf395 100644 --- a/xsdata/codegen/parsers/dtd.py +++ b/xsdata/codegen/parsers/dtd.py @@ -1,5 +1,5 @@ import io -from typing import Any, List, Optional +from typing import Any, Dict, List, Optional from xsdata.exceptions import ParserError from xsdata.models.dtd import ( @@ -17,8 +17,22 @@ class DtdParser: + """Document type definition parser. + + The parser requires lxml package to be installed. + """ + @classmethod - def parse(cls, source: Any, location: str) -> Dtd: + def parse(cls, source: bytes, location: str) -> Dtd: + """Parse the input source bytes object into a dtd instance. + + Args: + source: The source bytes object to parse + location: The source location uri + + Returns: + A dtd instance representing the parsed content. + """ try: from lxml import etree @@ -31,6 +45,14 @@ def parse(cls, source: Any, location: str) -> Dtd: @classmethod def build_element(cls, element: Any) -> DtdElement: + """Build a dtd element from the lxml element. + + Args: + element: The lxml dtd element instance + + Returns: + The converted xsdata dtd element instance. + """ content = cls.build_content(element.content) attributes = list(map(cls.build_attribute, element.iterattributes())) ns_map = cls.build_ns_map(element.prefix, attributes) @@ -45,6 +67,14 @@ def build_element(cls, element: Any) -> DtdElement: @classmethod def build_content(cls, content: Any) -> Optional[DtdContent]: + """Build a dtd content instance from the lxml content. + + Args: + content: The lxml content instance + + Returns: + The converted xsdata dtd content instance, or None if the content is empty. + """ if not content: return None @@ -58,6 +88,14 @@ def build_content(cls, content: Any) -> Optional[DtdContent]: @classmethod def build_attribute(cls, attribute: Any) -> DtdAttribute: + """Build a dtd attribute instance from the lxml instance. + + Args: + attribute: The lxml attribute instance + + Returns: + The converted xsdata dtd attribute instance. + """ return DtdAttribute( prefix=attribute.prefix, name=attribute.name, @@ -68,7 +106,18 @@ def build_attribute(cls, attribute: Any) -> DtdAttribute: ) @classmethod - def build_ns_map(cls, prefix: str, attributes: List[DtdAttribute]) -> dict: + def build_ns_map(cls, prefix: str, attributes: List[DtdAttribute]) -> Dict: + """Build the dtd element namespace prefix-URI map. + + It also adds common namespaces like xs, xsi, xlink and xml. + + Args: + prefix: The element namespace prefix + attributes: Element attributes, to extract any xmlns keys + + Returns: + The element namespace prefix-URI map. + """ ns_map = {ns.prefix: ns.uri for ns in Namespace.common()} for attribute in list(attributes): diff --git a/xsdata/codegen/parsers/schema.py b/xsdata/codegen/parsers/schema.py index b2b8c2a92..9ceb297c8 100644 --- a/xsdata/codegen/parsers/schema.py +++ b/xsdata/codegen/parsers/schema.py @@ -16,29 +16,34 @@ @dataclass class SchemaParser(UserXmlParser): - """ - A simple parser to convert an xsd schema to an easy to handle data - structure based on dataclasses. - - The parser is as a dummy as possible, but it will try to normalize - certain things like apply parent properties to children. - - :param location: - :param element_form: - :param attribute_form: - :param target_namespace: - :param default_attributes: - :param default_open_content: + """Xml schema definition parser. + + Apply implied rules, set indexes, resolve + location paths... + + Args: + location: The schema location uri + target_namespace: The schema target namespace + + Attributes: + index: The current element index + indices: The child element indices + element_form: The schema element form + attribute_form: The schema attribute form + default_attributes: The schema default attributes + default_open_content: The schema default open content """ - index: int = field(default_factory=int) - indices: List[int] = field(default_factory=list) location: Optional[str] = field(default=None) - element_form: Optional[FormType] = field(init=False, default=None) - attribute_form: Optional[FormType] = field(init=False, default=None) target_namespace: Optional[str] = field(default=None) - default_attributes: Optional[str] = field(default=None) - default_open_content: Optional[xsd.DefaultOpenContent] = field(default=None) + index: int = field(default_factory=int, init=False) + indices: List[int] = field(default_factory=list, init=False) + element_form: Optional[str] = field(default=None, init=False) + attribute_form: Optional[str] = field(default=None, init=False) + default_attributes: Optional[str] = field(default=None, init=False) + default_open_content: Optional[xsd.DefaultOpenContent] = field( + default=None, init=False + ) def start( self, @@ -49,6 +54,19 @@ def start( attrs: Dict, ns_map: Dict, ): + """Build and queue the XmlNode for the starting element. + + Override to set the current element index and append it in + child element indices. + + Args: + clazz: The target class type, auto locate if omitted + queue: The XmlNode queue list + objects: The list of all intermediate parsed objects + qname: The element qualified name + attrs: The element attributes + ns_map: The element namespace prefix-URI map + """ self.index += 1 self.indices.append(self.index) super().start(clazz, queue, objects, qname, attrs, ns_map) @@ -61,7 +79,20 @@ def end( text: Optional[str], tail: Optional[str], ) -> Any: - """Override parent method to set element index and namespaces map.""" + """Parse the last xml node and bind any intermediate objects. + + Override to set the xsd model index and ns map. + + Args: + queue: The XmlNode queue list + objects: The list of all intermediate parsed objects + qname: The element qualified name + text: The element text content + tail: The element tail content + + Returns: + Whether the binding process was successful. + """ item = queue[-1] super().end(queue, objects, qname, text, tail) @@ -71,16 +102,28 @@ def end( return obj - def start_schema(self, attrs: Dict): - """Collect the schema's default form for attributes and elements for - later usage.""" + def start_schema(self, attrs: Dict[str, str]): + """Start schema element entrypoint. + + Store the element/attribute default forms and the + default attributes, for later processing. - self.element_form = attrs.get("elementFormDefault") - self.attribute_form = attrs.get("attributeFormDefault") - self.default_attributes = attrs.get("defaultAttributes") + Args: + attrs: The element attributes + + """ + self.element_form = attrs.get("elementFormDefault", None) + self.attribute_form = attrs.get("attributeFormDefault", None) + self.default_attributes = attrs.get("defaultAttributes", None) def end_schema(self, obj: T): - """Normalize various properties for the schema and it's children.""" + """End schema element entrypoint. + + Normalize various properties for the schema and it's children. + + Args: + obj: The xsd schema instance. + """ if isinstance(obj, xsd.Schema): self.set_schema_forms(obj) self.set_schema_namespaces(obj) @@ -89,19 +132,30 @@ def end_schema(self, obj: T): self.reset_element_occurs(obj) def end_attribute(self, obj: T): - """Assign the schema's default form for attributes if the given - attribute form is None.""" + """End attribute element entrypoint. + + Assign the schema's default form in the attribute instance, + if it doesn't define its own. + + Args: + obj: The xsd attribute instance + + """ if isinstance(obj, xsd.Attribute) and obj.form is None and self.attribute_form: obj.form = FormType(self.attribute_form) def end_complex_type(self, obj: T): - """ + """End complex type element entrypoint. + Post parsing processor to apply default open content and attributes if applicable. Default open content doesn't apply if the current complex type has one of complex content, simple content or has its own open content. + + Args: + obj: The xsd complex type instance """ if not isinstance(obj, xsd.ComplexType): return @@ -122,8 +176,18 @@ def end_complex_type(self, obj: T): obj.open_content = self.default_open_content def end_default_open_content(self, obj: T): - """Set the instance default open content to be used later as a property - for all extensions and restrictions.""" + """End default open content element entrypoint. + + If the open content element mode is suffix, adjust + the index to trick later processors into putting attrs + derived from this open content last in the generated classes. + + Store the obj for later processing. + + Args: + obj: The xsd default open content instance + . + """ if isinstance(obj, xsd.DefaultOpenContent): if obj.any and obj.mode == Mode.SUFFIX: obj.any.index = sys.maxsize @@ -131,33 +195,61 @@ def end_default_open_content(self, obj: T): self.default_open_content = obj def end_element(self, obj: T): - """Assign the schema's default form for elements if the given element - form is None.""" + """End element entrypoint. + + Assign the schema's default form in the element instance, + if it doesn't define its own. + + Args: + obj: The xsd element instance + """ if isinstance(obj, xsd.Element) and obj.form is None and self.element_form: obj.form = FormType(self.element_form) def end_extension(self, obj: T): - """Set the open content if any to the given extension.""" + """End extension element entrypoint. + + Assign the schema's default open content in the extension instance, + if it doesn't define its own. + + Args: + obj: The xsd extension instance + """ if isinstance(obj, xsd.Extension) and not obj.open_content: obj.open_content = self.default_open_content @classmethod def end_open_content(cls, obj: T): - """Adjust the index to trick later processors into putting attributes - derived from this open content last in classes.""" + """End open content element entrypoint. + + If the open content element mode is suffix, adjust + the index to trick later processors into putting attrs + derived from this open content last in the generated classes. + + Args: + obj: The xsd open content instance + + """ if isinstance(obj, xsd.OpenContent) and obj.any and obj.mode == Mode.SUFFIX: obj.any.index = sys.maxsize def end_restriction(self, obj: T): - """Set the open content if any to the given restriction.""" + """End restriction element entrypoint. + + Assign the schema's default open content in the restriction instance, + if it doesn't define its own. + + Args: + obj: The xsd restriction instance + """ if isinstance(obj, xsd.Restriction) and not obj.open_content: obj.open_content = self.default_open_content def set_schema_forms(self, obj: xsd.Schema): - """ - Set the default form type for elements and attributes. + """Cascade schema forms to elements and attributes. - Global elements and attributes are by default qualified. + Args: + obj: The xsd schema instance """ if self.element_form: obj.element_form_default = FormType(self.element_form) @@ -171,13 +263,31 @@ def set_schema_forms(self, obj: xsd.Schema): child_attribute.form = FormType.QUALIFIED def set_schema_namespaces(self, obj: xsd.Schema): - """Set the given schema's target namespace and add the default - namespaces if the are missing xsi, xlink, xml, xs.""" + """Set the schema target namespace. + + If the schema was imported and doesn't have a target namespace, + it automatically inherits the parent schema target namespace. + + Args: + obj: The xsd schema instance + """ obj.target_namespace = obj.target_namespace or self.target_namespace def resolve_schemas_locations(self, obj: xsd.Schema): - """Resolve the locations of the schema overrides, redefines, includes - and imports relatively to the schema location.""" + """Resolve the location attributes of the schema. + + This method covers relative paths and implied schema + locations to common namespaces like xsi, xlink. + + Schema elements with location attribute: + - override + - redefines + - include + - import + + Args: + obj: The xsd schema instance + """ if not self.location: return @@ -195,17 +305,35 @@ def resolve_schemas_locations(self, obj: xsd.Schema): imp.location = self.resolve_local_path(imp.schema_location, imp.namespace) def resolve_path(self, location: Optional[str]) -> Optional[str]: - """Resolve the given location string relatively the schema location - path.""" + """Resolve the given location string. + + Use the parser location attribute as the base uri. + Args: + location: The location uri + + Returns: + The resolved location or None if it was not resolved + """ return urljoin(self.location, location) if self.location and location else None def resolve_local_path( - self, location: Optional[str], namespace: Optional[str] + self, + location: Optional[str], + namespace: Optional[str], ) -> Optional[str]: - """Resolve the given namespace to one of the local standard schemas or - fallback to the external file path.""" + """Resolve the given namespace to one of the local standard schemas. + + w3.org protects against fetching common schemas not from a browser, + instead we use the local xsdata copies. + Args: + location: The schema location + namespace: The schema namespace + + Returns: + The local path or the absolute remote uri. + """ common_ns = Namespace.get_enum(namespace) local_path = common_ns.location if common_ns else None @@ -216,6 +344,17 @@ def resolve_local_path( @classmethod def has_elements(cls, obj: ElementBase) -> bool: + """Helper function to check if instance has children. + + Valid children: xs:element, xs:any, xs:group. + + Args: + obj: The element base instance + + Returns: + The bool result. + + """ accepted_types = (xsd.Element, xsd.Any, xsd.Group) return any( isinstance(child, accepted_types) or cls.has_elements(child) @@ -224,7 +363,14 @@ def has_elements(cls, obj: ElementBase) -> bool: @classmethod def set_namespace_map(cls, obj: Any, ns_map: Optional[Dict]): - """Add common namespaces like xml, xsi, xlink if they are missing.""" + """Add common namespaces like xml, xsi, xlink if they are missing. + + These prefixes are implied and we need to support them. + + Args: + obj: A xsd model instance + ns_map: The namespace prefix-URI map + """ if hasattr(obj, "ns_map"): if ns_map: obj.ns_map.update( @@ -242,13 +388,25 @@ def set_namespace_map(cls, obj: Any, ns_map: Optional[Dict]): @classmethod def set_index(cls, obj: Any, index: int): + """Helper method to set an object's index. + + Args: + obj: A xsd model instance + index: The index number + """ if hasattr(obj, "index"): obj.index = index @classmethod def add_default_imports(cls, obj: xsd.Schema): - """Add missing imports to the standard schemas if the namespace is - declared and.""" + """Add missing imports to the standard schemas. + + We might need to generate the classes from these + common schemas, so add these implied imports. + + Args: + obj: The xsd schema instance + """ imp_namespaces = [imp.namespace for imp in obj.imports] xsi_ns = Namespace.XSI.uri if xsi_ns in obj.ns_map.values() and xsi_ns not in imp_namespaces: @@ -256,6 +414,13 @@ def add_default_imports(cls, obj: xsd.Schema): @classmethod def reset_element_occurs(cls, obj: xsd.Schema): + """Reset the root elements occurs restrictions. + + The root elements don't get those. + + Args: + obj: The xsd schema instance + """ for element in obj.elements: element.min_occurs = None element.max_occurs = None diff --git a/xsdata/codegen/resolver.py b/xsdata/codegen/resolver.py index 44f287137..15ebf77a8 100644 --- a/xsdata/codegen/resolver.py +++ b/xsdata/codegen/resolver.py @@ -12,23 +12,38 @@ class DependenciesResolver: - __slots__ = "packages", "aliases", "imports", "class_list", "class_map", "package" + """The dependencies resolver class. - def __init__(self, packages: Dict[str, str]): - self.packages = packages + Calculate what classes need to be imported + per package, with aliases support. + Args: + registry: The full class qname-module map. + + Attributes: + aliases: The generated aliases dictionary + imports: The list of generated imports + class_list: The topo-sorted list of class qnames + class_map: A qname-class map + + """ + + __slots__ = "registry", "aliases", "imports", "class_list", "class_map" + + def __init__(self, registry: Dict[str, str]): + self.registry = registry self.aliases: Dict[str, str] = {} self.imports: List[Import] = [] self.class_list: List[str] = [] self.class_map: Dict[str, Class] = {} def process(self, classes: List[Class]): - """ - Resolve the dependencies for the given list of classes and the target - package. + """Resolve the dependencies for the given class list. - Reset aliases and imports from any previous runs keep the record - of the processed class names + Reset previously resolved imports and aliases. + + Args: + classes: A list of classes that belong to the same target module """ self.imports.clear() self.aliases.clear() @@ -37,13 +52,11 @@ def process(self, classes: List[Class]): self.resolve_imports() def sorted_imports(self) -> List[Import]: - """Return a new sorted by name list of import packages.""" + """Return a new sorted by name list of import instances.""" return sorted(self.imports, key=lambda x: x.name) def sorted_classes(self) -> List[Class]: - """Return an iterator of classes property sorted for generation and - apply import aliases.""" - + """Apply aliases and return the sorted the generated class list.""" result = [] for name in self.class_list: obj = self.class_map.get(name) @@ -53,8 +66,14 @@ def sorted_classes(self) -> List[Class]: return result def apply_aliases(self, target: Class): - """Iterate over the target class dependencies and set the type - aliases.""" + """Apply import aliases to the target class. + + Update attr and extension types to point to the + new class aliases. Process inner classes too! + + Args: + target: The target class instance to process + """ for attr in target.attrs: for attr_type in attr.types: attr_type.alias = self.aliases.get(attr_type.qname) @@ -69,10 +88,9 @@ def apply_aliases(self, target: Class): collections.apply(target.inner, self.apply_aliases) def resolve_imports(self): - """Walk the import qualified names, check for naming collisions and add - the necessary code generator import instance.""" + """Build the list of class imports and set aliases if necessary.""" self.imports = [ - Import(qname=qname, source=self.find_package(qname)) + Import(qname=qname, source=self.get_class_module(qname)) for qname in self.import_classes() ] protected = {obj.slug for obj in self.class_map.values()} @@ -80,10 +98,21 @@ def resolve_imports(self): self.set_aliases() def set_aliases(self): + """Store generated aliases.""" self.aliases = {imp.qname: imp.alias for imp in self.imports if imp.alias} @classmethod def resolve_conflicts(cls, imports: List[Import], protected: set): + """Find naming conflicts between imports and generate aliases. + + Example: + from foo.bar import MyType as BarMyType + from bar.foo import MyType as FooMyType + + Args: + imports: The list of class import instances + protected: The set of protected class names from the module + """ for slug, group in collections.group_by(imports, key=get_slug).items(): if len(group) == 1: if slug in protected: @@ -100,18 +129,21 @@ def resolve_conflicts(cls, imports: List[Import], protected: set): add = "_".join(part for part in parts if part in diff) cur.alias = f"{add}:{cur.name}" - def find_package(self, qname: str) -> str: - """ - Return the package name for the given qualified class name. + def get_class_module(self, qname: str) -> str: + """Return the module for the given qualified class name. - :raises ResolverValueError: if name doesn't exist. + Args: + qname: The namespace qualified name of the class + + Raises: + ResolverValueError: if name doesn't exist. """ - if qname not in self.packages: + if qname not in self.registry: raise ResolverValueError(f"Unknown dependency: {qname}") - return self.packages[qname] + return self.registry[qname] def import_classes(self) -> List[str]: - """Return a list of class that need to be imported.""" + """Return a list of class qnames that need to be imported.""" return [qname for qname in self.class_list if qname not in self.class_map] @staticmethod @@ -121,7 +153,14 @@ def create_class_list(classes: List[Class]) -> List[str]: @staticmethod def create_class_map(classes: List[Class]) -> Dict[str, Class]: - """Index the list of classes by name.""" + """Index the list of classes by their qualified names. + + Raises: + ResolverValueError: If two classes have the same qname. + + Returns: + A qname-class map. + """ result: Dict[str, Class] = {} for obj in classes: if obj.qname in result: diff --git a/xsdata/codegen/transformer.py b/xsdata/codegen/transformer.py index 5114295bd..f68095471 100644 --- a/xsdata/codegen/transformer.py +++ b/xsdata/codegen/transformer.py @@ -11,14 +11,16 @@ from xsdata.codegen import opener from xsdata.codegen.analyzer import ClassAnalyzer from xsdata.codegen.container import ClassContainer -from xsdata.codegen.mappers.definitions import DefinitionsMapper -from xsdata.codegen.mappers.dict import DictMapper -from xsdata.codegen.mappers.dtd import DtdMapper -from xsdata.codegen.mappers.element import ElementMapper -from xsdata.codegen.mappers.schema import SchemaMapper +from xsdata.codegen.mappers import ( + DefinitionsMapper, + DictMapper, + DtdMapper, + ElementMapper, + SchemaMapper, +) from xsdata.codegen.models import Class +from xsdata.codegen.parsers import DtdParser from xsdata.codegen.parsers.definitions import DefinitionsParser -from xsdata.codegen.parsers.dtd import DtdParser from xsdata.codegen.parsers.schema import SchemaParser from xsdata.codegen.utils import ClassUtils from xsdata.codegen.writer import CodeWriter @@ -40,6 +42,15 @@ class SupportedType(NamedTuple): + """A supported resource model representation. + + Args: + id: The integer identifier + name: The name of the resource type + match_uri: A callable to match against URI strings + match_content: A callable to match against the raw file content + """ + id: int name: str match_uri: Callable @@ -80,13 +91,19 @@ class SupportedType(NamedTuple): ] -class SchemaTransformer: - """ - Orchestrate the code generation from a list of sources to the output - format. +class ResourceTransformer: + """Orchestrate the code generation from a list of sources. - :param print: Print to stdout the generated output - :param config: Generator configuration + Supports xsd, wsdl, dtd, xml and json documents. + + Args: + print: Print to stdout the generated output + config: Generator configuration + + Attributes: + classes: A list of class instances + processed: A list of processed uris + preloaded: A uri/content map used as cache """ __slots__ = ("print", "config", "classes", "processed", "preloaded") @@ -99,6 +116,12 @@ def __init__(self, print: bool, config: GeneratorConfig): self.preloaded: Dict = {} def process(self, uris: List[str], cache: bool = False): + """Process a list of resolved URI strings. + + Args: + uris: A list of absolute URI strings to process + cache: Specifies whether to catch the initial parsed classes + """ cache_file = self.get_cache_file(uris) if cache else None if cache_file and cache_file.exists(): logger.info(f"Loading from cache {cache_file}") @@ -113,6 +136,14 @@ def process(self, uris: List[str], cache: bool = False): self.process_classes() def process_sources(self, uris: List[str]): + """Process a list of resolved URI strings. + + Load the source URI strings and map them to codegen + classes for further processing. + + Args: + uris: A list of absolute URI strings to process + """ sources = defaultdict(list) for uri in uris: tp = self.classify_resource(uri) @@ -125,7 +156,11 @@ def process_sources(self, uris: List[str]): self.process_json_documents(sources[TYPE_JSON]) def process_definitions(self, uris: List[str]): - """Process a list of wsdl resources.""" + """Process a list of wsdl resources. + + Args: + uris: A list of wsdl URI strings to process + """ definitions = None for uri in uris: services = self.parse_definitions(uri, namespace=None) @@ -139,12 +174,20 @@ def process_definitions(self, uris: List[str]): self.convert_definitions(definitions) def process_schemas(self, uris: List[str]): - """Process a list of xsd resources.""" + """Process a list of xsd resources. + + Args: + uris: A list of xsd URI strings to process + """ for uri in uris: self.process_schema(uri) def process_dtds(self, uris: List[str]): - """Process a list of dtd resources.""" + """Process a list of dtd resources. + + Args: + uris: A list of dtd URI strings to process + """ classes: List[Class] = [] for uri in uris: @@ -158,13 +201,23 @@ def process_dtds(self, uris: List[str]): self.classes.extend(classes) def process_schema(self, uri: str, namespace: Optional[str] = None): - """Parse and convert schema to codegen models.""" + """Parse and convert schema to codegen models. + + Args: + uri: The schema URI location + namespace: The target namespace, if the URI is + from an inline import + """ schema = self.parse_schema(uri, namespace) if schema: self.convert_schema(schema) def process_xml_documents(self, uris: List[str]): - """Process a list of xml resources.""" + """Process a list of xml resources. + + Args: + uris: A list of xml URI strings to process + """ classes = [] parser = TreeParser() location = os.path.dirname(uris[0]) if uris else "" @@ -178,7 +231,11 @@ def process_xml_documents(self, uris: List[str]): self.classes.extend(ClassUtils.reduce_classes(classes)) def process_json_documents(self, uris: List[str]): - """Process a list of json resources.""" + """Process a list of json resources. + + Args: + uris: A list of json URI strings to process + """ classes = [] name = self.config.output.package.split(".")[-1] dirname = os.path.dirname(uris[0]) if uris else "" @@ -200,8 +257,7 @@ def process_json_documents(self, uris: List[str]): self.classes.extend(ClassUtils.reduce_classes(classes)) def process_classes(self): - """Process the generated classes and write or print the final - output.""" + """Process the generated classes and write or print the output.""" class_num, inner_num = self.count_classes(self.classes) if class_num: logger.info( @@ -223,8 +279,13 @@ def process_classes(self): raise CodeGenerationError("Nothing to generate.") def convert_schema(self, schema: Schema): - """Convert a schema instance to codegen classes and process imports to - other schemas.""" + """Convert a schema instance to codegen classes. + + Process recursively any schema imports. + + Args: + schema: The xsd schema instance + """ for sub in schema.included(): if sub.location: self.process_schema(sub.location, schema.target_namespace) @@ -236,7 +297,7 @@ def convert_definitions(self, definitions: Definitions): self.classes.extend(DefinitionsMapper.map(definitions)) def generate_classes(self, schema: Schema) -> List[Class]: - """Convert the given schema tree to a list of classes.""" + """Convert the given schema instance to a list of classes.""" uri = schema.location logger.info("Compiling schema %s", uri if uri else "...") classes = SchemaMapper.map(schema) @@ -248,7 +309,12 @@ def generate_classes(self, schema: Schema) -> List[Class]: return classes def parse_schema(self, uri: str, namespace: Optional[str]) -> Optional[Schema]: - """Parse the given schema uri and return the schema tree object.""" + """Parse the given URI and return the schema instance. + + Args: + uri: The resource URI + namespace: The target namespace + """ input_stream = self.load_resource(uri) if input_stream is None: return None @@ -258,11 +324,16 @@ def parse_schema(self, uri: str, namespace: Optional[str]) -> Optional[Schema]: return parser.from_bytes(input_stream, Schema) def parse_definitions( - self, uri: str, namespace: Optional[str] + self, + uri: str, + namespace: Optional[str], ) -> Optional[Definitions]: - """Parse recursively the given wsdl uri and return the definitions' - tree object.""" + """Parse recursively the given URI and return the definitions instance. + Args: + uri: The resource URI + namespace: The target namespace + """ input_stream = self.load_resource(uri) if input_stream is None: return None @@ -285,7 +356,14 @@ def parse_definitions( return definitions def load_resource(self, uri: str) -> Optional[bytes]: - """Read and return the contents of the given uri.""" + """Read and return the contents of the given URI. + + Args: + uri: The resource URI + + Returns: + The raw bytes content or None if the resource could not be read + """ if uri not in self.processed: try: self.processed.append(uri) @@ -298,9 +376,14 @@ def load_resource(self, uri: str) -> Optional[bytes]: return None def classify_resource(self, uri: str) -> int: - """Detect the resource type by the uri extension or the file - contents.""" + """Detect the resource type by the URI extension or the contents. + Args: + uri: The resource URI + + Returns: + The resource integer identifier. + """ for supported_type in supported_types: if supported_type.match_uri(uri): return supported_type.id @@ -318,16 +401,21 @@ def classify_resource(self, uri: str) -> int: return TYPE_UNKNOWN def analyze_classes(self, classes: List[Class]) -> List[Class]: - """Analyzer the given class list and simplify attributes and - extensions.""" - + """Analyzer the given class list and return the final list of classes.""" container = ClassContainer(config=self.config) container.extend(classes) return ClassAnalyzer.process(container) def count_classes(self, classes: List[Class]) -> Tuple[int, int]: - """Return a tuple of counters for the main and inner classes.""" + """Return a tuple of counters for the main and inner classes. + + Args: + classes: A list of class instances + + Returns: + A tuple of root, inner counters, e.g. (100, 5) + """ main = len(classes) inner = 0 for cls in classes: @@ -337,6 +425,14 @@ def count_classes(self, classes: List[Class]) -> Tuple[int, int]: @classmethod def get_cache_file(cls, uris: List[str]) -> Path: + """Return the cache path for the raw mapped classes. + + Args: + uris: A list of URI strings + + Returns: + A temporary file path instance + """ key = hashlib.md5("".join(uris).encode()).hexdigest() tempdir = tempfile.gettempdir() return Path(tempdir).joinpath(f"{key}.cache") diff --git a/xsdata/codegen/utils.py b/xsdata/codegen/utils.py index 51978eb6b..40e958c6c 100644 --- a/xsdata/codegen/utils.py +++ b/xsdata/codegen/utils.py @@ -20,35 +20,61 @@ class ClassUtils: """General reusable utils methods that didn't fit anywhere else.""" @classmethod - def find_value_attr(cls, target: Class) -> Attr: - """ - Find the text attribute of the class. + def find_value_attr(cls, source: Class) -> Attr: + """Find the text attribute of the class. + + Args: + source: The source class instance + + Returns: + The matched attr instance. - :raise CodeGenerationError: If no text node/attribute exists + Raises: + CodeGenerationError: If no text node/attribute exists """ - for attr in target.attrs: + for attr in source.attrs: if not attr.xml_type: return attr - raise CodeGenerationError(f"Class has no value attr {target.qname}") + raise CodeGenerationError(f"Class has no value attr {source.qname}") @classmethod def remove_attribute(cls, target: Class, attr: Attr): - """Safely remove the given attr from the target class by check obj - ids.""" + """Safely remove the given attr from the target class. + + Make sure you match the attr by the reference id, + simple comparison might remove a duplicate attr + with the same tag/namespace/name. + + Args: + target: The target class instance + attr: The attr instance to remove + + """ target.attrs = [at for at in target.attrs if id(at) != id(attr)] @classmethod def clean_inner_classes(cls, target: Class): - """Check if there are orphan inner classes and remove them.""" + """Check if there are orphan inner classes and remove them. + + Args: + target: The target class instance to inspect. + """ for inner in list(target.inner): if cls.is_orphan_inner(target, inner): target.inner.remove(inner) @classmethod def is_orphan_inner(cls, target: Class, inner: Class) -> bool: - """Check if there is at least once valid attr reference to the given - inner class.""" + """Check if the inner class is references in the target class. + + Args: + target: The target class instance + inner: The inner class instance + + Returns: + The bool result. + """ for attr in target.attrs: for attr_type in attr.types: if attr_type.forward and attr_type.qname == inner.qname: @@ -58,13 +84,16 @@ def is_orphan_inner(cls, target: Class, inner: Class) -> bool: @classmethod def copy_attributes(cls, source: Class, target: Class, extension: Extension): - """ - Copy the attributes and inner classes from the source class to the - target class and remove the extension that links the two classes - together. + """Copy the attrs from the source to the target class. - The new attributes are prepended in the list unless if they are + Remove the extension instance that connects the two classes. + The new attrs are prepended in the list unless if they are supposed to be last in a sequence. + + Args: + source: The source/parent class instance + target: The target/child class instance + extension: The extension instance that connects the classes """ target.extensions.remove(extension) target_attr_names = {attr.name for attr in target.attrs} @@ -85,9 +114,16 @@ def copy_attributes(cls, source: Class, target: Class, extension: Extension): @classmethod def copy_group_attributes(cls, source: Class, target: Class, attr: Attr): - """Copy the attributes and inner classes from the source class to the - target class and remove the group attribute that links the two classes - together.""" + """Copy the attrs of the source class to the target class. + + The attr represents a reference to the source class which is + derived from xs:group or xs:attributeGroup and wil be removed. + + Args: + source: The source class instance + target: The target class instance + attr: The group attr instance + """ index = target.attrs.index(attr) target.attrs.pop(index) @@ -100,9 +136,16 @@ def copy_group_attributes(cls, source: Class, target: Class, attr: Attr): @classmethod def copy_extensions(cls, source: Class, target: Class, extension: Extension): - """Copy the extensions from the source class to the target class and - merge the restrictions from the extension that linked the two classes - together.""" + """Copy the source class extensions to the target class instance. + + Merge the extension restrictions with the source class extensions + restrictions. + + Args: + source: The source class instance + target: The target class instance + extension: The extension instance that links the two classes together + """ for ext in source.extensions: clone = ext.clone() clone.restrictions.merge(extension.restrictions) @@ -110,30 +153,43 @@ def copy_extensions(cls, source: Class, target: Class, extension: Extension): @classmethod def clone_attribute(cls, attr: Attr, restrictions: Restrictions) -> Attr: - """Clone the given attribute and merge its restrictions with the given - instance.""" + """Clone the given attr and merge its restrictions with the given. + + Args: + attr: The source attr instance + restrictions: The additional restrictions, originated from + a substitution or another attr. + """ clone = attr.clone() clone.restrictions.merge(restrictions) return clone @classmethod def copy_inner_classes(cls, source: Class, target: Class, attr: Attr): - """Iterate all attr types and copy any inner classes from source to the - target class.""" - for attr_type in attr.types: - cls.copy_inner_class(source, target, attr, attr_type) + """Copy inner classes from source to the target class instance. - @classmethod - def copy_inner_class( - cls, source: Class, target: Class, attr: Attr, attr_type: AttrType - ): + Args: + source: The source class instance + target: The target class instance + attr: The attr with the possible forward references """ - Check if the given attr type is a forward reference and copy its inner - class from the source to the target class. + for attr_type in attr.types: + cls.copy_inner_class(source, target, attr_type) - Checks: - 1. Update type if inner class in a circular reference - 2. Copy inner class, rename it if source is a simple type. + @classmethod + def copy_inner_class(cls, source: Class, target: Class, attr_type: AttrType): + """Find and copy the inner class from source to the target class instance. + + Steps: + - Skip If the attr type is not a forward reference + - Validate the inner class is not a circular reference to the target + - Otherwise copy the inner class, and make sure it is re-sent for + processing + + Args: + source: The source class instance + target: The target class instance + attr_type: The attr type with the possible forward reference """ if not attr_type.forward: return @@ -151,6 +207,18 @@ class from the source to the target class. @classmethod def find_inner(cls, source: Class, qname: str) -> Class: + """Find an inner class in the source class by its qualified name. + + Args: + source: The parent class instance + qname: The inner class qualified name + + Returns: + The inner class instance + + Raises: + CodeGenerationError: If no inner class matched. + """ for inner in source.inner: if inner.qname == qname: return inner @@ -159,6 +227,15 @@ def find_inner(cls, source: Class, qname: str) -> Class: @classmethod def find_attr(cls, source: Class, name: str) -> Optional[Attr]: + """Find an attr in the source class by its name. + + Args: + source: The source class instance + name: The attr name to lookup + + Returns: + An attr instance or None if no attr matched. + """ for attr in source.attrs: if attr.name == name: return attr @@ -167,6 +244,17 @@ def find_attr(cls, source: Class, name: str) -> Optional[Attr]: @classmethod def flatten(cls, target: Class, location: str) -> Iterator[Class]: + """Flatten the target class instance and its inner classes. + + The inner classes are removed from target instance! + + Args: + target: The target class instance + location: The source location of the target class + + Yields: + An iterator over all the found classes. + """ target.location = location while target.inner: @@ -181,6 +269,14 @@ def flatten(cls, target: Class, location: str) -> Iterator[Class]: @classmethod def reduce_classes(cls, classes: List[Class]) -> List[Class]: + """Find duplicate classes and attrs and reduce them. + + Args: + classes: A list of classes + + Returns: + A list of unique classes with no duplicate attrs. + """ result = [] for group in collections.group_by(classes, key=get_qname).values(): target = group[0].clone() @@ -194,6 +290,14 @@ def reduce_classes(cls, classes: List[Class]) -> List[Class]: @classmethod def reduce_attributes(cls, classes: List[Class]) -> List[Attr]: + """Find and merge duplicate attrs from the given class list. + + Args: + classes: A list of class instances + + Returns: + A list of unique attr instances. + """ result = [] for attr in cls.sorted_attrs(classes): added = False @@ -215,6 +319,17 @@ def reduce_attributes(cls, classes: List[Class]) -> List[Attr]: @classmethod def sorted_attrs(cls, classes: List[Class]) -> List[Attr]: + """Sort and return the attrs from all the class list. + + The list contains duplicate classes, the method tries + to find all the attrs and sorts them by first occurrence. + + Args: + classes: A list of duplicate class instances. + + Returns: + A list of sorted duplicate attr instances. + """ attrs: List[Attr] = [] classes.sort(key=lambda x: len(x.attrs), reverse=True) @@ -241,6 +356,16 @@ def sorted_attrs(cls, classes: List[Class]) -> List[Attr]: @classmethod def merge_attributes(cls, target: Attr, source: Attr): + """Merge the source attr into the target instance. + + Merge the types, select the min min_occurs and the max max_occurs + from the two instances and copy the source sequence number + to the target if it's currently not set. + + Args: + target: The target attr instance which will be updated + source: The source attr instance + """ target.types.extend(tp for tp in source.types if tp not in target.types) target.restrictions.min_occurs = min( @@ -258,16 +383,19 @@ def merge_attributes(cls, target: Attr, source: Attr): @classmethod def rename_attribute_by_preference(cls, a: Attr, b: Attr): - """ - Decide and rename one of the two given attributes. + """Decide and rename one of the two given attributes. When both attributes are derived from the same xs:tag and one of the two fields has a specific namespace prepend it to the name. Preferable rename the second attribute. - Otherwise append the derived from tag to the name of one of the + Otherwise, append the derived from tag to the name of one of the two attributes. Preferably rename the second field or the field derived from xs:attribute. + + Args: + a: The first attr instance + b: The second attr instance """ if a.tag == b.tag and (a.namespace or b.namespace): change = b if b.namespace else a @@ -279,8 +407,12 @@ def rename_attribute_by_preference(cls, a: Attr, b: Attr): @classmethod def rename_attributes_by_index(cls, attrs: List[Attr], rename: List[Attr]): - """Append the next available index number to all the rename attributes - names.""" + """Append the next available index number to all the rename attr names. + + Args: + attrs: A list of attr instances whose names must be protected + rename: A list of attr instances that need to be renamed + """ for index in range(1, len(rename)): reserved = set(map(get_slug, attrs)) name = rename[index].name @@ -288,6 +420,15 @@ def rename_attributes_by_index(cls, attrs: List[Attr], rename: List[Attr]): @classmethod def unique_name(cls, name: str, reserved: Set[str]) -> str: + """Append the next available index number to the name. + + Args: + name: An object name + reserved: A set of reserved names + + Returns: + The new name with the index suffix + """ if text.alnum(name) in reserved: index = 1 while text.alnum(f"{name}_{index}") in reserved: @@ -299,17 +440,29 @@ def unique_name(cls, name: str, reserved: Set[str]) -> str: @classmethod def cleanup_class(cls, target: Class): + """Go through the target class attrs and filter their types. + + Removes duplicate and invalid types. + + Args: + target: The target class instance to inspect + """ for attr in target.attrs: attr.types = cls.filter_types(attr.types) @classmethod def filter_types(cls, types: List[AttrType]) -> List[AttrType]: - """ - Remove duplicate and invalid types. + """Remove duplicate and invalid types. Invalid: 1. xs:error 2. xs:anyType and xs:anySimpleType when there are other types present + + Args: + types: A list of attr type instances + + Returns: + The new list of unique and valid attr type instances. """ types = collections.unique_sequence(types, key="qname") types = collections.remove(types, lambda x: x.datatype == DataType.ERROR) diff --git a/xsdata/codegen/validator.py b/xsdata/codegen/validator.py index 64867b069..fa345a89f 100644 --- a/xsdata/codegen/validator.py +++ b/xsdata/codegen/validator.py @@ -10,8 +10,11 @@ class ClassValidator: - """Run validations against the class container in order to remove or merge - invalid or redefined types.""" + """Container class validator. + + Args: + container: The class container instance + """ __slots__ = "container" @@ -19,13 +22,15 @@ def __init__(self, container: ClassContainer): self.container = container def process(self): - """ - Remove if possible classes with the same qualified name. + """Main process entrypoint. + + Runs on groups of classes with the same + qualified name. Steps: 1. Remove invalid classes 2. Handle duplicate types - 3. Merge dummy types + 3. Merge global types """ for classes in self.container.data.values(): if len(classes) > 1: @@ -38,8 +43,11 @@ def process(self): self.merge_global_types(classes) def remove_invalid_classes(self, classes: List[Class]): - """Remove from the given class list any class with missing extension - type.""" + """Remove classes with undefined extensions. + + Args: + classes: A list of + """ def is_invalid(ext: Extension) -> bool: """Check if given type declaration is not native and is missing.""" @@ -51,9 +59,20 @@ def is_invalid(ext: Extension) -> bool: @classmethod def handle_duplicate_types(cls, classes: List[Class]): - """Handle classes with same namespace, name that are derived from the - same xs type.""" + """Find and handle duplicate classes. + + If a class is defined more than once, keep either + the one that was in redefines or overrides, or the + last definition. If a class was redefined merge + circular group attrs and extensions. + In order for two classes to be duplicated they must + have the same qualified name and be derived from the + same xsd element. + + Args: + classes: A list of classes with the same qualified name + """ for items in group_by(classes, get_tag).values(): if len(items) == 1: continue @@ -76,12 +95,14 @@ def handle_duplicate_types(cls, classes: List[Class]): @classmethod def merge_redefined_type(cls, source: Class, target: Class): - """ - Copy any attributes and extensions to redefined types from the original - definitions. + """Merge source properties to the target redefined target class instance. - Redefined inheritance is optional search for self references in - extensions and attribute groups. + Redefined classes usually have references to the original + class. We need to copy those. + + Args: + source: The original source class instance + target: The redefined target class instance """ circular_extension = cls.find_circular_extension(target) circular_group = cls.find_circular_group(target) @@ -95,11 +116,16 @@ def merge_redefined_type(cls, source: Class, target: Class): @classmethod def select_winner(cls, candidates: List[Class]) -> int: - """ - Returns the index of the class that will survive the duplicate process. + """From a list of classes select which class index will remain. Classes that were extracted from in xs:override/xs:redefined containers have priority, otherwise pick the last in the list. + + Args: + candidates: A list of duplicate class instances + + Returns: + The index of winner class or -1 if there is no clear winner. """ for index, item in enumerate(candidates): if item.container in (Tag.OVERRIDE, Tag.REDEFINE): @@ -109,8 +135,18 @@ def select_winner(cls, candidates: List[Class]) -> int: @classmethod def find_circular_extension(cls, target: Class) -> Optional[Extension]: - """Search for any target class extensions that is a circular - reference.""" + """Find the first circular reference extension. + + Redefined classes usually have references to the original + class with the same qualified name. We need to locate + those and copy any attrs. + + Args: + target: The target class instance to inspect + + Returns: + An extension instance or None if there is no circular extension. + """ for ext in target.extensions: if ext.type.name == target.name: return ext @@ -119,20 +155,33 @@ def find_circular_extension(cls, target: Class) -> Optional[Extension]: @classmethod def find_circular_group(cls, target: Class) -> Optional[Attr]: - """Search for any target class attributes that is a circular - reference.""" + """Find an attr with the same name as the target class name. + + Redefined classes usually have references to the original + class with the same qualified name. We need to locate + those and copy any attrs. + + Args: + target: The target class instance to inspect + + Returns: + An attr instance or None if there is no circular attr. + """ return ClassUtils.find_attr(target, target.name) @classmethod def merge_global_types(cls, classes: List[Class]): - """ - Merge parent-child global types. + """Merge parent-child global types. Conditions 1. One of them is derived from xs:element 2. One of them is derived from xs:complexType 3. The xs:element is a subclass of the xs:complexType 4. The xs:element has no attributes (This can't happen in a valid schema) + + + Args: + classes: A list of duplicate classes """ el = collections.first(x for x in classes if x.tag == Tag.ELEMENT) ct = collections.first(x for x in classes if x.tag == Tag.COMPLEX_TYPE) diff --git a/xsdata/codegen/writer.py b/xsdata/codegen/writer.py index 0301631f4..8c8af5362 100644 --- a/xsdata/codegen/writer.py +++ b/xsdata/codegen/writer.py @@ -12,10 +12,13 @@ class CodeWriter: - """ - Proxy to format generators and files structure creation. + """Code writer class. + + Args: + generator: The code generator instance - :param generator: Code generator instance + Attributes: + generators: A map of registered code generators """ __slots__ = "generator" @@ -28,9 +31,15 @@ def __init__(self, generator: AbstractGenerator): self.generator = generator def write(self, classes: List[Class]): - """Iterate over the designated generator outputs and create the - necessary directories and files.""" + """Write the classes to the designated modules. + + The classes may be written in the same module or + different ones, the entrypoint must create the + directory structure write the file outputs. + Args: + classes: A list of class instances + """ self.generator.normalize_packages(classes) header = self.generator.render_header() @@ -42,8 +51,11 @@ def write(self, classes: List[Class]): result.path.write_text(src_code, encoding="utf-8") def print(self, classes: List[Class]): - """Iterate over the designated generator outputs and print them to the - console.""" + """Print the generated code for the given classes. + + Args: + classes: A list of class instances + """ self.generator.normalize_packages(classes) header = self.generator.render_header() for result in self.generator.render(classes): @@ -53,6 +65,16 @@ def print(self, classes: List[Class]): @classmethod def from_config(cls, config: GeneratorConfig) -> "CodeWriter": + """Instance the code writer from the generator configuration instance. + + Validates that the output format is registered as a generator. + + Args: + config: The generator configuration instance + + Returns: + A new code writer instance. + """ if config.output.format.value not in cls.generators: raise CodeGenerationError( f"Unknown output format: '{config.output.format.value}'" @@ -63,14 +85,33 @@ def from_config(cls, config: GeneratorConfig) -> "CodeWriter": @classmethod def register_generator(cls, name: str, clazz: Type[AbstractGenerator]): + """Register a generator by name. + + Args: + name: The generator name + clazz: The generator class + """ cls.generators[name] = clazz @classmethod def unregister_generator(cls, name: str): + """Remove a generator by name. + + Args: + name: The generator name + """ cls.generators.pop(name) def ruff_code(self, src_code: str, file_path: Path) -> str: - """Run ruff format on the src code.""" + """Run ruff format on the src code. + + Args: + src_code: The output source code + file_path: The file path the source code will be written to + + Returns: + The formatted output source code + """ commands = [ [ "ruff", diff --git a/xsdata/formats/bindings.py b/xsdata/formats/bindings.py index 9070e6272..69e566790 100644 --- a/xsdata/formats/bindings.py +++ b/xsdata/formats/bindings.py @@ -7,25 +7,79 @@ class AbstractSerializer(abc.ABC): + """Abstract serializer class.""" + @abc.abstractmethod - def render(self, obj: object) -> object: - """Render the given object to the target output format.""" + def render(self, obj: Any) -> str: + """Serialize the input model instance to the output string format. + + Args: + obj: The input model instance to serialize + + Returns: + The serialized string output format. + """ class AbstractParser(abc.ABC): + """Abstract parser class.""" + def from_path(self, path: pathlib.Path, clazz: Optional[Type[T]] = None) -> T: - """Parse the input file path and return the resulting object tree.""" + """Parse the input file into the target class type. + + If no clazz is provided, the binding context will try + to locate it from imported dataclasses. + + Args: + path: The path to the input file + clazz: The target class type to parse the file into + + Returns: + An instance of the specified class representing the parsed content. + """ return self.parse(str(path.resolve()), clazz) def from_string(self, source: str, clazz: Optional[Type[T]] = None) -> T: - """Parse the input string and return the resulting object tree.""" + """Parse the input source string into the target class type. + + If no clazz is provided, the binding context will try + to locate it from imported dataclasses. + + Args: + source: The source string to parse + clazz: The target class type to parse the source string into + + Returns: + An instance of the specified class representing the parsed content. + """ return self.from_bytes(source.encode(), clazz) def from_bytes(self, source: bytes, clazz: Optional[Type[T]] = None) -> T: - """Parse the input bytes array return the resulting object tree.""" + """Parse the input source bytes object into the target class type. + + If no clazz is provided, the binding context will try + to locate it from imported dataclasses. + + Args: + source: The source bytes object to parse + clazz: The target class type to parse the source bytes object + + Returns: + An instance of the specified class representing the parsed content. + """ return self.parse(io.BytesIO(source), clazz) @abc.abstractmethod def parse(self, source: Any, clazz: Optional[Type[T]] = None) -> T: - """Parse the input stream or filename and return the resulting object - tree.""" + """Parse the input file or stream into the target class type. + + If no clazz is provided, the binding context will try + to locate it from imported dataclasses. + + Args: + source: The source stream object to parse + clazz: The target class type to parse the source bytes object + + Returns: + An instance of the specified class representing the parsed content. + """ diff --git a/xsdata/formats/converter.py b/xsdata/formats/converter.py index 1fa3b9165..cf4210024 100644 --- a/xsdata/formats/converter.py +++ b/xsdata/formats/converter.py @@ -38,19 +38,34 @@ class Converter(abc.ABC): @abc.abstractmethod def deserialize(self, value: Any, **kwargs: Any) -> Any: - """ - Convert any type to the converter dedicated type. + """Convert a value to a python type. + + Args: + value: The input value + **kwargs: Additional keyword arguments needed per converter - :raises ConverterError: if converter fails with and expected - ValueError + Returns: + The converted value. + + Raises: + ConverterError: if the value can't be converted. """ @abc.abstractmethod def serialize(self, value: Any, **kwargs: Any) -> str: - """Convert value to string.""" + """Convert value to string for serialization. + + Args: + value: The input value + **kwargs: Additional keyword arguments needed per converter + + Returns: + The converted string value. + """ @classmethod def validate_input_type(cls, value: Any, tp: Type): + """Validate the input value type matches the required type.""" if not isinstance(value, tp): raise ConverterError( f"Input value must be '{tp.__name__}' got '{type(value).__name__}'" @@ -58,19 +73,30 @@ def validate_input_type(cls, value: Any, tp: Type): class ConverterFactory: - __slots__ = ("registry",) + """Converter factory class. + + Attributes: + registry: The registered converters + """ + + __slots__ = "registry" def __init__(self): self.registry: Dict[Type, Converter] = {} def deserialize(self, value: Any, types: Sequence[Type], **kwargs: Any) -> Any: - """ - Attempt to convert a any value to one of the given types. + """Attempt to convert any value to one of the given types. - If all attempts fail return the value input value and issue a + If all attempts fail return the value input value and emit a warning. - :return: The first successful converted value. + Args: + value: The input value + types: The target candidate types + **kwargs: Additional keyword arguments needed per converter + + Returns: + The converted value or the input value. """ for data_type in types: try: @@ -85,10 +111,17 @@ def deserialize(self, value: Any, types: Sequence[Type], **kwargs: Any) -> Any: return value def serialize(self, value: Any, **kwargs: Any) -> Any: - """ - Convert the given value to string, ignore None values. + """Convert the given value to string. If the value is a list assume the value is a list of tokens. + + Args: + value: The input value + **kwargs: Additional keyword arguments needed per converter + + Returns: + The converted string value or None if the input value is None. + """ if value is None: return None @@ -106,14 +139,17 @@ def test( strict: bool = False, **kwargs: Any, ) -> bool: - """ - Test the given string value can be parsed using the given list of types - without warnings. + """Test the given string value can be converted to one of the given types. - If strict flag is enabled validate the textual representation - also matches the original input. - """ + Args: + value: The input value + types: The candidate target types + strict: validate the string output also matches the original input + **kwargs: Additional keyword arguments needed per converter + Returns: + The bool result. + """ if not isinstance(value, str): return False @@ -130,11 +166,11 @@ def test( return True def register_converter(self, data_type: Type, func: Union[Callable, Converter]): - """ - Register a callable or converter for the given data type. + """Register a callable or converter for the given data type. - Callables will be wrapped in a - :class:`xsdata.formats.converter.ProxyConverter` + Args: + data_type: The data type + func: The callable or converter instance """ if isinstance(func, Converter): self.registry[data_type] = func @@ -142,33 +178,41 @@ def register_converter(self, data_type: Type, func: Union[Callable, Converter]): self.registry[data_type] = ProxyConverter(func) def unregister_converter(self, data_type: Type): - """ - Unregister the converter for the given data type. + """Unregister the converter for the given data type. + + Args: + data_type: The data type - :raises KeyError: if the data type is not registered. + Raises: + KeyError: if the data type is not registered. """ self.registry.pop(data_type) - def type_converter(self, datatype: Type) -> Converter: - """ - Find a suitable converter for given data type. + def type_converter(self, data_type: Type) -> Converter: + """Find a suitable converter for given data type. Iterate over all but last mro items and check for registered converters, fall back to str and issue a warning if there are - not matches. + no matches. + + Args: + data_type: The data type + + Returns: + A converter instance """ try: # Quick in and out, without checking the whole mro. - return self.registry[datatype] + return self.registry[data_type] except KeyError: pass # We tested the first, ignore the object - for mro in datatype.__mro__[1:-1]: + for mro in data_type.__mro__[1:-1]: if mro in self.registry: return self.registry[mro] - warnings.warn(f"No converter registered for `{datatype}`", ConverterWarning) + warnings.warn(f"No converter registered for `{data_type}`", ConverterWarning) return self.registry[str] def value_converter(self, value: Any) -> Converter: @@ -184,7 +228,8 @@ def sort_types(cls, types: Sequence[Type]) -> List[Type]: return sorted(types, key=lambda x: __PYTHON_TYPES_SORTED__.get(x, 0)) @classmethod - def explicit_types(cls) -> Tuple: + def explicit_types(cls) -> Tuple[Type, ...]: + """Get a list of types that need strict test.""" return __EXPLICIT_TYPES__ @@ -219,15 +264,34 @@ def explicit_types(cls) -> Tuple: class StringConverter(Converter): + """A str converter.""" + def deserialize(self, value: Any, **kwargs: Any) -> Any: + """Convert a value to string.""" return value if isinstance(value, str) else str(value) def serialize(self, value: Any, **kwargs: Any) -> str: + """Convert a value to string.""" return value if isinstance(value, str) else str(value) class BoolConverter(Converter): + """A bool converter.""" + def deserialize(self, value: Any, **kwargs: Any) -> bool: + """Convert a value to bool. + + Args: + value: The input value + **kwargs: Unused keyword arguments + + Returns: + True if the value is in (True, "true", "1") + False if the value is in (False, "false", "0") + + Raises: + ConverterError: if the value can't be converted to bool. + """ if isinstance(value, str): val = value.strip() @@ -245,30 +309,85 @@ def deserialize(self, value: Any, **kwargs: Any) -> bool: raise ConverterError(f"Invalid bool literal '{value}'") def serialize(self, value: bool, **kwargs: Any) -> str: + """Convert a bool value to string. + + Args: + value: The input bool value + **kwargs: Unused keyword arguments + + Returns: + "true" or "false" + """ return "true" if value else "false" class IntConverter(Converter): + """An int converter.""" + def deserialize(self, value: Any, **kwargs: Any) -> int: + """Convert a value to int. + + Args: + value: The input value + **kwargs: Unused keyword arguments + + Returns: + The int converted value. + + Raises: + ConverterError: on value or type errors. + """ try: return int(value) except (ValueError, TypeError) as e: raise ConverterError(e) def serialize(self, value: int, **kwargs: Any) -> str: + """Convert an int value sto string. + + Args: + value: The input int value + **kwargs: Unused keyword arguments + + Returns: + The str converted value. + """ return str(value) class FloatConverter(Converter): + """A float converter.""" + INF = float("inf") def deserialize(self, value: Any, **kwargs: Any) -> float: + """Convert a value to float. + + Args: + value: The input value + **kwargs: Unused keyword arguments + + Returns: + The float converted value. + + Raises: + ConverterError: on value errors. + """ try: return float(value) except ValueError as e: raise ConverterError(e) def serialize(self, value: float, **kwargs: Any) -> str: + """Convert a float value sto string. + + Args: + value: The input int value + **kwargs: Unused keyword arguments + + Returns: + The str converted value. + """ if math.isnan(value): return "NaN" @@ -282,7 +401,23 @@ def serialize(self, value: float, **kwargs: Any) -> str: class BytesConverter(Converter): + """A bytes converter for base16 and base64 formats.""" + def deserialize(self, value: Any, **kwargs: Any) -> bytes: + """Convert a string value to base16 or base64 format. + + Args: + value: The input string value + **kwargs: Additional keyword arguments + format: The target output format (base16|base64) + + Returns: + The bytes converted value. + + Raises: + ConverterError: If format is empty or not supported or the value + contains invalid characters. + """ self.validate_input_type(value, str) try: @@ -299,6 +434,20 @@ def deserialize(self, value: Any, **kwargs: Any) -> bytes: raise ConverterError(e) def serialize(self, value: bytes, **kwargs: Any) -> str: + """Convert a bytes value sto string. + + Args: + value: The input bytes value + **kwargs: Additional keyword arguments + format: The input value format (base16|base64) + + Returns: + The str converted value. + + Raises: + ConverterError: If format doesn't match the value type or + it's not supported. + """ fmt = kwargs.get("format") if isinstance(value, XmlHexBinary) or fmt == "base16": @@ -311,13 +460,36 @@ def serialize(self, value: bytes, **kwargs: Any) -> str: class DecimalConverter(Converter): + """A decimal converter.""" + def deserialize(self, value: Any, **kwargs: Any) -> Decimal: + """Convert a value to decimal. + + Args: + value: The input value + **kwargs: Unused keyword arguments + + Returns: + The decimal converted value. + + Raises: + ConverterError: on InvalidOperation errors. + """ try: return Decimal(value) except InvalidOperation: raise ConverterError() def serialize(self, value: Decimal, **kwargs: Any) -> str: + """Convert a decimal value sto string. + + Args: + value: The input decimal value + **kwargs: Unused keyword arguments + + Returns: + The str converted value. + """ if value.is_infinite(): return str(value).replace("Infinity", "INF") @@ -325,19 +497,33 @@ def serialize(self, value: Decimal, **kwargs: Any) -> str: class QNameConverter(Converter): + """A QName converter.""" + def deserialize( self, value: str, ns_map: Optional[Dict] = None, **kwargs: Any, ) -> QName: - """ - Convert namespace prefixed strings, or fully qualified strings to - QNames. + """Convert a string value to QName instance. + + The method supports strings with namespace prefixes + or fully namespace qualified strings. - examples: + Examples: - xs:string -> QName("http://www.w3.org/2001/XMLSchema", "string") - {foo}bar -> QName("foo", "bar" + + Args: + value: The input str value + ns_map: A namespace prefix-URI map + **kwargs: Unused keyword arguments + + Returns: + A QName instance + + Raises: + ConverterError: If the prefix can't be resolved. """ self.validate_input_type(value, str) namespace, tag = self.resolve(value, ns_map) @@ -345,18 +531,29 @@ def deserialize( return QName(namespace, tag) if namespace else QName(tag) def serialize( - self, value: QName, ns_map: Optional[Dict] = None, **kwargs: Any + self, + value: QName, + ns_map: Optional[Dict] = None, + **kwargs: Any, ) -> str: - """ + """Convert a QName instance value sto string. + Convert a QName instance to string either with a namespace prefix if a prefix-URI namespaces mapping is provided or to a fully qualified name with the namespace. - examples: + Examples: - QName("http://www.w3.org/2001/XMLSchema", "int") & ns_map -> xs:int - QName("foo, "bar") -> {foo}bar - """ + Args: + value: The qname instance to convert + ns_map: A namespace prefix-URI map, if we want to use prefixes + **kwargs: Unused keyword arguments + + Returns: + The str converted value. + """ if ns_map is None: return value.text @@ -370,7 +567,21 @@ def serialize( return f"{prefix}:{tag}" if prefix else tag @staticmethod - def resolve(value: str, ns_map: Optional[Dict] = None) -> Tuple: + def resolve(value: str, ns_map: Optional[Dict] = None) -> Tuple[str, str]: + """Split a qname or ns prefixed string value or a uri, name pair. + + Args: + value: the input value to resolve + ns_map: A namespace prefix-URI map + + Returns: + A tuple of uri and name strings. + + Raises: + ConverterError: if the uri is not valid, + if the prefix can't be resolved to a URI, + if the name is not a valid NCName + """ value = value.strip() if not value: @@ -394,12 +605,33 @@ def resolve(value: str, ns_map: Optional[Dict] = None) -> Tuple: class EnumConverter(Converter): + """An enum converter.""" + def serialize(self, value: Enum, **kwargs: Any) -> str: + """Convert an enum member to a string.""" return converter.serialize(value.value, **kwargs) def deserialize( - self, value: Any, data_type: Optional[EnumMeta] = None, **kwargs: Any + self, + value: Any, + data_type: Optional[EnumMeta] = None, + **kwargs: Any, ) -> Enum: + """Convert a value to an enum member. + + Args: + value: The input value + data_type: The enumeration class + **kwargs: Additional keyword arguments needed + for parsing the value to a python type. + + Returns: + The enum member. + + Raises: + ConverterError: if the data type is not an enum, or the value + doesn't match any of the enum members. + """ if data_type is None or not isinstance(data_type, EnumMeta): raise ConverterError(f"'{data_type}' is not an enum") @@ -420,29 +652,47 @@ def deserialize( @classmethod def match( - cls, value: Any, values: Sequence, length: int, real: Any, **kwargs: Any + cls, + value: Any, + values: Sequence, + length: int, + real: Any, + **kwargs: Any, ) -> bool: + """Match a value to one of the enumeration values. + + Args: + value: The input value + values: The input value as a sequence, in case of NMTokens + length: The length of the sequence values + real: The enumeration value + **kwargs: Additional keyword arguments needed + for parsing the value to a python type. + + Returns: + Whether the value or values matches the enumeration member value. + """ if isinstance(value, str) and isinstance(real, str): return value == real or " ".join(values) == real if isinstance(real, (tuple, list)) and not hasattr(real, "_fields"): - if len(real) == length and cls.match_list(values, real, **kwargs): + if len(real) == length and cls._match_list(values, real, **kwargs): return True - elif length == 1 and cls.match_atomic(value, real, **kwargs): + elif length == 1 and cls._match_atomic(value, real, **kwargs): return True return False @classmethod - def match_list(cls, raw: Sequence, real: Sequence, **kwargs: Any) -> bool: + def _match_list(cls, raw: Sequence, real: Sequence, **kwargs: Any) -> bool: for index, val in enumerate(real): - if not cls.match_atomic(raw[index], val, **kwargs): + if not cls._match_atomic(raw[index], val, **kwargs): return False return True @classmethod - def match_atomic(cls, raw: Any, real: Any, **kwargs: Any) -> bool: + def _match_atomic(cls, raw: Any, real: Any, **kwargs: Any) -> bool: with warnings.catch_warnings(): warnings.simplefilter("ignore") cmp = converter.deserialize(raw, [type(real)], **kwargs) @@ -454,8 +704,24 @@ def match_atomic(cls, raw: Any, real: Any, **kwargs: Any) -> bool: class DateTimeBase(Converter, metaclass=abc.ABCMeta): + """An abstract datetime converter.""" + @classmethod def parse(cls, value: Any, **kwargs: Any) -> datetime: + """Parse a str into a datetime instance. + + Args: + value: The input string value + **kwargs: Additional keyword argument + format: The datetime format to use + + Returns: + The datetime instance + + Raises: + ConverterError: If no format was provided or the value + could not be converted. + """ try: return datetime.strptime(value, kwargs["format"]) except KeyError: @@ -464,6 +730,20 @@ def parse(cls, value: Any, **kwargs: Any) -> datetime: raise ConverterError(e) def serialize(self, value: Union[date, time], **kwargs: Any) -> str: + """Convert a datetime instance to string. + + Args: + value: The input datetime instance + **kwargs: Additional keyword argument + format: The datetime format to use + + Returns: + The converted str value. + + Raises: + ConverterError: If no format was provided or the value + could not be converted. + """ try: return value.strftime(kwargs["format"]) except KeyError: @@ -477,36 +757,100 @@ def deserialize(self, value: Any, **kwargs: Any) -> Any: class TimeConverter(DateTimeBase): + """A datetime.time converter.""" + def deserialize(self, value: Any, **kwargs: Any) -> time: + """Convert the input str to a time instance. + + Args: + value: The input string value + **kwargs: Additional keyword argument + format: The time format to use + + Returns: + The time instance + + Raises: + ConverterError: If no format was provided or the value + could not be converted. + """ return self.parse(value, **kwargs).time() class DateConverter(DateTimeBase): + """A datetime.date converter.""" + def deserialize(self, value: Any, **kwargs: Any) -> date: + """Convert the input str to a date instance. + + Args: + value: The input string value + **kwargs: Additional keyword argument + format: The time format to use + + Returns: + The date instance + + Raises: + ConverterError: If no format was provided or the value + could not be converted. + """ return self.parse(value, **kwargs).date() class DateTimeConverter(DateTimeBase): + """A datetime.datetime converter.""" + def deserialize(self, value: Any, **kwargs: Any) -> datetime: + """Convert the input str to a datetime instance. + + Args: + value: The input string value + **kwargs: Additional keyword argument + format: The time format to use + + Returns: + The datetime instance + + Raises: + ConverterError: If no format was provided or the value + could not be converted. + """ return self.parse(value, **kwargs) class ProxyConverter(Converter): - __slots__ = ("factory",) + """Proxy wrapper to treat callables as converters. + + Args: + factory: The callable factory + """ + + __slots__ = "factory" def __init__(self, factory: Callable): - """ - :param factory: factory function used to parse string values - """ self.factory = factory def deserialize(self, value: Any, **kwargs: Any) -> Any: + """Call the instance factory and return the result. + + Args: + value: The input value to convert + **kwargs: Unused keyword arguments + + Returns: + The return result of the callable. + + Raises: + ConverterError: on value errors. + """ try: return self.factory(value) except ValueError as e: raise ConverterError(e) def serialize(self, value: Any, **kwargs: Any) -> str: + """Cast value to str.""" return str(value) diff --git a/xsdata/formats/dataclass/client.py b/xsdata/formats/dataclass/client.py index 7b82241a7..ec8325a28 100644 --- a/xsdata/formats/dataclass/client.py +++ b/xsdata/formats/dataclass/client.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Any, Dict, NamedTuple, Optional, Type +from typing import Any, Dict, NamedTuple, Optional, Type, Union from xsdata.exceptions import ClientValueError from xsdata.formats.dataclass.parsers import XmlParser @@ -9,15 +9,15 @@ class Config(NamedTuple): - """ - Service configuration class. - - :param style: binding style - :param location: service endpoint url - :param transport: transport namespace - :param soap_action: soap action - :param input: input object type - :param output: output object type + """Service configuration class. + + Args: + style: The binding style + location: The service endpoint url + transport: The transport namespace + soap_action: The soap action + input: The input class + output: The output class """ style: str @@ -30,7 +30,21 @@ class Config(NamedTuple): @classmethod def from_service(cls, obj: Any, **kwargs: Any) -> "Config": - """Instantiate from a generated service class.""" + """Instantiate from a generated service class. + + Args: + obj: The service class + **kwargs: Override the service class properties + style: The binding style + location: The service endpoint url + transport: The transport namespace + soap_action: The soap action + input: The input class + output: The output class + + Returns: + A new config instance. + """ params = { key: kwargs[key] if key in kwargs else getattr(obj, key, None) for key in cls._fields @@ -40,16 +54,20 @@ def from_service(cls, obj: Any, **kwargs: Any) -> "Config": class TransportTypes: + """Transport types.""" + SOAP = "http://schemas.xmlsoap.org/soap/http" @dataclass class Client: - """ - :param config: service configuration - :param transport: transport instance to handle requests - :param parser: xml parser instance to handle xml response parsing - :param serializer: xml serializer instance to handle xml response parsing + """A wsdl client. + + Args: + config: The service config instance + transport: The transport instance + parser: The xml parser instance + serializer: The xml serializer instance """ config: Config @@ -59,16 +77,26 @@ class Client: dict_converter: DictConverter = field(init=False, default_factory=DictConverter) @classmethod - def from_service(cls, obj: Type, **kwargs: str) -> "Client": - """Instantiate client from a service definition.""" + def from_service(cls, obj: Type, **kwargs: Any) -> "Client": + """Instantiate client from a service class. + + Args: + obj: The service class + **kwargs: Override the service class properties + style: The binding style + location: The service endpoint url + transport: The transport namespace + soap_action: The soap action + input: The input class + output: The output class + + Returns: + A new client instance. + """ return cls(config=Config.from_service(obj, **kwargs)) def send(self, obj: Any, headers: Optional[Dict] = None) -> Any: - """ - Send a request and parse the response according to the service - configuration. - - The input object can be a dictionary, or the input type instance directly + """Build and send a request for the input object. >>> params = {"body": {"add": {"int_a": 3, "int_b": 4}}} >>> res = client.send(params) @@ -79,8 +107,12 @@ def send(self, obj: Any, headers: Optional[Dict] = None) -> Any: >>> body=CalculatorSoapAddInput.Body(add=Add(3, 4))) >>> res = client.send(req) - :param obj: a params dictionary or the input type instance - :param headers: a dictionary of any additional headers. + Args: + obj: The request model instance or a pure dictionary + headers: Additional headers to pass to the transport + + Returns: + The response model instance. """ data = self.prepare_payload(obj) headers = self.prepare_headers(headers or {}) @@ -88,13 +120,13 @@ def send(self, obj: Any, headers: Optional[Dict] = None) -> Any: return self.parser.from_bytes(response, self.config.output) def prepare_headers(self, headers: Dict) -> Dict: - """ - Prepare request headers according to the service configuration. + """Prepare the request headers. - Don't mutate input headers dictionary. + It merges the custom user headers with the necessary headers + to accommodate the service class configuration. - :raises ClientValueError: If the service transport type is - unsupported. + Raises: + ClientValueError: If the service transport type is not supported. """ result = headers.copy() if self.config.transport == TransportTypes.SOAP: @@ -108,12 +140,20 @@ def prepare_headers(self, headers: Dict) -> Dict: return result - def prepare_payload(self, obj: Any) -> Any: - """ - Prepare and serialize payload to be sent. + def prepare_payload(self, obj: Any) -> Union[str, bytes]: + """Prepare and serialize the payload to be sent. + + If the obj is a pure dictionary, it will be converted + first to a request model instance. + + Args: + obj: The request model instance or a pure dictionary + + Returns: + The serialized request body content as string or bytes. - :raises ClientValueError: If the config input type doesn't match - the given input. + Raises: + ClientValueError: If the config input type doesn't match the given object. """ if isinstance(obj, Dict): obj = self.dict_converter.convert(obj, self.config.input) diff --git a/xsdata/formats/dataclass/compat.py b/xsdata/formats/dataclass/compat.py index 64c1d51eb..fb7425c3b 100644 --- a/xsdata/formats/dataclass/compat.py +++ b/xsdata/formats/dataclass/compat.py @@ -8,6 +8,8 @@ class FieldInfo(NamedTuple): + """A class field info wrapper.""" + name: str init: bool metadata: Dict[str, Any] @@ -16,6 +18,8 @@ class FieldInfo(NamedTuple): class ClassType(abc.ABC): + """An interface for class types like attrs, pydantic.""" + __slots__ = () @property @@ -44,10 +48,13 @@ def is_model(self, obj: Any) -> bool: @abc.abstractmethod def verify_model(self, obj: Any): - """ - Verify the given value is a binding model. + """Verify the given value is a binding model. - :raises xsdata.exceptions.XmlContextError: if not supported + Args: + obj: The input model instance + + Raises: + XmlContextError: if not supported """ @abc.abstractmethod @@ -60,17 +67,21 @@ def default_value(self, field: FieldInfo, default: Optional[Any] = None) -> Any: @abc.abstractmethod def default_choice_value(self, choice: Dict) -> Any: - """Return the default value or factory of the given model field - choice.""" + """Return the default value or factory of the given model field choice.""" def score_object(self, obj: Any) -> float: - """ - Score a binding model instance by its field values types. + """Score a binding model instance by its field values types. Weights: 1. None: 0 2. str: 1 3. *: 1.5 + + Args: + obj: The input object + + Returns: + The float score value. """ if not obj: return -1.0 @@ -93,40 +104,79 @@ def score(value: Any) -> float: class ClassTypes: + """A class types registry. + + Attributes: + types: A name-instance map of the registered class types + """ + __slots__ = "types" def __init__(self): self.types: Dict[str, ClassType] = {} def register(self, name: str, fmt: ClassType, **_: Any): + """Register a class type instance by name. + + Args: + name: The name of the class type + fmt: The class type instance + **_: No idea :( + """ self.types[name] = fmt def get_type(self, name: str) -> ClassType: + """Get a class type instance by name. + + Args: + name: The class type name + + Returns: + The class type instance + + Raises: + KeyError: If the name is not registed. + """ return self.types[name] class Dataclasses(ClassType): + """The dataclasses class type.""" + __slots__ = () @property def any_element(self) -> Type: + """Return the generic any element class.""" return AnyElement @property def derived_element(self) -> Type: + """Return the generic derived element class.""" return DerivedElement def is_model(self, obj: Any) -> bool: + """Return whether the obj is a dataclass model.""" return is_dataclass(obj) def verify_model(self, obj: Any): + """Validate whether the obj is a dataclass model. + + Args: + obj: The input object to validate. + + Raises: + XmlContextError: If it's not a dataclass model. + """ if not self.is_model(obj): raise XmlContextError(f"Type '{obj}' is not a dataclass.") def get_fields(self, obj: Any) -> Iterator[FieldInfo]: + """Return a dataclass fields iterator.""" yield from cast(List[FieldInfo], fields(obj)) def default_value(self, field: FieldInfo, default: Optional[Any] = None) -> Any: + """Return the default value or factory of the given model field.""" if field.default_factory is not MISSING: return field.default_factory @@ -136,6 +186,7 @@ def default_value(self, field: FieldInfo, default: Optional[Any] = None) -> Any: return default def default_choice_value(self, choice: Dict) -> Any: + """Return the default value or factory of the given model field choice.""" factory = choice.get("default_factory") if callable(factory): return factory diff --git a/xsdata/formats/dataclass/context.py b/xsdata/formats/dataclass/context.py index a0b501440..c54ac88e4 100644 --- a/xsdata/formats/dataclass/context.py +++ b/xsdata/formats/dataclass/context.py @@ -1,6 +1,6 @@ import sys from collections import defaultdict -from typing import Any, Callable, Dict, List, Optional, Set, Type +from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Type from xsdata.exceptions import XmlContextError from xsdata.formats.bindings import T @@ -12,23 +12,31 @@ class XmlContext: - """ - The service provider for binding operations' metadata. + """The models context class. + + The context is responsible to provide binding metadata + for models and their fields. + + Args: + element_name_generator: Default element name generator + attribute_name_generator: Default attribute name generator + class_type: Default class type `dataclasses` + models_package: Restrict auto locate to a specific package - :param element_name_generator: Default element name generator - :param attribute_name_generator: Default attribute name generator - :param class_type: Default class type `dataclasses` - :param models_package: Restrict auto locate to a specific package + Attributes: + cache: Internal cache for binding metadata instances + xsi_cache: Internal cache for xsi types to class locations + sys_modules: The number of loaded sys modules """ __slots__ = ( "element_name_generator", "attribute_name_generator", "class_type", + "models_package", "cache", "xsi_cache", "sys_modules", - "models_package", ) def __init__( @@ -48,13 +56,16 @@ def __init__( self.sys_modules = 0 def reset(self): + """Reset all internal caches.""" self.cache.clear() self.xsi_cache.clear() self.sys_modules = 0 def get_builder( - self, globalns: Optional[Dict[str, Callable]] = None + self, + globalns: Optional[Dict[str, Callable]] = None, ) -> XmlMetaBuilder: + """Return a new xml meta builder instance.""" return XmlMetaBuilder( class_type=self.class_type, element_name_generator=self.element_name_generator, @@ -68,15 +79,17 @@ def fetch( parent_ns: Optional[str] = None, xsi_type: Optional[str] = None, ) -> XmlMeta: - """ - Fetch the model metadata of the given dataclass type, namespace and xsi - type. - - :param clazz: The requested dataclass type - :param parent_ns: The inherited parent namespace - :param xsi_type: if present it means that the given clazz is - derived and the lookup procedure needs to check and match a - dataclass model to the qualified name instead + """Build the model metadata for the given class. + + Args: + clazz: The requested dataclass type + parent_ns: The inherited parent namespace + xsi_type: if present it means that the given clazz is + derived and the lookup procedure needs to check and match a + dataclass model to the qualified name instead. + + Returns: + A xml meta instance """ meta = self.build(clazz, parent_ns) subclass = None @@ -86,7 +99,7 @@ def fetch( return self.build(subclass, parent_ns) if subclass else meta def build_xsi_cache(self): - """Index all imported dataclasses by their xsi:type qualified name.""" + """Index all imported data classes by their xsi:type qualified name.""" if len(sys.modules) == self.sys_modules: return @@ -102,6 +115,17 @@ def build_xsi_cache(self): self.sys_modules = len(sys.modules) def is_binding_model(self, clazz: Type[T]) -> bool: + """Return whether the clazz is a binding model. + + If the models package is not empty also validate + the class is located within that package. + + Args: + clazz: The class type to inspect + + Returns: + The bool result. + """ if not self.class_type.is_model(clazz): return False @@ -112,13 +136,16 @@ def is_binding_model(self, clazz: Type[T]) -> bool: ) def find_types(self, qname: str) -> List[Type[T]]: - """ - Find all classes that match the given xsi:type qname. + """Find all classes that match the given xsi:type qname. - Ignores native schema types, xs:string, xs:float, xs:int, ... - Rebuild cache if new modules were imported since last run - :param qname: Qualified name + Args: + qname: A namespace qualified name + + Returns: + A list of the matched classes. """ if not DataType.from_qname(qname): self.build_xsi_cache() @@ -128,21 +155,27 @@ def find_types(self, qname: str) -> List[Type[T]]: return [] def find_type(self, qname: str) -> Optional[Type[T]]: - """ - Return the most recently imported class that matches the given xsi:type - qname. + """Return the last imported class that matches the given xsi:type qname. + + Args: + qname: A namespace qualified name - :param qname: Qualified name + Returns: + A class type or None if no matches. """ types: List[Type] = self.find_types(qname) return types[-1] if types else None def find_type_by_fields(self, field_names: Set[str]) -> Optional[Type[T]]: - """ - Find a dataclass from all the imported modules that matches the given - list of field names. + """Find a data class that matches best the given list of field names. + + Args: + field_names: A set of field names - :param field_names: A unique list of field names + Returns: + The best matching class or None if no matches. The class must + have all the fields. If more than one classes have all the given + fields, return the one with the least extra fields. """ def get_field_diff(clazz: Type) -> int: @@ -162,18 +195,23 @@ def get_field_diff(clazz: Type) -> int: return choices[0][0] if len(choices) > 0 else None def find_subclass(self, clazz: Type, qname: str) -> Optional[Type]: - """ + """Find a subclass for the given clazz and xsi:type qname. + Compare all classes that match the given xsi:type qname and return the first one that is either a subclass or shares the same parent class as the original class. - :param clazz: The search dataclass type - :param qname: Qualified name + Args: + clazz: The input clazz type + qname: The xsi:type to lookup from cache + + Args: + The matching class type or None if no matches. """ types: List[Type] = self.find_types(qname) for tp in types: - # Why would an xml node with have an xsi:type that points - # to parent class is beyond me but it happens, let's protect + # Why would a xml node with have a xsi:type that points + # to parent class is beyond me, but it happens, let's protect # against that scenario if issubclass(clazz, tp): continue @@ -190,12 +228,15 @@ def build( parent_ns: Optional[str] = None, globalns: Optional[Dict[str, Callable]] = None, ) -> XmlMeta: - """ - Fetch from cache or build the binding metadata for the given class and - parent namespace. + """Fetch or build the binding metadata for the given class. + + Args: + clazz: A class type + parent_ns: The inherited parent namespace + globalns: Override the global python namespace - :param clazz: A dataclass type - :param parent_ns: The inherited parent namespace + Returns: + The class binding metadata instance. """ if clazz not in self.cache: builder = self.get_builder(globalns) @@ -203,8 +244,14 @@ def build( return self.cache[clazz] def build_recursive(self, clazz: Type, parent_ns: Optional[str] = None): - """Build the binding metadata for the given class and all of its - dependencies.""" + """Build the binding metadata for the given class and all of its dependencies. + + This method is used in benchmarks! + + Args: + clazz: The class type + parent_ns: The inherited parent namespace + """ if clazz not in self.cache: meta = self.build(clazz, parent_ns) for var in meta.get_all_vars(): @@ -214,6 +261,18 @@ def build_recursive(self, clazz: Type, parent_ns: Optional[str] = None): self.build_recursive(tp, meta.namespace) def local_names_match(self, names: Set[str], clazz: Type) -> bool: + """Check if the given field names match the given class type. + + Silently ignore, typing errors. These classes are from third + party libraries most of them time. + + Args: + names: A set of field names + clazz: The class type to inspect + + Returns: + Whether the class contains all the field names. + """ try: meta = self.build(clazz) local_names = {var.local_name for var in meta.get_all_vars()} @@ -230,12 +289,7 @@ def local_names_match(self, names: Set[str], clazz: Type) -> bool: @classmethod def is_derived(cls, obj: Any, clazz: Type) -> bool: - """ - Return whether the given obj is derived from the given dataclass type. - - :param obj: A dataclass instance - :param clazz: A dataclass type - """ + """Return whether the obj is a subclass or a parent of the given class type.""" if obj is None: return False @@ -245,7 +299,8 @@ def is_derived(cls, obj: Any, clazz: Type) -> bool: return any(x is not object and isinstance(obj, x) for x in clazz.__bases__) @classmethod - def get_subclasses(cls, clazz: Type): + def get_subclasses(cls, clazz: Type) -> Iterator[Type]: + """Return an iterator of the given class subclasses.""" try: for subclass in clazz.__subclasses__(): yield from cls.get_subclasses(subclass) diff --git a/xsdata/formats/dataclass/filters.py b/xsdata/formats/dataclass/filters.py index 97be9b903..e582023ff 100644 --- a/xsdata/formats/dataclass/filters.py +++ b/xsdata/formats/dataclass/filters.py @@ -2,7 +2,18 @@ import sys import textwrap from collections import defaultdict -from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type +from typing import ( + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + Optional, + Set, + Tuple, + Type, +) from docformatter import configuration, format from jinja2 import Environment @@ -23,6 +34,8 @@ class Filters: + """Jinja filters for code generation.""" + DEFAULT_KEY = "default" FACTORY_KEY = "default_factory" UNESCAPED_DBL_QUOTE_REGEX = re.compile(r"([^\\])\"") @@ -104,6 +117,7 @@ def __init__(self, config: GeneratorConfig): self.default_class_annotation = self.build_class_annotation(self.format) def register(self, env: Environment): + """Register the template filters to the jinja environment.""" env.globals.update( { "docstring_name": self.docstring_style.name.lower(), @@ -137,6 +151,7 @@ def register(self, env: Environment): @classmethod def build_class_annotation(cls, fmt: OutputFormat) -> str: + """Build the class annotations.""" args = [] if not fmt.repr: args.append("repr=False") @@ -155,7 +170,8 @@ def build_class_annotation(cls, fmt: OutputFormat) -> str: return f"@dataclass({', '.join(args)})" if args else "@dataclass" - def class_params(self, obj: Class): + def class_params(self, obj: Class) -> Iterator[Tuple[str, str]]: + """Yield the class variables with their docstring text.""" is_enum = obj.is_enumeration for attr in obj.attrs: name = attr.name @@ -166,8 +182,19 @@ def class_params(self, obj: Class): yield self.field_name(name, obj.name), docstring def class_name(self, name: str) -> str: - """Convert the given string to a class name according to the selected - conventions or use an existing alias.""" + """Class name filter. + + Steps: + - Apply substitutions before naming conventions + - Apply naming convention + - Apply substitutions after naming conventions + + Args: + name: The original class name + + Returns: + The final class name + """ name = self.apply_substitutions(name, ObjectType.CLASS) name = self.safe_name(name, self.class_safe_prefix, self.class_case) return self.apply_substitutions(name, ObjectType.CLASS) @@ -207,6 +234,7 @@ def class_annotations(self, obj: Class, class_name: str) -> List[str]: return collections.unique_sequence(annotations) def apply_substitutions(self, name: str, obj_type: ObjectType) -> str: + """Apply name substitutions by obj type.""" for search, replace in self.substitutions[obj_type].items(): name = re.sub(rf"{search}", rf"{replace}", name) @@ -237,11 +265,19 @@ def field_definition( return f"field({self.format_arguments(kwargs, 4)})" def field_name(self, name: str, class_name: str) -> str: - """ - Convert the given name to a field name according to the selected - conventions or use an existing alias. + """Field name filter. + + Steps: + - Apply substitutions before naming conventions + - Apply naming convention + - Apply substitutions after naming conventions + + Args: + name: The original field name + class_name: The class name, some naming conventions require it - Provide the class name as context for the naming schemes. + Returns: + The final field name """ prefix = self.field_safe_prefix name = self.apply_substitutions(name, ObjectType.FIELD) @@ -249,11 +285,19 @@ def field_name(self, name: str, class_name: str) -> str: return self.apply_substitutions(name, ObjectType.FIELD) def constant_name(self, name: str, class_name: str) -> str: - """ - Convert the given name to a constant name according to the selected - conventions or use an existing alias. + """Constant name filter. + + Steps: + - Apply substitutions before naming conventions + - Apply naming convention + - Apply substitutions after naming conventions - Provide the class name as context for the naming schemes. + Args: + name: The original constant name + class_name: The class name, some naming conventions require it + + Returns: + The final constant name """ prefix = self.field_safe_prefix name = self.apply_substitutions(name, ObjectType.FIELD) @@ -261,17 +305,38 @@ def constant_name(self, name: str, class_name: str) -> str: return self.apply_substitutions(name, ObjectType.FIELD) def module_name(self, name: str) -> str: - """Convert the given string to a module name according to the selected - conventions or use an existing alias.""" + """Module name filter. + + Steps: + - Apply substitutions before naming conventions + - Apply naming convention + - Apply substitutions after naming conventions + + Args: + name: The original module name + + Returns: + The final module name + """ prefix = self.module_safe_prefix name = self.apply_substitutions(name, ObjectType.MODULE) name = self.safe_name(namespaces.clean_uri(name), prefix, self.module_case) return self.apply_substitutions(name, ObjectType.MODULE) def package_name(self, name: str) -> str: - """Convert the given string to a package name according to the selected - conventions or use an existing alias.""" + """Package name filter. + + Steps: + - Apply substitutions before naming conventions + - Apply naming convention + - Apply substitutions after naming conventions + Args: + name: The original package name + + Returns: + The final package name + """ name = self.apply_substitutions(name, ObjectType.PACKAGE) if not name: @@ -287,7 +352,14 @@ def process_sub_package(pck: str) -> str: return self.apply_substitutions(name, ObjectType.PACKAGE) def type_name(self, attr_type: AttrType) -> str: - """Return native python type name or apply class name conventions.""" + """Field type filter. + + Args: + attr_type: The attr type instance. + + Returns: + The python type name or the user type final name + """ datatype = attr_type.datatype if datatype: return datatype.type.__name__ @@ -295,7 +367,11 @@ def type_name(self, attr_type: AttrType) -> str: return self.class_name(attr_type.alias or attr_type.name) def safe_name( - self, name: str, prefix: str, name_case: Callable, **kwargs: Any + self, + name: str, + prefix: str, + name_case: Callable, + **kwargs: Any, ) -> str: """Sanitize names for safe generation.""" if not name: @@ -339,15 +415,16 @@ def import_class(self, name: str, alias: Optional[str]) -> str: return self.class_name(name) def post_meta_hook(self, obj: Class) -> Optional[str]: - """Plugin hook to render additional information after the xsdata meta - class.""" + """Plugin hook to render additional information after the xsdata meta class.""" return None def field_metadata( - self, attr: Attr, parent_namespace: Optional[str], parents: List[str] + self, + attr: Attr, + parent_namespace: Optional[str], + parents: List[str], ) -> Dict: """Return a metadata dictionary for the given attribute.""" - if attr.is_prohibited: return {"type": XmlType.IGNORE} @@ -380,13 +457,7 @@ def field_metadata( def field_choices( self, attr: Attr, parent_namespace: Optional[str], parents: List[str] ) -> Optional[Tuple]: - """ - Return a list of metadata dictionaries for the choices of the given - attribute. - - Return None if attribute has no choices. - """ - + """Return a tuple of field metadata if the attr has choices.""" if not attr.choices: return None @@ -421,6 +492,7 @@ def field_choices( @classmethod def filter_metadata(cls, data: Dict) -> Dict: + """Filter out false,none keys from the given dict.""" return { key: value for key, value in data.items() @@ -473,8 +545,7 @@ def format_iterable(self, data: Iterable, indent: int) -> str: return wrap.format("\n".join(lines), ind) def format_string(self, data: str, indent: int, key: str = "", pad: int = 0) -> str: - """ - Return a pretty string representation of a string. + """Return a pretty string representation of a string. If the total length of the input string plus indent plus the key length and the additional pad is more than the max line length, @@ -538,16 +609,19 @@ def text_wrap( @classmethod def clean_docstring(cls, string: Optional[str], escape: bool = True) -> str: - """ - Prepare string for docstring generation. + """Prepare string for docstring generation. - Strip whitespace from each line - Replace triple double quotes with single quotes - Escape backslashes - :param string: input value - :param escape: skip backslashes escape, if string is going to - pass through formatting. + Args: + string: input value + escape: skip backslashes escape, if string is going to + pass through formatting. + + Returns: + The cleaned docstring text. """ if not string: return "" @@ -626,6 +700,7 @@ def field_default_value(self, attr: Attr, ns_map: Optional[Dict] = None) -> Any: ) def field_default_enum(self, attr: Attr) -> str: + """Generate the default value for enum fields.""" assert attr.default is not None qname, reference = attr.default[6:].split("::", 1) @@ -645,6 +720,7 @@ def field_default_enum(self, attr: Attr) -> str: def field_default_tokens( self, attr: Attr, types: List[Type], ns_map: Optional[Dict] ) -> str: + """Generate the default value for tokens fields.""" assert isinstance(attr.default, str) fmt = attr.restrictions.format @@ -660,14 +736,13 @@ def field_default_tokens( return f"lambda: {self.format_metadata(tokens, indent=8)}" def field_type(self, attr: Attr, parents: List[str]) -> str: - """Generate type hints for the given attribute.""" - + """Generate type hints for the given attr.""" if attr.is_prohibited: return "Any" - result = self.field_type_names(attr, parents, choice=False) + result = self._field_type_names(attr, parents, choice=False) - iterable_fmt = self.get_iterable_format() + iterable_fmt = self._get_iterable_format() if attr.is_tokens: result = iterable_fmt.format(result) @@ -688,8 +763,7 @@ def field_type(self, attr: Attr, parents: List[str]) -> str: return result def choice_type(self, choice: Attr, parents: List[str]) -> str: - """ - Generate type hints for the given choice. + """Generate type hints for the given choice. Choices support a subset of features from normal attributes. First of all we don't have a proper type hint but a type @@ -697,12 +771,18 @@ def choice_type(self, choice: Attr, parents: List[str]) -> str: The second big difference is that our choice belongs to a compound field that might be a list, that's why list restriction is also ignored. - """ - result = self.field_type_names(choice, parents, choice=True) + Args: + choice: The choice instance + parents: A list of the parent class names + + Returns: + The string representation of the type hint. + """ + result = self._field_type_names(choice, parents, choice=True) if choice.is_tokens: - iterable_fmt = self.get_iterable_format() + iterable_fmt = self._get_iterable_format() result = iterable_fmt.format(result) if self.subscriptable_types: @@ -710,15 +790,18 @@ def choice_type(self, choice: Attr, parents: List[str]) -> str: return f"Type[{result}]" - def field_type_names( - self, attr: Attr, parents: List[str], choice: bool = False + def _field_type_names( + self, + attr: Attr, + parents: List[str], + choice: bool = False, ) -> str: type_names = [ - self.field_type_name(x, parents, choice=choice) for x in attr.types + self._field_type_name(x, parents, choice=choice) for x in attr.types ] - return self.join_type_names(type_names) + return self._join_type_names(type_names) - def join_type_names(self, type_names: List[str]) -> str: + def _join_type_names(self, type_names: List[str]) -> str: type_names = collections.unique_sequence(type_names) if len(type_names) == 1: return type_names[0] @@ -728,7 +811,7 @@ def join_type_names(self, type_names: List[str]) -> str: return f'Union[{", ".join(type_names)}]' - def field_type_name( + def _field_type_name( self, attr_type: AttrType, parents: List[str], choice: bool = False ) -> str: name = self.type_name(attr_type) @@ -780,12 +863,13 @@ def default_imports(self, output: str) -> str: return "\n".join(imports) - def get_iterable_format(self): + def _get_iterable_format(self): fmt = "Tuple[{}, ...]" if self.format.frozen else "List[{}]" return fmt.lower() if self.subscriptable_types else fmt @classmethod def build_import_patterns(cls) -> Dict[str, Dict]: + """Build import search patterns.""" type_patterns = cls.build_type_patterns return { "dataclasses": {"dataclass": ["@dataclass"], "field": [" = field("]}, @@ -812,6 +896,7 @@ def build_import_patterns(cls) -> Dict[str, Dict]: @classmethod def build_type_patterns(cls, x: str) -> Tuple: + """Return all possible type occurrences in the generated code.""" return ( f": {x} =", f"[{x}]", diff --git a/xsdata/formats/dataclass/generator.py b/xsdata/formats/dataclass/generator.py index 540c1f355..ffd9ea48f 100644 --- a/xsdata/formats/dataclass/generator.py +++ b/xsdata/formats/dataclass/generator.py @@ -11,7 +11,15 @@ class DataclassGenerator(AbstractGenerator): - """Python dataclasses code generator.""" + """Python dataclasses code generator. + + Args: + config: The generator config instance + + Attributes: + env: The jinja2 environment instance + filters: The template filters instance + """ __slots__ = ("env", "filters") @@ -22,9 +30,6 @@ class DataclassGenerator(AbstractGenerator): class_template = "class.jinja2" def __init__(self, config: GeneratorConfig): - """Override generator constructor to set templates directory and - environment filters.""" - super().__init__(config) template_paths = self.get_template_paths() loader = FileSystemLoader(template_paths) @@ -34,17 +39,20 @@ def __init__(self, config: GeneratorConfig): @classmethod def get_template_paths(cls) -> List[str]: + """Return a list of template paths to feed the jinja2 loader.""" return [str(Path(__file__).parent.joinpath("templates"))] def render(self, classes: List[Class]) -> Iterator[GeneratorResult]: - """ - Return an iterator of the generated results. + """Render the given classes to python packages and modules. - Group classes into modules and yield an output per module and - per path __init__.py file. + Args: + classes: A list of class instances + + Yields: + An iterator of generator result instances. """ packages = {obj.qname: obj.target_module for obj in classes} - resolver = DependenciesResolver(packages=packages) + resolver = DependenciesResolver(registry=packages) # Generate packages for path, cluster in self.group_by_package(classes).items(): @@ -65,8 +73,15 @@ def render(self, classes: List[Class]) -> Iterator[GeneratorResult]: ) def render_package(self, classes: List[Class], module: str) -> str: - """Render the source code for the __init__.py with all the imports of - the generated class names.""" + """Render the package for the given classes. + + Args: + classes: A list of class instances + module: The target dot notation path + + Returns: + The rendered package output. + """ imports = [ Import(qname=obj.qname, source=obj.target_module) for obj in sorted(classes, key=lambda x: x.name) @@ -80,11 +95,19 @@ def render_package(self, classes: List[Class], module: str) -> str: return f"{output.strip()}\n" def render_module( - self, resolver: DependenciesResolver, classes: List[Class] + self, + resolver: DependenciesResolver, + classes: List[Class], ) -> str: - """Render the source code for the target module of the given class - list.""" + """Render the module for the given classes. + Args: + resolver: The dependencies resolver + classes: A list of class instances + + Returns: + The rendered module output. + """ if len({x.target_namespace for x in classes}) == 1: module_namespace = classes[0].target_namespace else: @@ -105,9 +128,19 @@ def render_module( ) def render_classes( - self, classes: List[Class], module_namespace: Optional[str] + self, + classes: List[Class], + module_namespace: Optional[str], ) -> str: - """Render the source code of the classes.""" + """Render the classes source code in a module. + + Args: + classes: A list of class instances + module_namespace: The module namespace URI + + Returns: + The rendered classes source code output. + """ def render_class(obj: Class) -> str: """Render class or enumeration.""" @@ -139,9 +172,14 @@ def package_name(self, name: str) -> str: @classmethod def ensure_packages(cls, package: Path) -> Iterator[GeneratorResult]: - """Ensure all the __init__ files exists for the target package path, - otherwise yield the necessary filepath, name, source output that needs - to be created.""" + """Ensure __init__.py files exists recursively in the package. + + Args: + package: The package file path + + Yields: + An iterator of generator result instances. + """ cwd = Path.cwd() while cwd < package: init = package.joinpath("__init__.py") @@ -153,4 +191,5 @@ def ensure_packages(cls, package: Path) -> Iterator[GeneratorResult]: @classmethod def init_filters(cls, config: GeneratorConfig) -> Filters: + """Initialize the filters instance by the generator configuration.""" return Filters(config) diff --git a/xsdata/formats/dataclass/models/builders.py b/xsdata/formats/dataclass/models/builders.py index ddc4955f6..d7e804f05 100644 --- a/xsdata/formats/dataclass/models/builders.py +++ b/xsdata/formats/dataclass/models/builders.py @@ -8,7 +8,6 @@ Iterator, List, Mapping, - NamedTuple, Optional, Sequence, Set, @@ -28,17 +27,58 @@ from xsdata.utils.namespaces import build_qname -class ClassMeta(NamedTuple): - element_name_generator: Callable - attribute_name_generator: Callable - qname: str - local_name: str - nillable: bool - namespace: Optional[str] - target_qname: Optional[str] +class ClassMeta: + """The binding model combined metadata. + + Args: + element_name_generator: The element name generator + attribute_name_generator: The attribute name generator + qname: The namespace qualified name of the class + local_name: The name of the element this class represents + nillable: Specifies whether this class supports nillable content + namespace: The class namespace + target_qname: The class target namespace qualified name + """ + + __slots__ = ( + "element_name_generator", + "attribute_name_generator", + "qname", + "local_name", + "nillable", + "namespace", + "target_qname", + ) + + def __init__( + self, + element_name_generator: Callable, + attribute_name_generator: Callable, + qname: str, + local_name: str, + nillable: bool, + namespace: Optional[str], + target_qname: Optional[str], + ): + self.element_name_generator = element_name_generator + self.attribute_name_generator = attribute_name_generator + self.qname = qname + self.local_name = local_name + self.nillable = nillable + self.namespace = namespace + self.target_qname = target_qname class XmlMetaBuilder: + """Binding class metadata builder. + + Args: + class_type: The supported class type, e.g. dataclass, attr, pydantic + element_name_generator: The default element name generator + attribute_name_generator: The default attribute name generator + globalns: The global namespace + """ + __slots__ = ( "class_type", "element_name_generator", @@ -59,7 +99,15 @@ def __init__( self.globalns = globalns def build(self, clazz: Type, parent_namespace: Optional[str]) -> XmlMeta: - """Build the binding metadata for a dataclass and its fields.""" + """Build the binding metadata for a dataclass and its fields. + + Args: + clazz: The target class + parent_namespace: The parent class namespace + + Returns: + The binding metadata instance. + """ self.class_type.verify_model(clazz) meta = self.build_class_meta(clazz, parent_namespace) @@ -114,8 +162,18 @@ def build_vars( namespace: Optional[str], element_name_generator: Callable, attribute_name_generator: Callable, - ): - """Build the binding metadata for the given dataclass fields.""" + ) -> Iterator[XmlVar]: + """Build the binding metadata for the given dataclass fields. + + Args: + clazz: The target class + namespace: The target class namespace + element_name_generator: The class element name generator + attribute_name_generator: The class attribute name generator + + Yields: + An iterator of the field binding metadata instances. + """ type_hints = get_type_hints(clazz, globalns=self.globalns) builder = XmlVarBuilder( class_type=self.class_type, @@ -144,12 +202,20 @@ def build_vars( yield var def build_class_meta( - self, clazz: Type, parent_namespace: Optional[str] = None + self, + clazz: Type, + parent_namespace: Optional[str] = None, ) -> ClassMeta: - """ - Fetch the class meta options and merge defaults. + """Build the class meta options and merge with the defaults. + + The class metaclass is not inheritable. + + Args: + clazz: The target class + parent_namespace: The parent class namespace - Metaclass is not inheritable + Returns: + A class meta instance. """ meta = clazz.Meta if "Meta" in clazz.__dict__ else None element_name_generator = getattr( @@ -184,6 +250,10 @@ def build_class_meta( @classmethod def find_declared_class(cls, clazz: Type, name: str) -> Type: + """Find the user class that matches the name. + + Todo: Honestly I have no idea why we needed this. + """ for base in clazz.__mro__: ann = base.__dict__.get("__annotations__") if ann and name in ann: @@ -210,8 +280,14 @@ def target_namespace(cls, module: Any, meta: Any) -> Optional[str]: return getattr(meta, "namespace", None) def default_xml_type(self, clazz: Type) -> str: - """Return the default xml type for the fields of the given dataclass - with an undefined type.""" + """Return the default xml type for the fields of the given dataclass. + + If a class has fields with no xml type defined, attempt + to figure it from the rest of the fields. It's either + a text or an element field. + + # Todo hacks like this are so unnecessary... + """ counters: Dict[str, int] = defaultdict(int) for var in self.class_type.get_fields(clazz): xml_type = var.metadata.get("type") @@ -229,6 +305,18 @@ def default_xml_type(self, clazz: Type) -> str: class XmlVarBuilder: + """Binding class field metadata builder. + + Args: + class_type: The supported class type, e.g. dataclass, attr, pydantic + default_xml_type: The default xml type of this class fields + element_name_generator: The element name generator + attribute_name_generator: The attribute name generator + + Attributes: + index: The index of the next var + """ + __slots__ = ( "index", "class_type", @@ -261,7 +349,21 @@ def build( globalns: Any, factory: Optional[Callable] = None, ) -> Optional[XmlVar]: - """Build the binding metadata for a dataclass field.""" + """Build the binding metadata for a class field. + + Args: + name: The field name + type_hint: The typing annotations of the field + metadata: The field metadata mapping + init: Specify whether this field can be initialized + parent_namespace: The class namespace + default_value: The field default value or factory + globalns: Python's global namespace + factory: The value factory + + Returns: + The field binding metadata instance. + """ xml_type = metadata.get("type", self.default_xml_type) if xml_type == XmlType.IGNORE: return None @@ -292,7 +394,7 @@ def build( f"a wrapper requires a collection type on attribute {name}" ) - local_name = self.build_local_name(xml_type, local_name, name) + local_name = local_name or self.build_local_name(xml_type, name) if tokens and sub_origin is None: sub_origin = origin @@ -354,7 +456,18 @@ def build_choices( globalns: Any, parent_namespace: Optional[str], ) -> Iterator[XmlVar]: - """Build the binding metadata for a compound dataclass field.""" + """Build the binding metadata for a compound dataclass field. + + Args: + name: The compound field name + choices: The list of choice metadata + factory: The compound field values factory + globalns: Python's global namespace + parent_namespace: The class namespace + + Yields: + An iterator of field choice binding metadata instance. + """ existing_types: Set[type] = set() for choice in choices: @@ -390,18 +503,20 @@ def build_choices( yield var - def build_local_name( - self, xml_type: str, local_name: Optional[str], name: str - ) -> str: - """Build a local name based on the field name and xml type if it's not - set.""" - if not local_name: - if xml_type == XmlType.ATTRIBUTE: - return self.attribute_name_generator(name) + def build_local_name(self, xml_type: str, name: str) -> str: + """Transform the name for serialization by the target xml type. - return self.element_name_generator(name) + Args: + xml_type: The xml type: element, attribute, ... + name: The field name - return local_name + Returns: + The name to use for serialization. + """ + if xml_type == XmlType.ATTRIBUTE: + return self.attribute_name_generator(name) + + return self.element_name_generator(name) @classmethod def resolve_namespaces( @@ -410,9 +525,7 @@ def resolve_namespaces( namespace: Optional[str], parent_namespace: Optional[str], ) -> Tuple[str, ...]: - """ - Resolve the namespace(s) for the given xml type and the parent - namespace. + """Resolve a fields supported namespaces. Only elements and wildcards are allowed to inherit the parent namespace if the given namespace is empty. @@ -420,10 +533,13 @@ def resolve_namespaces( In case of wildcard try to decode the ##any, ##other, ##local, ##target. - :param xml_type: The xml type - (Text|Element(s)|Attribute(s)|Wildcard) - :param namespace: The field namespace - :param parent_namespace: The parent namespace + Args: + xml_type: The xml type (Text|Element(s)|Attribute(s)|Wildcard) + namespace: The field namespace + parent_namespace: The parent namespace + + Returns: + A tuple of supported namespaces. """ if xml_type in (XmlType.ELEMENT, XmlType.WILDCARD) and namespace is None: namespace = parent_namespace @@ -446,12 +562,15 @@ def resolve_namespaces( @classmethod def default_namespace(cls, namespaces: Sequence[str]) -> Optional[str]: - """ - Return the first valid namespace uri or None. + """Return the first valid namespace uri or None. + + Args: + namespaces: A list of namespace options which may include + valid uri(s) or a placeholder e.g. ##any, ##other, + ##targetNamespace, ##local - :param namespaces: A list of namespace options which may include - valid uri(s) or one of the ##any, ##other, - ##targetNamespace, ##local + Returns: + A namespace uri or None if there isn't any. """ for namespace in namespaces: if namespace and not namespace.startswith("#"): @@ -461,7 +580,7 @@ def default_namespace(cls, namespaces: Sequence[str]) -> Optional[str]: @classmethod def is_any_type(cls, types: Sequence[Type], xml_type: str) -> bool: - """Return whether the given xml type supports derived values.""" + """Return whether the given xml type supports generic values.""" if xml_type in (XmlType.ELEMENT, XmlType.ELEMENTS): return object in types @@ -471,15 +590,15 @@ def is_any_type(cls, types: Sequence[Type], xml_type: str) -> bool: def analyze_types( cls, type_hint: Any, globalns: Any ) -> Tuple[Any, Any, Tuple[Type, ...]]: - """ - Analyze a type hint and return the origin, sub origin and the type - args. + """Analyze a type hint and return the origin, sub origin and the type args. The only case we support a sub origin is for fields derived from xs:NMTOKENS! - :raises XmlContextError: if the typing is not supported for - binding + # Todo please rewrite this in a way that makes sense :( + + Raises: + XmlContextError: if the typing is not supported for binding """ try: types = evaluate(type_hint, globalns) @@ -510,7 +629,6 @@ def is_valid( init: bool, ) -> bool: """Validate the given xml type against common unsupported cases.""" - if not init: # Ignore init==false vars return True @@ -531,7 +649,7 @@ def is_valid( return self.is_typing_supported(types) def is_typing_supported(self, types: Sequence[Type]) -> bool: - # Validate all types are registered in the converter. + """Validate all types are registered in the converter.""" for tp in types: if ( not self.class_type.is_model(tp) diff --git a/xsdata/formats/dataclass/models/elements.py b/xsdata/formats/dataclass/models/elements.py index 2aa2a236c..6179a91b9 100644 --- a/xsdata/formats/dataclass/models/elements.py +++ b/xsdata/formats/dataclass/models/elements.py @@ -41,41 +41,59 @@ class MetaMixin: __slots__: Tuple[str, ...] = () def __eq__(self, other: Any) -> bool: + """Implement equality operator.""" return tuple(self) == tuple(other) def __iter__(self) -> Iterator: + """Implement iteration.""" for name in self.__slots__: yield getattr(self, name) def __repr__(self) -> str: + """Implement representation.""" params = (f"{name}={getattr(self, name)!r}" for name in self.__slots__) return f"{self.__class__.__qualname__}({', '.join(params)})" class XmlVar(MetaMixin): - """ - Class field binding metadata. - - :param index: Field ordering - :param name: Field name - :param qname: Qualified name - :param types: List of all the supported data types - :param init: Include field in the constructor - :param mixed: Field supports mixed content type values - :param tokens: Field is derived from xs:list - :param format: Value format information - :param derived: Wrap parsed values with a generic type - :param any_type: Field supports dynamic value types - :param required: Field is mandatory - :param nillable: Field supports nillable content - :param sequence: Render values in sequential mode - :param list_element: Field is a list of elements - :param default: Field default value or factory - :param xml_Type: Field xml type - :param namespaces: List of the supported namespaces - :param elements: Mapping of qname-repeatable elements - :param wildcards: List of repeatable wildcards - :param wrapper: A name for the wrapper. Applies for list types only. + """Class field binding metadata. + + Args: + index: Position index of the variable + name: Name of the variable + qname: Namespace-qualified name of the variable + types: Supported types for the variable + clazz: Target class type + init: Indicates if the field should be included in the constructor + mixed: Indicates if the field supports mixed content type values. + factory: Callable factory for lists + tokens_factory: Callable factory for tokens + format: Information about the value format + derived: Indicates whether parsed values should be wrapped with a generic type + any_type: Indicates if the field supports dynamic value types + process_contents: Information about processing contents + required: Indicates if the field is mandatory + nillable: Indicates if the field supports nillable content + sequence: Specifies rendering values in sequential mode + default: Default value or factory for the field + xml_type: Type of the XML field (element, attribute, etc.) + namespaces: List of supported namespaces + elements: Mapping of qualified name-repeatable elements + wildcards: List of repeatable wildcards + wrapper: Name for the wrapper (applies for list types only) + + Attributes: + tokens: Indicates if the field has associated tokens + list_element: Indicates if the field is a list or tuple element + namespace_matches: Matching namespaces information + is_clazz_union: Indicates if the field is a union of multiple types + local_name: Local name extracted from the qualified name + is_text: Indicates if the field represents text content + is_element: Indicates if the field represents an XML element + is_elements: Indicates if the field represents a sequence of XML elements + is_wildcard: Indicates if the field represents a wildcard + is_attribute: Indicates if the field represents an XML attribute + is_attributes: Indicates if the field represents a sequence of XML attributes """ __slots__ = ( @@ -192,21 +210,35 @@ def __init__( @property def element_types(self) -> Set[Type]: + """Return the unique element types.""" return {tp for element in self.elements.values() for tp in element.types} def find_choice(self, qname: str) -> Optional["XmlVar"]: - """Match and return a choice field by its qualified name.""" + """Match and return a choice field by its qualified name. + + Args: + qname: The qualified name to lookup + + Returns: + The choice xml var instance or None if there are no matches. + """ match = self.elements.get(qname) return match or find_by_namespace(self.wildcards, qname) def find_value_choice(self, value: Any, is_class: bool) -> Optional["XmlVar"]: - """ - Match and return a choice field that matches the given value. + """Match and return a choice field that matches the given value. Cases: - value is none or empty tokens list: look for a nillable choice - value is a dataclass: look for exact type or a subclass - value is primitive: test value against the converter + + Args: + value: The value to match its type to one of the choices + is_class: Whether the value is a binding class + + Returns: + The choice xml var instance or None if there are no matches. """ is_tokens = collections.is_array(value) if value is None or (not value and is_tokens): @@ -218,25 +250,56 @@ def find_value_choice(self, value: Any, is_class: bool) -> Optional["XmlVar"]: return self.find_primitive_choice(value, is_tokens) def find_nillable_choice(self, is_tokens: bool) -> Optional["XmlVar"]: + """Find the first nillable choice. + + Args: + is_tokens: Specify if the choice must support token values + + Returns: + The choice xml var instance or None if there are no matches. + """ return collections.first( element for element in self.elements.values() if element.nillable and is_tokens == element.tokens ) - def find_clazz_choice(self, tp: Type) -> Optional["XmlVar"]: + def find_clazz_choice(self, clazz: Type) -> Optional["XmlVar"]: + """Find the best matching choice for the given class. + + Best Matches: + 1. The class is explicitly defined in a choice types + 2. The class is a subclass of one of the choice types + + Args: + clazz: The class type to match + + Returns: + The choice xml var instance or None if there are no matches. + """ derived = None for element in self.elements.values(): - if element.clazz: - if tp in element.types: - return element + if not element.clazz: + continue - if derived is None and any(issubclass(tp, t) for t in element.types): - derived = element + if clazz in element.types: + return element + + if derived is None and any(issubclass(clazz, t) for t in element.types): + derived = element return derived def find_primitive_choice(self, value: Any, is_tokens: bool) -> Optional["XmlVar"]: + """Match and return a choice field that matches the given primitive value. + + Args: + value: A primitive value, e.g. str, int, float, enum + is_tokens: Specify whether it's a tokens value + + Returns: + The choice xml var instance or None if there are no matches. + """ tp = type(value) if not is_tokens else type(value[0]) for element in self.elements.values(): if (element.any_type or element.clazz) or element.tokens != is_tokens: @@ -254,8 +317,14 @@ def find_primitive_choice(self, value: Any, is_tokens: bool) -> Optional["XmlVar return None def is_optional(self, value: Any) -> bool: - """Return whether this var instance is not required and the given value - matches the default one.""" + """Verify this var is optional and the value matches the default one. + + Args: + value: The value to compare against the default one + + Returns: + The bool result. + """ if self.required: return False @@ -264,7 +333,14 @@ def is_optional(self, value: Any) -> bool: return self.default == value def match_namespace(self, qname: str) -> bool: - """Match the given qname to the wildcard allowed namespaces.""" + """Match the given qname to the wildcard allowed namespaces. + + Args: + qname: The namespace qualified name of an element + + Returns: + The bool result. + """ if self.namespace_matches is None: self.namespace_matches = {} @@ -296,21 +372,24 @@ def _match_namespace(self, qname: str) -> bool: class XmlMeta(MetaMixin): - """ - Class binding metadata. - - :param clazz: The dataclass type - :param qname: The namespace qualified name. - :param target_qname: The target namespace qualified name. - :param nillable: Specifies whether an explicit empty value can be - assigned. - :param mixed_content: Has a wildcard with mixed flag enabled - :param text: Text var - :param choices: List of compound vars - :param elements: Mapping of qname-element vars - :param wildcards: List of wildcard vars - :param attributes: Mapping of qname-attribute vars - :param any_attributes: List of wildcard attributes vars + """Class binding metadata. + + Args: + clazz: The binding model + qname: The namespace-qualified name + target_qname: The target namespace-qualified name + nillable: Specifies whether this class supports nillable content + text: A text variable + choices: A list of compound variables + elements: A mapping of qualified name to sequence of element variables + wildcards: A list of wildcard variables + attributes: A mapping of qualified name to attribute variable + any_attributes: A list of wildcard variables + wrappers: a mapping of wrapper names to sequences of wrapped variables + + Attributes: + namespace: The target namespace extracted from the qualified name + mixed_content: Specifies if the class supports mixed content """ __slots__ = ( @@ -361,6 +440,7 @@ def __init__( @property def element_types(self) -> Set[Type]: + """Return a unique list of all elements types.""" return { tp for elements in self.elements.values() @@ -369,6 +449,7 @@ def element_types(self) -> Set[Type]: } def get_element_vars(self) -> List[XmlVar]: + """Return a sorted list of the class element variables.""" result = list( itertools.chain(self.wildcards, self.choices, *self.elements.values()) ) @@ -378,10 +459,12 @@ def get_element_vars(self) -> List[XmlVar]: return sorted(result, key=get_index) def get_attribute_vars(self) -> List[XmlVar]: + """Return a sorted list of the class attribute variables.""" result = itertools.chain(self.any_attributes, self.attributes.values()) return sorted(result, key=get_index) def get_all_vars(self) -> List[XmlVar]: + """Return a sorted list of all the class variables.""" result = list( itertools.chain( self.wildcards, @@ -397,14 +480,39 @@ def get_all_vars(self) -> List[XmlVar]: return sorted(result, key=get_index) def find_attribute(self, qname: str) -> Optional[XmlVar]: + """Find an attribute var with the given qname. + + Args: + qname: The namespace qualified name + + Returns: + The xml var instance or None if there is no match. + """ return self.attributes.get(qname) def find_any_attributes(self, qname: str) -> Optional[XmlVar]: + """Find a wildcard attribute var that matches the given qname. + + Args: + qname: The namespace qualified name + + Returns: + The xml var instance or None if there is no match. + """ return find_by_namespace(self.any_attributes, qname) def find_wildcard(self, qname: str) -> Optional[XmlVar]: - """Match the given qualified name to a wildcard and optionally to one - of its choice elements.""" + """Find a wildcard var that matches the given qname. + + If the wildcard has choices, attempt to match and return + one of them as well. + + Args: + qname: The namespace qualified name + + Returns: + The xml var instance or None if there is no match. + """ wildcard = find_by_namespace(self.wildcards, qname) if wildcard and wildcard.elements: @@ -415,12 +523,27 @@ def find_wildcard(self, qname: str) -> Optional[XmlVar]: return wildcard def find_any_wildcard(self) -> Optional[XmlVar]: - if self.wildcards: - return self.wildcards[0] + """Return the first declared wildcard var. - return None + Returns: + The xml var instance or None if there are no wildcard vars. + """ + return self.wildcards[0] if self.wildcards else None def find_children(self, qname: str) -> Iterator[XmlVar]: + """Find all class vars that match the given qname. + + Go through the elements, choices and wildcards. Sometimes + a class might contain more than one var with the same + qualified name. The binding process has to check all + of them and see which one to use. + + Args: + qname: The namespace qualified name + + Yields: + An iterator of all the class vars that match the given qname. + """ elements = self.elements.get(qname) if elements: yield from elements @@ -435,8 +558,17 @@ def find_children(self, qname: str) -> Iterator[XmlVar]: yield chd -def find_by_namespace(xml_vars: Sequence[XmlVar], qname: str) -> Optional[XmlVar]: - for xml_var in xml_vars: +def find_by_namespace(vars: Sequence[XmlVar], qname: str) -> Optional[XmlVar]: + """Match the given qname to one of the given vars. + + Args: + vars: The list of vars to match + qname: The namespace qualified name to lookup + + Returns: + The first matching xml var instance or None if there are no matches. + """ + for xml_var in vars: if xml_var.match_namespace(qname): return xml_var diff --git a/xsdata/formats/dataclass/models/generics.py b/xsdata/formats/dataclass/models/generics.py index 089f3da5c..a4092660d 100644 --- a/xsdata/formats/dataclass/models/generics.py +++ b/xsdata/formats/dataclass/models/generics.py @@ -8,14 +8,14 @@ @dataclass class AnyElement: - """ - Generic model to bind xml document data to wildcard fields. - - :param qname: The element's qualified name - :param text: The element's text content - :param tail: The element's tail content - :param children: The element's list of child elements. - :param attributes: The element's key-value attribute mappings. + """Generic model to bind xml document data to wildcard fields. + + Args: + qname: The element's qualified name + text: The element's text content + tail: The element's tail content + children: The element's list of child elements. + attributes: The element's key-value attribute mappings. """ qname: Optional[str] = field(default=None) @@ -31,14 +31,14 @@ class AnyElement: @dataclass class DerivedElement(Generic[T]): - """ - Generic model wrapper for type substituted elements. + """Generic model wrapper for type substituted elements. Example: eg. ... - :param qname: The element's qualified name - :param value: The wrapped value - :param type: The real xsi:type + Args: + qname: The element's qualified name + value: The wrapped value + type: The real xsi:type """ qname: str diff --git a/xsdata/formats/dataclass/parsers/bases.py b/xsdata/formats/dataclass/parsers/bases.py index 0701fc776..a6a822a69 100644 --- a/xsdata/formats/dataclass/parsers/bases.py +++ b/xsdata/formats/dataclass/parsers/bases.py @@ -6,7 +6,6 @@ from xsdata.exceptions import ConverterWarning, ParserError from xsdata.formats.bindings import T from xsdata.formats.dataclass.context import XmlContext -from xsdata.formats.dataclass.parsers.config import ParserConfig from xsdata.formats.dataclass.parsers.mixins import ( EventsHandler, PushParser, @@ -21,23 +20,32 @@ @dataclass class NodeParser(PushParser): - """ - Bind xml nodes to dataclasses. + """Bind xml nodes to data classes. + + Args: + context: The models context instance + handler: The xml handler class - :param config: Parser configuration - :param context: Model context provider - :param handler: Override default XmlHandler - :ivar ms_map: Namespace registry of parsed prefix-URI mappings + Attributes: + ns_map: The parsed namespace prefix-URI map """ - config: ParserConfig = field(default_factory=ParserConfig) context: XmlContext = field(default_factory=XmlContext) handler: Type[XmlHandler] = field(default=EventsHandler) - ns_map: Dict = field(init=False, default_factory=dict) def parse(self, source: Any, clazz: Optional[Type[T]] = None) -> T: - """Parse the input stream or filename and return the resulting object - tree.""" + """Parse the input file or stream into the target class type. + + If no clazz is provided, the binding context will try + to locate it from imported dataclasses. + + Args: + source: The source file or stream object to parse + clazz: The target class type to parse the source bytes object + + Returns: + An instance of the specified class representing the parsed content. + """ handler = self.handler(clazz=clazz, parser=self) with warnings.catch_warnings(): @@ -64,18 +72,15 @@ def start( attrs: Dict, ns_map: Dict, ): - """ - Start element notification receiver. - - Build and queue the XmlNode for the starting element. - - :param clazz: Root class type, if it's missing look for any - suitable models from the current context. - :param queue: The active XmlNode queue - :param objects: The list of all intermediate parsed objects - :param qname: Qualified name - :param attrs: Attribute key-value map - :param ns_map: Namespace prefix-URI map + """Build and queue the XmlNode for the starting element. + + Args: + clazz: The target class type, auto locate if omitted + queue: The XmlNode queue list + objects: The list of all intermediate parsed objects + qname: The element qualified name + attrs: The element attributes + ns_map: The element namespace prefix-URI map """ from xsdata.formats.dataclass.parsers.nodes import ElementNode, WrapperNode @@ -130,17 +135,17 @@ def end( text: Optional[str], tail: Optional[str], ) -> bool: - """ - End element notification receiver. + """Parse the last xml node and bind any intermediate objects. - Pop the last XmlNode from the queue and use it to build and - return the resulting object tree with its text and tail content. + Args: + queue: The XmlNode queue list + objects: The list of all intermediate parsed objects + qname: The element qualified name + text: The element text content + tail: The element tail content - :param queue: Xml nodes queue - :param objects: List of parsed objects - :param qname: Qualified name - :param text: Text content - :param tail: Tail content + Returns: + Whether the binding process was successful. """ item = queue.pop() return item.bind(qname, text, tail, objects) @@ -148,10 +153,10 @@ def end( @dataclass class RecordParser(NodeParser): - """ - Bind xml nodes to dataclasses and store the intermediate events. + """Bind xml nodes to dataclasses and store the intermediate events. - :ivar events: List of pushed events + Attributes: + events: The list of recorded events """ events: List = field(init=False, default_factory=list) @@ -165,19 +170,17 @@ def start( attrs: Dict, ns_map: Dict, ): - """ - Start element notification receiver. - - Build and queue the XmlNode for the starting element, append the - event with the attributes and ns map to the events list. - - :param clazz: Root class type, if it's missing look for any - suitable models from the current context. - :param queue: The active XmlNode queue - :param objects: The list of all intermediate parsed objects - :param qname: Qualified name - :param attrs: Attributes key-value map - :param ns_map: Namespace prefix-URI map + """Build and queue the XmlNode for the starting element. + + Record the start event for later processing. + + Args: + clazz: The target class type, auto locate if omitted + queue: The XmlNode queue list + objects: The list of all intermediate parsed objects + qname: The element qualified name + attrs: The element attributes + ns_map: The element namespace prefix-URI map """ self.events.append((EventType.START, qname, copy.deepcopy(attrs), ns_map)) super().start(clazz, queue, objects, qname, attrs, ns_map) @@ -190,23 +193,31 @@ def end( text: Optional[str], tail: Optional[str], ) -> Any: - """ - End element notification receiver. - - Pop the last XmlNode from the queue and use it to build and - return the resulting object tree with its text and tail content. - Append the end event with the text,tail content to the events - list. - - :param queue: Xml nodes queue - :param objects: List of parsed objects - :param qname: Qualified name - :param text: Text content - :param tail: Tail content + """Parse the last xml node and bind any intermediate objects. + + Record the end event for later processing + + Args: + queue: The XmlNode queue list + objects: The list of all intermediate parsed objects + qname: The element qualified name + text: The element text content + tail: The element tail content + + Returns: + Whether the binding process was successful. """ self.events.append((EventType.END, qname, text, tail)) return super().end(queue, objects, qname, text, tail) def register_namespace(self, prefix: Optional[str], uri: str): + """Register the uri prefix in the namespace registry. + + Record the start-ns event for later processing + + Args: + prefix: Namespace prefix + uri: Namespace uri + """ self.events.append((EventType.START_NS, prefix, uri)) super().register_namespace(prefix, uri) diff --git a/xsdata/formats/dataclass/parsers/config.py b/xsdata/formats/dataclass/parsers/config.py index 09d617e49..3a3d94901 100644 --- a/xsdata/formats/dataclass/parsers/config.py +++ b/xsdata/formats/dataclass/parsers/config.py @@ -1,53 +1,47 @@ -from typing import Callable, Dict, Optional, Type +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, Optional, Type from xsdata.formats.bindings import T -def default_class_factory(cls: Type[T], params: Dict) -> T: +def default_class_factory(cls: Type[T], params: Dict[str, Any]) -> T: + """The default class factory. + + To be used as a hook for plugins. + + Args: + cls: The target class type to instantiate + params: The class keyword arguments + + Returns: + A new class instance with the given params. + """ return cls(**params) # type: ignore +@dataclass class ParserConfig: - """ - Parsing configuration options. - - :param base_url: Specify a base URL when parsing from memory, and - you need support for relative links e.g. xinclude - :param load_dtd: Enable loading external dtd (lxml only) - :param process_xinclude: Enable xinclude statements processing - :param class_factory: Override default object instantiation - :param fail_on_unknown_properties: Skip unknown properties or fail - with exception - :param fail_on_unknown_attributes: Skip unknown XML attributes or - fail with exception - :param fail_on_converter_warnings: Turn converter warnings to - exceptions + """Parsing configuration options. + + Not all options are applicable for both xml and json documents. + + Args: + base_url: Specify a base URL when parsing from memory, and + you need support for relative links e.g. xinclude + load_dtd: Enable loading external dtd (lxml only) + process_xinclude: Enable xinclude statements processing + class_factory: Override default object instantiation + fail_on_unknown_properties: Skip unknown properties or fail with exception + fail_on_unknown_attributes: Skip unknown XML attributes or fail with exception + fail_on_converter_warnings: Turn converter warnings to exceptions """ - __slots__ = ( - "base_url", - "load_dtd", - "process_xinclude", - "class_factory", - "fail_on_unknown_properties", - "fail_on_unknown_attributes", - "fail_on_converter_warnings", + base_url: Optional[str] = None + load_dtd: bool = False + process_xinclude: bool = False + class_factory: Callable[[Type[T], Dict[str, Any]], T] = field( + default=default_class_factory ) - - def __init__( - self, - base_url: Optional[str] = None, - load_dtd: bool = False, - process_xinclude: bool = False, - class_factory: Callable[[Type[T], Dict], T] = default_class_factory, - fail_on_unknown_properties: bool = True, - fail_on_unknown_attributes: bool = False, - fail_on_converter_warnings: bool = False, - ): - self.base_url = base_url - self.load_dtd = load_dtd - self.process_xinclude = process_xinclude - self.class_factory = class_factory - self.fail_on_unknown_properties = fail_on_unknown_properties - self.fail_on_unknown_attributes = fail_on_unknown_attributes - self.fail_on_converter_warnings = fail_on_converter_warnings + fail_on_unknown_properties: bool = True + fail_on_unknown_attributes: bool = False + fail_on_converter_warnings: bool = False diff --git a/xsdata/formats/dataclass/parsers/handlers/__init__.py b/xsdata/formats/dataclass/parsers/handlers/__init__.py index 0b0c6913c..cf1bc05af 100644 --- a/xsdata/formats/dataclass/parsers/handlers/__init__.py +++ b/xsdata/formats/dataclass/parsers/handlers/__init__.py @@ -7,11 +7,13 @@ from xsdata.formats.dataclass.parsers.handlers.lxml import LxmlEventHandler def default_handler() -> Type[XmlHandler]: + """Return the default xml handler.""" return LxmlEventHandler except ImportError: # pragma: no cover def default_handler() -> Type[XmlHandler]: + """Return the default xml handler.""" return XmlEventHandler diff --git a/xsdata/formats/dataclass/parsers/handlers/lxml.py b/xsdata/formats/dataclass/parsers/handlers/lxml.py index 050f17358..cc7c986d7 100644 --- a/xsdata/formats/dataclass/parsers/handlers/lxml.py +++ b/xsdata/formats/dataclass/parsers/handlers/lxml.py @@ -1,4 +1,4 @@ -from typing import Any, Iterable +from typing import Any, Iterable, Tuple from lxml import etree @@ -10,28 +10,17 @@ class LxmlEventHandler(XmlHandler): - """ - Event handler based on :class:`lxml.etree.iterparse` api. - - :param parser: The parser instance to feed with events - :param clazz: The target binding model, auto located if omitted. - """ - - __slots__ = () + """An lxml event handler.""" def parse(self, source: Any) -> Any: - """ - Parse an XML document from a system identifier or an InputSource or - directly from a lxml Element or Tree. - - When Source is a lxml Element or Tree the handler will switch to - the :class:`lxml.etree.iterwalk` api. + """Parse the source XML document. - When source is a system identifier or an InputSource the parser - will ignore comments and recover from errors. + Args: + source: The xml source, can be a file resource or an input stream, + or a lxml tree/element. - When config process_xinclude is enabled the handler will parse - the whole document and then walk down the element tree. + Returns: + An instance of the class type representing the parsed content. """ if isinstance(source, (etree._ElementTree, etree._Element)): ctx = etree.iterwalk(source, EVENTS) @@ -50,8 +39,15 @@ def parse(self, source: Any) -> Any: return self.process_context(ctx) - def process_context(self, context: Iterable) -> Any: - """Iterate context and push the events to main parser.""" + def process_context(self, context: Iterable[Tuple[str, Any]]) -> Any: + """Iterate context and push events to main parser. + + Args: + context: The iterable lxml context + + Returns: + An instance of the class type representing the parsed content. + """ for event, element in context: if event == EventType.START: self.parser.start( diff --git a/xsdata/formats/dataclass/parsers/handlers/native.py b/xsdata/formats/dataclass/parsers/handlers/native.py index cee8875bc..ea1a83d5a 100644 --- a/xsdata/formats/dataclass/parsers/handlers/native.py +++ b/xsdata/formats/dataclass/parsers/handlers/native.py @@ -13,28 +13,17 @@ class XmlEventHandler(XmlHandler): - """ - Event handler based on :func:`xml.etree.ElementTree.iterparse` api. - - :param parser: The parser instance to feed with events - :param clazz: The target binding model, auto located if omitted. - """ - - __slots__ = () + """A native xml event handler.""" def parse(self, source: Any) -> Any: - """ - Parse an XML document from a system identifier or an InputSource or - directly from an xml Element or ElementTree. + """Parse the source XML document. - When source is an Element or ElementTree the handler will walk - over the objects structure. + Args: + source: The xml source, can be a file resource or an input stream, + or a xml tree/element. - When source is a system identifier or an InputSource the parser - will ignore comments and recover from errors. - - When config process_xinclude is enabled the handler will parse - the whole document and then walk down the element tree. + Returns: + An instance of the class type representing the parsed content. """ if isinstance(source, etree.ElementTree): source = source.getroot() @@ -53,8 +42,15 @@ def parse(self, source: Any) -> Any: return self.process_context(ctx) - def process_context(self, context: Iterable) -> Any: - """Iterate context and push the events to main parser.""" + def process_context(self, context: Iterable[Tuple[str, Any]]) -> Any: + """Iterate context and push events to main parser. + + Args: + context: The iterable xml context + + Returns: + An instance of the class type representing the parsed content. + """ ns_map: Dict = {} for event, element in context: if event == EventType.START: @@ -86,11 +82,17 @@ def process_context(self, context: Iterable) -> Any: def iterwalk(element: etree.Element, ns_map: Dict) -> Iterator[Tuple[str, Any]]: - """ - Walk over the element tree structure and emit start-ns/start/end events. + """Walk over the element tree and emit events. The ElementTree doesn't preserve the original namespace prefixes, we have to generate new ones. + + Args: + element: The etree element instance + ns_map: The namespace prefix-URI mapping + + Yields: + An iterator of events """ uri = namespaces.target_uri(element.tag) if uri is not None: @@ -106,6 +108,16 @@ def iterwalk(element: etree.Element, ns_map: Dict) -> Iterator[Tuple[str, Any]]: def get_base_url(base_url: Optional[str], source: Any) -> Optional[str]: + """Return the base url of the source. + + Args: + base_url: The base url from the parser config + source: The xml source input + + Args: + A base url str or None, if no base url is provided + and the source is not a string path. + """ if base_url: return base_url @@ -118,6 +130,8 @@ def xinclude_loader( encoding: Optional[str] = None, base_url: Optional[str] = None, ) -> Any: - """Custom loader for xinclude to support base_url argument that doesn't - exist for python < 3.9.""" + """Custom loader for xinclude parsing. + + The base_url argument was added in python >= 3.9. + """ return xinclude.default_loader(urljoin(base_url or "", href), parse, encoding) diff --git a/xsdata/formats/dataclass/parsers/json.py b/xsdata/formats/dataclass/parsers/json.py index c7a5d21dc..5b5fcecb3 100644 --- a/xsdata/formats/dataclass/parsers/json.py +++ b/xsdata/formats/dataclass/parsers/json.py @@ -1,7 +1,7 @@ import json import warnings from dataclasses import dataclass, field -from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Type, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union from xsdata.exceptions import ConverterWarning, ParserError from xsdata.formats.bindings import AbstractParser, T @@ -17,13 +17,12 @@ @dataclass class JsonParser(AbstractParser): - """ - Json parser for dataclasses. + """Json parser for data classes. - :param config: Parser configuration - :param context: Model context provider - :param load_factory: Replace the default json.load call with another - implementation + Args: + config: Parser configuration + context: The models context instance + load_factory: Json loader factory """ config: ParserConfig = field(default_factory=ParserConfig) @@ -31,9 +30,18 @@ class JsonParser(AbstractParser): load_factory: Callable = field(default=json.load) def parse(self, source: Any, clazz: Optional[Type[T]] = None) -> T: - """Parse the input stream or filename and return the resulting object - tree.""" + """Parse the input stream into the target class type. + + If no clazz is provided, the binding context will try + to locate it from imported dataclasses. + + Args: + source: The source file name or stream to parse + clazz: The target class type to parse the source object + Returns: + An instance of the specified class representing the parsed content. + """ data = self.load_json(source) tp = self.verify_type(clazz, data) @@ -50,6 +58,14 @@ def parse(self, source: Any, clazz: Optional[Type[T]] = None) -> T: raise ParserError(e) def load_json(self, source: Any) -> Union[Dict, List]: + """Load the given json source filename or stream. + + Args: + source: A file name or file stream + + Returns: + The loaded dictionary or list of dictionaries. + """ if not hasattr(source, "read"): with open(source, "rb") as fp: return self.load_factory(fp) @@ -57,6 +73,18 @@ def load_json(self, source: Any) -> Union[Dict, List]: return self.load_factory(source) def verify_type(self, clazz: Optional[Type[T]], data: Union[Dict, List]) -> Type[T]: + """Verify the given data matches the given clazz. + + If no clazz is provided, the binding context will try + to locate it from imported dataclasses. + + Args: + clazz: The target class type to parse object + data: The loaded dictionary or list of dictionaries + + Returns: + The clazz type to bind the loaded data. + """ if clazz is None: return self.detect_type(data) @@ -84,6 +112,14 @@ def verify_type(self, clazz: Optional[Type[T]], data: Union[Dict, List]) -> Type return clazz # type: ignore def detect_type(self, data: Union[Dict, List]) -> Type[T]: + """Locate the target clazz type from the data keys. + + Args: + data: The loaded dictionary or list of dictionaries + + Returns: + The clazz type to bind the loaded data. + """ if not data: raise ParserError("Document is empty, can not detect type") @@ -96,7 +132,15 @@ def detect_type(self, data: Union[Dict, List]) -> Type[T]: raise ParserError(f"Unable to locate model with properties({list(keys)})") def bind_dataclass(self, data: Dict, clazz: Type[T]) -> T: - """Recursively build the given model from the input dict data.""" + """Create a new instance of the given class type with the given data. + + Args: + data: The loaded data + clazz: The target class type to bind the input data + + Returns: + An instance of the class type representing the parsed content. + """ if set(data.keys()) == self.context.class_type.derived_keys: return self.bind_derived_dataclass(data, clazz) @@ -120,6 +164,22 @@ def bind_dataclass(self, data: Dict, clazz: Type[T]) -> T: raise ParserError(e) def bind_derived_dataclass(self, data: Dict, clazz: Type[T]) -> Any: + """Bind the input data to the given class type. + + Examples: + >>> { + "qname": "foo", + "type": "my:type", + "value": {"prop": "value"} + } + + Args: + data: The derived element dictionary + clazz: The target class type to bind the input data + + Returns: + An instance of the class type representing the parsed content. + """ qname = data["qname"] xsi_type = data["type"] params = data["value"] @@ -144,8 +204,15 @@ def bind_derived_dataclass(self, data: Dict, clazz: Type[T]) -> Any: return generic(qname=qname, type=xsi_type, value=value) def bind_best_dataclass(self, data: Dict, classes: Iterable[Type[T]]) -> T: - """Attempt to bind the given data to one possible models, if more than - one is successful return the object with the highest score.""" + """Bind the input data to all the given classes and return best match. + + Args: + data: The derived element dictionary + classes: The target class types to try + + Returns: + An instance of one of the class types representing the parsed content. + """ obj = None keys = set(data.keys()) max_score = -1.0 @@ -169,8 +236,20 @@ def bind_best_dataclass(self, data: Dict, classes: Iterable[Type[T]]) -> T: ) def bind_optional_dataclass(self, data: Dict, clazz: Type[T]) -> Optional[T]: - """Recursively build the given model from the input dict data but fail - on any converter warnings.""" + """Bind the input data to the given class type. + + This is a strict process, if there is any warning the process + returns None. This method is used to test if te data fit into + the class type. + + Args: + data: The derived element dictionary + clazz: The target class type to bind the input data + + Returns: + An instance of the class type representing the parsed content + or None if there is any warning or error. + """ try: with warnings.catch_warnings(): warnings.filterwarnings("error", category=ConverterWarning) @@ -179,10 +258,23 @@ def bind_optional_dataclass(self, data: Dict, clazz: Type[T]) -> Optional[T]: return None def bind_value( - self, meta: XmlMeta, var: XmlVar, value: Any, recursive: bool = False + self, + meta: XmlMeta, + var: XmlVar, + value: Any, + recursive: bool = False, ) -> Any: - """Main entry point for binding values.""" + """Main entry point for binding values. + Args: + meta: The parent xml meta instance + var: The xml var descriptor for the field + value: The data value + recursive: Whether this is a recursive call + + Returns: + The parsed object + """ # xs:anyAttributes get it out of the way, it's the mapping exception! if var.is_attributes: return dict(value) @@ -190,9 +282,11 @@ def bind_value( # Repeating element, recursively bind the values if not recursive and var.list_element and isinstance(value, list): assert var.factory is not None - return var.factory(self.bind_value(meta, var, val, True) for val in value) + return var.factory( + self.bind_value(meta, var, val, recursive=True) for val in value + ) - # If not dict this is an text or tokens value. + # If not dict this is a text or tokens value. if not isinstance(value, dict): return self.bind_text(meta, var, value) @@ -209,7 +303,16 @@ def bind_value( return self.bind_complex_type(meta, var, value) def bind_text(self, meta: XmlMeta, var: XmlVar, value: Any) -> Any: - """Bind text/tokens value entrypoint.""" + """Bind text/tokens value entrypoint. + + Args: + meta: The parent xml meta instance + var: The xml var descriptor for the field + value: The data value + + Returns: + The parsed tokens or text value. + """ if var.is_elements: # Compound field we need to match the value to one of the choice elements check_subclass = self.context.class_type.is_model(value) @@ -242,7 +345,16 @@ def bind_text(self, meta: XmlMeta, var: XmlVar, value: Any) -> Any: ) def bind_complex_type(self, meta: XmlMeta, var: XmlVar, data: Dict) -> Any: - """Bind data to a user defined dataclass.""" + """Bind complex values entrypoint. + + Args: + meta: The parent xml meta instance + var: The xml var descriptor for the field + data: The complex data value + + Returns: + The parsed dataclass instance. + """ if var.is_clazz_union: # Union of dataclasses return self.bind_best_dataclass(data, var.types) @@ -264,7 +376,24 @@ def bind_complex_type(self, meta: XmlMeta, var: XmlVar, data: Dict) -> Any: return self.bind_dataclass(data, var.clazz) def bind_derived_value(self, meta: XmlMeta, var: XmlVar, data: Dict) -> Any: - """Bind derived element entry point.""" + """Bind derived data entrypoint. + + The data is representation of a derived element, e.g. { + "qname": "foo", + "type": "my:type" + "value": Any + } + + The data value can be a primitive value or a complex value. + + Args: + meta: The parent xml meta instance + var: The xml var descriptor for the field + data: The derived element data + + Returns: + The parsed object. + """ qname = data["qname"] xsi_type = data["type"] params = data["value"] @@ -297,8 +426,21 @@ def bind_derived_value(self, meta: XmlMeta, var: XmlVar, data: Dict) -> Any: @classmethod def find_var( - cls, xml_vars: Sequence[XmlVar], local_name: str, is_list: bool = False + cls, + xml_vars: List[XmlVar], + local_name: str, + is_list: bool = False, ) -> Optional[XmlVar]: + """Match the name to a xml variable. + + Args: + xml_vars: A list of xml vars + local_name: A key from the loaded data + is_list: Whether the data value is an array + + Returns: + One of the xml vars, if all search attributes match, None otherwise. + """ for var in xml_vars: if var.local_name == local_name: var_is_list = var.list_element or var.tokens @@ -310,5 +452,20 @@ def find_var( @dataclass class DictConverter(JsonParser): - def convert(self, data: Dict, clazz: Type[T]) -> T: + """Map data to a data class. + + This is not a parser technically, as it doesn't + implement the parser interface correctly. + """ + + def convert(self, data: Dict[str, Any], clazz: Type[T]) -> T: + """Parse the input data into the target class type. + + Args: + data: The input dictionary + clazz: The target class type to parse the input data + + Returns: + An instance of the specified class representing the parsed data. + """ return self.bind_dataclass(data, clazz) diff --git a/xsdata/formats/dataclass/parsers/mixins.py b/xsdata/formats/dataclass/parsers/mixins.py index 86ea40709..ddc5cea2a 100644 --- a/xsdata/formats/dataclass/parsers/mixins.py +++ b/xsdata/formats/dataclass/parsers/mixins.py @@ -1,4 +1,5 @@ import abc +from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Tuple, Type from xsdata.exceptions import XmlHandlerError @@ -9,27 +10,40 @@ NoneStr = Optional[str] +@dataclass class PushParser(AbstractParser): - """ - A generic interface for event based content handlers like sax. + """A generic interface for event based content handlers like sax. + + Args: + config: The parser configuration instance - :param config: Parser configuration. + Attributes: + ns_map: The parsed namespace prefix-URI map """ - config: ParserConfig - ns_map: Dict + config: ParserConfig = field(default_factory=ParserConfig) + ns_map: Dict[Optional[str], str] = field(init=False, default_factory=dict) @abc.abstractmethod def start( self, clazz: Optional[Type], - queue: List, - objects: List, + queue: List[Any], + objects: List[Any], qname: str, - attrs: Dict, - ns_map: Dict, + attrs: Dict[str, str], + ns_map: Dict[Optional[str], str], ): - """Queue the next xml node for parsing.""" + """Build and queue the XmlNode for the starting element. + + Args: + clazz: The target class type, auto locate if omitted + queue: The XmlNode queue list + objects: The list of all intermediate parsed objects + qname: The element qualified name + attrs: The element attributes + ns_map: The element namespace prefix-URI map + """ @abc.abstractmethod def end( @@ -40,74 +54,92 @@ def end( text: NoneStr, tail: NoneStr, ) -> bool: - """ - Parse the last xml node and bind any intermediate objects. + """Parse the last xml node and bind any intermediate objects. - :return: The result of the binding process. + Args: + queue: The XmlNode queue list + objects: The list of all intermediate parsed objects + qname: The element qualified name + text: The element text content + tail: The element tail content + + Returns: + Whether the binding process was successful. """ def register_namespace(self, prefix: NoneStr, uri: str): - """ - Add the given prefix-URI namespaces mapping if the prefix is new. + """Register the uri prefix in the namespace registry. - :param prefix: Namespace prefix - :param uri: Namespace uri + Args: + prefix: Namespace prefix + uri: Namespace uri """ if prefix not in self.ns_map: self.ns_map[prefix] = uri class XmlNode(abc.ABC): - """ - The xml node interface. + """The xml node interface. The nodes are responsible to find and queue the child nodes when a new element starts and build the resulting object tree when the element ends. The parser needs to maintain a queue for these nodes - and a list of all the intermediate object trees. + and a list of all the intermediate objects. """ @abc.abstractmethod def child(self, qname: str, attrs: Dict, ns_map: Dict, position: int) -> "XmlNode": - """ - Initialize the next child node to be queued, when a new xml element - starts. + """Initialize the next child node to be queued, when an element starts. This entry point is responsible to create the next node type with all the necessary information on how to bind the incoming input data. - :param qname: Qualified name - :param attrs: Attribute key-value map - :param ns_map: Namespace prefix-URI map - :param position: The current objects position, to mark future - objects as children + Args: + qname: The element qualified name + attrs: The element attributes + ns_map: The element namespace prefix-URI map + position: The current length of the intermediate objects + + Returns: + The child xml node instance. """ @abc.abstractmethod - def bind(self, qname: str, text: NoneStr, tail: NoneStr, objects: List) -> bool: - """ - Build the object tree for the ending element and return whether the - result was successful or not. + def bind( + self, + qname: str, + text: NoneStr, + tail: NoneStr, + objects: List[Any], + ) -> bool: + """Bind the parsed data into an object for the ending element. - This entry point is called when an xml element ends and is + This entry point is called when a xml element ends and is responsible to parse the current element attributes/text, bind - any children objects and initialize new object. + any children objects and initialize new object. - :param qname: Qualified name - :param text: Text content - :param tail: Tail content - :param objects: The list of intermediate parsed objects, eg - [(qname, object)] + Args: + qname: The element qualified name + text: The element text content + tail: The element tail content + objects: The list of intermediate parsed objects + + Returns: + Whether the binding process was successful or not. """ class XmlHandler: - """ - Abstract content handler. + """Abstract content handler. - :param parser: The parser instance to feed with events - :param clazz: The target binding model, auto located if omitted. + Args: + parser: The parser instance to feed with events + clazz: The target class type, auto locate if omitted + + Attributes: + queue: The XmlNode queue list + objects: The list of intermediate parsed objects """ __slots__ = ("parser", "clazz", "queue", "objects") @@ -119,16 +151,26 @@ def __init__(self, parser: PushParser, clazz: Optional[Type]): self.objects: List = [] def parse(self, source: Any) -> Any: - """Parse an XML document from a system identifier or an InputSource.""" - raise NotImplementedError("This method must be implemented!") + """Parse the source XML document. + + Args: + source: The xml source, can be a file resource or an input stream. - def merge_parent_namespaces(self, ns_map: Dict) -> Dict: + Returns: + An instance of the class type representing the parsed content. """ - Merge and return the given prefix-URI map with the parent node. + raise NotImplementedError("This method must be implemented!") + + def merge_parent_namespaces(self, ns_map: Dict[Optional[str], str]) -> Dict: + """Merge the given prefix-URI map with the parent node map. - Register new prefixes with the parser. + This method also registers new prefixes with the parser. - :param ns_map: Namespace prefix-URI map + Args: + ns_map: The current element namespace prefix-URI map + + Returns: + The new merged namespace prefix-URI map. """ if self.queue: parent_ns_map = self.queue[-1].ns_map @@ -150,15 +192,15 @@ def merge_parent_namespaces(self, ns_map: Dict) -> Dict: class EventsHandler(XmlHandler): """Sax content handler for pre-recorded events.""" - __slots__ = ("data_frames", "flush_next") + def parse(self, source: List[Tuple]) -> Any: + """Forward the pre-recorded events to the main parser. - def __init__(self, parser: PushParser, clazz: Optional[Type]): - super().__init__(parser, clazz) - self.data_frames: List = [] - self.flush_next: Optional[str] = None + Args: + source: A list of event data - def parse(self, source: List[Tuple]) -> Any: - """Forward the pre-recorded events to the main parser.""" + Returns: + An instance of the class type representing the parsed content. + """ for event, *args in source: if event == EventType.START: qname, attrs, ns_map = args diff --git a/xsdata/formats/dataclass/parsers/nodes/element.py b/xsdata/formats/dataclass/parsers/nodes/element.py index 9584f44af..88e092e3e 100644 --- a/xsdata/formats/dataclass/parsers/nodes/element.py +++ b/xsdata/formats/dataclass/parsers/nodes/element.py @@ -14,19 +14,24 @@ class ElementNode(XmlNode): - """ - XmlNode for complex elements and dataclasses. - - :param meta: Model xml metadata - :param attrs: Key-value attribute mapping - :param ns_map: Namespace prefix-URI map - :param config: Parser configuration - :param context: Model context provider - :param position: The node position of objects cache - :param mixed: The node supports mixed content - :param derived_factory: Derived element factory - :param xsi_type: The xml type substitution - :param xsi_nil: The xml type substitution + """XmlNode for complex elements. + + Args: + meta: The class binding metadata instance + attrs: The element attributes + ns_map: The element namespace prefix-URI map + config: The parser config instance + context: The models context instance + position: The current objects length, everything after + this position are considered children of this node. + mixed: Specifies whether this node supports mixed content. + derived_factory: Derived element factory + xsi_type: The xml type substitution + xsi_nil: Specifies whether element has the xsi:nil attribute + + Attributes: + assigned: A set to store the processed sub-nodes + tail_processed: Whether the tail process is consumed """ __slots__ = ( @@ -71,8 +76,27 @@ def __init__( self.tail_processed: bool = False def bind( - self, qname: str, text: Optional[str], tail: Optional[str], objects: List + self, + qname: str, + text: Optional[str], + tail: Optional[str], + objects: List[Any], ) -> bool: + """Bind the parsed data into an object for the ending element. + + This entry point is called when a xml element ends and is + responsible to parse the current element attributes/text, bind + any children objects and initialize new object. + + Args: + qname: The element qualified name + text: The element text content + tail: The element tail content + objects: The list of intermediate parsed objects + + Returns: + Whether the binding process was successful or not. + """ obj: Any = None if not self.xsi_nil or self.meta.nillable: params: Dict = {} @@ -93,8 +117,20 @@ def bind( return True def bind_content( - self, params: Dict, text: Optional[str], tail: Optional[str], objects: List[Any] + self, + params: Dict, + text: Optional[str], + tail: Optional[str], + objects: List[Any], ): + """Parse the text and tail content. + + Args: + params: The class parameters + text: The element text content + tail: The element tail content + objects: The list of intermediate parsed objects + """ wild_var = self.meta.find_any_wildcard() if wild_var and wild_var.mixed: self.bind_mixed_objects(params, wild_var, objects) @@ -111,11 +147,20 @@ def bind_content( if isinstance(params[key], PendingCollection): params[key] = params[key].evaluate() - def bind_attrs(self, params: Dict): - """Parse the given element's attributes and any text content and return - a dictionary of field names and values based on the given class - metadata.""" + def bind_attrs(self, params: Dict[str, Any]): + """Parse the element attributes. + Scenarios: + - Each attribute matches a class field + - Class has a wildcard field that sucks everything else + + Args: + params: The class parameters + + Raises: + ParserError: If the document contains an unknown attribute + and the configuration is strict. + """ if not self.attrs: return @@ -137,6 +182,15 @@ def bind_attrs(self, params: Dict): ) def bind_attr(self, params: Dict, var: XmlVar, value: Any): + """Parse an element attribute. + + Ignores fields with init==false! + + Args: + params: The class parameters + var: The xml var instance + value: The attribute value + """ if var.init: params[var.name] = ParserUtils.parse_value( value=value, @@ -148,23 +202,48 @@ def bind_attr(self, params: Dict, var: XmlVar, value: Any): ) def bind_any_attr(self, params: Dict, var: XmlVar, qname: str, value: Any): + """Parse an element attribute to a wildcard field. + + Args: + params: The class parameters + var: The xml var instance + qname: The attribute namespace qualified name + value: The attribute value + """ if var.name not in params: params[var.name] = {} params[var.name][qname] = ParserUtils.parse_any_attribute(value, self.ns_map) def bind_objects(self, params: Dict, objects: List): - """Return a dictionary of qualified object names and their values for - the given queue item.""" + """Bind children objects. + + Emit a warning if an object doesn't fit in any + class parameters. + Args: + params: The class parameters + objects: The list of intermediate parsed objects + """ position = self.position - for qname, value in objects[position:]: - if not self.bind_object(params, qname, value): + for qname, obj in objects[position:]: + if not self.bind_object(params, qname, obj): logger.warning("Unassigned parsed object %s", qname) del objects[position:] def bind_object(self, params: Dict, qname: str, value: Any) -> bool: + """Bind a child object. + + Args: + params: The class parameters + qname: The qualified name of the element + value: The parsed object + + Returns: + Whether the parsed object can fit in one of class + parameters or not. + """ for var in self.meta.find_children(qname): if var.is_wildcard: return self.bind_wild_var(params, var, qname, value) @@ -176,14 +255,16 @@ def bind_object(self, params: Dict, qname: str, value: Any) -> bool: @classmethod def bind_var(cls, params: Dict, var: XmlVar, value: Any) -> bool: - """ - Add the given value to the params dictionary with the var name as key. + """Bind a child object to an element field. - Wrap the value to a list if var is a list. If the var name - already exists it means we have a name conflict and the parser - needs to lookup for any available wildcard fields. + Args: + params: The class parameters + var: The matched xml var instance + value: The parsed object - :return: Whether the binding process was successful or not. + Returns: + Whether the parsed object can fit in one of class + parameters or not. """ if var.init: if var.list_element: @@ -200,15 +281,22 @@ def bind_var(cls, params: Dict, var: XmlVar, value: Any) -> bool: return True def bind_wild_var(self, params: Dict, var: XmlVar, qname: str, value: Any) -> bool: - """ - Add the given value to the params dictionary with the wildcard var name - as key. + """Bind a child object to a wildcard field. + + The wildcard might support one or more values. If it + supports only one the values are nested under a parent + generic element instance. + + Args: + params: The class parameters + var: The wildcard var instance + qname: The qualified name of the element + value: The parsed value - If the key is already present wrap the previous value into a - generic AnyElement instance. If the previous value is already a - generic instance add the current value as a child object. + Returns: + Always true, since wildcard fields can absorb any value. """ - value = self.prepare_generic_value(qname, value, var) + value = self.prepare_generic_value(qname, value) if var.list_element: items = params.get(var.name) @@ -230,21 +318,30 @@ def bind_wild_var(self, params: Dict, var: XmlVar, qname: str, value: Any) -> bo return True def bind_mixed_objects(self, params: Dict, var: XmlVar, objects: List): - """Return a dictionary of qualified object names and their values for - the given mixed content xml var.""" + """Bind children objects to a mixed content wildcard field. + Args: + params: The class parameters + var: The wildcard var instance + objects: The list of intermediate parsed objects + """ pos = self.position params[var.name] = [ - self.prepare_generic_value(qname, value, var) - for qname, value in objects[pos:] + self.prepare_generic_value(qname, value) for qname, value in objects[pos:] ] del objects[pos:] - def prepare_generic_value( - self, qname: Optional[str], value: Any, var: XmlVar - ) -> Any: - """Prepare parsed value before binding to a wildcard field.""" + def prepare_generic_value(self, qname: Optional[str], value: Any) -> Any: + """Wrap primitive text nodes in a generic element. + + Args: + qname: The qualified name of the element + value: The parsed object + Returns: + The original parsed value if it's a data class, or + the wrapped primitive value in a generic element. + """ if qname and not self.context.class_type.is_model(value): any_factory = self.context.class_type.any_element value = any_factory(qname=qname, text=converter.serialize(value)) @@ -252,11 +349,15 @@ def prepare_generic_value( return value def bind_text(self, params: Dict, text: Optional[str]) -> bool: - """ - Add the given element's text content if any to the params dictionary - with the text var name as key. + """Bind the element text content. + + Args: + params: The class parameters + text: The element text content - Return if any data was bound. + Returns: + Whether the text content can fit in one of class + parameters or not. """ var = self.meta.text @@ -278,20 +379,32 @@ def bind_text(self, params: Dict, text: Optional[str]) -> bool: return True def bind_wild_text( - self, params: Dict, var: XmlVar, txt: Optional[str], tail: Optional[str] + self, + params: Dict, + var: XmlVar, + text: Optional[str], + tail: Optional[str], ) -> bool: + """Bind the element text and tail content to a wildcard field. + + If the field is a list, prepend the text and append the tail content. + Otherwise, build a generic element with the text/tail content + and any attributes. If the field is already occupied, then this + means the current node is a child, and we need to nested them. + + Args: + params: The class parameters + var: The wildcard var instance + text: The element text content + tail: The element text content + + Returns: + Whether the text content can fit in one of class + parameters or not. """ - Extract the text and tail content and bind it accordingly in the params - dictionary. Return if any data was bound. - - - var is a list prepend the text and append the tail. - - var is present in the params assign the text and tail to the generic object. - - Otherwise bind the given element to a new generic object. - """ - - txt = ParserUtils.normalize_content(txt) + text = ParserUtils.normalize_content(text) tail = ParserUtils.normalize_content(tail) - if txt is None and tail is None: + if text is None and tail is None: return False if var.list_element: @@ -299,7 +412,7 @@ def bind_wild_text( if items is None: params[var.name] = items = PendingCollection(None, var.factory) - items.insert(0, txt) + items.insert(0, text) if tail: items.append(tail) @@ -307,7 +420,7 @@ def bind_wild_text( previous = params.get(var.name, None) factory = self.context.class_type.any_element generic = factory( - text=txt, + text=text, tail=tail, attributes=ParserUtils.parse_any_attributes(self.attrs, self.ns_map), ) @@ -319,6 +432,21 @@ def bind_wild_text( return True def child(self, qname: str, attrs: Dict, ns_map: Dict, position: int) -> XmlNode: + """Initialize the next child node to be queued, when an element starts. + + This entry point is responsible to create the next node type + with all the necessary information on how to bind the incoming + input data. + + Args: + qname: The element qualified name + attrs: The element attributes + ns_map: The element namespace prefix-URI map + position: The current length of the intermediate objects + + Raises: + ParserError: If the child element is unknown + """ for var in self.meta.find_children(qname): unique = 0 if not var.is_element or var.list_element else var.index if not unique or unique not in self.assigned: @@ -336,8 +464,26 @@ def child(self, qname: str, attrs: Dict, ns_map: Dict, position: int) -> XmlNode return nodes.SkipNode() def build_node( - self, qname: str, var: XmlVar, attrs: Dict, ns_map: Dict, position: int + self, + qname: str, + var: XmlVar, + attrs: Dict, + ns_map: Dict, + position: int, ) -> Optional[XmlNode]: + """Build the next child node based on the xml var instance. + + Args: + qname: The element qualified name + var: The xml var instance + attrs: The element attributes + ns_map: The element namespace prefix-URI map + position: The current length of the intermediate objects + + Returns: + The next child node instance, or None if nothing matched + the starting element. + """ if var.is_clazz_union: return nodes.UnionNode( var=var, @@ -436,7 +582,24 @@ def build_element_node( derived_factory: Type, xsi_type: Optional[str] = None, xsi_nil: Optional[bool] = None, - ) -> Optional[XmlNode]: + ) -> Optional["ElementNode"]: + """Build the next element child node. + + Args: + clazz: The target class + derived: Whether derived elements should wrap the parsed object + nillable: Specifies whether nil content is allowed + attrs: The element attributes + ns_map: The element namespace prefix-URI map + position: The current length of the intermediate objects + derived_factory: The derived factory + xsi_type: The xml type substitution + xsi_nil: Specifies whether the node supports nillable content + + Returns: + The next child element node instance, or None if the + clazz doesn't match the starting element. + """ meta = self.context.fetch(clazz, self.meta.namespace, xsi_type) nillable = nillable or meta.nillable diff --git a/xsdata/formats/dataclass/parsers/nodes/primitive.py b/xsdata/formats/dataclass/parsers/nodes/primitive.py index a226e8003..4d3a41449 100644 --- a/xsdata/formats/dataclass/parsers/nodes/primitive.py +++ b/xsdata/formats/dataclass/parsers/nodes/primitive.py @@ -7,13 +7,13 @@ class PrimitiveNode(XmlNode): - """ - XmlNode for text elements with primitive values like str, int, float. + """XmlNode for text elements with simple type values. - :param var: Class field xml var instance - :param ns_map: Namespace prefix-URI map - :param mixed: The node supports mixed content - :param derived_factory: Derived element factory + Args: + var: The xml var instance + ns_map: The element namespace prefix-URI map + mixed: Specifies if this node supports mixed content + derived_factory: The derived element factory """ __slots__ = "var", "ns_map", "derived_factory" @@ -25,8 +25,27 @@ def __init__(self, var: XmlVar, ns_map: Dict, mixed: bool, derived_factory: Type self.mixed = mixed def bind( - self, qname: str, text: Optional[str], tail: Optional[str], objects: List + self, + qname: str, + text: Optional[str], + tail: Optional[str], + objects: List, ) -> bool: + """Bind the parsed data into an object for the ending element. + + This entry point is called when a xml element ends and is + responsible to parse the current element attributes/text/tail + content. + + Args: + qname: The element qualified name + text: The element text content + tail: The element tail content + objects: The list of intermediate parsed objects + + Returns: + Whether the binding process was successful or not. + """ obj = ParserUtils.parse_value( value=text, types=self.var.types, @@ -52,4 +71,5 @@ def bind( return True def child(self, qname: str, attrs: Dict, ns_map: Dict, position: int) -> XmlNode: + """Raise an exception if there is a child element inside this node.""" raise XmlContextError("Primitive node doesn't support child nodes!") diff --git a/xsdata/formats/dataclass/parsers/nodes/standard.py b/xsdata/formats/dataclass/parsers/nodes/standard.py index 01e6a9e91..362d1dcde 100644 --- a/xsdata/formats/dataclass/parsers/nodes/standard.py +++ b/xsdata/formats/dataclass/parsers/nodes/standard.py @@ -7,13 +7,13 @@ class StandardNode(XmlNode): - """ - XmlNode for any type elements with a standard xsi:type. + """XmlNode for elements with a standard xsi:type. - :param datatype: Standard xsi data type - :param ns_map: Namespace prefix-URI map - :param nillable: Specify whether the node supports nillable content - :param derived_factory: Optional derived element factory + Args: + datatype: The element standard xsi data type + ns_map: The element namespace prefix-URI map + nillable: Specifies whether nil content is allowed + derived_factory: The derived element factory """ __slots__ = "datatype", "ns_map", "nillable", "derived_factory" @@ -31,8 +31,27 @@ def __init__( self.derived_factory = derived_factory def bind( - self, qname: str, text: Optional[str], tail: Optional[str], objects: List + self, + qname: str, + text: Optional[str], + tail: Optional[str], + objects: List, ) -> bool: + """Bind the parsed data into an object for the ending element. + + This entry point is called when a xml element ends and is + responsible to parse the current element text content. + + Args: + qname: The element qualified name + text: The element text content + tail: The element tail content + objects: The list of intermediate parsed objects + + Returns: + Always true, it's not possible to fail during parsing + for this node. + """ obj = ParserUtils.parse_value( value=text, types=[self.datatype.type], @@ -53,4 +72,5 @@ def bind( return True def child(self, qname: str, attrs: Dict, ns_map: Dict, position: int) -> XmlNode: - raise XmlContextError("Primitive node doesn't support child nodes!") + """Raise an exception if there is a child element inside this node.""" + raise XmlContextError("StandardNode node doesn't support child nodes!") diff --git a/xsdata/formats/dataclass/parsers/nodes/union.py b/xsdata/formats/dataclass/parsers/nodes/union.py index 9f88b0127..d22604de6 100644 --- a/xsdata/formats/dataclass/parsers/nodes/union.py +++ b/xsdata/formats/dataclass/parsers/nodes/union.py @@ -14,20 +14,20 @@ class UnionNode(XmlNode): - """ - XmlNode for fields with multiple possible types where at least one of them - is a dataclass. + """XmlNode for union fields with at least one data class. The node will record all child events and in the end will replay them and try to build all possible objects and sort them by score before deciding the winner. - :param var: Class field xml var instance - :param attrs: Key-value attribute mapping - :param ns_map: Namespace prefix-URI map - :param position: The node position of objects cache - :param config: Parser configuration - :param context: Model context provider + Args: + var: The xml var instance + attrs: The element attributes + ns_map: The element namespace prefix-URI map + position: The current objects length, everything after + this position are considered children of this node. + config: The parser config instance + context: The xml context instance """ __slots__ = ( @@ -60,13 +60,50 @@ def __init__( self.events: List[Tuple[str, str, Any, Any]] = [] def child(self, qname: str, attrs: Dict, ns_map: Dict, position: int) -> XmlNode: + """Record the event for the child element. + + This entry point records all events, as it's not possible + to detect the target parsed object type just yet. When + this node ends, it will replay all events and attempt + to find the best matching type for the parsed object. + + Args: + qname: The element qualified name + attrs: The element attributes + ns_map: The element namespace prefix-URI map + position: The current length of the intermediate objects + """ self.level += 1 self.events.append(("start", qname, copy.deepcopy(attrs), ns_map)) return self def bind( - self, qname: str, text: Optional[str], tail: Optional[str], objects: List + self, + qname: str, + text: Optional[str], + tail: Optional[str], + objects: List, ) -> bool: + """Bind the parsed data into an object for the ending element. + + This entry point is called when a xml element ends and is + responsible to replay all xml events and parse/bind all + the children objects. + + Args: + qname: The element qualified name + text: The element text content + tail: The element tail content + objects: The list of intermediate parsed objects + + Returns: + Always returns true, if the binding process fails + it raises an exception. + + Raises: + ParserError: If none of the candidate types matched + the replayed events. + """ self.events.append(("end", qname, text, tail)) if self.level > 0: @@ -98,8 +135,15 @@ def bind( raise ParserError(f"Failed to parse union node: {self.var.qname}") def parse_class(self, clazz: Type[T]) -> Optional[T]: - """Initialize a new XmlParser and try to parse the given element, treat - converter warnings as errors and return None.""" + """Replay the recorded events and attempt to build the target class. + + Args: + clazz: The target class + + Returns: + The target class instance or None if the recorded + xml events didn't fit the class. + """ try: with warnings.catch_warnings(): warnings.filterwarnings("error", category=ConverterWarning) @@ -112,7 +156,16 @@ def parse_class(self, clazz: Type[T]) -> Optional[T]: return None def parse_value(self, value: Any, types: List[Type]) -> Any: - """Parse simple values, treat warnings as errors and return None.""" + """Parse simple values. + + Args: + value: The xml value + types: The list of the candidate simple types + + Returns: + The parsed value or None if value didn't match + with any of the given types. + """ try: with warnings.catch_warnings(): warnings.filterwarnings("error", category=ConverterWarning) diff --git a/xsdata/formats/dataclass/parsers/nodes/wildcard.py b/xsdata/formats/dataclass/parsers/nodes/wildcard.py index 288b83318..4449c9d48 100644 --- a/xsdata/formats/dataclass/parsers/nodes/wildcard.py +++ b/xsdata/formats/dataclass/parsers/nodes/wildcard.py @@ -6,24 +6,30 @@ class WildcardNode(XmlNode): - """ - XmlNode for extensible elements that can hold any attribute and content. + """XmlNode for extensible elements that can hold any attribute and content. The resulting object tree will be a :class:`~xsdata.formats.dataclass.models.generics.AnyElement` instance. - :param var: Class field xml var instance - :param attrs: Key-value attribute mapping - :param ns_map: Namespace prefix-URI map - :param position: The node position of objects cache - :param factory: Wildcard element factory + Args: + var: The xml var instance + attrs: The element attributes + ns_map: The element namespace prefix-URI map + position: The current objects length, everything after + this position are considered children of this node. + factory: The generic element factory """ __slots__ = "var", "attrs", "ns_map", "position", "factory" def __init__( - self, var: XmlVar, attrs: Dict, ns_map: Dict, position: int, factory: Type + self, + var: XmlVar, + attrs: Dict, + ns_map: Dict, + position: int, + factory: Type, ): self.var = var self.attrs = attrs @@ -34,6 +40,22 @@ def __init__( def bind( self, qname: str, text: Optional[str], tail: Optional[str], objects: List ) -> bool: + """Bind the parsed data into a generic element. + + This entry point is called when a xml element ends and is + responsible to parse the current element attributes/text, bind + any children objects and initialize new generic element that + can fit any xml string. + + Args: + qname: The element qualified name + text: The element text content + tail: The element tail content + objects: The list of intermediate parsed objects + + Returns: + Whether the binding process was successful or not. + """ children = self.fetch_any_children(self.position, objects) attributes = ParserUtils.parse_any_attributes(self.attrs, self.ns_map) derived = self.var.derived or qname != self.var.qname @@ -57,7 +79,17 @@ def bind( @classmethod def fetch_any_children(cls, position: int, objects: List) -> List: - """Fetch the children of a wildcard node.""" + """Fetch the children of this node in the objects list. + + The children are removed from the objects list. + + Args: + position: The position of the objects when this node was created. + objects: The list of intermediate parsed objects + + Returns: + A list of parsed objects. + """ children = [value for _, value in objects[position:]] del objects[position:] @@ -65,6 +97,18 @@ def fetch_any_children(cls, position: int, objects: List) -> List: return children def child(self, qname: str, attrs: Dict, ns_map: Dict, position: int) -> XmlNode: + """Initialize the next child wildcard node to be queued, when an element starts. + + This entry point is responsible to create the next node type + with all the necessary information on how to bind the incoming + input data. Wildcard nodes always return wildcard children. + + Args: + qname: The element qualified name + attrs: The element attributes + ns_map: The element namespace prefix-URI map + position: The current length of the intermediate objects + """ return WildcardNode( position=position, var=self.var, diff --git a/xsdata/formats/dataclass/parsers/nodes/wrapper.py b/xsdata/formats/dataclass/parsers/nodes/wrapper.py index 1763c1dc1..ec65942d5 100644 --- a/xsdata/formats/dataclass/parsers/nodes/wrapper.py +++ b/xsdata/formats/dataclass/parsers/nodes/wrapper.py @@ -5,10 +5,20 @@ class WrapperNode(XmlNode): - """ - XmlNode to wrap an element or primitive list. + """XmlNode for wrapper class fields. + + This node represents wrap class fields, that + don't actually appear in the serialized document. + + These fields simplify classes and this kind of + node simply proxies the child requests to the parent + node. + + Args: + parent: The parent node - :param parent: The parent node + Attributes: + ns_map: The node namespace prefix-URI map """ def __init__(self, parent: ElementNode): @@ -18,7 +28,29 @@ def __init__(self, parent: ElementNode): def bind( self, qname: str, text: Optional[str], tail: Optional[str], objects: List ) -> bool: + """This node will never appear in the xml, so it never binds any data. + + Args: + qname: The element qualified name + text: The element text content + tail: The element tail content + objects: The list of intermediate parsed objects + + Returns: + Always false because no binding takes place. + """ return False def child(self, qname: str, attrs: Dict, ns_map: Dict, position: int) -> XmlNode: + """Proxy the next child node to the parent node. + + Args: + qname: The element qualified name + attrs: The element attributes + ns_map: The element namespace prefix-URI map + position: The current length of the intermediate objects + + Returns: + The child xml node instance. + """ return self.parent.child(qname, attrs, ns_map, position) diff --git a/xsdata/formats/dataclass/parsers/tree.py b/xsdata/formats/dataclass/parsers/tree.py index 68bdf720c..b8541b7ed 100644 --- a/xsdata/formats/dataclass/parsers/tree.py +++ b/xsdata/formats/dataclass/parsers/tree.py @@ -10,11 +10,7 @@ @dataclass class TreeParser(NodeParser): - """ - Bind xml nodes to a tree of AnyElement objects. - - :param handler: Override default XmlHandler - """ + """Bind xml nodes to a tree of AnyElement objects.""" handler: Type[XmlHandler] = field(default=default_handler()) @@ -27,6 +23,16 @@ def start( attrs: Dict, ns_map: Dict, ): + """Build and queue the XmlNode for the starting element. + + Args: + clazz: The target class type, auto locate if omitted + queue: The XmlNode queue list + objects: The list of all intermediate parsed objects + qname: The element qualified name + attrs: The element attributes + ns_map: The element namespace prefix-URI map + """ try: item = queue[-1] child = item.child(qname, attrs, ns_map, len(objects)) diff --git a/xsdata/formats/dataclass/parsers/utils.py b/xsdata/formats/dataclass/parsers/utils.py index df2bfcca2..295cac11a 100644 --- a/xsdata/formats/dataclass/parsers/utils.py +++ b/xsdata/formats/dataclass/parsers/utils.py @@ -8,18 +8,46 @@ class PendingCollection(UserList): + """An iterable implementation of parsed values. + + The values are parsed individually in the end we + need to convert it to a tuple or a list based on + the mutability setting of the data class. + + Args: + initlist: An initial list of values or None + factory: A callable factory for the values + when the node parse has finished. + + """ + def __init__(self, initlist: Optional[Iterable], factory: Optional[Callable]): super().__init__(initlist) self.factory = factory or list - def evaluate(self) -> Iterable: + def evaluate(self) -> Iterable[Any]: + """Evaluate the values factory and return the result. + + Returns: + A list or tuple or set of values + """ return self.factory(self.data) class ParserUtils: + """Random parser util functions.""" + @classmethod def xsi_type(cls, attrs: Dict, ns_map: Dict) -> Optional[str]: - """Parse the xsi:type attribute if present.""" + """Parse the xsi:type attribute value if present. + + Args: + attrs: The element attributes + ns_map: The element namespace prefix-URI map + + Returns: + The xsi:type attribute value or None + """ xsi_type = attrs.get(QNames.XSI_TYPE) if not xsi_type: return None @@ -29,6 +57,14 @@ def xsi_type(cls, attrs: Dict, ns_map: Dict) -> Optional[str]: @classmethod def xsi_nil(cls, attrs: Dict) -> Optional[bool]: + """Return whether xsi:nil attribute value. + + Args: + attrs: The element attributes + + Returns: + The bool value or None if it doesn't exist. + """ xsi_nil = attrs.get(QNames.XSI_NIL) return xsi_nil == constants.XML_TRUE if xsi_nil else None @@ -42,8 +78,20 @@ def parse_value( tokens_factory: Optional[Callable] = None, format: Optional[str] = None, ) -> Any: - """Convert xml string values to s python primitive type.""" - + """Convert a value to a python primitive type. + + Args: + value: A primitive value or a list of primitive values + types: An iterable of types to try to convert the value + default: The default value/factory if the given is None + ns_map: The element namespace prefix-URI map + tokens_factory: A callable factory for the converted values + if the element is derived from xs:NMTOKENS + format: The format argument for base64/hex values or dates. + + Returns: + The converted value or values. + """ if value is None: if callable(default): return default() if tokens_factory else None @@ -61,11 +109,16 @@ def parse_value( @classmethod def normalize_content(cls, value: Optional[str]) -> Optional[str]: - """ - Normalize element text or tail content. + """Normalize element text or tail content. If content is just whitespace return None, otherwise preserve the original content. + + Args: + value: The element content + + Returns: + The normalized content """ if value and value.strip(): return value @@ -73,14 +126,37 @@ def normalize_content(cls, value: Optional[str]) -> Optional[str]: return None @classmethod - def parse_any_attributes(cls, attrs: Dict, ns_map: Dict) -> Dict: + def parse_any_attributes( + cls, attrs: Dict[str, str], ns_map: Dict[Optional[str], str] + ) -> Dict[str, str]: + """Parse attributes with qname support. + + Example: + {"foo": "bar", "xsi:type": "my:type"} -> + {"foo": "bar", "xsi:type" "{http://someuri.com}type"} + + Args: + attrs: The element attributes + ns_map: The element namespace prefix-URI map + + Returns: + The parsed attributes with expanded namespace prefixes + """ return { key: cls.parse_any_attribute(value, ns_map) for key, value in attrs.items() } @classmethod def parse_any_attribute(cls, value: str, ns_map: Dict) -> str: - """Attempt to parse any attribute.""" + """Expand the value with the full namespace if it has a prefix. + + Args: + value: The attr value + ns_map: The element namespace prefix-URI map + + Returns: + The expanded value. + """ prefix, suffix = text.split(value) if prefix and prefix in ns_map and not suffix.startswith("//"): value = build_qname(ns_map[prefix], suffix) diff --git a/xsdata/formats/dataclass/parsers/xml.py b/xsdata/formats/dataclass/parsers/xml.py index bc3d3a184..1a77fa429 100644 --- a/xsdata/formats/dataclass/parsers/xml.py +++ b/xsdata/formats/dataclass/parsers/xml.py @@ -11,13 +11,15 @@ @dataclass class XmlParser(NodeParser): - """ - Default Xml parser for dataclasses. + """Default Xml parser for data classes. + + Args: + config: The parser config instance + context: The xml context instance + handler: The xml handler class - :param config: Parser configuration - :param context: Model context provider - :param handler: Override default XmlHandler - :ivar ms_map: The prefix-URI map generated during parsing + Attributes: + ns_map: The parsed namespace prefix-URI map """ handler: Type[XmlHandler] = field(default=default_handler()) @@ -25,19 +27,25 @@ class XmlParser(NodeParser): @dataclass class UserXmlParser(NodeParser): - """ - User Xml parser for dataclasses with hooks for emitting events to alter the - behavior when an elements starts or ends. - - :param config: Parser configuration - :param context: Model context provider - :param handler: Override default XmlHandler - :ivar ms_map: The prefix-URI map generated during parsing - :ivar emit_cache: Qname to event name cache + """Xml parser for dataclasses with hooks to events. + + The event hooks allow custom parsers to inject custom + logic between the start/end element events. + + Args: + config: The parser config instance + context: The xml context instance + handler: The xml handler class + + Attributes: + ns_map: The parsed namespace prefix-URI map + hooks_cache: The hooks cache is used to avoid + inspecting the class for custom methods + on duplicate events. """ handler: Type[XmlHandler] = field(default=default_handler()) - emit_cache: Dict = field(init=False, default_factory=dict) + hooks_cache: Dict = field(init=False, default_factory=dict) def start( self, @@ -48,6 +56,18 @@ def start( attrs: Dict, ns_map: Dict, ): + """Build and queue the XmlNode for the starting element. + + Override to emit the start element event. + + Args: + clazz: The target class type, auto locate if omitted + queue: The XmlNode queue list + objects: The list of all intermediate parsed objects + qname: The element qualified name + attrs: The element attributes + ns_map: The element namespace prefix-URI map + """ super().start(clazz, queue, objects, qname, attrs, ns_map) self.emit_event(EventType.START, qname, attrs=attrs) @@ -59,31 +79,45 @@ def end( text: Optional[str], tail: Optional[str], ) -> bool: + """Parse the last xml node and bind any intermediate objects. + + Override to emit the end element event if the binding process + is successful. + + Args: + queue: The XmlNode queue list + objects: The list of all intermediate parsed objects + qname: The element qualified name + text: The element text content + tail: The element tail content + + Returns: + Whether the binding process was successful. + """ result = super().end(queue, objects, qname, text, tail) if result: self.emit_event(EventType.END, qname, obj=objects[-1][1]) return result def emit_event(self, event: str, name: str, **kwargs: Any): - """ - Propagate event to subclasses. + """Propagate event to subclasses. Match event and name to a subclass method and trigger it with any input keyword arguments. Example:: - event=start, name={urn}bookTitle -> start_booking_title(**kwargs) - :param event: Event type start|end - :param name: Element qualified name - :param kwargs: Event keyword arguments + Args: + event: The event type start|end + name: The qualified name of the element + kwargs: Additional keyword arguments passed to the hooks """ key = (event, name) - if key not in self.emit_cache: + if key not in self.hooks_cache: method_name = f"{event}_{snake_case(local_name(name))}" - self.emit_cache[key] = getattr(self, method_name, None) + self.hooks_cache[key] = getattr(self, method_name, None) - method = self.emit_cache[key] + method = self.hooks_cache[key] if method: method(**kwargs) diff --git a/xsdata/formats/dataclass/serializers/code.py b/xsdata/formats/dataclass/serializers/code.py index da5d09bcd..34fbd1903 100644 --- a/xsdata/formats/dataclass/serializers/code.py +++ b/xsdata/formats/dataclass/serializers/code.py @@ -1,11 +1,10 @@ from dataclasses import dataclass, field from enum import Enum from io import StringIO -from typing import Any, List, Mapping, Set, TextIO, Type +from typing import Any, List, Mapping, Set, TextIO, Tuple, Type, Union from xsdata.formats.bindings import AbstractSerializer from xsdata.formats.dataclass.context import XmlContext -from xsdata.formats.dataclass.serializers.config import SerializerConfig from xsdata.utils import collections from xsdata.utils.objects import literal_value @@ -17,41 +16,42 @@ @dataclass class PycodeSerializer(AbstractSerializer): - """ - Pycode serializer for dataclasses. + """Pycode serializer for data class instances. - Return a python representation code of a model instance. + Generate python pretty representation code from a model instance. - :param config: Serializer configuration - :param context: Model context provider + Args: + context: The models context instance """ - config: SerializerConfig = field(default_factory=SerializerConfig) context: XmlContext = field(default_factory=XmlContext) def render(self, obj: object, var_name: str = "obj") -> str: - """ - Convert and return the given object tree as python representation code. + """Serialize the input model instance to python representation string. + + Args: + obj: The input model instance to serialize + var_name: The var name to assign the model instance - :param obj: The input dataclass instance - :param var_name: The var name to assign the model instance + Returns: + The serialized representation string. """ output = StringIO() self.write(output, obj, var_name) return output.getvalue() def write(self, out: TextIO, obj: Any, var_name: str): - """ - Write the given object tree to the output text stream. + """Write the given object to the output text stream. - :param out: The output stream - :param obj: The input dataclass instance - :param var_name: The var name to assign the model instance + Args: + out: The output text stream + obj: The input model instance to serialize + var_name: The var name to assign the model instance """ types: Set[Type] = set() tmp = StringIO() - for chunk in self.write_object(obj, 0, types): + for chunk in self.repr_object(obj, 0, types): tmp.write(chunk) imports = self.build_imports(types) @@ -63,6 +63,14 @@ def write(self, out: TextIO, obj: Any, var_name: str): @classmethod def build_imports(cls, types: Set[Type]) -> str: + """Build a list of imports from the given types. + + Args: + types: A set of types + + Returns: + The `from x import y` statements as string. + """ imports = set() for tp in types: module = tp.__module__ @@ -75,20 +83,40 @@ def build_imports(cls, types: Set[Type]) -> str: return "".join(sorted(set(imports))) - def write_object(self, obj: Any, level: int, types: Set[Type]): + def repr_object(self, obj: Any, level: int, types: Set[Type]): + """Write the given object as repr code. + + Args: + obj: The input object to serialize + level: The current object level + types: The parent object types + + Yields: + An iterator of the representation strings. + """ types.add(type(obj)) if collections.is_array(obj): - yield from self.write_array(obj, level, types) + yield from self.repr_array(obj, level, types) elif isinstance(obj, dict): - yield from self.write_mapping(obj, level, types) + yield from self.repr_mapping(obj, level, types) elif self.context.class_type.is_model(obj): - yield from self.write_class(obj, level, types) + yield from self.repr_model(obj, level, types) elif isinstance(obj, Enum): yield str(obj) else: yield literal_value(obj) - def write_array(self, obj: List, level: int, types: Set[Type]): + def repr_array(self, obj: Union[List, Set, Tuple], level: int, types: Set[Type]): + """Convert an iterable object to repr code. + + Args: + obj: A list, set, tuple instance + level: The current object level + types: The parent object types + + Yields: + An iterator of the representation strings. + """ if not obj: yield str(obj) return @@ -97,12 +125,22 @@ def write_array(self, obj: List, level: int, types: Set[Type]): yield "[\n" for val in obj: yield spaces * next_level - yield from self.write_object(val, next_level, types) + yield from self.repr_object(val, next_level, types) yield ",\n" yield f"{spaces * level}]" - def write_mapping(self, obj: Mapping, level: int, types: Set[Type]): + def repr_mapping(self, obj: Mapping, level: int, types: Set[Type]): + """Convert a map object to repr code. + + Args: + obj: A map instance + level: The current object level + types: The parent object types + + Yields: + An iterator of the representation strings. + """ if not obj: yield str(obj) return @@ -111,14 +149,24 @@ def write_mapping(self, obj: Mapping, level: int, types: Set[Type]): yield "{\n" for key, value in obj.items(): yield spaces * next_level - yield from self.write_object(key, next_level, types) + yield from self.repr_object(key, next_level, types) yield ": " - yield from self.write_object(value, next_level, types) + yield from self.repr_object(value, next_level, types) yield ",\n" yield f"{spaces * level}}}" - def write_class(self, obj: Any, level: int, types: Set[Type]): + def repr_model(self, obj: Any, level: int, types: Set[Type]): + """Convert a data model instance to repr code. + + Args: + obj: A map instance + level: The current object level + types: The parent object types + + Yields: + An iterator of the representation strings. + """ yield f"{obj.__class__.__qualname__}(\n" next_level = level + 1 @@ -139,7 +187,7 @@ def write_class(self, obj: Any, level: int, types: Set[Type]): else: yield f"{spaces * next_level}{f.name}=" - yield from self.write_object(value, next_level, types) + yield from self.repr_object(value, next_level, types) index += 1 diff --git a/xsdata/formats/dataclass/serializers/config.py b/xsdata/formats/dataclass/serializers/config.py index e39f7af51..8642a46ae 100644 --- a/xsdata/formats/dataclass/serializers/config.py +++ b/xsdata/formats/dataclass/serializers/config.py @@ -1,56 +1,32 @@ +from dataclasses import dataclass from typing import Callable, Dict, Optional +@dataclass class SerializerConfig: - """ - Serializer configuration options. + """Serializer configuration options. - Some options are not applicable for both xml or json documents. + Not all options are applicable for both xml and json documents. - :param encoding: Text encoding - :param xml_version: XML Version number (1.0|1.1) - :param xml_declaration: Generate XML declaration - :param pretty_print: Enable pretty output - :param pretty_print_indent: Indentation string for each indent level - :param ignore_default_attributes: Ignore optional attributes with - default values - :param schema_location: xsi:schemaLocation attribute value - :param no_namespace_schema_location: xsi:noNamespaceSchemaLocation - attribute value - :param globalns: Dictionary containing global variables to extend or - overwrite for typing + Args: + encoding: Text encoding + xml_version: XML Version number (1.0|1.1) + xml_declaration: Generate XML declaration + pretty_print: Enable pretty output + pretty_print_indent: Indentation string for each indent level + ignore_default_attributes: Ignore optional attributes with default values + schema_location: xsi:schemaLocation attribute value + no_namespace_schema_location: xsi:noNamespaceSchemaLocation attribute value + globalns: Dictionary containing global variables to extend or + overwrite for typing """ - __slots__ = ( - "encoding", - "xml_version", - "xml_declaration", - "pretty_print", - "pretty_print_indent", - "ignore_default_attributes", - "schema_location", - "no_namespace_schema_location", - "globalns", - ) - - def __init__( - self, - encoding: str = "UTF-8", - xml_version: str = "1.0", - xml_declaration: bool = True, - pretty_print: bool = False, - pretty_print_indent: Optional[str] = None, - ignore_default_attributes: bool = False, - schema_location: Optional[str] = None, - no_namespace_schema_location: Optional[str] = None, - globalns: Optional[Dict[str, Callable]] = None, - ): - self.encoding = encoding - self.xml_version = xml_version - self.xml_declaration = xml_declaration - self.pretty_print = pretty_print - self.pretty_print_indent = pretty_print_indent - self.ignore_default_attributes = ignore_default_attributes - self.schema_location = schema_location - self.no_namespace_schema_location = no_namespace_schema_location - self.globalns = globalns + encoding: str = "UTF-8" + xml_version: str = "1.0" + xml_declaration: bool = True + pretty_print: bool = False + pretty_print_indent: Optional[str] = None + ignore_default_attributes: bool = False + schema_location: Optional[str] = None + no_namespace_schema_location: Optional[str] = None + globalns: Optional[Dict[str, Callable]] = None diff --git a/xsdata/formats/dataclass/serializers/json.py b/xsdata/formats/dataclass/serializers/json.py index cc4c6e5cc..8a355f262 100644 --- a/xsdata/formats/dataclass/serializers/json.py +++ b/xsdata/formats/dataclass/serializers/json.py @@ -1,5 +1,4 @@ import json -import warnings from dataclasses import dataclass, field from enum import Enum from io import StringIO @@ -14,6 +13,14 @@ def filter_none(x: Tuple) -> Dict: + """Convert a key-value pairs to dict, ignoring None values. + + Args: + x: Key-value pairs + + Returns: + The filtered dictionary. + """ return {k: v for k, v in x if v is not None} @@ -25,75 +32,90 @@ class DictFactory: @dataclass class JsonSerializer(AbstractSerializer): - """ - Json serializer for dataclasses. - - :param config: Serializer configuration - :param context: Model context provider - :param dict_factory: Override default dict factory to add further - logic - :param dump_factory: Override default json.dump call with another - implementation - :param indent: Output indentation level + """Json serializer for data classes. + + Args: + config: The serializer config instance + context: The models context instance + dict_factory: Dictionary factory + dump_factory: Json dump factory e.g. json.dump """ config: SerializerConfig = field(default_factory=SerializerConfig) context: XmlContext = field(default_factory=XmlContext) dict_factory: Callable = field(default=dict) dump_factory: Callable = field(default=json.dump) - indent: Optional[int] = field(default=None) - def render(self, obj: object) -> str: - """Convert the given object tree to json string.""" + def render(self, obj: Any) -> str: + """Serialize the input model instance to json string. + + Args: + obj: The input model instance + + Returns: + The serialized json string output. + """ output = StringIO() self.write(output, obj) return output.getvalue() def write(self, out: TextIO, obj: Any): - """ - Write the given object tree to the output text stream. + """Serialize the given object to the output text stream. - :param out: The output stream - :param obj: The input dataclass instance + Args: + out: The output text stream + obj: The input model instance to serialize """ indent: Optional[Union[int, str]] = None - if self.indent: - warnings.warn( - "JsonSerializer indent property is deprecated, use SerializerConfig", - DeprecationWarning, - ) - indent = self.indent - elif self.config.pretty_print: + if self.config.pretty_print: indent = self.config.pretty_print_indent or 2 self.dump_factory(self.convert(obj), out, indent=indent) - def convert(self, obj: Any, var: Optional[XmlVar] = None) -> Any: - if var is None or self.context.class_type.is_model(obj): - if collections.is_array(obj): - return [self.convert(o) for o in obj] + def convert(self, value: Any, var: Optional[XmlVar] = None) -> Any: + """Convert a value to json serializable object. + + Args: + value: The input value + var: The xml var instance + + Returns: + The converted json serializable value. + """ + if var is None or self.context.class_type.is_model(value): + if collections.is_array(value): + return list(map(self.convert, value)) - return self.dict_factory(self.next_value(obj)) + return self.dict_factory(self.next_value(value)) - if collections.is_array(obj): - return type(obj)(self.convert(v, var) for v in obj) + if collections.is_array(value): + return type(value)(self.convert(val, var) for val in value) - if isinstance(obj, (dict, int, float, str, bool)): - return obj + if isinstance(value, (dict, int, float, str, bool)): + return value - if isinstance(obj, Enum): - return self.convert(obj.value, var) + if isinstance(value, Enum): + return self.convert(value.value, var) - return converter.serialize(obj, format=var.format) + return converter.serialize(value, format=var.format) def next_value(self, obj: Any) -> Iterator[Tuple[str, Any]]: + """Fetch the next value of a model instance to convert. + + Args: + obj: The input model instance + + Yields: + An iterator of field name and value tuples. + """ ignore_optionals = self.config.ignore_default_attributes + meta = self.context.build(obj.__class__, globalns=self.config.globalns) - for var in self.context.build( - obj.__class__, globalns=self.config.globalns - ).get_all_vars(): + for var in meta.get_all_vars(): value = getattr(obj, var.name) - if var.is_attribute and ignore_optionals and var.is_optional(value): - continue - - yield var.local_name, self.convert(value, var) + if ( + not var.is_attribute + or not ignore_optionals + or not var.is_optional(value) + ): + yield var.local_name, self.convert(value, var) diff --git a/xsdata/formats/dataclass/serializers/mixins.py b/xsdata/formats/dataclass/serializers/mixins.py index a7d19200c..b2976a0ec 100644 --- a/xsdata/formats/dataclass/serializers/mixins.py +++ b/xsdata/formats/dataclass/serializers/mixins.py @@ -1,7 +1,21 @@ -from typing import Any, Dict, Generator, List, Optional, TextIO, Tuple +import abc +from typing import ( + Any, + Dict, + Final, + Iterator, + List, + Literal, + Optional, + TextIO, + Tuple, + Union, +) from xml.etree.ElementTree import QName from xml.sax.handler import ContentHandler +from typing_extensions import TypeAlias + from xsdata.exceptions import XmlWriterError from xsdata.formats.converter import converter from xsdata.formats.dataclass.serializers.config import SerializerConfig @@ -13,32 +27,45 @@ class XmlWriterEvent: - START = "start" - ATTR = "attr" - DATA = "data" - END = "end" + """Event names.""" + START: Final = "start" + ATTR: Final = "attr" + DATA: Final = "data" + END: Final = "end" -class XmlWriter: - """ - A consistency wrapper for sax content handlers. - - - Implements a custom sax-like event api with separate start - element/attribute events. - - Buffers events until all content has been received or a child - element is starting in order to build the current element's - namespace context correctly. - - Prepares values for serialization. - - :param config: Configuration instance - :param output: Output text stream - :param ns_map: User defined namespace prefix-URI map + +StartEvent: TypeAlias = Tuple[Literal["start"], str] +AttrEvent: TypeAlias = Tuple[Literal["attr"], str, Any] +DataEvent: TypeAlias = Tuple[Literal["data"], str] +EndEvent: TypeAlias = Tuple[Literal["end"], str] + +EventIterator = Iterator[Union[StartEvent, AttrEvent, DataEvent, EndEvent]] + + +class XmlWriter(abc.ABC): + """A consistency wrapper for sax content handlers. + + Args: + config: The serializer config instance + output: The output stream to write the result + ns_map: A user defined namespace prefix-URI map + + Attributes: + handler: The content handler instance + in_tail: Specifies whether the text content has been written + tail: The current element tail content + attrs: The current element attributes + ns_context: The namespace context queue + pending_tag: The pending element namespace, name tuple + pending_prefixes: The pending element namespace prefixes """ __slots__ = ( "config", "output", "ns_map", + # Instance attributes "handler", "in_tail", "tail", @@ -64,21 +91,27 @@ def __init__( self.ns_context: List[Dict] = [] self.pending_tag: Optional[Tuple] = None self.pending_prefixes: List[List] = [] - self.handler: ContentHandler + self.handler = self.build_handler() + + @abc.abstractmethod + def build_handler(self) -> ContentHandler: + """Build the content handler instance. - def write(self, events: Generator): + Returns: + A content handler instance. """ - Iterate over the generator events and feed the sax content handler with - the information needed to generate the xml output. - Example:: + def write(self, events: EventIterator): + """Feed the sax content handler with events. + + The receiver will also add additional root attributes + like xsi or no namespace location. - (XmlWriterEvent.START, "{http://www.w3.org/1999/xhtml}p"), - (XmlWriterEvent.ATTR, "class", "paragraph"), - (XmlWriterEvent.DATA, "Hello"), - (XmlWriterEvent.END, "{http://www.w3.org/1999/xhtml}p"), + Args: + events: An iterator of sax events - :param events: Events generator + Raises: + XmlWriterError: On unknown events. """ self.start_document() @@ -86,44 +119,48 @@ def write(self, events: Generator): self.add_attribute( QNames.XSI_SCHEMA_LOCATION, self.config.schema_location, - check_pending=False, + root=True, ) if self.config.no_namespace_schema_location: self.add_attribute( QNames.XSI_NO_NAMESPACE_SCHEMA_LOCATION, self.config.no_namespace_schema_location, - check_pending=False, + root=True, ) - for event, *args in events: - if event == XmlWriterEvent.START: + for name, *args in events: + if name == XmlWriterEvent.START: self.start_tag(*args) - elif event == XmlWriterEvent.END: + elif name == XmlWriterEvent.END: self.end_tag(*args) - elif event == XmlWriterEvent.ATTR: + elif name == XmlWriterEvent.ATTR: self.add_attribute(*args) - elif event == XmlWriterEvent.DATA: + elif name == XmlWriterEvent.DATA: self.set_data(*args) else: - raise XmlWriterError(f"Unhandled event: `{event}`") + raise XmlWriterError(f"Unhandled event: `{name}`") self.handler.endDocument() def start_document(self): - """Start document notification receiver.""" + """Start document notification receiver. + + Write the xml version and encoding, if the + configuration is enabled. + """ if self.config.xml_declaration: self.output.write(f'\n') def start_tag(self, qname: str): - """ - Start tag notification receiver. + """Start tag notification receiver. The receiver will flush the start of any pending element, create new namespaces context and queue the current tag for generation. - :param qname: Tag qualified name + Args: + qname: The qualified name of the starting element """ self.flush_start(False) @@ -133,44 +170,44 @@ def start_tag(self, qname: str): self.pending_tag = split_qname(qname) self.add_namespace(self.pending_tag[0]) - def add_attribute(self, key: str, value: Any, check_pending: bool = True): - """ - Add attribute notification receiver. + def add_attribute(self, qname: str, value: Any, root: bool = False): + """Add attribute notification receiver. The receiver will convert the key to a namespace, name tuple and convert the value to string. Internally the converter will also generate any missing namespace prefixes. - :param key: Attribute name - :param value: Attribute value - :param check_pending: Raise exception if not no element is - pending start + Args: + qname: The qualified name of the attribute + value: The value of the attribute + root: Specifies if attribute is for the root element + + Raises: + XmlWriterError: If it's not a root element attribute + and not no element is pending to start. """ - if not self.pending_tag and check_pending: + if not self.pending_tag and not root: raise XmlWriterError("Empty pending tag.") - if self.is_xsi_type(key, value): + if self.is_xsi_type(qname, value): value = QName(value) - name = split_qname(key) - self.attrs[name] = self.encode_data(value) + name_tuple = split_qname(qname) + self.attrs[name_tuple] = self.encode_data(value) def add_namespace(self, uri: Optional[str]): - """ - Add the given uri to the current namespace context if the uri is valid - and new. + """Add the given uri to the current namespace context. - The prefix will be auto generated if it doesn't exist in the - prefix-URI mappings. + If the uri empty or a prefix already exists, skip silently. - :param uri: Namespace uri + Args: + uri: The namespace URI """ if uri and not prefix_exists(uri, self.ns_map): generate_prefix(uri, self.ns_map) def set_data(self, data: Any): - """ - Set data notification receiver. + """Set data notification receiver. The receiver will convert the data to string, flush any previous pending start element and send it to the handler for generation. @@ -179,7 +216,8 @@ def set_data(self, data: Any): treat the current data as element tail content and queue it to be generated when the tag ends. - :param data: Element text or tail content + Args: + data: The element text or tail content """ value = self.encode_data(data) self.flush_start(is_nil=value is None) @@ -193,14 +231,14 @@ def set_data(self, data: Any): self.in_tail = True def end_tag(self, qname: str): - """ - End tag notification receiver. + """End tag notification receiver. The receiver will flush if pending the start of the element, end the element, its tail content and its namespaces prefix mapping and current context. - :param qname: Tag qualified name + Args: + qname: The qualified name of the element """ self.flush_start(True) self.handler.endElementNS(split_qname(qname), "") @@ -218,16 +256,16 @@ def end_tag(self, qname: str): self.handler.endPrefixMapping(prefix) def flush_start(self, is_nil: bool = True): - """ - Flush start notification receiver. + """Flush start notification receiver. The receiver will pop the xsi:nil attribute if the element is - not empty, prepare and send the namespaces prefix mappings and + not empty, prepare and send the namespace prefix-URI map and the element with its attributes to the content handler for generation. - :param is_nil: If true add ``xsi:nil="true"`` to the element - attributes + Args: + is_nil: Specify if the element requires `xsi:nil="true"` + when content is empty """ if not self.pending_tag: return @@ -247,9 +285,7 @@ def flush_start(self, is_nil: bool = True): self.pending_tag = None def start_namespaces(self): - """ - Send the new prefixes and namespaces added in the current context to - the content handler. + """Send the current namespace prefix-URI map to the content handler. Save the list of prefixes to be removed at the end of the current pending tag. @@ -268,27 +304,35 @@ def start_namespaces(self): self.handler.startPrefixMapping(prefix, uri) def reset_default_namespace(self): - """Reset the default namespace if exists and the current pending tag is - not qualified.""" + """Reset the default namespace if the pending element is not qualified.""" if self.pending_tag and not self.pending_tag[0] and None in self.ns_map: self.ns_map[None] = "" @classmethod - def is_xsi_type(cls, key: str, value: Any) -> bool: - """ - Return whether the value is an xsi:type or not based on the given - attribute name/value. + def is_xsi_type(cls, qname: str, value: Any) -> bool: + """Return whether the value is a xsi:type. + + Args: + qname: The attribute qualified name + value: The attribute value - :param key: Attribute name - :param value: Attribute value + Returns: + The bool result. """ if isinstance(value, str) and value.startswith("{"): - return key == QNames.XSI_TYPE or DataType.from_qname(value) is not None + return qname == QNames.XSI_TYPE or DataType.from_qname(value) is not None return False def encode_data(self, data: Any) -> Optional[str]: - """Encode data for xml rendering.""" + """Encode data for xml rendering. + + Args: + data: The content to encode/serialize + + Returns: + The xml encoded data + """ if data is None or isinstance(data, str): return data diff --git a/xsdata/formats/dataclass/serializers/writers/__init__.py b/xsdata/formats/dataclass/serializers/writers/__init__.py index ae1edac85..de79b5142 100644 --- a/xsdata/formats/dataclass/serializers/writers/__init__.py +++ b/xsdata/formats/dataclass/serializers/writers/__init__.py @@ -7,12 +7,18 @@ from xsdata.formats.dataclass.serializers.writers.lxml import LxmlEventWriter def default_writer() -> Type[XmlWriter]: + """Return the default xml writer.""" return LxmlEventWriter except ImportError: # pragma: no cover def default_writer() -> Type[XmlWriter]: + """Return the default xml writer.""" return XmlEventWriter -__all__ = ["LxmlEventWriter", "XmlEventWriter", "default_writer"] +__all__ = [ + "LxmlEventWriter", + "XmlEventWriter", + "default_writer", +] diff --git a/xsdata/formats/dataclass/serializers/writers/lxml.py b/xsdata/formats/dataclass/serializers/writers/lxml.py index af949095e..a9360186c 100644 --- a/xsdata/formats/dataclass/serializers/writers/lxml.py +++ b/xsdata/formats/dataclass/serializers/writers/lxml.py @@ -1,36 +1,53 @@ -from typing import Dict, Generator, TextIO +from typing import Iterator from lxml.etree import indent, tostring from lxml.sax import ElementTreeContentHandler -from xsdata.formats.dataclass.serializers.config import SerializerConfig from xsdata.formats.dataclass.serializers.mixins import XmlWriter class LxmlEventWriter(XmlWriter): + """Xml event writer based on `lxml.sax.ElementTreeContentHandler`. + + The writer converts the events to an lxml tree which is + then converted to string. + + Args: + config: The serializer config instance + output: The output stream to write the result + ns_map: A user defined namespace prefix-URI map + + Attributes: + handler: The content handler instance + in_tail: Specifies whether the text content has been written + tail: The current element tail content + attrs: The current element attributes + ns_context: The namespace context queue + pending_tag: The pending element namespace, name tuple + pending_prefixes: The pending element namespace prefixes """ - :class:`~xsdata.formats.dataclass.serializers.mixins.XmlWriter` - implementation based on lxml package. - - Based on the :class:`lxml.sax.ElementTreeContentHandler`, converts - sax events to an lxml ElementTree, serialize and write the result - to the output stream. Despite that since it's lxml it's still - pretty fast and has better support for special characters and - encodings than native python. - - :param config: Configuration instance - :param output: Output text stream - :param ns_map: User defined namespace prefix-URI map - """ - __slots__ = () + def build_handler(self) -> ElementTreeContentHandler: + """Build the content handler instance. + + Returns: + An element tree content handler instance. + """ + return ElementTreeContentHandler() + + def write(self, events: Iterator): + """Feed the sax content handler with events. - def __init__(self, config: SerializerConfig, output: TextIO, ns_map: Dict): - super().__init__(config, output, ns_map) + The receiver will also add additional root attributes + like xsi or no namespace location. In the end convert + the handler etree to string based on the configuration. - self.handler = ElementTreeContentHandler() + Args: + events: An iterator of sax events - def write(self, events: Generator): + Raises: + XmlWriterError: On unknown events. + """ super().write(events) assert isinstance(self.handler, ElementTreeContentHandler) diff --git a/xsdata/formats/dataclass/serializers/writers/native.py b/xsdata/formats/dataclass/serializers/writers/native.py index 31a68ee00..e407b9a7c 100644 --- a/xsdata/formats/dataclass/serializers/writers/native.py +++ b/xsdata/formats/dataclass/serializers/writers/native.py @@ -6,36 +6,58 @@ class XmlEventWriter(XmlWriter): - """ - :class:`~xsdata.formats.dataclass.serializers.mixins.XmlWriter` - implementation based on native python. + """Xml event writer based on `xml.sax.saxutils.XMLGenerator`. + + The writer converts sax events directly to xml output + without storing any intermediate results in memory. - Based on the native python :class:`xml.sax.saxutils.XMLGenerator` - with support for indentation. Converts sax events directly to xml - output without storing intermediate result to memory. + Args: + config: The serializer config instance + output: The output stream to write the result + ns_map: A user defined namespace prefix-URI map - :param config: Configuration instance - :param output: Output text stream - :param ns_map: User defined namespace prefix-URI map + Attributes: + handler: The content handler instance + in_tail: Specifies whether the text content has been written + tail: The current element tail content + attrs: The current element attributes + ns_context: The namespace context queue + pending_tag: The pending element namespace, name tuple + pending_prefixes: The pending element namespace prefixes """ __slots__ = ("current_level", "pending_end_element") def __init__(self, config: SerializerConfig, output: TextIO, ns_map: Dict): - """ - :param config: Configuration instance - :param output: Output text stream - :param ns_map: User defined namespace prefix-URI map - """ super().__init__(config, output, ns_map) self.current_level = 0 self.pending_end_element = False - self.handler = XMLGenerator( - out=self.output, encoding=self.config.encoding, short_empty_elements=True + + def build_handler(self) -> XMLGenerator: + """Build the content handler instance. + + Returns: + A xml generator content handler instance. + """ + return XMLGenerator( + out=self.output, + encoding=self.config.encoding, + short_empty_elements=True, ) def start_tag(self, qname: str): + """Start tag notification receiver. + + The receiver will flush the start of any pending element, create + new namespaces context and queue the current tag for generation. + + The receiver will also write the necessary whitespace if + pretty print is enabled. + + Args: + qname: The qualified name of the starting element + """ super().start_tag(qname) if self.config.pretty_print: @@ -49,6 +71,18 @@ def start_tag(self, qname: str): self.pending_end_element = False def end_tag(self, qname: str): + """End tag notification receiver. + + The receiver will flush if pending the start of the element, end + the element, its tail content and its namespaces prefix mapping + and current context. + + The receiver will also write the necessary whitespace if + pretty print is enabled. + + Args: + qname: The qualified name of the element + """ if not self.config.pretty_print: super().end_tag(qname) return diff --git a/xsdata/formats/dataclass/serializers/xml.py b/xsdata/formats/dataclass/serializers/xml.py index 7e256c71b..2425b4dc4 100644 --- a/xsdata/formats/dataclass/serializers/xml.py +++ b/xsdata/formats/dataclass/serializers/xml.py @@ -4,7 +4,6 @@ from typing import ( Any, Dict, - Generator, Iterable, Iterator, List, @@ -21,23 +20,25 @@ from xsdata.formats.dataclass.context import XmlContext from xsdata.formats.dataclass.models.elements import XmlMeta, XmlVar from xsdata.formats.dataclass.serializers.config import SerializerConfig -from xsdata.formats.dataclass.serializers.mixins import XmlWriter, XmlWriterEvent +from xsdata.formats.dataclass.serializers.mixins import ( + EventIterator, + XmlWriter, + XmlWriterEvent, +) from xsdata.formats.dataclass.serializers.writers import default_writer from xsdata.models.enums import DataType, QNames from xsdata.utils import collections, namespaces from xsdata.utils.constants import EMPTY_MAP -NoneStr = Optional[str] - @dataclass class XmlSerializer(AbstractSerializer): - """ - Xml serializer for dataclasses. + """Xml serializer for data classes. - :param config: Serializer configuration - :param context: Model context provider - :param writer: Override default XmlWriter + Args: + config: The serializer config instance + context: The models context instance + writer: The xml writer class """ config: SerializerConfig = field(default_factory=SerializerConfig) @@ -45,23 +46,26 @@ class XmlSerializer(AbstractSerializer): writer: Type[XmlWriter] = field(default=default_writer()) def render(self, obj: Any, ns_map: Optional[Dict] = None) -> str: - """ - Convert and return the given object tree as xml string. + """Serialize the input model instance to xml string. - :param obj: The input dataclass instance - :param ns_map: User defined namespace prefix-URI map + Args: + obj: The input model instance to serialize + ns_map: A user defined namespace prefix-URI map + + Returns: + The serialized xml string output. """ output = StringIO() self.write(output, obj, ns_map) return output.getvalue() def write(self, out: TextIO, obj: Any, ns_map: Optional[Dict] = None): - """ - Write the given object tree to the output text stream. + """Serialize the given object to the output text stream. - :param out: The output stream - :param obj: The input dataclass instance - :param ns_map: User defined namespace prefix-URI map + Args: + out: The output text stream + obj: The input model instance to serialize + ns_map: A user defined namespace prefix-URI map """ events = self.write_object(obj) handler = self.writer( @@ -71,8 +75,15 @@ def write(self, out: TextIO, obj: Any, ns_map: Optional[Dict] = None): ) handler.write(events) - def write_object(self, obj: Any): - """Produce an events stream from a dataclass or a derived element.""" + def write_object(self, obj: Any) -> EventIterator: + """Convert a user model, or derived element instance to sax events. + + Args: + obj: A user model, or derived element instance + + Yields: + An iterator of sax events. + """ qname = xsi_type = None if isinstance(obj, self.context.class_type.derived_element): meta = self.context.build( @@ -80,26 +91,37 @@ def write_object(self, obj: Any): ) qname = obj.qname obj = obj.value - xsi_type = namespaces.real_xsi_type(qname, meta.target_qname) + xsi_type = self.real_xsi_type(qname, meta.target_qname) yield from self.write_dataclass(obj, qname=qname, xsi_type=xsi_type) def write_dataclass( self, obj: Any, - namespace: NoneStr = None, - qname: NoneStr = None, + namespace: Optional[str] = None, + qname: Optional[str] = None, nillable: bool = False, xsi_type: Optional[str] = None, - ) -> Generator: - """ - Produce an events stream from a dataclass. + ) -> EventIterator: + """Convert a model instance to sax events. + + Optionally override the qualified name and the + xsi attributes type and nil. + + Args: + obj: A model instance + namespace: The field namespace URI + qname: Override the field qualified name + nillable: Specifies whether the field is nillable + xsi_type: Override the field xsi type - Optionally override the qualified name and the xsi properties - type and nil. + Yields: + An iterator of sax events. """ meta = self.context.build( - obj.__class__, namespace, globalns=self.config.globalns + obj.__class__, + namespace, + globalns=self.config.globalns, ) qname = qname or meta.qname nillable = nillable or meta.nillable @@ -117,10 +139,24 @@ def write_dataclass( yield XmlWriterEvent.END, qname - def write_xsi_type(self, value: Any, var: XmlVar, namespace: NoneStr) -> Generator: - """Produce an events stream from a dataclass for the given var with xsi - abstract type check for non wildcards.""" + def write_xsi_type( + self, + value: Any, + var: XmlVar, + namespace: Optional[str], + ) -> EventIterator: + """Convert a xsi:type value to sax events. + + The value can be assigned to wildcard, element or compound fields + Args: + value: A model instance + var: The field metadata instance + namespace: The field namespace URI + + Yields: + An iterator of sax events. + """ if var.is_wildcard: choice = var.find_value_choice(value, True) if choice: @@ -130,25 +166,37 @@ def write_xsi_type(self, value: Any, var: XmlVar, namespace: NoneStr) -> Generat elif var.is_element: xsi_type = self.xsi_type(var, value, namespace) yield from self.write_dataclass( - value, namespace, var.qname, var.nillable, xsi_type + value, + namespace, + var.qname, + var.nillable, + xsi_type, ) else: - # var elements + # var elements/compound meta = self.context.fetch(value.__class__, namespace) yield from self.write_dataclass(value, qname=meta.target_qname) - def write_value(self, value: Any, var: XmlVar, namespace: NoneStr) -> Generator: - """ - Delegates the given value to the correct writer according to the - variable metadata. + def write_value( + self, value: Any, var: XmlVar, namespace: Optional[str] + ) -> EventIterator: + """Convert any value to sax events according to the var instance. The order of the checks is important as more than one condition can be true. + + Args: + value: The input value + var: The field metadata instance + namespace: The class namespace URI + + Yields: + An iterator of sax events. """ if var.mixed: yield from self.write_mixed_content(value, var, namespace) elif var.is_text: - yield from self.write_data(value, var, namespace) + yield from self.write_data(value, var) elif var.tokens: yield from self.write_tokens(value, var, namespace) elif var.is_elements: @@ -159,9 +207,21 @@ def write_value(self, value: Any, var: XmlVar, namespace: NoneStr) -> Generator: yield from self.write_any_type(value, var, namespace) def write_list( - self, values: Iterable, var: XmlVar, namespace: NoneStr - ) -> Generator: - """Produce an events stream for the given list of values.""" + self, + values: Iterable, + var: XmlVar, + namespace: Optional[str], + ) -> EventIterator: + """Convert an array of values to sax events. + + Args: + values: A list, set, tuple instance + var: The field metadata instance + namespace: The class namespace + + Yields: + An iterator of sax events. + """ if var.wrapper is not None: yield XmlWriterEvent.START, var.wrapper for value in values: @@ -171,9 +231,19 @@ def write_list( for value in values: yield from self.write_value(value, var, namespace) - def write_tokens(self, value: Any, var: XmlVar, namespace: NoneStr) -> Generator: - """Produce an events stream for the given tokens list or list of tokens - lists.""" + def write_tokens( + self, value: Any, var: XmlVar, namespace: Optional[str] + ) -> EventIterator: + """Convert an array of token values to sax events. + + Args: + value: A list, set, tuple instance + var: The field metadata instance + namespace: The class namespace + + Yields: + An iterator of sax events. + """ if value or var.nillable or var.required: if value and collections.is_array(value[0]): for val in value: @@ -182,22 +252,39 @@ def write_tokens(self, value: Any, var: XmlVar, namespace: NoneStr) -> Generator yield from self.write_element(value, var, namespace) def write_mixed_content( - self, values: List, var: XmlVar, namespace: NoneStr - ) -> Generator: - """Produce an events stream for the given list of mixed type - objects.""" + self, + values: List, + var: XmlVar, + namespace: Optional[str], + ) -> EventIterator: + """Convert mixed content values to sax events. + + Args: + values: A list instance of mixed type values + var: The field metadata instance + namespace: The class namespace + + Yields: + An iterator of sax events. + """ for value in values: yield from self.write_any_type(value, var, namespace) - def write_any_type(self, value: Any, var: XmlVar, namespace: NoneStr) -> Generator: - """ - Produce an events stream for the given object. + def write_any_type( + self, value: Any, var: XmlVar, namespace: Optional[str] + ) -> EventIterator: + """Convert a value assigned to a xs:anyType field to sax events. + + Args: + value: A list instance of mixed type values + var: The field metadata instance + namespace: The class namespace - The object can be a dataclass or a generic object or any other - simple type. + Yields: + An iterator of sax events. """ if isinstance(value, self.context.class_type.any_element): - yield from self.write_wildcard(value, var, namespace) + yield from self.write_any_element(value, var, namespace) elif isinstance(value, self.context.class_type.derived_element): yield from self.write_derived_element(value, namespace) elif self.context.class_type.is_model(value): @@ -205,13 +292,24 @@ def write_any_type(self, value: Any, var: XmlVar, namespace: NoneStr) -> Generat elif var.is_element: yield from self.write_element(value, var, namespace) else: - yield from self.write_data(value, var, namespace) + yield from self.write_data(value, var) - def write_derived_element(self, value: Any, namespace: NoneStr) -> Generator: + def write_derived_element( + self, value: Any, namespace: Optional[str] + ) -> EventIterator: + """Convert a derived element instance to sax events. + + Args: + value: A list instance of mixed type values + namespace: The class namespace + + Yields: + An iterator of sax events. + """ if self.context.class_type.is_model(value.value): meta = self.context.fetch(value.value.__class__) qname = value.qname - xsi_type = namespaces.real_xsi_type(qname, meta.target_qname) + xsi_type = self.real_xsi_type(qname, meta.target_qname) yield from self.write_dataclass( value.value, namespace, qname=qname, xsi_type=xsi_type @@ -224,8 +322,19 @@ def write_derived_element(self, value: Any, namespace: NoneStr) -> Generator: yield XmlWriterEvent.DATA, value.value yield XmlWriterEvent.END, value.qname - def write_wildcard(self, value: Any, var: XmlVar, namespace: NoneStr) -> Generator: - """Produce an element events stream for the given generic object.""" + def write_any_element( + self, value: Any, var: XmlVar, namespace: Optional[str] + ) -> EventIterator: + """Convert a generic any element instance to sax events. + + Args: + value: A list instance of mixed type values + var: The field metadata instance + namespace: The class namespace + + Yields: + An iterator of sax events. + """ if value.qname: namespace, tag = namespaces.split_qname(value.qname) yield XmlWriterEvent.START, value.qname @@ -244,35 +353,69 @@ def write_wildcard(self, value: Any, var: XmlVar, namespace: NoneStr) -> Generat if value.tail: yield XmlWriterEvent.DATA, value.tail - def xsi_type(self, var: XmlVar, value: Any, namespace: NoneStr) -> Optional[str]: - """Get xsi:type if the given value is a derived instance.""" + def xsi_type( + self, var: XmlVar, value: Any, namespace: Optional[str] + ) -> Optional[str]: + """Return the xsi:type for the given value and field metadata instance. + + If the value type is either a child or parent for one of the var types, + we need to declare it as n xsi:type. + + Args: + value: A list instance of mixed type values + var: The field metadata instance + namespace: The class namespace + + Raises: + SerializerError: If the value type is completely unrelated to + the field types. + """ if not value or value.__class__ in var.types: return None clazz = var.clazz if clazz is None or self.context.is_derived(value, clazz): meta = self.context.fetch(value.__class__, namespace) - return namespaces.real_xsi_type(var.qname, meta.target_qname) + return self.real_xsi_type(var.qname, meta.target_qname) raise SerializerError( f"{value.__class__.__name__} is not derived from {clazz.__name__}" ) - def write_elements(self, value: Any, var: XmlVar, namespace: NoneStr) -> Generator: - """Produce an events stream from compound elements field.""" + def write_elements( + self, value: Any, var: XmlVar, namespace: Optional[str] + ) -> EventIterator: + """Convert the value assigned to a compound field to sax events. + + Args: + value: A list instance of mixed type values + var: The field metadata instance + namespace: The class namespace + + Yields: + An iterator of sax events. + """ if collections.is_array(value): - for choice in value: - yield from self.write_choice(choice, var, namespace) + for val in value: + yield from self.write_choice(val, var, namespace) else: yield from self.write_choice(value, var, namespace) - def write_choice(self, value: Any, var: XmlVar, namespace: NoneStr) -> Generator: - """ - Produce an events stream for the given value of a compound elements - field. + def write_choice( + self, value: Any, var: XmlVar, namespace: Optional[str] + ) -> EventIterator: + """Convert a single value assigned to a compound field to sax events. + + Args: + value: A list instance of mixed type values + var: The field metadata instance + namespace: The class namespace - The value can be anything as long as we can match the qualified - name or its type to a choice. + Yields: + An iterator of sax events. + + Raises: + SerializerError: If the value doesn't match any choice field. """ if isinstance(value, self.context.class_type.derived_element): choice = var.find_choice(value.qname) @@ -302,8 +445,22 @@ def write_choice(self, value: Any, var: XmlVar, namespace: NoneStr) -> Generator yield from func(value, choice, namespace) - def write_element(self, value: Any, var: XmlVar, namespace: NoneStr) -> Generator: - """Produce an element events stream for the given simple type value.""" + def write_element( + self, + value: Any, + var: XmlVar, + namespace: Optional[str], + ) -> EventIterator: + """Convert a value assigned to an element field to sax events. + + Args: + value: A list instance of mixed type values + var: The field metadata instance + namespace: The class namespace (unused) + + Yields: + An iterator of sax events. + """ yield XmlWriterEvent.START, var.qname if var.nillable: @@ -318,19 +475,34 @@ def write_element(self, value: Any, var: XmlVar, namespace: NoneStr) -> Generato yield XmlWriterEvent.END, var.qname @classmethod - def write_data(cls, value: Any, var: XmlVar, namespace: NoneStr) -> Generator: - """Produce a data event for the given value.""" + def write_data(cls, value: Any, var: XmlVar) -> EventIterator: + """Convert a value assigned to a text field to sax events. + + Args: + value: A list instance of mixed type values + var: The field metadata instance + + Yields: + An iterator of sax events. + """ yield XmlWriterEvent.DATA, cls.encode(value, var) @classmethod def next_value(cls, obj: Any, meta: XmlMeta) -> Iterator[Tuple[XmlVar, Any]]: - """ - Return the non attribute variables with their object values in the - correct order according to their definition and the sequential metadata - property. + """Produce the next non attribute value of a model instance to convert. + + The generator will produce the values in the order the fields + are defined in the model or by their sequence number. Sequential fields need to be rendered together in parallel order eg: + + Args: + obj: The input model instance + meta: The model metadata instance + + Yields: + An iterator of field metadata instance and value tuples. """ index = 0 attrs = meta.get_element_vars() @@ -380,17 +552,18 @@ def next_attribute( xsi_type: Optional[str], ignore_optionals: bool, ) -> Iterator[Tuple[str, Any]]: - """ - Return the attribute variables with their object values if set and not - empty iterables. - - :param obj: Input object - :param meta: Object metadata - :param nillable: Is model nillable - :param xsi_type: The true xsi:type of the object - :param ignore_optionals: Skip optional attributes with default - value - :return: + """Produce the next attribute value to convert. + + Args: + obj: The input model instance + meta: The model metadata instance + nillable: Specifies if the current element supports nillable content + xsi_type: The real xsi:type of the object + ignore_optionals: Specifies if optional attributes with default + values should be ignored. + + Yields: + An iterator of attribute name-value pairs. """ for var in meta.get_attribute_vars(): if var.is_attribute: @@ -414,8 +587,7 @@ def next_attribute( @classmethod def encode(cls, value: Any, var: XmlVar) -> Any: - """ - Encode values for xml serialization. + """Encode a value for xml serialization. Converts values to strings. QName instances is an exception, those values need to wait until the XmlWriter assigns prefixes @@ -426,6 +598,13 @@ def encode(cls, value: Any, var: XmlVar) -> Any: need to carry the xml vars inside the writer. Instead of that we do the easy encoding here and leave the qualified names for later. + + Args: + value: The simple type vale to encode + var: The field metadata instance + + Returns: + The encoded value. """ if isinstance(value, (str, QName)) or var is None: return value @@ -437,3 +616,17 @@ def encode(cls, value: Any, var: XmlVar) -> Any: return cls.encode(value.value, var) return converter.serialize(value, format=var.format) + + @classmethod + def real_xsi_type(cls, qname: str, target_qname: Optional[str]) -> Optional[str]: + """Compare the qname with the target qname and return the real xsi:type. + + Args: + qname: The field type qualified name + target_qname: The value type qualified name + + Returns: + None if the qname and target qname match, otherwise + return the target qname. + """ + return target_qname if target_qname != qname else None diff --git a/xsdata/formats/dataclass/transports.py b/xsdata/formats/dataclass/transports.py index 9e350cdae..20adff162 100644 --- a/xsdata/formats/dataclass/transports.py +++ b/xsdata/formats/dataclass/transports.py @@ -5,6 +5,8 @@ class Transport(abc.ABC): + """An HTTP transport interface.""" + __slots__ = () @abc.abstractmethod @@ -17,10 +19,10 @@ def post(self, url: str, data: Any, headers: Dict) -> bytes: class DefaultTransport(Transport): - """ - Default transport based on the requests library. + """Default transport based on the `requests` library. - :param timeout: Read timeout + Args: + timeout: Read timeout in seconds """ __slots__ = "timeout", "session" @@ -30,27 +32,58 @@ def __init__(self, timeout: float = 2.0, session: Optional[Session] = None): self.session = session or Session() def get(self, url: str, params: Dict, headers: Dict) -> bytes: - """ - :raises HTTPError: if status code is not valid for content unmarshalling. + """Send a GET request. + + Args: + url: The base URL + params: The query parameters + headers: A key-value map of HTTP headers + + Returns: + The encoded response content. + + Raises: + HTTPError: if status code is not valid for content unmarshalling. """ res = self.session.get( - url, params=params, headers=headers, timeout=self.timeout + url, + params=params, + headers=headers, + timeout=self.timeout, ) return self.handle_response(res) def post(self, url: str, data: Any, headers: Dict) -> Any: - """ - :raises HTTPError: if status code is not valid for content unmarshalling. + """Send a POST request. + + Args: + url: The base URL + data: The request body payload + headers: A key-value map of HTTP headers + + Returns: + The encoded response content. + + Raises: + HTTPError: if status code is not valid for content unmarshalling. """ res = self.session.post(url, data=data, headers=headers, timeout=self.timeout) return self.handle_response(res) @classmethod def handle_response(cls, response: Response) -> bytes: - """ + """Return the response content or raise an exception. + Status codes 200 or 500 means that we can unmarshall the response. - :raises HTTPError: If the response status code is not 200 or 500 + Args: + response: The response instance + + Returns: + The encoded response content. + + Raises: + HTTPError: If the response status code is not 200 or 500 """ if response.status_code not in (200, 500): response.raise_for_status() diff --git a/xsdata/formats/dataclass/typing.py b/xsdata/formats/dataclass/typing.py index d92cdaefa..a375346f1 100644 --- a/xsdata/formats/dataclass/typing.py +++ b/xsdata/formats/dataclass/typing.py @@ -37,6 +37,7 @@ def _eval_type(tp: Any, globalns: Any, localns: Any) -> Any: def is_from_typing(tp: Any) -> bool: + """Return whether the type is from the typing module.""" return str(tp).startswith(intern_typing) @@ -45,6 +46,7 @@ def evaluate( globalns: Any = None, localns: Any = None, ) -> Tuple[Type, ...]: + """Analyze/Validate the typing annotation.""" return tuple(_evaluate(_eval_type(tp, globalns, localns))) diff --git a/xsdata/formats/mixins.py b/xsdata/formats/mixins.py index 39caa45c9..8cf4dd0eb 100644 --- a/xsdata/formats/mixins.py +++ b/xsdata/formats/mixins.py @@ -12,12 +12,12 @@ class GeneratorResult(NamedTuple): - """ - Generator easy access output wrapper. + """Generator result transfer object. - :param path: file path to be written - :param title: result title for misc usage - :param source: source code/output to be written + Args: + path: The target file path + title: The result title for misc usage + source: The source code/output to be written """ path: Path @@ -26,16 +26,15 @@ class GeneratorResult(NamedTuple): class AbstractGenerator(abc.ABC): - """Abstract code generator class.""" + """Abstract code generator class. + + Args: + config: The generator config instance + """ __slots__ = "config" def __init__(self, config: GeneratorConfig): - """ - Generator constructor. - - :param config Generator configuration - """ self.config = config def module_name(self, module: str) -> str: @@ -74,11 +73,14 @@ def render_header(self) -> str: ) def normalize_packages(self, classes: List[Class]): - """ - Normalize the target package and module names by the given output - generator. + """Normalize the classes module and package names. + + Args: + classes: A list of class instances - :param classes: a list of codegen class instances + Raises: + CodeGenerationErrorL If the analyzer failed to + designate a class to a package and module. """ modules = {} packages = {} diff --git a/xsdata/models/config.py b/xsdata/models/config.py index efa89aa89..6fbc0ee6d 100644 --- a/xsdata/models/config.py +++ b/xsdata/models/config.py @@ -4,7 +4,7 @@ from dataclasses import dataclass, field from enum import Enum from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Pattern, TextIO +from typing import Any, Callable, Dict, List, Pattern, TextIO from xsdata import __version__ from xsdata.exceptions import CodeGenerationWarning, GeneratorConfigError @@ -14,22 +14,21 @@ from xsdata.formats.dataclass.serializers import XmlSerializer from xsdata.formats.dataclass.serializers.config import SerializerConfig from xsdata.formats.dataclass.serializers.writers import XmlEventWriter -from xsdata.logger import logger from xsdata.models.enums import Namespace from xsdata.models.mixins import array_element, attribute, element, text_node from xsdata.utils import objects, text class StructureStyle(Enum): - """ - Code writer output structure strategies. - - :cvar FILENAMES: filenames: groups classes by the schema location - :cvar NAMESPACES: namespaces: group classes by the target namespace - :cvar CLUSTERS: clusters: group by strong connected dependencies - :cvar SINGLE_PACKAGE: single-package: group all classes together - :cvar NAMESPACE_CLUSTERS: namespace-clusters: group by strong - connected dependencies and namespaces + """Output structure style enumeration. + + Attributes: + FILENAMES: filenames: groups classes by the schema location + NAMESPACES: namespaces: group classes by the target namespace + CLUSTERS: clusters: group by strong connected dependencies + SINGLE_PACKAGE: single-package: group all classes together + NAMESPACE_CLUSTERS: namespace-clusters: group by strong + connected dependencies and namespaces """ FILENAMES = "filenames" @@ -40,42 +39,22 @@ class StructureStyle(Enum): class NameCase(Enum): - """ - Code writer naming schemes. + """Naming case convention enumeration. All schemes are using a processor that splits a string into words - when it encounters non alphanumerical characters or when an upper + when it encounters non-alphanumerical characters or when an upper case letter follows a lower case letter. - +-----------+-----------+-----------+------------+-----------------+-----------+-------------+--------------+ - | Original | Pascal | Camel | Snake | Screaming Snake | Mixed | Mixed Snake | Mixed Pascal | - +===========+===========+===========+============+=================+===========+=============+==============+ - | p00p | P00P | p00P | p00p | P00P | p00p | p00p | P00p | - +-----------+-----------+-----------+------------+-----------------+-----------+-------------+--------------+ - | USERName | Username | username | username | USERNAME | USERName | USERName | USERName | - +-----------+-----------+-----------+------------+-----------------+-----------+-------------+--------------+ - | UserNAME | UserName | userName | user_name | USER_NAME | UserNAME | User_NAME | UserNAME | - +-----------+-----------+-----------+------------+-----------------+-----------+-------------+--------------+ - | USER_name | UserName | userName | user_name | USER_NAME | USERname | USER_name | USERname | - +-----------+-----------+-----------+------------+-----------------+-----------+-------------+--------------+ - | USER-NAME | UserName | userName | user_name | USER_NAME | USERNAME | USER_NAME | USERNAME | - +-----------+-----------+-----------+------------+-----------------+-----------+-------------+--------------+ - | User_Name | UserName | userName | user_name | USER_NAME | UserName | User_Name | UserName | - +-----------+-----------+-----------+------------+-----------------+-----------+-------------+--------------+ - | user_name | UserName | userName | user_name | USER_NAME | username | user_name | Username | - +-----------+-----------+-----------+------------+-----------------+-----------+-------------+--------------+ - | SUserNAME | SuserName | suserName | suser_name | SUSER_NAME | SUserNAME | SUser_NAME | SUserNAME | - +-----------+-----------+-----------+------------+-----------------+-----------+-------------+--------------+ - - :cvar ORIGINAL: originalCase - :cvar PASCAL: pascalCase - :cvar CAMEL: camelCase - :cvar SNAKE: snakeCase - :cvar SCREAMING_SNAKE: screamingSnakeCase - :cvar MIXED: mixedCase mixedCase - :cvar MIXED_SNAKE: mixedSnakeCase - :cvar MIXED_PASCAL: mixedPascalCase - """ # noqa + Attributes: + ORIGINAL: originalCase + PASCAL: pascalCase + CAMEL: camelCase + SNAKE: snakeCase + SCREAMING_SNAKE: screamingSnakeCase + MIXED: mixedCase + MIXED_SNAKE: mixedSnakeCase + MIXED_PASCAL: mixedPascalCase + """ ORIGINAL = "originalCase" PASCAL = "pascalCase" @@ -87,6 +66,7 @@ class NameCase(Enum): MIXED_PASCAL = "mixedPascalCase" def __call__(self, string: str, **kwargs: Any) -> str: + """Apply the callback to the input string.""" return self.callback(string, **kwargs) @property @@ -108,14 +88,14 @@ def callback(self) -> Callable: class DocstringStyle(Enum): - """ - Code writer docstring styles. + """Docstring style enumeration. - :cvar RST: reStructuredText - :cvar NUMPY: NumPy - :cvar GOOGLE: Google - :cvar ACCESSIBLE: Accessible - :cvar BLANK: Blank + Attributes: + RST: reStructuredText + NUMPY: NumPy + GOOGLE: Google + ACCESSIBLE: Accessible + BLANK: Blank """ RST = "reStructuredText" @@ -126,13 +106,13 @@ class DocstringStyle(Enum): class ClassFilterStrategy(Enum): - """ - Class filter strategy. + """Class filter strategy enumeration. - :cvar ALL: all: Generate all types, discouraged!!! - :cvar ALL_GLOBALS: allGlobals: Generate all global types - :cvar REFERRED_GLOBALS: referredGlobals: Generate all global types - with at least one reference. + Attributes: + ALL: all: Generate all types, discouraged!!! + ALL_GLOBALS: allGlobals: Generate all global types + REFERRED_GLOBALS: referredGlobals: Generate all global types + with at least one reference. """ ALL = "all" @@ -141,13 +121,13 @@ class ClassFilterStrategy(Enum): class ObjectType(Enum): - """ - Object type enumeration. + """Object type enumeration. - :cvar CLASS: class - :cvar FIELD: field - :cvar MODULE: module - :cvar PACKAGE: package + Attributes: + CLASS: class + FIELD: field + MODULE: module + PACKAGE: package """ CLASS = "class" @@ -157,11 +137,11 @@ class ObjectType(Enum): class ExtensionType(Enum): - """ - Extension type enumeration. + """Extension type enumeration. - :cvar CLASS: class - :cvar DECORATOR: decorator + Attributes: + CLASS: class + DECORATOR: decorator """ CLASS = "class" @@ -170,17 +150,17 @@ class ExtensionType(Enum): @dataclass class OutputFormat: - """ - Output format options. - - :param value: Output format name - :param repr: Generate __repr__ method - :param eq: Generate __eq__ method - :param order: Generate __lt__, __le__, __gt__, and __ge__ methods - :param unsafe_hash: Generate __hash__ method if not frozen - :param frozen: Enable read only properties - :param slots: Enable __slots__, python>=3.10 Only - :param kw_only: Enable keyword only arguments, python>=3.10 Only + """Output format model representation. + + Args: + value: Output format name + repr: Generate __repr__ method + eq: Generate __eq__ method + order: Generate __lt__, __le__, __gt__, and __ge__ methods + unsafe_hash: Generate __hash__ method + frozen: Enable read only properties + slots: Enable __slots__, python>=3.10 Only + kw_only: Enable keyword only arguments, python>=3.10 Only """ value: str = text_node(default="dataclasses", cli="output") @@ -193,9 +173,11 @@ class OutputFormat: kw_only: bool = attribute(default=False) def __post_init__(self): + """Post initialization method.""" self.validate() def validate(self): + """Validate and reset configuration conflicts.""" if self.order and not self.eq: raise GeneratorConfigError("eq must be true if order is true") @@ -217,19 +199,19 @@ def validate(self): @dataclass class CompoundFields: - """ - Compound fields options. - - :param enabled: Use compound fields for repeatable elements - :param default_name: Default compound field name - :param use_substitution_groups: Use substitution groups if they - exist, instead of element names. - :param force_default_name: Always use the default compound field - name, or try to generate one by the list of element names if - they are no longer than the max name parts. e.g. - hat_or_dress_or_something. - :param max_name_parts: Maximum number of element names before using - the default name. + """Compound fields model representation. + + Args: + enabled: Use compound fields for repeatable elements + default_name: Default compound field name + use_substitution_groups: Use substitution groups if they + exist, instead of element names. + force_default_name: Always use the default compound field + name, or try to generate one by the list of element names if + they are no longer than the max name parts. e.g. + hat_or_dress_or_something. + max_name_parts: Maximum number of element names before using + the default name. """ enabled: bool = text_node(default=False, cli="compound-fields") @@ -241,26 +223,26 @@ class CompoundFields: @dataclass class GeneratorOutput: - """ - Main generator output options. - - :param package: Target package - :param format: Output format - :param structure_style: Output structure style - :param docstring_style: Docstring style - :param filter_strategy: Class filter strategy - :param relative_imports: Use relative imports - :param compound_fields: Use compound fields for repeatable elements - :param max_line_length: Adjust the maximum line length - :param subscriptable_types: Use PEP-585 generics for standard - collections, python>=3.9 Only - :param union_type: Use PEP-604 union type, python>=3.10 Only - :param postponed_annotations: Enable postponed evaluation of - annotations - :param unnest_classes: Move inner classes to upper level - :param ignore_patterns: Ignore pattern restrictions - :param include_header: Include a header with codegen information in - the output + """Generator output model representation. + + Args: + package: Target package + format: Output format + structure_style: Output structure style + docstring_style: Docstring style + filter_strategy: Class filter strategy + relative_imports: Use relative imports + compound_fields: Use compound fields for repeatable elements + max_line_length: Adjust the maximum line length + subscriptable_types: Use PEP-585 generics for standard + collections, python>=3.9 Only + union_type: Use PEP-604 union type, python>=3.10 Only + postponed_annotations: Enable postponed evaluation of + annotations + unnest_classes: Move inner classes to upper level + ignore_patterns: Ignore pattern restrictions + include_header: Include a header with codegen information in + the output """ package: str = element(default="generated") @@ -283,9 +265,11 @@ class GeneratorOutput: include_header: bool = element(default=False) def __post_init__(self): + """Post initialization method.""" self.validate() def validate(self): + """Reset configuration conflicts.""" if self.subscriptable_types and sys.version_info < (3, 9): self.subscriptable_types = False warnings.warn( @@ -308,22 +292,19 @@ def validate(self): ) def update(self, **kwargs: Any): + """Update instance attributes recursively.""" objects.update(self, **kwargs) self.format.validate() @dataclass class NameConvention: - """ - Name convention model. - - :param case: Naming scheme, e.g. camelCase, snakeCase - :param safe_prefix: A prefix to be prepended into names that match - one of the reserved words: and, except, lambda, with, as, - finally, nonlocal, while, assert, false, none, yield, break, - for, not, class, from, or, continue, global, pass, def, if, - raise, del, import, return, elif, in, true, else, is, try, - str, int, bool, float, list, optional, dict, field + """Name convention model representation. + + Args: + case: Naming scheme, e.g. camelCase, snakeCase + safe_prefix: A prefix to be prepended into names that match + one of the reserved words. """ case: NameCase = attribute(optional=False) @@ -332,13 +313,13 @@ class NameConvention: @dataclass class GeneratorConventions: - """ - Generator global naming conventions. + """Generator naming conventions model representation. - :param class_name: Class naming conventions. - :param field_name: Field naming conventions. - :param module_name: Module naming conventions. - :param package_name: Package naming conventions. + Args: + class_name: Class naming conventions. + field_name: Field naming conventions. + module_name: Module naming conventions. + package_name: Package naming conventions. """ class_name: NameConvention = element( @@ -360,7 +341,8 @@ class GeneratorConventions: @dataclass class GeneratorAlias: - """ + """Generator alias model representation. + Define an alias for a module, package, class and field Alias definition model. @@ -370,8 +352,9 @@ class GeneratorAlias: filename or target namespace depending on the selected output structure. - :param source: The source name from schema definition - :param target: The target name of the object. + Args: + source: The source name from schema definition + target: The target name of the object. """ source: str = attribute(required=True) @@ -380,17 +363,17 @@ class GeneratorAlias: @dataclass class GeneratorAliases: - """ - Generator aliases for classes, fields, packages and modules that bypass the - global naming conventions. + """Generator aliases model representation. - .. warning:: - The generator doesn't validate aliases. + Generator aliases for classes, fields, packages and modules + that bypass the global naming conventions. The aliases + are not validated as valid python identifiers. - :param class_name: list of class name aliases - :param field_name: list of field name aliases - :param package_name: list of package name aliases - :param module_name: list of module name aliases + Args: + class_name: A list of class name aliases + field_name: A list of field name aliases + package_name: A list of package name aliases + module_name: A list of module name aliases """ class_name: List[GeneratorAlias] = array_element() @@ -401,13 +384,14 @@ class GeneratorAliases: @dataclass class GeneratorSubstitution: - """ - Search and replace substitution for a specific target type based on - :func:`re.sub` + """Generator substitution model representation. - :param type: The target object type - :param search: The search string or a pattern object - :param replace: The replacement string or pattern object + Search and replace substitutions based on `re.sub`. + + Args: + type: The target object type + search: The search string or a pattern object + replace: The replacement string or pattern object """ type: ObjectType = attribute(required=True) @@ -417,17 +401,22 @@ class GeneratorSubstitution: @dataclass class GeneratorExtension: - """ - Add decorators or base classes on the generated classes that match the - class name pattern. - - :param type: The extension type - :param class_name: The class name or a pattern to apply the - extension - :param import_string: The import string of the extension type - :param prepend: Prepend or append decorator or base class - :param apply_if_derived: Apply or skip if the class is already a - subclass + """Generator extension model representation. + + Add decorators or base classes on the generated classes + that match the class name pattern. + + Args: + type: The extension type + class_name: The class name or a pattern to apply the extension + import_string: The import string of the extension type + prepend: Prepend or append decorator or base class + apply_if_derived: Apply or skip if the class is already a subclass + + Attributes: + module_path: The module path of the base class or the annotation + func_name: The annotation or base class name + pattern: The compiled search class name pattern """ type: ExtensionType = attribute(required=True) @@ -450,6 +439,13 @@ class name pattern. ) def __post_init__(self): + """Post initialization method. + + Set the module, func_name and pattern instance attributes. + + Raises: + GeneratorConfigError: If the pattern can not be compiled. + """ try: self.module_path, self.func_name = self.import_string.rsplit(".", 1) except (ValueError, AttributeError): @@ -465,14 +461,14 @@ def __post_init__(self): @dataclass class GeneratorSubstitutions: - """ + """Generator substitutions model representation. + Generator search and replace substitutions for classes, fields, packages and modules names. The process runs before and after the default naming conventions. - .. warning:: The generator doesn't validate substitutions. - - :param substitution: The list of substitutions + Args: + substitution: The list of substitution instances """ substitution: List[GeneratorSubstitution] = array_element() @@ -480,13 +476,14 @@ class GeneratorSubstitutions: @dataclass class GeneratorExtensions: - """ - Generator extensions for classes. The process runs after the default naming - conventions. + """Generator extensions model representation. - .. warning:: The generator doesn't validate imports! + Generator extensions for classes. The process runs after the + default naming conventions. The generator doesn't validate + imports! - :param extension: The list of extensions + Args: + extension: The list of extension instances """ extension: List[GeneratorExtension] = array_element() @@ -494,51 +491,36 @@ class GeneratorExtensions: @dataclass class GeneratorConfig: - """ - Generator configuration binding model. - - :cvar version: xsdata version number the config was created/updated - :param output: Output options - :param conventions: Generator conventions - :param aliases: Generator aliases, Deprecated since v21.12, use - substitutions - :param substitutions: Generator search and replace substitutions for - classes, fields, packages and modules names. - :param extensions: Generator custom base classes and decorators for - classes. + """Generator configuration model representation. + + Args: + output: Output options + conventions: Generator conventions + substitutions: Search and replace substitutions for + classes, fields, packages and modules names. + extensions: Generator custom base classes and decorators for classes. + + Attributes: + version: The xsdata version number the config was created/updated """ class Meta: + """Metadata options.""" + name = "Config" namespace = "http://pypi.org/project/xsdata" version: str = attribute(init=False, default=__version__) output: GeneratorOutput = element(default_factory=GeneratorOutput) conventions: GeneratorConventions = element(default_factory=GeneratorConventions) - aliases: Optional[GeneratorAliases] = element(default=None) substitutions: GeneratorSubstitutions = element( default_factory=GeneratorSubstitutions ) extensions: GeneratorExtensions = element(default_factory=GeneratorExtensions) - def __post_init__(self): - if self.aliases: - alias_map = { - ObjectType.CLASS: self.aliases.class_name, - ObjectType.FIELD: self.aliases.field_name, - ObjectType.PACKAGE: self.aliases.package_name, - ObjectType.MODULE: self.aliases.module_name, - } - for object_type, aliases in alias_map.items(): - for alias in aliases: - self.substitutions.substitution.append( - GeneratorSubstitution( - type=object_type, search=alias.source, replace=alias.target - ) - ) - @classmethod def create(cls) -> "GeneratorConfig": + """Initialize with default substitutions for common namespaces.""" obj = cls() for ns in Namespace: @@ -558,6 +540,7 @@ def create(cls) -> "GeneratorConfig": @classmethod def read(cls, path: Path) -> "GeneratorConfig": + """Load configuration from a file path.""" if not path.exists(): return cls() @@ -572,23 +555,11 @@ def read(cls, path: Path) -> "GeneratorConfig": fail_on_converter_warnings=True, ), ) - config = parser.from_path(path, cls) - - if config.aliases and ( - config.aliases.class_name - or config.aliases.field_name - or config.aliases.package_name - or config.aliases.module_name - ): - config.aliases = None - logger.warning("Migrating aliases to substitutions config, verify output!") - with path.open("w") as fp: - config.write(fp, config) - - return config + return parser.from_path(path, cls) @classmethod def write(cls, output: TextIO, obj: "GeneratorConfig"): + """Write the configuration to the output stream as xml.""" ctx = XmlContext( element_name_generator=text.pascal_case, attribute_name_generator=text.camel_case, diff --git a/xsdata/models/datatype.py b/xsdata/models/datatype.py index ef65d1f2f..1b4d6d31f 100644 --- a/xsdata/models/datatype.py +++ b/xsdata/models/datatype.py @@ -31,6 +31,8 @@ class DateFormat: + """Xml date formats.""" + DATE = "%Y-%m-%d%z" TIME = "%H:%M:%S%z" DATE_TIME = "%Y-%m-%dT%H:%M:%S%z" @@ -42,17 +44,16 @@ class DateFormat: class XmlDate(NamedTuple): - """ - Concrete xs:date builtin type. + """Concrete xs:date builtin type. Represents iso 8601 date format [-]CCYY-MM-DD[Z|(+|-)hh:mm] with rich comparisons and hashing. - :param year: Any signed integer, eg (0, -535, 2020) - :param month: Unsigned integer between 1-12 - :param day: Unsigned integer between 1-31 - :param offset: Signed integer representing timezone offset in - minutes + Args: + year: Any signed integer, eg (0, -535, 2020) + month: Unsigned integer between 1-12 + day: Unsigned integer between 1-31 + offset: Signed integer representing timezone offset in minutes """ year: int @@ -67,9 +68,7 @@ def replace( day: Optional[int] = None, offset: Optional[int] = True, ) -> "XmlDate": - """Return a new instance replacing the specified fields with new - values.""" - + """Return a new instance replacing the specified fields with new values.""" if year is None: year = self.year if month is None: @@ -83,51 +82,48 @@ def replace( @classmethod def from_string(cls, string: str) -> "XmlDate": - """Initialize from string with format ``%Y-%m-%dT%z``""" + """Initialize from string with format `%Y-%m-%dT%z`.""" return cls(*parse_date_args(string, DateFormat.DATE)) @classmethod def from_date(cls, obj: datetime.date) -> "XmlDate": - """ - Initialize from :class:`datetime.date` instance. - - .. warning:: - - date instances don't have timezone information! - """ + """Initialize from a `datetime.date` instance.""" return cls(obj.year, obj.month, obj.day) @classmethod def from_datetime(cls, obj: datetime.datetime) -> "XmlDate": - """Initialize from :class:`datetime.datetime` instance.""" + """Initialize from `datetime.datetime` instance.""" return cls(obj.year, obj.month, obj.day, calculate_offset(obj)) @classmethod def today(cls) -> "XmlDate": - """Initialize from datetime.date.today()""" + """Initialize with the current date.""" return cls.from_date(datetime.date.today()) def to_date(self) -> datetime.date: - """Return a :class:`datetime.date` instance.""" + """Convert to a :`datetime.date` instance.""" return datetime.date(self.year, self.month, self.day) def to_datetime(self) -> datetime.datetime: - """Return a :class:`datetime.datetime` instance.""" + """Convert to a `datetime.datetime` instance.""" tz_info = calculate_timezone(self.offset) return datetime.datetime(self.year, self.month, self.day, tzinfo=tz_info) def __str__(self) -> str: - """ - Return the date formatted according to ISO 8601 for xml. + """Return the date formatted according to ISO 8601 for xml. Examples: - 2001-10-26 - 2001-10-26+02:00 - 2001-10-26Z + + Returns: + The str result. """ return format_date(self.year, self.month, self.day) + format_offset(self.offset) def __repr__(self) -> str: + """Return the instance string representation.""" args = [self.year, self.month, self.day, self.offset] if args[-1] is None: del args[-1] @@ -136,20 +132,20 @@ def __repr__(self) -> str: class XmlDateTime(NamedTuple): - """ - Concrete xs:dateTime builtin type. - - Represents iso 8601 date time format [-]CCYY-MM-DDThh - :mm: ss[Z|(+|-)hh:mm] with rich comparisons and hashing. - :param year: Any signed integer, eg (0, -535, 2020) - :param month: Unsigned integer between 1-12 - :param day: Unsigned integer between 1-31 - :param hour: Unsigned integer between 0-24 - :param minute: Unsigned integer between 0-59 - :param second: Unsigned integer between 0-59 - :param fractional_second: Unsigned integer between 0-999999999 - :param offset: Signed integer representing timezone offset in - minutes + """Concrete xs:dateTime builtin type. + + Represents iso 8601 date time format `[-]CCYY-MM-DDThh:mm: ss[Z|(+|-)hh:mm]` + with rich comparisons and hashing. + + Args: + year: Any signed integer, eg (0, -535, 2020) + month: Unsigned integer between 1-12 + day: Unsigned integer between 1-31 + hour: Unsigned integer between 0-24 + minute: Unsigned integer between 0-59 + second: Unsigned integer between 0-59 + fractional_second: Unsigned integer between 0-999999999 + offset: Signed integer representing timezone offset in minutes """ year: int @@ -163,10 +159,12 @@ class XmlDateTime(NamedTuple): @property def microsecond(self) -> int: + """Calculate the instance microseconds.""" return self.fractional_second // 1000 @property def duration(self) -> float: + """Calculate the instance signed duration in seconds.""" if self.year < 0: negative = True year = -self.year @@ -188,7 +186,7 @@ def duration(self) -> float: @classmethod def from_string(cls, string: str) -> "XmlDateTime": - """Initialize from string with format ``%Y-%m-%dT%H:%M:%S%z``""" + """Initialize from string with format `%Y-%m-%dT%H:%M:%S%z`.""" ( year, month, @@ -206,7 +204,7 @@ def from_string(cls, string: str) -> "XmlDateTime": @classmethod def from_datetime(cls, obj: datetime.datetime) -> "XmlDateTime": - """Initialize from :class:`datetime.datetime` instance.""" + """Initialize from `datetime.datetime` instance.""" return cls( obj.year, obj.month, @@ -220,16 +218,16 @@ def from_datetime(cls, obj: datetime.datetime) -> "XmlDateTime": @classmethod def now(cls, tz: Optional[datetime.timezone] = None) -> "XmlDateTime": - """Initialize from datetime.datetime.now()""" + """Initialize with the current datetime and the given timezone.""" return cls.from_datetime(datetime.datetime.now(tz=tz)) @classmethod def utcnow(cls) -> "XmlDateTime": - """Initialize from datetime.now(timezone.utc)""" + """Initialize with the current datetime and utc timezone.""" return cls.from_datetime(datetime.datetime.now(datetime.timezone.utc)) def to_datetime(self) -> datetime.datetime: - """Return a :class:`datetime.datetime` instance.""" + """Return a `datetime.datetime` instance.""" return datetime.datetime( self.year, self.month, @@ -252,9 +250,7 @@ def replace( fractional_second: Optional[int] = None, offset: Optional[int] = True, ) -> "XmlDateTime": - """Return a new instance replacing the specified fields with new - values.""" - + """Return a new instance replacing the specified fields with new values.""" if year is None: year = self.year if month is None: @@ -277,8 +273,7 @@ def replace( ) def __str__(self) -> str: - """ - Return the datetime formatted according to ISO 8601 for xml. + """Return the datetime formatted according to ISO 8601 for xml. Examples: - 2001-10-26T21:32:52 @@ -295,6 +290,7 @@ def __str__(self) -> str: ) def __repr__(self) -> str: + """Return the instance string representation.""" args = tuple(self) if args[-1] is None: args = args[:-1] @@ -305,36 +301,42 @@ def __repr__(self) -> str: return f"{self.__class__.__qualname__}({', '.join(map(str, args))})" def __eq__(self, other: Any) -> bool: - return cmp(self, other, operator.eq) + """Return self == other.""" + return _cmp(self, other, operator.eq) def __ne__(self, other: Any) -> bool: - return cmp(self, other, operator.ne) + """Return self != other.""" + return _cmp(self, other, operator.ne) def __lt__(self, other: Any) -> bool: - return cmp(self, other, operator.lt) + """Return self < other.""" + return _cmp(self, other, operator.lt) def __le__(self, other: Any) -> bool: - return cmp(self, other, operator.le) + """Return self <= other.""" + return _cmp(self, other, operator.le) def __gt__(self, other: Any) -> bool: - return cmp(self, other, operator.gt) + """Return self > other.""" + return _cmp(self, other, operator.gt) def __ge__(self, other: Any) -> bool: - return cmp(self, other, operator.ge) + """Return self >= other.""" + return _cmp(self, other, operator.ge) class XmlTime(NamedTuple): - """ - Concrete xs:time builtin type. - - Represents iso 8601 time format hh - :mm: ss[Z|(+|-)hh:mm] with rich comparisons and hashing. - :param hour: Unsigned integer between 0-24 - :param minute: Unsigned integer between 0-59 - :param second: Unsigned integer between 0-59 - :param fractional_second: Unsigned integer between 0-999999999 - :param offset: Signed integer representing timezone offset in - minutes + """Concrete xs:time builtin type. + + Represents iso 8601 time format `hh:mm: ss[Z|(+|-)hh:mm]` + with rich comparisons and hashing. + + Args: + hour: Unsigned integer between 0-24 + minute: Unsigned integer between 0-59 + second: Unsigned integer between 0-59 + fractional_second: Unsigned integer between 0-999999999 + offset: Signed integer representing timezone offset in minutes """ hour: int @@ -345,10 +347,12 @@ class XmlTime(NamedTuple): @property def microsecond(self) -> int: + """Calculate the instance microseconds.""" return self.fractional_second // 1000 @property def duration(self) -> float: + """Calculate the total duration in seconds.""" return ( self.hour * DS_HOUR + self.minute * DS_MINUTE @@ -365,9 +369,7 @@ def replace( fractional_second: Optional[int] = None, offset: Optional[int] = True, ) -> "XmlTime": - """Return a new instance replacing the specified fields with new - values.""" - + """Return a new instance replacing the specified fields with new values.""" if hour is None: hour = self.hour if minute is None: @@ -383,7 +385,7 @@ def replace( @classmethod def from_string(cls, string: str) -> "XmlTime": - """Initialize from string format ``%H:%M:%S%z``""" + """Initialize from string format `%H:%M:%S%z`.""" hour, minute, second, fractional_second, offset = parse_date_args( string, DateFormat.TIME ) @@ -392,7 +394,7 @@ def from_string(cls, string: str) -> "XmlTime": @classmethod def from_time(cls, obj: datetime.time) -> "XmlTime": - """Initialize from :class:`datetime.time` instance.""" + """Initialize from `datetime.time` instance.""" return cls( obj.hour, obj.minute, @@ -403,16 +405,16 @@ def from_time(cls, obj: datetime.time) -> "XmlTime": @classmethod def now(cls, tz: Optional[datetime.timezone] = None) -> "XmlTime": - """Initialize from datetime.datetime.now()""" + """Initialize with the current time and the given timezone.""" return cls.from_time(datetime.datetime.now(tz=tz).time()) @classmethod def utcnow(cls) -> "XmlTime": - """Initialize from datetime.now(timezone.utc)""" + """Initialize with the current time and utc timezone.""" return cls.from_time(datetime.datetime.now(datetime.timezone.utc).time()) def to_time(self) -> datetime.time: - """Return a :class:`datetime.time` instance.""" + """Convert to a `datetime.time` instance.""" return datetime.time( self.hour, self.minute, @@ -422,8 +424,7 @@ def to_time(self) -> datetime.time: ) def __str__(self) -> str: - """ - Return the time formatted according to ISO 8601 for xml. + """Return the time formatted according to ISO 8601 for xml. Examples: - 21:32:52 @@ -438,6 +439,7 @@ def __str__(self) -> str: ) def __repr__(self) -> str: + """Return the instance string representation.""" args = list(self) if args[-1] is None: del args[-1] @@ -445,28 +447,34 @@ def __repr__(self) -> str: return f"{self.__class__.__qualname__}({', '.join(map(str, args))})" def __eq__(self, other: Any) -> bool: - return cmp(self, other, operator.eq) + """Return self == other.""" + return _cmp(self, other, operator.eq) def __ne__(self, other: Any) -> bool: - return cmp(self, other, operator.ne) + """Return self != other.""" + return _cmp(self, other, operator.ne) def __lt__(self, other: Any) -> bool: - return cmp(self, other, operator.lt) + """Return self < other.""" + return _cmp(self, other, operator.lt) def __le__(self, other: Any) -> bool: - return cmp(self, other, operator.le) + """Return self <= other.""" + return _cmp(self, other, operator.le) def __gt__(self, other: Any) -> bool: - return cmp(self, other, operator.gt) + """Return self > other.""" + return _cmp(self, other, operator.gt) def __ge__(self, other: Any) -> bool: - return cmp(self, other, operator.ge) + """Return self >= other.""" + return _cmp(self, other, operator.ge) DurationType = Union[XmlTime, XmlDateTime] -def cmp(a: DurationType, b: DurationType, op: Callable) -> bool: +def _cmp(a: DurationType, b: DurationType, op: Callable) -> bool: if isinstance(b, a.__class__): return op(a.duration, b.duration) @@ -474,6 +482,8 @@ def cmp(a: DurationType, b: DurationType, op: Callable) -> bool: class TimeInterval(NamedTuple): + """Time interval model representation.""" + negative: bool years: Optional[int] months: Optional[int] @@ -484,8 +494,7 @@ class TimeInterval(NamedTuple): class XmlDuration(UserString): - """ - Concrete xs:duration builtin type. + """Concrete xs:duration builtin type. Represents iso 8601 duration format PnYnMnDTnHnMnS with rich comparisons and hashing. @@ -500,7 +509,8 @@ class XmlDuration(UserString): - **nM**: the number of minutes followed by a literal M - **nS**: the number of seconds followed by a literal S - :param value: String representation of a xs:duration, eg **P2Y6M5DT12H** + Args: + value: String representation of a xs:duration, eg **P2Y6M5DT12H** """ def __init__(self, value: str) -> None: @@ -566,13 +576,17 @@ def _parse_interval(cls, value: str) -> TimeInterval: ) def asdict(self) -> Dict: + """Return instance as a dict.""" return self._interval._asdict() def __repr__(self) -> str: + """Return the instance string representation.""" return f'{self.__class__.__qualname__}("{self.data}")' class TimePeriod(NamedTuple): + """Time period model representation.""" + year: Optional[int] month: Optional[int] day: Optional[int] @@ -580,8 +594,7 @@ class TimePeriod(NamedTuple): class XmlPeriod(UserString): - """ - Concrete xs:gYear/Month/Day builtin type. + """Concrete xs:gYear/Month/Day builtin type. Represents iso 8601 period formats with rich comparisons and hashing. @@ -592,7 +605,8 @@ class XmlPeriod(UserString): - xs:gMonthDay: **--%m-%d%z** - xs:gYearMonth: **%Y-%m%z** - :param value: String representation of a xs:period, eg **--11-01Z** + Args: + value: String representation of a xs:period, eg **--11-01Z** """ def __init__(self, value: str) -> None: @@ -653,9 +667,11 @@ def as_dict(self) -> Dict: return self._period._asdict() def __repr__(self) -> str: + """Return the instance string representation.""" return f'{self.__class__.__qualname__}("{self.data}")' def __eq__(self, other: Any) -> bool: + """Return self == other.""" if isinstance(other, XmlPeriod): return self._period == other._period @@ -663,8 +679,7 @@ def __eq__(self, other: Any) -> bool: class XmlHexBinary(bytes): - """ - Subclass bytes to infer base16 format. + """Subclass bytes to infer base16 format. This type can be used with xs:anyType fields that don't have a format property to specify the target output format. @@ -672,8 +687,7 @@ class XmlHexBinary(bytes): class XmlBase64Binary(bytes): - """ - Subclass bytes to infer base64 format. + """Subclass bytes to infer base64 format. This type can be used with xs:anyType fields that don't have a format property to specify the target output format. diff --git a/xsdata/models/dtd.py b/xsdata/models/dtd.py index 7bc307a05..9e96e3ed0 100644 --- a/xsdata/models/dtd.py +++ b/xsdata/models/dtd.py @@ -7,6 +7,8 @@ class DtdElementType(enum.Enum): + """DTD Element type enumeration.""" + UNDEFINED = "undefined" EMPTY = "empty" ANY = "any" @@ -15,6 +17,8 @@ class DtdElementType(enum.Enum): class DtdAttributeDefault(enum.Enum): + """DTD Attribute default enumeration.""" + REQUIRED = "required" IMPLIED = "implied" FIXED = "fixed" @@ -22,6 +26,8 @@ class DtdAttributeDefault(enum.Enum): class DtdAttributeType(enum.Enum): + """DTD Attribute type enumeration.""" + CDATA = "cdata" ID = "id" IDREF = "idref" @@ -35,6 +41,8 @@ class DtdAttributeType(enum.Enum): class DtdContentType(enum.Enum): + """DTD Content type enumeration.""" + PCDATA = "pcdata" ELEMENT = "element" SEQ = "seq" @@ -42,6 +50,8 @@ class DtdContentType(enum.Enum): class DtdContentOccur(enum.Enum): + """DTD Content occur enumeration.""" + ONCE = "once" OPT = "opt" MULT = "mult" @@ -50,6 +60,17 @@ class DtdContentOccur(enum.Enum): @dataclass class DtdAttribute: + """DTD Attribute model representation. + + Args: + name: The attribute name + prefix: The attribute namespace prefix + type: The attribute type + default: The attribute default type + default_value: The attribute default value + values: The available choices as value + """ + name: str prefix: Optional[str] type: DtdAttributeType @@ -59,11 +80,22 @@ class DtdAttribute: @property def data_type(self) -> DataType: + """Return the data type instance from the attribute type.""" return DataType.from_code(self.type.value.lower()) @dataclass class DtdContent: + """DTD Content model representation. + + Args: + name: The content name + type: The content type + occur: The content occur type + left: The parent content + right: The child content + """ + name: str type: DtdContentType occur: DtdContentOccur @@ -73,6 +105,17 @@ class DtdContent: @dataclass class DtdElement: + """DTD Element model representation. + + Args: + name: The element name + type: The element type + prefix: The element namespace prefix + content: The element content + attributes: The element attribute list + ns_map: The namespace prefix-URI map + """ + name: str type: DtdElementType prefix: Optional[str] @@ -82,11 +125,19 @@ class DtdElement: @property def qname(self) -> str: + """Return the element qualified name.""" namespace = self.ns_map.get(self.prefix) return build_qname(namespace, self.name) @dataclass class Dtd: + """The DTD Document model representation. + + Args: + location: The source location URI + elements: The list of included elements + """ + location: str elements: List[DtdElement] diff --git a/xsdata/models/enums.py b/xsdata/models/enums.py index 986a196e7..3f947d9e1 100644 --- a/xsdata/models/enums.py +++ b/xsdata/models/enums.py @@ -37,15 +37,18 @@ def __init__(self, uri: str, prefix: str): @property def location(self) -> Optional[str]: + """The location of the local file.""" local_path = COMMON_SCHEMA_DIR.joinpath(f"{self.prefix}.xsd") return local_path.as_uri() if local_path.exists() else None @classmethod def get_enum(cls, uri: Optional[str]) -> Optional["Namespace"]: + """Get the enum member instance from the uri.""" return __STANDARD_NAMESPACES__.get(uri) if uri else None @classmethod def common(cls) -> Tuple["Namespace", ...]: + """Return the common namespaces.""" return Namespace.XS, Namespace.XSI, Namespace.XML, Namespace.XLINK @@ -64,15 +67,13 @@ class QNames: class NamespaceType: - """ - Wildcard elements/attributes namespace types. - - :cvar ANY_NS: elements from any namespace is allowed - :cvar OTHER_NS: elements from any namespace except the parent - element's namespace - :cvar LOCAL_NS: elements must come from no namespace - :cvar TARGET_NS: elements from the namespace of the parent element - can be present + """Wildcard elements/attributes namespace types. + + Attributes: + ANY_NS: elements from any namespace is allowed + OTHER_NS: elements from any namespace except the parent element's namespace + LOCAL_NS: elements must come from no namespace + TARGET_NS: elements from the namespace of the parent element can be present """ ANY_NS = "##any" @@ -173,14 +174,16 @@ def __init__( self.wrapper = wrapper def __str__(self) -> str: + """Return the qualified string representation of the datatype.""" return f"{{{Namespace.XS.uri}}}{self.code}" def prefixed(self, prefix: Optional[str] = Namespace.XS.prefix) -> str: + """Return the prefixed string representation of the datatype.""" return f"{prefix}:{self.code}" if prefix else self.code @classmethod def from_value(cls, value: Any) -> "DataType": - """Infer the xsd type from the value itself.""" + """Load from a literal value.""" _type = type(value) calculate = __DataTypeInferIndex__.get(_type) if calculate: @@ -190,18 +193,22 @@ def from_value(cls, value: Any) -> "DataType": @classmethod def from_type(cls, tp: Type) -> "DataType": + """Load from a python type.""" return __DataTypeIndex__.get(tp, DataType.STRING) @classmethod def from_qname(cls, qname: str) -> Optional["DataType"]: + """Load from a qualified name.""" return __DataTypeQNameIndex__.get(qname) @classmethod def from_code(cls, code: str) -> "DataType": + """Load from the code name.""" return __DataTypeCodeIndex__.get(code, DataType.STRING) def period_datatype(value: XmlPeriod) -> DataType: + """Infer the datatype of a xml period instance.""" if value.year is not None: return DataType.G_YEAR_MONTH if value.month else DataType.G_YEAR if value.month: @@ -210,6 +217,7 @@ def period_datatype(value: XmlPeriod) -> DataType: def int_datatype(value: int) -> DataType: + """Infer the datatype of an int value.""" if -32768 <= value <= 32767: return DataType.SHORT if -2147483648 <= value <= 2147483647: @@ -220,6 +228,7 @@ def int_datatype(value: int) -> DataType: def float_datatype(value: float) -> DataType: + """Infer the datatype of a float value.""" if -1.175494351e-38 <= value <= 3.402823466e38: return DataType.FLOAT return DataType.DOUBLE @@ -259,7 +268,7 @@ class EventType: class Tag: - """Xml Schema tag names.""" + """Xml Schema tags.""" ALL = "All" ANNOTATION = "Annotation" @@ -329,13 +338,3 @@ class ProcessType(Enum): LAX = "lax" SKIP = "skip" STRICT = "strict" - - -class BindingStyle(Enum): - RPC = "rpc" - DOCUMENT = "document" - - -class UseChoice(Enum): - LITERAL = "literal" - ENCODED = "encoded" diff --git a/xsdata/models/mixins.py b/xsdata/models/mixins.py index 4c4f2a19d..d246350cf 100644 --- a/xsdata/models/mixins.py +++ b/xsdata/models/mixins.py @@ -10,11 +10,11 @@ @dataclass class ElementBase: - """ - Base xsd schema model. + """Base xsd schema model representation. - :param index: Occurrence position inside the definition - :param ns_map: Namespace prefix-URI map + Attributes: + index: The element position in the schema + ns_map: The element namespace prefix-URI map """ index: int = field( @@ -35,8 +35,7 @@ def class_name(self) -> str: @property def default_type(self) -> str: - """Return the default type if the given element has not specific - type.""" + """The element's inferred default type qname.""" return DataType.STRING.prefixed(self.xs_prefix) @property @@ -60,17 +59,17 @@ def bases(self) -> Iterator[str]: @property def has_children(self) -> bool: - """Return whether or not this element has any children.""" + """Return whether this element has any children.""" return next(self.children(), None) is not None @property def has_form(self) -> bool: - """Return whether or not this element has the form attribute.""" + """Return whether this element has the form attribute.""" return hasattr(self, "form") @property def is_abstract(self) -> bool: - """Return whether or not this element is defined as abstract.""" + """Return whether this element is defined as abstract.""" return getattr(self, "abstract", False) @property @@ -80,23 +79,22 @@ def is_property(self) -> bool: @property def is_fixed(self) -> bool: - """Return whether or not this element has a fixed value.""" + """Return whether this element has a fixed value.""" return getattr(self, "fixed", None) is not None @property def is_mixed(self) -> bool: - """Return whether or not this element accepts mixed content value.""" + """Return whether this element accepts mixed content value.""" return False @property def is_nillable(self) -> bool: - """Return whether or not this element is accepts empty empty values.""" + """Return whether this element accepts nillable content.""" return getattr(self, "nillable", False) @property def is_qualified(self) -> bool: - """Return whether or not this element name needs to be referenced with - the target namespace.""" + """Return whether this element must be referenced with the target namespace.""" if self.has_form: if getattr(self, "form", FormType.UNQUALIFIED) == FormType.QUALIFIED: return True @@ -108,14 +106,12 @@ def is_qualified(self) -> bool: @property def is_ref(self) -> bool: - """Return whether or not this element is a reference to another - element.""" + """Return whether this element is a reference to another element.""" return getattr(self, "ref", None) is not None @property def is_wildcard(self) -> bool: - """Return whether or not this element is a wildcard - element/attribute.""" + """Return whether this element is a wildcard element/attribute.""" return False @property @@ -131,13 +127,7 @@ def raw_namespace(self) -> Optional[str]: @property def real_name(self) -> str: - """ - Return the real name for this element by looking by looking either to - the name or ref attribute value. - - :raises SchemaValueError: when instance has no name/ref - attribute. - """ + """Return the real name for this element.""" name = getattr(self, "name", None) or getattr(self, "ref", None) if name: return text.suffix(name) @@ -146,7 +136,7 @@ def real_name(self) -> str: @property def attr_types(self) -> Iterator[str]: - """Return the attribute types for this element.""" + """Return the attr types for this element.""" yield from () @property @@ -156,7 +146,7 @@ def substitutions(self) -> List[str]: @property def xs_prefix(self) -> Optional[str]: - """Return the xml schema uri prefix.""" + """Return the xml schema URI prefix.""" for prefix, uri in self.ns_map.items(): if uri == Namespace.XS.uri: return prefix @@ -168,8 +158,7 @@ def get_restrictions(self) -> Dict[str, Any]: return {} def children(self, condition: Callable = return_true) -> Iterator["ElementBase"]: - """Iterate over all the ElementBase children of this element that match - the given condition if any.""" + """Yield the children recursively that match the given condition.""" for f in fields(self): value = getattr(self, f.name) if isinstance(value, list) and value and isinstance(value[0], ElementBase): @@ -203,9 +192,7 @@ def element(optional: bool = True, **kwargs: Any) -> Any: def add_default_value(params: Dict, optional: bool): - """Add default value to the params if it's missing and its marked as - optional.""" - + """Add the default value if it's missing and the optional flag is true.""" if optional and not ("default" in params or "default_factory" in params): params["default"] = None @@ -225,8 +212,7 @@ def array_any_element(**kwargs: Any) -> Any: def extract_metadata(params: Dict, **kwargs: Any) -> Dict: - """Extract not standard dataclass field parameters to a new metadata - dictionary and merge with any provided keyword arguments.""" + """Remove dataclasses standard field properties and merge any additional.""" metadata = { key: params.pop(key) for key in list(params.keys()) if key not in FIELD_PARAMS } diff --git a/xsdata/models/wsdl.py b/xsdata/models/wsdl.py index ea5403bde..3051c3a8b 100644 --- a/xsdata/models/wsdl.py +++ b/xsdata/models/wsdl.py @@ -12,8 +12,10 @@ @dataclass class Documentation: - """ - :params elements: + """WSDL Documentation model representation. + + Args: + elements: A list of generic any elements """ elements: List[object] = array_any_element() @@ -21,11 +23,13 @@ class Documentation: @dataclass class WsdlElement: - """ - :param name: - :param documentation: - :param location: - :param ns_map + """WSDL Base element model representation. + + Args: + name: The element name + documentation: The element documentation + location: The element location + ns_map: The element namespace prefix-URI map """ name: str = attribute() @@ -38,22 +42,31 @@ class WsdlElement: @dataclass class ExtensibleElement(WsdlElement): - """ - :param extended: + """WSDL Extensible element model representation. + + Args: + name: The element name + documentation: The element documentation + location: The element location + ns_map: The element namespace prefix-URI map + extended: A list of generic elements """ extended: List[object] = array_any_element() @property def extended_elements(self) -> Iterator[AnyElement]: + """Yields all generic element instances.""" yield from (ext for ext in self.extended if isinstance(ext, AnyElement)) @dataclass class Types: - """ - :param schemas: - :param documentation: + """WSDL Types model representation. + + Args: + schemas: Inline xml schema definitions + documentation: The type documentation """ schemas: List[Schema] = array_element(name="schema", namespace=Namespace.XS.uri) @@ -62,9 +75,11 @@ class Types: @dataclass class Import: - """ - :param location: - :param namespace: + """WSDL Import model representation. + + Args: + location: The location URI + namespace: The namespace URI """ location: Optional[str] = attribute() @@ -73,9 +88,15 @@ class Import: @dataclass class Part(WsdlElement): - """ - :param type: - :param element: + """WSDL Part model representation. + + Args: + name: The part name + documentation: The part documentation + location: The part location + ns_map: The part namespace prefix-URI map + type: The part type + element: The part element """ type: Optional[str] = attribute() @@ -84,8 +105,14 @@ class Part(WsdlElement): @dataclass class Message(WsdlElement): - """ - :param part: + """WSDL Message model representation. + + Args: + name: The message name + documentation: The message documentation + location: The message location + ns_map: The message namespace prefix-URI map + parts: The message parts """ parts: List[Part] = array_element(name="part") @@ -93,8 +120,15 @@ class Message(WsdlElement): @dataclass class PortTypeMessage(WsdlElement): - """ - :param message: + """WSDL Port type message model representation. + + Args: + Args: + name: The port type name + documentation: The port type documentation + location: The port type location + ns_map: The port type namespace prefix-URI map + message: The port type message """ message: str = attribute() @@ -102,10 +136,12 @@ class PortTypeMessage(WsdlElement): @dataclass class PortTypeOperation(WsdlElement): - """ - :param input: - :param output: - :param faults: + """WSDL Port type operation model representation. + + Args: + input: The input port type message instance + output: The output port type message instance + faults: The list of error port type message instances """ input: PortTypeMessage = element() @@ -115,27 +151,50 @@ class PortTypeOperation(WsdlElement): @dataclass class PortType(ExtensibleElement): - """ - :param operations: + """WSDL Port type model representation. + + Args: + name: The port type name + documentation: The port type documentation + location: The port type location + ns_map: The port type namespace prefix-URI map + extended: The port type extended elements + operations: The port type operations """ operations: List[PortTypeOperation] = array_element(name="operation") def find_operation(self, name: str) -> PortTypeOperation: + """Find an operation by name or raise an error.""" return find_or_die(self.operations, name, "PortTypeOperation") @dataclass class BindingMessage(ExtensibleElement): - pass + """WSDL Binding message model representation. + + Args: + name: The message name + documentation: The message documentation + location: The message location + ns_map: The message namespace prefix-URI map + extended: The message extended elements + """ @dataclass class BindingOperation(ExtensibleElement): - """ - :param input: - :param output: - :param faults: + """WSDL Binding operation model representation. + + Args: + input: The input binding message instance + output: The output binding message instance + faults: The list of error binding message instances + name: The operation name + documentation: The operation documentation + location: The operation location + ns_map: The operation namespace prefix-URI map + extended: The operation extended elements """ input: BindingMessage = element() @@ -145,16 +204,23 @@ class BindingOperation(ExtensibleElement): @dataclass class Binding(ExtensibleElement): - """ - :param type: - :param operations: - :param extended: + """WSDL Binding model representation. + + Args: + name: The binding name + documentation: The binding documentation + location: The binding location + ns_map: The binding namespace prefix-URI map + extended: The binding extended elements + type: The binding type + operations: The binding operations """ type: str = attribute() operations: List[BindingOperation] = array_element(name="operation") def unique_operations(self) -> Iterator[BindingOperation]: + """Yields all unique operation instances.""" grouped_operations = collections.group_by(self.operations, key=get_name) for operations in grouped_operations.values(): @@ -163,8 +229,15 @@ def unique_operations(self) -> Iterator[BindingOperation]: @dataclass class ServicePort(ExtensibleElement): - """ - :param binding: + """WSDL Service port model representation. + + Args: + name: The port name + documentation: The port documentation + location: The port location + ns_map: The port namespace prefix-URI map + extended: The port extended elements + binding: The port binding """ binding: str = attribute() @@ -172,8 +245,14 @@ class ServicePort(ExtensibleElement): @dataclass class Service(WsdlElement): - """ - :param ports: + """WSDL Service model representation. + + Args: + name: The service name + documentation: The service documentation + location: The service location + ns_map: The service namespace prefix-URI map + ports: The service ports """ ports: List[ServicePort] = array_element(name="port") @@ -181,17 +260,26 @@ class Service(WsdlElement): @dataclass class Definitions(ExtensibleElement): - """ - :param types: - :param imports: - :param messages: - :param port_types: - :param bindings: - :param services: - :param extended: + """WSDL Definitions model representation. + + Args: + name: The definition name + documentation: The definition documentation + location: The definition location + ns_map: The definition namespace prefix-URI map + extended: A list of generic elements + types: The definition types + imports: The definition imports + messages: The definition messages + port_types: The definition port types + bindings: The definition bindings + services: The definition services + extended: The definition extended elements """ class Meta: + """Metadata options.""" + name = "definitions" namespace = "http://schemas.xmlsoap.org/wsdl/" @@ -204,20 +292,25 @@ class Meta: services: List[Service] = array_element(name="service") @property - def schemas(self): + def schemas(self) -> Iterator[Schema]: + """Yield all schema definitions.""" if self.types: yield from self.types.schemas def find_binding(self, name: str) -> Binding: + """Find a binding by name or raise an error.""" return find_or_die(self.bindings, name, "Binding") def find_message(self, name: str) -> Message: + """Find a message by name or raise an error.""" return find_or_die(self.messages, name, "Message") def find_port_type(self, name: str) -> PortType: + """Find a port type by name or raise an error.""" return find_or_die(self.port_types, name, "PortType") def merge(self, source: "Definitions"): + """Merge the source instance with this instance.""" if not self.types: self.types = source.types elif source.types: @@ -230,6 +323,7 @@ def merge(self, source: "Definitions"): self.extended.extend(source.extended) def included(self) -> Iterator[Import]: + """Yield all imports.""" yield from self.imports @@ -237,6 +331,7 @@ def included(self) -> Iterator[Import]: def find_or_die(items: List[T], name: str, type_name: str) -> T: + """Find an item by name or raise an error.""" for msg in items: if msg.name == name: return msg diff --git a/xsdata/models/xsd.py b/xsdata/models/xsd.py index 3e066cda9..a3e285abc 100644 --- a/xsdata/models/xsd.py +++ b/xsdata/models/xsd.py @@ -35,30 +35,32 @@ @dataclass(frozen=True) class Docstring: + """Docstring model representation. + + Args: + content: A list of mixed content elements + """ + class Meta: + """Metadata options.""" + namespace = "http://www.w3.org/1999/xhtml" - elements: Array[object] = array_any_element() + content: Array[object] = array_any_element() @dataclass class Documentation(ElementBase): - """ - Model representation of a schema xs:documentation element. - - :param lang: language - :param source: anyURI - :param elements: ({any})* - :param attributes: any attributes with non-schema namespace - """ + """XSD MinLength model representation.""" lang: Optional[str] = attribute() source: Optional[str] = attribute() - elements: Array[object] = array_any_element(mixed=True) attributes: Optional["AnyAttribute"] = element() + content: Array[object] = array_any_element(mixed=True) def tostring(self) -> Optional[str]: - obj = Docstring(self.elements) + """Convert the content to a help string.""" + obj = Docstring(self.content) ns_map = {None: "http://www.w3.org/1999/xhtml"} xml = docstring_serializer.render(obj, ns_map=ns_map) start = xml.find(">") + 1 @@ -68,46 +70,30 @@ def tostring(self) -> Optional[str]: @dataclass class Appinfo(ElementBase): - """ - Model representation of a schema xs:appinfo element. - - :param lang: language - :param source: anyURI - :param attributes: any attributes with non-schema namespace - """ + """XSD Appinfo model representation.""" class Meta: + """Metadata options.""" + mixed = True source: Optional[str] = attribute() - elements: Array[object] = array_any_element() any_attribute: Optional["AnyAttribute"] = element(name="anyAttribute") + content: Array[object] = array_any_element(mixed=True) @dataclass class Annotation(ElementBase): - """ - Model representation of a schema xs:annotation element. + """XSD Annotation model representation.""" - :param appinfos: - :param documentations: - :param any_attribute: any attributes with non-schema namespace - """ - - appinfos: Array[Appinfo] = array_element(name="appinfo") + app_infos: Array[Appinfo] = array_element(name="appinfo") documentations: Array[Documentation] = array_element(name="documentation") any_attribute: Optional["AnyAttribute"] = element(name="anyAttribute") @dataclass class AnnotationBase(ElementBase): - """ - Base Class for elements that can contain annotations. - - :param id: ID - :param annotations: - :param any_attribute: any attributes with non-schema namespace - """ + """XSD AnnotationBase model representation.""" id: Optional[str] = attribute() annotations: Array[Annotation] = array_element(name="annotation") @@ -115,6 +101,7 @@ class AnnotationBase(ElementBase): @property def display_help(self) -> Optional[str]: + """Return all annotation documentations concatenated.""" help_str = "\n".join( documentation.tostring() or "" for annotation in self.annotations @@ -126,59 +113,49 @@ def display_help(self) -> Optional[str]: @dataclass class AnyAttribute(AnnotationBase): - """ - Model representation of a schema xs:anyAttribute element. - - :param namespace: ##any | ##other) | List of anyURI | - (##targetNamespace | ##local) - :param process_contents: (lax | skip | strict) : strict - """ + """XSD AnyAttribute model representation.""" namespace: str = attribute(default="##any") - process_contents: Optional[ProcessType] = attribute(name="processContents") + process_contents: Optional[ProcessType] = attribute( + name="processContents", default="strict" + ) def __post_init__(self): + """Clean the namespace value.""" self.namespace = " ".join(unique_sequence(self.namespace.split())) @property def is_property(self) -> bool: + """Specify it is qualified to be a class property.""" return True @property def raw_namespace(self) -> Optional[str]: + """The element explicit namespace.""" return self.namespace @property def real_name(self) -> str: + """Return the real name for this element.""" clean_ns = "_".join(map(clean_uri, self.namespace.split())) return f"@{clean_ns}_attributes" @property def attr_types(self) -> Iterator[str]: + """Yields the attr types for this element.""" yield DataType.ANY_TYPE.prefixed(self.xs_prefix) @dataclass class Assertion(AnnotationBase): - """ - Model representation of a schema xs:assertion element. - - :param test: an XPath expression - """ + """XSD Assertion model representation.""" test: Optional[str] = attribute() @dataclass class SimpleType(AnnotationBase): - """ - Model representation of a schema xs:simpleType element. - - :param name: NCName - :param restriction: - :param list: - :param union: - """ + """XSD SimpleType model representation.""" name: Optional[str] = attribute() restriction: Optional["Restriction"] = element() @@ -187,20 +164,24 @@ class SimpleType(AnnotationBase): @property def is_property(self) -> bool: + """Specify it is qualified to be a class property.""" return True @property def is_enumeration(self) -> bool: + """Return whether it is an enumeration restriction.""" return self.restriction is not None and len(self.restriction.enumerations) > 0 @property def real_name(self) -> str: + """Return the real name for this element.""" if self.name: return self.name return DEFAULT_ATTR_NAME @property def attr_types(self) -> Iterator[str]: + """Return the attr types for this element.""" if not self.is_enumeration and self.restriction: yield from self.restriction.attr_types elif self.list: @@ -209,6 +190,7 @@ def attr_types(self) -> Iterator[str]: yield from self.union.bases def get_restrictions(self) -> Dict[str, Anything]: + """Return the restrictions dictionary of this element.""" if self.restriction: return self.restriction.get_restrictions() if self.list: @@ -218,60 +200,58 @@ def get_restrictions(self) -> Dict[str, Anything]: @dataclass class List(AnnotationBase): - """ - Model representation of a schema xs:list element. - - :param simple_type: - :param item_type: QName - """ + """XSD List model representation.""" simple_type: Optional[SimpleType] = element(name="simpleType") item_type: str = attribute(name="itemType", default="") @property def is_property(self) -> bool: + """Specify it is qualified to be a class property.""" return True @property def real_name(self) -> str: + """Return the real name for this element.""" return DEFAULT_ATTR_NAME @property def attr_types(self) -> Iterator[str]: + """Return the attr types for this element.""" if self.item_type: yield self.item_type def get_restrictions(self) -> Dict[str, Anything]: + """Return the restrictions dictionary of this element.""" return {"tokens": True} @dataclass class Union(AnnotationBase): - """ - Model representation of a schema xs:union element. - - :param member_types: List of QName - :param simple_types: - """ + """XSD Union model representation.""" member_types: Optional[str] = attribute(name="memberTypes") simple_types: Array[SimpleType] = array_element(name="simpleType") @property def bases(self) -> Iterator[str]: + """Return an iterator of all the base types.""" if self.member_types: yield from self.member_types.split() @property def is_property(self) -> bool: + """Specify it is qualified to be a class property.""" return True @property def real_name(self) -> str: + """Return the real name for this element.""" return DEFAULT_ATTR_NAME @property def attr_types(self) -> Iterator[str]: + """Return the attr types for this element.""" for simple_type in self.simple_types: yield from simple_type.attr_types @@ -279,6 +259,7 @@ def attr_types(self) -> Iterator[str]: yield from self.member_types.split() def get_restrictions(self) -> Dict[str, Anything]: + """Return the restrictions dictionary of this element.""" restrictions = {} for simple_type in self.simple_types: restrictions.update(simple_type.get_restrictions()) @@ -287,19 +268,7 @@ def get_restrictions(self) -> Dict[str, Anything]: @dataclass class Attribute(AnnotationBase): - """ - Model representation of a schema xs:attribute element. - - :param default: string - :param fixed: string - :param form: qualified | unqualified - :param name: NCName - :param ref: QName - :param type: QName - :param target_namespace: anyURI - :param simple_type: - :param use: (optional | prohibited | required) : optional - """ + """XSD Attribute model representation.""" default: Optional[str] = attribute() fixed: Optional[str] = attribute() @@ -313,6 +282,7 @@ class Attribute(AnnotationBase): @property def bases(self) -> Iterator[str]: + """Return an iterator of all the base types.""" if self.type: yield self.type elif not self.has_children: @@ -320,10 +290,12 @@ def bases(self) -> Iterator[str]: @property def is_property(self) -> bool: + """Specify it is qualified to be a class property.""" return True @property def attr_types(self) -> Iterator[str]: + """Return the attr types for this element.""" if self.simple_type: yield from self.simple_type.attr_types elif self.type: @@ -333,10 +305,12 @@ def attr_types(self) -> Iterator[str]: @property def default_type(self) -> str: + """Returned the inferred default type qname.""" datatype = DataType.STRING if self.fixed else DataType.ANY_SIMPLE_TYPE return datatype.prefixed(self.xs_prefix) def get_restrictions(self) -> Dict[str, Anything]: + """Return the restrictions dictionary of this element.""" if self.use == UseType.REQUIRED: restrictions = {"min_occurs": 1, "max_occurs": 1} elif self.use == UseType.PROHIBITED: @@ -352,14 +326,7 @@ def get_restrictions(self) -> Dict[str, Anything]: @dataclass class AttributeGroup(AnnotationBase): - """ - Model representation of a schema xs:attributeGroup element. - - :param name: NCName - :param ref: QName - :param attributes: any attributes with non-schema namespace - :param attribute_groups: - """ + """XSD AttributeGroup model representation.""" ref: str = attribute(default="") name: Optional[str] = attribute() @@ -368,24 +335,19 @@ class AttributeGroup(AnnotationBase): @property def is_property(self) -> bool: + """Specify it is qualified to be a class property.""" return True @property def attr_types(self) -> Iterator[str]: + """Return the attr types for this element.""" if self.ref: yield self.ref @dataclass class Any(AnnotationBase): - """ - Model representation of a schema xs:any element. - - :param min_occurs: nonNegativeInteger : 1 - :param max_occurs: (nonNegativeInteger | unbounded) : 1 - :param namespace: List of (anyURI | (##targetNamespace | ##local)) - :param process_contents: (lax | skip | strict) : strict - """ + """XSD Any model representation.""" namespace: str = attribute(default="##any") min_occurs: int = attribute(default=1, name="minOccurs") @@ -395,26 +357,32 @@ class Any(AnnotationBase): ) def __post_init__(self): + """Clean the namespace value.""" self.namespace = " ".join(unique_sequence(self.namespace.split())) @property def is_property(self) -> bool: + """Specify it is qualified to be a class property.""" return True @property def real_name(self) -> str: + """Return the real name for this element.""" clean_ns = "_".join(map(clean_uri, self.namespace.split())) return f"@{clean_ns}_element" @property def raw_namespace(self) -> Optional[str]: + """The element explicit namespace.""" return self.namespace @property def attr_types(self) -> Iterator[str]: + """Return the attr types for this element.""" yield DataType.ANY_TYPE.prefixed(self.xs_prefix) def get_restrictions(self) -> Dict[str, Anything]: + """Return the restrictions dictionary of this element.""" max_occurs = sys.maxsize if self.max_occurs == "unbounded" else self.max_occurs return { @@ -426,15 +394,7 @@ def get_restrictions(self) -> Dict[str, Anything]: @dataclass class All(AnnotationBase): - """ - Model representation of a schema xs:all element. - - :param min_occurs: nonNegativeInteger : 1 - :param max_occurs: (nonNegativeInteger | unbounded) : 1 - :param any: - :param elements: - :param groups: - """ + """XSD All model representation.""" min_occurs: int = attribute(default=1, name="minOccurs") max_occurs: UnionType[int, str] = attribute(default=1, name="maxOccurs") @@ -443,6 +403,7 @@ class All(AnnotationBase): groups: Array["Group"] = array_element(name="group") def get_restrictions(self) -> Dict[str, Anything]: + """Return the restrictions dictionary of this element.""" max_occurs = sys.maxsize if self.max_occurs == "unbounded" else self.max_occurs return { @@ -452,17 +413,7 @@ def get_restrictions(self) -> Dict[str, Anything]: @dataclass class Sequence(AnnotationBase): - """ - Model representation of a schema xs:sequence element. - - :param min_occurs: nonNegativeInteger : 1 - :param max_occurs: (nonNegativeInteger | unbounded) : 1 - :param elements: - :param groups: - :param choices: - :param sequences: - :param any: - """ + """XSD Sequence model representation.""" min_occurs: int = attribute(default=1, name="minOccurs") max_occurs: UnionType[int, str] = attribute(default=1, name="maxOccurs") @@ -473,6 +424,7 @@ class Sequence(AnnotationBase): any: Array["Any"] = array_element() def get_restrictions(self) -> Dict[str, Anything]: + """Return the restrictions dictionary of this element.""" max_occurs = sys.maxsize if self.max_occurs == "unbounded" else self.max_occurs return { @@ -482,17 +434,7 @@ def get_restrictions(self) -> Dict[str, Anything]: @dataclass class Choice(AnnotationBase): - """ - Model representation of a schema xs:choice element. - - :param min_occurs: nonNegativeInteger : 1 - :param max_occurs: (nonNegativeInteger | unbounded) : 1 - :param elements: - :param groups: - :param choices: - :param sequences: - :param any: - """ + """XSD Choice model representation.""" min_occurs: int = attribute(default=1, name="minOccurs") max_occurs: UnionType[int, str] = attribute(default=1, name="maxOccurs") @@ -503,6 +445,7 @@ class Choice(AnnotationBase): any: Array["Any"] = array_element() def get_restrictions(self) -> Dict[str, Anything]: + """Return the restrictions dictionary of this element.""" max_occurs = sys.maxsize if self.max_occurs == "unbounded" else self.max_occurs return { @@ -512,17 +455,7 @@ def get_restrictions(self) -> Dict[str, Anything]: @dataclass class Group(AnnotationBase): - """ - Model representation of a schema xs:group element. - - :param name: NCName - :param ref: QName - :param min_occurs: nonNegativeInteger : 1 - :param max_occurs: (nonNegativeInteger | unbounded) : 1 - :param all: - :param choice: - :param sequence: - """ + """XSD Group model representation.""" name: Optional[str] = attribute() ref: str = attribute(default="") @@ -534,14 +467,17 @@ class Group(AnnotationBase): @property def is_property(self) -> bool: + """Specify it is qualified to be a class property.""" return True @property def attr_types(self) -> Iterator[str]: + """Return the attr types for this element.""" if self.ref: yield self.ref def get_restrictions(self) -> Dict[str, Anything]: + """Return the restrictions dictionary of this element.""" max_occurs = sys.maxsize if self.max_occurs == "unbounded" else self.max_occurs return { @@ -551,13 +487,7 @@ def get_restrictions(self) -> Dict[str, Anything]: @dataclass class OpenContent(AnnotationBase): - """ - Model representation of a schema xs:openContent element. - - :param applies_to_empty: default false - :param mode: (none | interleave | suffix) : interleave - :param any: - """ + """XSD OpenContent model representation.""" applies_to_empty: bool = attribute(default=False, name="appliesToEmpty") mode: Mode = attribute(default=Mode.INTERLEAVE) @@ -566,25 +496,12 @@ class OpenContent(AnnotationBase): @dataclass class DefaultOpenContent(OpenContent): - """Model representation of a schema xs:defaultOpenContent element.""" + """XSD DefaultOpenContent model representation.""" @dataclass class Extension(AnnotationBase): - """ - Model representation of a schema xs:extension element. - - :param base: QName - :param group: - :param all: - :param choice: - :param sequence: - :param any_attribute: any attributes with non-schema namespace - :param open_content: - :param attributes: - :param attribute_groups: - :param assertions: - """ + """XSD Extension model representation.""" base: Optional[str] = attribute() group: Optional[Group] = element() @@ -599,166 +516,118 @@ class Extension(AnnotationBase): @property def bases(self) -> Iterator[str]: + """Return an iterator of all the base types.""" if self.base: yield self.base @dataclass class Enumeration(AnnotationBase): - """ - Model representation of a schema xs:enumeration element. - - :param value: anySimpleType - """ + """XSD Enumeration model representation.""" value: str = attribute() @property def is_property(self) -> bool: + """Specify it is qualified to be a class property.""" return True @property def real_name(self) -> str: + """Return the enumeration value as its name.""" return self.value @property def default(self) -> str: + """Return the enumeration value as its default value.""" return self.value @property def is_fixed(self) -> bool: + """Specify this element has a fixed value.""" return True @dataclass class FractionDigits(AnnotationBase): - """ - Model representation of a schema xs:fractionDigits element. - - :param value: nonNegativeInteger - """ + """XSD FractionDigits model representation.""" value: int = attribute() @dataclass class Length(AnnotationBase): - """ - Model representation of a schema xs:length element. - - :param value: nonNegativeInteger - """ + """XSD Length model representation.""" value: int = attribute() @dataclass class MaxExclusive(AnnotationBase): - """ - Model representation of a schema xs:maxExclusive element. - - :param value: anySimpleType - """ + """XSD MaxExclusive model representation.""" value: str = attribute() @dataclass class MaxInclusive(AnnotationBase): - """ - Model representation of a schema xs:maxInclusive element. - - :param value: anySimpleType - """ + """XSD MaxInclusive model representation.""" value: str = attribute() @dataclass class MaxLength(AnnotationBase): - """ - Model representation of a schema xs:maxLength element. - - :param value: nonNegativeInteger - """ + """XSD MaxLength model representation.""" value: int = attribute() @dataclass class MinExclusive(AnnotationBase): - """ - Model representation of a schema xs:minExclusive element. - - :param value: anySimpleType - """ + """XSD MinExclusive model representation.""" value: str = attribute() @dataclass class MinInclusive(AnnotationBase): - """ - Model representation of a schema xs:minInclusive element. - - :param value: anySimpleType - """ + """XSD MinInclusive model representation.""" value: str = attribute() @dataclass class MinLength(AnnotationBase): - """ - Model representation of a schema xs:minLength element. - - :param value: nonNegativeInteger - """ + """XSD MinLength model representation.""" value: int = attribute() @dataclass class Pattern(AnnotationBase): - """ - Model representation of a schema xs:pattern element. - - :param value: string - """ + """XSD Pattern model representation.""" value: str = attribute() @dataclass class TotalDigits(AnnotationBase): - """ - Model representation of a schema xs:totalDigits element. - - :param value: positiveInteger - """ + """XSD TotalDigits model representation.""" value: int = attribute() @dataclass class WhiteSpace(AnnotationBase): - """ - Model representation of a schema xs:whiteSpace element. - - :param value: (collapse | preserve | replace) - """ + """XSD WhiteSpace model representation.""" value: str = attribute() @dataclass class ExplicitTimezone(AnnotationBase): - """ - Model representation of a schema xs:explicitTimezone element. - - :param value: NCName - :param fixed: default false - """ + """XSD ExplicitTimezone model representation.""" value: str = attribute() fixed: bool = attribute(default=False) @@ -766,35 +635,7 @@ class ExplicitTimezone(AnnotationBase): @dataclass class Restriction(AnnotationBase): - """ - Model representation of a schema xs:restriction element. - - :param base: QName - :param group: - :param all: - :param choice: - :param sequence: - :param open_content: - :param attributes: - :param attribute_groups: - :param enumerations: - :param asserts: - :param assertions: - :param any_element: - :param min_exclusive: - :param min_inclusive: - :param min_length: - :param max_exclusive: - :param max_inclusive: - :param max_length: - :param total_digits: - :param fraction_digits: - :param length: - :param white_space: - :param patterns: - :param explicit_timezone: - :param simple_type: - """ + """XSD Restriction model representation.""" base: Optional[str] = attribute() group: Optional[Group] = element() @@ -824,6 +665,7 @@ class Restriction(AnnotationBase): @property def attr_types(self) -> Iterator[str]: + """Return the attr types for this element.""" if self.simple_type: yield from self.simple_type.attr_types elif self.base and not self.enumerations: @@ -831,14 +673,17 @@ def attr_types(self) -> Iterator[str]: @property def real_name(self) -> str: + """Return the real name for this element.""" return DEFAULT_ATTR_NAME @property def bases(self) -> Iterator[str]: + """Return an iterator of all the base types.""" if self.base: yield self.base def get_restrictions(self) -> Dict[str, Anything]: + """Return the restrictions dictionary of this element.""" restrictions = {} if self.simple_type: restrictions.update(self.simple_type.get_restrictions()) @@ -874,12 +719,7 @@ def get_restrictions(self) -> Dict[str, Anything]: @dataclass class SimpleContent(AnnotationBase): - """ - Model representation of a schema xs:simpleContent element. - - :param restriction: - :param extension: - """ + """XSD SimpleContent model representation.""" restriction: Optional[Restriction] = element() extension: Optional[Extension] = element() @@ -887,38 +727,14 @@ class SimpleContent(AnnotationBase): @dataclass class ComplexContent(SimpleContent): - """ - Model representation of a schema xs:complexContent element. - - :param fixed: - """ + """XSD ComplexContent model representation.""" mixed: bool = attribute(default=False) @dataclass class ComplexType(AnnotationBase): - """ - Model representation of a schema xs:complexType element. - - :param name: NCName - :param block: (#all | List of (extension | restriction)) - :param final: (#all | List of (extension | restriction)) - :param simple_content: - :param complex_content: - :param group: - :param all: - :param choice: - :param sequence: - :param any_attribute: - :param open_content: - :param attributes: - :param attribute_groups: - :param assertion: - :param abstract: - :param mixed: - :param default_attributes_apply: - """ + """XSD ComplexType model representation.""" name: Optional[str] = attribute() block: Optional[str] = attribute() @@ -942,6 +758,7 @@ class ComplexType(AnnotationBase): @property def is_mixed(self) -> bool: + """Return whether this element accepts mixed content value.""" if self.mixed: return True @@ -953,30 +770,19 @@ def is_mixed(self) -> bool: @dataclass class Field(AnnotationBase): - """ - Model representation of a schema xs:field element. - - :param xpath: a subset of XPath expression - """ + """XSD Field model representation.""" xpath: Optional[str] = attribute() @dataclass class Selector(Field): - """Schema Model representation of a schema xs:selectorModel element..""" + """XSD Selector model representation.""" @dataclass class Unique(AnnotationBase): - """ - Model representation of a schema xs:unique element. - - :param name: NCName - :param ref: QName - :param selector: - :param fields: - """ + """XSD Unique model representation.""" name: Optional[str] = attribute() ref: Optional[str] = attribute() @@ -986,14 +792,7 @@ class Unique(AnnotationBase): @dataclass class Key(AnnotationBase): - """ - Model representation of a schema xs:key element. - - :param name: NCName - :param ref: QName - :param selector: - :param fields: - """ + """XSD Key model representation.""" name: Optional[str] = attribute() ref: Optional[str] = attribute() @@ -1003,15 +802,7 @@ class Key(AnnotationBase): @dataclass class Keyref(AnnotationBase): - """ - Model representation of a schema xs:keyref element. - - :param name: NCName - :param ref: QName - :param refer: QName - :param selector: - :param fields: - """ + """XSD Keyref model representation.""" name: Optional[str] = attribute() ref: Optional[str] = attribute() @@ -1022,14 +813,7 @@ class Keyref(AnnotationBase): @dataclass class Alternative(AnnotationBase): - """ - Model representation of a schema xs:alternative element. - - :param type: QName - :param test: an XPath expression - :param simple_type: - :param complex_type: - """ + """XSD Alternative model representation.""" type: Optional[str] = attribute() test: Optional[str] = attribute() @@ -1038,6 +822,7 @@ class Alternative(AnnotationBase): @property def real_name(self) -> str: + """Return the real name for this element.""" if self.test: return text.snake_case(self.test) if self.id: @@ -1046,10 +831,12 @@ def real_name(self) -> str: @property def bases(self) -> Iterator[str]: + """Return an iterator of all the base types.""" if self.type: yield self.type def get_restrictions(self) -> Dict[str, Anything]: + """Return the restrictions dictionary of this element.""" return { "path": [("alt", id(self), 0, 1)], } @@ -1057,31 +844,7 @@ def get_restrictions(self) -> Dict[str, Anything]: @dataclass class Element(AnnotationBase): - """ - Model representation of a schema xs:element element. - - :param name: NCName - :param ref: QName - :param type: QName - :param substitution_group: List of QName - :param default: - :param fixed: - :param form: qualified | unqualified - :param block: (#all | List of (extension | restriction | - substitution)) - :param final: (#all | List of (extension | restriction)) - :param target_namespace: anyURI - :param simple_type: - :param complex_type: - :param alternatives: - :param uniques: - :param keys: - :param keyrefs: - :param min_occurs: nonNegativeInteger : 1 - :param max_occurs: (nonNegativeInteger | unbounded) : 1 - :param nillable: - :param abstract: - """ + """XSD Element model representation.""" name: Optional[str] = attribute() ref: Optional[str] = attribute() @@ -1106,6 +869,7 @@ class Element(AnnotationBase): @property def bases(self) -> Iterator[str]: + """Return an iterator of all the base types.""" if self.type: yield self.type elif not self.has_children: @@ -1113,19 +877,23 @@ def bases(self) -> Iterator[str]: @property def is_property(self) -> bool: + """Specify it is qualified to be a class property.""" return True @property def is_mixed(self) -> bool: + """Return whether this element accepts mixed content value.""" return self.complex_type.is_mixed if self.complex_type else False @property def default_type(self) -> str: + """Returned the inferred default type qname.""" datatype = DataType.STRING if self.fixed else DataType.ANY_TYPE return datatype.prefixed(self.xs_prefix) @property def attr_types(self) -> Iterator[str]: + """Return the attr types for this element.""" if self.type: yield self.type elif self.ref: @@ -1137,9 +905,11 @@ def attr_types(self) -> Iterator[str]: @property def substitutions(self) -> Array[str]: + """Return a list of the substitution groups.""" return self.substitution_group.split() if self.substitution_group else [] def get_restrictions(self) -> Dict[str, Anything]: + """Return the restrictions dictionary of this element.""" max_occurs = sys.maxsize if self.max_occurs == "unbounded" else self.max_occurs restrictions = { @@ -1158,13 +928,7 @@ def get_restrictions(self) -> Dict[str, Anything]: @dataclass class Notation(AnnotationBase): - """ - Model representation of a schema xs:notation element. - - :param name: NCName - :param public: token - :param system: anyURI - """ + """XSD Notation model representation.""" name: Optional[str] = attribute() public: Optional[str] = attribute() @@ -1172,75 +936,37 @@ class Notation(AnnotationBase): @dataclass -class SchemaLocation(AnnotationBase): - """ - Model representation of a schema xs:schemaLocation element. Base schema - location. - - :param location: any url with a urllib supported scheme file - : http: - """ - - location: Optional[str] = field(default=None) - - -@dataclass -class Import(SchemaLocation): - """ - Model representation of a schema xs:import element. - - :param namespace: anyURI - :param schema_location: anyURI - """ +class Import(AnnotationBase): + """XSD Import model representation.""" namespace: Optional[str] = attribute() schema_location: Optional[str] = attribute(name="schemaLocation") + location: Optional[str] = field(default=None, metadata={"type": "ignore"}) @dataclass -class Include(SchemaLocation): - """ - Model representation of a schema xs:include element. - - :param schema_location: anyURI - """ +class Include(AnnotationBase): + """XSD Include model representation.""" schema_location: Optional[str] = attribute(name="schemaLocation") + location: Optional[str] = field(default=None, metadata={"type": "ignore"}) @dataclass -class Redefine(SchemaLocation): - """ - Model representation of a schema xs:redefine element. - - :param schema_location: anyURI - :param simple_types: - :param complex_types: - :param groups: - :param attribute_groups: - """ +class Redefine(AnnotationBase): + """XSD Redefine model representation.""" schema_location: Optional[str] = attribute(name="schemaLocation") simple_types: Array[SimpleType] = array_element(name="simpleType") complex_types: Array[ComplexType] = array_element(name="complexType") groups: Array[Group] = array_element(name="group") attribute_groups: Array[AttributeGroup] = array_element(name="attributeGroup") + location: Optional[str] = field(default=None, metadata={"type": "ignore"}) @dataclass -class Override(SchemaLocation): - """ - Model representation of a schema xs:override element. - - :param schema_location: anyURI - :param simple_types: - :param complex_types: - :param groups: - :param attribute_groups: - :param elements: - :param attributes: - :param notations: - """ +class Override(AnnotationBase): + """XSD Override model representation.""" schema_location: Optional[str] = attribute(name="schemaLocation") simple_types: Array[SimpleType] = array_element(name="simpleType") @@ -1250,40 +976,16 @@ class Override(SchemaLocation): elements: Array[Element] = array_element(name="element") attributes: Array[Attribute] = array_element(name="attribute") notations: Array[Notation] = array_element(name="notation") + location: Optional[str] = field(default=None, metadata={"type": "ignore"}) @dataclass -class Schema(SchemaLocation): - """ - Model representation of a schema xs:schema element. - - :param target: - :param block_default: (#all | List of (extension | restriction | - substitution)) - :param default_attributes: QName - :param final_default: (#all | List of extension | restriction | list - | union) : '' - :param target_namespace: anyURI - :param version: token - :param xmlns: - :param element_form_default: (qualified | unqualified) : unqualified - :param attribute_form_default: (qualified | unqualified) : - unqualified - :param default_open_content: - :param imports: - :param redefines: - :param overrides: - :param annotations: - :param simple_types: - :param complex_types: - :param groups: - :param attribute_groups: - :param elements: - :param attributes: - :param notations: - """ +class Schema(AnnotationBase): + """XSD Schema model representation.""" class Meta: + """Metadata options.""" + name = "schema" namespace = Namespace.XS.uri @@ -1295,10 +997,12 @@ class Meta: version: Optional[str] = attribute() xmlns: Optional[str] = attribute() element_form_default: FormType = attribute( - default=FormType.UNQUALIFIED, name="elementFormDefault" + default=FormType.UNQUALIFIED, + name="elementFormDefault", ) attribute_form_default: FormType = attribute( - default=FormType.UNQUALIFIED, name="attributeFormDefault" + default=FormType.UNQUALIFIED, + name="attributeFormDefault", ) default_open_content: Optional[DefaultOpenContent] = element( name="defaultOpenContent" @@ -1315,8 +1019,10 @@ class Meta: elements: Array[Element] = array_element(name="element") attributes: Array[Attribute] = array_element(name="attribute") notations: Array[Notation] = array_element(name="notation") + location: Optional[str] = field(default=None, metadata={"type": "ignore"}) def included(self) -> Iterator[UnionType[Import, Include, Redefine, Override]]: + """Yields an iterator of included resources.""" yield from self.imports yield from self.includes diff --git a/xsdata/utils/click.py b/xsdata/utils/click.py index fec47e9a7..ca07bdaf8 100644 --- a/xsdata/utils/click.py +++ b/xsdata/utils/click.py @@ -1,4 +1,5 @@ import enum +import inspect import logging from dataclasses import fields, is_dataclass from typing import ( @@ -24,6 +25,8 @@ def model_options(obj: Any) -> Callable[[FC], FC]: + """Decorate click commands to add model options.""" + def decorator(f: F) -> F: for option in reversed(list(build_options(obj, ""))): option(f) @@ -33,6 +36,7 @@ def decorator(f: F) -> F: def build_options(obj: Any, parent: str) -> Iterator[Callable[[FC], FC]]: + """Build click options by a data class.""" type_hints = get_type_hints(obj) doc_hints = get_doc_hints(obj) @@ -87,25 +91,36 @@ def build_options(obj: Any, parent: str) -> Iterator[Callable[[FC], FC]]: def get_doc_hints(obj: Any) -> Dict[str, str]: + """Return a param-docstring map of the class arguments.""" + docstrings = inspect.getdoc(obj) + assert docstrings is not None + + start = docstrings.index("Args:") + 6 + params = docstrings[start:].replace("\n ", " ") + result = {} - for line in obj.__doc__.split(":param "): - if line[0].isalpha(): - param, hint = line.split(":", 1) - result[param] = " ".join(hint.split()) + for line in params.splitlines(): + param, hint = line.split(":", 1) + result[param.strip()] = " ".join(hint.split()) return result class EnumChoice(click.Choice): + """Custom click choice widget for enumerations.""" + def __init__(self, enumeration: Type[enum.Enum]): self.enumeration = enumeration super().__init__([e.value for e in enumeration]) def convert(self, value: Any, *args: Any) -> enum.Enum: + """Parse the value into an enumeration member.""" return self.enumeration(value) class LogFormatter(logging.Formatter): + """Custom log formatter with click colors.""" + colors: Dict[str, Any] = { "error": {"fg": "red"}, "exception": {"fg": "red"}, @@ -115,6 +130,7 @@ class LogFormatter(logging.Formatter): } def format(self, record: logging.LogRecord) -> str: + """Format the log record with click styles.""" if not record.exc_info: level = record.levelname.lower() msg = record.getMessage() @@ -127,11 +143,14 @@ def format(self, record: logging.LogRecord) -> str: class LogHandler(logging.Handler): + """Custom click log handler to record warnings.""" + def __init__(self, level: Union[int, str] = logging.NOTSET): super().__init__(level) self.warnings: List[str] = [] def emit(self, record: logging.LogRecord): + """Override emit to record warnings.""" try: msg = self.format(record) if record.levelno > logging.INFO: @@ -142,6 +161,7 @@ def emit(self, record: logging.LogRecord): self.handleError(record) def emit_warnings(self): + """Print all recorded warnings to click stdout.""" num = len(self.warnings) if num: click.echo(click.style(f"Warnings: {num}", bold=True)) diff --git a/xsdata/utils/collections.py b/xsdata/utils/collections.py index fd744cd28..0bdc1bff6 100644 --- a/xsdata/utils/collections.py +++ b/xsdata/utils/collections.py @@ -16,6 +16,7 @@ def is_array(value: Any) -> bool: + """Return whether the value is a list style type.""" if isinstance(value, tuple): return not hasattr(value, "_fields") @@ -23,11 +24,14 @@ def is_array(value: Any) -> bool: def unique_sequence(items: Iterable[T], key: Optional[str] = None) -> List[T]: - """ - Return a new list with the unique values from an iterable. + """Return a new unique list, preserving the original order. + + Args: + items: The iterable to filter + key: An optional callable to generate the unique keys - Optionally you can also provide a lambda to generate the unique key - of each item in the iterable object. + Returns: + A new unique list. """ seen = set() @@ -64,8 +68,15 @@ def apply(items: Iterable, func: Callable): def find(items: Sequence, value: Any) -> int: - """Return the index of the value in the given sequence without raising - exception in case of failure.""" + """Return the index of the value in the given sequence. + + Args: + items: The sequence to search in + value: The value to search for + + Returns: + The index in the sequence or -1 if the value is not found. + """ try: return items.index(value) except ValueError: @@ -83,12 +94,7 @@ def prepend(target: List, *args: Any): def connected_components(lists: List[List[Any]]) -> Iterator[List[Any]]: - """ - Merge lists of lists that share common elements. - - https://stackoverflow.com/questions/4842613/merge-lists-that-share- - common-elements - """ + """Merge lists of lists that share common elements.""" neighbors = defaultdict(set) for each in lists: for item in each: @@ -109,6 +115,7 @@ def component(node: Any, neigh: Dict[Any, Set], see: Set[Any]): def find_connected_component(groups: List[List[Any]], value: Any) -> int: + """Find the list index that contains the given value.""" for index, group in enumerate(groups): if value in group: return index diff --git a/xsdata/utils/dates.py b/xsdata/utils/dates.py index 267c97480..79f0a54ce 100644 --- a/xsdata/utils/dates.py +++ b/xsdata/utils/dates.py @@ -1,9 +1,10 @@ import datetime from calendar import isleap -from typing import Any, Generator, Optional, Union +from typing import Any, Iterator, Optional, Union -def parse_date_args(value: Any, fmt: str) -> Generator: +def parse_date_args(value: Any, fmt: str) -> Iterator[int]: + """Parse the fmt args from the value.""" if not isinstance(value, str): raise ValueError("") @@ -12,6 +13,7 @@ def parse_date_args(value: Any, fmt: str) -> Generator: def calculate_timezone(offset: Optional[int]) -> Optional[datetime.timezone]: + """Return a timezone instance by the given hours offset.""" if offset is None: return None @@ -22,6 +24,7 @@ def calculate_timezone(offset: Optional[int]) -> Optional[datetime.timezone]: def calculate_offset(obj: Union[datetime.time, datetime.datetime]) -> Optional[int]: + """Convert the datetime offset to signed minutes.""" offset = obj.utcoffset() if offset is None: return None @@ -30,6 +33,7 @@ def calculate_offset(obj: Union[datetime.time, datetime.datetime]) -> Optional[i def format_date(year: int, month: int, day: int) -> str: + """Return a xml formatted signed date.""" if year < 0: year = -year sign = "-" @@ -40,6 +44,7 @@ def format_date(year: int, month: int, day: int) -> str: def format_time(hour: int, minute: int, second: int, fractional_second: int) -> str: + """Return a xml formatted time.""" if not fractional_second: return f"{hour:02d}:{minute:02d}:{second:02d}" @@ -55,6 +60,7 @@ def format_time(hour: int, minute: int, second: int, fractional_second: int) -> def format_offset(offset: Optional[int]) -> str: + """Return a xml formatted time offset.""" if offset is None: return "" @@ -77,10 +83,12 @@ def format_offset(offset: Optional[int]) -> str: def monthlen(year: int, month: int) -> int: + """Return the number of days for a specific month and year.""" return mdays[month] + (month == 2 and isleap(year)) def validate_date(year: int, month: int, day: int): + """Validate the given year, month day is a valid date.""" if not 1 <= month <= 12: raise ValueError("Month must be in 1..12") @@ -90,6 +98,7 @@ def validate_date(year: int, month: int, day: int): def validate_time(hour: int, minute: int, second: int, franctional_second: int): + """Validate the time args are valid.""" if not 0 <= hour <= 24: raise ValueError("Hour must be in 0..24") @@ -110,6 +119,19 @@ def validate_time(hour: int, minute: int, second: int, franctional_second: int): class DateTimeParser: + """XML Datetime parser. + + Args: + value: The datetime string + fmt: The target format string + + Attributes: + vlen: The length of the datetime string + flen: The length of the format string + vidx: The current position of the datetime string + fidx: The current position of the format string + """ + def __init__(self, value: str, fmt: str): self.format = fmt self.value = value @@ -118,7 +140,8 @@ def __init__(self, value: str, fmt: str): self.vidx = 0 self.fidx = 0 - def parse(self): + def parse(self) -> Iterator[int]: + """Yield the parsed datetime string arguments.""" try: while self.fidx < self.flen: char = self.next_format_char() @@ -138,23 +161,28 @@ def parse(self): ) def next_format_char(self) -> str: + """Return the next format character to evaluate.""" char = self.format[self.fidx] self.fidx += 1 return char def has_more(self) -> bool: + """Return whether the value is not fully parsed yet.""" return self.vidx < self.vlen def peek(self) -> str: + """Return the current evaluated character of the datetime string.""" return self.value[self.vidx] def skip(self, char: str): + """Validate and skip over the given char.""" if not self.has_more() or self.peek() != char: raise ValueError() self.vidx += 1 def parse_var(self, var: str): + """Parse the given var from the datetime string.""" if var in SIMPLE_TWO_DIGITS_FORMATS: yield self.parse_digits(2) elif var == "Y": @@ -169,6 +197,7 @@ def parse_var(self, var: str): raise ValueError() def parse_year(self) -> int: + """Parse the year argument.""" negative = False if self.peek() == "-": self.vidx += 1 @@ -195,6 +224,7 @@ def parse_year(self) -> int: return year def parse_fractional_second(self) -> int: + """Parse the fractional second argument.""" if self.has_more() and self.peek() == ".": self.vidx += 1 return self.parse_fixed_digits(9) @@ -202,11 +232,13 @@ def parse_fractional_second(self) -> int: return 0 def parse_digits(self, digits: int) -> int: + """Parse the given number of digits.""" start = self.vidx self.vidx += digits return int(self.value[start : self.vidx]) def parse_minimum_digits(self, min_digits: int) -> int: + """Parse until the next character is not a digit.""" start = self.vidx self.vidx += min_digits @@ -216,6 +248,7 @@ def parse_minimum_digits(self, min_digits: int) -> int: return int(self.value[start : self.vidx]) def parse_fixed_digits(self, max_digits: int) -> int: + """Parse a fixed number of digits.""" start = self.vidx just = max_digits while max_digits and self.has_more() and self.peek().isdigit(): @@ -225,6 +258,7 @@ def parse_fixed_digits(self, max_digits: int) -> int: return int(self.value[start : self.vidx].ljust(just, "0")) def parse_offset(self) -> Optional[int]: + """Parse the xml timezone offset as minutes.""" if not self.has_more(): return None diff --git a/xsdata/utils/debug.py b/xsdata/utils/debug.py index c347da42e..1825b425a 100644 --- a/xsdata/utils/debug.py +++ b/xsdata/utils/debug.py @@ -4,8 +4,7 @@ def dump(obj: Any): - """ - Write any object into a dump json file. + """Write any object into a dump json file. For internal troubleshooting purposes only!!! """ diff --git a/xsdata/utils/downloader.py b/xsdata/utils/downloader.py index 8a73a5a37..7eb92b8f0 100644 --- a/xsdata/utils/downloader.py +++ b/xsdata/utils/downloader.py @@ -11,11 +11,17 @@ class Downloader: - """ + """Remote recursive resource downloader. + Helper class to download a schema or a definitions with all their imports locally. The imports paths will be adjusted if necessary. - :param output: Output path + Args: + output: The output path + + Attributes: + base_path: The base path for the resources + downloaded: A cache of the downloaded resources """ __slots__ = ("output", "base_path", "downloaded") @@ -49,9 +55,7 @@ def parse_schema(self, uri: str, content: bytes): self.wget_included(schema) def parse_definitions(self, uri: str, content: bytes): - """Convert content to a definitions instance and process all sub - imports.""" - + """Convert content to a definitions instance and process all sub imports.""" parser = DefinitionsParser(location=uri) definitions = parser.from_bytes(content, Definitions) self.wget_included(definitions) @@ -60,18 +64,21 @@ def parse_definitions(self, uri: str, content: bytes): self.wget_included(schema) def wget_included(self, definition: Union[Schema, Definitions]): + """Download the definitions included resources.""" for included in definition.included(): if included.location: schema_location = getattr(included, "schema_location", None) self.wget(included.location, schema_location) def adjust_base_path(self, uri: str): - """ - Adjust base path for every new uri loaded. + """Adjust base path for every new uri loaded. Example runs: - file:///schemas/air_v48_0/Air.wsdl -> file:///schemas/air_v48_0 - file:///schemas/common_v48_0/CommonReqRsp.xsd -> file:///schemas + + Args: + uri: A resource location URI """ if not self.base_path: self.base_path = Path(uri).parent @@ -86,8 +93,7 @@ def adjust_base_path(self, uri: str): logger.info("Adjusting base path to %s", self.base_path) def adjust_imports(self, path: Path, content: str) -> str: - """Try to adjust the import locations for external locations that are - not relative to the first requested uri.""" + """Update the location of the imports to point to the downloaded files.""" matches = re.findall(r"ocation=\"(.*)\"", content) for match in matches: if isinstance(self.downloaded.get(match), Path): @@ -98,13 +104,16 @@ def adjust_imports(self, path: Path, content: str) -> str: return content def write_file(self, uri: str, location: Optional[str], content: str): - """ - Write the given uri and it's content according to the base path and if - the uri is relative to first requested uri. + """Write the downloaded uri to a local file. Keep track of all the written file paths, in case we have to modify the location attribute in an upcoming schema/definition import. + + Args: + uri: The resource URI + location: The import location of the resource + content: The raw content string """ common_path = os.path.commonpath((self.base_path or "", uri)) if common_path: diff --git a/xsdata/utils/graphs.py b/xsdata/utils/graphs.py index bd442821c..5d4c86065 100644 --- a/xsdata/utils/graphs.py +++ b/xsdata/utils/graphs.py @@ -2,13 +2,16 @@ def strongly_connected_components(edges: Dict[str, List[str]]) -> Iterator[Set[str]]: - """ - Compute Strongly Connected Components of a directed graph. + """Compute Strongly Connected Components of a directed graph. From https://code.activestate.com/recipes/578507/ From https://github.com/python/mypy/blob/master/mypy/build.py - :param edges: Mapping of vertex-edges values + Args: + edges: A vertex-edges map + + Yields: + A set of the strongly connected components """ identified: Set[str] = set() stack: List[str] = [] diff --git a/xsdata/utils/hooks.py b/xsdata/utils/hooks.py index d69bb7d5e..c417610f4 100644 --- a/xsdata/utils/hooks.py +++ b/xsdata/utils/hooks.py @@ -2,6 +2,7 @@ def load_entry_points(name: str): + """Load the plugins for the given hook name.""" entry_points = metadata.entry_points() if hasattr(entry_points, "select"): diff --git a/xsdata/utils/namespaces.py b/xsdata/utils/namespaces.py index b41bb2350..441032246 100644 --- a/xsdata/utils/namespaces.py +++ b/xsdata/utils/namespaces.py @@ -15,8 +15,7 @@ def load_prefix(uri: str, ns_map: Dict) -> Optional[str]: - """Get or create a prefix for the given uri in the prefix-URI namespace - mapping.""" + """Get or create a prefix for the uri in the prefix-URI map.""" for prefix, ns in ns_map.items(): if ns == uri: return prefix @@ -25,8 +24,7 @@ def load_prefix(uri: str, ns_map: Dict) -> Optional[str]: def generate_prefix(uri: str, ns_map: Dict) -> str: - """Generate and add a prefix for the given uri in the prefix-URI namespace - mapping.""" + """Generate a prefix for the given uri and append it in the prefix-URI map.""" namespace = Namespace.get_enum(uri) if namespace: prefix = namespace.prefix @@ -66,7 +64,7 @@ def clean_prefixes(ns_map: Dict) -> Dict: def clean_uri(namespace: str) -> str: - """Remove common prefixes and suffixes from a uri string.""" + """Remove common prefixes and suffixes from an URI string.""" if namespace[:2] == "##": namespace = namespace[2:] @@ -80,12 +78,6 @@ def clean_uri(namespace: str) -> str: return "_".join(x for x in namespace.split(".") if x not in __uri_ignore__) -def real_xsi_type(qname: str, target_qname: Optional[str]) -> Optional[str]: - """Determine if the given target qualified name should be used to define a - derived type.""" - return target_qname if target_qname != qname else None - - @functools.lru_cache(maxsize=50) def build_qname(tag_or_uri: Optional[str], tag: Optional[str] = None) -> str: """Create namespace qualified strings.""" @@ -99,22 +91,24 @@ def build_qname(tag_or_uri: Optional[str], tag: Optional[str] = None) -> str: @functools.lru_cache(maxsize=50) -def split_qname(tag: str) -> Tuple: +def split_qname(qname: str) -> Tuple: """Split namespace qualified strings.""" - if tag[0] == "{": - left, right = text.split(tag[1:], "}") + if qname[0] == "{": + left, right = text.split(qname[1:], "}") if left: return left, right - return None, tag + return None, qname -def target_uri(tag: str) -> Optional[str]: - return split_qname(tag)[0] +def target_uri(qname: str) -> Optional[str]: + """Return the URI namespace of the qname.""" + return split_qname(qname)[0] -def local_name(tag: str) -> str: - return split_qname(tag)[1] +def local_name(qname: str) -> str: + """Return the local name of the qname.""" + return split_qname(qname)[1] NCNAME_PUNCTUATION = {"\u00B7", "\u0387", ".", "-", "_"} diff --git a/xsdata/utils/objects.py b/xsdata/utils/objects.py index 54377bbba..209e7f26c 100644 --- a/xsdata/utils/objects.py +++ b/xsdata/utils/objects.py @@ -5,20 +5,21 @@ def update(obj: Any, **kwargs: Any): """Update an object from keyword arguments with dotted keys.""" - for key, value in kwargs.items(): - attrsetter(obj, key, value) + def attrsetter(obj: Any, attr: str, value: Any): + names = attr.split(".") + last = names.pop() + for name in names: + obj = getattr(obj, name) -def attrsetter(obj: Any, attr: str, value: Any): - names = attr.split(".") - last = names.pop() - for name in names: - obj = getattr(obj, name) + setattr(obj, last, value) - setattr(obj, last, value) + for key, value in kwargs.items(): + attrsetter(obj, key, value) def literal_value(value: Any) -> str: + """Return the value for code generation.""" if isinstance(value, float): return str(value) if math.isfinite(value) else f'float("{value}")' diff --git a/xsdata/utils/package.py b/xsdata/utils/package.py index a87159cc8..a0848b47c 100644 --- a/xsdata/utils/package.py +++ b/xsdata/utils/package.py @@ -16,7 +16,15 @@ def module_path(module: str) -> Path: @functools.lru_cache(maxsize=50) -def module_name(source: str) -> str: - module = source.split("/")[-1] +def module_name(uri: str) -> str: + """Convert a file uri to a module name. + + Args: + uri: A file URI location + + Returns: + The last part of the URI path stripped from known extensions. + """ + module = uri.split("/")[-1] name, extension = os.path.splitext(module) return name if extension in (".xsd", ".dtd", ".wsdl", ".xml", ".json") else module diff --git a/xsdata/utils/text.py b/xsdata/utils/text.py index dabd0f88f..560d7d1fd 100644 --- a/xsdata/utils/text.py +++ b/xsdata/utils/text.py @@ -48,7 +48,7 @@ "list", "nonlocal", "not", - "object", # py36 specific + "object", "or", "pass", "raise", @@ -76,12 +76,7 @@ def suffix(value: str, sep: str = ":") -> str: def split(value: str, sep: str = ":") -> Tuple: - """ - Separate the given string with the given separator and return a tuple of - the prefix and suffix. - - If the separator isn't present in the string return None as prefix. - """ + """Split the given value with the given separator once.""" left, _, right = value.partition(sep) return (left, right) if right else (None, left) @@ -138,8 +133,7 @@ def kebab_case(value: str, **kwargs: Any) -> str: def split_words(value: str) -> List[str]: - """Split a string on new capital letters and not alphanumeric - characters.""" + """Split a string on capital letters and not alphanumeric characters.""" words: List[str] = [] buffer: List[str] = [] previous = None @@ -151,11 +145,11 @@ def flush(): for char in value: tp = classify(char) - if tp == StringType.OTHER: + if tp == CharType.OTHER: flush() elif not previous or tp == previous: buffer.append(char) - elif tp == StringType.UPPER and previous != StringType.UPPER: + elif tp == CharType.UPPER and previous != CharType.UPPER: flush() buffer.append(char) else: @@ -167,7 +161,9 @@ def flush(): return words -class StringType: +class CharType: + """Character types.""" + UPPER = 1 LOWER = 2 NUMERIC = 3 @@ -178,15 +174,15 @@ def classify(character: str) -> int: """String classifier.""" code_point = ord(character) if 64 < code_point < 91: - return StringType.UPPER + return CharType.UPPER if 96 < code_point < 123: - return StringType.LOWER + return CharType.LOWER if 47 < code_point < 58: - return StringType.NUMERIC + return CharType.NUMERIC - return StringType.OTHER + return CharType.OTHER ESCAPE = re.compile(r'[\x00-\x1f\\"\b\f\n\r\t]') @@ -204,11 +200,7 @@ def classify(character: str) -> int: def escape_string(value: str) -> str: - """ - Escape a string for code generation. - - Source: json.encoder.py_encode_basestring - """ + """Escape a string for code generation.""" def replace(match: Match) -> str: return ESCAPE_DCT[match.group(0)] @@ -220,8 +212,7 @@ def replace(match: Match) -> str: def alnum(value: str) -> str: - """Return a lower case version of the string only with ascii alphanumerical - characters.""" + """Return the ascii alphanumerical characters in lower case.""" return "".join(filter(__alnum_ascii__.__contains__, value)).lower()