You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2020/01/12 07:19:12 UTC

[spark] branch master updated: [SPARK-27296][SQL] Allows Aggregator to be registered as a UDF

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 1f50a58  [SPARK-27296][SQL] Allows Aggregator to be registered as a UDF
1f50a58 is described below

commit 1f50a5875b46885a40668c058a1a28e736776244
Author: Erik Erlandson <ee...@redhat.com>
AuthorDate: Sun Jan 12 15:18:30 2020 +0800

    [SPARK-27296][SQL] Allows Aggregator to be registered as a UDF
    
    ## What changes were proposed in this pull request?
    Defines a new subclass of UDF: `UserDefinedAggregator`. Also allows `Aggregator` to be registered as a udf.  Under the hood, the implementation is based on the internal `TypedImperativeAggregate` class that spark's predefined aggregators make use of. The effect is that custom user defined aggregators are now serialized only on partition boundaries instead of being serialized and deserialized at each input row.
    
    The two new modes of using `Aggregator` are as follows:
    ```scala
    val agg: Aggregator[IN, BUF, OUT] = // typed aggregator
    val udaf1 = UserDefinedAggregator(agg)
    val udaf2 = spark.udf.register("agg", agg)
    ```
    
    ## How was this patch tested?
    Unit testing has been added that corresponds to the testing suites for `UserDefinedAggregateFunction`. Additionally, unit tests explicitly count the number of aggregator ser/de cycles to ensure that it is governed only by the number of data partitions.
    
    To evaluate the performance impact, I did two comparisons.
    The code and REPL results are recorded on [this gist](https://gist.github.com/erikerlandson/b0e106a4dbaf7f80b4f4f3a21f05f892)
    To characterize its behavior I benchmarked both a relatively simple aggregator and then an aggregator with a complex structure (a t-digest).
    
    ### performance
    The following compares the new `Aggregator` based aggregation against UDAF. In this scenario, the new aggregation is about 100x faster. The difference in performance impact depends on the complexity of the aggregator. For very simple aggregators (e.g. implementing 'sum', etc), the performance impact is more like 25-30%.
    
    ```scala
    scala> import scala.util.Random._, org.apache.spark.sql.Row, org.apache.spark.tdigest._
    import scala.util.Random._
    import org.apache.spark.sql.Row
    import org.apache.spark.tdigest._
    
    scala> val data = sc.parallelize(Vector.fill(50000){(nextInt(2), nextGaussian, nextGaussian.toFloat)}, 5).toDF("cat", "x1", "x2")
    data: org.apache.spark.sql.DataFrame = [cat: int, x1: double ... 1 more field]
    
    scala> val udaf = TDigestUDAF(0.5, 0)
    udaf: org.apache.spark.tdigest.TDigestUDAF = TDigestUDAF(0.5,0)
    
    scala> val bs = Benchmark.sample(10) { data.agg(udaf($"x1"), udaf($"x2")).first }
    bs: Array[(Double, org.apache.spark.sql.Row)] = Array((16.523,[TDigestSQL(TDigest(0.5,0,130,TDigestMap(-4.9171836327285225 -> (1.0, 1.0), -3.9615949140987685 -> (1.0, 2.0), -3.792874086327091 -> (0.7500781537109753, 2.7500781537109753), -3.720534874164185 -> (1.796754196108008, 4.546832349818983), -3.702105588052377 -> (0.4531676501810167, 5.0), -3.665883591332569 -> (2.3434687534153142, 7.343468753415314), -3.649982231368131 -> (0.6565312465846858, 8.0), -3.5914188829817744 -> (4.0,  [...]
    
    scala> bs.map(_._1)
    res0: Array[Double] = Array(16.523, 17.138, 17.863, 17.801, 17.769, 17.786, 17.744, 17.8, 17.939, 17.854)
    
    scala> val agg = TDigestAggregator(0.5, 0)
    agg: org.apache.spark.tdigest.TDigestAggregator = TDigestAggregator(0.5,0)
    
    scala> val udaa = spark.udf.register("tdigest", agg)
    udaa: org.apache.spark.sql.expressions.UserDefinedAggregator[Double,org.apache.spark.tdigest.TDigestSQL,org.apache.spark.tdigest.TDigestSQL] = UserDefinedAggregator(TDigestAggregator(0.5,0),None,true,true)
    
    scala> val bs = Benchmark.sample(10) { data.agg(udaa($"x1"), udaa($"x2")).first }
    bs: Array[(Double, org.apache.spark.sql.Row)] = Array((0.313,[TDigestSQL(TDigest(0.5,0,130,TDigestMap(-4.9171836327285225 -> (1.0, 1.0), -3.9615949140987685 -> (1.0, 2.0), -3.792874086327091 -> (0.7500781537109753, 2.7500781537109753), -3.720534874164185 -> (1.796754196108008, 4.546832349818983), -3.702105588052377 -> (0.4531676501810167, 5.0), -3.665883591332569 -> (2.3434687534153142, 7.343468753415314), -3.649982231368131 -> (0.6565312465846858, 8.0), -3.5914188829817744 -> (4.0, 1 [...]
    
    scala> bs.map(_._1)
    res1: Array[Double] = Array(0.313, 0.193, 0.175, 0.185, 0.174, 0.176, 0.16, 0.186, 0.171, 0.179)
    
    scala>
    ```
    
    Closes #25024 from erikerlandson/spark-27296.
    
    Authored-by: Erik Erlandson <ee...@redhat.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../org/apache/spark/sql/UDFRegistration.scala     |  18 +-
 .../spark/sql/execution/aggregate/udaf.scala       |  72 +++-
 .../sql/expressions/UserDefinedFunction.scala      |  48 ++-
 .../scala/org/apache/spark/sql/functions.scala     |  64 +++-
 .../spark/sql/DataFrameWindowFunctionsSuite.scala  |  38 +-
 .../spark/sql/hive/execution/UDAQuerySuite.scala   | 417 +++++++++++++++++++++
 6 files changed, 643 insertions(+), 14 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
index bb05c76..a4ff095 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
@@ -28,10 +28,11 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.sql.api.java._
 import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection}
 import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF}
-import org.apache.spark.sql.execution.aggregate.ScalaUDAF
+import org.apache.spark.sql.execution.aggregate.{ScalaAggregator, ScalaUDAF}
 import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
-import org.apache.spark.sql.expressions.{SparkUserDefinedFunction, UserDefinedAggregateFunction, UserDefinedFunction}
+import org.apache.spark.sql.expressions.{Aggregator, SparkUserDefinedFunction, UserDefinedAggregateFunction, UserDefinedAggregator, UserDefinedFunction}
 import org.apache.spark.sql.types.DataType
 import org.apache.spark.util.Utils
 
@@ -101,9 +102,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
    * @since 2.2.0
    */
   def register(name: String, udf: UserDefinedFunction): UserDefinedFunction = {
-    def builder(children: Seq[Expression]) = udf.apply(children.map(Column.apply) : _*).expr
-    functionRegistry.createOrReplaceTempFunction(name, builder)
-    udf
+    udf match {
+      case udaf: UserDefinedAggregator[_, _, _] =>
+        def builder(children: Seq[Expression]) = udaf.scalaAggregator(children)
+        functionRegistry.createOrReplaceTempFunction(name, builder)
+        udf
+      case _ =>
+        def builder(children: Seq[Expression]) = udf.apply(children.map(Column.apply) : _*).expr
+        functionRegistry.createOrReplaceTempFunction(name, builder)
+        udf
+    }
   }
 
   // scalastyle:off line.size.limit
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
index 100486f..dfae5c0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
@@ -17,13 +17,17 @@
 
 package org.apache.spark.sql.execution.aggregate
 
+import scala.reflect.runtime.universe.TypeTag
+
 import org.apache.spark.internal.Logging
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.{Column, Row}
 import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, _}
-import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate
-import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection
-import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
+import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete}
+import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate}
+import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateMutableProjection, GenerateSafeProjection}
+import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction}
 import org.apache.spark.sql.types._
 
 /**
@@ -450,3 +454,63 @@ case class ScalaUDAF(
 
   override def nodeName: String = udaf.getClass.getSimpleName
 }
+
+case class ScalaAggregator[IN, BUF, OUT](
+    children: Seq[Expression],
+    agg: Aggregator[IN, BUF, OUT],
+    inputEncoderNR: ExpressionEncoder[IN],
+    nullable: Boolean = true,
+    isDeterministic: Boolean = true,
+    mutableAggBufferOffset: Int = 0,
+    inputAggBufferOffset: Int = 0)
+  extends TypedImperativeAggregate[BUF]
+  with NonSQLExpression
+  with UserDefinedExpression
+  with ImplicitCastInputTypes
+  with Logging {
+
+  private[this] lazy val inputEncoder = inputEncoderNR.resolveAndBind()
+  private[this] lazy val bufferEncoder =
+    agg.bufferEncoder.asInstanceOf[ExpressionEncoder[BUF]].resolveAndBind()
+  private[this] lazy val outputEncoder = agg.outputEncoder.asInstanceOf[ExpressionEncoder[OUT]]
+
+  def dataType: DataType = outputEncoder.objSerializer.dataType
+
+  def inputTypes: Seq[DataType] = inputEncoder.schema.map(_.dataType)
+
+  override lazy val deterministic: Boolean = isDeterministic
+
+  def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ScalaAggregator[IN, BUF, OUT] =
+    copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+  def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ScalaAggregator[IN, BUF, OUT] =
+    copy(inputAggBufferOffset = newInputAggBufferOffset)
+
+  private[this] lazy val inputProjection = UnsafeProjection.create(children)
+
+  def createAggregationBuffer(): BUF = agg.zero
+
+  def update(buffer: BUF, input: InternalRow): BUF =
+    agg.reduce(buffer, inputEncoder.fromRow(inputProjection(input)))
+
+  def merge(buffer: BUF, input: BUF): BUF = agg.merge(buffer, input)
+
+  def eval(buffer: BUF): Any = {
+    val row = outputEncoder.toRow(agg.finish(buffer))
+    if (outputEncoder.isSerializedAsStruct) row else row.get(0, dataType)
+  }
+
+  private[this] lazy val bufferRow = new UnsafeRow(bufferEncoder.namedExpressions.length)
+
+  def serialize(agg: BUF): Array[Byte] =
+    bufferEncoder.toRow(agg).asInstanceOf[UnsafeRow].getBytes()
+
+  def deserialize(storageFormat: Array[Byte]): BUF = {
+    bufferRow.pointTo(storageFormat, storageFormat.length)
+    bufferEncoder.fromRow(bufferRow)
+  }
+
+  override def toString: String = s"""${nodeName}(${children.mkString(",")})"""
+
+  override def nodeName: String = agg.getClass.getSimpleName
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
index 0c956ec..85b2cd3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
@@ -17,10 +17,15 @@
 
 package org.apache.spark.sql.expressions
 
-import org.apache.spark.annotation.Stable
-import org.apache.spark.sql.Column
+import scala.reflect.runtime.universe.TypeTag
+
+import org.apache.spark.annotation.{Experimental, Stable}
+import org.apache.spark.sql.{Column, Encoder}
 import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF}
+import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete}
+import org.apache.spark.sql.execution.aggregate.ScalaAggregator
 import org.apache.spark.sql.types.{AnyDataType, DataType}
 
 /**
@@ -136,3 +141,42 @@ private[sql] case class SparkUserDefinedFunction(
     }
   }
 }
+
+private[sql] case class UserDefinedAggregator[IN, BUF, OUT](
+    aggregator: Aggregator[IN, BUF, OUT],
+    inputEncoder: Encoder[IN],
+    name: Option[String] = None,
+    nullable: Boolean = true,
+    deterministic: Boolean = true) extends UserDefinedFunction {
+
+  @scala.annotation.varargs
+  def apply(exprs: Column*): Column = {
+    Column(AggregateExpression(scalaAggregator(exprs.map(_.expr)), Complete, isDistinct = false))
+  }
+
+  // This is also used by udf.register(...) when it detects a UserDefinedAggregator
+  def scalaAggregator(exprs: Seq[Expression]): ScalaAggregator[IN, BUF, OUT] = {
+    val iEncoder = inputEncoder.asInstanceOf[ExpressionEncoder[IN]]
+    ScalaAggregator(exprs, aggregator, iEncoder, nullable, deterministic)
+  }
+
+  override def withName(name: String): UserDefinedAggregator[IN, BUF, OUT] = {
+    copy(name = Option(name))
+  }
+
+  override def asNonNullable(): UserDefinedAggregator[IN, BUF, OUT] = {
+    if (!nullable) {
+      this
+    } else {
+      copy(nullable = false)
+    }
+  }
+
+  override def asNondeterministic(): UserDefinedAggregator[IN, BUF, OUT] = {
+    if (!deterministic) {
+      this
+    } else {
+      copy(deterministic = false)
+    }
+  }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 59dbe3e..fde6d3e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -32,12 +32,11 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, HintInfo, ResolvedHint}
 import org.apache.spark.sql.execution.SparkSqlParser
-import org.apache.spark.sql.expressions.{SparkUserDefinedFunction, UserDefinedFunction}
+import org.apache.spark.sql.expressions.{Aggregator, SparkUserDefinedFunction, UserDefinedAggregator, UserDefinedFunction}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 import org.apache.spark.util.Utils
 
-
 /**
  * Commonly used functions available for DataFrame operations. Using functions defined here provides
  * a little bit more compile-time safety to make sure the function exists.
@@ -4232,6 +4231,67 @@ object functions {
   //////////////////////////////////////////////////////////////////////////////////////////////
 
   /**
+   * Obtains a `UserDefinedFunction` that wraps the given `Aggregator`
+   * so that it may be used with untyped Data Frames.
+   * {{{
+   *   val agg = // Aggregator[IN, BUF, OUT]
+   *
+   *   // declare a UDF based on agg
+   *   val aggUDF = udaf(agg)
+   *   val aggData = df.agg(aggUDF($"colname"))
+   *
+   *   // register agg as a named function
+   *   spark.udf.register("myAggName", udaf(agg))
+   * }}}
+   *
+   * @tparam IN the aggregator input type
+   * @tparam BUF the aggregating buffer type
+   * @tparam OUT the finalized output type
+   *
+   * @param agg the typed Aggregator
+   *
+   * @return a UserDefinedFunction that can be used as an aggregating expression.
+   *
+   * @note The input encoder is inferred from the input type IN.
+   */
+  def udaf[IN: TypeTag, BUF, OUT](agg: Aggregator[IN, BUF, OUT]): UserDefinedFunction = {
+    udaf(agg, ExpressionEncoder[IN]())
+  }
+
+  /**
+   * Obtains a `UserDefinedFunction` that wraps the given `Aggregator`
+   * so that it may be used with untyped Data Frames.
+   * {{{
+   *   Aggregator<IN, BUF, OUT> agg = // custom Aggregator
+   *   Encoder<IN> enc = // input encoder
+   *
+   *   // declare a UDF based on agg
+   *   UserDefinedFunction aggUDF = udaf(agg, enc)
+   *   DataFrame aggData = df.agg(aggUDF($"colname"))
+   *
+   *   // register agg as a named function
+   *   spark.udf.register("myAggName", udaf(agg, enc))
+   * }}}
+   *
+   * @tparam IN the aggregator input type
+   * @tparam BUF the aggregating buffer type
+   * @tparam OUT the finalized output type
+   *
+   * @param agg the typed Aggregator
+   * @param inputEncoder a specific input encoder to use
+   *
+   * @return a UserDefinedFunction that can be used as an aggregating expression
+   *
+   * @note This overloading takes an explicit input encoder, to support UDAF
+   * declarations in Java.
+   */
+  def udaf[IN, BUF, OUT](
+      agg: Aggregator[IN, BUF, OUT],
+      inputEncoder: Encoder[IN]): UserDefinedFunction = {
+    UserDefinedAggregator(agg, inputEncoder)
+  }
+
+  /**
    * Defines a Scala closure of 0 arguments as user-defined function (UDF).
    * The data types are automatically inferred based on the Scala closure's
    * signature. By default the returned UDF is deterministic. To change it to
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
index 696b056..2e37879 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
@@ -22,7 +22,7 @@ import org.scalatest.Matchers.the
 import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled}
 import org.apache.spark.sql.catalyst.optimizer.TransposeWindow
 import org.apache.spark.sql.execution.exchange.Exchange
-import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction, Window}
+import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction, Window}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSparkSession
@@ -412,6 +412,42 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSparkSession {
         Row("b", 2, 4, 8)))
   }
 
+  test("window function with aggregator") {
+    val agg = udaf(new Aggregator[(Long, Long), Long, Long] {
+      def zero: Long = 0L
+      def reduce(b: Long, a: (Long, Long)): Long = b + (a._1 * a._2)
+      def merge(b1: Long, b2: Long): Long = b1 + b2
+      def finish(r: Long): Long = r
+      def bufferEncoder: Encoder[Long] = Encoders.scalaLong
+      def outputEncoder: Encoder[Long] = Encoders.scalaLong
+    })
+
+    val df = Seq(
+      ("a", 1, 1),
+      ("a", 1, 5),
+      ("a", 2, 10),
+      ("a", 2, -1),
+      ("b", 4, 7),
+      ("b", 3, 8),
+      ("b", 2, 4))
+      .toDF("key", "a", "b")
+    val window = Window.partitionBy($"key").orderBy($"a").rangeBetween(Long.MinValue, 0L)
+    checkAnswer(
+      df.select(
+        $"key",
+        $"a",
+        $"b",
+        agg($"a", $"b").over(window)),
+      Seq(
+        Row("a", 1, 1, 6),
+        Row("a", 1, 5, 6),
+        Row("a", 2, 10, 24),
+        Row("a", 2, -1, 24),
+        Row("b", 4, 7, 60),
+        Row("b", 3, 8, 32),
+        Row("b", 2, 4, 8)))
+  }
+
   test("null inputs") {
     val df = Seq(("a", 1), ("a", 1), ("a", 2), ("a", 2), ("b", 4), ("b", 3), ("b", 2))
       .toDF("key", "value")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala
new file mode 100644
index 0000000..e6856a5
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala
@@ -0,0 +1,417 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution
+
+import java.lang.{Double => jlDouble, Integer => jlInt, Long => jlLong}
+
+import scala.collection.JavaConverters._
+import scala.util.Random
+
+import test.org.apache.spark.sql.MyDoubleAvg
+import test.org.apache.spark.sql.MyDoubleSum
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.expressions.{Aggregator}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.hive.test.TestHiveSingleton
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.sql.types._
+
+class MyDoubleAvgAggBase extends Aggregator[jlDouble, (Double, Long), jlDouble] {
+  def zero: (Double, Long) = (0.0, 0L)
+  def reduce(b: (Double, Long), a: jlDouble): (Double, Long) = {
+    if (a != null) (b._1 + a, b._2 + 1L) else b
+  }
+  def merge(b1: (Double, Long), b2: (Double, Long)): (Double, Long) =
+    (b1._1 + b2._1, b1._2 + b2._2)
+  def finish(r: (Double, Long)): jlDouble =
+    if (r._2 > 0L) 100.0 + (r._1 / r._2.toDouble) else null
+  def bufferEncoder: Encoder[(Double, Long)] =
+    Encoders.tuple(Encoders.scalaDouble, Encoders.scalaLong)
+  def outputEncoder: Encoder[jlDouble] = Encoders.DOUBLE
+}
+
+object MyDoubleAvgAgg extends MyDoubleAvgAggBase
+object MyDoubleSumAgg extends MyDoubleAvgAggBase {
+  override def finish(r: (Double, Long)): jlDouble = if (r._2 > 0L) r._1 else null
+}
+
+object LongProductSumAgg extends Aggregator[(jlLong, jlLong), Long, jlLong] {
+  def zero: Long = 0L
+  def reduce(b: Long, a: (jlLong, jlLong)): Long = {
+    if ((a._1 != null) && (a._2 != null)) b + (a._1 * a._2) else b
+  }
+  def merge(b1: Long, b2: Long): Long = b1 + b2
+  def finish(r: Long): jlLong = r
+  def bufferEncoder: Encoder[Long] = Encoders.scalaLong
+  def outputEncoder: Encoder[jlLong] = Encoders.LONG
+}
+
+@SQLUserDefinedType(udt = classOf[CountSerDeUDT])
+case class CountSerDeSQL(nSer: Int, nDeSer: Int, sum: Int)
+
+class CountSerDeUDT extends UserDefinedType[CountSerDeSQL] {
+  def userClass: Class[CountSerDeSQL] = classOf[CountSerDeSQL]
+
+  override def typeName: String = "count-ser-de"
+
+  private[spark] override def asNullable: CountSerDeUDT = this
+
+  def sqlType: DataType = StructType(
+    StructField("nSer", IntegerType, false) ::
+    StructField("nDeSer", IntegerType, false) ::
+    StructField("sum", IntegerType, false) ::
+    Nil)
+
+  def serialize(sql: CountSerDeSQL): Any = {
+    val row = new GenericInternalRow(3)
+    row.setInt(0, 1 + sql.nSer)
+    row.setInt(1, sql.nDeSer)
+    row.setInt(2, sql.sum)
+    row
+  }
+
+  def deserialize(any: Any): CountSerDeSQL = any match {
+    case row: InternalRow if (row.numFields == 3) =>
+      CountSerDeSQL(row.getInt(0), 1 + row.getInt(1), row.getInt(2))
+    case u => throw new Exception(s"failed to deserialize: $u")
+  }
+
+  override def equals(obj: Any): Boolean = {
+    obj match {
+      case _: CountSerDeUDT => true
+      case _ => false
+    }
+  }
+
+  override def hashCode(): Int = classOf[CountSerDeUDT].getName.hashCode()
+}
+
+case object CountSerDeUDT extends CountSerDeUDT
+
+object CountSerDeAgg extends Aggregator[Int, CountSerDeSQL, CountSerDeSQL] {
+  def zero: CountSerDeSQL = CountSerDeSQL(0, 0, 0)
+  def reduce(b: CountSerDeSQL, a: Int): CountSerDeSQL = b.copy(sum = b.sum + a)
+  def merge(b1: CountSerDeSQL, b2: CountSerDeSQL): CountSerDeSQL =
+    CountSerDeSQL(b1.nSer + b2.nSer, b1.nDeSer + b2.nDeSer, b1.sum + b2.sum)
+  def finish(r: CountSerDeSQL): CountSerDeSQL = r
+  def bufferEncoder: Encoder[CountSerDeSQL] = ExpressionEncoder[CountSerDeSQL]()
+  def outputEncoder: Encoder[CountSerDeSQL] = ExpressionEncoder[CountSerDeSQL]()
+}
+
+abstract class UDAQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
+  import testImplicits._
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    val data1 = Seq[(Integer, Integer)](
+      (1, 10),
+      (null, -60),
+      (1, 20),
+      (1, 30),
+      (2, 0),
+      (null, -10),
+      (2, -1),
+      (2, null),
+      (2, null),
+      (null, 100),
+      (3, null),
+      (null, null),
+      (3, null)).toDF("key", "value")
+    data1.write.saveAsTable("agg1")
+
+    val data2 = Seq[(Integer, Integer, Integer)](
+      (1, 10, -10),
+      (null, -60, 60),
+      (1, 30, -30),
+      (1, 30, 30),
+      (2, 1, 1),
+      (null, -10, 10),
+      (2, -1, null),
+      (2, 1, 1),
+      (2, null, 1),
+      (null, 100, -10),
+      (3, null, 3),
+      (null, null, null),
+      (3, null, null)).toDF("key", "value1", "value2")
+    data2.write.saveAsTable("agg2")
+
+    val data3 = Seq[(Seq[Integer], Integer, Integer)](
+      (Seq[Integer](1, 1), 10, -10),
+      (Seq[Integer](null), -60, 60),
+      (Seq[Integer](1, 1), 30, -30),
+      (Seq[Integer](1), 30, 30),
+      (Seq[Integer](2), 1, 1),
+      (null, -10, 10),
+      (Seq[Integer](2, 3), -1, null),
+      (Seq[Integer](2, 3), 1, 1),
+      (Seq[Integer](2, 3, 4), null, 1),
+      (Seq[Integer](null), 100, -10),
+      (Seq[Integer](3), null, 3),
+      (null, null, null),
+      (Seq[Integer](3), null, null)).toDF("key", "value1", "value2")
+    data3.write.saveAsTable("agg3")
+
+    val data4 = Seq[Boolean](true, false, true).toDF("boolvalues")
+    data4.write.saveAsTable("agg4")
+
+    val emptyDF = spark.createDataFrame(
+      sparkContext.emptyRDD[Row],
+      StructType(StructField("key", StringType) :: StructField("value", IntegerType) :: Nil))
+    emptyDF.createOrReplaceTempView("emptyTable")
+
+    // Register UDAs
+    spark.udf.register("mydoublesum", udaf(MyDoubleSumAgg))
+    spark.udf.register("mydoubleavg", udaf(MyDoubleAvgAgg))
+    spark.udf.register("longProductSum", udaf(LongProductSumAgg))
+  }
+
+  override def afterAll(): Unit = {
+    try {
+      spark.sql("DROP TABLE IF EXISTS agg1")
+      spark.sql("DROP TABLE IF EXISTS agg2")
+      spark.sql("DROP TABLE IF EXISTS agg3")
+      spark.sql("DROP TABLE IF EXISTS agg4")
+      spark.catalog.dropTempView("emptyTable")
+    } finally {
+      super.afterAll()
+    }
+  }
+
+  test("aggregators") {
+    checkAnswer(
+      spark.sql(
+        """
+          |SELECT
+          |  key,
+          |  mydoublesum(value + 1.5 * key),
+          |  mydoubleavg(value),
+          |  avg(value - key),
+          |  mydoublesum(value - 1.5 * key),
+          |  avg(value)
+          |FROM agg1
+          |GROUP BY key
+        """.stripMargin),
+      Row(1, 64.5, 120.0, 19.0, 55.5, 20.0) ::
+        Row(2, 5.0, 99.5, -2.5, -7.0, -0.5) ::
+        Row(3, null, null, null, null, null) ::
+        Row(null, null, 110.0, null, null, 10.0) :: Nil)
+  }
+
+  test("non-deterministic children expressions of aggregator") {
+    val e = intercept[AnalysisException] {
+      spark.sql(
+        """
+          |SELECT mydoublesum(value + 1.5 * key + rand())
+          |FROM agg1
+          |GROUP BY key
+        """.stripMargin)
+    }.getMessage
+    assert(Seq("nondeterministic expression",
+      "should not appear in the arguments of an aggregate function").forall(e.contains))
+  }
+
+  test("interpreted aggregate function") {
+    checkAnswer(
+      spark.sql(
+        """
+          |SELECT mydoublesum(value), key
+          |FROM agg1
+          |GROUP BY key
+        """.stripMargin),
+      Row(60.0, 1) :: Row(-1.0, 2) :: Row(null, 3) :: Row(30.0, null) :: Nil)
+
+    checkAnswer(
+      spark.sql(
+        """
+          |SELECT mydoublesum(value) FROM agg1
+        """.stripMargin),
+      Row(89.0) :: Nil)
+
+    checkAnswer(
+      spark.sql(
+        """
+          |SELECT mydoublesum(null)
+        """.stripMargin),
+      Row(null) :: Nil)
+  }
+
+  test("interpreted and expression-based aggregation functions") {
+    checkAnswer(
+      spark.sql(
+        """
+          |SELECT mydoublesum(value), key, avg(value)
+          |FROM agg1
+          |GROUP BY key
+        """.stripMargin),
+      Row(60.0, 1, 20.0) ::
+        Row(-1.0, 2, -0.5) ::
+        Row(null, 3, null) ::
+        Row(30.0, null, 10.0) :: Nil)
+
+    checkAnswer(
+      spark.sql(
+        """
+          |SELECT
+          |  mydoublesum(value + 1.5 * key),
+          |  avg(value - key),
+          |  key,
+          |  mydoublesum(value - 1.5 * key),
+          |  avg(value)
+          |FROM agg1
+          |GROUP BY key
+        """.stripMargin),
+      Row(64.5, 19.0, 1, 55.5, 20.0) ::
+        Row(5.0, -2.5, 2, -7.0, -0.5) ::
+        Row(null, null, 3, null, null) ::
+        Row(null, null, null, null, 10.0) :: Nil)
+  }
+
+  test("single distinct column set") {
+    checkAnswer(
+      spark.sql(
+        """
+          |SELECT
+          |  mydoubleavg(distinct value1),
+          |  avg(value1),
+          |  avg(value2),
+          |  key,
+          |  mydoubleavg(value1 - 1),
+          |  mydoubleavg(distinct value1) * 0.1,
+          |  avg(value1 + value2)
+          |FROM agg2
+          |GROUP BY key
+        """.stripMargin),
+      Row(120.0, 70.0/3.0, -10.0/3.0, 1, 67.0/3.0 + 100.0, 12.0, 20.0) ::
+        Row(100.0, 1.0/3.0, 1.0, 2, -2.0/3.0 + 100.0, 10.0, 2.0) ::
+        Row(null, null, 3.0, 3, null, null, null) ::
+        Row(110.0, 10.0, 20.0, null, 109.0, 11.0, 30.0) :: Nil)
+
+    checkAnswer(
+      spark.sql(
+        """
+          |SELECT
+          |  key,
+          |  mydoubleavg(distinct value1),
+          |  mydoublesum(value2),
+          |  mydoublesum(distinct value1),
+          |  mydoubleavg(distinct value1),
+          |  mydoubleavg(value1)
+          |FROM agg2
+          |GROUP BY key
+        """.stripMargin),
+      Row(1, 120.0, -10.0, 40.0, 120.0, 70.0/3.0 + 100.0) ::
+        Row(2, 100.0, 3.0, 0.0, 100.0, 1.0/3.0 + 100.0) ::
+        Row(3, null, 3.0, null, null, null) ::
+        Row(null, 110.0, 60.0, 30.0, 110.0, 110.0) :: Nil)
+  }
+
+  test("multiple distinct multiple columns sets") {
+    checkAnswer(
+      spark.sql(
+        """
+          |SELECT
+          |  key,
+          |  count(distinct value1),
+          |  sum(distinct value1),
+          |  count(distinct value2),
+          |  sum(distinct value2),
+          |  count(distinct value1, value2),
+          |  longProductSum(distinct value1, value2),
+          |  count(value1),
+          |  sum(value1),
+          |  count(value2),
+          |  sum(value2),
+          |  longProductSum(value1, value2),
+          |  count(*),
+          |  count(1)
+          |FROM agg2
+          |GROUP BY key
+        """.stripMargin),
+      Row(null, 3, 30, 3, 60, 3, -4700, 3, 30, 3, 60, -4700, 4, 4) ::
+        Row(1, 2, 40, 3, -10, 3, -100, 3, 70, 3, -10, -100, 3, 3) ::
+        Row(2, 2, 0, 1, 1, 1, 1, 3, 1, 3, 3, 2, 4, 4) ::
+        Row(3, 0, null, 1, 3, 0, 0, 0, null, 1, 3, 0, 2, 2) :: Nil)
+  }
+
+  test("verify aggregator ser/de behavior") {
+    val data = sparkContext.parallelize((1 to 100).toSeq, 3).toDF("value1")
+    val agg = udaf(CountSerDeAgg)
+    checkAnswer(
+      data.agg(agg($"value1")),
+      Row(CountSerDeSQL(4, 4, 5050)) :: Nil)
+  }
+
+  test("verify type casting failure") {
+    assertThrows[org.apache.spark.sql.AnalysisException] {
+      spark.sql(
+        """
+          |SELECT mydoublesum(boolvalues) FROM agg4
+        """.stripMargin)
+    }
+  }
+}
+
+class HashUDAQuerySuite extends UDAQuerySuite
+
+class HashUDAQueryWithControlledFallbackSuite extends UDAQuerySuite {
+
+  override protected def checkAnswer(actual: => DataFrame, expectedAnswer: Seq[Row]): Unit = {
+    super.checkAnswer(actual, expectedAnswer)
+    Seq("true", "false").foreach { enableTwoLevelMaps =>
+      withSQLConf("spark.sql.codegen.aggregate.map.twolevel.enabled" ->
+        enableTwoLevelMaps) {
+        (1 to 3).foreach { fallbackStartsAt =>
+          withSQLConf("spark.sql.TungstenAggregate.testFallbackStartsAt" ->
+            s"${(fallbackStartsAt - 1).toString}, ${fallbackStartsAt.toString}") {
+            QueryTest.getErrorMessageInCheckAnswer(actual, expectedAnswer) match {
+              case Some(errorMessage) =>
+                val newErrorMessage =
+                  s"""
+                     |The following aggregation query failed when using HashAggregate with
+                     |controlled fallback (it falls back to bytes to bytes map once it has processed
+                     |${fallbackStartsAt - 1} input rows and to sort-based aggregation once it has
+                     |processed $fallbackStartsAt input rows). The query is ${actual.queryExecution}
+                     |
+                    |$errorMessage
+                  """.stripMargin
+
+                fail(newErrorMessage)
+              case None => // Success
+            }
+          }
+        }
+      }
+    }
+  }
+
+  // Override it to make sure we call the actually overridden checkAnswer.
+  override protected def checkAnswer(df: => DataFrame, expectedAnswer: Row): Unit = {
+    checkAnswer(df, Seq(expectedAnswer))
+  }
+
+  // Override it to make sure we call the actually overridden checkAnswer.
+  override protected def checkAnswer(df: => DataFrame, expectedAnswer: DataFrame): Unit = {
+    checkAnswer(df, expectedAnswer.collect())
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org