Skip to content

Commit

Permalink
refactor!: proof_of_sql_parser::intermediate_ast::BinaryOp with `sq…
Browse files Browse the repository at this point in the history
…lparser::ast::BinaryOp` in the proof-of-sql crate (#362)

Please be sure to look over the pull request guidelines here:
https://github.com/spaceandtimelabs/sxt-proof-of-sql/blob/main/CONTRIBUTING.md#submit-pr.

# Please go through the following checklist
- [x] The PR title and commit messages adhere to guidelines here:
https://github.com/spaceandtimelabs/sxt-proof-of-sql/blob/main/CONTRIBUTING.md.
In particular `!` is used if and only if at least one breaking change
has been introduced.
- [x] I have run the ci check script with `source
scripts/run_ci_checks.sh`.

# Rationale for this change
This PR addresses the need to replace the
`proof_of_sql_parser::intermediate_ast::BinaryOp` with the
`sqlparser::ast::BinaryOp` in the `proof-of-sql` crate as part of a
larger transition toward integrating the `sqlparser` .

This change is a subtask of issue #235, with the main goal of
streamlining the repository by switching to the `sqlparser` crate and
gradually replacing intermediary constructs like
`proof_of_sql_parser::intermediate_ast` with `sqlparser::ast`.
<!--
Why are you proposing this change? If this is already explained clearly
in the linked issue then this section is not needed.
Explaining clearly why changes are proposed helps reviewers understand
your changes and offer better suggestions for fixes.

 Example:
 Add `NestedLoopJoinExec`.
 Closes #345.

Since we added `HashJoinExec` in #323 it has been possible to do
provable inner joins. However performance is not satisfactory in some
cases. Hence we need to fix the problem by implement
`NestedLoopJoinExec` and speed up the code
 for `HashJoinExec`.
-->

# What changes are included in this PR?
- All instances of `proof_of_sql_parser::intermediate_ast::BinaryOp`
have been replaced with `sqlparser::ast::BinaryOp`
- Every usage of `BianryOp` has been updated to maintain the original
functionality, ensuring no changes to the logic or behavior.
- Any unsupported `BinaryOp` variants from `sqlparser` have been
appropriately handled using existing error handling mechanisms (i.e.,
the `Unsupported `variant in `ExpressionEvaluationError`).

<!--
There is no need to duplicate the description in the ticket here but it
is sometimes worth providing a summary of the individual changes in this
PR.

Example:
- Add `NestedLoopJoinExec`.
- Speed up `HashJoinExec`.
- Route joins to `NestedLoopJoinExec` if the outer input is sufficiently
small.
-->

# Are these changes tested?
<!--
We typically require tests for all PRs in order to:
1. Prevent the code from being accidentally broken by subsequent changes
2. Serve as another way to document the expected behavior of the code

If tests are not included in your PR, please explain why (for example,
are they covered by existing tests)?

Example:
Yes.
-->
Yes

Closes #349 
Part of #235
  • Loading branch information
iajoiner authored Nov 17, 2024
2 parents 1371386 + e337e65 commit 2b31823
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 55 deletions.
25 changes: 15 additions & 10 deletions crates/proof-of-sql/src/base/database/expression_evaluation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,20 @@ use crate::base::{
};
use alloc::{format, string::ToString, vec};
use proof_of_sql_parser::{
intermediate_ast::{BinaryOperator, Expression, Literal},
intermediate_ast::{Expression, Literal},
Identifier,
};
use sqlparser::ast::UnaryOperator;
use sqlparser::ast::{BinaryOperator, UnaryOperator};

impl<S: Scalar> OwnedTable<S> {
/// Evaluate an expression on the table.
pub fn evaluate(&self, expr: &Expression) -> ExpressionEvaluationResult<OwnedColumn<S>> {
match expr {
Expression::Column(identifier) => self.evaluate_column(identifier),
Expression::Literal(lit) => self.evaluate_literal(lit),
Expression::Binary { op, left, right } => self.evaluate_binary_expr(*op, left, right),
Expression::Binary { op, left, right } => {
self.evaluate_binary_expr(&(*op).into(), left, right)
}
Expression::Unary { op, expr } => self.evaluate_unary_expr((*op).into(), expr),
_ => Err(ExpressionEvaluationError::Unsupported {
expression: format!("Expression {expr:?} is not supported yet"),
Expand Down Expand Up @@ -84,7 +86,7 @@ impl<S: Scalar> OwnedTable<S> {

fn evaluate_binary_expr(
&self,
op: BinaryOperator,
op: &BinaryOperator,
left: &Expression,
right: &Expression,
) -> ExpressionEvaluationResult<OwnedColumn<S>> {
Expand All @@ -93,13 +95,16 @@ impl<S: Scalar> OwnedTable<S> {
match op {
BinaryOperator::And => Ok(left.element_wise_and(&right)?),
BinaryOperator::Or => Ok(left.element_wise_or(&right)?),
BinaryOperator::Equal => Ok(left.element_wise_eq(&right)?),
BinaryOperator::GreaterThanOrEqual => Ok(left.element_wise_ge(&right)?),
BinaryOperator::LessThanOrEqual => Ok(left.element_wise_le(&right)?),
BinaryOperator::Add => Ok(left.element_wise_add(&right)?),
BinaryOperator::Subtract => Ok(left.element_wise_sub(&right)?),
BinaryOperator::Eq => Ok(left.element_wise_eq(&right)?),
BinaryOperator::GtEq => Ok(left.element_wise_ge(&right)?),
BinaryOperator::LtEq => Ok(left.element_wise_le(&right)?),
BinaryOperator::Plus => Ok(left.element_wise_add(&right)?),
BinaryOperator::Minus => Ok(left.element_wise_sub(&right)?),
BinaryOperator::Multiply => Ok(left.element_wise_mul(&right)?),
BinaryOperator::Division => Ok(left.element_wise_div(&right)?),
BinaryOperator::Divide => Ok(left.element_wise_div(&right)?),
_ => Err(ExpressionEvaluationError::Unsupported {
expression: format!("Binary operator '{op}' is not supported."),
}),
}
}
}
28 changes: 18 additions & 10 deletions crates/proof-of-sql/src/sql/parse/dyn_proof_expr_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ use crate::{
};
use alloc::{borrow::ToOwned, boxed::Box, format, string::ToString};
use proof_of_sql_parser::{
intermediate_ast::{AggregationOperator, BinaryOperator, Expression, Literal},
intermediate_ast::{AggregationOperator, Expression, Literal},
posql_time::{PoSQLTimeUnit, PoSQLTimestampError},
Identifier,
};
use sqlparser::ast::UnaryOperator;
use sqlparser::ast::{BinaryOperator, UnaryOperator};

/// Builder that enables building a `proofs::sql::proof_exprs::DynProofExpr` from
/// a `proof_of_sql_parser::intermediate_ast::Expression`.
Expand Down Expand Up @@ -60,7 +60,9 @@ impl DynProofExprBuilder<'_> {
match expr {
Expression::Column(identifier) => self.visit_column(*identifier),
Expression::Literal(lit) => self.visit_literal(lit),
Expression::Binary { op, left, right } => self.visit_binary_expr(*op, left, right),
Expression::Binary { op, left, right } => {
self.visit_binary_expr(&(*op).into(), left, right)
}
Expression::Unary { op, expr } => self.visit_unary_expr((*op).into(), expr),
Expression::Aggregation { op, expr } => self.visit_aggregate_expr(*op, expr),
_ => Err(ConversionError::Unprovable {
Expand Down Expand Up @@ -146,7 +148,7 @@ impl DynProofExprBuilder<'_> {

fn visit_binary_expr(
&self,
op: BinaryOperator,
op: &BinaryOperator,
left: &Expression,
right: &Expression,
) -> Result<DynProofExpr, ConversionError> {
Expand All @@ -161,27 +163,27 @@ impl DynProofExprBuilder<'_> {
let right = self.visit_expr(right);
DynProofExpr::try_new_or(left?, right?)
}
BinaryOperator::Equal => {
BinaryOperator::Eq => {
let left = self.visit_expr(left);
let right = self.visit_expr(right);
DynProofExpr::try_new_equals(left?, right?)
}
BinaryOperator::GreaterThanOrEqual => {
BinaryOperator::GtEq => {
let left = self.visit_expr(left);
let right = self.visit_expr(right);
DynProofExpr::try_new_inequality(left?, right?, false)
}
BinaryOperator::LessThanOrEqual => {
BinaryOperator::LtEq => {
let left = self.visit_expr(left);
let right = self.visit_expr(right);
DynProofExpr::try_new_inequality(left?, right?, true)
}
BinaryOperator::Add => {
BinaryOperator::Plus => {
let left = self.visit_expr(left);
let right = self.visit_expr(right);
DynProofExpr::try_new_add(left?, right?)
}
BinaryOperator::Subtract => {
BinaryOperator::Minus => {
let left = self.visit_expr(left);
let right = self.visit_expr(right);
DynProofExpr::try_new_subtract(left?, right?)
Expand All @@ -191,9 +193,15 @@ impl DynProofExprBuilder<'_> {
let right = self.visit_expr(right);
DynProofExpr::try_new_multiply(left?, right?)
}
BinaryOperator::Division => Err(ConversionError::Unprovable {
BinaryOperator::Divide => Err(ConversionError::Unprovable {
error: format!("Binary operator {op:?} is not supported at this location"),
}),
_ => {
// Handle unsupported binary operations
Err(ConversionError::UnsupportedOperation {
message: format!("{op:?}"),
})
}
}
}

Expand Down
47 changes: 29 additions & 18 deletions crates/proof-of-sql/src/sql/parse/query_context_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ use crate::base::{
use alloc::{boxed::Box, format, string::ToString, vec::Vec};
use proof_of_sql_parser::{
intermediate_ast::{
AggregationOperator, AliasedResultExpr, BinaryOperator, Expression, Literal, OrderBy,
SelectResultExpr, Slice, TableExpression,
AggregationOperator, AliasedResultExpr, Expression, Literal, OrderBy, SelectResultExpr,
Slice, TableExpression,
},
Identifier, ResourceId,
};
use sqlparser::ast::UnaryOperator;
use sqlparser::ast::{BinaryOperator, UnaryOperator};
pub struct QueryContextBuilder<'a> {
context: QueryContext,
schema_accessor: &'a dyn SchemaAccessor,
Expand Down Expand Up @@ -138,7 +138,9 @@ impl<'a> QueryContextBuilder<'a> {
Expression::Literal(literal) => self.visit_literal(literal),
Expression::Column(_) => self.visit_column_expr(expr),
Expression::Unary { op, expr } => self.visit_unary_expr((*op).into(), expr),
Expression::Binary { op, left, right } => self.visit_binary_expr(*op, left, right),
Expression::Binary { op, left, right } => {
self.visit_binary_expr(&(*op).into(), left, right)
}
Expression::Aggregation { op, expr } => self.visit_agg_expr(*op, expr),
}
}
Expand All @@ -156,7 +158,7 @@ impl<'a> QueryContextBuilder<'a> {

fn visit_binary_expr(
&mut self,
op: BinaryOperator,
op: &BinaryOperator,
left: &Expression,
right: &Expression,
) -> ConversionResult<ColumnType> {
Expand All @@ -166,13 +168,19 @@ impl<'a> QueryContextBuilder<'a> {
match op {
BinaryOperator::And
| BinaryOperator::Or
| BinaryOperator::Equal
| BinaryOperator::GreaterThanOrEqual
| BinaryOperator::LessThanOrEqual => Ok(ColumnType::Boolean),
| BinaryOperator::Eq
| BinaryOperator::GtEq
| BinaryOperator::LtEq => Ok(ColumnType::Boolean),
BinaryOperator::Multiply
| BinaryOperator::Division
| BinaryOperator::Subtract
| BinaryOperator::Add => Ok(left_dtype),
| BinaryOperator::Divide
| BinaryOperator::Minus
| BinaryOperator::Plus => Ok(left_dtype),
_ => {
// Handle unsupported binary operations
Err(ConversionError::UnsupportedOperation {
message: format!("{op:?}"),
})
}
}
}

Expand Down Expand Up @@ -268,7 +276,7 @@ impl<'a> QueryContextBuilder<'a> {
pub(crate) fn type_check_binary_operation(
left_dtype: &ColumnType,
right_dtype: &ColumnType,
binary_operator: BinaryOperator,
binary_operator: &BinaryOperator,
) -> bool {
match binary_operator {
BinaryOperator::And | BinaryOperator::Or => {
Expand All @@ -277,7 +285,7 @@ pub(crate) fn type_check_binary_operation(
(ColumnType::Boolean, ColumnType::Boolean)
)
}
BinaryOperator::Equal => {
BinaryOperator::Eq => {
matches!(
(left_dtype, right_dtype),
(ColumnType::VarChar, ColumnType::VarChar)
Expand All @@ -287,7 +295,7 @@ pub(crate) fn type_check_binary_operation(
| (ColumnType::Scalar, _)
) || (left_dtype.is_numeric() && right_dtype.is_numeric())
}
BinaryOperator::GreaterThanOrEqual | BinaryOperator::LessThanOrEqual => {
BinaryOperator::GtEq | BinaryOperator::LtEq => {
if left_dtype == &ColumnType::VarChar || right_dtype == &ColumnType::VarChar {
return false;
}
Expand All @@ -309,19 +317,22 @@ pub(crate) fn type_check_binary_operation(
| (ColumnType::TimestampTZ(_, _), ColumnType::TimestampTZ(_, _))
)
}
BinaryOperator::Add => try_add_subtract_column_types(*left_dtype, *right_dtype).is_ok(),
BinaryOperator::Subtract => {
BinaryOperator::Plus | BinaryOperator::Minus => {
try_add_subtract_column_types(*left_dtype, *right_dtype).is_ok()
}
BinaryOperator::Multiply => try_multiply_column_types(*left_dtype, *right_dtype).is_ok(),
BinaryOperator::Division => left_dtype.is_numeric() && right_dtype.is_numeric(),
BinaryOperator::Divide => left_dtype.is_numeric() && right_dtype.is_numeric(),
_ => {
// Handle unsupported binary operations
false
}
}
}

fn check_dtypes(
left_dtype: ColumnType,
right_dtype: ColumnType,
binary_operator: BinaryOperator,
binary_operator: &BinaryOperator,
) -> ConversionResult<()> {
if type_check_binary_operation(&left_dtype, &right_dtype, binary_operator) {
Ok(())
Expand Down
14 changes: 7 additions & 7 deletions crates/proof-of-sql/src/sql/proof_exprs/comparison_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ use crate::{
use alloc::string::ToString;
use bumpalo::Bump;
use core::cmp::{max, Ordering};
use proof_of_sql_parser::intermediate_ast::BinaryOperator;
#[cfg(feature = "rayon")]
use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
use sqlparser::ast::BinaryOperator;

#[allow(clippy::unnecessary_wraps)]
fn unchecked_subtract_impl<'a, S: Scalar>(
Expand Down Expand Up @@ -48,11 +48,11 @@ pub fn scale_and_subtract_literal<S: Scalar>(
let lhs_type = lhs.column_type();
let rhs_type = rhs.column_type();
let operator = if is_equal {
BinaryOperator::Equal
BinaryOperator::Eq
} else {
BinaryOperator::LessThanOrEqual
BinaryOperator::LtEq
};
if !type_check_binary_operation(&lhs_type, &rhs_type, operator) {
if !type_check_binary_operation(&lhs_type, &rhs_type, &operator) {
return Err(ConversionError::DataTypeMismatch {
left_type: lhs_type.to_string(),
right_type: rhs_type.to_string(),
Expand Down Expand Up @@ -121,11 +121,11 @@ pub(crate) fn scale_and_subtract<'a, S: Scalar>(
let lhs_type = lhs.column_type();
let rhs_type = rhs.column_type();
let operator = if is_equal {
BinaryOperator::Equal
BinaryOperator::Eq
} else {
BinaryOperator::LessThanOrEqual
BinaryOperator::LtEq
};
if !type_check_binary_operation(&lhs_type, &rhs_type, operator) {
if !type_check_binary_operation(&lhs_type, &rhs_type, &operator) {
return Err(ConversionError::DataTypeMismatch {
left_type: lhs_type.to_string(),
right_type: rhs_type.to_string(),
Expand Down
17 changes: 7 additions & 10 deletions crates/proof-of-sql/src/sql/proof_exprs/dyn_proof_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ use crate::{
use alloc::{boxed::Box, string::ToString};
use bumpalo::Bump;
use core::fmt::Debug;
use proof_of_sql_parser::intermediate_ast::{AggregationOperator, BinaryOperator};
use proof_of_sql_parser::intermediate_ast::AggregationOperator;
use serde::{Deserialize, Serialize};
use sqlparser::ast::BinaryOperator;

/// Enum of AST column expression types that implement `ProofExpr`. Is itself a `ProofExpr`.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
Expand Down Expand Up @@ -75,7 +76,7 @@ impl DynProofExpr {
pub fn try_new_equals(lhs: DynProofExpr, rhs: DynProofExpr) -> ConversionResult<Self> {
let lhs_datatype = lhs.data_type();
let rhs_datatype = rhs.data_type();
if type_check_binary_operation(&lhs_datatype, &rhs_datatype, BinaryOperator::Equal) {
if type_check_binary_operation(&lhs_datatype, &rhs_datatype, &BinaryOperator::Eq) {
Ok(Self::Equals(EqualsExpr::new(Box::new(lhs), Box::new(rhs))))
} else {
Err(ConversionError::DataTypeMismatch {
Expand All @@ -92,11 +93,7 @@ impl DynProofExpr {
) -> ConversionResult<Self> {
let lhs_datatype = lhs.data_type();
let rhs_datatype = rhs.data_type();
if type_check_binary_operation(
&lhs_datatype,
&rhs_datatype,
BinaryOperator::LessThanOrEqual,
) {
if type_check_binary_operation(&lhs_datatype, &rhs_datatype, &BinaryOperator::LtEq) {
Ok(Self::Inequality(InequalityExpr::new(
Box::new(lhs),
Box::new(rhs),
Expand All @@ -114,7 +111,7 @@ impl DynProofExpr {
pub fn try_new_add(lhs: DynProofExpr, rhs: DynProofExpr) -> ConversionResult<Self> {
let lhs_datatype = lhs.data_type();
let rhs_datatype = rhs.data_type();
if type_check_binary_operation(&lhs_datatype, &rhs_datatype, BinaryOperator::Add) {
if type_check_binary_operation(&lhs_datatype, &rhs_datatype, &BinaryOperator::Plus) {
Ok(Self::AddSubtract(AddSubtractExpr::new(
Box::new(lhs),
Box::new(rhs),
Expand All @@ -132,7 +129,7 @@ impl DynProofExpr {
pub fn try_new_subtract(lhs: DynProofExpr, rhs: DynProofExpr) -> ConversionResult<Self> {
let lhs_datatype = lhs.data_type();
let rhs_datatype = rhs.data_type();
if type_check_binary_operation(&lhs_datatype, &rhs_datatype, BinaryOperator::Subtract) {
if type_check_binary_operation(&lhs_datatype, &rhs_datatype, &BinaryOperator::Minus) {
Ok(Self::AddSubtract(AddSubtractExpr::new(
Box::new(lhs),
Box::new(rhs),
Expand All @@ -150,7 +147,7 @@ impl DynProofExpr {
pub fn try_new_multiply(lhs: DynProofExpr, rhs: DynProofExpr) -> ConversionResult<Self> {
let lhs_datatype = lhs.data_type();
let rhs_datatype = rhs.data_type();
if type_check_binary_operation(&lhs_datatype, &rhs_datatype, BinaryOperator::Multiply) {
if type_check_binary_operation(&lhs_datatype, &rhs_datatype, &BinaryOperator::Multiply) {
Ok(Self::Multiply(MultiplyExpr::new(
Box::new(lhs),
Box::new(rhs),
Expand Down

0 comments on commit 2b31823

Please sign in to comment.