You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2022/07/13 06:06:35 UTC

[spark] branch master updated: [SPARK-39651][SQL] Prune filter condition if compare with rand is deterministic

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c800d296eb0 [SPARK-39651][SQL] Prune filter condition if compare with rand is deterministic
c800d296eb0 is described below

commit c800d296eb000331cc708da99a741b7e1a5f7c48
Author: Jiaan Geng <be...@163.com>
AuthorDate: Wed Jul 13 14:06:14 2022 +0800

    [SPARK-39651][SQL] Prune filter condition if compare with rand is deterministic
    
    ### What changes were proposed in this pull request?
    Currently, the SQL show below evaluate rand(1) < 2 for rows one by one.
    `SELECT * FROM tab WHERE rand(1) < 2`
    
    In fact, we can prune the filter condition.
    
    ### Why are the changes needed?
    Prune filter condition and improve the performance.
    
    ### Does this PR introduce _any_ user-facing change?
    'No'.
    The internal behavior.
    
    ### How was this patch tested?
    New tests.
    
    Closes #37040 from beliefer/SPARK-39651.
    
    Authored-by: Jiaan Geng <be...@163.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../sql/catalyst/optimizer/OptimizeRand.scala      |  63 ++++++++
 .../spark/sql/catalyst/optimizer/Optimizer.scala   |   1 +
 .../sql/catalyst/rules/RuleIdCollection.scala      |   1 +
 .../sql/catalyst/optimizer/OptimizeRandSuite.scala | 176 +++++++++++++++++++++
 4 files changed, 241 insertions(+)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeRand.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeRand.scala
new file mode 100644
index 00000000000..09aa50c6d39
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeRand.scala
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.expressions.{BinaryComparison, DoubleLiteral, Expression, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Rand}
+import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.TreePattern.{BINARY_COMPARISON, EXPRESSION_WITH_RANDOM_SEED, LITERAL}
+
+/**
+ * Rand() generates a random column with i.i.d. uniformly distributed values in [0, 1), so
+ * compare double literal value with 1.0 or 0.0 could eliminate Rand() in binary comparison.
+ *
+ * 1. Converts the binary comparison to true literal when the comparison value must be true.
+ * 2. Converts the binary comparison to false literal when the comparison value must be false.
+ */
+object OptimizeRand extends Rule[LogicalPlan] {
+  def apply(plan: LogicalPlan): LogicalPlan =
+    plan.transformAllExpressionsWithPruning(_.containsAllPatterns(
+      EXPRESSION_WITH_RANDOM_SEED, LITERAL, BINARY_COMPARISON), ruleId) {
+      case op @ BinaryComparison(DoubleLiteral(_), _: Rand) => eliminateRand(swapComparison(op))
+      case op @ BinaryComparison(_: Rand, DoubleLiteral(_)) => eliminateRand(op)
+  }
+
+  /**
+   * Swaps the left and right sides of some binary comparisons. e.g., transform "a < b" to "b > a"
+   */
+  private def swapComparison(comparison: BinaryComparison): BinaryComparison = comparison match {
+    case a LessThan b => GreaterThan(b, a)
+    case a LessThanOrEqual b => GreaterThanOrEqual(b, a)
+    case a GreaterThan b => LessThan(b, a)
+    case a GreaterThanOrEqual b => LessThanOrEqual(b, a)
+    case o => o
+  }
+
+  private def eliminateRand(op: BinaryComparison): Expression = op match {
+    case GreaterThan(_: Rand, DoubleLiteral(value)) =>
+      if (value < 0.0) TrueLiteral else if (value >= 1.0) FalseLiteral else op
+    case GreaterThanOrEqual(_: Rand, DoubleLiteral(value)) =>
+      if (value <= 0.0) TrueLiteral else if (value >= 1.0) FalseLiteral else op
+    case LessThan(_: Rand, DoubleLiteral(value)) =>
+      if (value >= 1.0) TrueLiteral else if (value <= 0.0) FalseLiteral else op
+    case LessThanOrEqual(_: Rand, DoubleLiteral(value)) =>
+      if (value >= 1.0) TrueLiteral else if (value < 0.0) FalseLiteral else op
+    case other => other
+  }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 8ab08ba878e..6afb6a5424b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -107,6 +107,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
         ConstantPropagation,
         FoldablePropagation,
         OptimizeIn,
+        OptimizeRand,
         ConstantFolding,
         EliminateAggregateFilter,
         ReorderAssociativeOperator,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
index 1204fa8c604..2f118db8248 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
@@ -122,6 +122,7 @@ object RuleIdCollection {
       "org.apache.spark.sql.catalyst.optimizer.ObjectSerializerPruning" ::
       "org.apache.spark.sql.catalyst.optimizer.OptimizeCsvJsonExprs" ::
       "org.apache.spark.sql.catalyst.optimizer.OptimizeIn" ::
+      "org.apache.spark.sql.catalyst.optimizer.OptimizeRand" ::
       "org.apache.spark.sql.catalyst.optimizer.OptimizeOneRowPlan" ::
       "org.apache.spark.sql.catalyst.optimizer.Optimizer$OptimizeSubqueries" ::
       "org.apache.spark.sql.catalyst.optimizer.OptimizeRepartition" ::
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeRandSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeRandSuite.scala
new file mode 100644
index 00000000000..55b4d4ff928
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeRandSuite.scala
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions.{Alias, And, Literal, Or}
+import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+
+class OptimizeRandSuite extends PlanTest {
+
+  object Optimize extends RuleExecutor[LogicalPlan] {
+    val batches =
+      Batch("ConstantFolding", FixedPoint(10),
+        ConstantFolding,
+        BooleanSimplification,
+        OptimizeRand,
+        PruneFilters) :: Nil
+  }
+
+  val testRelation = LocalRelation($"a".int, $"b".int, $"c".int)
+  val x = testRelation.where($"a".attr.in(1, 3, 5)).subquery("x")
+  val literal0d = Literal(0d)
+  val literal1d = Literal(1d)
+  val literalHalf = Literal(0.5)
+  val negativeLiteral1d = Literal(-1d)
+  val rand5 = rand(5)
+
+  test("Optimize binary comparison with rand") {
+
+    // Optimize Rand to true literals.
+    Seq(
+      literal1d > rand5,
+      rand5 > negativeLiteral1d,
+      literal1d >= rand5,
+      rand5 >= literal0d,
+      rand5 < literal1d,
+      negativeLiteral1d < rand5,
+      rand5 <= literal1d,
+      literal0d <= rand5
+    ).foreach { comparison =>
+      val plan = testRelation.select(comparison.as("flag")).analyze
+      val actual = Optimize.execute(plan)
+      val correctAnswer = testRelation.select(Alias(TrueLiteral, "flag")()).analyze
+      comparePlans(actual, correctAnswer)
+    }
+
+    // Optimize Rand to false literals.
+    Seq(
+      literal0d > rand5,
+      rand5 > literal1d,
+      negativeLiteral1d >= rand5,
+      rand5 >= literal1d,
+      rand5 < literal0d,
+      literal1d < rand5,
+      rand5 <= negativeLiteral1d,
+      literal1d < rand5
+    ).foreach { comparison =>
+      val plan = testRelation.select(comparison.as("flag")).analyze
+      val actual = Optimize.execute(plan)
+      val correctAnswer = testRelation.select(Alias(FalseLiteral, "flag")()).analyze
+      comparePlans(actual, correctAnswer)
+    }
+
+    // Rand cannot be eliminated.
+    Seq(
+      rand5 > literal0d,
+      rand5 >= literalHalf,
+      rand5 < literalHalf,
+      rand5 <= literal0d
+    ).foreach { comparison =>
+      val plan = testRelation.select(comparison.as("flag")).analyze
+      val actual = Optimize.execute(plan)
+      comparePlans(actual, plan)
+    }
+  }
+
+  test("Prune filter conditions with rand") {
+
+    // Optimize Rand to true literals.
+    Seq(
+      literal1d > rand5,
+      literal1d >= rand5,
+      rand5 < literal1d,
+      rand5 <= literal1d
+    ).foreach { condition =>
+      val plan = x.where(condition).analyze
+      val actual = Optimize.execute(plan)
+      val correctAnswer = x.analyze
+      comparePlans(actual, correctAnswer)
+    }
+
+    // Optimize Rand to false literals.
+    Seq(
+      literal1d <= rand5,
+      literal1d < rand5,
+      rand5 >= literal1d,
+      rand5 > literal1d
+    ).foreach { condition =>
+      val plan = x.where(condition).analyze
+      val actual = Optimize.execute(plan)
+      val correctAnswer = testRelation.analyze
+      comparePlans(actual, correctAnswer)
+    }
+  }
+
+  test("Constant folding with rand") {
+
+    Seq(
+      And(literal1d > rand5, literal1d >= rand5),
+      And(rand5 < literal1d, rand5 <= literal1d)
+    ).foreach { condition =>
+      val plan = x.where(condition).analyze
+      val actual = Optimize.execute(plan)
+      val correctAnswer = x.analyze
+      comparePlans(actual, correctAnswer)
+    }
+
+    Seq(
+      Or(literal1d <= rand5, literal1d < rand5),
+      Or(rand5 >= literal1d, rand5 > literal1d)
+    ).foreach { condition =>
+      val plan = x.where(condition).analyze
+      val actual = Optimize.execute(plan)
+      val correctAnswer = testRelation.analyze
+      comparePlans(actual, correctAnswer)
+    }
+  }
+
+  test("Simplify filter conditions with rand") {
+    val aIsNotNull = $"a".isNotNull
+
+    Seq(
+      And(literal1d > rand5, aIsNotNull),
+      And(literal1d >= rand5, aIsNotNull),
+      And(rand5 < literal1d, aIsNotNull),
+      And(rand5 <= literal1d, aIsNotNull)
+    ).foreach { condition =>
+      val plan = testRelation.where(condition).analyze
+      val actual = Optimize.execute(plan)
+      val correctAnswer = testRelation.where(condition.right).analyze
+      comparePlans(actual, correctAnswer)
+    }
+
+    Seq(
+      Or(literal1d <= rand5, aIsNotNull),
+      Or(literal1d < rand5, aIsNotNull),
+      Or(rand5 >= literal1d, aIsNotNull),
+      Or(rand5 > literal1d, aIsNotNull)
+    ).foreach { condition =>
+      val plan = testRelation.where(condition).analyze
+      val actual = Optimize.execute(plan)
+      val correctAnswer = testRelation.where(condition.right).analyze
+      comparePlans(actual, correctAnswer)
+    }
+  }
+
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org