You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2022/01/10 12:08:09 UTC

[spark] branch master updated: [SPARK-35442][SQL] Support propagate empty relation through aggregate/union

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new ee4c4e5  [SPARK-35442][SQL] Support propagate empty relation through aggregate/union
ee4c4e5 is described below

commit ee4c4e5162f61b989a71f8f9d153845ee5e77a88
Author: ulysses-you <ul...@gmail.com>
AuthorDate: Mon Jan 10 20:07:01 2022 +0800

    [SPARK-35442][SQL] Support propagate empty relation through aggregate/union
    
    ### What changes were proposed in this pull request?
    
    - Add `LogicalQueryStage(_, agg: BaseAggregateExec)` check in `AQEPropagateEmptyRelation`
    - Add `LeafNode` check in `PropagateEmptyRelationBase`, so we can eliminate `LogicalQueryStage` to `LocalRelation`
    - Unify the `applyFunc` and `commonApplyFunc` in `PropagateEmptyRelationBase`
    
    ### Why are the changes needed?
    
    The Aggregate in AQE is different with others, the `LogicalQueryStage` looks like `LogicalQueryStage(Aggregate, BaseAggregate)`. We should handle this case specially.
    
    Logically, if the Aggregate grouping expression is not empty, we can eliminate it safely.
    
    ### Does this PR introduce _any_ user-facing change?
    
    no
    
    ### How was this patch tested?
    
    add new test in `AdaptiveQueryExecSuite`
    - `Support propagate empty relation through aggregate`
    - `Support propagate empty relation through union`
    
    Closes #35149 from ulysses-you/SPARK-35442-GA-SPARK.
    
    Authored-by: ulysses-you <ul...@gmail.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../optimizer/PropagateEmptyRelation.scala         | 84 ++++++++++------------
 .../adaptive/AQEPropagateEmptyRelation.scala       | 21 +++++-
 .../adaptive/AdaptiveQueryExecSuite.scala          | 56 ++++++++++++++-
 3 files changed, 110 insertions(+), 51 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
index 6ad0793..d02f12d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
@@ -28,14 +28,17 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{LOCAL_RELATION, TRUE_OR_
 /**
  * The base class of two rules in the normal and AQE Optimizer. It simplifies query plans with
  * empty or non-empty relations:
- *  1. Binary-node Logical Plans
+ *  1. Higher-node Logical Plans
+ *     - Union with all empty children.
+ *  2. Binary-node Logical Plans
  *     - Join with one or two empty children (including Intersect/Except).
  *     - Left semi Join
  *       Right side is non-empty and condition is empty. Eliminate join to its left side.
  *     - Left anti join
  *       Right side is non-empty and condition is empty. Eliminate join to an empty
  *       [[LocalRelation]].
- *  2. Unary-node Logical Plans
+ *  3. Unary-node Logical Plans
+ *     - Project/Filter/Sample with all empty children.
  *     - Limit/Repartition with all empty children.
  *     - Aggregate with all empty children and at least one grouping expression.
  *     - Generate(Explode) with all empty children. Others like Hive UDTF may return results.
@@ -59,6 +62,31 @@ abstract class PropagateEmptyRelationBase extends Rule[LogicalPlan] with CastSup
     plan.output.map{ a => Alias(cast(Literal(null), a.dataType), a.name)(a.exprId) }
 
   protected def commonApplyFunc: PartialFunction[LogicalPlan, LogicalPlan] = {
+    case p: Union if p.children.exists(isEmpty) =>
+      val newChildren = p.children.filterNot(isEmpty)
+      if (newChildren.isEmpty) {
+        empty(p)
+      } else {
+        val newPlan = if (newChildren.size > 1) Union(newChildren) else newChildren.head
+        val outputs = newPlan.output.zip(p.output)
+        // the original Union may produce different output attributes than the new one so we alias
+        // them if needed
+        if (outputs.forall { case (newAttr, oldAttr) => newAttr.exprId == oldAttr.exprId }) {
+          newPlan
+        } else {
+          val newOutput = outputs.map { case (newAttr, oldAttr) =>
+            if (newAttr.exprId == oldAttr.exprId) {
+              newAttr
+            } else {
+              val newExplicitMetadata =
+                if (oldAttr.metadata != newAttr.metadata) Some(oldAttr.metadata) else None
+              Alias(newAttr, oldAttr.name)(oldAttr.exprId, explicitMetadata = newExplicitMetadata)
+            }
+          }
+          Project(newOutput, newPlan)
+        }
+      }
+
     // Joins on empty LocalRelations generated from streaming sources are not eliminated
     // as stateful streaming joins need to perform other state management operations other than
     // just processing the input data.
@@ -98,7 +126,13 @@ abstract class PropagateEmptyRelationBase extends Rule[LogicalPlan] with CastSup
         p
       }
 
+    // the only case can be matched here is that LogicalQueryStage is empty
+    case p: LeafNode if !p.isInstanceOf[LocalRelation] && isEmpty(p) => empty(p)
+
     case p: UnaryNode if p.children.nonEmpty && p.children.forall(isEmpty) => p match {
+      case _: Project => empty(p)
+      case _: Filter => empty(p)
+      case _: Sample => empty(p)
       case _: Sort => empty(p)
       case _: GlobalLimit if !p.isStreaming => empty(p)
       case _: LocalLimit if !p.isStreaming => empty(p)
@@ -128,53 +162,11 @@ abstract class PropagateEmptyRelationBase extends Rule[LogicalPlan] with CastSup
 }
 
 /**
- * This rule runs in the normal optimizer and optimizes more cases
- * compared to [[PropagateEmptyRelationBase]]:
- * 1. Higher-node Logical Plans
- *    - Union with all empty children.
- * 2. Unary-node Logical Plans
- *    - Project/Filter/Sample with all empty children.
- *
- * The reason why we don't apply this rule at AQE optimizer side is: the benefit is not big enough
- * and it may introduce extra exchanges.
+ * This rule runs in the normal optimizer
  */
 object PropagateEmptyRelation extends PropagateEmptyRelationBase {
-  private def applyFunc: PartialFunction[LogicalPlan, LogicalPlan] = {
-    case p: Union if p.children.exists(isEmpty) =>
-      val newChildren = p.children.filterNot(isEmpty)
-      if (newChildren.isEmpty) {
-        empty(p)
-      } else {
-        val newPlan = if (newChildren.size > 1) Union(newChildren) else newChildren.head
-        val outputs = newPlan.output.zip(p.output)
-        // the original Union may produce different output attributes than the new one so we alias
-        // them if needed
-        if (outputs.forall { case (newAttr, oldAttr) => newAttr.exprId == oldAttr.exprId }) {
-          newPlan
-        } else {
-          val outputAliases = outputs.map { case (newAttr, oldAttr) =>
-            val newExplicitMetadata =
-              if (oldAttr.metadata != newAttr.metadata) Some(oldAttr.metadata) else None
-            Alias(newAttr, oldAttr.name)(oldAttr.exprId, explicitMetadata = newExplicitMetadata)
-          }
-          Project(outputAliases, newPlan)
-        }
-      }
-
-    case p: UnaryNode if p.children.nonEmpty && p.children.forall(isEmpty) && canPropagate(p) =>
-      empty(p)
-  }
-
-  // extract the pattern avoid conflict with commonApplyFunc
-  private def canPropagate(plan: LogicalPlan): Boolean = plan match {
-    case _: Project => true
-    case _: Filter => true
-    case _: Sample => true
-    case _ => false
-  }
-
   override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
     _.containsAnyPattern(LOCAL_RELATION, TRUE_OR_FALSE_LITERAL), ruleId) {
-    applyFunc.orElse(commonApplyFunc)
+    commonApplyFunc
   }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala
index ea2fb1c..bab7751 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala
@@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.optimizer.PropagateEmptyRelationBase
 import org.apache.spark.sql.catalyst.planning.ExtractSingleColumnNullAwareAntiJoin
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.trees.TreePattern.{LOCAL_RELATION, LOGICAL_QUERY_STAGE, TRUE_OR_FALSE_LITERAL}
+import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
 import org.apache.spark.sql.execution.joins.HashedRelationWithAllNullKeys
 
 /**
@@ -32,14 +33,28 @@ import org.apache.spark.sql.execution.joins.HashedRelationWithAllNullKeys
  */
 object AQEPropagateEmptyRelation extends PropagateEmptyRelationBase {
   override protected def isEmpty(plan: LogicalPlan): Boolean =
-    super.isEmpty(plan) || getRowCount(plan).contains(0)
+    super.isEmpty(plan) || getEstimatedRowCount(plan).contains(0)
 
   override protected def nonEmpty(plan: LogicalPlan): Boolean =
-    super.nonEmpty(plan) || getRowCount(plan).exists(_ > 0)
+    super.nonEmpty(plan) || getEstimatedRowCount(plan).exists(_ > 0)
 
-  private def getRowCount(plan: LogicalPlan): Option[BigInt] = plan match {
+  // The returned value follows:
+  //   - 0 means the plan must produce 0 row
+  //   - positive value means an estimated row count which can be over-estimated
+  //   - none means the plan has not materialized or the plan can not be estimated
+  private def getEstimatedRowCount(plan: LogicalPlan): Option[BigInt] = plan match {
     case LogicalQueryStage(_, stage: QueryStageExec) if stage.isMaterialized =>
       stage.getRuntimeStatistics.rowCount
+
+    case LogicalQueryStage(_, agg: BaseAggregateExec) if agg.groupingExpressions.nonEmpty &&
+      agg.child.isInstanceOf[QueryStageExec] =>
+      val stage = agg.child.asInstanceOf[QueryStageExec]
+      if (stage.isMaterialized) {
+        stage.getRuntimeStatistics.rowCount
+      } else {
+        None
+      }
+
     case _ => None
   }
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
index 51d4767..a29989c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
@@ -1406,6 +1406,56 @@ class AdaptiveQueryExecSuite
     }
   }
 
+  test("SPARK-35442: Support propagate empty relation through aggregate") {
+    withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
+      val (plan1, adaptivePlan1) = runAdaptiveAndVerifyResult(
+        "SELECT key, count(*) FROM testData WHERE value = 'no_match' GROUP BY key")
+      assert(!plan1.isInstanceOf[LocalTableScanExec])
+      assert(stripAQEPlan(adaptivePlan1).isInstanceOf[LocalTableScanExec])
+
+      val (plan2, adaptivePlan2) = runAdaptiveAndVerifyResult(
+        "SELECT key, count(*) FROM testData WHERE value = 'no_match' GROUP BY key limit 1")
+      assert(!plan2.isInstanceOf[LocalTableScanExec])
+      assert(stripAQEPlan(adaptivePlan2).isInstanceOf[LocalTableScanExec])
+
+      val (plan3, adaptivePlan3) = runAdaptiveAndVerifyResult(
+        "SELECT count(*) FROM testData WHERE value = 'no_match'")
+      assert(!plan3.isInstanceOf[LocalTableScanExec])
+      assert(!stripAQEPlan(adaptivePlan3).isInstanceOf[LocalTableScanExec])
+    }
+  }
+
+  test("SPARK-35442: Support propagate empty relation through union") {
+    def checkNumUnion(plan: SparkPlan, numUnion: Int): Unit = {
+      assert(
+        collect(plan) {
+          case u: UnionExec => u
+        }.size == numUnion)
+    }
+
+    withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
+      val (plan1, adaptivePlan1) = runAdaptiveAndVerifyResult(
+        """
+          |SELECT key, count(*) FROM testData WHERE value = 'no_match' GROUP BY key
+          |UNION ALL
+          |SELECT key, 1 FROM testData
+          |""".stripMargin)
+      checkNumUnion(plan1, 1)
+      checkNumUnion(adaptivePlan1, 0)
+      assert(!stripAQEPlan(adaptivePlan1).isInstanceOf[LocalTableScanExec])
+
+      val (plan2, adaptivePlan2) = runAdaptiveAndVerifyResult(
+        """
+          |SELECT key, count(*) FROM testData WHERE value = 'no_match' GROUP BY key
+          |UNION ALL
+          |SELECT /*+ REPARTITION */ key, 1 FROM testData WHERE value = 'no_match'
+          |""".stripMargin)
+      checkNumUnion(plan2, 1)
+      checkNumUnion(adaptivePlan2, 0)
+      assert(stripAQEPlan(adaptivePlan2).isInstanceOf[LocalTableScanExec])
+    }
+  }
+
   test("SPARK-32753: Only copy tags to node with no tags") {
     withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
       withTempView("v1") {
@@ -1794,7 +1844,8 @@ class AdaptiveQueryExecSuite
   test("SPARK-35239: Coalesce shuffle partition should handle empty input RDD") {
     withTable("t") {
       withSQLConf(SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1",
-        SQLConf.SHUFFLE_PARTITIONS.key -> "2") {
+        SQLConf.SHUFFLE_PARTITIONS.key -> "2",
+        SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName) {
         spark.sql("CREATE TABLE t (c1 int) USING PARQUET")
         val (_, adaptive) = runAdaptiveAndVerifyResult("SELECT c1, count(*) FROM t GROUP BY c1")
         assert(
@@ -2261,7 +2312,8 @@ class AdaptiveQueryExecSuite
   test("SPARK-37742: AQE reads invalid InMemoryRelation stats and mistakenly plans BHJ") {
     withSQLConf(
       SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
-      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1048584") {
+      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1048584",
+      SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName) {
       // Spark estimates a string column as 20 bytes so with 60k rows, these relations should be
       // estimated at ~120m bytes which is greater than the broadcast join threshold.
       val joinKeyOne = "00112233445566778899"

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org