You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2022/11/22 18:09:31 UTC

[GitHub] [spark] peter-toth commented on a diff in pull request #38714: [WIP][SPARK-41141]. avoid introducing a new aggregate expression in the analysis phase when subquery is referencing it

peter-toth commented on code in PR #38714:
URL: https://github.com/apache/spark/pull/38714#discussion_r1029660173


##########
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala:
##########
@@ -270,4 +277,66 @@ class ResolveSubquerySuite extends AnalysisTest {
       ), Seq(a, b)).as("sub") :: Nil, t1)
     )
   }
+
+  test("SPARK-41141 aggregates of outer query referenced in subquery should not create" +
+    " new aggregates if possible") {
+    withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> s"${PropagateEmptyRelation.ruleName}") {
+      val a = 'a.int
+      val b = 'b.int
+      val c = 'c.int
+      val d = 'd.int
+
+      val t1 = LocalRelation(a, b)
+      val t2 = LocalRelation(c, d)
+      val optimizer = new SimpleTestOptimizer()
+
+      val plansToTest = Seq(
+        t1.select($"a", $"b").
+          having($"b")(Cos(sum($"a")))(Exists(t2.select($"c").
+            where($"d" === Cos(sum($"a"))))) -> 1,
+
+        t1.select($"a", $"b").
+          having($"b")(sum($"a"))(Exists(t2.select($"c").
+            where($"d" === Cos(sum($"a"))))) -> 1,
+        t1.select($"a", $"b").
+          having($"b")(Cos(sum($"a")))(Exists(t2.select($"c").
+            where($"d" === sum($"a")))) -> 2,
+        t1.select($"a", $"b").
+          having($"b")(sum($"a"), Cos(sum($"b")))(Exists(t2.select($"c").
+            where($"d" === Cos(sum($"a")) + sum($"a") + sum($"b") + Cos(sum($"b"))))) -> 3
+      )
+
+      plansToTest.foreach {
+        case (logicalPlan: LogicalPlan, numAggFunctions) =>
+          assertAnalysis(logicalPlan, numAggFunctions)
+      }
+
+      def assertAnalysis(logicalPlan: LogicalPlan, expectedAggregateFunctions: Int): Unit = {
+        val analyzedQuery = logicalPlan.analyze
+        Assert.assertTrue(analyzedQuery.analyzed)
+        val optimizedQuery = optimizer.execute(analyzedQuery)

Review Comment:
   Why do we need `optimizedQuery` in this test?
   



##########
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala:
##########
@@ -17,13 +17,20 @@
 
 package org.apache.spark.sql.catalyst.analysis
 
+import scala.collection.mutable
+
+import org.junit.Assert

Review Comment:
   Scalatest assert should be available, no need for junit here.



##########
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala:
##########
@@ -17,13 +17,20 @@
 
 package org.apache.spark.sql.catalyst.analysis
 
+import scala.collection.mutable
+
+import org.junit.Assert
+
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions.{CreateArray, Expression, GetStructField, InSubquery, LateralSubquery, ListQuery, OuterReference, ScalarSubquery}
-import org.apache.spark.sql.catalyst.expressions.aggregate.Count
+import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Cos, CreateArray, Exists, Expression, GetStructField, InSubquery, LateralSubquery, ListQuery, OuterReference, ScalarSubquery, SubqueryExpression}
+import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, Count}
+import org.apache.spark.sql.catalyst.optimizer.{PropagateEmptyRelation, SimpleTestOptimizer}
 import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
 import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.internal.SQLConf
+

Review Comment:
   No need for this extra line.



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala:
##########
@@ -208,20 +208,33 @@ object SubExprUtils extends PredicateHelper {
    */
   def getOuterReferences(expr: Expression): Seq[Expression] = {
     val outerExpressions = ArrayBuffer.empty[Expression]
-    def collectOutRefs(input: Expression): Unit = input match {
+
+    def collectOutRefs(input: Expression): Boolean = input match {
       case a: AggregateExpression if containsOuter(a) =>
         if (a.references.nonEmpty) {
           throw QueryCompilationErrors.mixedRefsInAggFunc(a.sql, a.origin)
         } else {
           // Collect and update the sub-tree so that outer references inside this aggregate
           // expression will not be collected. For example: min(outer(a)) -> min(a).
-          val newExpr = stripOuterReference(a)
-          outerExpressions += newExpr
+          // delay collecting outer expression as we want to go as much up as possible

Review Comment:
   I wonder if it would make sense to handle like `sum(a) + 1` like outer expressions too?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org