You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2022/11/18 13:50:20 UTC

[spark] branch master updated: [SPARK-40999] Hint propagation to subqueries

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 0fa9c554fc0 [SPARK-40999] Hint propagation to subqueries
0fa9c554fc0 is described below

commit 0fa9c554fc0b3940a47c3d1c6a5a17ca9a8cee8e
Author: fred-db <fr...@databricks.com>
AuthorDate: Fri Nov 18 21:49:24 2022 +0800

    [SPARK-40999] Hint propagation to subqueries
    
    ### What changes were proposed in this pull request?
    
    We add a hint field to the `SubqueryExpression` class, pull hints in subqueries into the hint field during `EliminateResolvedHint` and propagate this hint to joins formed from the subquery in `RewritePredicateSubquery`
    
    ### Why are the changes needed?
    
    Currently, if a user tries to specify a query like the following, the hints on the subquery won't be respected.
    
    ```
    SELECT * FROM target t WHERE EXISTS
    (SELECT /*+ BROADCAST */ * FROM source s WHERE s.key = t.key)
    ```
    This happens as hints are removed from the plan and pulled into joins in the beginning of the optimization stage, but subqueries are only turned into joins during optimization. As we remove any hints that are not below a join, we end up removing hints that are below a subquery.
    
    It worked prior to a refactoring that added hints as a field to joins ([SPARK-26065](https://issues.apache.org/jira/browse/SPARK-26065)) and can cause a regression if someone made use of hints on subqueries before.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. Hints on subqueries will now work.
    
    ### How was this patch tested?
    
    UTs to check whether hints are correctly propagated to joins formed from subqueries.
    
    Closes #38497 from fred-db/hint-propagation-40999.
    
    Lead-authored-by: fred-db <fr...@databricks.com>
    Co-authored-by: Fredrik Klauss <fr...@databricks.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../spark/sql/catalyst/analysis/Analyzer.scala     |   8 +-
 .../sql/catalyst/analysis/CheckAnalysis.scala      |   5 +-
 .../spark/sql/catalyst/analysis/TypeCoercion.scala |   2 +-
 .../sql/catalyst/expressions/DynamicPruning.scala  |   9 +-
 .../spark/sql/catalyst/expressions/subquery.scala  |  35 +++-
 .../catalyst/optimizer/EliminateResolvedHint.scala |  22 +-
 .../catalyst/optimizer/InjectRuntimeFilter.scala   |   8 +-
 .../catalyst/optimizer/MergeScalarSubqueries.scala |   2 +
 .../sql/catalyst/optimizer/finishAnalysis.scala    |   3 +-
 .../spark/sql/catalyst/optimizer/subquery.scala    |  75 ++++---
 .../org/apache/spark/sql/internal/SQLConf.scala    |   7 +
 .../adaptive/InsertAdaptiveSparkPlan.scala         |   6 +-
 .../adaptive/PlanAdaptiveSubqueries.scala          |   6 +-
 .../dynamicpruning/PlanDynamicPruningFilters.scala |   2 +-
 .../org/apache/spark/sql/execution/subquery.scala  |   2 +-
 .../spark/sql/SubqueryHintPropagationSuite.scala   | 221 +++++++++++++++++++++
 .../execution/command/PlanResolutionSuite.scala    |   4 +-
 17 files changed, 348 insertions(+), 69 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index b84a03e77d6..104c5c1e080 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -2554,17 +2554,17 @@ class Analyzer(override val catalogManager: CatalogManager)
      */
     private def resolveSubQueries(plan: LogicalPlan, outer: LogicalPlan): LogicalPlan = {
       plan.transformAllExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION), ruleId) {
-        case s @ ScalarSubquery(sub, _, exprId, _) if !sub.resolved =>
+        case s @ ScalarSubquery(sub, _, exprId, _, _) if !sub.resolved =>
           resolveSubQuery(s, outer)(ScalarSubquery(_, _, exprId))
-        case e @ Exists(sub, _, exprId, _) if !sub.resolved =>
+        case e @ Exists(sub, _, exprId, _, _) if !sub.resolved =>
           resolveSubQuery(e, outer)(Exists(_, _, exprId))
-        case InSubquery(values, l @ ListQuery(_, _, exprId, _, _))
+        case InSubquery(values, l @ ListQuery(_, _, exprId, _, _, _))
             if values.forall(_.resolved) && !l.resolved =>
           val expr = resolveSubQuery(l, outer)((plan, exprs) => {
             ListQuery(plan, exprs, exprId, plan.output)
           })
           InSubquery(values, expr.asInstanceOf[ListQuery])
-        case s @ LateralSubquery(sub, _, exprId, _) if !sub.resolved =>
+        case s @ LateralSubquery(sub, _, exprId, _, _) if !sub.resolved =>
           resolveSubQuery(s, outer)(LateralSubquery(_, _, exprId))
       }
     }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 115b1014404..aecf36660cd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -853,12 +853,13 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
       }
     }
 
-    // Skip subquery aliases added by the Analyzer.
+    // Skip subquery aliases added by the Analyzer as well as hints.
     // For projects, do the necessary mapping and skip to its child.
     @scala.annotation.tailrec
     def cleanQueryInScalarSubquery(p: LogicalPlan): LogicalPlan = p match {
       case s: SubqueryAlias => cleanQueryInScalarSubquery(s.child)
       case p: Project => cleanQueryInScalarSubquery(p.child)
+      case h: ResolvedHint => cleanQueryInScalarSubquery(h.child)
       case child => child
     }
 
@@ -892,7 +893,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
     checkOuterReference(plan, expr)
 
     expr match {
-      case ScalarSubquery(query, outerAttrs, _, _) =>
+      case ScalarSubquery(query, outerAttrs, _, _, _) =>
         // Scalar subquery must return one column as output.
         if (query.output.size != 1) {
           expr.failAnalysis(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index 08847fbe958..e57d1075d2f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -360,7 +360,7 @@ abstract class TypeCoercionBase {
 
       // Handle type casting required between value expression and subquery output
       // in IN subquery.
-      case i @ InSubquery(lhs, ListQuery(sub, children, exprId, _, conditions))
+      case i @ InSubquery(lhs, ListQuery(sub, children, exprId, _, conditions, _))
           if !i.resolved && lhs.length == sub.output.length =>
         // LHS is the value expressions of IN subquery.
         // RHS is the subquery output.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala
index dd9e9307e74..1e94188bd18 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, LogicalPlan}
 import org.apache.spark.sql.catalyst.trees.TreePattern._
 import org.apache.spark.sql.catalyst.trees.UnaryLike
 
@@ -45,8 +45,9 @@ case class DynamicPruningSubquery(
     buildKeys: Seq[Expression],
     broadcastKeyIndex: Int,
     onlyInBroadcast: Boolean,
-    exprId: ExprId = NamedExpression.newExprId)
-  extends SubqueryExpression(buildQuery, Seq(pruningKey), exprId)
+    exprId: ExprId = NamedExpression.newExprId,
+    hint: Option[HintInfo] = None)
+  extends SubqueryExpression(buildQuery, Seq(pruningKey), exprId, Seq.empty, hint)
   with DynamicPruning
   with Unevaluable
   with UnaryLike[Expression] {
@@ -59,6 +60,8 @@ case class DynamicPruningSubquery(
 
   override def withNewPlan(plan: LogicalPlan): DynamicPruningSubquery = copy(buildQuery = plan)
 
+  override def withNewHint(hint: Option[HintInfo]): SubqueryExpression = copy(hint = hint)
+
   override lazy val resolved: Boolean = {
     pruningKey.resolved &&
       buildQuery.resolved &&
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
index 069251734db..e7384dac2d5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.plans.QueryPlan
-import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan}
+import org.apache.spark.sql.catalyst.plans.logical.{Filter, HintInfo, LogicalPlan}
 import org.apache.spark.sql.catalyst.trees.TreePattern._
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.types._
@@ -68,18 +68,23 @@ abstract class PlanExpression[T <: QueryPlan[_]] extends Expression {
  * @param exprId: ID of the expression
  * @param joinCond: the join conditions with the outer query. It contains both inner and outer
  *                  query references.
+ * @param hint: An optional hint for this subquery that will be passed to the join formed from
+ *              this subquery.
  */
 abstract class SubqueryExpression(
     plan: LogicalPlan,
     outerAttrs: Seq[Expression],
     exprId: ExprId,
-    joinCond: Seq[Expression] = Nil) extends PlanExpression[LogicalPlan] {
+    joinCond: Seq[Expression],
+    hint: Option[HintInfo]) extends PlanExpression[LogicalPlan] {
   override lazy val resolved: Boolean = childrenResolved && plan.resolved
   override lazy val references: AttributeSet =
     AttributeSet.fromAttributeSets(outerAttrs.map(_.references))
   override def children: Seq[Expression] = outerAttrs ++ joinCond
   override def withNewPlan(plan: LogicalPlan): SubqueryExpression
   def isCorrelated: Boolean = outerAttrs.nonEmpty
+  def hint: Option[HintInfo]
+  def withNewHint(hint: Option[HintInfo]): SubqueryExpression
 }
 
 object SubqueryExpression {
@@ -254,14 +259,16 @@ case class ScalarSubquery(
     plan: LogicalPlan,
     outerAttrs: Seq[Expression] = Seq.empty,
     exprId: ExprId = NamedExpression.newExprId,
-    joinCond: Seq[Expression] = Seq.empty)
-  extends SubqueryExpression(plan, outerAttrs, exprId, joinCond) with Unevaluable {
+    joinCond: Seq[Expression] = Seq.empty,
+    hint: Option[HintInfo] = None)
+  extends SubqueryExpression(plan, outerAttrs, exprId, joinCond, hint) with Unevaluable {
   override def dataType: DataType = {
     assert(plan.schema.fields.nonEmpty, "Scalar subquery should have only one column")
     plan.schema.fields.head.dataType
   }
   override def nullable: Boolean = true
   override def withNewPlan(plan: LogicalPlan): ScalarSubquery = copy(plan = plan)
+  override def withNewHint(hint: Option[HintInfo]): ScalarSubquery = copy(hint = hint)
   override def toString: String = s"scalar-subquery#${exprId.id} $conditionString"
   override lazy val canonicalized: Expression = {
     ScalarSubquery(
@@ -299,11 +306,13 @@ case class LateralSubquery(
     plan: LogicalPlan,
     outerAttrs: Seq[Expression] = Seq.empty,
     exprId: ExprId = NamedExpression.newExprId,
-    joinCond: Seq[Expression] = Seq.empty)
-  extends SubqueryExpression(plan, outerAttrs, exprId, joinCond) with Unevaluable {
+    joinCond: Seq[Expression] = Seq.empty,
+    hint: Option[HintInfo] = None)
+  extends SubqueryExpression(plan, outerAttrs, exprId, joinCond, hint) with Unevaluable {
   override def dataType: DataType = plan.output.toStructType
   override def nullable: Boolean = true
   override def withNewPlan(plan: LogicalPlan): LateralSubquery = copy(plan = plan)
+  override def withNewHint(hint: Option[HintInfo]): LateralSubquery = copy(hint = hint)
   override def toString: String = s"lateral-subquery#${exprId.id} $conditionString"
   override lazy val canonicalized: Expression = {
     LateralSubquery(
@@ -339,8 +348,9 @@ case class ListQuery(
     outerAttrs: Seq[Expression] = Seq.empty,
     exprId: ExprId = NamedExpression.newExprId,
     childOutputs: Seq[Attribute] = Seq.empty,
-    joinCond: Seq[Expression] = Seq.empty)
-  extends SubqueryExpression(plan, outerAttrs, exprId, joinCond) with Unevaluable {
+    joinCond: Seq[Expression] = Seq.empty,
+    hint: Option[HintInfo] = None)
+  extends SubqueryExpression(plan, outerAttrs, exprId, joinCond, hint) with Unevaluable {
   override def dataType: DataType = if (childOutputs.length > 1) {
     childOutputs.toStructType
   } else {
@@ -349,6 +359,7 @@ case class ListQuery(
   override lazy val resolved: Boolean = childrenResolved && plan.resolved && childOutputs.nonEmpty
   override def nullable: Boolean = false
   override def withNewPlan(plan: LogicalPlan): ListQuery = copy(plan = plan)
+  override def withNewHint(hint: Option[HintInfo]): ListQuery = copy(hint = hint)
   override def toString: String = s"list#${exprId.id} $conditionString"
   override lazy val canonicalized: Expression = {
     ListQuery(
@@ -397,10 +408,14 @@ case class Exists(
     plan: LogicalPlan,
     outerAttrs: Seq[Expression] = Seq.empty,
     exprId: ExprId = NamedExpression.newExprId,
-    joinCond: Seq[Expression] = Seq.empty)
-  extends SubqueryExpression(plan, outerAttrs, exprId, joinCond) with Predicate with Unevaluable {
+    joinCond: Seq[Expression] = Seq.empty,
+    hint: Option[HintInfo] = None)
+  extends SubqueryExpression(plan, outerAttrs, exprId, joinCond, hint)
+  with Predicate
+  with Unevaluable {
   override def nullable: Boolean = false
   override def withNewPlan(plan: LogicalPlan): Exists = copy(plan = plan)
+  override def withNewHint(hint: Option[HintInfo]): Exists = copy(hint = hint)
   override def toString: String = s"exists#${exprId.id} $conditionString"
   override lazy val canonicalized: Expression = {
     Exists(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateResolvedHint.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateResolvedHint.scala
index 71cbbbb763c..6fce47fbacd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateResolvedHint.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateResolvedHint.scala
@@ -17,8 +17,11 @@
 
 package org.apache.spark.sql.catalyst.optimizer
 
+import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.TreePattern.PLAN_EXPRESSION
+import org.apache.spark.sql.internal.SQLConf
 
 /**
  * Replaces [[ResolvedHint]] operators from the plan. Move the [[HintInfo]] to associated [[Join]]
@@ -31,20 +34,35 @@ object EliminateResolvedHint extends Rule[LogicalPlan] {
   // This is also called in the beginning of the optimization phase, and as a result
   // is using transformUp rather than resolveOperators.
   def apply(plan: LogicalPlan): LogicalPlan = {
-    val pulledUp = plan transformUp {
+    val joinsWithHints = plan transformUp {
       case j: Join if j.hint == JoinHint.NONE =>
         val (newLeft, leftHints) = extractHintsFromPlan(j.left)
         val (newRight, rightHints) = extractHintsFromPlan(j.right)
         val newJoinHint = JoinHint(mergeHints(leftHints), mergeHints(rightHints))
         j.copy(left = newLeft, right = newRight, hint = newJoinHint)
     }
-    pulledUp.transformUp {
+    val shouldPullHintsIntoSubqueries = SQLConf.get.getConf(SQLConf.PULL_HINTS_INTO_SUBQUERIES)
+    val joinsAndSubqueriesWithHints = if (shouldPullHintsIntoSubqueries) {
+      pullHintsIntoSubqueries(joinsWithHints)
+    } else {
+      joinsWithHints
+    }
+    joinsAndSubqueriesWithHints.transformUp {
       case h: ResolvedHint =>
         hintErrorHandler.joinNotFoundForJoinHint(h.hints)
         h.child
     }
   }
 
+  def pullHintsIntoSubqueries(plan: LogicalPlan): LogicalPlan = {
+    plan.transformAllExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION)) {
+      case s: SubqueryExpression if s.hint.isEmpty =>
+        val (newPlan, subqueryHints) = extractHintsFromPlan(s.plan)
+        val newHint = mergeHints(subqueryHints)
+        s.withNewPlan(newPlan).withNewHint(newHint)
+    }
+  }
+
   /**
    * Combine a list of [[HintInfo]]s into one [[HintInfo]].
    */
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
index efcf607b589..1d5129ff7f0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
@@ -231,9 +231,9 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
       leftKey: Expression,
       rightKey: Expression): Boolean = {
     (left, right) match {
-      case (Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _), plan), _) =>
+      case (Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _, _), plan), _) =>
         pruningKey.fastEquals(leftKey) || hasDynamicPruningSubquery(plan, right, leftKey, rightKey)
-      case (_, Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _), plan)) =>
+      case (_, Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _, _), plan)) =>
         pruningKey.fastEquals(rightKey) ||
           hasDynamicPruningSubquery(left, plan, leftKey, rightKey)
       case _ => false
@@ -264,10 +264,10 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
       rightKey: Expression): Boolean = {
     (left, right) match {
       case (Filter(InSubquery(Seq(key),
-      ListQuery(Aggregate(Seq(Alias(_, _)), Seq(Alias(_, _)), _), _, _, _, _)), _), _) =>
+      ListQuery(Aggregate(Seq(Alias(_, _)), Seq(Alias(_, _)), _), _, _, _, _, _)), _), _) =>
         key.fastEquals(leftKey) || key.fastEquals(new Murmur3Hash(Seq(leftKey)))
       case (_, Filter(InSubquery(Seq(key),
-      ListQuery(Aggregate(Seq(Alias(_, _)), Seq(Alias(_, _)), _), _, _, _, _)), _)) =>
+      ListQuery(Aggregate(Seq(Alias(_, _)), Seq(Alias(_, _)), _), _, _, _, _, _)), _)) =>
         key.fastEquals(rightKey) || key.fastEquals(new Murmur3Hash(Seq(rightKey)))
       case _ => false
     }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala
index 43be160f08f..6184160829b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala
@@ -162,6 +162,8 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] {
   private def insertReferences(plan: LogicalPlan, cache: ArrayBuffer[Header]): LogicalPlan = {
     plan.transformUpWithSubqueries {
       case n => n.transformExpressionsUpWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY)) {
+        // The subquery could contain a hint that is not propagated once we cache it, but as a
+        // non-correlated scalar subquery won't be turned into a Join the loss of hints is fine.
         case s: ScalarSubquery if !s.isCorrelated && s.deterministic =>
           val (subqueryIndex, headerIndex) = cacheSubquery(s.plan, cache)
           ScalarSubqueryReference(subqueryIndex, headerIndex, s.dataType, s.exprId)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala
index a33069051d9..466781fa1de 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala
@@ -66,7 +66,8 @@ object RewriteNonCorrelatedExists extends Rule[LogicalPlan] {
       IsNotNull(
         ScalarSubquery(
           plan = Limit(Literal(1), Project(Seq(Alias(Literal(1), "col")()), exists.plan)),
-          exprId = exists.exprId))
+          exprId = exists.exprId,
+          hint = exists.hint))
   }
 }
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
index 52619b38098..984dddf3dc4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
@@ -52,10 +52,12 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
       outerPlan: LogicalPlan,
       subplan: LogicalPlan,
       joinType: JoinType,
-      condition: Option[Expression]): Join = {
+      condition: Option[Expression],
+      subHint: Option[HintInfo]): Join = {
     // Deduplicate conflicting attributes if any.
     val dedupSubplan = dedupSubqueryOnSelfJoin(outerPlan, subplan, None, condition)
-    Join(outerPlan, dedupSubplan, joinType, condition, JoinHint.NONE)
+    // Add subquery hint as right hint as the subquery plan is on the right side of the join
+    Join(outerPlan, dedupSubplan, joinType, condition, JoinHint(None, subHint))
   }
 
   private def dedupSubqueryOnSelfJoin(
@@ -111,19 +113,19 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
 
       // Filter the plan by applying left semi and left anti joins.
       withSubquery.foldLeft(newFilter) {
-        case (p, Exists(sub, _, _, conditions)) =>
+        case (p, Exists(sub, _, _, conditions, subHint)) =>
           val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
-          buildJoin(outerPlan, sub, LeftSemi, joinCond)
-        case (p, Not(Exists(sub, _, _, conditions))) =>
+          buildJoin(outerPlan, sub, LeftSemi, joinCond, subHint)
+        case (p, Not(Exists(sub, _, _, conditions, subHint))) =>
           val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
-          buildJoin(outerPlan, sub, LeftAnti, joinCond)
-        case (p, InSubquery(values, ListQuery(sub, _, _, _, conditions))) =>
+          buildJoin(outerPlan, sub, LeftAnti, joinCond, subHint)
+        case (p, InSubquery(values, ListQuery(sub, _, _, _, conditions, subHint))) =>
           // Deduplicate conflicting attributes if any.
           val newSub = dedupSubqueryOnSelfJoin(p, sub, Some(values))
           val inConditions = values.zip(newSub.output).map(EqualTo.tupled)
           val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p)
-          Join(outerPlan, newSub, LeftSemi, joinCond, JoinHint.NONE)
-        case (p, Not(InSubquery(values, ListQuery(sub, _, _, _, conditions)))) =>
+          Join(outerPlan, newSub, LeftSemi, joinCond, JoinHint(None, subHint))
+        case (p, Not(InSubquery(values, ListQuery(sub, _, _, _, conditions, subHint)))) =>
           // This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr
           // Construct the condition. A NULL in one of the conditions is regarded as a positive
           // result; such a row will be filtered out by the Anti-Join operator.
@@ -148,7 +150,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
           // will have the final conditions in the LEFT ANTI as
           // (A.A1 = B.B1 OR ISNULL(A.A1 = B.B1)) AND (B.B2 = A.A2) AND B.B3 > 1
           val finalJoinCond = (nullAwareJoinConds ++ conditions).reduceLeft(And)
-          Join(outerPlan, newSub, LeftAnti, Option(finalJoinCond), JoinHint.NONE)
+          Join(outerPlan, newSub, LeftAnti, Option(finalJoinCond), JoinHint(None, subHint))
         case (p, predicate) =>
           val (newCond, inputPlan) = rewriteExistentialExpr(Seq(predicate), p)
           Project(p.output, Filter(newCond.get, inputPlan))
@@ -177,12 +179,13 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
     var newPlan = plan
     val newExprs = exprs.map { e =>
       e.transformDownWithPruning(_.containsAnyPattern(EXISTS_SUBQUERY, IN_SUBQUERY)) {
-        case Exists(sub, _, _, conditions) =>
+        case Exists(sub, _, _, conditions, subHint) =>
           val exists = AttributeReference("exists", BooleanType, nullable = false)()
+          val existenceJoin = ExistenceJoin(exists)
           newPlan =
-            buildJoin(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And))
+            buildJoin(newPlan, sub, existenceJoin, conditions.reduceLeftOption(And), subHint)
           exists
-        case Not(InSubquery(values, ListQuery(sub, _, _, _, conditions))) =>
+        case Not(InSubquery(values, ListQuery(sub, _, _, _, conditions, subHint))) =>
           val exists = AttributeReference("exists", BooleanType, nullable = false)()
           // Deduplicate conflicting attributes if any.
           val newSub = dedupSubqueryOnSelfJoin(newPlan, sub, Some(values))
@@ -201,15 +204,17 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
           //     +- Relation[id#80] parquet
           val nullAwareJoinConds = inConditions.map(c => Or(c, IsNull(c)))
           val finalJoinCond = (nullAwareJoinConds ++ conditions).reduceLeft(And)
-          newPlan = Join(newPlan, newSub, ExistenceJoin(exists), Some(finalJoinCond), JoinHint.NONE)
+          val joinHint = JoinHint(None, subHint)
+          newPlan = Join(newPlan, newSub, ExistenceJoin(exists), Some(finalJoinCond), joinHint)
           Not(exists)
-        case InSubquery(values, ListQuery(sub, _, _, _, conditions)) =>
+        case InSubquery(values, ListQuery(sub, _, _, _, conditions, subHint)) =>
           val exists = AttributeReference("exists", BooleanType, nullable = false)()
           // Deduplicate conflicting attributes if any.
           val newSub = dedupSubqueryOnSelfJoin(newPlan, sub, Some(values))
           val inConditions = values.zip(newSub.output).map(EqualTo.tupled)
           val newConditions = (inConditions ++ conditions).reduceLeftOption(And)
-          newPlan = Join(newPlan, newSub, ExistenceJoin(exists), newConditions, JoinHint.NONE)
+          val joinHint = JoinHint(None, subHint)
+          newPlan = Join(newPlan, newSub, ExistenceJoin(exists), newConditions, joinHint)
           exists
       }
     }
@@ -319,18 +324,19 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
     }
 
     plan.transformExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION)) {
-      case ScalarSubquery(sub, children, exprId, conditions) if children.nonEmpty =>
+      case ScalarSubquery(sub, children, exprId, conditions, hint) if children.nonEmpty =>
         val (newPlan, newCond) = decorrelate(sub, plan)
-        ScalarSubquery(newPlan, children, exprId, getJoinCondition(newCond, conditions))
-      case Exists(sub, children, exprId, conditions) if children.nonEmpty =>
+        ScalarSubquery(newPlan, children, exprId, getJoinCondition(newCond, conditions), hint)
+      case Exists(sub, children, exprId, conditions, hint) if children.nonEmpty =>
         val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, plan)
-        Exists(newPlan, children, exprId, getJoinCondition(newCond, conditions))
-      case ListQuery(sub, children, exprId, childOutputs, conditions) if children.nonEmpty =>
+        Exists(newPlan, children, exprId, getJoinCondition(newCond, conditions), hint)
+      case ListQuery(sub, children, exprId, childOutputs, conditions, hint) if children.nonEmpty =>
         val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, plan)
-        ListQuery(newPlan, children, exprId, childOutputs, getJoinCondition(newCond, conditions))
-      case LateralSubquery(sub, children, exprId, conditions) if children.nonEmpty =>
+        val joinCond = getJoinCondition(newCond, conditions)
+        ListQuery(newPlan, children, exprId, childOutputs, joinCond, hint)
+      case LateralSubquery(sub, children, exprId, conditions, hint) if children.nonEmpty =>
         val (newPlan, newCond) = decorrelate(sub, plan, handleCountBug = true)
-        LateralSubquery(newPlan, children, exprId, getJoinCondition(newCond, conditions))
+        LateralSubquery(newPlan, children, exprId, getJoinCondition(newCond, conditions), hint)
     }
   }
 
@@ -562,14 +568,17 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe
       subqueries: ArrayBuffer[ScalarSubquery]): (LogicalPlan, AttributeMap[Attribute]) = {
     val subqueryAttrMapping = ArrayBuffer[(Attribute, Attribute)]()
     val newChild = subqueries.foldLeft(child) {
-      case (currentChild, ScalarSubquery(sub, _, _, conditions)) =>
+      case (currentChild, ScalarSubquery(sub, _, _, conditions, subHint)) =>
         val query = DecorrelateInnerQuery.rewriteDomainJoins(currentChild, sub, conditions)
         val origOutput = query.output.head
+        // The subquery appears on the right side of the join, hence add its hint to the right
+        // of a join hint
+        val joinHint = JoinHint(None, subHint)
 
         val resultWithZeroTups = evalSubqueryOnZeroTups(query)
         lazy val planWithoutCountBug = Project(
           currentChild.output :+ origOutput,
-          Join(currentChild, query, LeftOuter, conditions.reduceOption(And), JoinHint.NONE))
+          Join(currentChild, query, LeftOuter, conditions.reduceOption(And), joinHint))
 
         if (resultWithZeroTups.isEmpty) {
           // CASE 1: Subquery guaranteed not to have the COUNT bug
@@ -607,7 +616,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe
                 currentChild.output :+ subqueryResultExpr,
                 Join(currentChild,
                   Project(query.output :+ alwaysTrueExpr, query),
-                  LeftOuter, conditions.reduceOption(And), JoinHint.NONE))
+                  LeftOuter, conditions.reduceOption(And), joinHint))
 
             } else {
               // CASE 3: Subquery with HAVING clause. Pull the HAVING clause above the join.
@@ -639,7 +648,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe
                 currentChild.output :+ caseExpr,
                 Join(currentChild,
                   Project(subqueryRoot.output :+ alwaysTrueExpr, subqueryRoot),
-                  LeftOuter, conditions.reduceOption(And), JoinHint.NONE))
+                  LeftOuter, conditions.reduceOption(And), joinHint))
             }
           }
         }
@@ -730,10 +739,11 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe
 object RewriteLateralSubquery extends Rule[LogicalPlan] {
   def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
     _.containsPattern(LATERAL_JOIN)) {
-    case LateralJoin(left, LateralSubquery(sub, _, _, joinCond), joinType, condition) =>
+    case LateralJoin(left, LateralSubquery(sub, _, _, joinCond, subHint), joinType, condition) =>
       val newRight = DecorrelateInnerQuery.rewriteDomainJoins(left, sub, joinCond)
       val newCond = (condition ++ joinCond).reduceOption(And)
-      Join(left, newRight, joinType, newCond, JoinHint.NONE)
+      // The subquery appears on the right side of the join, hence add the hint to the right side
+      Join(left, newRight, joinType, newCond, JoinHint(None, subHint))
   }
 }
 
@@ -763,12 +773,13 @@ object OptimizeOneRowRelationSubquery extends Rule[LogicalPlan] {
    * if there is no nested subqueries in the subquery plan.
    */
   private def rewrite(plan: LogicalPlan): LogicalPlan = plan.transformUpWithSubqueries {
-    case LateralJoin(left, right @ LateralSubquery(OneRowSubquery(projectList), _, _, _), _, None)
+    case LateralJoin(
+      left, right @ LateralSubquery(OneRowSubquery(projectList), _, _, _, _), _, None)
         if !hasCorrelatedSubquery(right.plan) && right.joinCond.isEmpty =>
       Project(left.output ++ projectList, left)
     case p: LogicalPlan => p.transformExpressionsUpWithPruning(
       _.containsPattern(SCALAR_SUBQUERY)) {
-      case s @ ScalarSubquery(OneRowSubquery(projectList), _, _, _)
+      case s @ ScalarSubquery(OneRowSubquery(projectList), _, _, _, _)
           if !hasCorrelatedSubquery(s.plan) && s.joinCond.isEmpty =>
         assert(projectList.size == 1)
         projectList.head
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 2f4b03cf591..84d78f365ac 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -3149,6 +3149,13 @@ object SQLConf {
       .booleanConf
       .createWithDefault(true)
 
+  val PULL_HINTS_INTO_SUBQUERIES =
+    buildConf("spark.sql.optimizer.pullHintsIntoSubqueries")
+      .internal()
+      .doc("Pull hints into subqueries in EliminateResolvedHint if enabled.")
+      .booleanConf
+      .createWithDefault(true)
+
   val TOP_K_SORT_FALLBACK_THRESHOLD =
     buildConf("spark.sql.execution.topKSortFallbackThreshold")
       .doc("In SQL queries with a SORT followed by a LIMIT like " +
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala
index d05f8c8c98d..939d245304b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala
@@ -122,21 +122,21 @@ case class InsertAdaptiveSparkPlan(
       return subqueryMap.toMap
     }
     plan.foreach(_.expressions.filter(_.containsPattern(PLAN_EXPRESSION)).foreach(_.foreach {
-      case expressions.ScalarSubquery(p, _, exprId, _)
+      case expressions.ScalarSubquery(p, _, exprId, _, _)
           if !subqueryMap.contains(exprId.id) =>
         val executedPlan = compileSubquery(p)
         verifyAdaptivePlan(executedPlan, p)
         val subquery = SubqueryExec.createForScalarSubquery(
           s"subquery#${exprId.id}", executedPlan)
         subqueryMap.put(exprId.id, subquery)
-      case expressions.InSubquery(_, ListQuery(query, _, exprId, _, _))
+      case expressions.InSubquery(_, ListQuery(query, _, exprId, _, _, _))
           if !subqueryMap.contains(exprId.id) =>
         val executedPlan = compileSubquery(query)
         verifyAdaptivePlan(executedPlan, query)
         val subquery = SubqueryExec(s"subquery#${exprId.id}", executedPlan)
         subqueryMap.put(exprId.id, subquery)
       case expressions.DynamicPruningSubquery(value, buildPlan,
-          buildKeys, broadcastKeyIndex, onlyInBroadcast, exprId)
+          buildKeys, broadcastKeyIndex, onlyInBroadcast, exprId, _)
           if !subqueryMap.contains(exprId.id) =>
         val executedPlan = compileSubquery(buildPlan)
         verifyAdaptivePlan(executedPlan, buildPlan)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala
index e1c07d31704..f88d2ffd541 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala
@@ -31,9 +31,9 @@ case class PlanAdaptiveSubqueries(
   def apply(plan: SparkPlan): SparkPlan = {
     plan.transformAllExpressionsWithPruning(
       _.containsAnyPattern(SCALAR_SUBQUERY, IN_SUBQUERY, DYNAMIC_PRUNING_SUBQUERY)) {
-      case expressions.ScalarSubquery(_, _, exprId, _) =>
+      case expressions.ScalarSubquery(_, _, exprId, _, _) =>
         execution.ScalarSubquery(subqueryMap(exprId.id), exprId)
-      case expressions.InSubquery(values, ListQuery(_, _, exprId, _, _)) =>
+      case expressions.InSubquery(values, ListQuery(_, _, exprId, _, _, _)) =>
         val expr = if (values.length == 1) {
           values.head
         } else {
@@ -44,7 +44,7 @@ case class PlanAdaptiveSubqueries(
           )
         }
         InSubqueryExec(expr, subqueryMap(exprId.id), exprId, shouldBroadcast = true)
-      case expressions.DynamicPruningSubquery(value, _, _, _, _, exprId) =>
+      case expressions.DynamicPruningSubquery(value, _, _, _, _, exprId, _) =>
         DynamicPruningExpression(InSubqueryExec(value, subqueryMap(exprId.id), exprId))
     }
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala
index c9ff28eb045..42c4c20e20d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala
@@ -51,7 +51,7 @@ case class PlanDynamicPruningFilters(sparkSession: SparkSession) extends Rule[Sp
 
     plan.transformAllExpressionsWithPruning(_.containsPattern(DYNAMIC_PRUNING_SUBQUERY)) {
       case DynamicPruningSubquery(
-          value, buildPlan, buildKeys, broadcastKeyIndex, onlyInBroadcast, exprId) =>
+          value, buildPlan, buildKeys, broadcastKeyIndex, onlyInBroadcast, exprId, _) =>
         val sparkPlan = QueryExecution.createSparkPlan(
           sparkSession, sparkSession.sessionState.planner, buildPlan)
         // Using `sparkPlan` is a little hacky as it is based on the assumption that this rule is
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
index 7f183d744bc..aaebe99eeee 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
@@ -182,7 +182,7 @@ case class PlanSubqueries(sparkSession: SparkSession) extends Rule[SparkPlan] {
           SubqueryExec.createForScalarSubquery(
             s"scalar-subquery#${subquery.exprId.id}", executedPlan),
           subquery.exprId)
-      case expressions.InSubquery(values, ListQuery(query, _, exprId, _, _)) =>
+      case expressions.InSubquery(values, ListQuery(query, _, exprId, _, _, _)) =>
         val expr = if (values.length == 1) {
           values.head
         } else {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubqueryHintPropagationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubqueryHintPropagationSuite.scala
new file mode 100644
index 00000000000..eefb762b59c
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubqueryHintPropagationSuite.scala
@@ -0,0 +1,221 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
+import org.apache.spark.sql.catalyst.plans.{InnerLike, LeftSemi}
+import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, HintInfo, Join, JoinHint, LogicalPlan}
+import org.apache.spark.sql.functions.col
+import org.apache.spark.sql.test.SharedSparkSession
+
+class SubqueryHintPropagationSuite extends QueryTest with SharedSparkSession {
+
+  setupTestData()
+
+  private val expectedHint =
+    Some(HintInfo(strategy = Some(BROADCAST)))
+  private val hints = Seq("BROADCAST", "SHUFFLE_MERGE")
+  private val hintStringified = hints.map("/*+ " + _ + " */").mkString
+
+  def verifyJoinContainsHint(plan: LogicalPlan): Unit = {
+    val expectedJoinHint = JoinHint(leftHint = None, rightHint = expectedHint)
+    val joinsFound = plan.collect {
+      case j @ Join(_, _, _, _, foundHint) =>
+        assert(expectedJoinHint == foundHint)
+    }
+    assert(joinsFound.size == 1)
+  }
+
+  test("Correlated Exists") {
+    val queryDf = sql(
+      s"""SELECT * FROM testData s1 WHERE EXISTS
+         |(SELECT $hintStringified
+         |s2.key FROM testData s2 WHERE s1.key = s2.key AND s1.value = s2.value)
+         |""".stripMargin)
+    verifyJoinContainsHint(queryDf.queryExecution.optimizedPlan)
+    checkAnswer(queryDf, testData)
+  }
+
+  test("Correlated Exists with hints in tempView") {
+    val tempView = "tmpView"
+    withTempView(tempView) {
+      val df = spark
+        .range(1, 30)
+        .where("true")
+      val dfWithHints = hints.foldRight(df)((hint, newDf) => newDf.hint(hint))
+        .selectExpr("id as key", "id as value")
+        .withColumn("value", col("value").cast("string"))
+      dfWithHints.createOrReplaceTempView(tempView)
+
+      val queryDf = sql(
+        s"""SELECT * FROM testData s1 WHERE EXISTS
+           |(SELECT s2.key FROM $tempView s2 WHERE s1.key = s2.key AND s1.value = s2.value)
+           |""".stripMargin)
+
+      verifyJoinContainsHint(queryDf.queryExecution.optimizedPlan)
+      checkAnswer(queryDf, dfWithHints)
+    }
+  }
+
+  test("Correlated Exists containing join with hint") {
+    val queryDf = sql(
+      s"""SELECT * FROM testData s1 WHERE EXISTS
+         |(SELECT s2.key FROM
+         |(SELECT $hintStringified * FROM testData) s2 JOIN testData s3
+         |ON s2.key = s3.key
+         |WHERE s2.key = s1.key)
+         |""".stripMargin)
+    val optimized = queryDf.queryExecution.optimizedPlan
+
+    // the subquery will be turned into a left semi join and should not contain any hints
+    optimized.foreach {
+      case Join(_, _, joinType, _, hint) =>
+        joinType match {
+          case _: InnerLike => assert(expectedHint == hint.leftHint)
+          case LeftSemi => assert(hint.leftHint.isEmpty && hint.rightHint.isEmpty)
+          case _ => throw new IllegalArgumentException("Unexpected join found.")
+        }
+      case _ =>
+    }
+    checkAnswer(queryDf, testData)
+  }
+
+  test("Negated Exists with hint") {
+    val queryDf = sql(
+      s"""SELECT * FROM testData s1 WHERE NOT EXISTS
+         |(SELECT $hintStringified
+         |* FROM testData s2 WHERE s1.key = s2.key AND s1.value = s2.value)
+         |""".stripMargin)
+    verifyJoinContainsHint(queryDf.queryExecution.optimizedPlan)
+    checkAnswer(queryDf, spark.emptyDataFrame)
+  }
+
+  test("Exists with complex predicate") {
+    val queryDf = sql(
+      s"""SELECT * FROM testData s1 WHERE EXISTS
+         |(SELECT $hintStringified
+         |* FROM testData s2 WHERE s1.key = s2.key AND s1.value = s2.value) OR s1.key = 5
+         |""".stripMargin)
+    verifyJoinContainsHint(queryDf.queryExecution.optimizedPlan)
+    checkAnswer(queryDf, testData)
+  }
+
+  test("Non-correlated IN") {
+    val queryDf = sql(
+      s"""SELECT * FROM testData s1 WHERE key IN
+         |(SELECT $hintStringified key FROM testData s2)
+         |""".stripMargin)
+    verifyJoinContainsHint(queryDf.queryExecution.optimizedPlan)
+    checkAnswer(queryDf, testData)
+  }
+
+  test("Correlated IN") {
+    val queryDf = sql(
+      s"""SELECT * FROM testData s1 WHERE key IN
+         |(SELECT $hintStringified
+         |key FROM testData s2 WHERE s1.key = s2.key AND s1.value = s2.value)
+         |""".stripMargin)
+    verifyJoinContainsHint(queryDf.queryExecution.optimizedPlan)
+    checkAnswer(queryDf, testData)
+  }
+
+  test("Negated IN with hint") {
+    val queryDf = sql(
+      s"""SELECT * FROM testData s1 WHERE key NOT IN
+         |(SELECT $hintStringified
+         |key FROM testData s2 WHERE s1.key = s2.key AND s1.value = s2.value)
+         |""".stripMargin)
+    verifyJoinContainsHint(queryDf.queryExecution.optimizedPlan)
+    checkAnswer(queryDf, spark.emptyDataFrame)
+  }
+
+  test("IN with complex predicate") {
+    val queryDf = sql(
+      s"""SELECT * FROM testData s1 WHERE key in
+         |(SELECT $hintStringified
+         | key FROM testData s2 WHERE s1.key = s2.key AND s1.value = s2.value) OR s1.key = 5
+         |""".stripMargin)
+    verifyJoinContainsHint(queryDf.queryExecution.optimizedPlan)
+    checkAnswer(queryDf, testData)
+  }
+
+  test("Scalar subquery") {
+    val queryDf = sql(
+      s"""SELECT * FROM testData s1 WHERE key =
+         |(SELECT $hintStringified MAX(key) FROM
+         |testData s2 WHERE s1.key = s2.key AND s1.value = s2.value)
+         |""".stripMargin)
+    verifyJoinContainsHint(queryDf.queryExecution.optimizedPlan)
+    checkAnswer(queryDf, testData)
+  }
+
+  test("Scalar subquery with COUNT") {
+    val queryDf = sql(
+      s"""SELECT * FROM testData s1 WHERE key =
+         |(SELECT $hintStringified COUNT(key) FROM
+         |testData s2 WHERE s1.key = s2.key AND s1.value = s2.value)
+         |""".stripMargin)
+    verifyJoinContainsHint(queryDf.queryExecution.optimizedPlan)
+    checkAnswer(queryDf, Row(1, "1"))
+  }
+
+  test("Scalar subquery with non-equality predicates") {
+    val queryDf = sql(
+      s"""SELECT * FROM testData s1 WHERE key =
+         |(SELECT $hintStringified MAX(key) FROM
+         |testData s2 WHERE s1.key > s2.key AND s1.value > s2.value)
+         |""".stripMargin)
+    val condContainsMax = (condition: Expression) => {
+      condition.find {
+        case e: AttributeReference if e.name.contains("max") =>
+          true
+        case _ => false
+      }.isDefined
+    }
+    val optimizedPlan = queryDf.queryExecution.optimizedPlan
+    val expectedJoinHint = JoinHint(leftHint = None, rightHint = expectedHint)
+    val joinsFound = optimizedPlan.collect {
+      case j: Join if j.condition.nonEmpty && condContainsMax(j.condition.get) =>
+        assert(expectedJoinHint == j.hint)
+    }
+    assert(joinsFound.size == 1)
+    checkAnswer(queryDf, spark.emptyDataFrame)
+  }
+
+  test("Scalar subquery nested subquery") {
+    val queryDf = sql(
+      s"""SELECT * FROM testData s1 WHERE key =
+         |(SELECT MAX(key) FROM
+         |(SELECT $hintStringified key FROM testData s2 WHERE
+         |s1.key = s2.key AND s1.value = s2.value))
+         |""".stripMargin)
+    verifyJoinContainsHint(queryDf.queryExecution.optimizedPlan)
+    checkAnswer(queryDf, testData)
+  }
+
+  test("Lateral subquery") {
+    val queryDf = sql(
+      s"""SELECT * FROM testData s1, LATERAL
+         |(SELECT $hintStringified * FROM testData s2)
+         |""".stripMargin)
+    verifyJoinContainsHint(queryDf.queryExecution.optimizedPlan)
+    // No condition, should be the same as a cross join.
+    val expectedAnswer = testData.crossJoin(testData)
+    checkAnswer(queryDf, expectedAnswer)
+  }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala
index a3c34b22a28..80f258c4659 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala
@@ -1000,7 +1000,7 @@ class PlanResolutionSuite extends AnalysisTest {
           query match {
             case ListQuery(Project(projects, SubqueryAlias(AliasIdentifier("s", Seq()),
                 UnresolvedSubqueryColumnAliases(outputColumnNames, Project(_, _: OneRowRelation)))),
-                _, _, _, _) =>
+                _, _, _, _, _) =>
               assert(projects.size == 1 && projects.head.name == "s.name")
               assert(outputColumnNames.size == 1 && outputColumnNames.head == "name")
             case o => fail("Unexpected subquery: \n" + o.treeString)
@@ -1090,7 +1090,7 @@ class PlanResolutionSuite extends AnalysisTest {
           query match {
             case ListQuery(Project(projects, SubqueryAlias(AliasIdentifier("s", Seq()),
                 UnresolvedSubqueryColumnAliases(outputColumnNames, Project(_, _: OneRowRelation)))),
-                _, _, _, _) =>
+                _, _, _, _, _) =>
               assert(projects.size == 1 && projects.head.name == "s.name")
               assert(outputColumnNames.size == 1 && outputColumnNames.head == "name")
             case o => fail("Unexpected subquery: \n" + o.treeString)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org