You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ya...@apache.org on 2020/05/01 13:12:15 UTC

[spark] branch branch-3.0 updated: [SPARK-31500][SQL] collect_set() of BinaryType returns duplicate elements

This is an automated email from the ASF dual-hosted git repository.

yamamuro pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 1795a70  [SPARK-31500][SQL] collect_set() of BinaryType returns duplicate elements
1795a70 is described below

commit 1795a70bb04fad1b8cf76271443a448f8d72fc8a
Author: Pablo Langa <so...@gmail.com>
AuthorDate: Fri May 1 22:09:04 2020 +0900

    [SPARK-31500][SQL] collect_set() of BinaryType returns duplicate elements
    
    ### What changes were proposed in this pull request?
    
    The collect_set() aggregate function should produce a set of distinct elements. When the column argument's type is BinayType this is not the case.
    
    Example:
    ```scala
    import org.apache.spark.sql.functions._
    import org.apache.spark.sql.expressions.Window
    
    case class R(id: String, value: String, bytes: Array[Byte])
    def makeR(id: String, value: String) = R(id, value, value.getBytes)
    val df = Seq(makeR("a", "dog"), makeR("a", "cat"), makeR("a", "cat"), makeR("b", "fish")).toDF()
    // In the example below "bytesSet" erroneously has duplicates but "stringSet" does not (as expected).
    df.agg(collect_set('value) as "stringSet", collect_set('bytes) as "byteSet").show(truncate=false)
    // The same problem is displayed when using window functions.
    val win = Window.partitionBy('id).rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
    val result = df.select(
      collect_set('value).over(win) as "stringSet",
      collect_set('bytes).over(win) as "bytesSet"
    )
    .select('stringSet, 'bytesSet, size('stringSet) as "stringSetSize", size('bytesSet) as "bytesSetSize")
    .show()
    ```
    
    We use a HashSet buffer to accumulate the results, the problem is that arrays equality in Scala don't behave as expected, arrays ara just plain java arrays and the equality don't compare the content of the arrays
    Array(1, 2, 3) == Array(1, 2, 3)  => False
    The result is that duplicates are not removed in the hashset
    
    The solution proposed is that in the last stage, when we have all the data in the Hashset buffer, we delete duplicates changing the type of the elements and then transform it to the original type.
    This transformation is only applied when we have a BinaryType
    
    ### Why are the changes needed?
    Fix the bug explained
    
    ### Does this PR introduce any user-facing change?
    Yes. Now `collect_set()` correctly deduplicates array of byte.
    
    ### How was this patch tested?
    Unit testing
    
    Closes #28351 from planga82/feature/SPARK-31500_COLLECT_SET_bug.
    
    Authored-by: Pablo Langa <so...@gmail.com>
    Signed-off-by: Takeshi Yamamuro <ya...@apache.org>
    (cherry picked from commit 4fecc20f6ecdfe642890cf0a368a85558c40a47c)
    Signed-off-by: Takeshi Yamamuro <ya...@apache.org>
---
 .../catalyst/expressions/aggregate/collect.scala   | 45 +++++++++++++++++++---
 .../apache/spark/sql/DataFrameAggregateSuite.scala | 16 ++++++++
 2 files changed, 55 insertions(+), 6 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
index 5848aa3..0a3d876 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
@@ -23,6 +23,7 @@ import scala.collection.mutable
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.ArrayData
 import org.apache.spark.sql.catalyst.util.GenericArrayData
 import org.apache.spark.sql.types._
 
@@ -46,13 +47,15 @@ abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImper
   // actual order of input rows.
   override lazy val deterministic: Boolean = false
 
+  protected def convertToBufferElement(value: Any): Any
+
   override def update(buffer: T, input: InternalRow): T = {
     val value = child.eval(input)
 
     // Do not allow null values. We follow the semantics of Hive's collect_list/collect_set here.
     // See: org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMkCollectionEvaluator
     if (value != null) {
-      buffer += InternalRow.copyValue(value)
+      buffer += convertToBufferElement(value)
     }
     buffer
   }
@@ -61,12 +64,10 @@ abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImper
     buffer ++= other
   }
 
-  override def eval(buffer: T): Any = {
-    new GenericArrayData(buffer.toArray)
-  }
+  protected val bufferElementType: DataType
 
   private lazy val projection = UnsafeProjection.create(
-    Array[DataType](ArrayType(elementType = child.dataType, containsNull = false)))
+    Array[DataType](ArrayType(elementType = bufferElementType, containsNull = false)))
   private lazy val row = new UnsafeRow(1)
 
   override def serialize(obj: T): Array[Byte] = {
@@ -77,7 +78,7 @@ abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImper
   override def deserialize(bytes: Array[Byte]): T = {
     val buffer = createAggregationBuffer()
     row.pointTo(bytes, bytes.length)
-    row.getArray(0).foreach(child.dataType, (_, x: Any) => buffer += x)
+    row.getArray(0).foreach(bufferElementType, (_, x: Any) => buffer += x)
     buffer
   }
 }
@@ -105,6 +106,10 @@ case class CollectList(
 
   def this(child: Expression) = this(child, 0, 0)
 
+  override lazy val bufferElementType = child.dataType
+
+  override def convertToBufferElement(value: Any): Any = InternalRow.copyValue(value)
+
   override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
     copy(mutableAggBufferOffset = newMutableAggBufferOffset)
 
@@ -114,6 +119,10 @@ case class CollectList(
   override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty
 
   override def prettyName: String = "collect_list"
+
+  override def eval(buffer: mutable.ArrayBuffer[Any]): Any = {
+    new GenericArrayData(buffer.toArray)
+  }
 }
 
 /**
@@ -139,6 +148,30 @@ case class CollectSet(
 
   def this(child: Expression) = this(child, 0, 0)
 
+  override lazy val bufferElementType = child.dataType match {
+    case BinaryType => ArrayType(ByteType)
+    case other => other
+  }
+
+  override def convertToBufferElement(value: Any): Any = child.dataType match {
+    /*
+     * collect_set() of BinaryType should not return duplicate elements,
+     * Java byte arrays use referential equality and identity hash codes
+     * so we need to use a different catalyst value for arrays
+     */
+    case BinaryType => UnsafeArrayData.fromPrimitiveArray(value.asInstanceOf[Array[Byte]])
+    case _ => InternalRow.copyValue(value)
+  }
+
+  override def eval(buffer: mutable.HashSet[Any]): Any = {
+    val array = child.dataType match {
+      case BinaryType =>
+        buffer.iterator.map(_.asInstanceOf[ArrayData].toByteArray).toArray
+      case _ => buffer.toArray
+    }
+    new GenericArrayData(array)
+  }
+
   override def checkInputDataTypes(): TypeCheckResult = {
     if (!child.dataType.existsRecursively(_.isInstanceOf[MapType])) {
       TypeCheckResult.TypeCheckSuccess
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 288f3da..4edf3a5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -530,6 +530,22 @@ class DataFrameAggregateSuite extends QueryTest
     )
   }
 
+  test("SPARK-31500: collect_set() of BinaryType returns duplicate elements") {
+    val bytesTest1 = "test1".getBytes
+    val bytesTest2 = "test2".getBytes
+    val df = Seq(bytesTest1, bytesTest1, bytesTest2).toDF("a")
+    checkAnswer(df.select(size(collect_set($"a"))), Row(2) :: Nil)
+
+    val a = "aa".getBytes
+    val b = "bb".getBytes
+    val c = "cc".getBytes
+    val d = "dd".getBytes
+    val df1 = Seq((a, b), (a, b), (c, d))
+      .toDF("x", "y")
+      .select(struct($"x", $"y").as("a"))
+    checkAnswer(df1.select(size(collect_set($"a"))), Row(2) :: Nil)
+  }
+
   test("collect_set functions cannot have maps") {
     val df = Seq((1, 3, 0), (2, 3, 0), (3, 4, 1))
       .toDF("a", "x", "y")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org