You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kylin.apache.org by xx...@apache.org on 2023/05/06 06:59:13 UTC
[kylin] 12/38: [DIRTY] Rewrite spark InferFiltersFromConstraints
This is an automated email from the ASF dual-hosted git repository.
xxyu pushed a commit to branch kylin5
in repository https://gitbox.apache.org/repos/asf/kylin.git
commit 00e52f4b3f3da7715a6e0fbde34c9d5db7ff6e61
Author: Mingming Ge <7m...@gmail.com>
AuthorDate: Sun Jan 29 18:45:04 2023 +0800
[DIRTY] Rewrite spark InferFiltersFromConstraints
---
.../java/org/apache/kylin/common/KapConfig.java | 4 +
.../scala/org/apache/spark/sql/SparderEnv.scala | 31 +-
.../RewriteInferFiltersFromConstraints.scala | 101 +++++++
.../RewriteInferFiltersFromConstraintsSuite.scala | 320 +++++++++++++++++++++
4 files changed, 439 insertions(+), 17 deletions(-)
diff --git a/src/core-common/src/main/java/org/apache/kylin/common/KapConfig.java b/src/core-common/src/main/java/org/apache/kylin/common/KapConfig.java
index c8aae82ddd..4f2132760e 100644
--- a/src/core-common/src/main/java/org/apache/kylin/common/KapConfig.java
+++ b/src/core-common/src/main/java/org/apache/kylin/common/KapConfig.java
@@ -265,6 +265,10 @@ public class KapConfig {
return Integer.parseInt(config.getOptional("kylin.query.engine.spark-sql-shuffle-partitions", "-1"));
}
+ public Boolean isConstraintPropagationEnabled() {
+ return Boolean.parseBoolean(config.getOptional("kylin.query.engine.spark-constraint-propagation-enabled", FALSE));
+ }
+
/**
* LDAP filter
*/
diff --git a/src/spark-project/sparder/src/main/scala/org/apache/spark/sql/SparderEnv.scala b/src/spark-project/sparder/src/main/scala/org/apache/spark/sql/SparderEnv.scala
index 41f6abe771..59ce599fce 100644
--- a/src/spark-project/sparder/src/main/scala/org/apache/spark/sql/SparderEnv.scala
+++ b/src/spark-project/sparder/src/main/scala/org/apache/spark/sql/SparderEnv.scala
@@ -23,13 +23,12 @@ import java.security.PrivilegedAction
import java.util.Map
import java.util.concurrent.locks.ReentrantLock
import java.util.concurrent.{Callable, ExecutorService}
-
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.security.UserGroupInformation
import org.apache.kylin.common.exception.{KylinException, KylinTimeoutException, ServerErrorCode}
import org.apache.kylin.common.msg.MsgPicker
import org.apache.kylin.common.util.{DefaultHostInfoFetcher, HadoopUtil, S3AUtil}
-import org.apache.kylin.common.{KylinConfig, QueryContext}
+import org.apache.kylin.common.{KapConfig, KylinConfig, QueryContext}
import org.apache.kylin.metadata.model.{NTableMetadataManager, TableExtDesc}
import org.apache.kylin.metadata.project.NProjectManager
import org.apache.kylin.query.runtime.plan.QueryToExecutionIDCache
@@ -39,7 +38,7 @@ import org.apache.spark.sql.KylinSession._
import org.apache.spark.sql.catalyst.optimizer.ConvertInnerJoinToSemiJoin
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.execution.datasource.{KylinSourceStrategy, LayoutFileSourceStrategy}
+import org.apache.spark.sql.execution.datasource.{KylinSourceStrategy, LayoutFileSourceStrategy, RewriteInferFiltersFromConstraints}
import org.apache.spark.sql.execution.ui.PostQueryExecutionForKylin
import org.apache.spark.sql.hive.ReplaceLocationRule
import org.apache.spark.sql.udf.UdfManager
@@ -223,29 +222,17 @@ object SparderEnv extends Logging {
SparkSession.builder
.master("local")
.appName("sparder-local-sql-context")
- .withExtensions { ext =>
- ext.injectPlannerStrategy(_ => KylinSourceStrategy)
- ext.injectPlannerStrategy(_ => LayoutFileSourceStrategy)
- ext.injectPostHocResolutionRule(ReplaceLocationRule)
- ext.injectOptimizerRule(_ => new ConvertInnerJoinToSemiJoin())
- }
.enableHiveSupport()
.getOrCreateKylinSession()
case _ =>
SparkSession.builder
.appName(appName)
.master("yarn")
- //if user defined other master in kylin.properties,
- // it will get overwrite later in org.apache.spark.sql.KylinSession.KylinBuilder.initSparkConf
- .withExtensions { ext =>
- ext.injectPlannerStrategy(_ => KylinSourceStrategy)
- ext.injectPlannerStrategy(_ => LayoutFileSourceStrategy)
- ext.injectPostHocResolutionRule(ReplaceLocationRule)
- ext.injectOptimizerRule(_ => new ConvertInnerJoinToSemiJoin())
- }
.enableHiveSupport()
.getOrCreateKylinSession()
}
+
+ injectExtensions(sparkSession.extensions)
spark = sparkSession
logInfo("Spark context started successfully with stack trace:")
logInfo(Thread.currentThread().getStackTrace.mkString("\n"))
@@ -277,6 +264,16 @@ object SparderEnv extends Logging {
}
}
+ def injectExtensions(sse: SparkSessionExtensions): Unit = {
+ sse.injectPlannerStrategy(_ => KylinSourceStrategy)
+ sse.injectPlannerStrategy(_ => LayoutFileSourceStrategy)
+ sse.injectPostHocResolutionRule(ReplaceLocationRule)
+ sse.injectOptimizerRule(_ => new ConvertInnerJoinToSemiJoin())
+ if (KapConfig.getInstanceFromEnv.isConstraintPropagationEnabled) {
+ sse.injectOptimizerRule(_ => RewriteInferFiltersFromConstraints)
+ }
+ }
+
def registerListener(sc: SparkContext): Unit = {
val sparkListener = new SparkListener {
diff --git a/src/spark-project/spark-common/src/main/scala/org/apache/spark/sql/execution/datasource/RewriteInferFiltersFromConstraints.scala b/src/spark-project/spark-common/src/main/scala/org/apache/spark/sql/execution/datasource/RewriteInferFiltersFromConstraints.scala
new file mode 100644
index 0000000000..7d5132b744
--- /dev/null
+++ b/src/spark-project/spark-common/src/main/scala/org/apache/spark/sql/execution/datasource/RewriteInferFiltersFromConstraints.scala
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasource
+
+import org.apache.spark.sql.catalyst.expressions.{And, Expression, ExpressionSet}
+import org.apache.spark.sql.catalyst.optimizer.InferFiltersFromConstraints.{constructIsNotNullConstraints, inferAdditionalConstraints}
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical.{Filter, Join, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.TreePattern.{FILTER, JOIN}
+
+object RewriteInferFiltersFromConstraints extends Rule[LogicalPlan] {
+
+ def apply(plan: LogicalPlan): LogicalPlan = {
+ if (conf.constraintPropagationEnabled) {
+ inferFilters(plan)
+ } else {
+ plan
+ }
+ }
+
+ private def inferFilters(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
+ _.containsAnyPattern(FILTER, JOIN)) {
+ case filter @ Filter(condition, child) =>
+ val newFilters = filter.constraints --
+ (child.constraints ++ splitConjunctivePredicates(condition))
+ if (newFilters.nonEmpty) {
+ Filter(And(newFilters.reduce(And), condition), child)
+ } else {
+ filter
+ }
+
+ case join @ Join(left, right, joinType, conditionOpt, _) =>
+ joinType match {
+ // For inner join, we can infer additional filters for both sides. LeftSemi is kind of an
+ // inner join, it just drops the right side in the final output.
+ case _: InnerLike | LeftSemi | LeftOuter =>
+ val allConstraints = getAllConstraints(left, right, conditionOpt)
+ val newLeft = inferNewFilter(left, allConstraints)
+ val newRight = inferNewFilter(right, allConstraints)
+ join.copy(left = newLeft, right = newRight)
+
+ // For right outer join, we can only infer additional filters for left side.
+ case RightOuter =>
+ val allConstraints = getAllConstraints(left, right, conditionOpt)
+ val newLeft = inferNewFilter(left, allConstraints)
+ join.copy(left = newLeft)
+
+ case LeftAnti =>
+ val allConstraints = getAllConstraints(left, right, conditionOpt)
+ val newRight = inferNewFilter(right, allConstraints)
+ join.copy(right = newRight)
+
+ case _ => join
+ }
+ }
+
+ private def getAllConstraints(
+ left: LogicalPlan,
+ right: LogicalPlan,
+ conditionOpt: Option[Expression]): ExpressionSet = {
+ val baseConstraints = left.constraints.union(right.constraints)
+ .union(ExpressionSet(conditionOpt.map(splitConjunctivePredicates).getOrElse(Nil)))
+ baseConstraints.union(inferAdditionalConstraints(baseConstraints))
+ }
+
+ private def inferNewFilter(plan: LogicalPlan, constraints: ExpressionSet): LogicalPlan = {
+ val newPredicates = constraints
+ .union(constructIsNotNullConstraints(constraints, plan.output))
+ .filter { c =>
+ c.references.nonEmpty && c.references.subsetOf(plan.outputSet) && c.deterministic
+ } -- plan.constraints
+ if (newPredicates.isEmpty) {
+ plan
+ } else {
+ Filter(newPredicates.reduce(And), plan)
+ }
+ }
+
+ protected def splitConjunctivePredicates(condition: Expression): Seq[Expression] = {
+ condition match {
+ case And(cond1, cond2) =>
+ splitConjunctivePredicates(cond1) ++ splitConjunctivePredicates(cond2)
+ case other => other :: Nil
+ }
+ }
+}
diff --git a/src/spark-project/spark-common/src/test/scala/org/apache/spark/sql/execution/datasource/RewriteInferFiltersFromConstraintsSuite.scala b/src/spark-project/spark-common/src/test/scala/org/apache/spark/sql/execution/datasource/RewriteInferFiltersFromConstraintsSuite.scala
new file mode 100644
index 0000000000..e766dd3b53
--- /dev/null
+++ b/src/spark-project/spark-common/src/test/scala/org/apache/spark/sql/execution/datasource/RewriteInferFiltersFromConstraintsSuite.scala
@@ -0,0 +1,320 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasource
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.optimizer.{BooleanSimplification, CombineFilters, InferFiltersFromConstraints, PruneFilters, PushPredicateThroughJoin, PushPredicateThroughNonJoin, SimplifyBinaryComparison}
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules._
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{IntegerType, LongType}
+
+class RewriteInferFiltersFromConstraintsSuite extends PlanTest {
+
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches =
+ Batch("InferAndPushDownFilters", FixedPoint(100),
+ PushPredicateThroughJoin,
+ PushPredicateThroughNonJoin,
+ RewriteInferFiltersFromConstraints,
+ CombineFilters,
+ SimplifyBinaryComparison,
+ BooleanSimplification,
+ PruneFilters) :: Nil
+ }
+
+ val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
+
+ private def testConstraintsAfterJoin(
+ x: LogicalPlan,
+ y: LogicalPlan,
+ expectedLeft: LogicalPlan,
+ expectedRight: LogicalPlan,
+ joinType: JoinType,
+ condition: Option[Expression] = Some("x.a".attr === "y.a".attr)) = {
+ val originalQuery = x.join(y, joinType, condition).analyze
+ val correctAnswer = expectedLeft.join(expectedRight, joinType, condition).analyze
+ val optimized = Optimize.execute(originalQuery)
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("filter: filter out constraints in condition") {
+ val originalQuery = testRelation.where('a === 1 && 'a === 'b).analyze
+ val correctAnswer = testRelation
+ .where(IsNotNull('a) && IsNotNull('b) && 'a === 'b && 'a === 1 && 'b === 1).analyze
+ val optimized = Optimize.execute(originalQuery)
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("single inner join: filter out values on either side on equi-join keys") {
+ val x = testRelation.subquery('x)
+ val y = testRelation.subquery('y)
+ val originalQuery = x.join(y,
+ condition = Some(("x.a".attr === "y.a".attr) && ("x.a".attr === 1) && ("y.c".attr > 5)))
+ .analyze
+ val left = x.where(IsNotNull('a) && "x.a".attr === 1)
+ val right = y.where(IsNotNull('a) && IsNotNull('c) && "y.c".attr > 5 && "y.a".attr === 1)
+ val correctAnswer = left.join(right, condition = Some("x.a".attr === "y.a".attr)).analyze
+ val optimized = Optimize.execute(originalQuery)
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("single inner join: filter out nulls on either side on non equal keys") {
+ val x = testRelation.subquery('x)
+ val y = testRelation.subquery('y)
+ val originalQuery = x.join(y,
+ condition = Some(("x.a".attr =!= "y.a".attr) && ("x.b".attr === 1) && ("y.c".attr > 5)))
+ .analyze
+ val left = x.where(IsNotNull('a) && IsNotNull('b) && "x.b".attr === 1)
+ val right = y.where(IsNotNull('a) && IsNotNull('c) && "y.c".attr > 5)
+ val correctAnswer = left.join(right, condition = Some("x.a".attr =!= "y.a".attr)).analyze
+ val optimized = Optimize.execute(originalQuery)
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("single inner join with pre-existing filters: filter out values on either side") {
+ val x = testRelation.subquery('x)
+ val y = testRelation.subquery('y)
+ val originalQuery = x.where('b > 5).join(y.where('a === 10),
+ condition = Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr)).analyze
+ val left = x.where(IsNotNull('a) && 'a === 10 && IsNotNull('b) && 'b > 5)
+ val right = y.where(IsNotNull('a) && IsNotNull('b) && 'a === 10 && 'b > 5)
+ val correctAnswer = left.join(right,
+ condition = Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr)).analyze
+ val optimized = Optimize.execute(originalQuery)
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("single outer join: no null filters are generated") {
+ val x = testRelation.subquery('x)
+ val y = testRelation.subquery('y)
+ val originalQuery = x.join(y, FullOuter,
+ condition = Some("x.a".attr === "y.a".attr)).analyze
+ val optimized = Optimize.execute(originalQuery)
+ comparePlans(optimized, originalQuery)
+ }
+
+ test("multiple inner joins: filter out values on all sides on equi-join keys") {
+ val t1 = testRelation.subquery('t1)
+ val t2 = testRelation.subquery('t2)
+ val t3 = testRelation.subquery('t3)
+ val t4 = testRelation.subquery('t4)
+
+ val originalQuery = t1.where('b > 5)
+ .join(t2, condition = Some("t1.b".attr === "t2.b".attr))
+ .join(t3, condition = Some("t2.b".attr === "t3.b".attr))
+ .join(t4, condition = Some("t3.b".attr === "t4.b".attr)).analyze
+ val correctAnswer = t1.where(IsNotNull('b) && 'b > 5)
+ .join(t2.where(IsNotNull('b) && 'b > 5), condition = Some("t1.b".attr === "t2.b".attr))
+ .join(t3.where(IsNotNull('b) && 'b > 5), condition = Some("t2.b".attr === "t3.b".attr))
+ .join(t4.where(IsNotNull('b) && 'b > 5), condition = Some("t3.b".attr === "t4.b".attr))
+ .analyze
+ val optimized = Optimize.execute(originalQuery)
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("inner join with filter: filter out values on all sides on equi-join keys") {
+ val x = testRelation.subquery('x)
+ val y = testRelation.subquery('y)
+
+ val originalQuery =
+ x.join(y, Inner, Some("x.a".attr === "y.a".attr)).where("x.a".attr > 5).analyze
+ val correctAnswer = x.where(IsNotNull('a) && 'a.attr > 5)
+ .join(y.where(IsNotNull('a) && 'a.attr > 5), Inner, Some("x.a".attr === "y.a".attr)).analyze
+ val optimized = Optimize.execute(originalQuery)
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("inner join with alias: alias contains multiple attributes") {
+ val t1 = testRelation.subquery('t1)
+ val t2 = testRelation.subquery('t2)
+
+ val originalQuery = t1.select('a, Coalesce(Seq('a, 'b)).as('int_col)).as("t")
+ .join(t2, Inner, Some("t.a".attr === "t2.a".attr && "t.int_col".attr === "t2.a".attr))
+ .analyze
+ val correctAnswer = t1
+ .where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'b))) && 'a === Coalesce(Seq('a, 'b)))
+ .select('a, Coalesce(Seq('a, 'b)).as('int_col)).as("t")
+ .join(t2.where(IsNotNull('a)), Inner,
+ Some("t.a".attr === "t2.a".attr && "t.int_col".attr === "t2.a".attr))
+ .analyze
+ val optimized = Optimize.execute(originalQuery)
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("inner join with alias: alias contains single attributes") {
+ val t1 = testRelation.subquery('t1)
+ val t2 = testRelation.subquery('t2)
+
+ val originalQuery = t1.select('a, 'b.as('d)).as("t")
+ .join(t2, Inner, Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr))
+ .analyze
+ val correctAnswer = t1
+ .where(IsNotNull('a) && IsNotNull('b) &&'a === 'b)
+ .select('a, 'b.as('d)).as("t")
+ .join(t2.where(IsNotNull('a)), Inner,
+ Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr))
+ .analyze
+ val optimized = Optimize.execute(originalQuery)
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("generate correct filters for alias that don't produce recursive constraints") {
+ val t1 = testRelation.subquery('t1)
+
+ val originalQuery = t1.select('a.as('x), 'b.as('y)).where('x === 1 && 'x === 'y).analyze
+ val correctAnswer =
+ t1.where('a === 1 && 'b === 1 && 'a === 'b && IsNotNull('a) && IsNotNull('b))
+ .select('a.as('x), 'b.as('y)).analyze
+ val optimized = Optimize.execute(originalQuery)
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("No inferred filter when constraint propagation is disabled") {
+ withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false") {
+ val originalQuery = testRelation.where('a === 1 && 'a === 'b).analyze
+ val optimized = Optimize.execute(originalQuery)
+ comparePlans(optimized, originalQuery)
+ }
+ }
+
+ test("constraints should be inferred from aliased literals") {
+ val originalLeft = testRelation.subquery('left).as("left")
+ val optimizedLeft = testRelation.subquery('left).where(IsNotNull('a) && 'a <=> 2).as("left")
+
+ val right = Project(Seq(Literal(2).as("two")), testRelation.subquery('right)).as("right")
+ val condition = Some("left.a".attr === "right.two".attr)
+
+ val original = originalLeft.join(right, Inner, condition)
+ val correct = optimizedLeft.join(right, Inner, condition)
+
+ comparePlans(Optimize.execute(original.analyze), correct.analyze)
+ }
+
+ test("SPARK-23405: left-semi equal-join should filter out null join keys on both sides") {
+ val x = testRelation.subquery('x)
+ val y = testRelation.subquery('y)
+ testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y.where(IsNotNull('a)), LeftSemi)
+ }
+
+ test("SPARK-21479: Outer join after-join filters push down to null-supplying side") {
+ val x = testRelation.subquery('x)
+ val y = testRelation.subquery('y)
+ val condition = Some("x.a".attr === "y.a".attr)
+ val originalQuery = x.join(y, LeftOuter, condition).where("x.a".attr === 2).analyze
+ val left = x.where(IsNotNull('a) && 'a === 2)
+ val right = y.where(IsNotNull('a) && 'a === 2)
+ val correctAnswer = left.join(right, LeftOuter, condition).analyze
+ val optimized = Optimize.execute(originalQuery)
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("SPARK-21479: Outer join pre-existing filters push down to null-supplying side") {
+ val x = testRelation.subquery('x)
+ val y = testRelation.subquery('y)
+ val condition = Some("x.a".attr === "y.a".attr)
+ val originalQuery = x.join(y.where("y.a".attr > 5), RightOuter, condition).analyze
+ val left = x.where(IsNotNull('a) && 'a > 5)
+ val right = y.where(IsNotNull('a) && 'a > 5)
+ val correctAnswer = left.join(right, RightOuter, condition).analyze
+ val optimized = Optimize.execute(originalQuery)
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("SPARK-21479: Outer join no filter push down to preserved side") {
+ val x = testRelation.subquery('x)
+ val y = testRelation.subquery('y)
+ testConstraintsAfterJoin(
+ x.where("a".attr === 1), y.where("a".attr === 1),
+ x.where(IsNotNull('a) && 'a === 1), y.where(IsNotNull('a) && 'a === 1),
+ LeftOuter)
+ }
+
+ test("SPARK-23564: left anti join should filter out null join keys on right side") {
+ val x = testRelation.subquery('x)
+ val y = testRelation.subquery('y)
+ testConstraintsAfterJoin(x, y, x, y.where(IsNotNull('a)), LeftAnti)
+ }
+
+ test("SPARK-23564: left outer join should filter out null join keys on right side") {
+ val x = testRelation.subquery('x)
+ val y = testRelation.subquery('y)
+ testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y.where(IsNotNull('a)), LeftOuter)
+ }
+
+ test("SPARK-23564: right outer join should filter out null join keys on left side") {
+ val x = testRelation.subquery('x)
+ val y = testRelation.subquery('y)
+ testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y, RightOuter)
+ }
+
+ test("Constraints should be inferred from cast equality constraint(filter higher data type)") {
+ val testRelation1 = LocalRelation('a.int)
+ val testRelation2 = LocalRelation('b.long)
+ val originalLeft = testRelation1.subquery('left)
+ val originalRight = testRelation2.where('b === 1L).subquery('right)
+
+ val left = testRelation1.where(IsNotNull('a) && 'a.cast(LongType) === 1L).subquery('left)
+ val right = testRelation2.where(IsNotNull('b) && 'b === 1L).subquery('right)
+
+ Seq(Some("left.a".attr.cast(LongType) === "right.b".attr),
+ Some("right.b".attr === "left.a".attr.cast(LongType))).foreach { condition =>
+ testConstraintsAfterJoin(originalLeft, originalRight, left, right, Inner, condition)
+ }
+
+ Seq(Some("left.a".attr === "right.b".attr.cast(IntegerType)),
+ Some("right.b".attr.cast(IntegerType) === "left.a".attr)).foreach { condition =>
+ testConstraintsAfterJoin(
+ originalLeft,
+ originalRight,
+ testRelation1.where(IsNotNull('a)).subquery('left),
+ right,
+ Inner,
+ condition)
+ }
+ }
+
+ test("Constraints shouldn't be inferred from cast equality constraint(filter lower data type)") {
+ val testRelation1 = LocalRelation('a.int)
+ val testRelation2 = LocalRelation('b.long)
+ val originalLeft = testRelation1.where('a === 1).subquery('left)
+ val originalRight = testRelation2.subquery('right)
+
+ val left = testRelation1.where(IsNotNull('a) && 'a === 1).subquery('left)
+ val right = testRelation2.where(IsNotNull('b)).subquery('right)
+
+ Seq(Some("left.a".attr.cast(LongType) === "right.b".attr),
+ Some("right.b".attr === "left.a".attr.cast(LongType))).foreach { condition =>
+ testConstraintsAfterJoin(originalLeft, originalRight, left, right, Inner, condition)
+ }
+
+ Seq(Some("left.a".attr === "right.b".attr.cast(IntegerType)),
+ Some("right.b".attr.cast(IntegerType) === "left.a".attr)).foreach { condition =>
+ testConstraintsAfterJoin(
+ originalLeft,
+ originalRight,
+ left,
+ testRelation2.where(IsNotNull('b) && 'b.attr.cast(IntegerType) === 1).subquery('right),
+ Inner,
+ condition)
+ }
+ }
+}