You are viewing a plain text version of this content. The canonical link for it is here.
Posted to by on 2015/10/20 22:40:31 UTC

spark git commit: [SPARK-11111] [SQL] fast null-safe join

Repository: spark
Updated Branches:
  refs/heads/master 478c7ce86 -> 67d468f8d

[SPARK-11111] [SQL] fast null-safe join

Currently, we use CartesianProduct for join with null-safe-equal condition.
scala> sqlContext.sql("select * from t a join t b on (a.i <=> b.i)").explain
== Physical Plan ==
TungstenProject [i#2,j#3,i#7,j#8]
 Filter (i#2 <=> i#7)
   LocalTableScan [i#2,j#3], [[1,1]]
   LocalTableScan [i#7,j#8], [[1,1]]
Actually, we can have an equal-join condition as  `coalesce(i, default) = coalesce(b.i, default)`, then an partitioned join algorithm could be used.

After this PR, the plan will become:
>>> sqlContext.sql("select * from a join b ON <=>").explain()
TungstenProject [id#0L,id#1L]
 Filter (id#0L <=> id#1L)
  SortMergeJoin [coalesce(id#0L,0)], [coalesce(id#1L,0)]
   TungstenSort [coalesce(id#0L,0) ASC], false, 0
    TungstenExchange hashpartitioning(coalesce(id#0L,0),200)
      Scan PhysicalRDD[id#0L]
   TungstenSort [coalesce(id#1L,0) ASC], false, 0
    TungstenExchange hashpartitioning(coalesce(id#1L,0),200)
      Scan PhysicalRDD[id#1L]

Author: Davies Liu <>

Closes #9120 from davies/null_safe.


Branch: refs/heads/master
Commit: 67d468f8d9172569ec9846edc6432240547696dd
Parents: 478c7ce
Author: Davies Liu <>
Authored: Tue Oct 20 13:40:24 2015 -0700
Committer: Davies Liu <>
Committed: Tue Oct 20 13:40:24 2015 -0700

 .../sql/catalyst/expressions/literals.scala     | 29 ++++++++++++++--
 .../spark/sql/catalyst/planning/patterns.scala  | 35 +++++++++++++-------
 .../expressions/LiteralExpressionSuite.scala    | 28 +++++++++++++++-
 .../org/apache/spark/sql/SQLQuerySuite.scala    | 14 ++++++++
 .../sql/execution/joins/InnerJoinSuite.scala    | 14 ++++++++
 5 files changed, 105 insertions(+), 15 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index 51be819..455fa24 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -19,8 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
 import java.sql.{Date, Timestamp}
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.CatalystTypeConverters
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.types._
@@ -52,6 +51,32 @@ object Literal {
   def create(v: Any, dataType: DataType): Literal = {
     Literal(CatalystTypeConverters.convertToCatalyst(v), dataType)
+  /**
+   * Create a literal with default value for given DataType
+   */
+  def default(dataType: DataType): Literal = dataType match {
+    case NullType => create(null, NullType)
+    case BooleanType => Literal(false)
+    case ByteType => Literal(0.toByte)
+    case ShortType => Literal(0.toShort)
+    case IntegerType => Literal(0)
+    case LongType => Literal(0L)
+    case FloatType => Literal(0.0f)
+    case DoubleType => Literal(0.0)
+    case dt: DecimalType => Literal(Decimal(0, dt.precision, dt.scale))
+    case DateType => create(0, DateType)
+    case TimestampType => create(0L, TimestampType)
+    case StringType => Literal("")
+    case BinaryType => Literal("".getBytes)
+    case CalendarIntervalType => Literal(new CalendarInterval(0, 0))
+    case arr: ArrayType => create(Array(), arr)
+    case map: MapType => create(Map(), map)
+    case struct: StructType =>
+      create(InternalRow.fromSeq( => default(f.dataType).value)), struct)
+    case other =>
+      throw new RuntimeException(s"no default for type $dataType")
+  }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 5353779..3b975b9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -18,10 +18,10 @@
 package org.apache.spark.sql.catalyst.planning
 import org.apache.spark.Logging
-import org.apache.spark.sql.catalyst.trees.TreeNodeRef
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.trees.TreeNodeRef
  * A pattern that matches any number of project or filter operations on top of another relational
@@ -160,6 +160,9 @@ object PartialAggregation {
  * A pattern that finds joins with equality conditions that can be evaluated using equi-join.
+ *
+ * Null-safe equality will be transformed into equality as joining key (replace null with default
+ * value).
 object ExtractEquiJoinKeys extends Logging with PredicateHelper {
   /** (joinType, leftKeys, rightKeys, condition, leftChild, rightChild) */
@@ -171,17 +174,25 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper {
       logDebug(s"Considering join on: $condition")
       // Find equi-join predicates that can be evaluated before the join, and thus can be used
       // as join keys.
-      val (joinPredicates, otherPredicates) =
- {
-          case EqualTo(l, r) =>
-            (canEvaluate(l, left) && canEvaluate(r, right)) ||
-            (canEvaluate(l, right) && canEvaluate(r, left))
-          case _ => false
-        }
-      val joinKeys = {
-        case EqualTo(l, r) if canEvaluate(l, left) && canEvaluate(r, right) => (l, r)
-        case EqualTo(l, r) if canEvaluate(l, right) && canEvaluate(r, left) => (r, l)
+      val predicates =
+      val joinKeys = predicates.flatMap {
+        case EqualTo(l, r) if canEvaluate(l, left) && canEvaluate(r, right) => Some((l, r))
+        case EqualTo(l, r) if canEvaluate(l, right) && canEvaluate(r, left) => Some((r, l))
+        // Replace null with default value for joining key, then those rows with null in it could
+        // be joined together
+        case EqualNullSafe(l, r) if canEvaluate(l, left) && canEvaluate(r, right) =>
+          Some((Coalesce(Seq(l, Literal.default(l.dataType))),
+            Coalesce(Seq(r, Literal.default(r.dataType)))))
+        case EqualNullSafe(l, r) if canEvaluate(l, right) && canEvaluate(r, left) =>
+          Some((Coalesce(Seq(r, Literal.default(r.dataType))),
+            Coalesce(Seq(l, Literal.default(l.dataType)))))
+        case other => None
+      }
+      val otherPredicates = predicates.filterNot {
+        case EqualTo(l, r) =>
+          canEvaluate(l, left) && canEvaluate(r, right) ||
+            canEvaluate(l, right) && canEvaluate(r, left)
+        case other => false
       if (joinKeys.nonEmpty) {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
index 015eb18..7b85286 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
@@ -18,7 +18,10 @@
 package org.apache.spark.sql.catalyst.expressions
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.CalendarInterval
 class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -30,15 +33,38 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
     checkEvaluation(Literal.create(null, IntegerType), null)
     checkEvaluation(Literal.create(null, LongType), null)
     checkEvaluation(Literal.create(null, FloatType), null)
-    checkEvaluation(Literal.create(null, LongType), null)
+    checkEvaluation(Literal.create(null, DoubleType), null)
     checkEvaluation(Literal.create(null, StringType), null)
     checkEvaluation(Literal.create(null, BinaryType), null)
     checkEvaluation(Literal.create(null, DecimalType.USER_DEFAULT), null)
+    checkEvaluation(Literal.create(null, DateType), null)
+    checkEvaluation(Literal.create(null, TimestampType), null)
+    checkEvaluation(Literal.create(null, CalendarIntervalType), null)
     checkEvaluation(Literal.create(null, ArrayType(ByteType, true)), null)
     checkEvaluation(Literal.create(null, MapType(StringType, IntegerType)), null)
     checkEvaluation(Literal.create(null, StructType(Seq.empty)), null)
+  test("default") {
+    checkEvaluation(Literal.default(BooleanType), false)
+    checkEvaluation(Literal.default(ByteType), 0.toByte)
+    checkEvaluation(Literal.default(ShortType), 0.toShort)
+    checkEvaluation(Literal.default(IntegerType), 0)
+    checkEvaluation(Literal.default(LongType), 0L)
+    checkEvaluation(Literal.default(FloatType), 0.0f)
+    checkEvaluation(Literal.default(DoubleType), 0.0)
+    checkEvaluation(Literal.default(StringType), "")
+    checkEvaluation(Literal.default(BinaryType), "".getBytes)
+    checkEvaluation(Literal.default(DecimalType.USER_DEFAULT), Decimal(0))
+    checkEvaluation(Literal.default(DecimalType.SYSTEM_DEFAULT), Decimal(0))
+    checkEvaluation(Literal.default(DateType), DateTimeUtils.toJavaDate(0))
+    checkEvaluation(Literal.default(TimestampType), DateTimeUtils.toJavaTimestamp(0L))
+    checkEvaluation(Literal.default(CalendarIntervalType), new CalendarInterval(0, 0L))
+    checkEvaluation(Literal.default(ArrayType(StringType)), Array())
+    checkEvaluation(Literal.default(MapType(IntegerType, StringType)), Map())
+    checkEvaluation(Literal.default(StructType(StructField("a", StringType) :: Nil)), Row(""))
+  }
   test("boolean literals") {
     checkEvaluation(Literal(true), true)
     checkEvaluation(Literal(false), false)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 6365916..a35a7f4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.DefaultParserDialect
 import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
 import org.apache.spark.sql.catalyst.errors.DialectException
 import org.apache.spark.sql.execution.aggregate
+import org.apache.spark.sql.execution.joins.{SortMergeJoin, CartesianProduct}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.SQLTestData._
 import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext}
@@ -850,6 +851,19 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
       Row(null, null, 6, "F") :: Nil)
+  test("SPARK-11111 null-safe join should not use cartesian product") {
+    val df = sql("select count(*) from testData a join testData b on (a.key <=> b.key)")
+    val cp = df.queryExecution.executedPlan.collect {
+      case cp: CartesianProduct => cp
+    }
+    assert(cp.isEmpty, "should not use CartesianProduct for null-safe join")
+    val smj = df.queryExecution.executedPlan.collect {
+      case smj: SortMergeJoin => smj
+    }
+    assert(smj.size > 0, "should use SortMergeJoin")
+    checkAnswer(df, Row(100) :: Nil)
+  }
   test("SPARK-3349 partitioning after limit") {
     sql("SELECT DISTINCT n FROM lowerCaseData ORDER BY n DESC")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
index 4174ee0..da58e96 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
@@ -212,4 +212,18 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
+  {
+    lazy val left = Seq((1, Some(0)), (2, None)).toDF("a", "b")
+    lazy val right = Seq((1, Some(0)), (2, None)).toDF("a", "b")
+    testInnerJoin(
+      "inner join, null safe",
+      left,
+      right,
+      () => (left.col("b") <=> right.col("b")).expr,
+      Seq(
+        (1, 0, 1, 0),
+        (2, null, 2, null)
+      )
+    )
+  }

To unsubscribe, e-mail:
For additional commands, e-mail: