You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/11/13 19:02:36 UTC

(arrow-datafusion) branch main updated: Implement `array_union` (#7897)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new cbb2fd7846 Implement `array_union` (#7897)
cbb2fd7846 is described below

commit cbb2fd784655b24424cb811a5e215a522d1961b9
Author: Edmondo Porcu <ed...@gmail.com>
AuthorDate: Mon Nov 13 11:02:30 2023 -0800

    Implement `array_union` (#7897)
    
    * Initial implementation of array union without deduplication
    
    * Update datafusion/physical-expr/src/array_expressions.rs
    
    Co-authored-by: comphead <co...@users.noreply.github.com>
    
    * Update docs/source/user-guide/expressions.md
    
    Co-authored-by: comphead <co...@users.noreply.github.com>
    
    * Row based implementation of array_union
    
    * Added asymmetrical test
    
    * Addressing PR comments
    
    * Implementing code review feedback
    
    * Added script
    
    * Added tests for array
    
    * Additional tests
    
    * Removing spurious import from array_intersect
    
    ---------
    
    Co-authored-by: comphead <co...@users.noreply.github.com>
---
 datafusion/expr/src/built_in_function.rs          |  6 ++
 datafusion/expr/src/expr_fn.rs                    |  2 +
 datafusion/physical-expr/src/array_expressions.rs | 83 +++++++++++++++++++-
 datafusion/physical-expr/src/functions.rs         |  4 +-
 datafusion/proto/proto/datafusion.proto           |  1 +
 datafusion/proto/src/generated/pbjson.rs          |  3 +
 datafusion/proto/src/generated/prost.rs           |  3 +
 datafusion/proto/src/logical_plan/from_proto.rs   |  7 ++
 datafusion/proto/src/logical_plan/to_proto.rs     |  1 +
 datafusion/sqllogictest/test_files/array.slt      | 95 +++++++++++++++++++++++
 docs/source/user-guide/expressions.md             |  1 +
 docs/source/user-guide/sql/scalar_functions.md    | 38 +++++++++
 12 files changed, 242 insertions(+), 2 deletions(-)

diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs
index ca3ca18e4d..0d2c1f2e3c 100644
--- a/datafusion/expr/src/built_in_function.rs
+++ b/datafusion/expr/src/built_in_function.rs
@@ -176,6 +176,8 @@ pub enum BuiltinScalarFunction {
     ArrayToString,
     /// array_intersect
     ArrayIntersect,
+    /// array_union
+    ArrayUnion,
     /// cardinality
     Cardinality,
     /// construct an array from columns
@@ -401,6 +403,7 @@ impl BuiltinScalarFunction {
             BuiltinScalarFunction::ArraySlice => Volatility::Immutable,
             BuiltinScalarFunction::ArrayToString => Volatility::Immutable,
             BuiltinScalarFunction::ArrayIntersect => Volatility::Immutable,
+            BuiltinScalarFunction::ArrayUnion => Volatility::Immutable,
             BuiltinScalarFunction::Cardinality => Volatility::Immutable,
             BuiltinScalarFunction::MakeArray => Volatility::Immutable,
             BuiltinScalarFunction::Ascii => Volatility::Immutable,
@@ -581,6 +584,7 @@ impl BuiltinScalarFunction {
             BuiltinScalarFunction::ArraySlice => Ok(input_expr_types[0].clone()),
             BuiltinScalarFunction::ArrayToString => Ok(Utf8),
             BuiltinScalarFunction::ArrayIntersect => Ok(input_expr_types[0].clone()),
+            BuiltinScalarFunction::ArrayUnion => Ok(input_expr_types[0].clone()),
             BuiltinScalarFunction::Cardinality => Ok(UInt64),
             BuiltinScalarFunction::MakeArray => match input_expr_types.len() {
                 0 => Ok(List(Arc::new(Field::new("item", Null, true)))),
@@ -885,6 +889,7 @@ impl BuiltinScalarFunction {
                 Signature::variadic_any(self.volatility())
             }
             BuiltinScalarFunction::ArrayIntersect => Signature::any(2, self.volatility()),
+            BuiltinScalarFunction::ArrayUnion => Signature::any(2, self.volatility()),
             BuiltinScalarFunction::Cardinality => Signature::any(1, self.volatility()),
             BuiltinScalarFunction::MakeArray => {
                 // 0 or more arguments of arbitrary type
@@ -1508,6 +1513,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] {
             "array_join",
             "list_join",
         ],
+        BuiltinScalarFunction::ArrayUnion => &["array_union", "list_union"],
         BuiltinScalarFunction::Cardinality => &["cardinality"],
         BuiltinScalarFunction::MakeArray => &["make_array", "make_list"],
         BuiltinScalarFunction::ArrayIntersect => &["array_intersect", "list_intersect"],
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index 0e0ad46da1..0d920beb41 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -717,6 +717,8 @@ scalar_expr!(
     array delimiter,
     "converts each element to its text representation."
 );
+scalar_expr!(ArrayUnion, array_union, array1 array2, "returns an array of the elements in the union of array1 and array2 without duplicates.");
+
 scalar_expr!(
     Cardinality,
     cardinality,
diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs
index 60e09c5a9c..9b074ff0ee 100644
--- a/datafusion/physical-expr/src/array_expressions.rs
+++ b/datafusion/physical-expr/src/array_expressions.rs
@@ -27,6 +27,7 @@ use arrow::datatypes::{DataType, Field, UInt64Type};
 use arrow::row::{RowConverter, SortField};
 use arrow_buffer::NullBuffer;
 
+use arrow_schema::FieldRef;
 use datafusion_common::cast::{
     as_generic_string_array, as_int64_array, as_list_array, as_string_array,
 };
@@ -36,8 +37,8 @@ use datafusion_common::{
     DataFusionError, Result,
 };
 
-use hashbrown::HashSet;
 use itertools::Itertools;
+use std::collections::HashSet;
 
 macro_rules! downcast_arg {
     ($ARG:expr, $ARRAY_TYPE:ident) => {{
@@ -1382,6 +1383,86 @@ macro_rules! to_string {
     }};
 }
 
+fn union_generic_lists<OffsetSize: OffsetSizeTrait>(
+    l: &GenericListArray<OffsetSize>,
+    r: &GenericListArray<OffsetSize>,
+    field: &FieldRef,
+) -> Result<GenericListArray<OffsetSize>> {
+    let converter = RowConverter::new(vec![SortField::new(l.value_type().clone())])?;
+
+    let nulls = NullBuffer::union(l.nulls(), r.nulls());
+    let l_values = l.values().clone();
+    let r_values = r.values().clone();
+    let l_values = converter.convert_columns(&[l_values])?;
+    let r_values = converter.convert_columns(&[r_values])?;
+
+    // Might be worth adding an upstream OffsetBufferBuilder
+    let mut offsets = Vec::<OffsetSize>::with_capacity(l.len() + 1);
+    offsets.push(OffsetSize::usize_as(0));
+    let mut rows = Vec::with_capacity(l_values.num_rows() + r_values.num_rows());
+    let mut dedup = HashSet::new();
+    for (l_w, r_w) in l.offsets().windows(2).zip(r.offsets().windows(2)) {
+        let l_slice = l_w[0].as_usize()..l_w[1].as_usize();
+        let r_slice = r_w[0].as_usize()..r_w[1].as_usize();
+        for i in l_slice {
+            let left_row = l_values.row(i);
+            if dedup.insert(left_row) {
+                rows.push(left_row);
+            }
+        }
+        for i in r_slice {
+            let right_row = r_values.row(i);
+            if dedup.insert(right_row) {
+                rows.push(right_row);
+            }
+        }
+        offsets.push(OffsetSize::usize_as(rows.len()));
+        dedup.clear();
+    }
+
+    let values = converter.convert_rows(rows)?;
+    let offsets = OffsetBuffer::new(offsets.into());
+    let result = values[0].clone();
+    Ok(GenericListArray::<OffsetSize>::new(
+        field.clone(),
+        offsets,
+        result,
+        nulls,
+    ))
+}
+
+/// Array_union SQL function
+pub fn array_union(args: &[ArrayRef]) -> Result<ArrayRef> {
+    if args.len() != 2 {
+        return exec_err!("array_union needs two arguments");
+    }
+    let array1 = &args[0];
+    let array2 = &args[1];
+    match (array1.data_type(), array2.data_type()) {
+        (DataType::Null, _) => Ok(array2.clone()),
+        (_, DataType::Null) => Ok(array1.clone()),
+        (DataType::List(field_ref), DataType::List(_)) => {
+            check_datatypes("array_union", &[&array1, &array2])?;
+            let list1 = array1.as_list::<i32>();
+            let list2 = array2.as_list::<i32>();
+            let result = union_generic_lists::<i32>(list1, list2, field_ref)?;
+            Ok(Arc::new(result))
+        }
+        (DataType::LargeList(field_ref), DataType::LargeList(_)) => {
+            check_datatypes("array_union", &[&array1, &array2])?;
+            let list1 = array1.as_list::<i64>();
+            let list2 = array2.as_list::<i64>();
+            let result = union_generic_lists::<i64>(list1, list2, field_ref)?;
+            Ok(Arc::new(result))
+        }
+        _ => {
+            internal_err!(
+                "array_union only support list with offsets of type int32 and int64"
+            )
+        }
+    }
+}
+
 /// Array_to_string SQL function
 pub fn array_to_string(args: &[ArrayRef]) -> Result<ArrayRef> {
     let arr = &args[0];
diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs
index 9185ade313..80c0eaf054 100644
--- a/datafusion/physical-expr/src/functions.rs
+++ b/datafusion/physical-expr/src/functions.rs
@@ -407,7 +407,9 @@ pub fn create_physical_fun(
         BuiltinScalarFunction::MakeArray => {
             Arc::new(|args| make_scalar_function(array_expressions::make_array)(args))
         }
-
+        BuiltinScalarFunction::ArrayUnion => {
+            Arc::new(|args| make_scalar_function(array_expressions::array_union)(args))
+        }
         // struct functions
         BuiltinScalarFunction::Struct => Arc::new(struct_expressions::struct_expr),
 
diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto
index 62b226e333..793378a1ea 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -635,6 +635,7 @@ enum ScalarFunction {
   StringToArray = 117;
   ToTimestampNanos = 118;
   ArrayIntersect = 119;
+  ArrayUnion = 120;
 }
 
 message ScalarFunctionNode {
diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs
index 7602e1a366..a78da2a51c 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -20908,6 +20908,7 @@ impl serde::Serialize for ScalarFunction {
             Self::StringToArray => "StringToArray",
             Self::ToTimestampNanos => "ToTimestampNanos",
             Self::ArrayIntersect => "ArrayIntersect",
+            Self::ArrayUnion => "ArrayUnion",
         };
         serializer.serialize_str(variant)
     }
@@ -21039,6 +21040,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction {
             "StringToArray",
             "ToTimestampNanos",
             "ArrayIntersect",
+            "ArrayUnion",
         ];
 
         struct GeneratedVisitor;
@@ -21199,6 +21201,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction {
                     "StringToArray" => Ok(ScalarFunction::StringToArray),
                     "ToTimestampNanos" => Ok(ScalarFunction::ToTimestampNanos),
                     "ArrayIntersect" => Ok(ScalarFunction::ArrayIntersect),
+                    "ArrayUnion" => Ok(ScalarFunction::ArrayUnion),
                     _ => Err(serde::de::Error::unknown_variant(value, FIELDS)),
                 }
             }
diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs
index 825481a188..7b7b0afb92 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -2562,6 +2562,7 @@ pub enum ScalarFunction {
     StringToArray = 117,
     ToTimestampNanos = 118,
     ArrayIntersect = 119,
+    ArrayUnion = 120,
 }
 impl ScalarFunction {
     /// String value of the enum field names used in the ProtoBuf definition.
@@ -2690,6 +2691,7 @@ impl ScalarFunction {
             ScalarFunction::StringToArray => "StringToArray",
             ScalarFunction::ToTimestampNanos => "ToTimestampNanos",
             ScalarFunction::ArrayIntersect => "ArrayIntersect",
+            ScalarFunction::ArrayUnion => "ArrayUnion",
         }
     }
     /// Creates an enum from field names used in the ProtoBuf definition.
@@ -2815,6 +2817,7 @@ impl ScalarFunction {
             "StringToArray" => Some(Self::StringToArray),
             "ToTimestampNanos" => Some(Self::ToTimestampNanos),
             "ArrayIntersect" => Some(Self::ArrayIntersect),
+            "ArrayUnion" => Some(Self::ArrayUnion),
             _ => None,
         }
     }
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs
index 674492edef..f7e38757e9 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -484,6 +484,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction {
             ScalarFunction::ArraySlice => Self::ArraySlice,
             ScalarFunction::ArrayToString => Self::ArrayToString,
             ScalarFunction::ArrayIntersect => Self::ArrayIntersect,
+            ScalarFunction::ArrayUnion => Self::ArrayUnion,
             ScalarFunction::Cardinality => Self::Cardinality,
             ScalarFunction::Array => Self::MakeArray,
             ScalarFunction::NullIf => Self::NullIf,
@@ -1424,6 +1425,12 @@ pub fn parse_expr(
                 ScalarFunction::ArrayNdims => {
                     Ok(array_ndims(parse_expr(&args[0], registry)?))
                 }
+                ScalarFunction::ArrayUnion => Ok(array(
+                    args.to_owned()
+                        .iter()
+                        .map(|expr| parse_expr(expr, registry))
+                        .collect::<Result<Vec<_>, _>>()?,
+                )),
                 ScalarFunction::Sqrt => Ok(sqrt(parse_expr(&args[0], registry)?)),
                 ScalarFunction::Cbrt => Ok(cbrt(parse_expr(&args[0], registry)?)),
                 ScalarFunction::Sin => Ok(sin(parse_expr(&args[0], registry)?)),
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs
index 946f2c6964..2bb7f89c7d 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -1487,6 +1487,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction {
             BuiltinScalarFunction::ArraySlice => Self::ArraySlice,
             BuiltinScalarFunction::ArrayToString => Self::ArrayToString,
             BuiltinScalarFunction::ArrayIntersect => Self::ArrayIntersect,
+            BuiltinScalarFunction::ArrayUnion => Self::ArrayUnion,
             BuiltinScalarFunction::Cardinality => Self::Cardinality,
             BuiltinScalarFunction::MakeArray => Self::Array,
             BuiltinScalarFunction::NullIf => Self::NullIf,
diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt
index 9207f0f0e3..54741afdf8 100644
--- a/datafusion/sqllogictest/test_files/array.slt
+++ b/datafusion/sqllogictest/test_files/array.slt
@@ -1919,6 +1919,101 @@ select array_to_string(make_array(), ',')
 ----
 (empty)
 
+
+## array_union (aliases: `list_union`)
+
+# array_union scalar function #1
+query ?
+select array_union([1, 2, 3, 4], [5, 6, 3, 4]);
+----
+[1, 2, 3, 4, 5, 6]
+
+# array_union scalar function #2
+query ?
+select array_union([1, 2, 3, 4], [5, 6, 7, 8]);
+----
+[1, 2, 3, 4, 5, 6, 7, 8]
+
+# array_union scalar function #3
+query ?
+select array_union([1,2,3], []);
+----
+[1, 2, 3]
+
+# array_union scalar function #4
+query ?
+select array_union([1, 2, 3, 4], [5, 4]);
+----
+[1, 2, 3, 4, 5]
+
+# array_union scalar function #5
+statement ok
+CREATE TABLE arrays_with_repeating_elements_for_union
+AS VALUES
+  ([1], [2]),
+  ([2, 3], [3]),
+  ([3], [3, 4])
+;
+
+query ?
+select array_union(column1, column2) from arrays_with_repeating_elements_for_union;
+----
+[1, 2]
+[2, 3]
+[3, 4]
+
+statement ok
+drop table arrays_with_repeating_elements_for_union;
+
+# array_union scalar function #6
+query ?
+select array_union([], []);
+----
+NULL
+
+# array_union scalar function #7
+query ?
+select array_union([[null]], []);
+----
+[[]]
+
+# array_union scalar function #8
+query ?
+select array_union([null], [null]);
+----
+[]
+
+# array_union scalar function #9
+query ?
+select array_union(null, []);
+----
+NULL
+
+# array_union scalar function #10
+query ?
+select array_union(null, null);
+----
+NULL
+
+# array_union scalar function #11
+query ?
+select array_union([1.2, 3.0], [1.2, 3.0, 5.7]);
+----
+[1.2, 3.0, 5.7]
+
+# array_union scalar function #12
+query ?
+select array_union(['hello'], ['hello','datafusion']);
+----
+[hello, datafusion]
+
+
+
+
+
+
+
+
 # list_to_string scalar function #4 (function alias `array_to_string`)
 query TTT
 select list_to_string(['h', 'e', 'l', 'l', 'o'], ','), list_to_string([1, 2, 3, 4, 5], '-'), list_to_string([1.0, 2.0, 3.0], '|');
diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md
index 27384dccff..bec3ba9bb2 100644
--- a/docs/source/user-guide/expressions.md
+++ b/docs/source/user-guide/expressions.md
@@ -233,6 +233,7 @@ Unlike to some databases the math functions in Datafusion works the same way as
 | array_slice(array, index)             | Returns a slice of the array. `array_slice([1, 2, 3, 4, 5, 6, 7, 8], 3, 6) -> [3, 4, 5, 6]`                                                                              |
 | array_to_string(array, delimiter)     | Converts each element to its text representation. `array_to_string([1, 2, 3, 4], ',') -> 1,2,3,4`                                                                        |
 | array_intersect(array1, array2)       | Returns an array of the elements in the intersection of array1 and array2. `array_intersect([1, 2, 3, 4], [5, 6, 3, 4]) -> [3, 4]`                                       |
+| array_union(array1, array2)           | Returns an array of the elements in the union of array1 and array2 without duplicates. `array_union([1, 2, 3, 4], [5, 6, 3, 4]) -> [1, 2, 3, 4, 5, 6]`                   |
 | cardinality(array)                    | Returns the total number of elements in the array. `cardinality([[1, 2, 3], [4, 5, 6]]) -> 6`                                                                            |
 | make_array(value1, [value2 [, ...]])  | Returns an Arrow array using the specified input expressions. `make_array(1, 2, 3) -> [1, 2, 3]`                                                                         |
 | trim_array(array, n)                  | Deprecated                                                                                                                                                               |
diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md
index be05084fb2..2959e82024 100644
--- a/docs/source/user-guide/sql/scalar_functions.md
+++ b/docs/source/user-guide/sql/scalar_functions.md
@@ -2211,6 +2211,44 @@ array_to_string(array, delimiter)
 - list_join
 - list_to_string
 
+### `array_union`
+
+Returns an array of elements that are present in both arrays (all elements from both arrays) with out duplicates.
+
+```
+array_union(array1, array2)
+```
+
+#### Arguments
+
+- **array1**: Array expression.
+  Can be a constant, column, or function, and any combination of array operators.
+- **array2**: Array expression.
+  Can be a constant, column, or function, and any combination of array operators.
+
+#### Example
+
+```
+❯ select array_union([1, 2, 3, 4], [5, 6, 3, 4]);
++----------------------------------------------------+
+| array_union([1, 2, 3, 4], [5, 6, 3, 4]);           |
++----------------------------------------------------+
+| [1, 2, 3, 4, 5, 6]                                 |
++----------------------------------------------------+
+❯ select array_union([1, 2, 3, 4], [5, 6, 7, 8]);
++----------------------------------------------------+
+| array_union([1, 2, 3, 4], [5, 6, 7, 8]);           |
++----------------------------------------------------+
+| [1, 2, 3, 4, 5, 6]                                 |
++----------------------------------------------------+
+```
+
+---
+
+#### Aliases
+
+- list_union
+
 ### `cardinality`
 
 Returns the total number of elements in the array.