You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2015/04/14 03:16:37 UTC

spark git commit: [SPARK-6877][SQL] Add code generation support for Min

Repository: spark
Updated Branches:
  refs/heads/master 5b8b324f3 -> 4898dfa46


[SPARK-6877][SQL] Add code generation support for Min

Currently `min` is not supported in code generation. This pr adds the support for it.

Author: Liang-Chi Hsieh <vi...@gmail.com>

Closes #5487 from viirya/add_min_codegen and squashes the following commits:

0ddec23 [Liang-Chi Hsieh] Add code generation support for Min.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4898dfa4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4898dfa4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4898dfa4

Branch: refs/heads/master
Commit: 4898dfa464be55772e3f9db10c48adcb3cfc9a3d
Parents: 5b8b324
Author: Liang-Chi Hsieh <vi...@gmail.com>
Authored: Mon Apr 13 18:16:33 2015 -0700
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Mon Apr 13 18:16:33 2015 -0700

----------------------------------------------------------------------
 .../sql/catalyst/expressions/arithmetic.scala   | 45 ++++++++++++++++++++
 .../expressions/codegen/CodeGenerator.scala     | 24 +++++++++++
 .../expressions/ExpressionEvaluationSuite.scala | 10 +++++
 .../sql/execution/GeneratedAggregate.scala      | 13 ++++++
 .../spark/sql/execution/SparkStrategies.scala   |  2 +-
 .../org/apache/spark/sql/SQLQuerySuite.scala    | 14 ++++--
 6 files changed, 104 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/4898dfa4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 1f6526e..566b34f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -369,6 +369,51 @@ case class MaxOf(left: Expression, right: Expression) extends Expression {
   override def toString: String = s"MaxOf($left, $right)"
 }
 
+case class MinOf(left: Expression, right: Expression) extends Expression {
+  type EvaluatedType = Any
+
+  override def foldable: Boolean = left.foldable && right.foldable
+
+  override def nullable: Boolean = left.nullable && right.nullable
+
+  override def children: Seq[Expression] = left :: right :: Nil
+
+  override lazy val resolved =
+    left.resolved && right.resolved &&
+    left.dataType == right.dataType
+
+  override def dataType: DataType = {
+    if (!resolved) {
+      throw new UnresolvedException(this,
+        s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}")
+    }
+    left.dataType
+  }
+
+  lazy val ordering = left.dataType match {
+    case i: NativeType => i.ordering.asInstanceOf[Ordering[Any]]
+    case other => sys.error(s"Type $other does not support ordered operations")
+  }
+
+  override def eval(input: Row): Any = {
+    val evalE1 = left.eval(input)
+    val evalE2 = right.eval(input)
+    if (evalE1 == null) {
+      evalE2
+    } else if (evalE2 == null) {
+      evalE1
+    } else {
+      if (ordering.compare(evalE1, evalE2) < 0) {
+        evalE1
+      } else {
+        evalE2
+      }
+    }
+  }
+
+  override def toString: String = s"MinOf($left, $right)"
+}
+
 /**
  * A function that get the absolute value of the numeric value.
  */

http://git-wip-us.apache.org/repos/asf/spark/blob/4898dfa4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index aac56e1..d141354 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -524,6 +524,30 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
           }
         """.children
 
+      case MinOf(e1, e2) =>
+        val eval1 = expressionEvaluator(e1)
+        val eval2 = expressionEvaluator(e2)
+
+        eval1.code ++ eval2.code ++
+        q"""
+          var $nullTerm = false
+          var $primitiveTerm: ${termForType(e1.dataType)} = ${defaultPrimitive(e1.dataType)}
+
+          if (${eval1.nullTerm}) {
+            $nullTerm = ${eval2.nullTerm}
+            $primitiveTerm = ${eval2.primitiveTerm}
+          } else if (${eval2.nullTerm}) {
+            $nullTerm = ${eval1.nullTerm}
+            $primitiveTerm = ${eval1.primitiveTerm}
+          } else {
+            if (${eval1.primitiveTerm} < ${eval2.primitiveTerm}) {
+              $primitiveTerm = ${eval1.primitiveTerm}
+            } else {
+              $primitiveTerm = ${eval2.primitiveTerm}
+            }
+          }
+        """.children
+
       case UnscaledValue(child) =>
         val childEval = expressionEvaluator(child)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/4898dfa4/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index d2b1090..d4362a9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -233,6 +233,16 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
     checkEvaluation(MaxOf(2, Literal.create(null, IntegerType)), 2)
   }
 
+  test("MinOf") {
+    checkEvaluation(MinOf(1, 2), 1)
+    checkEvaluation(MinOf(2, 1), 1)
+    checkEvaluation(MinOf(1L, 2L), 1L)
+    checkEvaluation(MinOf(2L, 1L), 1L)
+
+    checkEvaluation(MinOf(Literal.create(null, IntegerType), 1), 1)
+    checkEvaluation(MinOf(1, Literal.create(null, IntegerType)), 1)
+  }
+
   test("LIKE literal Regular Expression") {
     checkEvaluation(Literal.create(null, StringType).like("a"), null)
     checkEvaluation(Literal.create("a", StringType).like(Literal.create(null, StringType)), null)

http://git-wip-us.apache.org/repos/asf/spark/blob/4898dfa4/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
index b510cf0..b1ef655 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -164,6 +164,17 @@ case class GeneratedAggregate(
           updateMax :: Nil,
           currentMax)
 
+      case m @ Min(expr) =>
+        val currentMin = AttributeReference("currentMin", expr.dataType, nullable = true)()
+        val initialValue = Literal.create(null, expr.dataType)
+        val updateMin = MinOf(currentMin, expr)
+
+        AggregateEvaluation(
+          currentMin :: Nil,
+          initialValue :: Nil,
+          updateMin :: Nil,
+          currentMin)
+
       case CollectHashSet(Seq(expr)) =>
         val set =
           AttributeReference("hashSet", new OpenHashSetUDT(expr.dataType), nullable = false)()
@@ -188,6 +199,8 @@ case class GeneratedAggregate(
           initialValue :: Nil,
           collectSets :: Nil,
           CountSet(set))
+
+      case o => sys.error(s"$o can't be codegened.")
     }
 
     val computationSchema = computeFunctions.flatMap(_.schema)

http://git-wip-us.apache.org/repos/asf/spark/blob/4898dfa4/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index f0d92ff..5b99e40 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -155,7 +155,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
     }
 
     def canBeCodeGened(aggs: Seq[AggregateExpression]): Boolean = !aggs.exists {
-      case _: CombineSum | _: Sum | _: Count | _: Max | _: CombineSetsAndCount => false
+      case _: CombineSum | _: Sum | _: Count | _: Max | _: Min |  _: CombineSetsAndCount => false
       // The generated set implementation is pretty limited ATM.
       case CollectHashSet(exprs) if exprs.size == 1  &&
            Seq(IntegerType, LongType).contains(exprs.head.dataType) => false

http://git-wip-us.apache.org/repos/asf/spark/blob/4898dfa4/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 5e453e0..73fb791 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -172,6 +172,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
     testCodeGen(
       "SELECT max(key) FROM testData3x",
       Row(100) :: Nil)
+    // MIN
+    testCodeGen(
+      "SELECT value, min(key) FROM testData3x GROUP BY value",
+      (1 to 100).map(i => Row(i.toString, i)))
+    testCodeGen(
+      "SELECT min(key) FROM testData3x",
+      Row(1) :: Nil)
     // Some combinations.
     testCodeGen(
       """
@@ -179,16 +186,17 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
         |  value,
         |  sum(key),
         |  max(key),
+        |  min(key),
         |  avg(key),
         |  count(key),
         |  count(distinct key)
         |FROM testData3x
         |GROUP BY value
       """.stripMargin,
-      (1 to 100).map(i => Row(i.toString, i*3, i, i, 3, 1)))
+      (1 to 100).map(i => Row(i.toString, i*3, i, i, i, 3, 1)))
     testCodeGen(
-      "SELECT max(key), avg(key), count(key), count(distinct key) FROM testData3x",
-      Row(100, 50.5, 300, 100) :: Nil)
+      "SELECT max(key), min(key), avg(key), count(key), count(distinct key) FROM testData3x",
+      Row(100, 1, 50.5, 300, 100) :: Nil)
     // Aggregate with Code generation handling all null values
     testCodeGen(
       "SELECT  sum('a'), avg('a'), count(null) FROM testData",


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org