You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2019/04/30 02:35:53 UTC

[spark] branch master updated: [SPARK-24935][SQL][FOLLOWUP] support INIT -> UPDATE -> MERGE -> FINISH in Hive UDAF adapter

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 7432e7d  [SPARK-24935][SQL][FOLLOWUP] support INIT -> UPDATE -> MERGE -> FINISH in Hive UDAF adapter
7432e7d is described below

commit 7432e7ded44cc0014590d229827546f5d8f93868
Author: Wenchen Fan <we...@databricks.com>
AuthorDate: Tue Apr 30 10:35:23 2019 +0800

    [SPARK-24935][SQL][FOLLOWUP] support INIT -> UPDATE -> MERGE -> FINISH in Hive UDAF adapter
    
    ## What changes were proposed in this pull request?
    
    This is a followup of https://github.com/apache/spark/pull/24144 . #24144 missed one case: when hash aggregate fallback to sort aggregate, the life cycle of UDAF is: INIT -> UPDATE -> MERGE -> FINISH.
    
    However, not all Hive UDAF can support it. Hive UDAF knows the aggregation mode when creating the aggregation buffer, so that it can create different buffers for different inputs: the original data or the aggregation buffer. Please see an example in the [sketches library](https://github.com/DataSketches/sketches-hive/blob/7f9e76e9e03807277146291beb2c7bec40e8672b/src/main/java/com/yahoo/sketches/hive/cpc/DataToSketchUDAF.java#L107). The buffer for UPDATE may not support MERGE.
    
    This PR updates the Hive UDAF adapter in Spark to support INIT -> UPDATE -> MERGE -> FINISH, by turning it to  INIT -> UPDATE -> FINISH + IINIT -> MERGE -> FINISH.
    
    ## How was this patch tested?
    
    a new test case
    
    Closes #24459 from cloud-fan/hive-udaf.
    
    Authored-by: Wenchen Fan <we...@databricks.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../scala/org/apache/spark/sql/hive/hiveUDFs.scala | 54 ++++++++++++++++------
 .../spark/sql/hive/execution/HiveUDAFSuite.scala   | 38 +++++++++------
 2 files changed, 64 insertions(+), 28 deletions(-)

diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index 0938576..76e4085 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -304,6 +304,13 @@ private[hive] case class HiveGenericUDTF(
  *  - `wrap()`/`wrapperFor()`: from 3 to 1
  *  - `unwrap()`/`unwrapperFor()`: from 1 to 3
  *  - `GenericUDAFEvaluator.terminatePartial()`: from 2 to 3
+ *
+ *  Note that, Hive UDAF is initialized with aggregate mode, and some specific Hive UDAFs can't
+ *  mix UPDATE and MERGE actions during its life cycle. However, Spark may do UPDATE on a UDAF and
+ *  then do MERGE, in case of hash aggregate falling back to sort aggregate. To work around this
+ *  issue, we track the ability to do MERGE in the Hive UDAF aggregate buffer. If Spark does
+ *  UPDATE then MERGE, we can detect it and re-create the aggregate buffer with a different
+ *  aggregate mode.
  */
 private[hive] case class HiveUDAFFunction(
     name: String,
@@ -312,7 +319,7 @@ private[hive] case class HiveUDAFFunction(
     isUDAFBridgeRequired: Boolean = false,
     mutableAggBufferOffset: Int = 0,
     inputAggBufferOffset: Int = 0)
-  extends TypedImperativeAggregate[GenericUDAFEvaluator.AggregationBuffer]
+  extends TypedImperativeAggregate[HiveUDAFBuffer]
   with HiveInspectors
   with UserDefinedExpression {
 
@@ -410,55 +417,70 @@ private[hive] case class HiveUDAFFunction(
   // aggregate buffer. However, the Spark UDAF framework does not expose this information when
   // creating the buffer. Here we return null, and create the buffer in `update` and `merge`
   // on demand, so that we can know what input we are dealing with.
-  override def createAggregationBuffer(): AggregationBuffer = null
+  override def createAggregationBuffer(): HiveUDAFBuffer = null
 
   @transient
   private lazy val inputProjection = UnsafeProjection.create(children)
 
-  override def update(buffer: AggregationBuffer, input: InternalRow): AggregationBuffer = {
+  override def update(buffer: HiveUDAFBuffer, input: InternalRow): HiveUDAFBuffer = {
     // The input is original data, we create buffer with the partial1 evaluator.
     val nonNullBuffer = if (buffer == null) {
-      partial1HiveEvaluator.evaluator.getNewAggregationBuffer
+      HiveUDAFBuffer(partial1HiveEvaluator.evaluator.getNewAggregationBuffer, false)
     } else {
       buffer
     }
 
+    assert(!nonNullBuffer.canDoMerge, "can not call `merge` then `update` on a Hive UDAF.")
+
     partial1HiveEvaluator.evaluator.iterate(
-      nonNullBuffer, wrap(inputProjection(input), inputWrappers, cached, inputDataTypes))
+      nonNullBuffer.buf, wrap(inputProjection(input), inputWrappers, cached, inputDataTypes))
     nonNullBuffer
   }
 
-  override def merge(buffer: AggregationBuffer, input: AggregationBuffer): AggregationBuffer = {
+  override def merge(buffer: HiveUDAFBuffer, input: HiveUDAFBuffer): HiveUDAFBuffer = {
     // The input is aggregate buffer, we create buffer with the final evaluator.
     val nonNullBuffer = if (buffer == null) {
-      finalHiveEvaluator.evaluator.getNewAggregationBuffer
+      HiveUDAFBuffer(finalHiveEvaluator.evaluator.getNewAggregationBuffer, true)
     } else {
       buffer
     }
 
+    // It's possible that we've called `update` of this Hive UDAF, and some specific Hive UDAF
+    // implementation can't mix the `update` and `merge` calls during its life cycle. To work
+    // around it, here we create a fresh buffer with final evaluator, and merge the existing buffer
+    // to it, and replace the existing buffer with it.
+    val mergeableBuf = if (!nonNullBuffer.canDoMerge) {
+      val newBuf = finalHiveEvaluator.evaluator.getNewAggregationBuffer
+      finalHiveEvaluator.evaluator.merge(
+        newBuf, partial1HiveEvaluator.evaluator.terminatePartial(nonNullBuffer.buf))
+      HiveUDAFBuffer(newBuf, true)
+    } else {
+      nonNullBuffer
+    }
+
     // The 2nd argument of the Hive `GenericUDAFEvaluator.merge()` method is an input aggregation
     // buffer in the 3rd format mentioned in the ScalaDoc of this class. Originally, Hive converts
     // this `AggregationBuffer`s into this format before shuffling partial aggregation results, and
     // calls `GenericUDAFEvaluator.terminatePartial()` to do the conversion.
     finalHiveEvaluator.evaluator.merge(
-      nonNullBuffer, partial1HiveEvaluator.evaluator.terminatePartial(input))
-    nonNullBuffer
+      mergeableBuf.buf, partial1HiveEvaluator.evaluator.terminatePartial(input.buf))
+    mergeableBuf
   }
 
-  override def eval(buffer: AggregationBuffer): Any = {
-    resultUnwrapper(finalHiveEvaluator.evaluator.terminate(buffer))
+  override def eval(buffer: HiveUDAFBuffer): Any = {
+    resultUnwrapper(finalHiveEvaluator.evaluator.terminate(buffer.buf))
   }
 
-  override def serialize(buffer: AggregationBuffer): Array[Byte] = {
+  override def serialize(buffer: HiveUDAFBuffer): Array[Byte] = {
     // Serializes an `AggregationBuffer` that holds partial aggregation results so that we can
     // shuffle it for global aggregation later.
-    aggBufferSerDe.serialize(buffer)
+    aggBufferSerDe.serialize(buffer.buf)
   }
 
-  override def deserialize(bytes: Array[Byte]): AggregationBuffer = {
+  override def deserialize(bytes: Array[Byte]): HiveUDAFBuffer = {
     // Deserializes an `AggregationBuffer` from the shuffled partial aggregation phase to prepare
     // for global aggregation by merging multiple partial aggregation results within a single group.
-    aggBufferSerDe.deserialize(bytes)
+    HiveUDAFBuffer(aggBufferSerDe.deserialize(bytes), false)
   }
 
   // Helper class used to de/serialize Hive UDAF `AggregationBuffer` objects
@@ -506,3 +528,5 @@ private[hive] case class HiveUDAFFunction(
     }
   }
 }
+
+case class HiveUDAFBuffer(buf: AggregationBuffer, canDoMerge: Boolean)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala
index ef40323..3252cda 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala
@@ -28,10 +28,10 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn
 import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo
 import test.org.apache.spark.sql.MyDoubleAvg
 
-import org.apache.spark.SparkException
 import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
 import org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec
 import org.apache.spark.sql.hive.test.TestHiveSingleton
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SQLTestUtils
 
 class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
@@ -94,21 +94,33 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
     ))
   }
 
-  test("customized Hive UDAF with two aggregation buffers") {
-    val df = sql("SELECT key % 2, mock2(value) FROM t GROUP BY key % 2")
+  test("SPARK-24935: customized Hive UDAF with two aggregation buffers") {
+    withTempView("v") {
+      spark.range(100).createTempView("v")
+      val df = sql("SELECT id % 2, mock2(id) FROM v GROUP BY id % 2")
 
-    val aggs = df.queryExecution.executedPlan.collect {
-      case agg: ObjectHashAggregateExec => agg
-    }
+      val aggs = df.queryExecution.executedPlan.collect {
+        case agg: ObjectHashAggregateExec => agg
+      }
 
-    // There should be two aggregate operators, one for partial aggregation, and the other for
-    // global aggregation.
-    assert(aggs.length == 2)
+      // There should be two aggregate operators, one for partial aggregation, and the other for
+      // global aggregation.
+      assert(aggs.length == 2)
 
-    checkAnswer(df, Seq(
-      Row(0, Row(1, 1)),
-      Row(1, Row(1, 1))
-    ))
+      withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "1") {
+        checkAnswer(df, Seq(
+          Row(0, Row(50, 0)),
+          Row(1, Row(50, 0))
+        ))
+      }
+
+      withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "100") {
+        checkAnswer(df, Seq(
+          Row(0, Row(50, 0)),
+          Row(1, Row(50, 0))
+        ))
+      }
+    }
   }
 
   test("call JAVA UDAF") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org