You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/11/30 18:28:31 UTC

spark git commit: [SPARK-22570][SQL] Avoid to create a lot of global variables by using a local variable with allocation of an object in generated code

Repository: spark
Updated Branches:
  refs/heads/master 932bd09c8 -> 999ec137a


[SPARK-22570][SQL] Avoid to create a lot of global variables by using a local variable with allocation of an object in generated code

## What changes were proposed in this pull request?

This PR reduces # of global variables in generated code by replacing a global variable with a local variable with an allocation of an object every time. When a lot of global variables were generated, the generated code may meet 64K constant pool limit.
This PR reduces # of generated global variables in the following three operations:
* `Cast` with String to primitive byte/short/int/long
* `RegExpReplace`
* `CreateArray`

I intentionally leave [this part](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala#L595-L603). This is because this variable keeps a class that is dynamically generated. In other word, it is not possible to reuse one class.

## How was this patch tested?

Added test cases

Author: Kazuaki Ishizaki <is...@jp.ibm.com>

Closes #19797 from kiszk/SPARK-22570.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/999ec137
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/999ec137
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/999ec137

Branch: refs/heads/master
Commit: 999ec137a97844abbbd483dd98c7ded2f8ff356c
Parents: 932bd09
Author: Kazuaki Ishizaki <is...@jp.ibm.com>
Authored: Fri Dec 1 02:28:24 2017 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Fri Dec 1 02:28:24 2017 +0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/expressions/Cast.scala   | 24 ++++++-------
 .../expressions/complexTypeCreator.scala        | 36 ++++++++++++--------
 .../expressions/regexpExpressions.scala         |  8 ++---
 .../sql/catalyst/expressions/CastSuite.scala    |  8 +++++
 .../expressions/RegexpExpressionsSuite.scala    | 11 +++++-
 .../catalyst/optimizer/complexTypesSuite.scala  |  7 ++++
 6 files changed, 61 insertions(+), 33 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/999ec137/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 8cafaef..f4ecbdb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -799,16 +799,16 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
 
   private[this] def castToByteCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
     case StringType =>
-      val wrapper = ctx.freshName("wrapper")
-      ctx.addMutableState("UTF8String.IntWrapper", wrapper,
-        s"$wrapper = new UTF8String.IntWrapper();")
+      val wrapper = ctx.freshName("intWrapper")
       (c, evPrim, evNull) =>
         s"""
+          UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper();
           if ($c.toByte($wrapper)) {
             $evPrim = (byte) $wrapper.value;
           } else {
             $evNull = true;
           }
+          $wrapper = null;
         """
     case BooleanType =>
       (c, evPrim, evNull) => s"$evPrim = $c ? (byte) 1 : (byte) 0;"
@@ -826,16 +826,16 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
       from: DataType,
       ctx: CodegenContext): CastFunction = from match {
     case StringType =>
-      val wrapper = ctx.freshName("wrapper")
-      ctx.addMutableState("UTF8String.IntWrapper", wrapper,
-        s"$wrapper = new UTF8String.IntWrapper();")
+      val wrapper = ctx.freshName("intWrapper")
       (c, evPrim, evNull) =>
         s"""
+          UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper();
           if ($c.toShort($wrapper)) {
             $evPrim = (short) $wrapper.value;
           } else {
             $evNull = true;
           }
+          $wrapper = null;
         """
     case BooleanType =>
       (c, evPrim, evNull) => s"$evPrim = $c ? (short) 1 : (short) 0;"
@@ -851,16 +851,16 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
 
   private[this] def castToIntCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
     case StringType =>
-      val wrapper = ctx.freshName("wrapper")
-      ctx.addMutableState("UTF8String.IntWrapper", wrapper,
-        s"$wrapper = new UTF8String.IntWrapper();")
+      val wrapper = ctx.freshName("intWrapper")
       (c, evPrim, evNull) =>
         s"""
+          UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper();
           if ($c.toInt($wrapper)) {
             $evPrim = $wrapper.value;
           } else {
             $evNull = true;
           }
+          $wrapper = null;
         """
     case BooleanType =>
       (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;"
@@ -876,17 +876,17 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
 
   private[this] def castToLongCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
     case StringType =>
-      val wrapper = ctx.freshName("wrapper")
-      ctx.addMutableState("UTF8String.LongWrapper", wrapper,
-        s"$wrapper = new UTF8String.LongWrapper();")
+      val wrapper = ctx.freshName("longWrapper")
 
       (c, evPrim, evNull) =>
         s"""
+          UTF8String.LongWrapper $wrapper = new UTF8String.LongWrapper();
           if ($c.toLong($wrapper)) {
             $evPrim = $wrapper.value;
           } else {
             $evNull = true;
           }
+          $wrapper = null;
         """
     case BooleanType =>
       (c, evPrim, evNull) => s"$evPrim = $c ? 1L : 0L;"

http://git-wip-us.apache.org/repos/asf/spark/blob/999ec137/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 57a7f2e..fc68bf4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -63,7 +63,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
     val (preprocess, assigns, postprocess, arrayData) =
       GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false)
     ev.copy(
-      code = preprocess + ctx.splitExpressions(assigns) + postprocess,
+      code = preprocess + assigns + postprocess,
       value = arrayData,
       isNull = "false")
   }
@@ -77,24 +77,22 @@ private [sql] object GenArrayData {
    *
    * @param ctx a [[CodegenContext]]
    * @param elementType data type of underlying array elements
-   * @param elementsCode a set of [[ExprCode]] for each element of an underlying array
+   * @param elementsCode concatenated set of [[ExprCode]] for each element of an underlying array
    * @param isMapKey if true, throw an exception when the element is null
-   * @return (code pre-assignments, assignments to each array elements, code post-assignments,
-   *           arrayData name)
+   * @return (code pre-assignments, concatenated assignments to each array elements,
+   *           code post-assignments, arrayData name)
    */
   def genCodeToCreateArrayData(
       ctx: CodegenContext,
       elementType: DataType,
       elementsCode: Seq[ExprCode],
-      isMapKey: Boolean): (String, Seq[String], String, String) = {
-    val arrayName = ctx.freshName("array")
+      isMapKey: Boolean): (String, String, String, String) = {
     val arrayDataName = ctx.freshName("arrayData")
     val numElements = elementsCode.length
 
     if (!ctx.isPrimitiveType(elementType)) {
+      val arrayName = ctx.freshName("arrayObject")
       val genericArrayClass = classOf[GenericArrayData].getName
-      ctx.addMutableState("Object[]", arrayName,
-        s"$arrayName = new Object[$numElements];")
 
       val assignments = elementsCode.zipWithIndex.map { case (eval, i) =>
         val isNullAssignment = if (!isMapKey) {
@@ -110,17 +108,21 @@ private [sql] object GenArrayData {
          }
        """
       }
+      val assignmentString = ctx.splitExpressions(
+        expressions = assignments,
+        funcName = "apply",
+        extraArguments = ("Object[]", arrayDataName) :: Nil)
 
-      ("",
-       assignments,
+      (s"Object[] $arrayName = new Object[$numElements];",
+       assignmentString,
        s"final ArrayData $arrayDataName = new $genericArrayClass($arrayName);",
        arrayDataName)
     } else {
+      val arrayName = ctx.freshName("array")
       val unsafeArraySizeInBytes =
         UnsafeArrayData.calculateHeaderPortionInBytes(numElements) +
         ByteArrayMethods.roundNumberOfBytesToNearestWord(elementType.defaultSize * numElements)
       val baseOffset = Platform.BYTE_ARRAY_OFFSET
-      ctx.addMutableState("UnsafeArrayData", arrayDataName)
 
       val primitiveValueTypeName = ctx.primitiveTypeName(elementType)
       val assignments = elementsCode.zipWithIndex.map { case (eval, i) =>
@@ -137,14 +139,18 @@ private [sql] object GenArrayData {
          }
        """
       }
+      val assignmentString = ctx.splitExpressions(
+        expressions = assignments,
+        funcName = "apply",
+        extraArguments = ("UnsafeArrayData", arrayDataName) :: Nil)
 
       (s"""
         byte[] $arrayName = new byte[$unsafeArraySizeInBytes];
-        $arrayDataName = new UnsafeArrayData();
+        UnsafeArrayData $arrayDataName = new UnsafeArrayData();
         Platform.putLong($arrayName, $baseOffset, $numElements);
         $arrayDataName.pointTo($arrayName, $baseOffset, $unsafeArraySizeInBytes);
       """,
-       assignments,
+       assignmentString,
        "",
        arrayDataName)
     }
@@ -216,10 +222,10 @@ case class CreateMap(children: Seq[Expression]) extends Expression {
       s"""
        final boolean ${ev.isNull} = false;
        $preprocessKeyData
-       ${ctx.splitExpressions(assignKeys)}
+       $assignKeys
        $postprocessKeyData
        $preprocessValueData
-       ${ctx.splitExpressions(assignValues)}
+       $assignValues
        $postprocessValueData
        final MapData ${ev.value} = new $mapClass($keyArrayData, $valueArrayData);
       """

http://git-wip-us.apache.org/repos/asf/spark/blob/999ec137/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
index d0d663f..53d7096 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
@@ -321,8 +321,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
 
     val termLastReplacement = ctx.freshName("lastReplacement")
     val termLastReplacementInUTF8 = ctx.freshName("lastReplacementInUTF8")
-
-    val termResult = ctx.freshName("result")
+    val termResult = ctx.freshName("termResult")
 
     val classNamePattern = classOf[Pattern].getCanonicalName
     val classNameStringBuffer = classOf[java.lang.StringBuffer].getCanonicalName
@@ -334,8 +333,6 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
     ctx.addMutableState("String", termLastReplacement, s"${termLastReplacement} = null;")
     ctx.addMutableState("UTF8String",
       termLastReplacementInUTF8, s"${termLastReplacementInUTF8} = null;")
-    ctx.addMutableState(classNameStringBuffer,
-      termResult, s"${termResult} = new $classNameStringBuffer();")
 
     val setEvNotNull = if (nullable) {
       s"${ev.isNull} = false;"
@@ -355,7 +352,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
         ${termLastReplacementInUTF8} = $rep.clone();
         ${termLastReplacement} = ${termLastReplacementInUTF8}.toString();
       }
-      ${termResult}.delete(0, ${termResult}.length());
+      $classNameStringBuffer ${termResult} = new $classNameStringBuffer();
       java.util.regex.Matcher ${matcher} = ${termPattern}.matcher($subject.toString());
 
       while (${matcher}.find()) {
@@ -363,6 +360,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
       }
       ${matcher}.appendTail(${termResult});
       ${ev.value} = UTF8String.fromString(${termResult}.toString());
+      ${termResult} = null;
       $setEvNotNull
     """
     })

http://git-wip-us.apache.org/repos/asf/spark/blob/999ec137/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
index 7837d65..65617be 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
@@ -23,6 +23,7 @@ import java.util.{Calendar, Locale, TimeZone}
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
 import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.catalyst.util.DateTimeUtils.TimeZoneGMT
@@ -845,4 +846,11 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
     val outputOuter = Row.fromSeq((1 to N).map(_ => outputInner))
     checkEvaluation(cast(Literal.create(inputOuter, fromOuter), toOuter), outputOuter)
   }
+
+  test("SPARK-22570: Cast should not create a lot of global variables") {
+    val ctx = new CodegenContext
+    cast("1", IntegerType).genCode(ctx)
+    cast("2", LongType).genCode(ctx)
+    assert(ctx.mutableStates.length == 0)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/999ec137/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala
index 1ce150e..4fa61fb 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala
@@ -20,7 +20,8 @@ package org.apache.spark.sql.catalyst.expressions
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.types.{IntegerType, StringType}
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
+import org.apache.spark.sql.types.StringType
 
 /**
  * Unit tests for regular expression (regexp) related SQL expressions.
@@ -178,6 +179,14 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
     checkEvaluation(nonNullExpr, "num-num", row1)
   }
 
+  test("SPARK-22570: RegExpReplace should not create a lot of global variables") {
+    val ctx = new CodegenContext
+    RegExpReplace(Literal("100"), Literal("(\\d+)"), Literal("num")).genCode(ctx)
+    // four global variables (lastRegex, pattern, lastReplacement, and lastReplacementInUTF8)
+    // are always required
+    assert(ctx.mutableStates.length == 4)
+  }
+
   test("RegexExtract") {
     val row1 = create_row("100-200", "(\\d+)-(\\d+)", 1)
     val row2 = create_row("100-200", "(\\d+)-(\\d+)", 2)

http://git-wip-us.apache.org/repos/asf/spark/blob/999ec137/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala
index 3634acc..e367536 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
 import org.apache.spark.sql.catalyst.plans.PlanTest
 import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range}
 import org.apache.spark.sql.catalyst.rules.RuleExecutor
@@ -164,6 +165,12 @@ class ComplexTypesSuite extends PlanTest{
     comparePlans(Optimizer execute query, expected)
   }
 
+  test("SPARK-22570: CreateArray should not create a lot of global variables") {
+    val ctx = new CodegenContext
+    CreateArray(Seq(Literal(1))).genCode(ctx)
+    assert(ctx.mutableStates.length == 0)
+  }
+
   test("simplify map ops") {
     val rel = relation
       .select(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org