You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2016/11/23 12:15:25 UTC

spark git commit: [SPARK-18053][SQL] compare unsafe and safe complex-type values correctly

Repository: spark
Updated Branches:
  refs/heads/master 85235ed6c -> 84284e8c8


[SPARK-18053][SQL] compare unsafe and safe complex-type values correctly

## What changes were proposed in this pull request?

In Spark SQL, some expression may output safe format values, e.g. `CreateArray`, `CreateStruct`, `Cast`, etc. When we compare 2 values, we should be able to compare safe and unsafe formats.

The `GreaterThan`, `LessThan`, etc. in Spark SQL already handles it, but the `EqualTo` doesn't. This PR fixes it.

## How was this patch tested?

new unit test and regression test

Author: Wenchen Fan <we...@databricks.com>

Closes #15929 from cloud-fan/type-aware.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/84284e8c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/84284e8c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/84284e8c

Branch: refs/heads/master
Commit: 84284e8c82542d80dad94e458a0c0210bf803db3
Parents: 85235ed
Author: Wenchen Fan <we...@databricks.com>
Authored: Wed Nov 23 04:15:19 2016 -0800
Committer: Herman van Hovell <hv...@databricks.com>
Committed: Wed Nov 23 04:15:19 2016 -0800

----------------------------------------------------------------------
 .../sql/catalyst/expressions/UnsafeRow.java     |  6 +---
 .../expressions/codegen/CodeGenerator.scala     | 20 ++++++++++--
 .../sql/catalyst/expressions/predicates.scala   | 32 +++-----------------
 .../catalyst/expressions/PredicateSuite.scala   | 29 ++++++++++++++++++
 .../org/apache/spark/sql/SQLQuerySuite.scala    |  7 +++++
 5 files changed, 59 insertions(+), 35 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/84284e8c/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index c3f0aba..d205547 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -578,12 +578,8 @@ public final class UnsafeRow extends InternalRow implements Externalizable, Kryo
       return (sizeInBytes == o.sizeInBytes) &&
         ByteArrayMethods.arrayEquals(baseObject, baseOffset, o.baseObject, o.baseOffset,
           sizeInBytes);
-    } else if (!(other instanceof InternalRow)) {
-      return false;
-    } else {
-      throw new IllegalArgumentException(
-        "Cannot compare UnsafeRow to " + other.getClass().getName());
     }
+    return false;
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/84284e8c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 9c3c6d3..09007b7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -481,8 +481,13 @@ class CodegenContext {
     case FloatType => s"(java.lang.Float.isNaN($c1) && java.lang.Float.isNaN($c2)) || $c1 == $c2"
     case DoubleType => s"(java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 == $c2"
     case dt: DataType if isPrimitiveType(dt) => s"$c1 == $c2"
+    case dt: DataType if dt.isInstanceOf[AtomicType] => s"$c1.equals($c2)"
+    case array: ArrayType => genComp(array, c1, c2) + " == 0"
+    case struct: StructType => genComp(struct, c1, c2) + " == 0"
     case udt: UserDefinedType[_] => genEqual(udt.sqlType, c1, c2)
-    case other => s"$c1.equals($c2)"
+    case _ =>
+      throw new IllegalArgumentException(
+        "cannot generate equality code for un-comparable type: " + dataType.simpleString)
   }
 
   /**
@@ -512,6 +517,11 @@ class CodegenContext {
       val funcCode: String =
         s"""
           public int $compareFunc(ArrayData a, ArrayData b) {
+            // when comparing unsafe arrays, try equals first as it compares the binary directly
+            // which is very fast.
+            if (a instanceof UnsafeArrayData && b instanceof UnsafeArrayData && a.equals(b)) {
+              return 0;
+            }
             int lengthA = a.numElements();
             int lengthB = b.numElements();
             int $minLength = (lengthA > lengthB) ? lengthB : lengthA;
@@ -551,6 +561,11 @@ class CodegenContext {
       val funcCode: String =
         s"""
           public int $compareFunc(InternalRow a, InternalRow b) {
+            // when comparing unsafe rows, try equals first as it compares the binary directly
+            // which is very fast.
+            if (a instanceof UnsafeRow && b instanceof UnsafeRow && a.equals(b)) {
+              return 0;
+            }
             InternalRow i = null;
             $comparisons
             return 0;
@@ -561,7 +576,8 @@ class CodegenContext {
     case other if other.isInstanceOf[AtomicType] => s"$c1.compare($c2)"
     case udt: UserDefinedType[_] => genComp(udt.sqlType, c1, c2)
     case _ =>
-      throw new IllegalArgumentException("cannot generate compare code for un-comparable type")
+      throw new IllegalArgumentException(
+        "cannot generate compare code for un-comparable type: " + dataType.simpleString)
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/84284e8c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 2ad452b..3fcbb05 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -388,6 +388,8 @@ abstract class BinaryComparison extends BinaryOperator with Predicate {
       defineCodeGen(ctx, ev, (c1, c2) => s"${ctx.genComp(left.dataType, c1, c2)} $symbol 0")
     }
   }
+
+  protected lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType)
 }
 
 
@@ -429,17 +431,7 @@ case class EqualTo(left: Expression, right: Expression)
 
   override def symbol: String = "="
 
-  protected override def nullSafeEval(input1: Any, input2: Any): Any = {
-    if (left.dataType == FloatType) {
-      Utils.nanSafeCompareFloats(input1.asInstanceOf[Float], input2.asInstanceOf[Float]) == 0
-    } else if (left.dataType == DoubleType) {
-      Utils.nanSafeCompareDoubles(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) == 0
-    } else if (left.dataType != BinaryType) {
-      input1 == input2
-    } else {
-      java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]])
-    }
-  }
+  protected override def nullSafeEval(left: Any, right: Any): Any = ordering.equiv(left, right)
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     defineCodeGen(ctx, ev, (c1, c2) => ctx.genEqual(left.dataType, c1, c2))
@@ -482,15 +474,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
     } else if (input1 == null || input2 == null) {
       false
     } else {
-      if (left.dataType == FloatType) {
-        Utils.nanSafeCompareFloats(input1.asInstanceOf[Float], input2.asInstanceOf[Float]) == 0
-      } else if (left.dataType == DoubleType) {
-        Utils.nanSafeCompareDoubles(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) == 0
-      } else if (left.dataType != BinaryType) {
-        input1 == input2
-      } else {
-        java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]])
-      }
+      ordering.equiv(input1, input2)
     }
   }
 
@@ -513,8 +497,6 @@ case class LessThan(left: Expression, right: Expression)
 
   override def symbol: String = "<"
 
-  private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType)
-
   protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lt(input1, input2)
 }
 
@@ -527,8 +509,6 @@ case class LessThanOrEqual(left: Expression, right: Expression)
 
   override def symbol: String = "<="
 
-  private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType)
-
   protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lteq(input1, input2)
 }
 
@@ -541,8 +521,6 @@ case class GreaterThan(left: Expression, right: Expression)
 
   override def symbol: String = ">"
 
-  private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType)
-
   protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gt(input1, input2)
 }
 
@@ -555,7 +533,5 @@ case class GreaterThanOrEqual(left: Expression, right: Expression)
 
   override def symbol: String = ">="
 
-  private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType)
-
   protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gteq(input1, input2)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/84284e8c/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
index 2a445b8..f9f6799 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
@@ -21,6 +21,8 @@ import scala.collection.immutable.HashSet
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.RandomDataGenerator
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.util.GenericArrayData
 import org.apache.spark.sql.types._
 
 
@@ -293,4 +295,31 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
     checkEvaluation(EqualNullSafe(nullInt, normalInt), false)
     checkEvaluation(EqualNullSafe(nullInt, nullInt), true)
   }
+
+  test("EqualTo on complex type") {
+    val array = new GenericArrayData(Array(1, 2, 3))
+    val struct = create_row("a", 1L, array)
+
+    val arrayType = ArrayType(IntegerType)
+    val structType = new StructType()
+      .add("1", StringType)
+      .add("2", LongType)
+      .add("3", ArrayType(IntegerType))
+
+    val projection = UnsafeProjection.create(
+      new StructType().add("array", arrayType).add("struct", structType))
+
+    val unsafeRow = projection(InternalRow(array, struct))
+
+    val unsafeArray = unsafeRow.getArray(0)
+    val unsafeStruct = unsafeRow.getStruct(1, 3)
+
+    checkEvaluation(EqualTo(
+      Literal.create(array, arrayType),
+      Literal.create(unsafeArray, arrayType)), true)
+
+    checkEvaluation(EqualTo(
+      Literal.create(struct, structType),
+      Literal.create(unsafeStruct, structType)), true)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/84284e8c/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index a715176..d2ec3cf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -2469,4 +2469,11 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
       }
     }
   }
+
+  test("SPARK-18053: ARRAY equality is broken") {
+    withTable("array_tbl") {
+      spark.range(10).select(array($"id").as("arr")).write.saveAsTable("array_tbl")
+      assert(sql("SELECT * FROM array_tbl where arr = ARRAY(1L)").count == 1)
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org