You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2018/03/09 21:48:16 UTC

spark git commit: [SPARK-23628][SQL][BACKPORT-2.3] calculateParamLength should not return 1 + num of expressions

Repository: spark
Updated Branches:
  refs/heads/branch-2.3 bc5ce0476 -> 3ec25d5a8


[SPARK-23628][SQL][BACKPORT-2.3] calculateParamLength should not return 1 + num of expressions

## What changes were proposed in this pull request?

Backport of ea480990e726aed59750f1cea8d40adba56d991a to branch 2.3.

## How was this patch tested?

added UT

cc cloud-fan hvanhovell

Author: Marco Gaido <ma...@gmail.com>

Closes #20783 from mgaido91/SPARK-23628_2.3.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3ec25d5a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3ec25d5a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3ec25d5a

Branch: refs/heads/branch-2.3
Commit: 3ec25d5a803888e5e24a47a511e9d88c423c5310
Parents: bc5ce04
Author: Marco Gaido <ma...@gmail.com>
Authored: Fri Mar 9 13:48:12 2018 -0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Fri Mar 9 13:48:12 2018 -0800

----------------------------------------------------------------------
 .../expressions/codegen/CodeGenerator.scala         |  7 ++++---
 .../catalyst/expressions/CodeGenerationSuite.scala  |  6 ++++++
 .../sql/execution/WholeStageCodegenSuite.scala      | 16 ++++++++--------
 3 files changed, 18 insertions(+), 11 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3ec25d5a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 4dcbb70..a54af03 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -1253,14 +1253,15 @@ class CodegenContext {
    */
   def calculateParamLength(params: Seq[Expression]): Int = {
     def paramLengthForExpr(input: Expression): Int = {
-      // For a nullable expression, we need to pass in an extra boolean parameter.
-      (if (input.nullable) 1 else 0) + javaType(input.dataType) match {
+      val javaParamLength = javaType(input.dataType) match {
         case JAVA_LONG | JAVA_DOUBLE => 2
         case _ => 1
       }
+      // For a nullable expression, we need to pass in an extra boolean parameter.
+      (if (input.nullable) 1 else 0) + javaParamLength
     }
     // Initial value is 1 for `this`.
-    1 + params.map(paramLengthForExpr(_)).sum
+    1 + params.map(paramLengthForExpr).sum
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/3ec25d5a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
index 676ba39..3958bff 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
@@ -436,4 +436,10 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
     ctx.addImmutableStateIfNotExists("String", mutableState2)
     assert(ctx.inlinedMutableStates.length == 2)
   }
+
+  test("SPARK-23628: calculateParamLength should compute properly the param length") {
+    val ctx = new CodegenContext
+    assert(ctx.calculateParamLength(Seq.range(0, 100).map(Literal(_))) == 101)
+    assert(ctx.calculateParamLength(Seq.range(0, 100).map(x => Literal(x.toLong))) == 201)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/3ec25d5a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index ef16292..0fb9dd2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
 import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
 import org.apache.spark.sql.execution.joins.SortMergeJoinExec
 import org.apache.spark.sql.expressions.scalalang.typed
-import org.apache.spark.sql.functions.{avg, broadcast, col, max}
+import org.apache.spark.sql.functions.{avg, broadcast, col, lit, max}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
@@ -249,12 +249,12 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext {
   }
 
   test("Skip splitting consume function when parameter number exceeds JVM limit") {
-    import testImplicits._
-
-    Seq((255, false), (254, true)).foreach { case (columnNum, hasSplit) =>
+    // since every field is nullable we have 2 params for each input column (one for the value
+    // and one for the isNull variable)
+    Seq((128, false), (127, true)).foreach { case (columnNum, hasSplit) =>
       withTempPath { dir =>
         val path = dir.getCanonicalPath
-        spark.range(10).select(Seq.tabulate(columnNum) {i => ('id + i).as(s"c$i")} : _*)
+        spark.range(10).select(Seq.tabulate(columnNum) {i => lit(i).as(s"c$i")} : _*)
           .write.mode(SaveMode.Overwrite).parquet(path)
 
         withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "255",
@@ -263,10 +263,10 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext {
           val df = spark.read.parquet(path).selectExpr(projection: _*)
 
           val plan = df.queryExecution.executedPlan
-          val wholeStageCodeGenExec = plan.find(p => p match {
-            case wp: WholeStageCodegenExec => true
+          val wholeStageCodeGenExec = plan.find {
+            case _: WholeStageCodegenExec => true
             case _ => false
-          })
+          }
           assert(wholeStageCodeGenExec.isDefined)
           val code = wholeStageCodeGenExec.get.asInstanceOf[WholeStageCodegenExec].doCodeGen()._2
           assert(code.body.contains("project_doConsume") == hasSplit)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org