You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2018/03/30 15:21:14 UTC

spark git commit: [SPARK-23500][SQL][FOLLOWUP] Fix complex type simplification rules to apply to entire plan

Repository: spark
Updated Branches:
  refs/heads/master 5b5a36ed6 -> bc8d09311


[SPARK-23500][SQL][FOLLOWUP] Fix complex type simplification rules to apply to entire plan

## What changes were proposed in this pull request?
This PR is to improve the test coverage of the original PR https://github.com/apache/spark/pull/20687

## How was this patch tested?
N/A

Author: gatorsmile <ga...@gmail.com>

Closes #20911 from gatorsmile/addTests.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/bc8d0931
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/bc8d0931
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/bc8d0931

Branch: refs/heads/master
Commit: bc8d0931170cfa20a4fb64b3b11a2027ddb0d6e9
Parents: 5b5a36e
Author: gatorsmile <ga...@gmail.com>
Authored: Fri Mar 30 23:21:07 2018 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Fri Mar 30 23:21:07 2018 +0800

----------------------------------------------------------------------
 .../catalyst/optimizer/complexTypesSuite.scala  | 176 +++++++++++++------
 .../apache/spark/sql/ComplexTypesSuite.scala    | 109 ++++++++++++
 2 files changed, 233 insertions(+), 52 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/bc8d0931/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala
index e44a669..21ed987 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala
@@ -47,10 +47,17 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
           SimplifyExtractValueOps) :: Nil
   }
 
-  val idAtt = ('id).long.notNull
-  val nullableIdAtt = ('nullable_id).long
+  private val idAtt = ('id).long.notNull
+  private val nullableIdAtt = ('nullable_id).long
 
-  lazy val relation = LocalRelation(idAtt, nullableIdAtt)
+  private val relation = LocalRelation(idAtt, nullableIdAtt)
+  private val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.double, 'e.int)
+
+  private def checkRule(originalQuery: LogicalPlan, correctAnswer: LogicalPlan) = {
+    val optimized = Optimizer.execute(originalQuery.analyze)
+    assert(optimized.resolved, "optimized plans must be still resolvable")
+    comparePlans(optimized, correctAnswer.analyze)
+  }
 
   test("explicit get from namedStruct") {
     val query = relation
@@ -58,31 +65,28 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
         GetStructField(
           CreateNamedStruct(Seq("att", 'id )),
           0,
-          None) as "outerAtt").analyze
-    val expected = relation.select('id as "outerAtt").analyze
+          None) as "outerAtt")
+    val expected = relation.select('id as "outerAtt")
 
-    comparePlans(Optimizer execute query, expected)
+    checkRule(query, expected)
   }
 
   test("explicit get from named_struct- expression maintains original deduced alias") {
     val query = relation
       .select(GetStructField(CreateNamedStruct(Seq("att", 'id)), 0, None))
-      .analyze
 
     val expected = relation
       .select('id as "named_struct(att, id).att")
-      .analyze
 
-    comparePlans(Optimizer execute query, expected)
+    checkRule(query, expected)
   }
 
   test("collapsed getStructField ontop of namedStruct") {
     val query = relation
       .select(CreateNamedStruct(Seq("att", 'id)) as "struct1")
       .select(GetStructField('struct1, 0, None) as "struct1Att")
-      .analyze
-    val expected = relation.select('id as "struct1Att").analyze
-    comparePlans(Optimizer execute query, expected)
+    val expected = relation.select('id as "struct1Att")
+    checkRule(query, expected)
   }
 
   test("collapse multiple CreateNamedStruct/GetStructField pairs") {
@@ -94,16 +98,14 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
       .select(
         GetStructField('struct1, 0, None) as "struct1Att1",
         GetStructField('struct1, 1, None) as "struct1Att2")
-      .analyze
 
     val expected =
       relation.
         select(
           'id as "struct1Att1",
           ('id * 'id) as "struct1Att2")
-      .analyze
 
-    comparePlans(Optimizer execute query, expected)
+    checkRule(query, expected)
   }
 
   test("collapsed2 - deduced names") {
@@ -115,16 +117,14 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
       .select(
         GetStructField('struct1, 0, None),
         GetStructField('struct1, 1, None))
-      .analyze
 
     val expected =
       relation.
         select(
           'id as "struct1.att1",
           ('id * 'id) as "struct1.att2")
-        .analyze
 
-    comparePlans(Optimizer execute query, expected)
+    checkRule(query, expected)
   }
 
   test("simplified array ops") {
@@ -151,7 +151,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
             1,
             false),
           1) as "a4")
-      .analyze
 
     val expected = relation
       .select(
@@ -161,8 +160,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
           "att2", (('id + 1L) * ('id + 1L)))) as "a2",
         ('id + 1L) as "a3",
         ('id + 1L) as "a4")
-      .analyze
-    comparePlans(Optimizer execute query, expected)
+    checkRule(query, expected)
   }
 
   test("SPARK-22570: CreateArray should not create a lot of global variables") {
@@ -188,7 +186,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
         GetStructField(GetMapValue('m, "r1"), 0, None) as "a2",
         GetMapValue('m, "r32") as "a3",
         GetStructField(GetMapValue('m, "r32"), 0, None) as "a4")
-      .analyze
 
     val expected =
       relation.select(
@@ -201,8 +198,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
           )
         ) as "a3",
         Literal.create(null, LongType) as "a4")
-      .analyze
-    comparePlans(Optimizer execute query, expected)
+    checkRule(query, expected)
   }
 
   test("simplify map ops, constant lookup, dynamic keys") {
@@ -216,7 +212,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
           ('id + 3L), ('id + 4L),
           ('id + 4L), ('id + 5L))),
         13L) as "a")
-      .analyze
 
     val expected = relation
       .select(
@@ -225,8 +220,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
           (EqualTo(13L, ('id + 1L)), ('id + 2L)),
           (EqualTo(13L, ('id + 2L)), ('id + 3L)),
           (Literal(true), 'id))) as "a")
-      .analyze
-    comparePlans(Optimizer execute query, expected)
+    checkRule(query, expected)
   }
 
   test("simplify map ops, dynamic lookup, dynamic keys, lookup is equivalent to one of the keys") {
@@ -240,7 +234,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
             ('id + 3L), ('id + 4L),
             ('id + 4L), ('id + 5L))),
             ('id + 3L)) as "a")
-      .analyze
     val expected = relation
       .select(
         CaseWhen(Seq(
@@ -248,8 +241,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
           (EqualTo('id + 3L, ('id + 1L)), ('id + 2L)),
           (EqualTo('id + 3L, ('id + 2L)), ('id + 3L)),
           (Literal(true), ('id + 4L)))) as "a")
-      .analyze
-    comparePlans(Optimizer execute query, expected)
+    checkRule(query, expected)
   }
 
   test("simplify map ops, no positive match") {
@@ -263,7 +255,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
             ('id + 3L), ('id + 4L),
             ('id + 4L), ('id + 5L))),
           'id + 30L) as "a")
-      .analyze
     val expected = relation.select(
       CaseWhen(Seq(
         (EqualTo('id + 30L, 'id), ('id + 1L)),
@@ -271,8 +262,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
         (EqualTo('id + 30L, ('id + 2L)), ('id + 3L)),
         (EqualTo('id + 30L, ('id + 3L)), ('id + 4L)),
         (EqualTo('id + 30L, ('id + 4L)), ('id + 5L)))) as "a")
-      .analyze
-    comparePlans(Optimizer execute rel, expected)
+    checkRule(rel, expected)
   }
 
   test("simplify map ops, constant lookup, mixed keys, eliminated constants") {
@@ -287,7 +277,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
             ('id + 3L), ('id + 4L),
             ('id + 4L), ('id + 5L))),
           13L) as "a")
-      .analyze
 
     val expected = relation
       .select(
@@ -297,9 +286,8 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
             ('id + 2L), ('id + 3L),
             ('id + 3L), ('id + 4L),
             ('id + 4L), ('id + 5L))) as "a")
-      .analyze
 
-    comparePlans(Optimizer execute rel, expected)
+    checkRule(rel, expected)
   }
 
   test("simplify map ops, potential dynamic match with null value + an absolute constant match") {
@@ -314,7 +302,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
             ('id + 3L), ('id + 4L),
             ('id + 4L), ('id + 5L))),
           2L ) as "a")
-      .analyze
 
     val expected = relation
       .select(
@@ -327,18 +314,69 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
           // but it cannot override a potential match with ('id + 2L),
           // which is exactly what [[Coalesce]] would do in this case.
           (Literal.TrueLiteral, 'id))) as "a")
-      .analyze
-    comparePlans(Optimizer execute rel, expected)
+    checkRule(rel, expected)
+  }
+
+  test("SPARK-23500: Simplify array ops that are not at the top node") {
+    val query = LocalRelation('id.long)
+      .select(
+        CreateArray(Seq(
+          CreateNamedStruct(Seq(
+            "att1", 'id,
+            "att2", 'id * 'id)),
+          CreateNamedStruct(Seq(
+            "att1", 'id + 1,
+            "att2", ('id + 1) * ('id + 1))
+          ))
+        ) as "arr")
+      .select(
+        GetStructField(GetArrayItem('arr, 1), 0, None) as "a1",
+        GetArrayItem(
+          GetArrayStructFields('arr,
+            StructField("att1", LongType, nullable = false),
+            ordinal = 0,
+            numFields = 1,
+            containsNull = false),
+          ordinal = 1) as "a2")
+      .orderBy('id.asc)
+
+    val expected = LocalRelation('id.long)
+      .select(
+        ('id + 1L) as "a1",
+        ('id + 1L) as "a2")
+      .orderBy('id.asc)
+    checkRule(query, expected)
+  }
+
+  test("SPARK-23500: Simplify map ops that are not top nodes") {
+    val query =
+      LocalRelation('id.long)
+        .select(
+          CreateMap(Seq(
+            "r1", 'id,
+            "r2", 'id + 1L)) as "m")
+        .select(
+          GetMapValue('m, "r1") as "a1",
+          GetMapValue('m, "r32") as "a2")
+        .orderBy('id.asc)
+        .select('a1, 'a2)
+
+    val expected =
+      LocalRelation('id.long).select(
+        'id as "a1",
+        Literal.create(null, LongType) as "a2")
+        .orderBy('id.asc)
+    checkRule(query, expected)
   }
 
   test("SPARK-23500: Simplify complex ops that aren't at the plan root") {
     val structRel = relation
       .select(GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None) as "foo")
-      .groupBy($"foo")("1").analyze
+      .groupBy($"foo")("1")
     val structExpected = relation
       .select('nullable_id as "foo")
-      .groupBy($"foo")("1").analyze
-    comparePlans(Optimizer execute structRel, structExpected)
+      .groupBy($"foo")("1")
+    checkRule(structRel, structExpected)
 
     // These tests must use nullable attributes from the base relation for the following reason:
     // in the 'original' plans below, the Aggregate node produced by groupBy() has a
@@ -351,17 +389,17 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
     // SPARK-23634.
     val arrayRel = relation
       .select(GetArrayItem(CreateArray(Seq('nullable_id, 'nullable_id + 1L)), 0) as "a1")
-      .groupBy($"a1")("1").analyze
-    val arrayExpected = relation.select('nullable_id as "a1").groupBy($"a1")("1").analyze
-    comparePlans(Optimizer execute arrayRel, arrayExpected)
+      .groupBy($"a1")("1")
+    val arrayExpected = relation.select('nullable_id as "a1").groupBy($"a1")("1")
+    checkRule(arrayRel, arrayExpected)
 
     val mapRel = relation
       .select(GetMapValue(CreateMap(Seq("id", 'nullable_id)), "id") as "m1")
-      .groupBy($"m1")("1").analyze
+      .groupBy($"m1")("1")
     val mapExpected = relation
       .select('nullable_id as "m1")
-      .groupBy($"m1")("1").analyze
-    comparePlans(Optimizer execute mapRel, mapExpected)
+      .groupBy($"m1")("1")
+    checkRule(mapRel, mapExpected)
   }
 
   test("SPARK-23500: Ensure that aggregation expressions are not simplified") {
@@ -369,11 +407,45 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
     // grouping exprs so aren't tested here.
     val structAggRel = relation.groupBy(
       CreateNamedStruct(Seq("att1", 'nullable_id)))(
-      GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None)).analyze
-    comparePlans(Optimizer execute structAggRel, structAggRel)
+      GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None))
+    checkRule(structAggRel, structAggRel)
 
     val arrayAggRel = relation.groupBy(
-      CreateArray(Seq('nullable_id)))(GetArrayItem(CreateArray(Seq('nullable_id)), 0)).analyze
-    comparePlans(Optimizer execute arrayAggRel, arrayAggRel)
+      CreateArray(Seq('nullable_id)))(GetArrayItem(CreateArray(Seq('nullable_id)), 0))
+    checkRule(arrayAggRel, arrayAggRel)
+
+    // This could be done if we had a more complex rule that checks that
+    // the CreateMap does not come from key.
+    val originalQuery = relation
+      .groupBy('id)(
+        GetMapValue(CreateMap(Seq('id, 'id + 1L)), 0L) as "a"
+      )
+    checkRule(originalQuery, originalQuery)
+  }
+
+  test("SPARK-23500: namedStruct and getField in the same Project #1") {
+    val originalQuery =
+      testRelation
+        .select(
+          namedStruct("col1", 'b, "col2", 'c).as("s1"), 'a, 'b)
+        .select('s1 getField "col2" as 's1Col2,
+          namedStruct("col1", 'a, "col2", 'b).as("s2"))
+        .select('s1Col2, 's2 getField "col2" as 's2Col2)
+    val correctAnswer =
+      testRelation
+        .select('c as 's1Col2, 'b as 's2Col2)
+    checkRule(originalQuery, correctAnswer)
+  }
+
+  test("SPARK-23500: namedStruct and getField in the same Project #2") {
+    val originalQuery =
+      testRelation
+        .select(
+          namedStruct("col1", 'b, "col2", 'c) getField "col2" as 'sCol2,
+          namedStruct("col1", 'a, "col2", 'c) getField "col1" as 'sCol1)
+    val correctAnswer =
+      testRelation
+        .select('c as 'sCol2, 'a as 'sCol1)
+    checkRule(originalQuery, correctAnswer)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/bc8d0931/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala
new file mode 100644
index 0000000..b74fe2f
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.test.SharedSQLContext
+
+class ComplexTypesSuite extends QueryTest with SharedSQLContext {
+
+  override def beforeAll() {
+    super.beforeAll()
+    spark.range(10).selectExpr(
+      "id + 1 as i1", "id + 2 as i2", "id + 3 as i3", "id + 4 as i4", "id + 5 as i5")
+      .write.saveAsTable("tab")
+  }
+
+  override def afterAll() {
+    try {
+      spark.sql("DROP TABLE IF EXISTS tab")
+    } finally {
+      super.afterAll()
+    }
+  }
+
+  def checkNamedStruct(plan: LogicalPlan, expectedCount: Int): Unit = {
+    var count = 0
+    plan.foreach { operator =>
+      operator.transformExpressions {
+        case c: CreateNamedStruct =>
+          count += 1
+          c
+      }
+    }
+
+    if (expectedCount != count) {
+      fail(s"expect $expectedCount CreateNamedStruct but got $count.")
+    }
+  }
+
+  test("simple case") {
+    val df = spark.table("tab").selectExpr(
+      "i5", "named_struct('a', i1, 'b', i2) as col1", "named_struct('a', i3, 'c', i4) as col2")
+      .filter("col2.c > 11").selectExpr("col1.a")
+    checkAnswer(df, Row(9) :: Row(10) :: Nil)
+    checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 0)
+  }
+
+  test("named_struct is used in the top Project") {
+    val df = spark.table("tab").selectExpr(
+      "i5", "named_struct('a', i1, 'b', i2) as col1", "named_struct('a', i3, 'c', i4)")
+      .selectExpr("col1.a", "col1")
+      .filter("col1.a > 8")
+    checkAnswer(df, Row(9, Row(9, 10)) :: Row(10, Row(10, 11)) :: Nil)
+    checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 1)
+
+    val df1 = spark.table("tab").selectExpr(
+      "i5", "named_struct('a', i1, 'b', i2) as col1", "named_struct('a', i3, 'c', i4)")
+      .sort("col1")
+      .selectExpr("col1.a")
+      .filter("col1.a > 8")
+    checkAnswer(df1, Row(9) :: Row(10) :: Nil)
+    checkNamedStruct(df1.queryExecution.optimizedPlan, expectedCount = 1)
+  }
+
+  test("expression in named_struct") {
+    val df = spark.table("tab")
+      .selectExpr("i5", "struct(i1 as exp, i2, i3) as cola")
+      .selectExpr("cola.exp", "cola.i3").filter("cola.i3 > 10")
+    checkAnswer(df, Row(9, 11) :: Row(10, 12) :: Nil)
+    checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 0)
+
+    val df1 = spark.table("tab")
+      .selectExpr("i5", "struct(i1 + 1 as exp, i2, i3) as cola")
+      .selectExpr("cola.i3").filter("cola.exp > 10")
+    checkAnswer(df1, Row(12) :: Nil)
+    checkNamedStruct(df1.queryExecution.optimizedPlan, expectedCount = 0)
+  }
+
+  test("nested case") {
+    val df = spark.table("tab")
+      .selectExpr("struct(struct(i2, i3) as exp, i4) as cola")
+      .selectExpr("cola.exp.i2", "cola.i4").filter("cola.exp.i2 > 10")
+    checkAnswer(df, Row(11, 13) :: Nil)
+    checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 0)
+
+    val df1 = spark.table("tab")
+      .selectExpr("struct(i2, i3) as exp", "i4")
+      .selectExpr("struct(exp, i4) as cola")
+      .selectExpr("cola.exp.i2", "cola.i4").filter("cola.i4 > 11")
+    checkAnswer(df1, Row(10, 12) :: Row(11, 13) :: Nil)
+    checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 0)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org