You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2018/09/24 13:37:57 UTC

spark git commit: [SPARK-25416][SQL] ArrayPosition function may return incorrect result when right expression is implicitly down casted

Repository: spark
Updated Branches:
  refs/heads/master 804515f82 -> bb49661e1


[SPARK-25416][SQL] ArrayPosition function may return incorrect result when right expression is implicitly down casted

## What changes were proposed in this pull request?
In ArrayPosition, we currently cast the right hand side expression to match the element type of the left hand side Array. This may result in down casting and may return wrong result or questionable result.

Example :
```SQL
spark-sql> select array_position(array(1), 1.34);
1
```
```SQL
spark-sql> select array_position(array(1), 'foo');
null
```

We should safely coerce both left and right hand side expressions.
## How was this patch tested?
Added tests in DataFrameFunctionsSuite

Closes #22407 from dilipbiswal/SPARK-25416.

Authored-by: Dilip Biswal <db...@us.ibm.com>
Signed-off-by: Wenchen Fan <we...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/bb49661e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/bb49661e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/bb49661e

Branch: refs/heads/master
Commit: bb49661e192eed78a8a306deffd83c73bd4a9eff
Parents: 804515f
Author: Dilip Biswal <db...@us.ibm.com>
Authored: Mon Sep 24 21:37:51 2018 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Mon Sep 24 21:37:51 2018 +0800

----------------------------------------------------------------------
 .../expressions/collectionOperations.scala      | 21 +++++---
 .../spark/sql/DataFrameFunctionsSuite.scala     | 57 +++++++++++++++++---
 2 files changed, 64 insertions(+), 14 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/bb49661e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 161adc9..85bc1cd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -2071,18 +2071,23 @@ case class ArrayPosition(left: Expression, right: Expression)
   override def dataType: DataType = LongType
 
   override def inputTypes: Seq[AbstractDataType] = {
-    val elementType = left.dataType match {
-      case t: ArrayType => t.elementType
-      case _ => AnyDataType
+    (left.dataType, right.dataType) match {
+      case (ArrayType(e1, hasNull), e2) =>
+        TypeCoercion.findTightestCommonType(e1, e2) match {
+          case Some(dt) => Seq(ArrayType(dt, hasNull), dt)
+          case _ => Seq.empty
+        }
+      case _ => Seq.empty
     }
-    Seq(ArrayType, elementType)
   }
 
   override def checkInputDataTypes(): TypeCheckResult = {
-    super.checkInputDataTypes() match {
-      case f: TypeCheckResult.TypeCheckFailure => f
-      case TypeCheckResult.TypeCheckSuccess =>
-        TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName")
+    (left.dataType, right.dataType) match {
+      case (ArrayType(e1, _), e2) if e1.sameType(e2) =>
+        TypeUtils.checkForOrderingExpr(e2, s"function $prettyName")
+      case _ => TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " +
+        s"been ${ArrayType.simpleString} followed by a value with same element type, but it's " +
+        s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].")
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/bb49661e/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index ad52fd0..fd71f24 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -1097,18 +1097,63 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
     )
 
     checkAnswer(
-      df.selectExpr("array_position(array(array(1), null)[0], 1)"),
-      Seq(Row(1L), Row(1L))
+      OneRowRelation().selectExpr("array_position(array(1), 1.23D)"),
+      Seq(Row(0L))
     )
+
     checkAnswer(
-      df.selectExpr("array_position(array(1, null), array(1, null)[0])"),
-      Seq(Row(1L), Row(1L))
+      OneRowRelation().selectExpr("array_position(array(1), 1.0D)"),
+      Seq(Row(1L))
     )
 
-    val e = intercept[AnalysisException] {
+    checkAnswer(
+      OneRowRelation().selectExpr("array_position(array(1.D), 1)"),
+      Seq(Row(1L))
+    )
+
+    checkAnswer(
+      OneRowRelation().selectExpr("array_position(array(1.23D), 1)"),
+      Seq(Row(0L))
+    )
+
+    checkAnswer(
+      OneRowRelation().selectExpr("array_position(array(array(1)), array(1.0D))"),
+      Seq(Row(1L))
+    )
+
+    checkAnswer(
+      OneRowRelation().selectExpr("array_position(array(array(1)), array(1.23D))"),
+      Seq(Row(0L))
+    )
+
+    checkAnswer(
+      OneRowRelation().selectExpr("array_position(array(array(1), null)[0], 1)"),
+      Seq(Row(1L))
+    )
+    checkAnswer(
+      OneRowRelation().selectExpr("array_position(array(1, null), array(1, null)[0])"),
+      Seq(Row(1L))
+    )
+
+    val e1 = intercept[AnalysisException] {
       Seq(("a string element", "a")).toDF().selectExpr("array_position(_1, _2)")
     }
-    assert(e.message.contains("argument 1 requires array type, however, '`_1`' is of string type"))
+    val errorMsg1 =
+      s"""
+         |Input to function array_position should have been array followed by a
+         |value with same element type, but it's [string, string].
+       """.stripMargin.replace("\n", " ").trim()
+    assert(e1.message.contains(errorMsg1))
+
+    val e2 = intercept[AnalysisException] {
+      OneRowRelation().selectExpr("array_position(array(1), '1')")
+    }
+    val errorMsg2 =
+      s"""
+         |Input to function array_position should have been array followed by a
+         |value with same element type, but it's [array<int>, string].
+       """.stripMargin.replace("\n", " ").trim()
+    assert(e2.message.contains(errorMsg2))
   }
 
   test("element_at function") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org