You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2015/07/30 08:24:25 UTC

spark git commit: [SPARK-9428] [SQL] Add test cases for null inputs for expression unit tests

Repository: spark
Updated Branches:
  refs/heads/master 712465b68 -> e127ec34d


[SPARK-9428] [SQL] Add test cases for null inputs for expression unit tests

JIRA: https://issues.apache.org/jira/browse/SPARK-9428

Author: Yijie Shen <he...@gmail.com>

Closes #7748 from yjshen/string_cleanup and squashes the following commits:

e0c2b3d [Yijie Shen] update codegen in RegExpExtract and RegExpReplace
26614d2 [Yijie Shen] MathFunctionSuite
a402859 [Yijie Shen] complex_create, conditional and cast
6e4e608 [Yijie Shen] arithmetic and cast
52593c1 [Yijie Shen] null input test cases for StringExpressionSuite


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e127ec34
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e127ec34
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e127ec34

Branch: refs/heads/master
Commit: e127ec34d58ceb0a9d45748c2f2918786ba0a83d
Parents: 712465b
Author: Yijie Shen <he...@gmail.com>
Authored: Wed Jul 29 23:24:20 2015 -0700
Committer: Davies Liu <da...@gmail.com>
Committed: Wed Jul 29 23:24:20 2015 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/expressions/Cast.scala   | 12 ++--
 .../expressions/complexTypeCreator.scala        | 16 +++--
 .../sql/catalyst/expressions/conditionals.scala | 10 ++--
 .../spark/sql/catalyst/expressions/math.scala   | 14 ++---
 .../catalyst/expressions/stringOperations.scala | 11 ++--
 .../analysis/ExpressionTypeCheckingSuite.scala  |  7 ++-
 .../expressions/ArithmeticExpressionSuite.scala |  3 +
 .../sql/catalyst/expressions/CastSuite.scala    | 52 +++++++++++++++-
 .../catalyst/expressions/ComplexTypeSuite.scala | 23 +++----
 .../ConditionalExpressionSuite.scala            |  4 ++
 .../expressions/MathFunctionsSuite.scala        | 63 +++++++++++---------
 .../sql/catalyst/expressions/RandomSuite.scala  |  1 -
 .../expressions/StringExpressionsSuite.scala    | 26 ++++++++
 .../scala/org/apache/spark/sql/functions.scala  |  6 +-
 14 files changed, 167 insertions(+), 81 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e127ec34/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index c6e8af2..8c01c13 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -599,7 +599,7 @@ case class Cast(child: Expression, dataType: DataType)
           }
          """
     case BooleanType =>
-      (c, evPrim, evNull) => s"$evPrim = $c ? 1L : 0;"
+      (c, evPrim, evNull) => s"$evPrim = $c ? 1L : 0L;"
     case _: IntegralType =>
       (c, evPrim, evNull) => s"$evPrim = ${longToTimeStampCode(c)};"
     case DateType =>
@@ -665,7 +665,7 @@ case class Cast(child: Expression, dataType: DataType)
           }
         """
     case BooleanType =>
-      (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;"
+      (c, evPrim, evNull) => s"$evPrim = $c ? (byte) 1 : (byte) 0;"
     case DateType =>
       (c, evPrim, evNull) => s"$evNull = true;"
     case TimestampType =>
@@ -687,7 +687,7 @@ case class Cast(child: Expression, dataType: DataType)
           }
         """
     case BooleanType =>
-      (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;"
+      (c, evPrim, evNull) => s"$evPrim = $c ? (short) 1 : (short) 0;"
     case DateType =>
       (c, evPrim, evNull) => s"$evNull = true;"
     case TimestampType =>
@@ -731,7 +731,7 @@ case class Cast(child: Expression, dataType: DataType)
           }
         """
     case BooleanType =>
-      (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;"
+      (c, evPrim, evNull) => s"$evPrim = $c ? 1L : 0L;"
     case DateType =>
       (c, evPrim, evNull) => s"$evNull = true;"
     case TimestampType =>
@@ -753,7 +753,7 @@ case class Cast(child: Expression, dataType: DataType)
           }
         """
     case BooleanType =>
-      (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;"
+      (c, evPrim, evNull) => s"$evPrim = $c ? 1.0f : 0.0f;"
     case DateType =>
       (c, evPrim, evNull) => s"$evNull = true;"
     case TimestampType =>
@@ -775,7 +775,7 @@ case class Cast(child: Expression, dataType: DataType)
           }
         """
     case BooleanType =>
-      (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;"
+      (c, evPrim, evNull) => s"$evPrim = $c ? 1.0d : 0.0d;"
     case DateType =>
       (c, evPrim, evNull) => s"$evNull = true;"
     case TimestampType =>

http://git-wip-us.apache.org/repos/asf/spark/blob/e127ec34/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index d8c9087..0517050 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
+import org.apache.spark.unsafe.types.UTF8String
+
 import scala.collection.mutable
 
 import org.apache.spark.sql.catalyst.InternalRow
@@ -127,11 +129,12 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression {
   private lazy val (nameExprs, valExprs) =
     children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip
 
-  private lazy val names = nameExprs.map(_.eval(EmptyRow).toString)
+  private lazy val names = nameExprs.map(_.eval(EmptyRow))
 
   override lazy val dataType: StructType = {
     val fields = names.zip(valExprs).map { case (name, valExpr) =>
-      StructField(name, valExpr.dataType, valExpr.nullable, Metadata.empty)
+      StructField(name.asInstanceOf[UTF8String].toString,
+        valExpr.dataType, valExpr.nullable, Metadata.empty)
     }
     StructType(fields)
   }
@@ -144,14 +147,15 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression {
     if (children.size % 2 != 0) {
       TypeCheckResult.TypeCheckFailure(s"$prettyName expects an even number of arguments.")
     } else {
-      val invalidNames =
-        nameExprs.filterNot(e => e.foldable && e.dataType == StringType && !nullable)
+      val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType)
       if (invalidNames.nonEmpty) {
         TypeCheckResult.TypeCheckFailure(
-          s"Odd position only allow foldable and not-null StringType expressions, got :" +
+          s"Only foldable StringType expressions are allowed to appear at odd position , got :" +
             s" ${invalidNames.mkString(",")}")
-      } else {
+      } else if (names.forall(_ != null)){
         TypeCheckResult.TypeCheckSuccess
+      } else {
+        TypeCheckResult.TypeCheckFailure("Field name should not be null")
       }
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/e127ec34/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala
index 15b33da..961b1d8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala
@@ -315,7 +315,6 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW
  * It takes at least 2 parameters, and returns null iff all parameters are null.
  */
 case class Least(children: Seq[Expression]) extends Expression {
-  require(children.length > 1, "LEAST requires at least 2 arguments, got " + children.length)
 
   override def nullable: Boolean = children.forall(_.nullable)
   override def foldable: Boolean = children.forall(_.foldable)
@@ -323,7 +322,9 @@ case class Least(children: Seq[Expression]) extends Expression {
   private lazy val ordering = TypeUtils.getOrdering(dataType)
 
   override def checkInputDataTypes(): TypeCheckResult = {
-    if (children.map(_.dataType).distinct.count(_ != NullType) > 1) {
+    if (children.length <= 1) {
+      TypeCheckResult.TypeCheckFailure(s"LEAST requires at least 2 arguments")
+    } else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) {
       TypeCheckResult.TypeCheckFailure(
         s"The expressions should all have the same type," +
           s" got LEAST (${children.map(_.dataType)}).")
@@ -369,7 +370,6 @@ case class Least(children: Seq[Expression]) extends Expression {
  * It takes at least 2 parameters, and returns null iff all parameters are null.
  */
 case class Greatest(children: Seq[Expression]) extends Expression {
-  require(children.length > 1, "GREATEST requires at least 2 arguments, got " + children.length)
 
   override def nullable: Boolean = children.forall(_.nullable)
   override def foldable: Boolean = children.forall(_.foldable)
@@ -377,7 +377,9 @@ case class Greatest(children: Seq[Expression]) extends Expression {
   private lazy val ordering = TypeUtils.getOrdering(dataType)
 
   override def checkInputDataTypes(): TypeCheckResult = {
-    if (children.map(_.dataType).distinct.count(_ != NullType) > 1) {
+    if (children.length <= 1) {
+      TypeCheckResult.TypeCheckFailure(s"GREATEST requires at least 2 arguments")
+    } else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) {
       TypeCheckResult.TypeCheckFailure(
         s"The expressions should all have the same type," +
           s" got GREATEST (${children.map(_.dataType)}).")

http://git-wip-us.apache.org/repos/asf/spark/blob/e127ec34/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
index 68cca0a..e6d807f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
@@ -646,19 +646,19 @@ case class Logarithm(left: Expression, right: Expression)
 /**
  * Round the `child`'s result to `scale` decimal place when `scale` >= 0
  * or round at integral part when `scale` < 0.
- * For example, round(31.415, 2) would eval to 31.42 and round(31.415, -1) would eval to 30.
+ * For example, round(31.415, 2) = 31.42 and round(31.415, -1) = 30.
  *
- * Child of IntegralType would eval to itself when `scale` >= 0.
- * Child of FractionalType whose value is NaN or Infinite would always eval to itself.
+ * Child of IntegralType would round to itself when `scale` >= 0.
+ * Child of FractionalType whose value is NaN or Infinite would always round to itself.
  *
- * Round's dataType would always equal to `child`'s dataType except for [[DecimalType.Fixed]],
- * which leads to scale update in DecimalType's [[PrecisionInfo]]
+ * Round's dataType would always equal to `child`'s dataType except for DecimalType,
+ * which would lead scale decrease from the origin DecimalType.
  *
  * @param child expr to be round, all [[NumericType]] is allowed as Input
  * @param scale new scale to be round to, this should be a constant int at runtime
  */
 case class Round(child: Expression, scale: Expression)
-  extends BinaryExpression with ExpectsInputTypes {
+  extends BinaryExpression with ImplicitCastInputTypes {
 
   import BigDecimal.RoundingMode.HALF_UP
 
@@ -838,6 +838,4 @@ case class Round(child: Expression, scale: Expression)
       """
     }
   }
-
-  override def prettyName: String = "round"
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/e127ec34/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 6db4e19..5b3a64a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -22,7 +22,6 @@ import java.util.Locale
 import java.util.regex.{MatchResult, Pattern}
 
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.UnresolvedException
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
@@ -52,7 +51,7 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas
   override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
     val evals = children.map(_.gen(ctx))
     val inputs = evals.map { eval =>
-      s"${eval.isNull} ? (UTF8String)null : ${eval.primitive}"
+      s"${eval.isNull} ? null : ${eval.primitive}"
     }.mkString(", ")
     evals.map(_.code).mkString("\n") + s"""
       boolean ${ev.isNull} = false;
@@ -1008,7 +1007,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
 
     s"""
       ${evalSubject.code}
-      boolean ${ev.isNull} = ${evalSubject.isNull};
+      boolean ${ev.isNull} = true;
       ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
       if (!${evalSubject.isNull}) {
         ${evalRegexp.code}
@@ -1103,9 +1102,9 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio
     val evalIdx = idx.gen(ctx)
 
     s"""
-      ${ctx.javaType(dataType)} ${ev.primitive} = null;
-      boolean ${ev.isNull} = true;
       ${evalSubject.code}
+      ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+      boolean ${ev.isNull} = true;
       if (!${evalSubject.isNull}) {
         ${evalRegexp.code}
         if (!${evalRegexp.isNull}) {
@@ -1117,7 +1116,7 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio
               ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString());
             }
             ${classOf[java.util.regex.Matcher].getCanonicalName} m =
-                                   ${termPattern}.matcher(${evalSubject.primitive}.toString());
+              ${termPattern}.matcher(${evalSubject.primitive}.toString());
             if (m.find()) {
               ${classOf[java.util.regex.MatchResult].getCanonicalName} mr = m.toMatchResult();
               ${ev.primitive} = ${classNameUTF8String}.fromString(mr.group(${evalIdx.primitive}));

http://git-wip-us.apache.org/repos/asf/spark/blob/e127ec34/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
index 8acd4c6..a52e4cb 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
@@ -167,10 +167,13 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
       CreateNamedStruct(Seq("a", "b", 2.0)), "even number of arguments")
     assertError(
       CreateNamedStruct(Seq(1, "a", "b", 2.0)),
-        "Odd position only allow foldable and not-null StringType expressions")
+        "Only foldable StringType expressions are allowed to appear at odd position")
     assertError(
       CreateNamedStruct(Seq('a.string.at(0), "a", "b", 2.0)),
-        "Odd position only allow foldable and not-null StringType expressions")
+        "Only foldable StringType expressions are allowed to appear at odd position")
+    assertError(
+      CreateNamedStruct(Seq(Literal.create(null, StringType), "a")),
+        "Field name should not be null")
   }
 
   test("check types for ROUND") {

http://git-wip-us.apache.org/repos/asf/spark/blob/e127ec34/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
index 7773e09..d03b0fb 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
@@ -116,9 +116,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
 
   test("Abs") {
     testNumericDataTypes { convert =>
+      val input = Literal(convert(1))
+      val dataType = input.dataType
       checkEvaluation(Abs(Literal(convert(0))), convert(0))
       checkEvaluation(Abs(Literal(convert(1))), convert(1))
       checkEvaluation(Abs(Literal(convert(-1))), convert(1))
+      checkEvaluation(Abs(Literal.create(null, dataType)), null)
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/e127ec34/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
index 0e0213b..a517da9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
@@ -43,6 +43,42 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
     checkEvaluation(cast(v, Literal(expected).dataType), expected)
   }
 
+  private def checkNullCast(from: DataType, to: DataType): Unit = {
+    checkEvaluation(Cast(Literal.create(null, from), to), null)
+  }
+
+  test("null cast") {
+    import DataTypeTestUtils._
+
+    // follow [[org.apache.spark.sql.catalyst.expressions.Cast.canCast]] logic
+    // to ensure we test every possible cast situation here
+    atomicTypes.zip(atomicTypes).foreach { case (from, to) =>
+      checkNullCast(from, to)
+    }
+
+    atomicTypes.foreach(dt => checkNullCast(NullType, dt))
+    atomicTypes.foreach(dt => checkNullCast(dt, StringType))
+    checkNullCast(StringType, BinaryType)
+    checkNullCast(StringType, BooleanType)
+    checkNullCast(DateType, BooleanType)
+    checkNullCast(TimestampType, BooleanType)
+    numericTypes.foreach(dt => checkNullCast(dt, BooleanType))
+
+    checkNullCast(StringType, TimestampType)
+    checkNullCast(BooleanType, TimestampType)
+    checkNullCast(DateType, TimestampType)
+    numericTypes.foreach(dt => checkNullCast(dt, TimestampType))
+
+    atomicTypes.foreach(dt => checkNullCast(dt, DateType))
+
+    checkNullCast(StringType, CalendarIntervalType)
+    numericTypes.foreach(dt => checkNullCast(StringType, dt))
+    numericTypes.foreach(dt => checkNullCast(BooleanType, dt))
+    numericTypes.foreach(dt => checkNullCast(DateType, dt))
+    numericTypes.foreach(dt => checkNullCast(TimestampType, dt))
+    for (from <- numericTypes; to <- numericTypes) checkNullCast(from, to)
+  }
+
   test("cast string to date") {
     var c = Calendar.getInstance()
     c.set(2015, 0, 1, 0, 0, 0)
@@ -69,8 +105,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
   }
 
   test("cast string to timestamp") {
-    checkEvaluation(Cast(Literal("123"), TimestampType),
-      null)
+    checkEvaluation(Cast(Literal("123"), TimestampType), null)
 
     var c = Calendar.getInstance()
     c.set(2015, 0, 1, 0, 0, 0)
@@ -473,6 +508,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
     val array_notNull = Literal.create(Seq("123", "abc", ""),
       ArrayType(StringType, containsNull = false))
 
+    checkNullCast(ArrayType(StringType), ArrayType(IntegerType))
+
     {
       val ret = cast(array, ArrayType(IntegerType, containsNull = true))
       assert(ret.resolved === true)
@@ -526,6 +563,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
       Map("a" -> "123", "b" -> "abc", "c" -> ""),
       MapType(StringType, StringType, valueContainsNull = false))
 
+    checkNullCast(MapType(StringType, IntegerType), MapType(StringType, StringType))
+
     {
       val ret = cast(map, MapType(StringType, IntegerType, valueContainsNull = true))
       assert(ret.resolved === true)
@@ -580,6 +619,14 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
   }
 
   test("cast from struct") {
+    checkNullCast(
+      StructType(Seq(
+        StructField("a", StringType),
+        StructField("b", IntegerType))),
+      StructType(Seq(
+        StructField("a", StringType),
+        StructField("b", StringType))))
+
     val struct = Literal.create(
       InternalRow(
         UTF8String.fromString("123"),
@@ -728,5 +775,4 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
       StringType),
       "interval 1 years 3 months -3 days")
   }
-
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/e127ec34/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
index fc84277..5de5ddc 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
@@ -132,6 +132,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
     checkEvaluation(CreateArray(intWithNull), intSeq :+ null, EmptyRow)
     checkEvaluation(CreateArray(longWithNull), longSeq :+ null, EmptyRow)
     checkEvaluation(CreateArray(strWithNull), strSeq :+ null, EmptyRow)
+    checkEvaluation(CreateArray(Literal.create(null, IntegerType) :: Nil), null :: Nil)
   }
 
   test("CreateStruct") {
@@ -139,26 +140,20 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
     val c1 = 'a.int.at(0)
     val c3 = 'c.int.at(2)
     checkEvaluation(CreateStruct(Seq(c1, c3)), create_row(1, 3), row)
+    checkEvaluation(CreateStruct(Literal.create(null, LongType) :: Nil), create_row(null))
   }
 
   test("CreateNamedStruct") {
-    val row = InternalRow(1, 2, 3)
+    val row = create_row(1, 2, 3)
     val c1 = 'a.int.at(0)
     val c3 = 'c.int.at(2)
-    checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", c3)), InternalRow(1, 3), row)
-  }
-
-  test("CreateNamedStruct with literal field") {
-    val row = InternalRow(1, 2, 3)
-    val c1 = 'a.int.at(0)
+    checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", c3)), create_row(1, 3), row)
     checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", "y")),
-      InternalRow(1, UTF8String.fromString("y")), row)
-  }
-
-  test("CreateNamedStruct from all literal fields") {
-    checkEvaluation(
-      CreateNamedStruct(Seq("a", "x", "b", 2.0)),
-      InternalRow(UTF8String.fromString("x"), 2.0), InternalRow.empty)
+      create_row(1, UTF8String.fromString("y")), row)
+    checkEvaluation(CreateNamedStruct(Seq("a", "x", "b", 2.0)),
+      create_row(UTF8String.fromString("x"), 2.0))
+    checkEvaluation(CreateNamedStruct(Seq("a", Literal.create(null, IntegerType))),
+      create_row(null))
   }
 
   test("test dsl for complex type") {

http://git-wip-us.apache.org/repos/asf/spark/blob/e127ec34/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
index b31d666..d26bcdb 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
@@ -149,6 +149,8 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
     checkEvaluation(Least(Seq(c1, c2, Literal(-1))), -1, row)
     checkEvaluation(Least(Seq(c4, c5, c3, c3, Literal("a"))), "a", row)
 
+    val nullLiteral = Literal.create(null, IntegerType)
+    checkEvaluation(Least(Seq(nullLiteral, nullLiteral)), null)
     checkEvaluation(Least(Seq(Literal(null), Literal(null))), null, InternalRow.empty)
     checkEvaluation(Least(Seq(Literal(-1.0), Literal(2.5))), -1.0, InternalRow.empty)
     checkEvaluation(Least(Seq(Literal(-1), Literal(2))), -1, InternalRow.empty)
@@ -188,6 +190,8 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
     checkEvaluation(Greatest(Seq(c1, c2, Literal(2))), 2, row)
     checkEvaluation(Greatest(Seq(c4, c5, c3, Literal("ccc"))), "ccc", row)
 
+    val nullLiteral = Literal.create(null, IntegerType)
+    checkEvaluation(Greatest(Seq(nullLiteral, nullLiteral)), null)
     checkEvaluation(Greatest(Seq(Literal(null), Literal(null))), null, InternalRow.empty)
     checkEvaluation(Greatest(Seq(Literal(-1.0), Literal(2.5))), 2.5, InternalRow.empty)
     checkEvaluation(Greatest(Seq(Literal(-1), Literal(2))), 2, InternalRow.empty)

http://git-wip-us.apache.org/repos/asf/spark/blob/e127ec34/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
index 21459a7..9fcb548 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
@@ -110,35 +110,17 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
     checkEvaluation(c(Literal(1.0), Literal.create(null, DoubleType)), null, create_row(null))
   }
 
-  test("conv") {
-    checkEvaluation(Conv(Literal("3"), Literal(10), Literal(2)), "11")
-    checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(-16)), "-F")
-    checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(16)), "FFFFFFFFFFFFFFF1")
-    checkEvaluation(Conv(Literal("big"), Literal(36), Literal(16)), "3A48")
-    checkEvaluation(Conv(Literal.create(null, StringType), Literal(36), Literal(16)), null)
-    checkEvaluation(Conv(Literal("3"), Literal.create(null, IntegerType), Literal(16)), null)
-    checkEvaluation(
-      Conv(Literal("1234"), Literal(10), Literal(37)), null)
-    checkEvaluation(
-      Conv(Literal(""), Literal(10), Literal(16)), null)
-    checkEvaluation(
-      Conv(Literal("9223372036854775807"), Literal(36), Literal(16)), "FFFFFFFFFFFFFFFF")
-    // If there is an invalid digit in the number, the longest valid prefix should be converted.
-    checkEvaluation(
-      Conv(Literal("11abc"), Literal(10), Literal(16)), "B")
-  }
-
   private def checkNaN(
-      expression: Expression, inputRow: InternalRow = EmptyRow): Unit = {
+    expression: Expression, inputRow: InternalRow = EmptyRow): Unit = {
     checkNaNWithoutCodegen(expression, inputRow)
     checkNaNWithGeneratedProjection(expression, inputRow)
     checkNaNWithOptimization(expression, inputRow)
   }
 
   private def checkNaNWithoutCodegen(
-      expression: Expression,
-      expected: Any,
-      inputRow: InternalRow = EmptyRow): Unit = {
+    expression: Expression,
+    expected: Any,
+    inputRow: InternalRow = EmptyRow): Unit = {
     val actual = try evaluate(expression, inputRow) catch {
       case e: Exception => fail(s"Exception evaluating $expression", e)
     }
@@ -149,7 +131,6 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
     }
   }
 
-
   private def checkNaNWithGeneratedProjection(
     expression: Expression,
     inputRow: InternalRow = EmptyRow): Unit = {
@@ -172,6 +153,25 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
     checkNaNWithoutCodegen(optimizedPlan.expressions.head, inputRow)
   }
 
+  test("conv") {
+    checkEvaluation(Conv(Literal("3"), Literal(10), Literal(2)), "11")
+    checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(-16)), "-F")
+    checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(16)), "FFFFFFFFFFFFFFF1")
+    checkEvaluation(Conv(Literal("big"), Literal(36), Literal(16)), "3A48")
+    checkEvaluation(Conv(Literal.create(null, StringType), Literal(36), Literal(16)), null)
+    checkEvaluation(Conv(Literal("3"), Literal.create(null, IntegerType), Literal(16)), null)
+    checkEvaluation(Conv(Literal("3"), Literal(16), Literal.create(null, IntegerType)), null)
+    checkEvaluation(
+      Conv(Literal("1234"), Literal(10), Literal(37)), null)
+    checkEvaluation(
+      Conv(Literal(""), Literal(10), Literal(16)), null)
+    checkEvaluation(
+      Conv(Literal("9223372036854775807"), Literal(36), Literal(16)), "FFFFFFFFFFFFFFFF")
+    // If there is an invalid digit in the number, the longest valid prefix should be converted.
+    checkEvaluation(
+      Conv(Literal("11abc"), Literal(10), Literal(16)), "B")
+  }
+
   test("e") {
     testLeaf(EulerNumber, math.E)
   }
@@ -417,7 +417,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
   }
 
   test("round") {
-    val domain = -6 to 6
+    val scales = -6 to 6
     val doublePi: Double = math.Pi
     val shortPi: Short = 31415
     val intPi: Int = 314159265
@@ -437,17 +437,16 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
       31415926535900000L, 31415926535898000L, 31415926535897900L, 31415926535897930L) ++
       Seq.fill(7)(31415926535897932L)
 
-    val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14),
-      BigDecimal(3.142), BigDecimal(3.1416), BigDecimal(3.14159),
-      BigDecimal(3.141593), BigDecimal(3.1415927))
-
-    domain.zipWithIndex.foreach { case (scale, i) =>
+    scales.zipWithIndex.foreach { case (scale, i) =>
       checkEvaluation(Round(doublePi, scale), doubleResults(i), EmptyRow)
       checkEvaluation(Round(shortPi, scale), shortResults(i), EmptyRow)
       checkEvaluation(Round(intPi, scale), intResults(i), EmptyRow)
       checkEvaluation(Round(longPi, scale), longResults(i), EmptyRow)
     }
 
+    val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14),
+      BigDecimal(3.142), BigDecimal(3.1416), BigDecimal(3.14159),
+      BigDecimal(3.141593), BigDecimal(3.1415927))
     // round_scale > current_scale would result in precision increase
     // and not allowed by o.a.s.s.types.Decimal.changePrecision, therefore null
     (0 to 7).foreach { i =>
@@ -456,5 +455,11 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
     (8 to 10).foreach { scale =>
       checkEvaluation(Round(bdPi, scale), null, EmptyRow)
     }
+
+    DataTypeTestUtils.numericTypes.foreach { dataType =>
+      checkEvaluation(Round(Literal.create(null, dataType), Literal(2)), null)
+      checkEvaluation(Round(Literal.create(null, dataType),
+        Literal.create(null, IntegerType)), null)
+    }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/e127ec34/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala
index 5db9926..4a644d1 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala
@@ -21,7 +21,6 @@ import org.scalatest.Matchers._
 
 import org.apache.spark.SparkFunSuite
 
-
 class RandomSuite extends SparkFunSuite with ExpressionEvalHelper {
 
   test("random") {

http://git-wip-us.apache.org/repos/asf/spark/blob/e127ec34/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index 3d294fd..07b9525 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -348,6 +348,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
     checkEvaluation(StringTrimLeft(s), "花花世界 ", create_row("  花花世界 "))
     checkEvaluation(StringTrim(s), "花花世界", create_row("  花花世界 "))
     // scalastyle:on
+    checkEvaluation(StringTrim(Literal.create(null, StringType)), null)
+    checkEvaluation(StringTrimLeft(Literal.create(null, StringType)), null)
+    checkEvaluation(StringTrimRight(Literal.create(null, StringType)), null)
   }
 
   test("FORMAT") {
@@ -391,6 +394,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
     val s3 = 'c.string.at(2)
     val s4 = 'd.int.at(3)
     val row1 = create_row("aaads", "aa", "zz", 1)
+    val row2 = create_row(null, "aa", "zz", 0)
+    val row3 = create_row("aaads", null, "zz", 0)
+    val row4 = create_row(null, null, null, 0)
 
     checkEvaluation(new StringLocate(Literal("aa"), Literal("aaads")), 1, row1)
     checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(1)), 2, row1)
@@ -402,6 +408,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
     checkEvaluation(StringLocate(s2, s1, s4), 2, row1)
     checkEvaluation(new StringLocate(s3, s1), 0, row1)
     checkEvaluation(StringLocate(s3, s1, Literal.create(null, IntegerType)), 0, row1)
+    checkEvaluation(new StringLocate(s2, s1), null, row2)
+    checkEvaluation(new StringLocate(s2, s1), null, row3)
+    checkEvaluation(new StringLocate(s2, s1, Literal.create(null, IntegerType)), 0, row4)
   }
 
   test("LPAD/RPAD") {
@@ -448,6 +457,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
     val row1 = create_row("abccc")
     checkEvaluation(StringReverse(Literal("abccc")), "cccba", row1)
     checkEvaluation(StringReverse(s), "cccba", row1)
+    checkEvaluation(StringReverse(Literal.create(null, StringType)), null, row1)
   }
 
   test("SPACE") {
@@ -466,6 +476,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
     val row1 = create_row("100-200", "(\\d+)", "num")
     val row2 = create_row("100-200", "(\\d+)", "###")
     val row3 = create_row("100-200", "(-)", "###")
+    val row4 = create_row(null, "(\\d+)", "###")
+    val row5 = create_row("100-200", null, "###")
+    val row6 = create_row("100-200", "(-)", null)
 
     val s = 's.string.at(0)
     val p = 'p.string.at(1)
@@ -475,6 +488,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
     checkEvaluation(expr, "num-num", row1)
     checkEvaluation(expr, "###-###", row2)
     checkEvaluation(expr, "100###200", row3)
+    checkEvaluation(expr, null, row4)
+    checkEvaluation(expr, null, row5)
+    checkEvaluation(expr, null, row6)
   }
 
   test("RegexExtract") {
@@ -482,6 +498,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
     val row2 = create_row("100-200", "(\\d+)-(\\d+)", 2)
     val row3 = create_row("100-200", "(\\d+).*", 1)
     val row4 = create_row("100-200", "([a-z])", 1)
+    val row5 = create_row(null, "([a-z])", 1)
+    val row6 = create_row("100-200", null, 1)
+    val row7 = create_row("100-200", "([a-z])", null)
 
     val s = 's.string.at(0)
     val p = 'p.string.at(1)
@@ -492,6 +511,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
     checkEvaluation(expr, "200", row2)
     checkEvaluation(expr, "100", row3)
     checkEvaluation(expr, "", row4) // will not match anything, empty string get
+    checkEvaluation(expr, null, row5)
+    checkEvaluation(expr, null, row6)
+    checkEvaluation(expr, null, row7)
 
     val expr1 = new RegExpExtract(s, p)
     checkEvaluation(expr1, "100", row1)
@@ -501,11 +523,15 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
     val s1 = 'a.string.at(0)
     val s2 = 'b.string.at(1)
     val row1 = create_row("aa2bb3cc", "[1-9]+")
+    val row2 = create_row(null, "[1-9]+")
+    val row3 = create_row("aa2bb3cc", null)
 
     checkEvaluation(
       StringSplit(Literal("aa2bb3cc"), Literal("[1-9]+")), Seq("aa", "bb", "cc"), row1)
     checkEvaluation(
       StringSplit(s1, s2), Seq("aa", "bb", "cc"), row1)
+    checkEvaluation(StringSplit(s1, s2), null, row2)
+    checkEvaluation(StringSplit(s1, s2), null, row3)
   }
 
   test("length for string / binary") {

http://git-wip-us.apache.org/repos/asf/spark/blob/e127ec34/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 4261a5e..4e68a88 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1423,7 +1423,8 @@ object functions {
   def round(columnName: String): Column = round(Column(columnName), 0)
 
   /**
-   * Returns the value of `e` rounded to `scale` decimal places.
+   * Round the value of `e` to `scale` decimal places if `scale` >= 0
+   * or at integral part when `scale` < 0.
    *
    * @group math_funcs
    * @since 1.5.0
@@ -1431,7 +1432,8 @@ object functions {
   def round(e: Column, scale: Int): Column = Round(e.expr, Literal(scale))
 
   /**
-   * Returns the value of the given column rounded to `scale` decimal places.
+   * Round the value of the given column to `scale` decimal places if `scale` >= 0
+   * or at integral part when `scale` < 0.
    *
    * @group math_funcs
    * @since 1.5.0


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org