Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor!: proof_of_sql_parser::intermediate_ast::PoSQLTimezone with sqlparser::ast::TimezoneInfo in the proof-of-sql crate #451

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
* text=auto
varshith257 marked this conversation as resolved.
Show resolved Hide resolved
57 changes: 56 additions & 1 deletion crates/proof-of-sql-parser/src/sqlparser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,14 @@ use crate::{
OrderBy as PoSqlOrderBy, OrderByDirection, SelectResultExpr, SetExpression,
TableExpression, UnaryOperator as PoSqlUnaryOperator,
},
posql_time::{PoSQLTimeUnit, PoSQLTimeZone},
Identifier, ResourceId, SelectStatement,
};
use alloc::{boxed::Box, string::ToString, vec};
use alloc::{
boxed::Box,
string::{String, ToString},
vec,
};
use core::fmt::Display;
use sqlparser::ast::{
BinaryOperator, DataType, Expr, Function, FunctionArg, FunctionArgExpr, GroupByExpr, Ident,
Expand All @@ -28,6 +33,51 @@ fn id(id: Identifier) -> Expr {
Expr::Identifier(id.into())
}

/// Provides an extension for the `TimezoneInfo` type for offsets.
pub trait TimezoneInfoExt {
/// Retrieve the offset in seconds for `TimezoneInfo`.
fn offset(&self, timezone_str: Option<&str>) -> i32;
}

impl TimezoneInfoExt for TimezoneInfo {
fn offset(&self, timezone_str: Option<&str>) -> i32 {
match self {
TimezoneInfo::None => PoSQLTimeZone::utc().offset(),
TimezoneInfo::WithTimeZone => match timezone_str {
Some(tz_str) => PoSQLTimeZone::try_from(&Some(tz_str.into()))
.unwrap_or_else(|_| PoSQLTimeZone::utc())
.offset(),
None => PoSQLTimeZone::utc().offset(),
},
_ => panic!("Offsets are not applicable for WithoutTimeZone or Tz variants."),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WithoutTimeZone should be converted to UTC, right? Tz should be treated as WithTimeZone.

}
}
}

/// Utility function to create a `Timestamp` expression.
pub fn timestamp_to_expr(
value: &str,
time_unit: PoSQLTimeUnit,
timezone: TimezoneInfo,
) -> Result<Expr, String> {
let time_unit_as_u64 = u64::from(time_unit);

Ok(Expr::TypedString {
data_type: DataType::Timestamp(Some(time_unit_as_u64), timezone),
value: value.to_string(),
})
}

/// Parses [`PoSQLTimeZone`] into a `TimezoneInfo`.
impl From<PoSQLTimeZone> for TimezoneInfo {
fn from(posql_timezone: PoSQLTimeZone) -> Self {
match posql_timezone.offset() {
0 => TimezoneInfo::None,
_ => TimezoneInfo::WithTimeZone,
}
}
}

impl From<Identifier> for Ident {
fn from(id: Identifier) -> Self {
Ident::new(id.as_str())
Expand Down Expand Up @@ -268,6 +318,11 @@ mod test {
"select timestamp '2024-11-07T04:55:12.345+03:00' as time from t;",
"select timestamp(3) '2024-11-07 01:55:12.345 UTC' as time from t;",
);

check_posql_intermediate_ast_to_sqlparser_equivalence(
"select timestamp '2024-11-07T04:55:12+00:00' as time from t;",
"select timestamp(0) '2024-11-07 04:55:12 UTC' as time from t;",
);
}

// Check that PoSQL intermediate AST can be converted to SQL parser AST and that the two are equal.
Expand Down
5 changes: 3 additions & 2 deletions crates/proof-of-sql/benches/bench_append_rows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ use proof_of_sql::{
DoryCommitment, DoryProverPublicSetup, DoryScalar, ProverSetup, PublicParameters,
},
};
use proof_of_sql_parser::posql_time::{PoSQLTimeUnit, PoSQLTimeZone};
use proof_of_sql_parser::posql_time::PoSQLTimeUnit;
use rand::Rng;
use sqlparser::ast::TimezoneInfo;

/// Bench dory performance when appending rows to a table. This includes the computation of
/// commitments. Chose the number of columns to randomly generate across supported `PoSQL`
Expand Down Expand Up @@ -121,7 +122,7 @@ pub fn generate_random_owned_table<S: Scalar>(
"timestamptz" => columns.push(timestamptz(
&*identifier,
PoSQLTimeUnit::Second,
PoSQLTimeZone::utc(),
TimezoneInfo::None,
vec![rng.gen::<i64>(); num_rows],
)),
_ => unreachable!(),
Expand Down
120 changes: 65 additions & 55 deletions crates/proof-of-sql/src/base/arrow/arrow_array_to_column_conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,60 +202,69 @@ impl ArrayRefExt for ArrayRef {
}
}
// Handle all possible TimeStamp TimeUnit instances
DataType::Timestamp(time_unit, tz) => match time_unit {
ArrowTimeUnit::Second => {
if let Some(array) = self.as_any().downcast_ref::<TimestampSecondArray>() {
Ok(Column::TimestampTZ(
PoSQLTimeUnit::Second,
PoSQLTimeZone::try_from(tz)?,
&array.values()[range.start..range.end],
))
} else {
Err(ArrowArrayToColumnConversionError::UnsupportedType {
datatype: self.data_type().clone(),
})
DataType::Timestamp(time_unit, tz) => {
let timezone = PoSQLTimeZone::try_from(tz)?;
match time_unit {
ArrowTimeUnit::Second => {
if let Some(array) = self.as_any().downcast_ref::<TimestampSecondArray>() {
Ok(Column::TimestampTZ(
PoSQLTimeUnit::Second,
timezone.into(),
&array.values()[range.start..range.end],
))
} else {
Err(ArrowArrayToColumnConversionError::UnsupportedType {
datatype: self.data_type().clone(),
})
}
}
}
ArrowTimeUnit::Millisecond => {
if let Some(array) = self.as_any().downcast_ref::<TimestampMillisecondArray>() {
Ok(Column::TimestampTZ(
PoSQLTimeUnit::Millisecond,
PoSQLTimeZone::try_from(tz)?,
&array.values()[range.start..range.end],
))
} else {
Err(ArrowArrayToColumnConversionError::UnsupportedType {
datatype: self.data_type().clone(),
})
ArrowTimeUnit::Millisecond => {
if let Some(array) =
self.as_any().downcast_ref::<TimestampMillisecondArray>()
{
Ok(Column::TimestampTZ(
PoSQLTimeUnit::Millisecond,
timezone.into(),
&array.values()[range.start..range.end],
))
} else {
Err(ArrowArrayToColumnConversionError::UnsupportedType {
datatype: self.data_type().clone(),
})
}
}
}
ArrowTimeUnit::Microsecond => {
if let Some(array) = self.as_any().downcast_ref::<TimestampMicrosecondArray>() {
Ok(Column::TimestampTZ(
PoSQLTimeUnit::Microsecond,
PoSQLTimeZone::try_from(tz)?,
&array.values()[range.start..range.end],
))
} else {
Err(ArrowArrayToColumnConversionError::UnsupportedType {
datatype: self.data_type().clone(),
})
ArrowTimeUnit::Microsecond => {
if let Some(array) =
self.as_any().downcast_ref::<TimestampMicrosecondArray>()
{
Ok(Column::TimestampTZ(
PoSQLTimeUnit::Microsecond,
timezone.into(),
&array.values()[range.start..range.end],
))
} else {
Err(ArrowArrayToColumnConversionError::UnsupportedType {
datatype: self.data_type().clone(),
})
}
}
}
ArrowTimeUnit::Nanosecond => {
if let Some(array) = self.as_any().downcast_ref::<TimestampNanosecondArray>() {
Ok(Column::TimestampTZ(
PoSQLTimeUnit::Nanosecond,
PoSQLTimeZone::try_from(tz)?,
&array.values()[range.start..range.end],
))
} else {
Err(ArrowArrayToColumnConversionError::UnsupportedType {
datatype: self.data_type().clone(),
})
ArrowTimeUnit::Nanosecond => {
if let Some(array) =
self.as_any().downcast_ref::<TimestampNanosecondArray>()
{
Ok(Column::TimestampTZ(
PoSQLTimeUnit::Nanosecond,
timezone.into(),
&array.values()[range.start..range.end],
))
} else {
Err(ArrowArrayToColumnConversionError::UnsupportedType {
datatype: self.data_type().clone(),
})
}
}
}
},
}
DataType::Utf8 => {
if let Some(array) = self.as_any().downcast_ref::<StringArray>() {
let vals = alloc
Expand Down Expand Up @@ -292,6 +301,7 @@ mod tests {
use alloc::sync::Arc;
use arrow::array::Decimal256Builder;
use core::str::FromStr;
use sqlparser::ast::TimezoneInfo;

#[test]
fn we_can_convert_timestamp_array_normal_range() {
Expand All @@ -305,7 +315,7 @@ mod tests {
let result = array.to_column::<TestScalar>(&alloc, &(1..3), None);
assert_eq!(
result.unwrap(),
Column::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), &data[1..3])
Column::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None, &data[1..3])
);
}

Expand All @@ -323,7 +333,7 @@ mod tests {
.unwrap();
assert_eq!(
result,
Column::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), &[])
Column::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None, &[])
);
}

Expand All @@ -339,7 +349,7 @@ mod tests {
let result = array.to_column::<DoryScalar>(&alloc, &(1..1), None);
assert_eq!(
result.unwrap(),
Column::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), &[])
Column::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None, &[])
);
}

Expand Down Expand Up @@ -1006,7 +1016,7 @@ mod tests {
.unwrap();
assert_eq!(
result,
Column::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), &data[..])
Column::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None, &data[..])
);
}

Expand Down Expand Up @@ -1076,7 +1086,7 @@ mod tests {
array
.to_column::<TestScalar>(&alloc, &(1..3), None)
.unwrap(),
Column::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), &data[1..3])
Column::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None, &data[1..3])
);
}

Expand Down Expand Up @@ -1134,7 +1144,7 @@ mod tests {
.unwrap();
assert_eq!(
result,
Column::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), &[])
Column::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None, &[])
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ impl TryFrom<DataType> for ColumnType {
};
Ok(ColumnType::TimestampTZ(
posql_time_unit,
PoSQLTimeZone::try_from(&timezone_option)?,
PoSQLTimeZone::try_from(&timezone_option)?.into(),
))
}
DataType::Utf8 => Ok(ColumnType::VarChar),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ impl<S: Scalar> TryFrom<&ArrayRef> for OwnedColumn<S> {
let timestamps = array.values().iter().copied().collect::<Vec<i64>>();
Ok(OwnedColumn::TimestampTZ(
PoSQLTimeUnit::Second,
PoSQLTimeZone::try_from(timezone)?,
PoSQLTimeZone::try_from(timezone)?.into(),
timestamps,
))
}
Expand All @@ -252,7 +252,7 @@ impl<S: Scalar> TryFrom<&ArrayRef> for OwnedColumn<S> {
let timestamps = array.values().iter().copied().collect::<Vec<i64>>();
Ok(OwnedColumn::TimestampTZ(
PoSQLTimeUnit::Millisecond,
PoSQLTimeZone::try_from(timezone)?,
PoSQLTimeZone::try_from(timezone)?.into(),
timestamps,
))
}
Expand All @@ -266,7 +266,7 @@ impl<S: Scalar> TryFrom<&ArrayRef> for OwnedColumn<S> {
let timestamps = array.values().iter().copied().collect::<Vec<i64>>();
Ok(OwnedColumn::TimestampTZ(
PoSQLTimeUnit::Microsecond,
PoSQLTimeZone::try_from(timezone)?,
PoSQLTimeZone::try_from(timezone)?.into(),
timestamps,
))
}
Expand All @@ -280,7 +280,7 @@ impl<S: Scalar> TryFrom<&ArrayRef> for OwnedColumn<S> {
let timestamps = array.values().iter().copied().collect::<Vec<i64>>();
Ok(OwnedColumn::TimestampTZ(
PoSQLTimeUnit::Nanosecond,
PoSQLTimeZone::try_from(timezone)?,
PoSQLTimeZone::try_from(timezone)?.into(),
timestamps,
))
}
Expand Down
5 changes: 3 additions & 2 deletions crates/proof-of-sql/src/base/commitment/column_bounds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,8 @@ mod tests {
};
use alloc::{string::String, vec};
use itertools::Itertools;
use proof_of_sql_parser::posql_time::{PoSQLTimeUnit, PoSQLTimeZone};
use proof_of_sql_parser::posql_time::PoSQLTimeUnit;
use sqlparser::ast::TimezoneInfo;

#[test]
fn we_can_construct_bounds_by_method() {
Expand Down Expand Up @@ -563,7 +564,7 @@ mod tests {

let timestamp_column = OwnedColumn::<TestScalar>::TimestampTZ(
PoSQLTimeUnit::Second,
PoSQLTimeZone::utc(),
TimezoneInfo::None,
vec![1_i64, 2, 3, 4],
);
let committable_timestamp_column = CommittableColumn::from(&timestamp_column);
Expand Down
Loading
Loading