You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@spark.apache.org by "Kris Mok (Jira)" <ji...@apache.org> on 2023/03/18 01:05:00 UTC

[jira] [Created] (SPARK-42851) EquivalentExpressions methods need to be consistently guarded by supportedExpression

Kris Mok created SPARK-42851:
--------------------------------

             Summary: EquivalentExpressions methods need to be consistently guarded by supportedExpression
                 Key: SPARK-42851
                 URL: https://issues.apache.org/jira/browse/SPARK-42851
             Project: Spark
          Issue Type: Bug
          Components: SQL
    Affects Versions: 3.3.2, 3.4.0
            Reporter: Kris Mok


SPARK-41468 tried to fix a bug but introduced a new regression. Its change to {{EquivalentExpressions}} added a {{supportedExpression()}} guard to the {{addExprTree()}} and {{getExprState()}} methods, but didn't add the same guard to the other "add" entry point -- {{addExpr()}}.

As such, uses that add single expressions to CSE via {{addExpr()}} may succeed, but upon retrieval via {{getExprState()}} it'd inconsistently get a {{None}} due to failing the guard.

We need to make sure the "add" and "get" methods are consistent. It could be done by one of:
1. Adding the same {{supportedExpression()}} guard to {{addExpr()}}, or
2. Removing the guard from {{getExprState()}}, relying solely on the guard on the "add" path to make sure only intended state is added.
(or other alternative refactorings to fuse the guard into various methods to make it more efficient)

There are pros and cons to the two directions above, because {{addExpr()}} used to allow (potentially incorrect) more expressions to get CSE'd, making it more restrictive may cause performance regressions (for the cases that happened to work).

Example:
{code:sql}
select max(transform(array(id), x -> x)), max(transform(array(id), x -> x)) from range(2)
{code}

Running this query on Spark 3.2 branch returns the correct value:
{code}
scala> spark.sql("select max(transform(array(id), x -> x)), max(transform(array(id), x -> x)) from range(2)").collect
res0: Array[org.apache.spark.sql.Row] = Array([WrappedArray(1),WrappedArray(1)])
{code}
Here, {{transform(array(id), x -> x)}} is an {{AggregateExpression}} that was (potentially unsafely) recognized by {{addExpr()}} as a common subexpression, and {{getExprState()}} doesn't do extra guarding, so during physical planning, in {{PhysicalAggregation}} this expression gets CSE'd in both the aggregation expression list and the result expressions list.
{code}
AdaptiveSparkPlan isFinalPlan=false
+- SortAggregate(key=[], functions=[max(transform(array(id#0L), lambdafunction(lambda x#1L, lambda x#1L, false)))])
   +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [plan_id=11]
      +- SortAggregate(key=[], functions=[partial_max(transform(array(id#0L), lambdafunction(lambda x#1L, lambda x#1L, false)))])
         +- Range (0, 2, step=1, splits=16)
{code}

Running the same query on current master triggers an error when binding the result expression to the aggregate expression in the Aggregate operators (for a WSCG-enabled operator like {{HashAggregateExec}}, the same error would show up during codegen):
{code}
ERROR TaskSetManager: Task 0 in stage 2.0 failed 1 times; aborting job
org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 2.0 failed 1 times, most recent failure: Lost task 0.0 in stage 2.0 (TID 16) (ip-10-110-16-93.us-west-2.compute.internal executor driver): java.lang.IllegalStateException: Couldn't find max(transform(array(id#0L), lambdafunction(lambda x#2L, lambda x#2L, false)))#4 in [max(transform(array(id#0L), lambdafunction(lambda x#1L, lambda x#1L, false)))#3]
	at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:80)
	at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:73)
	at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$1(TreeNode.scala:512)
	at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:104)
	at org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:512)
	at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$3(TreeNode.scala:517)
	at org.apache.spark.sql.catalyst.trees.UnaryLike.mapChildren(TreeNode.scala:1249)
	at org.apache.spark.sql.catalyst.trees.UnaryLike.mapChildren$(TreeNode.scala:1248)
	at org.apache.spark.sql.catalyst.expressions.UnaryExpression.mapChildren(Expression.scala:532)
	at org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:517)
	at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:488)
	at org.apache.spark.sql.catalyst.trees.TreeNode.transform(TreeNode.scala:456)
	at org.apache.spark.sql.catalyst.expressions.BindReferences$.bindReference(BoundAttribute.scala:73)
	at org.apache.spark.sql.catalyst.expressions.BindReferences$.$anonfun$bindReferences$1(BoundAttribute.scala:94)
	at scala.collection.immutable.List.map(List.scala:297)
	at org.apache.spark.sql.catalyst.expressions.BindReferences$.bindReferences(BoundAttribute.scala:94)
	at org.apache.spark.sql.catalyst.expressions.UnsafeProjection$.create(Projection.scala:161)
	at org.apache.spark.sql.execution.aggregate.AggregationIterator.generateResultProjection(AggregationIterator.scala:246)
	at org.apache.spark.sql.execution.aggregate.AggregationIterator.<init>(AggregationIterator.scala:296)
	at org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.<init>(SortBasedAggregationIterator.scala:49)
	at org.apache.spark.sql.execution.aggregate.SortAggregateExec.$anonfun$doExecute$1(SortAggregateExec.scala:79)
	at org.apache.spark.sql.execution.aggregate.SortAggregateExec.$anonfun$doExecute$1$adapted(SortAggregateExec.scala:59)
...
{code}
Note that the aggregate expressions are deduplicated in {{PhysicalAggregation}}, but the result expressions were unable to deduplicate consistently due to the bug mentioned in this ticket.
{code}
AdaptiveSparkPlan isFinalPlan=false
+- SortAggregate(key=[], functions=[max(transform(array(id#15L), lambdafunction(lambda x#16L, lambda x#16L, false)))])
   +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [plan_id=38]
      +- SortAggregate(key=[], functions=[partial_max(transform(array(id#15L), lambdafunction(lambda x#16L, lambda x#16L, false)))])
         +- Range (0, 2, step=1, splits=16)
{code}

Fixing it via method 1 is more correct than method 2 in terms of avoiding incorrect CSE:
{code:diff}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
index 330d66a21b..12def60042 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
@@ -40,7 +40,11 @@ class EquivalentExpressions {
    * Returns true if there was already a matching expression.
    */
   def addExpr(expr: Expression): Boolean = {
-    updateExprInMap(expr, equivalenceMap)
+    if (supportedExpression(expr)) {
+      updateExprInMap(expr, equivalenceMap)
+    } else {
+      false
+    }
   }
 
   /**
{code}
the query runs correctly again, but this time the aggregate expression is NOT CSE'd anymore, done consistently for both aggregate expressions and result expressions:
{code}
AdaptiveSparkPlan isFinalPlan=false
+- SortAggregate(key=[], functions=[max(transform(array(id#0L), lambdafunction(lambda x#1L, lambda x#1L, false))), max(transform(array(id#0L), lambdafunction(lambda x#2L, lambda x#2L, false)))])
   +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [plan_id=11]
      +- SortAggregate(key=[], functions=[partial_max(transform(array(id#0L), lambdafunction(lambda x#1L, lambda x#1L, false))), partial_max(transform(array(id#0L), lambdafunction(lambda x#2L, lambda x#2L, false)))])
         +- Range (0, 2, step=1, splits=16)
{code}
and for this particular case, the CSE that used to take place was actually okay, so losing CSE here means performance regression.



--
This message was sent by Atlassian Jira
(v8.20.10#820010)

---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@spark.apache.org
For additional commands, e-mail: issues-help@spark.apache.org