You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/07/06 17:05:05 UTC
spark git commit: [SPARK-21228][SQL] InSet incorrect handling of
structs
Repository: spark
Updated Branches:
refs/heads/master 565e7a8d4 -> 26ac085de
[SPARK-21228][SQL] InSet incorrect handling of structs
## What changes were proposed in this pull request?
When data type is struct, InSet now uses TypeUtils.getInterpretedOrdering (similar to EqualTo) to build a TreeSet. In other cases it will use a HashSet as before (which should be faster). Similarly, In.eval uses Ordering.equiv instead of equals.
## How was this patch tested?
New test in SQLQuerySuite.
Author: Bogdan Raducanu <bo...@databricks.com>
Closes #18455 from bogdanrdc/SPARK-21228.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/26ac085d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/26ac085d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/26ac085d
Branch: refs/heads/master
Commit: 26ac085debb54d0104762d1cd4187cdf73f301ba
Parents: 565e7a8
Author: Bogdan Raducanu <bo...@databricks.com>
Authored: Fri Jul 7 01:04:57 2017 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Fri Jul 7 01:04:57 2017 +0800
----------------------------------------------------------------------
.../sql/catalyst/expressions/predicates.scala | 57 +++++++++++++-------
.../catalyst/expressions/PredicateSuite.scala | 31 ++++++-----
.../catalyst/optimizer/OptimizeInSuite.scala | 2 +-
.../org/apache/spark/sql/SQLQuerySuite.scala | 22 ++++++++
4 files changed, 78 insertions(+), 34 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/26ac085d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index f3fe58c..7bf10f1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -17,10 +17,11 @@
package org.apache.spark.sql.catalyst.expressions
+import scala.collection.immutable.TreeSet
+
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
-import org.apache.spark.sql.catalyst.expressions.codegen.{Predicate => BasePredicate}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
@@ -175,20 +176,23 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
|[${sub.output.map(_.dataType.catalogString).mkString(", ")}].
""".stripMargin)
} else {
- TypeCheckResult.TypeCheckSuccess
+ TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName")
}
}
case _ =>
- if (list.exists(l => l.dataType != value.dataType)) {
- TypeCheckResult.TypeCheckFailure("Arguments must be same type")
+ val mismatchOpt = list.find(l => l.dataType != value.dataType)
+ if (mismatchOpt.isDefined) {
+ TypeCheckResult.TypeCheckFailure(s"Arguments must be same type but were: " +
+ s"${value.dataType} != ${mismatchOpt.get.dataType}")
} else {
- TypeCheckResult.TypeCheckSuccess
+ TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName")
}
}
}
override def children: Seq[Expression] = value +: list
lazy val inSetConvertible = list.forall(_.isInstanceOf[Literal])
+ private lazy val ordering = TypeUtils.getInterpretedOrdering(value.dataType)
override def nullable: Boolean = children.exists(_.nullable)
override def foldable: Boolean = children.forall(_.foldable)
@@ -203,10 +207,10 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
var hasNull = false
list.foreach { e =>
val v = e.eval(input)
- if (v == evaluatedValue) {
- return true
- } else if (v == null) {
+ if (v == null) {
hasNull = true
+ } else if (ordering.equiv(v, evaluatedValue)) {
+ return true
}
}
if (hasNull) {
@@ -265,7 +269,7 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
override def nullable: Boolean = child.nullable || hasNull
protected override def nullSafeEval(value: Any): Any = {
- if (hset.contains(value)) {
+ if (set.contains(value)) {
true
} else if (hasNull) {
null
@@ -274,27 +278,40 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
}
}
- def getHSet(): Set[Any] = hset
+ @transient private[this] lazy val set = child.dataType match {
+ case _: AtomicType => hset
+ case _: NullType => hset
+ case _ =>
+ // for structs use interpreted ordering to be able to compare UnsafeRows with non-UnsafeRows
+ TreeSet.empty(TypeUtils.getInterpretedOrdering(child.dataType)) ++ hset
+ }
+
+ def getSet(): Set[Any] = set
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val setName = classOf[Set[Any]].getName
val InSetName = classOf[InSet].getName
val childGen = child.genCode(ctx)
ctx.references += this
- val hsetTerm = ctx.freshName("hset")
- val hasNullTerm = ctx.freshName("hasNull")
- ctx.addMutableState(setName, hsetTerm,
- s"$hsetTerm = (($InSetName)references[${ctx.references.size - 1}]).getHSet();")
- ctx.addMutableState("boolean", hasNullTerm, s"$hasNullTerm = $hsetTerm.contains(null);")
+ val setTerm = ctx.freshName("set")
+ val setNull = if (hasNull) {
+ s"""
+ |if (!${ev.value}) {
+ | ${ev.isNull} = true;
+ |}
+ """.stripMargin
+ } else {
+ ""
+ }
+ ctx.addMutableState(setName, setTerm,
+ s"$setTerm = (($InSetName)references[${ctx.references.size - 1}]).getSet();")
ev.copy(code = s"""
${childGen.code}
boolean ${ev.isNull} = ${childGen.isNull};
boolean ${ev.value} = false;
if (!${ev.isNull}) {
- ${ev.value} = $hsetTerm.contains(${childGen.value});
- if (!${ev.value} && $hasNullTerm) {
- ${ev.isNull} = true;
- }
+ ${ev.value} = $setTerm.contains(${childGen.value});
+ $setNull
}
""")
}
http://git-wip-us.apache.org/repos/asf/spark/blob/26ac085d/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
index 6fe295c..ef510a9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
@@ -35,7 +35,8 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
test(s"3VL $name") {
truthTable.foreach {
case (l, r, answer) =>
- val expr = op(NonFoldableLiteral(l, BooleanType), NonFoldableLiteral(r, BooleanType))
+ val expr = op(NonFoldableLiteral.create(l, BooleanType),
+ NonFoldableLiteral.create(r, BooleanType))
checkEvaluation(expr, answer)
}
}
@@ -72,7 +73,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
(false, true) ::
(null, null) :: Nil
notTrueTable.foreach { case (v, answer) =>
- checkEvaluation(Not(NonFoldableLiteral(v, BooleanType)), answer)
+ checkEvaluation(Not(NonFoldableLiteral.create(v, BooleanType)), answer)
}
checkConsistencyBetweenInterpretedAndCodegen(Not, BooleanType)
}
@@ -120,22 +121,26 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
(null, null, null) :: Nil)
test("IN") {
- checkEvaluation(In(NonFoldableLiteral(null, IntegerType), Seq(Literal(1), Literal(2))), null)
- checkEvaluation(In(NonFoldableLiteral(null, IntegerType),
- Seq(NonFoldableLiteral(null, IntegerType))), null)
- checkEvaluation(In(NonFoldableLiteral(null, IntegerType), Seq.empty), null)
+ checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), Seq(Literal(1),
+ Literal(2))), null)
+ checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType),
+ Seq(NonFoldableLiteral.create(null, IntegerType))), null)
+ checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), Seq.empty), null)
checkEvaluation(In(Literal(1), Seq.empty), false)
- checkEvaluation(In(Literal(1), Seq(NonFoldableLiteral(null, IntegerType))), null)
- checkEvaluation(In(Literal(1), Seq(Literal(1), NonFoldableLiteral(null, IntegerType))), true)
- checkEvaluation(In(Literal(2), Seq(Literal(1), NonFoldableLiteral(null, IntegerType))), null)
+ checkEvaluation(In(Literal(1), Seq(NonFoldableLiteral.create(null, IntegerType))), null)
+ checkEvaluation(In(Literal(1), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))),
+ true)
+ checkEvaluation(In(Literal(2), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))),
+ null)
checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true)
checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true)
checkEvaluation(In(Literal(3), Seq(Literal(1), Literal(2))), false)
checkEvaluation(
- And(In(Literal(1), Seq(Literal(1), Literal(2))), In(Literal(2), Seq(Literal(1), Literal(2)))),
+ And(In(Literal(1), Seq(Literal(1), Literal(2))), In(Literal(2), Seq(Literal(1),
+ Literal(2)))),
true)
- val ns = NonFoldableLiteral(null, StringType)
+ val ns = NonFoldableLiteral.create(null, StringType)
checkEvaluation(In(ns, Seq(Literal("1"), Literal("2"))), null)
checkEvaluation(In(ns, Seq(ns)), null)
checkEvaluation(In(Literal("a"), Seq(ns)), null)
@@ -155,7 +160,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
case _ => value
}
}
- val input = inputData.map(NonFoldableLiteral(_, t))
+ val input = inputData.map(NonFoldableLiteral.create(_, t))
val expected = if (inputData(0) == null) {
null
} else if (inputData.slice(1, 10).contains(inputData(0))) {
@@ -279,7 +284,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
test("BinaryComparison: null test") {
// Use -1 (default value for codegen) which can trigger some weird bugs, e.g. SPARK-14757
val normalInt = Literal(-1)
- val nullInt = NonFoldableLiteral(null, IntegerType)
+ val nullInt = NonFoldableLiteral.create(null, IntegerType)
def nullTest(op: (Expression, Expression) => Expression): Unit = {
checkEvaluation(op(normalInt, nullInt), null)
http://git-wip-us.apache.org/repos/asf/spark/blob/26ac085d/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
index 6a77580..28bf7b6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
@@ -169,7 +169,7 @@ class OptimizeInSuite extends PlanTest {
val optimizedPlan = OptimizeIn(plan)
optimizedPlan match {
case Filter(cond, _)
- if cond.isInstanceOf[InSet] && cond.asInstanceOf[InSet].getHSet().size == 3 =>
+ if cond.isInstanceOf[InSet] && cond.asInstanceOf[InSet].getSet().size == 3 =>
// pass
case _ => fail("Unexpected result for OptimizedIn")
}
http://git-wip-us.apache.org/repos/asf/spark/blob/26ac085d/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 68f61cf..5171aae 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -2616,4 +2616,26 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
val e = intercept[AnalysisException](sql("SELECT nvl(1, 2, 3)"))
assert(e.message.contains("Invalid number of arguments"))
}
+
+ test("SPARK-21228: InSet incorrect handling of structs") {
+ withTempView("A") {
+ // reduce this from the default of 10 so the repro query text is not too long
+ withSQLConf((SQLConf.OPTIMIZER_INSET_CONVERSION_THRESHOLD.key -> "3")) {
+ // a relation that has 1 column of struct type with values (1,1), ..., (9, 9)
+ spark.range(1, 10).selectExpr("named_struct('a', id, 'b', id) as a")
+ .createOrReplaceTempView("A")
+ val df = sql(
+ """
+ |SELECT * from
+ | (SELECT MIN(a) as minA FROM A) AA -- this Aggregate will return UnsafeRows
+ | -- the IN will become InSet with a Set of GenericInternalRows
+ | -- a GenericInternalRow is never equal to an UnsafeRow so the query would
+ | -- returns 0 results, which is incorrect
+ | WHERE minA IN (NAMED_STRUCT('a', 1L, 'b', 1L), NAMED_STRUCT('a', 2L, 'b', 2L),
+ | NAMED_STRUCT('a', 3L, 'b', 3L))
+ """.stripMargin)
+ checkAnswer(df, Row(Row(1, 1)))
+ }
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org