You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/06/15 12:49:44 UTC

[arrow-datafusion] branch main updated: feat: new concatenation operator for working with arrays (#6615)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 6872535358 feat: new concatenation operator for working with arrays (#6615)
6872535358 is described below

commit 6872535358c2c6b70350484e6cc7c058c7a806ae
Author: Igor Izvekov <iz...@gmail.com>
AuthorDate: Thu Jun 15 15:49:38 2023 +0300

    feat: new concatenation operator for working with arrays (#6615)
    
    * feat: new concatenation operator for working with arrays
    
    * fix: array_concat
    
    * fix: cargo fmt
    
    ---------
    
    Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
---
 .../core/tests/sqllogictests/test_files/array.slt  |  20 +++
 datafusion/expr/src/type_coercion/binary.rs        |   6 +
 datafusion/physical-expr/src/array_expressions.rs  | 193 ++++++---------------
 datafusion/physical-expr/src/expressions/binary.rs |  10 +-
 datafusion/physical-expr/src/functions.rs          |  12 +-
 5 files changed, 98 insertions(+), 143 deletions(-)

diff --git a/datafusion/core/tests/sqllogictests/test_files/array.slt b/datafusion/core/tests/sqllogictests/test_files/array.slt
index 44453546f3..6ebde09ee8 100644
--- a/datafusion/core/tests/sqllogictests/test_files/array.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/array.slt
@@ -302,6 +302,26 @@ select array_ndims(make_array()), array_ndims(make_array(make_array()))
 ----
 1 2
 
+# array concatenate operator #1 (like array_concat scalar function)
+query ?? rowsort
+select make_array(1, 2, 3) || make_array(4, 5, 6) || make_array(7, 8, 9), make_array([1], [2]) || make_array([3], [4]);
+----
+[1, 2, 3, 4, 5, 6, 7, 8, 9] [[1], [2], [3], [4]]
+
+# array concatenate operator #2 (like array_append scalar function)
+query ??? rowsort
+select make_array(1, 2, 3) || 4, make_array(1.0, 2.0, 3.0) || 4.0, make_array('h', 'e', 'l', 'l') || 'o';
+----
+[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o]
+
+# array concatenate operator #3 (like array_prepend scalar function)
+query ??? rowsort
+select 1 || make_array(2, 3, 4), 1.0 || make_array(2.0, 3.0, 4.0), 'h' || make_array('e', 'l', 'l', 'o');
+----
+[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o]
+
+# make_array
+
 query ?
 select make_array(1, 2.0)
 ----
diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs
index f8a04de45b..7c9179b2f3 100644
--- a/datafusion/expr/src/type_coercion/binary.rs
+++ b/datafusion/expr/src/type_coercion/binary.rs
@@ -673,6 +673,8 @@ fn string_concat_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<Da
         (LargeUtf8, from_type) | (from_type, LargeUtf8) => {
             string_concat_internal_coercion(from_type, &LargeUtf8)
         }
+        // TODO: cast between array elements (#6558)
+        (List(_), from_type) | (from_type, List(_)) => Some(from_type.to_owned()),
         _ => None,
     })
 }
@@ -697,6 +699,10 @@ fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType>
         (LargeUtf8, Utf8) => Some(LargeUtf8),
         (Utf8, LargeUtf8) => Some(LargeUtf8),
         (LargeUtf8, LargeUtf8) => Some(LargeUtf8),
+        // TODO: cast between array elements (#6558)
+        (List(_), List(_)) => Some(lhs_type.clone()),
+        (List(_), _) => Some(lhs_type.clone()),
+        (_, List(_)) => Some(rhs_type.clone()),
         _ => None,
     }
 }
diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs
index 44b747082a..298bb66dd9 100644
--- a/datafusion/physical-expr/src/array_expressions.rs
+++ b/datafusion/physical-expr/src/array_expressions.rs
@@ -22,6 +22,7 @@ use arrow::buffer::Buffer;
 use arrow::compute;
 use arrow::datatypes::{DataType, Field};
 use core::any::type_name;
+use datafusion_common::cast::as_list_array;
 use datafusion_common::ScalarValue;
 use datafusion_common::{DataFusionError, Result};
 use datafusion_expr::ColumnarValue;
@@ -159,6 +160,7 @@ pub fn make_array(values: &[ColumnarValue]) -> Result<ColumnarValue> {
         _ => array(values),
     }
 }
+
 macro_rules! downcast_arg {
     ($ARG:expr, $ARRAY_TYPE:ident) => {{
         $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| {
@@ -175,19 +177,17 @@ macro_rules! append {
         let child_array =
             downcast_arg!(downcast_arg!($ARRAY, ListArray).values(), $ARRAY_TYPE);
         let element = downcast_arg!($ELEMENT, $ARRAY_TYPE);
-        let concat = compute::concat(&[child_array, element])?;
+        let cat = compute::concat(&[child_array, element])?;
         let mut scalars = vec![];
-        for i in 0..concat.len() {
-            scalars.push(ColumnarValue::Scalar(ScalarValue::try_from_array(
-                &concat, i,
-            )?));
+        for i in 0..cat.len() {
+            scalars.push(ColumnarValue::Scalar(ScalarValue::try_from_array(&cat, i)?));
         }
         scalars
     }};
 }
 
 /// Array_append SQL function
-pub fn array_append(args: &[ColumnarValue]) -> Result<ColumnarValue> {
+pub fn array_append(args: &[ArrayRef]) -> Result<ArrayRef> {
     if args.len() != 2 {
         return Err(DataFusionError::Internal(format!(
             "Array_append function requires two arguments, got {}",
@@ -195,24 +195,10 @@ pub fn array_append(args: &[ColumnarValue]) -> Result<ColumnarValue> {
         )));
     }
 
-    let arr = match &args[0] {
-        ColumnarValue::Scalar(scalar) => scalar.to_array().clone(),
-        ColumnarValue::Array(arr) => arr.clone(),
-    };
+    let arr = as_list_array(&args[0])?;
+    let element = &args[1];
 
-    let element = match &args[1] {
-        ColumnarValue::Scalar(scalar) => scalar.to_array().clone(),
-        _ => {
-            return Err(DataFusionError::Internal(
-                "Array_append function requires scalar element".to_string(),
-            ))
-        }
-    };
-
-    let data_type = arr.data_type();
-    let arrays = match data_type {
-        DataType::List(field) => {
-            match (field.data_type(), element.data_type()) {
+    let scalars = match (arr.value_type(), element.data_type()) {
                 (DataType::Utf8, DataType::Utf8) => append!(arr, element, StringArray),
                 (DataType::LargeUtf8, DataType::LargeUtf8) => append!(arr, element, LargeStringArray),
                 (DataType::Boolean, DataType::Boolean) => append!(arr, element, BooleanArray),
@@ -226,22 +212,15 @@ pub fn array_append(args: &[ColumnarValue]) -> Result<ColumnarValue> {
                 (DataType::UInt16, DataType::UInt16) => append!(arr, element, UInt16Array),
                 (DataType::UInt32, DataType::UInt32) => append!(arr, element, UInt32Array),
                 (DataType::UInt64, DataType::UInt64) => append!(arr, element, UInt64Array),
-                (DataType::Null, _) => return array(&args[1..]),
+                (DataType::Null, _) => return Ok(array(&[ColumnarValue::Array(args[1].clone())])?.into_array(1)),
                 (array_data_type, element_data_type) => {
                     return Err(DataFusionError::NotImplemented(format!(
                         "Array_append is not implemented for types '{array_data_type:?}' and '{element_data_type:?}'."
                     )))
                 }
-            }
-        }
-        data_type => {
-            return Err(DataFusionError::Internal(format!(
-                "Array is not type '{data_type:?}'."
-            )))
-        }
     };
 
-    array(arrays.as_slice())
+    Ok(array(scalars.as_slice())?.into_array(1))
 }
 
 macro_rules! prepend {
@@ -249,19 +228,17 @@ macro_rules! prepend {
         let child_array =
             downcast_arg!(downcast_arg!($ARRAY, ListArray).values(), $ARRAY_TYPE);
         let element = downcast_arg!($ELEMENT, $ARRAY_TYPE);
-        let concat = compute::concat(&[element, child_array])?;
+        let cat = compute::concat(&[element, child_array])?;
         let mut scalars = vec![];
-        for i in 0..concat.len() {
-            scalars.push(ColumnarValue::Scalar(ScalarValue::try_from_array(
-                &concat, i,
-            )?));
+        for i in 0..cat.len() {
+            scalars.push(ColumnarValue::Scalar(ScalarValue::try_from_array(&cat, i)?));
         }
         scalars
     }};
 }
 
 /// Array_prepend SQL function
-pub fn array_prepend(args: &[ColumnarValue]) -> Result<ColumnarValue> {
+pub fn array_prepend(args: &[ArrayRef]) -> Result<ArrayRef> {
     if args.len() != 2 {
         return Err(DataFusionError::Internal(format!(
             "Array_prepend function requires two arguments, got {}",
@@ -269,24 +246,10 @@ pub fn array_prepend(args: &[ColumnarValue]) -> Result<ColumnarValue> {
         )));
     }
 
-    let element = match &args[0] {
-        ColumnarValue::Scalar(scalar) => scalar.to_array().clone(),
-        _ => {
-            return Err(DataFusionError::Internal(
-                "Array_prepend function requires scalar element".to_string(),
-            ))
-        }
-    };
-
-    let arr = match &args[1] {
-        ColumnarValue::Scalar(scalar) => scalar.to_array().clone(),
-        ColumnarValue::Array(arr) => arr.clone(),
-    };
+    let element = &args[0];
+    let arr = as_list_array(&args[1])?;
 
-    let data_type = arr.data_type();
-    let arrays = match data_type {
-        DataType::List(field) => {
-            match (field.data_type(), element.data_type()) {
+    let scalars = match (arr.value_type(), element.data_type()) {
                 (DataType::Utf8, DataType::Utf8) => prepend!(arr, element, StringArray),
                 (DataType::LargeUtf8, DataType::LargeUtf8) => prepend!(arr, element, LargeStringArray),
                 (DataType::Boolean, DataType::Boolean) => prepend!(arr, element, BooleanArray),
@@ -300,39 +263,24 @@ pub fn array_prepend(args: &[ColumnarValue]) -> Result<ColumnarValue> {
                 (DataType::UInt16, DataType::UInt16) => prepend!(arr, element, UInt16Array),
                 (DataType::UInt32, DataType::UInt32) => prepend!(arr, element, UInt32Array),
                 (DataType::UInt64, DataType::UInt64) => prepend!(arr, element, UInt64Array),
-                (DataType::Null, _) => return array(&args[..1]),
+                (DataType::Null, _) => return Ok(array(&[ColumnarValue::Array(args[0].clone())])?.into_array(1)),
                 (array_data_type, element_data_type) => {
                     return Err(DataFusionError::NotImplemented(format!(
                         "Array_prepend is not implemented for types '{array_data_type:?}' and '{element_data_type:?}'."
                     )))
                 }
-            }
-        }
-        data_type => {
-            return Err(DataFusionError::Internal(format!(
-                "Array is not type '{data_type:?}'."
-            )))
-        }
     };
 
-    array(arrays.as_slice())
+    Ok(array(scalars.as_slice())?.into_array(1))
 }
 
 /// Array_concat/Array_cat SQL function
-pub fn array_concat(args: &[ColumnarValue]) -> Result<ColumnarValue> {
-    let arrays: Vec<ArrayRef> = args
-        .iter()
-        .map(|x| match x {
-            ColumnarValue::Array(array) => array.clone(),
-            ColumnarValue::Scalar(scalar) => scalar.to_array().clone(),
-        })
-        .collect();
-    let data_type = arrays[0].data_type();
-    match data_type {
+pub fn array_concat(args: &[ArrayRef]) -> Result<ArrayRef> {
+    match args[0].data_type() {
         DataType::List(field) => match field.data_type() {
             DataType::Null => array_concat(&args[1..]),
             _ => {
-                let list_arrays = downcast_vec!(arrays, ListArray)
+                let list_arrays = downcast_vec!(args, ListArray)
                     .collect::<Result<Vec<&ListArray>>>()?;
                 let len: usize = list_arrays.iter().map(|a| a.values().len()).sum();
                 let capacity =
@@ -354,12 +302,10 @@ pub fn array_concat(args: &[ColumnarValue]) -> Result<ColumnarValue> {
                     .build()
                     .unwrap();
 
-                return Ok(ColumnarValue::Array(Arc::new(arrow::array::make_array(
-                    list,
-                ))));
+                return Ok(Arc::new(arrow::array::make_array(list)));
             }
         },
-        _ => Err(DataFusionError::NotImplemented(format!(
+        data_type => Err(DataFusionError::NotImplemented(format!(
             "Array is not type '{data_type:?}'."
         ))),
     }
@@ -1128,6 +1074,7 @@ pub fn array_ndims(args: &[ColumnarValue]) -> Result<ColumnarValue> {
 mod tests {
     use super::*;
     use arrow::array::UInt8Array;
+    use arrow::datatypes::Int64Type;
     use datafusion_common::cast::{
         as_generic_string_array, as_list_array, as_uint64_array, as_uint8_array,
     };
@@ -1193,21 +1140,15 @@ mod tests {
     #[test]
     fn test_array_append() {
         // array_append([1, 2, 3], 4) = [1, 2, 3, 4]
-        let args = [
-            ColumnarValue::Scalar(ScalarValue::List(
-                Some(vec![
-                    ScalarValue::Int64(Some(1)),
-                    ScalarValue::Int64(Some(2)),
-                    ScalarValue::Int64(Some(3)),
-                ]),
-                Arc::new(Field::new("item", DataType::Int64, false)),
-            )),
-            ColumnarValue::Scalar(ScalarValue::Int64(Some(4))),
-        ];
+        let data = vec![Some(vec![Some(1), Some(2), Some(3)])];
+        let list_array =
+            Arc::new(ListArray::from_iter_primitive::<Int64Type, _, _>(data)) as ArrayRef;
+        let int64_array = Arc::new(Int64Array::from(vec![Some(4)])) as ArrayRef;
 
-        let array = array_append(&args)
-            .expect("failed to initialize function array_append")
-            .into_array(1);
+        let args = [list_array, int64_array];
+
+        let array =
+            array_append(&args).expect("failed to initialize function array_append");
         let result =
             as_list_array(&array).expect("failed to initialize function array_append");
 
@@ -1225,21 +1166,15 @@ mod tests {
     #[test]
     fn test_array_prepend() {
         // array_prepend(1, [2, 3, 4]) = [1, 2, 3, 4]
-        let args = [
-            ColumnarValue::Scalar(ScalarValue::Int64(Some(1))),
-            ColumnarValue::Scalar(ScalarValue::List(
-                Some(vec![
-                    ScalarValue::Int64(Some(2)),
-                    ScalarValue::Int64(Some(3)),
-                    ScalarValue::Int64(Some(4)),
-                ]),
-                Arc::new(Field::new("item", DataType::Int64, false)),
-            )),
-        ];
+        let data = vec![Some(vec![Some(2), Some(3), Some(4)])];
+        let list_array =
+            Arc::new(ListArray::from_iter_primitive::<Int64Type, _, _>(data)) as ArrayRef;
+        let int64_array = Arc::new(Int64Array::from(vec![Some(1)])) as ArrayRef;
 
-        let array = array_prepend(&args)
-            .expect("failed to initialize function array_append")
-            .into_array(1);
+        let args = [int64_array, list_array];
+
+        let array =
+            array_prepend(&args).expect("failed to initialize function array_append");
         let result =
             as_list_array(&array).expect("failed to initialize function array_append");
 
@@ -1257,36 +1192,20 @@ mod tests {
     #[test]
     fn test_array_concat() {
         // array_concat([1, 2, 3], [4, 5, 6], [7, 8, 9]) = [1, 2, 3, 4, 5, 6, 7, 8, 9]
-        let args = [
-            ColumnarValue::Scalar(ScalarValue::List(
-                Some(vec![
-                    ScalarValue::Int64(Some(1)),
-                    ScalarValue::Int64(Some(2)),
-                    ScalarValue::Int64(Some(3)),
-                ]),
-                Arc::new(Field::new("item", DataType::Int64, false)),
-            )),
-            ColumnarValue::Scalar(ScalarValue::List(
-                Some(vec![
-                    ScalarValue::Int64(Some(4)),
-                    ScalarValue::Int64(Some(5)),
-                    ScalarValue::Int64(Some(6)),
-                ]),
-                Arc::new(Field::new("item", DataType::Int64, false)),
-            )),
-            ColumnarValue::Scalar(ScalarValue::List(
-                Some(vec![
-                    ScalarValue::Int64(Some(7)),
-                    ScalarValue::Int64(Some(8)),
-                    ScalarValue::Int64(Some(9)),
-                ]),
-                Arc::new(Field::new("item", DataType::Int64, false)),
-            )),
-        ];
-
-        let array = array_concat(&args)
-            .expect("failed to initialize function array_concat")
-            .into_array(1);
+        let data = vec![Some(vec![Some(1), Some(2), Some(3)])];
+        let list_array1 =
+            Arc::new(ListArray::from_iter_primitive::<Int64Type, _, _>(data)) as ArrayRef;
+        let data = vec![Some(vec![Some(4), Some(5), Some(6)])];
+        let list_array2 =
+            Arc::new(ListArray::from_iter_primitive::<Int64Type, _, _>(data)) as ArrayRef;
+        let data = vec![Some(vec![Some(7), Some(8), Some(9)])];
+        let list_array3 =
+            Arc::new(ListArray::from_iter_primitive::<Int64Type, _, _>(data)) as ArrayRef;
+
+        let args = [list_array1, list_array2, list_array3];
+
+        let array =
+            array_concat(&args).expect("failed to initialize function array_concat");
         let result =
             as_list_array(&array).expect("failed to initialize function array_concat");
 
diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs
index ab0c68deca..8e9e361596 100644
--- a/datafusion/physical-expr/src/expressions/binary.rs
+++ b/datafusion/physical-expr/src/expressions/binary.rs
@@ -82,6 +82,7 @@ use self::kernels_arrow::{
 };
 
 use super::column::Column;
+use crate::array_expressions::{array_append, array_concat, array_prepend};
 use crate::expressions::cast_column;
 use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison};
 use crate::intervals::{apply_operator, Interval};
@@ -1257,9 +1258,12 @@ impl BinaryExpr {
             BitwiseXor => bitwise_xor_dyn(left, right),
             BitwiseShiftRight => bitwise_shift_right_dyn(left, right),
             BitwiseShiftLeft => bitwise_shift_left_dyn(left, right),
-            StringConcat => {
-                binary_string_array_op!(left, right, concat_elements)
-            }
+            StringConcat => match (left_data_type, right_data_type) {
+                (DataType::List(_), DataType::List(_)) => array_concat(&[left, right]),
+                (DataType::List(_), _) => array_append(&[left, right]),
+                (_, DataType::List(_)) => array_prepend(&[left, right]),
+                _ => binary_string_array_op!(left, right, concat_elements),
+            },
         }
     }
 }
diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs
index c45986eb8a..37dd492e9e 100644
--- a/datafusion/physical-expr/src/functions.rs
+++ b/datafusion/physical-expr/src/functions.rs
@@ -385,8 +385,12 @@ pub fn create_physical_fun(
         }
 
         // array functions
-        BuiltinScalarFunction::ArrayAppend => Arc::new(array_expressions::array_append),
-        BuiltinScalarFunction::ArrayConcat => Arc::new(array_expressions::array_concat),
+        BuiltinScalarFunction::ArrayAppend => {
+            Arc::new(|args| make_scalar_function(array_expressions::array_append)(args))
+        }
+        BuiltinScalarFunction::ArrayConcat => {
+            Arc::new(|args| make_scalar_function(array_expressions::array_concat)(args))
+        }
         BuiltinScalarFunction::ArrayDims => Arc::new(array_expressions::array_dims),
         BuiltinScalarFunction::ArrayFill => Arc::new(array_expressions::array_fill),
         BuiltinScalarFunction::ArrayLength => Arc::new(array_expressions::array_length),
@@ -397,7 +401,9 @@ pub fn create_physical_fun(
         BuiltinScalarFunction::ArrayPositions => {
             Arc::new(array_expressions::array_positions)
         }
-        BuiltinScalarFunction::ArrayPrepend => Arc::new(array_expressions::array_prepend),
+        BuiltinScalarFunction::ArrayPrepend => {
+            Arc::new(|args| make_scalar_function(array_expressions::array_prepend)(args))
+        }
         BuiltinScalarFunction::ArrayRemove => Arc::new(array_expressions::array_remove),
         BuiltinScalarFunction::ArrayReplace => Arc::new(array_expressions::array_replace),
         BuiltinScalarFunction::ArrayToString => {