You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2015/07/23 08:44:18 UTC

spark git commit: [SPARK-8935] [SQL] Implement code generation for all casts

Repository: spark
Updated Branches:
  refs/heads/master 825ab1e45 -> 6d0d8b406


[SPARK-8935] [SQL] Implement code generation for all casts

JIRA: https://issues.apache.org/jira/browse/SPARK-8935

Author: Yijie Shen <he...@gmail.com>

Closes #7365 from yjshen/cast_codegen and squashes the following commits:

ef6e8b5 [Yijie Shen] getColumn and setColumn in struct cast, autounboxing in array and map
eaece18 [Yijie Shen] remove null case in cast code gen
fd7eba4 [Yijie Shen] resolve comments
80378a5 [Yijie Shen] the missing self cast
611d66e [Yijie Shen] Bug fix: NullType & primitive object unboxing
6d5c0fe [Yijie Shen] rebase and add Interval codegen
9424b65 [Yijie Shen] tiny style fix
4a1c801 [Yijie Shen] remove CodeHolder class, use function instead.
3f5df88 [Yijie Shen] CodeHolder for complex dataTypes
c286f13 [Yijie Shen] moved all the cast code into class body
4edfd76 [Yijie Shen] [WIP] finished primitive part


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6d0d8b40
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6d0d8b40
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6d0d8b40

Branch: refs/heads/master
Commit: 6d0d8b406942edcf9fc97e76fb227ff1eb35ca3a
Parents: 825ab1e
Author: Yijie Shen <he...@gmail.com>
Authored: Wed Jul 22 23:44:08 2015 -0700
Committer: Davies Liu <da...@gmail.com>
Committed: Wed Jul 22 23:44:08 2015 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/expressions/Cast.scala   | 523 +++++++++++++++++--
 .../expressions/DateExpressionsSuite.scala      |  36 +-
 2 files changed, 508 insertions(+), 51 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6d0d8b40/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 3346d3c..e66cd82 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -26,6 +26,8 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.{Interval, UTF8String}
 
+import scala.collection.mutable
+
 
 object Cast {
 
@@ -418,51 +420,506 @@ case class Cast(child: Expression, dataType: DataType)
   protected override def nullSafeEval(input: Any): Any = cast(input)
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
-    // TODO: Add support for more data types.
-    (child.dataType, dataType) match {
+    val eval = child.gen(ctx)
+    val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx)
+    eval.code +
+      castCode(ctx, eval.primitive, eval.isNull, ev.primitive, ev.isNull, dataType, nullSafeCast)
+  }
+
+  // three function arguments are: child.primitive, result.primitive and result.isNull
+  // it returns the code snippets to be put in null safe evaluation region
+  private[this] type CastFunction = (String, String, String) => String
+
+  private[this] def nullSafeCastFunction(
+      from: DataType,
+      to: DataType,
+      ctx: CodeGenContext): CastFunction = to match {
+
+    case _ if from == NullType => (c, evPrim, evNull) => s"$evNull = true;"
+    case _ if to == from => (c, evPrim, evNull) => s"$evPrim = $c;"
+    case StringType => castToStringCode(from, ctx)
+    case BinaryType => castToBinaryCode(from)
+    case DateType => castToDateCode(from, ctx)
+    case decimal: DecimalType => castToDecimalCode(from, decimal)
+    case TimestampType => castToTimestampCode(from, ctx)
+    case IntervalType => castToIntervalCode(from)
+    case BooleanType => castToBooleanCode(from)
+    case ByteType => castToByteCode(from)
+    case ShortType => castToShortCode(from)
+    case IntegerType => castToIntCode(from)
+    case FloatType => castToFloatCode(from)
+    case LongType => castToLongCode(from)
+    case DoubleType => castToDoubleCode(from)
+
+    case array: ArrayType => castArrayCode(from.asInstanceOf[ArrayType], array, ctx)
+    case map: MapType => castMapCode(from.asInstanceOf[MapType], map, ctx)
+    case struct: StructType => castStructCode(from.asInstanceOf[StructType], struct, ctx)
+  }
+
+  // Since we need to cast child expressions recursively inside ComplexTypes, such as Map's
+  // Key and Value, Struct's field, we need to name out all the variable names involved in a cast.
+  private[this] def castCode(ctx: CodeGenContext, childPrim: String, childNull: String,
+    resultPrim: String, resultNull: String, resultType: DataType, cast: CastFunction): String = {
+    s"""
+      boolean $resultNull = $childNull;
+      ${ctx.javaType(resultType)} $resultPrim = ${ctx.defaultValue(resultType)};
+      if (!${childNull}) {
+        ${cast(childPrim, resultPrim, resultNull)}
+      }
+    """
+  }
+
+  private[this] def castToStringCode(from: DataType, ctx: CodeGenContext): CastFunction = {
+    from match {
+      case BinaryType =>
+        (c, evPrim, evNull) => s"$evPrim = UTF8String.fromBytes($c);"
+      case DateType =>
+        (c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString(
+          org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($c));"""
+      case TimestampType =>
+        (c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString(
+          org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c));"""
+      case _ =>
+        (c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));"
+    }
+  }
+
+  private[this] def castToBinaryCode(from: DataType): CastFunction = from match {
+    case StringType =>
+      (c, evPrim, evNull) => s"$evPrim = $c.getBytes();"
+  }
+
+  private[this] def castToDateCode(
+      from: DataType,
+      ctx: CodeGenContext): CastFunction = from match {
+    case StringType =>
+      val intOpt = ctx.freshName("intOpt")
+      (c, evPrim, evNull) => s"""
+        scala.Option<Integer> $intOpt =
+          org.apache.spark.sql.catalyst.util.DateTimeUtils.stringToDate($c);
+        if ($intOpt.isDefined()) {
+          $evPrim = ((Integer) $intOpt.get()).intValue();
+        } else {
+          $evNull = true;
+        }
+       """
+    case TimestampType =>
+      (c, evPrim, evNull) =>
+        s"$evPrim = org.apache.spark.sql.catalyst.util.DateTimeUtils.millisToDays($c / 1000L);";
+    case _ =>
+      (c, evPrim, evNull) => s"$evNull = true;"
+  }
+
+  private[this] def changePrecision(d: String, decimalType: DecimalType,
+      evPrim: String, evNull: String): String = {
+    decimalType match {
+      case DecimalType.Unlimited =>
+        s"$evPrim = $d;"
+      case DecimalType.Fixed(precision, scale) =>
+        s"""
+          if ($d.changePrecision($precision, $scale)) {
+            $evPrim = $d;
+          } else {
+            $evNull = true;
+          }
+        """
+    }
+  }
 
-      case (BinaryType, StringType) =>
-        defineCodeGen (ctx, ev, c =>
-          s"UTF8String.fromBytes($c)")
+  private[this] def castToDecimalCode(from: DataType, target: DecimalType): CastFunction = {
+    from match {
+      case StringType =>
+        (c, evPrim, evNull) =>
+          s"""
+            try {
+              org.apache.spark.sql.types.Decimal tmpDecimal =
+                new org.apache.spark.sql.types.Decimal().set(
+                  new scala.math.BigDecimal(
+                    new java.math.BigDecimal($c.toString())));
+              ${changePrecision("tmpDecimal", target, evPrim, evNull)}
+            } catch (java.lang.NumberFormatException e) {
+              $evNull = true;
+            }
+          """
+      case BooleanType =>
+        (c, evPrim, evNull) =>
+          s"""
+            org.apache.spark.sql.types.Decimal tmpDecimal = null;
+            if ($c) {
+              tmpDecimal = new org.apache.spark.sql.types.Decimal().set(1);
+            } else {
+              tmpDecimal = new org.apache.spark.sql.types.Decimal().set(0);
+            }
+            ${changePrecision("tmpDecimal", target, evPrim, evNull)}
+          """
+      case DateType =>
+        // date can't cast to decimal in Hive
+        (c, evPrim, evNull) => s"$evNull = true;"
+      case TimestampType =>
+        // Note that we lose precision here.
+        (c, evPrim, evNull) =>
+          s"""
+            org.apache.spark.sql.types.Decimal tmpDecimal =
+              new org.apache.spark.sql.types.Decimal().set(
+                scala.math.BigDecimal.valueOf(${timestampToDoubleCode(c)}));
+            ${changePrecision("tmpDecimal", target, evPrim, evNull)}
+          """
+      case DecimalType() =>
+        (c, evPrim, evNull) =>
+          s"""
+            org.apache.spark.sql.types.Decimal tmpDecimal = $c.clone();
+            ${changePrecision("tmpDecimal", target, evPrim, evNull)}
+          """
+      case LongType =>
+        (c, evPrim, evNull) =>
+          s"""
+            org.apache.spark.sql.types.Decimal tmpDecimal =
+              new org.apache.spark.sql.types.Decimal().set($c);
+            ${changePrecision("tmpDecimal", target, evPrim, evNull)}
+          """
+      case x: NumericType =>
+        // All other numeric types can be represented precisely as Doubles
+        (c, evPrim, evNull) =>
+          s"""
+            try {
+              org.apache.spark.sql.types.Decimal tmpDecimal =
+                new org.apache.spark.sql.types.Decimal().set(
+                  scala.math.BigDecimal.valueOf((double) $c));
+              ${changePrecision("tmpDecimal", target, evPrim, evNull)}
+            } catch (java.lang.NumberFormatException e) {
+              $evNull = true;
+            }
+          """
+    }
+  }
 
-      case (DateType, StringType) =>
-        defineCodeGen(ctx, ev, c =>
-          s"""UTF8String.fromString(
-                org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($c))""")
+  private[this] def castToTimestampCode(
+      from: DataType,
+      ctx: CodeGenContext): CastFunction = from match {
+    case StringType =>
+      val longOpt = ctx.freshName("longOpt")
+      (c, evPrim, evNull) =>
+        s"""
+          scala.Option<Long> $longOpt =
+            org.apache.spark.sql.catalyst.util.DateTimeUtils.stringToTimestamp($c);
+          if ($longOpt.isDefined()) {
+            $evPrim = ((Long) $longOpt.get()).longValue();
+          } else {
+            $evNull = true;
+          }
+         """
+    case BooleanType =>
+      (c, evPrim, evNull) => s"$evPrim = $c ? 1L : 0;"
+    case _: IntegralType =>
+      (c, evPrim, evNull) => s"$evPrim = ${longToTimeStampCode(c)};"
+    case DateType =>
+      (c, evPrim, evNull) =>
+        s"$evPrim = org.apache.spark.sql.catalyst.util.DateTimeUtils.daysToMillis($c) * 1000;"
+    case DecimalType() =>
+      (c, evPrim, evNull) => s"$evPrim = ${decimalToTimestampCode(c)};"
+    case DoubleType =>
+      (c, evPrim, evNull) =>
+        s"""
+          if (Double.isNaN($c) || Double.isInfinite($c)) {
+            $evNull = true;
+          } else {
+            $evPrim = (long)($c * 1000000L);
+          }
+        """
+    case FloatType =>
+      (c, evPrim, evNull) =>
+        s"""
+          if (Float.isNaN($c) || Float.isInfinite($c)) {
+            $evNull = true;
+          } else {
+            $evPrim = (long)($c * 1000000L);
+          }
+        """
+  }
 
-      case (TimestampType, StringType) =>
-        defineCodeGen(ctx, ev, c =>
-          s"""UTF8String.fromString(
-                org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c))""")
+  private[this] def castToIntervalCode(from: DataType): CastFunction = from match {
+    case StringType =>
+      (c, evPrim, evNull) =>
+        s"$evPrim = org.apache.spark.unsafe.types.Interval.fromString($c.toString());"
+  }
+
+  private[this] def decimalToTimestampCode(d: String): String =
+    s"($d.toBigDecimal().bigDecimal().multiply(new java.math.BigDecimal(1000000L))).longValue()"
+  private[this] def longToTimeStampCode(l: String): String = s"$l * 1000L"
+  private[this] def timestampToIntegerCode(ts: String): String =
+    s"java.lang.Math.floor((double) $ts / 1000000L)"
+  private[this] def timestampToDoubleCode(ts: String): String = s"$ts / 1000000.0"
+
+  private[this] def castToBooleanCode(from: DataType): CastFunction = from match {
+    case StringType =>
+      (c, evPrim, evNull) => s"$evPrim = $c.numBytes() != 0;"
+    case TimestampType =>
+      (c, evPrim, evNull) => s"$evPrim = $c != 0;"
+    case DateType =>
+      // Hive would return null when cast from date to boolean
+      (c, evPrim, evNull) => s"$evNull = true;"
+    case DecimalType() =>
+      (c, evPrim, evNull) => s"$evPrim = !$c.isZero();"
+    case n: NumericType =>
+      (c, evPrim, evNull) => s"$evPrim = $c != 0;"
+  }
+
+  private[this] def castToByteCode(from: DataType): CastFunction = from match {
+    case StringType =>
+      (c, evPrim, evNull) =>
+        s"""
+          try {
+            $evPrim = Byte.valueOf($c.toString());
+          } catch (java.lang.NumberFormatException e) {
+            $evNull = true;
+          }
+        """
+    case BooleanType =>
+      (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;"
+    case DateType =>
+      (c, evPrim, evNull) => s"$evNull = true;"
+    case TimestampType =>
+      (c, evPrim, evNull) => s"$evPrim = (byte) ${timestampToIntegerCode(c)};"
+    case DecimalType() =>
+      (c, evPrim, evNull) => s"$evPrim = $c.toByte();"
+    case x: NumericType =>
+      (c, evPrim, evNull) => s"$evPrim = (byte) $c;"
+  }
 
-      case (_, StringType) =>
-        defineCodeGen(ctx, ev, c => s"UTF8String.fromString(String.valueOf($c))")
+  private[this] def castToShortCode(from: DataType): CastFunction = from match {
+    case StringType =>
+      (c, evPrim, evNull) =>
+        s"""
+          try {
+            $evPrim = Short.valueOf($c.toString());
+          } catch (java.lang.NumberFormatException e) {
+            $evNull = true;
+          }
+        """
+    case BooleanType =>
+      (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;"
+    case DateType =>
+      (c, evPrim, evNull) => s"$evNull = true;"
+    case TimestampType =>
+      (c, evPrim, evNull) => s"$evPrim = (short) ${timestampToIntegerCode(c)};"
+    case DecimalType() =>
+      (c, evPrim, evNull) => s"$evPrim = $c.toShort();"
+    case x: NumericType =>
+      (c, evPrim, evNull) => s"$evPrim = (short) $c;"
+  }
 
-      case (StringType, IntervalType) =>
-        defineCodeGen(ctx, ev, c =>
-          s"org.apache.spark.unsafe.types.Interval.fromString($c.toString())")
+  private[this] def castToIntCode(from: DataType): CastFunction = from match {
+    case StringType =>
+      (c, evPrim, evNull) =>
+        s"""
+          try {
+            $evPrim = Integer.valueOf($c.toString());
+          } catch (java.lang.NumberFormatException e) {
+            $evNull = true;
+          }
+        """
+    case BooleanType =>
+      (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;"
+    case DateType =>
+      (c, evPrim, evNull) => s"$evNull = true;"
+    case TimestampType =>
+      (c, evPrim, evNull) => s"$evPrim = (int) ${timestampToIntegerCode(c)};"
+    case DecimalType() =>
+      (c, evPrim, evNull) => s"$evPrim = $c.toInt();"
+    case x: NumericType =>
+      (c, evPrim, evNull) => s"$evPrim = (int) $c;"
+  }
 
-      // fallback for DecimalType, this must be before other numeric types
-      case (_, dt: DecimalType) =>
-        super.genCode(ctx, ev)
+  private[this] def castToLongCode(from: DataType): CastFunction = from match {
+    case StringType =>
+      (c, evPrim, evNull) =>
+        s"""
+          try {
+            $evPrim = Long.valueOf($c.toString());
+          } catch (java.lang.NumberFormatException e) {
+            $evNull = true;
+          }
+        """
+    case BooleanType =>
+      (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;"
+    case DateType =>
+      (c, evPrim, evNull) => s"$evNull = true;"
+    case TimestampType =>
+      (c, evPrim, evNull) => s"$evPrim = (long) ${timestampToIntegerCode(c)};"
+    case DecimalType() =>
+      (c, evPrim, evNull) => s"$evPrim = $c.toLong();"
+    case x: NumericType =>
+      (c, evPrim, evNull) => s"$evPrim = (long) $c;"
+  }
 
-      case (BooleanType, dt: NumericType) =>
-        defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c ? 1 : 0)")
+  private[this] def castToFloatCode(from: DataType): CastFunction = from match {
+    case StringType =>
+      (c, evPrim, evNull) =>
+        s"""
+          try {
+            $evPrim = Float.valueOf($c.toString());
+          } catch (java.lang.NumberFormatException e) {
+            $evNull = true;
+          }
+        """
+    case BooleanType =>
+      (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;"
+    case DateType =>
+      (c, evPrim, evNull) => s"$evNull = true;"
+    case TimestampType =>
+      (c, evPrim, evNull) => s"$evPrim = (float) (${timestampToDoubleCode(c)});"
+    case DecimalType() =>
+      (c, evPrim, evNull) => s"$evPrim = $c.toFloat();"
+    case x: NumericType =>
+      (c, evPrim, evNull) => s"$evPrim = (float) $c;"
+  }
 
-      case (dt: DecimalType, BooleanType) =>
-        defineCodeGen(ctx, ev, c => s"!$c.isZero()")
+  private[this] def castToDoubleCode(from: DataType): CastFunction = from match {
+    case StringType =>
+      (c, evPrim, evNull) =>
+        s"""
+          try {
+            $evPrim = Double.valueOf($c.toString());
+          } catch (java.lang.NumberFormatException e) {
+            $evNull = true;
+          }
+        """
+    case BooleanType =>
+      (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;"
+    case DateType =>
+      (c, evPrim, evNull) => s"$evNull = true;"
+    case TimestampType =>
+      (c, evPrim, evNull) => s"$evPrim = ${timestampToDoubleCode(c)};"
+    case DecimalType() =>
+      (c, evPrim, evNull) => s"$evPrim = $c.toDouble();"
+    case x: NumericType =>
+      (c, evPrim, evNull) => s"$evPrim = (double) $c;"
+  }
 
-      case (dt: NumericType, BooleanType) =>
-        defineCodeGen(ctx, ev, c => s"$c != 0")
+  private[this] def castArrayCode(
+      from: ArrayType, to: ArrayType, ctx: CodeGenContext): CastFunction = {
+    val elementCast = nullSafeCastFunction(from.elementType, to.elementType, ctx)
+
+    val arraySeqClass = classOf[mutable.ArraySeq[Any]].getName
+    val fromElementNull = ctx.freshName("feNull")
+    val fromElementPrim = ctx.freshName("fePrim")
+    val toElementNull = ctx.freshName("teNull")
+    val toElementPrim = ctx.freshName("tePrim")
+    val size = ctx.freshName("n")
+    val j = ctx.freshName("j")
+    val result = ctx.freshName("result")
+
+    (c, evPrim, evNull) =>
+      s"""
+        final int $size = $c.size();
+        final $arraySeqClass<Object> $result = new $arraySeqClass<Object>($size);
+        for (int $j = 0; $j < $size; $j ++) {
+          if ($c.apply($j) == null) {
+            $result.update($j, null);
+          } else {
+            boolean $fromElementNull = false;
+            ${ctx.javaType(from.elementType)} $fromElementPrim =
+              (${ctx.boxedType(from.elementType)}) $c.apply($j);
+            ${castCode(ctx, fromElementPrim,
+              fromElementNull, toElementPrim, toElementNull, to.elementType, elementCast)}
+            if ($toElementNull) {
+              $result.update($j, null);
+            } else {
+              $result.update($j, $toElementPrim);
+            }
+          }
+        }
+        $evPrim = $result;
+      """
+  }
 
-      case (_: DecimalType, dt: NumericType) =>
-        defineCodeGen(ctx, ev, c => s"($c).to${ctx.primitiveTypeName(dt)}()")
+  private[this] def castMapCode(from: MapType, to: MapType, ctx: CodeGenContext): CastFunction = {
+    val keyCast = nullSafeCastFunction(from.keyType, to.keyType, ctx)
+    val valueCast = nullSafeCastFunction(from.valueType, to.valueType, ctx)
+
+    val hashMapClass = classOf[mutable.HashMap[Any, Any]].getName
+    val fromKeyPrim = ctx.freshName("fkp")
+    val fromKeyNull = ctx.freshName("fkn")
+    val fromValuePrim = ctx.freshName("fvp")
+    val fromValueNull = ctx.freshName("fvn")
+    val toKeyPrim = ctx.freshName("tkp")
+    val toKeyNull = ctx.freshName("tkn")
+    val toValuePrim = ctx.freshName("tvp")
+    val toValueNull = ctx.freshName("tvn")
+    val result = ctx.freshName("result")
+
+    (c, evPrim, evNull) =>
+      s"""
+        final $hashMapClass $result = new $hashMapClass();
+        scala.collection.Iterator iter = $c.iterator();
+        while (iter.hasNext()) {
+          scala.Tuple2 kv = (scala.Tuple2) iter.next();
+          boolean $fromKeyNull = false;
+          ${ctx.javaType(from.keyType)} $fromKeyPrim =
+            (${ctx.boxedType(from.keyType)}) kv._1();
+          ${castCode(ctx, fromKeyPrim,
+            fromKeyNull, toKeyPrim, toKeyNull, to.keyType, keyCast)}
+
+          boolean $fromValueNull = kv._2() == null;
+          if ($fromValueNull) {
+            $result.put($toKeyPrim, null);
+          } else {
+            ${ctx.javaType(from.valueType)} $fromValuePrim =
+              (${ctx.boxedType(from.valueType)}) kv._2();
+            ${castCode(ctx, fromValuePrim,
+              fromValueNull, toValuePrim, toValueNull, to.valueType, valueCast)}
+            if ($toValueNull) {
+              $result.put($toKeyPrim, null);
+            } else {
+              $result.put($toKeyPrim, $toValuePrim);
+            }
+          }
+        }
+        $evPrim = $result;
+      """
+  }
 
-      case (_: NumericType, dt: NumericType) =>
-        defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c)")
+  private[this] def castStructCode(
+      from: StructType, to: StructType, ctx: CodeGenContext): CastFunction = {
 
-      case other =>
-        super.genCode(ctx, ev)
+    val fieldsCasts = from.fields.zip(to.fields).map {
+      case (fromField, toField) => nullSafeCastFunction(fromField.dataType, toField.dataType, ctx)
     }
+    val rowClass = classOf[GenericMutableRow].getName
+    val result = ctx.freshName("result")
+    val tmpRow = ctx.freshName("tmpRow")
+
+    val fieldsEvalCode = fieldsCasts.zipWithIndex.map { case (cast, i) => {
+      val fromFieldPrim = ctx.freshName("ffp")
+      val fromFieldNull = ctx.freshName("ffn")
+      val toFieldPrim = ctx.freshName("tfp")
+      val toFieldNull = ctx.freshName("tfn")
+      val fromType = ctx.javaType(from.fields(i).dataType)
+      s"""
+        boolean $fromFieldNull = $tmpRow.isNullAt($i);
+        if ($fromFieldNull) {
+          $result.setNullAt($i);
+        } else {
+          $fromType $fromFieldPrim =
+            ${ctx.getColumn(tmpRow, from.fields(i).dataType, i)};
+          ${castCode(ctx, fromFieldPrim,
+            fromFieldNull, toFieldPrim, toFieldNull, to.fields(i).dataType, cast)}
+          if ($toFieldNull) {
+            $result.setNullAt($i);
+          } else {
+            ${ctx.setColumn(result, to.fields(i).dataType, i, toFieldPrim)};
+          }
+        }
+       """
+      }
+    }.mkString("\n")
+
+    (c, evPrim, evNull) =>
+      s"""
+        final $rowClass $result = new $rowClass(${fieldsCasts.size});
+        final InternalRow $tmpRow = $c;
+        $fieldsEvalCode
+        $evPrim = $result.copy();
+      """
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/6d0d8b40/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
index f724bab..bdba6ce 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
@@ -39,7 +39,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
           val c = Calendar.getInstance()
           c.set(y, m, 28, 0, 0, 0)
           c.add(Calendar.DATE, i)
-          checkEvaluation(DayOfYear(Cast(Literal(new Date(c.getTimeInMillis)), DateType)),
+          checkEvaluation(DayOfYear(Literal(new Date(c.getTimeInMillis))),
             sdfDay.format(c.getTime).toInt)
         }
       }
@@ -51,7 +51,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
           val c = Calendar.getInstance()
           c.set(y, m, 28, 0, 0, 0)
           c.add(Calendar.DATE, i)
-          checkEvaluation(DayOfYear(Cast(Literal(new Date(c.getTimeInMillis)), DateType)),
+          checkEvaluation(DayOfYear(Literal(new Date(c.getTimeInMillis))),
             sdfDay.format(c.getTime).toInt)
         }
       }
@@ -63,7 +63,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
           val c = Calendar.getInstance()
           c.set(y, m, 28, 0, 0, 0)
           c.add(Calendar.DATE, i)
-          checkEvaluation(DayOfYear(Cast(Literal(new Date(c.getTimeInMillis)), DateType)),
+          checkEvaluation(DayOfYear(Literal(new Date(c.getTimeInMillis))),
             sdfDay.format(c.getTime).toInt)
         }
       }
@@ -75,7 +75,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
           val c = Calendar.getInstance()
           c.set(y, m, 28, 0, 0, 0)
           c.add(Calendar.DATE, i)
-          checkEvaluation(DayOfYear(Cast(Literal(new Date(c.getTimeInMillis)), DateType)),
+          checkEvaluation(DayOfYear(Literal(new Date(c.getTimeInMillis))),
             sdfDay.format(c.getTime).toInt)
         }
       }
@@ -87,7 +87,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
           val c = Calendar.getInstance()
           c.set(y, m, 28, 0, 0, 0)
           c.add(Calendar.DATE, i)
-          checkEvaluation(DayOfYear(Cast(Literal(new Date(c.getTimeInMillis)), DateType)),
+          checkEvaluation(DayOfYear(Literal(new Date(c.getTimeInMillis))),
             sdfDay.format(c.getTime).toInt)
         }
       }
@@ -96,7 +96,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
 
   test("Year") {
     checkEvaluation(Year(Literal.create(null, DateType)), null)
-    checkEvaluation(Year(Cast(Literal(d), DateType)), 2015)
+    checkEvaluation(Year(Literal(d)), 2015)
     checkEvaluation(Year(Cast(Literal(sdfDate.format(d)), DateType)), 2015)
     checkEvaluation(Year(Cast(Literal(ts), DateType)), 2013)
 
@@ -106,7 +106,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
         c.set(y, m, 28)
         (0 to 5 * 24).foreach { i =>
           c.add(Calendar.HOUR_OF_DAY, 1)
-          checkEvaluation(Year(Cast(Literal(new Date(c.getTimeInMillis)), DateType)),
+          checkEvaluation(Year(Literal(new Date(c.getTimeInMillis))),
             c.get(Calendar.YEAR))
         }
       }
@@ -115,7 +115,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
 
   test("Quarter") {
     checkEvaluation(Quarter(Literal.create(null, DateType)), null)
-    checkEvaluation(Quarter(Cast(Literal(d), DateType)), 2)
+    checkEvaluation(Quarter(Literal(d)), 2)
     checkEvaluation(Quarter(Cast(Literal(sdfDate.format(d)), DateType)), 2)
     checkEvaluation(Quarter(Cast(Literal(ts), DateType)), 4)
 
@@ -125,7 +125,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
         c.set(y, m, 28, 0, 0, 0)
         (0 to 5 * 24).foreach { i =>
           c.add(Calendar.HOUR_OF_DAY, 1)
-          checkEvaluation(Quarter(Cast(Literal(new Date(c.getTimeInMillis)), DateType)),
+          checkEvaluation(Quarter(Literal(new Date(c.getTimeInMillis))),
             c.get(Calendar.MONTH) / 3 + 1)
         }
       }
@@ -134,7 +134,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
 
   test("Month") {
     checkEvaluation(Month(Literal.create(null, DateType)), null)
-    checkEvaluation(Month(Cast(Literal(d), DateType)), 4)
+    checkEvaluation(Month(Literal(d)), 4)
     checkEvaluation(Month(Cast(Literal(sdfDate.format(d)), DateType)), 4)
     checkEvaluation(Month(Cast(Literal(ts), DateType)), 11)
 
@@ -144,7 +144,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
           val c = Calendar.getInstance()
           c.set(y, m, 28, 0, 0, 0)
           c.add(Calendar.HOUR_OF_DAY, i)
-          checkEvaluation(Month(Cast(Literal(new Date(c.getTimeInMillis)), DateType)),
+          checkEvaluation(Month(Literal(new Date(c.getTimeInMillis))),
             c.get(Calendar.MONTH) + 1)
         }
       }
@@ -156,7 +156,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
           val c = Calendar.getInstance()
           c.set(y, m, 28, 0, 0, 0)
           c.add(Calendar.HOUR_OF_DAY, i)
-          checkEvaluation(Month(Cast(Literal(new Date(c.getTimeInMillis)), DateType)),
+          checkEvaluation(Month(Literal(new Date(c.getTimeInMillis))),
             c.get(Calendar.MONTH) + 1)
         }
       }
@@ -166,7 +166,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
   test("Day / DayOfMonth") {
     checkEvaluation(DayOfMonth(Cast(Literal("2000-02-29"), DateType)), 29)
     checkEvaluation(DayOfMonth(Literal.create(null, DateType)), null)
-    checkEvaluation(DayOfMonth(Cast(Literal(d), DateType)), 8)
+    checkEvaluation(DayOfMonth(Literal(d)), 8)
     checkEvaluation(DayOfMonth(Cast(Literal(sdfDate.format(d)), DateType)), 8)
     checkEvaluation(DayOfMonth(Cast(Literal(ts), DateType)), 8)
 
@@ -175,7 +175,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
       c.set(y, 0, 1, 0, 0, 0)
       (0 to 365).foreach { d =>
         c.add(Calendar.DATE, 1)
-        checkEvaluation(DayOfMonth(Cast(Literal(new Date(c.getTimeInMillis)), DateType)),
+        checkEvaluation(DayOfMonth(Literal(new Date(c.getTimeInMillis))),
           c.get(Calendar.DAY_OF_MONTH))
       }
     }
@@ -190,14 +190,14 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
     val c = Calendar.getInstance()
     (0 to 60 by 5).foreach { s =>
       c.set(2015, 18, 3, 3, 5, s)
-      checkEvaluation(Second(Cast(Literal(new Timestamp(c.getTimeInMillis)), TimestampType)),
+      checkEvaluation(Second(Literal(new Timestamp(c.getTimeInMillis))),
         c.get(Calendar.SECOND))
     }
   }
 
   test("WeekOfYear") {
     checkEvaluation(WeekOfYear(Literal.create(null, DateType)), null)
-    checkEvaluation(WeekOfYear(Cast(Literal(d), DateType)), 15)
+    checkEvaluation(WeekOfYear(Literal(d)), 15)
     checkEvaluation(WeekOfYear(Cast(Literal(sdfDate.format(d)), DateType)), 15)
     checkEvaluation(WeekOfYear(Cast(Literal(ts), DateType)), 45)
     checkEvaluation(WeekOfYear(Cast(Literal("2011-05-06"), DateType)), 18)
@@ -223,7 +223,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
       (0 to 60 by 15).foreach { m =>
         (0 to 60 by 15).foreach { s =>
           c.set(2015, 18, 3, h, m, s)
-          checkEvaluation(Hour(Cast(Literal(new Timestamp(c.getTimeInMillis)), TimestampType)),
+          checkEvaluation(Hour(Literal(new Timestamp(c.getTimeInMillis))),
             c.get(Calendar.HOUR_OF_DAY))
         }
       }
@@ -240,7 +240,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
     (0 to 60 by 5).foreach { m =>
       (0 to 60 by 15).foreach { s =>
         c.set(2015, 18, 3, 3, m, s)
-        checkEvaluation(Minute(Cast(Literal(new Timestamp(c.getTimeInMillis)), TimestampType)),
+        checkEvaluation(Minute(Literal(new Timestamp(c.getTimeInMillis))),
           c.get(Calendar.MINUTE))
       }
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org