diff --git a/rewrite-python/src/main/java/org/openrewrite/python/tree/Py.java b/rewrite-python/src/main/java/org/openrewrite/python/tree/Py.java index bf763800..556c7b5b 100644 --- a/rewrite-python/src/main/java/org/openrewrite/python/tree/Py.java +++ b/rewrite-python/src/main/java/org/openrewrite/python/tree/Py.java @@ -2253,4 +2253,86 @@ public Slice withStep(@Nullable JRightPadded step) { } } + @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) + @EqualsAndHashCode(callSuper = false) + @RequiredArgsConstructor + @AllArgsConstructor(access = AccessLevel.PRIVATE) + final class StringLiteralConcatenation implements Py, Expression, TypedTree { + @Nullable + @NonFinal + transient WeakReference padding; + + @Getter + @With + @EqualsAndHashCode.Include + UUID id; + + @Getter + @With + Space prefix; + + @Getter + @With + Markers markers; + + List> literals; + + public List getLiterals() { + return JRightPadded.getElements(literals); + } + + public StringLiteralConcatenation withLiterals(List literals) { + return getPadding().withLiterals(JRightPadded.withElements(this.literals, literals)); + } + + @Override + public JavaType getType() { + return JavaType.Primitive.String; + } + + @Override + public T withType(@Nullable JavaType type) { + //noinspection unchecked + return (T) this; + } + + @Override + public

J acceptPython(PythonVisitor

v, P p) { + return v.visitStringLiteralConcatenation(this, p); + } + + @Override + @Transient + public CoordinateBuilder.Expression getCoordinates() { + return new CoordinateBuilder.Expression(this); + } + + public Padding getPadding() { + Padding p; + if (this.padding == null) { + p = new Padding(this); + this.padding = new WeakReference<>(p); + } else { + p = this.padding.get(); + if (p == null || p.t != this) { + p = new Padding(this); + this.padding = new WeakReference<>(p); + } + } + return p; + } + + @RequiredArgsConstructor + public static class Padding { + private final StringLiteralConcatenation t; + + public List> getLiterals() { + return t.literals; + } + + public StringLiteralConcatenation withLiterals(List> literals) { + return t.literals == literals ? t : new StringLiteralConcatenation(t.id, t.prefix, t.markers, literals, t.type); + } + } + } } diff --git a/rewrite/rewrite/python/_parser_visitor.py b/rewrite/rewrite/python/_parser_visitor.py index 0f3d07f8..d7b124fa 100644 --- a/rewrite/rewrite/python/_parser_visitor.py +++ b/rewrite/rewrite/python/_parser_visitor.py @@ -1294,15 +1294,31 @@ def visit_Constant(self, node): else: break - return j.Literal( - random_id(), - prefix, - Markers.EMPTY, - None if node.value is Ellipsis else node.value, - self._source[start:self._cursor], - None, - self.__map_type(node), - ) + if isinstance(node.value, str) and '\n' in node.value: + return py.StringLiteralConcatenation( + random_id(), + prefix, + Markers.EMPTY, + [self.__pad_right(j.Literal( + random_id(), + Space.EMPTY, + Markers.EMPTY, + None if node.value is Ellipsis else node.value, + self._source[start:self._cursor], + None, + self.__map_type(node), + ), Space.EMPTY)] + ) + else: + return j.Literal( + random_id(), + prefix, + Markers.EMPTY, + None if node.value is Ellipsis else node.value, + self._source[start:self._cursor], + None, + self.__map_type(node), + ) def visit_Dict(self, node): @@ -2123,6 +2139,7 @@ def __map_fstring(self, node: ast.JoinedStr, prefix: Space, tok: TokenInfo, toke # tokenizer tokens: FSTRING_START, FSTRING_MIDDLE, OP, ..., OP, FSTRING_MIDDLE, FSTRING_END parts = [] + literals = [] for value in node.values: if tok.type == token.OP and tok.string == '{': if not isinstance(value, ast.FormattedValue): @@ -2201,7 +2218,7 @@ def __map_fstring(self, node: ast.JoinedStr, prefix: Space, tok: TokenInfo, toke self._cursor += len(tok.string) + (1 if tok.string.endswith('{') or tok.string.endswith('}') else 0) if (tok := next(tokens)).type != token.FSTRING_MIDDLE: break - parts.append(j.Literal( + literals.append(self.__pad_right(j.Literal( random_id(), Space.EMPTY, Markers.EMPTY, @@ -2209,7 +2226,7 @@ def __map_fstring(self, node: ast.JoinedStr, prefix: Space, tok: TokenInfo, toke self._source[save_cursor:self._cursor], None, self.__map_type(value), - )) + ), Space.EMPTY)) if consume_end_delim: self._cursor += len(tok.string) # FSTRING_END token @@ -2217,13 +2234,21 @@ def __map_fstring(self, node: ast.JoinedStr, prefix: Space, tok: TokenInfo, toke elif tok.type == token.FSTRING_MIDDLE and len(tok.string) == 0: tok = next(tokens) - return (py.FormattedString( - random_id(), - prefix, - Markers.EMPTY, - delimiter, - parts - ), tok) + if literals: + return (py.StringLiteralConcatenation( + random_id(), + prefix, + Markers.EMPTY, + literals + ), tok) + else: + return (py.FormattedString( + random_id(), + prefix, + Markers.EMPTY, + delimiter, + parts + ), tok) def __cursor_at(self, s: str): return self._cursor < len(self._source) and (len(s) == 1 and self._source[self._cursor] == s or self._source.startswith(s, self._cursor)) diff --git a/rewrite/tests/python/all/fstring_test.py b/rewrite/tests/python/all/fstring_test.py index 7efd0762..6fe528e5 100644 --- a/rewrite/tests/python/all/fstring_test.py +++ b/rewrite/tests/python/all/fstring_test.py @@ -169,3 +169,32 @@ def test_nested_fstring_with_format_value(): def test_adjoining_expressions(): # language=python rewrite_run(python("""a = f'{1}{0}'""")) + + +def test_fstring_literal_concatenation(): + # language=python + rewrite_run( + python( + """ + a = ( + f"foo" + f"bar" + ) + """ + ) + ) + + +def test_fstring_literal_concatenation_with_comments(): + # language=python + rewrite_run( + python( + """ + a = ( + f"foo" + # comment + f"bar" + ) + """ + ) + )