You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2015/07/30 08:24:25 UTC
spark git commit: [SPARK-9428] [SQL] Add test cases for null inputs
for expression unit tests
Repository: spark
Updated Branches:
refs/heads/master 712465b68 -> e127ec34d
[SPARK-9428] [SQL] Add test cases for null inputs for expression unit tests
JIRA: https://issues.apache.org/jira/browse/SPARK-9428
Author: Yijie Shen <he...@gmail.com>
Closes #7748 from yjshen/string_cleanup and squashes the following commits:
e0c2b3d [Yijie Shen] update codegen in RegExpExtract and RegExpReplace
26614d2 [Yijie Shen] MathFunctionSuite
a402859 [Yijie Shen] complex_create, conditional and cast
6e4e608 [Yijie Shen] arithmetic and cast
52593c1 [Yijie Shen] null input test cases for StringExpressionSuite
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e127ec34
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e127ec34
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e127ec34
Branch: refs/heads/master
Commit: e127ec34d58ceb0a9d45748c2f2918786ba0a83d
Parents: 712465b
Author: Yijie Shen <he...@gmail.com>
Authored: Wed Jul 29 23:24:20 2015 -0700
Committer: Davies Liu <da...@gmail.com>
Committed: Wed Jul 29 23:24:20 2015 -0700
----------------------------------------------------------------------
.../spark/sql/catalyst/expressions/Cast.scala | 12 ++--
.../expressions/complexTypeCreator.scala | 16 +++--
.../sql/catalyst/expressions/conditionals.scala | 10 ++--
.../spark/sql/catalyst/expressions/math.scala | 14 ++---
.../catalyst/expressions/stringOperations.scala | 11 ++--
.../analysis/ExpressionTypeCheckingSuite.scala | 7 ++-
.../expressions/ArithmeticExpressionSuite.scala | 3 +
.../sql/catalyst/expressions/CastSuite.scala | 52 +++++++++++++++-
.../catalyst/expressions/ComplexTypeSuite.scala | 23 +++----
.../ConditionalExpressionSuite.scala | 4 ++
.../expressions/MathFunctionsSuite.scala | 63 +++++++++++---------
.../sql/catalyst/expressions/RandomSuite.scala | 1 -
.../expressions/StringExpressionsSuite.scala | 26 ++++++++
.../scala/org/apache/spark/sql/functions.scala | 6 +-
14 files changed, 167 insertions(+), 81 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/e127ec34/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index c6e8af2..8c01c13 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -599,7 +599,7 @@ case class Cast(child: Expression, dataType: DataType)
}
"""
case BooleanType =>
- (c, evPrim, evNull) => s"$evPrim = $c ? 1L : 0;"
+ (c, evPrim, evNull) => s"$evPrim = $c ? 1L : 0L;"
case _: IntegralType =>
(c, evPrim, evNull) => s"$evPrim = ${longToTimeStampCode(c)};"
case DateType =>
@@ -665,7 +665,7 @@ case class Cast(child: Expression, dataType: DataType)
}
"""
case BooleanType =>
- (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;"
+ (c, evPrim, evNull) => s"$evPrim = $c ? (byte) 1 : (byte) 0;"
case DateType =>
(c, evPrim, evNull) => s"$evNull = true;"
case TimestampType =>
@@ -687,7 +687,7 @@ case class Cast(child: Expression, dataType: DataType)
}
"""
case BooleanType =>
- (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;"
+ (c, evPrim, evNull) => s"$evPrim = $c ? (short) 1 : (short) 0;"
case DateType =>
(c, evPrim, evNull) => s"$evNull = true;"
case TimestampType =>
@@ -731,7 +731,7 @@ case class Cast(child: Expression, dataType: DataType)
}
"""
case BooleanType =>
- (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;"
+ (c, evPrim, evNull) => s"$evPrim = $c ? 1L : 0L;"
case DateType =>
(c, evPrim, evNull) => s"$evNull = true;"
case TimestampType =>
@@ -753,7 +753,7 @@ case class Cast(child: Expression, dataType: DataType)
}
"""
case BooleanType =>
- (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;"
+ (c, evPrim, evNull) => s"$evPrim = $c ? 1.0f : 0.0f;"
case DateType =>
(c, evPrim, evNull) => s"$evNull = true;"
case TimestampType =>
@@ -775,7 +775,7 @@ case class Cast(child: Expression, dataType: DataType)
}
"""
case BooleanType =>
- (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;"
+ (c, evPrim, evNull) => s"$evPrim = $c ? 1.0d : 0.0d;"
case DateType =>
(c, evPrim, evNull) => s"$evNull = true;"
case TimestampType =>
http://git-wip-us.apache.org/repos/asf/spark/blob/e127ec34/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index d8c9087..0517050 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.expressions
+import org.apache.spark.unsafe.types.UTF8String
+
import scala.collection.mutable
import org.apache.spark.sql.catalyst.InternalRow
@@ -127,11 +129,12 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression {
private lazy val (nameExprs, valExprs) =
children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip
- private lazy val names = nameExprs.map(_.eval(EmptyRow).toString)
+ private lazy val names = nameExprs.map(_.eval(EmptyRow))
override lazy val dataType: StructType = {
val fields = names.zip(valExprs).map { case (name, valExpr) =>
- StructField(name, valExpr.dataType, valExpr.nullable, Metadata.empty)
+ StructField(name.asInstanceOf[UTF8String].toString,
+ valExpr.dataType, valExpr.nullable, Metadata.empty)
}
StructType(fields)
}
@@ -144,14 +147,15 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression {
if (children.size % 2 != 0) {
TypeCheckResult.TypeCheckFailure(s"$prettyName expects an even number of arguments.")
} else {
- val invalidNames =
- nameExprs.filterNot(e => e.foldable && e.dataType == StringType && !nullable)
+ val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType)
if (invalidNames.nonEmpty) {
TypeCheckResult.TypeCheckFailure(
- s"Odd position only allow foldable and not-null StringType expressions, got :" +
+ s"Only foldable StringType expressions are allowed to appear at odd position , got :" +
s" ${invalidNames.mkString(",")}")
- } else {
+ } else if (names.forall(_ != null)){
TypeCheckResult.TypeCheckSuccess
+ } else {
+ TypeCheckResult.TypeCheckFailure("Field name should not be null")
}
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/e127ec34/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala
index 15b33da..961b1d8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala
@@ -315,7 +315,6 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW
* It takes at least 2 parameters, and returns null iff all parameters are null.
*/
case class Least(children: Seq[Expression]) extends Expression {
- require(children.length > 1, "LEAST requires at least 2 arguments, got " + children.length)
override def nullable: Boolean = children.forall(_.nullable)
override def foldable: Boolean = children.forall(_.foldable)
@@ -323,7 +322,9 @@ case class Least(children: Seq[Expression]) extends Expression {
private lazy val ordering = TypeUtils.getOrdering(dataType)
override def checkInputDataTypes(): TypeCheckResult = {
- if (children.map(_.dataType).distinct.count(_ != NullType) > 1) {
+ if (children.length <= 1) {
+ TypeCheckResult.TypeCheckFailure(s"LEAST requires at least 2 arguments")
+ } else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) {
TypeCheckResult.TypeCheckFailure(
s"The expressions should all have the same type," +
s" got LEAST (${children.map(_.dataType)}).")
@@ -369,7 +370,6 @@ case class Least(children: Seq[Expression]) extends Expression {
* It takes at least 2 parameters, and returns null iff all parameters are null.
*/
case class Greatest(children: Seq[Expression]) extends Expression {
- require(children.length > 1, "GREATEST requires at least 2 arguments, got " + children.length)
override def nullable: Boolean = children.forall(_.nullable)
override def foldable: Boolean = children.forall(_.foldable)
@@ -377,7 +377,9 @@ case class Greatest(children: Seq[Expression]) extends Expression {
private lazy val ordering = TypeUtils.getOrdering(dataType)
override def checkInputDataTypes(): TypeCheckResult = {
- if (children.map(_.dataType).distinct.count(_ != NullType) > 1) {
+ if (children.length <= 1) {
+ TypeCheckResult.TypeCheckFailure(s"GREATEST requires at least 2 arguments")
+ } else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) {
TypeCheckResult.TypeCheckFailure(
s"The expressions should all have the same type," +
s" got GREATEST (${children.map(_.dataType)}).")
http://git-wip-us.apache.org/repos/asf/spark/blob/e127ec34/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
index 68cca0a..e6d807f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
@@ -646,19 +646,19 @@ case class Logarithm(left: Expression, right: Expression)
/**
* Round the `child`'s result to `scale` decimal place when `scale` >= 0
* or round at integral part when `scale` < 0.
- * For example, round(31.415, 2) would eval to 31.42 and round(31.415, -1) would eval to 30.
+ * For example, round(31.415, 2) = 31.42 and round(31.415, -1) = 30.
*
- * Child of IntegralType would eval to itself when `scale` >= 0.
- * Child of FractionalType whose value is NaN or Infinite would always eval to itself.
+ * Child of IntegralType would round to itself when `scale` >= 0.
+ * Child of FractionalType whose value is NaN or Infinite would always round to itself.
*
- * Round's dataType would always equal to `child`'s dataType except for [[DecimalType.Fixed]],
- * which leads to scale update in DecimalType's [[PrecisionInfo]]
+ * Round's dataType would always equal to `child`'s dataType except for DecimalType,
+ * which would lead scale decrease from the origin DecimalType.
*
* @param child expr to be round, all [[NumericType]] is allowed as Input
* @param scale new scale to be round to, this should be a constant int at runtime
*/
case class Round(child: Expression, scale: Expression)
- extends BinaryExpression with ExpectsInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes {
import BigDecimal.RoundingMode.HALF_UP
@@ -838,6 +838,4 @@ case class Round(child: Expression, scale: Expression)
"""
}
}
-
- override def prettyName: String = "round"
}
http://git-wip-us.apache.org/repos/asf/spark/blob/e127ec34/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 6db4e19..5b3a64a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -22,7 +22,6 @@ import java.util.Locale
import java.util.regex.{MatchResult, Pattern}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -52,7 +51,7 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val evals = children.map(_.gen(ctx))
val inputs = evals.map { eval =>
- s"${eval.isNull} ? (UTF8String)null : ${eval.primitive}"
+ s"${eval.isNull} ? null : ${eval.primitive}"
}.mkString(", ")
evals.map(_.code).mkString("\n") + s"""
boolean ${ev.isNull} = false;
@@ -1008,7 +1007,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
s"""
${evalSubject.code}
- boolean ${ev.isNull} = ${evalSubject.isNull};
+ boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${evalSubject.isNull}) {
${evalRegexp.code}
@@ -1103,9 +1102,9 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio
val evalIdx = idx.gen(ctx)
s"""
- ${ctx.javaType(dataType)} ${ev.primitive} = null;
- boolean ${ev.isNull} = true;
${evalSubject.code}
+ ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+ boolean ${ev.isNull} = true;
if (!${evalSubject.isNull}) {
${evalRegexp.code}
if (!${evalRegexp.isNull}) {
@@ -1117,7 +1116,7 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio
${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString());
}
${classOf[java.util.regex.Matcher].getCanonicalName} m =
- ${termPattern}.matcher(${evalSubject.primitive}.toString());
+ ${termPattern}.matcher(${evalSubject.primitive}.toString());
if (m.find()) {
${classOf[java.util.regex.MatchResult].getCanonicalName} mr = m.toMatchResult();
${ev.primitive} = ${classNameUTF8String}.fromString(mr.group(${evalIdx.primitive}));
http://git-wip-us.apache.org/repos/asf/spark/blob/e127ec34/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
index 8acd4c6..a52e4cb 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
@@ -167,10 +167,13 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
CreateNamedStruct(Seq("a", "b", 2.0)), "even number of arguments")
assertError(
CreateNamedStruct(Seq(1, "a", "b", 2.0)),
- "Odd position only allow foldable and not-null StringType expressions")
+ "Only foldable StringType expressions are allowed to appear at odd position")
assertError(
CreateNamedStruct(Seq('a.string.at(0), "a", "b", 2.0)),
- "Odd position only allow foldable and not-null StringType expressions")
+ "Only foldable StringType expressions are allowed to appear at odd position")
+ assertError(
+ CreateNamedStruct(Seq(Literal.create(null, StringType), "a")),
+ "Field name should not be null")
}
test("check types for ROUND") {
http://git-wip-us.apache.org/repos/asf/spark/blob/e127ec34/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
index 7773e09..d03b0fb 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
@@ -116,9 +116,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
test("Abs") {
testNumericDataTypes { convert =>
+ val input = Literal(convert(1))
+ val dataType = input.dataType
checkEvaluation(Abs(Literal(convert(0))), convert(0))
checkEvaluation(Abs(Literal(convert(1))), convert(1))
checkEvaluation(Abs(Literal(convert(-1))), convert(1))
+ checkEvaluation(Abs(Literal.create(null, dataType)), null)
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/e127ec34/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
index 0e0213b..a517da9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
@@ -43,6 +43,42 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(cast(v, Literal(expected).dataType), expected)
}
+ private def checkNullCast(from: DataType, to: DataType): Unit = {
+ checkEvaluation(Cast(Literal.create(null, from), to), null)
+ }
+
+ test("null cast") {
+ import DataTypeTestUtils._
+
+ // follow [[org.apache.spark.sql.catalyst.expressions.Cast.canCast]] logic
+ // to ensure we test every possible cast situation here
+ atomicTypes.zip(atomicTypes).foreach { case (from, to) =>
+ checkNullCast(from, to)
+ }
+
+ atomicTypes.foreach(dt => checkNullCast(NullType, dt))
+ atomicTypes.foreach(dt => checkNullCast(dt, StringType))
+ checkNullCast(StringType, BinaryType)
+ checkNullCast(StringType, BooleanType)
+ checkNullCast(DateType, BooleanType)
+ checkNullCast(TimestampType, BooleanType)
+ numericTypes.foreach(dt => checkNullCast(dt, BooleanType))
+
+ checkNullCast(StringType, TimestampType)
+ checkNullCast(BooleanType, TimestampType)
+ checkNullCast(DateType, TimestampType)
+ numericTypes.foreach(dt => checkNullCast(dt, TimestampType))
+
+ atomicTypes.foreach(dt => checkNullCast(dt, DateType))
+
+ checkNullCast(StringType, CalendarIntervalType)
+ numericTypes.foreach(dt => checkNullCast(StringType, dt))
+ numericTypes.foreach(dt => checkNullCast(BooleanType, dt))
+ numericTypes.foreach(dt => checkNullCast(DateType, dt))
+ numericTypes.foreach(dt => checkNullCast(TimestampType, dt))
+ for (from <- numericTypes; to <- numericTypes) checkNullCast(from, to)
+ }
+
test("cast string to date") {
var c = Calendar.getInstance()
c.set(2015, 0, 1, 0, 0, 0)
@@ -69,8 +105,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
}
test("cast string to timestamp") {
- checkEvaluation(Cast(Literal("123"), TimestampType),
- null)
+ checkEvaluation(Cast(Literal("123"), TimestampType), null)
var c = Calendar.getInstance()
c.set(2015, 0, 1, 0, 0, 0)
@@ -473,6 +508,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
val array_notNull = Literal.create(Seq("123", "abc", ""),
ArrayType(StringType, containsNull = false))
+ checkNullCast(ArrayType(StringType), ArrayType(IntegerType))
+
{
val ret = cast(array, ArrayType(IntegerType, containsNull = true))
assert(ret.resolved === true)
@@ -526,6 +563,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
Map("a" -> "123", "b" -> "abc", "c" -> ""),
MapType(StringType, StringType, valueContainsNull = false))
+ checkNullCast(MapType(StringType, IntegerType), MapType(StringType, StringType))
+
{
val ret = cast(map, MapType(StringType, IntegerType, valueContainsNull = true))
assert(ret.resolved === true)
@@ -580,6 +619,14 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
}
test("cast from struct") {
+ checkNullCast(
+ StructType(Seq(
+ StructField("a", StringType),
+ StructField("b", IntegerType))),
+ StructType(Seq(
+ StructField("a", StringType),
+ StructField("b", StringType))))
+
val struct = Literal.create(
InternalRow(
UTF8String.fromString("123"),
@@ -728,5 +775,4 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
StringType),
"interval 1 years 3 months -3 days")
}
-
}
http://git-wip-us.apache.org/repos/asf/spark/blob/e127ec34/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
index fc84277..5de5ddc 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
@@ -132,6 +132,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(CreateArray(intWithNull), intSeq :+ null, EmptyRow)
checkEvaluation(CreateArray(longWithNull), longSeq :+ null, EmptyRow)
checkEvaluation(CreateArray(strWithNull), strSeq :+ null, EmptyRow)
+ checkEvaluation(CreateArray(Literal.create(null, IntegerType) :: Nil), null :: Nil)
}
test("CreateStruct") {
@@ -139,26 +140,20 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
val c1 = 'a.int.at(0)
val c3 = 'c.int.at(2)
checkEvaluation(CreateStruct(Seq(c1, c3)), create_row(1, 3), row)
+ checkEvaluation(CreateStruct(Literal.create(null, LongType) :: Nil), create_row(null))
}
test("CreateNamedStruct") {
- val row = InternalRow(1, 2, 3)
+ val row = create_row(1, 2, 3)
val c1 = 'a.int.at(0)
val c3 = 'c.int.at(2)
- checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", c3)), InternalRow(1, 3), row)
- }
-
- test("CreateNamedStruct with literal field") {
- val row = InternalRow(1, 2, 3)
- val c1 = 'a.int.at(0)
+ checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", c3)), create_row(1, 3), row)
checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", "y")),
- InternalRow(1, UTF8String.fromString("y")), row)
- }
-
- test("CreateNamedStruct from all literal fields") {
- checkEvaluation(
- CreateNamedStruct(Seq("a", "x", "b", 2.0)),
- InternalRow(UTF8String.fromString("x"), 2.0), InternalRow.empty)
+ create_row(1, UTF8String.fromString("y")), row)
+ checkEvaluation(CreateNamedStruct(Seq("a", "x", "b", 2.0)),
+ create_row(UTF8String.fromString("x"), 2.0))
+ checkEvaluation(CreateNamedStruct(Seq("a", Literal.create(null, IntegerType))),
+ create_row(null))
}
test("test dsl for complex type") {
http://git-wip-us.apache.org/repos/asf/spark/blob/e127ec34/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
index b31d666..d26bcdb 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
@@ -149,6 +149,8 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(Least(Seq(c1, c2, Literal(-1))), -1, row)
checkEvaluation(Least(Seq(c4, c5, c3, c3, Literal("a"))), "a", row)
+ val nullLiteral = Literal.create(null, IntegerType)
+ checkEvaluation(Least(Seq(nullLiteral, nullLiteral)), null)
checkEvaluation(Least(Seq(Literal(null), Literal(null))), null, InternalRow.empty)
checkEvaluation(Least(Seq(Literal(-1.0), Literal(2.5))), -1.0, InternalRow.empty)
checkEvaluation(Least(Seq(Literal(-1), Literal(2))), -1, InternalRow.empty)
@@ -188,6 +190,8 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(Greatest(Seq(c1, c2, Literal(2))), 2, row)
checkEvaluation(Greatest(Seq(c4, c5, c3, Literal("ccc"))), "ccc", row)
+ val nullLiteral = Literal.create(null, IntegerType)
+ checkEvaluation(Greatest(Seq(nullLiteral, nullLiteral)), null)
checkEvaluation(Greatest(Seq(Literal(null), Literal(null))), null, InternalRow.empty)
checkEvaluation(Greatest(Seq(Literal(-1.0), Literal(2.5))), 2.5, InternalRow.empty)
checkEvaluation(Greatest(Seq(Literal(-1), Literal(2))), 2, InternalRow.empty)
http://git-wip-us.apache.org/repos/asf/spark/blob/e127ec34/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
index 21459a7..9fcb548 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
@@ -110,35 +110,17 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(c(Literal(1.0), Literal.create(null, DoubleType)), null, create_row(null))
}
- test("conv") {
- checkEvaluation(Conv(Literal("3"), Literal(10), Literal(2)), "11")
- checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(-16)), "-F")
- checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(16)), "FFFFFFFFFFFFFFF1")
- checkEvaluation(Conv(Literal("big"), Literal(36), Literal(16)), "3A48")
- checkEvaluation(Conv(Literal.create(null, StringType), Literal(36), Literal(16)), null)
- checkEvaluation(Conv(Literal("3"), Literal.create(null, IntegerType), Literal(16)), null)
- checkEvaluation(
- Conv(Literal("1234"), Literal(10), Literal(37)), null)
- checkEvaluation(
- Conv(Literal(""), Literal(10), Literal(16)), null)
- checkEvaluation(
- Conv(Literal("9223372036854775807"), Literal(36), Literal(16)), "FFFFFFFFFFFFFFFF")
- // If there is an invalid digit in the number, the longest valid prefix should be converted.
- checkEvaluation(
- Conv(Literal("11abc"), Literal(10), Literal(16)), "B")
- }
-
private def checkNaN(
- expression: Expression, inputRow: InternalRow = EmptyRow): Unit = {
+ expression: Expression, inputRow: InternalRow = EmptyRow): Unit = {
checkNaNWithoutCodegen(expression, inputRow)
checkNaNWithGeneratedProjection(expression, inputRow)
checkNaNWithOptimization(expression, inputRow)
}
private def checkNaNWithoutCodegen(
- expression: Expression,
- expected: Any,
- inputRow: InternalRow = EmptyRow): Unit = {
+ expression: Expression,
+ expected: Any,
+ inputRow: InternalRow = EmptyRow): Unit = {
val actual = try evaluate(expression, inputRow) catch {
case e: Exception => fail(s"Exception evaluating $expression", e)
}
@@ -149,7 +131,6 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}
-
private def checkNaNWithGeneratedProjection(
expression: Expression,
inputRow: InternalRow = EmptyRow): Unit = {
@@ -172,6 +153,25 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkNaNWithoutCodegen(optimizedPlan.expressions.head, inputRow)
}
+ test("conv") {
+ checkEvaluation(Conv(Literal("3"), Literal(10), Literal(2)), "11")
+ checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(-16)), "-F")
+ checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(16)), "FFFFFFFFFFFFFFF1")
+ checkEvaluation(Conv(Literal("big"), Literal(36), Literal(16)), "3A48")
+ checkEvaluation(Conv(Literal.create(null, StringType), Literal(36), Literal(16)), null)
+ checkEvaluation(Conv(Literal("3"), Literal.create(null, IntegerType), Literal(16)), null)
+ checkEvaluation(Conv(Literal("3"), Literal(16), Literal.create(null, IntegerType)), null)
+ checkEvaluation(
+ Conv(Literal("1234"), Literal(10), Literal(37)), null)
+ checkEvaluation(
+ Conv(Literal(""), Literal(10), Literal(16)), null)
+ checkEvaluation(
+ Conv(Literal("9223372036854775807"), Literal(36), Literal(16)), "FFFFFFFFFFFFFFFF")
+ // If there is an invalid digit in the number, the longest valid prefix should be converted.
+ checkEvaluation(
+ Conv(Literal("11abc"), Literal(10), Literal(16)), "B")
+ }
+
test("e") {
testLeaf(EulerNumber, math.E)
}
@@ -417,7 +417,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
test("round") {
- val domain = -6 to 6
+ val scales = -6 to 6
val doublePi: Double = math.Pi
val shortPi: Short = 31415
val intPi: Int = 314159265
@@ -437,17 +437,16 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
31415926535900000L, 31415926535898000L, 31415926535897900L, 31415926535897930L) ++
Seq.fill(7)(31415926535897932L)
- val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14),
- BigDecimal(3.142), BigDecimal(3.1416), BigDecimal(3.14159),
- BigDecimal(3.141593), BigDecimal(3.1415927))
-
- domain.zipWithIndex.foreach { case (scale, i) =>
+ scales.zipWithIndex.foreach { case (scale, i) =>
checkEvaluation(Round(doublePi, scale), doubleResults(i), EmptyRow)
checkEvaluation(Round(shortPi, scale), shortResults(i), EmptyRow)
checkEvaluation(Round(intPi, scale), intResults(i), EmptyRow)
checkEvaluation(Round(longPi, scale), longResults(i), EmptyRow)
}
+ val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14),
+ BigDecimal(3.142), BigDecimal(3.1416), BigDecimal(3.14159),
+ BigDecimal(3.141593), BigDecimal(3.1415927))
// round_scale > current_scale would result in precision increase
// and not allowed by o.a.s.s.types.Decimal.changePrecision, therefore null
(0 to 7).foreach { i =>
@@ -456,5 +455,11 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
(8 to 10).foreach { scale =>
checkEvaluation(Round(bdPi, scale), null, EmptyRow)
}
+
+ DataTypeTestUtils.numericTypes.foreach { dataType =>
+ checkEvaluation(Round(Literal.create(null, dataType), Literal(2)), null)
+ checkEvaluation(Round(Literal.create(null, dataType),
+ Literal.create(null, IntegerType)), null)
+ }
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/e127ec34/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala
index 5db9926..4a644d1 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala
@@ -21,7 +21,6 @@ import org.scalatest.Matchers._
import org.apache.spark.SparkFunSuite
-
class RandomSuite extends SparkFunSuite with ExpressionEvalHelper {
test("random") {
http://git-wip-us.apache.org/repos/asf/spark/blob/e127ec34/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index 3d294fd..07b9525 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -348,6 +348,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(StringTrimLeft(s), "花花世界 ", create_row(" 花花世界 "))
checkEvaluation(StringTrim(s), "花花世界", create_row(" 花花世界 "))
// scalastyle:on
+ checkEvaluation(StringTrim(Literal.create(null, StringType)), null)
+ checkEvaluation(StringTrimLeft(Literal.create(null, StringType)), null)
+ checkEvaluation(StringTrimRight(Literal.create(null, StringType)), null)
}
test("FORMAT") {
@@ -391,6 +394,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val s3 = 'c.string.at(2)
val s4 = 'd.int.at(3)
val row1 = create_row("aaads", "aa", "zz", 1)
+ val row2 = create_row(null, "aa", "zz", 0)
+ val row3 = create_row("aaads", null, "zz", 0)
+ val row4 = create_row(null, null, null, 0)
checkEvaluation(new StringLocate(Literal("aa"), Literal("aaads")), 1, row1)
checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(1)), 2, row1)
@@ -402,6 +408,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(StringLocate(s2, s1, s4), 2, row1)
checkEvaluation(new StringLocate(s3, s1), 0, row1)
checkEvaluation(StringLocate(s3, s1, Literal.create(null, IntegerType)), 0, row1)
+ checkEvaluation(new StringLocate(s2, s1), null, row2)
+ checkEvaluation(new StringLocate(s2, s1), null, row3)
+ checkEvaluation(new StringLocate(s2, s1, Literal.create(null, IntegerType)), 0, row4)
}
test("LPAD/RPAD") {
@@ -448,6 +457,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val row1 = create_row("abccc")
checkEvaluation(StringReverse(Literal("abccc")), "cccba", row1)
checkEvaluation(StringReverse(s), "cccba", row1)
+ checkEvaluation(StringReverse(Literal.create(null, StringType)), null, row1)
}
test("SPACE") {
@@ -466,6 +476,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val row1 = create_row("100-200", "(\\d+)", "num")
val row2 = create_row("100-200", "(\\d+)", "###")
val row3 = create_row("100-200", "(-)", "###")
+ val row4 = create_row(null, "(\\d+)", "###")
+ val row5 = create_row("100-200", null, "###")
+ val row6 = create_row("100-200", "(-)", null)
val s = 's.string.at(0)
val p = 'p.string.at(1)
@@ -475,6 +488,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(expr, "num-num", row1)
checkEvaluation(expr, "###-###", row2)
checkEvaluation(expr, "100###200", row3)
+ checkEvaluation(expr, null, row4)
+ checkEvaluation(expr, null, row5)
+ checkEvaluation(expr, null, row6)
}
test("RegexExtract") {
@@ -482,6 +498,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val row2 = create_row("100-200", "(\\d+)-(\\d+)", 2)
val row3 = create_row("100-200", "(\\d+).*", 1)
val row4 = create_row("100-200", "([a-z])", 1)
+ val row5 = create_row(null, "([a-z])", 1)
+ val row6 = create_row("100-200", null, 1)
+ val row7 = create_row("100-200", "([a-z])", null)
val s = 's.string.at(0)
val p = 'p.string.at(1)
@@ -492,6 +511,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(expr, "200", row2)
checkEvaluation(expr, "100", row3)
checkEvaluation(expr, "", row4) // will not match anything, empty string get
+ checkEvaluation(expr, null, row5)
+ checkEvaluation(expr, null, row6)
+ checkEvaluation(expr, null, row7)
val expr1 = new RegExpExtract(s, p)
checkEvaluation(expr1, "100", row1)
@@ -501,11 +523,15 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val s1 = 'a.string.at(0)
val s2 = 'b.string.at(1)
val row1 = create_row("aa2bb3cc", "[1-9]+")
+ val row2 = create_row(null, "[1-9]+")
+ val row3 = create_row("aa2bb3cc", null)
checkEvaluation(
StringSplit(Literal("aa2bb3cc"), Literal("[1-9]+")), Seq("aa", "bb", "cc"), row1)
checkEvaluation(
StringSplit(s1, s2), Seq("aa", "bb", "cc"), row1)
+ checkEvaluation(StringSplit(s1, s2), null, row2)
+ checkEvaluation(StringSplit(s1, s2), null, row3)
}
test("length for string / binary") {
http://git-wip-us.apache.org/repos/asf/spark/blob/e127ec34/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 4261a5e..4e68a88 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1423,7 +1423,8 @@ object functions {
def round(columnName: String): Column = round(Column(columnName), 0)
/**
- * Returns the value of `e` rounded to `scale` decimal places.
+ * Round the value of `e` to `scale` decimal places if `scale` >= 0
+ * or at integral part when `scale` < 0.
*
* @group math_funcs
* @since 1.5.0
@@ -1431,7 +1432,8 @@ object functions {
def round(e: Column, scale: Int): Column = Round(e.expr, Literal(scale))
/**
- * Returns the value of the given column rounded to `scale` decimal places.
+ * Round the value of the given column to `scale` decimal places if `scale` >= 0
+ * or at integral part when `scale` < 0.
*
* @group math_funcs
* @since 1.5.0
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org