You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2019/04/26 06:39:09 UTC

[GitHub] [spark] cloud-fan commented on a change in pull request #24459: [SPARK-24935][SQL][followup] support INIT -> UPDATE -> MERGE -> FINISH in Hive UDAF adapter

cloud-fan commented on a change in pull request #24459: [SPARK-24935][SQL][followup] support INIT -> UPDATE -> MERGE -> FINISH in Hive UDAF adapter
URL: https://github.com/apache/spark/pull/24459#discussion_r278823210
 
 

 ##########
 File path: sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
 ##########
 @@ -410,55 +417,70 @@ private[hive] case class HiveUDAFFunction(
   // aggregate buffer. However, the Spark UDAF framework does not expose this information when
   // creating the buffer. Here we return null, and create the buffer in `update` and `merge`
   // on demand, so that we can know what input we are dealing with.
-  override def createAggregationBuffer(): AggregationBuffer = null
+  override def createAggregationBuffer(): HiveUDAFBuffer = null
 
   @transient
   private lazy val inputProjection = UnsafeProjection.create(children)
 
-  override def update(buffer: AggregationBuffer, input: InternalRow): AggregationBuffer = {
+  override def update(buffer: HiveUDAFBuffer, input: InternalRow): HiveUDAFBuffer = {
     // The input is original data, we create buffer with the partial1 evaluator.
     val nonNullBuffer = if (buffer == null) {
-      partial1HiveEvaluator.evaluator.getNewAggregationBuffer
+      HiveUDAFBuffer(partial1HiveEvaluator.evaluator.getNewAggregationBuffer, false)
     } else {
       buffer
     }
 
+    assert(!nonNullBuffer.canDoMerge, "can not call `merge` then `update` on a Hive UDAF.")
+
     partial1HiveEvaluator.evaluator.iterate(
-      nonNullBuffer, wrap(inputProjection(input), inputWrappers, cached, inputDataTypes))
+      nonNullBuffer.buf, wrap(inputProjection(input), inputWrappers, cached, inputDataTypes))
     nonNullBuffer
   }
 
-  override def merge(buffer: AggregationBuffer, input: AggregationBuffer): AggregationBuffer = {
+  override def merge(buffer: HiveUDAFBuffer, input: HiveUDAFBuffer): HiveUDAFBuffer = {
     // The input is aggregate buffer, we create buffer with the final evaluator.
     val nonNullBuffer = if (buffer == null) {
-      finalHiveEvaluator.evaluator.getNewAggregationBuffer
+      HiveUDAFBuffer(finalHiveEvaluator.evaluator.getNewAggregationBuffer, true)
     } else {
       buffer
     }
 
+    // It's possible that we've called `update` of this Hive UDAF, and some specific Hive UDAF
+    // implementation can't mix the `update` and `merge` calls during its life cycle. To work
+    // around it, here we create a fresh buffer with final evaluator, and merge the existing buffer
+    // to it, and replace the existing buffer with it.
+    val mergeableBuf = if (!nonNullBuffer.canDoMerge) {
+      val newBuf = finalHiveEvaluator.evaluator.getNewAggregationBuffer
+      finalHiveEvaluator.evaluator.merge(
+        newBuf, partial1HiveEvaluator.evaluator.terminatePartial(nonNullBuffer.buf))
+      HiveUDAFBuffer(newBuf, true)
+    } else {
+      nonNullBuffer
+    }
+
     // The 2nd argument of the Hive `GenericUDAFEvaluator.merge()` method is an input aggregation
     // buffer in the 3rd format mentioned in the ScalaDoc of this class. Originally, Hive converts
     // this `AggregationBuffer`s into this format before shuffling partial aggregation results, and
     // calls `GenericUDAFEvaluator.terminatePartial()` to do the conversion.
     finalHiveEvaluator.evaluator.merge(
-      nonNullBuffer, partial1HiveEvaluator.evaluator.terminatePartial(input))
-    nonNullBuffer
+      mergeableBuf.buf, partial1HiveEvaluator.evaluator.terminatePartial(input.buf))
+    mergeableBuf
   }
 
-  override def eval(buffer: AggregationBuffer): Any = {
-    resultUnwrapper(finalHiveEvaluator.evaluator.terminate(buffer))
+  override def eval(buffer: HiveUDAFBuffer): Any = {
+    resultUnwrapper(finalHiveEvaluator.evaluator.terminate(buffer.buf))
   }
 
-  override def serialize(buffer: AggregationBuffer): Array[Byte] = {
+  override def serialize(buffer: HiveUDAFBuffer): Array[Byte] = {
     // Serializes an `AggregationBuffer` that holds partial aggregation results so that we can
     // shuffle it for global aggregation later.
-    aggBufferSerDe.serialize(buffer)
+    aggBufferSerDe.serialize(buffer.buf)
   }
 
-  override def deserialize(bytes: Array[Byte]): AggregationBuffer = {
+  override def deserialize(bytes: Array[Byte]): HiveUDAFBuffer = {
     // Deserializes an `AggregationBuffer` from the shuffled partial aggregation phase to prepare
     // for global aggregation by merging multiple partial aggregation results within a single group.
-    aggBufferSerDe.deserialize(bytes)
+    HiveUDAFBuffer(aggBufferSerDe.deserialize(bytes), false)
 
 Review comment:
   the deserialized buffer can only appear as the second parameter in `merge`, so `canDoMerge` doesn't matter here.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org