From 1dfa10d7128841f24c5d94cd1c5bdd2c742ff9de Mon Sep 17 00:00:00 2001 From: Mikhail Cheshkov Date: Tue, 25 Feb 2025 12:31:22 +0200 Subject: [PATCH] fix(cubesql): Generate typed null literals (#9238) This is to avoid expression like SUM(NULL), which are ambiguous in PostgreSQL --- .../cubesql/src/compile/engine/df/wrapper.rs | 147 ++++++++++++++---- .../cubesql/src/compile/test/test_wrapper.rs | 39 +++++ 2 files changed, 153 insertions(+), 33 deletions(-) diff --git a/rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs b/rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs index e0cd241c238c7..afae6974c0b13 100644 --- a/rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs +++ b/rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs @@ -511,7 +511,7 @@ pub struct SqlGenerationResult { static DATE_PART_REGEX: LazyLock = LazyLock::new(|| Regex::new("^[A-Za-z_ ]+$").unwrap()); macro_rules! generate_sql_for_timestamp { - (@generic $value:ident, $value_block:expr, $sql_generator:expr, $sql_query:expr) => { + (@generic $literal:ident, $value:ident, $value_block:expr, $sql_generator:expr, $sql_query:expr) => { if let Some($value) = $value { let value = $value_block.to_rfc3339_opts(SecondsFormat::Millis, true); ( @@ -527,27 +527,27 @@ macro_rules! generate_sql_for_timestamp { $sql_query, ) } else { - ("NULL".to_string(), $sql_query) + (Self::generate_null_for_literal($sql_generator, &$literal)?, $sql_query) } }; - ($value:ident, timestamp, $sql_generator:expr, $sql_query:expr) => { + ($literal:ident, $value:ident, timestamp, $sql_generator:expr, $sql_query:expr) => { generate_sql_for_timestamp!( - @generic $value, { Utc.timestamp_opt($value as i64, 0).unwrap() }, $sql_generator, $sql_query + @generic $literal, $value, { Utc.timestamp_opt($value as i64, 0).unwrap() }, $sql_generator, $sql_query ) }; - ($value:ident, timestamp_millis_opt, $sql_generator:expr, $sql_query:expr) => { + ($literal:ident, $value:ident, timestamp_millis_opt, $sql_generator:expr, $sql_query:expr) => { generate_sql_for_timestamp!( - @generic $value, { Utc.timestamp_millis_opt($value as i64).unwrap() }, $sql_generator, $sql_query + @generic $literal, $value, { Utc.timestamp_millis_opt($value as i64).unwrap() }, $sql_generator, $sql_query ) }; - ($value:ident, timestamp_micros, $sql_generator:expr, $sql_query:expr) => { + ($literal:ident, $value:ident, timestamp_micros, $sql_generator:expr, $sql_query:expr) => { generate_sql_for_timestamp!( - @generic $value, { Utc.timestamp_micros($value as i64).unwrap() }, $sql_generator, $sql_query + @generic $literal, $value, { Utc.timestamp_micros($value as i64).unwrap() }, $sql_generator, $sql_query ) }; - ($value:ident, $method:ident, $sql_generator:expr, $sql_query:expr) => { + ($literal:ident, $value:ident, $method:ident, $sql_generator:expr, $sql_query:expr) => { generate_sql_for_timestamp!( - @generic $value, { Utc.$method($value as i64) }, $sql_generator, $sql_query + @generic $literal, $value, { Utc.$method($value as i64) }, $sql_generator, $sql_query ) }; } @@ -1606,6 +1606,27 @@ impl CubeScanWrapperNode { .map_err(|e| DataFusionError::Internal(format!("Can't generate SQL for type: {}", e))) } + fn generate_typed_null( + sql_generator: Arc, + data_type: Option, + ) -> result::Result { + let Some(data_type) = data_type else { + return Ok("NULL".to_string()); + }; + + let sql_type = Self::generate_sql_type(sql_generator.clone(), data_type)?; + let result = Self::generate_sql_cast_expr(sql_generator, "NULL".to_string(), sql_type)?; + Ok(result) + } + + fn generate_null_for_literal( + sql_generator: Arc, + value: &ScalarValue, + ) -> result::Result { + let data_type = value.get_datatype(); + Self::generate_typed_null(sql_generator, Some(data_type)) + } + /// This function is async to be able to call to JS land, /// in case some SQL generation could not be done through Jinja pub fn generate_sql_for_expr<'ctx>( @@ -2083,15 +2104,25 @@ impl CubeScanWrapperNode { )) }) }) - .unwrap_or(Ok("NULL".to_string()))?, + .transpose()? + .map_or_else( + || Self::generate_null_for_literal(sql_generator, &literal), + Ok, + )?, sql_query, ), ScalarValue::Float32(f) => ( - f.map(|f| format!("{}", f)).unwrap_or("NULL".to_string()), + f.map(|f| format!("{f}")).map_or_else( + || Self::generate_null_for_literal(sql_generator, &literal), + Ok, + )?, sql_query, ), ScalarValue::Float64(f) => ( - f.map(|f| format!("{}", f)).unwrap_or("NULL".to_string()), + f.map(|f| format!("{f}")).map_or_else( + || Self::generate_null_for_literal(sql_generator, &literal), + Ok, + )?, sql_query, ), ScalarValue::Decimal128(x, precision, scale) => { @@ -2111,41 +2142,65 @@ impl CubeScanWrapperNode { data_type, )? } else { - "NULL".to_string() + Self::generate_null_for_literal(sql_generator, &literal)? }, sql_query, ) } ScalarValue::Int8(x) => ( - x.map(|x| format!("{}", x)).unwrap_or("NULL".to_string()), + x.map(|x| format!("{x}")).map_or_else( + || Self::generate_null_for_literal(sql_generator, &literal), + Ok, + )?, sql_query, ), ScalarValue::Int16(x) => ( - x.map(|x| format!("{}", x)).unwrap_or("NULL".to_string()), + x.map(|x| format!("{x}")).map_or_else( + || Self::generate_null_for_literal(sql_generator, &literal), + Ok, + )?, sql_query, ), ScalarValue::Int32(x) => ( - x.map(|x| format!("{}", x)).unwrap_or("NULL".to_string()), + x.map(|x| format!("{x}")).map_or_else( + || Self::generate_null_for_literal(sql_generator, &literal), + Ok, + )?, sql_query, ), ScalarValue::Int64(x) => ( - x.map(|x| format!("{}", x)).unwrap_or("NULL".to_string()), + x.map(|x| format!("{x}")).map_or_else( + || Self::generate_null_for_literal(sql_generator, &literal), + Ok, + )?, sql_query, ), ScalarValue::UInt8(x) => ( - x.map(|x| format!("{}", x)).unwrap_or("NULL".to_string()), + x.map(|x| format!("{x}")).map_or_else( + || Self::generate_null_for_literal(sql_generator, &literal), + Ok, + )?, sql_query, ), ScalarValue::UInt16(x) => ( - x.map(|x| format!("{}", x)).unwrap_or("NULL".to_string()), + x.map(|x| format!("{x}")).map_or_else( + || Self::generate_null_for_literal(sql_generator, &literal), + Ok, + )?, sql_query, ), ScalarValue::UInt32(x) => ( - x.map(|x| format!("{}", x)).unwrap_or("NULL".to_string()), + x.map(|x| format!("{x}")).map_or_else( + || Self::generate_null_for_literal(sql_generator, &literal), + Ok, + )?, sql_query, ), ScalarValue::UInt64(x) => ( - x.map(|x| format!("{}", x)).unwrap_or("NULL".to_string()), + x.map(|x| format!("{x}")).map_or_else( + || Self::generate_null_for_literal(sql_generator, &literal), + Ok, + )?, sql_query, ), ScalarValue::Utf8(x) => { @@ -2153,7 +2208,10 @@ impl CubeScanWrapperNode { let param_index = sql_query.add_value(x); (format!("${}$", param_index), sql_query) } else { - ("NULL".into(), sql_query) + ( + Self::generate_typed_null(sql_generator, Some(DataType::Utf8))?, + sql_query, + ) } } // ScalarValue::LargeUtf8(_) => {} @@ -2194,42 +2252,54 @@ impl CubeScanWrapperNode { sql_query, ) } else { - ("NULL".to_string(), sql_query) + ( + Self::generate_null_for_literal(sql_generator, &literal)?, + sql_query, + ) } } // ScalarValue::Date64(_) => {} // generate_sql_for_timestamp will call Utc constructors, so only support UTC zone for now // DataFusion can return "UTC" for stuff like `NOW()` during constant folding - ScalarValue::TimestampSecond(s, tz) + ScalarValue::TimestampSecond(s, ref tz) if matches!(tz.as_deref(), None | Some("UTC")) => { - generate_sql_for_timestamp!(s, timestamp, sql_generator, sql_query) + generate_sql_for_timestamp!( + literal, + s, + timestamp, + sql_generator, + sql_query + ) } - ScalarValue::TimestampMillisecond(ms, tz) + ScalarValue::TimestampMillisecond(ms, ref tz) if matches!(tz.as_deref(), None | Some("UTC")) => { generate_sql_for_timestamp!( + literal, ms, timestamp_millis_opt, sql_generator, sql_query ) } - ScalarValue::TimestampMicrosecond(ms, tz) + ScalarValue::TimestampMicrosecond(ms, ref tz) if matches!(tz.as_deref(), None | Some("UTC")) => { generate_sql_for_timestamp!( + literal, ms, timestamp_micros, sql_generator, sql_query ) } - ScalarValue::TimestampNanosecond(nanoseconds, tz) + ScalarValue::TimestampNanosecond(nanoseconds, ref tz) if matches!(tz.as_deref(), None | Some("UTC")) => { generate_sql_for_timestamp!( + literal, nanoseconds, timestamp_nanos, sql_generator, @@ -2253,7 +2323,10 @@ impl CubeScanWrapperNode { sql_query, ) } else { - ("NULL".to_string(), sql_query) + ( + Self::generate_null_for_literal(sql_generator, &literal)?, + sql_query, + ) } } ScalarValue::IntervalDayTime(x) => { @@ -2263,7 +2336,10 @@ impl CubeScanWrapperNode { let generated_sql = decomposed.generate_interval_sql(&templates)?; (generated_sql, sql_query) } else { - ("NULL".to_string(), sql_query) + ( + Self::generate_null_for_literal(sql_generator, &literal)?, + sql_query, + ) } } ScalarValue::IntervalMonthDayNano(x) => { @@ -2273,11 +2349,16 @@ impl CubeScanWrapperNode { let generated_sql = decomposed.generate_interval_sql(&templates)?; (generated_sql, sql_query) } else { - ("NULL".to_string(), sql_query) + ( + Self::generate_null_for_literal(sql_generator, &literal)?, + sql_query, + ) } } // ScalarValue::Struct(_, _) => {} - ScalarValue::Null => ("NULL".to_string(), sql_query), + ScalarValue::Null => { + (Self::generate_typed_null(sql_generator, None)?, sql_query) + } x => { return Err(DataFusionError::Internal(format!( "Can't generate SQL for literal: {:?}", diff --git a/rust/cubesql/cubesql/src/compile/test/test_wrapper.rs b/rust/cubesql/cubesql/src/compile/test/test_wrapper.rs index d4786c4d06b84..489062c800272 100644 --- a/rust/cubesql/cubesql/src/compile/test/test_wrapper.rs +++ b/rust/cubesql/cubesql/src/compile/test/test_wrapper.rs @@ -1559,3 +1559,42 @@ async fn wrapper_cast_limit_explicit_members() { assert_eq!(request.measures.unwrap().len(), 1); assert_eq!(request.dimensions.unwrap().len(), 0); } + +#[tokio::test] +async fn wrapper_typed_null() { + if !Rewriter::sql_push_down_enabled() { + return; + } + init_testing_logger(); + + let query_plan = convert_select_to_query_plan( + // language=PostgreSQL + r#" + SELECT + dim_str0, + AVG(avgPrice), + CASE + WHEN SUM((NULLIF(0.0, 0.0))) IS NOT NULL THEN SUM((NULLIF(0.0, 0.0))) + ELSE 0 + END + FROM MultiTypeCube + GROUP BY 1 + ;"# + .to_string(), + DatabaseProtocol::PostgreSQL, + ) + .await; + + let physical_plan = query_plan.as_physical_plan().await.unwrap(); + println!( + "Physical plan: {}", + displayable(physical_plan.as_ref()).indent() + ); + + assert!(query_plan + .as_logical_plan() + .find_cube_scan_wrapped_sql() + .wrapped_sql + .sql + .contains("SUM(CAST(NULL AS DOUBLE))")); +}