You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2016/01/20 19:02:59 UTC

spark git commit: [SPARK-12881] [SQL] subexpress elimination in mutable projection

Repository: spark
Updated Branches:
  refs/heads/master 753b19451 -> 8e4f894e9


[SPARK-12881] [SQL] subexpress elimination in mutable projection

Author: Davies Liu <da...@databricks.com>

Closes #10814 from davies/mutable_subexpr.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8e4f894e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8e4f894e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8e4f894e

Branch: refs/heads/master
Commit: 8e4f894e986ccd943df9ddf55fc853eb0558886f
Parents: 753b194
Author: Davies Liu <da...@databricks.com>
Authored: Wed Jan 20 10:02:40 2016 -0800
Committer: Davies Liu <da...@gmail.com>
Committed: Wed Jan 20 10:02:40 2016 -0800

----------------------------------------------------------------------
 .../expressions/EquivalentExpressions.scala     |  5 ++-
 .../sql/catalyst/expressions/Expression.scala   |  4 +-
 .../expressions/codegen/CodeGenerator.scala     |  8 ++--
 .../codegen/GenerateMutableProjection.scala     | 43 +++++++++++++++-----
 .../codegen/GenerateUnsafeProjection.scala      |  6 +--
 .../SubexpressionEliminationSuite.scala         | 13 ++++++
 .../apache/spark/sql/execution/SparkPlan.scala  |  6 ++-
 .../org/apache/spark/sql/execution/Window.scala |  8 +++-
 .../aggregate/SortBasedAggregate.scala          |  3 +-
 .../execution/aggregate/TungstenAggregate.scala |  3 +-
 .../org/apache/spark/sql/SQLQuerySuite.scala    |  8 ++++
 11 files changed, 80 insertions(+), 27 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8e4f894e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
index f7162e4..affd1bd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.expressions
 
 import scala.collection.mutable
 
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+
 /**
  * This class is used to compute equality of (sub)expression trees. Expressions can be added
  * to this class and they subsequently query for expression equality. Expression trees are
@@ -67,7 +69,8 @@ class EquivalentExpressions {
    */
   def addExprTree(root: Expression, ignoreLeaf: Boolean = true): Unit = {
     val skip = root.isInstanceOf[LeafExpression] && ignoreLeaf
-    if (!skip && !addExpr(root)) {
+    // the children of CodegenFallback will not be used to generate code (call eval() instead)
+    if (!skip && !addExpr(root) && !root.isInstanceOf[CodegenFallback]) {
       root.children.foreach(addExprTree(_, ignoreLeaf))
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/8e4f894e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 25cf210..db17ba7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -100,8 +100,8 @@ abstract class Expression extends TreeNode[Expression] {
       ExprCode(code, subExprState.isNull, subExprState.value)
     }.getOrElse {
       val isNull = ctx.freshName("isNull")
-      val primitive = ctx.freshName("primitive")
-      val ve = ExprCode("", isNull, primitive)
+      val value = ctx.freshName("value")
+      val ve = ExprCode("", isNull, value)
       ve.code = genCode(ctx, ve)
       // Add `this` in the comment.
       ve.copy(s"/* ${this.toCommentSafeString} */\n" + ve.code.trim)

http://git-wip-us.apache.org/repos/asf/spark/blob/8e4f894e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 683029f..2747c31 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -125,7 +125,7 @@ class CodegenContext {
   val subExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState]
 
   // The collection of sub-exression result resetting methods that need to be called on each row.
-  val subExprResetVariables = mutable.ArrayBuffer.empty[String]
+  val subexprFunctions = mutable.ArrayBuffer.empty[String]
 
   def declareAddedFunctions(): String = {
     addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n")
@@ -424,9 +424,9 @@ class CodegenContext {
     val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1)
     commonExprs.foreach(e => {
       val expr = e.head
-      val isNull = freshName("isNull")
-      val value = freshName("value")
       val fnName = freshName("evalExpr")
+      val isNull = s"${fnName}IsNull"
+      val value = s"${fnName}Value"
 
       // Generate the code for this expression tree and wrap it in a function.
       val code = expr.gen(this)
@@ -461,7 +461,7 @@ class CodegenContext {
       addMutableState(javaType(expr.dataType), value,
         s"$value = ${defaultValue(expr.dataType)};")
 
-      subExprResetVariables += s"$fnName($INPUT_ROW);"
+      subexprFunctions += s"$fnName($INPUT_ROW);"
       val state = SubExprEliminationState(isNull, value)
       e.foreach(subExprEliminationExprs.put(_, state))
     })

http://git-wip-us.apache.org/repos/asf/spark/blob/8e4f894e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index 59ef0f5..d9fe761 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -38,12 +38,29 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
   protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
     in.map(BindReferences.bindReference(_, inputSchema))
 
+  def generate(
+      expressions: Seq[Expression],
+      inputSchema: Seq[Attribute],
+      useSubexprElimination: Boolean): (() => MutableProjection) = {
+    create(canonicalize(bind(expressions, inputSchema)), useSubexprElimination)
+  }
+
   protected def create(expressions: Seq[Expression]): (() => MutableProjection) = {
+    create(expressions, false)
+  }
+
+  private def create(
+      expressions: Seq[Expression],
+      useSubexprElimination: Boolean): (() => MutableProjection) = {
     val ctx = newCodeGenContext()
-    val projectionCodes = expressions.zipWithIndex.map {
-      case (NoOp, _) => ""
-      case (e, i) =>
-        val evaluationCode = e.gen(ctx)
+    val (validExpr, index) = expressions.zipWithIndex.filter {
+      case (NoOp, _) => false
+      case _ => true
+    }.unzip
+    val exprVals = ctx.generateExpressions(validExpr, useSubexprElimination)
+    val projectionCodes = exprVals.zip(index).map {
+      case (ev, i) =>
+        val e = expressions(i)
         if (e.nullable) {
           val isNull = s"isNull_$i"
           val value = s"value_$i"
@@ -51,22 +68,25 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
           ctx.addMutableState(ctx.javaType(e.dataType), value,
             s"this.$value = ${ctx.defaultValue(e.dataType)};")
           s"""
-            ${evaluationCode.code}
-            this.$isNull = ${evaluationCode.isNull};
-            this.$value = ${evaluationCode.value};
+            ${ev.code}
+            this.$isNull = ${ev.isNull};
+            this.$value = ${ev.value};
            """
         } else {
           val value = s"value_$i"
           ctx.addMutableState(ctx.javaType(e.dataType), value,
             s"this.$value = ${ctx.defaultValue(e.dataType)};")
           s"""
-            ${evaluationCode.code}
-            this.$value = ${evaluationCode.value};
+            ${ev.code}
+            this.$value = ${ev.value};
            """
         }
     }
-    val updates = expressions.zipWithIndex.map {
-      case (NoOp, _) => ""
+
+    // Evaluate all the the subexpressions.
+    val evalSubexpr = ctx.subexprFunctions.mkString("\n")
+
+    val updates = validExpr.zip(index).map {
       case (e, i) =>
         if (e.nullable) {
           if (e.dataType.isInstanceOf[DecimalType]) {
@@ -128,6 +148,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
 
         public java.lang.Object apply(java.lang.Object _i) {
           InternalRow ${ctx.INPUT_ROW} = (InternalRow) _i;
+          $evalSubexpr
           $allProjections
           // copy all the results into MutableRow
           $allUpdates

http://git-wip-us.apache.org/repos/asf/spark/blob/8e4f894e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 61e7469..72bf39a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -294,13 +294,13 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
     val holderClass = classOf[BufferHolder].getName
     ctx.addMutableState(holderClass, bufferHolder, s"this.$bufferHolder = new $holderClass();")
 
-    // Reset the subexpression values for each row.
-    val subexprReset = ctx.subExprResetVariables.mkString("\n")
+    // Evaluate all the subexpression.
+    val evalSubexpr = ctx.subexprFunctions.mkString("\n")
 
     val code =
       s"""
         $bufferHolder.reset();
-        $subexprReset
+        $evalSubexpr
         ${writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, bufferHolder)}
 
         $result.pointTo($bufferHolder.buffer, $bufferHolder.totalSize());

http://git-wip-us.apache.org/repos/asf/spark/blob/8e4f894e/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
index a61297b..43a3eb9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
@@ -154,4 +154,17 @@ class SubexpressionEliminationSuite extends SparkFunSuite {
     equivalence.addExpr(sum)
     assert(equivalence.getAllEquivalentExprs.isEmpty)
   }
+
+  test("Children of CodegenFallback") {
+    val one = Literal(1)
+    val two = Add(one, one)
+    val explode = Explode(two)
+    val add = Add(two, explode)
+
+    var equivalence = new EquivalentExpressions
+    equivalence.addExprTree(add, true)
+    // the `two` inside `explode` should not be added
+    assert(equivalence.getAllEquivalentExprs.filter(_.size > 1).size == 0)
+    assert(equivalence.getAllEquivalentExprs.filter(_.size == 1).size == 3)  // add, two, explode
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/8e4f894e/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 75101ea..b19b772 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -196,10 +196,12 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
   private[this] def isTesting: Boolean = sys.props.contains("spark.testing")
 
   protected def newMutableProjection(
-      expressions: Seq[Expression], inputSchema: Seq[Attribute]): () => MutableProjection = {
+      expressions: Seq[Expression],
+      inputSchema: Seq[Attribute],
+      useSubexprElimination: Boolean = false): () => MutableProjection = {
     log.debug(s"Creating MutableProj: $expressions, inputSchema: $inputSchema")
     try {
-      GenerateMutableProjection.generate(expressions, inputSchema)
+      GenerateMutableProjection.generate(expressions, inputSchema, useSubexprElimination)
     } catch {
       case e: Exception =>
         if (isTesting) {

http://git-wip-us.apache.org/repos/asf/spark/blob/8e4f894e/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
index 168b5ab..26a7340 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
@@ -194,7 +194,11 @@ case class Window(
         val functions = functionSeq.toArray
 
         // Construct an aggregate processor if we need one.
-        def processor = AggregateProcessor(functions, ordinal, child.output, newMutableProjection)
+        def processor = AggregateProcessor(
+          functions,
+          ordinal,
+          child.output,
+          (expressions, schema) => newMutableProjection(expressions, schema))
 
         // Create the factory
         val factory = key match {
@@ -206,7 +210,7 @@ case class Window(
                 ordinal,
                 functions,
                 child.output,
-                newMutableProjection,
+                (expressions, schema) => newMutableProjection(expressions, schema),
                 offset)
 
           // Growing Frame.

http://git-wip-us.apache.org/repos/asf/spark/blob/8e4f894e/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
index 1d56592..06a3991 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
@@ -87,7 +87,8 @@ case class SortBasedAggregate(
           aggregateAttributes,
           initialInputBufferOffset,
           resultExpressions,
-          newMutableProjection,
+          (expressions, inputSchema) =>
+            newMutableProjection(expressions, inputSchema, subexpressionEliminationEnabled),
           numInputRows,
           numOutputRows)
         if (!hasInput && groupingExpressions.isEmpty) {

http://git-wip-us.apache.org/repos/asf/spark/blob/8e4f894e/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index a9cf043..8dcbab4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -94,7 +94,8 @@ case class TungstenAggregate(
             aggregateAttributes,
             initialInputBufferOffset,
             resultExpressions,
-            newMutableProjection,
+            (expressions, inputSchema) =>
+              newMutableProjection(expressions, inputSchema, subexpressionEliminationEnabled),
             child.output,
             iter,
             testFallbackStartsAt,

http://git-wip-us.apache.org/repos/asf/spark/blob/8e4f894e/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index d7f1823..b159346 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
 import org.apache.spark.sql.catalyst.parser.ParserConf
 import org.apache.spark.sql.execution.{aggregate, SparkQl}
 import org.apache.spark.sql.execution.joins.{CartesianProduct, SortMergeJoin}
+import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext}
 import org.apache.spark.sql.test.SQLTestData._
@@ -1968,6 +1969,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
     verifyCallCount(
       df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2)
 
+    val testUdf = functions.udf((x: Int) => {
+      countAcc.++=(1)
+      x
+    })
+    verifyCallCount(
+      df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1)
+
     // Would be nice if semantic equals for `+` understood commutative
     verifyCallCount(
       df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org