diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index ae298c08dc..9f7a65381f 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -1224,6 +1224,7 @@ def with_( append: bool = True, dialect: DialectType = None, copy: bool = True, + scalar: bool = False, **opts, ) -> Q: """ @@ -1244,6 +1245,7 @@ def with_( Otherwise, this resets the expressions. dialect: the dialect used to parse the input expression. copy: if `False`, modify this expression instance in-place. + scalar: if `True`, this is a scalar common table expression. opts: other options to use to parse the input expressions. Returns: @@ -1258,6 +1260,7 @@ def with_( append=append, dialect=dialect, copy=copy, + scalar=scalar, **opts, ) @@ -7036,11 +7039,15 @@ def _apply_cte_builder( append: bool = True, dialect: DialectType = None, copy: bool = True, + scalar: bool = False, **opts, ) -> E: alias_expression = maybe_parse(alias, dialect=dialect, into=TableAlias, **opts) - as_expression = maybe_parse(as_, dialect=dialect, **opts) - cte = CTE(this=as_expression, alias=alias_expression, materialized=materialized) + as_expression = maybe_parse(as_, dialect=dialect, copy=copy, **opts) + if scalar and not isinstance(as_expression, Subquery): + # scalar CTE must be wrapped in a subquery + as_expression = Subquery(this=as_expression) + cte = CTE(this=as_expression, alias=alias_expression, materialized=materialized, scalar=scalar) return _apply_child_list_builder( cte, instance=instance, diff --git a/tests/test_build.py b/tests/test_build.py index bd0962183e..64d1e96152 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -822,6 +822,22 @@ def test_build(self): lambda: exp.union("SELECT 1", "SELECT 2", "SELECT 3", "SELECT 4"), "SELECT 1 UNION SELECT 2 UNION SELECT 3 UNION SELECT 4", ), + ( + lambda: select("x") + .with_("var1", as_=select("x").from_("tbl2").subquery(), scalar=True) + .from_("tbl") + .where("x > var1"), + "WITH (SELECT x FROM tbl2) AS var1 SELECT x FROM tbl WHERE x > var1", + "clickhouse", + ), + ( + lambda: select("x") + .with_("var1", as_=select("x").from_("tbl2"), scalar=True) + .from_("tbl") + .where("x > var1"), + "WITH (SELECT x FROM tbl2) AS var1 SELECT x FROM tbl WHERE x > var1", + "clickhouse", + ), ]: with self.subTest(sql): self.assertEqual(expression().sql(dialect[0] if dialect else None), sql)