From a42524bdeb7a5bbd7c61685d160a4dbc20c5f08a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Tue, 30 Jul 2024 12:43:27 +0200 Subject: [PATCH] Respect nulls in approx_percentile_cont --- .../src/approx_percentile_cont.rs | 17 ++++++++++++++--- .../sqllogictest/test_files/aggregate.slt | 6 ++++++ 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index dfb94a84cbec..40d7378bb3ed 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -19,7 +19,8 @@ use std::any::Any; use std::fmt::{Debug, Formatter}; use std::sync::Arc; -use arrow::array::RecordBatch; +use arrow::array::{Array, RecordBatch}; +use arrow::compute::{filter, is_not_null}; use arrow::{ array::{ ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, @@ -104,6 +105,12 @@ impl ApproxPercentileCont { None }; + if args.ignore_nulls { + return not_impl_err!( + "IGNORE NULLS clause not yet supported for APPROX_PERCENTILE_CONT" + ); + } + let accumulator: ApproxPercentileAccumulator = match args.input_type { t @ (DataType::UInt8 | DataType::UInt16 @@ -393,8 +400,12 @@ impl Accumulator for ApproxPercentileAccumulator { } fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion_common::Result<()> { - let values = &values[0]; - let sorted_values = &arrow::compute::sort(values, None)?; + // respect nulls by default + let mut values = values[0]; + if let Some(nulls) = values.nulls() { + values = filter(&values, &is_not_null(values)?)?; + } + let sorted_values = &arrow::compute::sort(&values, None)?; let sorted_values = ApproxPercentileAccumulator::convert_to_float(sorted_values)?; self.digest = self.digest.merge_sorted_f64(&sorted_values); Ok(()) diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index fa228d499d1f..6b5d2677fa10 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -1237,6 +1237,12 @@ SELECT (ABS(1 - CAST(approx_percentile_cont(c11, 0.9) AS DOUBLE) / 0.834) < 0.05 ---- true +# percentile_cont_with_nulls +query I +SELECT APPROX_PERCENTILE_CONT(v, 0.5) FROM (VALUES (1), (2), (3), (NULL), (NULL), (NULL)) as t (v); +---- +2 + # csv_query_cube_avg query TIR SELECT c1, c2, AVG(c3) FROM aggregate_test_100 GROUP BY CUBE (c1, c2) ORDER BY c1, c2