You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2017/07/08 06:05:41 UTC
spark git commit: [SPARK-21281][SQL] Use string types by default if
array and map have no argument
Repository: spark
Updated Branches:
refs/heads/master e1a172c20 -> 7896e7b99
[SPARK-21281][SQL] Use string types by default if array and map have no argument
## What changes were proposed in this pull request?
This pr modified code to use string types by default if `array` and `map` in functions have no argument. This behaviour is the same with Hive one;
```
hive> CREATE TEMPORARY TABLE t1 AS SELECT map();
hive> DESCRIBE t1;
_c0 map<string,string>
hive> CREATE TEMPORARY TABLE t2 AS SELECT array();
hive> DESCRIBE t2;
_c0 array<string>
```
## How was this patch tested?
Added tests in `DataFrameFunctionsSuite`.
Author: Takeshi Yamamuro <ya...@apache.org>
Closes #18516 from maropu/SPARK-21281.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7896e7b9
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7896e7b9
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7896e7b9
Branch: refs/heads/master
Commit: 7896e7b99d95d28800f5644bd36b3990cf0ef8c4
Parents: e1a172c
Author: Takeshi Yamamuro <ya...@apache.org>
Authored: Fri Jul 7 23:05:38 2017 -0700
Committer: gatorsmile <ga...@gmail.com>
Committed: Fri Jul 7 23:05:38 2017 -0700
----------------------------------------------------------------------
.../sql/catalyst/expressions/arithmetic.scala | 10 +++---
.../expressions/complexTypeCreator.scala | 35 +++++++++++--------
.../spark/sql/catalyst/expressions/hash.scala | 5 +--
.../catalyst/expressions/nullExpressions.scala | 7 ++--
.../analysis/ExpressionTypeCheckingSuite.scala | 4 +--
.../scala/org/apache/spark/sql/functions.scala | 10 ++----
.../spark/sql/DataFrameFunctionsSuite.scala | 36 ++++++++++++++++++++
7 files changed, 74 insertions(+), 33 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/7896e7b9/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index ec6e6ba..423bf66 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -527,13 +527,14 @@ case class Least(children: Seq[Expression]) extends Expression {
override def checkInputDataTypes(): TypeCheckResult = {
if (children.length <= 1) {
- TypeCheckResult.TypeCheckFailure(s"LEAST requires at least 2 arguments")
+ TypeCheckResult.TypeCheckFailure(
+ s"input to function $prettyName requires at least two arguments")
} else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) {
TypeCheckResult.TypeCheckFailure(
s"The expressions should all have the same type," +
s" got LEAST(${children.map(_.dataType.simpleString).mkString(", ")}).")
} else {
- TypeUtils.checkForOrderingExpr(dataType, "function " + prettyName)
+ TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName")
}
}
@@ -592,13 +593,14 @@ case class Greatest(children: Seq[Expression]) extends Expression {
override def checkInputDataTypes(): TypeCheckResult = {
if (children.length <= 1) {
- TypeCheckResult.TypeCheckFailure(s"GREATEST requires at least 2 arguments")
+ TypeCheckResult.TypeCheckFailure(
+ s"input to function $prettyName requires at least two arguments")
} else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) {
TypeCheckResult.TypeCheckFailure(
s"The expressions should all have the same type," +
s" got GREATEST(${children.map(_.dataType.simpleString).mkString(", ")}).")
} else {
- TypeUtils.checkForOrderingExpr(dataType, "function " + prettyName)
+ TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName")
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/7896e7b9/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 98c4cbe..d9eeb53 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -41,12 +41,13 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
override def foldable: Boolean = children.forall(_.foldable)
- override def checkInputDataTypes(): TypeCheckResult =
- TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function array")
+ override def checkInputDataTypes(): TypeCheckResult = {
+ TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), s"function $prettyName")
+ }
override def dataType: ArrayType = {
ArrayType(
- children.headOption.map(_.dataType).getOrElse(NullType),
+ children.headOption.map(_.dataType).getOrElse(StringType),
containsNull = children.exists(_.nullable))
}
@@ -93,7 +94,7 @@ private [sql] object GenArrayData {
if (!ctx.isPrimitiveType(elementType)) {
val genericArrayClass = classOf[GenericArrayData].getName
ctx.addMutableState("Object[]", arrayName,
- s"$arrayName = new Object[${numElements}];")
+ s"$arrayName = new Object[$numElements];")
val assignments = elementsCode.zipWithIndex.map { case (eval, i) =>
val isNullAssignment = if (!isMapKey) {
@@ -119,7 +120,7 @@ private [sql] object GenArrayData {
UnsafeArrayData.calculateHeaderPortionInBytes(numElements) +
ByteArrayMethods.roundNumberOfBytesToNearestWord(elementType.defaultSize * numElements)
val baseOffset = Platform.BYTE_ARRAY_OFFSET
- ctx.addMutableState("UnsafeArrayData", arrayDataName, "");
+ ctx.addMutableState("UnsafeArrayData", arrayDataName, "")
val primitiveValueTypeName = ctx.primitiveTypeName(elementType)
val assignments = elementsCode.zipWithIndex.map { case (eval, i) =>
@@ -169,13 +170,16 @@ case class CreateMap(children: Seq[Expression]) extends Expression {
override def checkInputDataTypes(): TypeCheckResult = {
if (children.size % 2 != 0) {
- TypeCheckResult.TypeCheckFailure(s"$prettyName expects a positive even number of arguments.")
+ TypeCheckResult.TypeCheckFailure(
+ s"$prettyName expects a positive even number of arguments.")
} else if (keys.map(_.dataType).distinct.length > 1) {
- TypeCheckResult.TypeCheckFailure("The given keys of function map should all be the same " +
- "type, but they are " + keys.map(_.dataType.simpleString).mkString("[", ", ", "]"))
+ TypeCheckResult.TypeCheckFailure(
+ "The given keys of function map should all be the same type, but they are " +
+ keys.map(_.dataType.simpleString).mkString("[", ", ", "]"))
} else if (values.map(_.dataType).distinct.length > 1) {
- TypeCheckResult.TypeCheckFailure("The given values of function map should all be the same " +
- "type, but they are " + values.map(_.dataType.simpleString).mkString("[", ", ", "]"))
+ TypeCheckResult.TypeCheckFailure(
+ "The given values of function map should all be the same type, but they are " +
+ values.map(_.dataType.simpleString).mkString("[", ", ", "]"))
} else {
TypeCheckResult.TypeCheckSuccess
}
@@ -183,8 +187,8 @@ case class CreateMap(children: Seq[Expression]) extends Expression {
override def dataType: DataType = {
MapType(
- keyType = keys.headOption.map(_.dataType).getOrElse(NullType),
- valueType = values.headOption.map(_.dataType).getOrElse(NullType),
+ keyType = keys.headOption.map(_.dataType).getOrElse(StringType),
+ valueType = values.headOption.map(_.dataType).getOrElse(StringType),
valueContainsNull = values.exists(_.nullable))
}
@@ -292,14 +296,17 @@ trait CreateNamedStructLike extends Expression {
}
override def checkInputDataTypes(): TypeCheckResult = {
- if (children.size % 2 != 0) {
+ if (children.length < 1) {
+ TypeCheckResult.TypeCheckFailure(
+ s"input to function $prettyName requires at least one argument")
+ } else if (children.size % 2 != 0) {
TypeCheckResult.TypeCheckFailure(s"$prettyName expects an even number of arguments.")
} else {
val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType)
if (invalidNames.nonEmpty) {
TypeCheckResult.TypeCheckFailure(
"Only foldable StringType expressions are allowed to appear at odd position, got:" +
- s" ${invalidNames.mkString(",")}")
+ s" ${invalidNames.mkString(",")}")
} else if (!names.contains(null)) {
TypeCheckResult.TypeCheckSuccess
} else {
http://git-wip-us.apache.org/repos/asf/spark/blob/7896e7b9/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
index ffd0e64..2476fc9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
@@ -247,8 +247,9 @@ abstract class HashExpression[E] extends Expression {
override def nullable: Boolean = false
override def checkInputDataTypes(): TypeCheckResult = {
- if (children.isEmpty) {
- TypeCheckResult.TypeCheckFailure("function hash requires at least one argument")
+ if (children.length < 1) {
+ TypeCheckResult.TypeCheckFailure(
+ s"input to function $prettyName requires at least one argument")
} else {
TypeCheckResult.TypeCheckSuccess
}
http://git-wip-us.apache.org/repos/asf/spark/blob/7896e7b9/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
index 0866b8d..1b62514 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
@@ -52,10 +52,11 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
override def foldable: Boolean = children.forall(_.foldable)
override def checkInputDataTypes(): TypeCheckResult = {
- if (children == Nil) {
- TypeCheckResult.TypeCheckFailure("input to function coalesce cannot be empty")
+ if (children.length < 1) {
+ TypeCheckResult.TypeCheckFailure(
+ s"input to function $prettyName requires at least one argument")
} else {
- TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function coalesce")
+ TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), s"function $prettyName")
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/7896e7b9/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
index 30459f1..3072577 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
@@ -155,7 +155,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
"input to function array should all be the same type")
assertError(Coalesce(Seq('intField, 'booleanField)),
"input to function coalesce should all be the same type")
- assertError(Coalesce(Nil), "input to function coalesce cannot be empty")
+ assertError(Coalesce(Nil), "function coalesce requires at least one argument")
assertError(new Murmur3Hash(Nil), "function hash requires at least one argument")
assertError(Explode('intField),
"input to function explode should be array or map type")
@@ -207,7 +207,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
test("check types for Greatest/Least") {
for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) {
- assertError(operator(Seq('booleanField)), "requires at least 2 arguments")
+ assertError(operator(Seq('booleanField)), "requires at least two arguments")
assertError(operator(Seq('intField, 'stringField)), "should all have the same type")
assertError(operator(Seq('mapField, 'mapField)), "does not support ordering")
}
http://git-wip-us.apache.org/repos/asf/spark/blob/7896e7b9/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 3c67960..1263071 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1565,10 +1565,7 @@ object functions {
* @since 1.5.0
*/
@scala.annotation.varargs
- def greatest(exprs: Column*): Column = withExpr {
- require(exprs.length > 1, "greatest requires at least 2 arguments.")
- Greatest(exprs.map(_.expr))
- }
+ def greatest(exprs: Column*): Column = withExpr { Greatest(exprs.map(_.expr)) }
/**
* Returns the greatest value of the list of column names, skipping null values.
@@ -1672,10 +1669,7 @@ object functions {
* @since 1.5.0
*/
@scala.annotation.varargs
- def least(exprs: Column*): Column = withExpr {
- require(exprs.length > 1, "least requires at least 2 arguments.")
- Least(exprs.map(_.expr))
- }
+ def least(exprs: Column*): Column = withExpr { Least(exprs.map(_.expr)) }
/**
* Returns the least value of the list of column names, skipping null values.
http://git-wip-us.apache.org/repos/asf/spark/blob/7896e7b9/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 0e9a2c6..0681b9c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -448,6 +448,42 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
rand(Random.nextLong()), randn(Random.nextLong())
).foreach(assertValuesDoNotChangeAfterCoalesceOrUnion(_))
}
+
+ test("SPARK-21281 use string types by default if array and map have no argument") {
+ val ds = spark.range(1)
+ var expectedSchema = new StructType()
+ .add("x", ArrayType(StringType, containsNull = false), nullable = false)
+ assert(ds.select(array().as("x")).schema == expectedSchema)
+ expectedSchema = new StructType()
+ .add("x", MapType(StringType, StringType, valueContainsNull = false), nullable = false)
+ assert(ds.select(map().as("x")).schema == expectedSchema)
+ }
+
+ test("SPARK-21281 fails if functions have no argument") {
+ val df = Seq(1).toDF("a")
+
+ val funcsMustHaveAtLeastOneArg =
+ ("coalesce", (df: DataFrame) => df.select(coalesce())) ::
+ ("coalesce", (df: DataFrame) => df.selectExpr("coalesce()")) ::
+ ("named_struct", (df: DataFrame) => df.select(struct())) ::
+ ("named_struct", (df: DataFrame) => df.selectExpr("named_struct()")) ::
+ ("hash", (df: DataFrame) => df.select(hash())) ::
+ ("hash", (df: DataFrame) => df.selectExpr("hash()")) :: Nil
+ funcsMustHaveAtLeastOneArg.foreach { case (name, func) =>
+ val errMsg = intercept[AnalysisException] { func(df) }.getMessage
+ assert(errMsg.contains(s"input to function $name requires at least one argument"))
+ }
+
+ val funcsMustHaveAtLeastTwoArgs =
+ ("greatest", (df: DataFrame) => df.select(greatest())) ::
+ ("greatest", (df: DataFrame) => df.selectExpr("greatest()")) ::
+ ("least", (df: DataFrame) => df.select(least())) ::
+ ("least", (df: DataFrame) => df.selectExpr("least()")) :: Nil
+ funcsMustHaveAtLeastTwoArgs.foreach { case (name, func) =>
+ val errMsg = intercept[AnalysisException] { func(df) }.getMessage
+ assert(errMsg.contains(s"input to function $name requires at least two arguments"))
+ }
+ }
}
object DataFrameFunctionsSuite {
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org