You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2021/08/09 08:49:56 UTC

[spark] branch branch-3.2 updated: [SPARK-36352][SQL] Spark should check result plan's output schema name

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.2 by this push:
     new a5ecf2a  [SPARK-36352][SQL] Spark should check result plan's output schema name
a5ecf2a is described below

commit a5ecf2a490727fec97790b149f59bdc498b445be
Author: Angerszhuuuu <an...@gmail.com>
AuthorDate: Mon Aug 9 16:47:56 2021 +0800

    [SPARK-36352][SQL] Spark should check result plan's output schema name
    
    ### What changes were proposed in this pull request?
    Spark should check result plan's output schema name
    
    ### Why are the changes needed?
    In current code, some optimizer rule may change plan's output schema, since in the code we always use semantic equal to check output, but it may change the plan's output schema.
    For example, for SchemaPruning, if we have a plan
    ```
    Project[a, B]
    |--Scan[A, b, c]
    ```
    the origin output schema is `a, B`, after SchemaPruning. it become
    ```
    Project[A, b]
    |--Scan[A, b]
    ```
    It change the plan's schema. when we use CTAS, the schema is same as query plan's output.
    Then since we change the schema, it not consistent with origin SQL. So we need to check final result plan's schema with origin plan's schema
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    existed UT
    
    Closes #33583 from AngersZhuuuu/SPARK-36352.
    
    Authored-by: Angerszhuuuu <an...@gmail.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
    (cherry picked from commit e051a540a10cdda42dc86a6195c0357aea8900e4)
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../spark/sql/catalyst/analysis/Analyzer.scala     |  6 ++++--
 .../spark/sql/catalyst/optimizer/Optimizer.scala   | 22 +++++++++-------------
 .../spark/sql/catalyst/rules/RuleExecutor.scala    |  6 +++---
 .../org/apache/spark/sql/types/DataType.scala      |  2 +-
 .../org/apache/spark/sql/util/SchemaUtils.scala    | 11 +++++++++++
 .../sql/catalyst/trees/RuleExecutorSuite.scala     |  8 ++++++--
 .../sql/execution/adaptive/AQEOptimizer.scala      | 12 ++++++++----
 .../execution/datasources/DataSourceStrategy.scala |  2 +-
 .../sql/execution/datasources/SchemaPruning.scala  | 10 ++++++----
 .../datasources/v2/V2ScanRelationPushDown.scala    |  3 ++-
 .../execution/datasources/SchemaPruningSuite.scala | 12 ++++++++++++
 11 files changed, 63 insertions(+), 31 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 963b42b..b6228d1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -174,8 +174,10 @@ class Analyzer(override val catalogManager: CatalogManager)
 
   private val v1SessionCatalog: SessionCatalog = catalogManager.v1SessionCatalog
 
-  override protected def isPlanIntegral(plan: LogicalPlan): Boolean = {
-    !Utils.isTesting || LogicalPlanIntegrity.checkIfExprIdsAreGloballyUnique(plan)
+  override protected def isPlanIntegral(
+      previousPlan: LogicalPlan,
+      currentPlan: LogicalPlan): Boolean = {
+    !Utils.isTesting || LogicalPlanIntegrity.checkIfExprIdsAreGloballyUnique(currentPlan)
   }
 
   override def isView(nameParts: Seq[String]): Boolean = v1SessionCatalog.isView(nameParts)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 369fb51..40b4c01 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -32,6 +32,7 @@ import org.apache.spark.sql.connector.catalog.CatalogManager
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
+import org.apache.spark.sql.util.SchemaUtils._
 import org.apache.spark.util.Utils
 
 /**
@@ -46,10 +47,14 @@ abstract class Optimizer(catalogManager: CatalogManager)
   // - is still resolved
   // - only host special expressions in supported operators
   // - has globally-unique attribute IDs
-  override protected def isPlanIntegral(plan: LogicalPlan): Boolean = {
-    !Utils.isTesting || (plan.resolved &&
-      plan.find(PlanHelper.specialExpressionsInUnsupportedOperator(_).nonEmpty).isEmpty &&
-      LogicalPlanIntegrity.checkIfExprIdsAreGloballyUnique(plan))
+  // - optimized plan have same schema with previous plan.
+  override protected def isPlanIntegral(
+      previousPlan: LogicalPlan,
+      currentPlan: LogicalPlan): Boolean = {
+    !Utils.isTesting || (currentPlan.resolved &&
+      currentPlan.find(PlanHelper.specialExpressionsInUnsupportedOperator(_).nonEmpty).isEmpty &&
+      LogicalPlanIntegrity.checkIfExprIdsAreGloballyUnique(currentPlan) &&
+      DataType.equalsIgnoreNullability(previousPlan.schema, currentPlan.schema))
   }
 
   override protected val excludedOnceBatches: Set[String] =
@@ -515,15 +520,6 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] {
  * Remove no-op operators from the query plan that do not make any modifications.
  */
 object RemoveNoopOperators extends Rule[LogicalPlan] {
-  def restoreOriginalOutputNames(
-      projectList: Seq[NamedExpression],
-      originalNames: Seq[String]): Seq[NamedExpression] = {
-    projectList.zip(originalNames).map {
-      case (attr: Attribute, name) => attr.withName(name)
-      case (alias: Alias, name) => alias.withName(name)
-      case (other, _) => other
-    }
-  }
 
   def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
     _.containsAnyPattern(PROJECT, WINDOW), ruleId) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
index 17d7794..759eba6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
@@ -156,7 +156,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging {
    * `Optimizer`, so we can catch rules that return invalid plans. The check function returns
    * `false` if the given plan doesn't pass the structural integrity check.
    */
-  protected def isPlanIntegral(plan: TreeType): Boolean = true
+  protected def isPlanIntegral(previousPlan: TreeType, currentPlan: TreeType): Boolean = true
 
   /**
    * Util method for checking whether a plan remains the same if re-optimized.
@@ -192,7 +192,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging {
     val beforeMetrics = RuleExecutor.getCurrentMetrics()
 
     // Run the structural integrity checker against the initial input
-    if (!isPlanIntegral(plan)) {
+    if (!isPlanIntegral(plan, plan)) {
       throw QueryExecutionErrors.structuralIntegrityOfInputPlanIsBrokenInClassError(
         this.getClass.getName.stripSuffix("$"))
     }
@@ -224,7 +224,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging {
             tracker.foreach(_.recordRuleInvocation(rule.ruleName, runTime, effective))
 
             // Run the structural integrity checker against the plan after each rule.
-            if (effective && !isPlanIntegral(result)) {
+            if (effective && !isPlanIntegral(plan, result)) {
               throw QueryExecutionErrors.structuralIntegrityIsBrokenAfterApplyingRuleError(
                 rule.ruleName, batch.name)
             }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
index 585045d..ef1aeec 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
@@ -292,7 +292,7 @@ object DataType {
   /**
    * Compares two types, ignoring nullability of ArrayType, MapType, StructType.
    */
-  private[types] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = {
+  private[sql] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = {
     (left, right) match {
       case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) =>
         equalsIgnoreNullability(leftElementType, rightElementType)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala
index da105af..63c1f18 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala
@@ -21,6 +21,7 @@ import java.util.Locale
 
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.analysis._
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, NamedExpression}
 import org.apache.spark.sql.connector.expressions.{BucketTransform, FieldReference, NamedTransform, Transform}
 import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
 import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructField, StructType}
@@ -273,6 +274,16 @@ private[spark] object SchemaUtils {
     field._1
   }
 
+  def restoreOriginalOutputNames(
+      projectList: Seq[NamedExpression],
+      originalNames: Seq[String]): Seq[NamedExpression] = {
+    projectList.zip(originalNames).map {
+      case (attr: Attribute, name) => attr.withName(name)
+      case (alias: Alias, name) => alias.withName(name)
+      case (other, _) => other
+    }
+  }
+
   /**
    * @param str The string to be escaped.
    * @return The escaped string.
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala
index 25352e2..b14686b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala
@@ -73,7 +73,9 @@ class RuleExecutorSuite extends SparkFunSuite {
 
   test("structural integrity checker - verify initial input") {
     object WithSIChecker extends RuleExecutor[Expression] {
-      override protected def isPlanIntegral(expr: Expression): Boolean = expr match {
+      override protected def isPlanIntegral(
+          previousPlan: Expression,
+          currentPlan: Expression): Boolean = currentPlan match {
         case IntegerLiteral(_) => true
         case _ => false
       }
@@ -91,7 +93,9 @@ class RuleExecutorSuite extends SparkFunSuite {
 
   test("structural integrity checker - verify rule execution result") {
     object WithSICheckerForPositiveLiteral extends RuleExecutor[Expression] {
-      override protected def isPlanIntegral(expr: Expression): Boolean = expr match {
+      override protected def isPlanIntegral(
+          previousPlan: Expression,
+          currentPlan: Expression): Boolean = currentPlan match {
         case IntegerLiteral(i) if i > 0 => true
         case _ => false
       }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala
index 0767039..f8cba90 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala
@@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
 import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LogicalPlanIntegrity, PlanHelper}
 import org.apache.spark.sql.catalyst.rules.RuleExecutor
 import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.DataType
 import org.apache.spark.util.Utils
 
 /**
@@ -64,9 +65,12 @@ class AQEOptimizer(conf: SQLConf) extends RuleExecutor[LogicalPlan] {
     }
   }
 
-  override protected def isPlanIntegral(plan: LogicalPlan): Boolean = {
-    !Utils.isTesting || (plan.resolved &&
-      plan.find(PlanHelper.specialExpressionsInUnsupportedOperator(_).nonEmpty).isEmpty &&
-      LogicalPlanIntegrity.checkIfExprIdsAreGloballyUnique(plan))
+  override protected def isPlanIntegral(
+      previousPlan: LogicalPlan,
+      currentPlan: LogicalPlan): Boolean = {
+    !Utils.isTesting || (currentPlan.resolved &&
+      currentPlan.find(PlanHelper.specialExpressionsInUnsupportedOperator(_).nonEmpty).isEmpty &&
+      LogicalPlanIntegrity.checkIfExprIdsAreGloballyUnique(currentPlan) &&
+      DataType.equalsIgnoreNullability(previousPlan.schema, currentPlan.schema))
   }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index 81ecb2c..11d23f4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -464,7 +464,7 @@ object DataSourceStrategy
    */
   protected[sql] def normalizeExprs(
       exprs: Seq[Expression],
-      attributes: Seq[AttributeReference]): Seq[Expression] = {
+      attributes: Seq[Attribute]): Seq[Expression] = {
     exprs.map { e =>
       e transform {
         case a: AttributeReference =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala
index a197445..4f331c7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
 import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
 import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}
+import org.apache.spark.sql.util.SchemaUtils._
 
 /**
  * Prunes unnecessary physical columns given a [[PhysicalOperation]] over a data source relation.
@@ -82,8 +83,8 @@ object SchemaPruning extends Rule[LogicalPlan] {
         val prunedRelation = leafNodeBuilder(prunedDataSchema)
         val projectionOverSchema = ProjectionOverSchema(prunedDataSchema)
 
-        Some(buildNewProjection(normalizedProjects, normalizedFilters, prunedRelation,
-          projectionOverSchema))
+        Some(buildNewProjection(projects, normalizedProjects, normalizedFilters,
+          prunedRelation, projectionOverSchema))
       } else {
         None
       }
@@ -125,6 +126,7 @@ object SchemaPruning extends Rule[LogicalPlan] {
    */
   private def buildNewProjection(
       projects: Seq[NamedExpression],
+      normalizedProjects: Seq[NamedExpression],
       filters: Seq[Expression],
       leafNode: LeafNode,
       projectionOverSchema: ProjectionOverSchema): Project = {
@@ -143,7 +145,7 @@ object SchemaPruning extends Rule[LogicalPlan] {
 
     // Construct the new projections of our Project by
     // rewriting the original projections
-    val newProjects = projects.map(_.transformDown {
+    val newProjects = normalizedProjects.map(_.transformDown {
       case projectionOverSchema(expr) => expr
     }).map { case expr: NamedExpression => expr }
 
@@ -151,7 +153,7 @@ object SchemaPruning extends Rule[LogicalPlan] {
       logDebug(s"New projects:\n${newProjects.map(_.treeString).mkString("\n")}")
     }
 
-    Project(newProjects, projectionChild)
+    Project(restoreOriginalOutputNames(newProjects, projects.map(_.name)), projectionChild)
   }
 
   /**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
index d05519b..ab5a0fe 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
@@ -30,6 +30,7 @@ import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownA
 import org.apache.spark.sql.execution.datasources.DataSourceStrategy
 import org.apache.spark.sql.sources
 import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.SchemaUtils._
 
 object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
   import DataSourceV2Implicits._
@@ -207,7 +208,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
         val newProjects = normalizedProjects
           .map(projectionFunc)
           .asInstanceOf[Seq[NamedExpression]]
-        Project(newProjects, withFilter)
+        Project(restoreOriginalOutputNames(newProjects, project.map(_.name)), withFilter)
       } else {
         withFilter
       }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
index ac5c289..395ee6f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
@@ -870,4 +870,16 @@ abstract class SchemaPruningSuite
       checkAnswer(query, Row(1) :: Row(2) :: Nil)
     }
   }
+
+  test("SPARK-36352: Spark should check result plan's output schema name") {
+    withMixedCaseData {
+      val query = sql("select cOL1, cOl2.B from mixedcase")
+      assert(query.queryExecution.executedPlan.schema.catalogString ==
+        "struct<cOL1:string,B:int>")
+      checkAnswer(query.orderBy("id"),
+        Row("r0c1", 1) ::
+          Row("r1c1", 2) ::
+          Nil)
+    }
+  }
 }

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org