You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2016/11/07 18:09:28 UTC

spark git commit: [SPARK-18125][SQL][BRANCH-2.0] Fix a compilation error in codegen due to splitExpression

Repository: spark
Updated Branches:
  refs/heads/branch-2.0 dd5cb0a98 -> b5d7217af


[SPARK-18125][SQL][BRANCH-2.0] Fix a compilation error in codegen due to splitExpression

## What changes were proposed in this pull request?

Backport to branch 2.0.

As reported in the jira, sometimes the generated java code in codegen will cause compilation error.

Code snippet to test it:

    case class Route(src: String, dest: String, cost: Int)
    case class GroupedRoutes(src: String, dest: String, routes: Seq[Route])

    val ds = sc.parallelize(Array(
      Route("a", "b", 1),
      Route("a", "b", 2),
      Route("a", "c", 2),
      Route("a", "d", 10),
      Route("b", "a", 1),
      Route("b", "a", 5),
      Route("b", "c", 6))
    ).toDF.as[Route]

    val grped = ds.map(r => GroupedRoutes(r.src, r.dest, Seq(r)))
      .groupByKey(r => (r.src, r.dest))
      .reduceGroups { (g1: GroupedRoutes, g2: GroupedRoutes) =>
        GroupedRoutes(g1.src, g1.dest, g1.routes ++ g2.routes)
      }.map(_._2)

The problem here is, in `ReferenceToExpressions` we evaluate the children vars to local variables. Then the result expression is evaluated to use those children variables. In the above case, the result expression code is too long and will be split by `CodegenContext.splitExpression`. So those local variables cannot be accessed and cause compilation error.

## How was this patch tested?

Jenkins tests.

Please review https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark before opening a pull request.

Author: Liang-Chi Hsieh <vi...@gmail.com>

Closes #15796 from viirya/fix-codege-compilation-error-2.0.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b5d7217a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b5d7217a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b5d7217a

Branch: refs/heads/branch-2.0
Commit: b5d7217aff80c4c407672dc1858c824954953b1d
Parents: dd5cb0a
Author: Liang-Chi Hsieh <vi...@gmail.com>
Authored: Mon Nov 7 19:09:18 2016 +0100
Committer: Herman van Hovell <hv...@databricks.com>
Committed: Mon Nov 7 19:09:18 2016 +0100

----------------------------------------------------------------------
 .../expressions/ReferenceToExpressions.scala    | 28 +++++++++++----
 .../org/apache/spark/sql/DatasetSuite.scala     | 37 ++++++++++++++++++++
 2 files changed, 59 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b5d7217a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala
index 502d791..6c75a7a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala
@@ -45,6 +45,7 @@ case class ReferenceToExpressions(result: Expression, children: Seq[Expression])
     var maxOrdinal = -1
     result foreach {
       case b: BoundReference if b.ordinal > maxOrdinal => maxOrdinal = b.ordinal
+      case _ =>
     }
     if (maxOrdinal > children.length) {
       return TypeCheckFailure(s"The result expression need $maxOrdinal input expressions, but " +
@@ -62,15 +63,30 @@ case class ReferenceToExpressions(result: Expression, children: Seq[Expression])
 
   override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val childrenGen = children.map(_.genCode(ctx))
-    val childrenVars = childrenGen.zip(children).map {
-      case (childGen, child) => LambdaVariable(childGen.value, childGen.isNull, child.dataType)
-    }
+    val (classChildrenVars, initClassChildrenVars) = childrenGen.zip(children).map {
+      case (childGen, child) =>
+        // SPARK-18125: The children vars are local variables. If the result expression uses
+        // splitExpression, those variables cannot be accessed so compilation fails.
+        // To fix it, we use class variables to hold those local variables.
+        val classChildVarName = ctx.freshName("classChildVar")
+        val classChildVarIsNull = ctx.freshName("classChildVarIsNull")
+        ctx.addMutableState(ctx.javaType(child.dataType), classChildVarName, "")
+        ctx.addMutableState("boolean", classChildVarIsNull, "")
+
+        val classChildVar =
+          LambdaVariable(classChildVarName, classChildVarIsNull, child.dataType)
+
+        val initCode = s"${classChildVar.value} = ${childGen.value};\n" +
+          s"${classChildVar.isNull} = ${childGen.isNull};"
+
+        (classChildVar, initCode)
+    }.unzip
 
     val resultGen = result.transform {
-      case b: BoundReference => childrenVars(b.ordinal)
+      case b: BoundReference => classChildrenVars(b.ordinal)
     }.genCode(ctx)
 
-    ExprCode(code = childrenGen.map(_.code).mkString("\n") + "\n" + resultGen.code,
-      isNull = resultGen.isNull, value = resultGen.value)
+    ExprCode(code = childrenGen.map(_.code).mkString("\n") + initClassChildrenVars.mkString("\n") +
+      resultGen.code, isNull = resultGen.isNull, value = resultGen.value)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/b5d7217a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 6113e5d..7a98915 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -895,6 +895,40 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
 
     assert(dataset.collect() sameElements Array(resultValue, resultValue))
   }
+
+  test("SPARK-18125: Spark generated code causes CompileException") {
+    val data = Array(
+      Route("a", "b", 1),
+      Route("a", "b", 2),
+      Route("a", "c", 2),
+      Route("a", "d", 10),
+      Route("b", "a", 1),
+      Route("b", "a", 5),
+      Route("b", "c", 6))
+    val ds = sparkContext.parallelize(data).toDF.as[Route]
+
+    val grped = ds.map(r => GroupedRoutes(r.src, r.dest, Seq(r)))
+      .groupByKey(r => (r.src, r.dest))
+      .reduceGroups { (g1: GroupedRoutes, g2: GroupedRoutes) =>
+        GroupedRoutes(g1.src, g1.dest, g1.routes ++ g2.routes)
+      }.map(_._2)
+
+    val expected = Seq(
+      GroupedRoutes("a", "d", Seq(Route("a", "d", 10))),
+      GroupedRoutes("b", "c", Seq(Route("b", "c", 6))),
+      GroupedRoutes("a", "b", Seq(Route("a", "b", 1), Route("a", "b", 2))),
+      GroupedRoutes("b", "a", Seq(Route("b", "a", 1), Route("b", "a", 5))),
+      GroupedRoutes("a", "c", Seq(Route("a", "c", 2)))
+    )
+
+    implicit def ordering[GroupedRoutes]: Ordering[GroupedRoutes] = new Ordering[GroupedRoutes] {
+      override def compare(x: GroupedRoutes, y: GroupedRoutes): Int = {
+        x.toString.compareTo(y.toString)
+      }
+    }
+
+    checkDatasetUnorderly(grped, expected: _*)
+  }
 }
 
 case class Generic[T](id: T, value: Double)
@@ -967,3 +1001,6 @@ object DatasetTransform {
     ds.map(_ + 1)
   }
 }
+
+case class Route(src: String, dest: String, cost: Int)
+case class GroupedRoutes(src: String, dest: String, routes: Seq[Route])


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org