From 56b6c6444ecac08b011aea83d11373fc4fe70d03 Mon Sep 17 00:00:00 2001 From: yueyinqiu Date: Sat, 19 Oct 2024 00:06:46 +0800 Subject: [PATCH 1/4] nonlocal by copying global --- astroid/rebuilder.py | 11 ++++++++++- tests/test_locals.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) create mode 100644 tests/test_locals.py diff --git a/astroid/rebuilder.py b/astroid/rebuilder.py index 4c77906e02..27f34ce846 100644 --- a/astroid/rebuilder.py +++ b/astroid/rebuilder.py @@ -61,6 +61,7 @@ def __init__( self._manager = manager self._data = data.split("\n") if data else None self._global_names: list[dict[str, list[nodes.Global]]] = [] + self._nonlocal_names: list[dict[str, list[nodes.Nonlocal]]] = [] self._import_from_nodes: list[nodes.ImportFrom] = [] self._delayed_assattr: list[nodes.AssignAttr] = [] self._visit_meths: dict[type[ast.AST], Callable[[ast.AST, NodeNG], NodeNG]] = {} @@ -451,6 +452,8 @@ def _save_assignment(self, node: nodes.AssignName | nodes.DelName) -> None: """Save assignment situation since node.parent is not available yet.""" if self._global_names and node.name in self._global_names[-1]: node.root().set_local(node.name, node) + elif self._nonlocal_names and node.name in self._nonlocal_names[-1]: + node.root().set_local(node.name, node) else: assert node.parent assert node.name @@ -1065,6 +1068,7 @@ def _visit_functiondef( ) -> _FunctionT: """Visit an FunctionDef node to become astroid.""" self._global_names.append({}) + self._nonlocal_names.append({}) node, doc_ast_node = self._get_doc(node) lineno = node.lineno @@ -1113,6 +1117,7 @@ def _visit_functiondef( ), ) self._global_names.pop() + self._nonlocal_names.pop() parent.set_local(newnode.name, newnode) return newnode @@ -1383,7 +1388,7 @@ def visit_name( def visit_nonlocal(self, node: ast.Nonlocal, parent: NodeNG) -> nodes.Nonlocal: """Visit a Nonlocal node and return a new instance of it.""" - return nodes.Nonlocal( + newnode = nodes.Nonlocal( names=node.names, lineno=node.lineno, col_offset=node.col_offset, @@ -1391,6 +1396,10 @@ def visit_nonlocal(self, node: ast.Nonlocal, parent: NodeNG) -> nodes.Nonlocal: end_col_offset=node.end_col_offset, parent=parent, ) + if self._nonlocal_names: + for name in node.names: + self._nonlocal_names[-1].setdefault(name, []).append(newnode) + return newnode def visit_constant(self, node: ast.Constant, parent: NodeNG) -> nodes.Const: """Visit a Constant node by returning a fresh instance of Const.""" diff --git a/tests/test_locals.py b/tests/test_locals.py new file mode 100644 index 0000000000..bf98e2ee88 --- /dev/null +++ b/tests/test_locals.py @@ -0,0 +1,29 @@ +# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html +# For details: https://github.com/pylint-dev/astroid/blob/main/LICENSE +# Copyright (c) https://github.com/pylint-dev/astroid/blob/main/CONTRIBUTORS.txt + +import unittest + +from astroid.nodes.scoped_nodes.scoped_nodes import FunctionDef + +from astroid import Uninferable, builder, extract_node, nodes +from astroid.exceptions import InferenceError + + +class TestLocals(unittest.TestCase): + def test(self) -> None: + module = builder.parse( + """ + x1 = 1 + def f1(): + x2 = 2 + def f2(): + global x1 + nonlocal x2 + x1 = 1 + x2 = 2 + x3 = 3 + """ + ) + x = module.locals["f1"][0].locals["f2"][0].locals + pass From 0331862e44199e6a4dac462bef055b955dcac3f6 Mon Sep 17 00:00:00 2001 From: yueyinqiu Date: Sat, 19 Oct 2024 13:37:00 +0800 Subject: [PATCH 2/4] find nonlocal --- astroid/rebuilder.py | 29 ++++++++++++++++++++++++----- tests/test_builder.py | 33 +++++++++++++++++++++++++++++++++ tests/test_locals.py | 29 ----------------------------- 3 files changed, 57 insertions(+), 34 deletions(-) delete mode 100644 tests/test_locals.py diff --git a/astroid/rebuilder.py b/astroid/rebuilder.py index 27f34ce846..743cc83d25 100644 --- a/astroid/rebuilder.py +++ b/astroid/rebuilder.py @@ -61,7 +61,11 @@ def __init__( self._manager = manager self._data = data.split("\n") if data else None self._global_names: list[dict[str, list[nodes.Global]]] = [] - self._nonlocal_names: list[dict[str, list[nodes.Nonlocal]]] = [] + # In _nonlocal_names, + # what we save is the function where the variable is created, + # rather than nodes.Nonlocal. + # We don't really need the Nonlocal statement. + self._nonlocal_names: list[dict[str, nodes.FunctionDef]] = [] self._import_from_nodes: list[nodes.ImportFrom] = [] self._delayed_assattr: list[nodes.AssignAttr] = [] self._visit_meths: dict[type[ast.AST], Callable[[ast.AST, NodeNG], NodeNG]] = {} @@ -453,7 +457,8 @@ def _save_assignment(self, node: nodes.AssignName | nodes.DelName) -> None: if self._global_names and node.name in self._global_names[-1]: node.root().set_local(node.name, node) elif self._nonlocal_names and node.name in self._nonlocal_names[-1]: - node.root().set_local(node.name, node) + function_def = self._nonlocal_names[-1][node.name] + function_def.set_local(node.name, node) else: assert node.parent assert node.name @@ -1396,9 +1401,23 @@ def visit_nonlocal(self, node: ast.Nonlocal, parent: NodeNG) -> nodes.Nonlocal: end_col_offset=node.end_col_offset, parent=parent, ) - if self._nonlocal_names: - for name in node.names: - self._nonlocal_names[-1].setdefault(name, []).append(newnode) + names = set(newnode.names) + # Go through the tree and find where those names are created + scope = newnode + while len(names) != 0: + scope = scope.parent + if not scope: + # It's not inside a nested function or there are no variables with that name. + # Just ignore it as visit_global does when global is used in module scope. + break + if isinstance(scope, nodes.FunctionDef): + found = [] + for name in names: + if name in scope.locals: + found.append(name) + self._nonlocal_names[-1][name] = scope + for name in found: + names.remove(name) return newnode def visit_constant(self, node: ast.Constant, parent: NodeNG) -> nodes.Const: diff --git a/tests/test_builder.py b/tests/test_builder.py index 9de7f16ba7..41c9e7870e 100644 --- a/tests/test_builder.py +++ b/tests/test_builder.py @@ -716,6 +716,39 @@ def test_type_comments_without_content(self) -> None: ) assert node + def test_locals_with_global_and_nonlocal(self) -> None: + module = builder.parse( + """ + x1 = 1 # Line 2 + def f1(): # Line 3 + x2 = 2 # Line 4 + def f2(): # Line 5 + global x1 # Line 6 + nonlocal x2 # Line 7 + x1 = 1 # Line 8 + x2 = 2 # Line 9 + x3 = 3 # Line 10 + """ + ) + self.assertSetEqual(set(module.locals), {"x1", "f1"}) + x1 = module.locals["x1"] + f1 = module.locals["f1"][0] + self.assertEqual(len(x1), 2) + self.assertEqual(x1[0].lineno, 2) + self.assertEqual(x1[1].lineno, 8) + + self.assertSetEqual(set(f1.locals), {"x2", "f2"}) + x2 = f1.locals["x2"] + f2 = f1.locals["f2"][0] + self.assertEqual(len(x2), 2) + self.assertEqual(x2[0].lineno, 4) + self.assertEqual(x2[1].lineno, 9) + + self.assertSetEqual(set(f2.locals), {"x3"}) + x3 = f2.locals["x3"] + self.assertEqual(len(x3), 1) + self.assertEqual(x3[0].lineno, 10) + class FileBuildTest(unittest.TestCase): def setUp(self) -> None: diff --git a/tests/test_locals.py b/tests/test_locals.py deleted file mode 100644 index bf98e2ee88..0000000000 --- a/tests/test_locals.py +++ /dev/null @@ -1,29 +0,0 @@ -# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html -# For details: https://github.com/pylint-dev/astroid/blob/main/LICENSE -# Copyright (c) https://github.com/pylint-dev/astroid/blob/main/CONTRIBUTORS.txt - -import unittest - -from astroid.nodes.scoped_nodes.scoped_nodes import FunctionDef - -from astroid import Uninferable, builder, extract_node, nodes -from astroid.exceptions import InferenceError - - -class TestLocals(unittest.TestCase): - def test(self) -> None: - module = builder.parse( - """ - x1 = 1 - def f1(): - x2 = 2 - def f2(): - global x1 - nonlocal x2 - x1 = 1 - x2 = 2 - x3 = 3 - """ - ) - x = module.locals["f1"][0].locals["f2"][0].locals - pass From 11db5e270bf552ecd0859f71396404d45fbd7907 Mon Sep 17 00:00:00 2001 From: yueyinqiu Date: Mon, 21 Oct 2024 11:29:04 +0800 Subject: [PATCH 3/4] Update rebuilder.py --- astroid/rebuilder.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/astroid/rebuilder.py b/astroid/rebuilder.py index 743cc83d25..5c74ff8bec 100644 --- a/astroid/rebuilder.py +++ b/astroid/rebuilder.py @@ -61,10 +61,8 @@ def __init__( self._manager = manager self._data = data.split("\n") if data else None self._global_names: list[dict[str, list[nodes.Global]]] = [] - # In _nonlocal_names, - # what we save is the function where the variable is created, - # rather than nodes.Nonlocal. - # We don't really need the Nonlocal statement. + # In _nonlocal_names saves the FunctionDef where the variable is created, + # rather than Nonlocal, since we don't really need the Nonlocal statement. self._nonlocal_names: list[dict[str, nodes.FunctionDef]] = [] self._import_from_nodes: list[nodes.ImportFrom] = [] self._delayed_assattr: list[nodes.AssignAttr] = [] @@ -1403,7 +1401,7 @@ def visit_nonlocal(self, node: ast.Nonlocal, parent: NodeNG) -> nodes.Nonlocal: ) names = set(newnode.names) # Go through the tree and find where those names are created - scope = newnode + scope: nodes.NodeNG = newnode while len(names) != 0: scope = scope.parent if not scope: From 8422de0bd9caa777852edefa7cb233a06ff5f9f9 Mon Sep 17 00:00:00 2001 From: yueyinqiu Date: Mon, 21 Oct 2024 11:31:49 +0800 Subject: [PATCH 4/4] Update rebuilder.py --- astroid/rebuilder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/astroid/rebuilder.py b/astroid/rebuilder.py index 5c74ff8bec..92d2127787 100644 --- a/astroid/rebuilder.py +++ b/astroid/rebuilder.py @@ -1401,7 +1401,7 @@ def visit_nonlocal(self, node: ast.Nonlocal, parent: NodeNG) -> nodes.Nonlocal: ) names = set(newnode.names) # Go through the tree and find where those names are created - scope: nodes.NodeNG = newnode + scope = newnode while len(names) != 0: scope = scope.parent if not scope: