You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ue...@apache.org on 2018/07/12 08:42:36 UTC
spark git commit: [SPARK-23914][SQL] Add array_union function
Repository: spark
Updated Branches:
refs/heads/master 5ad4735bd -> 301bff706
[SPARK-23914][SQL] Add array_union function
## What changes were proposed in this pull request?
The PR adds the SQL function `array_union`. The behavior of the function is based on Presto's one.
This function returns returns an array of the elements in the union of array1 and array2.
Note: The order of elements in the result is not defined.
## How was this patch tested?
Added UTs
Author: Kazuaki Ishizaki <is...@jp.ibm.com>
Closes #21061 from kiszk/SPARK-23914.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/301bff70
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/301bff70
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/301bff70
Branch: refs/heads/master
Commit: 301bff70637983426d76b106b7c659c1f28ed7bf
Parents: 5ad4735
Author: Kazuaki Ishizaki <is...@jp.ibm.com>
Authored: Thu Jul 12 17:42:29 2018 +0900
Committer: Takuya UESHIN <ue...@databricks.com>
Committed: Thu Jul 12 17:42:29 2018 +0900
----------------------------------------------------------------------
python/pyspark/sql/functions.py | 19 ++
.../catalyst/expressions/UnsafeArrayData.java | 19 +-
.../catalyst/analysis/FunctionRegistry.scala | 1 +
.../expressions/collectionOperations.scala | 319 +++++++++++++++++++
.../CollectionExpressionsSuite.scala | 81 +++++
.../scala/org/apache/spark/sql/functions.scala | 11 +
.../spark/sql/DataFrameFunctionsSuite.scala | 52 +++
7 files changed, 499 insertions(+), 3 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/301bff70/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 9f61e29..5ef7398 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -2033,6 +2033,25 @@ def array_distinct(col):
return Column(sc._jvm.functions.array_distinct(_to_java_column(col)))
+@ignore_unicode_prefix
+@since(2.4)
+def array_union(col1, col2):
+ """
+ Collection function: returns an array of the elements in the union of col1 and col2,
+ without duplicates.
+
+ :param col1: name of column containing array
+ :param col2: name of column containing array
+
+ >>> from pyspark.sql import Row
+ >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])])
+ >>> df.select(array_union(df.c1, df.c2)).collect()
+ [Row(array_union(c1, c2)=[u'b', u'a', u'c', u'd', u'f'])]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.array_union(_to_java_column(col1), _to_java_column(col2)))
+
+
@since(1.4)
def explode(col):
"""Returns a new row for each element in the given array or map.
http://git-wip-us.apache.org/repos/asf/spark/blob/301bff70/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
index 4dd2b73..cf2a5ed 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
@@ -450,7 +450,7 @@ public final class UnsafeArrayData extends ArrayData {
return values;
}
- private static UnsafeArrayData fromPrimitiveArray(
+ public static UnsafeArrayData fromPrimitiveArray(
Object arr, int offset, int length, int elementSize) {
final long headerInBytes = calculateHeaderPortionInBytes(length);
final long valueRegionInBytes = (long)elementSize * length;
@@ -463,14 +463,27 @@ public final class UnsafeArrayData extends ArrayData {
final long[] data = new long[(int)totalSizeInLongs];
Platform.putLong(data, Platform.LONG_ARRAY_OFFSET, length);
- Platform.copyMemory(arr, offset, data,
- Platform.LONG_ARRAY_OFFSET + headerInBytes, valueRegionInBytes);
+ if (arr != null) {
+ Platform.copyMemory(arr, offset, data,
+ Platform.LONG_ARRAY_OFFSET + headerInBytes, valueRegionInBytes);
+ }
UnsafeArrayData result = new UnsafeArrayData();
result.pointTo(data, Platform.LONG_ARRAY_OFFSET, (int)totalSizeInLongs * 8);
return result;
}
+ public static UnsafeArrayData forPrimitiveArray(int offset, int length, int elementSize) {
+ return fromPrimitiveArray(null, offset, length, elementSize);
+ }
+
+ public static boolean shouldUseGenericArrayData(int elementSize, int length) {
+ final long headerInBytes = calculateHeaderPortionInBytes(length);
+ final long valueRegionInBytes = (long)elementSize * length;
+ final long totalSizeInLongs = (headerInBytes + valueRegionInBytes + 7) / 8;
+ return totalSizeInLongs > Integer.MAX_VALUE / 8;
+ }
+
public static UnsafeArrayData fromPrimitiveArray(boolean[] arr) {
return fromPrimitiveArray(arr, Platform.BOOLEAN_ARRAY_OFFSET, arr.length, 1);
}
http://git-wip-us.apache.org/repos/asf/spark/blob/301bff70/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index e7517e8..1d9e470 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -414,6 +414,7 @@ object FunctionRegistry {
expression[ArrayJoin]("array_join"),
expression[ArrayPosition]("array_position"),
expression[ArraySort]("array_sort"),
+ expression[ArrayUnion]("array_union"),
expression[CreateMap]("map"),
expression[CreateNamedStruct]("named_struct"),
expression[ElementAt]("element_at"),
http://git-wip-us.apache.org/repos/asf/spark/blob/301bff70/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index b8f2aa3..0f4f4f1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -3486,3 +3486,322 @@ case class ArrayDistinct(child: Expression)
override def prettyName: String = "array_distinct"
}
+
+/**
+ * Will become common base class for [[ArrayUnion]], ArrayIntersect, and ArrayExcept.
+ */
+abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast {
+ override def dataType: DataType = {
+ val dataTypes = children.map(_.dataType.asInstanceOf[ArrayType])
+ ArrayType(elementType, dataTypes.exists(_.containsNull))
+ }
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ val typeCheckResult = super.checkInputDataTypes()
+ if (typeCheckResult.isSuccess) {
+ TypeUtils.checkForOrderingExpr(dataType.asInstanceOf[ArrayType].elementType,
+ s"function $prettyName")
+ } else {
+ typeCheckResult
+ }
+ }
+
+ @transient protected lazy val ordering: Ordering[Any] =
+ TypeUtils.getInterpretedOrdering(elementType)
+
+ @transient protected lazy val elementTypeSupportEquals = elementType match {
+ case BinaryType => false
+ case _: AtomicType => true
+ case _ => false
+ }
+}
+
+object ArraySetLike {
+ def throwUnionLengthOverflowException(length: Int): Unit = {
+ throw new RuntimeException(s"Unsuccessful try to union arrays with $length " +
+ s"elements due to exceeding the array size limit " +
+ s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.")
+ }
+}
+
+
+/**
+ * Returns an array of the elements in the union of x and y, without duplicates
+ */
+@ExpressionDescription(
+ usage = """
+ _FUNC_(array1, array2) - Returns an array of the elements in the union of array1 and array2,
+ without duplicates.
+ """,
+ examples = """
+ Examples:
+ > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5));
+ array(1, 2, 3, 5)
+ """,
+ since = "2.4.0")
+case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike {
+ var hsInt: OpenHashSet[Int] = _
+ var hsLong: OpenHashSet[Long] = _
+
+ def assignInt(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = {
+ val elem = array.getInt(idx)
+ if (!hsInt.contains(elem)) {
+ if (resultArray != null) {
+ resultArray.setInt(pos, elem)
+ }
+ hsInt.add(elem)
+ true
+ } else {
+ false
+ }
+ }
+
+ def assignLong(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = {
+ val elem = array.getLong(idx)
+ if (!hsLong.contains(elem)) {
+ if (resultArray != null) {
+ resultArray.setLong(pos, elem)
+ }
+ hsLong.add(elem)
+ true
+ } else {
+ false
+ }
+ }
+
+ def evalIntLongPrimitiveType(
+ array1: ArrayData,
+ array2: ArrayData,
+ resultArray: ArrayData,
+ isLongType: Boolean): Int = {
+ // store elements into resultArray
+ var nullElementSize = 0
+ var pos = 0
+ Seq(array1, array2).foreach { array =>
+ var i = 0
+ while (i < array.numElements()) {
+ val size = if (!isLongType) hsInt.size else hsLong.size
+ if (size + nullElementSize > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
+ ArraySetLike.throwUnionLengthOverflowException(size)
+ }
+ if (array.isNullAt(i)) {
+ if (nullElementSize == 0) {
+ if (resultArray != null) {
+ resultArray.setNullAt(pos)
+ }
+ pos += 1
+ nullElementSize = 1
+ }
+ } else {
+ val assigned = if (!isLongType) {
+ assignInt(array, i, resultArray, pos)
+ } else {
+ assignLong(array, i, resultArray, pos)
+ }
+ if (assigned) {
+ pos += 1
+ }
+ }
+ i += 1
+ }
+ }
+ pos
+ }
+
+ override def nullSafeEval(input1: Any, input2: Any): Any = {
+ val array1 = input1.asInstanceOf[ArrayData]
+ val array2 = input2.asInstanceOf[ArrayData]
+
+ if (elementTypeSupportEquals) {
+ elementType match {
+ case IntegerType =>
+ // avoid boxing of primitive int array elements
+ // calculate result array size
+ hsInt = new OpenHashSet[Int]
+ val elements = evalIntLongPrimitiveType(array1, array2, null, false)
+ hsInt = new OpenHashSet[Int]
+ val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData(
+ IntegerType.defaultSize, elements)) {
+ new GenericArrayData(new Array[Any](elements))
+ } else {
+ UnsafeArrayData.forPrimitiveArray(
+ Platform.INT_ARRAY_OFFSET, elements, IntegerType.defaultSize)
+ }
+ evalIntLongPrimitiveType(array1, array2, resultArray, false)
+ resultArray
+ case LongType =>
+ // avoid boxing of primitive long array elements
+ // calculate result array size
+ hsLong = new OpenHashSet[Long]
+ val elements = evalIntLongPrimitiveType(array1, array2, null, true)
+ hsLong = new OpenHashSet[Long]
+ val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData(
+ LongType.defaultSize, elements)) {
+ new GenericArrayData(new Array[Any](elements))
+ } else {
+ UnsafeArrayData.forPrimitiveArray(
+ Platform.LONG_ARRAY_OFFSET, elements, LongType.defaultSize)
+ }
+ evalIntLongPrimitiveType(array1, array2, resultArray, true)
+ resultArray
+ case _ =>
+ val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
+ val hs = new OpenHashSet[Any]
+ var foundNullElement = false
+ Seq(array1, array2).foreach { array =>
+ var i = 0
+ while (i < array.numElements()) {
+ if (array.isNullAt(i)) {
+ if (!foundNullElement) {
+ arrayBuffer += null
+ foundNullElement = true
+ }
+ } else {
+ val elem = array.get(i, elementType)
+ if (!hs.contains(elem)) {
+ if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
+ ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.size)
+ }
+ arrayBuffer += elem
+ hs.add(elem)
+ }
+ }
+ i += 1
+ }
+ }
+ new GenericArrayData(arrayBuffer)
+ }
+ } else {
+ ArrayUnion.unionOrdering(array1, array2, elementType, ordering)
+ }
+ }
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val i = ctx.freshName("i")
+ val pos = ctx.freshName("pos")
+ val value = ctx.freshName("value")
+ val size = ctx.freshName("size")
+ val (postFix, openHashElementType, getter, setter, javaTypeName, castOp, arrayBuilder) =
+ if (elementTypeSupportEquals) {
+ elementType match {
+ case ByteType | ShortType | IntegerType | LongType =>
+ val ptName = CodeGenerator.primitiveTypeName(elementType)
+ val unsafeArray = ctx.freshName("unsafeArray")
+ (if (elementType == LongType) s"$$mcJ$$sp" else s"$$mcI$$sp",
+ if (elementType == LongType) "Long" else "Int",
+ s"get$ptName($i)", s"set$ptName($pos, $value)", CodeGenerator.javaType(elementType),
+ if (elementType == LongType) "(long)" else "(int)",
+ s"""
+ |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")}
+ |${ev.value} = $unsafeArray;
+ """.stripMargin)
+ case _ =>
+ val genericArrayData = classOf[GenericArrayData].getName
+ val et = ctx.addReferenceObj("elementType", elementType)
+ ("", "Object",
+ s"get($i, $et)", s"update($pos, $value)", "Object", "",
+ s"${ev.value} = new $genericArrayData(new Object[$size]);")
+ }
+ } else {
+ ("", "", "", "", "", "", "")
+ }
+
+ nullSafeCodeGen(ctx, ev, (array1, array2) => {
+ if (openHashElementType != "") {
+ // Here, we ensure elementTypeSupportEquals is true
+ val foundNullElement = ctx.freshName("foundNullElement")
+ val openHashSet = classOf[OpenHashSet[_]].getName
+ val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$openHashElementType()"
+ val hs = ctx.freshName("hs")
+ val arrayData = classOf[ArrayData].getName
+ val arrays = ctx.freshName("arrays")
+ val array = ctx.freshName("array")
+ val arrayDataIdx = ctx.freshName("arrayDataIdx")
+ s"""
+ |$openHashSet $hs = new $openHashSet$postFix($classTag);
+ |boolean $foundNullElement = false;
+ |$arrayData[] $arrays = new $arrayData[]{$array1, $array2};
+ |for (int $arrayDataIdx = 0; $arrayDataIdx < 2; $arrayDataIdx++) {
+ | $arrayData $array = $arrays[$arrayDataIdx];
+ | for (int $i = 0; $i < $array.numElements(); $i++) {
+ | if ($array.isNullAt($i)) {
+ | $foundNullElement = true;
+ | } else {
+ | $hs.add$postFix($array.$getter);
+ | }
+ | }
+ |}
+ |int $size = $hs.size() + ($foundNullElement ? 1 : 0);
+ |$arrayBuilder
+ |$hs = new $openHashSet$postFix($classTag);
+ |$foundNullElement = false;
+ |int $pos = 0;
+ |for (int $arrayDataIdx = 0; $arrayDataIdx < 2; $arrayDataIdx++) {
+ | $arrayData $array = $arrays[$arrayDataIdx];
+ | for (int $i = 0; $i < $array.numElements(); $i++) {
+ | if ($array.isNullAt($i)) {
+ | if (!$foundNullElement) {
+ | ${ev.value}.setNullAt($pos++);
+ | $foundNullElement = true;
+ | }
+ | } else {
+ | $javaTypeName $value = $array.$getter;
+ | if (!$hs.contains($castOp $value)) {
+ | $hs.add$postFix($value);
+ | ${ev.value}.$setter;
+ | $pos++;
+ | }
+ | }
+ | }
+ |}
+ """.stripMargin
+ } else {
+ val arrayUnion = classOf[ArrayUnion].getName
+ val et = ctx.addReferenceObj("elementTypeUnion", elementType)
+ val order = ctx.addReferenceObj("orderingUnion", ordering)
+ val method = "unionOrdering"
+ s"${ev.value} = $arrayUnion$$.MODULE$$.$method($array1, $array2, $et, $order);"
+ }
+ })
+ }
+
+ override def prettyName: String = "array_union"
+}
+
+object ArrayUnion {
+ def unionOrdering(
+ array1: ArrayData,
+ array2: ArrayData,
+ elementType: DataType,
+ ordering: Ordering[Any]): ArrayData = {
+ val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
+ var alreadyIncludeNull = false
+ Seq(array1, array2).foreach(_.foreach(elementType, (_, elem) => {
+ var found = false
+ if (elem == null) {
+ if (alreadyIncludeNull) {
+ found = true
+ } else {
+ alreadyIncludeNull = true
+ }
+ } else {
+ // check elem is already stored in arrayBuffer or not?
+ var j = 0
+ while (!found && j < arrayBuffer.size) {
+ val va = arrayBuffer(j)
+ if (va != null && ordering.equiv(va, elem)) {
+ found = true
+ }
+ j = j + 1
+ }
+ }
+ if (!found) {
+ if (arrayBuffer.length > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
+ ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.length)
+ }
+ arrayBuffer += elem
+ }
+ }))
+ new GenericArrayData(arrayBuffer)
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/301bff70/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index a838a2e..85d6a1b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -1304,4 +1304,85 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(ArrayDistinct(c1), Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)))
checkEvaluation(ArrayDistinct(c2), Seq[Seq[Int]](null, Seq[Int](2, 1)))
}
+
+ test("Array Union") {
+ val a00 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false))
+ val a01 = Literal.create(Seq(4, 2), ArrayType(IntegerType, containsNull = false))
+ val a02 = Literal.create(Seq(1, 2, null, 4, 5), ArrayType(IntegerType, containsNull = true))
+ val a03 = Literal.create(Seq(-5, 4, -3, 2, 4), ArrayType(IntegerType, containsNull = false))
+ val a04 = Literal.create(Seq.empty[Int], ArrayType(IntegerType, containsNull = false))
+ val a05 = Literal.create(Seq[Byte](1, 2, 3), ArrayType(ByteType, containsNull = false))
+ val a06 = Literal.create(Seq[Byte](4, 2), ArrayType(ByteType, containsNull = false))
+ val a07 = Literal.create(Seq[Short](1, 2, 3), ArrayType(ShortType, containsNull = false))
+ val a08 = Literal.create(Seq[Short](4, 2), ArrayType(ShortType, containsNull = false))
+
+ val a10 = Literal.create(Seq(1L, 2L, 3L), ArrayType(LongType, containsNull = false))
+ val a11 = Literal.create(Seq(4L, 2L), ArrayType(LongType, containsNull = false))
+ val a12 = Literal.create(Seq(1L, 2L, null, 4L, 5L), ArrayType(LongType, containsNull = true))
+ val a13 = Literal.create(Seq(-5L, 4L, -3L, 2L, -1L), ArrayType(LongType, containsNull = false))
+ val a14 = Literal.create(Seq.empty[Long], ArrayType(LongType, containsNull = false))
+
+ val a20 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType, containsNull = false))
+ val a21 = Literal.create(Seq("c", "d", "a", "f"), ArrayType(StringType, containsNull = false))
+ val a22 = Literal.create(Seq("b", null, "a", "g"), ArrayType(StringType, containsNull = true))
+
+ val a30 = Literal.create(Seq(null, null), ArrayType(IntegerType))
+ val a31 = Literal.create(null, ArrayType(StringType))
+
+ checkEvaluation(ArrayUnion(a00, a01), Seq(1, 2, 3, 4))
+ checkEvaluation(ArrayUnion(a02, a03), Seq(1, 2, null, 4, 5, -5, -3))
+ checkEvaluation(ArrayUnion(a03, a02), Seq(-5, 4, -3, 2, 1, null, 5))
+ checkEvaluation(ArrayUnion(a02, a04), Seq(1, 2, null, 4, 5))
+ checkEvaluation(ArrayUnion(a05, a06), Seq[Byte](1, 2, 3, 4))
+ checkEvaluation(ArrayUnion(a07, a08), Seq[Short](1, 2, 3, 4))
+
+ checkEvaluation(ArrayUnion(a10, a11), Seq(1L, 2L, 3L, 4L))
+ checkEvaluation(ArrayUnion(a12, a13), Seq(1L, 2L, null, 4L, 5L, -5L, -3L, -1L))
+ checkEvaluation(ArrayUnion(a13, a12), Seq(-5L, 4L, -3L, 2L, -1L, 1L, null, 5L))
+ checkEvaluation(ArrayUnion(a12, a14), Seq(1L, 2L, null, 4L, 5L))
+
+ checkEvaluation(ArrayUnion(a20, a21), Seq("b", "a", "c", "d", "f"))
+ checkEvaluation(ArrayUnion(a20, a22), Seq("b", "a", "c", null, "g"))
+
+ checkEvaluation(ArrayUnion(a30, a30), Seq(null))
+ checkEvaluation(ArrayUnion(a20, a31), null)
+ checkEvaluation(ArrayUnion(a31, a20), null)
+
+ val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)),
+ ArrayType(BinaryType))
+ val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](4, 3)),
+ ArrayType(BinaryType))
+ val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](4, 3)),
+ ArrayType(BinaryType))
+ val b3 = Literal.create(Seq[Array[Byte]](
+ Array[Byte](1, 2), Array[Byte](4, 3), Array[Byte](1, 2)), ArrayType(BinaryType))
+ val b4 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), null), ArrayType(BinaryType))
+ val b5 = Literal.create(Seq[Array[Byte]](null, Array[Byte](1, 2)), ArrayType(BinaryType))
+ val b6 = Literal.create(Seq.empty, ArrayType(BinaryType))
+ val arrayWithBinaryNull = Literal.create(Seq(null), ArrayType(BinaryType))
+
+ checkEvaluation(ArrayUnion(b0, b1),
+ Seq(Array[Byte](5, 6), Array[Byte](1, 2), Array[Byte](2, 1), Array[Byte](4, 3)))
+ checkEvaluation(ArrayUnion(b0, b2),
+ Seq(Array[Byte](5, 6), Array[Byte](1, 2), Array[Byte](4, 3)))
+ checkEvaluation(ArrayUnion(b2, b4), Seq(Array[Byte](1, 2), Array[Byte](4, 3), null))
+ checkEvaluation(ArrayUnion(b3, b0),
+ Seq(Array[Byte](1, 2), Array[Byte](4, 3), Array[Byte](5, 6)))
+ checkEvaluation(ArrayUnion(b4, b0), Seq(Array[Byte](1, 2), null, Array[Byte](5, 6)))
+ checkEvaluation(ArrayUnion(b4, b5), Seq(Array[Byte](1, 2), null))
+ checkEvaluation(ArrayUnion(b6, b4), Seq(Array[Byte](1, 2), null))
+ checkEvaluation(ArrayUnion(b4, arrayWithBinaryNull), Seq(Array[Byte](1, 2), null))
+
+ val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)),
+ ArrayType(ArrayType(IntegerType)))
+ val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)),
+ ArrayType(ArrayType(IntegerType)))
+ checkEvaluation(ArrayUnion(aa0, aa1),
+ Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4), Seq[Int](5, 6), Seq[Int](2, 1)))
+
+ assert(ArrayUnion(a00, a01).dataType.asInstanceOf[ArrayType].containsNull === false)
+ assert(ArrayUnion(a00, a02).dataType.asInstanceOf[ArrayType].containsNull === true)
+ assert(ArrayUnion(a20, a21).dataType.asInstanceOf[ArrayType].containsNull === false)
+ assert(ArrayUnion(a20, a22).dataType.asInstanceOf[ArrayType].containsNull === true)
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/301bff70/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 6b956dd..b98ab11 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -3204,6 +3204,7 @@ object functions {
/**
* Remove all elements that equal to element from the given array.
+ *
* @group collection_funcs
* @since 2.4.0
*/
@@ -3219,6 +3220,16 @@ object functions {
def array_distinct(e: Column): Column = withExpr { ArrayDistinct(e.expr) }
/**
+ * Returns an array of the elements in the union of the given two arrays, without duplicates.
+ *
+ * @group collection_funcs
+ * @since 2.4.0
+ */
+ def array_union(col1: Column, col2: Column): Column = withExpr {
+ ArrayUnion(col1.expr, col2.expr)
+ }
+
+ /**
* Creates a new row for each element in the given array or map column.
*
* @group collection_funcs
http://git-wip-us.apache.org/repos/asf/spark/blob/301bff70/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index d60ed7a..d461571 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -1198,6 +1198,58 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
"argument 1 requires (array or map) type, however, '`_1`' is of string type"))
}
+ test("array_union functions") {
+ val df1 = Seq((Array(1, 2, 3), Array(4, 2))).toDF("a", "b")
+ val ans1 = Row(Seq(1, 2, 3, 4))
+ checkAnswer(df1.select(array_union($"a", $"b")), ans1)
+ checkAnswer(df1.selectExpr("array_union(a, b)"), ans1)
+
+ val df2 = Seq((Array[Integer](1, 2, null, 4, 5), Array(-5, 4, -3, 2, -1))).toDF("a", "b")
+ val ans2 = Row(Seq(1, 2, null, 4, 5, -5, -3, -1))
+ checkAnswer(df2.select(array_union($"a", $"b")), ans2)
+ checkAnswer(df2.selectExpr("array_union(a, b)"), ans2)
+
+ val df3 = Seq((Array(1L, 2L, 3L), Array(4L, 2L))).toDF("a", "b")
+ val ans3 = Row(Seq(1L, 2L, 3L, 4L))
+ checkAnswer(df3.select(array_union($"a", $"b")), ans3)
+ checkAnswer(df3.selectExpr("array_union(a, b)"), ans3)
+
+ val df4 = Seq((Array[java.lang.Long](1L, 2L, null, 4L, 5L), Array(-5L, 4L, -3L, 2L, -1L)))
+ .toDF("a", "b")
+ val ans4 = Row(Seq(1L, 2L, null, 4L, 5L, -5L, -3L, -1L))
+ checkAnswer(df4.select(array_union($"a", $"b")), ans4)
+ checkAnswer(df4.selectExpr("array_union(a, b)"), ans4)
+
+ val df5 = Seq((Array("b", "a", "c"), Array("b", null, "a", "g"))).toDF("a", "b")
+ val ans5 = Row(Seq("b", "a", "c", null, "g"))
+ checkAnswer(df5.select(array_union($"a", $"b")), ans5)
+ checkAnswer(df5.selectExpr("array_union(a, b)"), ans5)
+
+ val df6 = Seq((null, Array("a"))).toDF("a", "b")
+ intercept[AnalysisException] {
+ df6.select(array_union($"a", $"b"))
+ }
+ intercept[AnalysisException] {
+ df6.selectExpr("array_union(a, b)")
+ }
+
+ val df7 = Seq((null, null)).toDF("a", "b")
+ intercept[AnalysisException] {
+ df7.select(array_union($"a", $"b"))
+ }
+ intercept[AnalysisException] {
+ df7.selectExpr("array_union(a, b)")
+ }
+
+ val df8 = Seq((Array(Array(1)), Array("a"))).toDF("a", "b")
+ intercept[AnalysisException] {
+ df8.select(array_union($"a", $"b"))
+ }
+ intercept[AnalysisException] {
+ df8.selectExpr("array_union(a, b)")
+ }
+ }
+
test("concat function - arrays") {
val nseqi : Seq[Int] = null
val nseqs : Seq[String] = null
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org