You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2017/07/08 06:05:41 UTC

spark git commit: [SPARK-21281][SQL] Use string types by default if array and map have no argument

Repository: spark
Updated Branches:
  refs/heads/master e1a172c20 -> 7896e7b99


[SPARK-21281][SQL] Use string types by default if array and map have no argument

## What changes were proposed in this pull request?
This pr modified code to use string types by default if `array` and `map` in functions have no argument. This behaviour is the same with Hive one;
```
hive> CREATE TEMPORARY TABLE t1 AS SELECT map();
hive> DESCRIBE t1;
_c0   map<string,string>

hive> CREATE TEMPORARY TABLE t2 AS SELECT array();
hive> DESCRIBE t2;
_c0   array<string>
```

## How was this patch tested?
Added tests in `DataFrameFunctionsSuite`.

Author: Takeshi Yamamuro <ya...@apache.org>

Closes #18516 from maropu/SPARK-21281.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7896e7b9
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7896e7b9
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7896e7b9

Branch: refs/heads/master
Commit: 7896e7b99d95d28800f5644bd36b3990cf0ef8c4
Parents: e1a172c
Author: Takeshi Yamamuro <ya...@apache.org>
Authored: Fri Jul 7 23:05:38 2017 -0700
Committer: gatorsmile <ga...@gmail.com>
Committed: Fri Jul 7 23:05:38 2017 -0700

----------------------------------------------------------------------
 .../sql/catalyst/expressions/arithmetic.scala   | 10 +++---
 .../expressions/complexTypeCreator.scala        | 35 +++++++++++--------
 .../spark/sql/catalyst/expressions/hash.scala   |  5 +--
 .../catalyst/expressions/nullExpressions.scala  |  7 ++--
 .../analysis/ExpressionTypeCheckingSuite.scala  |  4 +--
 .../scala/org/apache/spark/sql/functions.scala  | 10 ++----
 .../spark/sql/DataFrameFunctionsSuite.scala     | 36 ++++++++++++++++++++
 7 files changed, 74 insertions(+), 33 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7896e7b9/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index ec6e6ba..423bf66 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -527,13 +527,14 @@ case class Least(children: Seq[Expression]) extends Expression {
 
   override def checkInputDataTypes(): TypeCheckResult = {
     if (children.length <= 1) {
-      TypeCheckResult.TypeCheckFailure(s"LEAST requires at least 2 arguments")
+      TypeCheckResult.TypeCheckFailure(
+        s"input to function $prettyName requires at least two arguments")
     } else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) {
       TypeCheckResult.TypeCheckFailure(
         s"The expressions should all have the same type," +
           s" got LEAST(${children.map(_.dataType.simpleString).mkString(", ")}).")
     } else {
-      TypeUtils.checkForOrderingExpr(dataType, "function " + prettyName)
+      TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName")
     }
   }
 
@@ -592,13 +593,14 @@ case class Greatest(children: Seq[Expression]) extends Expression {
 
   override def checkInputDataTypes(): TypeCheckResult = {
     if (children.length <= 1) {
-      TypeCheckResult.TypeCheckFailure(s"GREATEST requires at least 2 arguments")
+      TypeCheckResult.TypeCheckFailure(
+        s"input to function $prettyName requires at least two arguments")
     } else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) {
       TypeCheckResult.TypeCheckFailure(
         s"The expressions should all have the same type," +
           s" got GREATEST(${children.map(_.dataType.simpleString).mkString(", ")}).")
     } else {
-      TypeUtils.checkForOrderingExpr(dataType, "function " + prettyName)
+      TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName")
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/7896e7b9/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 98c4cbe..d9eeb53 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -41,12 +41,13 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
 
   override def foldable: Boolean = children.forall(_.foldable)
 
-  override def checkInputDataTypes(): TypeCheckResult =
-    TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function array")
+  override def checkInputDataTypes(): TypeCheckResult = {
+    TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), s"function $prettyName")
+  }
 
   override def dataType: ArrayType = {
     ArrayType(
-      children.headOption.map(_.dataType).getOrElse(NullType),
+      children.headOption.map(_.dataType).getOrElse(StringType),
       containsNull = children.exists(_.nullable))
   }
 
@@ -93,7 +94,7 @@ private [sql] object GenArrayData {
     if (!ctx.isPrimitiveType(elementType)) {
       val genericArrayClass = classOf[GenericArrayData].getName
       ctx.addMutableState("Object[]", arrayName,
-        s"$arrayName = new Object[${numElements}];")
+        s"$arrayName = new Object[$numElements];")
 
       val assignments = elementsCode.zipWithIndex.map { case (eval, i) =>
         val isNullAssignment = if (!isMapKey) {
@@ -119,7 +120,7 @@ private [sql] object GenArrayData {
         UnsafeArrayData.calculateHeaderPortionInBytes(numElements) +
         ByteArrayMethods.roundNumberOfBytesToNearestWord(elementType.defaultSize * numElements)
       val baseOffset = Platform.BYTE_ARRAY_OFFSET
-      ctx.addMutableState("UnsafeArrayData", arrayDataName, "");
+      ctx.addMutableState("UnsafeArrayData", arrayDataName, "")
 
       val primitiveValueTypeName = ctx.primitiveTypeName(elementType)
       val assignments = elementsCode.zipWithIndex.map { case (eval, i) =>
@@ -169,13 +170,16 @@ case class CreateMap(children: Seq[Expression]) extends Expression {
 
   override def checkInputDataTypes(): TypeCheckResult = {
     if (children.size % 2 != 0) {
-      TypeCheckResult.TypeCheckFailure(s"$prettyName expects a positive even number of arguments.")
+      TypeCheckResult.TypeCheckFailure(
+        s"$prettyName expects a positive even number of arguments.")
     } else if (keys.map(_.dataType).distinct.length > 1) {
-      TypeCheckResult.TypeCheckFailure("The given keys of function map should all be the same " +
-        "type, but they are " + keys.map(_.dataType.simpleString).mkString("[", ", ", "]"))
+      TypeCheckResult.TypeCheckFailure(
+        "The given keys of function map should all be the same type, but they are " +
+          keys.map(_.dataType.simpleString).mkString("[", ", ", "]"))
     } else if (values.map(_.dataType).distinct.length > 1) {
-      TypeCheckResult.TypeCheckFailure("The given values of function map should all be the same " +
-        "type, but they are " + values.map(_.dataType.simpleString).mkString("[", ", ", "]"))
+      TypeCheckResult.TypeCheckFailure(
+        "The given values of function map should all be the same type, but they are " +
+          values.map(_.dataType.simpleString).mkString("[", ", ", "]"))
     } else {
       TypeCheckResult.TypeCheckSuccess
     }
@@ -183,8 +187,8 @@ case class CreateMap(children: Seq[Expression]) extends Expression {
 
   override def dataType: DataType = {
     MapType(
-      keyType = keys.headOption.map(_.dataType).getOrElse(NullType),
-      valueType = values.headOption.map(_.dataType).getOrElse(NullType),
+      keyType = keys.headOption.map(_.dataType).getOrElse(StringType),
+      valueType = values.headOption.map(_.dataType).getOrElse(StringType),
       valueContainsNull = values.exists(_.nullable))
   }
 
@@ -292,14 +296,17 @@ trait CreateNamedStructLike extends Expression {
   }
 
   override def checkInputDataTypes(): TypeCheckResult = {
-    if (children.size % 2 != 0) {
+    if (children.length < 1) {
+      TypeCheckResult.TypeCheckFailure(
+        s"input to function $prettyName requires at least one argument")
+    } else if (children.size % 2 != 0) {
       TypeCheckResult.TypeCheckFailure(s"$prettyName expects an even number of arguments.")
     } else {
       val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType)
       if (invalidNames.nonEmpty) {
         TypeCheckResult.TypeCheckFailure(
           "Only foldable StringType expressions are allowed to appear at odd position, got:" +
-          s" ${invalidNames.mkString(",")}")
+            s" ${invalidNames.mkString(",")}")
       } else if (!names.contains(null)) {
         TypeCheckResult.TypeCheckSuccess
       } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/7896e7b9/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
index ffd0e64..2476fc9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
@@ -247,8 +247,9 @@ abstract class HashExpression[E] extends Expression {
   override def nullable: Boolean = false
 
   override def checkInputDataTypes(): TypeCheckResult = {
-    if (children.isEmpty) {
-      TypeCheckResult.TypeCheckFailure("function hash requires at least one argument")
+    if (children.length < 1) {
+      TypeCheckResult.TypeCheckFailure(
+        s"input to function $prettyName requires at least one argument")
     } else {
       TypeCheckResult.TypeCheckSuccess
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/7896e7b9/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
index 0866b8d..1b62514 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
@@ -52,10 +52,11 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
   override def foldable: Boolean = children.forall(_.foldable)
 
   override def checkInputDataTypes(): TypeCheckResult = {
-    if (children == Nil) {
-      TypeCheckResult.TypeCheckFailure("input to function coalesce cannot be empty")
+    if (children.length < 1) {
+      TypeCheckResult.TypeCheckFailure(
+        s"input to function $prettyName requires at least one argument")
     } else {
-      TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function coalesce")
+      TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), s"function $prettyName")
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/7896e7b9/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
index 30459f1..3072577 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
@@ -155,7 +155,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
       "input to function array should all be the same type")
     assertError(Coalesce(Seq('intField, 'booleanField)),
       "input to function coalesce should all be the same type")
-    assertError(Coalesce(Nil), "input to function coalesce cannot be empty")
+    assertError(Coalesce(Nil), "function coalesce requires at least one argument")
     assertError(new Murmur3Hash(Nil), "function hash requires at least one argument")
     assertError(Explode('intField),
       "input to function explode should be array or map type")
@@ -207,7 +207,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
 
   test("check types for Greatest/Least") {
     for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) {
-      assertError(operator(Seq('booleanField)), "requires at least 2 arguments")
+      assertError(operator(Seq('booleanField)), "requires at least two arguments")
       assertError(operator(Seq('intField, 'stringField)), "should all have the same type")
       assertError(operator(Seq('mapField, 'mapField)), "does not support ordering")
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/7896e7b9/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 3c67960..1263071 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1565,10 +1565,7 @@ object functions {
    * @since 1.5.0
    */
   @scala.annotation.varargs
-  def greatest(exprs: Column*): Column = withExpr {
-    require(exprs.length > 1, "greatest requires at least 2 arguments.")
-    Greatest(exprs.map(_.expr))
-  }
+  def greatest(exprs: Column*): Column = withExpr { Greatest(exprs.map(_.expr)) }
 
   /**
    * Returns the greatest value of the list of column names, skipping null values.
@@ -1672,10 +1669,7 @@ object functions {
    * @since 1.5.0
    */
   @scala.annotation.varargs
-  def least(exprs: Column*): Column = withExpr {
-    require(exprs.length > 1, "least requires at least 2 arguments.")
-    Least(exprs.map(_.expr))
-  }
+  def least(exprs: Column*): Column = withExpr { Least(exprs.map(_.expr)) }
 
   /**
    * Returns the least value of the list of column names, skipping null values.

http://git-wip-us.apache.org/repos/asf/spark/blob/7896e7b9/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 0e9a2c6..0681b9c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -448,6 +448,42 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
       rand(Random.nextLong()), randn(Random.nextLong())
     ).foreach(assertValuesDoNotChangeAfterCoalesceOrUnion(_))
   }
+
+  test("SPARK-21281 use string types by default if array and map have no argument") {
+    val ds = spark.range(1)
+    var expectedSchema = new StructType()
+      .add("x", ArrayType(StringType, containsNull = false), nullable = false)
+    assert(ds.select(array().as("x")).schema == expectedSchema)
+    expectedSchema = new StructType()
+      .add("x", MapType(StringType, StringType, valueContainsNull = false), nullable = false)
+    assert(ds.select(map().as("x")).schema == expectedSchema)
+  }
+
+  test("SPARK-21281 fails if functions have no argument") {
+    val df = Seq(1).toDF("a")
+
+    val funcsMustHaveAtLeastOneArg =
+      ("coalesce", (df: DataFrame) => df.select(coalesce())) ::
+      ("coalesce", (df: DataFrame) => df.selectExpr("coalesce()")) ::
+      ("named_struct", (df: DataFrame) => df.select(struct())) ::
+      ("named_struct", (df: DataFrame) => df.selectExpr("named_struct()")) ::
+      ("hash", (df: DataFrame) => df.select(hash())) ::
+      ("hash", (df: DataFrame) => df.selectExpr("hash()")) :: Nil
+    funcsMustHaveAtLeastOneArg.foreach { case (name, func) =>
+      val errMsg = intercept[AnalysisException] { func(df) }.getMessage
+      assert(errMsg.contains(s"input to function $name requires at least one argument"))
+    }
+
+    val funcsMustHaveAtLeastTwoArgs =
+      ("greatest", (df: DataFrame) => df.select(greatest())) ::
+      ("greatest", (df: DataFrame) => df.selectExpr("greatest()")) ::
+      ("least", (df: DataFrame) => df.select(least())) ::
+      ("least", (df: DataFrame) => df.selectExpr("least()")) :: Nil
+    funcsMustHaveAtLeastTwoArgs.foreach { case (name, func) =>
+      val errMsg = intercept[AnalysisException] { func(df) }.getMessage
+      assert(errMsg.contains(s"input to function $name requires at least two arguments"))
+    }
+  }
 }
 
 object DataFrameFunctionsSuite {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org