You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2022/07/13 06:06:35 UTC
[spark] branch master updated: [SPARK-39651][SQL] Prune filter condition if compare with rand is deterministic
This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new c800d296eb0 [SPARK-39651][SQL] Prune filter condition if compare with rand is deterministic
c800d296eb0 is described below
commit c800d296eb000331cc708da99a741b7e1a5f7c48
Author: Jiaan Geng <be...@163.com>
AuthorDate: Wed Jul 13 14:06:14 2022 +0800
[SPARK-39651][SQL] Prune filter condition if compare with rand is deterministic
### What changes were proposed in this pull request?
Currently, the SQL show below evaluate rand(1) < 2 for rows one by one.
`SELECT * FROM tab WHERE rand(1) < 2`
In fact, we can prune the filter condition.
### Why are the changes needed?
Prune filter condition and improve the performance.
### Does this PR introduce _any_ user-facing change?
'No'.
The internal behavior.
### How was this patch tested?
New tests.
Closes #37040 from beliefer/SPARK-39651.
Authored-by: Jiaan Geng <be...@163.com>
Signed-off-by: Wenchen Fan <we...@databricks.com>
---
.../sql/catalyst/optimizer/OptimizeRand.scala | 63 ++++++++
.../spark/sql/catalyst/optimizer/Optimizer.scala | 1 +
.../sql/catalyst/rules/RuleIdCollection.scala | 1 +
.../sql/catalyst/optimizer/OptimizeRandSuite.scala | 176 +++++++++++++++++++++
4 files changed, 241 insertions(+)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeRand.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeRand.scala
new file mode 100644
index 00000000000..09aa50c6d39
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeRand.scala
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.expressions.{BinaryComparison, DoubleLiteral, Expression, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Rand}
+import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.TreePattern.{BINARY_COMPARISON, EXPRESSION_WITH_RANDOM_SEED, LITERAL}
+
+/**
+ * Rand() generates a random column with i.i.d. uniformly distributed values in [0, 1), so
+ * compare double literal value with 1.0 or 0.0 could eliminate Rand() in binary comparison.
+ *
+ * 1. Converts the binary comparison to true literal when the comparison value must be true.
+ * 2. Converts the binary comparison to false literal when the comparison value must be false.
+ */
+object OptimizeRand extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan =
+ plan.transformAllExpressionsWithPruning(_.containsAllPatterns(
+ EXPRESSION_WITH_RANDOM_SEED, LITERAL, BINARY_COMPARISON), ruleId) {
+ case op @ BinaryComparison(DoubleLiteral(_), _: Rand) => eliminateRand(swapComparison(op))
+ case op @ BinaryComparison(_: Rand, DoubleLiteral(_)) => eliminateRand(op)
+ }
+
+ /**
+ * Swaps the left and right sides of some binary comparisons. e.g., transform "a < b" to "b > a"
+ */
+ private def swapComparison(comparison: BinaryComparison): BinaryComparison = comparison match {
+ case a LessThan b => GreaterThan(b, a)
+ case a LessThanOrEqual b => GreaterThanOrEqual(b, a)
+ case a GreaterThan b => LessThan(b, a)
+ case a GreaterThanOrEqual b => LessThanOrEqual(b, a)
+ case o => o
+ }
+
+ private def eliminateRand(op: BinaryComparison): Expression = op match {
+ case GreaterThan(_: Rand, DoubleLiteral(value)) =>
+ if (value < 0.0) TrueLiteral else if (value >= 1.0) FalseLiteral else op
+ case GreaterThanOrEqual(_: Rand, DoubleLiteral(value)) =>
+ if (value <= 0.0) TrueLiteral else if (value >= 1.0) FalseLiteral else op
+ case LessThan(_: Rand, DoubleLiteral(value)) =>
+ if (value >= 1.0) TrueLiteral else if (value <= 0.0) FalseLiteral else op
+ case LessThanOrEqual(_: Rand, DoubleLiteral(value)) =>
+ if (value >= 1.0) TrueLiteral else if (value < 0.0) FalseLiteral else op
+ case other => other
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 8ab08ba878e..6afb6a5424b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -107,6 +107,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
ConstantPropagation,
FoldablePropagation,
OptimizeIn,
+ OptimizeRand,
ConstantFolding,
EliminateAggregateFilter,
ReorderAssociativeOperator,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
index 1204fa8c604..2f118db8248 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
@@ -122,6 +122,7 @@ object RuleIdCollection {
"org.apache.spark.sql.catalyst.optimizer.ObjectSerializerPruning" ::
"org.apache.spark.sql.catalyst.optimizer.OptimizeCsvJsonExprs" ::
"org.apache.spark.sql.catalyst.optimizer.OptimizeIn" ::
+ "org.apache.spark.sql.catalyst.optimizer.OptimizeRand" ::
"org.apache.spark.sql.catalyst.optimizer.OptimizeOneRowPlan" ::
"org.apache.spark.sql.catalyst.optimizer.Optimizer$OptimizeSubqueries" ::
"org.apache.spark.sql.catalyst.optimizer.OptimizeRepartition" ::
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeRandSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeRandSuite.scala
new file mode 100644
index 00000000000..55b4d4ff928
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeRandSuite.scala
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions.{Alias, And, Literal, Or}
+import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+
+class OptimizeRandSuite extends PlanTest {
+
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches =
+ Batch("ConstantFolding", FixedPoint(10),
+ ConstantFolding,
+ BooleanSimplification,
+ OptimizeRand,
+ PruneFilters) :: Nil
+ }
+
+ val testRelation = LocalRelation($"a".int, $"b".int, $"c".int)
+ val x = testRelation.where($"a".attr.in(1, 3, 5)).subquery("x")
+ val literal0d = Literal(0d)
+ val literal1d = Literal(1d)
+ val literalHalf = Literal(0.5)
+ val negativeLiteral1d = Literal(-1d)
+ val rand5 = rand(5)
+
+ test("Optimize binary comparison with rand") {
+
+ // Optimize Rand to true literals.
+ Seq(
+ literal1d > rand5,
+ rand5 > negativeLiteral1d,
+ literal1d >= rand5,
+ rand5 >= literal0d,
+ rand5 < literal1d,
+ negativeLiteral1d < rand5,
+ rand5 <= literal1d,
+ literal0d <= rand5
+ ).foreach { comparison =>
+ val plan = testRelation.select(comparison.as("flag")).analyze
+ val actual = Optimize.execute(plan)
+ val correctAnswer = testRelation.select(Alias(TrueLiteral, "flag")()).analyze
+ comparePlans(actual, correctAnswer)
+ }
+
+ // Optimize Rand to false literals.
+ Seq(
+ literal0d > rand5,
+ rand5 > literal1d,
+ negativeLiteral1d >= rand5,
+ rand5 >= literal1d,
+ rand5 < literal0d,
+ literal1d < rand5,
+ rand5 <= negativeLiteral1d,
+ literal1d < rand5
+ ).foreach { comparison =>
+ val plan = testRelation.select(comparison.as("flag")).analyze
+ val actual = Optimize.execute(plan)
+ val correctAnswer = testRelation.select(Alias(FalseLiteral, "flag")()).analyze
+ comparePlans(actual, correctAnswer)
+ }
+
+ // Rand cannot be eliminated.
+ Seq(
+ rand5 > literal0d,
+ rand5 >= literalHalf,
+ rand5 < literalHalf,
+ rand5 <= literal0d
+ ).foreach { comparison =>
+ val plan = testRelation.select(comparison.as("flag")).analyze
+ val actual = Optimize.execute(plan)
+ comparePlans(actual, plan)
+ }
+ }
+
+ test("Prune filter conditions with rand") {
+
+ // Optimize Rand to true literals.
+ Seq(
+ literal1d > rand5,
+ literal1d >= rand5,
+ rand5 < literal1d,
+ rand5 <= literal1d
+ ).foreach { condition =>
+ val plan = x.where(condition).analyze
+ val actual = Optimize.execute(plan)
+ val correctAnswer = x.analyze
+ comparePlans(actual, correctAnswer)
+ }
+
+ // Optimize Rand to false literals.
+ Seq(
+ literal1d <= rand5,
+ literal1d < rand5,
+ rand5 >= literal1d,
+ rand5 > literal1d
+ ).foreach { condition =>
+ val plan = x.where(condition).analyze
+ val actual = Optimize.execute(plan)
+ val correctAnswer = testRelation.analyze
+ comparePlans(actual, correctAnswer)
+ }
+ }
+
+ test("Constant folding with rand") {
+
+ Seq(
+ And(literal1d > rand5, literal1d >= rand5),
+ And(rand5 < literal1d, rand5 <= literal1d)
+ ).foreach { condition =>
+ val plan = x.where(condition).analyze
+ val actual = Optimize.execute(plan)
+ val correctAnswer = x.analyze
+ comparePlans(actual, correctAnswer)
+ }
+
+ Seq(
+ Or(literal1d <= rand5, literal1d < rand5),
+ Or(rand5 >= literal1d, rand5 > literal1d)
+ ).foreach { condition =>
+ val plan = x.where(condition).analyze
+ val actual = Optimize.execute(plan)
+ val correctAnswer = testRelation.analyze
+ comparePlans(actual, correctAnswer)
+ }
+ }
+
+ test("Simplify filter conditions with rand") {
+ val aIsNotNull = $"a".isNotNull
+
+ Seq(
+ And(literal1d > rand5, aIsNotNull),
+ And(literal1d >= rand5, aIsNotNull),
+ And(rand5 < literal1d, aIsNotNull),
+ And(rand5 <= literal1d, aIsNotNull)
+ ).foreach { condition =>
+ val plan = testRelation.where(condition).analyze
+ val actual = Optimize.execute(plan)
+ val correctAnswer = testRelation.where(condition.right).analyze
+ comparePlans(actual, correctAnswer)
+ }
+
+ Seq(
+ Or(literal1d <= rand5, aIsNotNull),
+ Or(literal1d < rand5, aIsNotNull),
+ Or(rand5 >= literal1d, aIsNotNull),
+ Or(rand5 > literal1d, aIsNotNull)
+ ).foreach { condition =>
+ val plan = testRelation.where(condition).analyze
+ val actual = Optimize.execute(plan)
+ val correctAnswer = testRelation.where(condition.right).analyze
+ comparePlans(actual, correctAnswer)
+ }
+ }
+
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org