You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/06/07 23:11:25 UTC
[2/2] spark git commit: [SPARK-8117] [SQL] Push codegen
implementation into each Expression
[SPARK-8117] [SQL] Push codegen implementation into each Expression
This PR move codegen implementation of expressions into Expression class itself, make it easy to manage.
It introduces two APIs in Expression:
```
def gen(ctx: CodeGenContext): GeneratedExpressionCode
def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code
```
gen(ctx) will call genSource(ctx, ev) to generate Java source code for the current expression. A expression needs to override genSource().
Here are the types:
```
type Term String
type Code String
/**
* Java source for evaluating an [[Expression]] given a [[Row]] of input.
*/
case class GeneratedExpressionCode(var code: Code,
nullTerm: Term,
primitiveTerm: Term,
objectTerm: Term)
/**
* A context for codegen, which is used to bookkeeping the expressions those are not supported
* by codegen, then they are evaluated directly. The unsupported expression is appended at the
* end of `references`, the position of it is kept in the code, used to access and evaluate it.
*/
class CodeGenContext {
/**
* Holding all the expressions those do not support codegen, will be evaluated directly.
*/
val references: Seq[Expression] = new mutable.ArrayBuffer[Expression]()
}
```
This is basically #6660, but fixed style violation and compilation failure.
Author: Davies Liu <da...@databricks.com>
Author: Reynold Xin <rx...@databricks.com>
Closes #6690 from rxin/codegen and squashes the following commits:
e1368c2 [Reynold Xin] Fixed tests.
73db80e [Reynold Xin] Fixed compilation failure.
19d6435 [Reynold Xin] Fixed style violation.
9adaeaf [Davies Liu] address comments
f42c732 [Davies Liu] improve coverage and tests
bad6828 [Davies Liu] address comments
e03edaa [Davies Liu] consts fold
86fac2c [Davies Liu] fix style
02262c9 [Davies Liu] address comments
b5d3617 [Davies Liu] Merge pull request #5 from rxin/codegen
48c454f [Reynold Xin] Some code gen update.
2344bc0 [Davies Liu] fix test
12ff88a [Davies Liu] fix build
c5fb514 [Davies Liu] rename
8c6d82d [Davies Liu] update docs
b145047 [Davies Liu] fix style
e57959d [Davies Liu] add type alias
3ff25f8 [Davies Liu] refactor
593d617 [Davies Liu] pushing codegen into Expression
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5e7b6b67
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5e7b6b67
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5e7b6b67
Branch: refs/heads/master
Commit: 5e7b6b67bed9cd0d8c7d4e78df666b807e8f7ef2
Parents: b127ff8
Author: Davies Liu <da...@databricks.com>
Authored: Sun Jun 7 14:11:20 2015 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Sun Jun 7 14:11:20 2015 -0700
----------------------------------------------------------------------
.../catalyst/expressions/BoundAttribute.scala | 9 +
.../spark/sql/catalyst/expressions/Cast.scala | 42 ++
.../sql/catalyst/expressions/Expression.scala | 100 +++
.../sql/catalyst/expressions/arithmetic.scala | 161 +++-
.../expressions/codegen/CodeGenerator.scala | 750 ++++---------------
.../codegen/GenerateMutableProjection.scala | 6 +-
.../expressions/codegen/GenerateOrdering.scala | 20 +-
.../expressions/codegen/GeneratePredicate.scala | 4 +-
.../codegen/GenerateProjection.scala | 26 +-
.../catalyst/expressions/codegen/package.scala | 3 +
.../catalyst/expressions/decimalFunctions.scala | 19 +
.../sql/catalyst/expressions/literals.scala | 54 ++
.../catalyst/expressions/mathfuncs/binary.scala | 24 +-
.../catalyst/expressions/mathfuncs/unary.scala | 30 +-
.../catalyst/expressions/namedExpressions.scala | 6 +-
.../catalyst/expressions/nullFunctions.scala | 55 ++
.../sql/catalyst/expressions/predicates.scala | 192 ++++-
.../spark/sql/catalyst/expressions/sets.scala | 54 +-
.../catalyst/expressions/stringOperations.scala | 18 +
.../expressions/ExpressionEvaluationSuite.scala | 87 ++-
.../expressions/GeneratedEvaluationSuite.scala | 27 +-
.../GeneratedMutableEvaluationSuite.scala | 61 --
.../ParquetPartitionDiscoverySuite.scala | 6 +-
23 files changed, 1036 insertions(+), 718 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/5e7b6b67/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index 1ffc95c..005de31 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.errors.attachTree
+import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext}
import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.trees
@@ -41,6 +42,14 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
override def qualifiers: Seq[String] = throw new UnsupportedOperationException
override def exprId: ExprId = throw new UnsupportedOperationException
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
+ s"""
+ boolean ${ev.isNull} = i.isNullAt($ordinal);
+ ${ctx.javaType(dataType)} ${ev.primitive} = ${ev.isNull} ?
+ ${ctx.defaultValue(dataType)} : (${ctx.getColumn(dataType, ordinal)});
+ """
+ }
}
object BindReferences extends Logging {
http://git-wip-us.apache.org/repos/asf/spark/blob/5e7b6b67/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 21adac1..5f76a51 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp}
import java.text.{DateFormat, SimpleDateFormat}
import org.apache.spark.Logging
+import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext}
import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.sql.types._
@@ -433,6 +434,47 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
val evaluated = child.eval(input)
if (evaluated == null) null else cast(evaluated)
}
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
+ // TODO(cg): Add support for more data types.
+ (child.dataType, dataType) match {
+
+ case (BinaryType, StringType) =>
+ defineCodeGen (ctx, ev, c =>
+ s"new ${ctx.stringType}().set($c)")
+ case (DateType, StringType) =>
+ defineCodeGen(ctx, ev, c =>
+ s"""new ${ctx.stringType}().set(
+ org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""")
+ // Special handling required for timestamps in hive test cases since the toString function
+ // does not match the expected output.
+ case (TimestampType, StringType) =>
+ super.genCode(ctx, ev)
+ case (_, StringType) =>
+ defineCodeGen(ctx, ev, c => s"new ${ctx.stringType}().set(String.valueOf($c))")
+
+ // fallback for DecimalType, this must be before other numeric types
+ case (_, dt: DecimalType) =>
+ super.genCode(ctx, ev)
+
+ case (BooleanType, dt: NumericType) =>
+ defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c ? 1 : 0)")
+ case (dt: DecimalType, BooleanType) =>
+ defineCodeGen(ctx, ev, c => s"$c.isZero()")
+ case (dt: NumericType, BooleanType) =>
+ defineCodeGen(ctx, ev, c => s"$c != 0")
+
+ case (_: DecimalType, IntegerType) =>
+ defineCodeGen(ctx, ev, c => s"($c).toInt()")
+ case (_: DecimalType, dt: NumericType) =>
+ defineCodeGen(ctx, ev, c => s"($c).to${ctx.boxedType(dt)}()")
+ case (_: NumericType, dt: NumericType) =>
+ defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c)")
+
+ case other =>
+ super.genCode(ctx, ev)
+ }
+ }
}
object Cast {
http://git-wip-us.apache.org/repos/asf/spark/blob/5e7b6b67/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index b2b9d1a..0ed576b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute}
+import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext, Term}
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.types._
@@ -52,6 +53,44 @@ abstract class Expression extends TreeNode[Expression] {
def eval(input: Row = null): Any
/**
+ * Returns an [[GeneratedExpressionCode]], which contains Java source code that
+ * can be used to generate the result of evaluating the expression on an input row.
+ *
+ * @param ctx a [[CodeGenContext]]
+ * @return [[GeneratedExpressionCode]]
+ */
+ def gen(ctx: CodeGenContext): GeneratedExpressionCode = {
+ val isNull = ctx.freshName("isNull")
+ val primitive = ctx.freshName("primitive")
+ val ve = GeneratedExpressionCode("", isNull, primitive)
+ ve.code = genCode(ctx, ve)
+ ve
+ }
+
+ /**
+ * Returns Java source code that can be compiled to evaluate this expression.
+ * The default behavior is to call the eval method of the expression. Concrete expression
+ * implementations should override this to do actual code generation.
+ *
+ * @param ctx a [[CodeGenContext]]
+ * @param ev an [[GeneratedExpressionCode]] with unique terms.
+ * @return Java source code
+ */
+ protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
+ ctx.references += this
+ val objectTerm = ctx.freshName("obj")
+ s"""
+ /* expression: ${this} */
+ Object ${objectTerm} = expressions[${ctx.references.size - 1}].eval(i);
+ boolean ${ev.isNull} = ${objectTerm} == null;
+ ${ctx.javaType(this.dataType)} ${ev.primitive} = ${ctx.defaultValue(this.dataType)};
+ if (!${ev.isNull}) {
+ ${ev.primitive} = (${ctx.boxedType(this.dataType)})${objectTerm};
+ }
+ """
+ }
+
+ /**
* Returns `true` if this expression and all its children have been resolved to a specific schema
* and input data types checking passed, and `false` if it still contains any unresolved
* placeholders or has data types mismatch.
@@ -116,6 +155,41 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
override def nullable: Boolean = left.nullable || right.nullable
override def toString: String = s"($left $symbol $right)"
+
+ /**
+ * Short hand for generating binary evaluation code, which depends on two sub-evaluations of
+ * the same type. If either of the sub-expressions is null, the result of this computation
+ * is assumed to be null.
+ *
+ * @param f accepts two variable names and returns Java code to compute the output.
+ */
+ protected def defineCodeGen(
+ ctx: CodeGenContext,
+ ev: GeneratedExpressionCode,
+ f: (Term, Term) => Code): String = {
+ // TODO: Right now some timestamp tests fail if we enforce this...
+ if (left.dataType != right.dataType) {
+ // log.warn(s"${left.dataType} != ${right.dataType}")
+ }
+
+ val eval1 = left.gen(ctx)
+ val eval2 = right.gen(ctx)
+ val resultCode = f(eval1.primitive, eval2.primitive)
+
+ s"""
+ ${eval1.code}
+ boolean ${ev.isNull} = ${eval1.isNull};
+ ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+ if (!${ev.isNull}) {
+ ${eval2.code}
+ if(!${eval2.isNull}) {
+ ${ev.primitive} = $resultCode;
+ } else {
+ ${ev.isNull} = true;
+ }
+ }
+ """
+ }
}
private[sql] object BinaryExpression {
@@ -128,6 +202,32 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression]
abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] {
self: Product =>
+
+ /**
+ * Called by unary expressions to generate a code block that returns null if its parent returns
+ * null, and if not not null, use `f` to generate the expression.
+ *
+ * As an example, the following does a boolean inversion (i.e. NOT).
+ * {{{
+ * defineCodeGen(ctx, ev, c => s"!($c)")
+ * }}}
+ *
+ * @param f function that accepts a variable name and returns Java code to compute the output.
+ */
+ protected def defineCodeGen(
+ ctx: CodeGenContext,
+ ev: GeneratedExpressionCode,
+ f: Term => Code): Code = {
+ val eval = child.gen(ctx)
+ // reuse the previous isNull
+ ev.isNull = eval.isNull
+ eval.code + s"""
+ ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+ if (!${ev.isNull}) {
+ ${ev.primitive} = ${f(eval.primitive)};
+ }
+ """
+ }
}
// TODO Semantically we probably not need GroupExpression
http://git-wip-us.apache.org/repos/asf/spark/blob/5e7b6b67/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index a3770f9..3ac7c92 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.expressions.codegen.{Code, GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
@@ -49,6 +50,11 @@ case class UnaryMinus(child: Expression) extends UnaryArithmetic {
private lazy val numeric = TypeUtils.getNumeric(dataType)
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = dataType match {
+ case dt: DecimalType => defineCodeGen(ctx, ev, c => s"c.unary_$$minus()")
+ case dt: NumericType => defineCodeGen(ctx, ev, c => s"-($c)")
+ }
+
protected override def evalInternal(evalE: Any) = numeric.negate(evalE)
}
@@ -67,6 +73,21 @@ case class Sqrt(child: Expression) extends UnaryArithmetic {
if (value < 0) null
else math.sqrt(value)
}
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
+ val eval = child.gen(ctx)
+ eval.code + s"""
+ boolean ${ev.isNull} = ${eval.isNull};
+ ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+ if (!${ev.isNull}) {
+ if (${eval.primitive} < 0.0) {
+ ${ev.isNull} = true;
+ } else {
+ ${ev.primitive} = java.lang.Math.sqrt(${eval.primitive});
+ }
+ }
+ """
+ }
}
/**
@@ -86,6 +107,9 @@ case class Abs(child: Expression) extends UnaryArithmetic {
abstract class BinaryArithmetic extends BinaryExpression {
self: Product =>
+ /** Name of the function for this expression on a [[Decimal]] type. */
+ def decimalMethod: String = ""
+
override def dataType: DataType = left.dataType
override def checkInputDataTypes(): TypeCheckResult = {
@@ -114,6 +138,17 @@ abstract class BinaryArithmetic extends BinaryExpression {
}
}
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = dataType match {
+ case dt: DecimalType =>
+ defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)")
+ // byte and short are casted into int when add, minus, times or divide
+ case ByteType | ShortType =>
+ defineCodeGen(ctx, ev, (eval1, eval2) =>
+ s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)")
+ case _ =>
+ defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
+ }
+
protected def evalInternal(evalE1: Any, evalE2: Any): Any =
sys.error(s"BinaryArithmetics must override either eval or evalInternal")
}
@@ -124,6 +159,7 @@ private[sql] object BinaryArithmetic {
case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "+"
+ override def decimalMethod: String = "$plus"
override lazy val resolved =
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
@@ -138,6 +174,7 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "-"
+ override def decimalMethod: String = "$minus"
override lazy val resolved =
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
@@ -152,6 +189,7 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti
case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "*"
+ override def decimalMethod: String = "$times"
override lazy val resolved =
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
@@ -166,6 +204,8 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti
case class Divide(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "/"
+ override def decimalMethod: String = "$divide"
+
override def nullable: Boolean = true
override lazy val resolved =
@@ -192,10 +232,40 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
}
}
}
+
+ /**
+ * Special case handling due to division by 0 => null.
+ */
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
+ val eval1 = left.gen(ctx)
+ val eval2 = right.gen(ctx)
+ val test = if (left.dataType.isInstanceOf[DecimalType]) {
+ s"${eval2.primitive}.isZero()"
+ } else {
+ s"${eval2.primitive} == 0"
+ }
+ val method = if (left.dataType.isInstanceOf[DecimalType]) {
+ s".$decimalMethod"
+ } else {
+ s"$symbol"
+ }
+ eval1.code + eval2.code +
+ s"""
+ boolean ${ev.isNull} = false;
+ ${ctx.javaType(left.dataType)} ${ev.primitive} = ${ctx.defaultValue(left.dataType)};
+ if (${eval1.isNull} || ${eval2.isNull} || $test) {
+ ${ev.isNull} = true;
+ } else {
+ ${ev.primitive} = ${eval1.primitive}$method(${eval2.primitive});
+ }
+ """
+ }
}
case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "%"
+ override def decimalMethod: String = "reminder"
+
override def nullable: Boolean = true
override lazy val resolved =
@@ -222,6 +292,34 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
}
}
}
+
+ /**
+ * Special case handling for x % 0 ==> null.
+ */
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
+ val eval1 = left.gen(ctx)
+ val eval2 = right.gen(ctx)
+ val test = if (left.dataType.isInstanceOf[DecimalType]) {
+ s"${eval2.primitive}.isZero()"
+ } else {
+ s"${eval2.primitive} == 0"
+ }
+ val method = if (left.dataType.isInstanceOf[DecimalType]) {
+ s".$decimalMethod"
+ } else {
+ s"$symbol"
+ }
+ eval1.code + eval2.code +
+ s"""
+ boolean ${ev.isNull} = false;
+ ${ctx.javaType(left.dataType)} ${ev.primitive} = ${ctx.defaultValue(left.dataType)};
+ if (${eval1.isNull} || ${eval2.isNull} || $test) {
+ ${ev.isNull} = true;
+ } else {
+ ${ev.primitive} = ${eval1.primitive}$method(${eval2.primitive});
+ }
+ """
+ }
}
/**
@@ -271,7 +369,7 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet
}
/**
- * A function that calculates bitwise xor(^) of two numbers.
+ * A function that calculates bitwise xor of two numbers.
*/
case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "^"
@@ -313,6 +411,10 @@ case class BitwiseNot(child: Expression) extends UnaryArithmetic {
((evalE: Long) => ~evalE).asInstanceOf[(Any) => Any]
}
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
+ defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dataType)})~($c)")
+ }
+
protected override def evalInternal(evalE: Any) = not(evalE)
}
@@ -340,6 +442,33 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
}
}
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
+ if (ctx.isNativeType(left.dataType)) {
+ val eval1 = left.gen(ctx)
+ val eval2 = right.gen(ctx)
+ eval1.code + eval2.code + s"""
+ boolean ${ev.isNull} = false;
+ ${ctx.javaType(left.dataType)} ${ev.primitive} =
+ ${ctx.defaultValue(left.dataType)};
+
+ if (${eval1.isNull}) {
+ ${ev.isNull} = ${eval2.isNull};
+ ${ev.primitive} = ${eval2.primitive};
+ } else if (${eval2.isNull}) {
+ ${ev.isNull} = ${eval1.isNull};
+ ${ev.primitive} = ${eval1.primitive};
+ } else {
+ if (${eval1.primitive} > ${eval2.primitive}) {
+ ${ev.primitive} = ${eval1.primitive};
+ } else {
+ ${ev.primitive} = ${eval2.primitive};
+ }
+ }
+ """
+ } else {
+ super.genCode(ctx, ev)
+ }
+ }
override def toString: String = s"MaxOf($left, $right)"
}
@@ -367,5 +496,35 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
}
}
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
+ if (ctx.isNativeType(left.dataType)) {
+
+ val eval1 = left.gen(ctx)
+ val eval2 = right.gen(ctx)
+
+ eval1.code + eval2.code + s"""
+ boolean ${ev.isNull} = false;
+ ${ctx.javaType(left.dataType)} ${ev.primitive} =
+ ${ctx.defaultValue(left.dataType)};
+
+ if (${eval1.isNull}) {
+ ${ev.isNull} = ${eval2.isNull};
+ ${ev.primitive} = ${eval2.primitive};
+ } else if (${eval2.isNull}) {
+ ${ev.isNull} = ${eval1.isNull};
+ ${ev.primitive} = ${eval1.primitive};
+ } else {
+ if (${eval1.primitive} < ${eval2.primitive}) {
+ ${ev.primitive} = ${eval1.primitive};
+ } else {
+ ${ev.primitive} = ${eval2.primitive};
+ }
+ }
+ """
+ } else {
+ super.genCode(ctx, ev)
+ }
+ }
+
override def toString: String = s"MinOf($left, $right)"
}
http://git-wip-us.apache.org/repos/asf/spark/blob/5e7b6b67/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index cd60412..c8d0aaf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -24,7 +24,6 @@ import com.google.common.cache.{CacheBuilder, CacheLoader}
import org.codehaus.janino.ClassBodyEvaluator
import org.apache.spark.Logging
-import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
@@ -33,20 +32,166 @@ class IntegerHashSet extends org.apache.spark.util.collection.OpenHashSet[Int]
class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long]
/**
+ * Java source for evaluating an [[Expression]] given a [[Row]] of input.
+ *
+ * @param code The sequence of statements required to evaluate the expression.
+ * @param isNull A term that holds a boolean value representing whether the expression evaluated
+ * to null.
+ * @param primitive A term for a possible primitive value of the result of the evaluation. Not
+ * valid if `isNull` is set to `true`.
+ */
+case class GeneratedExpressionCode(var code: Code, var isNull: Term, var primitive: Term)
+
+/**
+ * A context for codegen, which is used to bookkeeping the expressions those are not supported
+ * by codegen, then they are evaluated directly. The unsupported expression is appended at the
+ * end of `references`, the position of it is kept in the code, used to access and evaluate it.
+ */
+class CodeGenContext {
+
+ /**
+ * Holding all the expressions those do not support codegen, will be evaluated directly.
+ */
+ val references: mutable.ArrayBuffer[Expression] = new mutable.ArrayBuffer[Expression]()
+
+ val stringType: String = classOf[UTF8String].getName
+ val decimalType: String = classOf[Decimal].getName
+
+ private val curId = new java.util.concurrent.atomic.AtomicInteger()
+
+ /**
+ * Returns a term name that is unique within this instance of a `CodeGenerator`.
+ *
+ * (Since we aren't in a macro context we do not seem to have access to the built in `freshName`
+ * function.)
+ */
+ def freshName(prefix: String): Term = {
+ s"$prefix${curId.getAndIncrement}"
+ }
+
+ /**
+ * Return the code to access a column for given DataType
+ */
+ def getColumn(dataType: DataType, ordinal: Int): Code = {
+ if (isNativeType(dataType)) {
+ s"i.${accessorForType(dataType)}($ordinal)"
+ } else {
+ s"(${boxedType(dataType)})i.apply($ordinal)"
+ }
+ }
+
+ /**
+ * Return the code to update a column in Row for given DataType
+ */
+ def setColumn(dataType: DataType, ordinal: Int, value: Term): Code = {
+ if (isNativeType(dataType)) {
+ s"${mutatorForType(dataType)}($ordinal, $value)"
+ } else {
+ s"update($ordinal, $value)"
+ }
+ }
+
+ /**
+ * Return the name of accessor in Row for a DataType
+ */
+ def accessorForType(dt: DataType): Term = dt match {
+ case IntegerType => "getInt"
+ case other => s"get${boxedType(dt)}"
+ }
+
+ /**
+ * Return the name of mutator in Row for a DataType
+ */
+ def mutatorForType(dt: DataType): Term = dt match {
+ case IntegerType => "setInt"
+ case other => s"set${boxedType(dt)}"
+ }
+
+ /**
+ * Return the Java type for a DataType
+ */
+ def javaType(dt: DataType): Term = dt match {
+ case IntegerType => "int"
+ case LongType => "long"
+ case ShortType => "short"
+ case ByteType => "byte"
+ case DoubleType => "double"
+ case FloatType => "float"
+ case BooleanType => "boolean"
+ case dt: DecimalType => decimalType
+ case BinaryType => "byte[]"
+ case StringType => stringType
+ case DateType => "int"
+ case TimestampType => "java.sql.Timestamp"
+ case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName
+ case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName
+ case _ => "Object"
+ }
+
+ /**
+ * Return the boxed type in Java
+ */
+ def boxedType(dt: DataType): Term = dt match {
+ case IntegerType => "Integer"
+ case LongType => "Long"
+ case ShortType => "Short"
+ case ByteType => "Byte"
+ case DoubleType => "Double"
+ case FloatType => "Float"
+ case BooleanType => "Boolean"
+ case DateType => "Integer"
+ case _ => javaType(dt)
+ }
+
+ /**
+ * Return the representation of default value for given DataType
+ */
+ def defaultValue(dt: DataType): Term = dt match {
+ case BooleanType => "false"
+ case FloatType => "-1.0f"
+ case ShortType => "(short)-1"
+ case LongType => "-1L"
+ case ByteType => "(byte)-1"
+ case DoubleType => "-1.0"
+ case IntegerType => "-1"
+ case DateType => "-1"
+ case _ => "null"
+ }
+
+ /**
+ * Returns a function to generate equal expression in Java
+ */
+ def equalFunc(dataType: DataType): ((Term, Term) => Code) = dataType match {
+ case BinaryType => { case (eval1, eval2) =>
+ s"java.util.Arrays.equals($eval1, $eval2)" }
+ case IntegerType | BooleanType | LongType | DoubleType | FloatType | ShortType | ByteType =>
+ { case (eval1, eval2) => s"$eval1 == $eval2" }
+ case other =>
+ { case (eval1, eval2) => s"$eval1.equals($eval2)" }
+ }
+
+ /**
+ * List of data types that have special accessors and setters in [[Row]].
+ */
+ val nativeTypes =
+ Seq(IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType)
+
+ /**
+ * Returns true if the data type has a special accessor and setter in [[Row]].
+ */
+ def isNativeType(dt: DataType): Boolean = nativeTypes.contains(dt)
+}
+
+/**
* A base class for generators of byte code to perform expression evaluation. Includes a set of
* helpers for referring to Catalyst types and building trees that perform evaluation of individual
* expressions.
*/
abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Logging {
- protected val rowType = classOf[Row].getName
- protected val stringType = classOf[UTF8String].getName
- protected val decimalType = classOf[Decimal].getName
- protected val exprType = classOf[Expression].getName
- protected val mutableRowType = classOf[MutableRow].getName
- protected val genericMutableRowType = classOf[GenericMutableRow].getName
-
- private val curId = new java.util.concurrent.atomic.AtomicInteger()
+ protected val exprType: String = classOf[Expression].getName
+ protected val mutableRowType: String = classOf[MutableRow].getName
+ protected val genericMutableRowType: String = classOf[GenericMutableRow].getName
/**
* Can be flipped on manually in the console to add (expensive) expression evaluation trace code.
@@ -75,10 +220,16 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
*/
protected def compile(code: String): Class[_] = {
val startTime = System.nanoTime()
- val clazz = new ClassBodyEvaluator(code).getClazz()
+ val clazz = try {
+ new ClassBodyEvaluator(code).getClazz()
+ } catch {
+ case e: Exception =>
+ logError(s"failed to compile:\n $code", e)
+ throw e
+ }
val endTime = System.nanoTime()
def timeMs: Double = (endTime - startTime).toDouble / 1000000
- logDebug(s"Compiled Java code (${code.size} bytes) in $timeMs ms")
+ logDebug(s"Code (${code.size} bytes) compiled in $timeMs ms")
clazz
}
@@ -113,585 +264,10 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
def generate(expressions: InType): OutType = cache.get(canonicalize(expressions))
/**
- * Returns a term name that is unique within this instance of a `CodeGenerator`.
- *
- * (Since we aren't in a macro context we do not seem to have access to the built in `freshName`
- * function.)
- */
- protected def freshName(prefix: String): String = {
- s"$prefix${curId.getAndIncrement}"
- }
-
- /**
- * Scala ASTs for evaluating an [[Expression]] given a [[Row]] of input.
- *
- * @param code The sequence of statements required to evaluate the expression.
- * @param nullTerm A term that holds a boolean value representing whether the expression evaluated
- * to null.
- * @param primitiveTerm A term for a possible primitive value of the result of the evaluation. Not
- * valid if `nullTerm` is set to `true`.
- * @param objectTerm A possibly boxed version of the result of evaluating this expression.
- */
- protected case class EvaluatedExpression(
- code: String,
- nullTerm: String,
- primitiveTerm: String,
- objectTerm: String)
-
- /**
- * A context for codegen, which is used to bookkeeping the expressions those are not supported
- * by codegen, then they are evaluated directly. The unsupported expression is appended at the
- * end of `references`, the position of it is kept in the code, used to access and evaluate it.
- */
- protected class CodeGenContext {
- /**
- * Holding all the expressions those do not support codegen, will be evaluated directly.
- */
- val references: mutable.ArrayBuffer[Expression] = new mutable.ArrayBuffer[Expression]()
- }
-
- /**
* Create a new codegen context for expression evaluator, used to store those
* expressions that don't support codegen
*/
def newCodeGenContext(): CodeGenContext = {
- new CodeGenContext()
- }
-
- /**
- * Given an expression tree returns an [[EvaluatedExpression]], which contains Scala trees that
- * can be used to determine the result of evaluating the expression on an input row.
- */
- def expressionEvaluator(e: Expression, ctx: CodeGenContext): EvaluatedExpression = {
- val primitiveTerm = freshName("primitiveTerm")
- val nullTerm = freshName("nullTerm")
- val objectTerm = freshName("objectTerm")
-
- implicit class Evaluate1(e: Expression) {
- def castOrNull(f: String => String, dataType: DataType): String = {
- val eval = expressionEvaluator(e, ctx)
- eval.code +
- s"""
- boolean $nullTerm = ${eval.nullTerm};
- ${primitiveForType(dataType)} $primitiveTerm = ${defaultPrimitive(dataType)};
- if (!$nullTerm) {
- $primitiveTerm = ${f(eval.primitiveTerm)};
- }
- """
- }
- }
-
- implicit class Evaluate2(expressions: (Expression, Expression)) {
-
- /**
- * Short hand for generating binary evaluation code, which depends on two sub-evaluations of
- * the same type. If either of the sub-expressions is null, the result of this computation
- * is assumed to be null.
- *
- * @param f a function from two primitive term names to a tree that evaluates them.
- */
- def evaluate(f: (String, String) => String): String =
- evaluateAs(expressions._1.dataType)(f)
-
- def evaluateAs(resultType: DataType)(f: (String, String) => String): String = {
- // TODO: Right now some timestamp tests fail if we enforce this...
- if (expressions._1.dataType != expressions._2.dataType) {
- log.warn(s"${expressions._1.dataType} != ${expressions._2.dataType}")
- }
-
- val eval1 = expressionEvaluator(expressions._1, ctx)
- val eval2 = expressionEvaluator(expressions._2, ctx)
- val resultCode = f(eval1.primitiveTerm, eval2.primitiveTerm)
-
- eval1.code + eval2.code +
- s"""
- boolean $nullTerm = ${eval1.nullTerm} || ${eval2.nullTerm};
- ${primitiveForType(resultType)} $primitiveTerm = ${defaultPrimitive(resultType)};
- if(!$nullTerm) {
- $primitiveTerm = (${primitiveForType(resultType)})($resultCode);
- }
- """
- }
- }
-
- val inputTuple = "i"
-
- // TODO: Skip generation of null handling code when expression are not nullable.
- val primitiveEvaluation: PartialFunction[Expression, String] = {
- case b @ BoundReference(ordinal, dataType, nullable) =>
- s"""
- final boolean $nullTerm = $inputTuple.isNullAt($ordinal);
- final ${primitiveForType(dataType)} $primitiveTerm = $nullTerm ?
- ${defaultPrimitive(dataType)} : (${getColumn(inputTuple, dataType, ordinal)});
- """
-
- case expressions.Literal(null, dataType) =>
- s"""
- final boolean $nullTerm = true;
- ${primitiveForType(dataType)} $primitiveTerm = ${defaultPrimitive(dataType)};
- """
-
- case expressions.Literal(value: UTF8String, StringType) =>
- val arr = s"new byte[]{${value.getBytes.map(_.toString).mkString(", ")}}"
- s"""
- final boolean $nullTerm = false;
- ${stringType} $primitiveTerm =
- new ${stringType}().set(${arr});
- """
-
- case expressions.Literal(value, FloatType) =>
- s"""
- final boolean $nullTerm = false;
- float $primitiveTerm = ${value}f;
- """
-
- case expressions.Literal(value, dt @ DecimalType()) =>
- s"""
- final boolean $nullTerm = false;
- ${primitiveForType(dt)} $primitiveTerm = new ${primitiveForType(dt)}().set($value);
- """
-
- case expressions.Literal(value, dataType) =>
- s"""
- final boolean $nullTerm = false;
- ${primitiveForType(dataType)} $primitiveTerm = $value;
- """
-
- case Cast(child @ BinaryType(), StringType) =>
- child.castOrNull(c =>
- s"new ${stringType}().set($c)",
- StringType)
-
- case Cast(child @ DateType(), StringType) =>
- child.castOrNull(c =>
- s"""new ${stringType}().set(
- org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""",
- StringType)
-
- case Cast(child @ BooleanType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
- child.castOrNull(c => s"(${primitiveForType(dt)})($c?1:0)", dt)
-
- case Cast(child @ DecimalType(), IntegerType) =>
- child.castOrNull(c => s"($c).toInt()", IntegerType)
-
- case Cast(child @ DecimalType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
- child.castOrNull(c => s"($c).to${termForType(dt)}()", dt)
-
- case Cast(child @ NumericType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
- child.castOrNull(c => s"(${primitiveForType(dt)})($c)", dt)
-
- // Special handling required for timestamps in hive test cases since the toString function
- // does not match the expected output.
- case Cast(e, StringType) if e.dataType != TimestampType =>
- e.castOrNull(c =>
- s"new ${stringType}().set(String.valueOf($c))",
- StringType)
-
- case EqualTo(e1 @ BinaryType(), e2 @ BinaryType()) =>
- (e1, e2).evaluateAs (BooleanType) {
- case (eval1, eval2) =>
- s"java.util.Arrays.equals((byte[])$eval1, (byte[])$eval2)"
- }
-
- case EqualTo(e1, e2) =>
- (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 == $eval2" }
-
- case GreaterThan(e1 @ NumericType(), e2 @ NumericType()) =>
- (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 > $eval2" }
- case GreaterThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) =>
- (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 >= $eval2" }
- case LessThan(e1 @ NumericType(), e2 @ NumericType()) =>
- (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 < $eval2" }
- case LessThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) =>
- (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 <= $eval2" }
-
- case And(e1, e2) =>
- val eval1 = expressionEvaluator(e1, ctx)
- val eval2 = expressionEvaluator(e2, ctx)
- s"""
- ${eval1.code}
- boolean $nullTerm = false;
- boolean $primitiveTerm = false;
-
- if (!${eval1.nullTerm} && !${eval1.primitiveTerm}) {
- } else {
- ${eval2.code}
- if (!${eval2.nullTerm} && !${eval2.primitiveTerm}) {
- } else if (!${eval1.nullTerm} && !${eval2.nullTerm}) {
- $primitiveTerm = true;
- } else {
- $nullTerm = true;
- }
- }
- """
-
- case Or(e1, e2) =>
- val eval1 = expressionEvaluator(e1, ctx)
- val eval2 = expressionEvaluator(e2, ctx)
-
- s"""
- ${eval1.code}
- boolean $nullTerm = false;
- boolean $primitiveTerm = false;
-
- if (!${eval1.nullTerm} && ${eval1.primitiveTerm}) {
- $primitiveTerm = true;
- } else {
- ${eval2.code}
- if (!${eval2.nullTerm} && ${eval2.primitiveTerm}) {
- $primitiveTerm = true;
- } else if (!${eval1.nullTerm} && !${eval2.nullTerm}) {
- $primitiveTerm = false;
- } else {
- $nullTerm = true;
- }
- }
- """
-
- case Not(child) =>
- // Uh, bad function name...
- child.castOrNull(c => s"!$c", BooleanType)
-
- case Add(e1 @ DecimalType(), e2 @ DecimalType()) =>
- (e1, e2) evaluate { case (eval1, eval2) => s"$eval1.$$plus($eval2)" }
- case Subtract(e1 @ DecimalType(), e2 @ DecimalType()) =>
- (e1, e2) evaluate { case (eval1, eval2) => s"$eval1.$$minus($eval2)" }
- case Multiply(e1 @ DecimalType(), e2 @ DecimalType()) =>
- (e1, e2) evaluate { case (eval1, eval2) => s"$eval1.$$times($eval2)" }
- case Divide(e1 @ DecimalType(), e2 @ DecimalType()) =>
- val eval1 = expressionEvaluator(e1, ctx)
- val eval2 = expressionEvaluator(e2, ctx)
- eval1.code + eval2.code +
- s"""
- boolean $nullTerm = false;
- ${primitiveForType(e1.dataType)} $primitiveTerm = null;
- if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm}.isZero()) {
- $nullTerm = true;
- } else {
- $primitiveTerm = ${eval1.primitiveTerm}.$$div${eval2.primitiveTerm});
- }
- """
- case Remainder(e1 @ DecimalType(), e2 @ DecimalType()) =>
- val eval1 = expressionEvaluator(e1, ctx)
- val eval2 = expressionEvaluator(e2, ctx)
- eval1.code + eval2.code +
- s"""
- boolean $nullTerm = false;
- ${primitiveForType(e1.dataType)} $primitiveTerm = 0;
- if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm}.isZero()) {
- $nullTerm = true;
- } else {
- $primitiveTerm = ${eval1.primitiveTerm}.remainder(${eval2.primitiveTerm});
- }
- """
-
- case Add(e1, e2) =>
- (e1, e2) evaluate { case (eval1, eval2) => s"$eval1 + $eval2" }
- case Subtract(e1, e2) =>
- (e1, e2) evaluate { case (eval1, eval2) => s"$eval1 - $eval2" }
- case Multiply(e1, e2) =>
- (e1, e2) evaluate { case (eval1, eval2) => s"$eval1 * $eval2" }
- case Divide(e1, e2) =>
- val eval1 = expressionEvaluator(e1, ctx)
- val eval2 = expressionEvaluator(e2, ctx)
- eval1.code + eval2.code +
- s"""
- boolean $nullTerm = false;
- ${primitiveForType(e1.dataType)} $primitiveTerm = 0;
- if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm} == 0) {
- $nullTerm = true;
- } else {
- $primitiveTerm = ${eval1.primitiveTerm} / ${eval2.primitiveTerm};
- }
- """
- case Remainder(e1, e2) =>
- val eval1 = expressionEvaluator(e1, ctx)
- val eval2 = expressionEvaluator(e2, ctx)
- eval1.code + eval2.code +
- s"""
- boolean $nullTerm = false;
- ${primitiveForType(e1.dataType)} $primitiveTerm = 0;
- if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm} == 0) {
- $nullTerm = true;
- } else {
- $primitiveTerm = ${eval1.primitiveTerm} % ${eval2.primitiveTerm};
- }
- """
-
- case IsNotNull(e) =>
- val eval = expressionEvaluator(e, ctx)
- s"""
- ${eval.code}
- boolean $nullTerm = false;
- boolean $primitiveTerm = !${eval.nullTerm};
- """
-
- case IsNull(e) =>
- val eval = expressionEvaluator(e, ctx)
- s"""
- ${eval.code}
- boolean $nullTerm = false;
- boolean $primitiveTerm = ${eval.nullTerm};
- """
-
- case e @ Coalesce(children) =>
- s"""
- boolean $nullTerm = true;
- ${primitiveForType(e.dataType)} $primitiveTerm = ${defaultPrimitive(e.dataType)};
- """ +
- children.map { c =>
- val eval = expressionEvaluator(c, ctx)
- s"""
- if($nullTerm) {
- ${eval.code}
- if(!${eval.nullTerm}) {
- $nullTerm = false;
- $primitiveTerm = ${eval.primitiveTerm};
- }
- }
- """
- }.mkString("\n")
-
- case e @ expressions.If(condition, trueValue, falseValue) =>
- val condEval = expressionEvaluator(condition, ctx)
- val trueEval = expressionEvaluator(trueValue, ctx)
- val falseEval = expressionEvaluator(falseValue, ctx)
-
- s"""
- boolean $nullTerm = false;
- ${primitiveForType(e.dataType)} $primitiveTerm = ${defaultPrimitive(e.dataType)};
- ${condEval.code}
- if(!${condEval.nullTerm} && ${condEval.primitiveTerm}) {
- ${trueEval.code}
- $nullTerm = ${trueEval.nullTerm};
- $primitiveTerm = ${trueEval.primitiveTerm};
- } else {
- ${falseEval.code}
- $nullTerm = ${falseEval.nullTerm};
- $primitiveTerm = ${falseEval.primitiveTerm};
- }
- """
-
- case NewSet(elementType) =>
- s"""
- boolean $nullTerm = false;
- ${hashSetForType(elementType)} $primitiveTerm = new ${hashSetForType(elementType)}();
- """
-
- case AddItemToSet(item, set) =>
- val itemEval = expressionEvaluator(item, ctx)
- val setEval = expressionEvaluator(set, ctx)
-
- val elementType = set.dataType.asInstanceOf[OpenHashSetUDT].elementType
- val htype = hashSetForType(elementType)
-
- itemEval.code + setEval.code +
- s"""
- if (!${itemEval.nullTerm} && !${setEval.nullTerm}) {
- (($htype)${setEval.primitiveTerm}).add(${itemEval.primitiveTerm});
- }
- boolean $nullTerm = false;
- ${htype} $primitiveTerm = ($htype)${setEval.primitiveTerm};
- """
-
- case CombineSets(left, right) =>
- val leftEval = expressionEvaluator(left, ctx)
- val rightEval = expressionEvaluator(right, ctx)
-
- val elementType = left.dataType.asInstanceOf[OpenHashSetUDT].elementType
- val htype = hashSetForType(elementType)
-
- leftEval.code + rightEval.code +
- s"""
- boolean $nullTerm = false;
- ${htype} $primitiveTerm =
- (${htype})${leftEval.primitiveTerm};
- $primitiveTerm.union((${htype})${rightEval.primitiveTerm});
- """
-
- case MaxOf(e1, e2) if !e1.dataType.isInstanceOf[DecimalType] =>
- val eval1 = expressionEvaluator(e1, ctx)
- val eval2 = expressionEvaluator(e2, ctx)
-
- eval1.code + eval2.code +
- s"""
- boolean $nullTerm = false;
- ${primitiveForType(e1.dataType)} $primitiveTerm = ${defaultPrimitive(e1.dataType)};
-
- if (${eval1.nullTerm}) {
- $nullTerm = ${eval2.nullTerm};
- $primitiveTerm = ${eval2.primitiveTerm};
- } else if (${eval2.nullTerm}) {
- $nullTerm = ${eval1.nullTerm};
- $primitiveTerm = ${eval1.primitiveTerm};
- } else {
- if (${eval1.primitiveTerm} > ${eval2.primitiveTerm}) {
- $primitiveTerm = ${eval1.primitiveTerm};
- } else {
- $primitiveTerm = ${eval2.primitiveTerm};
- }
- }
- """
-
- case MinOf(e1, e2) if !e1.dataType.isInstanceOf[DecimalType] =>
- val eval1 = expressionEvaluator(e1, ctx)
- val eval2 = expressionEvaluator(e2, ctx)
-
- eval1.code + eval2.code +
- s"""
- boolean $nullTerm = false;
- ${primitiveForType(e1.dataType)} $primitiveTerm = ${defaultPrimitive(e1.dataType)};
-
- if (${eval1.nullTerm}) {
- $nullTerm = ${eval2.nullTerm};
- $primitiveTerm = ${eval2.primitiveTerm};
- } else if (${eval2.nullTerm}) {
- $nullTerm = ${eval1.nullTerm};
- $primitiveTerm = ${eval1.primitiveTerm};
- } else {
- if (${eval1.primitiveTerm} < ${eval2.primitiveTerm}) {
- $primitiveTerm = ${eval1.primitiveTerm};
- } else {
- $primitiveTerm = ${eval2.primitiveTerm};
- }
- }
- """
-
- case UnscaledValue(child) =>
- val childEval = expressionEvaluator(child, ctx)
-
- childEval.code +
- s"""
- boolean $nullTerm = ${childEval.nullTerm};
- long $primitiveTerm = $nullTerm ? -1 : ${childEval.primitiveTerm}.toUnscaledLong();
- """
-
- case MakeDecimal(child, precision, scale) =>
- val eval = expressionEvaluator(child, ctx)
-
- eval.code +
- s"""
- boolean $nullTerm = ${eval.nullTerm};
- org.apache.spark.sql.types.Decimal $primitiveTerm = ${defaultPrimitive(DecimalType())};
-
- if (!$nullTerm) {
- $primitiveTerm = new org.apache.spark.sql.types.Decimal();
- $primitiveTerm = $primitiveTerm.setOrNull(${eval.primitiveTerm}, $precision, $scale);
- $nullTerm = $primitiveTerm == null;
- }
- """
- }
-
- // If there was no match in the partial function above, we fall back on calling the interpreted
- // expression evaluator.
- val code: String =
- primitiveEvaluation.lift.apply(e).getOrElse {
- logError(s"No rules to generate $e")
- ctx.references += e
- s"""
- /* expression: ${e} */
- Object $objectTerm = expressions[${ctx.references.size - 1}].eval(i);
- boolean $nullTerm = $objectTerm == null;
- ${primitiveForType(e.dataType)} $primitiveTerm = ${defaultPrimitive(e.dataType)};
- if (!$nullTerm) $primitiveTerm = (${termForType(e.dataType)})$objectTerm;
- """
- }
-
- EvaluatedExpression(code, nullTerm, primitiveTerm, objectTerm)
- }
-
- protected def getColumn(inputRow: String, dataType: DataType, ordinal: Int) = {
- dataType match {
- case StringType => s"(${stringType})$inputRow.apply($ordinal)"
- case dt: DataType if isNativeType(dt) => s"$inputRow.${accessorForType(dt)}($ordinal)"
- case _ => s"(${termForType(dataType)})$inputRow.apply($ordinal)"
- }
- }
-
- protected def setColumn(
- destinationRow: String,
- dataType: DataType,
- ordinal: Int,
- value: String): String = {
- dataType match {
- case StringType => s"$destinationRow.update($ordinal, $value)"
- case dt: DataType if isNativeType(dt) =>
- s"$destinationRow.${mutatorForType(dt)}($ordinal, $value)"
- case _ => s"$destinationRow.update($ordinal, $value)"
- }
- }
-
- protected def accessorForType(dt: DataType) = dt match {
- case IntegerType => "getInt"
- case other => s"get${termForType(dt)}"
- }
-
- protected def mutatorForType(dt: DataType) = dt match {
- case IntegerType => "setInt"
- case other => s"set${termForType(dt)}"
- }
-
- protected def hashSetForType(dt: DataType): String = dt match {
- case IntegerType => classOf[IntegerHashSet].getName
- case LongType => classOf[LongHashSet].getName
- case unsupportedType =>
- sys.error(s"Code generation not support for hashset of type $unsupportedType")
+ new CodeGenContext
}
-
- protected def primitiveForType(dt: DataType): String = dt match {
- case IntegerType => "int"
- case LongType => "long"
- case ShortType => "short"
- case ByteType => "byte"
- case DoubleType => "double"
- case FloatType => "float"
- case BooleanType => "boolean"
- case dt: DecimalType => decimalType
- case BinaryType => "byte[]"
- case StringType => stringType
- case DateType => "int"
- case TimestampType => "java.sql.Timestamp"
- case _ => "Object"
- }
-
- protected def defaultPrimitive(dt: DataType): String = dt match {
- case BooleanType => "false"
- case FloatType => "-1.0f"
- case ShortType => "-1"
- case LongType => "-1"
- case ByteType => "-1"
- case DoubleType => "-1.0"
- case IntegerType => "-1"
- case DateType => "-1"
- case dt: DecimalType => "null"
- case StringType => "null"
- case _ => "null"
- }
-
- protected def termForType(dt: DataType): String = dt match {
- case IntegerType => "Integer"
- case LongType => "Long"
- case ShortType => "Short"
- case ByteType => "Byte"
- case DoubleType => "Double"
- case FloatType => "Float"
- case BooleanType => "Boolean"
- case dt: DecimalType => decimalType
- case BinaryType => "byte[]"
- case StringType => stringType
- case DateType => "Integer"
- case TimestampType => "java.sql.Timestamp"
- case _ => "Object"
- }
-
- /**
- * List of data types that have special accessors and setters in [[Row]].
- */
- protected val nativeTypes =
- Seq(IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType)
-
- /**
- * Returns true if the data type has a special accessor and setter in [[Row]].
- */
- protected def isNativeType(dt: DataType) = nativeTypes.contains(dt)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/5e7b6b67/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index 638b53f..e5ee2ac 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -37,13 +37,13 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
protected def create(expressions: Seq[Expression]): (() => MutableProjection) = {
val ctx = newCodeGenContext()
val projectionCode = expressions.zipWithIndex.map { case (e, i) =>
- val evaluationCode = expressionEvaluator(e, ctx)
+ val evaluationCode = e.gen(ctx)
evaluationCode.code +
s"""
- if(${evaluationCode.nullTerm})
+ if(${evaluationCode.isNull})
mutableRow.setNullAt($i);
else
- ${setColumn("mutableRow", e.dataType, i, evaluationCode.primitiveTerm)};
+ mutableRow.${ctx.setColumn(e.dataType, i, evaluationCode.primitive)};
"""
}.mkString("\n")
val code = s"""
http://git-wip-us.apache.org/repos/asf/spark/blob/5e7b6b67/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
index 0ff840d..36e155d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
@@ -52,15 +52,15 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit
val ctx = newCodeGenContext()
val comparisons = ordering.zipWithIndex.map { case (order, i) =>
- val evalA = expressionEvaluator(order.child, ctx)
- val evalB = expressionEvaluator(order.child, ctx)
+ val evalA = order.child.gen(ctx)
+ val evalB = order.child.gen(ctx)
val asc = order.direction == Ascending
val compare = order.child.dataType match {
case BinaryType =>
s"""
{
- byte[] x = ${if (asc) evalA.primitiveTerm else evalB.primitiveTerm};
- byte[] y = ${if (!asc) evalB.primitiveTerm else evalA.primitiveTerm};
+ byte[] x = ${if (asc) evalA.primitive else evalB.primitive};
+ byte[] y = ${if (!asc) evalB.primitive else evalA.primitive};
int j = 0;
while (j < x.length && j < y.length) {
if (x[j] != y[j]) return x[j] - y[j];
@@ -73,8 +73,8 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit
}"""
case _: NumericType =>
s"""
- if (${evalA.primitiveTerm} != ${evalB.primitiveTerm}) {
- if (${evalA.primitiveTerm} > ${evalB.primitiveTerm}) {
+ if (${evalA.primitive} != ${evalB.primitive}) {
+ if (${evalA.primitive} > ${evalB.primitive}) {
return ${if (asc) "1" else "-1"};
} else {
return ${if (asc) "-1" else "1"};
@@ -82,7 +82,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit
}"""
case _ =>
s"""
- int comp = ${evalA.primitiveTerm}.compare(${evalB.primitiveTerm});
+ int comp = ${evalA.primitive}.compare(${evalB.primitive});
if (comp != 0) {
return ${if (asc) "comp" else "-comp"};
}"""
@@ -93,11 +93,11 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit
${evalA.code}
i = $b;
${evalB.code}
- if (${evalA.nullTerm} && ${evalB.nullTerm}) {
+ if (${evalA.isNull} && ${evalB.isNull}) {
// Nothing
- } else if (${evalA.nullTerm}) {
+ } else if (${evalA.isNull}) {
return ${if (order.direction == Ascending) "-1" else "1"};
- } else if (${evalB.nullTerm}) {
+ } else if (${evalB.isNull}) {
return ${if (order.direction == Ascending) "1" else "-1"};
} else {
$compare
http://git-wip-us.apache.org/repos/asf/spark/blob/5e7b6b67/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
index fb18769..4a547b5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
@@ -38,7 +38,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (Row) => Boolean] {
protected def create(predicate: Expression): ((Row) => Boolean) = {
val ctx = newCodeGenContext()
- val eval = expressionEvaluator(predicate, ctx)
+ val eval = predicate.gen(ctx)
val code = s"""
import org.apache.spark.sql.Row;
@@ -55,7 +55,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (Row) => Boolean] {
@Override
public boolean eval(Row i) {
${eval.code}
- return !${eval.nullTerm} && ${eval.primitiveTerm};
+ return !${eval.isNull} && ${eval.primitive};
}
}"""
http://git-wip-us.apache.org/repos/asf/spark/blob/5e7b6b67/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
index d5be1fc..7caf4aa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
@@ -45,19 +45,19 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
val ctx = newCodeGenContext()
val columns = expressions.zipWithIndex.map {
case (e, i) =>
- s"private ${primitiveForType(e.dataType)} c$i = ${defaultPrimitive(e.dataType)};\n"
+ s"private ${ctx.javaType(e.dataType)} c$i = ${ctx.defaultValue(e.dataType)};\n"
}.mkString("\n ")
val initColumns = expressions.zipWithIndex.map {
case (e, i) =>
- val eval = expressionEvaluator(e, ctx)
+ val eval = e.gen(ctx)
s"""
{
// column$i
${eval.code}
- nullBits[$i] = ${eval.nullTerm};
- if(!${eval.nullTerm}) {
- c$i = ${eval.primitiveTerm};
+ nullBits[$i] = ${eval.isNull};
+ if (!${eval.isNull}) {
+ c$i = ${eval.primitive};
}
}
"""
@@ -68,10 +68,10 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
}.mkString("\n ")
val updateCases = expressions.zipWithIndex.map { case (e, i) =>
- s"case $i: { c$i = (${termForType(e.dataType)})value; return;}"
+ s"case $i: { c$i = (${ctx.boxedType(e.dataType)})value; return;}"
}.mkString("\n ")
- val specificAccessorFunctions = nativeTypes.map { dataType =>
+ val specificAccessorFunctions = ctx.nativeTypes.map { dataType =>
val cases = expressions.zipWithIndex.map {
case (e, i) if e.dataType == dataType =>
s"case $i: return c$i;"
@@ -80,21 +80,21 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
if (cases.count(_ != '\n') > 0) {
s"""
@Override
- public ${primitiveForType(dataType)} ${accessorForType(dataType)}(int i) {
+ public ${ctx.javaType(dataType)} ${ctx.accessorForType(dataType)}(int i) {
if (isNullAt(i)) {
- return ${defaultPrimitive(dataType)};
+ return ${ctx.defaultValue(dataType)};
}
switch (i) {
$cases
}
- return ${defaultPrimitive(dataType)};
+ return ${ctx.defaultValue(dataType)};
}"""
} else {
""
}
}.mkString("\n")
- val specificMutatorFunctions = nativeTypes.map { dataType =>
+ val specificMutatorFunctions = ctx.nativeTypes.map { dataType =>
val cases = expressions.zipWithIndex.map {
case (e, i) if e.dataType == dataType =>
s"case $i: { c$i = value; return; }"
@@ -103,7 +103,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
if (cases.count(_ != '\n') > 0) {
s"""
@Override
- public void ${mutatorForType(dataType)}(int i, ${primitiveForType(dataType)} value) {
+ public void ${ctx.mutatorForType(dataType)}(int i, ${ctx.javaType(dataType)} value) {
nullBits[i] = false;
switch (i) {
$cases
@@ -122,7 +122,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
case LongType => s"$col ^ ($col >>> 32)"
case FloatType => s"Float.floatToIntBits($col)"
case DoubleType =>
- s"Double.doubleToLongBits($col) ^ (Double.doubleToLongBits($col) >>> 32)"
+ s"(int)(Double.doubleToLongBits($col) ^ (Double.doubleToLongBits($col) >>> 32))"
case _ => s"$col.hashCode()"
}
s"isNullAt($i) ? 0 : ($nonNull)"
http://git-wip-us.apache.org/repos/asf/spark/blob/5e7b6b67/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala
index 7f1b12c..6f9589d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala
@@ -27,6 +27,9 @@ import org.apache.spark.util.Utils
*/
package object codegen {
+ type Term = String
+ type Code = String
+
/** Canonicalizes an expression so those that differ only by names can reuse the same code. */
object ExpressionCanonicalizer extends rules.RuleExecutor[Expression] {
val batches =
http://git-wip-us.apache.org/repos/asf/spark/blob/5e7b6b67/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
index 65ba189..ddfadf3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
+import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext}
import org.apache.spark.sql.types._
/** Return the unscaled Long value of a Decimal, assuming it fits in a Long */
@@ -35,6 +36,10 @@ case class UnscaledValue(child: Expression) extends UnaryExpression {
childResult.asInstanceOf[Decimal].toUnscaledLong
}
}
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
+ defineCodeGen(ctx, ev, c => s"$c.toUnscaledLong()")
+ }
}
/** Create a Decimal from an unscaled Long value */
@@ -53,4 +58,18 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un
new Decimal().setOrNull(childResult.asInstanceOf[Long], precision, scale)
}
}
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
+ val eval = child.gen(ctx)
+ eval.code + s"""
+ boolean ${ev.isNull} = ${eval.isNull};
+ ${ctx.decimalType} ${ev.primitive} = null;
+
+ if (!${ev.isNull}) {
+ ${ev.primitive} = (new ${ctx.decimalType}()).setOrNull(
+ ${eval.primitive}, $precision, $scale);
+ ${ev.isNull} = ${ev.primitive} == null;
+ }
+ """
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/5e7b6b67/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index d3ca3d9..3a92716 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import java.sql.{Date, Timestamp}
import org.apache.spark.sql.catalyst.CatalystTypeConverters
+import org.apache.spark.sql.catalyst.expressions.codegen.{Code, CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.sql.types._
@@ -78,7 +79,60 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres
override def toString: String = if (value != null) value.toString else "null"
+ override def equals(other: Any): Boolean = other match {
+ case o: Literal =>
+ dataType.equals(o.dataType) &&
+ (value == null && null == o.value || value != null && value.equals(o.value))
+ case _ => false
+ }
+
override def eval(input: Row): Any = value
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
+ // change the isNull and primitive to consts, to inline them
+ if (value == null) {
+ ev.isNull = "true"
+ ev.primitive = ctx.defaultValue(dataType)
+ ""
+ } else {
+ dataType match {
+ case BooleanType =>
+ ev.isNull = "false"
+ ev.primitive = value.toString
+ ""
+ case FloatType => // This must go before NumericType
+ val v = value.asInstanceOf[Float]
+ if (v.isNaN || v.isInfinite) {
+ super.genCode(ctx, ev)
+ } else {
+ ev.isNull = "false"
+ ev.primitive = s"${value}f"
+ ""
+ }
+ case DoubleType => // This must go before NumericType
+ val v = value.asInstanceOf[Double]
+ if (v.isNaN || v.isInfinite) {
+ super.genCode(ctx, ev)
+ } else {
+ ev.isNull = "false"
+ ev.primitive = s"${value}"
+ ""
+ }
+
+ case ByteType | ShortType => // This must go before NumericType
+ ev.isNull = "false"
+ ev.primitive = s"(${ctx.javaType(dataType)})$value"
+ ""
+ case dt: NumericType if !dt.isInstanceOf[DecimalType] =>
+ ev.isNull = "false"
+ ev.primitive = value.toString
+ ""
+ // eval() version may be faster for non-primitive types
+ case other =>
+ super.genCode(ctx, ev)
+ }
+ }
+ }
}
// TODO: Specialize
http://git-wip-us.apache.org/repos/asf/spark/blob/5e7b6b67/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala
index db853a2..88211ac 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.expressions.mathfuncs
+import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, BinaryExpression, Expression, Row}
import org.apache.spark.sql.types._
@@ -49,6 +50,10 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String)
}
}
}
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
+ defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.${name.toLowerCase}($c1, $c2)")
+ }
}
case class Atan2(left: Expression, right: Expression)
@@ -70,9 +75,26 @@ case class Atan2(left: Expression, right: Expression)
}
}
}
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
+ defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)") + s"""
+ if (Double.valueOf(${ev.primitive}).isNaN()) {
+ ${ev.isNull} = true;
+ }
+ """
+ }
}
case class Hypot(left: Expression, right: Expression)
extends BinaryMathExpression(math.hypot, "HYPOT")
-case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER")
+case class Pow(left: Expression, right: Expression)
+ extends BinaryMathExpression(math.pow, "POWER") {
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
+ defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") + s"""
+ if (Double.valueOf(${ev.primitive}).isNaN()) {
+ ${ev.isNull} = true;
+ }
+ """
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/5e7b6b67/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala
index 41b4223..5563cd9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.expressions.mathfuncs
+import org.apache.spark.sql.catalyst.expressions.codegen.{Code, CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, Row, UnaryExpression}
import org.apache.spark.sql.types._
@@ -44,6 +45,23 @@ abstract class UnaryMathExpression(f: Double => Double, name: String)
if (result.isNaN) null else result
}
}
+
+ // name of function in java.lang.Math
+ def funcName: String = name.toLowerCase
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
+ val eval = child.gen(ctx)
+ eval.code + s"""
+ boolean ${ev.isNull} = ${eval.isNull};
+ ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+ if (!${ev.isNull}) {
+ ${ev.primitive} = java.lang.Math.${funcName}(${eval.primitive});
+ if (Double.valueOf(${ev.primitive}).isNaN()) {
+ ${ev.isNull} = true;
+ }
+ }
+ """
+ }
}
case class Acos(child: Expression) extends UnaryMathExpression(math.acos, "ACOS")
@@ -72,7 +90,9 @@ case class Log10(child: Expression) extends UnaryMathExpression(math.log10, "LOG
case class Log1p(child: Expression) extends UnaryMathExpression(math.log1p, "LOG1P")
-case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND")
+case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND") {
+ override def funcName: String = "rint"
+}
case class Signum(child: Expression) extends UnaryMathExpression(math.signum, "SIGNUM")
@@ -84,6 +104,10 @@ case class Tan(child: Expression) extends UnaryMathExpression(math.tan, "TAN")
case class Tanh(child: Expression) extends UnaryMathExpression(math.tanh, "TANH")
-case class ToDegrees(child: Expression) extends UnaryMathExpression(math.toDegrees, "DEGREES")
+case class ToDegrees(child: Expression) extends UnaryMathExpression(math.toDegrees, "DEGREES") {
+ override def funcName: String = "toDegrees"
+}
-case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadians, "RADIANS")
+case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadians, "RADIANS") {
+ override def funcName: String = "toRadians"
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/5e7b6b67/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 00565ec..2e4b9ba 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -17,10 +17,10 @@
package org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.errors.TreeNodeException
-import org.apache.spark.sql.catalyst.trees.LeafNode
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
+import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.types._
object NamedExpression {
@@ -116,6 +116,8 @@ case class Alias(child: Expression, name: String)(
override def eval(input: Row): Any = child.eval(input)
+ override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx)
+
override def dataType: DataType = child.dataType
override def nullable: Boolean = child.nullable
override def metadata: Metadata = {
http://git-wip-us.apache.org/repos/asf/spark/blob/5e7b6b67/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
index 5070570..9ecfb3c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
+import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext}
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.types.DataType
@@ -51,6 +52,25 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
}
result
}
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
+ s"""
+ boolean ${ev.isNull} = true;
+ ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+ """ +
+ children.map { e =>
+ val eval = e.gen(ctx)
+ s"""
+ if (${ev.isNull}) {
+ ${eval.code}
+ if (!${eval.isNull}) {
+ ${ev.isNull} = false;
+ ${ev.primitive} = ${eval.primitive};
+ }
+ }
+ """
+ }.mkString("\n")
+ }
}
case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] {
@@ -61,6 +81,13 @@ case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expr
child.eval(input) == null
}
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
+ val eval = child.gen(ctx)
+ ev.isNull = "false"
+ ev.primitive = eval.isNull
+ eval.code
+ }
+
override def toString: String = s"IS NULL $child"
}
@@ -72,6 +99,13 @@ case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[E
override def eval(input: Row): Any = {
child.eval(input) != null
}
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
+ val eval = child.gen(ctx)
+ ev.isNull = "false"
+ ev.primitive = s"(!(${eval.isNull}))"
+ eval.code
+ }
}
/**
@@ -95,4 +129,25 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate
}
numNonNulls >= n
}
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
+ val nonnull = ctx.freshName("nonnull")
+ val code = children.map { e =>
+ val eval = e.gen(ctx)
+ s"""
+ if ($nonnull < $n) {
+ ${eval.code}
+ if (!${eval.isNull}) {
+ $nonnull += 1;
+ }
+ }
+ """
+ }.mkString("\n")
+ s"""
+ int $nonnull = 0;
+ $code
+ boolean ${ev.isNull} = false;
+ boolean ${ev.primitive} = $nonnull >= $n;
+ """
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/5e7b6b67/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 58273b1..1d0f19a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -18,9 +18,10 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.TypeUtils
-import org.apache.spark.sql.types.{BinaryType, BooleanType, DataType}
+import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.types._
object InterpretedPredicate {
def create(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) =
@@ -82,6 +83,10 @@ case class Not(child: Expression) extends UnaryExpression with Predicate with Ex
case b: Boolean => !b
}
}
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
+ defineCodeGen(ctx, ev, c => s"!($c)")
+ }
}
/**
@@ -141,6 +146,29 @@ case class And(left: Expression, right: Expression)
}
}
}
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
+ val eval1 = left.gen(ctx)
+ val eval2 = right.gen(ctx)
+
+ // The result should be `false`, if any of them is `false` whenever the other is null or not.
+ s"""
+ ${eval1.code}
+ boolean ${ev.isNull} = false;
+ boolean ${ev.primitive} = false;
+
+ if (!${eval1.isNull} && !${eval1.primitive}) {
+ } else {
+ ${eval2.code}
+ if (!${eval2.isNull} && !${eval2.primitive}) {
+ } else if (!${eval1.isNull} && !${eval2.isNull}) {
+ ${ev.primitive} = true;
+ } else {
+ ${ev.isNull} = true;
+ }
+ }
+ """
+ }
}
case class Or(left: Expression, right: Expression)
@@ -167,6 +195,29 @@ case class Or(left: Expression, right: Expression)
}
}
}
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
+ val eval1 = left.gen(ctx)
+ val eval2 = right.gen(ctx)
+
+ // The result should be `true`, if any of them is `true` whenever the other is null or not.
+ s"""
+ ${eval1.code}
+ boolean ${ev.isNull} = false;
+ boolean ${ev.primitive} = true;
+
+ if (!${eval1.isNull} && ${eval1.primitive}) {
+ } else {
+ ${eval2.code}
+ if (!${eval2.isNull} && ${eval2.primitive}) {
+ } else if (!${eval1.isNull} && !${eval2.isNull}) {
+ ${ev.primitive} = false;
+ } else {
+ ${ev.isNull} = true;
+ }
+ }
+ """
+ }
}
abstract class BinaryComparison extends BinaryExpression with Predicate {
@@ -198,6 +249,20 @@ abstract class BinaryComparison extends BinaryExpression with Predicate {
}
}
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
+ left.dataType match {
+ case dt: NumericType if ctx.isNativeType(dt) => defineCodeGen (ctx, ev, {
+ (c1, c3) => s"$c1 $symbol $c3"
+ })
+ case TimestampType =>
+ // java.sql.Timestamp does not have compare()
+ super.genCode(ctx, ev)
+ case other => defineCodeGen (ctx, ev, {
+ (c1, c2) => s"$c1.compare($c2) $symbol 0"
+ })
+ }
+ }
+
protected def evalInternal(evalE1: Any, evalE2: Any): Any =
sys.error(s"BinaryComparisons must override either eval or evalInternal")
}
@@ -215,6 +280,9 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison
if (left.dataType != BinaryType) l == r
else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]])
}
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
+ defineCodeGen(ctx, ev, ctx.equalFunc(left.dataType))
+ }
}
case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComparison {
@@ -235,6 +303,17 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
l == r
}
}
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
+ val eval1 = left.gen(ctx)
+ val eval2 = right.gen(ctx)
+ val equalCode = ctx.equalFunc(left.dataType)(eval1.primitive, eval2.primitive)
+ ev.isNull = "false"
+ eval1.code + eval2.code + s"""
+ boolean ${ev.primitive} = (${eval1.isNull} && ${eval2.isNull}) ||
+ (!${eval1.isNull} && $equalCode);
+ """
+ }
}
case class LessThan(left: Expression, right: Expression) extends BinaryComparison {
@@ -309,6 +388,27 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
}
}
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
+ val condEval = predicate.gen(ctx)
+ val trueEval = trueValue.gen(ctx)
+ val falseEval = falseValue.gen(ctx)
+
+ s"""
+ ${condEval.code}
+ boolean ${ev.isNull} = false;
+ ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+ if (!${condEval.isNull} && ${condEval.primitive}) {
+ ${trueEval.code}
+ ${ev.isNull} = ${trueEval.isNull};
+ ${ev.primitive} = ${trueEval.primitive};
+ } else {
+ ${falseEval.code}
+ ${ev.isNull} = ${falseEval.isNull};
+ ${ev.primitive} = ${falseEval.primitive};
+ }
+ """
+ }
+
override def toString: String = s"if ($predicate) $trueValue else $falseValue"
}
@@ -393,6 +493,48 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike {
return res
}
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
+ val len = branchesArr.length
+ val got = ctx.freshName("got")
+
+ val cases = (0 until len/2).map { i =>
+ val cond = branchesArr(i * 2).gen(ctx)
+ val res = branchesArr(i * 2 + 1).gen(ctx)
+ s"""
+ if (!$got) {
+ ${cond.code}
+ if (!${cond.isNull} && ${cond.primitive}) {
+ $got = true;
+ ${res.code}
+ ${ev.isNull} = ${res.isNull};
+ ${ev.primitive} = ${res.primitive};
+ }
+ }
+ """
+ }.mkString("\n")
+
+ val other = if (len % 2 == 1) {
+ val res = branchesArr(len - 1).gen(ctx)
+ s"""
+ if (!$got) {
+ ${res.code}
+ ${ev.isNull} = ${res.isNull};
+ ${ev.primitive} = ${res.primitive};
+ }
+ """
+ } else {
+ ""
+ }
+
+ s"""
+ boolean $got = false;
+ boolean ${ev.isNull} = true;
+ ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+ $cases
+ $other
+ """
+ }
+
override def toString: String = {
"CASE" + branches.sliding(2, 2).map {
case Seq(cond, value) => s" WHEN $cond THEN $value"
@@ -444,6 +586,52 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW
return res
}
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
+ val keyEval = key.gen(ctx)
+ val len = branchesArr.length
+ val got = ctx.freshName("got")
+
+ val cases = (0 until len/2).map { i =>
+ val cond = branchesArr(i * 2).gen(ctx)
+ val res = branchesArr(i * 2 + 1).gen(ctx)
+ s"""
+ if (!$got) {
+ ${cond.code}
+ if (${keyEval.isNull} && ${cond.isNull} ||
+ !${keyEval.isNull} && !${cond.isNull}
+ && ${ctx.equalFunc(key.dataType)(keyEval.primitive, cond.primitive)}) {
+ $got = true;
+ ${res.code}
+ ${ev.isNull} = ${res.isNull};
+ ${ev.primitive} = ${res.primitive};
+ }
+ }
+ """
+ }.mkString("\n")
+
+ val other = if (len % 2 == 1) {
+ val res = branchesArr(len - 1).gen(ctx)
+ s"""
+ if (!$got) {
+ ${res.code}
+ ${ev.isNull} = ${res.isNull};
+ ${ev.primitive} = ${res.primitive};
+ }
+ """
+ } else {
+ ""
+ }
+
+ s"""
+ boolean $got = false;
+ boolean ${ev.isNull} = true;
+ ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+ ${keyEval.code}
+ $cases
+ $other
+ """
+ }
+
private def equalNullSafe(l: Any, r: Any) = {
if (l == null && r == null) {
true
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org