You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2021/04/20 09:22:54 UTC

[spark] branch master updated: [SPARK-34974][SQL] Improve subquery decorrelation framework

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new b6bb24c  [SPARK-34974][SQL] Improve subquery decorrelation framework
b6bb24c is described below

commit b6bb24ca1bd432f86bd7483b7813e5ea36ce426f
Author: allisonwang-db <66...@users.noreply.github.com>
AuthorDate: Tue Apr 20 09:22:22 2021 +0000

    [SPARK-34974][SQL] Improve subquery decorrelation framework
    
    ### What changes were proposed in this pull request?
    This PR implements the decorrelation technique in the paper "Unnesting Arbitrary Queries" by T. Neumann; A. Kemper
    (http://www.btw-2015.de/res/proceedings/Hauptband/Wiss/Neumann-Unnesting_Arbitrary_Querie.pdf). It currently supports Filter, Project, Aggregate, Join, and UnaryNode that passes CheckAnalysis.
    
    This feature can be controlled by the config `spark.sql.optimizer.decorrelateInnerQuery.enabled` (default: true).
    
    A few notes:
    1. This PR does not relax any constraints in CheckAnalysis for correlated subqueries, even though some cases can be supported by this new framework, such as aggregate with correlated non-equality predicates. This PR focuses on adding the new framework and making sure all existing cases can be supported. Constraints can be relaxed gradually in the future via separate PRs.
    2. The new framework is only enabled for correlated scalar subqueries, as the first step. EXISTS/IN subqueries can be supported in the future.
    
    ### Why are the changes needed?
    Currently, Spark has limited support for correlated subqueries. It only allows `Filter` to reference outer query columns and does not support non-equality predicates when the subquery is aggregated. This new framework will allow more operators to host outer column references and support correlated non-equality predicates and more types of operators in correlated subqueries.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Existing unit and SQL query tests and new optimizer plan tests.
    
    Closes #32072 from allisonwang-db/spark-34974-decorrelation.
    
    Authored-by: allisonwang-db <66...@users.noreply.github.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../spark/sql/catalyst/expressions/subquery.scala  |  32 +-
 .../catalyst/optimizer/DecorrelateInnerQuery.scala | 484 +++++++++++++++++++++
 .../spark/sql/catalyst/optimizer/subquery.scala    |  38 +-
 .../plans/logical/basicLogicalOperators.scala      |  11 +
 .../org/apache/spark/sql/internal/SQLConf.scala    |  10 +
 .../optimizer/DecorrelateInnerQuerySuite.scala     | 283 ++++++++++++
 6 files changed, 819 insertions(+), 39 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
index 4e07e72..2bedf84 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
@@ -140,14 +140,11 @@ object SubExprUtils extends PredicateHelper {
    * Given a logical plan, returns TRUE if it has an outer reference and false otherwise.
    */
   def hasOuterReferences(plan: LogicalPlan): Boolean = {
-    plan.find {
-      case f: Filter => containsOuter(f.condition)
-      case other => false
-    }.isDefined
+    plan.find(_.expressions.exists(containsOuter)).isDefined
   }
 
   /**
-   * Given a list of expressions, returns the expressions which have outer references. Aggregate
+   * Given an expression, returns the expressions which have outer references. Aggregate
    * expressions are treated in a special way. If the children of aggregate expression contains an
    * outer reference, then the entire aggregate expression is marked as an outer reference.
    * Example (SQL):
@@ -183,18 +180,18 @@ object SubExprUtils extends PredicateHelper {
    * }}}
    * The code below needs to change when we support the above cases.
    */
-  def getOuterReferences(conditions: Seq[Expression]): Seq[Expression] = {
+  def getOuterReferences(expr: Expression): Seq[Expression] = {
     val outerExpressions = ArrayBuffer.empty[Expression]
-    conditions foreach { expr =>
-      expr transformDown {
-        case a: AggregateExpression if a.collectLeaves.forall(_.isInstanceOf[OuterReference]) =>
-          val newExpr = stripOuterReference(a)
-          outerExpressions += newExpr
-          newExpr
-        case OuterReference(e) =>
-          outerExpressions += e
-          e
-      }
+    expr transformDown {
+      case a: AggregateExpression if a.collectLeaves.forall(_.isInstanceOf[OuterReference]) =>
+        // Collect and update the sub-tree so that outer references inside this aggregate
+        // expression will not be collected. For example: min(outer(a)) -> min(a).
+        val newExpr = stripOuterReference(a)
+        outerExpressions += newExpr
+        newExpr
+      case OuterReference(e) =>
+        outerExpressions += e
+        e
     }
     outerExpressions.toSeq
   }
@@ -204,8 +201,7 @@ object SubExprUtils extends PredicateHelper {
    * Filter operator can host outer references.
    */
   def getOuterReferences(plan: LogicalPlan): Seq[Expression] = {
-    val conditions = plan.collect { case Filter(cond, _) => cond }
-    getOuterReferences(conditions)
+    plan.flatMap(_.expressions.flatMap(getOuterReferences))
   }
 
   /**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala
new file mode 100644
index 0000000..377dcd6
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala
@@ -0,0 +1,484 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical._
+
+/**
+ * Decorrelate the inner query by eliminating outer references and create domain joins.
+ * The implementation is based on the paper: Unnesting Arbitrary Queries by Thomas Neumann
+ * and Alfons Kemper. https://dl.gi.de/handle/20.500.12116/2418.
+ *
+ * A correlated subquery can be viewed as a "dependent" nested loop join between the outer and
+ * the inner query. For each row produced by the outer query, we bind the [[OuterReference]]s in
+ * in the inner query with the corresponding values in the row, and then evaluate the inner query.
+ *
+ * Dependent Join
+ * :- Outer Query
+ * +- Inner Query
+ *
+ * If the [[OuterReference]]s are bound to the same value, the inner query will return the same
+ * result. Based on this, we can reduce the times to evaluate the inner query by first getting
+ * all distinct values of the [[OuterReference]]s.
+ *
+ * Normal Join
+ * :- Outer Query
+ * +- Dependent Join
+ *    :- Inner Query
+ *    +- Distinct Aggregate (outer_ref1, outer_ref2, ...)
+ *       +- Outer Query
+ *
+ * The distinct aggregate of the outer references is called a "domain", and the dependent join
+ * between the inner query and the domain is called a "domain join". We need to push down the
+ * domain join through the inner query until there is no outer reference in the sub-tree and
+ * the domain join will turn into a normal join.
+ *
+ * The decorrelation function returns a new query plan with optional placeholder [[DomainJoins]]s
+ * added and a list of join conditions with the outer query. [[DomainJoin]]s need to be rewritten
+ * into actual inner join between the inner query sub-tree and the outer query.
+ *
+ * E.g. decorrelate an inner query with equality predicates:
+ *
+ * SELECT (SELECT MIN(b) FROM t1 WHERE t2.c = t1.a) FROM t2
+ *
+ * Aggregate [] [min(b)]            Aggregate [a] [min(b), a]
+ * +- Filter (outer(c) = a)   =>   +- Relation [t1]
+ *    +- Relation [t1]
+ *
+ * Join conditions: [c = a]
+ *
+ * E.g. decorrelate an inner query with non-equality predicates:
+ *
+ * SELECT (SELECT MIN(b) FROM t1 WHERE t2.c > t1.a) FROM t2
+ *
+ * Aggregate [] [min(b)]            Aggregate [c'] [min(b), c']
+ * +- Filter (outer(c) > a)   =>   +- Filter (c' > a)
+ *    +- Relation [t1]                  +- DomainJoin [c']
+ *                                         +- Relation [t1]
+ *
+ * Join conditions: [c <=> c']
+ */
+object DecorrelateInnerQuery extends PredicateHelper {
+
+  /**
+   * Check if an expression contains any attribute. Note OuterReference is a
+   * leaf node and will not be found here.
+   */
+  private def containsAttribute(expression: Expression): Boolean = {
+    expression.find(_.isInstanceOf[Attribute]).isDefined
+  }
+
+  /**
+   * Check if an expression can be pulled up over an [[Aggregate]] without changing the
+   * semantics of the plan. The expression must be an equality predicate that guarantees
+   * one-to-one mapping between inner and outer attributes. More specifically, one side
+   * of the predicate must be an attribute and another side of the predicate must not
+   * contain other attributes from the inner query.
+   * For example:
+   *   (a = outer(c)) -> true
+   *   (a > outer(c)) -> false
+   *   (a + b = outer(c)) -> false
+   *   (a = outer(c) - b) -> false
+   */
+  private def canPullUpOverAgg(expression: Expression): Boolean = expression match {
+    case Equality(_: Attribute, b) => !containsAttribute(b)
+    case Equality(a, _: Attribute) => !containsAttribute(a)
+    case o => !containsAttribute(o)
+  }
+
+  /**
+   * Collect outer references in an expressions that are in the output attributes of the outer plan.
+   */
+  private def collectOuterReferences(expression: Expression): AttributeSet = {
+    AttributeSet(expression.collect { case o: OuterReference => o.toAttribute })
+  }
+
+  /**
+   * Collect outer references in a sequence of expressions that are in the output attributes
+   * of the outer plan.
+   */
+  private def collectOuterReferences(expressions: Seq[Expression]): AttributeSet = {
+    AttributeSet.fromAttributeSets(expressions.map(collectOuterReferences))
+  }
+
+  /**
+   * Build a mapping between outer references with equivalent inner query attributes.
+   * E.g. [outer(a) = x, y = outer(b), outer(c) = z + 1] => {a -> x, b -> y}
+   */
+  private def collectEquivalentOuterReferences(
+      expressions: Seq[Expression]): Map[Attribute, Attribute] = {
+    expressions.collect {
+      case Equality(o: OuterReference, a: Attribute) => (o.toAttribute, a.toAttribute)
+      case Equality(a: Attribute, o: OuterReference) => (o.toAttribute, a.toAttribute)
+    }.toMap
+  }
+
+  /**
+   * Replace all outer references using the expressions in the given outer reference map.
+   */
+  private def replaceOuterReference[E <: Expression](
+      expression: E,
+      outerReferenceMap: Map[Attribute, Attribute]): E = {
+    expression.transform {
+      case o: OuterReference => outerReferenceMap.getOrElse(o.toAttribute, o)
+    }.asInstanceOf[E]
+  }
+
+  /**
+   * Replace all outer references in the given expressions using the expressions in the
+   * outer reference map.
+   */
+  private def replaceOuterReferences[E <: Expression](
+      expressions: Seq[E],
+      outerReferenceMap: Map[Attribute, Attribute]): Seq[E] = {
+    expressions.map(replaceOuterReference(_, outerReferenceMap))
+  }
+
+  /**
+   * Return all references that are presented in the join conditions but not in the output
+   * of the given named expressions.
+   */
+  private def missingReferences(
+      namedExpressions: Seq[NamedExpression],
+      joinCond: Seq[Expression]): AttributeSet = {
+    val output = namedExpressions.map(_.toAttribute)
+    AttributeSet(joinCond.flatMap(_.references)) -- AttributeSet(output)
+  }
+
+  /**
+   * Deduplicate the inner and the outer query attributes and return an aliased
+   * subquery plan and join conditions if duplicates are found. Duplicated attributes
+   * can break the structural integrity when joining the inner and outer plan together.
+   */
+  def deduplicate(
+      innerPlan: LogicalPlan,
+      conditions: Seq[Expression],
+      outerOutputSet: AttributeSet): (LogicalPlan, Seq[Expression]) = {
+    val duplicates = innerPlan.outputSet.intersect(outerOutputSet)
+    if (duplicates.nonEmpty) {
+      val aliasMap = AttributeMap(duplicates.map { dup =>
+        dup -> Alias(dup, dup.toString)()
+      }.toSeq)
+      val aliasedExpressions = innerPlan.output.map { ref =>
+        aliasMap.getOrElse(ref, ref)
+      }
+      val aliasedProjection = Project(aliasedExpressions, innerPlan)
+      val aliasedConditions = conditions.map(_.transform {
+        case ref: Attribute => aliasMap.getOrElse(ref, ref).toAttribute
+      })
+      (aliasedProjection, aliasedConditions)
+    } else {
+      (innerPlan, conditions)
+    }
+  }
+
+  /**
+   * Build a mapping between domain attributes and corresponding outer query expressions
+   * using the join conditions.
+   */
+  private def buildDomainAttrMap(
+      conditions: Seq[Expression],
+      domainAttrs: Seq[Attribute]): Map[Attribute, Expression] = {
+    val domainAttrSet = AttributeSet(domainAttrs)
+    conditions.collect {
+      // When we build the join conditions between the domain attributes and outer references,
+      // the left hand side is always the domain attribute used in the inner query and the right
+      // hand side is the attribute from the outer query. Note here the right hand side of a
+      // condition is not necessarily an attribute, for example it can be a literal (if foldable)
+      // or a cast expression after the optimization.
+      case EqualNullSafe(left: Attribute, right: Expression) if domainAttrSet.contains(left) =>
+        left -> right
+    }.toMap
+  }
+
+  /**
+   * Rewrite all [[DomainJoin]]s in the inner query to actual inner joins with the outer query.
+   */
+  def rewriteDomainJoins(
+      outerPlan: LogicalPlan,
+      innerPlan: LogicalPlan,
+      conditions: Seq[Expression]): LogicalPlan = {
+    innerPlan transform {
+      case d @ DomainJoin(domainAttrs, child) =>
+        val domainAttrMap = buildDomainAttrMap(conditions, domainAttrs)
+        // We should only rewrite a domain join when all corresponding outer plan attributes
+        // can be found from the join condition.
+        if (domainAttrMap.size == domainAttrs.size) {
+          val groupingExprs = domainAttrs.map(domainAttrMap)
+          val aggregateExprs = groupingExprs.zip(domainAttrs).map {
+            // Rebuild the aliases.
+            case (inputAttr, outputAttr) => Alias(inputAttr, outputAttr.name)(outputAttr.exprId)
+          }
+          // Construct a domain with the outer query plan.
+          // DomainJoin [a', b']  =>  Aggregate [a, b] [a AS a', b AS b']
+          //                          +- Relation [a, b]
+          val domain = Aggregate(groupingExprs, aggregateExprs, outerPlan)
+          child match {
+            // A special optimization for OneRowRelation.
+            // TODO: add a more general rule to optimize join with OneRowRelation.
+            case _: OneRowRelation => domain
+            // Construct a domain join.
+            // Join Inner
+            // :- Inner Query
+            // +- Domain
+            case _ => Join(child, domain, Inner, None, JoinHint.NONE)
+          }
+        } else {
+          throw new UnsupportedOperationException(
+            s"Unable to rewrite domain join with conditions: $conditions\n$d")
+        }
+    }
+  }
+
+  def apply(
+      innerPlan: LogicalPlan,
+      outerPlan: LogicalPlan): (LogicalPlan, Seq[Expression]) = {
+    apply(innerPlan, Seq(outerPlan))
+  }
+
+  def apply(
+      innerPlan: LogicalPlan,
+      outerPlans: Seq[LogicalPlan]): (LogicalPlan, Seq[Expression]) = {
+    val outputSet = AttributeSet(outerPlans.flatMap(_.outputSet))
+
+    // The return type of the recursion.
+    // The first parameter is a new logical plan with correlation eliminated.
+    // The second parameter is a list of join conditions with the outer query.
+    // The third parameter is a mapping between the outer references and equivalent
+    // expressions from the inner query that is used to replace outer references.
+    type ReturnType = (LogicalPlan, Seq[Expression], Map[Attribute, Attribute])
+
+    // Decorrelate the input plan with a set of parent outer references and a boolean flag
+    // indicating whether the result of the plan will be aggregated. Steps:
+    // 1. Recursively collects outer references from the inner query until it reaches a node
+    //    that does not contain correlated value.
+    // 2. Inserts an optional [[DomainJoin]] node to indicate whether a domain (inner) join is
+    //    needed between the outer query and the specific sub-tree of the inner query.
+    // 3. Returns a list of join conditions with the outer query and a mapping between outer
+    //    references with references inside the inner query. The parent nodes need to preserve
+    //    the references inside the join conditions and substitute all outer references using
+    //    the mapping.
+    def decorrelate(
+        plan: LogicalPlan,
+        parentOuterReferences: AttributeSet,
+        aggregated: Boolean = false): ReturnType = {
+      val isCorrelated = hasOuterReferences(plan)
+      if (!isCorrelated) {
+        // We have reached a plan without correlation to the outer plan.
+        if (parentOuterReferences.isEmpty) {
+          // If there is no outer references from the parent nodes, it means all outer
+          // attributes can be substituted by attributes from the inner plan. So no
+          // domain join is needed.
+          (plan, Nil, Map.empty[Attribute, Attribute])
+        } else {
+          // Build the domain join with the parent outer references.
+          val attributes = parentOuterReferences.toSeq
+          val domains = attributes.map(_.newInstance())
+          // A placeholder to be rewritten into domain join.
+          val domainJoin = DomainJoin(domains, plan)
+          val outerReferenceMap = attributes.zip(domains).toMap
+          // Build join conditions between domain attributes and outer references.
+          // EqualNullSafe is used to make sure null key can be joined together. Note
+          // outer referenced attributes can be changed during the outer query optimization.
+          // The equality conditions will also serve as an attribute mapping between new
+          // outer references and domain attributes when rewriting the domain joins.
+          // E.g. if the attribute a is changed to a1, the join condition a' <=> outer(a)
+          // will become a' <=> a1, and we can construct the aliases based on the condition:
+          // DomainJoin [a']        Join Inner
+          // +- InnerQuery     =>   :- InnerQuery
+          //                        +- Aggregate [a1] [a1 AS a']
+          //                           +- OuterQuery
+          val conditions = outerReferenceMap.map {
+            case (o, a) => EqualNullSafe(a, OuterReference(o))
+          }
+          (domainJoin, conditions.toSeq, outerReferenceMap)
+        }
+      } else {
+        plan match {
+          case Filter(condition, child) =>
+            val conditions = splitConjunctivePredicates(condition)
+            val (correlated, uncorrelated) = conditions.partition(containsOuter)
+            // Find outer references that can be substituted by attributes from the inner
+            // query using the equality predicates.
+            val equivalences = collectEquivalentOuterReferences(correlated)
+            // Correlated predicates can be removed from the Filter's condition and used as
+            // join conditions with the outer query. However, if the results of the sub-tree
+            // is aggregated, only certain correlated equality predicates can be used, because
+            // the references in the join conditions need to be preserved in both the grouping
+            // and aggregate expressions of an Aggregate, which may change the semantics of the
+            // plan and lead to incorrect results. Here is an example:
+            // Relations:
+            //   t1(a, b): [(1, 1)]
+            //   t2(c, d): [(1, 1), (2, 0)]
+            //
+            // Query:
+            //   SELECT * FROM t1 WHERE a = (SELECT MAX(c) FROM t2 WHERE b >= d)
+            //
+            // Subquery plan transformation if correlated predicates are used as join conditions:
+            //   Aggregate [max(c)]               Aggregate [d] [max(c), d]
+            //   +- Filter (outer(b) >= d)   =>   +- Relation [c, d]
+            //      +- Relation [c, d]
+            //
+            // Plan after rewrite:
+            //   Project [a, b]                                   -- [(1, 1)]
+            //   +- Join LeftOuter (b >= d AND a = max(c))
+            //      :- Relation [a, b]
+            //      +- Aggregate [d] [max(c), d]                  -- [(1, 1), (2, 0)]
+            //         +- Relation [c, d]
+            //
+            // The result of the original query should be an empty set but the transformed
+            // query will output an incorrect result of (1, 1). The correct transformation
+            // with domain join is illustrated below:
+            //   Aggregate [max(c)]               Aggregate [b'] [max(c), b']
+            //   +- Filter (outer(b) >= d)   =>   +- Filter (b' >= d)
+            //      +- Relation [c, d]               +- DomainJoin [b']
+            //                                          +- Relation [c, d]
+            // Plan after rewrite:
+            //   Project [a, b]
+            //   +- Join LeftOuter (b <=> b' AND a = max(c))  -- []
+            //      :- Relation [a, b]
+            //      +- Aggregate [b'] [max(c), b']            -- [(2, 1)]
+            //         +- Join Inner (b' >= d)                -- [(1, 1, 1), (2, 0, 1)] (DomainJoin)
+            //            :- Relation [c, d]
+            //            +- Aggregate [b] [b AS b']          -- [(1)] (Domain)
+            //               +- Relation [a, b]
+            if (aggregated) {
+              // Split the correlated predicates into predicates that can and cannot be directly
+              // used as join conditions with the outer query depending on whether they can
+              // be pulled up over an Aggregate without changing the semantics of the plan.
+              val (equalityCond, predicates) = correlated.partition(canPullUpOverAgg)
+              val outerReferences = collectOuterReferences(predicates)
+              val newOuterReferences =
+                parentOuterReferences ++ outerReferences -- equivalences.keySet
+              val (newChild, joinCond, outerReferenceMap) =
+                decorrelate(child, newOuterReferences, aggregated)
+              // Add the outer references mapping collected from the equality conditions.
+              val newOuterReferenceMap = outerReferenceMap ++ equivalences
+              // Replace all outer references in the non-equality predicates.
+              val newCorrelated = replaceOuterReferences(predicates, newOuterReferenceMap)
+              // The new filter condition is the original filter condition with correlated
+              // equality predicates removed.
+              val newFilterCond = newCorrelated ++ uncorrelated
+              val newFilter = newFilterCond match {
+                case Nil => newChild
+                case conditions => Filter(conditions.reduce(And), newChild)
+              }
+              // Equality predicates are used as join conditions with the outer query.
+              val newJoinCond = joinCond ++ equalityCond
+              (newFilter, newJoinCond, newOuterReferenceMap)
+            } else {
+              // Results of this sub-tree is not aggregated, so all correlated predicates
+              // can be directly used as outer query join conditions.
+              val newOuterReferences = parentOuterReferences -- equivalences.keySet
+              val (newChild, joinCond, outerReferenceMap) =
+                decorrelate(child, newOuterReferences, aggregated)
+              // Add the outer references mapping collected from the equality conditions.
+              val newOuterReferenceMap = outerReferenceMap ++ equivalences
+              val newFilter = uncorrelated match {
+                case Nil => newChild
+                case conditions => Filter(conditions.reduce(And), newChild)
+              }
+              val newJoinCond = joinCond ++ correlated
+              (newFilter, newJoinCond, newOuterReferenceMap)
+            }
+
+          case Project(projectList, child) =>
+            val outerReferences = collectOuterReferences(projectList)
+            val newOuterReferences = parentOuterReferences ++ outerReferences
+            val (newChild, joinCond, outerReferenceMap) =
+              decorrelate(child, newOuterReferences, aggregated)
+            // Replace all outer references in the original project list.
+            val newProjectList = replaceOuterReferences(projectList, outerReferenceMap)
+            // Preserve required domain attributes in the join condition by adding the missing
+            // references to the new project list.
+            val referencesToAdd = missingReferences(newProjectList, joinCond)
+            val newProject = Project(newProjectList ++ referencesToAdd, newChild)
+            (newProject, joinCond, outerReferenceMap)
+
+          case a @ Aggregate(groupingExpressions, aggregateExpressions, child) =>
+            val outerReferences = collectOuterReferences(a.expressions)
+            val newOuterReferences = parentOuterReferences ++ outerReferences
+            val (newChild, joinCond, outerReferenceMap) =
+              decorrelate(child, newOuterReferences, aggregated = true)
+            // Replace all outer references in grouping and aggregate expressions.
+            val newGroupingExpr = replaceOuterReferences(groupingExpressions, outerReferenceMap)
+            val newAggExpr = replaceOuterReferences(aggregateExpressions, outerReferenceMap)
+            // Add all required domain attributes to both grouping and aggregate expressions.
+            val referencesToAdd = missingReferences(newAggExpr, joinCond)
+            val newAggregate = a.copy(
+              groupingExpressions = newGroupingExpr ++ referencesToAdd,
+              aggregateExpressions = newAggExpr ++ referencesToAdd,
+              child = newChild)
+            (newAggregate, joinCond, outerReferenceMap)
+
+          case j @ Join(left, right, joinType, condition, _) =>
+            val outerReferences = collectOuterReferences(j.expressions)
+            // Join condition containing outer references is not supported.
+            assert(outerReferences.isEmpty, s"Correlated column is not allowed in join: $j")
+            val newOuterReferences = parentOuterReferences ++ outerReferences
+            val shouldPushToLeft = joinType match {
+              case LeftOuter | LeftSemiOrAnti(_) | FullOuter => true
+              case _ => hasOuterReferences(left)
+            }
+            val shouldPushToRight = joinType match {
+              case RightOuter | FullOuter => true
+              case _ => hasOuterReferences(right)
+            }
+            val (newLeft, leftJoinCond, leftOuterReferenceMap) = if (shouldPushToLeft) {
+              decorrelate(left, newOuterReferences, aggregated)
+            } else {
+              (left, Nil, Map.empty[Attribute, Attribute])
+            }
+            val (newRight, rightJoinCond, rightOuterReferenceMap) = if (shouldPushToRight) {
+              decorrelate(right, newOuterReferences, aggregated)
+            } else {
+              (right, Nil, Map.empty[Attribute, Attribute])
+            }
+            val newOuterReferenceMap = leftOuterReferenceMap ++ rightOuterReferenceMap
+            val newJoinCond = leftJoinCond ++ rightJoinCond
+            // If we push the dependent join to both sides, we can augment the join condition
+            // such that both sides are matched on the domain attributes. For example,
+            // - Left Map: {outer(c1) = c1}
+            // - Right Map: {outer(c1) = 10 - c1}
+            // Then the join condition can be augmented with (c1 <=> 10 - c1).
+            val augmentedConditions = leftOuterReferenceMap.flatMap {
+              case (outer, inner) => rightOuterReferenceMap.get(outer).map(EqualNullSafe(inner, _))
+            }
+            val newCondition = (condition ++ augmentedConditions).reduceOption(And)
+            val newJoin = j.copy(left = newLeft, right = newRight, condition = newCondition)
+            (newJoin, newJoinCond, newOuterReferenceMap)
+
+          case u: UnaryNode =>
+            val outerReferences = collectOuterReferences(u.expressions)
+            assert(outerReferences.isEmpty, s"Correlated column is not allowed in $u")
+            decorrelate(u.child, parentOuterReferences, aggregated)
+
+          case o =>
+            throw new UnsupportedOperationException(
+              s"Decorrelate inner query through ${o.nodeName} is not supported.")
+        }
+      }
+    }
+    val (newChild, joinCond, _) = decorrelate(BooleanSimplification(innerPlan), AttributeSet.empty)
+    val (plan, conditions) = deduplicate(newChild, joinCond, outputSet)
+    (plan, stripOuterReferences(conditions))
+  }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
index ef73e58..9381796 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules._
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 
 /*
@@ -272,22 +273,8 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
     val baseConditions = predicateMap.values.flatten.toSeq
     val (newPlan, newCond) = if (outer.nonEmpty) {
       val outputSet = outer.map(_.outputSet).reduce(_ ++ _)
-      val duplicates = transformed.outputSet.intersect(outputSet)
-      val (plan, deDuplicatedConditions) = if (duplicates.nonEmpty) {
-        val aliasMap = AttributeMap(duplicates.map { dup =>
-          dup -> Alias(dup, dup.toString)()
-        }.toSeq)
-        val aliasedExpressions = transformed.output.map { ref =>
-          aliasMap.getOrElse(ref, ref)
-        }
-        val aliasedProjection = Project(aliasedExpressions, transformed)
-        val aliasedConditions = baseConditions.map(_.transform {
-          case ref: Attribute => aliasMap.getOrElse(ref, ref).toAttribute
-        })
-        (aliasedProjection, aliasedConditions)
-      } else {
-        (transformed, baseConditions)
-      }
+      val (plan, deDuplicatedConditions) =
+        DecorrelateInnerQuery.deduplicate(transformed, baseConditions, outputSet)
       (plan, stripOuterReferences(deDuplicatedConditions))
     } else {
       (transformed, stripOuterReferences(baseConditions))
@@ -308,9 +295,17 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
       if (newCond.isEmpty) oldCond else newCond
     }
 
+    def decorrelate(sub: LogicalPlan, outer: Seq[LogicalPlan]): (LogicalPlan, Seq[Expression]) = {
+      if (SQLConf.get.decorrelateInnerQueryEnabled) {
+        DecorrelateInnerQuery(sub, outer)
+      } else {
+        pullOutCorrelatedPredicates(sub, outer)
+      }
+    }
+
     plan transformExpressions {
       case ScalarSubquery(sub, children, exprId) if children.nonEmpty =>
-        val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans)
+        val (newPlan, newCond) = decorrelate(sub, outerPlans)
         ScalarSubquery(newPlan, getJoinCondition(newCond, children), exprId)
       case Exists(sub, children, exprId) if children.nonEmpty =>
         val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans)
@@ -379,7 +374,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe
       bindings: Map[ExprId, Expression]): Expression = {
     val rewrittenExpr = expr transform {
       case r: AttributeReference =>
-        bindings.getOrElse(r.exprId, Literal.default(NullType))
+        bindings.getOrElse(r.exprId, Literal.create(null, r.dataType))
     }
 
     tryEvalExpr(rewrittenExpr)
@@ -394,9 +389,9 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe
     // Also replace attribute refs (for example, for grouping columns) with NULL.
     val rewrittenExpr = expr transform {
       case a @ AggregateExpression(aggFunc, _, _, resultId, _) =>
-        aggFunc.defaultResult.getOrElse(Literal.default(NullType))
+        aggFunc.defaultResult.getOrElse(Literal.create(null, aggFunc.dataType))
 
-      case _: AttributeReference => Literal.default(NullType)
+      case a: AttributeReference => Literal.create(null, a.dataType)
     }
 
     tryEvalExpr(rewrittenExpr)
@@ -525,7 +520,8 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe
       subqueries: ArrayBuffer[ScalarSubquery]): (LogicalPlan, AttributeMap[Attribute]) = {
     val subqueryAttrMapping = ArrayBuffer[(Attribute, Attribute)]()
     val newChild = subqueries.foldLeft(child) {
-      case (currentChild, ScalarSubquery(query, conditions, _)) =>
+      case (currentChild, ScalarSubquery(sub, conditions, _)) =>
+        val query = DecorrelateInnerQuery.rewriteDomainJoins(currentChild, sub, conditions)
         val origOutput = query.output.head
 
         val resultWithZeroTups = evalSubqueryOnZeroTups(query)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 21e87b4..49e3e3c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -1393,3 +1393,14 @@ case class CollectMetrics(
   override protected def withNewChildInternal(newChild: LogicalPlan): CollectMetrics =
     copy(child = newChild)
 }
+
+/**
+ * A placeholder for domain join that can be added when decorrelating subqueries.
+ * It should be rewritten during the optimization phase.
+ */
+case class DomainJoin(domainAttrs: Seq[Attribute], child: LogicalPlan) extends UnaryNode {
+  override def output: Seq[Attribute] = child.output ++ domainAttrs
+  override def producedAttributes: AttributeSet = AttributeSet(domainAttrs)
+  override protected def withNewChildInternal(newChild: LogicalPlan): DomainJoin =
+    copy(child = newChild)
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index f4c236c..04e7400 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -2427,6 +2427,14 @@ object SQLConf {
       .booleanConf
       .createWithDefault(true)
 
+  val DECORRELATE_INNER_QUERY_ENABLED =
+    buildConf("spark.sql.optimizer.decorrelateInnerQuery.enabled")
+      .internal()
+      .doc("Decorrelate inner query by eliminating correlated references and build domain joins.")
+      .version("3.2.0")
+      .booleanConf
+      .createWithDefault(true)
+
   val TOP_K_SORT_FALLBACK_THRESHOLD =
     buildConf("spark.sql.execution.topKSortFallbackThreshold")
       .internal()
@@ -3829,6 +3837,8 @@ class SQLConf extends Serializable with Logging {
 
   def legacyIntervalEnabled: Boolean = getConf(LEGACY_INTERVAL_ENABLED)
 
+  def decorrelateInnerQueryEnabled: Boolean = getConf(SQLConf.DECORRELATE_INNER_QUERY_ENABLED)
+
   /** ********************** SQLConf functionality methods ************ */
 
   /** Set Spark SQL configuration properties. */
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala
new file mode 100644
index 0000000..93b2703
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala
@@ -0,0 +1,283 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.types.IntegerType
+
+class DecorrelateInnerQuerySuite extends PlanTest {
+
+  val a = AttributeReference("a", IntegerType)()
+  val b = AttributeReference("b", IntegerType)()
+  val c = AttributeReference("c", IntegerType)()
+  val x = AttributeReference("x", IntegerType)()
+  val y = AttributeReference("y", IntegerType)()
+  val z = AttributeReference("z", IntegerType)()
+  val testRelation = LocalRelation(a, b, c)
+  val testRelation2 = LocalRelation(x, y, z)
+
+  private def hasOuterReferences(plan: LogicalPlan): Boolean = {
+    plan.find(_.expressions.exists(SubExprUtils.containsOuter)).isDefined
+  }
+
+  private def check(
+      innerPlan: LogicalPlan,
+      outerPlan: LogicalPlan,
+      correctAnswer: LogicalPlan,
+      conditions: Seq[Expression]): Unit = {
+    val (outputPlan, joinCond) = DecorrelateInnerQuery(innerPlan, outerPlan)
+    assert(!hasOuterReferences(outputPlan))
+    comparePlans(outputPlan, correctAnswer)
+    assert(joinCond.length == conditions.length)
+    joinCond.zip(conditions).foreach(e => compareExpressions(e._1, e._2))
+  }
+
+  test("filter with correlated equality predicates only") {
+    val outerPlan = testRelation2
+    val innerPlan =
+      Project(Seq(a, b),
+        Filter(OuterReference(x) === a,
+          testRelation))
+    val correctAnswer = Project(Seq(a, b), testRelation)
+    check(innerPlan, outerPlan, correctAnswer, Seq(x === a))
+  }
+
+  test("filter with local and correlated equality predicates") {
+    val outerPlan = testRelation2
+    val innerPlan =
+      Project(Seq(a, b),
+        Filter(And(OuterReference(x) === a, b === 3),
+          testRelation))
+    val correctAnswer =
+      Project(Seq(a, b),
+        Filter(b === 3,
+          testRelation))
+    check(innerPlan, outerPlan, correctAnswer, Seq(x === a))
+  }
+
+  test("filter with correlated non-equality predicates") {
+    val outerPlan = testRelation2
+    val innerPlan =
+      Project(Seq(a, b),
+        Filter(OuterReference(x) > a,
+          testRelation))
+    val correctAnswer = Project(Seq(a, b), testRelation)
+    check(innerPlan, outerPlan, correctAnswer, Seq(x > a))
+  }
+
+  test("duplicated output attributes") {
+    val outerPlan = testRelation
+    val innerPlan =
+      Project(Seq(a),
+        Filter(OuterReference(a) === a,
+          testRelation))
+    val (outputPlan, joinCond) = DecorrelateInnerQuery(innerPlan, outerPlan)
+    val a1 = outputPlan.output.head
+    val correctAnswer =
+      Project(Seq(Alias(a, a1.name)(a1.exprId)),
+        Project(Seq(a),
+          testRelation))
+    comparePlans(outputPlan, correctAnswer)
+    assert(joinCond == Seq(a === a1))
+  }
+
+  test("filter with equality predicates with correlated values on both sides") {
+    val outerPlan = testRelation2
+    val innerPlan =
+      Project(Seq(a),
+        Filter(OuterReference(x) === OuterReference(y) + b,
+          testRelation))
+    val correctAnswer = Project(Seq(a, b), testRelation)
+    check(innerPlan, outerPlan, correctAnswer, Seq(x === y + b))
+  }
+
+  test("aggregate with correlated equality predicates that can be pulled up") {
+    val outerPlan = testRelation2
+    val minB = Alias(min(b), "min_b")()
+    val innerPlan =
+      Aggregate(Nil, Seq(minB),
+        Filter(And(OuterReference(x) === a, b === 3),
+          testRelation))
+    val correctAnswer =
+      Aggregate(Seq(a), Seq(minB, a),
+        Filter(b === 3,
+          testRelation))
+    check(innerPlan, outerPlan, correctAnswer, Seq(x === a))
+  }
+
+  test("aggregate with correlated equality predicates that cannot be pulled up") {
+    val outerPlan = testRelation2
+    val minB = Alias(min(b), "min_b")()
+    val innerPlan =
+      Aggregate(Nil, Seq(minB),
+        Filter(OuterReference(x) === OuterReference(y) + a,
+          testRelation))
+    val correctAnswer =
+      Aggregate(Seq(x, y), Seq(minB, x, y),
+        Filter(x === y + a,
+          DomainJoin(Seq(x, y), testRelation)))
+    check(innerPlan, outerPlan, correctAnswer, Seq(x <=> x, y <=> y))
+  }
+
+  test("aggregate with correlated equality predicates that has no attribute") {
+    val outerPlan = testRelation2
+    val minB = Alias(min(b), "min_b")()
+    val innerPlan =
+      Aggregate(Nil, Seq(minB),
+        Filter(OuterReference(x) === OuterReference(y),
+          testRelation))
+    val correctAnswer =
+      Aggregate(Nil, Seq(minB),
+        testRelation)
+    check(innerPlan, outerPlan, correctAnswer, Seq(x === y))
+  }
+
+  test("aggregate with correlated non-equality predicates") {
+    val outerPlan = testRelation2
+    val minB = Alias(min(b), "min_b")()
+    val innerPlan =
+      Aggregate(Nil, Seq(minB),
+        Filter(OuterReference(x) > a,
+          testRelation))
+    val correctAnswer =
+      Aggregate(Seq(x), Seq(minB, x),
+        Filter(x > a,
+          DomainJoin(Seq(x), testRelation)))
+    check(innerPlan, outerPlan, correctAnswer, Seq(x <=> x))
+  }
+
+  test("join with correlated equality predicates") {
+    val outerPlan = testRelation2
+    val joinCondition = Some($"t1.b" === $"t2.b")
+    val left =
+      Project(Seq(b),
+        Filter(OuterReference(x) === b,
+          testRelation)).as("t1")
+    val right =
+      Project(Seq(b),
+        Filter(OuterReference(x) === a,
+          testRelation)).as("t2")
+    Seq(Inner, LeftOuter, LeftSemi, LeftAnti, RightOuter, FullOuter, Cross).foreach { joinType =>
+      val innerPlan = Join(left, right, joinType, joinCondition, JoinHint.NONE).analyze
+      val newLeft = Project(Seq(b), testRelation).as("t1")
+      val newRight = Project(Seq(b, a), testRelation).as("t2")
+      // Since the left-hand side has outer(x) = b, and the right-hand side has outer(x) = a, the
+      // join condition will be augmented with b <=> a.
+      val newCond = Some(And($"t1.b" <=> $"t2.a", $"t1.b" === $"t2.b"))
+      val correctAnswer = Join(newLeft, newRight, joinType, newCond, JoinHint.NONE).analyze
+      check(innerPlan, outerPlan, correctAnswer, Seq(x === b, x === a))
+    }
+  }
+
+  test("correlated values inside join condition") {
+    val outerPlan = testRelation2
+    val innerPlan =
+      Join(
+        testRelation.as("t1"),
+        Filter(OuterReference(y) === 3, testRelation),
+        Inner,
+        Some(OuterReference(x) === a),
+        JoinHint.NONE)
+    val error = intercept[AssertionError] { DecorrelateInnerQuery(innerPlan, outerPlan) }
+    assert(error.getMessage.contains("Correlated column is not allowed in join"))
+  }
+
+  test("correlated values in project") {
+    val outerPlan = testRelation2
+    val innerPlan = Project(Seq(OuterReference(x), OuterReference(y)), OneRowRelation())
+    val correctAnswer = Project(Seq(x, y), DomainJoin(Seq(x, y), OneRowRelation()))
+    check(innerPlan, outerPlan, correctAnswer, Seq(x <=> x, y <=> y))
+  }
+
+  test("correlated values in project with alias") {
+    val outerPlan = testRelation2
+    val innerPlan =
+      Project(Seq(OuterReference(x), 'y1, 'sum),
+        Project(Seq(
+          OuterReference(x),
+          OuterReference(y).as("y1"),
+          Add(OuterReference(x), OuterReference(y)).as("sum")),
+            testRelation)).analyze
+    val correctAnswer =
+      Project(Seq(x, 'y1, 'sum, y),
+        Project(Seq(x, y.as("y1"), (x + y).as("sum"), y),
+          DomainJoin(Seq(x, y), testRelation))).analyze
+    check(innerPlan, outerPlan, correctAnswer, Seq(x <=> x, y <=> y))
+  }
+
+  test("correlated values in project with correlated equality conditions in filter") {
+    val outerPlan = testRelation2
+    val innerPlan =
+      Project(
+        Seq(OuterReference(x)),
+        Filter(
+          And(OuterReference(x) === a, And(OuterReference(x) + OuterReference(y) === c, b === 1)),
+          testRelation
+        )
+      )
+    val correctAnswer = Project(Seq(a, c), Filter(b === 1, testRelation))
+    check(innerPlan, outerPlan, correctAnswer, Seq(x === a, x + y === c))
+  }
+
+  test("correlated values in project without correlated equality conditions in filter") {
+    val outerPlan = testRelation2
+    val innerPlan =
+      Project(
+        Seq(OuterReference(y)),
+        Filter(
+          And(OuterReference(x) === a, And(OuterReference(x) + OuterReference(y) === c, b === 1)),
+          testRelation
+        )
+      )
+    val correctAnswer =
+      Project(Seq(y, a, c),
+        Filter(b === 1,
+          DomainJoin(Seq(y), testRelation)
+        )
+      )
+    check(innerPlan, outerPlan, correctAnswer, Seq(y <=> y, x === a, x + y === c))
+  }
+
+  test("correlated values in project with aggregate") {
+    val outerPlan = testRelation2
+    val innerPlan =
+      Aggregate(
+        Seq('x1), Seq(min('y1).as("min_y1")),
+        Project(
+          Seq(a, OuterReference(x).as("x1"), OuterReference(y).as("y1")),
+          Filter(
+            And(OuterReference(x) === a, OuterReference(y) === OuterReference(z)),
+            testRelation
+          )
+        )
+      ).analyze
+    val correctAnswer =
+      Aggregate(
+        Seq('x1, y, a), Seq(min('y1).as("min_y1"), y, a),
+        Project(
+          Seq(a, a.as("x1"), y.as("y1"), y),
+          DomainJoin(Seq(y), testRelation)
+        )
+      ).analyze
+    check(innerPlan, outerPlan, correctAnswer, Seq(y <=> y, x === a, y === z))
+  }
+}

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org