You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/11/21 00:42:11 UTC

spark git commit: [SPARK-22549][SQL] Fix 64KB JVM bytecode limit problem with concat_ws

Repository: spark
Updated Branches:
  refs/heads/master c13b60e01 -> 41c6f3601


[SPARK-22549][SQL] Fix 64KB JVM bytecode limit problem with concat_ws

## What changes were proposed in this pull request?

This PR changes `concat_ws` code generation to place generated code for expression for arguments into separated methods if these size could be large.
This PR resolved the case of `concat_ws` with a lot of argument

## How was this patch tested?

Added new test cases into `StringExpressionsSuite`

Author: Kazuaki Ishizaki <is...@jp.ibm.com>

Closes #19777 from kiszk/SPARK-22549.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/41c6f360
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/41c6f360
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/41c6f360

Branch: refs/heads/master
Commit: 41c6f36018eb086477f21574aacd71616513bd8e
Parents: c13b60e
Author: Kazuaki Ishizaki <is...@jp.ibm.com>
Authored: Tue Nov 21 01:42:05 2017 +0100
Committer: Wenchen Fan <we...@databricks.com>
Committed: Tue Nov 21 01:42:05 2017 +0100

----------------------------------------------------------------------
 .../expressions/stringExpressions.scala         | 100 ++++++++++++++-----
 .../expressions/StringExpressionsSuite.scala    |  13 +++
 2 files changed, 89 insertions(+), 24 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/41c6f360/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index d5bb7e9..e6f55f4 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -137,13 +137,34 @@ case class ConcatWs(children: Seq[Expression])
     if (children.forall(_.dataType == StringType)) {
       // All children are strings. In that case we can construct a fixed size array.
       val evals = children.map(_.genCode(ctx))
-
-      val inputs = evals.map { eval =>
-        s"${eval.isNull} ? (UTF8String) null : ${eval.value}"
-      }.mkString(", ")
-
-      ev.copy(evals.map(_.code).mkString("\n") + s"""
-        UTF8String ${ev.value} = UTF8String.concatWs($inputs);
+      val separator = evals.head
+      val strings = evals.tail
+      val numArgs = strings.length
+      val args = ctx.freshName("args")
+
+      val inputs = strings.zipWithIndex.map { case (eval, index) =>
+        if (eval.isNull != "true") {
+          s"""
+             ${eval.code}
+             if (!${eval.isNull}) {
+               $args[$index] = ${eval.value};
+             }
+           """
+        } else {
+          ""
+        }
+      }
+      val codes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) {
+        ctx.splitExpressions(inputs, "valueConcatWs",
+          ("InternalRow", ctx.INPUT_ROW) :: ("UTF8String[]", args) :: Nil)
+      } else {
+        inputs.mkString("\n")
+      }
+      ev.copy(s"""
+        UTF8String[] $args = new UTF8String[$numArgs];
+        ${separator.code}
+        $codes
+        UTF8String ${ev.value} = UTF8String.concatWs(${separator.value}, $args);
         boolean ${ev.isNull} = ${ev.value} == null;
       """)
     } else {
@@ -156,32 +177,63 @@ case class ConcatWs(children: Seq[Expression])
         child.dataType match {
           case StringType =>
             ("", // we count all the StringType arguments num at once below.
-              s"$array[$idxInVararg ++] = ${eval.isNull} ? (UTF8String) null : ${eval.value};")
+             if (eval.isNull == "true") {
+               ""
+             } else {
+               s"$array[$idxInVararg ++] = ${eval.isNull} ? (UTF8String) null : ${eval.value};"
+             })
           case _: ArrayType =>
             val size = ctx.freshName("n")
-            (s"""
-              if (!${eval.isNull}) {
-                $varargNum += ${eval.value}.numElements();
-              }
-            """,
-            s"""
-            if (!${eval.isNull}) {
-              final int $size = ${eval.value}.numElements();
-              for (int j = 0; j < $size; j ++) {
-                $array[$idxInVararg ++] = ${ctx.getValue(eval.value, StringType, "j")};
-              }
+            if (eval.isNull == "true") {
+              ("", "")
+            } else {
+              (s"""
+                if (!${eval.isNull}) {
+                  $varargNum += ${eval.value}.numElements();
+                }
+                """,
+               s"""
+                if (!${eval.isNull}) {
+                  final int $size = ${eval.value}.numElements();
+                  for (int j = 0; j < $size; j ++) {
+                    $array[$idxInVararg ++] = ${ctx.getValue(eval.value, StringType, "j")};
+                  }
+                }
+                """)
             }
-            """)
         }
       }.unzip
 
-      ev.copy(evals.map(_.code).mkString("\n") +
-      s"""
+      val codes = ctx.splitExpressions(ctx.INPUT_ROW, evals.map(_.code))
+      val varargCounts = ctx.splitExpressions(varargCount, "varargCountsConcatWs",
+        ("InternalRow", ctx.INPUT_ROW) :: Nil,
+        "int",
+        { body =>
+          s"""
+           int $varargNum = 0;
+           $body
+           return $varargNum;
+         """
+        },
+        _.mkString(s"$varargNum += ", s";\n$varargNum += ", ";"))
+      val varargBuilds = ctx.splitExpressions(varargBuild, "varargBuildsConcatWs",
+        ("InternalRow", ctx.INPUT_ROW) :: ("UTF8String []", array) :: ("int", idxInVararg) :: Nil,
+        "int",
+        { body =>
+          s"""
+           $body
+           return $idxInVararg;
+         """
+        },
+        _.mkString(s"$idxInVararg = ", s";\n$idxInVararg = ", ";"))
+      ev.copy(
+        s"""
+        $codes
         int $varargNum = ${children.count(_.dataType == StringType) - 1};
         int $idxInVararg = 0;
-        ${varargCount.mkString("\n")}
+        $varargCounts
         UTF8String[] $array = new UTF8String[$varargNum];
-        ${varargBuild.mkString("\n")}
+        $varargBuilds
         UTF8String ${ev.value} = UTF8String.concatWs(${evals.head.value}, $array);
         boolean ${ev.isNull} = ${ev.value} == null;
       """)

http://git-wip-us.apache.org/repos/asf/spark/blob/41c6f360/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index aa9d5a0..7ce4306 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -80,6 +80,19 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
     // scalastyle:on
   }
 
+  test("SPARK-22549: ConcatWs should not generate codes beyond 64KB") {
+    val N = 5000
+    val sepExpr = Literal.create("#", StringType)
+    val strings1 = (1 to N).map(x => s"s$x")
+    val inputsExpr1 = strings1.map(Literal.create(_, StringType))
+    checkEvaluation(ConcatWs(sepExpr +: inputsExpr1), strings1.mkString("#"), EmptyRow)
+
+    val strings2 = (1 to N).map(x => Seq(s"s$x"))
+    val inputsExpr2 = strings2.map(Literal.create(_, ArrayType(StringType)))
+    checkEvaluation(
+      ConcatWs(sepExpr +: inputsExpr2), strings2.map(s => s(0)).mkString("#"), EmptyRow)
+  }
+
   test("elt") {
     def testElt(result: String, n: java.lang.Integer, args: String*): Unit = {
       checkEvaluation(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org