You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2016/11/23 12:15:25 UTC
spark git commit: [SPARK-18053][SQL] compare unsafe and safe
complex-type values correctly
Repository: spark
Updated Branches:
refs/heads/master 85235ed6c -> 84284e8c8
[SPARK-18053][SQL] compare unsafe and safe complex-type values correctly
## What changes were proposed in this pull request?
In Spark SQL, some expression may output safe format values, e.g. `CreateArray`, `CreateStruct`, `Cast`, etc. When we compare 2 values, we should be able to compare safe and unsafe formats.
The `GreaterThan`, `LessThan`, etc. in Spark SQL already handles it, but the `EqualTo` doesn't. This PR fixes it.
## How was this patch tested?
new unit test and regression test
Author: Wenchen Fan <we...@databricks.com>
Closes #15929 from cloud-fan/type-aware.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/84284e8c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/84284e8c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/84284e8c
Branch: refs/heads/master
Commit: 84284e8c82542d80dad94e458a0c0210bf803db3
Parents: 85235ed
Author: Wenchen Fan <we...@databricks.com>
Authored: Wed Nov 23 04:15:19 2016 -0800
Committer: Herman van Hovell <hv...@databricks.com>
Committed: Wed Nov 23 04:15:19 2016 -0800
----------------------------------------------------------------------
.../sql/catalyst/expressions/UnsafeRow.java | 6 +---
.../expressions/codegen/CodeGenerator.scala | 20 ++++++++++--
.../sql/catalyst/expressions/predicates.scala | 32 +++-----------------
.../catalyst/expressions/PredicateSuite.scala | 29 ++++++++++++++++++
.../org/apache/spark/sql/SQLQuerySuite.scala | 7 +++++
5 files changed, 59 insertions(+), 35 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/84284e8c/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index c3f0aba..d205547 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -578,12 +578,8 @@ public final class UnsafeRow extends InternalRow implements Externalizable, Kryo
return (sizeInBytes == o.sizeInBytes) &&
ByteArrayMethods.arrayEquals(baseObject, baseOffset, o.baseObject, o.baseOffset,
sizeInBytes);
- } else if (!(other instanceof InternalRow)) {
- return false;
- } else {
- throw new IllegalArgumentException(
- "Cannot compare UnsafeRow to " + other.getClass().getName());
}
+ return false;
}
/**
http://git-wip-us.apache.org/repos/asf/spark/blob/84284e8c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 9c3c6d3..09007b7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -481,8 +481,13 @@ class CodegenContext {
case FloatType => s"(java.lang.Float.isNaN($c1) && java.lang.Float.isNaN($c2)) || $c1 == $c2"
case DoubleType => s"(java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 == $c2"
case dt: DataType if isPrimitiveType(dt) => s"$c1 == $c2"
+ case dt: DataType if dt.isInstanceOf[AtomicType] => s"$c1.equals($c2)"
+ case array: ArrayType => genComp(array, c1, c2) + " == 0"
+ case struct: StructType => genComp(struct, c1, c2) + " == 0"
case udt: UserDefinedType[_] => genEqual(udt.sqlType, c1, c2)
- case other => s"$c1.equals($c2)"
+ case _ =>
+ throw new IllegalArgumentException(
+ "cannot generate equality code for un-comparable type: " + dataType.simpleString)
}
/**
@@ -512,6 +517,11 @@ class CodegenContext {
val funcCode: String =
s"""
public int $compareFunc(ArrayData a, ArrayData b) {
+ // when comparing unsafe arrays, try equals first as it compares the binary directly
+ // which is very fast.
+ if (a instanceof UnsafeArrayData && b instanceof UnsafeArrayData && a.equals(b)) {
+ return 0;
+ }
int lengthA = a.numElements();
int lengthB = b.numElements();
int $minLength = (lengthA > lengthB) ? lengthB : lengthA;
@@ -551,6 +561,11 @@ class CodegenContext {
val funcCode: String =
s"""
public int $compareFunc(InternalRow a, InternalRow b) {
+ // when comparing unsafe rows, try equals first as it compares the binary directly
+ // which is very fast.
+ if (a instanceof UnsafeRow && b instanceof UnsafeRow && a.equals(b)) {
+ return 0;
+ }
InternalRow i = null;
$comparisons
return 0;
@@ -561,7 +576,8 @@ class CodegenContext {
case other if other.isInstanceOf[AtomicType] => s"$c1.compare($c2)"
case udt: UserDefinedType[_] => genComp(udt.sqlType, c1, c2)
case _ =>
- throw new IllegalArgumentException("cannot generate compare code for un-comparable type")
+ throw new IllegalArgumentException(
+ "cannot generate compare code for un-comparable type: " + dataType.simpleString)
}
/**
http://git-wip-us.apache.org/repos/asf/spark/blob/84284e8c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 2ad452b..3fcbb05 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -388,6 +388,8 @@ abstract class BinaryComparison extends BinaryOperator with Predicate {
defineCodeGen(ctx, ev, (c1, c2) => s"${ctx.genComp(left.dataType, c1, c2)} $symbol 0")
}
}
+
+ protected lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType)
}
@@ -429,17 +431,7 @@ case class EqualTo(left: Expression, right: Expression)
override def symbol: String = "="
- protected override def nullSafeEval(input1: Any, input2: Any): Any = {
- if (left.dataType == FloatType) {
- Utils.nanSafeCompareFloats(input1.asInstanceOf[Float], input2.asInstanceOf[Float]) == 0
- } else if (left.dataType == DoubleType) {
- Utils.nanSafeCompareDoubles(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) == 0
- } else if (left.dataType != BinaryType) {
- input1 == input2
- } else {
- java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]])
- }
- }
+ protected override def nullSafeEval(left: Any, right: Any): Any = ordering.equiv(left, right)
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, (c1, c2) => ctx.genEqual(left.dataType, c1, c2))
@@ -482,15 +474,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
} else if (input1 == null || input2 == null) {
false
} else {
- if (left.dataType == FloatType) {
- Utils.nanSafeCompareFloats(input1.asInstanceOf[Float], input2.asInstanceOf[Float]) == 0
- } else if (left.dataType == DoubleType) {
- Utils.nanSafeCompareDoubles(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) == 0
- } else if (left.dataType != BinaryType) {
- input1 == input2
- } else {
- java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]])
- }
+ ordering.equiv(input1, input2)
}
}
@@ -513,8 +497,6 @@ case class LessThan(left: Expression, right: Expression)
override def symbol: String = "<"
- private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType)
-
protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lt(input1, input2)
}
@@ -527,8 +509,6 @@ case class LessThanOrEqual(left: Expression, right: Expression)
override def symbol: String = "<="
- private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType)
-
protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lteq(input1, input2)
}
@@ -541,8 +521,6 @@ case class GreaterThan(left: Expression, right: Expression)
override def symbol: String = ">"
- private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType)
-
protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gt(input1, input2)
}
@@ -555,7 +533,5 @@ case class GreaterThanOrEqual(left: Expression, right: Expression)
override def symbol: String = ">="
- private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType)
-
protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gteq(input1, input2)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/84284e8c/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
index 2a445b8..f9f6799 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
@@ -21,6 +21,8 @@ import scala.collection.immutable.HashSet
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.RandomDataGenerator
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.types._
@@ -293,4 +295,31 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(EqualNullSafe(nullInt, normalInt), false)
checkEvaluation(EqualNullSafe(nullInt, nullInt), true)
}
+
+ test("EqualTo on complex type") {
+ val array = new GenericArrayData(Array(1, 2, 3))
+ val struct = create_row("a", 1L, array)
+
+ val arrayType = ArrayType(IntegerType)
+ val structType = new StructType()
+ .add("1", StringType)
+ .add("2", LongType)
+ .add("3", ArrayType(IntegerType))
+
+ val projection = UnsafeProjection.create(
+ new StructType().add("array", arrayType).add("struct", structType))
+
+ val unsafeRow = projection(InternalRow(array, struct))
+
+ val unsafeArray = unsafeRow.getArray(0)
+ val unsafeStruct = unsafeRow.getStruct(1, 3)
+
+ checkEvaluation(EqualTo(
+ Literal.create(array, arrayType),
+ Literal.create(unsafeArray, arrayType)), true)
+
+ checkEvaluation(EqualTo(
+ Literal.create(struct, structType),
+ Literal.create(unsafeStruct, structType)), true)
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/84284e8c/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index a715176..d2ec3cf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -2469,4 +2469,11 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
}
}
+
+ test("SPARK-18053: ARRAY equality is broken") {
+ withTable("array_tbl") {
+ spark.range(10).select(array($"id").as("arr")).write.saveAsTable("array_tbl")
+ assert(sql("SELECT * FROM array_tbl where arr = ARRAY(1L)").count == 1)
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org