Skip to content

Commit

Permalink
fix(cubesql): Fix CASE type with NULL values
Browse files Browse the repository at this point in the history
  • Loading branch information
MazterQyou committed Feb 5, 2024
1 parent 64a7ebd commit 67960e6
Show file tree
Hide file tree
Showing 18 changed files with 144 additions and 65 deletions.
16 changes: 8 additions & 8 deletions packages/cubejs-backend-native/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 8 additions & 8 deletions rust/cubesql/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion rust/cubesql/cubesql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ homepage = "https://cube.dev"

[dependencies]
arc-swap = "1"
datafusion = { git = 'https://github.com/cube-js/arrow-datafusion.git', rev = "a0b4a6d2953c67857a3e24343fb2cba8ce2297cd", default-features = false, features = ["regex_expressions", "unicode_expressions"] }
datafusion = { git = 'https://github.com/cube-js/arrow-datafusion.git', rev = "3c85ef6583587f5b0b037be5810e979bede9c7dc", default-features = false, features = ["regex_expressions", "unicode_expressions"] }
anyhow = "1.0"
thiserror = "1.0.50"
cubeclient = { path = "../cubeclient" }
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
---
source: cubesql/e2e/tests/postgres.rs
assertion_line: 297
expression: "self.print_query_result(res, with_description, true).await"
---
Utf8(NULL) type: 25 (text)
NULL type: 25 (text)
f32 type: 700 (float4)
f64 type: 701 (float8)
i16 type: 21 (int2)
Expand All @@ -27,8 +26,8 @@ interval_month_day_nano type: 1186 (interval)
str_arr type: 1009 (_text)
i64_arr type: 1016 (_int8)
f64_arr type: 1022 (_float8)
+------------+-------+-------+-----+-----+-----+-----+-----+-----+-----------+------------+------+----+------+---------+--------------+------------+----------------------------+---------------------+-------------------+-------------------------------+-------------+---------+-------------+
| Utf8(NULL) | f32 | f64 | i16 | u16 | i32 | u32 | i64 | u64 | bool_true | bool_false | str | d0 | d2 | d5 | d10 | date | tsmp | interval_year_month | interval_day_time | interval_month_day_nano | str_arr | i64_arr | f64_arr |
+------------+-------+-------+-----+-----+-----+-----+-----+-----+-----------+------------+------+----+------+---------+--------------+------------+----------------------------+---------------------+-------------------+-------------------------------+-------------+---------+-------------+
| NULL | 1.234 | 1.234 | 1 | 1 | 1 | 1 | 1 | 1 | true | false | test | 1 | 1.25 | 1.25000 | 1.2500000000 | 2022-04-25 | 2022-04-25 16:25:01.164774 | 1 year 1 mons | 01:30:00 | 1 year 1 mons 1 days 01:30:00 | test1,test2 | 1,2,3 | 1.2,2.3,3.4 |
+------------+-------+-------+-----+-----+-----+-----+-----+-----+-----------+------------+------+----+------+---------+--------------+------------+----------------------------+---------------------+-------------------+-------------------------------+-------------+---------+-------------+
+------+-------+-------+-----+-----+-----+-----+-----+-----+-----------+------------+------+----+------+---------+--------------+------------+----------------------------+---------------------+-------------------+-------------------------------+-------------+---------+-------------+
| NULL | f32 | f64 | i16 | u16 | i32 | u32 | i64 | u64 | bool_true | bool_false | str | d0 | d2 | d5 | d10 | date | tsmp | interval_year_month | interval_day_time | interval_month_day_nano | str_arr | i64_arr | f64_arr |
+------+-------+-------+-----+-----+-----+-----+-----+-----+-----------+------------+------+----+------+---------+--------------+------------+----------------------------+---------------------+-------------------+-------------------------------+-------------+---------+-------------+
| NULL | 1.234 | 1.234 | 1 | 1 | 1 | 1 | 1 | 1 | true | false | test | 1 | 1.25 | 1.25000 | 1.2500000000 | 2022-04-25 | 2022-04-25 16:25:01.164774 | 1 year 1 mons | 01:30:00 | 1 year 1 mons 1 days 01:30:00 | test1,test2 | 1,2,3 | 1.2,2.3,3.4 |
+------+-------+-------+-----+-----+-----+-----+-----+-----+-----------+------------+------+----+------+---------+--------------+------------+----------------------------+---------------------+-------------------+-------------------------------+-------------+---------+-------------+
10 changes: 10 additions & 0 deletions rust/cubesql/cubesql/src/compile/engine/df/coerce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ pub fn is_signed_numeric(dt: &DataType) -> bool {
| DataType::Float16
| DataType::Float32
| DataType::Float64
| DataType::Null
)
}

Expand All @@ -33,6 +34,9 @@ pub fn numerical_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<Da
}

match (lhs_type, rhs_type) {
(_, DataType::Null) => Some(lhs_type.clone()),
(DataType::Null, _) => Some(rhs_type.clone()),

Check warning on line 38 in rust/cubesql/cubesql/src/compile/engine/df/coerce.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/df/coerce.rs#L37-L38

Added lines #L37 - L38 were not covered by tests
//
(_, DataType::UInt64) => Some(DataType::UInt64),
(DataType::UInt64, _) => Some(DataType::UInt64),
//
Expand All @@ -50,6 +54,9 @@ pub fn if_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType>
}

let hack_ty = match (lhs_type, rhs_type) {
(_, DataType::Null) => Some(lhs_type.clone()),
(DataType::Null, _) => Some(rhs_type.clone()),

Check warning on line 58 in rust/cubesql/cubesql/src/compile/engine/df/coerce.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/df/coerce.rs#L58

Added line #L58 was not covered by tests
//
(DataType::Utf8, DataType::UInt64) => Some(DataType::Utf8),
(DataType::Utf8, DataType::Int64) => Some(DataType::Utf8),
//
Expand All @@ -69,6 +76,9 @@ pub fn least_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataTy
}

let hack_ty = match (lhs_type, rhs_type) {
(_, DataType::Null) => Some(lhs_type.clone()),
(DataType::Null, _) => Some(rhs_type.clone()),
//
(DataType::Utf8, DataType::UInt64) => Some(DataType::Utf8),
(DataType::Utf8, DataType::Int64) => Some(DataType::Utf8),
//
Expand Down
14 changes: 12 additions & 2 deletions rust/cubesql/cubesql/src/compile/engine/df/columar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,23 @@ use std::sync::Arc;

macro_rules! if_then_else {
($BUILDER_TYPE:ty, $ARRAY_TYPE:ty, $BOOLS:expr, $TRUE:expr, $FALSE:expr) => {{
let true_values = $TRUE
let true_values = if $TRUE.data_type() == &DataType::Null {
Arc::new(<$ARRAY_TYPE>::from(vec![None; $TRUE.len()]))
} else {
$TRUE
};
let true_values = true_values
.as_ref()
.as_any()
.downcast_ref::<$ARRAY_TYPE>()
.expect("true_values downcast failed");

let false_values = $FALSE
let false_values = if $FALSE.data_type() == &DataType::Null {
Arc::new(<$ARRAY_TYPE>::from(vec![None; $FALSE.len()]))
} else {
$FALSE
};
let false_values = false_values
.as_ref()
.as_any()
.downcast_ref::<$ARRAY_TYPE>()
Expand Down
1 change: 1 addition & 0 deletions rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1334,6 +1334,7 @@ impl CubeScanWrapperNode {
}
// ScalarValue::IntervalMonthDayNano(_) => {}
// ScalarValue::Struct(_, _) => {}
ScalarValue::Null => ("NULL".to_string(), sql_query),
x => {
return Err(DataFusionError::Internal(format!(
"Can't generate SQL for literal: {:?}",
Expand Down
34 changes: 24 additions & 10 deletions rust/cubesql/cubesql/src/compile/engine/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -395,17 +395,31 @@ pub fn create_isnull_udf() -> ScalarUDF {
Arc::new(builder.finish()) as ArrayRef
}
2 => {
if args[0].data_type() != &DataType::Utf8 || args[1].data_type() != &DataType::Utf8
{
return Err(DataFusionError::Internal(format!(
"isnull with 2 arguments supports only (Utf8, Utf8), actual: ({}, {})",
args[0].data_type(),
args[1].data_type(),
)));
}
let expr = match args[0].data_type() {
DataType::Utf8 => Arc::clone(&args[0]),
DataType::Null => cast(&args[0], &DataType::Utf8)?,
_ => {
return Err(DataFusionError::Internal(format!(

Check warning on line 402 in rust/cubesql/cubesql/src/compile/engine/udf.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/udf.rs#L402

Added line #L402 was not covered by tests
"isnull with 2 arguments supports only (Utf8, Utf8), actual: ({}, {})",
args[0].data_type(),
args[1].data_type(),

Check warning on line 405 in rust/cubesql/cubesql/src/compile/engine/udf.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/udf.rs#L404-L405

Added lines #L404 - L405 were not covered by tests
)))
}
};
let replacement = match args[1].data_type() {
DataType::Utf8 => Arc::clone(&args[1]),
DataType::Null => cast(&args[1], &DataType::Utf8)?,
_ => {
return Err(DataFusionError::Internal(format!(

Check warning on line 413 in rust/cubesql/cubesql/src/compile/engine/udf.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/udf.rs#L413

Added line #L413 was not covered by tests
"isnull with 2 arguments supports only (Utf8, Utf8), actual: ({}, {})",
args[0].data_type(),
args[1].data_type(),

Check warning on line 416 in rust/cubesql/cubesql/src/compile/engine/udf.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/udf.rs#L415-L416

Added lines #L415 - L416 were not covered by tests
)))
}
};

let exprs = downcast_string_arg!(&args[0], "expr", i32);
let replacements = downcast_string_arg!(&args[1], "replacement", i32);
let exprs = downcast_string_arg!(expr, "expr", i32);
let replacements = downcast_string_arg!(replacement, "replacement", i32);

let result = exprs
.iter()
Expand Down
27 changes: 27 additions & 0 deletions rust/cubesql/cubesql/src/compile/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21233,4 +21233,31 @@ limit
assert!(sql.contains("-(EXTRACT(YEAR FROM"));
assert!(sql.contains("* INTERVAL '1 DAY'"));
}

#[tokio::test]
async fn test_case_mixed_values_with_null() -> Result<(), CubeError> {
init_logger();

insta::assert_snapshot!(
"test_case_mixed_values_with_null",
execute_query(
"
SELECT LEFT(ACOS(
CASE i
WHEN 0 THEN NULL
ELSE (i::float / 10.0)
END
)::text, 10) AS acos
FROM (
SELECT generate_series(0, 5) AS i
) AS t
"
.to_string(),
DatabaseProtocol::PostgreSQL
)
.await?
);

Ok(())
}
}
3 changes: 1 addition & 2 deletions rust/cubesql/cubesql/src/compile/rewrite/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ use datafusion::{
JoinConstraint, JoinType, Operator,
},
physical_plan::{
aggregates::AggregateFunction, functions::BuiltinScalarFunction,
window_functions::WindowFunction,
aggregates::AggregateFunction, functions::BuiltinScalarFunction, windows::WindowFunction,
},
scalar::ScalarValue,
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::{
},
var, var_iter,
};
use datafusion::physical_plan::window_functions::WindowFunction;
use datafusion::physical_plan::windows::WindowFunction;
use egg::{EGraph, Rewrite, Subst};

impl WrapperRules {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
source: cubesql/src/compile/mod.rs
expression: "execute_query(\"select NULL, NULL, NULL\".to_string(),\n DatabaseProtocol::PostgreSQL).await?"
---
+------------+-------+-------+
| Utf8(NULL) | NULL2 | NULL3 |
+------------+-------+-------+
| NULL | NULL | NULL |
+------------+-------+-------+
+------+-------+-------+
| NULL | NULL2 | NULL3 |
+------+-------+-------+
| NULL | NULL | NULL |
+------+-------+-------+
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
---
source: cubesql/src/compile/mod.rs
expression: "execute_query(\"\n SELECT LEFT(ACOS(\n CASE i\n WHEN 0 THEN NULL\n ELSE (i::float / 10.0)\n END\n )::text, 10) AS acos\n FROM (\n SELECT generate_series(0, 5) AS i\n ) AS t\n \".to_string(),\n DatabaseProtocol::PostgreSQL).await?"
---
+------------+
| acos |
+------------+
| NULL |
| 1.47062890 |
| 1.36943840 |
| 1.26610367 |
| 1.15927948 |
| 1.04719755 |
+------------+
Loading

0 comments on commit 67960e6

Please sign in to comment.