You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by do...@apache.org on 2020/12/28 08:04:31 UTC

[spark] branch branch-3.1 updated: [SPARK-33923][SQL][TESTS] Fix some tests with AQE enabled

This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-3.1
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.1 by this push:
     new 37a7f86  [SPARK-33923][SQL][TESTS] Fix some tests with AQE enabled
37a7f86 is described below

commit 37a7f86f658a613eb02e3a55ae8c03f749e758c9
Author: yi.wu <yi...@databricks.com>
AuthorDate: Mon Dec 28 00:03:45 2020 -0800

    [SPARK-33923][SQL][TESTS] Fix some tests with AQE enabled
    
    ### What changes were proposed in this pull request?
    
    * Remove the explicit AQE disable confs
    * Use `AdaptiveSparkPlanHelper` to check plans
    * No longer extending `DisableAdaptiveExecutionSuite` for `BucketedReadSuite` but only disable AQE for two certain tests there.
    
    ### Why are the changes needed?
    
    Some tests that are fixed in https://github.com/apache/spark/pull/30655 doesn't really require AQE off.  Instead, they could use `AdaptiveSparkPlanHelper` to pass when AQE on. It's better to run tests with AQE on since we've turned it on by default.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Pass all tests and the updated tests.
    
    Closes #30941 from Ngone51/SPARK-33680-follow-up.
    
    Authored-by: yi.wu <yi...@databricks.com>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
    (cherry picked from commit 00fa49aeaa601f50df81adb25184f141ba0a44ee)
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 .../apache/spark/sql/DataFrameAggregateSuite.scala | 10 +--
 .../org/apache/spark/sql/DataFrameJoinSuite.scala  |  8 +-
 .../scala/org/apache/spark/sql/JoinSuite.scala     | 24 +++---
 .../apache/spark/sql/execution/PlannerSuite.scala  | 93 ++++++++--------------
 .../spark/sql/sources/BucketedReadSuite.scala      | 15 ++--
 .../sources/BucketedReadWithHiveSupportSuite.scala |  4 +-
 6 files changed, 62 insertions(+), 92 deletions(-)

diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 78983a4..6603fc0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -1002,8 +1002,7 @@ class DataFrameAggregateSuite extends QueryTest
   Seq(true, false).foreach { value =>
     test(s"SPARK-31620: agg with subquery (whole-stage-codegen = $value)") {
       withSQLConf(
-        SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> value.toString,
-        SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
+        SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> value.toString) {
         withTempView("t1", "t2") {
           sql("create temporary view t1 as select * from values (1, 2) as t1(a, b)")
           sql("create temporary view t2 as select * from values (3, 4) as t2(c, d)")
@@ -1026,14 +1025,13 @@ class DataFrameAggregateSuite extends QueryTest
 
           // test SortAggregateExec
           var df = sql("select max(if(c > (select a from t1), 'str1', 'str2')) as csum from t2")
-          assert(df.queryExecution.executedPlan
-            .find { case _: SortAggregateExec => true }.isDefined)
+          assert(find(df.queryExecution.executedPlan)(_.isInstanceOf[SortAggregateExec]).isDefined)
           checkAnswer(df, Row("str1") :: Nil)
 
           // test ObjectHashAggregateExec
           df = sql("select collect_list(d), sum(if(c > (select a from t1), d, 0)) as csum from t2")
-          assert(df.queryExecution.executedPlan
-            .find { case _: ObjectHashAggregateExec => true }.isDefined)
+          assert(
+            find(df.queryExecution.executedPlan)(_.isInstanceOf[ObjectHashAggregateExec]).isDefined)
           checkAnswer(df, Row(Array(4), 4) :: Nil)
         }
       }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
index 1513c2e..ad13d7d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
@@ -345,15 +345,13 @@ class DataFrameJoinSuite extends QueryTest
 
     withTempDatabase { dbName =>
       withTable(table1Name, table2Name) {
-        withSQLConf(
-          SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
-          SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
+        withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
           spark.range(50).write.saveAsTable(s"$dbName.$table1Name")
           spark.range(100).write.saveAsTable(s"$dbName.$table2Name")
 
           def checkIfHintApplied(df: DataFrame): Unit = {
             val sparkPlan = df.queryExecution.executedPlan
-            val broadcastHashJoins = sparkPlan.collect { case p: BroadcastHashJoinExec => p }
+            val broadcastHashJoins = collect(sparkPlan) { case p: BroadcastHashJoinExec => p }
             assert(broadcastHashJoins.size == 1)
             val broadcastExchanges = broadcastHashJoins.head.collect {
               case p: BroadcastExchangeExec => p
@@ -368,7 +366,7 @@ class DataFrameJoinSuite extends QueryTest
 
           def checkIfHintNotApplied(df: DataFrame): Unit = {
             val sparkPlan = df.queryExecution.executedPlan
-            val broadcastHashJoins = sparkPlan.collect { case p: BroadcastHashJoinExec => p }
+            val broadcastHashJoins = collect(sparkPlan) { case p: BroadcastHashJoinExec => p }
             assert(broadcastHashJoins.isEmpty)
           }
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index a728e5c..1bdfdb5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -1107,7 +1107,6 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
 
   test("SPARK-32330: Preserve shuffled hash join build side partitioning") {
     withSQLConf(
-        SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
         SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "50",
         SQLConf.SHUFFLE_PARTITIONS.key -> "2",
         SQLConf.PREFER_SORTMERGEJOIN.key -> "false") {
@@ -1116,9 +1115,9 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
       Seq("inner", "cross").foreach(joinType => {
         val plan = df1.join(df2, $"k1" === $"k2", joinType).groupBy($"k1").count()
           .queryExecution.executedPlan
-        assert(plan.collect { case _: ShuffledHashJoinExec => true }.size === 1)
+        assert(collect(plan) { case _: ShuffledHashJoinExec => true }.size === 1)
         // No extra shuffle before aggregate
-        assert(plan.collect { case _: ShuffleExchangeExec => true }.size === 2)
+        assert(collect(plan) { case _: ShuffleExchangeExec => true }.size === 2)
       })
     }
   }
@@ -1131,7 +1130,6 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
 
     // Test broadcast hash join
     withSQLConf(
-      SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
       SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "50") {
       Seq("inner", "left_outer").foreach(joinType => {
         val plan = df1.join(df2, $"k1" === $"k2", joinType)
@@ -1139,16 +1137,15 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
           .join(df4, $"k1" === $"k4", joinType)
           .queryExecution
           .executedPlan
-        assert(plan.collect { case _: SortMergeJoinExec => true }.size === 2)
-        assert(plan.collect { case _: BroadcastHashJoinExec => true }.size === 1)
+        assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 2)
+        assert(collect(plan) { case _: BroadcastHashJoinExec => true }.size === 1)
         // No extra sort before last sort merge join
-        assert(plan.collect { case _: SortExec => true }.size === 3)
+        assert(collect(plan) { case _: SortExec => true }.size === 3)
       })
     }
 
     // Test shuffled hash join
     withSQLConf(
-      SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
       SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "50",
       SQLConf.SHUFFLE_PARTITIONS.key -> "2",
       SQLConf.PREFER_SORTMERGEJOIN.key -> "false") {
@@ -1160,10 +1157,10 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
           .join(df4, $"k1" === $"k4", joinType)
           .queryExecution
           .executedPlan
-        assert(plan.collect { case _: SortMergeJoinExec => true }.size === 2)
-        assert(plan.collect { case _: ShuffledHashJoinExec => true }.size === 1)
+        assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 2)
+        assert(collect(plan) { case _: ShuffledHashJoinExec => true }.size === 1)
         // No extra sort before last sort merge join
-        assert(plan.collect { case _: SortExec => true }.size === 3)
+        assert(collect(plan) { case _: SortExec => true }.size === 3)
       })
     }
   }
@@ -1256,17 +1253,16 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
       withSQLConf(
         // Set broadcast join threshold and number of shuffle partitions,
         // as shuffled hash join depends on these two configs.
-        SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
         SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80",
         SQLConf.SHUFFLE_PARTITIONS.key -> "2") {
         val smjDF = df1.join(df2, joinExprs, "full")
-        assert(smjDF.queryExecution.executedPlan.collect {
+        assert(collect(smjDF.queryExecution.executedPlan) {
           case _: SortMergeJoinExec => true }.size === 1)
         val smjResult = smjDF.collect()
 
         withSQLConf(SQLConf.PREFER_SORTMERGEJOIN.key -> "false") {
           val shjDF = df1.join(df2, joinExprs, "full")
-          assert(shjDF.queryExecution.executedPlan.collect {
+          assert(collect(shjDF.queryExecution.executedPlan) {
             case _: ShuffledHashJoinExec => true }.size === 1)
           // Same result between shuffled hash join and sort merge join
           checkAnswer(shjDF, smjResult)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 4e01d1c..f27acad 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -877,9 +877,7 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
   }
 
   test("aliases in the project should not introduce extra shuffle") {
-    withSQLConf(
-      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
-      SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
+    withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
       withTempView("df1", "df2") {
         spark.range(10).selectExpr("id AS key", "0").repartition($"key").createTempView("df1")
         spark.range(20).selectExpr("id AS key", "0").repartition($"key").createTempView("df2")
@@ -891,7 +889,7 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
             |  (SELECT key AS k from df2) t2
             |ON t1.k = t2.k
           """.stripMargin).queryExecution.executedPlan
-        val exchanges = planned.collect { case s: ShuffleExchangeExec => s }
+        val exchanges = collect(planned) { case s: ShuffleExchangeExec => s }
         assert(exchanges.size == 2)
       }
     }
@@ -899,9 +897,7 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
 
   test("SPARK-33399: aliases should be handled properly in PartitioningCollection output" +
     " partitioning") {
-    withSQLConf(
-      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
-      SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
+    withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
       withTempView("t1", "t2", "t3") {
         spark.range(10).repartition($"id").createTempView("t1")
         spark.range(20).repartition($"id").createTempView("t2")
@@ -916,10 +912,10 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
             |) t12, t3
             |WHERE t1id = t3.id
           """.stripMargin).queryExecution.executedPlan
-        val exchanges = planned.collect { case s: ShuffleExchangeExec => s }
+        val exchanges = collect(planned) { case s: ShuffleExchangeExec => s }
         assert(exchanges.size == 3)
 
-        val projects = planned.collect { case p: ProjectExec => p }
+        val projects = collect(planned) { case p: ProjectExec => p }
         assert(projects.exists(_.outputPartitioning match {
           case PartitioningCollection(Seq(HashPartitioning(Seq(k1: AttributeReference), _),
             HashPartitioning(Seq(k2: AttributeReference), _))) if k1.name == "t1id" =>
@@ -931,9 +927,7 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
   }
 
   test("SPARK-33399: aliases should be handled properly in HashPartitioning") {
-    withSQLConf(
-      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
-      SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
+    withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
       withTempView("t1", "t2", "t3") {
         spark.range(10).repartition($"id").createTempView("t1")
         spark.range(20).repartition($"id").createTempView("t2")
@@ -948,10 +942,10 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
             |) t12 INNER JOIN t3
             |WHERE t1id = t3.id
           """.stripMargin).queryExecution.executedPlan
-        val exchanges = planned.collect { case s: ShuffleExchangeExec => s }
+        val exchanges = collect(planned) { case s: ShuffleExchangeExec => s }
         assert(exchanges.size == 3)
 
-        val projects = planned.collect { case p: ProjectExec => p }
+        val projects = collect(planned) { case p: ProjectExec => p }
         assert(projects.exists(_.outputPartitioning match {
           case HashPartitioning(Seq(a: AttributeReference), _) => a.name == "t1id"
           case _ => false
@@ -961,19 +955,17 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
   }
 
   test("SPARK-33399: alias handling should happen properly for RangePartitioning") {
-    withSQLConf(
-      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
-      SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
+    withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
       val df = spark.range(1, 100)
         .select(col("id").as("id1")).groupBy("id1").count()
       // Plan for this will be Range -> ProjectWithAlias -> HashAggregate -> HashAggregate
       // if Project normalizes alias in its Range outputPartitioning, then no Exchange should come
       // in between HashAggregates
       val planned = df.queryExecution.executedPlan
-      val exchanges = planned.collect { case s: ShuffleExchangeExec => s }
+      val exchanges = collect(planned) { case s: ShuffleExchangeExec => s }
       assert(exchanges.isEmpty)
 
-      val projects = planned.collect { case p: ProjectExec => p }
+      val projects = collect(planned) { case p: ProjectExec => p }
       assert(projects.exists(_.outputPartitioning match {
         case RangePartitioning(Seq(SortOrder(ar: AttributeReference, _, _, _)), _) =>
           ar.name == "id1"
@@ -984,9 +976,7 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
 
   test("SPARK-33399: aliased should be handled properly " +
     "for partitioning and sortorder involving complex expressions") {
-    withSQLConf(
-      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
-      SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
+    withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
       withTempView("t1", "t2", "t3") {
         spark.range(10).select(col("id").as("id1")).createTempView("t1")
         spark.range(20).select(col("id").as("id2")).createTempView("t2")
@@ -1001,12 +991,12 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
             |) t12, t3
             |WHERE t1id * 10 = t3.id3 * 10
           """.stripMargin).queryExecution.executedPlan
-        val sortNodes = planned.collect { case s: SortExec => s }
+        val sortNodes = collect(planned) { case s: SortExec => s }
         assert(sortNodes.size == 3)
-        val exchangeNodes = planned.collect { case e: ShuffleExchangeExec => e }
+        val exchangeNodes = collect(planned) { case e: ShuffleExchangeExec => e }
         assert(exchangeNodes.size == 3)
 
-        val projects = planned.collect { case p: ProjectExec => p }
+        val projects = collect(planned) { case p: ProjectExec => p }
         assert(projects.exists(_.outputPartitioning match {
           case PartitioningCollection(Seq(HashPartitioning(Seq(Multiply(ar1, _, _)), _),
             HashPartitioning(Seq(Multiply(ar2, _, _)), _))) =>
@@ -1024,16 +1014,14 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
   }
 
   test("SPARK-33399: alias handling should happen properly for SinglePartition") {
-    withSQLConf(
-      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
-      SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
+    withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
       val df = spark.range(1, 100, 1, 1)
         .select(col("id").as("id1")).groupBy("id1").count()
       val planned = df.queryExecution.executedPlan
-      val exchanges = planned.collect { case s: ShuffleExchangeExec => s }
+      val exchanges = collect(planned) { case s: ShuffleExchangeExec => s }
       assert(exchanges.isEmpty)
 
-      val projects = planned.collect { case p: ProjectExec => p }
+      val projects = collect(planned) { case p: ProjectExec => p }
       assert(projects.exists(_.outputPartitioning match {
         case SinglePartition => true
         case _ => false
@@ -1043,9 +1031,7 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
 
   test("SPARK-33399: No extra exchanges in case of" +
     " [Inner Join -> Project with aliases -> HashAggregate]") {
-    withSQLConf(
-      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
-      SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
+    withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
       withTempView("t1", "t2") {
         spark.range(10).repartition($"id").createTempView("t1")
         spark.range(20).repartition($"id").createTempView("t2")
@@ -1059,10 +1045,10 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
             |) t12
             |GROUP BY t1id, t2id
           """.stripMargin).queryExecution.executedPlan
-        val exchanges = planned.collect { case s: ShuffleExchangeExec => s }
+        val exchanges = collect(planned) { case s: ShuffleExchangeExec => s }
         assert(exchanges.size == 2)
 
-        val projects = planned.collect { case p: ProjectExec => p }
+        val projects = collect(planned) { case p: ProjectExec => p }
         assert(projects.exists(_.outputPartitioning match {
           case PartitioningCollection(Seq(HashPartitioning(Seq(k1: AttributeReference), _),
           HashPartitioning(Seq(k2: AttributeReference), _))) =>
@@ -1074,9 +1060,7 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
   }
 
   test("SPARK-33400: Normalization of sortOrder should take care of sameOrderExprs") {
-    withSQLConf(
-      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
-      SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
+    withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
       withTempView("t1", "t2", "t3") {
         spark.range(10).repartition($"id").createTempView("t1")
         spark.range(20).repartition($"id").createTempView("t2")
@@ -1092,10 +1076,10 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
             |WHERE t2id = t3.id
           """.stripMargin).queryExecution.executedPlan
 
-        val sortNodes = planned.collect { case s: SortExec => s }
+        val sortNodes = collect(planned) { case s: SortExec => s }
         assert(sortNodes.size == 3)
 
-        val projects = planned.collect { case p: ProjectExec => p }
+        val projects = collect(planned) { case p: ProjectExec => p }
         assert(projects.exists(_.outputOrdering match {
           case Seq(SortOrder(_, Ascending, NullsFirst, sameOrderExprs)) =>
             sameOrderExprs.size == 1 && sameOrderExprs.head.isInstanceOf[AttributeReference] &&
@@ -1135,9 +1119,7 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
   }
 
   test("aliases to expressions should not be replaced") {
-    withSQLConf(
-      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
-      SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
+    withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
       withTempView("df1", "df2") {
         spark.range(10).selectExpr("id AS key", "0").repartition($"key").createTempView("df1")
         spark.range(20).selectExpr("id AS key", "0").repartition($"key").createTempView("df2")
@@ -1149,7 +1131,7 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
             |  (SELECT key + 1 AS k2 from df2) t2
             |ON t1.k1 = t2.k2
             |""".stripMargin).queryExecution.executedPlan
-        val exchanges = planned.collect { case s: ShuffleExchangeExec => s }
+        val exchanges = collect(planned) { case s: ShuffleExchangeExec => s }
 
         // Make sure aliases to an expression (key + 1) are not replaced.
         Seq("k1", "k2").foreach { alias =>
@@ -1163,9 +1145,7 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
   }
 
   test("aliases in the aggregate expressions should not introduce extra shuffle") {
-    withSQLConf(
-      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
-      SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
+    withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
       val t1 = spark.range(10).selectExpr("floor(id/4) as k1")
       val t2 = spark.range(20).selectExpr("floor(id/4) as k2")
 
@@ -1174,17 +1154,15 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
 
       val planned = agg1.join(agg2, $"k1" === $"k3").queryExecution.executedPlan
 
-      assert(planned.collect { case h: HashAggregateExec => h }.nonEmpty)
+      assert(collect(planned) { case h: HashAggregateExec => h }.nonEmpty)
 
-      val exchanges = planned.collect { case s: ShuffleExchangeExec => s }
+      val exchanges = collect(planned) { case s: ShuffleExchangeExec => s }
       assert(exchanges.size == 2)
     }
   }
 
   test("aliases in the object hash/sort aggregate expressions should not introduce extra shuffle") {
-    withSQLConf(
-      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
-      SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
+    withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
       Seq(true, false).foreach { useObjectHashAgg =>
         withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> useObjectHashAgg.toString) {
           val t1 = spark.range(10).selectExpr("floor(id/4) as k1")
@@ -1196,12 +1174,12 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
           val planned = agg1.join(agg2, $"k1" === $"k3").queryExecution.executedPlan
 
           if (useObjectHashAgg) {
-            assert(planned.collect { case o: ObjectHashAggregateExec => o }.nonEmpty)
+            assert(collect(planned) { case o: ObjectHashAggregateExec => o }.nonEmpty)
           } else {
-            assert(planned.collect { case s: SortAggregateExec => s }.nonEmpty)
+            assert(collect(planned) { case s: SortAggregateExec => s }.nonEmpty)
           }
 
-          val exchanges = planned.collect { case s: ShuffleExchangeExec => s }
+          val exchanges = collect(planned) { case s: ShuffleExchangeExec => s }
           assert(exchanges.size == 2)
         }
       }
@@ -1211,7 +1189,6 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
   test("aliases in the sort aggregate expressions should not introduce extra sort") {
     withSQLConf(
       SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
-      SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
       SQLConf.USE_OBJECT_HASH_AGG.key -> "false") {
       val t1 = spark.range(10).selectExpr("floor(id/4) as k1")
       val t2 = spark.range(20).selectExpr("floor(id/4) as k2")
@@ -1220,10 +1197,10 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
       val agg2 = t2.groupBy("k2").agg(collect_list("k2"))
 
       val planned = agg1.join(agg2, $"k3" === $"k2").queryExecution.executedPlan
-      assert(planned.collect { case s: SortAggregateExec => s }.nonEmpty)
+      assert(collect(planned) { case s: SortAggregateExec => s }.nonEmpty)
 
       // We expect two SortExec nodes on each side of join.
-      val sorts = planned.collect { case s: SortExec => s }
+      val sorts = collect(planned) { case s: SortExec => s }
       assert(sorts.size == 4)
     }
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
index 4ae8cdb..9dcc0cf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
 import org.apache.spark.sql.execution.{FileSourceScanExec, SortExec, SparkPlan}
-import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, DisableAdaptiveExecutionSuite}
+import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, DisableAdaptiveExecution}
 import org.apache.spark.sql.execution.datasources.BucketingUtils
 import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
 import org.apache.spark.sql.execution.joins.SortMergeJoinExec
@@ -40,7 +40,7 @@ import org.apache.spark.util.Utils
 import org.apache.spark.util.collection.BitSet
 
 class BucketedReadWithoutHiveSupportSuite
-  extends BucketedReadSuite with DisableAdaptiveExecutionSuite with SharedSparkSession {
+  extends BucketedReadSuite with SharedSparkSession {
   protected override def beforeAll(): Unit = {
     super.beforeAll()
     assert(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "in-memory")
@@ -48,7 +48,7 @@ class BucketedReadWithoutHiveSupportSuite
 }
 
 
-abstract class BucketedReadSuite extends QueryTest with SQLTestUtils {
+abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with AdaptiveSparkPlanHelper {
   import testImplicits._
 
   protected override def beforeAll(): Unit = {
@@ -104,7 +104,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils {
   }
 
   private def getFileScan(plan: SparkPlan): FileSourceScanExec = {
-    val fileScan = plan.collect { case f: FileSourceScanExec => f }
+    val fileScan = collect(plan) { case f: FileSourceScanExec => f }
     assert(fileScan.nonEmpty, plan)
     fileScan.head
   }
@@ -930,7 +930,9 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils {
   }
 
   test("bucket coalescing eliminates shuffle") {
-    withSQLConf(SQLConf.COALESCE_BUCKETS_IN_JOIN_ENABLED.key -> "true") {
+    withSQLConf(
+      SQLConf.COALESCE_BUCKETS_IN_JOIN_ENABLED.key -> "true",
+      SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
       // The side with bucketedTableTestSpec1 will be coalesced to have 4 output partitions.
       // Currently, sort will be introduced for the side that is coalesced.
       val testSpec1 = BucketedTableTestSpec(
@@ -997,7 +999,8 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils {
     }
   }
 
-  test("bucket coalescing is applied when join expressions match with partitioning expressions") {
+  test("bucket coalescing is applied when join expressions match with partitioning expressions",
+    DisableAdaptiveExecution("Expected shuffle num mismatched")) {
     withTable("t1", "t2") {
       df1.write.format("parquet").bucketBy(8, "i", "j").saveAsTable("t1")
       df2.write.format("parquet").bucketBy(4, "i", "j").saveAsTable("t2")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadWithHiveSupportSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadWithHiveSupportSuite.scala
index 0790135..35dab79 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadWithHiveSupportSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadWithHiveSupportSuite.scala
@@ -17,12 +17,10 @@
 
 package org.apache.spark.sql.sources
 
-import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite
 import org.apache.spark.sql.hive.test.TestHiveSingleton
 import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
 
-class BucketedReadWithHiveSupportSuite
-  extends BucketedReadSuite with DisableAdaptiveExecutionSuite with TestHiveSingleton {
+class BucketedReadWithHiveSupportSuite extends BucketedReadSuite with TestHiveSingleton {
   protected override def beforeAll(): Unit = {
     super.beforeAll()
     assert(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "hive")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org