Skip to content

Commit

Permalink
refactor!: PosqlTimeZone to use sqlparser::ast::TimezoneInfo
Browse files Browse the repository at this point in the history
  • Loading branch information
varshith257 committed Jan 4, 2025
1 parent 91173a3 commit eb75d22
Show file tree
Hide file tree
Showing 31 changed files with 275 additions and 207 deletions.
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
* text=auto
50 changes: 50 additions & 0 deletions crates/proof-of-sql-parser/src/sqlparser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::{
OrderBy as PoSqlOrderBy, OrderByDirection, SelectResultExpr, SetExpression,
TableExpression, UnaryOperator as PoSqlUnaryOperator,
},
posql_time::{PoSQLTimeZone, PoSQLTimestamp},
Identifier, ResourceId, SelectStatement,
};
use alloc::{boxed::Box, string::ToString, vec};
Expand All @@ -28,6 +29,50 @@ fn id(id: Identifier) -> Expr {
Expr::Identifier(id.into())
}

/// Provides an extension for the `TimezoneInfo` type for offsets.
pub trait TimezoneInfoExt {
/// Retrieve the offset in seconds for `TimezoneInfo`.
fn offset(&self, timezone_str: Option<&str>) -> i32;
}

impl TimezoneInfoExt for TimezoneInfo {
fn offset(&self, timezone_str: Option<&str>) -> i32 {
match self {
TimezoneInfo::None => PoSQLTimeZone::utc().offset(),
TimezoneInfo::WithTimeZone => match timezone_str {
Some(tz_str) => PoSQLTimeZone::try_from(&Some(tz_str.into()))
.unwrap_or_else(|_| PoSQLTimeZone::utc())
.offset(),
None => PoSQLTimeZone::utc().offset(),
},
_ => panic!("Offsets are not applicable for WithoutTimeZone or Tz variants."),
}
}
}

/// Convert a timestamp string into an [`Expr`].
impl From<&PoSQLTimestamp> for Expr {
fn from(timestamp: &PoSQLTimestamp) -> Self {
Expr::TypedString {
data_type: DataType::Timestamp(
Some(timestamp.timeunit().into()),
timestamp.timezone().into(),
),
value: timestamp.timestamp().to_string(),
}
}
}

/// Parses [`PoSQLTimeZone`] into a `TimezoneInfo`.
impl From<PoSQLTimeZone> for TimezoneInfo {
fn from(posql_timezone: PoSQLTimeZone) -> Self {
match posql_timezone.offset() {
0 => TimezoneInfo::None,
_ => TimezoneInfo::WithTimeZone,
}
}
}

impl From<Identifier> for Ident {
fn from(id: Identifier) -> Self {
Ident::new(id.as_str())
Expand Down Expand Up @@ -268,6 +313,11 @@ mod test {
"select timestamp '2024-11-07T04:55:12.345+03:00' as time from t;",
"select timestamp(3) '2024-11-07 01:55:12.345 UTC' as time from t;",
);

check_posql_intermediate_ast_to_sqlparser_equivalence(
"select timestamp '2024-11-07T04:55:12+00:00' as time from t;",
"select timestamp(0) '2024-11-07 04:55:12 UTC' as time from t;",
);
}

// Check that PoSQL intermediate AST can be converted to SQL parser AST and that the two are equal.
Expand Down
5 changes: 3 additions & 2 deletions crates/proof-of-sql/benches/bench_append_rows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ use proof_of_sql::{
DoryCommitment, DoryProverPublicSetup, DoryScalar, ProverSetup, PublicParameters,
},
};
use proof_of_sql_parser::posql_time::{PoSQLTimeUnit, PoSQLTimeZone};
use proof_of_sql_parser::posql_time::PoSQLTimeUnit;
use rand::Rng;
use sqlparser::ast::TimezoneInfo;

/// Bench dory performance when appending rows to a table. This includes the computation of
/// commitments. Chose the number of columns to randomly generate across supported `PoSQL`
Expand Down Expand Up @@ -121,7 +122,7 @@ pub fn generate_random_owned_table<S: Scalar>(
"timestamptz" => columns.push(timestamptz(
&*identifier,
PoSQLTimeUnit::Second,
PoSQLTimeZone::utc(),
TimezoneInfo::None,
vec![rng.gen::<i64>(); num_rows],
)),
_ => unreachable!(),
Expand Down
120 changes: 65 additions & 55 deletions crates/proof-of-sql/src/base/arrow/arrow_array_to_column_conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,60 +202,69 @@ impl ArrayRefExt for ArrayRef {
}
}
// Handle all possible TimeStamp TimeUnit instances
DataType::Timestamp(time_unit, tz) => match time_unit {
ArrowTimeUnit::Second => {
if let Some(array) = self.as_any().downcast_ref::<TimestampSecondArray>() {
Ok(Column::TimestampTZ(
PoSQLTimeUnit::Second,
PoSQLTimeZone::try_from(tz)?,
&array.values()[range.start..range.end],
))
} else {
Err(ArrowArrayToColumnConversionError::UnsupportedType {
datatype: self.data_type().clone(),
})
DataType::Timestamp(time_unit, tz) => {
let timezone = PoSQLTimeZone::try_from(tz)?;
match time_unit {
ArrowTimeUnit::Second => {
if let Some(array) = self.as_any().downcast_ref::<TimestampSecondArray>() {
Ok(Column::TimestampTZ(
PoSQLTimeUnit::Second,
timezone.into(),
&array.values()[range.start..range.end],
))
} else {
Err(ArrowArrayToColumnConversionError::UnsupportedType {
datatype: self.data_type().clone(),
})
}
}
}
ArrowTimeUnit::Millisecond => {
if let Some(array) = self.as_any().downcast_ref::<TimestampMillisecondArray>() {
Ok(Column::TimestampTZ(
PoSQLTimeUnit::Millisecond,
PoSQLTimeZone::try_from(tz)?,
&array.values()[range.start..range.end],
))
} else {
Err(ArrowArrayToColumnConversionError::UnsupportedType {
datatype: self.data_type().clone(),
})
ArrowTimeUnit::Millisecond => {
if let Some(array) =
self.as_any().downcast_ref::<TimestampMillisecondArray>()
{
Ok(Column::TimestampTZ(
PoSQLTimeUnit::Millisecond,
timezone.into(),
&array.values()[range.start..range.end],
))
} else {
Err(ArrowArrayToColumnConversionError::UnsupportedType {
datatype: self.data_type().clone(),
})
}
}
}
ArrowTimeUnit::Microsecond => {
if let Some(array) = self.as_any().downcast_ref::<TimestampMicrosecondArray>() {
Ok(Column::TimestampTZ(
PoSQLTimeUnit::Microsecond,
PoSQLTimeZone::try_from(tz)?,
&array.values()[range.start..range.end],
))
} else {
Err(ArrowArrayToColumnConversionError::UnsupportedType {
datatype: self.data_type().clone(),
})
ArrowTimeUnit::Microsecond => {
if let Some(array) =
self.as_any().downcast_ref::<TimestampMicrosecondArray>()
{
Ok(Column::TimestampTZ(
PoSQLTimeUnit::Microsecond,
timezone.into(),
&array.values()[range.start..range.end],
))
} else {
Err(ArrowArrayToColumnConversionError::UnsupportedType {
datatype: self.data_type().clone(),
})
}
}
}
ArrowTimeUnit::Nanosecond => {
if let Some(array) = self.as_any().downcast_ref::<TimestampNanosecondArray>() {
Ok(Column::TimestampTZ(
PoSQLTimeUnit::Nanosecond,
PoSQLTimeZone::try_from(tz)?,
&array.values()[range.start..range.end],
))
} else {
Err(ArrowArrayToColumnConversionError::UnsupportedType {
datatype: self.data_type().clone(),
})
ArrowTimeUnit::Nanosecond => {
if let Some(array) =
self.as_any().downcast_ref::<TimestampNanosecondArray>()
{
Ok(Column::TimestampTZ(
PoSQLTimeUnit::Nanosecond,
timezone.into(),
&array.values()[range.start..range.end],
))
} else {
Err(ArrowArrayToColumnConversionError::UnsupportedType {
datatype: self.data_type().clone(),
})
}
}
}
},
}
DataType::Utf8 => {
if let Some(array) = self.as_any().downcast_ref::<StringArray>() {
let vals = alloc
Expand Down Expand Up @@ -292,6 +301,7 @@ mod tests {
use alloc::sync::Arc;
use arrow::array::Decimal256Builder;
use core::str::FromStr;
use sqlparser::ast::TimezoneInfo;

#[test]
fn we_can_convert_timestamp_array_normal_range() {
Expand All @@ -305,7 +315,7 @@ mod tests {
let result = array.to_column::<TestScalar>(&alloc, &(1..3), None);
assert_eq!(
result.unwrap(),
Column::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), &data[1..3])
Column::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None, &data[1..3])
);
}

Expand All @@ -323,7 +333,7 @@ mod tests {
.unwrap();
assert_eq!(
result,
Column::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), &[])
Column::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None, &[])
);
}

Expand All @@ -339,7 +349,7 @@ mod tests {
let result = array.to_column::<DoryScalar>(&alloc, &(1..1), None);
assert_eq!(
result.unwrap(),
Column::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), &[])
Column::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None, &[])
);
}

Expand Down Expand Up @@ -1006,7 +1016,7 @@ mod tests {
.unwrap();
assert_eq!(
result,
Column::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), &data[..])
Column::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None, &data[..])
);
}

Expand Down Expand Up @@ -1076,7 +1086,7 @@ mod tests {
array
.to_column::<TestScalar>(&alloc, &(1..3), None)
.unwrap(),
Column::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), &data[1..3])
Column::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None, &data[1..3])
);
}

Expand Down Expand Up @@ -1134,7 +1144,7 @@ mod tests {
.unwrap();
assert_eq!(
result,
Column::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), &[])
Column::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None, &[])
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ impl TryFrom<DataType> for ColumnType {
};
Ok(ColumnType::TimestampTZ(
posql_time_unit,
PoSQLTimeZone::try_from(&timezone_option)?,
PoSQLTimeZone::try_from(&timezone_option)?.into(),
))
}
DataType::Utf8 => Ok(ColumnType::VarChar),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ impl<S: Scalar> TryFrom<&ArrayRef> for OwnedColumn<S> {
let timestamps = array.values().iter().copied().collect::<Vec<i64>>();
Ok(OwnedColumn::TimestampTZ(
PoSQLTimeUnit::Second,
PoSQLTimeZone::try_from(timezone)?,
PoSQLTimeZone::try_from(timezone)?.into(),
timestamps,
))
}
Expand All @@ -252,7 +252,7 @@ impl<S: Scalar> TryFrom<&ArrayRef> for OwnedColumn<S> {
let timestamps = array.values().iter().copied().collect::<Vec<i64>>();
Ok(OwnedColumn::TimestampTZ(
PoSQLTimeUnit::Millisecond,
PoSQLTimeZone::try_from(timezone)?,
PoSQLTimeZone::try_from(timezone)?.into(),
timestamps,
))
}
Expand All @@ -266,7 +266,7 @@ impl<S: Scalar> TryFrom<&ArrayRef> for OwnedColumn<S> {
let timestamps = array.values().iter().copied().collect::<Vec<i64>>();
Ok(OwnedColumn::TimestampTZ(
PoSQLTimeUnit::Microsecond,
PoSQLTimeZone::try_from(timezone)?,
PoSQLTimeZone::try_from(timezone)?.into(),
timestamps,
))
}
Expand All @@ -280,7 +280,7 @@ impl<S: Scalar> TryFrom<&ArrayRef> for OwnedColumn<S> {
let timestamps = array.values().iter().copied().collect::<Vec<i64>>();
Ok(OwnedColumn::TimestampTZ(
PoSQLTimeUnit::Nanosecond,
PoSQLTimeZone::try_from(timezone)?,
PoSQLTimeZone::try_from(timezone)?.into(),
timestamps,
))
}
Expand Down
5 changes: 3 additions & 2 deletions crates/proof-of-sql/src/base/commitment/column_bounds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,8 @@ mod tests {
};
use alloc::{string::String, vec};
use itertools::Itertools;
use proof_of_sql_parser::posql_time::{PoSQLTimeUnit, PoSQLTimeZone};
use proof_of_sql_parser::posql_time::PoSQLTimeUnit;
use sqlparser::ast::TimezoneInfo;

#[test]
fn we_can_construct_bounds_by_method() {
Expand Down Expand Up @@ -563,7 +564,7 @@ mod tests {

let timestamp_column = OwnedColumn::<TestScalar>::TimestampTZ(
PoSQLTimeUnit::Second,
PoSQLTimeZone::utc(),
TimezoneInfo::None,
vec![1_i64, 2, 3, 4],
);
let committable_timestamp_column = CommittableColumn::from(&timestamp_column);
Expand Down
Loading

0 comments on commit eb75d22

Please sign in to comment.