You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2016/04/19 15:38:30 UTC

spark git commit: [SPARK-14577][SQL] Add spark.sql.codegen.maxCaseBranches config option

Repository: spark
Updated Branches:
  refs/heads/master 74fe235ab -> 3d46d796a


[SPARK-14577][SQL] Add spark.sql.codegen.maxCaseBranches config option

## What changes were proposed in this pull request?

We currently disable codegen for `CaseWhen` if the number of branches is greater than 20 (in CaseWhen.MAX_NUM_CASES_FOR_CODEGEN). It would be better if this value is a non-public config defined in SQLConf.

## How was this patch tested?

Pass the Jenkins tests (including a new testcase `Support spark.sql.codegen.maxCaseBranches option`)

Author: Dongjoon Hyun <do...@apache.org>

Closes #12353 from dongjoon-hyun/SPARK-14577.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3d46d796
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3d46d796
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3d46d796

Branch: refs/heads/master
Commit: 3d46d796a3a2b60b37dc318652eded5e992be1e5
Parents: 74fe235
Author: Dongjoon Hyun <do...@apache.org>
Authored: Tue Apr 19 21:38:15 2016 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Tue Apr 19 21:38:15 2016 +0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/CatalystConf.scala       |   4 +-
 .../expressions/conditionalExpressions.scala    |  86 ++++++++++------
 .../sql/catalyst/optimizer/Optimizer.scala      |  14 ++-
 .../optimizer/OptimizeCodegenSuite.scala        | 102 +++++++++++++++++++
 .../spark/sql/execution/WholeStageCodegen.scala |   1 -
 .../org/apache/spark/sql/internal/SQLConf.scala |   8 ++
 6 files changed, 180 insertions(+), 35 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3d46d796/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
index abba866..0efe3c4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
@@ -29,6 +29,7 @@ trait CatalystConf {
   def groupByOrdinal: Boolean
 
   def optimizerMaxIterations: Int
+  def maxCaseBranchesForCodegen: Int
 
   /**
    * Returns the [[Resolver]] for the current configuration, which can be used to determine if two
@@ -45,6 +46,7 @@ case class SimpleCatalystConf(
     caseSensitiveAnalysis: Boolean,
     orderByOrdinal: Boolean = true,
     groupByOrdinal: Boolean = true,
-    optimizerMaxIterations: Int = 100)
+    optimizerMaxIterations: Int = 100,
+    maxCaseBranchesForCodegen: Int = 20)
   extends CatalystConf {
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/3d46d796/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index 336649c..e97e089 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -81,18 +81,15 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
 }
 
 /**
- * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END".
- * When a = true, returns b; when c = true, returns d; else returns e.
+ * Abstract parent class for common logic in CaseWhen and CaseWhenCodegen.
  *
  * @param branches seq of (branch condition, branch value)
  * @param elseValue optional value for the else branch
  */
-// scalastyle:off line.size.limit
-@ExpressionDescription(
-  usage = "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END - When a = true, returns b; when c = true, return d; else return e.")
-// scalastyle:on line.size.limit
-case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[Expression] = None)
-  extends Expression with CodegenFallback {
+abstract class CaseWhenBase(
+    branches: Seq[(Expression, Expression)],
+    elseValue: Option[Expression])
+  extends Expression with Serializable {
 
   override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue
 
@@ -142,16 +139,58 @@ case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[E
     }
   }
 
-  def shouldCodegen: Boolean = {
-    branches.length < CaseWhen.MAX_NUM_CASES_FOR_CODEGEN
+  override def toString: String = {
+    val cases = branches.map { case (c, v) => s" WHEN $c THEN $v" }.mkString
+    val elseCase = elseValue.map(" ELSE " + _).getOrElse("")
+    "CASE" + cases + elseCase + " END"
+  }
+
+  override def sql: String = {
+    val cases = branches.map { case (c, v) => s" WHEN ${c.sql} THEN ${v.sql}" }.mkString
+    val elseCase = elseValue.map(" ELSE " + _.sql).getOrElse("")
+    "CASE" + cases + elseCase + " END"
+  }
+}
+
+
+/**
+ * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END".
+ * When a = true, returns b; when c = true, returns d; else returns e.
+ *
+ * @param branches seq of (branch condition, branch value)
+ * @param elseValue optional value for the else branch
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage = "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END - When a = true, returns b; when c = true, return d; else return e.")
+// scalastyle:on line.size.limit
+case class CaseWhen(
+    val branches: Seq[(Expression, Expression)],
+    val elseValue: Option[Expression] = None)
+  extends CaseWhenBase(branches, elseValue) with CodegenFallback with Serializable {
+
+  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    super[CodegenFallback].doGenCode(ctx, ev)
+  }
+
+  def toCodegen(): CaseWhenCodegen = {
+    CaseWhenCodegen(branches, elseValue)
   }
+}
+
+/**
+ * CaseWhen expression used when code generation condition is satisfied.
+ * OptimizeCodegen optimizer replaces CaseWhen into CaseWhenCodegen.
+ *
+ * @param branches seq of (branch condition, branch value)
+ * @param elseValue optional value for the else branch
+ */
+case class CaseWhenCodegen(
+    val branches: Seq[(Expression, Expression)],
+    val elseValue: Option[Expression] = None)
+  extends CaseWhenBase(branches, elseValue) with Serializable {
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
-    if (!shouldCodegen) {
-      // Fallback to interpreted mode if there are too many branches, as it may reach the
-      // 64K limit (limit on bytecode size for a single function).
-      return super[CodegenFallback].doGenCode(ctx, ev)
-    }
     // Generate code that looks like:
     //
     // condA = ...
@@ -202,26 +241,10 @@ case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[E
       ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
       $generatedCode""")
   }
-
-  override def toString: String = {
-    val cases = branches.map { case (c, v) => s" WHEN $c THEN $v" }.mkString
-    val elseCase = elseValue.map(" ELSE " + _).getOrElse("")
-    "CASE" + cases + elseCase + " END"
-  }
-
-  override def sql: String = {
-    val cases = branches.map { case (c, v) => s" WHEN ${c.sql} THEN ${v.sql}" }.mkString
-    val elseCase = elseValue.map(" ELSE " + _.sql).getOrElse("")
-    "CASE" + cases + elseCase + " END"
-  }
 }
 
 /** Factory methods for CaseWhen. */
 object CaseWhen {
-
-  // The maximum number of switches supported with codegen.
-  val MAX_NUM_CASES_FOR_CODEGEN = 20
-
   def apply(branches: Seq[(Expression, Expression)], elseValue: Expression): CaseWhen = {
     CaseWhen(branches, Option(elseValue))
   }
@@ -242,7 +265,6 @@ object CaseWhen {
   }
 }
 
-
 /**
  * Case statements of the form "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END".
  * When a = b, returns c; when a = d, returns e; else returns f.

http://git-wip-us.apache.org/repos/asf/spark/blob/3d46d796/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index c46bdfb..b806b72 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -104,7 +104,9 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
     Batch("LocalRelation", fixedPoint,
       ConvertToLocalRelation) ::
     Batch("Subquery", Once,
-      OptimizeSubqueries) :: Nil
+      OptimizeSubqueries) ::
+    Batch("OptimizeCodegen", Once,
+      OptimizeCodegen(conf)) :: Nil
   }
 
   /**
@@ -864,6 +866,16 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
 }
 
 /**
+ * Optimizes expressions by replacing according to CodeGen configuration.
+ */
+case class OptimizeCodegen(conf: CatalystConf) extends Rule[LogicalPlan] {
+  def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+    case e @ CaseWhen(branches, _) if branches.size < conf.maxCaseBranchesForCodegen =>
+      e.toCodegen()
+  }
+}
+
+/**
  * Combines all adjacent [[Union]] operators into a single [[Union]].
  */
 object CombineUnions extends Rule[LogicalPlan] {

http://git-wip-us.apache.org/repos/asf/spark/blob/3d46d796/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala
new file mode 100644
index 0000000..4385b0e
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.SimpleCatalystConf
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.Literal._
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules._
+
+
+class OptimizeCodegenSuite extends PlanTest {
+
+  object Optimize extends RuleExecutor[LogicalPlan] {
+    val batches = Batch("OptimizeCodegen", Once, OptimizeCodegen(SimpleCatalystConf(true))) :: Nil
+  }
+
+  protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
+    val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation).analyze
+    val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation).analyze)
+    comparePlans(actual, correctAnswer)
+  }
+
+  test("Codegen only when the number of branches is small.") {
+    assertEquivalent(
+      CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)),
+      CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen())
+
+    assertEquivalent(
+      CaseWhen(List.fill(100)(TrueLiteral, Literal(1)), Literal(2)),
+      CaseWhen(List.fill(100)(TrueLiteral, Literal(1)), Literal(2)))
+  }
+
+  test("Nested CaseWhen Codegen.") {
+    assertEquivalent(
+      CaseWhen(
+        Seq((CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)), Literal(3))),
+        CaseWhen(Seq((TrueLiteral, Literal(4))), Literal(5))),
+      CaseWhen(
+        Seq((CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen(), Literal(3))),
+        CaseWhen(Seq((TrueLiteral, Literal(4))), Literal(5)).toCodegen()).toCodegen())
+  }
+
+  test("Multiple CaseWhen in one operator.") {
+    val plan = OneRowRelation
+      .select(
+        CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)),
+        CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)),
+        CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)),
+        CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6))).analyze
+    val correctAnswer = OneRowRelation
+      .select(
+        CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen(),
+        CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)).toCodegen(),
+        CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)),
+        CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6)).toCodegen()).analyze
+    val optimized = Optimize.execute(plan)
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("Multiple CaseWhen in different operators") {
+    val plan = OneRowRelation
+      .select(
+        CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)),
+        CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)),
+        CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)))
+      .where(
+        LessThan(
+          CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6)),
+          CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)))
+      ).analyze
+    val correctAnswer = OneRowRelation
+      .select(
+        CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen(),
+        CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)).toCodegen(),
+        CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)))
+      .where(
+        LessThan(
+          CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6)).toCodegen(),
+          CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)))
+      ).analyze
+    val optimized = Optimize.execute(plan)
+    comparePlans(optimized, correctAnswer)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/3d46d796/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index 29b66e3..46eaede 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -429,7 +429,6 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] {
 
   private def supportCodegen(e: Expression): Boolean = e match {
     case e: LeafExpression => true
-    case e: CaseWhen => e.shouldCodegen
     // CodegenFallback requires the input to be an InternalRow
     case e: CodegenFallback => false
     case _ => true

http://git-wip-us.apache.org/repos/asf/spark/blob/3d46d796/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 7f206bd..4ae8278 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -402,6 +402,12 @@ object SQLConf {
     .intConf
     .createWithDefault(200)
 
+  val MAX_CASES_BRANCHES = SQLConfigBuilder("spark.sql.codegen.maxCaseBranches")
+    .internal()
+    .doc("The maximum number of switches supported with codegen.")
+    .intConf
+    .createWithDefault(20)
+
   val FILES_MAX_PARTITION_BYTES = SQLConfigBuilder("spark.sql.files.maxPartitionBytes")
     .doc("The maximum number of bytes to pack into a single partition when reading files.")
     .longConf
@@ -529,6 +535,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {
 
   def wholeStageMaxNumFields: Int = getConf(WHOLESTAGE_MAX_NUM_FIELDS)
 
+  def maxCaseBranchesForCodegen: Int = getConf(MAX_CASES_BRANCHES)
+
   def exchangeReuseEnabled: Boolean = getConf(EXCHANGE_REUSE_ENABLED)
 
   def canonicalView: Boolean = getConf(CANONICAL_NATIVE_VIEW)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org