You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2019/03/24 23:08:06 UTC

[spark] branch master updated: [SPARK-24935][SQL] fix Hive UDAF with two aggregation buffers

This is an automated email from the ASF dual-hosted git repository.

lixiao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a6c207c  [SPARK-24935][SQL] fix Hive UDAF with two aggregation buffers
a6c207c is described below

commit a6c207c9c0c7aa057cfa27d16fe882b396440113
Author: pgandhi <pg...@verizonmedia.com>
AuthorDate: Sun Mar 24 16:07:35 2019 -0700

    [SPARK-24935][SQL] fix Hive UDAF with two aggregation buffers
    
    ## What changes were proposed in this pull request?
    
    Hive UDAF knows the aggregation mode when creating the aggregation buffer, so that it can create different buffers for different inputs: the original data or the aggregation buffer. Please see an example in the [sketches library](https://github.com/DataSketches/sketches-hive/blob/7f9e76e9e03807277146291beb2c7bec40e8672b/src/main/java/com/yahoo/sketches/hive/cpc/DataToSketchUDAF.java#L107).
    
    However, the Hive UDAF adapter in Spark always creates the buffer with partial1 mode, which can only deal with one input: the original data. This PR fixes it.
    
    All credits go to pgandhi999 , who investigate the problem and study the Hive UDAF behaviors, and write the tests.
    
    close https://github.com/apache/spark/pull/23778
    
    ## How was this patch tested?
    
    a new test
    
    Closes #24144 from cloud-fan/hive.
    
    Lead-authored-by: pgandhi <pg...@verizonmedia.com>
    Co-authored-by: Wenchen Fan <we...@databricks.com>
    Signed-off-by: gatorsmile <ga...@gmail.com>
---
 .../scala/org/apache/spark/sql/hive/hiveUDFs.scala |  64 ++++++++-----
 .../spark/sql/hive/execution/HiveUDAFSuite.scala   | 106 +++++++++++++++++++++
 2 files changed, 147 insertions(+), 23 deletions(-)

diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index 4a84509..8ece4b5 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -352,12 +352,14 @@ private[hive] case class HiveUDAFFunction(
     HiveEvaluator(evaluator, evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputInspectors))
   }
 
-  // The UDAF evaluator used to merge partial aggregation results.
+  // The UDAF evaluator used to consume partial aggregation results and produce final results.
+  // Hive `ObjectInspector` used to inspect final results.
   @transient
-  private lazy val partial2ModeEvaluator = {
+  private lazy val finalHiveEvaluator = {
     val evaluator = newEvaluator()
-    evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, Array(partial1HiveEvaluator.objectInspector))
-    evaluator
+    HiveEvaluator(
+      evaluator,
+      evaluator.init(GenericUDAFEvaluator.Mode.FINAL, Array(partial1HiveEvaluator.objectInspector)))
   }
 
   // Spark SQL data type of partial aggregation results
@@ -365,16 +367,6 @@ private[hive] case class HiveUDAFFunction(
   private lazy val partialResultDataType =
     inspectorToDataType(partial1HiveEvaluator.objectInspector)
 
-  // The UDAF evaluator used to compute the final result from a partial aggregation result objects.
-  // Hive `ObjectInspector` used to inspect the final aggregation result object.
-  @transient
-  private lazy val finalHiveEvaluator = {
-    val evaluator = newEvaluator()
-    HiveEvaluator(
-      evaluator,
-      evaluator.init(GenericUDAFEvaluator.Mode.FINAL, Array(partial1HiveEvaluator.objectInspector)))
-  }
-
   // Wrapper functions used to wrap Spark SQL input arguments into Hive specific format.
   @transient
   private lazy val inputWrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray
@@ -401,25 +393,43 @@ private[hive] case class HiveUDAFFunction(
     s"$name($distinct${children.map(_.sql).mkString(", ")})"
   }
 
-  override def createAggregationBuffer(): AggregationBuffer =
-    partial1HiveEvaluator.evaluator.getNewAggregationBuffer
+  // The hive UDAF may create different buffers to handle different inputs: original data or
+  // aggregate buffer. However, the Spark UDAF framework does not expose this information when
+  // creating the buffer. Here we return null, and create the buffer in `update` and `merge`
+  // on demand, so that we can know what input we are dealing with.
+  override def createAggregationBuffer(): AggregationBuffer = null
 
   @transient
   private lazy val inputProjection = UnsafeProjection.create(children)
 
   override def update(buffer: AggregationBuffer, input: InternalRow): AggregationBuffer = {
+    // The input is original data, we create buffer with the partial1 evaluator.
+    val nonNullBuffer = if (buffer == null) {
+      partial1HiveEvaluator.evaluator.getNewAggregationBuffer
+    } else {
+      buffer
+    }
+
     partial1HiveEvaluator.evaluator.iterate(
-      buffer, wrap(inputProjection(input), inputWrappers, cached, inputDataTypes))
-    buffer
+      nonNullBuffer, wrap(inputProjection(input), inputWrappers, cached, inputDataTypes))
+    nonNullBuffer
   }
 
   override def merge(buffer: AggregationBuffer, input: AggregationBuffer): AggregationBuffer = {
+    // The input is aggregate buffer, we create buffer with the final evaluator.
+    val nonNullBuffer = if (buffer == null) {
+      finalHiveEvaluator.evaluator.getNewAggregationBuffer
+    } else {
+      buffer
+    }
+
     // The 2nd argument of the Hive `GenericUDAFEvaluator.merge()` method is an input aggregation
     // buffer in the 3rd format mentioned in the ScalaDoc of this class. Originally, Hive converts
     // this `AggregationBuffer`s into this format before shuffling partial aggregation results, and
     // calls `GenericUDAFEvaluator.terminatePartial()` to do the conversion.
-    partial2ModeEvaluator.merge(buffer, partial1HiveEvaluator.evaluator.terminatePartial(input))
-    buffer
+    finalHiveEvaluator.evaluator.merge(
+      nonNullBuffer, partial1HiveEvaluator.evaluator.terminatePartial(input))
+    nonNullBuffer
   }
 
   override def eval(buffer: AggregationBuffer): Any = {
@@ -450,11 +460,19 @@ private[hive] case class HiveUDAFFunction(
     private val mutableRow = new GenericInternalRow(1)
 
     def serialize(buffer: AggregationBuffer): Array[Byte] = {
+      // The buffer may be null if there is no input. It's unclear if the hive UDAF accepts null
+      // buffer, for safety we create an empty buffer here.
+      val nonNullBuffer = if (buffer == null) {
+        partial1HiveEvaluator.evaluator.getNewAggregationBuffer
+      } else {
+        buffer
+      }
+
       // `GenericUDAFEvaluator.terminatePartial()` converts an `AggregationBuffer` into an object
       // that can be inspected by the `ObjectInspector` returned by `GenericUDAFEvaluator.init()`.
       // Then we can unwrap it to a Spark SQL value.
       mutableRow.update(0, partialResultUnwrapper(
-        partial1HiveEvaluator.evaluator.terminatePartial(buffer)))
+        partial1HiveEvaluator.evaluator.terminatePartial(nonNullBuffer)))
       val unsafeRow = projection(mutableRow)
       val bytes = ByteBuffer.allocate(unsafeRow.getSizeInBytes)
       unsafeRow.writeTo(bytes)
@@ -466,11 +484,11 @@ private[hive] case class HiveUDAFFunction(
       // returned by `GenericUDAFEvaluator.terminatePartial()` back to an `AggregationBuffer`. The
       // workaround here is creating an initial `AggregationBuffer` first and then merge the
       // deserialized object into the buffer.
-      val buffer = partial2ModeEvaluator.getNewAggregationBuffer
+      val buffer = finalHiveEvaluator.evaluator.getNewAggregationBuffer
       val unsafeRow = new UnsafeRow(1)
       unsafeRow.pointTo(bytes, bytes.length)
       val partialResult = unsafeRow.get(0, partialResultDataType)
-      partial2ModeEvaluator.merge(buffer, partialResultWrapper(partialResult))
+      finalHiveEvaluator.evaluator.merge(buffer, partialResultWrapper(partialResult))
       buffer
     }
   }
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala
index fe3dece..ef40323 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala
@@ -28,6 +28,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn
 import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo
 import test.org.apache.spark.sql.MyDoubleAvg
 
+import org.apache.spark.SparkException
 import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
 import org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec
 import org.apache.spark.sql.hive.test.TestHiveSingleton
@@ -40,6 +41,7 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
     super.beforeAll()
     sql(s"CREATE TEMPORARY FUNCTION mock AS '${classOf[MockUDAF].getName}'")
     sql(s"CREATE TEMPORARY FUNCTION hive_max AS '${classOf[GenericUDAFMax].getName}'")
+    sql(s"CREATE TEMPORARY FUNCTION mock2 AS '${classOf[MockUDAF2].getName}'")
 
     Seq(
       (0: Integer) -> "val_0",
@@ -92,6 +94,23 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
     ))
   }
 
+  test("customized Hive UDAF with two aggregation buffers") {
+    val df = sql("SELECT key % 2, mock2(value) FROM t GROUP BY key % 2")
+
+    val aggs = df.queryExecution.executedPlan.collect {
+      case agg: ObjectHashAggregateExec => agg
+    }
+
+    // There should be two aggregate operators, one for partial aggregation, and the other for
+    // global aggregation.
+    assert(aggs.length == 2)
+
+    checkAnswer(df, Seq(
+      Row(0, Row(1, 1)),
+      Row(1, Row(1, 1))
+    ))
+  }
+
   test("call JAVA UDAF") {
     withTempView("temp") {
       withUserDefinedFunction("myDoubleAvg" -> false) {
@@ -127,12 +146,22 @@ class MockUDAF extends AbstractGenericUDAFResolver {
   override def getEvaluator(info: Array[TypeInfo]): GenericUDAFEvaluator = new MockUDAFEvaluator
 }
 
+class MockUDAF2 extends AbstractGenericUDAFResolver {
+  override def getEvaluator(info: Array[TypeInfo]): GenericUDAFEvaluator = new MockUDAFEvaluator2
+}
+
 class MockUDAFBuffer(var nonNullCount: Long, var nullCount: Long)
   extends GenericUDAFEvaluator.AbstractAggregationBuffer {
 
   override def estimate(): Int = JavaDataModel.PRIMITIVES2 * 2
 }
 
+class MockUDAFBuffer2(var nonNullCount: Long, var nullCount: Long)
+  extends GenericUDAFEvaluator.AbstractAggregationBuffer {
+
+  override def estimate(): Int = JavaDataModel.PRIMITIVES2 * 2
+}
+
 class MockUDAFEvaluator extends GenericUDAFEvaluator {
   private val nonNullCountOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector
 
@@ -184,3 +213,80 @@ class MockUDAFEvaluator extends GenericUDAFEvaluator {
 
   override def terminate(agg: AggregationBuffer): AnyRef = terminatePartial(agg)
 }
+
+// Same as MockUDAFEvaluator but using two aggregation buffers, one for PARTIAL1 and the other
+// for PARTIAL2.
+class MockUDAFEvaluator2 extends GenericUDAFEvaluator {
+  private val nonNullCountOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector
+
+  private val nullCountOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector
+  private var aggMode: Mode = null
+
+  private val bufferOI = {
+    val fieldNames = Seq("nonNullCount", "nullCount").asJava
+    val fieldOIs = Seq(nonNullCountOI: ObjectInspector, nullCountOI: ObjectInspector).asJava
+    ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs)
+  }
+
+  private val nonNullCountField = bufferOI.getStructFieldRef("nonNullCount")
+
+  private val nullCountField = bufferOI.getStructFieldRef("nullCount")
+
+  override def getNewAggregationBuffer: AggregationBuffer = {
+    // These 2 modes consume original data.
+    if (aggMode == Mode.PARTIAL1 || aggMode == Mode.COMPLETE) {
+      new MockUDAFBuffer(0L, 0L)
+    } else {
+      new MockUDAFBuffer2(0L, 0L)
+    }
+  }
+
+  override def reset(agg: AggregationBuffer): Unit = {
+    val buffer = agg.asInstanceOf[MockUDAFBuffer]
+    buffer.nonNullCount = 0L
+    buffer.nullCount = 0L
+  }
+
+  override def init(mode: Mode, parameters: Array[ObjectInspector]): ObjectInspector = {
+    aggMode = mode
+    bufferOI
+  }
+
+  override def iterate(agg: AggregationBuffer, parameters: Array[AnyRef]): Unit = {
+    val buffer = agg.asInstanceOf[MockUDAFBuffer]
+    if (parameters.head eq null) {
+      buffer.nullCount += 1L
+    } else {
+      buffer.nonNullCount += 1L
+    }
+  }
+
+  override def merge(agg: AggregationBuffer, partial: Object): Unit = {
+    if (partial ne null) {
+      val nonNullCount = nonNullCountOI.get(bufferOI.getStructFieldData(partial, nonNullCountField))
+      val nullCount = nullCountOI.get(bufferOI.getStructFieldData(partial, nullCountField))
+      val buffer = agg.asInstanceOf[MockUDAFBuffer2]
+      buffer.nonNullCount += nonNullCount
+      buffer.nullCount += nullCount
+    }
+  }
+
+  // As this method is called for both states, Partial1 and Partial2, the hack in the method
+  // to check for class of aggregation buffer was necessary.
+  override def terminatePartial(agg: AggregationBuffer): AnyRef = {
+    var result: AnyRef = null
+    if (agg.getClass.toString.contains("MockUDAFBuffer2")) {
+      val buffer = agg.asInstanceOf[MockUDAFBuffer2]
+      result = Array[Object](buffer.nonNullCount: java.lang.Long, buffer.nullCount: java.lang.Long)
+    } else {
+      val buffer = agg.asInstanceOf[MockUDAFBuffer]
+      result = Array[Object](buffer.nonNullCount: java.lang.Long, buffer.nullCount: java.lang.Long)
+    }
+    result
+  }
+
+  override def terminate(agg: AggregationBuffer): AnyRef = {
+    val buffer = agg.asInstanceOf[MockUDAFBuffer2]
+    Array[Object](buffer.nonNullCount: java.lang.Long, buffer.nullCount: java.lang.Long)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org