You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yh...@apache.org on 2016/11/16 22:32:40 UTC

spark git commit: [SPARK-18186] Migrate HiveUDAFFunction to TypedImperativeAggregate for partial aggregation support

Repository: spark
Updated Branches:
  refs/heads/master a36a76ac4 -> 2ca8ae9aa


[SPARK-18186] Migrate HiveUDAFFunction to TypedImperativeAggregate for partial aggregation support

## What changes were proposed in this pull request?

While being evaluated in Spark SQL, Hive UDAFs don't support partial aggregation. This PR migrates `HiveUDAFFunction`s to `TypedImperativeAggregate`, which already provides partial aggregation support for aggregate functions that may use arbitrary Java objects as aggregation states.

The following snippet shows the effect of this PR:

```scala
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax
sql(s"CREATE FUNCTION hive_max AS '${classOf[GenericUDAFMax].getName}'")

spark.range(100).createOrReplaceTempView("t")

// A query using both Spark SQL native `max` and Hive `max`
sql(s"SELECT max(id), hive_max(id) FROM t").explain()
```

Before this PR:

```
== Physical Plan ==
SortAggregate(key=[], functions=[max(id#1L), default.hive_max(default.hive_max, HiveFunctionWrapper(org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax,org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax7475f57e), id#1L, false, 0, 0)])
+- Exchange SinglePartition
   +- *Range (0, 100, step=1, splits=Some(1))
```

After this PR:

```
== Physical Plan ==
SortAggregate(key=[], functions=[max(id#1L), default.hive_max(default.hive_max, HiveFunctionWrapper(org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax,org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax5e18a6a7), id#1L, false, 0, 0)])
+- Exchange SinglePartition
   +- SortAggregate(key=[], functions=[partial_max(id#1L), partial_default.hive_max(default.hive_max, HiveFunctionWrapper(org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax,org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax5e18a6a7), id#1L, false, 0, 0)])
      +- *Range (0, 100, step=1, splits=Some(1))
```

The tricky part of the PR is mostly about updating and passing around aggregation states of `HiveUDAFFunction`s since the aggregation state of a Hive UDAF may appear in three different forms. Let's take a look at the testing `MockUDAF` added in this PR as an example. This UDAF computes the count of non-null values together with the count of nulls of a given column. Its aggregation state may appear as the following forms at different time:

1. A `MockUDAFBuffer`, which is a concrete subclass of `GenericUDAFEvaluator.AggregationBuffer`

   The form used by Hive UDAF API. This form is required by the following scenarios:

   - Calling `GenericUDAFEvaluator.iterate()` to update an existing aggregation state with new input values.
   - Calling `GenericUDAFEvaluator.terminate()` to get the final aggregated value from an existing aggregation state.
   - Calling `GenericUDAFEvaluator.merge()` to merge other aggregation states into an existing aggregation state.

     The existing aggregation state to be updated must be in this form.

   Conversions:

   - To form 2:

     `GenericUDAFEvaluator.terminatePartial()`

   - To form 3:

     Convert to form 2 first, and then to 3.

2. An `Object[]` array containing two `java.lang.Long` values.

   The form used to interact with Hive's `ObjectInspector`s. This form is required by the following scenarios:

   - Calling `GenericUDAFEvaluator.terminatePartial()` to convert an existing aggregation state in form 1 to form 2.
   - Calling `GenericUDAFEvaluator.merge()` to merge other aggregation states into an existing aggregation state.

     The input aggregation state must be in this form.

   Conversions:

   - To form 1:

     No direct method. Have to create an empty `AggregationBuffer` and merge it into the empty buffer.

   - To form 3:

     `unwrapperFor()`/`unwrap()` method of `HiveInspectors`

3. The byte array that holds data of an `UnsafeRow` with two `LongType` fields.

   The form used by Spark SQL to shuffle partial aggregation results. This form is required because `TypedImperativeAggregate` always asks its subclasses to serialize their aggregation states into a byte array.

   Conversions:

   - To form 1:

     Convert to form 2 first, and then to 1.

   - To form 2:

     `wrapperFor()`/`wrap()` method of `HiveInspectors`

Here're some micro-benchmark results produced by the most recent master and this PR branch.

Master:

```
Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.10.5
Intel(R) Core(TM) i7-4960HQ CPU  2.60GHz

hive udaf vs spark af:                   Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
w/o groupBy                                    339 /  372          3.1         323.2       1.0X
w/ groupBy                                     503 /  529          2.1         479.7       0.7X
```

This PR:

```
Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.10.5
Intel(R) Core(TM) i7-4960HQ CPU  2.60GHz

hive udaf vs spark af:                   Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
w/o groupBy                                    116 /  126          9.0         110.8       1.0X
w/ groupBy                                     151 /  159          6.9         144.0       0.8X
```

Benchmark code snippet:

```scala
  test("Hive UDAF benchmark") {
    val N = 1 << 20

    sparkSession.sql(s"CREATE TEMPORARY FUNCTION hive_max AS '${classOf[GenericUDAFMax].getName}'")

    val benchmark = new Benchmark(
      name = "hive udaf vs spark af",
      valuesPerIteration = N,
      minNumIters = 5,
      warmupTime = 5.seconds,
      minTime = 5.seconds,
      outputPerIteration = true
    )

    benchmark.addCase("w/o groupBy") { _ =>
      sparkSession.range(N).agg("id" -> "hive_max").collect()
    }

    benchmark.addCase("w/ groupBy") { _ =>
      sparkSession.range(N).groupBy($"id" % 10).agg("id" -> "hive_max").collect()
    }

    benchmark.run()

    sparkSession.sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_max")
  }
```

## How was this patch tested?

New test suite `HiveUDAFSuite` is added.

Author: Cheng Lian <li...@databricks.com>

Closes #15703 from liancheng/partial-agg-hive-udaf.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2ca8ae9a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2ca8ae9a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2ca8ae9a

Branch: refs/heads/master
Commit: 2ca8ae9aa1b29bf1f46d0b656d9885e438e67f53
Parents: a36a76a
Author: Cheng Lian <li...@databricks.com>
Authored: Wed Nov 16 14:32:36 2016 -0800
Committer: Yin Huai <yh...@databricks.com>
Committed: Wed Nov 16 14:32:36 2016 -0800

----------------------------------------------------------------------
 .../org/apache/spark/sql/hive/hiveUDFs.scala    | 199 ++++++++++++++-----
 .../sql/hive/execution/HiveUDAFSuite.scala      | 152 ++++++++++++++
 2 files changed, 301 insertions(+), 50 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2ca8ae9a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index 4203308..32edd4a 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -17,16 +17,18 @@
 
 package org.apache.spark.sql.hive
 
+import java.nio.ByteBuffer
+
 import scala.collection.JavaConverters._
 import scala.collection.mutable.ArrayBuffer
 
 import org.apache.hadoop.hive.ql.exec._
 import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType}
 import org.apache.hadoop.hive.ql.udf.generic._
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer
 import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._
 import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper
-import org.apache.hadoop.hive.serde2.objectinspector.{ConstantObjectInspector, ObjectInspector,
-  ObjectInspectorFactory}
+import org.apache.hadoop.hive.serde2.objectinspector.{ConstantObjectInspector, ObjectInspector, ObjectInspectorFactory}
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions
 
 import org.apache.spark.internal.Logging
@@ -58,7 +60,7 @@ private[hive] case class HiveSimpleUDF(
 
   @transient
   private lazy val isUDFDeterministic = {
-    val udfType = function.getClass().getAnnotation(classOf[HiveUDFType])
+    val udfType = function.getClass.getAnnotation(classOf[HiveUDFType])
     udfType != null && udfType.deterministic()
   }
 
@@ -75,7 +77,7 @@ private[hive] case class HiveSimpleUDF(
 
   @transient
   lazy val unwrapper = unwrapperFor(ObjectInspectorFactory.getReflectionObjectInspector(
-    method.getGenericReturnType(), ObjectInspectorOptions.JAVA))
+    method.getGenericReturnType, ObjectInspectorOptions.JAVA))
 
   @transient
   private lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length)
@@ -263,8 +265,35 @@ private[hive] case class HiveGenericUDTF(
 }
 
 /**
- * Currently we don't support partial aggregation for queries using Hive UDAF, which may hurt
- * performance a lot.
+ * While being evaluated by Spark SQL, the aggregation state of a Hive UDAF may be in the following
+ * three formats:
+ *
+ *  1. An instance of some concrete `GenericUDAFEvaluator.AggregationBuffer` class
+ *
+ *     This is the native Hive representation of an aggregation state. Hive `GenericUDAFEvaluator`
+ *     methods like `iterate()`, `merge()`, `terminatePartial()`, and `terminate()` use this format.
+ *     We call these methods to evaluate Hive UDAFs.
+ *
+ *  2. A Java object that can be inspected using the `ObjectInspector` returned by the
+ *     `GenericUDAFEvaluator.init()` method.
+ *
+ *     Hive uses this format to produce a serializable aggregation state so that it can shuffle
+ *     partial aggregation results. Whenever we need to convert a Hive `AggregationBuffer` instance
+ *     into a Spark SQL value, we have to convert it to this format first and then do the conversion
+ *     with the help of `ObjectInspector`s.
+ *
+ *  3. A Spark SQL value
+ *
+ *     We use this format for serializing Hive UDAF aggregation states on Spark side. To be more
+ *     specific, we convert `AggregationBuffer`s into equivalent Spark SQL values, write them into
+ *     `UnsafeRow`s, and then retrieve the byte array behind those `UnsafeRow`s as serialization
+ *     results.
+ *
+ * We may use the following methods to convert the aggregation state back and forth:
+ *
+ *  - `wrap()`/`wrapperFor()`: from 3 to 1
+ *  - `unwrap()`/`unwrapperFor()`: from 1 to 3
+ *  - `GenericUDAFEvaluator.terminatePartial()`: from 2 to 3
  */
 private[hive] case class HiveUDAFFunction(
     name: String,
@@ -273,7 +302,7 @@ private[hive] case class HiveUDAFFunction(
     isUDAFBridgeRequired: Boolean = false,
     mutableAggBufferOffset: Int = 0,
     inputAggBufferOffset: Int = 0)
-  extends ImperativeAggregate with HiveInspectors {
+  extends TypedImperativeAggregate[GenericUDAFEvaluator.AggregationBuffer] with HiveInspectors {
 
   override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
     copy(mutableAggBufferOffset = newMutableAggBufferOffset)
@@ -281,73 +310,73 @@ private[hive] case class HiveUDAFFunction(
   override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
     copy(inputAggBufferOffset = newInputAggBufferOffset)
 
+  // Hive `ObjectInspector`s for all child expressions (input parameters of the function).
   @transient
-  private lazy val resolver =
-    if (isUDAFBridgeRequired) {
+  private lazy val inputInspectors = children.map(toInspector).toArray
+
+  // Spark SQL data types of input parameters.
+  @transient
+  private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray
+
+  private def newEvaluator(): GenericUDAFEvaluator = {
+    val resolver = if (isUDAFBridgeRequired) {
       new GenericUDAFBridge(funcWrapper.createFunction[UDAF]())
     } else {
       funcWrapper.createFunction[AbstractGenericUDAFResolver]()
     }
 
-  @transient
-  private lazy val inspectors = children.map(toInspector).toArray
-
-  @transient
-  private lazy val functionAndInspector = {
-    val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors, false, false)
-    val f = resolver.getEvaluator(parameterInfo)
-    f -> f.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors)
+    val parameterInfo = new SimpleGenericUDAFParameterInfo(inputInspectors, false, false)
+    resolver.getEvaluator(parameterInfo)
   }
 
+  // The UDAF evaluator used to consume raw input rows and produce partial aggregation results.
   @transient
-  private lazy val function = functionAndInspector._1
+  private lazy val partial1ModeEvaluator = newEvaluator()
 
+  // Hive `ObjectInspector` used to inspect partial aggregation results.
   @transient
-  private lazy val wrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray
+  private val partialResultInspector = partial1ModeEvaluator.init(
+    GenericUDAFEvaluator.Mode.PARTIAL1,
+    inputInspectors
+  )
 
+  // The UDAF evaluator used to merge partial aggregation results.
   @transient
-  private lazy val returnInspector = functionAndInspector._2
+  private lazy val partial2ModeEvaluator = {
+    val evaluator = newEvaluator()
+    evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, Array(partialResultInspector))
+    evaluator
+  }
 
+  // Spark SQL data type of partial aggregation results
   @transient
-  private lazy val unwrapper = unwrapperFor(returnInspector)
+  private lazy val partialResultDataType = inspectorToDataType(partialResultInspector)
 
+  // The UDAF evaluator used to compute the final result from a partial aggregation result objects.
   @transient
-  private[this] var buffer: GenericUDAFEvaluator.AggregationBuffer = _
-
-  override def eval(input: InternalRow): Any = unwrapper(function.evaluate(buffer))
+  private lazy val finalModeEvaluator = newEvaluator()
 
+  // Hive `ObjectInspector` used to inspect the final aggregation result object.
   @transient
-  private lazy val inputProjection = new InterpretedProjection(children)
+  private val returnInspector = finalModeEvaluator.init(
+    GenericUDAFEvaluator.Mode.FINAL,
+    Array(partialResultInspector)
+  )
 
+  // Wrapper functions used to wrap Spark SQL input arguments into Hive specific format.
   @transient
-  private lazy val cached = new Array[AnyRef](children.length)
+  private lazy val inputWrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray
 
+  // Unwrapper function used to unwrap final aggregation result objects returned by Hive UDAFs into
+  // Spark SQL specific format.
   @transient
-  private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray
-
-  // Hive UDAF has its own buffer, so we don't need to occupy a slot in the aggregation
-  // buffer for it.
-  override def aggBufferSchema: StructType = StructType(Nil)
-
-  override def update(_buffer: InternalRow, input: InternalRow): Unit = {
-    val inputs = inputProjection(input)
-    function.iterate(buffer, wrap(inputs, wrappers, cached, inputDataTypes))
-  }
-
-  override def merge(buffer1: InternalRow, buffer2: InternalRow): Unit = {
-    throw new UnsupportedOperationException(
-      "Hive UDAF doesn't support partial aggregate")
-  }
+  private lazy val resultUnwrapper = unwrapperFor(returnInspector)
 
-  override def initialize(_buffer: InternalRow): Unit = {
-    buffer = function.getNewAggregationBuffer
-  }
-
-  override val aggBufferAttributes: Seq[AttributeReference] = Nil
+  @transient
+  private lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length)
 
-  // Note: although this simply copies aggBufferAttributes, this common code can not be placed
-  // in the superclass because that will lead to initialization ordering issues.
-  override val inputAggBufferAttributes: Seq[AttributeReference] = Nil
+  @transient
+  private lazy val aggBufferSerDe: AggregationBufferSerDe = new AggregationBufferSerDe
 
   // We rely on Hive to check the input data types, so use `AnyDataType` here to bypass our
   // catalyst type checking framework.
@@ -355,7 +384,7 @@ private[hive] case class HiveUDAFFunction(
 
   override def nullable: Boolean = true
 
-  override def supportsPartial: Boolean = false
+  override def supportsPartial: Boolean = true
 
   override lazy val dataType: DataType = inspectorToDataType(returnInspector)
 
@@ -365,4 +394,74 @@ private[hive] case class HiveUDAFFunction(
     val distinct = if (isDistinct) "DISTINCT " else " "
     s"$name($distinct${children.map(_.sql).mkString(", ")})"
   }
+
+  override def createAggregationBuffer(): AggregationBuffer =
+    partial1ModeEvaluator.getNewAggregationBuffer
+
+  @transient
+  private lazy val inputProjection = UnsafeProjection.create(children)
+
+  override def update(buffer: AggregationBuffer, input: InternalRow): Unit = {
+    partial1ModeEvaluator.iterate(
+      buffer, wrap(inputProjection(input), inputWrappers, cached, inputDataTypes))
+  }
+
+  override def merge(buffer: AggregationBuffer, input: AggregationBuffer): Unit = {
+    // The 2nd argument of the Hive `GenericUDAFEvaluator.merge()` method is an input aggregation
+    // buffer in the 3rd format mentioned in the ScalaDoc of this class. Originally, Hive converts
+    // this `AggregationBuffer`s into this format before shuffling partial aggregation results, and
+    // calls `GenericUDAFEvaluator.terminatePartial()` to do the conversion.
+    partial2ModeEvaluator.merge(buffer, partial1ModeEvaluator.terminatePartial(input))
+  }
+
+  override def eval(buffer: AggregationBuffer): Any = {
+    resultUnwrapper(finalModeEvaluator.terminate(buffer))
+  }
+
+  override def serialize(buffer: AggregationBuffer): Array[Byte] = {
+    // Serializes an `AggregationBuffer` that holds partial aggregation results so that we can
+    // shuffle it for global aggregation later.
+    aggBufferSerDe.serialize(buffer)
+  }
+
+  override def deserialize(bytes: Array[Byte]): AggregationBuffer = {
+    // Deserializes an `AggregationBuffer` from the shuffled partial aggregation phase to prepare
+    // for global aggregation by merging multiple partial aggregation results within a single group.
+    aggBufferSerDe.deserialize(bytes)
+  }
+
+  // Helper class used to de/serialize Hive UDAF `AggregationBuffer` objects
+  private class AggregationBufferSerDe {
+    private val partialResultUnwrapper = unwrapperFor(partialResultInspector)
+
+    private val partialResultWrapper = wrapperFor(partialResultInspector, partialResultDataType)
+
+    private val projection = UnsafeProjection.create(Array(partialResultDataType))
+
+    private val mutableRow = new GenericInternalRow(1)
+
+    def serialize(buffer: AggregationBuffer): Array[Byte] = {
+      // `GenericUDAFEvaluator.terminatePartial()` converts an `AggregationBuffer` into an object
+      // that can be inspected by the `ObjectInspector` returned by `GenericUDAFEvaluator.init()`.
+      // Then we can unwrap it to a Spark SQL value.
+      mutableRow.update(0, partialResultUnwrapper(partial1ModeEvaluator.terminatePartial(buffer)))
+      val unsafeRow = projection(mutableRow)
+      val bytes = ByteBuffer.allocate(unsafeRow.getSizeInBytes)
+      unsafeRow.writeTo(bytes)
+      bytes.array()
+    }
+
+    def deserialize(bytes: Array[Byte]): AggregationBuffer = {
+      // `GenericUDAFEvaluator` doesn't provide any method that is capable to convert an object
+      // returned by `GenericUDAFEvaluator.terminatePartial()` back to an `AggregationBuffer`. The
+      // workaround here is creating an initial `AggregationBuffer` first and then merge the
+      // deserialized object into the buffer.
+      val buffer = partial2ModeEvaluator.getNewAggregationBuffer
+      val unsafeRow = new UnsafeRow(1)
+      unsafeRow.pointTo(bytes, bytes.length)
+      val partialResult = unsafeRow.get(0, partialResultDataType)
+      partial2ModeEvaluator.merge(buffer, partialResultWrapper(partialResult))
+      buffer
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/2ca8ae9a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala
new file mode 100644
index 0000000..c9ef72e
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution
+
+import scala.collection.JavaConverters._
+
+import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, GenericUDAFEvaluator, GenericUDAFMax}
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.{AggregationBuffer, Mode}
+import org.apache.hadoop.hive.ql.util.JavaDataModel
+import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory}
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
+import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo
+
+import org.apache.spark.sql.{QueryTest, Row}
+import org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec
+import org.apache.spark.sql.hive.test.TestHiveSingleton
+import org.apache.spark.sql.test.SQLTestUtils
+
+class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
+  import testImplicits._
+
+  protected override def beforeAll(): Unit = {
+    sql(s"CREATE TEMPORARY FUNCTION mock AS '${classOf[MockUDAF].getName}'")
+    sql(s"CREATE TEMPORARY FUNCTION hive_max AS '${classOf[GenericUDAFMax].getName}'")
+
+    Seq(
+      (0: Integer) -> "val_0",
+      (1: Integer) -> "val_1",
+      (2: Integer) -> null,
+      (3: Integer) -> null
+    ).toDF("key", "value").repartition(2).createOrReplaceTempView("t")
+  }
+
+  protected override def afterAll(): Unit = {
+    sql(s"DROP TEMPORARY FUNCTION IF EXISTS mock")
+    sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_max")
+  }
+
+  test("built-in Hive UDAF") {
+    val df = sql("SELECT key % 2, hive_max(key) FROM t GROUP BY key % 2")
+
+    val aggs = df.queryExecution.executedPlan.collect {
+      case agg: ObjectHashAggregateExec => agg
+    }
+
+    // There should be two aggregate operators, one for partial aggregation, and the other for
+    // global aggregation.
+    assert(aggs.length == 2)
+
+    checkAnswer(df, Seq(
+      Row(0, 2),
+      Row(1, 3)
+    ))
+  }
+
+  test("customized Hive UDAF") {
+    val df = sql("SELECT key % 2, mock(value) FROM t GROUP BY key % 2")
+
+    val aggs = df.queryExecution.executedPlan.collect {
+      case agg: ObjectHashAggregateExec => agg
+    }
+
+    // There should be two aggregate operators, one for partial aggregation, and the other for
+    // global aggregation.
+    assert(aggs.length == 2)
+
+    checkAnswer(df, Seq(
+      Row(0, Row(1, 1)),
+      Row(1, Row(1, 1))
+    ))
+  }
+}
+
+/**
+ * A testing Hive UDAF that computes the counts of both non-null values and nulls of a given column.
+ */
+class MockUDAF extends AbstractGenericUDAFResolver {
+  override def getEvaluator(info: Array[TypeInfo]): GenericUDAFEvaluator = new MockUDAFEvaluator
+}
+
+class MockUDAFBuffer(var nonNullCount: Long, var nullCount: Long)
+  extends GenericUDAFEvaluator.AbstractAggregationBuffer {
+
+  override def estimate(): Int = JavaDataModel.PRIMITIVES2 * 2
+}
+
+class MockUDAFEvaluator extends GenericUDAFEvaluator {
+  private val nonNullCountOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector
+
+  private val nullCountOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector
+
+  private val bufferOI = {
+    val fieldNames = Seq("nonNullCount", "nullCount").asJava
+    val fieldOIs = Seq(nonNullCountOI: ObjectInspector, nullCountOI: ObjectInspector).asJava
+    ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs)
+  }
+
+  private val nonNullCountField = bufferOI.getStructFieldRef("nonNullCount")
+
+  private val nullCountField = bufferOI.getStructFieldRef("nullCount")
+
+  override def getNewAggregationBuffer: AggregationBuffer = new MockUDAFBuffer(0L, 0L)
+
+  override def reset(agg: AggregationBuffer): Unit = {
+    val buffer = agg.asInstanceOf[MockUDAFBuffer]
+    buffer.nonNullCount = 0L
+    buffer.nullCount = 0L
+  }
+
+  override def init(mode: Mode, parameters: Array[ObjectInspector]): ObjectInspector = bufferOI
+
+  override def iterate(agg: AggregationBuffer, parameters: Array[AnyRef]): Unit = {
+    val buffer = agg.asInstanceOf[MockUDAFBuffer]
+    if (parameters.head eq null) {
+      buffer.nullCount += 1L
+    } else {
+      buffer.nonNullCount += 1L
+    }
+  }
+
+  override def merge(agg: AggregationBuffer, partial: Object): Unit = {
+    if (partial ne null) {
+      val nonNullCount = nonNullCountOI.get(bufferOI.getStructFieldData(partial, nonNullCountField))
+      val nullCount = nullCountOI.get(bufferOI.getStructFieldData(partial, nullCountField))
+      val buffer = agg.asInstanceOf[MockUDAFBuffer]
+      buffer.nonNullCount += nonNullCount
+      buffer.nullCount += nullCount
+    }
+  }
+
+  override def terminatePartial(agg: AggregationBuffer): AnyRef = {
+    val buffer = agg.asInstanceOf[MockUDAFBuffer]
+    Array[Object](buffer.nonNullCount: java.lang.Long, buffer.nullCount: java.lang.Long)
+  }
+
+  override def terminate(agg: AggregationBuffer): AnyRef = terminatePartial(agg)
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org