diff --git a/tests/test_protected.py b/tests/test_protected.py new file mode 100644 index 0000000..163ef7a --- /dev/null +++ b/tests/test_protected.py @@ -0,0 +1,15 @@ +import znflow + + +class MyNode(znflow.Node): + _protected_ = znflow.Node._protected_ + ["a"] + + a: int = 42 + b: int = 42 + + +def test_protected(): + with znflow.DiGraph(): + node = MyNode() + assert node.a == 42 + assert isinstance(node.b, znflow.Connection) diff --git a/znflow/node.py b/znflow/node.py index 052991a..015a664 100644 --- a/znflow/node.py +++ b/znflow/node.py @@ -70,7 +70,7 @@ def __new__(cls, *args, **kwargs): graph.add_node(instance, this_uuid=this_uuid) return instance - def __getattribute__(self, item): + def __getattribute__(self, item: str): if item.startswith("_"): return super().__getattribute__(item) if self._graph_ not in [empty_graph, None]: @@ -80,7 +80,7 @@ def __getattribute__(self, item): f"'{self.__class__.__name__}' object has no attribute '{item}'" ) - if item not in type(self)._protected_: + if item not in self._protected_: if self._in_construction: return super().__getattribute__(item) return Connection(instance=self, attribute=item)