You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/12/18 20:26:09 UTC

(arrow-datafusion) branch main updated: support LargeList in array_element (#8570)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new d33ca4dd37 support LargeList in array_element (#8570)
d33ca4dd37 is described below

commit d33ca4dd37b8b47120579b7c3e0456c1fcbcb06f
Author: Alex Huang <hu...@gmail.com>
AuthorDate: Mon Dec 18 21:26:02 2023 +0100

    support LargeList in array_element (#8570)
---
 datafusion/expr/src/built_in_function.rs          |  3 +-
 datafusion/physical-expr/src/array_expressions.rs | 82 ++++++++++++++++-------
 datafusion/sqllogictest/test_files/array.slt      | 72 +++++++++++++++++++-
 3 files changed, 130 insertions(+), 27 deletions(-)

diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs
index 289704ed98..3818e8ee56 100644
--- a/datafusion/expr/src/built_in_function.rs
+++ b/datafusion/expr/src/built_in_function.rs
@@ -591,8 +591,9 @@ impl BuiltinScalarFunction {
             BuiltinScalarFunction::ArrayDistinct => Ok(input_expr_types[0].clone()),
             BuiltinScalarFunction::ArrayElement => match &input_expr_types[0] {
                 List(field) => Ok(field.data_type().clone()),
+                LargeList(field) => Ok(field.data_type().clone()),
                 _ => plan_err!(
-                    "The {self} function can only accept list as the first argument"
+                    "The {self} function can only accept list or largelist as the first argument"
                 ),
             },
             BuiltinScalarFunction::ArrayLength => Ok(UInt64),
diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs
index cc4b2899fc..d396581083 100644
--- a/datafusion/physical-expr/src/array_expressions.rs
+++ b/datafusion/physical-expr/src/array_expressions.rs
@@ -370,18 +370,14 @@ pub fn make_array(arrays: &[ArrayRef]) -> Result<ArrayRef> {
     }
 }
 
-/// array_element SQL function
-///
-/// There are two arguments for array_element, the first one is the array, the second one is the 1-indexed index.
-/// `array_element(array, index)`
-///
-/// For example:
-/// > array_element(\[1, 2, 3], 2) -> 2
-pub fn array_element(args: &[ArrayRef]) -> Result<ArrayRef> {
-    let list_array = as_list_array(&args[0])?;
-    let indexes = as_int64_array(&args[1])?;
-
-    let values = list_array.values();
+fn general_array_element<O: OffsetSizeTrait>(
+    array: &GenericListArray<O>,
+    indexes: &Int64Array,
+) -> Result<ArrayRef>
+where
+    i64: TryInto<O>,
+{
+    let values = array.values();
     let original_data = values.to_data();
     let capacity = Capacities::Array(original_data.len());
 
@@ -389,37 +385,47 @@ pub fn array_element(args: &[ArrayRef]) -> Result<ArrayRef> {
     let mut mutable =
         MutableArrayData::with_capacities(vec![&original_data], true, capacity);
 
-    fn adjusted_array_index(index: i64, len: usize) -> Option<i64> {
+    fn adjusted_array_index<O: OffsetSizeTrait>(index: i64, len: O) -> Result<Option<O>>
+    where
+        i64: TryInto<O>,
+    {
+        let index: O = index.try_into().map_err(|_| {
+            DataFusionError::Execution(format!(
+                "array_element got invalid index: {}",
+                index
+            ))
+        })?;
         // 0 ~ len - 1
-        let adjusted_zero_index = if index < 0 {
-            index + len as i64
+        let adjusted_zero_index = if index < O::usize_as(0) {
+            index + len
         } else {
-            index - 1
+            index - O::usize_as(1)
         };
 
-        if 0 <= adjusted_zero_index && adjusted_zero_index < len as i64 {
-            Some(adjusted_zero_index)
+        if O::usize_as(0) <= adjusted_zero_index && adjusted_zero_index < len {
+            Ok(Some(adjusted_zero_index))
         } else {
             // Out of bounds
-            None
+            Ok(None)
         }
     }
 
-    for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() {
-        let start = offset_window[0] as usize;
-        let end = offset_window[1] as usize;
+    for (row_index, offset_window) in array.offsets().windows(2).enumerate() {
+        let start = offset_window[0];
+        let end = offset_window[1];
         let len = end - start;
 
         // array is null
-        if len == 0 {
+        if len == O::usize_as(0) {
             mutable.extend_nulls(1);
             continue;
         }
 
-        let index = adjusted_array_index(indexes.value(row_index), len);
+        let index = adjusted_array_index::<O>(indexes.value(row_index), len)?;
 
         if let Some(index) = index {
-            mutable.extend(0, start + index as usize, start + index as usize + 1);
+            let start = start.as_usize() + index.as_usize();
+            mutable.extend(0, start, start + 1_usize);
         } else {
             // Index out of bounds
             mutable.extend_nulls(1);
@@ -430,6 +436,32 @@ pub fn array_element(args: &[ArrayRef]) -> Result<ArrayRef> {
     Ok(arrow_array::make_array(data))
 }
 
+/// array_element SQL function
+///
+/// There are two arguments for array_element, the first one is the array, the second one is the 1-indexed index.
+/// `array_element(array, index)`
+///
+/// For example:
+/// > array_element(\[1, 2, 3], 2) -> 2
+pub fn array_element(args: &[ArrayRef]) -> Result<ArrayRef> {
+    match &args[0].data_type() {
+        DataType::List(_) => {
+            let array = as_list_array(&args[0])?;
+            let indexes = as_int64_array(&args[1])?;
+            general_array_element::<i32>(array, indexes)
+        }
+        DataType::LargeList(_) => {
+            let array = as_large_list_array(&args[0])?;
+            let indexes = as_int64_array(&args[1])?;
+            general_array_element::<i64>(array, indexes)
+        }
+        _ => not_impl_err!(
+            "array_element does not support type: {:?}",
+            args[0].data_type()
+        ),
+    }
+}
+
 fn general_except<OffsetSize: OffsetSizeTrait>(
     l: &GenericListArray<OffsetSize>,
     r: &GenericListArray<OffsetSize>,
diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt
index d148f71181..b38f73ecb8 100644
--- a/datafusion/sqllogictest/test_files/array.slt
+++ b/datafusion/sqllogictest/test_files/array.slt
@@ -717,7 +717,7 @@ from arrays_values_without_nulls;
 ## array_element (aliases: array_extract, list_extract, list_element)
 
 # array_element error
-query error DataFusion error: Error during planning: The array_element function can only accept list as the first argument
+query error DataFusion error: Error during planning: The array_element function can only accept list or largelist as the first argument
 select array_element(1, 2);
 
 
@@ -727,58 +727,106 @@ select array_element(make_array(1, 2, 3, 4, 5), 2), array_element(make_array('h'
 ----
 2 l
 
+query IT
+select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3);
+----
+2 l
+
 # array_element scalar function #2 (with positive index; out of bounds)
 query IT
 select array_element(make_array(1, 2, 3, 4, 5), 7), array_element(make_array('h', 'e', 'l', 'l', 'o'), 11);
 ----
 NULL NULL
 
+query IT
+select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 7), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 11);
+----
+NULL NULL
+
 # array_element scalar function #3 (with zero)
 query IT
 select array_element(make_array(1, 2, 3, 4, 5), 0), array_element(make_array('h', 'e', 'l', 'l', 'o'), 0);
 ----
 NULL NULL
 
+query IT
+select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 0), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 0);
+----
+NULL NULL
+
 # array_element scalar function #4 (with NULL)
 query error
 select array_element(make_array(1, 2, 3, 4, 5), NULL), array_element(make_array('h', 'e', 'l', 'l', 'o'), NULL);
 
+query error
+select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), NULL), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), NULL);
+
 # array_element scalar function #5 (with negative index)
 query IT
 select array_element(make_array(1, 2, 3, 4, 5), -2), array_element(make_array('h', 'e', 'l', 'l', 'o'), -3);
 ----
 4 l
 
+query IT
+select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -2), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -3);
+----
+4 l
+
 # array_element scalar function #6 (with negative index; out of bounds)
 query IT
 select array_element(make_array(1, 2, 3, 4, 5), -11), array_element(make_array('h', 'e', 'l', 'l', 'o'), -7);
 ----
 NULL NULL
 
+query IT
+select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -11), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -7);
+----
+NULL NULL
+
 # array_element scalar function #7 (nested array)
 query ?
 select array_element(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), 1);
 ----
 [1, 2, 3, 4, 5]
 
+query ?
+select array_element(arrow_cast(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), 'LargeList(List(Int64))'), 1);
+----
+[1, 2, 3, 4, 5]
+
 # array_extract scalar function #8 (function alias `array_slice`)
 query IT
 select array_extract(make_array(1, 2, 3, 4, 5), 2), array_extract(make_array('h', 'e', 'l', 'l', 'o'), 3);
 ----
 2 l
 
+query IT
+select array_extract(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2), array_extract(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3);
+----
+2 l
+
 # list_element scalar function #9 (function alias `array_slice`)
 query IT
 select list_element(make_array(1, 2, 3, 4, 5), 2), list_element(make_array('h', 'e', 'l', 'l', 'o'), 3);
 ----
 2 l
 
+query IT
+select list_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2), array_extract(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3);
+----
+2 l
+
 # list_extract scalar function #10 (function alias `array_slice`)
 query IT
 select list_extract(make_array(1, 2, 3, 4, 5), 2), list_extract(make_array('h', 'e', 'l', 'l', 'o'), 3);
 ----
 2 l
 
+query IT
+select list_extract(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2), array_extract(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3);
+----
+2 l
+
 # array_element with columns
 query I
 select array_element(column1, column2) from slices;
@@ -791,6 +839,17 @@ NULL
 NULL
 55
 
+query I
+select array_element(arrow_cast(column1, 'LargeList(Int64)'), column2) from slices;
+----
+NULL
+12
+NULL
+37
+NULL
+NULL
+55
+
 # array_element with columns and scalars
 query II
 select array_element(make_array(1, 2, 3, 4, 5), column2), array_element(column1, 3) from slices;
@@ -803,6 +862,17 @@ NULL 23
 NULL 43
 5 NULL
 
+query II
+select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), column2), array_element(arrow_cast(column1, 'LargeList(Int64)'), 3) from slices;
+----
+1 3
+2 13
+NULL 23
+2 33
+4 NULL
+NULL 43
+5 NULL
+
 ## array_pop_back (aliases: `list_pop_back`)
 
 # array_pop_back scalar function #1