You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/11/13 19:02:36 UTC
(arrow-datafusion) branch main updated: Implement `array_union` (#7897)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new cbb2fd7846 Implement `array_union` (#7897)
cbb2fd7846 is described below
commit cbb2fd784655b24424cb811a5e215a522d1961b9
Author: Edmondo Porcu <ed...@gmail.com>
AuthorDate: Mon Nov 13 11:02:30 2023 -0800
Implement `array_union` (#7897)
* Initial implementation of array union without deduplication
* Update datafusion/physical-expr/src/array_expressions.rs
Co-authored-by: comphead <co...@users.noreply.github.com>
* Update docs/source/user-guide/expressions.md
Co-authored-by: comphead <co...@users.noreply.github.com>
* Row based implementation of array_union
* Added asymmetrical test
* Addressing PR comments
* Implementing code review feedback
* Added script
* Added tests for array
* Additional tests
* Removing spurious import from array_intersect
---------
Co-authored-by: comphead <co...@users.noreply.github.com>
---
datafusion/expr/src/built_in_function.rs | 6 ++
datafusion/expr/src/expr_fn.rs | 2 +
datafusion/physical-expr/src/array_expressions.rs | 83 +++++++++++++++++++-
datafusion/physical-expr/src/functions.rs | 4 +-
datafusion/proto/proto/datafusion.proto | 1 +
datafusion/proto/src/generated/pbjson.rs | 3 +
datafusion/proto/src/generated/prost.rs | 3 +
datafusion/proto/src/logical_plan/from_proto.rs | 7 ++
datafusion/proto/src/logical_plan/to_proto.rs | 1 +
datafusion/sqllogictest/test_files/array.slt | 95 +++++++++++++++++++++++
docs/source/user-guide/expressions.md | 1 +
docs/source/user-guide/sql/scalar_functions.md | 38 +++++++++
12 files changed, 242 insertions(+), 2 deletions(-)
diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs
index ca3ca18e4d..0d2c1f2e3c 100644
--- a/datafusion/expr/src/built_in_function.rs
+++ b/datafusion/expr/src/built_in_function.rs
@@ -176,6 +176,8 @@ pub enum BuiltinScalarFunction {
ArrayToString,
/// array_intersect
ArrayIntersect,
+ /// array_union
+ ArrayUnion,
/// cardinality
Cardinality,
/// construct an array from columns
@@ -401,6 +403,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArraySlice => Volatility::Immutable,
BuiltinScalarFunction::ArrayToString => Volatility::Immutable,
BuiltinScalarFunction::ArrayIntersect => Volatility::Immutable,
+ BuiltinScalarFunction::ArrayUnion => Volatility::Immutable,
BuiltinScalarFunction::Cardinality => Volatility::Immutable,
BuiltinScalarFunction::MakeArray => Volatility::Immutable,
BuiltinScalarFunction::Ascii => Volatility::Immutable,
@@ -581,6 +584,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArraySlice => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArrayToString => Ok(Utf8),
BuiltinScalarFunction::ArrayIntersect => Ok(input_expr_types[0].clone()),
+ BuiltinScalarFunction::ArrayUnion => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::Cardinality => Ok(UInt64),
BuiltinScalarFunction::MakeArray => match input_expr_types.len() {
0 => Ok(List(Arc::new(Field::new("item", Null, true)))),
@@ -885,6 +889,7 @@ impl BuiltinScalarFunction {
Signature::variadic_any(self.volatility())
}
BuiltinScalarFunction::ArrayIntersect => Signature::any(2, self.volatility()),
+ BuiltinScalarFunction::ArrayUnion => Signature::any(2, self.volatility()),
BuiltinScalarFunction::Cardinality => Signature::any(1, self.volatility()),
BuiltinScalarFunction::MakeArray => {
// 0 or more arguments of arbitrary type
@@ -1508,6 +1513,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] {
"array_join",
"list_join",
],
+ BuiltinScalarFunction::ArrayUnion => &["array_union", "list_union"],
BuiltinScalarFunction::Cardinality => &["cardinality"],
BuiltinScalarFunction::MakeArray => &["make_array", "make_list"],
BuiltinScalarFunction::ArrayIntersect => &["array_intersect", "list_intersect"],
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index 0e0ad46da1..0d920beb41 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -717,6 +717,8 @@ scalar_expr!(
array delimiter,
"converts each element to its text representation."
);
+scalar_expr!(ArrayUnion, array_union, array1 array2, "returns an array of the elements in the union of array1 and array2 without duplicates.");
+
scalar_expr!(
Cardinality,
cardinality,
diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs
index 60e09c5a9c..9b074ff0ee 100644
--- a/datafusion/physical-expr/src/array_expressions.rs
+++ b/datafusion/physical-expr/src/array_expressions.rs
@@ -27,6 +27,7 @@ use arrow::datatypes::{DataType, Field, UInt64Type};
use arrow::row::{RowConverter, SortField};
use arrow_buffer::NullBuffer;
+use arrow_schema::FieldRef;
use datafusion_common::cast::{
as_generic_string_array, as_int64_array, as_list_array, as_string_array,
};
@@ -36,8 +37,8 @@ use datafusion_common::{
DataFusionError, Result,
};
-use hashbrown::HashSet;
use itertools::Itertools;
+use std::collections::HashSet;
macro_rules! downcast_arg {
($ARG:expr, $ARRAY_TYPE:ident) => {{
@@ -1382,6 +1383,86 @@ macro_rules! to_string {
}};
}
+fn union_generic_lists<OffsetSize: OffsetSizeTrait>(
+ l: &GenericListArray<OffsetSize>,
+ r: &GenericListArray<OffsetSize>,
+ field: &FieldRef,
+) -> Result<GenericListArray<OffsetSize>> {
+ let converter = RowConverter::new(vec![SortField::new(l.value_type().clone())])?;
+
+ let nulls = NullBuffer::union(l.nulls(), r.nulls());
+ let l_values = l.values().clone();
+ let r_values = r.values().clone();
+ let l_values = converter.convert_columns(&[l_values])?;
+ let r_values = converter.convert_columns(&[r_values])?;
+
+ // Might be worth adding an upstream OffsetBufferBuilder
+ let mut offsets = Vec::<OffsetSize>::with_capacity(l.len() + 1);
+ offsets.push(OffsetSize::usize_as(0));
+ let mut rows = Vec::with_capacity(l_values.num_rows() + r_values.num_rows());
+ let mut dedup = HashSet::new();
+ for (l_w, r_w) in l.offsets().windows(2).zip(r.offsets().windows(2)) {
+ let l_slice = l_w[0].as_usize()..l_w[1].as_usize();
+ let r_slice = r_w[0].as_usize()..r_w[1].as_usize();
+ for i in l_slice {
+ let left_row = l_values.row(i);
+ if dedup.insert(left_row) {
+ rows.push(left_row);
+ }
+ }
+ for i in r_slice {
+ let right_row = r_values.row(i);
+ if dedup.insert(right_row) {
+ rows.push(right_row);
+ }
+ }
+ offsets.push(OffsetSize::usize_as(rows.len()));
+ dedup.clear();
+ }
+
+ let values = converter.convert_rows(rows)?;
+ let offsets = OffsetBuffer::new(offsets.into());
+ let result = values[0].clone();
+ Ok(GenericListArray::<OffsetSize>::new(
+ field.clone(),
+ offsets,
+ result,
+ nulls,
+ ))
+}
+
+/// Array_union SQL function
+pub fn array_union(args: &[ArrayRef]) -> Result<ArrayRef> {
+ if args.len() != 2 {
+ return exec_err!("array_union needs two arguments");
+ }
+ let array1 = &args[0];
+ let array2 = &args[1];
+ match (array1.data_type(), array2.data_type()) {
+ (DataType::Null, _) => Ok(array2.clone()),
+ (_, DataType::Null) => Ok(array1.clone()),
+ (DataType::List(field_ref), DataType::List(_)) => {
+ check_datatypes("array_union", &[&array1, &array2])?;
+ let list1 = array1.as_list::<i32>();
+ let list2 = array2.as_list::<i32>();
+ let result = union_generic_lists::<i32>(list1, list2, field_ref)?;
+ Ok(Arc::new(result))
+ }
+ (DataType::LargeList(field_ref), DataType::LargeList(_)) => {
+ check_datatypes("array_union", &[&array1, &array2])?;
+ let list1 = array1.as_list::<i64>();
+ let list2 = array2.as_list::<i64>();
+ let result = union_generic_lists::<i64>(list1, list2, field_ref)?;
+ Ok(Arc::new(result))
+ }
+ _ => {
+ internal_err!(
+ "array_union only support list with offsets of type int32 and int64"
+ )
+ }
+ }
+}
+
/// Array_to_string SQL function
pub fn array_to_string(args: &[ArrayRef]) -> Result<ArrayRef> {
let arr = &args[0];
diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs
index 9185ade313..80c0eaf054 100644
--- a/datafusion/physical-expr/src/functions.rs
+++ b/datafusion/physical-expr/src/functions.rs
@@ -407,7 +407,9 @@ pub fn create_physical_fun(
BuiltinScalarFunction::MakeArray => {
Arc::new(|args| make_scalar_function(array_expressions::make_array)(args))
}
-
+ BuiltinScalarFunction::ArrayUnion => {
+ Arc::new(|args| make_scalar_function(array_expressions::array_union)(args))
+ }
// struct functions
BuiltinScalarFunction::Struct => Arc::new(struct_expressions::struct_expr),
diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto
index 62b226e333..793378a1ea 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -635,6 +635,7 @@ enum ScalarFunction {
StringToArray = 117;
ToTimestampNanos = 118;
ArrayIntersect = 119;
+ ArrayUnion = 120;
}
message ScalarFunctionNode {
diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs
index 7602e1a366..a78da2a51c 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -20908,6 +20908,7 @@ impl serde::Serialize for ScalarFunction {
Self::StringToArray => "StringToArray",
Self::ToTimestampNanos => "ToTimestampNanos",
Self::ArrayIntersect => "ArrayIntersect",
+ Self::ArrayUnion => "ArrayUnion",
};
serializer.serialize_str(variant)
}
@@ -21039,6 +21040,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction {
"StringToArray",
"ToTimestampNanos",
"ArrayIntersect",
+ "ArrayUnion",
];
struct GeneratedVisitor;
@@ -21199,6 +21201,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction {
"StringToArray" => Ok(ScalarFunction::StringToArray),
"ToTimestampNanos" => Ok(ScalarFunction::ToTimestampNanos),
"ArrayIntersect" => Ok(ScalarFunction::ArrayIntersect),
+ "ArrayUnion" => Ok(ScalarFunction::ArrayUnion),
_ => Err(serde::de::Error::unknown_variant(value, FIELDS)),
}
}
diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs
index 825481a188..7b7b0afb92 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -2562,6 +2562,7 @@ pub enum ScalarFunction {
StringToArray = 117,
ToTimestampNanos = 118,
ArrayIntersect = 119,
+ ArrayUnion = 120,
}
impl ScalarFunction {
/// String value of the enum field names used in the ProtoBuf definition.
@@ -2690,6 +2691,7 @@ impl ScalarFunction {
ScalarFunction::StringToArray => "StringToArray",
ScalarFunction::ToTimestampNanos => "ToTimestampNanos",
ScalarFunction::ArrayIntersect => "ArrayIntersect",
+ ScalarFunction::ArrayUnion => "ArrayUnion",
}
}
/// Creates an enum from field names used in the ProtoBuf definition.
@@ -2815,6 +2817,7 @@ impl ScalarFunction {
"StringToArray" => Some(Self::StringToArray),
"ToTimestampNanos" => Some(Self::ToTimestampNanos),
"ArrayIntersect" => Some(Self::ArrayIntersect),
+ "ArrayUnion" => Some(Self::ArrayUnion),
_ => None,
}
}
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs
index 674492edef..f7e38757e9 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -484,6 +484,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction {
ScalarFunction::ArraySlice => Self::ArraySlice,
ScalarFunction::ArrayToString => Self::ArrayToString,
ScalarFunction::ArrayIntersect => Self::ArrayIntersect,
+ ScalarFunction::ArrayUnion => Self::ArrayUnion,
ScalarFunction::Cardinality => Self::Cardinality,
ScalarFunction::Array => Self::MakeArray,
ScalarFunction::NullIf => Self::NullIf,
@@ -1424,6 +1425,12 @@ pub fn parse_expr(
ScalarFunction::ArrayNdims => {
Ok(array_ndims(parse_expr(&args[0], registry)?))
}
+ ScalarFunction::ArrayUnion => Ok(array(
+ args.to_owned()
+ .iter()
+ .map(|expr| parse_expr(expr, registry))
+ .collect::<Result<Vec<_>, _>>()?,
+ )),
ScalarFunction::Sqrt => Ok(sqrt(parse_expr(&args[0], registry)?)),
ScalarFunction::Cbrt => Ok(cbrt(parse_expr(&args[0], registry)?)),
ScalarFunction::Sin => Ok(sin(parse_expr(&args[0], registry)?)),
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs
index 946f2c6964..2bb7f89c7d 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -1487,6 +1487,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction {
BuiltinScalarFunction::ArraySlice => Self::ArraySlice,
BuiltinScalarFunction::ArrayToString => Self::ArrayToString,
BuiltinScalarFunction::ArrayIntersect => Self::ArrayIntersect,
+ BuiltinScalarFunction::ArrayUnion => Self::ArrayUnion,
BuiltinScalarFunction::Cardinality => Self::Cardinality,
BuiltinScalarFunction::MakeArray => Self::Array,
BuiltinScalarFunction::NullIf => Self::NullIf,
diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt
index 9207f0f0e3..54741afdf8 100644
--- a/datafusion/sqllogictest/test_files/array.slt
+++ b/datafusion/sqllogictest/test_files/array.slt
@@ -1919,6 +1919,101 @@ select array_to_string(make_array(), ',')
----
(empty)
+
+## array_union (aliases: `list_union`)
+
+# array_union scalar function #1
+query ?
+select array_union([1, 2, 3, 4], [5, 6, 3, 4]);
+----
+[1, 2, 3, 4, 5, 6]
+
+# array_union scalar function #2
+query ?
+select array_union([1, 2, 3, 4], [5, 6, 7, 8]);
+----
+[1, 2, 3, 4, 5, 6, 7, 8]
+
+# array_union scalar function #3
+query ?
+select array_union([1,2,3], []);
+----
+[1, 2, 3]
+
+# array_union scalar function #4
+query ?
+select array_union([1, 2, 3, 4], [5, 4]);
+----
+[1, 2, 3, 4, 5]
+
+# array_union scalar function #5
+statement ok
+CREATE TABLE arrays_with_repeating_elements_for_union
+AS VALUES
+ ([1], [2]),
+ ([2, 3], [3]),
+ ([3], [3, 4])
+;
+
+query ?
+select array_union(column1, column2) from arrays_with_repeating_elements_for_union;
+----
+[1, 2]
+[2, 3]
+[3, 4]
+
+statement ok
+drop table arrays_with_repeating_elements_for_union;
+
+# array_union scalar function #6
+query ?
+select array_union([], []);
+----
+NULL
+
+# array_union scalar function #7
+query ?
+select array_union([[null]], []);
+----
+[[]]
+
+# array_union scalar function #8
+query ?
+select array_union([null], [null]);
+----
+[]
+
+# array_union scalar function #9
+query ?
+select array_union(null, []);
+----
+NULL
+
+# array_union scalar function #10
+query ?
+select array_union(null, null);
+----
+NULL
+
+# array_union scalar function #11
+query ?
+select array_union([1.2, 3.0], [1.2, 3.0, 5.7]);
+----
+[1.2, 3.0, 5.7]
+
+# array_union scalar function #12
+query ?
+select array_union(['hello'], ['hello','datafusion']);
+----
+[hello, datafusion]
+
+
+
+
+
+
+
+
# list_to_string scalar function #4 (function alias `array_to_string`)
query TTT
select list_to_string(['h', 'e', 'l', 'l', 'o'], ','), list_to_string([1, 2, 3, 4, 5], '-'), list_to_string([1.0, 2.0, 3.0], '|');
diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md
index 27384dccff..bec3ba9bb2 100644
--- a/docs/source/user-guide/expressions.md
+++ b/docs/source/user-guide/expressions.md
@@ -233,6 +233,7 @@ Unlike to some databases the math functions in Datafusion works the same way as
| array_slice(array, index) | Returns a slice of the array. `array_slice([1, 2, 3, 4, 5, 6, 7, 8], 3, 6) -> [3, 4, 5, 6]` |
| array_to_string(array, delimiter) | Converts each element to its text representation. `array_to_string([1, 2, 3, 4], ',') -> 1,2,3,4` |
| array_intersect(array1, array2) | Returns an array of the elements in the intersection of array1 and array2. `array_intersect([1, 2, 3, 4], [5, 6, 3, 4]) -> [3, 4]` |
+| array_union(array1, array2) | Returns an array of the elements in the union of array1 and array2 without duplicates. `array_union([1, 2, 3, 4], [5, 6, 3, 4]) -> [1, 2, 3, 4, 5, 6]` |
| cardinality(array) | Returns the total number of elements in the array. `cardinality([[1, 2, 3], [4, 5, 6]]) -> 6` |
| make_array(value1, [value2 [, ...]]) | Returns an Arrow array using the specified input expressions. `make_array(1, 2, 3) -> [1, 2, 3]` |
| trim_array(array, n) | Deprecated |
diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md
index be05084fb2..2959e82024 100644
--- a/docs/source/user-guide/sql/scalar_functions.md
+++ b/docs/source/user-guide/sql/scalar_functions.md
@@ -2211,6 +2211,44 @@ array_to_string(array, delimiter)
- list_join
- list_to_string
+### `array_union`
+
+Returns an array of elements that are present in both arrays (all elements from both arrays) with out duplicates.
+
+```
+array_union(array1, array2)
+```
+
+#### Arguments
+
+- **array1**: Array expression.
+ Can be a constant, column, or function, and any combination of array operators.
+- **array2**: Array expression.
+ Can be a constant, column, or function, and any combination of array operators.
+
+#### Example
+
+```
+❯ select array_union([1, 2, 3, 4], [5, 6, 3, 4]);
++----------------------------------------------------+
+| array_union([1, 2, 3, 4], [5, 6, 3, 4]); |
++----------------------------------------------------+
+| [1, 2, 3, 4, 5, 6] |
++----------------------------------------------------+
+❯ select array_union([1, 2, 3, 4], [5, 6, 7, 8]);
++----------------------------------------------------+
+| array_union([1, 2, 3, 4], [5, 6, 7, 8]); |
++----------------------------------------------------+
+| [1, 2, 3, 4, 5, 6] |
++----------------------------------------------------+
+```
+
+---
+
+#### Aliases
+
+- list_union
+
### `cardinality`
Returns the total number of elements in the array.