You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2018/05/22 17:50:29 UTC
[2/2] spark git commit: [SPARK-24121][SQL] Add API for handling
expression code generation
[SPARK-24121][SQL] Add API for handling expression code generation
## What changes were proposed in this pull request?
This patch tries to implement this [proposal](https://github.com/apache/spark/pull/19813#issuecomment-354045400) to add an API for handling expression code generation. It should allow us to manipulate how to generate codes for expressions.
In details, this adds an new abstraction `CodeBlock` to `JavaCode`. `CodeBlock` holds the code snippet and inputs for generating actual java code.
For example, in following java code:
```java
int ${variable} = 1;
boolean ${isNull} = ${CodeGenerator.defaultValue(BooleanType)};
```
`variable`, `isNull` are two `VariableValue` and `CodeGenerator.defaultValue(BooleanType)` is a string. They are all inputs to this code block and held by `CodeBlock` representing this code.
For codegen, we provide a specified string interpolator `code`, so you can define a code like this:
```scala
val codeBlock =
code"""
|int ${variable} = 1;
|boolean ${isNull} = ${CodeGenerator.defaultValue(BooleanType)};
""".stripMargin
// Generates actual java code.
codeBlock.toString
```
Because those inputs are held separately in `CodeBlock` before generating code, we can safely manipulate them, e.g., replacing statements to aliased variables, etc..
## How was this patch tested?
Added tests.
Author: Liang-Chi Hsieh <vi...@gmail.com>
Closes #21193 from viirya/SPARK-24121.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f9f055af
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f9f055af
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f9f055af
Branch: refs/heads/master
Commit: f9f055afa47412eec8228c843b34a90decb9be43
Parents: 8086acc
Author: Liang-Chi Hsieh <vi...@gmail.com>
Authored: Wed May 23 01:50:22 2018 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Wed May 23 01:50:22 2018 +0800
----------------------------------------------------------------------
.../catalyst/expressions/BoundAttribute.scala | 5 +-
.../spark/sql/catalyst/expressions/Cast.scala | 10 +-
.../sql/catalyst/expressions/Expression.scala | 26 ++--
.../expressions/MonotonicallyIncreasingID.scala | 3 +-
.../sql/catalyst/expressions/ScalaUDF.scala | 3 +-
.../sql/catalyst/expressions/SortOrder.scala | 3 +-
.../catalyst/expressions/SparkPartitionID.scala | 3 +-
.../sql/catalyst/expressions/TimeWindow.scala | 3 +-
.../sql/catalyst/expressions/arithmetic.scala | 13 +-
.../expressions/codegen/CodeGenerator.scala | 25 ++--
.../expressions/codegen/CodegenFallback.scala | 5 +-
.../codegen/GenerateSafeProjection.scala | 7 +-
.../codegen/GenerateUnsafeProjection.scala | 5 +-
.../catalyst/expressions/codegen/javaCode.scala | 145 ++++++++++++++++++-
.../expressions/collectionOperations.scala | 19 +--
.../expressions/complexTypeCreator.scala | 7 +-
.../expressions/conditionalExpressions.scala | 5 +-
.../expressions/datetimeExpressions.scala | 23 +--
.../expressions/decimalExpressions.scala | 5 +-
.../sql/catalyst/expressions/generators.scala | 3 +-
.../spark/sql/catalyst/expressions/hash.scala | 5 +-
.../catalyst/expressions/inputFileBlock.scala | 14 +-
.../catalyst/expressions/mathExpressions.scala | 5 +-
.../spark/sql/catalyst/expressions/misc.scala | 5 +-
.../catalyst/expressions/nullExpressions.scala | 9 +-
.../catalyst/expressions/objects/objects.scala | 48 +++---
.../sql/catalyst/expressions/predicates.scala | 15 +-
.../expressions/randomExpressions.scala | 5 +-
.../expressions/regexpExpressions.scala | 9 +-
.../expressions/stringExpressions.scala | 25 ++--
.../expressions/ExpressionEvalHelperSuite.scala | 3 +-
.../expressions/codegen/CodeBlockSuite.scala | 136 +++++++++++++++++
.../spark/sql/execution/ColumnarBatchScan.scala | 9 +-
.../apache/spark/sql/execution/ExpandExec.scala | 3 +-
.../spark/sql/execution/GenerateExec.scala | 5 +-
.../sql/execution/WholeStageCodegenExec.scala | 15 +-
.../execution/aggregate/HashAggregateExec.scala | 7 +-
.../execution/aggregate/HashMapGenerator.scala | 3 +-
.../execution/joins/BroadcastHashJoinExec.scala | 3 +-
.../sql/execution/joins/SortMergeJoinExec.scala | 5 +-
.../spark/sql/GeneratorFunctionSuite.scala | 4 +-
41 files changed, 479 insertions(+), 172 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index 4cc84b2..df3ab05 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -21,6 +21,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types._
/**
@@ -56,13 +57,13 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
if (nullable) {
ev.copy(code =
- s"""
+ code"""
|boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal);
|$javaType ${ev.value} = ${ev.isNull} ?
| ${CodeGenerator.defaultValue(dataType)} : ($value);
""".stripMargin)
} else {
- ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = FalseLiteral)
+ ev.copy(code = code"$javaType ${ev.value} = $value;", isNull = FalseLiteral)
}
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 12330bf..699ea53 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -23,6 +23,7 @@ import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
@@ -623,8 +624,13 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val eval = child.genCode(ctx)
val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx)
- ev.copy(code = eval.code +
- castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast))
+
+ ev.copy(code =
+ code"""
+ ${eval.code}
+ // This comment is added for manually tracking reference of ${eval.value}, ${eval.isNull}
+ ${castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast)}
+ """)
}
// The function arguments are: `input`, `result` and `resultIsNull`. We don't need `inputIsNull`
http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 97dff6a..9b9fa41 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -22,6 +22,7 @@ import java.util.Locale
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -108,9 +109,9 @@ abstract class Expression extends TreeNode[Expression] {
JavaCode.isNullVariable(isNull),
JavaCode.variable(value, dataType)))
reduceCodeSize(ctx, eval)
- if (eval.code.nonEmpty) {
+ if (eval.code.toString.nonEmpty) {
// Add `this` in the comment.
- eval.copy(code = s"${ctx.registerComment(this.toString)}\n" + eval.code.trim)
+ eval.copy(code = ctx.registerComment(this.toString) + eval.code)
} else {
eval
}
@@ -119,7 +120,7 @@ abstract class Expression extends TreeNode[Expression] {
private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = {
// TODO: support whole stage codegen too
- if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) {
+ if (eval.code.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) {
val setIsNull = if (!eval.isNull.isInstanceOf[LiteralValue]) {
val globalIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "globalIsNull")
val localIsNull = eval.isNull
@@ -136,14 +137,14 @@ abstract class Expression extends TreeNode[Expression] {
val funcFullName = ctx.addNewFunction(funcName,
s"""
|private $javaType $funcName(InternalRow ${ctx.INPUT_ROW}) {
- | ${eval.code.trim}
+ | ${eval.code}
| $setIsNull
| return ${eval.value};
|}
""".stripMargin)
eval.value = JavaCode.variable(newValue, dataType)
- eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});"
+ eval.code = code"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});"
}
}
@@ -437,15 +438,14 @@ abstract class UnaryExpression extends Expression {
if (nullable) {
val nullSafeEval = ctx.nullSafeExec(child.nullable, childGen.isNull)(resultCode)
- ev.copy(code = s"""
+ ev.copy(code = code"""
${childGen.code}
boolean ${ev.isNull} = ${childGen.isNull};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$nullSafeEval
""")
} else {
- ev.copy(code = s"""
- boolean ${ev.isNull} = false;
+ ev.copy(code = code"""
${childGen.code}
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$resultCode""", isNull = FalseLiteral)
@@ -537,14 +537,13 @@ abstract class BinaryExpression extends Expression {
}
}
- ev.copy(code = s"""
+ ev.copy(code = code"""
boolean ${ev.isNull} = true;
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$nullSafeEval
""")
} else {
- ev.copy(code = s"""
- boolean ${ev.isNull} = false;
+ ev.copy(code = code"""
${leftGen.code}
${rightGen.code}
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@@ -681,13 +680,12 @@ abstract class TernaryExpression extends Expression {
}
}
- ev.copy(code = s"""
+ ev.copy(code = code"""
boolean ${ev.isNull} = true;
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$nullSafeEval""")
} else {
- ev.copy(code = s"""
- boolean ${ev.isNull} = false;
+ ev.copy(code = code"""
${leftGen.code}
${midGen.code}
${rightGen.code}
http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
index 9f07796..f1da592 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.{DataType, LongType}
/**
@@ -72,7 +73,7 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Stateful {
ctx.addPartitionInitializationStatement(s"$countTerm = 0L;")
ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;")
- ev.copy(code = s"""
+ ev.copy(code = code"""
final ${CodeGenerator.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm;
$countTerm++;""", isNull = FalseLiteral)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
index e869258..3e7ca88 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.DataType
/**
@@ -1030,7 +1031,7 @@ case class ScalaUDF(
""".stripMargin
ev.copy(code =
- s"""
+ code"""
|$evalCode
|${initArgs.mkString("\n")}
|$callFunc
http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
index ff7c98f..2ce9d07 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.unsafe.sort.PrefixComparators._
@@ -181,7 +182,7 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression {
}
ev.copy(code = childCode.code +
- s"""
+ code"""
|long ${ev.value} = 0L;
|boolean ${ev.isNull} = ${childCode.isNull};
|if (!${childCode.isNull}) {
http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala
index 787bcaf..9856b37 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.{DataType, IntegerType}
/**
@@ -46,7 +47,7 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic {
val idTerm = "partitionId"
ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_INT, idTerm)
ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;")
- ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;",
+ ev.copy(code = code"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;",
isNull = FalseLiteral)
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala
index 6c4a360..84e38a8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
@@ -164,7 +165,7 @@ case class PreciseTimestampConversion(
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val eval = child.genCode(ctx)
ev.copy(code = eval.code +
- s"""boolean ${ev.isNull} = ${eval.isNull};
+ code"""boolean ${ev.isNull} = ${eval.isNull};
|${CodeGenerator.javaType(dataType)} ${ev.value} = ${eval.value};
""".stripMargin)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index efd4e99..fe91e52 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
@@ -259,7 +260,7 @@ trait DivModLike extends BinaryArithmetic {
s"($javaType)(${eval1.value} $symbol ${eval2.value})"
}
if (!left.nullable && !right.nullable) {
- ev.copy(code = s"""
+ ev.copy(code = code"""
${eval2.code}
boolean ${ev.isNull} = false;
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@@ -270,7 +271,7 @@ trait DivModLike extends BinaryArithmetic {
${ev.value} = $operation;
}""")
} else {
- ev.copy(code = s"""
+ ev.copy(code = code"""
${eval2.code}
boolean ${ev.isNull} = false;
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@@ -436,7 +437,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
}
if (!left.nullable && !right.nullable) {
- ev.copy(code = s"""
+ ev.copy(code = code"""
${eval2.code}
boolean ${ev.isNull} = false;
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@@ -447,7 +448,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
$result
}""")
} else {
- ev.copy(code = s"""
+ ev.copy(code = code"""
${eval2.code}
boolean ${ev.isNull} = false;
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@@ -569,7 +570,7 @@ case class Least(children: Seq[Expression]) extends Expression {
""".stripMargin,
foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n"))
ev.copy(code =
- s"""
+ code"""
|${ev.isNull} = true;
|$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|$codes
@@ -644,7 +645,7 @@ case class Greatest(children: Seq[Expression]) extends Expression {
""".stripMargin,
foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n"))
ev.copy(code =
- s"""
+ code"""
|${ev.isNull} = true;
|$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|$codes
http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index d382d9a..66315e5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -38,6 +38,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.metrics.source.CodegenMetrics
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -57,19 +58,19 @@ import org.apache.spark.util.{ParentClassLoader, Utils}
* @param value A term for a (possibly primitive) value of the result of the evaluation. Not
* valid if `isNull` is set to `true`.
*/
-case class ExprCode(var code: String, var isNull: ExprValue, var value: ExprValue)
+case class ExprCode(var code: Block, var isNull: ExprValue, var value: ExprValue)
object ExprCode {
def apply(isNull: ExprValue, value: ExprValue): ExprCode = {
- ExprCode(code = "", isNull, value)
+ ExprCode(code = EmptyBlock, isNull, value)
}
def forNullValue(dataType: DataType): ExprCode = {
- ExprCode(code = "", isNull = TrueLiteral, JavaCode.defaultLiteral(dataType))
+ ExprCode(code = EmptyBlock, isNull = TrueLiteral, JavaCode.defaultLiteral(dataType))
}
def forNonNullValue(value: ExprValue): ExprCode = {
- ExprCode(code = "", isNull = FalseLiteral, value = value)
+ ExprCode(code = EmptyBlock, isNull = FalseLiteral, value = value)
}
}
@@ -330,9 +331,9 @@ class CodegenContext {
def addBufferedState(dataType: DataType, variableName: String, initCode: String): ExprCode = {
val value = addMutableState(javaType(dataType), variableName)
val code = dataType match {
- case StringType => s"$value = $initCode.clone();"
- case _: StructType | _: ArrayType | _: MapType => s"$value = $initCode.copy();"
- case _ => s"$value = $initCode;"
+ case StringType => code"$value = $initCode.clone();"
+ case _: StructType | _: ArrayType | _: MapType => code"$value = $initCode.copy();"
+ case _ => code"$value = $initCode;"
}
ExprCode(code, FalseLiteral, JavaCode.global(value, dataType))
}
@@ -1056,7 +1057,7 @@ class CodegenContext {
val eval = expr.genCode(this)
val state = SubExprEliminationState(eval.isNull, eval.value)
e.foreach(localSubExprEliminationExprs.put(_, state))
- eval.code.trim
+ eval.code.toString
}
SubExprCodes(codes, localSubExprEliminationExprs.toMap)
}
@@ -1084,7 +1085,7 @@ class CodegenContext {
val fn =
s"""
|private void $fnName(InternalRow $INPUT_ROW) {
- | ${eval.code.trim}
+ | ${eval.code}
| $isNull = ${eval.isNull};
| $value = ${eval.value};
|}
@@ -1141,7 +1142,7 @@ class CodegenContext {
def registerComment(
text: => String,
placeholderId: String = "",
- force: Boolean = false): String = {
+ force: Boolean = false): Block = {
// By default, disable comments in generated code because computing the comments themselves can
// be extremely expensive in certain cases, such as deeply-nested expressions which operate over
// inputs with wide schemas. For more details on the performance issues that motivated this
@@ -1160,9 +1161,9 @@ class CodegenContext {
s"// $text"
}
placeHolderToComments += (name -> comment)
- s"/*$name*/"
+ code"/*$name*/"
} else {
- ""
+ EmptyBlock
}
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
index a91989e..3f4704d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions.codegen
import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, Nondeterministic}
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
/**
* A trait that can be used to provide a fallback mode for expression code generation.
@@ -46,7 +47,7 @@ trait CodegenFallback extends Expression {
val placeHolder = ctx.registerComment(this.toString)
val javaType = CodeGenerator.javaType(this.dataType)
if (nullable) {
- ev.copy(code = s"""
+ ev.copy(code = code"""
$placeHolder
Object $objectTerm = ((Expression) references[$idx]).eval($input);
boolean ${ev.isNull} = $objectTerm == null;
@@ -55,7 +56,7 @@ trait CodegenFallback extends Expression {
${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm;
}""")
} else {
- ev.copy(code = s"""
+ ev.copy(code = code"""
$placeHolder
Object $objectTerm = ((Expression) references[$idx]).eval($input);
$javaType ${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm;
http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
index 01c350e..3977866 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
@@ -22,6 +22,7 @@ import scala.annotation.tailrec
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.types._
@@ -71,7 +72,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
arguments = Seq("InternalRow" -> tmpInput, "Object[]" -> values)
)
val code =
- s"""
+ code"""
|final InternalRow $tmpInput = $input;
|final Object[] $values = new Object[${schema.length}];
|$allFields
@@ -97,7 +98,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
ctx,
JavaCode.expression(CodeGenerator.getValue(tmpInput, elementType, index), elementType),
elementType)
- val code = s"""
+ val code = code"""
final ArrayData $tmpInput = $input;
final int $numElements = $tmpInput.numElements();
final Object[] $values = new Object[$numElements];
@@ -124,7 +125,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
val keyConverter = createCodeForArray(ctx, s"$tmpInput.keyArray()", keyType)
val valueConverter = createCodeForArray(ctx, s"$tmpInput.valueArray()", valueType)
- val code = s"""
+ val code = code"""
final MapData $tmpInput = $input;
${keyConverter.code}
${valueConverter.code}
http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 01b4d6c..8f2a5a0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions.codegen
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types._
/**
@@ -286,7 +287,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
ctx, ctx.INPUT_ROW, exprEvals, exprTypes, rowWriter, isTopLevel = true)
val code =
- s"""
+ code"""
|$rowWriter.reset();
|$evalSubexpr
|$writeExpressions
@@ -343,7 +344,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
| }
|
| public UnsafeRow apply(InternalRow ${ctx.INPUT_ROW}) {
- | ${eval.code.trim}
+ | ${eval.code}
| return ${eval.value};
| }
|
http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala
index 74ff018..250ce48 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen
import java.lang.{Boolean => JBool}
+import scala.collection.mutable.ArrayBuffer
import scala.language.{existentials, implicitConversions}
import org.apache.spark.sql.types.{BooleanType, DataType}
@@ -115,6 +116,147 @@ object JavaCode {
}
/**
+ * A trait representing a block of java code.
+ */
+trait Block extends JavaCode {
+
+ // The expressions to be evaluated inside this block.
+ def exprValues: Set[ExprValue]
+
+ // Returns java code string for this code block.
+ override def toString: String = _marginChar match {
+ case Some(c) => code.stripMargin(c).trim
+ case _ => code.trim
+ }
+
+ def length: Int = toString.length
+
+ def nonEmpty: Boolean = toString.nonEmpty
+
+ // The leading prefix that should be stripped from each line.
+ // By default we strip blanks or control characters followed by '|' from the line.
+ var _marginChar: Option[Char] = Some('|')
+
+ def stripMargin(c: Char): this.type = {
+ _marginChar = Some(c)
+ this
+ }
+
+ def stripMargin: this.type = {
+ _marginChar = Some('|')
+ this
+ }
+
+ // Concatenates this block with other block.
+ def + (other: Block): Block
+}
+
+object Block {
+
+ val CODE_BLOCK_BUFFER_LENGTH: Int = 512
+
+ implicit def blocksToBlock(blocks: Seq[Block]): Block = Blocks(blocks)
+
+ implicit class BlockHelper(val sc: StringContext) extends AnyVal {
+ def code(args: Any*): Block = {
+ sc.checkLengths(args)
+ if (sc.parts.length == 0) {
+ EmptyBlock
+ } else {
+ args.foreach {
+ case _: ExprValue =>
+ case _: Int | _: Long | _: Float | _: Double | _: String =>
+ case _: Block =>
+ case other => throw new IllegalArgumentException(
+ s"Can not interpolate ${other.getClass.getName} into code block.")
+ }
+
+ val (codeParts, blockInputs) = foldLiteralArgs(sc.parts, args)
+ CodeBlock(codeParts, blockInputs)
+ }
+ }
+ }
+
+ // Folds eagerly the literal args into the code parts.
+ private def foldLiteralArgs(parts: Seq[String], args: Seq[Any]): (Seq[String], Seq[JavaCode]) = {
+ val codeParts = ArrayBuffer.empty[String]
+ val blockInputs = ArrayBuffer.empty[JavaCode]
+
+ val strings = parts.iterator
+ val inputs = args.iterator
+ val buf = new StringBuilder(Block.CODE_BLOCK_BUFFER_LENGTH)
+
+ buf.append(strings.next)
+ while (strings.hasNext) {
+ val input = inputs.next
+ input match {
+ case _: ExprValue | _: Block =>
+ codeParts += buf.toString
+ buf.clear
+ blockInputs += input.asInstanceOf[JavaCode]
+ case _ =>
+ buf.append(input)
+ }
+ buf.append(strings.next)
+ }
+ if (buf.nonEmpty) {
+ codeParts += buf.toString
+ }
+
+ (codeParts.toSeq, blockInputs.toSeq)
+ }
+}
+
+/**
+ * A block of java code. Including a sequence of code parts and some inputs to this block.
+ * The actual java code is generated by embedding the inputs into the code parts.
+ */
+case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[JavaCode]) extends Block {
+ override lazy val exprValues: Set[ExprValue] = {
+ blockInputs.flatMap {
+ case b: Block => b.exprValues
+ case e: ExprValue => Set(e)
+ }.toSet
+ }
+
+ override lazy val code: String = {
+ val strings = codeParts.iterator
+ val inputs = blockInputs.iterator
+ val buf = new StringBuilder(Block.CODE_BLOCK_BUFFER_LENGTH)
+ buf.append(StringContext.treatEscapes(strings.next))
+ while (strings.hasNext) {
+ buf.append(inputs.next)
+ buf.append(StringContext.treatEscapes(strings.next))
+ }
+ buf.toString
+ }
+
+ override def + (other: Block): Block = other match {
+ case c: CodeBlock => Blocks(Seq(this, c))
+ case b: Blocks => Blocks(Seq(this) ++ b.blocks)
+ case EmptyBlock => this
+ }
+}
+
+case class Blocks(blocks: Seq[Block]) extends Block {
+ override lazy val exprValues: Set[ExprValue] = blocks.flatMap(_.exprValues).toSet
+ override lazy val code: String = blocks.map(_.toString).mkString("\n")
+
+ override def + (other: Block): Block = other match {
+ case c: CodeBlock => Blocks(blocks :+ c)
+ case b: Blocks => Blocks(blocks ++ b.blocks)
+ case EmptyBlock => this
+ }
+}
+
+object EmptyBlock extends Block with Serializable {
+ override val code: String = ""
+ override val exprValues: Set[ExprValue] = Set.empty
+
+ override def + (other: Block): Block = other
+}
+
+/**
* A typed java fragment that must be a valid java expression.
*/
trait ExprValue extends JavaCode {
@@ -123,10 +265,9 @@ trait ExprValue extends JavaCode {
}
object ExprValue {
- implicit def exprValueToString(exprValue: ExprValue): String = exprValue.toString
+ implicit def exprValueToString(exprValue: ExprValue): String = exprValue.code
}
-
/**
* A java expression fragment.
*/
http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 7da4c3c..c28eab7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -91,7 +92,7 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val childGen = child.genCode(ctx)
- ev.copy(code = s"""
+ ev.copy(code = code"""
boolean ${ev.isNull} = false;
${childGen.code}
${CodeGenerator.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 :
@@ -1177,14 +1178,14 @@ case class ArrayJoin(
}
if (nullable) {
ev.copy(
- s"""
+ code"""
|boolean ${ev.isNull} = true;
|UTF8String ${ev.value} = null;
|$code
""".stripMargin)
} else {
ev.copy(
- s"""
+ code"""
|UTF8String ${ev.value} = null;
|$code
""".stripMargin, FalseLiteral)
@@ -1269,11 +1270,11 @@ case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCast
val childGen = child.genCode(ctx)
val javaType = CodeGenerator.javaType(dataType)
val i = ctx.freshName("i")
- val item = ExprCode("",
+ val item = ExprCode(EmptyBlock,
isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"),
value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType))
ev.copy(code =
- s"""
+ code"""
|${childGen.code}
|boolean ${ev.isNull} = true;
|$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@@ -1334,11 +1335,11 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast
val childGen = child.genCode(ctx)
val javaType = CodeGenerator.javaType(dataType)
val i = ctx.freshName("i")
- val item = ExprCode("",
+ val item = ExprCode(EmptyBlock,
isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"),
value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType))
ev.copy(code =
- s"""
+ code"""
|${childGen.code}
|boolean ${ev.isNull} = true;
|$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@@ -1653,7 +1654,7 @@ case class Concat(children: Seq[Expression]) extends Expression {
expressions = inputs,
funcName = "valueConcat",
extraArguments = (s"$javaType[]", args) :: Nil)
- ev.copy(s"""
+ ev.copy(code"""
$initCode
$codes
$javaType ${ev.value} = $concatenator.concat($args);
@@ -1963,7 +1964,7 @@ case class ArrayRepeat(left: Expression, right: Expression)
val resultCode = nullElementsProtection(ev, rightGen.isNull, coreLogic)
ev.copy(code =
- s"""
+ code"""
|boolean ${ev.isNull} = false;
|${leftGen.code}
|${rightGen.code}
http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 67876a8..a9867aa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
@@ -63,7 +64,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
val (preprocess, assigns, postprocess, arrayData) =
GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false)
ev.copy(
- code = preprocess + assigns + postprocess,
+ code = code"${preprocess}${assigns}${postprocess}",
value = JavaCode.variable(arrayData, dataType),
isNull = FalseLiteral)
}
@@ -219,7 +220,7 @@ case class CreateMap(children: Seq[Expression]) extends Expression {
val (preprocessValueData, assignValues, postprocessValueData, valueArrayData) =
GenArrayData.genCodeToCreateArrayData(ctx, valueDt, evalValues, false)
val code =
- s"""
+ code"""
final boolean ${ev.isNull} = false;
$preprocessKeyData
$assignKeys
@@ -373,7 +374,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc
extraArguments = "Object[]" -> values :: Nil)
ev.copy(code =
- s"""
+ code"""
|Object[] $values = new Object[${valExprs.size}];
|$valuesCode
|final InternalRow ${ev.value} = new $rowClass($values);
http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index 205d77f..77ac6c0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types._
// scalastyle:off line.size.limit
@@ -66,7 +67,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
val falseEval = falseValue.genCode(ctx)
val code =
- s"""
+ code"""
|${condEval.code}
|boolean ${ev.isNull} = false;
|${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@@ -265,7 +266,7 @@ case class CaseWhen(
}.mkString)
ev.copy(code =
- s"""
+ code"""
|${CodeGenerator.JAVA_BYTE} $resultState = $NOT_MATCHED;
|do {
| $codes
http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
index 03422fe..e8d85f7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
@@ -27,6 +27,7 @@ import org.apache.commons.lang3.StringEscapeUtils
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
@@ -717,7 +718,7 @@ abstract class UnixTime
} else {
val formatterName = ctx.addReferenceObj("formatter", formatter, df)
val eval1 = left.genCode(ctx)
- ev.copy(code = s"""
+ ev.copy(code = code"""
${eval1.code}
boolean ${ev.isNull} = ${eval1.isNull};
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@@ -746,7 +747,7 @@ abstract class UnixTime
})
case TimestampType =>
val eval1 = left.genCode(ctx)
- ev.copy(code = s"""
+ ev.copy(code = code"""
${eval1.code}
boolean ${ev.isNull} = ${eval1.isNull};
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@@ -757,7 +758,7 @@ abstract class UnixTime
val tz = ctx.addReferenceObj("timeZone", timeZone)
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
val eval1 = left.genCode(ctx)
- ev.copy(code = s"""
+ ev.copy(code = code"""
${eval1.code}
boolean ${ev.isNull} = ${eval1.isNull};
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@@ -852,7 +853,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[
} else {
val formatterName = ctx.addReferenceObj("formatter", formatter, df)
val t = left.genCode(ctx)
- ev.copy(code = s"""
+ ev.copy(code = code"""
${t.code}
boolean ${ev.isNull} = ${t.isNull};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@@ -1042,7 +1043,7 @@ case class StringToTimestampWithoutTimezone(child: Expression, timeZoneId: Optio
val tz = ctx.addReferenceObj("timeZone", timeZone)
val longOpt = ctx.freshName("longOpt")
val eval = child.genCode(ctx)
- val code = s"""
+ val code = code"""
|${eval.code}
|${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = true;
|${CodeGenerator.JAVA_LONG} ${ev.value} = ${CodeGenerator.defaultValue(TimestampType)};
@@ -1090,7 +1091,7 @@ case class FromUTCTimestamp(left: Expression, right: Expression)
if (right.foldable) {
val tz = right.eval().asInstanceOf[UTF8String]
if (tz == null) {
- ev.copy(code = s"""
+ ev.copy(code = code"""
|boolean ${ev.isNull} = true;
|long ${ev.value} = 0;
""".stripMargin)
@@ -1104,7 +1105,7 @@ case class FromUTCTimestamp(left: Expression, right: Expression)
ctx.addImmutableStateIfNotExists(tzClass, utcTerm,
v => s"""$v = $dtu.getTimeZone("UTC");""")
val eval = left.genCode(ctx)
- ev.copy(code = s"""
+ ev.copy(code = code"""
|${eval.code}
|boolean ${ev.isNull} = ${eval.isNull};
|long ${ev.value} = 0;
@@ -1287,7 +1288,7 @@ case class ToUTCTimestamp(left: Expression, right: Expression)
if (right.foldable) {
val tz = right.eval().asInstanceOf[UTF8String]
if (tz == null) {
- ev.copy(code = s"""
+ ev.copy(code = code"""
|boolean ${ev.isNull} = true;
|long ${ev.value} = 0;
""".stripMargin)
@@ -1301,7 +1302,7 @@ case class ToUTCTimestamp(left: Expression, right: Expression)
ctx.addImmutableStateIfNotExists(tzClass, utcTerm,
v => s"""$v = $dtu.getTimeZone("UTC");""")
val eval = left.genCode(ctx)
- ev.copy(code = s"""
+ ev.copy(code = code"""
|${eval.code}
|boolean ${ev.isNull} = ${eval.isNull};
|long ${ev.value} = 0;
@@ -1444,13 +1445,13 @@ trait TruncInstant extends BinaryExpression with ImplicitCastInputTypes {
val javaType = CodeGenerator.javaType(dataType)
if (format.foldable) {
if (truncLevel == DateTimeUtils.TRUNC_INVALID || truncLevel > maxLevel) {
- ev.copy(code = s"""
+ ev.copy(code = code"""
boolean ${ev.isNull} = true;
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""")
} else {
val t = instant.genCode(ctx)
val truncFuncStr = truncFunc(t.value, truncLevel.toString)
- ev.copy(code = s"""
+ ev.copy(code = code"""
${t.code}
boolean ${ev.isNull} = ${t.isNull};
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
index db1579b..04de833 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, EmptyBlock, ExprCode}
import org.apache.spark.sql.types._
/**
@@ -72,7 +72,8 @@ case class PromotePrecision(child: Expression) extends UnaryExpression {
override def eval(input: InternalRow): Any = child.eval(input)
/** Just a simple pass-through for code generation. */
override def genCode(ctx: CodegenContext): ExprCode = child.genCode(ctx)
- override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ev.copy("")
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
+ ev.copy(EmptyBlock)
override def prettyName: String = "promote_precision"
override def sql: String = child.sql
override lazy val canonicalized: Expression = child.canonicalized
http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index 3af4bfe..b7c52f1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types._
@@ -215,7 +216,7 @@ case class Stack(children: Seq[Expression]) extends Generator {
// Create the collection.
val wrapperClass = classOf[mutable.WrappedArray[_]].getName
ev.copy(code =
- s"""
+ code"""
|$code
|$wrapperClass<InternalRow> ${ev.value} = $wrapperClass$$.MODULE$$.make($rowData);
""".stripMargin, isNull = FalseLiteral)
http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
index ef79033..cec00b6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
@@ -28,6 +28,7 @@ import org.apache.commons.codec.digest.DigestUtils
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
@@ -293,7 +294,7 @@ abstract class HashExpression[E] extends Expression {
foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n"))
ev.copy(code =
- s"""
+ code"""
|$hashResultType ${ev.value} = $seed;
|$codes
""".stripMargin)
@@ -674,7 +675,7 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
ev.copy(code =
- s"""
+ code"""
|${CodeGenerator.JAVA_INT} ${ev.value} = $seed;
|${CodeGenerator.JAVA_INT} $childHash = 0;
|$codes
http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala
index 2a3cc58..3b0141a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.rdd.InputFileBlockHolder
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.{DataType, LongType, StringType}
import org.apache.spark.unsafe.types.UTF8String
@@ -42,8 +43,9 @@ case class InputFileName() extends LeafExpression with Nondeterministic {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val className = InputFileBlockHolder.getClass.getName.stripSuffix("$")
- ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " +
- s"$className.getInputFilePath();", isNull = FalseLiteral)
+ val typeDef = s"final ${CodeGenerator.javaType(dataType)}"
+ ev.copy(code = code"$typeDef ${ev.value} = $className.getInputFilePath();",
+ isNull = FalseLiteral)
}
}
@@ -65,8 +67,8 @@ case class InputFileBlockStart() extends LeafExpression with Nondeterministic {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val className = InputFileBlockHolder.getClass.getName.stripSuffix("$")
- ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " +
- s"$className.getStartOffset();", isNull = FalseLiteral)
+ val typeDef = s"final ${CodeGenerator.javaType(dataType)}"
+ ev.copy(code = code"$typeDef ${ev.value} = $className.getStartOffset();", isNull = FalseLiteral)
}
}
@@ -88,7 +90,7 @@ case class InputFileBlockLength() extends LeafExpression with Nondeterministic {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val className = InputFileBlockHolder.getClass.getName.stripSuffix("$")
- ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " +
- s"$className.getLength();", isNull = FalseLiteral)
+ val typeDef = s"final ${CodeGenerator.javaType(dataType)}"
+ ev.copy(code = code"$typeDef ${ev.value} = $className.getLength();", isNull = FalseLiteral)
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
index bc4cfce..c2e1720 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.NumberConverter
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -1191,11 +1192,11 @@ abstract class RoundBase(child: Expression, scale: Expression,
val javaType = CodeGenerator.javaType(dataType)
if (scaleV == null) { // if scale is null, no need to eval its child at all
- ev.copy(code = s"""
+ ev.copy(code = code"""
boolean ${ev.isNull} = true;
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""")
} else {
- ev.copy(code = s"""
+ ev.copy(code = code"""
${ce.code}
boolean ${ev.isNull} = ${ce.isNull};
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index b783469..5d98dac 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -21,6 +21,7 @@ import java.util.UUID
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.RandomUUIDGenerator
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -88,7 +89,7 @@ case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCa
// Use unnamed reference that doesn't create a local field here to reduce the number of fields
// because errMsgField is used only when the value is null or false.
val errMsgField = ctx.addReferenceObj("errMsg", errMsg)
- ExprCode(code = s"""${eval.code}
+ ExprCode(code = code"""${eval.code}
|if (${eval.isNull} || !${eval.value}) {
| throw new RuntimeException($errMsgField);
|}""".stripMargin, isNull = TrueLiteral,
@@ -151,7 +152,7 @@ case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Sta
ctx.addPartitionInitializationStatement(s"$randomGen = " +
"new org.apache.spark.sql.catalyst.util.RandomUUIDGenerator(" +
s"${randomSeed.get}L + partitionIndex);")
- ev.copy(code = s"final UTF8String ${ev.value} = $randomGen.getNextUUIDUTF8String();",
+ ev.copy(code = code"final UTF8String ${ev.value} = $randomGen.getNextUUIDUTF8String();",
isNull = FalseLiteral)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
index 0787342..2eeed3b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
@@ -111,7 +112,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
ev.copy(code =
- s"""
+ code"""
|${ev.isNull} = true;
|$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|do {
@@ -232,7 +233,7 @@ case class IsNaN(child: Expression) extends UnaryExpression
val eval = child.genCode(ctx)
child.dataType match {
case DoubleType | FloatType =>
- ev.copy(code = s"""
+ ev.copy(code = code"""
${eval.code}
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
${ev.value} = !${eval.isNull} && Double.isNaN(${eval.value});""", isNull = FalseLiteral)
@@ -278,7 +279,7 @@ case class NaNvl(left: Expression, right: Expression)
val rightGen = right.genCode(ctx)
left.dataType match {
case DoubleType | FloatType =>
- ev.copy(code = s"""
+ ev.copy(code = code"""
${leftGen.code}
boolean ${ev.isNull} = false;
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@@ -440,7 +441,7 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate
}.mkString)
ev.copy(code =
- s"""
+ code"""
|${CodeGenerator.JAVA_INT} $nonnull = 0;
|do {
| $codes
http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index f974fd8..2bf4203 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
@@ -269,7 +270,7 @@ case class StaticInvoke(
s"${ev.value} = $callFunc;"
}
- val code = s"""
+ val code = code"""
$argCode
$prepareIsNull
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@@ -385,8 +386,7 @@ case class Invoke(
"""
}
- val code = s"""
- ${obj.code}
+ val code = obj.code + code"""
boolean ${ev.isNull} = true;
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${obj.isNull}) {
@@ -492,7 +492,7 @@ case class NewInstance(
s"new $className($argString)"
}
- val code = s"""
+ val code = code"""
$argCode
${outer.map(_.code).getOrElse("")}
final $javaType ${ev.value} = ${ev.isNull} ?
@@ -532,9 +532,7 @@ case class UnwrapOption(
val javaType = CodeGenerator.javaType(dataType)
val inputObject = child.genCode(ctx)
- val code = s"""
- ${inputObject.code}
-
+ val code = inputObject.code + code"""
final boolean ${ev.isNull} = ${inputObject.isNull} || ${inputObject.value}.isEmpty();
$javaType ${ev.value} = ${ev.isNull} ? ${CodeGenerator.defaultValue(dataType)} :
(${CodeGenerator.boxedType(javaType)}) ${inputObject.value}.get();
@@ -564,9 +562,7 @@ case class WrapOption(child: Expression, optType: DataType)
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val inputObject = child.genCode(ctx)
- val code = s"""
- ${inputObject.code}
-
+ val code = inputObject.code + code"""
scala.Option ${ev.value} =
${inputObject.isNull} ?
scala.Option$$.MODULE$$.apply(null) : new scala.Some(${inputObject.value});
@@ -935,8 +931,7 @@ case class MapObjects private(
)
}
- val code = s"""
- ${genInputData.code}
+ val code = genInputData.code + code"""
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${genInputData.isNull}) {
@@ -1147,8 +1142,7 @@ case class CatalystToExternalMap private(
"""
val getBuilderResult = s"${ev.value} = (${collClass.getName}) $builderValue.result();"
- val code = s"""
- ${genInputData.code}
+ val code = genInputData.code + code"""
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${genInputData.isNull}) {
@@ -1391,9 +1385,8 @@ case class ExternalMapToCatalyst private(
val mapCls = classOf[ArrayBasedMapData].getName
val convertedKeyType = CodeGenerator.boxedType(keyConverter.dataType)
val convertedValueType = CodeGenerator.boxedType(valueConverter.dataType)
- val code =
- s"""
- ${inputMap.code}
+ val code = inputMap.code +
+ code"""
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${inputMap.isNull}) {
final int $length = ${inputMap.value}.size();
@@ -1471,7 +1464,7 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType)
val schemaField = ctx.addReferenceObj("schema", schema)
val code =
- s"""
+ code"""
|Object[] $values = new Object[${children.size}];
|$childrenCode
|final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, $schemaField);
@@ -1499,8 +1492,7 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean)
val javaType = CodeGenerator.javaType(dataType)
val serialize = s"$serializer.serialize(${input.value}, null).array()"
- val code = s"""
- ${input.code}
+ val code = input.code + code"""
final $javaType ${ev.value} =
${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $serialize;
"""
@@ -1532,8 +1524,7 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B
val deserialize =
s"($javaType) $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null)"
- val code = s"""
- ${input.code}
+ val code = input.code + code"""
final $javaType ${ev.value} =
${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $deserialize;
"""
@@ -1614,9 +1605,8 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp
funcName = "initializeJavaBean",
extraArguments = beanInstanceJavaType -> javaBeanInstance :: Nil)
- val code =
- s"""
- |${instanceGen.code}
+ val code = instanceGen.code +
+ code"""
|$beanInstanceJavaType $javaBeanInstance = ${instanceGen.value};
|if (!${instanceGen.isNull}) {
| $initializeCode
@@ -1664,9 +1654,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil)
// because errMsgField is used only when the value is null.
val errMsgField = ctx.addReferenceObj("errMsg", errMsg)
- val code = s"""
- ${childGen.code}
-
+ val code = childGen.code + code"""
if (${childGen.isNull}) {
throw new NullPointerException($errMsgField);
}
@@ -1709,7 +1697,7 @@ case class GetExternalRowField(
// because errMsgField is used only when the field is null.
val errMsgField = ctx.addReferenceObj("errMsg", errMsg)
val row = child.genCode(ctx)
- val code = s"""
+ val code = code"""
${row.code}
if (${row.isNull}) {
@@ -1784,7 +1772,7 @@ case class ValidateExternalType(child: Expression, expected: DataType)
s"$obj instanceof ${CodeGenerator.boxedType(dataType)}"
}
- val code = s"""
+ val code = code"""
${input.code}
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${input.isNull}) {
http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index f8c6dc4..f54103c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -22,6 +22,7 @@ import scala.collection.immutable.TreeSet
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate}
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
@@ -290,7 +291,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
}.mkString("\n"))
ev.copy(code =
- s"""
+ code"""
|${valueGen.code}
|byte $tmpResult = $HAS_NULL;
|if (!${valueGen.isNull}) {
@@ -354,7 +355,7 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
""
}
ev.copy(code =
- s"""
+ code"""
|${childGen.code}
|${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = ${childGen.isNull};
|${CodeGenerator.JAVA_BOOLEAN} ${ev.value} = false;
@@ -406,7 +407,7 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with
// The result should be `false`, if any of them is `false` whenever the other is null or not.
if (!left.nullable && !right.nullable) {
- ev.copy(code = s"""
+ ev.copy(code = code"""
${eval1.code}
boolean ${ev.value} = false;
@@ -415,7 +416,7 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with
${ev.value} = ${eval2.value};
}""", isNull = FalseLiteral)
} else {
- ev.copy(code = s"""
+ ev.copy(code = code"""
${eval1.code}
boolean ${ev.isNull} = false;
boolean ${ev.value} = false;
@@ -470,7 +471,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P
// The result should be `true`, if any of them is `true` whenever the other is null or not.
if (!left.nullable && !right.nullable) {
ev.isNull = FalseLiteral
- ev.copy(code = s"""
+ ev.copy(code = code"""
${eval1.code}
boolean ${ev.value} = true;
@@ -479,7 +480,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P
${ev.value} = ${eval2.value};
}""", isNull = FalseLiteral)
} else {
- ev.copy(code = s"""
+ ev.copy(code = code"""
${eval1.code}
boolean ${ev.isNull} = false;
boolean ${ev.value} = true;
@@ -621,7 +622,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
val eval1 = left.genCode(ctx)
val eval2 = right.genCode(ctx)
val equalCode = ctx.genEqual(left.dataType, eval1.value, eval2.value)
- ev.copy(code = eval1.code + eval2.code + s"""
+ ev.copy(code = eval1.code + eval2.code + code"""
boolean ${ev.value} = (${eval1.isNull} && ${eval2.isNull}) ||
(!${eval1.isNull} && !${eval2.isNull} && $equalCode);""", isNull = FalseLiteral)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
index 2653b28..926c2f0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom
@@ -82,7 +83,7 @@ case class Rand(child: Expression) extends RDG {
val rngTerm = ctx.addMutableState(className, "rng")
ctx.addPartitionInitializationStatement(
s"$rngTerm = new $className(${seed}L + partitionIndex);")
- ev.copy(code = s"""
+ ev.copy(code = code"""
final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""",
isNull = FalseLiteral)
}
@@ -120,7 +121,7 @@ case class Randn(child: Expression) extends RDG {
val rngTerm = ctx.addMutableState(className, "rng")
ctx.addPartitionInitializationStatement(
s"$rngTerm = new $className(${seed}L + partitionIndex);")
- ev.copy(code = s"""
+ ev.copy(code = code"""
final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""",
isNull = FalseLiteral)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
index ad0c079..7b68bb7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
@@ -23,6 +23,7 @@ import java.util.regex.{MatchResult, Pattern}
import org.apache.commons.lang3.StringEscapeUtils
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -123,7 +124,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi
// We don't use nullSafeCodeGen here because we don't want to re-evaluate right again.
val eval = left.genCode(ctx)
- ev.copy(code = s"""
+ ev.copy(code = code"""
${eval.code}
boolean ${ev.isNull} = ${eval.isNull};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@@ -132,7 +133,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi
}
""")
} else {
- ev.copy(code = s"""
+ ev.copy(code = code"""
boolean ${ev.isNull} = true;
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
""")
@@ -198,7 +199,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress
// We don't use nullSafeCodeGen here because we don't want to re-evaluate right again.
val eval = left.genCode(ctx)
- ev.copy(code = s"""
+ ev.copy(code = code"""
${eval.code}
boolean ${ev.isNull} = ${eval.isNull};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@@ -207,7 +208,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress
}
""")
} else {
- ev.copy(code = s"""
+ ev.copy(code = code"""
boolean ${ev.isNull} = true;
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
""")
http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index ea005a2..9823b2f 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -27,6 +27,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
@@ -105,7 +106,7 @@ case class ConcatWs(children: Seq[Expression])
expressions = inputs,
funcName = "valueConcatWs",
extraArguments = ("UTF8String[]", args) :: Nil)
- ev.copy(s"""
+ ev.copy(code"""
UTF8String[] $args = new UTF8String[$numArgs];
${separator.code}
$codes
@@ -149,7 +150,7 @@ case class ConcatWs(children: Seq[Expression])
}
}.unzip
- val codes = ctx.splitExpressionsWithCurrentInputs(evals.map(_.code))
+ val codes = ctx.splitExpressionsWithCurrentInputs(evals.map(_.code.toString))
val varargCounts = ctx.splitExpressionsWithCurrentInputs(
expressions = varargCount,
@@ -176,7 +177,7 @@ case class ConcatWs(children: Seq[Expression])
foldFunctions = _.map(funcCall => s"$idxVararg = $funcCall;").mkString("\n"))
ev.copy(
- s"""
+ code"""
$codes
int $varargNum = ${children.count(_.dataType == StringType) - 1};
int $idxVararg = 0;
@@ -288,7 +289,7 @@ case class Elt(children: Seq[Expression]) extends Expression {
}.mkString)
ev.copy(
- s"""
+ code"""
|${index.code}
|final int $indexVal = ${index.value};
|${CodeGenerator.JAVA_BOOLEAN} $indexMatched = false;
@@ -654,7 +655,7 @@ case class StringTrim(
val srcString = evals(0)
if (evals.length == 1) {
- ev.copy(evals.map(_.code).mkString + s"""
+ ev.copy(evals.map(_.code) :+ code"""
boolean ${ev.isNull} = false;
UTF8String ${ev.value} = null;
if (${srcString.isNull}) {
@@ -671,7 +672,7 @@ case class StringTrim(
} else {
${ev.value} = ${srcString.value}.trim(${trimString.value});
}"""
- ev.copy(evals.map(_.code).mkString + s"""
+ ev.copy(evals.map(_.code) :+ code"""
boolean ${ev.isNull} = false;
UTF8String ${ev.value} = null;
if (${srcString.isNull}) {
@@ -754,7 +755,7 @@ case class StringTrimLeft(
val srcString = evals(0)
if (evals.length == 1) {
- ev.copy(evals.map(_.code).mkString + s"""
+ ev.copy(evals.map(_.code) :+ code"""
boolean ${ev.isNull} = false;
UTF8String ${ev.value} = null;
if (${srcString.isNull}) {
@@ -771,7 +772,7 @@ case class StringTrimLeft(
} else {
${ev.value} = ${srcString.value}.trimLeft(${trimString.value});
}"""
- ev.copy(evals.map(_.code).mkString + s"""
+ ev.copy(evals.map(_.code) :+ code"""
boolean ${ev.isNull} = false;
UTF8String ${ev.value} = null;
if (${srcString.isNull}) {
@@ -856,7 +857,7 @@ case class StringTrimRight(
val srcString = evals(0)
if (evals.length == 1) {
- ev.copy(evals.map(_.code).mkString + s"""
+ ev.copy(evals.map(_.code) :+ code"""
boolean ${ev.isNull} = false;
UTF8String ${ev.value} = null;
if (${srcString.isNull}) {
@@ -873,7 +874,7 @@ case class StringTrimRight(
} else {
${ev.value} = ${srcString.value}.trimRight(${trimString.value});
}"""
- ev.copy(evals.map(_.code).mkString + s"""
+ ev.copy(evals.map(_.code) :+ code"""
boolean ${ev.isNull} = false;
UTF8String ${ev.value} = null;
if (${srcString.isNull}) {
@@ -1024,7 +1025,7 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression)
val substrGen = substr.genCode(ctx)
val strGen = str.genCode(ctx)
val startGen = start.genCode(ctx)
- ev.copy(code = s"""
+ ev.copy(code = code"""
int ${ev.value} = 0;
boolean ${ev.isNull} = false;
${startGen.code}
@@ -1350,7 +1351,7 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC
val formatter = classOf[java.util.Formatter].getName
val sb = ctx.freshName("sb")
val stringBuffer = classOf[StringBuffer].getName
- ev.copy(code = s"""
+ ev.copy(code = code"""
${pattern.code}
boolean ${ev.isNull} = ${pattern.isNull};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala
index 64b65e2..7c7c4cc 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.{DataType, IntegerType}
/**
@@ -45,7 +46,7 @@ case class BadCodegenExpression() extends LeafExpression {
override def eval(input: InternalRow): Any = 10
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
ev.copy(code =
- s"""
+ code"""
|int some_variable = 11;
|int ${ev.value} = 10;
""".stripMargin)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org