Skip to content

Commit

Permalink
wip: constant folding original expr
Browse files Browse the repository at this point in the history
  • Loading branch information
MazterQyou committed Jan 10, 2024
1 parent 0959010 commit 14bca66
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 6 deletions.
37 changes: 35 additions & 2 deletions rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use crate::{
},
CubeError,
};
use chrono::{Days, NaiveDate};
use cubeclient::models::V1LoadRequestQuery;
use datafusion::{
error::{DataFusionError, Result},
Expand All @@ -23,7 +24,7 @@ use datafusion::{
use itertools::Itertools;
use regex::{Captures, Regex};
use serde_derive::*;
use std::{any::Any, collections::HashMap, fmt, future::Future, pin::Pin, result, sync::Arc};
use std::{any::Any, collections::HashMap, convert::TryInto, fmt, future::Future, pin::Pin, result, sync::Arc};

#[derive(Debug, Clone, Deserialize)]
pub struct SqlQuery {
Expand Down Expand Up @@ -1218,7 +1219,39 @@ impl CubeScanWrapperNode {
// ScalarValue::Binary(_) => {}
// ScalarValue::LargeBinary(_) => {}
// ScalarValue::List(_, _) => {}
// ScalarValue::Date32(_) => {}
ScalarValue::Date32(x) => {
if let Some(x) = x {
let days = Days::new(x.abs().try_into().unwrap());
let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap();
let new_date = if x < 0 {
epoch.checked_sub_days(days)

Check warning on line 1227 in rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs#L1222-L1227

Added lines #L1222 - L1227 were not covered by tests
} else {
epoch.checked_add_days(days)

Check warning on line 1229 in rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs#L1229

Added line #L1229 was not covered by tests
};
let Some(new_date) = new_date else {
return Err(DataFusionError::Internal(format!("Can't generate SQL for date: day out of bounds ({})", x)));

Check warning on line 1232 in rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs#L1231-L1232

Added lines #L1231 - L1232 were not covered by tests
};
let formatted_date = new_date.format("%Y-%m-%d").to_string();

Check warning on line 1234 in rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs#L1234

Added line #L1234 was not covered by tests
(
sql_generator

Check warning on line 1236 in rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs#L1236

Added line #L1236 was not covered by tests
.get_sql_templates()
.scalar_function(
"DATE".to_string(),
vec![format!("'{}'", formatted_date)],
None,
None
).map_err(|e| {
DataFusionError::Internal(format!(

Check warning on line 1244 in rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs#L1239-L1244

Added lines #L1239 - L1244 were not covered by tests
"Can't generate SQL for date: {}",
e
))
})?,
sql_query

Check warning on line 1249 in rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs#L1249

Added line #L1249 was not covered by tests
)
} else {
("NULL".to_string(), sql_query)

Check warning on line 1252 in rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs#L1252

Added line #L1252 was not covered by tests
}
}
// ScalarValue::Date64(_) => {}
// ScalarValue::TimestampSecond(_, _) => {}
// ScalarValue::TimestampMillisecond(_, _) => {}
Expand Down
11 changes: 7 additions & 4 deletions rust/cubesql/cubesql/src/compile/rewrite/analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -865,6 +865,7 @@ impl LogicalPlanAnalysis {
|| &fun.name == "date_to_timestamp"
|| &fun.name == "interval_mul"
{
println!("EVALUATING {} UDF CONSTANT FOLDING", fun.name);
Self::eval_constant_expr(&egraph, &expr)
} else {
None
Expand Down Expand Up @@ -1188,15 +1189,17 @@ impl Analysis<LogicalPlanLanguage> for LogicalPlanAnalysis {
// TODO: ideally all constants should be aliased, but this requires
// rewrites to extract `.data.constant` instead of `literal_expr`.
let alias_name =
if c.is_null() || matches!(c, ScalarValue::Date32(_) | ScalarValue::Date64(_) | ScalarValue::Float64(_) | ScalarValue::Int64(_)) {
egraph[id]
if c.is_null() || matches!(c, ScalarValue::Date32(_) | ScalarValue::Date64(_) | ScalarValue::Float64(_) | ScalarValue::Int64(_) | ScalarValue::IntervalYearMonth(_)) {
let original_expr = &egraph[id]
.data
.original_expr
.as_ref()
.original_expr;
println!("ORIGINAL EXPR: {:#?}", original_expr);
original_expr.as_ref()
.map(|expr| expr.name(&DFSchema::empty()).unwrap())
} else {
None
};
println!("ALIAS NAME: {:?}", alias_name);
let c = c.clone();
let value = egraph.add(LogicalPlanLanguage::LiteralExprValue(LiteralExprValue(c)));
let literal_expr = egraph.add(LogicalPlanLanguage::LiteralExpr([value]));
Expand Down

0 comments on commit 14bca66

Please sign in to comment.