You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2015/04/09 03:42:41 UTC

spark git commit: [SPARK-6451][SQL] supported code generation for CombineSum

Repository: spark
Updated Branches:
  refs/heads/master 941828054 -> 7d7384c78


[SPARK-6451][SQL] supported code generation for CombineSum

Author: Venkata Ramana Gollamudi <ra...@huawei.com>

Closes #5138 from gvramana/sum_fix_codegen and squashes the following commits:

95f5fe4 [Venkata Ramana Gollamudi] rebase merge changes
12f45a5 [Venkata Ramana Gollamudi] Combined and added code generations tests as per comment
d6a76ac [Venkata Ramana Gollamudi] added support for codegeneration for CombineSum and tests


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7d7384c7
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7d7384c7
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7d7384c7

Branch: refs/heads/master
Commit: 7d7384c781ea72e1eabab3daca2e237e3b0fc666
Parents: 9418280
Author: Venkata Ramana Gollamudi <ra...@huawei.com>
Authored: Wed Apr 8 18:42:34 2015 -0700
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Wed Apr 8 18:42:34 2015 -0700

----------------------------------------------------------------------
 .../sql/execution/GeneratedAggregate.scala      | 44 +++++++++-
 .../spark/sql/execution/SparkStrategies.scala   |  2 +-
 .../org/apache/spark/sql/SQLQuerySuite.scala    | 92 +++++++++++++++++++-
 3 files changed, 133 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7d7384c7/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
index a8018b9..861a2c2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -99,7 +99,10 @@ case class GeneratedAggregate(
         // but really, common sub expression elimination would be better....
         val zero = Cast(Literal(0), calcType)
         val updateFunction = Coalesce(
-          Add(Coalesce(currentSum :: zero :: Nil), Cast(expr, calcType)) :: currentSum :: Nil)
+          Add(
+            Coalesce(currentSum :: zero :: Nil),
+            Cast(expr, calcType)
+          ) :: currentSum :: zero :: Nil)
         val result =
           expr.dataType match {
             case DecimalType.Fixed(_, _) =>
@@ -109,6 +112,45 @@ case class GeneratedAggregate(
 
         AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result)
 
+      case cs @ CombineSum(expr) =>
+        val calcType = expr.dataType
+          expr.dataType match {
+            case DecimalType.Fixed(_, _) =>
+              DecimalType.Unlimited
+            case _ =>
+              expr.dataType
+          }
+
+        val currentSum = AttributeReference("currentSum", calcType, nullable = true)()
+        val initialValue = Literal.create(null, calcType)
+
+        // Coalasce avoids double calculation...
+        // but really, common sub expression elimination would be better....
+        val zero = Cast(Literal(0), calcType)
+        // If we're evaluating UnscaledValue(x), we can do Count on x directly, since its
+        // UnscaledValue will be null if and only if x is null; helps with Average on decimals
+        val actualExpr = expr match {
+          case UnscaledValue(e) => e
+          case _ => expr
+        }
+        // partial sum result can be null only when no input rows present 
+        val updateFunction = If(
+          IsNotNull(actualExpr),
+          Coalesce(
+            Add(
+              Coalesce(currentSum :: zero :: Nil), 
+              Cast(expr, calcType)) :: currentSum :: zero :: Nil),
+          currentSum)
+          
+        val result =
+          expr.dataType match {
+            case DecimalType.Fixed(_, _) =>
+              Cast(currentSum, cs.dataType)
+            case _ => currentSum
+          }
+
+        AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result)
+        
       case a @ Average(expr) =>
         val calcType =
           expr.dataType match {

http://git-wip-us.apache.org/repos/asf/spark/blob/7d7384c7/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index f754fa7..23f7e56 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -155,7 +155,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
     }
 
     def canBeCodeGened(aggs: Seq[AggregateExpression]): Boolean = !aggs.exists {
-      case _: Sum | _: Count | _: Max | _: CombineSetsAndCount => false
+      case _: CombineSum | _: Sum | _: Count | _: Max | _: CombineSetsAndCount => false
       // The generated set implementation is pretty limited ATM.
       case CollectHashSet(exprs) if exprs.size == 1  &&
            Seq(IntegerType, LongType).contains(exprs.head.dataType) => false

http://git-wip-us.apache.org/repos/asf/spark/blob/7d7384c7/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 87e7cf8..1ad92a3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql
 import org.apache.spark.sql.test.TestSQLContext
 import org.scalatest.BeforeAndAfterAll
 
+import org.apache.spark.sql.execution.GeneratedAggregate
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.catalyst.errors.TreeNodeException
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
@@ -102,14 +103,99 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
       sql("SELECT ABS(2.5)"),
       Row(2.5))
   }
-
+  
   test("aggregation with codegen") {
     val originalValue = conf.codegenEnabled
     setConf(SQLConf.CODEGEN_ENABLED, "true")
-    sql("SELECT key FROM testData GROUP BY key").collect()
+    // Prepare a table that we can group some rows.
+    table("testData")
+      .unionAll(table("testData"))
+      .unionAll(table("testData"))
+      .registerTempTable("testData3x")
+
+    def testCodeGen(sqlText: String, expectedResults: Seq[Row]): Unit = {
+      val df = sql(sqlText)
+      // First, check if we have GeneratedAggregate.
+      var hasGeneratedAgg = false
+      df.queryExecution.executedPlan.foreach {
+        case generatedAgg: GeneratedAggregate => hasGeneratedAgg = true
+        case _ =>
+      }
+      if (!hasGeneratedAgg) {
+        fail(
+          s"""
+             |Codegen is enabled, but query $sqlText does not have GeneratedAggregate in the plan.
+             |${df.queryExecution.simpleString}
+           """.stripMargin)
+      }
+      // Then, check results.
+      checkAnswer(df, expectedResults)
+    }
+
+    // Just to group rows.
+    testCodeGen(
+      "SELECT key FROM testData3x GROUP BY key",
+      (1 to 100).map(Row(_)))
+    // COUNT
+    testCodeGen(
+      "SELECT key, count(value) FROM testData3x GROUP BY key",
+      (1 to 100).map(i => Row(i, 3)))
+    testCodeGen(
+      "SELECT count(key) FROM testData3x",
+      Row(300) :: Nil)
+    // COUNT DISTINCT ON int
+    testCodeGen(
+      "SELECT value, count(distinct key) FROM testData3x GROUP BY value",
+      (1 to 100).map(i => Row(i.toString, 1)))
+    testCodeGen(
+      "SELECT count(distinct key) FROM testData3x",
+      Row(100) :: Nil)
+    // SUM
+     testCodeGen(
+       "SELECT value, sum(key) FROM testData3x GROUP BY value",
+       (1 to 100).map(i => Row(i.toString, 3 * i)))
+     testCodeGen(
+      "SELECT sum(key), SUM(CAST(key as Double)) FROM testData3x",      
+      Row(5050 * 3, 5050 * 3.0) :: Nil)
+    // AVERAGE
+    testCodeGen(
+      "SELECT value, avg(key) FROM testData3x GROUP BY value",
+      (1 to 100).map(i => Row(i.toString, i)))
+    testCodeGen(
+      "SELECT avg(key) FROM testData3x",
+      Row(50.5) :: Nil)
+    // MAX
+    testCodeGen(
+      "SELECT value, max(key) FROM testData3x GROUP BY value",
+      (1 to 100).map(i => Row(i.toString, i)))
+    testCodeGen(
+      "SELECT max(key) FROM testData3x",
+      Row(100) :: Nil)
+    // Some combinations.
+    testCodeGen(
+      """
+        |SELECT
+        |  value,
+        |  sum(key),
+        |  max(key),
+        |  avg(key),
+        |  count(key),
+        |  count(distinct key)
+        |FROM testData3x
+        |GROUP BY value
+      """.stripMargin,
+      (1 to 100).map(i => Row(i.toString, i*3, i, i, 3, 1)))
+    testCodeGen(
+      "SELECT max(key), avg(key), count(key), count(distinct key) FROM testData3x",
+      Row(100, 50.5, 300, 100) :: Nil)
+    // Aggregate with Code generation handling all null values
+    testCodeGen(
+      "SELECT  sum('a'), avg('a'), count(null) FROM testData",
+      Row(0, null, 0) :: Nil)
+      
+    dropTempTable("testData3x")
     setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString)
   }
-
   test("Add Parser of SQL COALESCE()") {
     checkAnswer(
       sql("""SELECT COALESCE(1, 2)"""),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org