You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/11/30 18:28:31 UTC
spark git commit: [SPARK-22570][SQL] Avoid to create a lot of global
variables by using a local variable with allocation of an object in generated
code
Repository: spark
Updated Branches:
refs/heads/master 932bd09c8 -> 999ec137a
[SPARK-22570][SQL] Avoid to create a lot of global variables by using a local variable with allocation of an object in generated code
## What changes were proposed in this pull request?
This PR reduces # of global variables in generated code by replacing a global variable with a local variable with an allocation of an object every time. When a lot of global variables were generated, the generated code may meet 64K constant pool limit.
This PR reduces # of generated global variables in the following three operations:
* `Cast` with String to primitive byte/short/int/long
* `RegExpReplace`
* `CreateArray`
I intentionally leave [this part](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala#L595-L603). This is because this variable keeps a class that is dynamically generated. In other word, it is not possible to reuse one class.
## How was this patch tested?
Added test cases
Author: Kazuaki Ishizaki <is...@jp.ibm.com>
Closes #19797 from kiszk/SPARK-22570.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/999ec137
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/999ec137
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/999ec137
Branch: refs/heads/master
Commit: 999ec137a97844abbbd483dd98c7ded2f8ff356c
Parents: 932bd09
Author: Kazuaki Ishizaki <is...@jp.ibm.com>
Authored: Fri Dec 1 02:28:24 2017 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Fri Dec 1 02:28:24 2017 +0800
----------------------------------------------------------------------
.../spark/sql/catalyst/expressions/Cast.scala | 24 ++++++-------
.../expressions/complexTypeCreator.scala | 36 ++++++++++++--------
.../expressions/regexpExpressions.scala | 8 ++---
.../sql/catalyst/expressions/CastSuite.scala | 8 +++++
.../expressions/RegexpExpressionsSuite.scala | 11 +++++-
.../catalyst/optimizer/complexTypesSuite.scala | 7 ++++
6 files changed, 61 insertions(+), 33 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/999ec137/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 8cafaef..f4ecbdb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -799,16 +799,16 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
private[this] def castToByteCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
case StringType =>
- val wrapper = ctx.freshName("wrapper")
- ctx.addMutableState("UTF8String.IntWrapper", wrapper,
- s"$wrapper = new UTF8String.IntWrapper();")
+ val wrapper = ctx.freshName("intWrapper")
(c, evPrim, evNull) =>
s"""
+ UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper();
if ($c.toByte($wrapper)) {
$evPrim = (byte) $wrapper.value;
} else {
$evNull = true;
}
+ $wrapper = null;
"""
case BooleanType =>
(c, evPrim, evNull) => s"$evPrim = $c ? (byte) 1 : (byte) 0;"
@@ -826,16 +826,16 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
from: DataType,
ctx: CodegenContext): CastFunction = from match {
case StringType =>
- val wrapper = ctx.freshName("wrapper")
- ctx.addMutableState("UTF8String.IntWrapper", wrapper,
- s"$wrapper = new UTF8String.IntWrapper();")
+ val wrapper = ctx.freshName("intWrapper")
(c, evPrim, evNull) =>
s"""
+ UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper();
if ($c.toShort($wrapper)) {
$evPrim = (short) $wrapper.value;
} else {
$evNull = true;
}
+ $wrapper = null;
"""
case BooleanType =>
(c, evPrim, evNull) => s"$evPrim = $c ? (short) 1 : (short) 0;"
@@ -851,16 +851,16 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
private[this] def castToIntCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
case StringType =>
- val wrapper = ctx.freshName("wrapper")
- ctx.addMutableState("UTF8String.IntWrapper", wrapper,
- s"$wrapper = new UTF8String.IntWrapper();")
+ val wrapper = ctx.freshName("intWrapper")
(c, evPrim, evNull) =>
s"""
+ UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper();
if ($c.toInt($wrapper)) {
$evPrim = $wrapper.value;
} else {
$evNull = true;
}
+ $wrapper = null;
"""
case BooleanType =>
(c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;"
@@ -876,17 +876,17 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
private[this] def castToLongCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
case StringType =>
- val wrapper = ctx.freshName("wrapper")
- ctx.addMutableState("UTF8String.LongWrapper", wrapper,
- s"$wrapper = new UTF8String.LongWrapper();")
+ val wrapper = ctx.freshName("longWrapper")
(c, evPrim, evNull) =>
s"""
+ UTF8String.LongWrapper $wrapper = new UTF8String.LongWrapper();
if ($c.toLong($wrapper)) {
$evPrim = $wrapper.value;
} else {
$evNull = true;
}
+ $wrapper = null;
"""
case BooleanType =>
(c, evPrim, evNull) => s"$evPrim = $c ? 1L : 0L;"
http://git-wip-us.apache.org/repos/asf/spark/blob/999ec137/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 57a7f2e..fc68bf4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -63,7 +63,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
val (preprocess, assigns, postprocess, arrayData) =
GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false)
ev.copy(
- code = preprocess + ctx.splitExpressions(assigns) + postprocess,
+ code = preprocess + assigns + postprocess,
value = arrayData,
isNull = "false")
}
@@ -77,24 +77,22 @@ private [sql] object GenArrayData {
*
* @param ctx a [[CodegenContext]]
* @param elementType data type of underlying array elements
- * @param elementsCode a set of [[ExprCode]] for each element of an underlying array
+ * @param elementsCode concatenated set of [[ExprCode]] for each element of an underlying array
* @param isMapKey if true, throw an exception when the element is null
- * @return (code pre-assignments, assignments to each array elements, code post-assignments,
- * arrayData name)
+ * @return (code pre-assignments, concatenated assignments to each array elements,
+ * code post-assignments, arrayData name)
*/
def genCodeToCreateArrayData(
ctx: CodegenContext,
elementType: DataType,
elementsCode: Seq[ExprCode],
- isMapKey: Boolean): (String, Seq[String], String, String) = {
- val arrayName = ctx.freshName("array")
+ isMapKey: Boolean): (String, String, String, String) = {
val arrayDataName = ctx.freshName("arrayData")
val numElements = elementsCode.length
if (!ctx.isPrimitiveType(elementType)) {
+ val arrayName = ctx.freshName("arrayObject")
val genericArrayClass = classOf[GenericArrayData].getName
- ctx.addMutableState("Object[]", arrayName,
- s"$arrayName = new Object[$numElements];")
val assignments = elementsCode.zipWithIndex.map { case (eval, i) =>
val isNullAssignment = if (!isMapKey) {
@@ -110,17 +108,21 @@ private [sql] object GenArrayData {
}
"""
}
+ val assignmentString = ctx.splitExpressions(
+ expressions = assignments,
+ funcName = "apply",
+ extraArguments = ("Object[]", arrayDataName) :: Nil)
- ("",
- assignments,
+ (s"Object[] $arrayName = new Object[$numElements];",
+ assignmentString,
s"final ArrayData $arrayDataName = new $genericArrayClass($arrayName);",
arrayDataName)
} else {
+ val arrayName = ctx.freshName("array")
val unsafeArraySizeInBytes =
UnsafeArrayData.calculateHeaderPortionInBytes(numElements) +
ByteArrayMethods.roundNumberOfBytesToNearestWord(elementType.defaultSize * numElements)
val baseOffset = Platform.BYTE_ARRAY_OFFSET
- ctx.addMutableState("UnsafeArrayData", arrayDataName)
val primitiveValueTypeName = ctx.primitiveTypeName(elementType)
val assignments = elementsCode.zipWithIndex.map { case (eval, i) =>
@@ -137,14 +139,18 @@ private [sql] object GenArrayData {
}
"""
}
+ val assignmentString = ctx.splitExpressions(
+ expressions = assignments,
+ funcName = "apply",
+ extraArguments = ("UnsafeArrayData", arrayDataName) :: Nil)
(s"""
byte[] $arrayName = new byte[$unsafeArraySizeInBytes];
- $arrayDataName = new UnsafeArrayData();
+ UnsafeArrayData $arrayDataName = new UnsafeArrayData();
Platform.putLong($arrayName, $baseOffset, $numElements);
$arrayDataName.pointTo($arrayName, $baseOffset, $unsafeArraySizeInBytes);
""",
- assignments,
+ assignmentString,
"",
arrayDataName)
}
@@ -216,10 +222,10 @@ case class CreateMap(children: Seq[Expression]) extends Expression {
s"""
final boolean ${ev.isNull} = false;
$preprocessKeyData
- ${ctx.splitExpressions(assignKeys)}
+ $assignKeys
$postprocessKeyData
$preprocessValueData
- ${ctx.splitExpressions(assignValues)}
+ $assignValues
$postprocessValueData
final MapData ${ev.value} = new $mapClass($keyArrayData, $valueArrayData);
"""
http://git-wip-us.apache.org/repos/asf/spark/blob/999ec137/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
index d0d663f..53d7096 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
@@ -321,8 +321,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
val termLastReplacement = ctx.freshName("lastReplacement")
val termLastReplacementInUTF8 = ctx.freshName("lastReplacementInUTF8")
-
- val termResult = ctx.freshName("result")
+ val termResult = ctx.freshName("termResult")
val classNamePattern = classOf[Pattern].getCanonicalName
val classNameStringBuffer = classOf[java.lang.StringBuffer].getCanonicalName
@@ -334,8 +333,6 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
ctx.addMutableState("String", termLastReplacement, s"${termLastReplacement} = null;")
ctx.addMutableState("UTF8String",
termLastReplacementInUTF8, s"${termLastReplacementInUTF8} = null;")
- ctx.addMutableState(classNameStringBuffer,
- termResult, s"${termResult} = new $classNameStringBuffer();")
val setEvNotNull = if (nullable) {
s"${ev.isNull} = false;"
@@ -355,7 +352,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
${termLastReplacementInUTF8} = $rep.clone();
${termLastReplacement} = ${termLastReplacementInUTF8}.toString();
}
- ${termResult}.delete(0, ${termResult}.length());
+ $classNameStringBuffer ${termResult} = new $classNameStringBuffer();
java.util.regex.Matcher ${matcher} = ${termPattern}.matcher($subject.toString());
while (${matcher}.find()) {
@@ -363,6 +360,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
}
${matcher}.appendTail(${termResult});
${ev.value} = UTF8String.fromString(${termResult}.toString());
+ ${termResult} = null;
$setEvNotNull
"""
})
http://git-wip-us.apache.org/repos/asf/spark/blob/999ec137/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
index 7837d65..65617be 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
@@ -23,6 +23,7 @@ import java.util.{Calendar, Locale, TimeZone}
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.catalyst.util.DateTimeUtils.TimeZoneGMT
@@ -845,4 +846,11 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
val outputOuter = Row.fromSeq((1 to N).map(_ => outputInner))
checkEvaluation(cast(Literal.create(inputOuter, fromOuter), toOuter), outputOuter)
}
+
+ test("SPARK-22570: Cast should not create a lot of global variables") {
+ val ctx = new CodegenContext
+ cast("1", IntegerType).genCode(ctx)
+ cast("2", LongType).genCode(ctx)
+ assert(ctx.mutableStates.length == 0)
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/999ec137/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala
index 1ce150e..4fa61fb 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala
@@ -20,7 +20,8 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.types.{IntegerType, StringType}
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
+import org.apache.spark.sql.types.StringType
/**
* Unit tests for regular expression (regexp) related SQL expressions.
@@ -178,6 +179,14 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(nonNullExpr, "num-num", row1)
}
+ test("SPARK-22570: RegExpReplace should not create a lot of global variables") {
+ val ctx = new CodegenContext
+ RegExpReplace(Literal("100"), Literal("(\\d+)"), Literal("num")).genCode(ctx)
+ // four global variables (lastRegex, pattern, lastReplacement, and lastReplacementInUTF8)
+ // are always required
+ assert(ctx.mutableStates.length == 4)
+ }
+
test("RegexExtract") {
val row1 = create_row("100-200", "(\\d+)-(\\d+)", 1)
val row2 = create_row("100-200", "(\\d+)-(\\d+)", 2)
http://git-wip-us.apache.org/repos/asf/spark/blob/999ec137/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala
index 3634acc..e367536 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
@@ -164,6 +165,12 @@ class ComplexTypesSuite extends PlanTest{
comparePlans(Optimizer execute query, expected)
}
+ test("SPARK-22570: CreateArray should not create a lot of global variables") {
+ val ctx = new CodegenContext
+ CreateArray(Seq(Literal(1))).genCode(ctx)
+ assert(ctx.mutableStates.length == 0)
+ }
+
test("simplify map ops") {
val rel = relation
.select(
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org