You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yh...@apache.org on 2015/09/24 18:54:14 UTC

spark git commit: [SPARK-10765] [SQL] use new aggregate interface for hive UDAF

Repository: spark
Updated Branches:
  refs/heads/master 02144d674 -> 341b13f8f


[SPARK-10765] [SQL] use new aggregate interface for hive UDAF

Author: Wenchen Fan <cl...@163.com>

Closes #8874 from cloud-fan/hive-agg.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/341b13f8
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/341b13f8
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/341b13f8

Branch: refs/heads/master
Commit: 341b13f8f5eb118f1fb4d4f84418715ac4750a4d
Parents: 02144d6
Author: Wenchen Fan <cl...@163.com>
Authored: Thu Sep 24 09:54:07 2015 -0700
Committer: Yin Huai <yh...@databricks.com>
Committed: Thu Sep 24 09:54:07 2015 -0700

----------------------------------------------------------------------
 .../expressions/aggregate/interfaces.scala      |   7 +-
 .../spark/sql/execution/SparkStrategies.scala   |  14 +-
 .../aggregate/AggregationIterator.scala         |   2 +-
 .../spark/sql/execution/aggregate/utils.scala   |  51 +++++++
 .../org/apache/spark/sql/hive/hiveUDFs.scala    | 139 ++++++++-----------
 .../spark/sql/hive/execution/HiveUDFSuite.scala |   2 +-
 6 files changed, 129 insertions(+), 86 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/341b13f8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index 576d8c7..d869953 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -17,7 +17,6 @@
 
 package org.apache.spark.sql.catalyst.expressions.aggregate
 
-import org.apache.spark.sql.catalyst.errors.TreeNodeException
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
 import org.apache.spark.sql.catalyst.InternalRow
@@ -169,6 +168,12 @@ abstract class AggregateFunction2
 
   override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String =
     throw new UnsupportedOperationException(s"Cannot evaluate expression: $this")
+
+  /**
+   * Indicates if this function supports partial aggregation.
+   * Currently Hive UDAF is the only one that doesn't support partial aggregation.
+   */
+  def supportsPartial: Boolean = true
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/341b13f8/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 41b215c..b078c8b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -221,7 +221,19 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
             }
 
             val aggregateOperator =
-              if (functionsWithDistinct.isEmpty) {
+              if (aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) {
+                if (functionsWithDistinct.nonEmpty) {
+                  sys.error("Distinct columns cannot exist in Aggregate operator containing " +
+                    "aggregate functions which don't support partial aggregation.")
+                } else {
+                  aggregate.Utils.planAggregateWithoutPartial(
+                    groupingExpressions,
+                    aggregateExpressions,
+                    aggregateFunctionMap,
+                    resultExpressions,
+                    planLater(child))
+                }
+              } else if (functionsWithDistinct.isEmpty) {
                 aggregate.Utils.planAggregateWithoutDistinct(
                   groupingExpressions,
                   aggregateExpressions,

http://git-wip-us.apache.org/repos/asf/spark/blob/341b13f8/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
index abca373..62dbc07 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
@@ -26,7 +26,7 @@ import org.apache.spark.unsafe.KVIterator
 import scala.collection.mutable.ArrayBuffer
 
 /**
- * The base class of [[SortBasedAggregationIterator]] and [[UnsafeHybridAggregationIterator]].
+ * The base class of [[SortBasedAggregationIterator]].
  * It mainly contains two parts:
  * 1. It initializes aggregate functions.
  * 2. It creates two functions, `processRow` and `generateOutput` based on [[AggregateMode]] of

http://git-wip-us.apache.org/repos/asf/spark/blob/341b13f8/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
index 80816a0..4f5e86c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
@@ -37,6 +37,57 @@ object Utils {
       UnsafeProjection.canSupport(groupingExpressions)
   }
 
+  def planAggregateWithoutPartial(
+      groupingExpressions: Seq[Expression],
+      aggregateExpressions: Seq[AggregateExpression2],
+      aggregateFunctionMap: Map[(AggregateFunction2, Boolean), (AggregateFunction2, Attribute)],
+      resultExpressions: Seq[NamedExpression],
+      child: SparkPlan): Seq[SparkPlan] = {
+
+    val namedGroupingExpressions = groupingExpressions.map {
+      case ne: NamedExpression => ne -> ne
+      // If the expression is not a NamedExpressions, we add an alias.
+      // So, when we generate the result of the operator, the Aggregate Operator
+      // can directly get the Seq of attributes representing the grouping expressions.
+      case other =>
+        val withAlias = Alias(other, other.toString)()
+        other -> withAlias
+    }
+    val groupExpressionMap = namedGroupingExpressions.toMap
+    val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute)
+
+    val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete))
+    val completeAggregateAttributes =
+      completeAggregateExpressions.map {
+        expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)._2
+      }
+
+    val rewrittenResultExpressions = resultExpressions.map { expr =>
+      expr.transformDown {
+        case agg: AggregateExpression2 =>
+          aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._2
+        case expression =>
+          // We do not rely on the equality check at here since attributes may
+          // different cosmetically. Instead, we use semanticEquals.
+          groupExpressionMap.collectFirst {
+            case (expr, ne) if expr semanticEquals expression => ne.toAttribute
+          }.getOrElse(expression)
+      }.asInstanceOf[NamedExpression]
+    }
+
+    SortBasedAggregate(
+      requiredChildDistributionExpressions = Some(namedGroupingAttributes),
+      groupingExpressions = namedGroupingExpressions.map(_._2),
+      nonCompleteAggregateExpressions = Nil,
+      nonCompleteAggregateAttributes = Nil,
+      completeAggregateExpressions = completeAggregateExpressions,
+      completeAggregateAttributes = completeAggregateAttributes,
+      initialInputBufferOffset = 0,
+      resultExpressions = rewrittenResultExpressions,
+      child = child
+    ) :: Nil
+  }
+
   def planAggregateWithoutDistinct(
       groupingExpressions: Seq[Expression],
       aggregateExpressions: Seq[AggregateExpression2],

http://git-wip-us.apache.org/repos/asf/spark/blob/341b13f8/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index cad0237..fa9012b 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -37,6 +37,7 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis
 import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules.Rule
@@ -65,9 +66,10 @@ private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry)
         HiveGenericUDF(new HiveFunctionWrapper(functionClassName), children)
       } else if (
         classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) {
-        HiveGenericUDAF(new HiveFunctionWrapper(functionClassName), children)
+        HiveUDAFFunction(new HiveFunctionWrapper(functionClassName), children)
       } else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) {
-        HiveUDAF(new HiveFunctionWrapper(functionClassName), children)
+        HiveUDAFFunction(
+          new HiveFunctionWrapper(functionClassName), children, isUDAFBridgeRequired = true)
       } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) {
         HiveGenericUDTF(new HiveFunctionWrapper(functionClassName), children)
       } else {
@@ -441,70 +443,6 @@ private[hive] case class HiveWindowFunction(
     new HiveWindowFunction(funcWrapper, pivotResult, isUDAFBridgeRequired, children)
 }
 
-private[hive] case class HiveGenericUDAF(
-    funcWrapper: HiveFunctionWrapper,
-    children: Seq[Expression]) extends AggregateExpression1
-  with HiveInspectors {
-
-  type UDFType = AbstractGenericUDAFResolver
-
-  @transient
-  protected lazy val resolver: AbstractGenericUDAFResolver = funcWrapper.createFunction()
-
-  @transient
-  protected lazy val objectInspector = {
-    val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors.toArray, false, false)
-    resolver.getEvaluator(parameterInfo)
-      .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray)
-  }
-
-  @transient
-  protected lazy val inspectors = children.map(toInspector)
-
-  def dataType: DataType = inspectorToDataType(objectInspector)
-
-  def nullable: Boolean = true
-
-  override def toString: String = {
-    s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
-  }
-
-  def newInstance(): HiveUDAFFunction = new HiveUDAFFunction(funcWrapper, children, this)
-}
-
-/** It is used as a wrapper for the hive functions which uses UDAF interface */
-private[hive] case class HiveUDAF(
-    funcWrapper: HiveFunctionWrapper,
-    children: Seq[Expression]) extends AggregateExpression1
-  with HiveInspectors {
-
-  type UDFType = UDAF
-
-  @transient
-  protected lazy val resolver: AbstractGenericUDAFResolver =
-    new GenericUDAFBridge(funcWrapper.createFunction())
-
-  @transient
-  protected lazy val objectInspector = {
-    val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors.toArray, false, false)
-    resolver.getEvaluator(parameterInfo)
-      .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray)
-  }
-
-  @transient
-  protected lazy val inspectors = children.map(toInspector)
-
-  def dataType: DataType = inspectorToDataType(objectInspector)
-
-  def nullable: Boolean = true
-
-  override def toString: String = {
-    s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
-  }
-
-  def newInstance(): HiveUDAFFunction = new HiveUDAFFunction(funcWrapper, children, this, true)
-}
-
 /**
  * Converts a Hive Generic User Defined Table Generating Function (UDTF) to a
  * [[Generator]].  Note that the semantics of Generators do not allow
@@ -584,49 +522,86 @@ private[hive] case class HiveGenericUDTF(
   }
 }
 
+/**
+ * Currently we don't support partial aggregation for queries using Hive UDAF, which may hurt
+ * performance a lot.
+ */
 private[hive] case class HiveUDAFFunction(
     funcWrapper: HiveFunctionWrapper,
-    exprs: Seq[Expression],
-    base: AggregateExpression1,
+    children: Seq[Expression],
     isUDAFBridgeRequired: Boolean = false)
-  extends AggregateFunction1
-  with HiveInspectors {
+  extends AggregateFunction2 with HiveInspectors {
 
-  def this() = this(null, null, null)
+  def this() = this(null, null)
 
-  private val resolver =
+  @transient
+  private lazy val resolver =
     if (isUDAFBridgeRequired) {
       new GenericUDAFBridge(funcWrapper.createFunction[UDAF]())
     } else {
       funcWrapper.createFunction[AbstractGenericUDAFResolver]()
     }
 
-  private val inspectors = exprs.map(toInspector).toArray
+  @transient
+  private lazy val inspectors = children.map(toInspector).toArray
 
-  private val function = {
+  @transient
+  private lazy val functionAndInspector = {
     val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors, false, false)
-    resolver.getEvaluator(parameterInfo)
+    val f = resolver.getEvaluator(parameterInfo)
+    f -> f.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors)
   }
 
-  private val returnInspector = function.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors)
+  @transient
+  private lazy val function = functionAndInspector._1
+
+  @transient
+  private lazy val returnInspector = functionAndInspector._2
 
-  private val buffer =
-    function.getNewAggregationBuffer
+  @transient
+  private[this] var buffer: GenericUDAFEvaluator.AggregationBuffer = _
 
   override def eval(input: InternalRow): Any = unwrap(function.evaluate(buffer), returnInspector)
 
   @transient
-  val inputProjection = new InterpretedProjection(exprs)
+  private lazy val inputProjection = new InterpretedProjection(children)
 
   @transient
-  protected lazy val cached = new Array[AnyRef](exprs.length)
+  private lazy val cached = new Array[AnyRef](children.length)
 
   @transient
-  private lazy val inputDataTypes: Array[DataType] = exprs.map(_.dataType).toArray
+  private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray
+
+  // Hive UDAF has its own buffer, so we don't need to occupy a slot in the aggregation
+  // buffer for it.
+  override def bufferSchema: StructType = StructType(Nil)
 
-  def update(input: InternalRow): Unit = {
+  override def update(_buffer: MutableRow, input: InternalRow): Unit = {
     val inputs = inputProjection(input)
     function.iterate(buffer, wrap(inputs, inspectors, cached, inputDataTypes))
   }
+
+  override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
+    throw new UnsupportedOperationException(
+      "Hive UDAF doesn't support partial aggregate")
+  }
+
+  override def cloneBufferAttributes: Seq[Attribute] = Nil
+
+  override def initialize(_buffer: MutableRow): Unit = {
+    buffer = function.getNewAggregationBuffer
+  }
+
+  override def bufferAttributes: Seq[AttributeReference] = Nil
+
+  // We rely on Hive to check the input data types, so use `AnyDataType` here to bypass our
+  // catalyst type checking framework.
+  override def inputTypes: Seq[AbstractDataType] = children.map(_ => AnyDataType)
+
+  override def nullable: Boolean = true
+
+  override def supportsPartial: Boolean = false
+
+  override lazy val dataType: DataType = inspectorToDataType(returnInspector)
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/341b13f8/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
index d9ba895..3c8a009 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
@@ -131,7 +131,7 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton {
     hiveContext.setConf(SQLConf.CODEGEN_ENABLED, codegenDefault)
   }
 
-  test("SPARK-6409 UDAFAverage test") {
+  test("SPARK-6409 UDAF Average test") {
     sql(s"CREATE TEMPORARY FUNCTION test_avg AS '${classOf[GenericUDAFAverage].getName}'")
     checkAnswer(
       sql("SELECT test_avg(1), test_avg(substr(value,5)) FROM src"),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org