You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2022/06/16 06:25:37 UTC

[spark] branch master updated: [SPARK-39488][SQL] Simplify the error handling of TempResolvedColumn

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6d86d41b53c [SPARK-39488][SQL] Simplify the error handling of TempResolvedColumn
6d86d41b53c is described below

commit 6d86d41b53c338a1897b27668eb22623383828bb
Author: Wenchen Fan <we...@databricks.com>
AuthorDate: Thu Jun 16 14:25:17 2022 +0800

    [SPARK-39488][SQL] Simplify the error handling of TempResolvedColumn
    
    ### What changes were proposed in this pull request?
    
    This is a followup of https://github.com/apache/spark/pull/35404 and https://github.com/apache/spark/pull/36746 , to simplify the error handling of `TempResolvedColumn`.
    
    The idea is:
    1. The rule `ResolveAggregationFunctions` in the main resolution batch creates `TempResolvedColumn` and only removes it if the aggregate expression is fully resolved. It either strips `TempResolvedColumn` if it's inside aggregate function or group expression, or restores `TempResolvedColumn` to `UnresolvedAttribute` otherwise, hoping other rules can resolve it.
    2. The rule `RemoveTempResolvedColumn` in a latter batch can still hit `TempResolvedColumn` if the aggregate expression is unresolved (due to input type mismatch for example, e.g. `avg(bool_col)`, `date_add(int_group_col, 1)`). At this stage, there is no way to restore `TempResolvedColumn` to `UnresolvedAttribute`  and resolve it differently. The query will fail and we should blindly strip `TempResolvedColumn` to provide better error message.
    ### Why are the changes needed?
    
    code cleanup
    
    ### Does this PR introduce _any_ user-facing change?
    
    no
    
    ### How was this patch tested?
    
    existing tests
    
    Closes #36809 from cloud-fan/error.
    
    Lead-authored-by: Wenchen Fan <we...@databricks.com>
    Co-authored-by: Wenchen Fan <cl...@gmail.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../spark/sql/catalyst/analysis/Analyzer.scala     | 65 ++++++++++++----------
 .../sql/catalyst/analysis/CheckAnalysis.scala      | 17 +-----
 .../sql/catalyst/analysis/AnalysisSuite.scala      |  2 +-
 .../src/test/resources/sql-tests/inputs/having.sql |  3 +
 .../resources/sql-tests/results/having.sql.out     |  9 +++
 5 files changed, 49 insertions(+), 47 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 446bc46d9b1..9fe9d490539 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -28,7 +28,6 @@ import scala.util.{Failure, Random, Success, Try}
 
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst._
-import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer.DATA_TYPE_MISMATCH_ERROR_MESSAGE
 import org.apache.spark.sql.catalyst.catalog._
 import org.apache.spark.sql.catalyst.encoders.OuterScopes
 import org.apache.spark.sql.catalyst.expressions.{Expression, FrameLessOffsetWindowFunction, _}
@@ -2647,10 +2646,6 @@ class Analyzer(override val catalogManager: CatalogManager)
       (extraAggExprs.toSeq, transformed)
     }
 
-    private def trimTempResolvedField(input: Expression): Expression = input.transform {
-      case t: TempResolvedColumn => t.child
-    }
-
     private def buildAggExprList(
         expr: Expression,
         agg: Aggregate,
@@ -2666,12 +2661,12 @@ class Analyzer(override val catalogManager: CatalogManager)
       } else {
         expr match {
           case ae: AggregateExpression =>
-            val cleaned = trimTempResolvedField(ae)
+            val cleaned = RemoveTempResolvedColumn.trimTempResolvedColumn(ae)
             val alias = Alias(cleaned, cleaned.toString)()
             aggExprList += alias
             alias.toAttribute
           case grouping: Expression if agg.groupingExpressions.exists(grouping.semanticEquals) =>
-            trimTempResolvedField(grouping) match {
+            RemoveTempResolvedColumn.trimTempResolvedColumn(grouping) match {
               case ne: NamedExpression =>
                 aggExprList += ne
                 ne.toAttribute
@@ -2683,7 +2678,7 @@ class Analyzer(override val catalogManager: CatalogManager)
           case t: TempResolvedColumn =>
             // Undo the resolution as this column is neither inside aggregate functions nor a
             // grouping column. It shouldn't be resolved with `agg.child.output`.
-            CurrentOrigin.withOrigin(t.origin)(UnresolvedAttribute(t.nameParts))
+            RemoveTempResolvedColumn.restoreTempResolvedColumn(t)
           case other =>
             other.withNewChildren(other.children.map(buildAggExprList(_, agg, aggExprList)))
         }
@@ -4345,32 +4340,42 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] {
 }
 
 /**
- * Removes all [[TempResolvedColumn]]s in the query plan. This is the last resort, in case some
- * rules in the main resolution batch miss to remove [[TempResolvedColumn]]s. We should run this
- * rule right after the main resolution batch.
+ * The rule `ResolveAggregationFunctions` in the main resolution batch creates
+ * [[TempResolvedColumn]] in filter conditions and sort expressions to hold the temporarily resolved
+ * column with `agg.child`. When filter conditions or sort expressions are resolved,
+ * `ResolveAggregationFunctions` will replace [[TempResolvedColumn]], to [[AttributeReference]] if
+ * it's inside aggregate functions or group expressions, or to [[UnresolvedAttribute]] otherwise,
+ * hoping other rules can resolve it.
+ *
+ * This rule runs after the main resolution batch, and can still hit [[TempResolvedColumn]] if
+ * filter conditions or sort expressions are not resolved. When this happens, there is no point to
+ * turn [[TempResolvedColumn]] to [[UnresolvedAttribute]], as we can't resolve the column
+ * differently, and query will fail. This rule strips all [[TempResolvedColumn]]s in Filter/Sort and
+ * turns them to [[AttributeReference]] so that the error message can tell users why the filter
+ * conditions or sort expressions were not resolved.
  */
 object RemoveTempResolvedColumn extends Rule[LogicalPlan] {
   override def apply(plan: LogicalPlan): LogicalPlan = {
-    plan.foreachUp {
-      // HAVING clause will be resolved as a Filter. When having func(column with wrong data type),
-      // the column could be wrapped by a TempResolvedColumn, e.g. mean(tempresolvedcolumn(t.c)).
-      // Because TempResolvedColumn can still preserve column data type, here is a chance to check
-      // if the data type matches with the required data type of the function. We can throw an error
-      // when data types mismatches.
-      case operator: Filter =>
-        operator.expressions.foreach(_.foreachUp {
-          case e: Expression if e.childrenResolved && e.checkInputDataTypes().isFailure =>
-            e.checkInputDataTypes() match {
-              case TypeCheckResult.TypeCheckFailure(message) =>
-                e.setTagValue(DATA_TYPE_MISMATCH_ERROR_MESSAGE, message)
-            }
-          case _ =>
-        })
-      case _ =>
+    plan.resolveOperatorsUp {
+      case f @ Filter(cond, agg: Aggregate) if agg.resolved =>
+        withOrigin(f.origin)(f.copy(condition = trimTempResolvedColumn(cond)))
+      case s @ Sort(sortOrder, _, agg: Aggregate) if agg.resolved =>
+        val newSortOrder = sortOrder.map { order =>
+          trimTempResolvedColumn(order).asInstanceOf[SortOrder]
+        }
+        withOrigin(s.origin)(s.copy(order = newSortOrder))
+      case other => other.transformExpressionsUp {
+        // This should not happen. We restore TempResolvedColumn to UnresolvedAttribute to be safe.
+        case t: TempResolvedColumn => restoreTempResolvedColumn(t)
+      }
     }
+  }
 
-    plan.resolveExpressions {
-      case t: TempResolvedColumn => UnresolvedAttribute(t.nameParts)
-    }
+  def trimTempResolvedColumn(input: Expression): Expression = input.transform {
+    case t: TempResolvedColumn => t.child
+  }
+
+  def restoreTempResolvedColumn(t: TempResolvedColumn): Expression = {
+    CurrentOrigin.withOrigin(t.origin)(UnresolvedAttribute(t.nameParts))
   }
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 45e70bdcb6c..416e3a2b834 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -50,8 +50,6 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
 
   val DATA_TYPE_MISMATCH_ERROR = TreeNodeTag[Boolean]("dataTypeMismatchError")
 
-  val DATA_TYPE_MISMATCH_ERROR_MESSAGE = TreeNodeTag[String]("dataTypeMismatchError")
-
   protected def failAnalysis(msg: String): Nothing = {
     throw new AnalysisException(msg)
   }
@@ -176,20 +174,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
             }
         }
 
-        val expressions = getAllExpressions(operator)
-
-        expressions.foreach(_.foreachUp {
-          case e: Expression =>
-            e.getTagValue(DATA_TYPE_MISMATCH_ERROR_MESSAGE) match {
-              case Some(message) =>
-                e.failAnalysis(s"cannot resolve '${e.sql}' due to data type mismatch: $message" +
-                  extraHintForAnsiTypeCoercionExpression(operator))
-              case _ =>
-            }
-          case _ =>
-        })
-
-        expressions.foreach(_.foreachUp {
+        getAllExpressions(operator).foreach(_.foreachUp {
           case a: Attribute if !a.resolved =>
             val missingCol = a.sql
             val candidates = operator.inputSet.toSeq.map(_.qualifiedName)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index a6e952fd865..5c3f4b5f558 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -1172,7 +1172,7 @@ class AnalysisSuite extends AnalysisTest with Matchers {
            |FROM t
            |GROUP BY t.c, t.d
            |HAVING ${func}(c) > 0d""".stripMargin),
-        Seq(s"cannot resolve '$func(c)' due to data type mismatch"),
+        Seq(s"cannot resolve '$func(t.c)' due to data type mismatch"),
         false)
     }
   }
diff --git a/sql/core/src/test/resources/sql-tests/inputs/having.sql b/sql/core/src/test/resources/sql-tests/inputs/having.sql
index 2799b1a94d0..056b99e363d 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/having.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/having.sql
@@ -11,6 +11,9 @@ SELECT k, sum(v) FROM hav GROUP BY k HAVING sum(v) > 2;
 -- having condition contains grouping column
 SELECT count(k) FROM hav GROUP BY v + 1 HAVING v + 1 = 2;
 
+-- invalid having condition contains grouping column
+SELECT count(k) FROM hav GROUP BY v HAVING v = array(1);
+
 -- SPARK-11032: resolve having correctly
 SELECT MIN(t.v) FROM (SELECT * FROM hav WHERE v > 0) t HAVING(COUNT(1) > 0);
 
diff --git a/sql/core/src/test/resources/sql-tests/results/having.sql.out b/sql/core/src/test/resources/sql-tests/results/having.sql.out
index fff470b3d81..e9e24562d1b 100644
--- a/sql/core/src/test/resources/sql-tests/results/having.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/having.sql.out
@@ -29,6 +29,15 @@ struct<count(k):bigint>
 1
 
 
+-- !query
+SELECT count(k) FROM hav GROUP BY v HAVING v = array(1)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(hav.v = array(1))' due to data type mismatch: differing types in '(hav.v = array(1))' (int and array<int>).; line 1 pos 43
+
+
 -- !query
 SELECT MIN(t.v) FROM (SELECT * FROM hav WHERE v > 0) t HAVING(COUNT(1) > 0)
 -- !query schema


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org