You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2018/01/25 11:50:34 UTC

spark git commit: [SPARK-21717][SQL] Decouple consume functions of physical operators in whole-stage codegen

Repository: spark
Updated Branches:
  refs/heads/master 39ee2acf9 -> d20bbc2d8


[SPARK-21717][SQL] Decouple consume functions of physical operators in whole-stage codegen

## What changes were proposed in this pull request?

It has been observed in SPARK-21603 that whole-stage codegen suffers performance degradation, if the generated functions are too long to be optimized by JIT.

We basically produce a single function to incorporate generated codes from all physical operators in whole-stage. Thus, it is possibly to grow the size of generated function over a threshold that we can't have JIT optimization for it anymore.

This patch is trying to decouple the logic of consuming rows in physical operators to avoid a giant function processing rows.

## How was this patch tested?

Added tests.

Author: Liang-Chi Hsieh <vi...@gmail.com>

Closes #18931 from viirya/SPARK-21717.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d20bbc2d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d20bbc2d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d20bbc2d

Branch: refs/heads/master
Commit: d20bbc2d87ae6bd56d236a7c3d036b52c5f20ff5
Parents: 39ee2ac
Author: Liang-Chi Hsieh <vi...@gmail.com>
Authored: Thu Jan 25 19:49:58 2018 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Thu Jan 25 19:49:58 2018 +0800

----------------------------------------------------------------------
 .../expressions/codegen/CodeGenerator.scala     |  38 +++++-
 .../org/apache/spark/sql/internal/SQLConf.scala |  12 ++
 .../sql/execution/WholeStageCodegenExec.scala   | 135 +++++++++++++++----
 .../sql/execution/WholeStageCodegenSuite.scala  |  47 ++++++-
 4 files changed, 203 insertions(+), 29 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d20bbc2d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index f96ed76..4dcbb70 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -1245,6 +1245,31 @@ class CodegenContext {
       ""
     }
   }
+
+  /**
+   * Returns the length of parameters for a Java method descriptor. `this` contributes one unit
+   * and a parameter of type long or double contributes two units. Besides, for nullable parameter,
+   * we also need to pass a boolean parameter for the null status.
+   */
+  def calculateParamLength(params: Seq[Expression]): Int = {
+    def paramLengthForExpr(input: Expression): Int = {
+      // For a nullable expression, we need to pass in an extra boolean parameter.
+      (if (input.nullable) 1 else 0) + javaType(input.dataType) match {
+        case JAVA_LONG | JAVA_DOUBLE => 2
+        case _ => 1
+      }
+    }
+    // Initial value is 1 for `this`.
+    1 + params.map(paramLengthForExpr(_)).sum
+  }
+
+  /**
+   * In Java, a method descriptor is valid only if it represents method parameters with a total
+   * length less than a pre-defined constant.
+   */
+  def isValidParamLength(paramLength: Int): Boolean = {
+    paramLength <= CodeGenerator.MAX_JVM_METHOD_PARAMS_LENGTH
+  }
 }
 
 /**
@@ -1311,26 +1336,29 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
 object CodeGenerator extends Logging {
 
   // This is the value of HugeMethodLimit in the OpenJDK JVM settings
-  val DEFAULT_JVM_HUGE_METHOD_LIMIT = 8000
+  final val DEFAULT_JVM_HUGE_METHOD_LIMIT = 8000
+
+  // The max valid length of method parameters in JVM.
+  final val MAX_JVM_METHOD_PARAMS_LENGTH = 255
 
   // This is the threshold over which the methods in an inner class are grouped in a single
   // method which is going to be called by the outer class instead of the many small ones
-  val MERGE_SPLIT_METHODS_THRESHOLD = 3
+  final val MERGE_SPLIT_METHODS_THRESHOLD = 3
 
   // The number of named constants that can exist in the class is limited by the Constant Pool
   // limit, 65,536. We cannot know how many constants will be inserted for a class, so we use a
   // threshold of 1000k bytes to determine when a function should be inlined to a private, inner
   // class.
-  val GENERATED_CLASS_SIZE_THRESHOLD = 1000000
+  final val GENERATED_CLASS_SIZE_THRESHOLD = 1000000
 
   // This is the threshold for the number of global variables, whose types are primitive type or
   // complex type (e.g. more than one-dimensional array), that will be placed at the outer class
-  val OUTER_CLASS_VARIABLES_THRESHOLD = 10000
+  final val OUTER_CLASS_VARIABLES_THRESHOLD = 10000
 
   // This is the maximum number of array elements to keep global variables in one Java array
   // 32767 is the maximum integer value that does not require a constant pool entry in a Java
   // bytecode instruction
-  val MUTABLESTATEARRAY_SIZE_LIMIT = 32768
+  final val MUTABLESTATEARRAY_SIZE_LIMIT = 32768
 
   /**
    * Compile the Java source code into a Java class, using Janino.

http://git-wip-us.apache.org/repos/asf/spark/blob/d20bbc2d/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 1cef09a..470f88c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -661,6 +661,15 @@ object SQLConf {
     .intConf
     .createWithDefault(CodeGenerator.DEFAULT_JVM_HUGE_METHOD_LIMIT)
 
+  val WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR =
+    buildConf("spark.sql.codegen.splitConsumeFuncByOperator")
+      .internal()
+      .doc("When true, whole stage codegen would put the logic of consuming rows of each " +
+        "physical operator into individual methods, instead of a single big method. This can be " +
+        "used to avoid oversized function that can miss the opportunity of JIT optimization.")
+      .booleanConf
+      .createWithDefault(true)
+
   val FILES_MAX_PARTITION_BYTES = buildConf("spark.sql.files.maxPartitionBytes")
     .doc("The maximum number of bytes to pack into a single partition when reading files.")
     .longConf
@@ -1263,6 +1272,9 @@ class SQLConf extends Serializable with Logging {
 
   def hugeMethodLimit: Int = getConf(WHOLESTAGE_HUGE_METHOD_LIMIT)
 
+  def wholeStageSplitConsumeFuncByOperator: Boolean =
+    getConf(WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR)
+
   def tableRelationCacheSize: Int =
     getConf(StaticSQLConf.FILESOURCE_TABLE_RELATION_CACHE_SIZE)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/d20bbc2d/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
index 6102937..8ea9e81 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.execution
 
 import java.util.Locale
 
+import scala.collection.mutable
+
 import org.apache.spark.broadcast
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
@@ -106,6 +108,31 @@ trait CodegenSupport extends SparkPlan {
    */
   protected def doProduce(ctx: CodegenContext): String
 
+  private def prepareRowVar(ctx: CodegenContext, row: String, colVars: Seq[ExprCode]): ExprCode = {
+    if (row != null) {
+      ExprCode("", "false", row)
+    } else {
+      if (colVars.nonEmpty) {
+        val colExprs = output.zipWithIndex.map { case (attr, i) =>
+          BoundReference(i, attr.dataType, attr.nullable)
+        }
+        val evaluateInputs = evaluateVariables(colVars)
+        // generate the code to create a UnsafeRow
+        ctx.INPUT_ROW = row
+        ctx.currentVars = colVars
+        val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
+        val code = s"""
+          |$evaluateInputs
+          |${ev.code.trim}
+         """.stripMargin.trim
+        ExprCode(code, "false", ev.value)
+      } else {
+        // There is no columns
+        ExprCode("", "false", "unsafeRow")
+      }
+    }
+  }
+
   /**
    * Consume the generated columns or row from current SparkPlan, call its parent's `doConsume()`.
    *
@@ -126,28 +153,7 @@ trait CodegenSupport extends SparkPlan {
         }
       }
 
-    val rowVar = if (row != null) {
-      ExprCode("", "false", row)
-    } else {
-      if (outputVars.nonEmpty) {
-        val colExprs = output.zipWithIndex.map { case (attr, i) =>
-          BoundReference(i, attr.dataType, attr.nullable)
-        }
-        val evaluateInputs = evaluateVariables(outputVars)
-        // generate the code to create a UnsafeRow
-        ctx.INPUT_ROW = row
-        ctx.currentVars = outputVars
-        val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
-        val code = s"""
-          |$evaluateInputs
-          |${ev.code.trim}
-         """.stripMargin.trim
-        ExprCode(code, "false", ev.value)
-      } else {
-        // There is no columns
-        ExprCode("", "false", "unsafeRow")
-      }
-    }
+    val rowVar = prepareRowVar(ctx, row, outputVars)
 
     // Set up the `currentVars` in the codegen context, as we generate the code of `inputVars`
     // before calling `parent.doConsume`. We can't set up `INPUT_ROW`, because parent needs to
@@ -156,14 +162,97 @@ trait CodegenSupport extends SparkPlan {
     ctx.INPUT_ROW = null
     ctx.freshNamePrefix = parent.variablePrefix
     val evaluated = evaluateRequiredVariables(output, inputVars, parent.usedInputs)
+
+    // Under certain conditions, we can put the logic to consume the rows of this operator into
+    // another function. So we can prevent a generated function too long to be optimized by JIT.
+    // The conditions:
+    // 1. The config "spark.sql.codegen.splitConsumeFuncByOperator" is enabled.
+    // 2. `inputVars` are all materialized. That is guaranteed to be true if the parent plan uses
+    //    all variables in output (see `requireAllOutput`).
+    // 3. The number of output variables must less than maximum number of parameters in Java method
+    //    declaration.
+    val confEnabled = SQLConf.get.wholeStageSplitConsumeFuncByOperator
+    val requireAllOutput = output.forall(parent.usedInputs.contains(_))
+    val paramLength = ctx.calculateParamLength(output) + (if (row != null) 1 else 0)
+    val consumeFunc = if (confEnabled && requireAllOutput && ctx.isValidParamLength(paramLength)) {
+      constructDoConsumeFunction(ctx, inputVars, row)
+    } else {
+      parent.doConsume(ctx, inputVars, rowVar)
+    }
     s"""
        |${ctx.registerComment(s"CONSUME: ${parent.simpleString}")}
        |$evaluated
-       |${parent.doConsume(ctx, inputVars, rowVar)}
+       |$consumeFunc
+     """.stripMargin
+  }
+
+  /**
+   * To prevent concatenated function growing too long to be optimized by JIT. We can separate the
+   * parent's `doConsume` codes of a `CodegenSupport` operator into a function to call.
+   */
+  private def constructDoConsumeFunction(
+      ctx: CodegenContext,
+      inputVars: Seq[ExprCode],
+      row: String): String = {
+    val (args, params, inputVarsInFunc) = constructConsumeParameters(ctx, output, inputVars, row)
+    val rowVar = prepareRowVar(ctx, row, inputVarsInFunc)
+
+    val doConsume = ctx.freshName("doConsume")
+    ctx.currentVars = inputVarsInFunc
+    ctx.INPUT_ROW = null
+
+    val doConsumeFuncName = ctx.addNewFunction(doConsume,
+      s"""
+         | private void $doConsume(${params.mkString(", ")}) throws java.io.IOException {
+         |   ${parent.doConsume(ctx, inputVarsInFunc, rowVar)}
+         | }
+       """.stripMargin)
+
+    s"""
+       | $doConsumeFuncName(${args.mkString(", ")});
      """.stripMargin
   }
 
   /**
+   * Returns arguments for calling method and method definition parameters of the consume function.
+   * And also returns the list of `ExprCode` for the parameters.
+   */
+  private def constructConsumeParameters(
+      ctx: CodegenContext,
+      attributes: Seq[Attribute],
+      variables: Seq[ExprCode],
+      row: String): (Seq[String], Seq[String], Seq[ExprCode]) = {
+    val arguments = mutable.ArrayBuffer[String]()
+    val parameters = mutable.ArrayBuffer[String]()
+    val paramVars = mutable.ArrayBuffer[ExprCode]()
+
+    if (row != null) {
+      arguments += row
+      parameters += s"InternalRow $row"
+    }
+
+    variables.zipWithIndex.foreach { case (ev, i) =>
+      val paramName = ctx.freshName(s"expr_$i")
+      val paramType = ctx.javaType(attributes(i).dataType)
+
+      arguments += ev.value
+      parameters += s"$paramType $paramName"
+      val paramIsNull = if (!attributes(i).nullable) {
+        // Use constant `false` without passing `isNull` for non-nullable variable.
+        "false"
+      } else {
+        val isNull = ctx.freshName(s"exprIsNull_$i")
+        arguments += ev.isNull
+        parameters += s"boolean $isNull"
+        isNull
+      }
+
+      paramVars += ExprCode("", paramIsNull, paramName)
+    }
+    (arguments, parameters, paramVars)
+  }
+
+  /**
    * Returns source code to evaluate all the variables, and clear the code of them, to prevent
    * them to be evaluated twice.
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/d20bbc2d/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index 22ca128..242bb48 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -205,7 +205,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext {
     val codeWithShortFunctions = genGroupByCode(3)
     val (_, maxCodeSize1) = CodeGenerator.compile(codeWithShortFunctions)
     assert(maxCodeSize1 < SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.defaultValue.get)
-    val codeWithLongFunctions = genGroupByCode(20)
+    val codeWithLongFunctions = genGroupByCode(50)
     val (_, maxCodeSize2) = CodeGenerator.compile(codeWithLongFunctions)
     assert(maxCodeSize2 > SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.defaultValue.get)
   }
@@ -228,4 +228,49 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext {
       }
     }
   }
+
+  test("Control splitting consume function by operators with config") {
+    import testImplicits._
+    val df = spark.range(10).select(Seq.tabulate(2) {i => ('id + i).as(s"c$i")} : _*)
+
+    Seq(true, false).foreach { config =>
+      withSQLConf(SQLConf.WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR.key -> s"$config") {
+        val plan = df.queryExecution.executedPlan
+        val wholeStageCodeGenExec = plan.find(p => p match {
+          case wp: WholeStageCodegenExec => true
+          case _ => false
+        })
+        assert(wholeStageCodeGenExec.isDefined)
+        val code = wholeStageCodeGenExec.get.asInstanceOf[WholeStageCodegenExec].doCodeGen()._2
+        assert(code.body.contains("project_doConsume") == config)
+      }
+    }
+  }
+
+  test("Skip splitting consume function when parameter number exceeds JVM limit") {
+    import testImplicits._
+
+    Seq((255, false), (254, true)).foreach { case (columnNum, hasSplit) =>
+      withTempPath { dir =>
+        val path = dir.getCanonicalPath
+        spark.range(10).select(Seq.tabulate(columnNum) {i => ('id + i).as(s"c$i")} : _*)
+          .write.mode(SaveMode.Overwrite).parquet(path)
+
+        withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "255",
+            SQLConf.WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR.key -> "true") {
+          val projection = Seq.tabulate(columnNum)(i => s"c$i + c$i as newC$i")
+          val df = spark.read.parquet(path).selectExpr(projection: _*)
+
+          val plan = df.queryExecution.executedPlan
+          val wholeStageCodeGenExec = plan.find(p => p match {
+            case wp: WholeStageCodegenExec => true
+            case _ => false
+          })
+          assert(wholeStageCodeGenExec.isDefined)
+          val code = wholeStageCodeGenExec.get.asInstanceOf[WholeStageCodegenExec].doCodeGen()._2
+          assert(code.body.contains("project_doConsume") == hasSplit)
+        }
+      }
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org