You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2016/02/23 21:55:54 UTC

spark git commit: [SPARK-13329] [SQL] considering output for statistics of logical plan

Repository: spark
Updated Branches:
  refs/heads/master c5bfe5d2a -> c481bdf51


[SPARK-13329] [SQL] considering output for statistics of logical plan

The current implementation of statistics of UnaryNode does not considering output (for example, Project may product much less columns than it's child), we should considering it to have a better guess.

We usually only join with few columns from a parquet table, the size of projected plan could be much smaller than the original parquet files. Having a better guess of size help we choose between broadcast join or sort merge join.

After this PR, I saw a few queries choose broadcast join other than sort merge join without turning spark.sql.autoBroadcastJoinThreshold for every query, ended up with about 6-8X improvements on end-to-end time.

We use `defaultSize` of DataType to estimate the size of a column, currently For DecimalType/StringType/BinaryType and UDT, we are over-estimate too much (4096 Bytes), so this PR change them to some more reasonable values. Here are the new defaultSize for them:

DecimalType:  8 or 16 bytes, based on the precision
StringType:  20 bytes
BinaryType: 100 bytes
UDF: default size of SQL type

These numbers are not perfect (hard to have a perfect number for them), but should be better than 4096.

Author: Davies Liu <da...@databricks.com>

Closes #11210 from davies/statics.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c481bdf5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c481bdf5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c481bdf5

Branch: refs/heads/master
Commit: c481bdf512f09060c9b9f341a5ce9fce00427d08
Parents: c5bfe5d
Author: Davies Liu <da...@databricks.com>
Authored: Tue Feb 23 12:55:44 2016 -0800
Committer: Davies Liu <da...@gmail.com>
Committed: Tue Feb 23 12:55:44 2016 -0800

----------------------------------------------------------------------
 python/pyspark/sql/dataframe.py                 |  4 +-
 .../catalyst/plans/logical/LogicalPlan.scala    | 15 +++++
 .../catalyst/plans/logical/basicOperators.scala | 34 +++++++++-
 .../org/apache/spark/sql/types/BinaryType.scala |  4 +-
 .../apache/spark/sql/types/DecimalType.scala    |  4 +-
 .../org/apache/spark/sql/types/StringType.scala |  4 +-
 .../spark/sql/types/UserDefinedType.scala       |  5 +-
 .../apache/spark/sql/types/DataTypeSuite.scala  | 12 ++--
 .../org/apache/spark/sql/execution/Expand.scala |  2 -
 .../scala/org/apache/spark/sql/JoinSuite.scala  | 67 +++++++++++---------
 .../org/apache/spark/sql/SQLQuerySuite.scala    |  5 +-
 11 files changed, 100 insertions(+), 56 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c481bdf5/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 83b034f..bf43452 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -551,8 +551,8 @@ class DataFrame(object):
         >>> df_as1 = df.alias("df_as1")
         >>> df_as2 = df.alias("df_as2")
         >>> joined_df = df_as1.join(df_as2, col("df_as1.name") == col("df_as2.name"), 'inner')
-        >>> joined_df.select(col("df_as1.name"), col("df_as2.name"), col("df_as2.age")).collect()
-        [Row(name=u'Bob', name=u'Bob', age=5), Row(name=u'Alice', name=u'Alice', age=2)]
+        >>> joined_df.select("df_as1.name", "df_as2.name", "df_as2.age").collect()
+        [Row(name=u'Alice', name=u'Alice', age=2), Row(name=u'Bob', name=u'Bob', age=5)]
         """
         assert isinstance(alias, basestring), "alias should be a string"
         return DataFrame(getattr(self._jdf, "as")(alias), self.sql_ctx)

http://git-wip-us.apache.org/repos/asf/spark/blob/c481bdf5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 35df242..8095083 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -316,6 +316,21 @@ abstract class UnaryNode extends LogicalPlan {
   override def children: Seq[LogicalPlan] = child :: Nil
 
   override protected def validConstraints: Set[Expression] = child.constraints
+
+  override def statistics: Statistics = {
+    // There should be some overhead in Row object, the size should not be zero when there is
+    // no columns, this help to prevent divide-by-zero error.
+    val childRowSize = child.output.map(_.dataType.defaultSize).sum + 8
+    val outputRowSize = output.map(_.dataType.defaultSize).sum + 8
+    // Assume there will be the same number of rows as child has.
+    var sizeInBytes = (child.statistics.sizeInBytes * outputRowSize) / childRowSize
+    if (sizeInBytes == 0) {
+      // sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero
+      // (product of children).
+      sizeInBytes = 1
+    }
+    Statistics(sizeInBytes = sizeInBytes)
+  }
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/c481bdf5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index c98d33d..af43cb3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -176,6 +176,13 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation
       Some(children.flatMap(_.maxRows).min)
     }
   }
+
+  override def statistics: Statistics = {
+    val leftSize = left.statistics.sizeInBytes
+    val rightSize = right.statistics.sizeInBytes
+    val sizeInBytes = if (leftSize < rightSize) leftSize else rightSize
+    Statistics(sizeInBytes = sizeInBytes)
+  }
 }
 
 case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) {
@@ -188,6 +195,10 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le
     childrenResolved &&
       left.output.length == right.output.length &&
       left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType }
+
+  override def statistics: Statistics = {
+    Statistics(sizeInBytes = left.statistics.sizeInBytes)
+  }
 }
 
 /** Factory for constructing new `Union` nodes. */
@@ -426,6 +437,14 @@ case class Aggregate(
 
   override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute)
   override def maxRows: Option[Long] = child.maxRows
+
+  override def statistics: Statistics = {
+    if (groupingExpressions.isEmpty) {
+      Statistics(sizeInBytes = 1)
+    } else {
+      super.statistics
+    }
+  }
 }
 
 case class Window(
@@ -521,9 +540,7 @@ case class Expand(
     AttributeSet(projections.flatten.flatMap(_.references))
 
   override def statistics: Statistics = {
-    // TODO shouldn't we factor in the size of the projection versus the size of the backing child
-    //      row?
-    val sizeInBytes = child.statistics.sizeInBytes * projections.length
+    val sizeInBytes = super.statistics.sizeInBytes * projections.length
     Statistics(sizeInBytes = sizeInBytes)
   }
 }
@@ -648,6 +665,17 @@ case class Sample(
     val isTableSample: java.lang.Boolean = false) extends UnaryNode {
 
   override def output: Seq[Attribute] = child.output
+
+  override def statistics: Statistics = {
+    val ratio = upperBound - lowerBound
+    // BigInt can't multiply with Double
+    var sizeInBytes = child.statistics.sizeInBytes * (ratio * 100).toInt / 100
+    if (sizeInBytes == 0) {
+      sizeInBytes = 1
+    }
+    Statistics(sizeInBytes = sizeInBytes)
+  }
+
   override protected def otherCopyArgs: Seq[AnyRef] = isTableSample :: Nil
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/c481bdf5/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala
index f2c6f34..c40e140 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala
@@ -47,9 +47,9 @@ class BinaryType private() extends AtomicType {
   }
 
   /**
-   * The default size of a value of the BinaryType is 4096 bytes.
+   * The default size of a value of the BinaryType is 100 bytes.
    */
-  override def defaultSize: Int = 4096
+  override def defaultSize: Int = 100
 
   private[spark] override def asNullable: BinaryType = this
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/c481bdf5/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
index 71ea5b8..2e03dda 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
@@ -91,9 +91,9 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType {
   }
 
   /**
-   * The default size of a value of the DecimalType is 4096 bytes.
+   * The default size of a value of the DecimalType is 8 bytes (precision <= 18) or 16 bytes.
    */
-  override def defaultSize: Int = 4096
+  override def defaultSize: Int = if (precision <= Decimal.MAX_LONG_DIGITS) 8 else 16
 
   override def simpleString: String = s"decimal($precision,$scale)"
 

http://git-wip-us.apache.org/repos/asf/spark/blob/c481bdf5/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala
index a7627a2..44a2536 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala
@@ -38,9 +38,9 @@ class StringType private() extends AtomicType {
   private[sql] val ordering = implicitly[Ordering[InternalType]]
 
   /**
-   * The default size of a value of the StringType is 4096 bytes.
+   * The default size of a value of the StringType is 20 bytes.
    */
-  override def defaultSize: Int = 4096
+  override def defaultSize: Int = 20
 
   private[spark] override def asNullable: StringType = this
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/c481bdf5/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
index 7664c30..9d2449f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
@@ -71,10 +71,7 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable {
    */
   def userClass: java.lang.Class[UserType]
 
-  /**
-   * The default size of a value of the UserDefinedType is 4096 bytes.
-   */
-  override def defaultSize: Int = 4096
+  override def defaultSize: Int = sqlType.defaultSize
 
   /**
    * For UDT, asNullable will not change the nullability of its internal sqlType and just returns

http://git-wip-us.apache.org/repos/asf/spark/blob/c481bdf5/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
index c2bbca7..6b85f12 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
@@ -248,15 +248,15 @@ class DataTypeSuite extends SparkFunSuite {
   checkDefaultSize(LongType, 8)
   checkDefaultSize(FloatType, 4)
   checkDefaultSize(DoubleType, 8)
-  checkDefaultSize(DecimalType(10, 5), 4096)
-  checkDefaultSize(DecimalType.SYSTEM_DEFAULT, 4096)
+  checkDefaultSize(DecimalType(10, 5), 8)
+  checkDefaultSize(DecimalType.SYSTEM_DEFAULT, 16)
   checkDefaultSize(DateType, 4)
   checkDefaultSize(TimestampType, 8)
-  checkDefaultSize(StringType, 4096)
-  checkDefaultSize(BinaryType, 4096)
+  checkDefaultSize(StringType, 20)
+  checkDefaultSize(BinaryType, 100)
   checkDefaultSize(ArrayType(DoubleType, true), 800)
-  checkDefaultSize(ArrayType(StringType, false), 409600)
-  checkDefaultSize(MapType(IntegerType, StringType, true), 410000)
+  checkDefaultSize(ArrayType(StringType, false), 2000)
+  checkDefaultSize(MapType(IntegerType, StringType, true), 2400)
   checkDefaultSize(MapType(IntegerType, ArrayType(DoubleType), false), 80400)
   checkDefaultSize(structType, 812)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/c481bdf5/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
index d26a0b7..a9e77ab 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
@@ -17,8 +17,6 @@
 
 package org.apache.spark.sql.execution
 
-import scala.collection.immutable.IndexedSeq
-
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.errors._

http://git-wip-us.apache.org/repos/asf/spark/blob/c481bdf5/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 8f2a0c0..a1211e4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -63,36 +63,40 @@ class JoinSuite extends QueryTest with SharedSQLContext {
   test("join operator selection") {
     sqlContext.cacheManager.clearCache()
 
-    Seq(
-      ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]),
-      ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]),
-      ("SELECT * FROM testData JOIN testData2", classOf[CartesianProduct]),
-      ("SELECT * FROM testData JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
-      ("SELECT * FROM testData LEFT JOIN testData2", classOf[CartesianProduct]),
-      ("SELECT * FROM testData RIGHT JOIN testData2", classOf[CartesianProduct]),
-      ("SELECT * FROM testData FULL OUTER JOIN testData2", classOf[CartesianProduct]),
-      ("SELECT * FROM testData LEFT JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
-      ("SELECT * FROM testData RIGHT JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
-      ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
-      ("SELECT * FROM testData JOIN testData2 WHERE key > a", classOf[CartesianProduct]),
-      ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key > a", classOf[CartesianProduct]),
-      ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]),
-      ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]),
-      ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]),
-      ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeOuterJoin]),
-      ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2",
-        classOf[SortMergeJoin]), // converted from Right Outer to Inner
-      ("SELECT * FROM testData right join testData2 ON key = a and key = 2",
-        classOf[SortMergeOuterJoin]),
-      ("SELECT * FROM testData full outer join testData2 ON key = a",
-        classOf[SortMergeOuterJoin]),
-      ("SELECT * FROM testData left JOIN testData2 ON (key * a != key + a)",
-        classOf[BroadcastNestedLoopJoin]),
-      ("SELECT * FROM testData right JOIN testData2 ON (key * a != key + a)",
-        classOf[BroadcastNestedLoopJoin]),
-      ("SELECT * FROM testData full JOIN testData2 ON (key * a != key + a)",
-        classOf[BroadcastNestedLoopJoin])
-    ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
+    withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "0") {
+      Seq(
+        ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]),
+        ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]),
+        ("SELECT * FROM testData JOIN testData2", classOf[CartesianProduct]),
+        ("SELECT * FROM testData JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
+        ("SELECT * FROM testData LEFT JOIN testData2", classOf[CartesianProduct]),
+        ("SELECT * FROM testData RIGHT JOIN testData2", classOf[CartesianProduct]),
+        ("SELECT * FROM testData FULL OUTER JOIN testData2", classOf[CartesianProduct]),
+        ("SELECT * FROM testData LEFT JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
+        ("SELECT * FROM testData RIGHT JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
+        ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key = 2",
+          classOf[CartesianProduct]),
+        ("SELECT * FROM testData JOIN testData2 WHERE key > a", classOf[CartesianProduct]),
+        ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key > a",
+          classOf[CartesianProduct]),
+        ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]),
+        ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]),
+        ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]),
+        ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeOuterJoin]),
+        ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2",
+          classOf[SortMergeJoin]),
+        ("SELECT * FROM testData right join testData2 ON key = a and key = 2",
+          classOf[SortMergeOuterJoin]),
+        ("SELECT * FROM testData full outer join testData2 ON key = a",
+          classOf[SortMergeOuterJoin]),
+        ("SELECT * FROM testData left JOIN testData2 ON (key * a != key + a)",
+          classOf[BroadcastNestedLoopJoin]),
+        ("SELECT * FROM testData right JOIN testData2 ON (key * a != key + a)",
+          classOf[BroadcastNestedLoopJoin]),
+        ("SELECT * FROM testData full JOIN testData2 ON (key * a != key + a)",
+          classOf[BroadcastNestedLoopJoin])
+      ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
+    }
   }
 
 //  ignore("SortMergeJoin shouldn't work on unsortable columns") {
@@ -118,9 +122,10 @@ class JoinSuite extends QueryTest with SharedSQLContext {
   test("broadcasted hash outer join operator selection") {
     sqlContext.cacheManager.clearCache()
     sql("CACHE TABLE testData")
+    sql("CACHE TABLE testData2")
     Seq(
       ("SELECT * FROM testData LEFT JOIN testData2 ON key = a",
-        classOf[SortMergeOuterJoin]),
+        classOf[BroadcastHashJoin]),
       ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2",
         classOf[BroadcastHashJoin]),
       ("SELECT * FROM testData right join testData2 ON key = a and key = 2",

http://git-wip-us.apache.org/repos/asf/spark/blob/c481bdf5/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index b4d6f4e..1d9db27 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -22,7 +22,7 @@ import java.sql.Timestamp
 
 import org.apache.spark.AccumulatorSuite
 import org.apache.spark.sql.execution.aggregate
-import org.apache.spark.sql.execution.joins.{CartesianProduct, SortMergeJoin}
+import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, CartesianProduct, SortMergeJoin}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext}
 import org.apache.spark.sql.test.SQLTestData._
@@ -780,8 +780,9 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
     assert(cp.isEmpty, "should not use CartesianProduct for null-safe join")
     val smj = df.queryExecution.sparkPlan.collect {
       case smj: SortMergeJoin => smj
+      case j: BroadcastHashJoin => j
     }
-    assert(smj.size > 0, "should use SortMergeJoin")
+    assert(smj.size > 0, "should use SortMergeJoin or BroadcastHashJoin")
     checkAnswer(df, Row(100) :: Nil)
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org