You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2018/05/22 13:08:54 UTC

spark git commit: [SPARK-24313][SQL] Fix collection operations' interpreted evaluation for complex types

Repository: spark
Updated Branches:
  refs/heads/master a4470bc78 -> d3d180731


[SPARK-24313][SQL] Fix collection operations' interpreted evaluation for complex types

## What changes were proposed in this pull request?

The interpreted evaluation of several collection operations works only for simple datatypes. For complex data types, for instance, `array_contains` it returns always `false`. The list of the affected functions is `array_contains`, `array_position`, `element_at` and `GetMapValue`.

The PR fixes the behavior for all the datatypes.

## How was this patch tested?

added UT

Author: Marco Gaido <ma...@gmail.com>

Closes #21361 from mgaido91/SPARK-24313.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d3d18073
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d3d18073
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d3d18073

Branch: refs/heads/master
Commit: d3d18073152cab4408464d1417ec644d939cfdf7
Parents: a4470bc
Author: Marco Gaido <ma...@gmail.com>
Authored: Tue May 22 21:08:49 2018 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Tue May 22 21:08:49 2018 +0800

----------------------------------------------------------------------
 .../expressions/collectionOperations.scala      | 41 ++++++++++++----
 .../expressions/complexTypeExtractors.scala     | 19 ++++++--
 .../CollectionExpressionsSuite.scala            | 49 +++++++++++++++++++-
 .../catalyst/optimizer/complexTypesSuite.scala  | 13 ++++++
 .../org/apache/spark/sql/DataFrameSuite.scala   |  5 ++
 5 files changed, 113 insertions(+), 14 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d3d18073/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 8d763dc..7da4c3c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -657,6 +657,9 @@ case class ArrayContains(left: Expression, right: Expression)
 
   override def dataType: DataType = BooleanType
 
+  @transient private lazy val ordering: Ordering[Any] =
+    TypeUtils.getInterpretedOrdering(right.dataType)
+
   override def inputTypes: Seq[AbstractDataType] = right.dataType match {
     case NullType => Seq.empty
     case _ => left.dataType match {
@@ -673,7 +676,7 @@ case class ArrayContains(left: Expression, right: Expression)
       TypeCheckResult.TypeCheckFailure(
         "Arguments must be an array followed by a value of same type as the array members")
     } else {
-      TypeCheckResult.TypeCheckSuccess
+      TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName")
     }
   }
 
@@ -686,7 +689,7 @@ case class ArrayContains(left: Expression, right: Expression)
     arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) =>
       if (v == null) {
         hasNull = true
-      } else if (v == value) {
+      } else if (ordering.equiv(v, value)) {
         return true
       }
     )
@@ -735,11 +738,7 @@ case class ArraysOverlap(left: Expression, right: Expression)
 
   override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match {
     case TypeCheckResult.TypeCheckSuccess =>
-      if (RowOrdering.isOrderable(elementType)) {
-        TypeCheckResult.TypeCheckSuccess
-      } else {
-        TypeCheckResult.TypeCheckFailure(s"${elementType.simpleString} cannot be used in comparison.")
-      }
+      TypeUtils.checkForOrderingExpr(elementType, s"function $prettyName")
     case failure => failure
   }
 
@@ -1391,13 +1390,24 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast
 case class ArrayPosition(left: Expression, right: Expression)
   extends BinaryExpression with ImplicitCastInputTypes {
 
+  @transient private lazy val ordering: Ordering[Any] =
+    TypeUtils.getInterpretedOrdering(right.dataType)
+
   override def dataType: DataType = LongType
   override def inputTypes: Seq[AbstractDataType] =
     Seq(ArrayType, left.dataType.asInstanceOf[ArrayType].elementType)
 
+  override def checkInputDataTypes(): TypeCheckResult = {
+    super.checkInputDataTypes() match {
+      case f: TypeCheckResult.TypeCheckFailure => f
+      case TypeCheckResult.TypeCheckSuccess =>
+        TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName")
+    }
+  }
+
   override def nullSafeEval(arr: Any, value: Any): Any = {
     arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) =>
-      if (v == value) {
+      if (v != null && ordering.equiv(v, value)) {
         return (i + 1).toLong
       }
     )
@@ -1446,6 +1456,9 @@ case class ArrayPosition(left: Expression, right: Expression)
   since = "2.4.0")
 case class ElementAt(left: Expression, right: Expression) extends GetMapValueUtil {
 
+  @transient private lazy val ordering: Ordering[Any] =
+    TypeUtils.getInterpretedOrdering(left.dataType.asInstanceOf[MapType].keyType)
+
   override def dataType: DataType = left.dataType match {
     case ArrayType(elementType, _) => elementType
     case MapType(_, valueType, _) => valueType
@@ -1460,6 +1473,16 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti
     )
   }
 
+  override def checkInputDataTypes(): TypeCheckResult = {
+    super.checkInputDataTypes() match {
+      case f: TypeCheckResult.TypeCheckFailure => f
+      case TypeCheckResult.TypeCheckSuccess if left.dataType.isInstanceOf[MapType] =>
+        TypeUtils.checkForOrderingExpr(
+          left.dataType.asInstanceOf[MapType].keyType, s"function $prettyName")
+      case TypeCheckResult.TypeCheckSuccess => TypeCheckResult.TypeCheckSuccess
+    }
+  }
+
   override def nullable: Boolean = true
 
   override def nullSafeEval(value: Any, ordinal: Any): Any = {
@@ -1484,7 +1507,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti
           }
         }
       case _: MapType =>
-        getValueEval(value, ordinal, left.dataType.asInstanceOf[MapType].keyType)
+        getValueEval(value, ordinal, left.dataType.asInstanceOf[MapType].keyType, ordering)
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/d3d18073/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index 3fba52d..99671d5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -21,7 +21,7 @@ import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis._
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
-import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData}
+import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData, TypeUtils}
 import org.apache.spark.sql.types._
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -273,7 +273,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
 
 abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes {
   // todo: current search is O(n), improve it.
-  def getValueEval(value: Any, ordinal: Any, keyType: DataType): Any = {
+  def getValueEval(value: Any, ordinal: Any, keyType: DataType, ordering: Ordering[Any]): Any = {
     val map = value.asInstanceOf[MapData]
     val length = map.numElements()
     val keys = map.keyArray()
@@ -282,7 +282,7 @@ abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTy
     var i = 0
     var found = false
     while (i < length && !found) {
-      if (keys.get(i, keyType) == ordinal) {
+      if (ordering.equiv(keys.get(i, keyType), ordinal)) {
         found = true
       } else {
         i += 1
@@ -345,8 +345,19 @@ abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTy
 case class GetMapValue(child: Expression, key: Expression)
   extends GetMapValueUtil with ExtractValue with NullIntolerant {
 
+  @transient private lazy val ordering: Ordering[Any] =
+    TypeUtils.getInterpretedOrdering(keyType)
+
   private def keyType = child.dataType.asInstanceOf[MapType].keyType
 
+  override def checkInputDataTypes(): TypeCheckResult = {
+    super.checkInputDataTypes() match {
+      case f: TypeCheckResult.TypeCheckFailure => f
+      case TypeCheckResult.TypeCheckSuccess =>
+        TypeUtils.checkForOrderingExpr(keyType, s"function $prettyName")
+    }
+  }
+
   // We have done type checking for child in `ExtractValue`, so only need to check the `key`.
   override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType)
 
@@ -363,7 +374,7 @@ case class GetMapValue(child: Expression, key: Expression)
 
   // todo: current search is O(n), improve it.
   override def nullSafeEval(value: Any, ordinal: Any): Any = {
-    getValueEval(value, ordinal, keyType)
+    getValueEval(value, ordinal, keyType, ordering)
   }
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {

http://git-wip-us.apache.org/repos/asf/spark/blob/d3d18073/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index 71ff96b..3fc0b08 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -157,6 +157,33 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
 
     checkEvaluation(ArrayContains(a3, Literal("")), null)
     checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null)
+
+    // binary
+    val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)),
+      ArrayType(BinaryType))
+    val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](4, 3)),
+      ArrayType(BinaryType))
+    val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null),
+      ArrayType(BinaryType))
+    val b3 = Literal.create(Seq[Array[Byte]](null, Array[Byte](1, 2)),
+      ArrayType(BinaryType))
+    val be = Literal.create(Array[Byte](1, 2), BinaryType)
+    val nullBinary = Literal.create(null, BinaryType)
+
+    checkEvaluation(ArrayContains(b0, be), true)
+    checkEvaluation(ArrayContains(b1, be), false)
+    checkEvaluation(ArrayContains(b0, nullBinary), null)
+    checkEvaluation(ArrayContains(b2, be), null)
+    checkEvaluation(ArrayContains(b3, be), true)
+
+    // complex data types
+    val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)),
+      ArrayType(ArrayType(IntegerType)))
+    val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)),
+      ArrayType(ArrayType(IntegerType)))
+    val aae = Literal.create(Seq[Int](1, 2), ArrayType(IntegerType))
+    checkEvaluation(ArrayContains(aa0, aae), true)
+    checkEvaluation(ArrayContains(aa1, aae), false)
   }
 
   test("ArraysOverlap") {
@@ -372,6 +399,14 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
 
     checkEvaluation(ArrayPosition(a3, Literal("")), null)
     checkEvaluation(ArrayPosition(a3, Literal.create(null, StringType)), null)
+
+    val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)),
+      ArrayType(ArrayType(IntegerType)))
+    val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)),
+      ArrayType(ArrayType(IntegerType)))
+    val aae = Literal.create(Seq[Int](1, 2), ArrayType(IntegerType))
+    checkEvaluation(ArrayPosition(aa0, aae), 1L)
+    checkEvaluation(ArrayPosition(aa1, aae), 0L)
   }
 
   test("elementAt") {
@@ -409,7 +444,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
     val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType))
     val m2 = Literal.create(null, MapType(StringType, StringType))
 
-    checkEvaluation(ElementAt(m0, Literal(1.0)), null)
+    assert(ElementAt(m0, Literal(1.0)).checkInputDataTypes().isFailure)
 
     checkEvaluation(ElementAt(m0, Literal("d")), null)
 
@@ -420,6 +455,18 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
     checkEvaluation(ElementAt(m0, Literal("c")), null)
 
     checkEvaluation(ElementAt(m2, Literal("a")), null)
+
+    // test binary type as keys
+    val mb0 = Literal.create(
+      Map(Array[Byte](1, 2) -> "1", Array[Byte](3, 4) -> null, Array[Byte](2, 1) -> "2"),
+      MapType(BinaryType, StringType))
+    val mb1 = Literal.create(Map[Array[Byte], String](), MapType(BinaryType, StringType))
+
+    checkEvaluation(ElementAt(mb0, Literal(Array[Byte](1, 2, 3))), null)
+
+    checkEvaluation(ElementAt(mb1, Literal(Array[Byte](1, 2))), null)
+    checkEvaluation(ElementAt(mb0, Literal(Array[Byte](2, 1), BinaryType)), "2")
+    checkEvaluation(ElementAt(mb0, Literal(Array[Byte](3, 4))), null)
   }
 
   test("Concat") {

http://git-wip-us.apache.org/repos/asf/spark/blob/d3d18073/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala
index 633d86d..5452e72 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala
@@ -439,4 +439,17 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
         .select('c as 'sCol2, 'a as 'sCol1)
     checkRule(originalQuery, correctAnswer)
   }
+
+  test("SPARK-24313: support binary type as map keys in GetMapValue") {
+    val mb0 = Literal.create(
+      Map(Array[Byte](1, 2) -> "1", Array[Byte](3, 4) -> null, Array[Byte](2, 1) -> "2"),
+      MapType(BinaryType, StringType))
+    val mb1 = Literal.create(Map[Array[Byte], String](), MapType(BinaryType, StringType))
+
+    checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](1, 2, 3))), null)
+
+    checkEvaluation(GetMapValue(mb1, Literal(Array[Byte](1, 2))), null)
+    checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](2, 1), BinaryType)), "2")
+    checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](3, 4))), null)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/d3d18073/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 60e84e6..1cc8cb3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -2265,4 +2265,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
     val df = spark.range(1).select($"id", new Column(Uuid()))
     checkAnswer(df, df.collect())
   }
+
+  test("SPARK-24313: access map with binary keys") {
+    val mapWithBinaryKey = map(lit(Array[Byte](1.toByte)), lit(1))
+    checkAnswer(spark.range(1).select(mapWithBinaryKey.getItem(Array[Byte](1.toByte))), Row(1))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org