From 319082868b5a0de111f017590dbf58971836f20e Mon Sep 17 00:00:00 2001 From: Christodoulos Tsoulloftas Date: Fri, 8 Mar 2024 15:29:35 +0200 Subject: [PATCH] feat: Detect circular references more accurately --- docs/codegen/architecture.md | 9 +- .../test_detect_circular_references.py | 75 ++++++++++++++++ .../handlers/test_process_attributes_types.py | 81 +---------------- tests/codegen/test_container.py | 5 +- xsdata/codegen/container.py | 22 ++--- xsdata/codegen/handlers/__init__.py | 4 +- .../handlers/detect_circular_references.py | 90 +++++++++++++++++++ .../handlers/process_attributes_types.py | 59 +----------- xsdata/codegen/models.py | 28 ++---- 9 files changed, 200 insertions(+), 173 deletions(-) create mode 100644 tests/codegen/handlers/test_detect_circular_references.py create mode 100644 xsdata/codegen/handlers/detect_circular_references.py diff --git a/docs/codegen/architecture.md b/docs/codegen/architecture.md index 71db0ed43..4f7471ef9 100644 --- a/docs/codegen/architecture.md +++ b/docs/codegen/architecture.md @@ -57,8 +57,6 @@ graph LR B --> C[Validate class references] ``` -API: [xsdata.codegen.analyzer.ClassAnalyzer][] - ### Validate Classes - Remove types with unknown references @@ -135,9 +133,13 @@ pass through each step before next one starts. The order of the steps is very im - [ValidateAttributesOverrides][xsdata.codegen.handlers.ValidateAttributesOverrides] -### Step: Finalize +### Step: Vacuum - [VacuumInnerClasses][xsdata.codegen.handlers.VacuumInnerClasses] + +### Step: Finalize + +- [DetectCircularReferences][xsdata.codegen.handlers.DetectCircularReferences] - [CreateCompoundFields][xsdata.codegen.handlers.CreateCompoundFields] - [DisambiguateChoices][xsdata.codegen.handlers.DisambiguateChoices] - [ResetAttributeSequenceNumbers][xsdata.codegen.handlers.ResetAttributeSequenceNumbers] @@ -145,4 +147,5 @@ pass through each step before next one starts. The order of the steps is very im ### Step: Designate - [RenameDuplicateClasses][xsdata.codegen.handlers.RenameDuplicateClasses] +- [ValidateReferences][xsdata.codegen.handlers.ValidateReferences] - [DesignateClassPackages][xsdata.codegen.handlers.DesignateClassPackages] diff --git a/tests/codegen/handlers/test_detect_circular_references.py b/tests/codegen/handlers/test_detect_circular_references.py new file mode 100644 index 000000000..8d8f73606 --- /dev/null +++ b/tests/codegen/handlers/test_detect_circular_references.py @@ -0,0 +1,75 @@ +from xsdata.codegen.container import ClassContainer +from xsdata.codegen.handlers.detect_circular_references import DetectCircularReferences +from xsdata.models.config import GeneratorConfig +from xsdata.models.enums import DataType +from xsdata.utils.testing import ( + AttrFactory, + AttrTypeFactory, + ClassFactory, + FactoryTestCase, +) + + +class DetectCircularReferencesTests(FactoryTestCase): + def setUp(self): + super().setUp() + config = GeneratorConfig() + self.container = ClassContainer(config=config) + self.processor = DetectCircularReferences(self.container) + + def test_process(self): + first = ClassFactory.create(qname="first") + second = ClassFactory.create(qname="second") + third = ClassFactory.create(qname="third") + + first.attrs.append(AttrFactory.native(DataType.STRING)) + first.attrs.append( + AttrFactory.create( + types=[ + AttrTypeFactory.create(qname="second", reference=second.ref), + AttrTypeFactory.create(qname="third", reference=third.ref), + ], + choices=[ + AttrFactory.reference("second", reference=second.ref), + AttrFactory.reference("third", reference=third.ref), + ], + ) + ) + + second.attrs = AttrFactory.list(2) + third.attrs.append(AttrFactory.reference("first", reference=first.ref)) + self.container.extend([first, second, third]) + + self.processor.process(first) + + first_flags = [tp.circular for tp in first.types()] + self.assertEqual([False, False, True, False, True], first_flags) + + second_flags = [tp.circular for tp in second.types()] + self.assertEqual([False, False], second_flags) + + # First has the flags this doesn't need it :) + third_flags = [tp.circular for tp in third.types()] + self.assertEqual([False], third_flags) + + def test_build_reference_types(self): + target = ClassFactory.create() + inner = ClassFactory.create() + + outer_attr = AttrFactory.create() + inner_attr = AttrFactory.reference("foo", reference=target.ref) + + inner.attrs.append(inner_attr) + target.inner.append(inner) + target.attrs.append(outer_attr) + + self.container.add(target) + + self.processor.build_reference_types() + + expected = { + target.ref: [inner_attr.types[0]], + inner.ref: [inner_attr.types[0]], + } + + self.assertEqual(expected, self.processor.reference_types) diff --git a/tests/codegen/handlers/test_process_attributes_types.py b/tests/codegen/handlers/test_process_attributes_types.py index 1b3ef7f7d..bfb47e25a 100644 --- a/tests/codegen/handlers/test_process_attributes_types.py +++ b/tests/codegen/handlers/test_process_attributes_types.py @@ -2,7 +2,7 @@ from xsdata.codegen.container import ClassContainer from xsdata.codegen.handlers import ProcessAttributeTypes -from xsdata.codegen.models import Class, Restrictions, Status +from xsdata.codegen.models import Restrictions, Status from xsdata.codegen.utils import ClassUtils from xsdata.models.config import GeneratorConfig from xsdata.models.enums import DataType, Tag @@ -180,11 +180,8 @@ def test_process_dependency_type_with_enumeration_type(self, mock_find_dependenc self.assertIsNone(attr.restrictions.min_length) self.assertIsNone(attr.restrictions.max_length) - @mock.patch.object(ProcessAttributeTypes, "set_circular_flag") @mock.patch.object(ProcessAttributeTypes, "find_dependency") - def test_process_dependency_type_with_complex_type( - self, mock_find_dependency, mock_set_circular_flag - ): + def test_process_dependency_type_with_complex_type(self, mock_find_dependency): complex_type = ClassFactory.elements(1) mock_find_dependency.return_value = complex_type @@ -193,13 +190,13 @@ def test_process_dependency_type_with_complex_type( attr_type = attr.types[0] self.processor.process_dependency_type(target, attr, attr_type) - mock_set_circular_flag.assert_called_once_with(complex_type, target, attr_type) self.assertFalse(attr.restrictions.nillable) complex_type.nillable = True self.processor.process_dependency_type(target, attr, attr_type) self.assertTrue(attr.restrictions.nillable) + self.assertEqual(complex_type.ref, attr_type.reference) @mock.patch.object(ProcessAttributeTypes, "find_dependency") def test_process_dependency_type_with_abstract_type_type( @@ -328,63 +325,6 @@ def test_copy_attribute_properties_set_default_value_if_none(self): self.assertEqual("foo", attr.default) self.assertTrue("foo", attr.fixed) - @mock.patch.object(ProcessAttributeTypes, "is_circular_dependency") - def test_set_circular_flag(self, mock_is_circular_dependency): - source = ClassFactory.create() - target = ClassFactory.create() - attr = AttrFactory.create() - attr_type = attr.types[0] - - mock_is_circular_dependency.return_value = True - - self.processor.set_circular_flag(source, target, attr_type) - self.assertTrue(attr_type.circular) - self.assertEqual(id(source), attr_type.reference) - - mock_is_circular_dependency.assert_called_once_with(source, target, set()) - - @mock.patch.object(ClassContainer, "find") - @mock.patch.object(Class, "dependencies") - def test_is_circular_dependency(self, mock_dependencies, mock_container_find): - source = ClassFactory.create() - target = ClassFactory.create() - another = ClassFactory.create() - processing = ClassFactory.create(status=Status.FLATTENING) - - find_classes = {"a": another, "b": target} - - mock_container_find.side_effect = lambda x: find_classes.get(x) - mock_dependencies.side_effect = [ - list("ccde"), - list("abc"), - list("xy"), - ] - - self.assertTrue( - self.processor.is_circular_dependency(processing, target, set()) - ) - - self.processor.dependencies.clear() - self.assertFalse(self.processor.is_circular_dependency(source, target, set())) - - self.processor.dependencies.clear() - self.assertTrue(self.processor.is_circular_dependency(source, target, set())) - - self.processor.dependencies.clear() - self.assertTrue(self.processor.is_circular_dependency(source, source, set())) - - mock_container_find.assert_has_calls( - [ - mock.call("c"), - mock.call("d"), - mock.call("e"), - mock.call("a"), - mock.call("x"), - mock.call("y"), - mock.call("b"), - ] - ) - def test_find_dependency(self): attr_type = AttrTypeFactory.create(qname="a") @@ -413,21 +353,6 @@ def test_find_dependency(self): actual = self.processor.find_dependency(attr_type, Tag.EXTENSION) self.assertEqual(simple_type, actual) - @mock.patch.object(Class, "dependencies") - def test_cached_dependencies(self, mock_class_dependencies): - mock_class_dependencies.return_value = ["a", "b"] - - source = ClassFactory.create() - self.processor.dependencies[id(source)] = ("a",) - - actual = self.processor.cached_dependencies(source) - self.assertEqual(("a",), actual) - - self.processor.dependencies.clear() - actual = self.processor.cached_dependencies(source) - self.assertEqual(("a", "b"), actual) - mock_class_dependencies.assert_called_once_with() - def test_update_restrictions(self): attr = AttrFactory.native(DataType.NMTOKENS) self.processor.update_restrictions(attr, attr.types[0].datatype) diff --git a/tests/codegen/test_container.py b/tests/codegen/test_container.py index 37ac74ed5..0fe89c264 100644 --- a/tests/codegen/test_container.py +++ b/tests/codegen/test_container.py @@ -53,8 +53,9 @@ def test_initialize(self): "SanitizeAttributesDefaultValue", ], 40: ["ValidateAttributesOverrides"], - 50: [ - "VacuumInnerClasses", + 50: ["VacuumInnerClasses"], + 60: [ + "DetectCircularReferences", "CreateCompoundFields", "DisambiguateChoices", "ResetAttributeSequenceNumbers", diff --git a/xsdata/codegen/container.py b/xsdata/codegen/container.py index 22cafe9ae..afaba2e98 100644 --- a/xsdata/codegen/container.py +++ b/xsdata/codegen/container.py @@ -5,6 +5,7 @@ CalculateAttributePaths, CreateCompoundFields, DesignateClassPackages, + DetectCircularReferences, DisambiguateChoices, FilterClasses, FlattenAttributeGroups, @@ -40,7 +41,8 @@ class Steps: FLATTEN = 20 SANITIZE = 30 RESOLVE = 40 - FINALIZE = 50 + CLEANUP = 50 + FINALIZE = 60 class ClassContainer(ContainerInterface): @@ -91,8 +93,11 @@ def __init__(self, config: GeneratorConfig): Steps.RESOLVE: [ ValidateAttributesOverrides(self), ], - Steps.FINALIZE: [ + Steps.CLEANUP: [ VacuumInnerClasses(), + ], + Steps.FINALIZE: [ + DetectCircularReferences(self), CreateCompoundFields(self), DisambiguateChoices(self), ResetAttributeSequenceNumbers(self), @@ -165,17 +170,7 @@ def first(self, qname: str) -> Class: return classes[0] def process(self): - """Run the processor and filter steps. - - Steps: - 1. Ungroup xs:groups and xs:attributeGroups - 2. Remove the group classes from the container - 3. Flatten extensions, attrs and attr types - 4. Remove the classes that won't be generated - 5. Resolve attrs overrides - 5. Create compound fields, cleanup classes and atts - 7. Designate final class names, packages and modules - """ + """Run the processor and filter steps.""" self.validate_classes() self.process_classes(Steps.UNGROUP) self.remove_groups() @@ -183,6 +178,7 @@ def process(self): self.filter_classes() self.process_classes(Steps.SANITIZE) self.process_classes(Steps.RESOLVE) + self.process_classes(Steps.CLEANUP) self.process_classes(Steps.FINALIZE) self.designate_classes() diff --git a/xsdata/codegen/handlers/__init__.py b/xsdata/codegen/handlers/__init__.py index 598c8d8f6..b732a6643 100644 --- a/xsdata/codegen/handlers/__init__.py +++ b/xsdata/codegen/handlers/__init__.py @@ -2,6 +2,7 @@ from .calculate_attribute_paths import CalculateAttributePaths from .create_compound_fields import CreateCompoundFields from .designate_class_packages import DesignateClassPackages +from .detect_circular_references import DetectCircularReferences from .disambiguate_choices import DisambiguateChoices from .filter_classes import FilterClasses from .flatten_attribute_groups import FlattenAttributeGroups @@ -26,6 +27,7 @@ "CalculateAttributePaths", "CreateCompoundFields", "DesignateClassPackages", + "DetectCircularReferences", "DisambiguateChoices", "FilterClasses", "FlattenAttributeGroups", @@ -35,8 +37,8 @@ "ProcessMixedContentClass", "RenameDuplicateAttributes", "RenameDuplicateClasses", - "ResetAttributeSequences", "ResetAttributeSequenceNumbers", + "ResetAttributeSequences", "SanitizeAttributesDefaultValue", "SanitizeEnumerationClass", "UnnestInnerClasses", diff --git a/xsdata/codegen/handlers/detect_circular_references.py b/xsdata/codegen/handlers/detect_circular_references.py new file mode 100644 index 000000000..5f6be8bf4 --- /dev/null +++ b/xsdata/codegen/handlers/detect_circular_references.py @@ -0,0 +1,90 @@ +from typing import Dict, List + +from xsdata.codegen.mixins import ( + ContainerInterface, + RelativeHandlerInterface, +) +from xsdata.codegen.models import AttrType, Class + + +class DetectCircularReferences(RelativeHandlerInterface): + """Accurately detect circular dependencies between classes. + + Args: + container: The class container instance + + Attributes: + reference_types: A map of class refs to dependency types + """ + + __slots__ = "container", "reference_types" + + def __init__(self, container: ContainerInterface): + super().__init__(container) + self.reference_types: Dict[int, List[AttrType]] = {} + + def process(self, target: Class): + """Go through all the attr types and find circular references. + + Args: + target: The class to inspect and update + """ + if not self.reference_types: + self.build_reference_types() + + for attr in target.attrs: + self.process_types(attr.types, target.ref) + + for choice in attr.choices: + self.process_types(choice.types, target.ref) + + def process_types(self, types: List[AttrType], class_reference: int): + """Go through the types and find circular references. + + Args: + types: A list attr/choice type instances + class_reference: The parent attr/choice class reference + """ + for tp in types: + if not tp.forward and not tp.native and not tp.circular: + tp.circular = self.is_circular(tp.reference, class_reference) + + def is_circular(self, start: int, stop: int) -> bool: + """Detect if the start reference leads to the stop reference. + + The procedure is a dfs search to avoid max recursion errors. + + Args: + start: The attr type reference + stop: The parent class reference + + Returns: + Whether the start reference leads back to the stop reference. + """ + path = set() + stack = [start] + while len(stack) != 0: + if stop in path: + return True + + ref = stack.pop() + path.add(ref) + + for tp in self.reference_types[ref]: + if not tp.circular and tp.reference not in path: + stack.append(tp.reference) + + return stop in path + + def build_reference_types(self): + """Build the reference types mapping.""" + + def generate(target: Class): + yield target.ref, [tp for tp in target.types() if tp.reference] + + for inner in target.inner: + yield from generate(inner) + + for item in self.container: + for ref, types in generate(item): + self.reference_types[ref] = types diff --git a/xsdata/codegen/handlers/process_attributes_types.py b/xsdata/codegen/handlers/process_attributes_types.py index 6450ddd15..065a4347c 100644 --- a/xsdata/codegen/handlers/process_attributes_types.py +++ b/xsdata/codegen/handlers/process_attributes_types.py @@ -1,7 +1,7 @@ -from typing import Dict, Optional, Set, Tuple +from typing import Dict, Optional from xsdata.codegen.mixins import ContainerInterface, RelativeHandlerInterface -from xsdata.codegen.models import Attr, AttrType, Class, Status +from xsdata.codegen.models import Attr, AttrType, Class from xsdata.codegen.utils import ClassUtils from xsdata.logger import logger from xsdata.models.enums import DataType, Tag @@ -204,7 +204,8 @@ def process_dependency_type(self, target: Class, attr: Attr, attr_type: AttrType else: if source.nillable: attr.restrictions.nillable = True - self.set_circular_flag(source, target, attr_type) + + attr_type.reference = id(source) self.detect_lazy_namespace(source, target, attr) @classmethod @@ -255,58 +256,6 @@ def copy_attribute_properties( attr.fixed = attr.fixed or source_attr.fixed attr.default = attr.default or source_attr.default - def set_circular_flag(self, source: Class, target: Class, attr_type: AttrType): - """Detect circular references and set the type flag. - - Args: - source: The source class instance - target: The target class instance - attr_type: The attr type instance - """ - attr_type.reference = id(source) - attr_type.circular = self.is_circular_dependency(source, target, set()) - - if attr_type.circular: - logger.debug("Possible circular reference %s, %s", target.name, source.name) - - def is_circular_dependency(self, source: Class, target: Class, seen: Set) -> bool: - """Check if any source dependencies recursively match the target class. - - Args: - source: The source class instance - target: The target class instance - seen: A set of qualified names, to guard against recursive runs - - Returns: - The bool result - """ - if source is target or source.status == Status.FLATTENING: - return True - - for qname in self.cached_dependencies(source): - if qname not in seen: - seen.add(qname) - check = self.container.find(qname) - if check and self.is_circular_dependency(check, target, seen): - return True - - return False - - def cached_dependencies(self, source: Class) -> Tuple[str]: - """Returns from cache the source class dependencies. - - Args: - source: The source class instance - - Returns: - A tuple containing the qualified names of the class dependencies. - """ - cache_key = id(source) - if cache_key not in self.dependencies: - self.dependencies[cache_key] = tuple(source.dependencies()) - - return self.dependencies[cache_key] - @classmethod def reset_attribute_type(cls, attr_type: AttrType, use_str: bool = True): """Reset the attribute type to string or any simple type. diff --git a/xsdata/codegen/models.py b/xsdata/codegen/models.py index 1c0f0b1c4..b1c3a7f7c 100644 --- a/xsdata/codegen/models.py +++ b/xsdata/codegen/models.py @@ -474,8 +474,10 @@ class Status(IntEnum): SANITIZED = 31 RESOLVING = 40 RESOLVED = 41 - FINALIZING = 50 - FINALIZED = 51 + CLEANING = 50 + CLEANED = 51 + FINALIZING = 60 + FINALIZED = 61 @dataclass @@ -606,25 +608,9 @@ def is_service(self) -> bool: @property def references(self) -> Iterator[int]: """Yield all class object reference numbers.""" - - def all_refs(): - for ext in self.extensions: - yield ext.type.reference - - for attr in self.attrs: - for tp in attr.types: - yield tp.reference - - for choice in attr.choices: - for ctp in choice.types: - yield ctp.reference - - for inner in self.inner: - yield from inner.references - - for ref in all_refs(): - if ref: - yield ref + for tp in self.types(): + if tp.reference: + yield tp.reference @property def target_module(self) -> str: