From 1ba8e6187564722f9b3d2b7b63d4ce954a72a228 Mon Sep 17 00:00:00 2001 From: wooly18 <97658409+wooly18@users.noreply.github.com> Date: Sat, 18 Jan 2025 01:08:22 +0800 Subject: [PATCH] [`flake8-comprehensions`] strip parentheses around generators in `unnecessary-generator-set` (`C401`) (#15553) ## Summary Fixes parentheses not being stripped in C401. Pretty much the same as #11607 which fixed it for C400. ## Test Plan `cargo nextest run` --- .../fixtures/flake8_comprehensions/C401.py | 6 +- .../rules/unnecessary_generator_set.rs | 31 ++++++- ...8_comprehensions__tests__C401_C401.py.snap | 85 ++++++++++++++++--- 3 files changed, 109 insertions(+), 13 deletions(-) diff --git a/crates/ruff_linter/resources/test/fixtures/flake8_comprehensions/C401.py b/crates/ruff_linter/resources/test/fixtures/flake8_comprehensions/C401.py index 6f54b249a6a11..e6e488312d59c 100644 --- a/crates/ruff_linter/resources/test/fixtures/flake8_comprehensions/C401.py +++ b/crates/ruff_linter/resources/test/fixtures/flake8_comprehensions/C401.py @@ -1,4 +1,4 @@ -# Cannot conbime with C416. Should use set comprehension here. +# Cannot combine with C416. Should use set comprehension here. even_nums = set(2 * x for x in range(3)) odd_nums = set( 2 * x + 1 for x in range(3) @@ -21,6 +21,10 @@ def f(x): print(f"{set(a for a in 'abc') - set(a for a in 'ab')}") print(f"{ set(a for a in 'abc') - set(a for a in 'ab') }") +# Strip parentheses from inner generators. +set((2 * x for x in range(3))) +set(((2 * x for x in range(3)))) +set((((2 * x for x in range(3))))) # Not built-in set. def set(*args, **kwargs): diff --git a/crates/ruff_linter/src/rules/flake8_comprehensions/rules/unnecessary_generator_set.rs b/crates/ruff_linter/src/rules/flake8_comprehensions/rules/unnecessary_generator_set.rs index 514eee6198b75..59803b0c03504 100644 --- a/crates/ruff_linter/src/rules/flake8_comprehensions/rules/unnecessary_generator_set.rs +++ b/crates/ruff_linter/src/rules/flake8_comprehensions/rules/unnecessary_generator_set.rs @@ -2,6 +2,7 @@ use ruff_diagnostics::{AlwaysFixableViolation, Diagnostic, Edit, Fix}; use ruff_macros::{derive_message_formats, ViolationMetadata}; use ruff_python_ast as ast; use ruff_python_ast::comparable::ComparableExpr; +use ruff_python_ast::parenthesize::parenthesized_range; use ruff_python_ast::ExprGenerator; use ruff_text_size::{Ranged, TextSize}; @@ -27,12 +28,14 @@ use super::helpers; /// ```python /// set(f(x) for x in foo) /// set(x for x in foo) +/// set((x for x in foo)) /// ``` /// /// Use instead: /// ```python /// {f(x) for x in foo} /// set(foo) +/// set(foo) /// ``` /// /// ## Fix safety @@ -74,7 +77,10 @@ pub(crate) fn unnecessary_generator_set(checker: &mut Checker, call: &ast::ExprC }; let ast::Expr::Generator(ExprGenerator { - elt, generators, .. + elt, + generators, + parenthesized, + .. }) = argument else { return; @@ -126,7 +132,28 @@ pub(crate) fn unnecessary_generator_set(checker: &mut Checker, call: &ast::ExprC call.end(), ); - Fix::unsafe_edits(call_start, [call_end]) + // Remove the inner parentheses, if the expression is a generator. The easiest way to do + // this reliably is to use the printer. + if *parenthesized { + // The generator's range will include the innermost parentheses, but it could be + // surrounded by additional parentheses. + let range = parenthesized_range( + argument.into(), + (&call.arguments).into(), + checker.comment_ranges(), + checker.locator().contents(), + ) + .unwrap_or(argument.range()); + + // The generator always parenthesizes the expression; trim the parentheses. + let generator = checker.generator().expr(argument); + let generator = generator[1..generator.len() - 1].to_string(); + + let replacement = Edit::range_replacement(generator, range); + Fix::unsafe_edits(call_start, [call_end, replacement]) + } else { + Fix::unsafe_edits(call_start, [call_end]) + } }; checker.diagnostics.push(diagnostic.with_fix(fix)); } diff --git a/crates/ruff_linter/src/rules/flake8_comprehensions/snapshots/ruff_linter__rules__flake8_comprehensions__tests__C401_C401.py.snap b/crates/ruff_linter/src/rules/flake8_comprehensions/snapshots/ruff_linter__rules__flake8_comprehensions__tests__C401_C401.py.snap index 57ce5fc2b5de9..87579feb4ba9f 100644 --- a/crates/ruff_linter/src/rules/flake8_comprehensions/snapshots/ruff_linter__rules__flake8_comprehensions__tests__C401_C401.py.snap +++ b/crates/ruff_linter/src/rules/flake8_comprehensions/snapshots/ruff_linter__rules__flake8_comprehensions__tests__C401_C401.py.snap @@ -3,7 +3,7 @@ source: crates/ruff_linter/src/rules/flake8_comprehensions/mod.rs --- C401.py:2:13: C401 [*] Unnecessary generator (rewrite as a set comprehension) | -1 | # Cannot conbime with C416. Should use set comprehension here. +1 | # Cannot combine with C416. Should use set comprehension here. 2 | even_nums = set(2 * x for x in range(3)) | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ C401 3 | odd_nums = set( @@ -12,7 +12,7 @@ C401.py:2:13: C401 [*] Unnecessary generator (rewrite as a set comprehension) = help: Rewrite as a set comprehension ℹ Unsafe fix -1 1 | # Cannot conbime with C416. Should use set comprehension here. +1 1 | # Cannot combine with C416. Should use set comprehension here. 2 |-even_nums = set(2 * x for x in range(3)) 2 |+even_nums = {2 * x for x in range(3)} 3 3 | odd_nums = set( @@ -21,7 +21,7 @@ C401.py:2:13: C401 [*] Unnecessary generator (rewrite as a set comprehension) C401.py:3:12: C401 [*] Unnecessary generator (rewrite as a set comprehension) | -1 | # Cannot conbime with C416. Should use set comprehension here. +1 | # Cannot combine with C416. Should use set comprehension here. 2 | even_nums = set(2 * x for x in range(3)) 3 | odd_nums = set( | ____________^ @@ -33,7 +33,7 @@ C401.py:3:12: C401 [*] Unnecessary generator (rewrite as a set comprehension) = help: Rewrite as a set comprehension ℹ Unsafe fix -1 1 | # Cannot conbime with C416. Should use set comprehension here. +1 1 | # Cannot combine with C416. Should use set comprehension here. 2 2 | even_nums = set(2 * x for x in range(3)) 3 |-odd_nums = set( 3 |+odd_nums = { @@ -188,7 +188,7 @@ C401.py:21:10: C401 [*] Unnecessary generator (rewrite using `set()`) 21 |+print(f"{set('abc') - set(a for a in 'ab')}") 22 22 | print(f"{ set(a for a in 'abc') - set(a for a in 'ab') }") 23 23 | -24 24 | +24 24 | # Strip parentheses from inner generators. C401.py:21:34: C401 [*] Unnecessary generator (rewrite using `set()`) | @@ -208,7 +208,7 @@ C401.py:21:34: C401 [*] Unnecessary generator (rewrite using `set()`) 21 |+print(f"{set(a for a in 'abc') - set('ab')}") 22 22 | print(f"{ set(a for a in 'abc') - set(a for a in 'ab') }") 23 23 | -24 24 | +24 24 | # Strip parentheses from inner generators. C401.py:22:11: C401 [*] Unnecessary generator (rewrite using `set()`) | @@ -216,6 +216,8 @@ C401.py:22:11: C401 [*] Unnecessary generator (rewrite using `set()`) 21 | print(f"{set(a for a in 'abc') - set(a for a in 'ab')}") 22 | print(f"{ set(a for a in 'abc') - set(a for a in 'ab') }") | ^^^^^^^^^^^^^^^^^^^^^ C401 +23 | +24 | # Strip parentheses from inner generators. | = help: Rewrite using `set()` @@ -226,8 +228,8 @@ C401.py:22:11: C401 [*] Unnecessary generator (rewrite using `set()`) 22 |-print(f"{ set(a for a in 'abc') - set(a for a in 'ab') }") 22 |+print(f"{ set('abc') - set(a for a in 'ab') }") 23 23 | -24 24 | -25 25 | # Not built-in set. +24 24 | # Strip parentheses from inner generators. +25 25 | set((2 * x for x in range(3))) C401.py:22:35: C401 [*] Unnecessary generator (rewrite using `set()`) | @@ -235,6 +237,8 @@ C401.py:22:35: C401 [*] Unnecessary generator (rewrite using `set()`) 21 | print(f"{set(a for a in 'abc') - set(a for a in 'ab')}") 22 | print(f"{ set(a for a in 'abc') - set(a for a in 'ab') }") | ^^^^^^^^^^^^^^^^^^^^ C401 +23 | +24 | # Strip parentheses from inner generators. | = help: Rewrite using `set()` @@ -245,5 +249,66 @@ C401.py:22:35: C401 [*] Unnecessary generator (rewrite using `set()`) 22 |-print(f"{ set(a for a in 'abc') - set(a for a in 'ab') }") 22 |+print(f"{ set(a for a in 'abc') - set('ab') }") 23 23 | -24 24 | -25 25 | # Not built-in set. +24 24 | # Strip parentheses from inner generators. +25 25 | set((2 * x for x in range(3))) + +C401.py:25:1: C401 [*] Unnecessary generator (rewrite as a set comprehension) + | +24 | # Strip parentheses from inner generators. +25 | set((2 * x for x in range(3))) + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ C401 +26 | set(((2 * x for x in range(3)))) +27 | set((((2 * x for x in range(3))))) + | + = help: Rewrite as a set comprehension + +ℹ Unsafe fix +22 22 | print(f"{ set(a for a in 'abc') - set(a for a in 'ab') }") +23 23 | +24 24 | # Strip parentheses from inner generators. +25 |-set((2 * x for x in range(3))) + 25 |+{2 * x for x in range(3)} +26 26 | set(((2 * x for x in range(3)))) +27 27 | set((((2 * x for x in range(3))))) +28 28 | + +C401.py:26:1: C401 [*] Unnecessary generator (rewrite as a set comprehension) + | +24 | # Strip parentheses from inner generators. +25 | set((2 * x for x in range(3))) +26 | set(((2 * x for x in range(3)))) + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ C401 +27 | set((((2 * x for x in range(3))))) + | + = help: Rewrite as a set comprehension + +ℹ Unsafe fix +23 23 | +24 24 | # Strip parentheses from inner generators. +25 25 | set((2 * x for x in range(3))) +26 |-set(((2 * x for x in range(3)))) + 26 |+{2 * x for x in range(3)} +27 27 | set((((2 * x for x in range(3))))) +28 28 | +29 29 | # Not built-in set. + +C401.py:27:1: C401 [*] Unnecessary generator (rewrite as a set comprehension) + | +25 | set((2 * x for x in range(3))) +26 | set(((2 * x for x in range(3)))) +27 | set((((2 * x for x in range(3))))) + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ C401 +28 | +29 | # Not built-in set. + | + = help: Rewrite as a set comprehension + +ℹ Unsafe fix +24 24 | # Strip parentheses from inner generators. +25 25 | set((2 * x for x in range(3))) +26 26 | set(((2 * x for x in range(3)))) +27 |-set((((2 * x for x in range(3))))) + 27 |+{2 * x for x in range(3)} +28 28 | +29 29 | # Not built-in set. +30 30 | def set(*args, **kwargs):