Skip to content

Commit

Permalink
fix(cubesql): Generate typed null literals (#9238)
Browse files Browse the repository at this point in the history
This is to avoid expression like SUM(NULL), which are ambiguous in PostgreSQL
  • Loading branch information
mcheshkov authored Feb 25, 2025
1 parent 75095e1 commit 1dfa10d
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 33 deletions.
147 changes: 114 additions & 33 deletions rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ pub struct SqlGenerationResult {
static DATE_PART_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new("^[A-Za-z_ ]+$").unwrap());

macro_rules! generate_sql_for_timestamp {
(@generic $value:ident, $value_block:expr, $sql_generator:expr, $sql_query:expr) => {
(@generic $literal:ident, $value:ident, $value_block:expr, $sql_generator:expr, $sql_query:expr) => {
if let Some($value) = $value {
let value = $value_block.to_rfc3339_opts(SecondsFormat::Millis, true);
(
Expand All @@ -527,27 +527,27 @@ macro_rules! generate_sql_for_timestamp {
$sql_query,
)
} else {
("NULL".to_string(), $sql_query)
(Self::generate_null_for_literal($sql_generator, &$literal)?, $sql_query)
}
};
($value:ident, timestamp, $sql_generator:expr, $sql_query:expr) => {
($literal:ident, $value:ident, timestamp, $sql_generator:expr, $sql_query:expr) => {
generate_sql_for_timestamp!(
@generic $value, { Utc.timestamp_opt($value as i64, 0).unwrap() }, $sql_generator, $sql_query
@generic $literal, $value, { Utc.timestamp_opt($value as i64, 0).unwrap() }, $sql_generator, $sql_query
)
};
($value:ident, timestamp_millis_opt, $sql_generator:expr, $sql_query:expr) => {
($literal:ident, $value:ident, timestamp_millis_opt, $sql_generator:expr, $sql_query:expr) => {
generate_sql_for_timestamp!(
@generic $value, { Utc.timestamp_millis_opt($value as i64).unwrap() }, $sql_generator, $sql_query
@generic $literal, $value, { Utc.timestamp_millis_opt($value as i64).unwrap() }, $sql_generator, $sql_query
)
};
($value:ident, timestamp_micros, $sql_generator:expr, $sql_query:expr) => {
($literal:ident, $value:ident, timestamp_micros, $sql_generator:expr, $sql_query:expr) => {
generate_sql_for_timestamp!(
@generic $value, { Utc.timestamp_micros($value as i64).unwrap() }, $sql_generator, $sql_query
@generic $literal, $value, { Utc.timestamp_micros($value as i64).unwrap() }, $sql_generator, $sql_query
)
};
($value:ident, $method:ident, $sql_generator:expr, $sql_query:expr) => {
($literal:ident, $value:ident, $method:ident, $sql_generator:expr, $sql_query:expr) => {
generate_sql_for_timestamp!(
@generic $value, { Utc.$method($value as i64) }, $sql_generator, $sql_query
@generic $literal, $value, { Utc.$method($value as i64) }, $sql_generator, $sql_query
)
};
}
Expand Down Expand Up @@ -1606,6 +1606,27 @@ impl CubeScanWrapperNode {
.map_err(|e| DataFusionError::Internal(format!("Can't generate SQL for type: {}", e)))
}

fn generate_typed_null(
sql_generator: Arc<dyn SqlGenerator>,
data_type: Option<DataType>,
) -> result::Result<String, DataFusionError> {
let Some(data_type) = data_type else {
return Ok("NULL".to_string());
};

let sql_type = Self::generate_sql_type(sql_generator.clone(), data_type)?;
let result = Self::generate_sql_cast_expr(sql_generator, "NULL".to_string(), sql_type)?;
Ok(result)
}

fn generate_null_for_literal(
sql_generator: Arc<dyn SqlGenerator>,
value: &ScalarValue,
) -> result::Result<String, DataFusionError> {
let data_type = value.get_datatype();
Self::generate_typed_null(sql_generator, Some(data_type))
}

/// This function is async to be able to call to JS land,
/// in case some SQL generation could not be done through Jinja
pub fn generate_sql_for_expr<'ctx>(
Expand Down Expand Up @@ -2083,15 +2104,25 @@ impl CubeScanWrapperNode {
))
})
})
.unwrap_or(Ok("NULL".to_string()))?,
.transpose()?
.map_or_else(
|| Self::generate_null_for_literal(sql_generator, &literal),
Ok,
)?,
sql_query,
),
ScalarValue::Float32(f) => (
f.map(|f| format!("{}", f)).unwrap_or("NULL".to_string()),
f.map(|f| format!("{f}")).map_or_else(
|| Self::generate_null_for_literal(sql_generator, &literal),
Ok,
)?,
sql_query,
),
ScalarValue::Float64(f) => (
f.map(|f| format!("{}", f)).unwrap_or("NULL".to_string()),
f.map(|f| format!("{f}")).map_or_else(
|| Self::generate_null_for_literal(sql_generator, &literal),
Ok,
)?,
sql_query,
),
ScalarValue::Decimal128(x, precision, scale) => {
Expand All @@ -2111,49 +2142,76 @@ impl CubeScanWrapperNode {
data_type,
)?
} else {
"NULL".to_string()
Self::generate_null_for_literal(sql_generator, &literal)?
},
sql_query,
)
}
ScalarValue::Int8(x) => (
x.map(|x| format!("{}", x)).unwrap_or("NULL".to_string()),
x.map(|x| format!("{x}")).map_or_else(
|| Self::generate_null_for_literal(sql_generator, &literal),
Ok,
)?,
sql_query,
),
ScalarValue::Int16(x) => (
x.map(|x| format!("{}", x)).unwrap_or("NULL".to_string()),
x.map(|x| format!("{x}")).map_or_else(
|| Self::generate_null_for_literal(sql_generator, &literal),
Ok,
)?,
sql_query,
),
ScalarValue::Int32(x) => (
x.map(|x| format!("{}", x)).unwrap_or("NULL".to_string()),
x.map(|x| format!("{x}")).map_or_else(
|| Self::generate_null_for_literal(sql_generator, &literal),
Ok,
)?,
sql_query,
),
ScalarValue::Int64(x) => (
x.map(|x| format!("{}", x)).unwrap_or("NULL".to_string()),
x.map(|x| format!("{x}")).map_or_else(
|| Self::generate_null_for_literal(sql_generator, &literal),
Ok,
)?,
sql_query,
),
ScalarValue::UInt8(x) => (
x.map(|x| format!("{}", x)).unwrap_or("NULL".to_string()),
x.map(|x| format!("{x}")).map_or_else(
|| Self::generate_null_for_literal(sql_generator, &literal),
Ok,
)?,
sql_query,
),
ScalarValue::UInt16(x) => (
x.map(|x| format!("{}", x)).unwrap_or("NULL".to_string()),
x.map(|x| format!("{x}")).map_or_else(
|| Self::generate_null_for_literal(sql_generator, &literal),
Ok,
)?,
sql_query,
),
ScalarValue::UInt32(x) => (
x.map(|x| format!("{}", x)).unwrap_or("NULL".to_string()),
x.map(|x| format!("{x}")).map_or_else(
|| Self::generate_null_for_literal(sql_generator, &literal),
Ok,
)?,
sql_query,
),
ScalarValue::UInt64(x) => (
x.map(|x| format!("{}", x)).unwrap_or("NULL".to_string()),
x.map(|x| format!("{x}")).map_or_else(
|| Self::generate_null_for_literal(sql_generator, &literal),
Ok,
)?,
sql_query,
),
ScalarValue::Utf8(x) => {
if x.is_some() {
let param_index = sql_query.add_value(x);
(format!("${}$", param_index), sql_query)
} else {
("NULL".into(), sql_query)
(
Self::generate_typed_null(sql_generator, Some(DataType::Utf8))?,
sql_query,
)
}
}
// ScalarValue::LargeUtf8(_) => {}
Expand Down Expand Up @@ -2194,42 +2252,54 @@ impl CubeScanWrapperNode {
sql_query,
)
} else {
("NULL".to_string(), sql_query)
(
Self::generate_null_for_literal(sql_generator, &literal)?,
sql_query,
)
}
}
// ScalarValue::Date64(_) => {}

// generate_sql_for_timestamp will call Utc constructors, so only support UTC zone for now
// DataFusion can return "UTC" for stuff like `NOW()` during constant folding
ScalarValue::TimestampSecond(s, tz)
ScalarValue::TimestampSecond(s, ref tz)
if matches!(tz.as_deref(), None | Some("UTC")) =>
{
generate_sql_for_timestamp!(s, timestamp, sql_generator, sql_query)
generate_sql_for_timestamp!(
literal,
s,
timestamp,
sql_generator,
sql_query
)
}
ScalarValue::TimestampMillisecond(ms, tz)
ScalarValue::TimestampMillisecond(ms, ref tz)
if matches!(tz.as_deref(), None | Some("UTC")) =>
{
generate_sql_for_timestamp!(
literal,
ms,
timestamp_millis_opt,
sql_generator,
sql_query
)
}
ScalarValue::TimestampMicrosecond(ms, tz)
ScalarValue::TimestampMicrosecond(ms, ref tz)
if matches!(tz.as_deref(), None | Some("UTC")) =>
{
generate_sql_for_timestamp!(
literal,
ms,
timestamp_micros,
sql_generator,
sql_query
)
}
ScalarValue::TimestampNanosecond(nanoseconds, tz)
ScalarValue::TimestampNanosecond(nanoseconds, ref tz)
if matches!(tz.as_deref(), None | Some("UTC")) =>
{
generate_sql_for_timestamp!(
literal,
nanoseconds,
timestamp_nanos,
sql_generator,
Expand All @@ -2253,7 +2323,10 @@ impl CubeScanWrapperNode {
sql_query,
)
} else {
("NULL".to_string(), sql_query)
(
Self::generate_null_for_literal(sql_generator, &literal)?,
sql_query,
)
}
}
ScalarValue::IntervalDayTime(x) => {
Expand All @@ -2263,7 +2336,10 @@ impl CubeScanWrapperNode {
let generated_sql = decomposed.generate_interval_sql(&templates)?;
(generated_sql, sql_query)
} else {
("NULL".to_string(), sql_query)
(
Self::generate_null_for_literal(sql_generator, &literal)?,
sql_query,
)
}
}
ScalarValue::IntervalMonthDayNano(x) => {
Expand All @@ -2273,11 +2349,16 @@ impl CubeScanWrapperNode {
let generated_sql = decomposed.generate_interval_sql(&templates)?;
(generated_sql, sql_query)
} else {
("NULL".to_string(), sql_query)
(
Self::generate_null_for_literal(sql_generator, &literal)?,
sql_query,
)
}
}
// ScalarValue::Struct(_, _) => {}
ScalarValue::Null => ("NULL".to_string(), sql_query),
ScalarValue::Null => {
(Self::generate_typed_null(sql_generator, None)?, sql_query)
}
x => {
return Err(DataFusionError::Internal(format!(
"Can't generate SQL for literal: {:?}",
Expand Down
39 changes: 39 additions & 0 deletions rust/cubesql/cubesql/src/compile/test/test_wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1559,3 +1559,42 @@ async fn wrapper_cast_limit_explicit_members() {
assert_eq!(request.measures.unwrap().len(), 1);
assert_eq!(request.dimensions.unwrap().len(), 0);
}

#[tokio::test]
async fn wrapper_typed_null() {
if !Rewriter::sql_push_down_enabled() {
return;
}
init_testing_logger();

let query_plan = convert_select_to_query_plan(
// language=PostgreSQL
r#"
SELECT
dim_str0,
AVG(avgPrice),
CASE
WHEN SUM((NULLIF(0.0, 0.0))) IS NOT NULL THEN SUM((NULLIF(0.0, 0.0)))
ELSE 0
END
FROM MultiTypeCube
GROUP BY 1
;"#
.to_string(),
DatabaseProtocol::PostgreSQL,
)
.await;

let physical_plan = query_plan.as_physical_plan().await.unwrap();
println!(
"Physical plan: {}",
displayable(physical_plan.as_ref()).indent()
);

assert!(query_plan
.as_logical_plan()
.find_cube_scan_wrapped_sql()
.wrapped_sql
.sql
.contains("SUM(CAST(NULL AS DOUBLE))"));
}

0 comments on commit 1dfa10d

Please sign in to comment.