From 94c772a19c10e22007c40f8b4a218acccf386146 Mon Sep 17 00:00:00 2001 From: Ted Conbeer Date: Sun, 17 Dec 2023 04:02:50 -0700 Subject: [PATCH] fix: detect multiline strings inside jinja tags --- CHANGELOG.md | 3 +++ src/sqlfmt/jinjafmt.py | 27 ++++++++++++++++--- .../302_jinjafmt_multiline_str.sql | 18 +++++++++++++ .../test_general_formatting.py | 1 + tests/unit_tests/test_jinjafmt.py | 13 +++++++++ 5 files changed, 59 insertions(+), 3 deletions(-) create mode 100644 tests/data/preformatted/302_jinjafmt_multiline_str.sql diff --git a/CHANGELOG.md b/CHANGELOG.md index 878a76ab..272df996 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,9 @@ All notable changes to this project will be documented in this file. ## [Unreleased] +### Bug Fixes + +- Fixes a bug where extra indentation was added inside multiline jinja tags if those jinja tags contained a python multiline string ([#536](https://github.com/tconbeer/sqlfmt/issues/500) - thank you [@yassun7010](https://github.com/yassun7010)!). ## [0.21.0] - 2023-10-20 ### Bug Fixes diff --git a/src/sqlfmt/jinjafmt.py b/src/sqlfmt/jinjafmt.py index 19bddd0b..a0b5942d 100644 --- a/src/sqlfmt/jinjafmt.py +++ b/src/sqlfmt/jinjafmt.py @@ -1,10 +1,11 @@ +import ast import keyword import re from dataclasses import dataclass, field from importlib import import_module from itertools import chain, product from types import ModuleType -from typing import Dict, List, NamedTuple, Optional, Tuple +from typing import Dict, List, MutableSet, NamedTuple, Optional, Tuple from sqlfmt.line import Line from sqlfmt.mode import Mode @@ -259,6 +260,7 @@ def _multiline_str(self) -> str: will already be indented to the proper depth (because of the Line). """ indent = " " * 4 * (self.depth[0] + self.depth[1]) + no_indent_lines = self._find_multiline_python_str_lines() code_lines = iter(self.code.splitlines(keepends=False)) if self.verb: @@ -268,8 +270,10 @@ def _multiline_str(self) -> str: lines = [f"{self.opening_marker}"] extra_indent = " " * 4 - for code_line in code_lines: - lines.append(f"{indent}{extra_indent}{code_line}") + for i, code_line in enumerate(code_lines, start=1 if self.verb else 0): + lines.append( + f"{indent}{'' if i in no_indent_lines else extra_indent}{code_line}" + ) if self.verb: lines[-1] = f"{indent}{lines[-1].lstrip()} {self.closing_marker}" @@ -281,6 +285,23 @@ def _multiline_str(self) -> str: def _basic_str(self) -> str: return f"{self.opening_marker} {self.verb}{self.code} {self.closing_marker}" + def _find_multiline_python_str_lines(self) -> MutableSet[int]: + # we don't have to worry about syntax errors here because black has already + # run on this code. + tree = ast.parse(self.code, mode="eval") + + line_indicies: MutableSet[int] = set() + for node in ast.walk(tree): + if ( + isinstance(node, ast.Constant) + and isinstance(node.value, str) + and "\n" in node.value + and node.end_lineno is not None + ): + line_indicies |= set(range(node.lineno, node.end_lineno)) + + return line_indicies + def _remove_trailing_comma(self) -> None: """ dbt Jinja doesn't allow trailing commas in macro definitions. Mutates diff --git a/tests/data/preformatted/302_jinjafmt_multiline_str.sql b/tests/data/preformatted/302_jinjafmt_multiline_str.sql new file mode 100644 index 00000000..a124d3be --- /dev/null +++ b/tests/data/preformatted/302_jinjafmt_multiline_str.sql @@ -0,0 +1,18 @@ +{{ + config( + materialized="incremental", + pre_hook=""" + delete from + dwh.user as t using ( + select distinct campaign_name, date + from datalake.conversion + where date_part = date('{{ execution_date }}') + ) as s + where + t.campaign_name = s.campaign_name + and to_date(t.imported_at) <= s.date_part + """, + ) +}} + +select campaign_name, date_part, count(distinct user_id) as users diff --git a/tests/functional_tests/test_general_formatting.py b/tests/functional_tests/test_general_formatting.py index 93287e8a..2bf42d12 100644 --- a/tests/functional_tests/test_general_formatting.py +++ b/tests/functional_tests/test_general_formatting.py @@ -16,6 +16,7 @@ "preformatted/006_fmt_off_447.sql", "preformatted/007_fmt_off_comments.sql", "preformatted/301_multiline_jinjafmt.sql", + "preformatted/302_jinjafmt_multiline_str.sql", "preformatted/400_create_table.sql", "unformatted/100_select_case.sql", "unformatted/101_multiline.sql", diff --git a/tests/unit_tests/test_jinjafmt.py b/tests/unit_tests/test_jinjafmt.py index 5d191a3a..dc78ff58 100644 --- a/tests/unit_tests/test_jinjafmt.py +++ b/tests/unit_tests/test_jinjafmt.py @@ -481,3 +481,16 @@ def test_preprocess_and_postprocess_are_inverse_ops(source_string: str) -> None: assert BlackWrapper._postprocess_string( *BlackWrapper._preprocess_string(source_string) ).replace(" ", "") == source_string.replace(" ", "") + +@pytest.mark.parametrize( + "source_string", + [ + """{{\n config(\n foo="bar",\n )\n}}""", + '''{{\n config(\n foo="""\n\nbar\n\n""",\n )\n}}''', + ] +) +def test_multiline_str(source_string: str) -> None: + tag = JinjaTag.from_string(source_string=source_string, depth=(0, 0)) + tag.code, tag.is_blackened = BlackWrapper().format_string(source_string=tag.code, max_length=88) + assert tag.is_blackened + assert str(tag) == source_string