You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by do...@apache.org on 2020/07/24 18:07:41 UTC

[spark] branch branch-3.0 updated: [SPARK-32430][SQL] Extend SparkSessionExtensions to inject rules into AQE query stage preparation

This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 7004c98  [SPARK-32430][SQL] Extend SparkSessionExtensions to inject rules into AQE query stage preparation
7004c98 is described below

commit 7004c989048b08891fb5f62ce2fcf0c89ce1496a
Author: Andy Grove <an...@nvidia.com>
AuthorDate: Fri Jul 24 11:03:57 2020 -0700

    [SPARK-32430][SQL] Extend SparkSessionExtensions to inject rules into AQE query stage preparation
    
    ### What changes were proposed in this pull request?
    
    Provide a generic mechanism for plugins to inject rules into the AQE "query prep" stage that happens before query stage creation.
    
    This goes along with https://issues.apache.org/jira/browse/SPARK-32332 where the current AQE implementation doesn't allow for users to properly extend it for columnar processing.
    
    ### Why are the changes needed?
    
    The issue here is that we create new query stages but we do not have access to the parent plan of the new query stage so certain things can not be determined because you have to know what the parent did.  With this change it would allow you to add TAGs to be able to figure out what is going on.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    A new unit test is included in the PR.
    
    Closes #29224 from andygrove/insert-aqe-rule.
    
    Authored-by: Andy Grove <an...@nvidia.com>
    Signed-off-by: Dongjoon Hyun <do...@apache.org>
    (cherry picked from commit 64a01c0a559396fccd615dc00576a80bc8cc5648)
    Signed-off-by: Dongjoon Hyun <do...@apache.org>
---
 .../apache/spark/sql/SparkSessionExtensions.scala  | 20 ++++++++-
 .../execution/adaptive/AdaptiveSparkPlanExec.scala |  2 +-
 .../sql/internal/BaseSessionStateBuilder.scala     |  9 +++-
 .../apache/spark/sql/internal/SessionState.scala   |  4 +-
 .../spark/sql/SparkSessionExtensionSuite.scala     | 49 ++++++++++++++++++++++
 5 files changed, 79 insertions(+), 5 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
index 1c2bf9e..bd870fb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.ExpressionInfo
 import org.apache.spark.sql.catalyst.parser.ParserInterface
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.execution.ColumnarRule
+import org.apache.spark.sql.execution.{ColumnarRule, SparkPlan}
 
 /**
  * :: Experimental ::
@@ -44,6 +44,7 @@ import org.apache.spark.sql.execution.ColumnarRule
  * <li>Customized Parser.</li>
  * <li>(External) Catalog listeners.</li>
  * <li>Columnar Rules.</li>
+ * <li>Adaptive Query Stage Preparation Rules.</li>
  * </ul>
  *
  * The extensions can be used by calling `withExtensions` on the [[SparkSession.Builder]], for
@@ -96,8 +97,10 @@ class SparkSessionExtensions {
   type ParserBuilder = (SparkSession, ParserInterface) => ParserInterface
   type FunctionDescription = (FunctionIdentifier, ExpressionInfo, FunctionBuilder)
   type ColumnarRuleBuilder = SparkSession => ColumnarRule
+  type QueryStagePrepRuleBuilder = SparkSession => Rule[SparkPlan]
 
   private[this] val columnarRuleBuilders = mutable.Buffer.empty[ColumnarRuleBuilder]
+  private[this] val queryStagePrepRuleBuilders = mutable.Buffer.empty[QueryStagePrepRuleBuilder]
 
   /**
    * Build the override rules for columnar execution.
@@ -107,12 +110,27 @@ class SparkSessionExtensions {
   }
 
   /**
+   * Build the override rules for the query stage preparation phase of adaptive query execution.
+   */
+  private[sql] def buildQueryStagePrepRules(session: SparkSession): Seq[Rule[SparkPlan]] = {
+    queryStagePrepRuleBuilders.map(_.apply(session)).toSeq
+  }
+
+  /**
    * Inject a rule that can override the columnar execution of an executor.
    */
   def injectColumnar(builder: ColumnarRuleBuilder): Unit = {
     columnarRuleBuilders += builder
   }
 
+  /**
+   * Inject a rule that can override the the query stage preparation phase of adaptive query
+   * execution.
+   */
+  def injectQueryStagePrepRule(builder: QueryStagePrepRuleBuilder): Unit = {
+    queryStagePrepRuleBuilders += builder
+  }
+
   private[this] val resolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder]
 
   /**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
index f6a3333..5714c33 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
@@ -90,7 +90,7 @@ case class AdaptiveSparkPlanExec(
   // Exchange nodes) after running these rules.
   private def queryStagePreparationRules: Seq[Rule[SparkPlan]] = Seq(
     ensureRequirements
-  )
+  ) ++ context.session.sessionState.queryStagePrepRules
 
   // A list of physical optimizer rules to be applied to a new stage before its execution. These
   // optimizations should be stage-independent.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
index 4ae12f8..83a7a557 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.parser.ParserInterface
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.connector.catalog.CatalogManager
-import org.apache.spark.sql.execution.{ColumnarRule, QueryExecution, SparkOptimizer, SparkPlanner, SparkSqlParser}
+import org.apache.spark.sql.execution.{ColumnarRule, QueryExecution, SparkOptimizer, SparkPlan, SparkPlanner, SparkSqlParser}
 import org.apache.spark.sql.execution.aggregate.ResolveEncodersInScalaAgg
 import org.apache.spark.sql.execution.analysis.DetectAmbiguousSelfJoin
 import org.apache.spark.sql.execution.command.CommandCheck
@@ -286,6 +286,10 @@ abstract class BaseSessionStateBuilder(
     extensions.buildColumnarRules(session)
   }
 
+  protected def queryStagePrepRules: Seq[Rule[SparkPlan]] = {
+    extensions.buildQueryStagePrepRules(session)
+  }
+
   /**
    * Create a query execution object.
    */
@@ -337,7 +341,8 @@ abstract class BaseSessionStateBuilder(
       () => resourceLoader,
       createQueryExecution,
       createClone,
-      columnarRules)
+      columnarRules,
+      queryStagePrepRules)
   }
 }
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
index abd1250..cd425b0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
@@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.catalog._
 import org.apache.spark.sql.catalyst.optimizer.Optimizer
 import org.apache.spark.sql.catalyst.parser.ParserInterface
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.connector.catalog.CatalogManager
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.streaming.StreamingQueryManager
@@ -73,7 +74,8 @@ private[sql] class SessionState(
     resourceLoaderBuilder: () => SessionResourceLoader,
     createQueryExecution: LogicalPlan => QueryExecution,
     createClone: (SparkSession, SessionState) => SessionState,
-    val columnarRules: Seq[ColumnarRule]) {
+    val columnarRules: Seq[ColumnarRule],
+    val queryStagePrepRules: Seq[Rule[SparkPlan]]) {
 
   // The following fields are lazy to avoid creating the Hive client when creating SessionState.
   lazy val catalog: SessionCatalog = catalogBuilder()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
index d9c90c7..44e784d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
@@ -26,7 +26,9 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface}
 import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, UnresolvedHint}
 import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.TreeNodeTag
 import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
 import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.internal.SQLConf.COLUMN_BATCH_SIZE
@@ -145,6 +147,28 @@ class SparkSessionExtensionSuite extends SparkFunSuite {
     }
   }
 
+  test("inject adaptive query prep rule") {
+    val extensions = create { extensions =>
+      // inject rule that will run during AQE query stage preparation and will add custom tags
+      // to the plan
+      extensions.injectQueryStagePrepRule(session => MyQueryStagePrepRule())
+      // inject rule that will run during AQE query stage optimization and will verify that the
+      // custom tags were written in the preparation phase
+      extensions.injectColumnar(session =>
+        MyColumarRule(MyNewQueryStageRule(), MyNewQueryStageRule()))
+    }
+    withSession(extensions) { session =>
+      session.sessionState.conf.setConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED, true)
+      assert(session.sessionState.queryStagePrepRules.contains(MyQueryStagePrepRule()))
+      assert(session.sessionState.columnarRules.contains(
+        MyColumarRule(MyNewQueryStageRule(), MyNewQueryStageRule())))
+      import session.sqlContext.implicits._
+      val data = Seq((100L), (200L), (300L)).toDF("vals").repartition(1)
+      val df = data.selectExpr("vals + 1")
+      df.collect()
+    }
+  }
+
   test("inject columnar") {
     val extensions = create { extensions =>
       extensions.injectColumnar(session =>
@@ -731,6 +755,31 @@ class MyExtensions extends (SparkSessionExtensions => Unit) {
   }
 }
 
+object QueryPrepRuleHelper {
+  val myPrepTag: TreeNodeTag[String] = TreeNodeTag[String]("myPrepTag")
+  val myPrepTagValue: String = "myPrepTagValue"
+}
+
+// this rule will run during AQE query preparation and will write custom tags to each node
+case class MyQueryStagePrepRule() extends Rule[SparkPlan] {
+  override def apply(plan: SparkPlan): SparkPlan = plan.transformDown {
+    case plan =>
+      plan.setTagValue(QueryPrepRuleHelper.myPrepTag, QueryPrepRuleHelper.myPrepTagValue)
+      plan
+  }
+}
+
+// this rule will run during AQE query stage optimization and will verify custom tags were
+// already written during query preparation phase
+case class MyNewQueryStageRule() extends Rule[SparkPlan] {
+  override def apply(plan: SparkPlan): SparkPlan = plan.transformDown {
+    case plan if !plan.isInstanceOf[AdaptiveSparkPlanExec] =>
+      assert(plan.getTagValue(QueryPrepRuleHelper.myPrepTag).get ==
+          QueryPrepRuleHelper.myPrepTagValue)
+      plan
+  }
+}
+
 case class MyRule2(spark: SparkSession) extends Rule[LogicalPlan] {
   override def apply(plan: LogicalPlan): LogicalPlan = plan
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org