Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(cubesql): Generate typed null literals #9238

Merged
merged 1 commit into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 114 additions & 33 deletions rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@
static DATE_PART_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new("^[A-Za-z_ ]+$").unwrap());

macro_rules! generate_sql_for_timestamp {
(@generic $value:ident, $value_block:expr, $sql_generator:expr, $sql_query:expr) => {
(@generic $literal:ident, $value:ident, $value_block:expr, $sql_generator:expr, $sql_query:expr) => {
if let Some($value) = $value {
let value = $value_block.to_rfc3339_opts(SecondsFormat::Millis, true);
(
Expand All @@ -527,27 +527,27 @@
$sql_query,
)
} else {
("NULL".to_string(), $sql_query)
(Self::generate_null_for_literal($sql_generator, &$literal)?, $sql_query)
}
};
($value:ident, timestamp, $sql_generator:expr, $sql_query:expr) => {
($literal:ident, $value:ident, timestamp, $sql_generator:expr, $sql_query:expr) => {
generate_sql_for_timestamp!(
@generic $value, { Utc.timestamp_opt($value as i64, 0).unwrap() }, $sql_generator, $sql_query
@generic $literal, $value, { Utc.timestamp_opt($value as i64, 0).unwrap() }, $sql_generator, $sql_query
)
};
($value:ident, timestamp_millis_opt, $sql_generator:expr, $sql_query:expr) => {
($literal:ident, $value:ident, timestamp_millis_opt, $sql_generator:expr, $sql_query:expr) => {
generate_sql_for_timestamp!(
@generic $value, { Utc.timestamp_millis_opt($value as i64).unwrap() }, $sql_generator, $sql_query
@generic $literal, $value, { Utc.timestamp_millis_opt($value as i64).unwrap() }, $sql_generator, $sql_query
)
};
($value:ident, timestamp_micros, $sql_generator:expr, $sql_query:expr) => {
($literal:ident, $value:ident, timestamp_micros, $sql_generator:expr, $sql_query:expr) => {
generate_sql_for_timestamp!(
@generic $value, { Utc.timestamp_micros($value as i64).unwrap() }, $sql_generator, $sql_query
@generic $literal, $value, { Utc.timestamp_micros($value as i64).unwrap() }, $sql_generator, $sql_query
)
};
($value:ident, $method:ident, $sql_generator:expr, $sql_query:expr) => {
($literal:ident, $value:ident, $method:ident, $sql_generator:expr, $sql_query:expr) => {
generate_sql_for_timestamp!(
@generic $value, { Utc.$method($value as i64) }, $sql_generator, $sql_query
@generic $literal, $value, { Utc.$method($value as i64) }, $sql_generator, $sql_query
)
};
}
Expand Down Expand Up @@ -1606,6 +1606,27 @@
.map_err(|e| DataFusionError::Internal(format!("Can't generate SQL for type: {}", e)))
}

fn generate_typed_null(
sql_generator: Arc<dyn SqlGenerator>,
data_type: Option<DataType>,
) -> result::Result<String, DataFusionError> {
let Some(data_type) = data_type else {
return Ok("NULL".to_string());
};

let sql_type = Self::generate_sql_type(sql_generator.clone(), data_type)?;
let result = Self::generate_sql_cast_expr(sql_generator, "NULL".to_string(), sql_type)?;
Ok(result)
}

fn generate_null_for_literal(
sql_generator: Arc<dyn SqlGenerator>,
value: &ScalarValue,
) -> result::Result<String, DataFusionError> {
let data_type = value.get_datatype();
Self::generate_typed_null(sql_generator, Some(data_type))
}

/// This function is async to be able to call to JS land,
/// in case some SQL generation could not be done through Jinja
pub fn generate_sql_for_expr<'ctx>(
Expand Down Expand Up @@ -2083,15 +2104,25 @@
))
})
})
.unwrap_or(Ok("NULL".to_string()))?,
.transpose()?
.map_or_else(
|| Self::generate_null_for_literal(sql_generator, &literal),
Ok,
)?,
sql_query,
),
ScalarValue::Float32(f) => (
f.map(|f| format!("{}", f)).unwrap_or("NULL".to_string()),
f.map(|f| format!("{f}")).map_or_else(
|| Self::generate_null_for_literal(sql_generator, &literal),
Ok,
)?,

Check warning on line 2118 in rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs#L2115-L2118

Added lines #L2115 - L2118 were not covered by tests
sql_query,
),
ScalarValue::Float64(f) => (
f.map(|f| format!("{}", f)).unwrap_or("NULL".to_string()),
f.map(|f| format!("{f}")).map_or_else(
|| Self::generate_null_for_literal(sql_generator, &literal),
Ok,
)?,
sql_query,
),
ScalarValue::Decimal128(x, precision, scale) => {
Expand All @@ -2111,49 +2142,76 @@
data_type,
)?
} else {
"NULL".to_string()
Self::generate_null_for_literal(sql_generator, &literal)?

Check warning on line 2145 in rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs#L2145

Added line #L2145 was not covered by tests
},
sql_query,
)
}
ScalarValue::Int8(x) => (
x.map(|x| format!("{}", x)).unwrap_or("NULL".to_string()),
x.map(|x| format!("{x}")).map_or_else(
|| Self::generate_null_for_literal(sql_generator, &literal),
Ok,
)?,

Check warning on line 2154 in rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs#L2151-L2154

Added lines #L2151 - L2154 were not covered by tests
sql_query,
),
ScalarValue::Int16(x) => (
x.map(|x| format!("{}", x)).unwrap_or("NULL".to_string()),
x.map(|x| format!("{x}")).map_or_else(
|| Self::generate_null_for_literal(sql_generator, &literal),
Ok,
)?,

Check warning on line 2161 in rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs#L2158-L2161

Added lines #L2158 - L2161 were not covered by tests
sql_query,
),
ScalarValue::Int32(x) => (
x.map(|x| format!("{}", x)).unwrap_or("NULL".to_string()),
x.map(|x| format!("{x}")).map_or_else(
|| Self::generate_null_for_literal(sql_generator, &literal),
Ok,
)?,

Check warning on line 2168 in rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs#L2165-L2168

Added lines #L2165 - L2168 were not covered by tests
sql_query,
),
ScalarValue::Int64(x) => (
x.map(|x| format!("{}", x)).unwrap_or("NULL".to_string()),
x.map(|x| format!("{x}")).map_or_else(
|| Self::generate_null_for_literal(sql_generator, &literal),
Ok,
)?,
sql_query,
),
ScalarValue::UInt8(x) => (
x.map(|x| format!("{}", x)).unwrap_or("NULL".to_string()),
x.map(|x| format!("{x}")).map_or_else(
|| Self::generate_null_for_literal(sql_generator, &literal),
Ok,
)?,

Check warning on line 2182 in rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs#L2179-L2182

Added lines #L2179 - L2182 were not covered by tests
sql_query,
),
ScalarValue::UInt16(x) => (
x.map(|x| format!("{}", x)).unwrap_or("NULL".to_string()),
x.map(|x| format!("{x}")).map_or_else(
|| Self::generate_null_for_literal(sql_generator, &literal),
Ok,
)?,

Check warning on line 2189 in rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs#L2186-L2189

Added lines #L2186 - L2189 were not covered by tests
sql_query,
),
ScalarValue::UInt32(x) => (
x.map(|x| format!("{}", x)).unwrap_or("NULL".to_string()),
x.map(|x| format!("{x}")).map_or_else(
|| Self::generate_null_for_literal(sql_generator, &literal),
Ok,
)?,

Check warning on line 2196 in rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs#L2193-L2196

Added lines #L2193 - L2196 were not covered by tests
sql_query,
),
ScalarValue::UInt64(x) => (
x.map(|x| format!("{}", x)).unwrap_or("NULL".to_string()),
x.map(|x| format!("{x}")).map_or_else(
|| Self::generate_null_for_literal(sql_generator, &literal),
Ok,
)?,

Check warning on line 2203 in rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs#L2200-L2203

Added lines #L2200 - L2203 were not covered by tests
sql_query,
),
ScalarValue::Utf8(x) => {
if x.is_some() {
let param_index = sql_query.add_value(x);
(format!("${}$", param_index), sql_query)
} else {
("NULL".into(), sql_query)
(
Self::generate_typed_null(sql_generator, Some(DataType::Utf8))?,
sql_query,
)
}
}
// ScalarValue::LargeUtf8(_) => {}
Expand Down Expand Up @@ -2194,42 +2252,54 @@
sql_query,
)
} else {
("NULL".to_string(), sql_query)
(
Self::generate_null_for_literal(sql_generator, &literal)?,
sql_query,

Check warning on line 2257 in rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs#L2256-L2257

Added lines #L2256 - L2257 were not covered by tests
)
}
}
// ScalarValue::Date64(_) => {}

// generate_sql_for_timestamp will call Utc constructors, so only support UTC zone for now
// DataFusion can return "UTC" for stuff like `NOW()` during constant folding
ScalarValue::TimestampSecond(s, tz)
ScalarValue::TimestampSecond(s, ref tz)

Check warning on line 2265 in rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs#L2265

Added line #L2265 was not covered by tests
if matches!(tz.as_deref(), None | Some("UTC")) =>
{
generate_sql_for_timestamp!(s, timestamp, sql_generator, sql_query)
generate_sql_for_timestamp!(

Check warning on line 2268 in rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs#L2268

Added line #L2268 was not covered by tests
literal,
s,
timestamp,
sql_generator,
sql_query

Check warning on line 2273 in rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs#L2272-L2273

Added lines #L2272 - L2273 were not covered by tests
)
}
ScalarValue::TimestampMillisecond(ms, tz)
ScalarValue::TimestampMillisecond(ms, ref tz)

Check warning on line 2276 in rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs#L2276

Added line #L2276 was not covered by tests
if matches!(tz.as_deref(), None | Some("UTC")) =>
{
generate_sql_for_timestamp!(
literal,
ms,
timestamp_millis_opt,
sql_generator,
sql_query
)
}
ScalarValue::TimestampMicrosecond(ms, tz)
ScalarValue::TimestampMicrosecond(ms, ref tz)

Check warning on line 2287 in rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs#L2287

Added line #L2287 was not covered by tests
if matches!(tz.as_deref(), None | Some("UTC")) =>
{
generate_sql_for_timestamp!(
literal,
ms,
timestamp_micros,
sql_generator,
sql_query
)
}
ScalarValue::TimestampNanosecond(nanoseconds, tz)
ScalarValue::TimestampNanosecond(nanoseconds, ref tz)
if matches!(tz.as_deref(), None | Some("UTC")) =>
{
generate_sql_for_timestamp!(
literal,
nanoseconds,
timestamp_nanos,
sql_generator,
Expand All @@ -2253,7 +2323,10 @@
sql_query,
)
} else {
("NULL".to_string(), sql_query)
(
Self::generate_null_for_literal(sql_generator, &literal)?,
sql_query,

Check warning on line 2328 in rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs#L2327-L2328

Added lines #L2327 - L2328 were not covered by tests
)
}
}
ScalarValue::IntervalDayTime(x) => {
Expand All @@ -2263,7 +2336,10 @@
let generated_sql = decomposed.generate_interval_sql(&templates)?;
(generated_sql, sql_query)
} else {
("NULL".to_string(), sql_query)
(
Self::generate_null_for_literal(sql_generator, &literal)?,
sql_query,

Check warning on line 2341 in rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs#L2340-L2341

Added lines #L2340 - L2341 were not covered by tests
)
}
}
ScalarValue::IntervalMonthDayNano(x) => {
Expand All @@ -2273,11 +2349,16 @@
let generated_sql = decomposed.generate_interval_sql(&templates)?;
(generated_sql, sql_query)
} else {
("NULL".to_string(), sql_query)
(
Self::generate_null_for_literal(sql_generator, &literal)?,
sql_query,

Check warning on line 2354 in rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs#L2353-L2354

Added lines #L2353 - L2354 were not covered by tests
)
}
}
// ScalarValue::Struct(_, _) => {}
ScalarValue::Null => ("NULL".to_string(), sql_query),
ScalarValue::Null => {
(Self::generate_typed_null(sql_generator, None)?, sql_query)
}
x => {
return Err(DataFusionError::Internal(format!(
"Can't generate SQL for literal: {:?}",
Expand Down
39 changes: 39 additions & 0 deletions rust/cubesql/cubesql/src/compile/test/test_wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1559,3 +1559,42 @@ async fn wrapper_cast_limit_explicit_members() {
assert_eq!(request.measures.unwrap().len(), 1);
assert_eq!(request.dimensions.unwrap().len(), 0);
}

#[tokio::test]
async fn wrapper_typed_null() {
if !Rewriter::sql_push_down_enabled() {
return;
}
init_testing_logger();

let query_plan = convert_select_to_query_plan(
// language=PostgreSQL
r#"
SELECT
dim_str0,
AVG(avgPrice),
CASE
WHEN SUM((NULLIF(0.0, 0.0))) IS NOT NULL THEN SUM((NULLIF(0.0, 0.0)))
ELSE 0
END
FROM MultiTypeCube
GROUP BY 1
;"#
.to_string(),
DatabaseProtocol::PostgreSQL,
)
.await;

let physical_plan = query_plan.as_physical_plan().await.unwrap();
println!(
"Physical plan: {}",
displayable(physical_plan.as_ref()).indent()
);

assert!(query_plan
.as_logical_plan()
.find_cube_scan_wrapped_sql()
.wrapped_sql
.sql
.contains("SUM(CAST(NULL AS DOUBLE))"));
}
Loading