You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kylin.apache.org by xx...@apache.org on 2023/05/06 06:59:13 UTC

[kylin] 12/38: [DIRTY] Rewrite spark InferFiltersFromConstraints

This is an automated email from the ASF dual-hosted git repository.

xxyu pushed a commit to branch kylin5
in repository https://gitbox.apache.org/repos/asf/kylin.git

commit 00e52f4b3f3da7715a6e0fbde34c9d5db7ff6e61
Author: Mingming Ge <7m...@gmail.com>
AuthorDate: Sun Jan 29 18:45:04 2023 +0800

    [DIRTY] Rewrite spark InferFiltersFromConstraints
---
 .../java/org/apache/kylin/common/KapConfig.java    |   4 +
 .../scala/org/apache/spark/sql/SparderEnv.scala    |  31 +-
 .../RewriteInferFiltersFromConstraints.scala       | 101 +++++++
 .../RewriteInferFiltersFromConstraintsSuite.scala  | 320 +++++++++++++++++++++
 4 files changed, 439 insertions(+), 17 deletions(-)

diff --git a/src/core-common/src/main/java/org/apache/kylin/common/KapConfig.java b/src/core-common/src/main/java/org/apache/kylin/common/KapConfig.java
index c8aae82ddd..4f2132760e 100644
--- a/src/core-common/src/main/java/org/apache/kylin/common/KapConfig.java
+++ b/src/core-common/src/main/java/org/apache/kylin/common/KapConfig.java
@@ -265,6 +265,10 @@ public class KapConfig {
         return Integer.parseInt(config.getOptional("kylin.query.engine.spark-sql-shuffle-partitions", "-1"));
     }
 
+    public Boolean isConstraintPropagationEnabled() {
+        return Boolean.parseBoolean(config.getOptional("kylin.query.engine.spark-constraint-propagation-enabled", FALSE));
+    }
+
     /**
      * LDAP filter
      */
diff --git a/src/spark-project/sparder/src/main/scala/org/apache/spark/sql/SparderEnv.scala b/src/spark-project/sparder/src/main/scala/org/apache/spark/sql/SparderEnv.scala
index 41f6abe771..59ce599fce 100644
--- a/src/spark-project/sparder/src/main/scala/org/apache/spark/sql/SparderEnv.scala
+++ b/src/spark-project/sparder/src/main/scala/org/apache/spark/sql/SparderEnv.scala
@@ -23,13 +23,12 @@ import java.security.PrivilegedAction
 import java.util.Map
 import java.util.concurrent.locks.ReentrantLock
 import java.util.concurrent.{Callable, ExecutorService}
-
 import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.security.UserGroupInformation
 import org.apache.kylin.common.exception.{KylinException, KylinTimeoutException, ServerErrorCode}
 import org.apache.kylin.common.msg.MsgPicker
 import org.apache.kylin.common.util.{DefaultHostInfoFetcher, HadoopUtil, S3AUtil}
-import org.apache.kylin.common.{KylinConfig, QueryContext}
+import org.apache.kylin.common.{KapConfig, KylinConfig, QueryContext}
 import org.apache.kylin.metadata.model.{NTableMetadataManager, TableExtDesc}
 import org.apache.kylin.metadata.project.NProjectManager
 import org.apache.kylin.query.runtime.plan.QueryToExecutionIDCache
@@ -39,7 +38,7 @@ import org.apache.spark.sql.KylinSession._
 import org.apache.spark.sql.catalyst.optimizer.ConvertInnerJoinToSemiJoin
 import org.apache.spark.sql.catalyst.parser.ParseException
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.execution.datasource.{KylinSourceStrategy, LayoutFileSourceStrategy}
+import org.apache.spark.sql.execution.datasource.{KylinSourceStrategy, LayoutFileSourceStrategy, RewriteInferFiltersFromConstraints}
 import org.apache.spark.sql.execution.ui.PostQueryExecutionForKylin
 import org.apache.spark.sql.hive.ReplaceLocationRule
 import org.apache.spark.sql.udf.UdfManager
@@ -223,29 +222,17 @@ object SparderEnv extends Logging {
           SparkSession.builder
             .master("local")
             .appName("sparder-local-sql-context")
-            .withExtensions { ext =>
-              ext.injectPlannerStrategy(_ => KylinSourceStrategy)
-              ext.injectPlannerStrategy(_ => LayoutFileSourceStrategy)
-              ext.injectPostHocResolutionRule(ReplaceLocationRule)
-              ext.injectOptimizerRule(_ => new ConvertInnerJoinToSemiJoin())
-            }
             .enableHiveSupport()
             .getOrCreateKylinSession()
         case _ =>
           SparkSession.builder
             .appName(appName)
             .master("yarn")
-            //if user defined other master in kylin.properties,
-            // it will get overwrite later in org.apache.spark.sql.KylinSession.KylinBuilder.initSparkConf
-            .withExtensions { ext =>
-              ext.injectPlannerStrategy(_ => KylinSourceStrategy)
-              ext.injectPlannerStrategy(_ => LayoutFileSourceStrategy)
-              ext.injectPostHocResolutionRule(ReplaceLocationRule)
-              ext.injectOptimizerRule(_ => new ConvertInnerJoinToSemiJoin())
-            }
             .enableHiveSupport()
             .getOrCreateKylinSession()
       }
+
+      injectExtensions(sparkSession.extensions)
       spark = sparkSession
       logInfo("Spark context started successfully with stack trace:")
       logInfo(Thread.currentThread().getStackTrace.mkString("\n"))
@@ -277,6 +264,16 @@ object SparderEnv extends Logging {
     }
   }
 
+  def injectExtensions(sse: SparkSessionExtensions): Unit = {
+    sse.injectPlannerStrategy(_ => KylinSourceStrategy)
+    sse.injectPlannerStrategy(_ => LayoutFileSourceStrategy)
+    sse.injectPostHocResolutionRule(ReplaceLocationRule)
+    sse.injectOptimizerRule(_ => new ConvertInnerJoinToSemiJoin())
+    if (KapConfig.getInstanceFromEnv.isConstraintPropagationEnabled) {
+      sse.injectOptimizerRule(_ => RewriteInferFiltersFromConstraints)
+    }
+  }
+
   def registerListener(sc: SparkContext): Unit = {
     val sparkListener = new SparkListener {
 
diff --git a/src/spark-project/spark-common/src/main/scala/org/apache/spark/sql/execution/datasource/RewriteInferFiltersFromConstraints.scala b/src/spark-project/spark-common/src/main/scala/org/apache/spark/sql/execution/datasource/RewriteInferFiltersFromConstraints.scala
new file mode 100644
index 0000000000..7d5132b744
--- /dev/null
+++ b/src/spark-project/spark-common/src/main/scala/org/apache/spark/sql/execution/datasource/RewriteInferFiltersFromConstraints.scala
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasource
+
+import org.apache.spark.sql.catalyst.expressions.{And, Expression, ExpressionSet}
+import org.apache.spark.sql.catalyst.optimizer.InferFiltersFromConstraints.{constructIsNotNullConstraints, inferAdditionalConstraints}
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical.{Filter, Join, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.TreePattern.{FILTER, JOIN}
+
+object RewriteInferFiltersFromConstraints extends Rule[LogicalPlan] {
+
+  def apply(plan: LogicalPlan): LogicalPlan = {
+    if (conf.constraintPropagationEnabled) {
+      inferFilters(plan)
+    } else {
+      plan
+    }
+  }
+
+  private def inferFilters(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
+    _.containsAnyPattern(FILTER, JOIN)) {
+    case filter @ Filter(condition, child) =>
+      val newFilters = filter.constraints --
+        (child.constraints ++ splitConjunctivePredicates(condition))
+      if (newFilters.nonEmpty) {
+        Filter(And(newFilters.reduce(And), condition), child)
+      } else {
+        filter
+      }
+
+    case join @ Join(left, right, joinType, conditionOpt, _) =>
+      joinType match {
+        // For inner join, we can infer additional filters for both sides. LeftSemi is kind of an
+        // inner join, it just drops the right side in the final output.
+        case _: InnerLike | LeftSemi | LeftOuter =>
+          val allConstraints = getAllConstraints(left, right, conditionOpt)
+          val newLeft = inferNewFilter(left, allConstraints)
+          val newRight = inferNewFilter(right, allConstraints)
+          join.copy(left = newLeft, right = newRight)
+
+        // For right outer join, we can only infer additional filters for left side.
+        case RightOuter =>
+          val allConstraints = getAllConstraints(left, right, conditionOpt)
+          val newLeft = inferNewFilter(left, allConstraints)
+          join.copy(left = newLeft)
+
+        case LeftAnti =>
+          val allConstraints = getAllConstraints(left, right, conditionOpt)
+          val newRight = inferNewFilter(right, allConstraints)
+          join.copy(right = newRight)
+
+        case _ => join
+      }
+  }
+
+  private def getAllConstraints(
+                                 left: LogicalPlan,
+                                 right: LogicalPlan,
+                                 conditionOpt: Option[Expression]): ExpressionSet = {
+    val baseConstraints = left.constraints.union(right.constraints)
+      .union(ExpressionSet(conditionOpt.map(splitConjunctivePredicates).getOrElse(Nil)))
+    baseConstraints.union(inferAdditionalConstraints(baseConstraints))
+  }
+
+  private def inferNewFilter(plan: LogicalPlan, constraints: ExpressionSet): LogicalPlan = {
+    val newPredicates = constraints
+      .union(constructIsNotNullConstraints(constraints, plan.output))
+      .filter { c =>
+        c.references.nonEmpty && c.references.subsetOf(plan.outputSet) && c.deterministic
+      } -- plan.constraints
+    if (newPredicates.isEmpty) {
+      plan
+    } else {
+      Filter(newPredicates.reduce(And), plan)
+    }
+  }
+
+  protected def splitConjunctivePredicates(condition: Expression): Seq[Expression] = {
+    condition match {
+      case And(cond1, cond2) =>
+        splitConjunctivePredicates(cond1) ++ splitConjunctivePredicates(cond2)
+      case other => other :: Nil
+    }
+  }
+}
diff --git a/src/spark-project/spark-common/src/test/scala/org/apache/spark/sql/execution/datasource/RewriteInferFiltersFromConstraintsSuite.scala b/src/spark-project/spark-common/src/test/scala/org/apache/spark/sql/execution/datasource/RewriteInferFiltersFromConstraintsSuite.scala
new file mode 100644
index 0000000000..e766dd3b53
--- /dev/null
+++ b/src/spark-project/spark-common/src/test/scala/org/apache/spark/sql/execution/datasource/RewriteInferFiltersFromConstraintsSuite.scala
@@ -0,0 +1,320 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasource
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.optimizer.{BooleanSimplification, CombineFilters, InferFiltersFromConstraints, PruneFilters, PushPredicateThroughJoin, PushPredicateThroughNonJoin, SimplifyBinaryComparison}
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules._
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{IntegerType, LongType}
+
+class RewriteInferFiltersFromConstraintsSuite extends PlanTest {
+
+  object Optimize extends RuleExecutor[LogicalPlan] {
+    val batches =
+      Batch("InferAndPushDownFilters", FixedPoint(100),
+        PushPredicateThroughJoin,
+        PushPredicateThroughNonJoin,
+        RewriteInferFiltersFromConstraints,
+        CombineFilters,
+        SimplifyBinaryComparison,
+        BooleanSimplification,
+        PruneFilters) :: Nil
+  }
+
+  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
+
+  private def testConstraintsAfterJoin(
+      x: LogicalPlan,
+      y: LogicalPlan,
+      expectedLeft: LogicalPlan,
+      expectedRight: LogicalPlan,
+      joinType: JoinType,
+      condition: Option[Expression] = Some("x.a".attr === "y.a".attr)) = {
+    val originalQuery = x.join(y, joinType, condition).analyze
+    val correctAnswer = expectedLeft.join(expectedRight, joinType, condition).analyze
+    val optimized = Optimize.execute(originalQuery)
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("filter: filter out constraints in condition") {
+    val originalQuery = testRelation.where('a === 1 && 'a === 'b).analyze
+    val correctAnswer = testRelation
+      .where(IsNotNull('a) && IsNotNull('b) && 'a === 'b && 'a === 1 && 'b === 1).analyze
+    val optimized = Optimize.execute(originalQuery)
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("single inner join: filter out values on either side on equi-join keys") {
+    val x = testRelation.subquery('x)
+    val y = testRelation.subquery('y)
+    val originalQuery = x.join(y,
+      condition = Some(("x.a".attr === "y.a".attr) && ("x.a".attr === 1) && ("y.c".attr > 5)))
+      .analyze
+    val left = x.where(IsNotNull('a) && "x.a".attr === 1)
+    val right = y.where(IsNotNull('a) && IsNotNull('c) && "y.c".attr > 5 && "y.a".attr === 1)
+    val correctAnswer = left.join(right, condition = Some("x.a".attr === "y.a".attr)).analyze
+    val optimized = Optimize.execute(originalQuery)
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("single inner join: filter out nulls on either side on non equal keys") {
+    val x = testRelation.subquery('x)
+    val y = testRelation.subquery('y)
+    val originalQuery = x.join(y,
+      condition = Some(("x.a".attr =!= "y.a".attr) && ("x.b".attr === 1) && ("y.c".attr > 5)))
+      .analyze
+    val left = x.where(IsNotNull('a) && IsNotNull('b) && "x.b".attr === 1)
+    val right = y.where(IsNotNull('a) && IsNotNull('c) && "y.c".attr > 5)
+    val correctAnswer = left.join(right, condition = Some("x.a".attr =!= "y.a".attr)).analyze
+    val optimized = Optimize.execute(originalQuery)
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("single inner join with pre-existing filters: filter out values on either side") {
+    val x = testRelation.subquery('x)
+    val y = testRelation.subquery('y)
+    val originalQuery = x.where('b > 5).join(y.where('a === 10),
+      condition = Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr)).analyze
+    val left = x.where(IsNotNull('a) && 'a === 10 && IsNotNull('b) && 'b > 5)
+    val right = y.where(IsNotNull('a) && IsNotNull('b) && 'a === 10 && 'b > 5)
+    val correctAnswer = left.join(right,
+      condition = Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr)).analyze
+    val optimized = Optimize.execute(originalQuery)
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("single outer join: no null filters are generated") {
+    val x = testRelation.subquery('x)
+    val y = testRelation.subquery('y)
+    val originalQuery = x.join(y, FullOuter,
+      condition = Some("x.a".attr === "y.a".attr)).analyze
+    val optimized = Optimize.execute(originalQuery)
+    comparePlans(optimized, originalQuery)
+  }
+
+  test("multiple inner joins: filter out values on all sides on equi-join keys") {
+    val t1 = testRelation.subquery('t1)
+    val t2 = testRelation.subquery('t2)
+    val t3 = testRelation.subquery('t3)
+    val t4 = testRelation.subquery('t4)
+
+    val originalQuery = t1.where('b > 5)
+      .join(t2, condition = Some("t1.b".attr === "t2.b".attr))
+      .join(t3, condition = Some("t2.b".attr === "t3.b".attr))
+      .join(t4, condition = Some("t3.b".attr === "t4.b".attr)).analyze
+    val correctAnswer = t1.where(IsNotNull('b) && 'b > 5)
+      .join(t2.where(IsNotNull('b) && 'b > 5), condition = Some("t1.b".attr === "t2.b".attr))
+      .join(t3.where(IsNotNull('b) && 'b > 5), condition = Some("t2.b".attr === "t3.b".attr))
+      .join(t4.where(IsNotNull('b) && 'b > 5), condition = Some("t3.b".attr === "t4.b".attr))
+      .analyze
+    val optimized = Optimize.execute(originalQuery)
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("inner join with filter: filter out values on all sides on equi-join keys") {
+    val x = testRelation.subquery('x)
+    val y = testRelation.subquery('y)
+
+    val originalQuery =
+      x.join(y, Inner, Some("x.a".attr === "y.a".attr)).where("x.a".attr > 5).analyze
+    val correctAnswer = x.where(IsNotNull('a) && 'a.attr > 5)
+      .join(y.where(IsNotNull('a) && 'a.attr > 5), Inner, Some("x.a".attr === "y.a".attr)).analyze
+    val optimized = Optimize.execute(originalQuery)
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("inner join with alias: alias contains multiple attributes") {
+    val t1 = testRelation.subquery('t1)
+    val t2 = testRelation.subquery('t2)
+
+    val originalQuery = t1.select('a, Coalesce(Seq('a, 'b)).as('int_col)).as("t")
+      .join(t2, Inner, Some("t.a".attr === "t2.a".attr && "t.int_col".attr === "t2.a".attr))
+      .analyze
+    val correctAnswer = t1
+      .where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'b))) && 'a === Coalesce(Seq('a, 'b)))
+      .select('a, Coalesce(Seq('a, 'b)).as('int_col)).as("t")
+      .join(t2.where(IsNotNull('a)), Inner,
+        Some("t.a".attr === "t2.a".attr && "t.int_col".attr === "t2.a".attr))
+      .analyze
+    val optimized = Optimize.execute(originalQuery)
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("inner join with alias: alias contains single attributes") {
+    val t1 = testRelation.subquery('t1)
+    val t2 = testRelation.subquery('t2)
+
+    val originalQuery = t1.select('a, 'b.as('d)).as("t")
+      .join(t2, Inner, Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr))
+      .analyze
+    val correctAnswer = t1
+      .where(IsNotNull('a) && IsNotNull('b) &&'a === 'b)
+      .select('a, 'b.as('d)).as("t")
+      .join(t2.where(IsNotNull('a)), Inner,
+        Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr))
+      .analyze
+    val optimized = Optimize.execute(originalQuery)
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("generate correct filters for alias that don't produce recursive constraints") {
+    val t1 = testRelation.subquery('t1)
+
+    val originalQuery = t1.select('a.as('x), 'b.as('y)).where('x === 1 && 'x === 'y).analyze
+    val correctAnswer =
+      t1.where('a === 1 && 'b === 1 && 'a === 'b && IsNotNull('a) && IsNotNull('b))
+        .select('a.as('x), 'b.as('y)).analyze
+    val optimized = Optimize.execute(originalQuery)
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("No inferred filter when constraint propagation is disabled") {
+    withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false") {
+      val originalQuery = testRelation.where('a === 1 && 'a === 'b).analyze
+      val optimized = Optimize.execute(originalQuery)
+      comparePlans(optimized, originalQuery)
+    }
+  }
+
+  test("constraints should be inferred from aliased literals") {
+    val originalLeft = testRelation.subquery('left).as("left")
+    val optimizedLeft = testRelation.subquery('left).where(IsNotNull('a) && 'a <=> 2).as("left")
+
+    val right = Project(Seq(Literal(2).as("two")), testRelation.subquery('right)).as("right")
+    val condition = Some("left.a".attr === "right.two".attr)
+
+    val original = originalLeft.join(right, Inner, condition)
+    val correct = optimizedLeft.join(right, Inner, condition)
+
+    comparePlans(Optimize.execute(original.analyze), correct.analyze)
+  }
+
+  test("SPARK-23405: left-semi equal-join should filter out null join keys on both sides") {
+    val x = testRelation.subquery('x)
+    val y = testRelation.subquery('y)
+    testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y.where(IsNotNull('a)), LeftSemi)
+  }
+
+  test("SPARK-21479: Outer join after-join filters push down to null-supplying side") {
+    val x = testRelation.subquery('x)
+    val y = testRelation.subquery('y)
+    val condition = Some("x.a".attr === "y.a".attr)
+    val originalQuery = x.join(y, LeftOuter, condition).where("x.a".attr === 2).analyze
+    val left = x.where(IsNotNull('a) && 'a === 2)
+    val right = y.where(IsNotNull('a) && 'a === 2)
+    val correctAnswer = left.join(right, LeftOuter, condition).analyze
+    val optimized = Optimize.execute(originalQuery)
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("SPARK-21479: Outer join pre-existing filters push down to null-supplying side") {
+    val x = testRelation.subquery('x)
+    val y = testRelation.subquery('y)
+    val condition = Some("x.a".attr === "y.a".attr)
+    val originalQuery = x.join(y.where("y.a".attr > 5), RightOuter, condition).analyze
+    val left = x.where(IsNotNull('a) && 'a > 5)
+    val right = y.where(IsNotNull('a) && 'a > 5)
+    val correctAnswer = left.join(right, RightOuter, condition).analyze
+    val optimized = Optimize.execute(originalQuery)
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("SPARK-21479: Outer join no filter push down to preserved side") {
+    val x = testRelation.subquery('x)
+    val y = testRelation.subquery('y)
+    testConstraintsAfterJoin(
+      x.where("a".attr === 1), y.where("a".attr === 1),
+      x.where(IsNotNull('a) && 'a === 1), y.where(IsNotNull('a) && 'a === 1),
+      LeftOuter)
+  }
+
+  test("SPARK-23564: left anti join should filter out null join keys on right side") {
+    val x = testRelation.subquery('x)
+    val y = testRelation.subquery('y)
+    testConstraintsAfterJoin(x, y, x, y.where(IsNotNull('a)), LeftAnti)
+  }
+
+  test("SPARK-23564: left outer join should filter out null join keys on right side") {
+    val x = testRelation.subquery('x)
+    val y = testRelation.subquery('y)
+    testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y.where(IsNotNull('a)), LeftOuter)
+  }
+
+  test("SPARK-23564: right outer join should filter out null join keys on left side") {
+    val x = testRelation.subquery('x)
+    val y = testRelation.subquery('y)
+    testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y, RightOuter)
+  }
+
+  test("Constraints should be inferred from cast equality constraint(filter higher data type)") {
+    val testRelation1 = LocalRelation('a.int)
+    val testRelation2 = LocalRelation('b.long)
+    val originalLeft = testRelation1.subquery('left)
+    val originalRight = testRelation2.where('b === 1L).subquery('right)
+
+    val left = testRelation1.where(IsNotNull('a) && 'a.cast(LongType) === 1L).subquery('left)
+    val right = testRelation2.where(IsNotNull('b) && 'b === 1L).subquery('right)
+
+    Seq(Some("left.a".attr.cast(LongType) === "right.b".attr),
+      Some("right.b".attr === "left.a".attr.cast(LongType))).foreach { condition =>
+      testConstraintsAfterJoin(originalLeft, originalRight, left, right, Inner, condition)
+    }
+
+    Seq(Some("left.a".attr === "right.b".attr.cast(IntegerType)),
+      Some("right.b".attr.cast(IntegerType) === "left.a".attr)).foreach { condition =>
+      testConstraintsAfterJoin(
+        originalLeft,
+        originalRight,
+        testRelation1.where(IsNotNull('a)).subquery('left),
+        right,
+        Inner,
+        condition)
+    }
+  }
+
+  test("Constraints shouldn't be inferred from cast equality constraint(filter lower data type)") {
+    val testRelation1 = LocalRelation('a.int)
+    val testRelation2 = LocalRelation('b.long)
+    val originalLeft = testRelation1.where('a === 1).subquery('left)
+    val originalRight = testRelation2.subquery('right)
+
+    val left = testRelation1.where(IsNotNull('a) && 'a === 1).subquery('left)
+    val right = testRelation2.where(IsNotNull('b)).subquery('right)
+
+    Seq(Some("left.a".attr.cast(LongType) === "right.b".attr),
+      Some("right.b".attr === "left.a".attr.cast(LongType))).foreach { condition =>
+      testConstraintsAfterJoin(originalLeft, originalRight, left, right, Inner, condition)
+    }
+
+    Seq(Some("left.a".attr === "right.b".attr.cast(IntegerType)),
+      Some("right.b".attr.cast(IntegerType) === "left.a".attr)).foreach { condition =>
+      testConstraintsAfterJoin(
+        originalLeft,
+        originalRight,
+        left,
+        testRelation2.where(IsNotNull('b) && 'b.attr.cast(IntegerType) === 1).subquery('right),
+        Inner,
+        condition)
+    }
+  }
+}