Skip to content

Commit

Permalink
feat: add projection to FilterExec
Browse files Browse the repository at this point in the history
  • Loading branch information
junjunjd committed Oct 26, 2023
1 parent 4881b5d commit 8149f9b
Show file tree
Hide file tree
Showing 9 changed files with 48 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,7 @@ mod tests {
cast(expressions::lit(1u32), &schema, DataType::Int32)?,
&schema,
)?,
None,
source,
)?);

Expand Down Expand Up @@ -578,6 +579,7 @@ mod tests {
cast(expressions::lit(1u32), &schema, DataType::Int32)?,
&schema,
)?,
None,
source,
)?);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2001,7 +2001,7 @@ mod tests {
Operator::Eq,
Arc::new(Literal::new(ScalarValue::Int64(Some(0)))),
));
Arc::new(FilterExec::try_new(predicate, input).unwrap())
Arc::new(FilterExec::try_new(predicate, None, input).unwrap())
}

fn sort_exec(
Expand Down Expand Up @@ -2649,7 +2649,7 @@ mod tests {
)?;

let filter_top_join: Arc<dyn ExecutionPlan> =
Arc::new(FilterExec::try_new(predicate, top_join)?);
Arc::new(FilterExec::try_new(predicate, None, top_join)?);

// The bottom joins' join key ordering is adjusted based on the top join. And the top join should not introduce additional RepartitionExec
let expected = &[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,7 @@ mod tests {
schema,
)
.unwrap();
Arc::new(FilterExec::try_new(predicate, input).unwrap())
Arc::new(FilterExec::try_new(predicate, None, input).unwrap())
}

fn coalesce_batches_exec(input: Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/physical_optimizer/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ pub fn filter_exec(
predicate: Arc<dyn PhysicalExpr>,
input: Arc<dyn ExecutionPlan>,
) -> Arc<dyn ExecutionPlan> {
Arc::new(FilterExec::try_new(predicate, input).unwrap())
Arc::new(FilterExec::try_new(predicate, None, input).unwrap())
}

pub fn sort_preserving_merge_exec(
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -914,7 +914,7 @@ impl DefaultPhysicalPlanner {
&input_schema,
session_state,
)?;
Ok(Arc::new(FilterExec::try_new(runtime_expr, physical_input)?))
Ok(Arc::new(FilterExec::try_new(runtime_expr, None, physical_input)?))
}
LogicalPlan::Union(Union { inputs, schema }) => {
let physical_plans = self.create_initial_plan_multi(inputs.iter().map(|lp| lp.as_ref()), session_state).await?;
Expand Down
6 changes: 5 additions & 1 deletion datafusion/core/src/test_util/parquet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,11 @@ impl TestParquetFile {
None,
));

let exec = Arc::new(FilterExec::try_new(physical_filter_expr, parquet_exec)?);
let exec = Arc::new(FilterExec::try_new(
physical_filter_expr,
None,
parquet_exec,
)?);
Ok(exec)
} else {
Ok(Arc::new(ParquetExec::new(scan_config, None, None)))
Expand Down
46 changes: 34 additions & 12 deletions datafusion/physical-plan/src/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ use arrow::datatypes::{DataType, SchemaRef};
use arrow::record_batch::RecordBatch;
use datafusion_common::cast::as_boolean_array;
use datafusion_common::stats::Precision;
use datafusion_common::{plan_err, DataFusionError, Result};
use datafusion_common::{plan_err, project_schema, DataFusionError, Result};
use datafusion_execution::TaskContext;
use datafusion_expr::Operator;
use datafusion_physical_expr::expressions::BinaryExpr;
Expand All @@ -59,6 +59,10 @@ use log::trace;
pub struct FilterExec {
/// The expression to filter on. This expression must evaluate to a boolean value.
predicate: Arc<dyn PhysicalExpr>,
/// Optional projection
projection: Option<Vec<usize>>,
/// Schema representing the data after the optional projection is applied
projected_schema: SchemaRef,
/// The input plan
input: Arc<dyn ExecutionPlan>,
/// Execution metrics
Expand All @@ -69,11 +73,15 @@ impl FilterExec {
/// Create a FilterExec on an input
pub fn try_new(
predicate: Arc<dyn PhysicalExpr>,
projection: Option<Vec<usize>>,
input: Arc<dyn ExecutionPlan>,
) -> Result<Self> {
match predicate.data_type(input.schema().as_ref())? {
let projected_schema = project_schema(&input.schema(), projection.as_ref())?;
match predicate.data_type(projected_schema.as_ref())? {
DataType::Boolean => Ok(Self {
predicate,
projection,
projected_schema,
input: input.clone(),
metrics: ExecutionPlanMetricsSet::new(),
}),
Expand Down Expand Up @@ -117,7 +125,7 @@ impl ExecutionPlan for FilterExec {
/// Get the schema for this execution plan
fn schema(&self) -> SchemaRef {
// The filter operator does not make any changes to the schema of its input
self.input.schema()
self.projected_schema.clone()
}

fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
Expand Down Expand Up @@ -175,6 +183,7 @@ impl ExecutionPlan for FilterExec {
) -> Result<Arc<dyn ExecutionPlan>> {
Ok(Arc::new(FilterExec::try_new(
self.predicate.clone(),
self.projection.clone(),
children[0].clone(),
)?))
}
Expand All @@ -187,8 +196,9 @@ impl ExecutionPlan for FilterExec {
trace!("Start FilterExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id());
let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
Ok(Box::pin(FilterExecStream {
schema: self.input.schema(),
schema: self.schema(),
predicate: self.predicate.clone(),
projection: self.projection.clone(),
input: self.input.execute(partition, context)?,
baseline_metrics,
}))
Expand All @@ -212,7 +222,7 @@ impl ExecutionPlan for FilterExec {
let num_rows = input_stats.num_rows;
let total_byte_size = input_stats.total_byte_size;
let input_analysis_ctx = AnalysisContext::try_from_statistics(
&self.input.schema(),
&self.schema(),
&input_stats.column_statistics,
)?;
let analysis_ctx = analyze(predicate, input_analysis_ctx)?;
Expand Down Expand Up @@ -285,6 +295,8 @@ struct FilterExecStream {
schema: SchemaRef,
/// The expression to filter on. This expression must evaluate to a boolean value.
predicate: Arc<dyn PhysicalExpr>,
/// Optional projection
projection: Option<Vec<usize>>,
/// The input partition to filter.
input: SendableRecordBatchStream,
/// runtime metrics recording
Expand Down Expand Up @@ -318,6 +330,11 @@ impl Stream for FilterExecStream {
Poll::Ready(value) => match value {
Some(Ok(batch)) => {
let timer = self.baseline_metrics.elapsed_compute().timer();
// load just the columns requested
let batch = match self.projection.as_ref() {
Some(columns) => batch.project(columns)?,
None => batch.clone(),
};
let filtered_batch = batch_filter(&batch, &self.predicate)?;
// skip entirely filtered batches
if filtered_batch.num_rows() == 0 {
Expand Down Expand Up @@ -468,7 +485,7 @@ mod tests {

// WHERE a <= 25
let filter: Arc<dyn ExecutionPlan> =
Arc::new(FilterExec::try_new(predicate, input)?);
Arc::new(FilterExec::try_new(predicate, None, input)?);

let statistics = filter.statistics()?;
assert_eq!(statistics.num_rows, Precision::Inexact(25));
Expand Down Expand Up @@ -509,6 +526,7 @@ mod tests {
// WHERE a <= 25
let sub_filter: Arc<dyn ExecutionPlan> = Arc::new(FilterExec::try_new(
binary(col("a", &schema)?, Operator::LtEq, lit(25i32), &schema)?,
None,
input,
)?);

Expand All @@ -517,6 +535,7 @@ mod tests {
// WHERE a <= 25
let filter: Arc<dyn ExecutionPlan> = Arc::new(FilterExec::try_new(
binary(col("a", &schema)?, Operator::GtEq, lit(10i32), &schema)?,
None,
sub_filter,
)?);

Expand Down Expand Up @@ -566,18 +585,21 @@ mod tests {
// WHERE a <= 25
let a_lte_25: Arc<dyn ExecutionPlan> = Arc::new(FilterExec::try_new(
binary(col("a", &schema)?, Operator::LtEq, lit(25i32), &schema)?,
None,
input,
)?);

// WHERE b > 45
let b_gt_5: Arc<dyn ExecutionPlan> = Arc::new(FilterExec::try_new(
binary(col("b", &schema)?, Operator::Gt, lit(45i32), &schema)?,
None,
a_lte_25,
)?);

// WHERE a >= 10
let filter: Arc<dyn ExecutionPlan> = Arc::new(FilterExec::try_new(
binary(col("a", &schema)?, Operator::GtEq, lit(10i32), &schema)?,
None,
b_gt_5,
)?);
let statistics = filter.statistics()?;
Expand Down Expand Up @@ -623,7 +645,7 @@ mod tests {

// WHERE a <= 25
let filter: Arc<dyn ExecutionPlan> =
Arc::new(FilterExec::try_new(predicate, input)?);
Arc::new(FilterExec::try_new(predicate, None, input)?);

let statistics = filter.statistics()?;
assert_eq!(statistics.num_rows, Precision::Absent);
Expand Down Expand Up @@ -697,7 +719,7 @@ mod tests {
)),
));
let filter: Arc<dyn ExecutionPlan> =
Arc::new(FilterExec::try_new(predicate, input)?);
Arc::new(FilterExec::try_new(predicate, None, input)?);
let statistics = filter.statistics()?;
// 0.5 (from a) * 0.333333... (from b) * 0.798387... (from c) ≈ 0.1330...
// num_rows after ceil => 133.0... => 134
Expand Down Expand Up @@ -796,7 +818,7 @@ mod tests {
// Since filter predicate passes all entries, statistics after filter shouldn't change.
let expected = input.statistics()?.column_statistics;
let filter: Arc<dyn ExecutionPlan> =
Arc::new(FilterExec::try_new(predicate, input)?);
Arc::new(FilterExec::try_new(predicate, None, input)?);
let statistics = filter.statistics()?;

assert_eq!(statistics.num_rows, Precision::Inexact(1000));
Expand Down Expand Up @@ -849,7 +871,7 @@ mod tests {
)),
));
let filter: Arc<dyn ExecutionPlan> =
Arc::new(FilterExec::try_new(predicate, input)?);
Arc::new(FilterExec::try_new(predicate, None, input)?);
let statistics = filter.statistics()?;

assert_eq!(statistics.num_rows, Precision::Inexact(0));
Expand Down Expand Up @@ -905,7 +927,7 @@ mod tests {
Arc::new(Literal::new(ScalarValue::Int32(Some(50)))),
));
let filter: Arc<dyn ExecutionPlan> =
Arc::new(FilterExec::try_new(predicate, input)?);
Arc::new(FilterExec::try_new(predicate, None, input)?);
let statistics = filter.statistics()?;

assert_eq!(statistics.num_rows, Precision::Inexact(490));
Expand Down Expand Up @@ -955,7 +977,7 @@ mod tests {
)),
));
let filter: Arc<dyn ExecutionPlan> =
Arc::new(FilterExec::try_new(predicate, input)?);
Arc::new(FilterExec::try_new(predicate, None, input)?);
let filter_statistics = filter.statistics()?;

let expected_filter_statistics = Statistics {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/proto/src/physical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ impl AsExecutionPlan for PhysicalPlanNode {
.to_owned(),
)
})?;
Ok(Arc::new(FilterExec::try_new(predicate, input)?))
Ok(Arc::new(FilterExec::try_new(predicate, None, input)?))
}
PhysicalPlanType::CsvScan(scan) => Ok(Arc::new(CsvExec::new(
parse_protobuf_file_scan_config(
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/tests/cases/roundtrip_physical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ fn roundtrip_filter_with_not_and_in_list() -> Result<()> {
let and = binary(not, Operator::And, in_list, &schema)?;
roundtrip_test(Arc::new(FilterExec::try_new(
and,
None,
Arc::new(EmptyExec::new(false, schema.clone())),
)?))
}
Expand Down

0 comments on commit 8149f9b

Please sign in to comment.