You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2019/03/04 11:09:52 UTC

[spark] branch master updated: [SPARK-19712][SQL] Pushing Left Semi and Left Anti joins through Project, Aggregate, Window, Union etc.

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new ad4823c  [SPARK-19712][SQL] Pushing Left Semi and Left Anti joins through Project, Aggregate, Window, Union etc.
ad4823c is described below

commit ad4823c99dcbb3d48f0cb8c556450ff857208709
Author: Dilip Biswal <db...@us.ibm.com>
AuthorDate: Mon Mar 4 19:09:24 2019 +0800

    [SPARK-19712][SQL] Pushing Left Semi and Left Anti joins through Project, Aggregate, Window, Union etc.
    
    ## What changes were proposed in this pull request?
    This PR adds support for pushing down LeftSemi and LeftAnti joins below operators such as Project, Aggregate, Window, Union etc.  This is the initial piece of work that will be needed for
    the subsequent work of moving the subquery rewrites to the beginning of optimization phase.
    
    The larger  PR is [here](https://github.com/apache/spark/pull/23211) . This PR addresses the comment at [link](https://github.com/apache/spark/pull/23211#issuecomment-445705922).
    ## How was this patch tested?
    Added a new test suite LeftSemiAntiJoinPushDownSuite.
    
    Closes #23750 from dilipbiswal/SPARK-19712-pushleftsemi.
    
    Authored-by: Dilip Biswal <db...@us.ibm.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../spark/sql/catalyst/optimizer/Optimizer.scala   |  34 +--
 .../optimizer/PushDownLeftSemiAntiJoin.scala       | 197 +++++++++++++++
 .../spark/sql/catalyst/plans/joinTypes.scala       |   7 +
 .../optimizer/LeftSemiAntiJoinPushDownSuite.scala  | 279 +++++++++++++++++++++
 .../sql/execution/metric/SQLMetricsSuite.scala     |   4 +-
 5 files changed, 505 insertions(+), 16 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index ad25898..3c59e4d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -95,6 +95,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
         EliminateOuterJoin,
         PushPredicateThroughJoin,
         PushDownPredicate,
+        PushDownLeftSemiAntiJoin,
         LimitPushDown,
         ColumnPruning,
         InferFiltersFromConstraints,
@@ -1012,24 +1013,13 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
     // This also applies to Aggregate.
     case Filter(condition, project @ Project(fields, grandChild))
       if fields.forall(_.deterministic) && canPushThroughCondition(grandChild, condition) =>
-
-      // Create a map of Aliases to their values from the child projection.
-      // e.g., 'SELECT a + b AS c, d ...' produces Map(c -> a + b).
-      val aliasMap = AttributeMap(fields.collect {
-        case a: Alias => (a.toAttribute, a.child)
-      })
-
+      val aliasMap = getAliasMap(project)
       project.copy(child = Filter(replaceAlias(condition, aliasMap), grandChild))
 
     case filter @ Filter(condition, aggregate: Aggregate)
       if aggregate.aggregateExpressions.forall(_.deterministic)
         && aggregate.groupingExpressions.nonEmpty =>
-      // Find all the aliased expressions in the aggregate list that don't include any actual
-      // AggregateExpression, and create a map from the alias to the expression
-      val aliasMap = AttributeMap(aggregate.aggregateExpressions.collect {
-        case a: Alias if a.child.find(_.isInstanceOf[AggregateExpression]).isEmpty =>
-          (a.toAttribute, a.child)
-      })
+      val aliasMap = getAliasMap(aggregate)
 
       // For each filter, expand the alias and check if the filter can be evaluated using
       // attributes produced by the aggregate operator's child operator.
@@ -1127,7 +1117,23 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
       }
   }
 
-  private def canPushThrough(p: UnaryNode): Boolean = p match {
+  def getAliasMap(plan: Project): AttributeMap[Expression] = {
+    // Create a map of Aliases to their values from the child projection.
+    // e.g., 'SELECT a + b AS c, d ...' produces Map(c -> a + b).
+    AttributeMap(plan.projectList.collect { case a: Alias => (a.toAttribute, a.child) })
+  }
+
+  def getAliasMap(plan: Aggregate): AttributeMap[Expression] = {
+    // Find all the aliased expressions in the aggregate list that don't include any actual
+    // AggregateExpression, and create a map from the alias to the expression
+    val aliasMap = plan.aggregateExpressions.collect {
+      case a: Alias if a.child.find(_.isInstanceOf[AggregateExpression]).isEmpty =>
+        (a.toAttribute, a.child)
+    }
+    AttributeMap(aliasMap)
+  }
+
+  def canPushThrough(p: UnaryNode): Boolean = p match {
     // Note that some operators (e.g. project, aggregate, union) are being handled separately
     // (earlier in this rule).
     case _: AppendColumns => true
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala
new file mode 100644
index 0000000..bc868df
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala
@@ -0,0 +1,197 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules.Rule
+
+/**
+ * This rule is a variant of [[PushDownPredicate]] which can handle
+ * pushing down Left semi and Left Anti joins below the following operators.
+ *  1) Project
+ *  2) Window
+ *  3) Union
+ *  4) Aggregate
+ *  5) Other permissible unary operators. please see [[PushDownPredicate.canPushThrough]].
+ */
+object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper {
+  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+    // LeftSemi/LeftAnti over Project
+    case Join(p @ Project(pList, gChild), rightOp, LeftSemiOrAnti(joinType), joinCond, hint)
+      if pList.forall(_.deterministic) &&
+        !pList.exists(ScalarSubquery.hasCorrelatedScalarSubquery) &&
+        canPushThroughCondition(Seq(gChild), joinCond, rightOp) =>
+      if (joinCond.isEmpty) {
+        // No join condition, just push down the Join below Project
+        p.copy(child = Join(gChild, rightOp, joinType, joinCond, hint))
+      } else {
+        val aliasMap = PushDownPredicate.getAliasMap(p)
+        val newJoinCond = if (aliasMap.nonEmpty) {
+          Option(replaceAlias(joinCond.get, aliasMap))
+        } else {
+          joinCond
+        }
+        p.copy(child = Join(gChild, rightOp, joinType, newJoinCond, hint))
+      }
+
+    // LeftSemi/LeftAnti over Aggregate
+    case join @ Join(agg: Aggregate, rightOp, LeftSemiOrAnti(joinType), joinCond, hint)
+      if agg.aggregateExpressions.forall(_.deterministic) && agg.groupingExpressions.nonEmpty &&
+        !agg.aggregateExpressions.exists(ScalarSubquery.hasCorrelatedScalarSubquery) =>
+      if (joinCond.isEmpty) {
+        // No join condition, just push down Join below Aggregate
+        agg.copy(child = Join(agg.child, rightOp, joinType, joinCond, hint))
+      } else {
+        val aliasMap = PushDownPredicate.getAliasMap(agg)
+
+        // For each join condition, expand the alias and check if the condition can be evaluated
+        // using attributes produced by the aggregate operator's child operator.
+        val (pushDown, stayUp) = splitConjunctivePredicates(joinCond.get).partition { cond =>
+          val replaced = replaceAlias(cond, aliasMap)
+          cond.references.nonEmpty &&
+            replaced.references.subsetOf(agg.child.outputSet ++ rightOp.outputSet)
+        }
+
+        // Check if the remaining predicates do not contain columns from the right
+        // hand side of the join. Since the remaining predicates will be kept
+        // as a filter over aggregate, this check is necessary after the left semi
+        // or left anti join is moved below aggregate. The reason is, for this kind
+        // of join, we only output from the left leg of the join.
+        val rightOpColumns = AttributeSet(stayUp.toSet).intersect(rightOp.outputSet)
+
+        if (pushDown.nonEmpty && rightOpColumns.isEmpty) {
+          val pushDownPredicate = pushDown.reduce(And)
+          val replaced = replaceAlias(pushDownPredicate, aliasMap)
+          val newAgg = agg.copy(child = Join(agg.child, rightOp, joinType, Option(replaced), hint))
+          // If there is no more filter to stay up, just return the Aggregate over Join.
+          // Otherwise, create "Filter(stayUp) <- Aggregate <- Join(pushDownPredicate)".
+          if (stayUp.isEmpty) newAgg else Filter(stayUp.reduce(And), newAgg)
+        } else {
+          // The join condition is not a subset of the Aggregate's GROUP BY columns,
+          // no push down.
+          join
+        }
+      }
+
+    // LeftSemi/LeftAnti over Window
+    case join @ Join(w: Window, rightOp, LeftSemiOrAnti(joinType), joinCond, hint)
+      if w.partitionSpec.forall(_.isInstanceOf[AttributeReference]) =>
+      if (joinCond.isEmpty) {
+        // No join condition, just push down Join below Window
+        w.copy(child = Join(w.child, rightOp, joinType, joinCond, hint))
+      } else {
+        val partitionAttrs = AttributeSet(w.partitionSpec.flatMap(_.references)) ++
+          rightOp.outputSet
+
+        val (pushDown, stayUp) = splitConjunctivePredicates(joinCond.get).partition { cond =>
+          cond.references.subsetOf(partitionAttrs)
+        }
+
+        // Check if the remaining predicates do not contain columns from the right
+        // hand side of the join. Since the remaining predicates will be kept
+        // as a filter over window, this check is necessary after the left semi
+        // or left anti join is moved below window. The reason is, for this kind
+        // of join, we only output from the left leg of the join.
+        val rightOpColumns = AttributeSet(stayUp.toSet).intersect(rightOp.outputSet)
+
+        if (pushDown.nonEmpty && rightOpColumns.isEmpty) {
+          val predicate = pushDown.reduce(And)
+          val newPlan = w.copy(child = Join(w.child, rightOp, joinType, Option(predicate), hint))
+          if (stayUp.isEmpty) newPlan else Filter(stayUp.reduce(And), newPlan)
+        } else {
+          // The join condition is not a subset of the Window's PARTITION BY clause,
+          // no push down.
+          join
+        }
+      }
+
+    // LeftSemi/LeftAnti over Union
+    case join @ Join(union: Union, rightOp, LeftSemiOrAnti(joinType), joinCond, hint)
+      if canPushThroughCondition(union.children, joinCond, rightOp) =>
+      if (joinCond.isEmpty) {
+        // Push down the Join below Union
+        val newGrandChildren = union.children.map { Join(_, rightOp, joinType, joinCond, hint) }
+        union.withNewChildren(newGrandChildren)
+      } else {
+        val output = union.output
+        val newGrandChildren = union.children.map { grandchild =>
+          val newCond = joinCond.get transform {
+            case e if output.exists(_.semanticEquals(e)) =>
+              grandchild.output(output.indexWhere(_.semanticEquals(e)))
+          }
+          assert(newCond.references.subsetOf(grandchild.outputSet ++ rightOp.outputSet))
+          Join(grandchild, rightOp, joinType, Option(newCond), hint)
+        }
+        union.withNewChildren(newGrandChildren)
+      }
+
+    // LeftSemi/LeftAnti over UnaryNode
+    case join @ Join(u: UnaryNode, rightOp, LeftSemiOrAnti(joinType), joinCond, hint)
+      if PushDownPredicate.canPushThrough(u) && u.expressions.forall(_.deterministic) =>
+      pushDownJoin(join, u.child) { joinCond =>
+        u.withNewChildren(Seq(Join(u.child, rightOp, joinType, joinCond, hint)))
+      }
+  }
+
+  /**
+   * Check if we can safely push a join through a project or union by making sure that attributes
+   * referred in join condition do not contain the same attributes as the plan they are moved
+   * into. This can happen when both sides of join refers to the same source (self join). This
+   * function makes sure that the join condition refers to attributes that are not ambiguous (i.e
+   * present in both the legs of the join) or else the resultant plan will be invalid.
+   */
+  private def canPushThroughCondition(
+      plans: Seq[LogicalPlan],
+      condition: Option[Expression],
+      rightOp: LogicalPlan): Boolean = {
+    val attributes = AttributeSet(plans.flatMap(_.output))
+    if (condition.isDefined) {
+      val matched = condition.get.references.intersect(rightOp.outputSet).intersect(attributes)
+      matched.isEmpty
+    } else {
+      true
+    }
+  }
+
+
+  private def pushDownJoin(
+      join: Join,
+      grandchild: LogicalPlan)(insertJoin: Option[Expression] => LogicalPlan): LogicalPlan = {
+    if (join.condition.isEmpty) {
+      insertJoin(None)
+    } else {
+      val (pushDown, stayUp) = splitConjunctivePredicates(join.condition.get)
+        .partition {_.references.subsetOf(grandchild.outputSet ++ join.right.outputSet)}
+
+      val rightOpColumns = AttributeSet(stayUp.toSet).intersect(join.right.outputSet)
+      if (pushDown.nonEmpty && rightOpColumns.isEmpty) {
+        val newChild = insertJoin(Option(pushDown.reduceLeft(And)))
+        if (stayUp.nonEmpty) {
+          Filter(stayUp.reduceLeft(And), newChild)
+        } else {
+          newChild
+        }
+      } else {
+        join
+      }
+    }
+  }
+}
+
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
index c778490..86cdc26 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
@@ -114,3 +114,10 @@ object LeftExistence {
     case _ => None
   }
 }
+
+object LeftSemiOrAnti {
+  def unapply(joinType: JoinType): Option[JoinType] = joinType match {
+    case LeftSemi | LeftAnti => Some(joinType)
+    case _ => None
+  }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala
new file mode 100644
index 0000000..1a0231e
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala
@@ -0,0 +1,279 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules._
+import org.apache.spark.sql.types.IntegerType
+import org.apache.spark.unsafe.types.CalendarInterval
+
+class LeftSemiPushdownSuite extends PlanTest {
+
+  object Optimize extends RuleExecutor[LogicalPlan] {
+    val batches =
+      Batch("Subqueries", Once,
+        EliminateSubqueryAliases) ::
+      Batch("Filter Pushdown", FixedPoint(10),
+        CombineFilters,
+        PushDownPredicate,
+        PushDownLeftSemiAntiJoin,
+        BooleanSimplification,
+        CollapseProject) :: Nil
+  }
+
+  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
+
+  val testRelation1 = LocalRelation('d.int)
+
+  test("Project: LeftSemiAnti join pushdown") {
+    val originalQuery = testRelation
+      .select(star())
+      .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd))
+
+    val optimized = Optimize.execute(originalQuery.analyze)
+    val correctAnswer = testRelation
+      .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd))
+      .select('a, 'b, 'c)
+      .analyze
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("Project: LeftSemiAnti join no pushdown because of non-deterministic proj exprs") {
+    val originalQuery = testRelation
+      .select(Rand('a), 'b, 'c)
+      .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd))
+
+    val optimized = Optimize.execute(originalQuery.analyze)
+    comparePlans(optimized, originalQuery.analyze)
+  }
+
+  test("Project: LeftSemiAnti join non correlated scalar subq") {
+    val subq = ScalarSubquery(testRelation.groupBy('b)(sum('c).as("sum")).analyze)
+    val originalQuery = testRelation
+      .select(subq.as("sum"))
+      .join(testRelation1, joinType = LeftSemi, condition = Some('sum === 'd))
+
+    val optimized = Optimize.execute(originalQuery.analyze)
+    val correctAnswer = testRelation
+      .join(testRelation1, joinType = LeftSemi, condition = Some(subq === 'd))
+      .select(subq.as("sum"))
+      .analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("Project: LeftSemiAnti join no pushdown - correlated scalar subq in projection list") {
+    val testRelation2 = LocalRelation('e.int, 'f.int)
+    val subqPlan = testRelation2.groupBy('e)(sum('f).as("sum")).where('e === 'a)
+    val subqExpr = ScalarSubquery(subqPlan)
+    val originalQuery = testRelation
+      .select(subqExpr.as("sum"))
+      .join(testRelation1, joinType = LeftSemi, condition = Some('sum === 'd))
+
+    val optimized = Optimize.execute(originalQuery.analyze)
+    comparePlans(optimized, originalQuery.analyze)
+  }
+
+  test("Aggregate: LeftSemiAnti join pushdown") {
+    val originalQuery = testRelation
+      .groupBy('b)('b, sum('c))
+      .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd))
+
+    val optimized = Optimize.execute(originalQuery.analyze)
+    val correctAnswer = testRelation
+      .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd))
+      .groupBy('b)('b, sum('c))
+      .analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("Aggregate: LeftSemiAnti join no pushdown due to non-deterministic aggr expressions") {
+    val originalQuery = testRelation
+      .groupBy('b)('b, Rand(10).as('c))
+      .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd))
+
+    val optimized = Optimize.execute(originalQuery.analyze)
+    comparePlans(optimized, originalQuery.analyze)
+  }
+
+  test("Aggregate: LeftSemiAnti join partial pushdown") {
+    val originalQuery = testRelation
+      .groupBy('b)('b, sum('c).as('sum))
+      .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd && 'sum === 10))
+
+    val optimized = Optimize.execute(originalQuery.analyze)
+    val correctAnswer = testRelation
+      .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd))
+      .groupBy('b)('b, sum('c).as('sum))
+      .where('sum === 10)
+      .analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("LeftSemiAnti join over aggregate - no pushdown") {
+    val originalQuery = testRelation
+      .groupBy('b)('b, sum('c).as('sum))
+      .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd && 'sum === 'd))
+
+    val optimized = Optimize.execute(originalQuery.analyze)
+    comparePlans(optimized, originalQuery.analyze)
+  }
+
+  test("Aggregate: LeftSemiAnti join non-correlated scalar subq aggr exprs") {
+    val subq = ScalarSubquery(testRelation.groupBy('b)(sum('c).as("sum")).analyze)
+    val originalQuery = testRelation
+      .groupBy('a) ('a, subq.as("sum"))
+      .join(testRelation1, joinType = LeftSemi, condition = Some('sum === 'd && 'a === 'd))
+
+    val optimized = Optimize.execute(originalQuery.analyze)
+    val correctAnswer = testRelation
+      .join(testRelation1, joinType = LeftSemi, condition = Some(subq === 'd && 'a === 'd))
+      .groupBy('a) ('a, subq.as("sum"))
+      .analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("LeftSemiAnti join over Window") {
+    val winExpr = windowExpr(count('b), windowSpec('a :: Nil, 'b.asc :: Nil, UnspecifiedFrame))
+
+    val originalQuery = testRelation
+      .select('a, 'b, 'c, winExpr.as('window))
+      .join(testRelation1, joinType = LeftSemi, condition = Some('a === 'd))
+
+    val optimized = Optimize.execute(originalQuery.analyze)
+
+    val correctAnswer = testRelation
+      .join(testRelation1, joinType = LeftSemi, condition = Some('a === 'd))
+      .select('a, 'b, 'c)
+      .window(winExpr.as('window) :: Nil, 'a :: Nil, 'b.asc :: Nil)
+      .select('a, 'b, 'c, 'window).analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("Window: LeftSemiAnti partial pushdown") {
+    // Attributes from join condition which does not refer to the window partition spec
+    // are kept up in the plan as a Filter operator above Window.
+    val winExpr = windowExpr(count('b), windowSpec('a :: Nil, 'b.asc :: Nil, UnspecifiedFrame))
+
+    val originalQuery = testRelation
+      .select('a, 'b, 'c, winExpr.as('window))
+      .join(testRelation1, joinType = LeftSemi, condition = Some('a === 'd && 'b > 5))
+
+    val optimized = Optimize.execute(originalQuery.analyze)
+
+    val correctAnswer = testRelation
+      .join(testRelation1, joinType = LeftSemi, condition = Some('a === 'd))
+      .select('a, 'b, 'c)
+      .window(winExpr.as('window) :: Nil, 'a :: Nil, 'b.asc :: Nil)
+      .where('b > 5)
+      .select('a, 'b, 'c, 'window).analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("Union: LeftSemiAnti join pushdown") {
+    val testRelation2 = LocalRelation('x.int, 'y.int, 'z.int)
+
+    val originalQuery = Union(Seq(testRelation, testRelation2))
+      .join(testRelation1, joinType = LeftSemi, condition = Some('a === 'd))
+
+    val optimized = Optimize.execute(originalQuery.analyze)
+
+    val correctAnswer = Union(Seq(
+      testRelation.join(testRelation1, joinType = LeftSemi, condition = Some('a === 'd)),
+      testRelation2.join(testRelation1, joinType = LeftSemi, condition = Some('x === 'd))))
+      .analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("Union: LeftSemiAnti join no pushdown in self join scenario") {
+    val testRelation2 = LocalRelation('x.int, 'y.int, 'z.int)
+
+    val originalQuery = Union(Seq(testRelation, testRelation2))
+      .join(testRelation2, joinType = LeftSemi, condition = Some('a === 'x))
+
+    val optimized = Optimize.execute(originalQuery.analyze)
+    comparePlans(optimized, originalQuery.analyze)
+  }
+
+  test("Unary: LeftSemiAnti join pushdown") {
+    val originalQuery = testRelation
+      .select(star())
+      .repartition(1)
+      .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd))
+
+    val optimized = Optimize.execute(originalQuery.analyze)
+    val correctAnswer = testRelation
+      .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd))
+      .select('a, 'b, 'c)
+      .repartition(1)
+      .analyze
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("Unary: LeftSemiAnti join pushdown - empty join condition") {
+    val originalQuery = testRelation
+      .select(star())
+      .repartition(1)
+      .join(testRelation1, joinType = LeftSemi, condition = None)
+
+    val optimized = Optimize.execute(originalQuery.analyze)
+    val correctAnswer = testRelation
+      .join(testRelation1, joinType = LeftSemi, condition = None)
+      .select('a, 'b, 'c)
+      .repartition(1)
+      .analyze
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("Unary: LeftSemiAnti join pushdown - partial pushdown") {
+    val testRelationWithArrayType = LocalRelation('a.int, 'b.int, 'c_arr.array(IntegerType))
+    val originalQuery = testRelationWithArrayType
+      .generate(Explode('c_arr), alias = Some("arr"), outputNames = Seq("out_col"))
+      .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd && 'b === 'out_col))
+
+    val optimized = Optimize.execute(originalQuery.analyze)
+    val correctAnswer = testRelationWithArrayType
+      .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd))
+      .generate(Explode('c_arr), alias = Some("arr"), outputNames = Seq("out_col"))
+      .where('b === 'out_col)
+      .analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("Unary: LeftSemiAnti join pushdown - no pushdown") {
+    val testRelationWithArrayType = LocalRelation('a.int, 'b.int, 'c_arr.array(IntegerType))
+    val originalQuery = testRelationWithArrayType
+      .generate(Explode('c_arr), alias = Some("arr"), outputNames = Seq("out_col"))
+      .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd && 'd === 'out_col))
+
+    val optimized = Optimize.execute(originalQuery.analyze)
+    comparePlans(optimized, originalQuery.analyze)
+  }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index 98a8ad5..b77048a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -345,10 +345,10 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared
     val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value")
     val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value")
     // Assume the execution plan is
-    // ... -> BroadcastHashJoin(nodeId = 0)
+    // ... -> BroadcastHashJoin(nodeId = 1)
     val df = df1.join(broadcast(df2), $"key" === $"key2", "leftsemi")
     testSparkPlanMetrics(df, 2, Map(
-      0L -> (("BroadcastHashJoin", Map(
+      1L -> (("BroadcastHashJoin", Map(
         "number of output rows" -> 2L))))
     )
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org