You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2022/06/16 06:25:37 UTC
[spark] branch master updated: [SPARK-39488][SQL] Simplify the error handling of TempResolvedColumn
This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 6d86d41b53c [SPARK-39488][SQL] Simplify the error handling of TempResolvedColumn
6d86d41b53c is described below
commit 6d86d41b53c338a1897b27668eb22623383828bb
Author: Wenchen Fan <we...@databricks.com>
AuthorDate: Thu Jun 16 14:25:17 2022 +0800
[SPARK-39488][SQL] Simplify the error handling of TempResolvedColumn
### What changes were proposed in this pull request?
This is a followup of https://github.com/apache/spark/pull/35404 and https://github.com/apache/spark/pull/36746 , to simplify the error handling of `TempResolvedColumn`.
The idea is:
1. The rule `ResolveAggregationFunctions` in the main resolution batch creates `TempResolvedColumn` and only removes it if the aggregate expression is fully resolved. It either strips `TempResolvedColumn` if it's inside aggregate function or group expression, or restores `TempResolvedColumn` to `UnresolvedAttribute` otherwise, hoping other rules can resolve it.
2. The rule `RemoveTempResolvedColumn` in a latter batch can still hit `TempResolvedColumn` if the aggregate expression is unresolved (due to input type mismatch for example, e.g. `avg(bool_col)`, `date_add(int_group_col, 1)`). At this stage, there is no way to restore `TempResolvedColumn` to `UnresolvedAttribute` and resolve it differently. The query will fail and we should blindly strip `TempResolvedColumn` to provide better error message.
### Why are the changes needed?
code cleanup
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
existing tests
Closes #36809 from cloud-fan/error.
Lead-authored-by: Wenchen Fan <we...@databricks.com>
Co-authored-by: Wenchen Fan <cl...@gmail.com>
Signed-off-by: Wenchen Fan <we...@databricks.com>
---
.../spark/sql/catalyst/analysis/Analyzer.scala | 65 ++++++++++++----------
.../sql/catalyst/analysis/CheckAnalysis.scala | 17 +-----
.../sql/catalyst/analysis/AnalysisSuite.scala | 2 +-
.../src/test/resources/sql-tests/inputs/having.sql | 3 +
.../resources/sql-tests/results/having.sql.out | 9 +++
5 files changed, 49 insertions(+), 47 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 446bc46d9b1..9fe9d490539 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -28,7 +28,6 @@ import scala.util.{Failure, Random, Success, Try}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst._
-import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer.DATA_TYPE_MISMATCH_ERROR_MESSAGE
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.encoders.OuterScopes
import org.apache.spark.sql.catalyst.expressions.{Expression, FrameLessOffsetWindowFunction, _}
@@ -2647,10 +2646,6 @@ class Analyzer(override val catalogManager: CatalogManager)
(extraAggExprs.toSeq, transformed)
}
- private def trimTempResolvedField(input: Expression): Expression = input.transform {
- case t: TempResolvedColumn => t.child
- }
-
private def buildAggExprList(
expr: Expression,
agg: Aggregate,
@@ -2666,12 +2661,12 @@ class Analyzer(override val catalogManager: CatalogManager)
} else {
expr match {
case ae: AggregateExpression =>
- val cleaned = trimTempResolvedField(ae)
+ val cleaned = RemoveTempResolvedColumn.trimTempResolvedColumn(ae)
val alias = Alias(cleaned, cleaned.toString)()
aggExprList += alias
alias.toAttribute
case grouping: Expression if agg.groupingExpressions.exists(grouping.semanticEquals) =>
- trimTempResolvedField(grouping) match {
+ RemoveTempResolvedColumn.trimTempResolvedColumn(grouping) match {
case ne: NamedExpression =>
aggExprList += ne
ne.toAttribute
@@ -2683,7 +2678,7 @@ class Analyzer(override val catalogManager: CatalogManager)
case t: TempResolvedColumn =>
// Undo the resolution as this column is neither inside aggregate functions nor a
// grouping column. It shouldn't be resolved with `agg.child.output`.
- CurrentOrigin.withOrigin(t.origin)(UnresolvedAttribute(t.nameParts))
+ RemoveTempResolvedColumn.restoreTempResolvedColumn(t)
case other =>
other.withNewChildren(other.children.map(buildAggExprList(_, agg, aggExprList)))
}
@@ -4345,32 +4340,42 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] {
}
/**
- * Removes all [[TempResolvedColumn]]s in the query plan. This is the last resort, in case some
- * rules in the main resolution batch miss to remove [[TempResolvedColumn]]s. We should run this
- * rule right after the main resolution batch.
+ * The rule `ResolveAggregationFunctions` in the main resolution batch creates
+ * [[TempResolvedColumn]] in filter conditions and sort expressions to hold the temporarily resolved
+ * column with `agg.child`. When filter conditions or sort expressions are resolved,
+ * `ResolveAggregationFunctions` will replace [[TempResolvedColumn]], to [[AttributeReference]] if
+ * it's inside aggregate functions or group expressions, or to [[UnresolvedAttribute]] otherwise,
+ * hoping other rules can resolve it.
+ *
+ * This rule runs after the main resolution batch, and can still hit [[TempResolvedColumn]] if
+ * filter conditions or sort expressions are not resolved. When this happens, there is no point to
+ * turn [[TempResolvedColumn]] to [[UnresolvedAttribute]], as we can't resolve the column
+ * differently, and query will fail. This rule strips all [[TempResolvedColumn]]s in Filter/Sort and
+ * turns them to [[AttributeReference]] so that the error message can tell users why the filter
+ * conditions or sort expressions were not resolved.
*/
object RemoveTempResolvedColumn extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = {
- plan.foreachUp {
- // HAVING clause will be resolved as a Filter. When having func(column with wrong data type),
- // the column could be wrapped by a TempResolvedColumn, e.g. mean(tempresolvedcolumn(t.c)).
- // Because TempResolvedColumn can still preserve column data type, here is a chance to check
- // if the data type matches with the required data type of the function. We can throw an error
- // when data types mismatches.
- case operator: Filter =>
- operator.expressions.foreach(_.foreachUp {
- case e: Expression if e.childrenResolved && e.checkInputDataTypes().isFailure =>
- e.checkInputDataTypes() match {
- case TypeCheckResult.TypeCheckFailure(message) =>
- e.setTagValue(DATA_TYPE_MISMATCH_ERROR_MESSAGE, message)
- }
- case _ =>
- })
- case _ =>
+ plan.resolveOperatorsUp {
+ case f @ Filter(cond, agg: Aggregate) if agg.resolved =>
+ withOrigin(f.origin)(f.copy(condition = trimTempResolvedColumn(cond)))
+ case s @ Sort(sortOrder, _, agg: Aggregate) if agg.resolved =>
+ val newSortOrder = sortOrder.map { order =>
+ trimTempResolvedColumn(order).asInstanceOf[SortOrder]
+ }
+ withOrigin(s.origin)(s.copy(order = newSortOrder))
+ case other => other.transformExpressionsUp {
+ // This should not happen. We restore TempResolvedColumn to UnresolvedAttribute to be safe.
+ case t: TempResolvedColumn => restoreTempResolvedColumn(t)
+ }
}
+ }
- plan.resolveExpressions {
- case t: TempResolvedColumn => UnresolvedAttribute(t.nameParts)
- }
+ def trimTempResolvedColumn(input: Expression): Expression = input.transform {
+ case t: TempResolvedColumn => t.child
+ }
+
+ def restoreTempResolvedColumn(t: TempResolvedColumn): Expression = {
+ CurrentOrigin.withOrigin(t.origin)(UnresolvedAttribute(t.nameParts))
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 45e70bdcb6c..416e3a2b834 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -50,8 +50,6 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
val DATA_TYPE_MISMATCH_ERROR = TreeNodeTag[Boolean]("dataTypeMismatchError")
- val DATA_TYPE_MISMATCH_ERROR_MESSAGE = TreeNodeTag[String]("dataTypeMismatchError")
-
protected def failAnalysis(msg: String): Nothing = {
throw new AnalysisException(msg)
}
@@ -176,20 +174,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
}
}
- val expressions = getAllExpressions(operator)
-
- expressions.foreach(_.foreachUp {
- case e: Expression =>
- e.getTagValue(DATA_TYPE_MISMATCH_ERROR_MESSAGE) match {
- case Some(message) =>
- e.failAnalysis(s"cannot resolve '${e.sql}' due to data type mismatch: $message" +
- extraHintForAnsiTypeCoercionExpression(operator))
- case _ =>
- }
- case _ =>
- })
-
- expressions.foreach(_.foreachUp {
+ getAllExpressions(operator).foreach(_.foreachUp {
case a: Attribute if !a.resolved =>
val missingCol = a.sql
val candidates = operator.inputSet.toSeq.map(_.qualifiedName)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index a6e952fd865..5c3f4b5f558 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -1172,7 +1172,7 @@ class AnalysisSuite extends AnalysisTest with Matchers {
|FROM t
|GROUP BY t.c, t.d
|HAVING ${func}(c) > 0d""".stripMargin),
- Seq(s"cannot resolve '$func(c)' due to data type mismatch"),
+ Seq(s"cannot resolve '$func(t.c)' due to data type mismatch"),
false)
}
}
diff --git a/sql/core/src/test/resources/sql-tests/inputs/having.sql b/sql/core/src/test/resources/sql-tests/inputs/having.sql
index 2799b1a94d0..056b99e363d 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/having.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/having.sql
@@ -11,6 +11,9 @@ SELECT k, sum(v) FROM hav GROUP BY k HAVING sum(v) > 2;
-- having condition contains grouping column
SELECT count(k) FROM hav GROUP BY v + 1 HAVING v + 1 = 2;
+-- invalid having condition contains grouping column
+SELECT count(k) FROM hav GROUP BY v HAVING v = array(1);
+
-- SPARK-11032: resolve having correctly
SELECT MIN(t.v) FROM (SELECT * FROM hav WHERE v > 0) t HAVING(COUNT(1) > 0);
diff --git a/sql/core/src/test/resources/sql-tests/results/having.sql.out b/sql/core/src/test/resources/sql-tests/results/having.sql.out
index fff470b3d81..e9e24562d1b 100644
--- a/sql/core/src/test/resources/sql-tests/results/having.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/having.sql.out
@@ -29,6 +29,15 @@ struct<count(k):bigint>
1
+-- !query
+SELECT count(k) FROM hav GROUP BY v HAVING v = array(1)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+cannot resolve '(hav.v = array(1))' due to data type mismatch: differing types in '(hav.v = array(1))' (int and array<int>).; line 1 pos 43
+
+
-- !query
SELECT MIN(t.v) FROM (SELECT * FROM hav WHERE v > 0) t HAVING(COUNT(1) > 0)
-- !query schema
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org