You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2018/03/30 15:21:14 UTC
spark git commit: [SPARK-23500][SQL][FOLLOWUP] Fix complex type
simplification rules to apply to entire plan
Repository: spark
Updated Branches:
refs/heads/master 5b5a36ed6 -> bc8d09311
[SPARK-23500][SQL][FOLLOWUP] Fix complex type simplification rules to apply to entire plan
## What changes were proposed in this pull request?
This PR is to improve the test coverage of the original PR https://github.com/apache/spark/pull/20687
## How was this patch tested?
N/A
Author: gatorsmile <ga...@gmail.com>
Closes #20911 from gatorsmile/addTests.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/bc8d0931
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/bc8d0931
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/bc8d0931
Branch: refs/heads/master
Commit: bc8d0931170cfa20a4fb64b3b11a2027ddb0d6e9
Parents: 5b5a36e
Author: gatorsmile <ga...@gmail.com>
Authored: Fri Mar 30 23:21:07 2018 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Fri Mar 30 23:21:07 2018 +0800
----------------------------------------------------------------------
.../catalyst/optimizer/complexTypesSuite.scala | 176 +++++++++++++------
.../apache/spark/sql/ComplexTypesSuite.scala | 109 ++++++++++++
2 files changed, 233 insertions(+), 52 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/bc8d0931/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala
index e44a669..21ed987 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala
@@ -47,10 +47,17 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
SimplifyExtractValueOps) :: Nil
}
- val idAtt = ('id).long.notNull
- val nullableIdAtt = ('nullable_id).long
+ private val idAtt = ('id).long.notNull
+ private val nullableIdAtt = ('nullable_id).long
- lazy val relation = LocalRelation(idAtt, nullableIdAtt)
+ private val relation = LocalRelation(idAtt, nullableIdAtt)
+ private val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.double, 'e.int)
+
+ private def checkRule(originalQuery: LogicalPlan, correctAnswer: LogicalPlan) = {
+ val optimized = Optimizer.execute(originalQuery.analyze)
+ assert(optimized.resolved, "optimized plans must be still resolvable")
+ comparePlans(optimized, correctAnswer.analyze)
+ }
test("explicit get from namedStruct") {
val query = relation
@@ -58,31 +65,28 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
GetStructField(
CreateNamedStruct(Seq("att", 'id )),
0,
- None) as "outerAtt").analyze
- val expected = relation.select('id as "outerAtt").analyze
+ None) as "outerAtt")
+ val expected = relation.select('id as "outerAtt")
- comparePlans(Optimizer execute query, expected)
+ checkRule(query, expected)
}
test("explicit get from named_struct- expression maintains original deduced alias") {
val query = relation
.select(GetStructField(CreateNamedStruct(Seq("att", 'id)), 0, None))
- .analyze
val expected = relation
.select('id as "named_struct(att, id).att")
- .analyze
- comparePlans(Optimizer execute query, expected)
+ checkRule(query, expected)
}
test("collapsed getStructField ontop of namedStruct") {
val query = relation
.select(CreateNamedStruct(Seq("att", 'id)) as "struct1")
.select(GetStructField('struct1, 0, None) as "struct1Att")
- .analyze
- val expected = relation.select('id as "struct1Att").analyze
- comparePlans(Optimizer execute query, expected)
+ val expected = relation.select('id as "struct1Att")
+ checkRule(query, expected)
}
test("collapse multiple CreateNamedStruct/GetStructField pairs") {
@@ -94,16 +98,14 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
.select(
GetStructField('struct1, 0, None) as "struct1Att1",
GetStructField('struct1, 1, None) as "struct1Att2")
- .analyze
val expected =
relation.
select(
'id as "struct1Att1",
('id * 'id) as "struct1Att2")
- .analyze
- comparePlans(Optimizer execute query, expected)
+ checkRule(query, expected)
}
test("collapsed2 - deduced names") {
@@ -115,16 +117,14 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
.select(
GetStructField('struct1, 0, None),
GetStructField('struct1, 1, None))
- .analyze
val expected =
relation.
select(
'id as "struct1.att1",
('id * 'id) as "struct1.att2")
- .analyze
- comparePlans(Optimizer execute query, expected)
+ checkRule(query, expected)
}
test("simplified array ops") {
@@ -151,7 +151,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
1,
false),
1) as "a4")
- .analyze
val expected = relation
.select(
@@ -161,8 +160,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
"att2", (('id + 1L) * ('id + 1L)))) as "a2",
('id + 1L) as "a3",
('id + 1L) as "a4")
- .analyze
- comparePlans(Optimizer execute query, expected)
+ checkRule(query, expected)
}
test("SPARK-22570: CreateArray should not create a lot of global variables") {
@@ -188,7 +186,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
GetStructField(GetMapValue('m, "r1"), 0, None) as "a2",
GetMapValue('m, "r32") as "a3",
GetStructField(GetMapValue('m, "r32"), 0, None) as "a4")
- .analyze
val expected =
relation.select(
@@ -201,8 +198,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
)
) as "a3",
Literal.create(null, LongType) as "a4")
- .analyze
- comparePlans(Optimizer execute query, expected)
+ checkRule(query, expected)
}
test("simplify map ops, constant lookup, dynamic keys") {
@@ -216,7 +212,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
('id + 3L), ('id + 4L),
('id + 4L), ('id + 5L))),
13L) as "a")
- .analyze
val expected = relation
.select(
@@ -225,8 +220,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
(EqualTo(13L, ('id + 1L)), ('id + 2L)),
(EqualTo(13L, ('id + 2L)), ('id + 3L)),
(Literal(true), 'id))) as "a")
- .analyze
- comparePlans(Optimizer execute query, expected)
+ checkRule(query, expected)
}
test("simplify map ops, dynamic lookup, dynamic keys, lookup is equivalent to one of the keys") {
@@ -240,7 +234,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
('id + 3L), ('id + 4L),
('id + 4L), ('id + 5L))),
('id + 3L)) as "a")
- .analyze
val expected = relation
.select(
CaseWhen(Seq(
@@ -248,8 +241,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
(EqualTo('id + 3L, ('id + 1L)), ('id + 2L)),
(EqualTo('id + 3L, ('id + 2L)), ('id + 3L)),
(Literal(true), ('id + 4L)))) as "a")
- .analyze
- comparePlans(Optimizer execute query, expected)
+ checkRule(query, expected)
}
test("simplify map ops, no positive match") {
@@ -263,7 +255,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
('id + 3L), ('id + 4L),
('id + 4L), ('id + 5L))),
'id + 30L) as "a")
- .analyze
val expected = relation.select(
CaseWhen(Seq(
(EqualTo('id + 30L, 'id), ('id + 1L)),
@@ -271,8 +262,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
(EqualTo('id + 30L, ('id + 2L)), ('id + 3L)),
(EqualTo('id + 30L, ('id + 3L)), ('id + 4L)),
(EqualTo('id + 30L, ('id + 4L)), ('id + 5L)))) as "a")
- .analyze
- comparePlans(Optimizer execute rel, expected)
+ checkRule(rel, expected)
}
test("simplify map ops, constant lookup, mixed keys, eliminated constants") {
@@ -287,7 +277,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
('id + 3L), ('id + 4L),
('id + 4L), ('id + 5L))),
13L) as "a")
- .analyze
val expected = relation
.select(
@@ -297,9 +286,8 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
('id + 2L), ('id + 3L),
('id + 3L), ('id + 4L),
('id + 4L), ('id + 5L))) as "a")
- .analyze
- comparePlans(Optimizer execute rel, expected)
+ checkRule(rel, expected)
}
test("simplify map ops, potential dynamic match with null value + an absolute constant match") {
@@ -314,7 +302,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
('id + 3L), ('id + 4L),
('id + 4L), ('id + 5L))),
2L ) as "a")
- .analyze
val expected = relation
.select(
@@ -327,18 +314,69 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
// but it cannot override a potential match with ('id + 2L),
// which is exactly what [[Coalesce]] would do in this case.
(Literal.TrueLiteral, 'id))) as "a")
- .analyze
- comparePlans(Optimizer execute rel, expected)
+ checkRule(rel, expected)
+ }
+
+ test("SPARK-23500: Simplify array ops that are not at the top node") {
+ val query = LocalRelation('id.long)
+ .select(
+ CreateArray(Seq(
+ CreateNamedStruct(Seq(
+ "att1", 'id,
+ "att2", 'id * 'id)),
+ CreateNamedStruct(Seq(
+ "att1", 'id + 1,
+ "att2", ('id + 1) * ('id + 1))
+ ))
+ ) as "arr")
+ .select(
+ GetStructField(GetArrayItem('arr, 1), 0, None) as "a1",
+ GetArrayItem(
+ GetArrayStructFields('arr,
+ StructField("att1", LongType, nullable = false),
+ ordinal = 0,
+ numFields = 1,
+ containsNull = false),
+ ordinal = 1) as "a2")
+ .orderBy('id.asc)
+
+ val expected = LocalRelation('id.long)
+ .select(
+ ('id + 1L) as "a1",
+ ('id + 1L) as "a2")
+ .orderBy('id.asc)
+ checkRule(query, expected)
+ }
+
+ test("SPARK-23500: Simplify map ops that are not top nodes") {
+ val query =
+ LocalRelation('id.long)
+ .select(
+ CreateMap(Seq(
+ "r1", 'id,
+ "r2", 'id + 1L)) as "m")
+ .select(
+ GetMapValue('m, "r1") as "a1",
+ GetMapValue('m, "r32") as "a2")
+ .orderBy('id.asc)
+ .select('a1, 'a2)
+
+ val expected =
+ LocalRelation('id.long).select(
+ 'id as "a1",
+ Literal.create(null, LongType) as "a2")
+ .orderBy('id.asc)
+ checkRule(query, expected)
}
test("SPARK-23500: Simplify complex ops that aren't at the plan root") {
val structRel = relation
.select(GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None) as "foo")
- .groupBy($"foo")("1").analyze
+ .groupBy($"foo")("1")
val structExpected = relation
.select('nullable_id as "foo")
- .groupBy($"foo")("1").analyze
- comparePlans(Optimizer execute structRel, structExpected)
+ .groupBy($"foo")("1")
+ checkRule(structRel, structExpected)
// These tests must use nullable attributes from the base relation for the following reason:
// in the 'original' plans below, the Aggregate node produced by groupBy() has a
@@ -351,17 +389,17 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
// SPARK-23634.
val arrayRel = relation
.select(GetArrayItem(CreateArray(Seq('nullable_id, 'nullable_id + 1L)), 0) as "a1")
- .groupBy($"a1")("1").analyze
- val arrayExpected = relation.select('nullable_id as "a1").groupBy($"a1")("1").analyze
- comparePlans(Optimizer execute arrayRel, arrayExpected)
+ .groupBy($"a1")("1")
+ val arrayExpected = relation.select('nullable_id as "a1").groupBy($"a1")("1")
+ checkRule(arrayRel, arrayExpected)
val mapRel = relation
.select(GetMapValue(CreateMap(Seq("id", 'nullable_id)), "id") as "m1")
- .groupBy($"m1")("1").analyze
+ .groupBy($"m1")("1")
val mapExpected = relation
.select('nullable_id as "m1")
- .groupBy($"m1")("1").analyze
- comparePlans(Optimizer execute mapRel, mapExpected)
+ .groupBy($"m1")("1")
+ checkRule(mapRel, mapExpected)
}
test("SPARK-23500: Ensure that aggregation expressions are not simplified") {
@@ -369,11 +407,45 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
// grouping exprs so aren't tested here.
val structAggRel = relation.groupBy(
CreateNamedStruct(Seq("att1", 'nullable_id)))(
- GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None)).analyze
- comparePlans(Optimizer execute structAggRel, structAggRel)
+ GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None))
+ checkRule(structAggRel, structAggRel)
val arrayAggRel = relation.groupBy(
- CreateArray(Seq('nullable_id)))(GetArrayItem(CreateArray(Seq('nullable_id)), 0)).analyze
- comparePlans(Optimizer execute arrayAggRel, arrayAggRel)
+ CreateArray(Seq('nullable_id)))(GetArrayItem(CreateArray(Seq('nullable_id)), 0))
+ checkRule(arrayAggRel, arrayAggRel)
+
+ // This could be done if we had a more complex rule that checks that
+ // the CreateMap does not come from key.
+ val originalQuery = relation
+ .groupBy('id)(
+ GetMapValue(CreateMap(Seq('id, 'id + 1L)), 0L) as "a"
+ )
+ checkRule(originalQuery, originalQuery)
+ }
+
+ test("SPARK-23500: namedStruct and getField in the same Project #1") {
+ val originalQuery =
+ testRelation
+ .select(
+ namedStruct("col1", 'b, "col2", 'c).as("s1"), 'a, 'b)
+ .select('s1 getField "col2" as 's1Col2,
+ namedStruct("col1", 'a, "col2", 'b).as("s2"))
+ .select('s1Col2, 's2 getField "col2" as 's2Col2)
+ val correctAnswer =
+ testRelation
+ .select('c as 's1Col2, 'b as 's2Col2)
+ checkRule(originalQuery, correctAnswer)
+ }
+
+ test("SPARK-23500: namedStruct and getField in the same Project #2") {
+ val originalQuery =
+ testRelation
+ .select(
+ namedStruct("col1", 'b, "col2", 'c) getField "col2" as 'sCol2,
+ namedStruct("col1", 'a, "col2", 'c) getField "col1" as 'sCol1)
+ val correctAnswer =
+ testRelation
+ .select('c as 'sCol2, 'a as 'sCol1)
+ checkRule(originalQuery, correctAnswer)
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/bc8d0931/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala
new file mode 100644
index 0000000..b74fe2f
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.test.SharedSQLContext
+
+class ComplexTypesSuite extends QueryTest with SharedSQLContext {
+
+ override def beforeAll() {
+ super.beforeAll()
+ spark.range(10).selectExpr(
+ "id + 1 as i1", "id + 2 as i2", "id + 3 as i3", "id + 4 as i4", "id + 5 as i5")
+ .write.saveAsTable("tab")
+ }
+
+ override def afterAll() {
+ try {
+ spark.sql("DROP TABLE IF EXISTS tab")
+ } finally {
+ super.afterAll()
+ }
+ }
+
+ def checkNamedStruct(plan: LogicalPlan, expectedCount: Int): Unit = {
+ var count = 0
+ plan.foreach { operator =>
+ operator.transformExpressions {
+ case c: CreateNamedStruct =>
+ count += 1
+ c
+ }
+ }
+
+ if (expectedCount != count) {
+ fail(s"expect $expectedCount CreateNamedStruct but got $count.")
+ }
+ }
+
+ test("simple case") {
+ val df = spark.table("tab").selectExpr(
+ "i5", "named_struct('a', i1, 'b', i2) as col1", "named_struct('a', i3, 'c', i4) as col2")
+ .filter("col2.c > 11").selectExpr("col1.a")
+ checkAnswer(df, Row(9) :: Row(10) :: Nil)
+ checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 0)
+ }
+
+ test("named_struct is used in the top Project") {
+ val df = spark.table("tab").selectExpr(
+ "i5", "named_struct('a', i1, 'b', i2) as col1", "named_struct('a', i3, 'c', i4)")
+ .selectExpr("col1.a", "col1")
+ .filter("col1.a > 8")
+ checkAnswer(df, Row(9, Row(9, 10)) :: Row(10, Row(10, 11)) :: Nil)
+ checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 1)
+
+ val df1 = spark.table("tab").selectExpr(
+ "i5", "named_struct('a', i1, 'b', i2) as col1", "named_struct('a', i3, 'c', i4)")
+ .sort("col1")
+ .selectExpr("col1.a")
+ .filter("col1.a > 8")
+ checkAnswer(df1, Row(9) :: Row(10) :: Nil)
+ checkNamedStruct(df1.queryExecution.optimizedPlan, expectedCount = 1)
+ }
+
+ test("expression in named_struct") {
+ val df = spark.table("tab")
+ .selectExpr("i5", "struct(i1 as exp, i2, i3) as cola")
+ .selectExpr("cola.exp", "cola.i3").filter("cola.i3 > 10")
+ checkAnswer(df, Row(9, 11) :: Row(10, 12) :: Nil)
+ checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 0)
+
+ val df1 = spark.table("tab")
+ .selectExpr("i5", "struct(i1 + 1 as exp, i2, i3) as cola")
+ .selectExpr("cola.i3").filter("cola.exp > 10")
+ checkAnswer(df1, Row(12) :: Nil)
+ checkNamedStruct(df1.queryExecution.optimizedPlan, expectedCount = 0)
+ }
+
+ test("nested case") {
+ val df = spark.table("tab")
+ .selectExpr("struct(struct(i2, i3) as exp, i4) as cola")
+ .selectExpr("cola.exp.i2", "cola.i4").filter("cola.exp.i2 > 10")
+ checkAnswer(df, Row(11, 13) :: Nil)
+ checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 0)
+
+ val df1 = spark.table("tab")
+ .selectExpr("struct(i2, i3) as exp", "i4")
+ .selectExpr("struct(exp, i4) as cola")
+ .selectExpr("cola.exp.i2", "cola.i4").filter("cola.i4 > 11")
+ checkAnswer(df1, Row(10, 12) :: Row(11, 13) :: Nil)
+ checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 0)
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org