diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index cb77cc8e79b1..ff9cdedab8b1 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -590,6 +590,13 @@ pub fn base_type(data_type: &DataType) -> DataType { } } +/// Information about how to coerce lists. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] +pub enum ListCoercion { + /// [`DataType::FixedSizeList`] should be coerced to [`DataType::List`]. + FixedSizedListToList, +} + /// A helper function to coerce base type in List. /// /// Example @@ -600,16 +607,22 @@ pub fn base_type(data_type: &DataType) -> DataType { /// /// let data_type = DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))); /// let base_type = DataType::Float64; -/// let coerced_type = coerced_type_with_base_type_only(&data_type, &base_type); +/// let coerced_type = coerced_type_with_base_type_only(&data_type, &base_type, None); /// assert_eq!(coerced_type, DataType::List(Arc::new(Field::new_list_field(DataType::Float64, true)))); pub fn coerced_type_with_base_type_only( data_type: &DataType, base_type: &DataType, + array_coercion: Option<&ListCoercion>, ) -> DataType { - match data_type { - DataType::List(field) | DataType::FixedSizeList(field, _) => { - let field_type = - coerced_type_with_base_type_only(field.data_type(), base_type); + match (data_type, array_coercion) { + (DataType::List(field), _) + | (DataType::FixedSizeList(field, _), Some(ListCoercion::FixedSizedListToList)) => + { + let field_type = coerced_type_with_base_type_only( + field.data_type(), + base_type, + array_coercion, + ); DataType::List(Arc::new(Field::new( field.name(), @@ -617,9 +630,24 @@ pub fn coerced_type_with_base_type_only( field.is_nullable(), ))) } - DataType::LargeList(field) => { - let field_type = - coerced_type_with_base_type_only(field.data_type(), base_type); + (DataType::FixedSizeList(field, len), _) => { + let field_type = coerced_type_with_base_type_only( + field.data_type(), + base_type, + array_coercion, + ); + + DataType::FixedSizeList( + Arc::new(Field::new(field.name(), field_type, field.is_nullable())), + *len, + ) + } + (DataType::LargeList(field), _) => { + let field_type = coerced_type_with_base_type_only( + field.data_type(), + base_type, + array_coercion, + ); DataType::LargeList(Arc::new(Field::new( field.name(), diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index 1bfae28af840..4ca4961d7b63 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -19,11 +19,11 @@ //! and return types of functions in DataFusion. use std::fmt::Display; -use std::num::NonZeroUsize; use crate::type_coercion::aggregates::NUMERICS; use arrow::datatypes::{DataType, IntervalUnit, TimeUnit}; use datafusion_common::types::{LogicalTypeRef, NativeType}; +use datafusion_common::utils::ListCoercion; use itertools::Itertools; /// Constant that is used as a placeholder for any valid timezone. @@ -227,25 +227,13 @@ impl Display for TypeSignatureClass { #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub enum ArrayFunctionSignature { - /// Specialized Signature for ArrayAppend and similar functions - /// The first argument should be List/LargeList/FixedSizedList, and the second argument should be non-list or list. - /// The second argument's list dimension should be one dimension less than the first argument's list dimension. - /// List dimension of the List/LargeList is equivalent to the number of List. - /// List dimension of the non-list is 0. - ArrayAndElement, - /// Specialized Signature for ArrayPrepend and similar functions - /// The first argument should be non-list or list, and the second argument should be List/LargeList. - /// The first argument's list dimension should be one dimension less than the second argument's list dimension. - ElementAndArray, - /// Specialized Signature for Array functions of the form (List/LargeList, Index+) - /// The first argument should be List/LargeList/FixedSizedList, and the next n arguments should be Int64. - ArrayAndIndexes(NonZeroUsize), - /// Specialized Signature for Array functions of the form (List/LargeList, Element, Optional Index) - ArrayAndElementAndOptionalIndex, - /// Specialized Signature for ArrayEmpty and similar functions - /// The function takes a single argument that must be a List/LargeList/FixedSizeList - /// or something that can be coerced to one of those types. - Array, + /// A function takes at least one List/LargeList/FixedSizeList argument. + Array { + /// A full list of the arguments accepted by this function. + arguments: Vec, + /// Additional information about how array arguments should be coerced. + array_coercion: Option, + }, /// A function takes a single argument that must be a List/LargeList/FixedSizeList /// which gets coerced to List, with element type recursively coerced to List too if it is list-like. RecursiveArray, @@ -257,25 +245,15 @@ pub enum ArrayFunctionSignature { impl Display for ArrayFunctionSignature { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - ArrayFunctionSignature::ArrayAndElement => { - write!(f, "array, element") - } - ArrayFunctionSignature::ArrayAndElementAndOptionalIndex => { - write!(f, "array, element, [index]") - } - ArrayFunctionSignature::ElementAndArray => { - write!(f, "element, array") - } - ArrayFunctionSignature::ArrayAndIndexes(count) => { - write!(f, "array")?; - for _ in 0..count.get() { - write!(f, ", index")?; + ArrayFunctionSignature::Array { arguments, .. } => { + for (idx, argument) in arguments.iter().enumerate() { + write!(f, "{argument}")?; + if idx != arguments.len() - 1 { + write!(f, ", ")?; + } } Ok(()) } - ArrayFunctionSignature::Array => { - write!(f, "array") - } ArrayFunctionSignature::RecursiveArray => { write!(f, "recursive_array") } @@ -286,6 +264,34 @@ impl Display for ArrayFunctionSignature { } } +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] +pub enum ArrayFunctionArgument { + /// A non-list or list argument. The list dimensions should be one less than the Array's list + /// dimensions. + Element, + /// An Int64 index argument. + Index, + /// An argument of type List/LargeList/FixedSizeList. All Array arguments must be coercible + /// to the same type. + Array, +} + +impl Display for ArrayFunctionArgument { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ArrayFunctionArgument::Element => { + write!(f, "element") + } + ArrayFunctionArgument::Index => { + write!(f, "index") + } + ArrayFunctionArgument::Array => { + write!(f, "array") + } + } + } +} + impl TypeSignature { pub fn to_string_repr(&self) -> Vec { match self { @@ -580,7 +586,13 @@ impl Signature { pub fn array_and_element(volatility: Volatility) -> Self { Signature { type_signature: TypeSignature::ArraySignature( - ArrayFunctionSignature::ArrayAndElement, + ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Array, + ArrayFunctionArgument::Element, + ], + array_coercion: Some(ListCoercion::FixedSizedListToList), + }, ), volatility, } @@ -588,30 +600,38 @@ impl Signature { /// Specialized Signature for Array functions with an optional index pub fn array_and_element_and_optional_index(volatility: Volatility) -> Self { Signature { - type_signature: TypeSignature::ArraySignature( - ArrayFunctionSignature::ArrayAndElementAndOptionalIndex, - ), - volatility, - } - } - /// Specialized Signature for ArrayPrepend and similar functions - pub fn element_and_array(volatility: Volatility) -> Self { - Signature { - type_signature: TypeSignature::ArraySignature( - ArrayFunctionSignature::ElementAndArray, - ), + type_signature: TypeSignature::OneOf(vec![ + TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Array, + ArrayFunctionArgument::Element, + ], + array_coercion: None, + }), + TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Array, + ArrayFunctionArgument::Element, + ArrayFunctionArgument::Index, + ], + array_coercion: None, + }), + ]), volatility, } } + /// Specialized Signature for ArrayElement and similar functions pub fn array_and_index(volatility: Volatility) -> Self { - Self::array_and_indexes(volatility, NonZeroUsize::new(1).expect("1 is non-zero")) - } - /// Specialized Signature for ArraySlice and similar functions - pub fn array_and_indexes(volatility: Volatility, count: NonZeroUsize) -> Self { Signature { type_signature: TypeSignature::ArraySignature( - ArrayFunctionSignature::ArrayAndIndexes(count), + ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Array, + ArrayFunctionArgument::Index, + ], + array_coercion: None, + }, ), volatility, } @@ -619,7 +639,12 @@ impl Signature { /// Specialized Signature for ArrayEmpty and similar functions pub fn array(volatility: Volatility) -> Self { Signature { - type_signature: TypeSignature::ArraySignature(ArrayFunctionSignature::Array), + type_signature: TypeSignature::ArraySignature( + ArrayFunctionSignature::Array { + arguments: vec![ArrayFunctionArgument::Array], + array_coercion: None, + }, + ), volatility, } } diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index aaa65c676a42..2f04f234eb1d 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -71,8 +71,8 @@ pub use datafusion_expr_common::columnar_value::ColumnarValue; pub use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator}; pub use datafusion_expr_common::operator::Operator; pub use datafusion_expr_common::signature::{ - ArrayFunctionSignature, Signature, TypeSignature, TypeSignatureClass, Volatility, - TIMEZONE_WILDCARD, + ArrayFunctionArgument, ArrayFunctionSignature, Signature, TypeSignature, + TypeSignatureClass, Volatility, TIMEZONE_WILDCARD, }; pub use datafusion_expr_common::type_coercion::binary; pub use expr::{ diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 7ac836ef3aeb..7fda92862be9 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -21,13 +21,14 @@ use arrow::{ compute::can_cast_types, datatypes::{DataType, TimeUnit}, }; -use datafusion_common::utils::coerced_fixed_size_list_to_list; +use datafusion_common::utils::{coerced_fixed_size_list_to_list, ListCoercion}; use datafusion_common::{ exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, types::{LogicalType, NativeType}, utils::list_ndims, Result, }; +use datafusion_expr_common::signature::ArrayFunctionArgument; use datafusion_expr_common::{ signature::{ ArrayFunctionSignature, TypeSignatureClass, FIXED_SIZE_LIST_WILDCARD, @@ -357,88 +358,81 @@ fn get_valid_types( signature: &TypeSignature, current_types: &[DataType], ) -> Result>> { - fn array_element_and_optional_index( + fn array_valid_types( function_name: &str, current_types: &[DataType], + arguments: &[ArrayFunctionArgument], + array_coercion: Option<&ListCoercion>, ) -> Result>> { - // make sure there's 2 or 3 arguments - if !(current_types.len() == 2 || current_types.len() == 3) { + if current_types.len() != arguments.len() { return Ok(vec![vec![]]); } - let first_two_types = ¤t_types[0..2]; - let mut valid_types = - array_append_or_prepend_valid_types(function_name, first_two_types, true)?; - - // Early return if there are only 2 arguments - if current_types.len() == 2 { - return Ok(valid_types); - } - - let valid_types_with_index = valid_types - .iter() - .map(|t| { - let mut t = t.clone(); - t.push(DataType::Int64); - t - }) - .collect::>(); - - valid_types.extend(valid_types_with_index); - - Ok(valid_types) - } - - fn array_append_or_prepend_valid_types( - function_name: &str, - current_types: &[DataType], - is_append: bool, - ) -> Result>> { - if current_types.len() != 2 { - return Ok(vec![vec![]]); - } - - let (array_type, elem_type) = if is_append { - (¤t_types[0], ¤t_types[1]) - } else { - (¤t_types[1], ¤t_types[0]) + let array_idx = arguments.iter().enumerate().find_map(|(idx, arg)| { + if *arg == ArrayFunctionArgument::Array { + Some(idx) + } else { + None + } + }); + let Some(array_idx) = array_idx else { + return Err(internal_datafusion_err!("Function '{function_name}' expected at least one argument array argument")); }; - - // We follow Postgres on `array_append(Null, T)`, which is not valid. - if array_type.eq(&DataType::Null) { + let Some(array_type) = array(¤t_types[array_idx]) else { return Ok(vec![vec![]]); - } + }; // We need to find the coerced base type, mainly for cases like: // `array_append(List(null), i64)` -> `List(i64)` - let array_base_type = datafusion_common::utils::base_type(array_type); - let elem_base_type = datafusion_common::utils::base_type(elem_type); - let new_base_type = comparison_coercion(&array_base_type, &elem_base_type); - - let new_base_type = new_base_type.ok_or_else(|| { - internal_datafusion_err!( - "Function '{function_name}' does not support coercion from {array_base_type:?} to {elem_base_type:?}" - ) - })?; - + let mut new_base_type = datafusion_common::utils::base_type(&array_type); + for (current_type, argument_type) in current_types.iter().zip(arguments.iter()) { + match argument_type { + ArrayFunctionArgument::Element | ArrayFunctionArgument::Array => { + new_base_type = + coerce_array_types(function_name, current_type, &new_base_type)?; + } + ArrayFunctionArgument::Index => {} + } + } let new_array_type = datafusion_common::utils::coerced_type_with_base_type_only( - array_type, + &array_type, &new_base_type, + array_coercion, ); - match new_array_type { + let new_elem_type = match new_array_type { DataType::List(ref field) | DataType::LargeList(ref field) - | DataType::FixedSizeList(ref field, _) => { - let new_elem_type = field.data_type(); - if is_append { - Ok(vec![vec![new_array_type.clone(), new_elem_type.clone()]]) - } else { - Ok(vec![vec![new_elem_type.to_owned(), new_array_type.clone()]]) + | DataType::FixedSizeList(ref field, _) => field.data_type(), + _ => return Ok(vec![vec![]]), + }; + + let mut valid_types = Vec::with_capacity(arguments.len()); + for (current_type, argument_type) in current_types.iter().zip(arguments.iter()) { + let valid_type = match argument_type { + ArrayFunctionArgument::Element => new_elem_type.clone(), + ArrayFunctionArgument::Index => DataType::Int64, + ArrayFunctionArgument::Array => { + let Some(current_type) = array(current_type) else { + return Ok(vec![vec![]]); + }; + let new_type = + datafusion_common::utils::coerced_type_with_base_type_only( + ¤t_type, + &new_base_type, + array_coercion, + ); + // All array arguments must be coercible to the same type + if new_type != new_array_type { + return Ok(vec![vec![]]); + } + new_type } - } - _ => Ok(vec![vec![]]), + }; + valid_types.push(valid_type); } + + Ok(vec![valid_types]) } fn array(array_type: &DataType) -> Option { @@ -449,6 +443,20 @@ fn get_valid_types( } } + fn coerce_array_types( + function_name: &str, + current_type: &DataType, + base_type: &DataType, + ) -> Result { + let current_base_type = datafusion_common::utils::base_type(current_type); + let new_base_type = comparison_coercion(base_type, ¤t_base_type); + new_base_type.ok_or_else(|| { + internal_datafusion_err!( + "Function '{function_name}' does not support coercion from {base_type:?} to {current_base_type:?}" + ) + }) + } + fn recursive_array(array_type: &DataType) -> Option { match array_type { DataType::List(_) @@ -693,40 +701,9 @@ fn get_valid_types( vec![current_types.to_vec()] } TypeSignature::Exact(valid_types) => vec![valid_types.clone()], - TypeSignature::ArraySignature(ref function_signature) => match function_signature - { - ArrayFunctionSignature::ArrayAndElement => { - array_append_or_prepend_valid_types(function_name, current_types, true)? - } - ArrayFunctionSignature::ElementAndArray => { - array_append_or_prepend_valid_types(function_name, current_types, false)? - } - ArrayFunctionSignature::ArrayAndIndexes(count) => { - if current_types.len() != count.get() + 1 { - return Ok(vec![vec![]]); - } - array(¤t_types[0]).map_or_else( - || vec![vec![]], - |array_type| { - let mut inner = Vec::with_capacity(count.get() + 1); - inner.push(array_type); - for _ in 0..count.get() { - inner.push(DataType::Int64); - } - vec![inner] - }, - ) - } - ArrayFunctionSignature::ArrayAndElementAndOptionalIndex => { - array_element_and_optional_index(function_name, current_types)? - } - ArrayFunctionSignature::Array => { - if current_types.len() != 1 { - return Ok(vec![vec![]]); - } - - array(¤t_types[0]) - .map_or_else(|| vec![vec![]], |array_type| vec![vec![array_type]]) + TypeSignature::ArraySignature(ref function_signature) => match function_signature { + ArrayFunctionSignature::Array { arguments, array_coercion, } => { + array_valid_types(function_name, current_types, arguments, array_coercion.as_ref())? } ArrayFunctionSignature::RecursiveArray => { if current_types.len() != 1 { diff --git a/datafusion/functions-nested/src/cardinality.rs b/datafusion/functions-nested/src/cardinality.rs index ad30c0b540af..886709779917 100644 --- a/datafusion/functions-nested/src/cardinality.rs +++ b/datafusion/functions-nested/src/cardinality.rs @@ -30,8 +30,8 @@ use datafusion_common::utils::take_function_args; use datafusion_common::Result; use datafusion_common::{exec_err, plan_err}; use datafusion_expr::{ - ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature, - TypeSignature, Volatility, + ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation, + ScalarUDFImpl, Signature, TypeSignature, Volatility, }; use datafusion_macros::user_doc; use std::any::Any; @@ -50,7 +50,10 @@ impl Cardinality { Self { signature: Signature::one_of( vec![ - TypeSignature::ArraySignature(ArrayFunctionSignature::Array), + TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ArrayFunctionArgument::Array], + array_coercion: None, + }), TypeSignature::ArraySignature(ArrayFunctionSignature::MapArray), ], Volatility::Immutable, diff --git a/datafusion/functions-nested/src/concat.rs b/datafusion/functions-nested/src/concat.rs index 14d4b958867f..f40417386944 100644 --- a/datafusion/functions-nested/src/concat.rs +++ b/datafusion/functions-nested/src/concat.rs @@ -26,6 +26,7 @@ use arrow::array::{ }; use arrow::buffer::OffsetBuffer; use arrow::datatypes::{DataType, Field}; +use datafusion_common::utils::ListCoercion; use datafusion_common::Result; use datafusion_common::{ cast::as_generic_list_array, @@ -33,7 +34,8 @@ use datafusion_common::{ utils::{list_ndims, take_function_args}, }; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation, + ScalarUDFImpl, Signature, TypeSignature, Volatility, }; use datafusion_macros::user_doc; @@ -165,7 +167,18 @@ impl Default for ArrayPrepend { impl ArrayPrepend { pub fn new() -> Self { Self { - signature: Signature::element_and_array(Volatility::Immutable), + signature: Signature { + type_signature: TypeSignature::ArraySignature( + ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Element, + ArrayFunctionArgument::Array, + ], + array_coercion: Some(ListCoercion::FixedSizedListToList), + }, + ), + volatility: Volatility::Immutable, + }, aliases: vec![ String::from("list_prepend"), String::from("array_push_front"), diff --git a/datafusion/functions-nested/src/extract.rs b/datafusion/functions-nested/src/extract.rs index 697c868fdea1..6bf4d16db636 100644 --- a/datafusion/functions-nested/src/extract.rs +++ b/datafusion/functions-nested/src/extract.rs @@ -30,17 +30,19 @@ use arrow::datatypes::{ use datafusion_common::cast::as_int64_array; use datafusion_common::cast::as_large_list_array; use datafusion_common::cast::as_list_array; +use datafusion_common::utils::ListCoercion; use datafusion_common::{ exec_err, internal_datafusion_err, plan_err, utils::take_function_args, DataFusionError, Result, }; -use datafusion_expr::{ArrayFunctionSignature, Expr, TypeSignature}; +use datafusion_expr::{ + ArrayFunctionArgument, ArrayFunctionSignature, Expr, TypeSignature, +}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; use std::any::Any; -use std::num::NonZeroUsize; use std::sync::Arc; use crate::utils::make_scalar_function; @@ -330,16 +332,23 @@ impl ArraySlice { Self { signature: Signature::one_of( vec![ - TypeSignature::ArraySignature( - ArrayFunctionSignature::ArrayAndIndexes( - NonZeroUsize::new(2).expect("2 is non-zero"), - ), - ), - TypeSignature::ArraySignature( - ArrayFunctionSignature::ArrayAndIndexes( - NonZeroUsize::new(3).expect("3 is non-zero"), - ), - ), + TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Array, + ArrayFunctionArgument::Index, + ArrayFunctionArgument::Index, + ], + array_coercion: None, + }), + TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Array, + ArrayFunctionArgument::Index, + ArrayFunctionArgument::Index, + ArrayFunctionArgument::Index, + ], + array_coercion: None, + }), ], Volatility::Immutable, ), @@ -665,7 +674,15 @@ pub(super) struct ArrayPopFront { impl ArrayPopFront { pub fn new() -> Self { Self { - signature: Signature::array(Volatility::Immutable), + signature: Signature { + type_signature: TypeSignature::ArraySignature( + ArrayFunctionSignature::Array { + arguments: vec![ArrayFunctionArgument::Array], + array_coercion: Some(ListCoercion::FixedSizedListToList), + }, + ), + volatility: Volatility::Immutable, + }, aliases: vec![String::from("list_pop_front")], } } @@ -765,7 +782,15 @@ pub(super) struct ArrayPopBack { impl ArrayPopBack { pub fn new() -> Self { Self { - signature: Signature::array(Volatility::Immutable), + signature: Signature { + type_signature: TypeSignature::ArraySignature( + ArrayFunctionSignature::Array { + arguments: vec![ArrayFunctionArgument::Array], + array_coercion: Some(ListCoercion::FixedSizedListToList), + }, + ), + volatility: Volatility::Immutable, + }, aliases: vec![String::from("list_pop_back")], } } diff --git a/datafusion/functions-nested/src/replace.rs b/datafusion/functions-nested/src/replace.rs index 53f43de4108d..6d84e64cba4d 100644 --- a/datafusion/functions-nested/src/replace.rs +++ b/datafusion/functions-nested/src/replace.rs @@ -25,9 +25,11 @@ use arrow::datatypes::{DataType, Field}; use arrow::buffer::OffsetBuffer; use datafusion_common::cast::as_int64_array; +use datafusion_common::utils::ListCoercion; use datafusion_common::{exec_err, utils::take_function_args, Result}; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation, + ScalarUDFImpl, Signature, TypeSignature, Volatility, }; use datafusion_macros::user_doc; @@ -91,7 +93,19 @@ impl Default for ArrayReplace { impl ArrayReplace { pub fn new() -> Self { Self { - signature: Signature::any(3, Volatility::Immutable), + signature: Signature { + type_signature: TypeSignature::ArraySignature( + ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Array, + ArrayFunctionArgument::Element, + ArrayFunctionArgument::Element, + ], + array_coercion: Some(ListCoercion::FixedSizedListToList), + }, + ), + volatility: Volatility::Immutable, + }, aliases: vec![String::from("list_replace")], } } @@ -160,7 +174,20 @@ pub(super) struct ArrayReplaceN { impl ArrayReplaceN { pub fn new() -> Self { Self { - signature: Signature::any(4, Volatility::Immutable), + signature: Signature { + type_signature: TypeSignature::ArraySignature( + ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Array, + ArrayFunctionArgument::Element, + ArrayFunctionArgument::Element, + ArrayFunctionArgument::Index, + ], + array_coercion: Some(ListCoercion::FixedSizedListToList), + }, + ), + volatility: Volatility::Immutable, + }, aliases: vec![String::from("list_replace_n")], } } @@ -228,7 +255,19 @@ pub(super) struct ArrayReplaceAll { impl ArrayReplaceAll { pub fn new() -> Self { Self { - signature: Signature::any(3, Volatility::Immutable), + signature: Signature { + type_signature: TypeSignature::ArraySignature( + ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Array, + ArrayFunctionArgument::Element, + ArrayFunctionArgument::Element, + ], + array_coercion: Some(ListCoercion::FixedSizedListToList), + }, + ), + volatility: Volatility::Immutable, + }, aliases: vec![String::from("list_replace_all")], } } diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 8f23bfe5ea65..4418d426cc02 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -2656,6 +2656,29 @@ select list_push_front(1, arrow_cast(make_array(2, 3, 4), 'LargeList(Int64)')), ---- [1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] +# array_prepend scalar function #7 (element is fixed size list) +query ??? +select array_prepend(arrow_cast(make_array(1), 'FixedSizeList(1, Int64)'), make_array(arrow_cast(make_array(2), 'FixedSizeList(1, Int64)'), arrow_cast(make_array(3), 'FixedSizeList(1, Int64)'), arrow_cast(make_array(4), 'FixedSizeList(1, Int64)'))), + array_prepend(arrow_cast(make_array(1.0), 'FixedSizeList(1, Float64)'), make_array(arrow_cast([2.0], 'FixedSizeList(1, Float64)'), arrow_cast([3.0], 'FixedSizeList(1, Float64)'), arrow_cast([4.0], 'FixedSizeList(1, Float64)'))), + array_prepend(arrow_cast(make_array('h'), 'FixedSizeList(1, Utf8)'), make_array(arrow_cast(['e'], 'FixedSizeList(1, Utf8)'), arrow_cast(['l'], 'FixedSizeList(1, Utf8)'), arrow_cast(['l'], 'FixedSizeList(1, Utf8)'), arrow_cast(['o'], 'FixedSizeList(1, Utf8)'))); +---- +[[1], [2], [3], [4]] [[1.0], [2.0], [3.0], [4.0]] [[h], [e], [l], [l], [o]] + +# TODO: https://github.com/apache/datafusion/issues/14613 +#query ??? +#select array_prepend(arrow_cast(make_array(1), 'FixedSizeList(1, Int64)'), arrow_cast(make_array(make_array(2), make_array(3), make_array(4)), 'LargeList(FixedSizeList(1, Int64))')), +# array_prepend(arrow_cast(make_array(1.0), 'FixedSizeList(1, Float64)'), arrow_cast(make_array([2.0], [3.0], [4.0]), 'LargeList(FixedSizeList(1, Float64))')), +# array_prepend(arrow_cast(make_array('h'), 'FixedSizeList(1, Utf8)'), arrow_cast(make_array(['e'], ['l'], ['l'], ['o']), 'LargeList(FixedSizeList(1, Utf8))')); +#---- +#[[1], [2], [3], [4]] [[1.0], [2.0], [3.0], [4.0]] [[h], [e], [l], [l], [o]] + +query ??? +select array_prepend(arrow_cast([1], 'FixedSizeList(1, Int64)'), arrow_cast([[1], [2], [3]], 'FixedSizeList(3, FixedSizeList(1, Int64))')), + array_prepend(arrow_cast([1.0], 'FixedSizeList(1, Float64)'), arrow_cast([[2.0], [3.0], [4.0]], 'FixedSizeList(3, FixedSizeList(1, Float64))')), + array_prepend(arrow_cast(['h'], 'FixedSizeList(1, Utf8)'), arrow_cast([['e'], ['l'], ['l'], ['o']], 'FixedSizeList(4, FixedSizeList(1, Utf8))')); +---- +[[1], [1], [2], [3]] [[1.0], [2.0], [3.0], [4.0]] [[h], [e], [l], [l], [o]] + # array_prepend with columns #1 query ? select array_prepend(column2, column1) from arrays_values; @@ -3563,6 +3586,17 @@ select list_replace( ---- [1, 3, 3, 4] [1, 0, 4, 5, 4, 6, 7] [1, 2, 3] +# array_replace scalar function #4 (null input) +query ? +select array_replace(make_array(1, 2, 3, 4, 5), NULL, NULL); +---- +[1, 2, 3, 4, 5] + +query ? +select array_replace(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), NULL, NULL); +---- +[1, 2, 3, 4, 5] + # array_replace scalar function with columns #1 query ? select array_replace(column1, column2, column3) from arrays_with_repeating_elements; @@ -3728,6 +3762,17 @@ select ---- [1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3] +# array_replace_n scalar function #4 (null input) +query ? +select array_replace_n(make_array(1, 2, 3, 4, 5), NULL, NULL, NULL); +---- +[1, 2, 3, 4, 5] + +query ? +select array_replace_n(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), NULL, NULL, NULL); +---- +[1, 2, 3, 4, 5] + # array_replace_n scalar function with columns #1 query ? select @@ -3904,6 +3949,17 @@ select ---- [1, 3, 3, 4] [1, 0, 0, 5, 0, 6, 7] [1, 2, 3] +# array_replace_all scalar function #4 (null input) +query ? +select array_replace_all(make_array(1, 2, 3, 4, 5), NULL, NULL); +---- +[1, 2, 3, 4, 5] + +query ? +select array_replace_all(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), NULL, NULL); +---- +[1, 2, 3, 4, 5] + # array_replace_all scalar function with columns #1 query ? select