You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2021/07/13 09:36:26 UTC

[spark] branch branch-3.2 updated: [SPARK-35551][SQL] Handle the COUNT bug for lateral subqueries

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.2 by this push:
     new 78c4e37  [SPARK-35551][SQL] Handle the COUNT bug for lateral subqueries
78c4e37 is described below

commit 78c4e3710da37c3358b8eec38aee5fa99d4b34bb
Author: allisonwang-db <al...@databricks.com>
AuthorDate: Tue Jul 13 17:35:03 2021 +0800

    [SPARK-35551][SQL] Handle the COUNT bug for lateral subqueries
    
    ### What changes were proposed in this pull request?
    This PR modifies `DecorrelateInnerQuery` to handle the COUNT bug for lateral subqueries. Similar to SPARK-15370, rewriting lateral subqueries as joins can change the semantics of the subquery and lead to incorrect answers.
    
    However we can't reuse the existing code to handle the count bug for correlated scalar subqueries because it assumes the subquery to have a specific shape (either with Filter + Aggregate or Aggregate as the root node). Instead, this PR proposes a more generic way to handle the COUNT bug. If an Aggregate is subject to the COUNT bug, we insert a left outer domain join between the outer query and the aggregate with a `alwaysTrue` marker and rewrite the final result conditioning on the ma [...]
    
    ```sql
    -- t1: [(0, 1), (1, 2)]
    -- t2: [(0, 2), (0, 3)]
    select * from t1 left outer join lateral (select count(*) from t2 where t2.c1 = t1.c1)
    ```
    
    Without count bug handling, the query plan is
    ```
    Project [c1#44, c2#45, count(1)#53L]
    +- Join LeftOuter, (c1#48 = c1#44)
       :- LocalRelation [c1#44, c2#45]
       +- Aggregate [c1#48], [count(1) AS count(1)#53L, c1#48]
          +- LocalRelation [c1#48]
    ```
    and the answer is wrong:
    ```
    +---+---+--------+
    |c1 |c2 |count(1)|
    +---+---+--------+
    |0  |1  |2       |
    |1  |2  |null    |
    +---+---+--------+
    ```
    
    With the count bug handling:
    ```
    Project [c1#1, c2#2, count(1)#10L]
    +- Join LeftOuter, (c1#34 <=> c1#1)
       :- LocalRelation [c1#1, c2#2]
       +- Project [if (isnull(alwaysTrue#32)) 0 else count(1)#33L AS count(1)#10L, c1#34]
          +- Join LeftOuter, (c1#5 = c1#34)
             :- Aggregate [c1#1], [c1#1 AS c1#34]
             :  +- LocalRelation [c1#1]
             +- Aggregate [c1#5], [count(1) AS count(1)#33L, c1#5, true AS alwaysTrue#32]
                +- LocalRelation [c1#5]
    ```
    and we have the correct answer:
    ```
    +---+---+--------+
    |c1 |c2 |count(1)|
    +---+---+--------+
    |0  |1  |2       |
    |1  |2  |0       |
    +---+---+--------+
    ```
    
    ### Why are the changes needed?
    Fix a correctness bug with lateral join rewrite.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Added SQL query tests. The results are consistent with Postgres' results.
    
    Closes #33070 from allisonwang-db/spark-35551-lateral-count-bug.
    
    Authored-by: allisonwang-db <al...@databricks.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
    (cherry picked from commit 4f760f2b1fd6c6fc8be34157ec9db5cc112f4d13)
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../sql/catalyst/expressions/AttributeMap.scala    |   2 +
 .../sql/catalyst/expressions/AttributeMap.scala    |   2 +
 .../catalyst/optimizer/DecorrelateInnerQuery.scala | 200 ++++++++++++++++---
 .../spark/sql/catalyst/optimizer/subquery.scala    |  44 +++--
 .../plans/logical/basicLogicalOperators.scala      |  16 +-
 .../resources/sql-tests/inputs/join-lateral.sql    |  61 +++++-
 .../sql-tests/results/join-lateral.sql.out         | 211 ++++++++++++++++++++-
 7 files changed, 488 insertions(+), 48 deletions(-)

diff --git a/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
index 189318a..3a42457 100644
--- a/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
+++ b/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
@@ -50,4 +50,6 @@ class AttributeMap[A](val baseMap: Map[ExprId, (Attribute, A)])
   override def iterator: Iterator[(Attribute, A)] = baseMap.valuesIterator
 
   override def -(key: Attribute): Map[Attribute, A] = baseMap.values.toMap - key
+
+  def ++(other: AttributeMap[A]): AttributeMap[A] = new AttributeMap(baseMap ++ other.baseMap)
 }
diff --git a/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
index 7715291..1f1df2d 100644
--- a/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
+++ b/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
@@ -51,4 +51,6 @@ class AttributeMap[A](val baseMap: Map[ExprId, (Attribute, A)])
   override def iterator: Iterator[(Attribute, A)] = baseMap.valuesIterator
 
   override def removed(key: Attribute): Map[Attribute, A] = baseMap.values.toMap - key
+
+  def ++(other: AttributeMap[A]): AttributeMap[A] = new AttributeMap(baseMap ++ other.baseMap)
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala
index f0441e3..f30dd99 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.catalyst.optimizer
 
+import scala.collection.mutable.ArrayBuffer
+
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
 import org.apache.spark.sql.catalyst.plans._
@@ -126,11 +128,11 @@ object DecorrelateInnerQuery extends PredicateHelper {
    * E.g. [outer(a) = x, y = outer(b), outer(c) = z + 1] => {a -> x, b -> y}
    */
   private def collectEquivalentOuterReferences(
-      expressions: Seq[Expression]): Map[Attribute, Attribute] = {
-    expressions.collect {
+      expressions: Seq[Expression]): AttributeMap[Attribute] = {
+    AttributeMap(expressions.collect {
       case Equality(o: OuterReference, a: Attribute) => (o.toAttribute, a.toAttribute)
       case Equality(a: Attribute, o: OuterReference) => (o.toAttribute, a.toAttribute)
-    }.toMap
+    })
   }
 
   /**
@@ -138,7 +140,7 @@ object DecorrelateInnerQuery extends PredicateHelper {
    */
   private def replaceOuterReference[E <: Expression](
       expression: E,
-      outerReferenceMap: Map[Attribute, Attribute]): E = {
+      outerReferenceMap: AttributeMap[Attribute]): E = {
     expression.transformWithPruning(_.containsPattern(OUTER_REFERENCE)) {
       case o: OuterReference => outerReferenceMap.getOrElse(o.toAttribute, o)
     }.asInstanceOf[E]
@@ -150,7 +152,7 @@ object DecorrelateInnerQuery extends PredicateHelper {
    */
   private def replaceOuterReferences[E <: Expression](
       expressions: Seq[E],
-      outerReferenceMap: Map[Attribute, Attribute]): Seq[E] = {
+      outerReferenceMap: AttributeMap[Attribute]): Seq[E] = {
     expressions.map(replaceOuterReference(_, outerReferenceMap))
   }
 
@@ -212,14 +214,40 @@ object DecorrelateInnerQuery extends PredicateHelper {
   }
 
   /**
-   * Rewrite all [[DomainJoin]]s in the inner query to actual inner joins with the outer query.
+   * Rewrite all [[DomainJoin]]s in the inner query to actual joins with the outer query.
    */
   def rewriteDomainJoins(
       outerPlan: LogicalPlan,
       innerPlan: LogicalPlan,
       conditions: Seq[Expression]): LogicalPlan = innerPlan match {
-    case d @ DomainJoin(domainAttrs, child) =>
+    case d @ DomainJoin(domainAttrs, child, joinType, condition) =>
       val domainAttrMap = buildDomainAttrMap(conditions, domainAttrs)
+
+      val newChild = joinType match {
+        // Left outer domain joins are used to handle the COUNT bug.
+        case LeftOuter =>
+          // Replace the attributes in the domain join condition with the actual outer expressions
+          // and use the new join conditions to rewrite domain joins in its child. For example:
+          // DomainJoin [c'] LeftOuter (a = c') with domainAttrMap: { c' -> _1 }.
+          // Then the new conditions to use will be [(a = _1)].
+          assert(condition.isDefined,
+            s"LeftOuter domain join should always have the join condition defined:\n$d")
+          val newCond = condition.get.transform {
+            case a: Attribute => domainAttrMap.getOrElse(a, a)
+          }
+          // Recursively rewrite domain joins using the new conditions.
+          rewriteDomainJoins(outerPlan, child, splitConjunctivePredicates(newCond))
+        case Inner =>
+          // The decorrelation framework adds domain inner joins by traversing down the plan tree
+          // recursively until it reaches a node that is not correlated with the outer query.
+          // So the child node of a domain inner join shouldn't contain another domain join.
+          assert(child.find(_.isInstanceOf[DomainJoin]).isEmpty,
+            s"Child of a domain inner join shouldn't contain another domain join.\n$child")
+          child
+        case o =>
+          throw new IllegalStateException(s"Unexpected domain join type $o")
+      }
+
       // We should only rewrite a domain join when all corresponding outer plan attributes
       // can be found from the join condition.
       if (domainAttrMap.size == domainAttrs.size) {
@@ -232,21 +260,15 @@ object DecorrelateInnerQuery extends PredicateHelper {
         // DomainJoin [a', b']  =>  Aggregate [a, b] [a AS a', b AS b']
         //                          +- Relation [a, b]
         val domain = Aggregate(groupingExprs, aggregateExprs, outerPlan)
-        child match {
+        newChild match {
           // A special optimization for OneRowRelation.
           // TODO: add a more general rule to optimize join with OneRowRelation.
           case _: OneRowRelation => domain
           // Construct a domain join.
-          // Join Inner
-          // :- Inner Query
-          // +- Domain
-          case _ =>
-            // The decorrelation framework adds domain joins by traversing down the plan tree
-            // recursively until it reaches a node that is not correlated with the outer query.
-            // So the child node of a domain join shouldn't contain another domain join.
-            assert(child.find(_.isInstanceOf[DomainJoin]).isEmpty,
-              s"Child of a domain join shouldn't contain another domain join.\n$child")
-            Join(child, domain, Inner, None, JoinHint.NONE)
+          // Join joinType condition
+          // :- Domain
+          // +- Inner Query
+          case _ => Join(domain, newChild, joinType, condition, JoinHint.NONE)
         }
       } else {
         throw QueryExecutionErrors.cannotRewriteDomainJoinWithConditionsError(conditions, d)
@@ -257,7 +279,8 @@ object DecorrelateInnerQuery extends PredicateHelper {
 
   def apply(
       innerPlan: LogicalPlan,
-      outerPlan: LogicalPlan): (LogicalPlan, Seq[Expression]) = {
+      outerPlan: LogicalPlan,
+      handleCountBug: Boolean = false): (LogicalPlan, Seq[Expression]) = {
     val outputPlanInputAttrs = outerPlan.inputSet
 
     // The return type of the recursion.
@@ -265,7 +288,7 @@ object DecorrelateInnerQuery extends PredicateHelper {
     // The second parameter is a list of join conditions with the outer query.
     // The third parameter is a mapping between the outer references and equivalent
     // expressions from the inner query that is used to replace outer references.
-    type ReturnType = (LogicalPlan, Seq[Expression], Map[Attribute, Attribute])
+    type ReturnType = (LogicalPlan, Seq[Expression], AttributeMap[Attribute])
 
     // Decorrelate the input plan with a set of parent outer references and a boolean flag
     // indicating whether the result of the plan will be aggregated. Steps:
@@ -288,7 +311,7 @@ object DecorrelateInnerQuery extends PredicateHelper {
           // If there is no outer references from the parent nodes, it means all outer
           // attributes can be substituted by attributes from the inner plan. So no
           // domain join is needed.
-          (plan, Nil, Map.empty[Attribute, Attribute])
+          (plan, Nil, AttributeMap.empty[Attribute])
         } else {
           // Build the domain join with the parent outer references.
           val attributes = parentOuterReferences.toSeq
@@ -310,7 +333,7 @@ object DecorrelateInnerQuery extends PredicateHelper {
           val conditions = outerReferenceMap.map {
             case (o, a) => EqualNullSafe(a, OuterReference(o))
           }
-          (domainJoin, conditions.toSeq, outerReferenceMap)
+          (domainJoin, conditions.toSeq, AttributeMap(outerReferenceMap))
         }
       } else {
         plan match {
@@ -428,7 +451,134 @@ object DecorrelateInnerQuery extends PredicateHelper {
               groupingExpressions = newGroupingExpr ++ referencesToAdd,
               aggregateExpressions = newAggExpr ++ referencesToAdd,
               child = newChild)
-            (newAggregate, joinCond, outerReferenceMap)
+
+            // Preserving domain attributes over an Aggregate with an empty grouping expression
+            // is subject to the "COUNT bug" that can lead to wrong answer:
+            //
+            // Suppose the original query is:
+            //   SELECT a, (SELECT COUNT(*) cnt FROM t2 WHERE t1.a = t2.c) FROM t1
+            //
+            // Decorrelated plan:
+            //   Project [a, scalar-subquery [a = c]]
+            //   :  +- Aggregate [c] [count(*) AS cnt, c]
+            //   :     +- Relation [c, d]
+            //   +- Relation [a, b]
+            //
+            // After rewrite:
+            //   Project [a, cnt]
+            //   +- Join LeftOuter (a = c)
+            //      :- Relation [a, b]
+            //      +- Aggregate [c] [count(*) AS cnt, c]
+            //         +- Relation [c, d]
+            //
+            //     T1            T2          T2' (GROUP BY c)
+            // +---+---+     +---+---+     +---+-----+
+            // | a | b |     | c | d |     | c | cnt |
+            // +---+---+     +---+---+     +---+-----+
+            // | 0 | 1 |     | 0 | 2 |     | 0 | 2   |
+            // | 1 | 2 |     | 0 | 3 |     +---+-----+
+            // +---+---+     +---+---+
+            //
+            // T1 nested loop join T2     T1 left outer join T2'
+            // on (a = c):                on (a = c):
+            // +---+-----+                +---+-----++
+            // | a | cnt |                | a | cnt  |
+            // +---+-----+                +---+------+
+            // | 0 | 2   |                | 0 | 2    |
+            // | 1 | 0   | <--- correct   | 1 | null | <--- wrong result
+            // +---+-----+                +---+------+
+            //
+            // If an aggregate is subject to the COUNT bug:
+            // 1) add a column `true AS alwaysTrue` to the result of the aggregate
+            // 2) insert a left outer domain join between the outer query and this aggregate
+            // 3) rewrite the original aggregate's output column using the default value of the
+            //    aggregate function and the alwaysTrue column.
+            //
+            // For example, T1 left outer join T2' with `alwaysTrue` marker:
+            // +---+------+------------+--------------------------------+
+            // | c | cnt  | alwaysTrue | if(isnull(alwaysTrue), 0, cnt) |
+            // +---+------+------------+--------------------------------+
+            // | 0 | 2    | true       | 2                              |
+            // | 0 | null | null       | 0                              |  <--- correct result
+            // +---+------+------------+--------------------------------+
+            if (groupingExpressions.isEmpty && handleCountBug) {
+              // Evaluate the aggregate expressions with zero tuples.
+              val resultMap = RewriteCorrelatedScalarSubquery.evalAggregateOnZeroTups(newAggregate)
+              val alwaysTrue = Alias(Literal.TrueLiteral, "alwaysTrue")()
+              val alwaysTrueRef = alwaysTrue.toAttribute.withNullability(true)
+              val expressions = ArrayBuffer.empty[NamedExpression]
+              // Create new aliases for aggregate expressions that have non-null default
+              // values and reconstruct the output with the `alwaysTrue` marker.
+              val projectList = newAggregate.aggregateExpressions.map { a =>
+                resultMap.get(a.exprId) match {
+                  // Aggregate expression is not subject to the count bug.
+                  case Some(Literal(null, _)) | None =>
+                    expressions += a
+                    // The attribute is nullable since it is from the right-hand side of a
+                    // left outer join.
+                    a.toAttribute.withNullability(true)
+                  case Some(default) =>
+                    assert(a.isInstanceOf[Alias], s"Cannot have non-aliased expression $a in " +
+                      s"aggregate that evaluates to non-null value with zero tuples.")
+                    val newAttr = a.newInstance()
+                    val ref = newAttr.toAttribute.withNullability(true)
+                    expressions += newAttr
+                    Alias(If(IsNull(alwaysTrueRef), default, ref), a.name)(a.exprId)
+                }
+              }
+              // Insert a placeholder left outer domain join between the outer query and
+              // and aggregate node and use the current collected join conditions as the
+              // left outer join condition.
+              //
+              // Original subquery:
+              //   Aggregate [count(1) AS cnt]
+              //   +- Filter (a = outer(c))
+              //      +- Relation [a, b]
+              //
+              // After decorrelation and before COUNT bug handling:
+              //   Aggregate [a] [count(1) AS cnt, a]
+              //   +- Relation [a, b]
+              //
+              // joinCond with the outer query: (a = outer(c))
+              //
+              // Handle the COUNT bug:
+              //   Project [if(isnull(alwaysTrue), 0, cnt') AS cnt, c']
+              //   +- DomainJoin [c'] LeftOuter (a = c')
+              //      +- Aggregate [a] [count(1) AS cnt', a, true AS alwaysTrue]
+              //         +- Relation [a, b]
+              //
+              // New joinCond with the outer query: (c' <=> outer(c)), and the DomainJoin
+              // will be written as:
+              //   Project [if(isnull(alwaysTrue), 0, cnt') AS cnt, c']
+              //   +- Join LeftOuter (a = c')
+              //      :- Aggregate [c] [c AS c']
+              //      :  +- OuterQuery [c, d]
+              //      +- Aggregate [a] [count(1) AS cnt', a, true AS alwaysTrue]
+              //         +- Relation [a, b]
+              //
+              val agg = newAggregate.copy(aggregateExpressions = expressions.toSeq :+ alwaysTrue)
+              // Find all outer references that are used in the join conditions.
+              val outerAttrs = collectOuterReferences(joinCond).toSeq
+              // Create new instance of the outer attributes as if they are generated inside
+              // the subquery by a left outer join with the outer query. Use new instance here
+              // to avoid conflicting join attributes with the inner query.
+              val domainAttrs = outerAttrs.map(_.newInstance())
+              val mapping = AttributeMap(outerAttrs.zip(domainAttrs))
+              // Use the current join conditions returned from the recursive call as the join
+              // conditions for the left outer join. All outer references in the join
+              // conditions are replaced by the newly created domain attributes.
+              val condition = replaceOuterReferences(joinCond, mapping).reduceOption(And)
+              val domainJoin = DomainJoin(domainAttrs, agg, LeftOuter, condition)
+              // Original domain attributes preserved through Aggregate are no longer needed.
+              val newProjectList = projectList.filter(!referencesToAdd.contains(_))
+              val project = Project(newProjectList ++ domainAttrs, domainJoin)
+              val newJoinCond = outerAttrs.zip(domainAttrs).map { case (outer, inner) =>
+                EqualNullSafe(inner, OuterReference(outer))
+              }
+              (project, newJoinCond, mapping)
+            } else {
+              (newAggregate, joinCond, outerReferenceMap)
+            }
 
           case j @ Join(left, right, joinType, condition, _) =>
             val outerReferences = collectOuterReferences(j.expressions)
@@ -446,12 +596,12 @@ object DecorrelateInnerQuery extends PredicateHelper {
             val (newLeft, leftJoinCond, leftOuterReferenceMap) = if (shouldPushToLeft) {
               decorrelate(left, newOuterReferences, aggregated)
             } else {
-              (left, Nil, Map.empty[Attribute, Attribute])
+              (left, Nil, AttributeMap.empty[Attribute])
             }
             val (newRight, rightJoinCond, rightOuterReferenceMap) = if (shouldPushToRight) {
               decorrelate(right, newOuterReferences, aggregated)
             } else {
-              (right, Nil, Map.empty[Attribute, Attribute])
+              (right, Nil, AttributeMap.empty[Attribute])
             }
             val newOuterReferenceMap = leftOuterReferenceMap ++ rightOuterReferenceMap
             val newJoinCond = leftJoinCond ++ rightJoinCond
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
index 7914d14..53448fb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
@@ -296,9 +296,12 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
       if (newCond.isEmpty) oldCond else newCond
     }
 
-    def decorrelate(sub: LogicalPlan, outer: LogicalPlan): (LogicalPlan, Seq[Expression]) = {
+    def decorrelate(
+        sub: LogicalPlan,
+        outer: LogicalPlan,
+        handleCountBug: Boolean = false): (LogicalPlan, Seq[Expression]) = {
       if (SQLConf.get.decorrelateInnerQueryEnabled) {
-        DecorrelateInnerQuery(sub, outer)
+        DecorrelateInnerQuery(sub, outer, handleCountBug)
       } else {
         pullOutCorrelatedPredicates(sub, outer)
       }
@@ -315,7 +318,7 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
         val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, plan)
         ListQuery(newPlan, children, exprId, childOutputs, getJoinCondition(newCond, conditions))
       case LateralSubquery(sub, children, exprId, conditions) if children.nonEmpty =>
-        val (newPlan, newCond) = decorrelate(sub, plan)
+        val (newPlan, newCond) = decorrelate(sub, plan, handleCountBug = true)
         LateralSubquery(newPlan, children, exprId, getJoinCondition(newCond, conditions))
     }
   }
@@ -396,7 +399,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe
   /**
    * Statically evaluate an expression containing one or more aggregates on an empty input.
    */
-  private def evalAggOnZeroTups(expr: Expression) : Expression = {
+  private def evalAggExprOnZeroTups(expr: Expression) : Expression = {
     // AggregateExpressions are Unevaluable, so we need to replace all aggregates
     // in the expression with the value they would return for zero input tuples.
     // Also replace attribute refs (for example, for grouping columns) with NULL.
@@ -411,6 +414,24 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe
   }
 
   /**
+   * Statically evaluate an [[Aggregate]] on an empty input and return a mapping
+   * between its output attribute expression ID and evaluated result.
+   */
+  def evalAggregateOnZeroTups(a: Aggregate): Map[ExprId, Expression] = {
+    // Some of the expressions under the Aggregate node are the join columns
+    // for joining with the outer query block. Fill those expressions in with
+    // nulls and statically evaluate the remainder.
+    a.aggregateExpressions.map {
+      case ref: AttributeReference => (ref.exprId, Literal.create(null, ref.dataType))
+      case alias @ Alias(_: AttributeReference, _) =>
+        (alias.exprId, Literal.create(null, alias.dataType))
+      case alias @ Alias(l: Literal, _) =>
+        (alias.exprId, l.copy(value = null))
+      case ne => (ne.exprId, evalAggExprOnZeroTups(ne))
+    }.toMap
+  }
+
+  /**
    * Statically evaluate a scalar subquery on an empty input.
    *
    * <b>WARNING:</b> This method only covers subqueries that pass the checks under
@@ -454,18 +475,8 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe
           projectList.map(ne => (ne.exprId, bindingExpr(ne, bindings))).toMap
         }
 
-      case Aggregate(_, aggExprs, _) =>
-        // Some of the expressions under the Aggregate node are the join columns
-        // for joining with the outer query block. Fill those expressions in with
-        // nulls and statically evaluate the remainder.
-        aggExprs.map {
-          case ref: AttributeReference => (ref.exprId, Literal.create(null, ref.dataType))
-          case alias @ Alias(_: AttributeReference, _) =>
-            (alias.exprId, Literal.create(null, alias.dataType))
-          case alias @ Alias(l: Literal, _) =>
-            (alias.exprId, l.copy(value = null))
-          case ne => (ne.exprId, evalAggOnZeroTups(ne))
-        }.toMap
+      case a: Aggregate =>
+        evalAggregateOnZeroTups(a)
 
       case l: LeafNode =>
         l.output.map(a => (a.exprId, Literal.create(null, a.dataType))).toMap
@@ -695,7 +706,6 @@ object RewriteLateralSubquery extends Rule[LogicalPlan] {
   def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
     _.containsPattern(LATERAL_JOIN)) {
     case LateralJoin(left, LateralSubquery(sub, _, _, joinCond), joinType, condition) =>
-      // TODO(SPARK-35551): handle the COUNT bug
       val newRight = DecorrelateInnerQuery.rewriteDomainJoins(left, sub, joinCond)
       val newCond = (condition ++ joinCond).reduceOption(And)
       Join(left, newRight, joinType, newCond, JoinHint.NONE)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index ba7a028..4633a36 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -1449,9 +1449,21 @@ case class CollectMetrics(
  * A placeholder for domain join that can be added when decorrelating subqueries.
  * It should be rewritten during the optimization phase.
  */
-case class DomainJoin(domainAttrs: Seq[Attribute], child: LogicalPlan) extends UnaryNode {
-  override def output: Seq[Attribute] = child.output ++ domainAttrs
+case class DomainJoin(
+    domainAttrs: Seq[Attribute],
+    child: LogicalPlan,
+    joinType: JoinType = Inner,
+    condition: Option[Expression] = None) extends UnaryNode {
+
+  require(Seq(Inner, LeftOuter).contains(joinType), s"Unsupported domain join type $joinType")
+
+  override def output: Seq[Attribute] = joinType match {
+    case LeftOuter => domainAttrs ++ child.output.map(_.withNullability(true))
+    case _ => domainAttrs ++ child.output
+  }
+
   override def producedAttributes: AttributeSet = AttributeSet(domainAttrs)
+
   override protected def withNewChildInternal(newChild: LogicalPlan): DomainJoin =
     copy(child = newChild)
 }
diff --git a/sql/core/src/test/resources/sql-tests/inputs/join-lateral.sql b/sql/core/src/test/resources/sql-tests/inputs/join-lateral.sql
index 30eaca6..cbc085f 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/join-lateral.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/join-lateral.sql
@@ -86,8 +86,65 @@ SELECT * FROM t1 WHERE c1 = (SELECT MIN(a) FROM t2, LATERAL (SELECT c1 AS a));
 -- lateral join inside correlated subquery
 SELECT * FROM t1 WHERE c1 = (SELECT MIN(a) FROM t2, LATERAL (SELECT c1 AS a) WHERE c1 = t1.c1);
 
--- TODO(SPARK-35551): handle the COUNT bug (the expected result should be (1, 2, 0))
-SELECT * FROM t1, LATERAL (SELECT COUNT(*) AS cnt FROM t2 WHERE c1 = t1.c1) WHERE cnt = 0;
+-- COUNT bug with a single aggregate expression
+SELECT * FROM t1, LATERAL (SELECT COUNT(*) cnt FROM t2 WHERE c1 = t1.c1);
+
+-- COUNT bug with multiple aggregate expressions
+SELECT * FROM t1, LATERAL (SELECT COUNT(*) cnt, SUM(c2) sum FROM t2 WHERE c1 = t1.c1);
+
+-- COUNT bug without count aggregate
+SELECT * FROM t1, LATERAL (SELECT SUM(c2) IS NULL FROM t2 WHERE t1.c1 = t2.c1);
+
+-- COUNT bug with complex aggregate expressions
+SELECT * FROM t1, LATERAL (SELECT COUNT(*) + CASE WHEN sum(c2) IS NULL THEN 0 ELSE sum(c2) END FROM t2 WHERE t1.c1 = t2.c1);
+
+-- COUNT bug with non-empty group by columns (should not handle the count bug)
+SELECT * FROM t1, LATERAL (SELECT c1, COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1 GROUP BY c1);
+SELECT * FROM t1, LATERAL (SELECT c2, COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1 GROUP BY c2);
+
+-- COUNT bug with different join types
+SELECT * FROM t1 JOIN LATERAL (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1);
+SELECT * FROM t1 LEFT JOIN LATERAL (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1);
+SELECT * FROM t1 CROSS JOIN LATERAL (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1);
+
+-- COUNT bug with group by columns and different join types
+SELECT * FROM t1 LEFT JOIN LATERAL (SELECT c1, COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1 GROUP BY c1);
+SELECT * FROM t1 CROSS JOIN LATERAL (SELECT c1, COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1 GROUP BY c1);
+
+-- COUNT bug with non-empty join conditions
+SELECT * FROM t1 JOIN LATERAL (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1) ON cnt + 1 = c1;
+
+-- COUNT bug with self join
+SELECT * FROM t1, LATERAL (SELECT COUNT(*) cnt FROM t1 t2 WHERE t1.c1 = t2.c1);
+SELECT * FROM t1, LATERAL (SELECT COUNT(*) cnt FROM t1 t2 WHERE t1.c1 = t2.c1 HAVING cnt > 0);
+
+-- COUNT bug with multiple aggregates
+SELECT * FROM t1, LATERAL (SELECT SUM(cnt) FROM (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1));
+SELECT * FROM t1, LATERAL (SELECT COUNT(cnt) FROM (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1));
+SELECT * FROM t1, LATERAL (SELECT SUM(cnt) FROM (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1 GROUP BY c1));
+SELECT * FROM t1, LATERAL (
+  SELECT COUNT(*) FROM (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1)
+  JOIN t2 ON cnt = t2.c1
+);
+
+-- COUNT bug with correlated predicates above the left outer join
+SELECT * FROM t1, LATERAL (SELECT * FROM (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1) WHERE cnt = c1 - 1);
+SELECT * FROM t1, LATERAL (SELECT COUNT(*) FROM (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1) WHERE cnt = c1 - 1);
+SELECT * FROM t1, LATERAL (
+  SELECT COUNT(*) FROM (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1)
+  WHERE cnt = c1 - 1 GROUP BY cnt
+);
+
+-- COUNT bug with joins in the subquery
+SELECT * FROM t1, LATERAL (
+  SELECT * FROM (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1)
+  JOIN t2 ON cnt = t2.c1
+);
+SELECT * FROM t1, LATERAL (
+  SELECT l.cnt + r.cnt
+  FROM (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1) l
+  JOIN (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1) r
+);
 
 -- lateral subquery with group by
 SELECT * FROM t1 LEFT JOIN LATERAL (SELECT MIN(c2) FROM t2 WHERE c1 = t1.c1 GROUP BY c1);
diff --git a/sql/core/src/test/resources/sql-tests/results/join-lateral.sql.out b/sql/core/src/test/resources/sql-tests/results/join-lateral.sql.out
index 1892c91..0dd2c41 100644
--- a/sql/core/src/test/resources/sql-tests/results/join-lateral.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/join-lateral.sql.out
@@ -1,5 +1,5 @@
 -- Automatically generated by SQLQueryTestSuite
--- Number of queries: 44
+-- Number of queries: 66
 
 
 -- !query
@@ -389,12 +389,219 @@ struct<c1:int,c2:int>
 
 
 -- !query
-SELECT * FROM t1, LATERAL (SELECT COUNT(*) AS cnt FROM t2 WHERE c1 = t1.c1) WHERE cnt = 0
+SELECT * FROM t1, LATERAL (SELECT COUNT(*) cnt FROM t2 WHERE c1 = t1.c1)
 -- !query schema
 struct<c1:int,c2:int,cnt:bigint>
 -- !query output
+0	1	2
+1	2	0
+
+
+-- !query
+SELECT * FROM t1, LATERAL (SELECT COUNT(*) cnt, SUM(c2) sum FROM t2 WHERE c1 = t1.c1)
+-- !query schema
+struct<c1:int,c2:int,cnt:bigint,sum:bigint>
+-- !query output
+0	1	2	5
+1	2	0	NULL
+
+
+-- !query
+SELECT * FROM t1, LATERAL (SELECT SUM(c2) IS NULL FROM t2 WHERE t1.c1 = t2.c1)
+-- !query schema
+struct<c1:int,c2:int,(sum(c2) IS NULL):boolean>
+-- !query output
+0	1	false
+1	2	true
+
+
+-- !query
+SELECT * FROM t1, LATERAL (SELECT COUNT(*) + CASE WHEN sum(c2) IS NULL THEN 0 ELSE sum(c2) END FROM t2 WHERE t1.c1 = t2.c1)
+-- !query schema
+struct<c1:int,c2:int,(count(1) + CASE WHEN (sum(c2) IS NULL) THEN 0 ELSE sum(c2) END):bigint>
+-- !query output
+0	1	7
+1	2	0
 
 
+-- !query
+SELECT * FROM t1, LATERAL (SELECT c1, COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1 GROUP BY c1)
+-- !query schema
+struct<c1:int,c2:int,c1:int,cnt:bigint>
+-- !query output
+0	1	0	2
+
+
+-- !query
+SELECT * FROM t1, LATERAL (SELECT c2, COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1 GROUP BY c2)
+-- !query schema
+struct<c1:int,c2:int,c2:int,cnt:bigint>
+-- !query output
+0	1	2	1
+0	1	3	1
+
+
+-- !query
+SELECT * FROM t1 JOIN LATERAL (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1)
+-- !query schema
+struct<c1:int,c2:int,cnt:bigint>
+-- !query output
+0	1	2
+1	2	0
+
+
+-- !query
+SELECT * FROM t1 LEFT JOIN LATERAL (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1)
+-- !query schema
+struct<c1:int,c2:int,cnt:bigint>
+-- !query output
+0	1	2
+1	2	0
+
+
+-- !query
+SELECT * FROM t1 CROSS JOIN LATERAL (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1)
+-- !query schema
+struct<c1:int,c2:int,cnt:bigint>
+-- !query output
+0	1	2
+1	2	0
+
+
+-- !query
+SELECT * FROM t1 LEFT JOIN LATERAL (SELECT c1, COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1 GROUP BY c1)
+-- !query schema
+struct<c1:int,c2:int,c1:int,cnt:bigint>
+-- !query output
+0	1	0	2
+1	2	NULL	NULL
+
+
+-- !query
+SELECT * FROM t1 CROSS JOIN LATERAL (SELECT c1, COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1 GROUP BY c1)
+-- !query schema
+struct<c1:int,c2:int,c1:int,cnt:bigint>
+-- !query output
+0	1	0	2
+
+
+-- !query
+SELECT * FROM t1 JOIN LATERAL (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1) ON cnt + 1 = c1
+-- !query schema
+struct<c1:int,c2:int,cnt:bigint>
+-- !query output
+1	2	0
+
+
+-- !query
+SELECT * FROM t1, LATERAL (SELECT COUNT(*) cnt FROM t1 t2 WHERE t1.c1 = t2.c1)
+-- !query schema
+struct<c1:int,c2:int,cnt:bigint>
+-- !query output
+0	1	1
+1	2	1
+
+
+-- !query
+SELECT * FROM t1, LATERAL (SELECT COUNT(*) cnt FROM t1 t2 WHERE t1.c1 = t2.c1 HAVING cnt > 0)
+-- !query schema
+struct<c1:int,c2:int,cnt:bigint>
+-- !query output
+0	1	1
+1	2	1
+
+
+-- !query
+SELECT * FROM t1, LATERAL (SELECT SUM(cnt) FROM (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1))
+-- !query schema
+struct<c1:int,c2:int,sum(cnt):bigint>
+-- !query output
+0	1	2
+1	2	0
+
+
+-- !query
+SELECT * FROM t1, LATERAL (SELECT COUNT(cnt) FROM (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1))
+-- !query schema
+struct<c1:int,c2:int,count(cnt):bigint>
+-- !query output
+0	1	1
+1	2	1
+
+
+-- !query
+SELECT * FROM t1, LATERAL (SELECT SUM(cnt) FROM (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1 GROUP BY c1))
+-- !query schema
+struct<c1:int,c2:int,sum(cnt):bigint>
+-- !query output
+0	1	2
+1	2	NULL
+
+
+-- !query
+SELECT * FROM t1, LATERAL (
+  SELECT COUNT(*) FROM (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1)
+  JOIN t2 ON cnt = t2.c1
+)
+-- !query schema
+struct<c1:int,c2:int,count(1):bigint>
+-- !query output
+0	1	0
+1	2	2
+
+
+-- !query
+SELECT * FROM t1, LATERAL (SELECT * FROM (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1) WHERE cnt = c1 - 1)
+-- !query schema
+struct<c1:int,c2:int,cnt:bigint>
+-- !query output
+1	2	0
+
+
+-- !query
+SELECT * FROM t1, LATERAL (SELECT COUNT(*) FROM (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1) WHERE cnt = c1 - 1)
+-- !query schema
+struct<c1:int,c2:int,count(1):bigint>
+-- !query output
+0	1	0
+1	2	1
+
+
+-- !query
+SELECT * FROM t1, LATERAL (
+  SELECT COUNT(*) FROM (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1)
+  WHERE cnt = c1 - 1 GROUP BY cnt
+)
+-- !query schema
+struct<c1:int,c2:int,count(1):bigint>
+-- !query output
+1	2	1
+
+
+-- !query
+SELECT * FROM t1, LATERAL (
+  SELECT * FROM (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1)
+  JOIN t2 ON cnt = t2.c1
+)
+-- !query schema
+struct<c1:int,c2:int,cnt:bigint,c1:int,c2:int>
+-- !query output
+1	2	0	0	2
+1	2	0	0	3
+
+
+-- !query
+SELECT * FROM t1, LATERAL (
+  SELECT l.cnt + r.cnt
+  FROM (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1) l
+  JOIN (SELECT COUNT(*) cnt FROM t2 WHERE t1.c1 = t2.c1) r
+)
+-- !query schema
+struct<c1:int,c2:int,(cnt + cnt):bigint>
+-- !query output
+0	1	4
+1	2	0
+
 
 -- !query
 SELECT * FROM t1 LEFT JOIN LATERAL (SELECT MIN(c2) FROM t2 WHERE c1 = t1.c1 GROUP BY c1)

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org