You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2018/09/08 18:37:10 UTC
spark git commit: [SPARK-20636] Add new optimization rule to
transpose adjacent Window expressions.
Repository: spark
Updated Branches:
refs/heads/master 26f74b7cb -> 78981efc2
[SPARK-20636] Add new optimization rule to transpose adjacent Window expressions.
## What changes were proposed in this pull request?
Add new optimization rule to eliminate unnecessary shuffling by flipping adjacent Window expressions.
## How was this patch tested?
Tested with unit tests, integration tests, and manual tests.
Closes #17899 from ptkool/adjacent_window_optimization.
Authored-by: ptkool <mi...@shopify.com>
Signed-off-by: gatorsmile <ga...@gmail.com>
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/78981efc
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/78981efc
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/78981efc
Branch: refs/heads/master
Commit: 78981efc2cf321ce93e176d30d49bb1a8bd59eb2
Parents: 26f74b7
Author: ptkool <mi...@shopify.com>
Authored: Sat Sep 8 11:36:55 2018 -0700
Committer: gatorsmile <ga...@gmail.com>
Committed: Sat Sep 8 11:36:55 2018 -0700
----------------------------------------------------------------------
.../sql/catalyst/optimizer/Optimizer.scala | 22 ++++
.../optimizer/TransposeWindowSuite.scala | 114 +++++++++++++++++++
.../sql/DataFrameWindowFunctionsSuite.scala | 45 ++++++--
3 files changed, 170 insertions(+), 11 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/78981efc/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index e4b4f1e..b432ce2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -735,6 +735,28 @@ object CollapseWindow extends Rule[LogicalPlan] {
}
/**
+ * Transpose Adjacent Window Expressions.
+ * - If the partition spec of the parent Window expression is compatible with the partition spec
+ * of the child window expression, transpose them.
+ */
+object TransposeWindow extends Rule[LogicalPlan] {
+ private def compatibleParititions(ps1 : Seq[Expression], ps2: Seq[Expression]): Boolean = {
+ ps1.length < ps2.length && ps2.take(ps1.length).permutations.exists(ps1.zip(_).forall {
+ case (l, r) => l.semanticEquals(r)
+ })
+ }
+
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
+ case w1 @ Window(we1, ps1, os1, w2 @ Window(we2, ps2, os2, grandChild))
+ if w1.references.intersect(w2.windowOutputSet).isEmpty &&
+ w1.expressions.forall(_.deterministic) &&
+ w2.expressions.forall(_.deterministic) &&
+ compatibleParititions(ps1, ps2) =>
+ Project(w1.output, Window(we2, ps2, os2, Window(we1, ps1, os1, grandChild)))
+ }
+}
+
+/**
* Generate a list of additional filters from an operator's existing constraint but remove those
* that are either already part of the operator's condition or are part of the operator's child
* constraints. These filters are currently inserted to the existing conditions in the Filter
http://git-wip-us.apache.org/repos/asf/spark/blob/78981efc/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransposeWindowSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransposeWindowSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransposeWindowSuite.scala
new file mode 100644
index 0000000..58b3d1c
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransposeWindowSuite.scala
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions.Rand
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+
+class TransposeWindowSuite extends PlanTest {
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches =
+ Batch("CollapseProject", FixedPoint(100), CollapseProject, RemoveRedundantProject) ::
+ Batch("FlipWindow", Once, CollapseWindow, TransposeWindow) :: Nil
+ }
+
+ val testRelation = LocalRelation('a.string, 'b.string, 'c.int, 'd.string)
+
+ val a = testRelation.output(0)
+ val b = testRelation.output(1)
+ val c = testRelation.output(2)
+ val d = testRelation.output(3)
+
+ val partitionSpec1 = Seq(a)
+ val partitionSpec2 = Seq(a, b)
+ val partitionSpec3 = Seq(d)
+ val partitionSpec4 = Seq(b, a, d)
+
+ val orderSpec1 = Seq(d.asc)
+ val orderSpec2 = Seq(d.desc)
+
+ test("transpose two adjacent windows with compatible partitions") {
+ val query = testRelation
+ .window(Seq(sum(c).as('sum_a_2)), partitionSpec2, orderSpec2)
+ .window(Seq(sum(c).as('sum_a_1)), partitionSpec1, orderSpec1)
+
+ val analyzed = query.analyze
+ val optimized = Optimize.execute(analyzed)
+
+ val correctAnswer = testRelation
+ .window(Seq(sum(c).as('sum_a_1)), partitionSpec1, orderSpec1)
+ .window(Seq(sum(c).as('sum_a_2)), partitionSpec2, orderSpec2)
+ .select('a, 'b, 'c, 'd, 'sum_a_2, 'sum_a_1)
+
+ comparePlans(optimized, correctAnswer.analyze)
+ }
+
+ test("transpose two adjacent windows with differently ordered compatible partitions") {
+ val query = testRelation
+ .window(Seq(sum(c).as('sum_a_2)), partitionSpec4, Seq.empty)
+ .window(Seq(sum(c).as('sum_a_1)), partitionSpec2, Seq.empty)
+
+ val analyzed = query.analyze
+ val optimized = Optimize.execute(analyzed)
+
+ val correctAnswer = testRelation
+ .window(Seq(sum(c).as('sum_a_1)), partitionSpec2, Seq.empty)
+ .window(Seq(sum(c).as('sum_a_2)), partitionSpec4, Seq.empty)
+ .select('a, 'b, 'c, 'd, 'sum_a_2, 'sum_a_1)
+
+ comparePlans(optimized, correctAnswer.analyze)
+ }
+
+ test("don't transpose two adjacent windows with incompatible partitions") {
+ val query = testRelation
+ .window(Seq(sum(c).as('sum_a_2)), partitionSpec3, Seq.empty)
+ .window(Seq(sum(c).as('sum_a_1)), partitionSpec1, Seq.empty)
+
+ val analyzed = query.analyze
+ val optimized = Optimize.execute(analyzed)
+
+ comparePlans(optimized, analyzed)
+ }
+
+ test("don't transpose two adjacent windows with intersection of partition and output set") {
+ val query = testRelation
+ .window(Seq(('a + 'b).as('e), sum(c).as('sum_a_2)), partitionSpec3, Seq.empty)
+ .window(Seq(sum(c).as('sum_a_1)), Seq(a, 'e), Seq.empty)
+
+ val analyzed = query.analyze
+ val optimized = Optimize.execute(analyzed)
+
+ comparePlans(optimized, analyzed)
+ }
+
+ test("don't transpose two adjacent windows with non-deterministic expressions") {
+ val query = testRelation
+ .window(Seq(Rand(0).as('e), sum(c).as('sum_a_2)), partitionSpec3, Seq.empty)
+ .window(Seq(sum(c).as('sum_a_1)), partitionSpec1, Seq.empty)
+
+ val analyzed = query.analyze
+ val optimized = Optimize.execute(analyzed)
+
+ comparePlans(optimized, analyzed)
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/78981efc/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
index 97a8439..78277d7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
@@ -30,6 +30,7 @@ import org.apache.spark.sql.types._
* Window function testing for DataFrame API.
*/
class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
+
import testImplicits._
test("reuse window partitionBy") {
@@ -72,9 +73,9 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
cume_dist().over(Window.partitionBy("value").orderBy("key")),
percent_rank().over(Window.partitionBy("value").orderBy("key"))),
Row(1, 1, 1, 1.0d, 1, 1, 1, 1, 1, 1, 1.0d, 0.0d) ::
- Row(1, 1, 1, 1.0d, 1, 1, 1, 1, 1, 1, 1.0d / 3.0d, 0.0d) ::
- Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 1, 2, 2, 2, 1.0d, 0.5d) ::
- Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 2, 3, 2, 2, 1.0d, 0.5d) :: Nil)
+ Row(1, 1, 1, 1.0d, 1, 1, 1, 1, 1, 1, 1.0d / 3.0d, 0.0d) ::
+ Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 1, 2, 2, 2, 1.0d, 0.5d) ::
+ Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 2, 3, 2, 2, 1.0d, 0.5d) :: Nil)
}
test("window function should fail if order by clause is not specified") {
@@ -162,12 +163,12 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
Seq(
Row("a", -50.0, 50.0, 50.0, 7.0710678118654755, 7.0710678118654755),
Row("b", -50.0, 50.0, 50.0, 7.0710678118654755, 7.0710678118654755),
- Row("c", 0.0, 0.0, 0.0, 0.0, 0.0 ),
- Row("d", 0.0, 0.0, 0.0, 0.0, 0.0 ),
- Row("e", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ),
- Row("f", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ),
- Row("g", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ),
- Row("h", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ),
+ Row("c", 0.0, 0.0, 0.0, 0.0, 0.0),
+ Row("d", 0.0, 0.0, 0.0, 0.0, 0.0),
+ Row("e", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544),
+ Row("f", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544),
+ Row("g", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544),
+ Row("h", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544),
Row("i", Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN)))
}
@@ -326,7 +327,7 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
var_samp($"value").over(window),
approx_count_distinct($"value").over(window)),
Seq.fill(4)(Row("a", 1.0d / 4.0d, 1.0d / 3.0d, 2))
- ++ Seq.fill(3)(Row("b", 2.0d / 3.0d, 1.0d, 3)))
+ ++ Seq.fill(3)(Row("b", 2.0d / 3.0d, 1.0d, 3)))
}
test("window function with aggregates") {
@@ -624,7 +625,7 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
test("SPARK-24575: Window functions inside WHERE and HAVING clauses") {
def checkAnalysisError(df: => DataFrame): Unit = {
- val thrownException = the [AnalysisException] thrownBy {
+ val thrownException = the[AnalysisException] thrownBy {
df.queryExecution.analyzed
}
assert(thrownException.message.contains("window functions inside WHERE and HAVING clauses"))
@@ -658,4 +659,26 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
|GROUP BY a
|HAVING SUM(b) = 5 AND RANK() OVER(ORDER BY a) = 1""".stripMargin))
}
+
+ test("window functions in multiple selects") {
+ val df = Seq(
+ ("S1", "P1", 100),
+ ("S1", "P1", 700),
+ ("S2", "P1", 200),
+ ("S2", "P2", 300)
+ ).toDF("sno", "pno", "qty")
+
+ val w1 = Window.partitionBy("sno")
+ val w2 = Window.partitionBy("sno", "pno")
+
+ checkAnswer(
+ df.select($"sno", $"pno", $"qty", sum($"qty").over(w2).alias("sum_qty_2"))
+ .select($"sno", $"pno", $"qty", col("sum_qty_2"), sum("qty").over(w1).alias("sum_qty_1")),
+ Seq(
+ Row("S1", "P1", 100, 800, 800),
+ Row("S1", "P1", 700, 800, 800),
+ Row("S2", "P1", 200, 200, 500),
+ Row("S2", "P2", 300, 300, 500)))
+
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org