diff --git a/xsdata/codegen/handlers/validate_attributes_overrides.py b/xsdata/codegen/handlers/validate_attributes_overrides.py index 9e93bfc97..c0c8691b9 100644 --- a/xsdata/codegen/handlers/validate_attributes_overrides.py +++ b/xsdata/codegen/handlers/validate_attributes_overrides.py @@ -7,6 +7,7 @@ from xsdata.codegen.models import Attr from xsdata.codegen.models import Class from xsdata.codegen.models import get_slug +from xsdata.codegen.models import Restrictions from xsdata.codegen.utils import ClassUtils from xsdata.utils import collections @@ -24,6 +25,12 @@ class ValidateAttributesOverrides(RelativeHandlerInterface): __slots__ = () def process(self, target: Class): + restriction_attrs: dict[str, Attr] = {} + if len([ext for ext in target.extensions if ext.tag == "Restriction"]) > 0: + restriction_attrs = { + attr.slug: attr for attr in target.attrs if not attr.is_attribute + } + base_attrs_map = self.base_attrs_map(target) for attr in list(target.attrs): base_attrs = base_attrs_map.get(attr.slug) @@ -37,6 +44,18 @@ def process(self, target: Class): elif attr.is_prohibited: self.remove_attribute(target, attr) + if len([ext for ext in target.extensions if ext.tag == "Restriction"]) > 0: + # What we want here is to check the restriction.attrs against base_attrs_map + for slug, attr in base_attrs_map.items(): + if not attr[0].is_attribute and slug not in restriction_attrs: + attr_restricted = Attr( + tag=attr[0].tag, + name=attr[0].name, + index=attr[0].index, + restrictions=Restrictions(is_restricted=True), + ) + target.attrs.append(attr_restricted) + @classmethod def overrides(cls, a: Attr, b: Attr) -> bool: return a.xml_type == b.xml_type and a.namespace == b.namespace diff --git a/xsdata/codegen/models.py b/xsdata/codegen/models.py index a054a653b..5a95cc8ee 100644 --- a/xsdata/codegen/models.py +++ b/xsdata/codegen/models.py @@ -62,6 +62,7 @@ class Restrictions: group: Optional[int] = field(default=None) process_contents: Optional[str] = field(default=None) path: List[Tuple[str, int, int, int]] = field(default_factory=list) + is_restricted: bool = field(default=False) @property def is_list(self) -> bool: diff --git a/xsdata/formats/dataclass/filters.py b/xsdata/formats/dataclass/filters.py index e7e114b38..c5888ba7b 100644 --- a/xsdata/formats/dataclass/filters.py +++ b/xsdata/formats/dataclass/filters.py @@ -127,6 +127,7 @@ def register(self, env: Environment): "field_default": self.field_default_value, "field_metadata": self.field_metadata, "field_definition": self.field_definition, + "is_restricted": self.is_restricted, "class_name": self.class_name, "class_bases": self.class_bases, "class_annotations": self.class_annotations, @@ -223,6 +224,9 @@ def apply_substitutions(self, name: str, obj_type: ObjectType) -> str: return name + def is_restricted(self, attr: Attr) -> bool: + return attr.restrictions.is_restricted + def field_definition( self, attr: Attr, diff --git a/xsdata/formats/dataclass/templates/class.jinja2 b/xsdata/formats/dataclass/templates/class.jinja2 index 4075911ae..75384b4fb 100644 --- a/xsdata/formats/dataclass/templates/class.jinja2 +++ b/xsdata/formats/dataclass/templates/class.jinja2 @@ -44,7 +44,7 @@ class {{ class_name }}{{"({})".format(base_classes) if base_classes }}: {%- for attr in obj.attrs %} {%- set field_typing = attr|field_type(parents) %} {%- set field_definition = attr|field_definition(obj.ns_map, parent_namespace, parents) %} - {{ attr.name|field_name(obj.name) }}: {{ field_typing }} = {{ field_definition }} + {{ attr.name|field_name(obj.name) }}: {%- if attr|is_restricted %} RestrictedVar {%- else %} {{ field_typing }} = {{ field_definition }} {%- endif -%} {%- endfor -%} {%- for inner in obj.inner %} {%- set tpl = "enum.jinja2" if inner.is_enumeration else "class.jinja2" -%} diff --git a/xsdata/formats/dataclass/templates/imports.jinja2 b/xsdata/formats/dataclass/templates/imports.jinja2 index 7c59053d8..1d535a370 100644 --- a/xsdata/formats/dataclass/templates/imports.jinja2 +++ b/xsdata/formats/dataclass/templates/imports.jinja2 @@ -9,3 +9,7 @@ from {{ source | import_module(module) }} import ( ) {% endif -%} {%- endfor %} +{# +TODO: remove +#} +from typing import ClassVar as RestrictedVar