Skip to content

Commit

Permalink
add partition filter
Browse files Browse the repository at this point in the history
  • Loading branch information
houqp committed May 27, 2024
1 parent aee976a commit 0267450
Show file tree
Hide file tree
Showing 6 changed files with 230 additions and 15 deletions.
12 changes: 8 additions & 4 deletions datafusion/common/src/scalar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ use arrow::{
compute::kernels::cast::{cast_with_options, CastOptions},
datatypes::{
i256, ArrowDictionaryKeyType, ArrowNativeType, ArrowTimestampType, DataType,
Date32Type, Field, Float32Type, Int16Type, Int32Type, Int64Type, Int8Type,
IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit,
Date32Type, Date64Type, Field, Float32Type, Int16Type, Int32Type, Int64Type,
Int8Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit,
IntervalYearMonthType, TimeUnit, TimestampMicrosecondType,
TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
UInt16Type, UInt32Type, UInt64Type, UInt8Type, DECIMAL128_MAX_PRECISION,
Expand Down Expand Up @@ -3188,8 +3188,12 @@ impl fmt::Display for ScalarValue {
ScalarValue::List(arr) => fmt_list(arr.to_owned() as ArrayRef, f)?,
ScalarValue::LargeList(arr) => fmt_list(arr.to_owned() as ArrayRef, f)?,
ScalarValue::FixedSizeList(arr) => fmt_list(arr.to_owned() as ArrayRef, f)?,
ScalarValue::Date32(e) => format_option!(f, e)?,
ScalarValue::Date64(e) => format_option!(f, e)?,
ScalarValue::Date32(e) => {
format_option!(f, e.map(|v| Date32Type::to_naive_date(v).to_string()))?
}
ScalarValue::Date64(e) => {
format_option!(f, e.map(|v| Date64Type::to_naive_date(v).to_string()))?
}
ScalarValue::Time32Second(e) => format_option!(f, e)?,
ScalarValue::Time32Millisecond(e) => format_option!(f, e)?,
ScalarValue::Time64Microsecond(e) => format_option!(f, e)?,
Expand Down
217 changes: 214 additions & 3 deletions datafusion/core/src/datasource/listing/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@

//! Helper functions for the table implementation
use std::collections::HashMap;
use std::sync::Arc;

use super::PartitionedFile;
use crate::datasource::listing::ListingTableUrl;
use crate::execution::context::SessionState;
use crate::logical_expr::{BinaryExpr, Operator};
use crate::{error::Result, scalar::ScalarValue};

use arrow::{
Expand Down Expand Up @@ -185,9 +187,17 @@ async fn list_partitions(
store: &dyn ObjectStore,
table_path: &ListingTableUrl,
max_depth: usize,
partition_prefix: Option<Path>,
) -> Result<Vec<Partition>> {
let partition = Partition {
path: table_path.prefix().clone(),
path: match partition_prefix {
Some(prefix) => Path::from_iter(
Path::from(table_path.prefix().as_ref())
.parts()
.chain(Path::from(prefix.as_ref()).parts()),
),
None => table_path.prefix().clone(),
},
depth: 0,
files: None,
};
Expand Down Expand Up @@ -321,6 +331,81 @@ async fn prune_partitions(
Ok(filtered)
}

#[derive(Debug)]
enum PartitionValue {
Single(String),
Multi,
}

fn populate_partition_values<'a>(
partition_values: &mut HashMap<&'a str, PartitionValue>,
filter: &'a Expr,
) {
if let Expr::BinaryExpr(BinaryExpr {
ref left,
op,
ref right,
}) = filter
{
match op {
Operator::Eq => match (left.as_ref(), right.as_ref()) {
(Expr::Column(Column { ref name, .. }), Expr::Literal(val))
| (Expr::Literal(val), Expr::Column(Column { ref name, .. })) => {
if partition_values
.insert(name, PartitionValue::Single(val.to_string()))
.is_some()
{
partition_values.insert(name, PartitionValue::Multi);
}
}
_ => {}
},
Operator::And => {
populate_partition_values(partition_values, left);
populate_partition_values(partition_values, right);
}
_ => {}
}
}
}

fn evaluate_partition_prefix<'a>(
partition_cols: &'a [(String, DataType)],
filters: &'a [Expr],
) -> Option<Path> {
let mut partition_values = HashMap::new();

if filters.len() > 1 {
return None;
}

for filter in filters {
populate_partition_values(&mut partition_values, filter);
}

if partition_values.is_empty() {
return None;
}

let mut parts = vec![];
for (p, _) in partition_cols {
match partition_values.get(p.as_str()) {
Some(PartitionValue::Single(val)) => {
parts.push(format!("{p}={val}"));
}
_ => {
break;
}
}
}

if parts.is_empty() {
None
} else {
Some(Path::from_iter(parts))
}
}

/// Discover the partitions on the given path and prune out files
/// that belong to irrelevant partitions using `filters` expressions.
/// `filters` might contain expressions that can be resolved only at the
Expand All @@ -343,7 +428,10 @@ pub async fn pruned_partition_list<'a>(
));
}

let partitions = list_partitions(store, table_path, partition_cols.len()).await?;
let partition_prefix = evaluate_partition_prefix(partition_cols, filters);
let partitions =
list_partitions(store, table_path, partition_cols.len(), partition_prefix)
.await?;
debug!("Listed {} partitions", partitions.len());

let pruned =
Expand Down Expand Up @@ -433,7 +521,7 @@ mod tests {

use futures::StreamExt;

use crate::logical_expr::{case, col, lit};
use crate::logical_expr::{case, col, lit, Expr, Operator};
use crate::test::object_store::make_test_store_and_state;

use super::*;
Expand Down Expand Up @@ -692,4 +780,127 @@ mod tests {
// this helper function
assert!(expr_applicable_for_cols(&[], &lit(true)));
}

#[test]
fn test_evaluate_partition_prefix() {
let partitions = &[
("a".to_string(), DataType::Utf8),
("b".to_string(), DataType::Int16),
("c".to_string(), DataType::Boolean),
];

assert_eq!(
evaluate_partition_prefix(partitions, &[Expr::eq(col("a"), lit("foo"))],),
Some(Path::from("a=foo")),
);

assert_eq!(
evaluate_partition_prefix(
partitions,
&[Expr::and(
Expr::eq(col("a"), lit("foo")),
Expr::eq(col("b"), lit("bar")),
)],
),
Some(Path::from("a=foo/b=bar")),
);

assert_eq!(
evaluate_partition_prefix(
partitions,
&[Expr::and(
Expr::eq(col("a"), lit("foo")),
Expr::and(
Expr::eq(col("b"), lit("1")),
Expr::eq(col("c"), lit("true")),
),
)],
),
Some(Path::from("a=foo/b=1/c=true")),
);

// no prefix when filter is empty
assert_eq!(evaluate_partition_prefix(partitions, &[]), None);

// b=foo results in no prefix because a is not restricted
assert_eq!(
evaluate_partition_prefix(partitions, &[Expr::eq(col("b"), lit("foo"))],),
None,
);

// a=foo and c=baz only results in preifx a=foo because b is not restricted
assert_eq!(
evaluate_partition_prefix(
partitions,
&[Expr::and(
Expr::eq(col("a"), lit("foo")),
Expr::eq(col("c"), lit("baz")),
)],
),
Some(Path::from("a=foo")),
);

// a=foo or b=bar results in no prefix
assert_eq!(
evaluate_partition_prefix(
partitions,
&[
Expr::eq(col("a"), lit("foo")),
Expr::eq(col("b"), lit("bar")),
],
),
None,
);

// partition with multiple values results in no prefix
assert_eq!(
evaluate_partition_prefix(
partitions,
&[Expr::and(
Expr::eq(col("a"), lit("foo")),
Expr::eq(col("a"), lit("bar")),
)],
),
None,
);

// no prefix because partition a is not restricted to a single literal
assert_eq!(
evaluate_partition_prefix(
partitions,
&[Expr::or(
Expr::eq(col("a"), lit("foo")),
Expr::eq(col("a"), lit("bar")),
)],
),
None,
);
}

#[test]
fn test_evaluate_date_partition_prefix() {
let partitions = &[("a".to_string(), DataType::Date32)];
assert_eq!(
evaluate_partition_prefix(
partitions,
&[Expr::eq(
col("a"),
Expr::Literal(ScalarValue::Date32(Some(3)))
)],
),
Some(Path::from("a=1970-01-04")),
);

let partitions = &[("a".to_string(), DataType::Date64)];
assert_eq!(
evaluate_partition_prefix(
partitions,
&[Expr::eq(
col("a"),
Expr::Literal(ScalarValue::Date64(Some(4 * 24 * 60 * 60 * 1000)))
)],
),
Some(Path::from("a=1970-01-05")),
);
}
}
2 changes: 1 addition & 1 deletion datafusion/core/tests/simplification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ fn select_date_plus_interval() -> Result<()> {

// Note that constant folder runs and folds the entire
// expression down to a single constant (true)
let expected = r#"Projection: Date32("18636") AS to_timestamp(Utf8("2020-09-08T12:05:00+00:00")) + IntervalDayTime("528280977408")
let expected = r#"Projection: Date32("2021-01-09") AS to_timestamp(Utf8("2020-09-08T12:05:00+00:00")) + IntervalDayTime("528280977408")
TableScan: test"#;
let actual = get_optimized_plan_formatted(&plan, &time);

Expand Down
4 changes: 2 additions & 2 deletions datafusion/optimizer/tests/optimizer_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ fn between_date32_plus_interval() -> Result<()> {
let expected =
"Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]]\
\n Projection: \
\n Filter: test.col_date32 >= Date32(\"10303\") AND test.col_date32 <= Date32(\"10393\")\
\n Filter: test.col_date32 >= Date32(\"1998-03-18\") AND test.col_date32 <= Date32(\"1998-06-16\")\
\n TableScan: test projection=[col_date32]";
assert_eq!(expected, format!("{plan:?}"));
Ok(())
Expand All @@ -200,7 +200,7 @@ fn between_date64_plus_interval() -> Result<()> {
let expected =
"Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]]\
\n Projection: \
\n Filter: test.col_date64 >= Date64(\"890179200000\") AND test.col_date64 <= Date64(\"897955200000\")\
\n Filter: test.col_date64 >= Date64(\"1998-03-18\") AND test.col_date64 <= Date64(\"1998-06-16\")\
\n TableScan: test projection=[col_date64]";
assert_eq!(expected, format!("{plan:?}"));
Ok(())
Expand Down
6 changes: 3 additions & 3 deletions datafusion/sqllogictest/test_files/tpch/q1.slt.part
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ Sort: lineitem.l_returnflag ASC NULLS LAST, lineitem.l_linestatus ASC NULLS LAST
--Projection: lineitem.l_returnflag, lineitem.l_linestatus, SUM(lineitem.l_quantity) AS sum_qty, SUM(lineitem.l_extendedprice) AS sum_base_price, SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS sum_disc_price, SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount * Int64(1) + lineitem.l_tax) AS sum_charge, AVG(lineitem.l_quantity) AS avg_qty, AVG(lineitem.l_extendedprice) AS avg_price, AVG(lineitem.l_discount) AS avg_disc, COUNT(*) AS count_order
----Aggregate: groupBy=[[lineitem.l_returnflag, lineitem.l_linestatus]], aggr=[[SUM(lineitem.l_quantity), SUM(lineitem.l_extendedprice), SUM(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)Decimal128(Some(1),20,0) - lineitem.l_discountlineitem.l_discountDecimal128(Some(1),20,0)lineitem.l_extendedprice AS lineitem.l_extendedprice * Decimal128(Some(1),20,0) - lineitem.l_discount) AS SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount), SUM(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)Decimal128(Some(1),20,0) - lineitem.l_discountlineitem.l_discountDecimal128(Some(1),20,0)lineitem.l_extendedprice AS lineitem.l_extendedprice * Decimal128(Some(1),20,0) - lineitem.l_discount * (Decimal128(Some(1),20,0) + lineitem.l_tax)) AS SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount * Int64(1) + lineitem.l_tax), AVG(lineitem.l_quantity), AVG(lineitem.l_extendedprice), AVG(lineitem.l_discount), COUNT(UInt8(1)) AS COUNT(*)]]
------Projection: lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount) AS lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)Decimal128(Some(1),20,0) - lineitem.l_discountlineitem.l_discountDecimal128(Some(1),20,0)lineitem.l_extendedprice, lineitem.l_quantity, lineitem.l_extendedprice, lineitem.l_discount, lineitem.l_tax, lineitem.l_returnflag, lineitem.l_linestatus
--------Filter: lineitem.l_shipdate <= Date32("10471")
----------TableScan: lineitem projection=[l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate], partial_filters=[lineitem.l_shipdate <= Date32("10471")]
--------Filter: lineitem.l_shipdate <= Date32("1998-09-02")
----------TableScan: lineitem projection=[l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate], partial_filters=[lineitem.l_shipdate <= Date32("1998-09-02")]
physical_plan
SortPreservingMergeExec: [l_returnflag@0 ASC NULLS LAST,l_linestatus@1 ASC NULLS LAST]
--SortExec: expr=[l_returnflag@0 ASC NULLS LAST,l_linestatus@1 ASC NULLS LAST]
Expand All @@ -56,7 +56,7 @@ SortPreservingMergeExec: [l_returnflag@0 ASC NULLS LAST,l_linestatus@1 ASC NULLS
------------AggregateExec: mode=Partial, gby=[l_returnflag@5 as l_returnflag, l_linestatus@6 as l_linestatus], aggr=[SUM(lineitem.l_quantity), SUM(lineitem.l_extendedprice), SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount), SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount * Int64(1) + lineitem.l_tax), AVG(lineitem.l_quantity), AVG(lineitem.l_extendedprice), AVG(lineitem.l_discount), COUNT(*)]
--------------ProjectionExec: expr=[l_extendedprice@1 * (Some(1),20,0 - l_discount@2) as lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)Decimal128(Some(1),20,0) - lineitem.l_discountlineitem.l_discountDecimal128(Some(1),20,0)lineitem.l_extendedprice, l_quantity@0 as l_quantity, l_extendedprice@1 as l_extendedprice, l_discount@2 as l_discount, l_tax@3 as l_tax, l_returnflag@4 as l_returnflag, l_linestatus@5 as l_linestatus]
----------------CoalesceBatchesExec: target_batch_size=8192
------------------FilterExec: l_shipdate@6 <= 10471
------------------FilterExec: l_shipdate@6 <= 1998-09-02
--------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate], has_header=false

query TTRRRRRRRI
Expand Down
4 changes: 2 additions & 2 deletions datafusion/sqllogictest/test_files/tpch/q10.slt.part
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ Limit: skip=0, fetch=10
------------------Inner Join: customer.c_custkey = orders.o_custkey
--------------------TableScan: customer projection=[c_custkey, c_name, c_address, c_nationkey, c_phone, c_acctbal, c_comment]
--------------------Projection: orders.o_orderkey, orders.o_custkey
----------------------Filter: orders.o_orderdate >= Date32("8674") AND orders.o_orderdate < Date32("8766")
----------------------Filter: orders.o_orderdate >= Date32("1993-10-01") AND orders.o_orderdate < Date32("1994-01-01")
------------------------TableScan: orders projection=[o_orderkey, o_custkey, o_orderdate], partial_filters=[orders.o_orderdate >= Date32("8674"), orders.o_orderdate < Date32("8766")]
----------------Projection: lineitem.l_orderkey, lineitem.l_extendedprice, lineitem.l_discount
------------------Filter: lineitem.l_returnflag = Utf8("R")
Expand Down Expand Up @@ -96,7 +96,7 @@ GlobalLimitExec: skip=0, fetch=10
--------------------------------------RepartitionExec: partitioning=Hash([o_custkey@1], 4), input_partitions=4
----------------------------------------ProjectionExec: expr=[o_orderkey@0 as o_orderkey, o_custkey@1 as o_custkey]
------------------------------------------CoalesceBatchesExec: target_batch_size=8192
--------------------------------------------FilterExec: o_orderdate@2 >= 8674 AND o_orderdate@2 < 8766
--------------------------------------------FilterExec: o_orderdate@2 >= 1993-10-01 AND o_orderdate@2 < 1994-01-01
----------------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:0..4223281], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:4223281..8446562], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:8446562..12669843], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:12669843..16893122]]}, projection=[o_orderkey, o_custkey, o_orderdate], has_header=false
----------------------------CoalesceBatchesExec: target_batch_size=8192
------------------------------RepartitionExec: partitioning=Hash([l_orderkey@0], 4), input_partitions=4
Expand Down

0 comments on commit 0267450

Please sign in to comment.