You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2018/09/27 07:04:54 UTC

spark git commit: [SPARK-25522][SQL] Improve type promotion for input arguments of elementAt function

Repository: spark
Updated Branches:
  refs/heads/master ff876137f -> d03e0af80


[SPARK-25522][SQL] Improve type promotion for input arguments of elementAt function

## What changes were proposed in this pull request?
In ElementAt, when first argument is MapType, we should coerce the key type and the second argument based on findTightestCommonType. This is not happening currently. We may produce wrong output as we will incorrectly downcast the right hand side double expression to int.

```SQL
spark-sql> select element_at(map(1,"one", 2, "two"), 2.2);

two
```

Also, when the first argument is ArrayType, the second argument should be an integer type or a smaller integral type that can be safely casted to an integer type. Currently we may do an unsafe cast. In the following case, we should fail with an error as 2.2 is not a integer index. But instead we down cast it to int currently and return a result instead.

```SQL
spark-sql> select element_at(array(1,2), 1.24D);

1
```
This PR also supports implicit cast between two MapTypes. I have followed similar logic that exists today to do implicit casts between two array types.
## How was this patch tested?
Added new tests in DataFrameFunctionSuite, TypeCoercionSuite.

Closes #22544 from dilipbiswal/SPARK-25522.

Authored-by: Dilip Biswal <db...@us.ibm.com>
Signed-off-by: Wenchen Fan <we...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d03e0af8
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d03e0af8
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d03e0af8

Branch: refs/heads/master
Commit: d03e0af80d7659f12821cc2442efaeaee94d3985
Parents: ff87613
Author: Dilip Biswal <db...@us.ibm.com>
Authored: Thu Sep 27 15:04:59 2018 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Thu Sep 27 15:04:59 2018 +0800

----------------------------------------------------------------------
 .../sql/catalyst/analysis/TypeCoercion.scala    | 19 +++++
 .../spark/sql/catalyst/expressions/Cast.scala   |  2 +-
 .../expressions/collectionOperations.scala      | 37 ++++++----
 .../catalyst/analysis/TypeCoercionSuite.scala   | 43 +++++++++--
 .../spark/sql/DataFrameFunctionsSuite.scala     | 75 +++++++++++++++++++-
 5 files changed, 154 insertions(+), 22 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d03e0af8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index 49d286f..72ac80e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -950,6 +950,25 @@ object TypeCoercion {
             if !Cast.forceNullable(fromType, toType) =>
           implicitCast(fromType, toType).map(ArrayType(_, false)).orNull
 
+        // Implicit cast between Map types.
+        // Follows the same semantics of implicit casting between two array types.
+        // Refer to documentation above. Make sure that both key and values
+        // can not be null after the implicit cast operation by calling forceNullable
+        // method.
+        case (MapType(fromKeyType, fromValueType, fn), MapType(toKeyType, toValueType, tn))
+            if !Cast.forceNullable(fromKeyType, toKeyType) && Cast.resolvableNullability(fn, tn) =>
+          if (Cast.forceNullable(fromValueType, toValueType) && !tn) {
+            null
+          } else {
+            val newKeyType = implicitCast(fromKeyType, toKeyType).orNull
+            val newValueType = implicitCast(fromValueType, toValueType).orNull
+            if (newKeyType != null && newValueType != null) {
+              MapType(newKeyType, newValueType, tn)
+            } else {
+              null
+            }
+          }
+
         case _ => null
       }
       Option(ret)

http://git-wip-us.apache.org/repos/asf/spark/blob/d03e0af8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 8f77799..ee463bf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -183,7 +183,7 @@ object Cast {
     case _ => false
   }
 
-  private def resolvableNullability(from: Boolean, to: Boolean) = !from || to
+  def resolvableNullability(from: Boolean, to: Boolean): Boolean = !from || to
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/d03e0af8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 9cc7dba..b24d748 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -2154,21 +2154,34 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti
   }
 
   override def inputTypes: Seq[AbstractDataType] = {
-    Seq(TypeCollection(ArrayType, MapType),
-      left.dataType match {
-        case _: ArrayType => IntegerType
-        case _: MapType => mapKeyType
-        case _ => AnyDataType // no match for a wrong 'left' expression type
-      }
-    )
+    (left.dataType, right.dataType) match {
+      case (arr: ArrayType, e2: IntegralType) if (e2 != LongType) =>
+        Seq(arr, IntegerType)
+      case (MapType(keyType, valueType, hasNull), e2) =>
+        TypeCoercion.findTightestCommonType(keyType, e2) match {
+          case Some(dt) => Seq(MapType(dt, valueType, hasNull), dt)
+          case _ => Seq.empty
+        }
+      case (l, r) => Seq.empty
+
+    }
   }
 
   override def checkInputDataTypes(): TypeCheckResult = {
-    super.checkInputDataTypes() match {
-      case f: TypeCheckResult.TypeCheckFailure => f
-      case TypeCheckResult.TypeCheckSuccess if left.dataType.isInstanceOf[MapType] =>
-        TypeUtils.checkForOrderingExpr(mapKeyType, s"function $prettyName")
-      case TypeCheckResult.TypeCheckSuccess => TypeCheckResult.TypeCheckSuccess
+    (left.dataType, right.dataType) match {
+      case (_: ArrayType, e2) if e2 != IntegerType =>
+        TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " +
+          s"been ${ArrayType.simpleString} followed by a ${IntegerType.simpleString}, but it's " +
+          s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].")
+      case (MapType(e1, _, _), e2) if (!e2.sameType(e1)) =>
+        TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " +
+          s"been ${MapType.simpleString} followed by a value of same key type, but it's " +
+          s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].")
+      case (e1, _) if (!e1.isInstanceOf[MapType] && !e1.isInstanceOf[ArrayType]) =>
+        TypeCheckResult.TypeCheckFailure(s"The first argument to function $prettyName should " +
+          s"have been ${ArrayType.simpleString} or ${MapType.simpleString} type, but its " +
+          s"${left.dataType.catalogString} type.")
+      case _ => TypeCheckResult.TypeCheckSuccess
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/d03e0af8/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
index 0594673..0eba1c5 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
@@ -257,12 +257,43 @@ class TypeCoercionSuite extends AnalysisTest {
     shouldNotCast(checkedType, IntegralType)
   }
 
-  test("implicit type cast - MapType(StringType, StringType)") {
-    val checkedType = MapType(StringType, StringType)
-    checkTypeCasting(checkedType, castableTypes = Seq(checkedType))
-    shouldNotCast(checkedType, DecimalType)
-    shouldNotCast(checkedType, NumericType)
-    shouldNotCast(checkedType, IntegralType)
+  test("implicit type cast between two Map types") {
+    val sourceType = MapType(IntegerType, IntegerType, true)
+    val castableTypes = numericTypes ++ Seq(StringType).filter(!Cast.forceNullable(IntegerType, _))
+    val targetTypes = numericTypes.filter(!Cast.forceNullable(IntegerType, _)).map { t =>
+      MapType(t, sourceType.valueType, valueContainsNull = true)
+    }
+    val nonCastableTargetTypes = allTypes.filterNot(castableTypes.contains(_)).map {t =>
+      MapType(t, sourceType.valueType, valueContainsNull = true)
+    }
+
+    // Tests that its possible to setup implicit casts between two map types when
+    // source map's key type is integer and the target map's key type are either Byte, Short,
+    // Long, Double, Float, Decimal(38, 18) or String.
+    targetTypes.foreach { targetType =>
+      shouldCast(sourceType, targetType, targetType)
+    }
+
+    // Tests that its not possible to setup implicit casts between two map types when
+    // source map's key type is integer and the target map's key type are either Binary,
+    // Boolean, Date, Timestamp, Array, Struct, CaleandarIntervalType or NullType
+    nonCastableTargetTypes.foreach { targetType =>
+      shouldNotCast(sourceType, targetType)
+    }
+
+    // Tests that its not possible to cast from nullable map type to not nullable map type.
+    val targetNotNullableTypes = allTypes.filterNot(_ == IntegerType).map { t =>
+      MapType(t, sourceType.valueType, valueContainsNull = false)
+    }
+    val sourceMapExprWithValueNull =
+      CreateMap(Seq(Literal.default(sourceType.keyType),
+        Literal.create(null, sourceType.valueType)))
+    targetNotNullableTypes.foreach { targetType =>
+      val castDefault =
+        TypeCoercion.ImplicitTypeCasts.implicitCast(sourceMapExprWithValueNull, targetType)
+      assert(castDefault.isEmpty,
+        s"Should not be able to cast $sourceType to $targetType, but got $castDefault")
+    }
   }
 
   test("implicit type cast - StructType().add(\"a1\", StringType)") {

http://git-wip-us.apache.org/repos/asf/spark/blob/d03e0af8/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 88dbae8..60ebc5e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -1211,11 +1211,80 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
       Seq(Row("3"), Row(""), Row(null))
     )
 
-    val e = intercept[AnalysisException] {
+    val e1 = intercept[AnalysisException] {
       Seq(("a string element", 1)).toDF().selectExpr("element_at(_1, _2)")
     }
-    assert(e.message.contains(
-      "argument 1 requires (array or map) type, however, '`_1`' is of string type"))
+    val errorMsg1 =
+      s"""
+         |The first argument to function element_at should have been array or map type, but
+         |its string type.
+       """.stripMargin.replace("\n", " ").trim()
+    assert(e1.message.contains(errorMsg1))
+
+    checkAnswer(
+      OneRowRelation().selectExpr("element_at(array(2, 1), 2S)"),
+      Seq(Row(1))
+    )
+
+    checkAnswer(
+      OneRowRelation().selectExpr("element_at(array('a', 'b'), 1Y)"),
+      Seq(Row("a"))
+    )
+
+    checkAnswer(
+      OneRowRelation().selectExpr("element_at(array(1, 2, 3), 3)"),
+      Seq(Row(3))
+    )
+
+    val e2 = intercept[AnalysisException] {
+      OneRowRelation().selectExpr("element_at(array('a', 'b'), 1L)")
+    }
+    val errorMsg2 =
+      s"""
+         |Input to function element_at should have been array followed by a int, but it's
+         |[array<string>, bigint].
+       """.stripMargin.replace("\n", " ").trim()
+    assert(e2.message.contains(errorMsg2))
+
+    checkAnswer(
+      OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), 2Y)"),
+      Seq(Row("b"))
+    )
+
+    checkAnswer(
+      OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), 1S)"),
+      Seq(Row("a"))
+    )
+
+    checkAnswer(
+      OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), 2)"),
+      Seq(Row("b"))
+    )
+
+    checkAnswer(
+      OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), 2L)"),
+      Seq(Row("b"))
+    )
+
+    checkAnswer(
+      OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), 1.0D)"),
+      Seq(Row("a"))
+    )
+
+    checkAnswer(
+      OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), 1.23D)"),
+      Seq(Row(null))
+    )
+
+    val e3 = intercept[AnalysisException] {
+      OneRowRelation().selectExpr("element_at(map(1, 'a', 2, 'b'), '1')")
+    }
+    val errorMsg3 =
+      s"""
+         |Input to function element_at should have been map followed by a value of same
+         |key type, but it's [map<int,string>, string].
+       """.stripMargin.replace("\n", " ").trim()
+    assert(e3.message.contains(errorMsg3))
   }
 
   test("array_union functions") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org