You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2018/03/05 10:39:10 UTC
[2/2] spark git commit: [SPARK-23546][SQL] Refactor stateless
methods/values in CodegenContext
[SPARK-23546][SQL] Refactor stateless methods/values in CodegenContext
## What changes were proposed in this pull request?
A current `CodegenContext` class has immutable value or method without mutable state, too.
This refactoring moves them to `CodeGenerator` object class which can be accessed from anywhere without an instantiated `CodegenContext` in the program.
## How was this patch tested?
Existing tests
Author: Kazuaki Ishizaki <is...@jp.ibm.com>
Closes #20700 from kiszk/SPARK-23546.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2ce37b50
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2ce37b50
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2ce37b50
Branch: refs/heads/master
Commit: 2ce37b50fc01558f49ad22f89c8659f50544ffec
Parents: 269cd53
Author: Kazuaki Ishizaki <is...@jp.ibm.com>
Authored: Mon Mar 5 11:39:01 2018 +0100
Committer: Herman van Hovell <hv...@databricks.com>
Committed: Mon Mar 5 11:39:01 2018 +0100
----------------------------------------------------------------------
.../catalyst/expressions/BoundAttribute.scala | 9 +-
.../spark/sql/catalyst/expressions/Cast.scala | 35 +-
.../sql/catalyst/expressions/Expression.scala | 16 +-
.../expressions/MonotonicallyIncreasingID.scala | 8 +-
.../sql/catalyst/expressions/ScalaUDF.scala | 7 +-
.../catalyst/expressions/SparkPartitionID.scala | 7 +-
.../sql/catalyst/expressions/TimeWindow.scala | 4 +-
.../sql/catalyst/expressions/arithmetic.scala | 51 +--
.../expressions/bitwiseExpressions.scala | 2 +-
.../expressions/codegen/CodeGenerator.scala | 458 ++++++++++---------
.../expressions/codegen/CodegenFallback.scala | 7 +-
.../codegen/GenerateMutableProjection.scala | 6 +-
.../expressions/codegen/GenerateOrdering.scala | 4 +-
.../codegen/GenerateSafeProjection.scala | 6 +-
.../codegen/GenerateUnsafeProjection.scala | 11 +-
.../expressions/collectionOperations.scala | 6 +-
.../expressions/complexTypeCreator.scala | 4 +-
.../expressions/complexTypeExtractors.scala | 15 +-
.../expressions/conditionalExpressions.scala | 10 +-
.../expressions/datetimeExpressions.scala | 18 +-
.../spark/sql/catalyst/expressions/hash.scala | 25 +-
.../catalyst/expressions/inputFileBlock.scala | 8 +-
.../sql/catalyst/expressions/literals.scala | 8 +-
.../catalyst/expressions/mathExpressions.scala | 5 +-
.../catalyst/expressions/nullExpressions.scala | 22 +-
.../catalyst/expressions/objects/objects.scala | 99 ++--
.../sql/catalyst/expressions/predicates.scala | 14 +-
.../expressions/randomExpressions.scala | 8 +-
.../expressions/regexpExpressions.scala | 8 +-
.../expressions/stringExpressions.scala | 39 +-
.../expressions/CodeGenerationSuite.scala | 4 +-
.../spark/sql/execution/ColumnarBatchScan.scala | 13 +-
.../apache/spark/sql/execution/ExpandExec.scala | 5 +-
.../spark/sql/execution/GenerateExec.scala | 8 +-
.../apache/spark/sql/execution/SortExec.scala | 5 +-
.../sql/execution/WholeStageCodegenExec.scala | 2 +-
.../execution/aggregate/HashAggregateExec.scala | 16 +-
.../execution/aggregate/HashMapGenerator.scala | 8 +-
.../aggregate/RowBasedHashMapGenerator.scala | 8 +-
.../aggregate/VectorizedHashMapGenerator.scala | 11 +-
.../sql/execution/basicPhysicalOperators.scala | 10 +-
.../columnar/GenerateColumnAccessor.scala | 2 +-
.../execution/joins/BroadcastHashJoinExec.scala | 5 +-
.../sql/execution/joins/SortMergeJoinExec.scala | 8 +-
.../org/apache/spark/sql/execution/limit.scala | 7 +-
45 files changed, 535 insertions(+), 497 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/2ce37b50/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index 6a17a39..89ffbb0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors.attachTree
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.types._
/**
@@ -66,13 +66,14 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
ev.copy(code = oev.code)
} else {
assert(ctx.INPUT_ROW != null, "INPUT_ROW and currentVars cannot both be null.")
- val javaType = ctx.javaType(dataType)
- val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
+ val javaType = CodeGenerator.javaType(dataType)
+ val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
if (nullable) {
ev.copy(code =
s"""
|boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal);
- |$javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value);
+ |$javaType ${ev.value} = ${ev.isNull} ?
+ | ${CodeGenerator.defaultValue(dataType)} : ($value);
""".stripMargin)
} else {
ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = "false")
http://git-wip-us.apache.org/repos/asf/spark/blob/2ce37b50/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 79b0516..12330bf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -669,7 +669,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
result: String, resultIsNull: String, resultType: DataType, cast: CastFunction): String = {
s"""
boolean $resultIsNull = $inputIsNull;
- ${ctx.javaType(resultType)} $result = ${ctx.defaultValue(resultType)};
+ ${CodeGenerator.javaType(resultType)} $result = ${CodeGenerator.defaultValue(resultType)};
if (!$inputIsNull) {
${cast(input, result, resultIsNull)}
}
@@ -685,7 +685,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
val funcName = ctx.freshName("elementToString")
val elementToStringFunc = ctx.addNewFunction(funcName,
s"""
- |private UTF8String $funcName(${ctx.javaType(et)} element) {
+ |private UTF8String $funcName(${CodeGenerator.javaType(et)} element) {
| UTF8String elementStr = null;
| ${elementToStringCode("element", "elementStr", null /* resultIsNull won't be used */)}
| return elementStr;
@@ -697,13 +697,13 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
|$buffer.append("[");
|if ($array.numElements() > 0) {
| if (!$array.isNullAt(0)) {
- | $buffer.append($elementToStringFunc(${ctx.getValue(array, et, "0")}));
+ | $buffer.append($elementToStringFunc(${CodeGenerator.getValue(array, et, "0")}));
| }
| for (int $loopIndex = 1; $loopIndex < $array.numElements(); $loopIndex++) {
| $buffer.append(",");
| if (!$array.isNullAt($loopIndex)) {
| $buffer.append(" ");
- | $buffer.append($elementToStringFunc(${ctx.getValue(array, et, loopIndex)}));
+ | $buffer.append($elementToStringFunc(${CodeGenerator.getValue(array, et, loopIndex)}));
| }
| }
|}
@@ -723,7 +723,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
val dataToStringCode = castToStringCode(dataType, ctx)
ctx.addNewFunction(funcName,
s"""
- |private UTF8String $funcName(${ctx.javaType(dataType)} data) {
+ |private UTF8String $funcName(${CodeGenerator.javaType(dataType)} data) {
| UTF8String dataStr = null;
| ${dataToStringCode("data", "dataStr", null /* resultIsNull won't be used */)}
| return dataStr;
@@ -734,23 +734,26 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
val keyToStringFunc = dataToStringFunc("keyToString", kt)
val valueToStringFunc = dataToStringFunc("valueToString", vt)
val loopIndex = ctx.freshName("loopIndex")
+ val getMapFirstKey = CodeGenerator.getValue(s"$map.keyArray()", kt, "0")
+ val getMapFirstValue = CodeGenerator.getValue(s"$map.valueArray()", vt, "0")
+ val getMapKeyArray = CodeGenerator.getValue(s"$map.keyArray()", kt, loopIndex)
+ val getMapValueArray = CodeGenerator.getValue(s"$map.valueArray()", vt, loopIndex)
s"""
|$buffer.append("[");
|if ($map.numElements() > 0) {
- | $buffer.append($keyToStringFunc(${ctx.getValue(s"$map.keyArray()", kt, "0")}));
+ | $buffer.append($keyToStringFunc($getMapFirstKey));
| $buffer.append(" ->");
| if (!$map.valueArray().isNullAt(0)) {
| $buffer.append(" ");
- | $buffer.append($valueToStringFunc(${ctx.getValue(s"$map.valueArray()", vt, "0")}));
+ | $buffer.append($valueToStringFunc($getMapFirstValue));
| }
| for (int $loopIndex = 1; $loopIndex < $map.numElements(); $loopIndex++) {
| $buffer.append(", ");
- | $buffer.append($keyToStringFunc(${ctx.getValue(s"$map.keyArray()", kt, loopIndex)}));
+ | $buffer.append($keyToStringFunc($getMapKeyArray));
| $buffer.append(" ->");
| if (!$map.valueArray().isNullAt($loopIndex)) {
| $buffer.append(" ");
- | $buffer.append($valueToStringFunc(
- | ${ctx.getValue(s"$map.valueArray()", vt, loopIndex)}));
+ | $buffer.append($valueToStringFunc($getMapValueArray));
| }
| }
|}
@@ -773,7 +776,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
| ${if (i != 0) s"""$buffer.append(" ");""" else ""}
|
| // Append $i field into the string buffer
- | ${ctx.javaType(ft)} $field = ${ctx.getValue(row, ft, s"$i")};
+ | ${CodeGenerator.javaType(ft)} $field = ${CodeGenerator.getValue(row, ft, s"$i")};
| UTF8String $fieldStr = null;
| ${fieldToStringCode(field, fieldStr, null /* resultIsNull won't be used */)}
| $buffer.append($fieldStr);
@@ -1202,8 +1205,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
$values[$j] = null;
} else {
boolean $fromElementNull = false;
- ${ctx.javaType(fromType)} $fromElementPrim =
- ${ctx.getValue(c, fromType, j)};
+ ${CodeGenerator.javaType(fromType)} $fromElementPrim =
+ ${CodeGenerator.getValue(c, fromType, j)};
${castCode(ctx, fromElementPrim,
fromElementNull, toElementPrim, toElementNull, toType, elementCast)}
if ($toElementNull) {
@@ -1259,20 +1262,20 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
val fromFieldNull = ctx.freshName("ffn")
val toFieldPrim = ctx.freshName("tfp")
val toFieldNull = ctx.freshName("tfn")
- val fromType = ctx.javaType(from.fields(i).dataType)
+ val fromType = CodeGenerator.javaType(from.fields(i).dataType)
s"""
boolean $fromFieldNull = $tmpInput.isNullAt($i);
if ($fromFieldNull) {
$tmpResult.setNullAt($i);
} else {
$fromType $fromFieldPrim =
- ${ctx.getValue(tmpInput, from.fields(i).dataType, i.toString)};
+ ${CodeGenerator.getValue(tmpInput, from.fields(i).dataType, i.toString)};
${castCode(ctx, fromFieldPrim,
fromFieldNull, toFieldPrim, toFieldNull, to.fields(i).dataType, cast)}
if ($toFieldNull) {
$tmpResult.setNullAt($i);
} else {
- ${ctx.setColumn(tmpResult, to.fields(i).dataType, i, toFieldPrim)};
+ ${CodeGenerator.setColumn(tmpResult, to.fields(i).dataType, i, toFieldPrim)};
}
}
"""
http://git-wip-us.apache.org/repos/asf/spark/blob/2ce37b50/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 4568714..ed90b18 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -119,7 +119,7 @@ abstract class Expression extends TreeNode[Expression] {
// TODO: support whole stage codegen too
if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) {
val setIsNull = if (eval.isNull != "false" && eval.isNull != "true") {
- val globalIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "globalIsNull")
+ val globalIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "globalIsNull")
val localIsNull = eval.isNull
eval.isNull = globalIsNull
s"$globalIsNull = $localIsNull;"
@@ -127,7 +127,7 @@ abstract class Expression extends TreeNode[Expression] {
""
}
- val javaType = ctx.javaType(dataType)
+ val javaType = CodeGenerator.javaType(dataType)
val newValue = ctx.freshName("value")
val funcName = ctx.freshName(nodeName)
@@ -411,14 +411,14 @@ abstract class UnaryExpression extends Expression {
ev.copy(code = s"""
${childGen.code}
boolean ${ev.isNull} = ${childGen.isNull};
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$nullSafeEval
""")
} else {
ev.copy(code = s"""
boolean ${ev.isNull} = false;
${childGen.code}
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$resultCode""", isNull = "false")
}
}
@@ -510,7 +510,7 @@ abstract class BinaryExpression extends Expression {
ev.copy(code = s"""
boolean ${ev.isNull} = true;
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$nullSafeEval
""")
} else {
@@ -518,7 +518,7 @@ abstract class BinaryExpression extends Expression {
boolean ${ev.isNull} = false;
${leftGen.code}
${rightGen.code}
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$resultCode""", isNull = "false")
}
}
@@ -654,7 +654,7 @@ abstract class TernaryExpression extends Expression {
ev.copy(code = s"""
boolean ${ev.isNull} = true;
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$nullSafeEval""")
} else {
ev.copy(code = s"""
@@ -662,7 +662,7 @@ abstract class TernaryExpression extends Expression {
${leftGen.code}
${midGen.code}
${rightGen.code}
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$resultCode""", isNull = "false")
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/2ce37b50/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
index 11fb579..4523079 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.types.{DataType, LongType}
/**
@@ -65,14 +65,14 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val countTerm = ctx.addMutableState(ctx.JAVA_LONG, "count")
+ val countTerm = ctx.addMutableState(CodeGenerator.JAVA_LONG, "count")
val partitionMaskTerm = "partitionMask"
- ctx.addImmutableStateIfNotExists(ctx.JAVA_LONG, partitionMaskTerm)
+ ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_LONG, partitionMaskTerm)
ctx.addPartitionInitializationStatement(s"$countTerm = 0L;")
ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;")
ev.copy(code = s"""
- final ${ctx.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm;
+ final ${CodeGenerator.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm;
$countTerm++;""", isNull = "false")
}
http://git-wip-us.apache.org/repos/asf/spark/blob/2ce37b50/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
index 989c023..e869258 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
@@ -1018,11 +1018,12 @@ case class ScalaUDF(
val udf = ctx.addReferenceObj("udf", function, s"scala.Function${children.length}")
val getFuncResult = s"$udf.apply(${funcArgs.mkString(", ")})"
val resultConverter = s"$convertersTerm[${children.length}]"
+ val boxedType = CodeGenerator.boxedType(dataType)
val callFunc =
s"""
- |${ctx.boxedType(dataType)} $resultTerm = null;
+ |$boxedType $resultTerm = null;
|try {
- | $resultTerm = (${ctx.boxedType(dataType)})$resultConverter.apply($getFuncResult);
+ | $resultTerm = ($boxedType)$resultConverter.apply($getFuncResult);
|} catch (Exception e) {
| throw new org.apache.spark.SparkException($errorMsgTerm, e);
|}
@@ -1035,7 +1036,7 @@ case class ScalaUDF(
|$callFunc
|
|boolean ${ev.isNull} = $resultTerm == null;
- |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|if (!${ev.isNull}) {
| ${ev.value} = $resultTerm;
|}
http://git-wip-us.apache.org/repos/asf/spark/blob/2ce37b50/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala
index a160b9b..cc6a769 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.types.{DataType, IntegerType}
/**
@@ -44,8 +44,9 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val idTerm = "partitionId"
- ctx.addImmutableStateIfNotExists(ctx.JAVA_INT, idTerm)
+ ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_INT, idTerm)
ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;")
- ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = $idTerm;", isNull = "false")
+ ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;",
+ isNull = "false")
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/2ce37b50/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala
index 9a9f579..6c4a360 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala
@@ -22,7 +22,7 @@ import org.apache.commons.lang3.StringUtils
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
@@ -165,7 +165,7 @@ case class PreciseTimestampConversion(
val eval = child.genCode(ctx)
ev.copy(code = eval.code +
s"""boolean ${ev.isNull} = ${eval.isNull};
- |${ctx.javaType(dataType)} ${ev.value} = ${eval.value};
+ |${CodeGenerator.javaType(dataType)} ${ev.value} = ${eval.value};
""".stripMargin)
}
override def nullSafeEval(input: Any): Any = input
http://git-wip-us.apache.org/repos/asf/spark/blob/2ce37b50/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 8bb1459..508bdd5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -49,8 +49,8 @@ case class UnaryMinus(child: Expression) extends UnaryExpression
// codegen would fail to compile if we just write (-($c))
// for example, we could not write --9223372036854775808L in code
s"""
- ${ctx.javaType(dt)} $originValue = (${ctx.javaType(dt)})($eval);
- ${ev.value} = (${ctx.javaType(dt)})(-($originValue));
+ ${CodeGenerator.javaType(dt)} $originValue = (${CodeGenerator.javaType(dt)})($eval);
+ ${ev.value} = (${CodeGenerator.javaType(dt)})(-($originValue));
"""})
case dt: CalendarIntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()")
}
@@ -107,7 +107,7 @@ case class Abs(child: Expression)
case dt: DecimalType =>
defineCodeGen(ctx, ev, c => s"$c.abs()")
case dt: NumericType =>
- defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})(java.lang.Math.abs($c))")
+ defineCodeGen(ctx, ev, c => s"(${CodeGenerator.javaType(dt)})(java.lang.Math.abs($c))")
}
protected override def nullSafeEval(input: Any): Any = numeric.abs(input)
@@ -129,7 +129,7 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant {
// byte and short are casted into int when add, minus, times or divide
case ByteType | ShortType =>
defineCodeGen(ctx, ev,
- (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)")
+ (eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)")
case _ =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
}
@@ -167,7 +167,7 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$plus($eval2)")
case ByteType | ShortType =>
defineCodeGen(ctx, ev,
- (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)")
+ (eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)")
case CalendarIntervalType =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.add($eval2)")
case _ =>
@@ -203,7 +203,7 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$minus($eval2)")
case ByteType | ShortType =>
defineCodeGen(ctx, ev,
- (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)")
+ (eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)")
case CalendarIntervalType =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.subtract($eval2)")
case _ =>
@@ -278,7 +278,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
} else {
s"${eval2.value} == 0"
}
- val javaType = ctx.javaType(dataType)
+ val javaType = CodeGenerator.javaType(dataType)
val divide = if (dataType.isInstanceOf[DecimalType]) {
s"${eval1.value}.$decimalMethod(${eval2.value})"
} else {
@@ -288,7 +288,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
ev.copy(code = s"""
${eval2.code}
boolean ${ev.isNull} = false;
- $javaType ${ev.value} = ${ctx.defaultValue(javaType)};
+ $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if ($isZero) {
${ev.isNull} = true;
} else {
@@ -299,7 +299,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
ev.copy(code = s"""
${eval2.code}
boolean ${ev.isNull} = false;
- $javaType ${ev.value} = ${ctx.defaultValue(javaType)};
+ $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (${eval2.isNull} || $isZero) {
${ev.isNull} = true;
} else {
@@ -365,7 +365,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
} else {
s"${eval2.value} == 0"
}
- val javaType = ctx.javaType(dataType)
+ val javaType = CodeGenerator.javaType(dataType)
val remainder = if (dataType.isInstanceOf[DecimalType]) {
s"${eval1.value}.$decimalMethod(${eval2.value})"
} else {
@@ -375,7 +375,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
ev.copy(code = s"""
${eval2.code}
boolean ${ev.isNull} = false;
- $javaType ${ev.value} = ${ctx.defaultValue(javaType)};
+ $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if ($isZero) {
${ev.isNull} = true;
} else {
@@ -386,7 +386,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
ev.copy(code = s"""
${eval2.code}
boolean ${ev.isNull} = false;
- $javaType ${ev.value} = ${ctx.defaultValue(javaType)};
+ $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (${eval2.isNull} || $isZero) {
${ev.isNull} = true;
} else {
@@ -454,13 +454,13 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
s"${eval2.value} == 0"
}
val remainder = ctx.freshName("remainder")
- val javaType = ctx.javaType(dataType)
+ val javaType = CodeGenerator.javaType(dataType)
val result = dataType match {
case DecimalType.Fixed(_, _) =>
val decimalAdd = "$plus"
s"""
- ${ctx.javaType(dataType)} $remainder = ${eval1.value}.remainder(${eval2.value});
+ $javaType $remainder = ${eval1.value}.remainder(${eval2.value});
if ($remainder.compare(new org.apache.spark.sql.types.Decimal().set(0)) < 0) {
${ev.value}=($remainder.$decimalAdd(${eval2.value})).remainder(${eval2.value});
} else {
@@ -470,17 +470,16 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
// byte and short are casted into int when add, minus, times or divide
case ByteType | ShortType =>
s"""
- ${ctx.javaType(dataType)} $remainder =
- (${ctx.javaType(dataType)})(${eval1.value} % ${eval2.value});
+ $javaType $remainder = ($javaType)(${eval1.value} % ${eval2.value});
if ($remainder < 0) {
- ${ev.value}=(${ctx.javaType(dataType)})(($remainder + ${eval2.value}) % ${eval2.value});
+ ${ev.value}=($javaType)(($remainder + ${eval2.value}) % ${eval2.value});
} else {
${ev.value}=$remainder;
}
"""
case _ =>
s"""
- ${ctx.javaType(dataType)} $remainder = ${eval1.value} % ${eval2.value};
+ $javaType $remainder = ${eval1.value} % ${eval2.value};
if ($remainder < 0) {
${ev.value}=($remainder + ${eval2.value}) % ${eval2.value};
} else {
@@ -493,7 +492,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
ev.copy(code = s"""
${eval2.code}
boolean ${ev.isNull} = false;
- $javaType ${ev.value} = ${ctx.defaultValue(javaType)};
+ $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if ($isZero) {
${ev.isNull} = true;
} else {
@@ -504,7 +503,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
ev.copy(code = s"""
${eval2.code}
boolean ${ev.isNull} = false;
- $javaType ${ev.value} = ${ctx.defaultValue(javaType)};
+ $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (${eval2.isNull} || $isZero) {
${ev.isNull} = true;
} else {
@@ -602,7 +601,7 @@ case class Least(children: Seq[Expression]) extends Expression {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val evalChildren = children.map(_.genCode(ctx))
- ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
+ ev.isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull)
val evals = evalChildren.map(eval =>
s"""
|${eval.code}
@@ -614,7 +613,7 @@ case class Least(children: Seq[Expression]) extends Expression {
""".stripMargin
)
- val resultType = ctx.javaType(dataType)
+ val resultType = CodeGenerator.javaType(dataType)
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = evals,
funcName = "least",
@@ -629,7 +628,7 @@ case class Least(children: Seq[Expression]) extends Expression {
ev.copy(code =
s"""
|${ev.isNull} = true;
- |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|$codes
""".stripMargin)
}
@@ -681,7 +680,7 @@ case class Greatest(children: Seq[Expression]) extends Expression {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val evalChildren = children.map(_.genCode(ctx))
- ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
+ ev.isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull)
val evals = evalChildren.map(eval =>
s"""
|${eval.code}
@@ -693,7 +692,7 @@ case class Greatest(children: Seq[Expression]) extends Expression {
""".stripMargin
)
- val resultType = ctx.javaType(dataType)
+ val resultType = CodeGenerator.javaType(dataType)
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = evals,
funcName = "greatest",
@@ -708,7 +707,7 @@ case class Greatest(children: Seq[Expression]) extends Expression {
ev.copy(code =
s"""
|${ev.isNull} = true;
- |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|$codes
""".stripMargin)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/2ce37b50/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala
index 173481f..cc24e39 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala
@@ -147,7 +147,7 @@ case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInp
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dataType)}) ~($c)")
+ defineCodeGen(ctx, ev, c => s"(${CodeGenerator.javaType(dataType)}) ~($c)")
}
protected override def nullSafeEval(input: Any): Any = not(input)
http://git-wip-us.apache.org/repos/asf/spark/blob/2ce37b50/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 60a6f50..793824b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -59,6 +59,11 @@ import org.apache.spark.util.{ParentClassLoader, Utils}
case class ExprCode(var code: String, var isNull: String, var value: String)
object ExprCode {
+ def forNullValue(dataType: DataType): ExprCode = {
+ val defaultValueLiteral = CodeGenerator.defaultValue(dataType, typedNull = true)
+ ExprCode(code = "", isNull = "true", value = defaultValueLiteral)
+ }
+
def forNonNullValue(value: String): ExprCode = {
ExprCode(code = "", isNull = "false", value = value)
}
@@ -105,6 +110,8 @@ private[codegen] case class NewFunctionSpec(
*/
class CodegenContext {
+ import CodeGenerator._
+
/**
* Holding a list of objects that could be used passed into generated class.
*/
@@ -196,11 +203,11 @@ class CodegenContext {
/**
* Returns the reference of next available slot in current compacted array. The size of each
- * compacted array is controlled by the constant `CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT`.
+ * compacted array is controlled by the constant `MUTABLESTATEARRAY_SIZE_LIMIT`.
* Once reaching the threshold, new compacted array is created.
*/
def getNextSlot(): String = {
- if (currentIndex < CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT) {
+ if (currentIndex < MUTABLESTATEARRAY_SIZE_LIMIT) {
val res = s"${arrayNames.last}[$currentIndex]"
currentIndex += 1
res
@@ -247,10 +254,10 @@ class CodegenContext {
* are satisfied:
* 1. forceInline is true
* 2. its type is primitive type and the total number of the inlined mutable variables
- * is less than `CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD`
+ * is less than `OUTER_CLASS_VARIABLES_THRESHOLD`
* 3. its type is multi-dimensional array
* When a variable is compacted into an array, the max size of the array for compaction
- * is given by `CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT`.
+ * is given by `MUTABLESTATEARRAY_SIZE_LIMIT`.
*/
def addMutableState(
javaType: String,
@@ -261,7 +268,7 @@ class CodegenContext {
// want to put a primitive type variable at outerClass for performance
val canInlinePrimitive = isPrimitiveType(javaType) &&
- (inlinedMutableStates.length < CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD)
+ (inlinedMutableStates.length < OUTER_CLASS_VARIABLES_THRESHOLD)
if (forceInline || canInlinePrimitive || javaType.contains("[][]")) {
val varName = if (useFreshName) freshName(variableName) else variableName
val initCode = initFunc(varName)
@@ -339,7 +346,7 @@ class CodegenContext {
val length = if (index + 1 == numArrays) {
mutableStateArrays.getCurrentIndex
} else {
- CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT
+ MUTABLESTATEARRAY_SIZE_LIMIT
}
if (javaType.contains("[]")) {
// initializer had an one-dimensional array variable
@@ -468,7 +475,7 @@ class CodegenContext {
inlineToOuterClass: Boolean): NewFunctionSpec = {
val (className, classInstance) = if (inlineToOuterClass) {
outerClassName -> ""
- } else if (currClassSize > CodeGenerator.GENERATED_CLASS_SIZE_THRESHOLD) {
+ } else if (currClassSize > GENERATED_CLASS_SIZE_THRESHOLD) {
val className = freshName("NestedClass")
val classInstance = freshName("nestedClassInstance")
@@ -537,14 +544,6 @@ class CodegenContext {
extraClasses.append(code)
}
- final val JAVA_BOOLEAN = "boolean"
- final val JAVA_BYTE = "byte"
- final val JAVA_SHORT = "short"
- final val JAVA_INT = "int"
- final val JAVA_LONG = "long"
- final val JAVA_FLOAT = "float"
- final val JAVA_DOUBLE = "double"
-
/**
* The map from a variable name to it's next ID.
*/
@@ -581,196 +580,6 @@ class CodegenContext {
}
/**
- * Returns the specialized code to access a value from `inputRow` at `ordinal`.
- */
- def getValue(input: String, dataType: DataType, ordinal: String): String = {
- val jt = javaType(dataType)
- dataType match {
- case _ if isPrimitiveType(jt) => s"$input.get${primitiveTypeName(jt)}($ordinal)"
- case t: DecimalType => s"$input.getDecimal($ordinal, ${t.precision}, ${t.scale})"
- case StringType => s"$input.getUTF8String($ordinal)"
- case BinaryType => s"$input.getBinary($ordinal)"
- case CalendarIntervalType => s"$input.getInterval($ordinal)"
- case t: StructType => s"$input.getStruct($ordinal, ${t.size})"
- case _: ArrayType => s"$input.getArray($ordinal)"
- case _: MapType => s"$input.getMap($ordinal)"
- case NullType => "null"
- case udt: UserDefinedType[_] => getValue(input, udt.sqlType, ordinal)
- case _ => s"($jt)$input.get($ordinal, null)"
- }
- }
-
- /**
- * Returns the code to update a column in Row for a given DataType.
- */
- def setColumn(row: String, dataType: DataType, ordinal: Int, value: String): String = {
- val jt = javaType(dataType)
- dataType match {
- case _ if isPrimitiveType(jt) => s"$row.set${primitiveTypeName(jt)}($ordinal, $value)"
- case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})"
- case udt: UserDefinedType[_] => setColumn(row, udt.sqlType, ordinal, value)
- // The UTF8String, InternalRow, ArrayData and MapData may came from UnsafeRow, we should copy
- // it to avoid keeping a "pointer" to a memory region which may get updated afterwards.
- case StringType | _: StructType | _: ArrayType | _: MapType =>
- s"$row.update($ordinal, $value.copy())"
- case _ => s"$row.update($ordinal, $value)"
- }
- }
-
- /**
- * Update a column in MutableRow from ExprCode.
- *
- * @param isVectorized True if the underlying row is of type `ColumnarBatch.Row`, false otherwise
- */
- def updateColumn(
- row: String,
- dataType: DataType,
- ordinal: Int,
- ev: ExprCode,
- nullable: Boolean,
- isVectorized: Boolean = false): String = {
- if (nullable) {
- // Can't call setNullAt on DecimalType, because we need to keep the offset
- if (!isVectorized && dataType.isInstanceOf[DecimalType]) {
- s"""
- if (!${ev.isNull}) {
- ${setColumn(row, dataType, ordinal, ev.value)};
- } else {
- ${setColumn(row, dataType, ordinal, "null")};
- }
- """
- } else {
- s"""
- if (!${ev.isNull}) {
- ${setColumn(row, dataType, ordinal, ev.value)};
- } else {
- $row.setNullAt($ordinal);
- }
- """
- }
- } else {
- s"""${setColumn(row, dataType, ordinal, ev.value)};"""
- }
- }
-
- /**
- * Returns the specialized code to set a given value in a column vector for a given `DataType`.
- */
- def setValue(vector: String, rowId: String, dataType: DataType, value: String): String = {
- val jt = javaType(dataType)
- dataType match {
- case _ if isPrimitiveType(jt) =>
- s"$vector.put${primitiveTypeName(jt)}($rowId, $value);"
- case t: DecimalType => s"$vector.putDecimal($rowId, $value, ${t.precision});"
- case t: StringType => s"$vector.putByteArray($rowId, $value.getBytes());"
- case _ =>
- throw new IllegalArgumentException(s"cannot generate code for unsupported type: $dataType")
- }
- }
-
- /**
- * Returns the specialized code to set a given value in a column vector for a given `DataType`
- * that could potentially be nullable.
- */
- def updateColumn(
- vector: String,
- rowId: String,
- dataType: DataType,
- ev: ExprCode,
- nullable: Boolean): String = {
- if (nullable) {
- s"""
- if (!${ev.isNull}) {
- ${setValue(vector, rowId, dataType, ev.value)}
- } else {
- $vector.putNull($rowId);
- }
- """
- } else {
- s"""${setValue(vector, rowId, dataType, ev.value)};"""
- }
- }
-
- /**
- * Returns the specialized code to access a value from a column vector for a given `DataType`.
- */
- def getValueFromVector(vector: String, dataType: DataType, rowId: String): String = {
- if (dataType.isInstanceOf[StructType]) {
- // `ColumnVector.getStruct` is different from `InternalRow.getStruct`, it only takes an
- // `ordinal` parameter.
- s"$vector.getStruct($rowId)"
- } else {
- getValue(vector, dataType, rowId)
- }
- }
-
- /**
- * Returns the name used in accessor and setter for a Java primitive type.
- */
- def primitiveTypeName(jt: String): String = jt match {
- case JAVA_INT => "Int"
- case _ => boxedType(jt)
- }
-
- def primitiveTypeName(dt: DataType): String = primitiveTypeName(javaType(dt))
-
- /**
- * Returns the Java type for a DataType.
- */
- def javaType(dt: DataType): String = dt match {
- case BooleanType => JAVA_BOOLEAN
- case ByteType => JAVA_BYTE
- case ShortType => JAVA_SHORT
- case IntegerType | DateType => JAVA_INT
- case LongType | TimestampType => JAVA_LONG
- case FloatType => JAVA_FLOAT
- case DoubleType => JAVA_DOUBLE
- case dt: DecimalType => "Decimal"
- case BinaryType => "byte[]"
- case StringType => "UTF8String"
- case CalendarIntervalType => "CalendarInterval"
- case _: StructType => "InternalRow"
- case _: ArrayType => "ArrayData"
- case _: MapType => "MapData"
- case udt: UserDefinedType[_] => javaType(udt.sqlType)
- case ObjectType(cls) if cls.isArray => s"${javaType(ObjectType(cls.getComponentType))}[]"
- case ObjectType(cls) => cls.getName
- case _ => "Object"
- }
-
- /**
- * Returns the boxed type in Java.
- */
- def boxedType(jt: String): String = jt match {
- case JAVA_BOOLEAN => "Boolean"
- case JAVA_BYTE => "Byte"
- case JAVA_SHORT => "Short"
- case JAVA_INT => "Integer"
- case JAVA_LONG => "Long"
- case JAVA_FLOAT => "Float"
- case JAVA_DOUBLE => "Double"
- case other => other
- }
-
- def boxedType(dt: DataType): String = boxedType(javaType(dt))
-
- /**
- * Returns the representation of default value for a given Java Type.
- */
- def defaultValue(jt: String): String = jt match {
- case JAVA_BOOLEAN => "false"
- case JAVA_BYTE => "(byte)-1"
- case JAVA_SHORT => "(short)-1"
- case JAVA_INT => "-1"
- case JAVA_LONG => "-1L"
- case JAVA_FLOAT => "-1.0f"
- case JAVA_DOUBLE => "-1.0"
- case _ => "null"
- }
-
- def defaultValue(dt: DataType): String = defaultValue(javaType(dt))
-
- /**
* Generates code for equal expression in Java.
*/
def genEqual(dataType: DataType, c1: String, c2: String): String = dataType match {
@@ -812,6 +621,7 @@ class CodegenContext {
val isNullB = freshName("isNullB")
val compareFunc = freshName("compareArray")
val minLength = freshName("minLength")
+ val jt = javaType(elementType)
val funcCode: String =
s"""
public int $compareFunc(ArrayData a, ArrayData b) {
@@ -833,8 +643,8 @@ class CodegenContext {
} else if ($isNullB) {
return 1;
} else {
- ${javaType(elementType)} $elementA = ${getValue("a", elementType, "i")};
- ${javaType(elementType)} $elementB = ${getValue("b", elementType, "i")};
+ $jt $elementA = ${getValue("a", elementType, "i")};
+ $jt $elementB = ${getValue("b", elementType, "i")};
int comp = ${genComp(elementType, elementA, elementB)};
if (comp != 0) {
return comp;
@@ -907,19 +717,6 @@ class CodegenContext {
}
/**
- * List of java data types that have special accessors and setters in [[InternalRow]].
- */
- val primitiveTypes =
- Seq(JAVA_BOOLEAN, JAVA_BYTE, JAVA_SHORT, JAVA_INT, JAVA_LONG, JAVA_FLOAT, JAVA_DOUBLE)
-
- /**
- * Returns true if the Java type has a special accessor and setter in [[InternalRow]].
- */
- def isPrimitiveType(jt: String): Boolean = primitiveTypes.contains(jt)
-
- def isPrimitiveType(dt: DataType): Boolean = isPrimitiveType(javaType(dt))
-
- /**
* Splits the generated code of expressions into multiple functions, because function has
* 64kb code size limit in JVM. If the class to which the function would be inlined would grow
* beyond 1000kb, we declare a private, inner sub-class, and the function is inlined to it
@@ -1089,7 +886,7 @@ class CodegenContext {
// for performance reasons, the functions are prepended, instead of appended,
// thus here they are in reversed order
val orderedFunctions = innerClassFunctions.reverse
- if (orderedFunctions.size > CodeGenerator.MERGE_SPLIT_METHODS_THRESHOLD) {
+ if (orderedFunctions.size > MERGE_SPLIT_METHODS_THRESHOLD) {
// Adding a new function to each inner class which contains the invocation of all the
// ones which have been added to that inner class. For example,
// private class NestedClass {
@@ -1289,7 +1086,7 @@ class CodegenContext {
* length less than a pre-defined constant.
*/
def isValidParamLength(paramLength: Int): Boolean = {
- paramLength <= CodeGenerator.MAX_JVM_METHOD_PARAMS_LENGTH
+ paramLength <= MAX_JVM_METHOD_PARAMS_LENGTH
}
}
@@ -1524,4 +1321,221 @@ object CodeGenerator extends Logging {
result
}
})
+
+ /**
+ * Name of Java primitive data type
+ */
+ final val JAVA_BOOLEAN = "boolean"
+ final val JAVA_BYTE = "byte"
+ final val JAVA_SHORT = "short"
+ final val JAVA_INT = "int"
+ final val JAVA_LONG = "long"
+ final val JAVA_FLOAT = "float"
+ final val JAVA_DOUBLE = "double"
+
+ /**
+ * List of java primitive data types
+ */
+ val primitiveTypes =
+ Seq(JAVA_BOOLEAN, JAVA_BYTE, JAVA_SHORT, JAVA_INT, JAVA_LONG, JAVA_FLOAT, JAVA_DOUBLE)
+
+ /**
+ * Returns true if a Java type is Java primitive primitive type
+ */
+ def isPrimitiveType(jt: String): Boolean = primitiveTypes.contains(jt)
+
+ def isPrimitiveType(dt: DataType): Boolean = isPrimitiveType(javaType(dt))
+
+ /**
+ * Returns the specialized code to access a value from `inputRow` at `ordinal`.
+ */
+ def getValue(input: String, dataType: DataType, ordinal: String): String = {
+ val jt = javaType(dataType)
+ dataType match {
+ case _ if isPrimitiveType(jt) => s"$input.get${primitiveTypeName(jt)}($ordinal)"
+ case t: DecimalType => s"$input.getDecimal($ordinal, ${t.precision}, ${t.scale})"
+ case StringType => s"$input.getUTF8String($ordinal)"
+ case BinaryType => s"$input.getBinary($ordinal)"
+ case CalendarIntervalType => s"$input.getInterval($ordinal)"
+ case t: StructType => s"$input.getStruct($ordinal, ${t.size})"
+ case _: ArrayType => s"$input.getArray($ordinal)"
+ case _: MapType => s"$input.getMap($ordinal)"
+ case NullType => "null"
+ case udt: UserDefinedType[_] => getValue(input, udt.sqlType, ordinal)
+ case _ => s"($jt)$input.get($ordinal, null)"
+ }
+ }
+
+ /**
+ * Returns the code to update a column in Row for a given DataType.
+ */
+ def setColumn(row: String, dataType: DataType, ordinal: Int, value: String): String = {
+ val jt = javaType(dataType)
+ dataType match {
+ case _ if isPrimitiveType(jt) => s"$row.set${primitiveTypeName(jt)}($ordinal, $value)"
+ case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})"
+ case udt: UserDefinedType[_] => setColumn(row, udt.sqlType, ordinal, value)
+ // The UTF8String, InternalRow, ArrayData and MapData may came from UnsafeRow, we should copy
+ // it to avoid keeping a "pointer" to a memory region which may get updated afterwards.
+ case StringType | _: StructType | _: ArrayType | _: MapType =>
+ s"$row.update($ordinal, $value.copy())"
+ case _ => s"$row.update($ordinal, $value)"
+ }
+ }
+
+ /**
+ * Update a column in MutableRow from ExprCode.
+ *
+ * @param isVectorized True if the underlying row is of type `ColumnarBatch.Row`, false otherwise
+ */
+ def updateColumn(
+ row: String,
+ dataType: DataType,
+ ordinal: Int,
+ ev: ExprCode,
+ nullable: Boolean,
+ isVectorized: Boolean = false): String = {
+ if (nullable) {
+ // Can't call setNullAt on DecimalType, because we need to keep the offset
+ if (!isVectorized && dataType.isInstanceOf[DecimalType]) {
+ s"""
+ |if (!${ev.isNull}) {
+ | ${setColumn(row, dataType, ordinal, ev.value)};
+ |} else {
+ | ${setColumn(row, dataType, ordinal, "null")};
+ |}
+ """.stripMargin
+ } else {
+ s"""
+ |if (!${ev.isNull}) {
+ | ${setColumn(row, dataType, ordinal, ev.value)};
+ |} else {
+ | $row.setNullAt($ordinal);
+ |}
+ """.stripMargin
+ }
+ } else {
+ s"""${setColumn(row, dataType, ordinal, ev.value)};"""
+ }
+ }
+
+ /**
+ * Returns the specialized code to set a given value in a column vector for a given `DataType`.
+ */
+ def setValue(vector: String, rowId: String, dataType: DataType, value: String): String = {
+ val jt = javaType(dataType)
+ dataType match {
+ case _ if isPrimitiveType(jt) =>
+ s"$vector.put${primitiveTypeName(jt)}($rowId, $value);"
+ case t: DecimalType => s"$vector.putDecimal($rowId, $value, ${t.precision});"
+ case t: StringType => s"$vector.putByteArray($rowId, $value.getBytes());"
+ case _ =>
+ throw new IllegalArgumentException(s"cannot generate code for unsupported type: $dataType")
+ }
+ }
+
+ /**
+ * Returns the specialized code to set a given value in a column vector for a given `DataType`
+ * that could potentially be nullable.
+ */
+ def updateColumn(
+ vector: String,
+ rowId: String,
+ dataType: DataType,
+ ev: ExprCode,
+ nullable: Boolean): String = {
+ if (nullable) {
+ s"""
+ |if (!${ev.isNull}) {
+ | ${setValue(vector, rowId, dataType, ev.value)}
+ |} else {
+ | $vector.putNull($rowId);
+ |}
+ """.stripMargin
+ } else {
+ s"""${setValue(vector, rowId, dataType, ev.value)};"""
+ }
+ }
+
+ /**
+ * Returns the specialized code to access a value from a column vector for a given `DataType`.
+ */
+ def getValueFromVector(vector: String, dataType: DataType, rowId: String): String = {
+ if (dataType.isInstanceOf[StructType]) {
+ // `ColumnVector.getStruct` is different from `InternalRow.getStruct`, it only takes an
+ // `ordinal` parameter.
+ s"$vector.getStruct($rowId)"
+ } else {
+ getValue(vector, dataType, rowId)
+ }
+ }
+
+ /**
+ * Returns the name used in accessor and setter for a Java primitive type.
+ */
+ def primitiveTypeName(jt: String): String = jt match {
+ case JAVA_INT => "Int"
+ case _ => boxedType(jt)
+ }
+
+ def primitiveTypeName(dt: DataType): String = primitiveTypeName(javaType(dt))
+
+ /**
+ * Returns the Java type for a DataType.
+ */
+ def javaType(dt: DataType): String = dt match {
+ case BooleanType => JAVA_BOOLEAN
+ case ByteType => JAVA_BYTE
+ case ShortType => JAVA_SHORT
+ case IntegerType | DateType => JAVA_INT
+ case LongType | TimestampType => JAVA_LONG
+ case FloatType => JAVA_FLOAT
+ case DoubleType => JAVA_DOUBLE
+ case _: DecimalType => "Decimal"
+ case BinaryType => "byte[]"
+ case StringType => "UTF8String"
+ case CalendarIntervalType => "CalendarInterval"
+ case _: StructType => "InternalRow"
+ case _: ArrayType => "ArrayData"
+ case _: MapType => "MapData"
+ case udt: UserDefinedType[_] => javaType(udt.sqlType)
+ case ObjectType(cls) if cls.isArray => s"${javaType(ObjectType(cls.getComponentType))}[]"
+ case ObjectType(cls) => cls.getName
+ case _ => "Object"
+ }
+
+ /**
+ * Returns the boxed type in Java.
+ */
+ def boxedType(jt: String): String = jt match {
+ case JAVA_BOOLEAN => "Boolean"
+ case JAVA_BYTE => "Byte"
+ case JAVA_SHORT => "Short"
+ case JAVA_INT => "Integer"
+ case JAVA_LONG => "Long"
+ case JAVA_FLOAT => "Float"
+ case JAVA_DOUBLE => "Double"
+ case other => other
+ }
+
+ def boxedType(dt: DataType): String = boxedType(javaType(dt))
+
+ /**
+ * Returns the representation of default value for a given Java Type.
+ * @param jt the string name of the Java type
+ * @param typedNull if true, for null literals, return a typed (with a cast) version
+ */
+ def defaultValue(jt: String, typedNull: Boolean): String = jt match {
+ case JAVA_BOOLEAN => "false"
+ case JAVA_BYTE => "(byte)-1"
+ case JAVA_SHORT => "(short)-1"
+ case JAVA_INT => "-1"
+ case JAVA_LONG => "-1L"
+ case JAVA_FLOAT => "-1.0f"
+ case JAVA_DOUBLE => "-1.0"
+ case _ => if (typedNull) s"(($jt)null)" else "null"
+ }
+
+ def defaultValue(dt: DataType, typedNull: Boolean = false): String =
+ defaultValue(javaType(dt), typedNull)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/2ce37b50/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
index 0322d1d..e12420b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
@@ -44,20 +44,21 @@ trait CodegenFallback extends Expression {
}
val objectTerm = ctx.freshName("obj")
val placeHolder = ctx.registerComment(this.toString)
+ val javaType = CodeGenerator.javaType(this.dataType)
if (nullable) {
ev.copy(code = s"""
$placeHolder
Object $objectTerm = ((Expression) references[$idx]).eval($input);
boolean ${ev.isNull} = $objectTerm == null;
- ${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)};
+ $javaType ${ev.value} = ${CodeGenerator.defaultValue(this.dataType)};
if (!${ev.isNull}) {
- ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm;
+ ${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm;
}""")
} else {
ev.copy(code = s"""
$placeHolder
Object $objectTerm = ((Expression) references[$idx]).eval($input);
- ${ctx.javaType(this.dataType)} ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm;
+ $javaType ${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm;
""", isNull = "false")
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/2ce37b50/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index b53c008..d35fd8e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -62,9 +62,9 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP
val projectionCodes: Seq[(String, String, String, Int)] = exprVals.zip(index).map {
case (ev, i) =>
val e = expressions(i)
- val value = ctx.addMutableState(ctx.javaType(e.dataType), "value")
+ val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "value")
if (e.nullable) {
- val isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "isNull")
+ val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "isNull")
(s"""
|${ev.code}
|$isNull = ${ev.isNull};
@@ -84,7 +84,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP
val updates = validExpr.zip(projectionCodes).map {
case (e, (_, isNull, value, i)) =>
val ev = ExprCode("", isNull, value)
- ctx.updateColumn("mutableRow", e.dataType, i, ev, e.nullable)
+ CodeGenerator.updateColumn("mutableRow", e.dataType, i, ev, e.nullable)
}
val allProjections = ctx.splitExpressionsWithCurrentInputs(projectionCodes.map(_._1))
http://git-wip-us.apache.org/repos/asf/spark/blob/2ce37b50/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
index 4a45957..9a51be6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
@@ -89,7 +89,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
s"""
${ctx.INPUT_ROW} = a;
boolean $isNullA;
- ${ctx.javaType(order.child.dataType)} $primitiveA;
+ ${CodeGenerator.javaType(order.child.dataType)} $primitiveA;
{
${eval.code}
$isNullA = ${eval.isNull};
@@ -97,7 +97,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
}
${ctx.INPUT_ROW} = b;
boolean $isNullB;
- ${ctx.javaType(order.child.dataType)} $primitiveB;
+ ${CodeGenerator.javaType(order.child.dataType)} $primitiveB;
{
${eval.code}
$isNullB = ${eval.isNull};
http://git-wip-us.apache.org/repos/asf/spark/blob/2ce37b50/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
index 3dcbb51..f92f70e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
@@ -53,7 +53,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
val rowClass = classOf[GenericInternalRow].getName
val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) =>
- val converter = convertToSafe(ctx, ctx.getValue(tmpInput, dt, i.toString), dt)
+ val converter = convertToSafe(ctx, CodeGenerator.getValue(tmpInput, dt, i.toString), dt)
s"""
if (!$tmpInput.isNullAt($i)) {
${converter.code}
@@ -90,7 +90,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
val arrayClass = classOf[GenericArrayData].getName
val elementConverter = convertToSafe(
- ctx, ctx.getValue(tmpInput, elementType, index), elementType)
+ ctx, CodeGenerator.getValue(tmpInput, elementType, index), elementType)
val code = s"""
final ArrayData $tmpInput = $input;
final int $numElements = $tmpInput.numElements();
@@ -153,7 +153,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
mutableRow.setNullAt($i);
} else {
${converter.code}
- ${ctx.setColumn("mutableRow", e.dataType, i, converter.value)};
+ ${CodeGenerator.setColumn("mutableRow", e.dataType, i, converter.value)};
}
"""
}
http://git-wip-us.apache.org/repos/asf/spark/blob/2ce37b50/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 36ffa8d..22717f5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -52,7 +52,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
// Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
val tmpInput = ctx.freshName("tmpInput")
val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) =>
- ExprCode("", s"$tmpInput.isNullAt($i)", ctx.getValue(tmpInput, dt, i.toString))
+ ExprCode("", s"$tmpInput.isNullAt($i)", CodeGenerator.getValue(tmpInput, dt, i.toString))
}
s"""
@@ -195,16 +195,16 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
case other => other
}
- val jt = ctx.javaType(et)
+ val jt = CodeGenerator.javaType(et)
val elementOrOffsetSize = et match {
case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => 8
- case _ if ctx.isPrimitiveType(jt) => et.defaultSize
+ case _ if CodeGenerator.isPrimitiveType(jt) => et.defaultSize
case _ => 8 // we need 8 bytes to store offset and length
}
val tmpCursor = ctx.freshName("tmpCursor")
- val element = ctx.getValue(tmpInput, et, index)
+ val element = CodeGenerator.getValue(tmpInput, et, index)
val writeElement = et match {
case t: StructType =>
s"""
@@ -235,7 +235,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
case _ => s"$arrayWriter.write($index, $element);"
}
- val primitiveTypeName = if (ctx.isPrimitiveType(jt)) ctx.primitiveTypeName(et) else ""
+ val primitiveTypeName =
+ if (CodeGenerator.isPrimitiveType(jt)) CodeGenerator.primitiveTypeName(et) else ""
s"""
final ArrayData $tmpInput = $input;
if ($tmpInput instanceof UnsafeArrayData) {
http://git-wip-us.apache.org/repos/asf/spark/blob/2ce37b50/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 4270b98..beb8469 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -20,7 +20,7 @@ import java.util.Comparator
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, CodegenFallback, ExprCode}
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.types._
@@ -54,7 +54,7 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType
ev.copy(code = s"""
boolean ${ev.isNull} = false;
${childGen.code}
- ${ctx.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 :
+ ${CodeGenerator.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 :
(${childGen.value}).numElements();""", isNull = "false")
}
}
@@ -270,7 +270,7 @@ case class ArrayContains(left: Expression, right: Expression)
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (arr, value) => {
val i = ctx.freshName("i")
- val getValue = ctx.getValue(arr, right.dataType, i)
+ val getValue = CodeGenerator.getValue(arr, right.dataType, i)
s"""
for (int $i = 0; $i < $arr.numElements(); $i ++) {
if ($arr.isNullAt($i)) {
http://git-wip-us.apache.org/repos/asf/spark/blob/2ce37b50/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 047b80a..85facda 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -90,7 +90,7 @@ private [sql] object GenArrayData {
val arrayDataName = ctx.freshName("arrayData")
val numElements = elementsCode.length
- if (!ctx.isPrimitiveType(elementType)) {
+ if (!CodeGenerator.isPrimitiveType(elementType)) {
val arrayName = ctx.freshName("arrayObject")
val genericArrayClass = classOf[GenericArrayData].getName
@@ -124,7 +124,7 @@ private [sql] object GenArrayData {
ByteArrayMethods.roundNumberOfBytesToNearestWord(elementType.defaultSize * numElements)
val baseOffset = Platform.BYTE_ARRAY_OFFSET
- val primitiveValueTypeName = ctx.primitiveTypeName(elementType)
+ val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
val assignments = elementsCode.zipWithIndex.map { case (eval, i) =>
val isNullAssignment = if (!isMapKey) {
s"$arrayDataName.setNullAt($i);"
http://git-wip-us.apache.org/repos/asf/spark/blob/2ce37b50/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index 7e53ca3..6cdad19 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.types._
@@ -129,12 +129,12 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String]
if ($eval.isNullAt($ordinal)) {
${ev.isNull} = true;
} else {
- ${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)};
+ ${ev.value} = ${CodeGenerator.getValue(eval, dataType, ordinal.toString)};
}
"""
} else {
s"""
- ${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)};
+ ${ev.value} = ${CodeGenerator.getValue(eval, dataType, ordinal.toString)};
"""
}
})
@@ -205,7 +205,7 @@ case class GetArrayStructFields(
} else {
final InternalRow $row = $eval.getStruct($j, $numFields);
$nullSafeEval {
- $values[$j] = ${ctx.getValue(row, field.dataType, ordinal.toString)};
+ $values[$j] = ${CodeGenerator.getValue(row, field.dataType, ordinal.toString)};
}
}
}
@@ -260,7 +260,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
if ($index >= $eval1.numElements() || $index < 0$nullCheck) {
${ev.isNull} = true;
} else {
- ${ev.value} = ${ctx.getValue(eval1, dataType, index)};
+ ${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)};
}
"""
})
@@ -327,6 +327,7 @@ case class GetMapValue(child: Expression, key: Expression)
} else {
""
}
+ val keyJavaType = CodeGenerator.javaType(keyType)
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
s"""
final int $length = $eval1.numElements();
@@ -336,7 +337,7 @@ case class GetMapValue(child: Expression, key: Expression)
int $index = 0;
boolean $found = false;
while ($index < $length && !$found) {
- final ${ctx.javaType(keyType)} $key = ${ctx.getValue(keys, keyType, index)};
+ final $keyJavaType $key = ${CodeGenerator.getValue(keys, keyType, index)};
if (${ctx.genEqual(keyType, key, eval2)}) {
$found = true;
} else {
@@ -347,7 +348,7 @@ case class GetMapValue(child: Expression, key: Expression)
if (!$found$nullCheck) {
${ev.isNull} = true;
} else {
- ${ev.value} = ${ctx.getValue(values, dataType, index)};
+ ${ev.value} = ${CodeGenerator.getValue(values, dataType, index)};
}
"""
})
http://git-wip-us.apache.org/repos/asf/spark/blob/2ce37b50/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index b444c3a..f4e9619 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -69,7 +69,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
s"""
|${condEval.code}
|boolean ${ev.isNull} = false;
- |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|if (!${condEval.isNull} && ${condEval.value}) {
| ${trueEval.code}
| ${ev.isNull} = ${trueEval.isNull};
@@ -191,7 +191,7 @@ case class CaseWhen(
// It is initialized to `NOT_MATCHED`, and if it's set to `HAS_NULL` or `HAS_NONNULL`,
// We won't go on anymore on the computation.
val resultState = ctx.freshName("caseWhenResultState")
- ev.value = ctx.addMutableState(ctx.javaType(dataType), ev.value)
+ ev.value = ctx.addMutableState(CodeGenerator.javaType(dataType), ev.value)
// these blocks are meant to be inside a
// do {
@@ -244,10 +244,10 @@ case class CaseWhen(
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = allConditions,
funcName = "caseWhen",
- returnType = ctx.JAVA_BYTE,
+ returnType = CodeGenerator.JAVA_BYTE,
makeSplitFunction = func =>
s"""
- |${ctx.JAVA_BYTE} $resultState = $NOT_MATCHED;
+ |${CodeGenerator.JAVA_BYTE} $resultState = $NOT_MATCHED;
|do {
| $func
|} while (false);
@@ -264,7 +264,7 @@ case class CaseWhen(
ev.copy(code =
s"""
- |${ctx.JAVA_BYTE} $resultState = $NOT_MATCHED;
+ |${CodeGenerator.JAVA_BYTE} $resultState = $NOT_MATCHED;
|do {
| $codes
|} while (false);
http://git-wip-us.apache.org/repos/asf/spark/blob/2ce37b50/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
index 424871f..1ae4e5a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
@@ -26,7 +26,7 @@ import scala.util.control.NonFatal
import org.apache.commons.lang3.StringEscapeUtils
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
@@ -673,18 +673,19 @@ abstract class UnixTime
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val javaType = CodeGenerator.javaType(dataType)
left.dataType match {
case StringType if right.foldable =>
val df = classOf[DateFormat].getName
if (formatter == null) {
- ExprCode("", "true", ctx.defaultValue(dataType))
+ ExprCode.forNullValue(dataType)
} else {
val formatterName = ctx.addReferenceObj("formatter", formatter, df)
val eval1 = left.genCode(ctx)
ev.copy(code = s"""
${eval1.code}
boolean ${ev.isNull} = ${eval1.isNull};
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${ev.isNull}) {
try {
${ev.value} = $formatterName.parse(${eval1.value}.toString()).getTime() / 1000L;
@@ -713,7 +714,7 @@ abstract class UnixTime
ev.copy(code = s"""
${eval1.code}
boolean ${ev.isNull} = ${eval1.isNull};
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = ${eval1.value} / 1000000L;
}""")
@@ -724,7 +725,7 @@ abstract class UnixTime
ev.copy(code = s"""
${eval1.code}
boolean ${ev.isNull} = ${eval1.isNull};
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = $dtu.daysToMillis(${eval1.value}, $tz) / 1000L;
}""")
@@ -819,7 +820,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[
ev.copy(code = s"""
${t.code}
boolean ${ev.isNull} = ${t.isNull};
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${ev.isNull}) {
try {
${ev.value} = UTF8String.fromString($formatterName.format(
@@ -1344,18 +1345,19 @@ trait TruncInstant extends BinaryExpression with ImplicitCastInputTypes {
: ExprCode = {
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
+ val javaType = CodeGenerator.javaType(dataType)
if (format.foldable) {
if (truncLevel == DateTimeUtils.TRUNC_INVALID || truncLevel > maxLevel) {
ev.copy(code = s"""
boolean ${ev.isNull} = true;
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};""")
+ $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""")
} else {
val t = instant.genCode(ctx)
val truncFuncStr = truncFunc(t.value, truncLevel.toString)
ev.copy(code = s"""
${t.code}
boolean ${ev.isNull} = ${t.isNull};
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = $dtu.$truncFuncStr;
}""")
http://git-wip-us.apache.org/repos/asf/spark/blob/2ce37b50/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
index 055ebf6..b702422 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
@@ -278,7 +278,7 @@ abstract class HashExpression[E] extends Expression {
}
}
- val hashResultType = ctx.javaType(dataType)
+ val hashResultType = CodeGenerator.javaType(dataType)
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = childrenHash,
funcName = "computeHash",
@@ -307,9 +307,10 @@ abstract class HashExpression[E] extends Expression {
ctx: CodegenContext): String = {
val element = ctx.freshName("element")
+ val jt = CodeGenerator.javaType(elementType)
ctx.nullSafeExec(nullable, s"$input.isNullAt($index)") {
s"""
- final ${ctx.javaType(elementType)} $element = ${ctx.getValue(input, elementType, index)};
+ final $jt $element = ${CodeGenerator.getValue(input, elementType, index)};
${computeHash(element, elementType, result, ctx)}
"""
}
@@ -407,7 +408,7 @@ abstract class HashExpression[E] extends Expression {
val fieldsHash = fields.zipWithIndex.map { case (field, index) =>
nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx)
}
- val hashResultType = ctx.javaType(dataType)
+ val hashResultType = CodeGenerator.javaType(dataType)
ctx.splitExpressions(
expressions = fieldsHash,
funcName = "computeHashForStruct",
@@ -651,11 +652,11 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = childrenHash,
funcName = "computeHash",
- extraArguments = Seq(ctx.JAVA_INT -> ev.value),
- returnType = ctx.JAVA_INT,
+ extraArguments = Seq(CodeGenerator.JAVA_INT -> ev.value),
+ returnType = CodeGenerator.JAVA_INT,
makeSplitFunction = body =>
s"""
- |${ctx.JAVA_INT} $childHash = 0;
+ |${CodeGenerator.JAVA_INT} $childHash = 0;
|$body
|return ${ev.value};
""".stripMargin,
@@ -664,8 +665,8 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
ev.copy(code =
s"""
- |${ctx.JAVA_INT} ${ev.value} = $seed;
- |${ctx.JAVA_INT} $childHash = 0;
+ |${CodeGenerator.JAVA_INT} ${ev.value} = $seed;
+ |${CodeGenerator.JAVA_INT} $childHash = 0;
|$codes
""".stripMargin)
}
@@ -780,14 +781,14 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
""".stripMargin
}
- s"${ctx.JAVA_INT} $childResult = 0;\n" + ctx.splitExpressions(
+ s"${CodeGenerator.JAVA_INT} $childResult = 0;\n" + ctx.splitExpressions(
expressions = fieldsHash,
funcName = "computeHashForStruct",
- arguments = Seq("InternalRow" -> input, ctx.JAVA_INT -> result),
- returnType = ctx.JAVA_INT,
+ arguments = Seq("InternalRow" -> input, CodeGenerator.JAVA_INT -> result),
+ returnType = CodeGenerator.JAVA_INT,
makeSplitFunction = body =>
s"""
- |${ctx.JAVA_INT} $childResult = 0;
+ |${CodeGenerator.JAVA_INT} $childResult = 0;
|$body
|return $result;
""".stripMargin,
http://git-wip-us.apache.org/repos/asf/spark/blob/2ce37b50/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala
index 7a8edab..07785e7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.rdd.InputFileBlockHolder
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.types.{DataType, LongType, StringType}
import org.apache.spark.unsafe.types.UTF8String
@@ -42,7 +42,7 @@ case class InputFileName() extends LeafExpression with Nondeterministic {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val className = InputFileBlockHolder.getClass.getName.stripSuffix("$")
- ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " +
+ ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " +
s"$className.getInputFilePath();", isNull = "false")
}
}
@@ -65,7 +65,7 @@ case class InputFileBlockStart() extends LeafExpression with Nondeterministic {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val className = InputFileBlockHolder.getClass.getName.stripSuffix("$")
- ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " +
+ ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " +
s"$className.getStartOffset();", isNull = "false")
}
}
@@ -88,7 +88,7 @@ case class InputFileBlockLength() extends LeafExpression with Nondeterministic {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val className = InputFileBlockHolder.getClass.getName.stripSuffix("$")
- ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " +
+ ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " +
s"$className.getLength();", isNull = "false")
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/2ce37b50/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index c1e65e3..7395609 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -277,13 +277,9 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression {
override def eval(input: InternalRow): Any = value
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val javaType = ctx.javaType(dataType)
+ val javaType = CodeGenerator.javaType(dataType)
if (value == null) {
- val defaultValueLiteral = ctx.defaultValue(javaType) match {
- case "null" => s"(($javaType)null)"
- case lit => lit
- }
- ExprCode(code = "", isNull = "true", value = defaultValueLiteral)
+ ExprCode.forNullValue(dataType)
} else {
dataType match {
case BooleanType | IntegerType | DateType =>
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org