Skip to content

Commit

Permalink
refactor!: PoSQLBinaryOP to use sqlparser::ast::BinaryOP
Browse files Browse the repository at this point in the history
  • Loading branch information
varshith257 committed Nov 13, 2024
1 parent be228d2 commit 59499d1
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 53 deletions.
24 changes: 15 additions & 9 deletions crates/proof-of-sql/src/base/database/expression_evaluation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,20 @@ use crate::base::{
};
use alloc::{format, string::ToString, vec};
use proof_of_sql_parser::{
intermediate_ast::{BinaryOperator, Expression, Literal},
intermediate_ast::{Expression, Literal},
Identifier,
};
use sqlparser::ast::UnaryOperator;
use sqlparser::ast::{BinaryOperator, UnaryOperator};

impl<S: Scalar> OwnedTable<S> {
/// Evaluate an expression on the table.
pub fn evaluate(&self, expr: &Expression) -> ExpressionEvaluationResult<OwnedColumn<S>> {
match expr {
Expression::Column(identifier) => self.evaluate_column(identifier),
Expression::Literal(lit) => self.evaluate_literal(lit),
Expression::Binary { op, left, right } => self.evaluate_binary_expr(*op, left, right),
Expression::Binary { op, left, right } => {
self.evaluate_binary_expr((*op).into(), left, right)
}
Expression::Unary { op, expr } => self.evaluate_unary_expr((*op).into(), expr),
_ => Err(ExpressionEvaluationError::Unsupported {
expression: format!("Expression {expr:?} is not supported yet"),
Expand Down Expand Up @@ -82,6 +84,7 @@ impl<S: Scalar> OwnedTable<S> {
}
}

#[allow(clippy::needless_pass_by_value)]
fn evaluate_binary_expr(
&self,
op: BinaryOperator,
Expand All @@ -93,13 +96,16 @@ impl<S: Scalar> OwnedTable<S> {
match op {
BinaryOperator::And => Ok(left.element_wise_and(&right)?),
BinaryOperator::Or => Ok(left.element_wise_or(&right)?),
BinaryOperator::Equal => Ok(left.element_wise_eq(&right)?),
BinaryOperator::GreaterThanOrEqual => Ok(left.element_wise_ge(&right)?),
BinaryOperator::LessThanOrEqual => Ok(left.element_wise_le(&right)?),
BinaryOperator::Add => Ok(left.element_wise_add(&right)?),
BinaryOperator::Subtract => Ok(left.element_wise_sub(&right)?),
BinaryOperator::Eq => Ok(left.element_wise_eq(&right)?),
BinaryOperator::GtEq => Ok(left.element_wise_ge(&right)?),
BinaryOperator::LtEq => Ok(left.element_wise_le(&right)?),
BinaryOperator::Plus => Ok(left.element_wise_add(&right)?),
BinaryOperator::Minus => Ok(left.element_wise_sub(&right)?),
BinaryOperator::Multiply => Ok(left.element_wise_mul(&right)?),
BinaryOperator::Division => Ok(left.element_wise_div(&right)?),
BinaryOperator::Divide => Ok(left.element_wise_div(&right)?),
_ => Err(ExpressionEvaluationError::Unsupported {
expression: format!("Binary operator '{op}' is not supported."),
}),
}
}
}
27 changes: 18 additions & 9 deletions crates/proof-of-sql/src/sql/parse/dyn_proof_expr_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ use crate::{
};
use alloc::{borrow::ToOwned, boxed::Box, format, string::ToString};
use proof_of_sql_parser::{
intermediate_ast::{AggregationOperator, BinaryOperator, Expression, Literal},
intermediate_ast::{AggregationOperator, Expression, Literal},
posql_time::{PoSQLTimeUnit, PoSQLTimestampError},
Identifier,
};
use sqlparser::ast::UnaryOperator;
use sqlparser::ast::{BinaryOperator, UnaryOperator};

/// Builder that enables building a `proofs::sql::proof_exprs::DynProofExpr` from
/// a `proof_of_sql_parser::intermediate_ast::Expression`.
Expand Down Expand Up @@ -60,7 +60,9 @@ impl DynProofExprBuilder<'_> {
match expr {
Expression::Column(identifier) => self.visit_column(*identifier),
Expression::Literal(lit) => self.visit_literal(lit),
Expression::Binary { op, left, right } => self.visit_binary_expr(*op, left, right),
Expression::Binary { op, left, right } => {
self.visit_binary_expr((*op).into(), left, right)
}
Expression::Unary { op, expr } => self.visit_unary_expr((*op).into(), expr),
Expression::Aggregation { op, expr } => self.visit_aggregate_expr(*op, expr),
_ => Err(ConversionError::Unprovable {
Expand Down Expand Up @@ -144,6 +146,7 @@ impl DynProofExprBuilder<'_> {
}
}

#[allow(clippy::needless_pass_by_value)]
fn visit_binary_expr(
&self,
op: BinaryOperator,
Expand All @@ -161,27 +164,27 @@ impl DynProofExprBuilder<'_> {
let right = self.visit_expr(right);
DynProofExpr::try_new_or(left?, right?)
}
BinaryOperator::Equal => {
BinaryOperator::Eq => {
let left = self.visit_expr(left);
let right = self.visit_expr(right);
DynProofExpr::try_new_equals(left?, right?)
}
BinaryOperator::GreaterThanOrEqual => {
BinaryOperator::GtEq => {
let left = self.visit_expr(left);
let right = self.visit_expr(right);
DynProofExpr::try_new_inequality(left?, right?, false)
}
BinaryOperator::LessThanOrEqual => {
BinaryOperator::LtEq => {
let left = self.visit_expr(left);
let right = self.visit_expr(right);
DynProofExpr::try_new_inequality(left?, right?, true)
}
BinaryOperator::Add => {
BinaryOperator::Plus => {
let left = self.visit_expr(left);
let right = self.visit_expr(right);
DynProofExpr::try_new_add(left?, right?)
}
BinaryOperator::Subtract => {
BinaryOperator::Minus => {
let left = self.visit_expr(left);
let right = self.visit_expr(right);
DynProofExpr::try_new_subtract(left?, right?)
Expand All @@ -191,9 +194,15 @@ impl DynProofExprBuilder<'_> {
let right = self.visit_expr(right);
DynProofExpr::try_new_multiply(left?, right?)
}
BinaryOperator::Division => Err(ConversionError::Unprovable {
BinaryOperator::Divide => Err(ConversionError::Unprovable {
error: format!("Binary operator {op:?} is not supported at this location"),
}),
_ => {
// Handle unsupported binary operations
Err(ConversionError::UnsupportedOperation {
message: format!("{op:?}"),
})
}
}
}

Expand Down
49 changes: 31 additions & 18 deletions crates/proof-of-sql/src/sql/parse/query_context_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ use crate::base::{
use alloc::{boxed::Box, format, string::ToString, vec::Vec};
use proof_of_sql_parser::{
intermediate_ast::{
AggregationOperator, AliasedResultExpr, BinaryOperator, Expression, Literal, OrderBy,
SelectResultExpr, Slice, TableExpression,
AggregationOperator, AliasedResultExpr, Expression, Literal, OrderBy, SelectResultExpr,
Slice, TableExpression,
},
Identifier, ResourceId,
};
use sqlparser::ast::UnaryOperator;
use sqlparser::ast::{BinaryOperator, UnaryOperator};
pub struct QueryContextBuilder<'a> {
context: QueryContext,
schema_accessor: &'a dyn SchemaAccessor,
Expand Down Expand Up @@ -138,7 +138,9 @@ impl<'a> QueryContextBuilder<'a> {
Expression::Literal(literal) => self.visit_literal(literal),
Expression::Column(_) => self.visit_column_expr(expr),
Expression::Unary { op, expr } => self.visit_unary_expr((*op).into(), expr),
Expression::Binary { op, left, right } => self.visit_binary_expr(*op, left, right),
Expression::Binary { op, left, right } => {
self.visit_binary_expr((*op).into(), left, right)
}
Expression::Aggregation { op, expr } => self.visit_agg_expr(*op, expr),
}
}
Expand All @@ -154,6 +156,7 @@ impl<'a> QueryContextBuilder<'a> {
self.visit_column_identifier(identifier)
}

#[allow(clippy::needless_pass_by_value)]
fn visit_binary_expr(
&mut self,
op: BinaryOperator,
Expand All @@ -162,17 +165,23 @@ impl<'a> QueryContextBuilder<'a> {
) -> ConversionResult<ColumnType> {
let left_dtype = self.visit_expr(left)?;
let right_dtype = self.visit_expr(right)?;
check_dtypes(left_dtype, right_dtype, op)?;
check_dtypes(left_dtype, right_dtype, op.clone())?;
match op {
BinaryOperator::And
| BinaryOperator::Or
| BinaryOperator::Equal
| BinaryOperator::GreaterThanOrEqual
| BinaryOperator::LessThanOrEqual => Ok(ColumnType::Boolean),
| BinaryOperator::Eq
| BinaryOperator::GtEq
| BinaryOperator::LtEq => Ok(ColumnType::Boolean),
BinaryOperator::Multiply
| BinaryOperator::Division
| BinaryOperator::Subtract
| BinaryOperator::Add => Ok(left_dtype),
| BinaryOperator::Divide
| BinaryOperator::Minus
| BinaryOperator::Plus => Ok(left_dtype),
_ => {
// Handle unsupported binary operations
Err(ConversionError::UnsupportedOperation {
message: format!("{op:?}"),
})
}
}
}

Expand Down Expand Up @@ -268,7 +277,7 @@ impl<'a> QueryContextBuilder<'a> {
pub(crate) fn type_check_binary_operation(
left_dtype: &ColumnType,
right_dtype: &ColumnType,
binary_operator: BinaryOperator,
binary_operator: &BinaryOperator,
) -> bool {
match binary_operator {
BinaryOperator::And | BinaryOperator::Or => {
Expand All @@ -277,7 +286,7 @@ pub(crate) fn type_check_binary_operation(
(ColumnType::Boolean, ColumnType::Boolean)
)
}
BinaryOperator::Equal => {
BinaryOperator::Eq => {
matches!(
(left_dtype, right_dtype),
(ColumnType::VarChar, ColumnType::VarChar)
Expand All @@ -287,7 +296,7 @@ pub(crate) fn type_check_binary_operation(
| (ColumnType::Scalar, _)
) || (left_dtype.is_numeric() && right_dtype.is_numeric())
}
BinaryOperator::GreaterThanOrEqual | BinaryOperator::LessThanOrEqual => {
BinaryOperator::GtEq | BinaryOperator::LtEq => {
if left_dtype == &ColumnType::VarChar || right_dtype == &ColumnType::VarChar {
return false;
}
Expand All @@ -309,21 +318,25 @@ pub(crate) fn type_check_binary_operation(
| (ColumnType::TimestampTZ(_, _), ColumnType::TimestampTZ(_, _))
)
}
BinaryOperator::Add => try_add_subtract_column_types(*left_dtype, *right_dtype).is_ok(),
BinaryOperator::Subtract => {
BinaryOperator::Plus | BinaryOperator::Minus => {
try_add_subtract_column_types(*left_dtype, *right_dtype).is_ok()
}
BinaryOperator::Multiply => try_multiply_column_types(*left_dtype, *right_dtype).is_ok(),
BinaryOperator::Division => left_dtype.is_numeric() && right_dtype.is_numeric(),
BinaryOperator::Divide => left_dtype.is_numeric() && right_dtype.is_numeric(),
_ => {
// Handle unsupported binary operations
false
}
}
}

#[allow(clippy::needless_pass_by_value)]
fn check_dtypes(
left_dtype: ColumnType,
right_dtype: ColumnType,
binary_operator: BinaryOperator,
) -> ConversionResult<()> {
if type_check_binary_operation(&left_dtype, &right_dtype, binary_operator) {
if type_check_binary_operation(&left_dtype, &right_dtype, &binary_operator) {
Ok(())
} else {
Err(ConversionError::DataTypeMismatch {
Expand Down
14 changes: 7 additions & 7 deletions crates/proof-of-sql/src/sql/proof_exprs/comparison_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ use crate::{
use alloc::string::ToString;
use bumpalo::Bump;
use core::cmp::{max, Ordering};
use proof_of_sql_parser::intermediate_ast::BinaryOperator;
#[cfg(feature = "rayon")]
use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
use sqlparser::ast::BinaryOperator;

#[allow(clippy::unnecessary_wraps)]
fn unchecked_subtract_impl<'a, S: Scalar>(
Expand Down Expand Up @@ -48,11 +48,11 @@ pub fn scale_and_subtract_literal<S: Scalar>(
let lhs_type = lhs.column_type();
let rhs_type = rhs.column_type();
let operator = if is_equal {
BinaryOperator::Equal
BinaryOperator::Eq
} else {
BinaryOperator::LessThanOrEqual
BinaryOperator::LtEq
};
if !type_check_binary_operation(&lhs_type, &rhs_type, operator) {
if !type_check_binary_operation(&lhs_type, &rhs_type, &operator) {
return Err(ConversionError::DataTypeMismatch {
left_type: lhs_type.to_string(),
right_type: rhs_type.to_string(),
Expand Down Expand Up @@ -121,11 +121,11 @@ pub(crate) fn scale_and_subtract<'a, S: Scalar>(
let lhs_type = lhs.column_type();
let rhs_type = rhs.column_type();
let operator = if is_equal {
BinaryOperator::Equal
BinaryOperator::Eq
} else {
BinaryOperator::LessThanOrEqual
BinaryOperator::LtEq
};
if !type_check_binary_operation(&lhs_type, &rhs_type, operator) {
if !type_check_binary_operation(&lhs_type, &rhs_type, &operator) {
return Err(ConversionError::DataTypeMismatch {
left_type: lhs_type.to_string(),
right_type: rhs_type.to_string(),
Expand Down
17 changes: 7 additions & 10 deletions crates/proof-of-sql/src/sql/proof_exprs/dyn_proof_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ use crate::{
use alloc::{boxed::Box, string::ToString};
use bumpalo::Bump;
use core::fmt::Debug;
use proof_of_sql_parser::intermediate_ast::{AggregationOperator, BinaryOperator};
use proof_of_sql_parser::intermediate_ast::AggregationOperator;
use serde::{Deserialize, Serialize};
use sqlparser::ast::BinaryOperator;

/// Enum of AST column expression types that implement `ProofExpr`. Is itself a `ProofExpr`.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
Expand Down Expand Up @@ -75,7 +76,7 @@ impl DynProofExpr {
pub fn try_new_equals(lhs: DynProofExpr, rhs: DynProofExpr) -> ConversionResult<Self> {
let lhs_datatype = lhs.data_type();
let rhs_datatype = rhs.data_type();
if type_check_binary_operation(&lhs_datatype, &rhs_datatype, BinaryOperator::Equal) {
if type_check_binary_operation(&lhs_datatype, &rhs_datatype, &BinaryOperator::Eq) {
Ok(Self::Equals(EqualsExpr::new(Box::new(lhs), Box::new(rhs))))
} else {
Err(ConversionError::DataTypeMismatch {
Expand All @@ -92,11 +93,7 @@ impl DynProofExpr {
) -> ConversionResult<Self> {
let lhs_datatype = lhs.data_type();
let rhs_datatype = rhs.data_type();
if type_check_binary_operation(
&lhs_datatype,
&rhs_datatype,
BinaryOperator::LessThanOrEqual,
) {
if type_check_binary_operation(&lhs_datatype, &rhs_datatype, &BinaryOperator::LtEq) {
Ok(Self::Inequality(InequalityExpr::new(
Box::new(lhs),
Box::new(rhs),
Expand All @@ -114,7 +111,7 @@ impl DynProofExpr {
pub fn try_new_add(lhs: DynProofExpr, rhs: DynProofExpr) -> ConversionResult<Self> {
let lhs_datatype = lhs.data_type();
let rhs_datatype = rhs.data_type();
if type_check_binary_operation(&lhs_datatype, &rhs_datatype, BinaryOperator::Add) {
if type_check_binary_operation(&lhs_datatype, &rhs_datatype, &BinaryOperator::Plus) {
Ok(Self::AddSubtract(AddSubtractExpr::new(
Box::new(lhs),
Box::new(rhs),
Expand All @@ -132,7 +129,7 @@ impl DynProofExpr {
pub fn try_new_subtract(lhs: DynProofExpr, rhs: DynProofExpr) -> ConversionResult<Self> {
let lhs_datatype = lhs.data_type();
let rhs_datatype = rhs.data_type();
if type_check_binary_operation(&lhs_datatype, &rhs_datatype, BinaryOperator::Subtract) {
if type_check_binary_operation(&lhs_datatype, &rhs_datatype, &BinaryOperator::Minus) {
Ok(Self::AddSubtract(AddSubtractExpr::new(
Box::new(lhs),
Box::new(rhs),
Expand All @@ -150,7 +147,7 @@ impl DynProofExpr {
pub fn try_new_multiply(lhs: DynProofExpr, rhs: DynProofExpr) -> ConversionResult<Self> {
let lhs_datatype = lhs.data_type();
let rhs_datatype = rhs.data_type();
if type_check_binary_operation(&lhs_datatype, &rhs_datatype, BinaryOperator::Multiply) {
if type_check_binary_operation(&lhs_datatype, &rhs_datatype, &BinaryOperator::Multiply) {
Ok(Self::Multiply(MultiplyExpr::new(
Box::new(lhs),
Box::new(rhs),
Expand Down

0 comments on commit 59499d1

Please sign in to comment.