Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor!: proof_of_sql_parser::intermediate_ast::BinaryOp with sqlparser::ast::BinaryOp in the proof-of-sql crate #362

Merged
merged 1 commit into from
Nov 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions crates/proof-of-sql/src/base/database/expression_evaluation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,20 @@ use crate::base::{
};
use alloc::{format, string::ToString, vec};
use proof_of_sql_parser::{
intermediate_ast::{BinaryOperator, Expression, Literal},
intermediate_ast::{Expression, Literal},
Identifier,
};
use sqlparser::ast::UnaryOperator;
use sqlparser::ast::{BinaryOperator, UnaryOperator};

impl<S: Scalar> OwnedTable<S> {
/// Evaluate an expression on the table.
pub fn evaluate(&self, expr: &Expression) -> ExpressionEvaluationResult<OwnedColumn<S>> {
match expr {
Expression::Column(identifier) => self.evaluate_column(identifier),
Expression::Literal(lit) => self.evaluate_literal(lit),
Expression::Binary { op, left, right } => self.evaluate_binary_expr(*op, left, right),
Expression::Binary { op, left, right } => {
self.evaluate_binary_expr(&(*op).into(), left, right)
}
Expression::Unary { op, expr } => self.evaluate_unary_expr((*op).into(), expr),
_ => Err(ExpressionEvaluationError::Unsupported {
expression: format!("Expression {expr:?} is not supported yet"),
Expand Down Expand Up @@ -84,7 +86,7 @@ impl<S: Scalar> OwnedTable<S> {

fn evaluate_binary_expr(
&self,
op: BinaryOperator,
op: &BinaryOperator,
left: &Expression,
right: &Expression,
) -> ExpressionEvaluationResult<OwnedColumn<S>> {
Expand All @@ -93,13 +95,16 @@ impl<S: Scalar> OwnedTable<S> {
match op {
BinaryOperator::And => Ok(left.element_wise_and(&right)?),
BinaryOperator::Or => Ok(left.element_wise_or(&right)?),
BinaryOperator::Equal => Ok(left.element_wise_eq(&right)?),
BinaryOperator::GreaterThanOrEqual => Ok(left.element_wise_ge(&right)?),
BinaryOperator::LessThanOrEqual => Ok(left.element_wise_le(&right)?),
BinaryOperator::Add => Ok(left.element_wise_add(&right)?),
BinaryOperator::Subtract => Ok(left.element_wise_sub(&right)?),
BinaryOperator::Eq => Ok(left.element_wise_eq(&right)?),
BinaryOperator::GtEq => Ok(left.element_wise_ge(&right)?),
BinaryOperator::LtEq => Ok(left.element_wise_le(&right)?),
BinaryOperator::Plus => Ok(left.element_wise_add(&right)?),
BinaryOperator::Minus => Ok(left.element_wise_sub(&right)?),
BinaryOperator::Multiply => Ok(left.element_wise_mul(&right)?),
BinaryOperator::Division => Ok(left.element_wise_div(&right)?),
BinaryOperator::Divide => Ok(left.element_wise_div(&right)?),
_ => Err(ExpressionEvaluationError::Unsupported {
expression: format!("Binary operator '{op}' is not supported."),
}),
}
}
}
28 changes: 18 additions & 10 deletions crates/proof-of-sql/src/sql/parse/dyn_proof_expr_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ use crate::{
};
use alloc::{borrow::ToOwned, boxed::Box, format, string::ToString};
use proof_of_sql_parser::{
intermediate_ast::{AggregationOperator, BinaryOperator, Expression, Literal},
intermediate_ast::{AggregationOperator, Expression, Literal},
posql_time::{PoSQLTimeUnit, PoSQLTimestampError},
Identifier,
};
use sqlparser::ast::UnaryOperator;
use sqlparser::ast::{BinaryOperator, UnaryOperator};

/// Builder that enables building a `proofs::sql::proof_exprs::DynProofExpr` from
/// a `proof_of_sql_parser::intermediate_ast::Expression`.
Expand Down Expand Up @@ -60,7 +60,9 @@ impl DynProofExprBuilder<'_> {
match expr {
Expression::Column(identifier) => self.visit_column(*identifier),
Expression::Literal(lit) => self.visit_literal(lit),
Expression::Binary { op, left, right } => self.visit_binary_expr(*op, left, right),
Expression::Binary { op, left, right } => {
self.visit_binary_expr(&(*op).into(), left, right)
}
Expression::Unary { op, expr } => self.visit_unary_expr((*op).into(), expr),
Expression::Aggregation { op, expr } => self.visit_aggregate_expr(*op, expr),
_ => Err(ConversionError::Unprovable {
Expand Down Expand Up @@ -146,7 +148,7 @@ impl DynProofExprBuilder<'_> {

fn visit_binary_expr(
&self,
op: BinaryOperator,
op: &BinaryOperator,
left: &Expression,
right: &Expression,
) -> Result<DynProofExpr, ConversionError> {
Expand All @@ -161,27 +163,27 @@ impl DynProofExprBuilder<'_> {
let right = self.visit_expr(right);
DynProofExpr::try_new_or(left?, right?)
}
BinaryOperator::Equal => {
BinaryOperator::Eq => {
let left = self.visit_expr(left);
let right = self.visit_expr(right);
DynProofExpr::try_new_equals(left?, right?)
}
BinaryOperator::GreaterThanOrEqual => {
BinaryOperator::GtEq => {
let left = self.visit_expr(left);
let right = self.visit_expr(right);
DynProofExpr::try_new_inequality(left?, right?, false)
}
BinaryOperator::LessThanOrEqual => {
BinaryOperator::LtEq => {
let left = self.visit_expr(left);
let right = self.visit_expr(right);
DynProofExpr::try_new_inequality(left?, right?, true)
}
BinaryOperator::Add => {
BinaryOperator::Plus => {
let left = self.visit_expr(left);
let right = self.visit_expr(right);
DynProofExpr::try_new_add(left?, right?)
}
BinaryOperator::Subtract => {
BinaryOperator::Minus => {
let left = self.visit_expr(left);
let right = self.visit_expr(right);
DynProofExpr::try_new_subtract(left?, right?)
Expand All @@ -191,9 +193,15 @@ impl DynProofExprBuilder<'_> {
let right = self.visit_expr(right);
DynProofExpr::try_new_multiply(left?, right?)
}
BinaryOperator::Division => Err(ConversionError::Unprovable {
BinaryOperator::Divide => Err(ConversionError::Unprovable {
error: format!("Binary operator {op:?} is not supported at this location"),
}),
_ => {
// Handle unsupported binary operations
Err(ConversionError::UnsupportedOperation {
message: format!("{op:?}"),
})
}
}
}

Expand Down
47 changes: 29 additions & 18 deletions crates/proof-of-sql/src/sql/parse/query_context_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ use crate::base::{
use alloc::{boxed::Box, format, string::ToString, vec::Vec};
use proof_of_sql_parser::{
intermediate_ast::{
AggregationOperator, AliasedResultExpr, BinaryOperator, Expression, Literal, OrderBy,
SelectResultExpr, Slice, TableExpression,
AggregationOperator, AliasedResultExpr, Expression, Literal, OrderBy, SelectResultExpr,
Slice, TableExpression,
},
Identifier, ResourceId,
};
use sqlparser::ast::UnaryOperator;
use sqlparser::ast::{BinaryOperator, UnaryOperator};
pub struct QueryContextBuilder<'a> {
context: QueryContext,
schema_accessor: &'a dyn SchemaAccessor,
Expand Down Expand Up @@ -138,7 +138,9 @@ impl<'a> QueryContextBuilder<'a> {
Expression::Literal(literal) => self.visit_literal(literal),
Expression::Column(_) => self.visit_column_expr(expr),
Expression::Unary { op, expr } => self.visit_unary_expr((*op).into(), expr),
Expression::Binary { op, left, right } => self.visit_binary_expr(*op, left, right),
Expression::Binary { op, left, right } => {
self.visit_binary_expr(&(*op).into(), left, right)
}
Expression::Aggregation { op, expr } => self.visit_agg_expr(*op, expr),
}
}
Expand All @@ -156,7 +158,7 @@ impl<'a> QueryContextBuilder<'a> {

fn visit_binary_expr(
&mut self,
op: BinaryOperator,
op: &BinaryOperator,
left: &Expression,
right: &Expression,
) -> ConversionResult<ColumnType> {
Expand All @@ -166,13 +168,19 @@ impl<'a> QueryContextBuilder<'a> {
match op {
BinaryOperator::And
| BinaryOperator::Or
| BinaryOperator::Equal
| BinaryOperator::GreaterThanOrEqual
| BinaryOperator::LessThanOrEqual => Ok(ColumnType::Boolean),
| BinaryOperator::Eq
| BinaryOperator::GtEq
| BinaryOperator::LtEq => Ok(ColumnType::Boolean),
BinaryOperator::Multiply
| BinaryOperator::Division
| BinaryOperator::Subtract
| BinaryOperator::Add => Ok(left_dtype),
| BinaryOperator::Divide
| BinaryOperator::Minus
| BinaryOperator::Plus => Ok(left_dtype),
_ => {
// Handle unsupported binary operations
Err(ConversionError::UnsupportedOperation {
message: format!("{op:?}"),
})
}
}
}

Expand Down Expand Up @@ -268,7 +276,7 @@ impl<'a> QueryContextBuilder<'a> {
pub(crate) fn type_check_binary_operation(
left_dtype: &ColumnType,
right_dtype: &ColumnType,
binary_operator: BinaryOperator,
binary_operator: &BinaryOperator,
) -> bool {
match binary_operator {
BinaryOperator::And | BinaryOperator::Or => {
Expand All @@ -277,7 +285,7 @@ pub(crate) fn type_check_binary_operation(
(ColumnType::Boolean, ColumnType::Boolean)
)
}
BinaryOperator::Equal => {
BinaryOperator::Eq => {
matches!(
(left_dtype, right_dtype),
(ColumnType::VarChar, ColumnType::VarChar)
Expand All @@ -287,7 +295,7 @@ pub(crate) fn type_check_binary_operation(
| (ColumnType::Scalar, _)
) || (left_dtype.is_numeric() && right_dtype.is_numeric())
}
BinaryOperator::GreaterThanOrEqual | BinaryOperator::LessThanOrEqual => {
BinaryOperator::GtEq | BinaryOperator::LtEq => {
if left_dtype == &ColumnType::VarChar || right_dtype == &ColumnType::VarChar {
return false;
}
Expand All @@ -309,19 +317,22 @@ pub(crate) fn type_check_binary_operation(
| (ColumnType::TimestampTZ(_, _), ColumnType::TimestampTZ(_, _))
)
}
BinaryOperator::Add => try_add_subtract_column_types(*left_dtype, *right_dtype).is_ok(),
BinaryOperator::Subtract => {
BinaryOperator::Plus | BinaryOperator::Minus => {
try_add_subtract_column_types(*left_dtype, *right_dtype).is_ok()
}
BinaryOperator::Multiply => try_multiply_column_types(*left_dtype, *right_dtype).is_ok(),
BinaryOperator::Division => left_dtype.is_numeric() && right_dtype.is_numeric(),
BinaryOperator::Divide => left_dtype.is_numeric() && right_dtype.is_numeric(),
_ => {
// Handle unsupported binary operations
false
}
}
}

fn check_dtypes(
left_dtype: ColumnType,
right_dtype: ColumnType,
binary_operator: BinaryOperator,
binary_operator: &BinaryOperator,
) -> ConversionResult<()> {
if type_check_binary_operation(&left_dtype, &right_dtype, binary_operator) {
Ok(())
Expand Down
14 changes: 7 additions & 7 deletions crates/proof-of-sql/src/sql/proof_exprs/comparison_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ use crate::{
use alloc::string::ToString;
use bumpalo::Bump;
use core::cmp::{max, Ordering};
use proof_of_sql_parser::intermediate_ast::BinaryOperator;
#[cfg(feature = "rayon")]
use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
use sqlparser::ast::BinaryOperator;

#[allow(clippy::unnecessary_wraps)]
fn unchecked_subtract_impl<'a, S: Scalar>(
Expand Down Expand Up @@ -48,11 +48,11 @@ pub fn scale_and_subtract_literal<S: Scalar>(
let lhs_type = lhs.column_type();
let rhs_type = rhs.column_type();
let operator = if is_equal {
BinaryOperator::Equal
BinaryOperator::Eq
} else {
BinaryOperator::LessThanOrEqual
BinaryOperator::LtEq
};
if !type_check_binary_operation(&lhs_type, &rhs_type, operator) {
if !type_check_binary_operation(&lhs_type, &rhs_type, &operator) {
return Err(ConversionError::DataTypeMismatch {
left_type: lhs_type.to_string(),
right_type: rhs_type.to_string(),
Expand Down Expand Up @@ -121,11 +121,11 @@ pub(crate) fn scale_and_subtract<'a, S: Scalar>(
let lhs_type = lhs.column_type();
let rhs_type = rhs.column_type();
let operator = if is_equal {
BinaryOperator::Equal
BinaryOperator::Eq
} else {
BinaryOperator::LessThanOrEqual
BinaryOperator::LtEq
};
if !type_check_binary_operation(&lhs_type, &rhs_type, operator) {
if !type_check_binary_operation(&lhs_type, &rhs_type, &operator) {
return Err(ConversionError::DataTypeMismatch {
left_type: lhs_type.to_string(),
right_type: rhs_type.to_string(),
Expand Down
17 changes: 7 additions & 10 deletions crates/proof-of-sql/src/sql/proof_exprs/dyn_proof_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ use crate::{
use alloc::{boxed::Box, string::ToString};
use bumpalo::Bump;
use core::fmt::Debug;
use proof_of_sql_parser::intermediate_ast::{AggregationOperator, BinaryOperator};
use proof_of_sql_parser::intermediate_ast::AggregationOperator;
use serde::{Deserialize, Serialize};
use sqlparser::ast::BinaryOperator;

/// Enum of AST column expression types that implement `ProofExpr`. Is itself a `ProofExpr`.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
Expand Down Expand Up @@ -75,7 +76,7 @@ impl DynProofExpr {
pub fn try_new_equals(lhs: DynProofExpr, rhs: DynProofExpr) -> ConversionResult<Self> {
let lhs_datatype = lhs.data_type();
let rhs_datatype = rhs.data_type();
if type_check_binary_operation(&lhs_datatype, &rhs_datatype, BinaryOperator::Equal) {
if type_check_binary_operation(&lhs_datatype, &rhs_datatype, &BinaryOperator::Eq) {
Ok(Self::Equals(EqualsExpr::new(Box::new(lhs), Box::new(rhs))))
} else {
Err(ConversionError::DataTypeMismatch {
Expand All @@ -92,11 +93,7 @@ impl DynProofExpr {
) -> ConversionResult<Self> {
let lhs_datatype = lhs.data_type();
let rhs_datatype = rhs.data_type();
if type_check_binary_operation(
&lhs_datatype,
&rhs_datatype,
BinaryOperator::LessThanOrEqual,
) {
if type_check_binary_operation(&lhs_datatype, &rhs_datatype, &BinaryOperator::LtEq) {
Ok(Self::Inequality(InequalityExpr::new(
Box::new(lhs),
Box::new(rhs),
Expand All @@ -114,7 +111,7 @@ impl DynProofExpr {
pub fn try_new_add(lhs: DynProofExpr, rhs: DynProofExpr) -> ConversionResult<Self> {
let lhs_datatype = lhs.data_type();
let rhs_datatype = rhs.data_type();
if type_check_binary_operation(&lhs_datatype, &rhs_datatype, BinaryOperator::Add) {
if type_check_binary_operation(&lhs_datatype, &rhs_datatype, &BinaryOperator::Plus) {
Ok(Self::AddSubtract(AddSubtractExpr::new(
Box::new(lhs),
Box::new(rhs),
Expand All @@ -132,7 +129,7 @@ impl DynProofExpr {
pub fn try_new_subtract(lhs: DynProofExpr, rhs: DynProofExpr) -> ConversionResult<Self> {
let lhs_datatype = lhs.data_type();
let rhs_datatype = rhs.data_type();
if type_check_binary_operation(&lhs_datatype, &rhs_datatype, BinaryOperator::Subtract) {
if type_check_binary_operation(&lhs_datatype, &rhs_datatype, &BinaryOperator::Minus) {
Ok(Self::AddSubtract(AddSubtractExpr::new(
Box::new(lhs),
Box::new(rhs),
Expand All @@ -150,7 +147,7 @@ impl DynProofExpr {
pub fn try_new_multiply(lhs: DynProofExpr, rhs: DynProofExpr) -> ConversionResult<Self> {
let lhs_datatype = lhs.data_type();
let rhs_datatype = rhs.data_type();
if type_check_binary_operation(&lhs_datatype, &rhs_datatype, BinaryOperator::Multiply) {
if type_check_binary_operation(&lhs_datatype, &rhs_datatype, &BinaryOperator::Multiply) {
Ok(Self::Multiply(MultiplyExpr::new(
Box::new(lhs),
Box::new(rhs),
Expand Down
Loading