You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2016/12/15 00:12:16 UTC

spark git commit: [SPARK-18854][SQL] numberedTreeString and apply(i) inconsistent for subqueries

Repository: spark
Updated Branches:
  refs/heads/master 786274257 -> ffdd1fcd1


[SPARK-18854][SQL] numberedTreeString and apply(i) inconsistent for subqueries

## What changes were proposed in this pull request?
This is a bug introduced by subquery handling. numberedTreeString (which uses generateTreeString under the hood) numbers trees including innerChildren (used to print subqueries), but apply (which uses getNodeNumbered) ignores innerChildren. As a result, apply(i) would return the wrong plan node if there are subqueries.

This patch fixes the bug.

## How was this patch tested?
Added a test case in SubquerySuite.scala to test both the depth-first traversal of numbering as well as making sure the two methods are consistent.

Author: Reynold Xin <rx...@databricks.com>

Closes #16277 from rxin/SPARK-18854.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ffdd1fcd
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ffdd1fcd
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ffdd1fcd

Branch: refs/heads/master
Commit: ffdd1fcd1e8f4f6453d5b0517c0ce82766b8e75f
Parents: 7862742
Author: Reynold Xin <rx...@databricks.com>
Authored: Wed Dec 14 16:12:14 2016 -0800
Committer: Reynold Xin <rx...@databricks.com>
Committed: Wed Dec 14 16:12:14 2016 -0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/plans/QueryPlan.scala    |  9 ++++
 .../plans/logical/basicLogicalOperators.scala   |  2 +-
 .../spark/sql/catalyst/trees/TreeNode.scala     | 46 +++++++++++---------
 .../execution/columnar/InMemoryRelation.scala   |  3 +-
 .../org/apache/spark/sql/SubquerySuite.scala    | 18 ++++++++
 5 files changed, 55 insertions(+), 23 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ffdd1fcd/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index b108017..e67f2be 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -24,6 +24,15 @@ import org.apache.spark.sql.types.{DataType, StructType}
 abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanType] {
   self: PlanType =>
 
+  /**
+   * Override [[TreeNode.apply]] to so we can return a more narrow type.
+   *
+   * Note that this cannot return BaseType because logical plan's plan node might return
+   * physical plan for innerChildren, e.g. in-memory relation logical plan node has a reference
+   * to the physical plan node it is referencing.
+   */
+  override def apply(number: Int): QueryPlan[_] = super.apply(number).asInstanceOf[QueryPlan[_]]
+
   def output: Seq[Attribute]
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/ffdd1fcd/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index b9bdd53..0de5aa8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -393,7 +393,7 @@ case class With(child: LogicalPlan, cteRelations: Seq[(String, SubqueryAlias)])
     s"CTE $cteAliases"
   }
 
-  override def innerChildren: Seq[QueryPlan[_]] = cteRelations.map(_._2)
+  override def innerChildren: Seq[LogicalPlan] = cteRelations.map(_._2)
 }
 
 case class WithWindowDefinition(

http://git-wip-us.apache.org/repos/asf/spark/blob/ffdd1fcd/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index ea8d8fe..670fa2b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.trees
 import java.util.UUID
 
 import scala.collection.Map
-import scala.collection.mutable.Stack
 import scala.reflect.ClassTag
 
 import org.apache.commons.lang3.ClassUtils
@@ -28,12 +27,9 @@ import org.json4s.JsonAST._
 import org.json4s.JsonDSL._
 import org.json4s.jackson.JsonMethods._
 
-import org.apache.spark.SparkContext
-import org.apache.spark.rdd.{EmptyRDD, RDD}
 import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, FunctionResource}
 import org.apache.spark.sql.catalyst.FunctionIdentifier
 import org.apache.spark.sql.catalyst.ScalaReflection._
-import org.apache.spark.sql.catalyst.ScalaReflectionLock
 import org.apache.spark.sql.catalyst.TableIdentifier
 import org.apache.spark.sql.catalyst.errors._
 import org.apache.spark.sql.catalyst.expressions._
@@ -493,7 +489,10 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
 
   /**
    * Returns a string representation of the nodes in this tree, where each operator is numbered.
-   * The numbers can be used with [[trees.TreeNode.apply apply]] to easily access specific subtrees.
+   * The numbers can be used with [[TreeNode.apply]] to easily access specific subtrees.
+   *
+   * The numbers are based on depth-first traversal of the tree (with innerChildren traversed first
+   * before children).
    */
   def numberedTreeString: String =
     treeString.split("\n").zipWithIndex.map { case (line, i) => f"$i%02d $line" }.mkString("\n")
@@ -501,17 +500,24 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
   /**
    * Returns the tree node at the specified number.
    * Numbers for each node can be found in the [[numberedTreeString]].
+   *
+   * Note that this cannot return BaseType because logical plan's plan node might return
+   * physical plan for innerChildren, e.g. in-memory relation logical plan node has a reference
+   * to the physical plan node it is referencing.
    */
-  def apply(number: Int): BaseType = getNodeNumbered(new MutableInt(number))
+  def apply(number: Int): TreeNode[_] = getNodeNumbered(new MutableInt(number)).orNull
 
-  protected def getNodeNumbered(number: MutableInt): BaseType = {
+  private def getNodeNumbered(number: MutableInt): Option[TreeNode[_]] = {
     if (number.i < 0) {
-      null.asInstanceOf[BaseType]
+      None
     } else if (number.i == 0) {
-      this
+      Some(this)
     } else {
       number.i -= 1
-      children.map(_.getNodeNumbered(number)).find(_ != null).getOrElse(null.asInstanceOf[BaseType])
+      // Note that this traversal order must be the same as numberedTreeString.
+      innerChildren.map(_.getNodeNumbered(number)).find(_ != None).getOrElse {
+        children.map(_.getNodeNumbered(number)).find(_ != None).flatten
+      }
     }
   }
 
@@ -527,6 +533,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
    * The `i`-th element in `lastChildren` indicates whether the ancestor of the current node at
    * depth `i + 1` is the last child of its own parent node.  The depth of the root node is 0, and
    * `lastChildren` for the root node should be empty.
+   *
+   * Note that this traversal (numbering) order must be the same as [[getNodeNumbered]].
    */
   def generateTreeString(
       depth: Int,
@@ -534,19 +542,16 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
       builder: StringBuilder,
       verbose: Boolean,
       prefix: String = ""): StringBuilder = {
+
     if (depth > 0) {
       lastChildren.init.foreach { isLast =>
-        val prefixFragment = if (isLast) "   " else ":  "
-        builder.append(prefixFragment)
+        builder.append(if (isLast) "   " else ":  ")
       }
-
-      val branch = if (lastChildren.last) "+- " else ":- "
-      builder.append(branch)
+      builder.append(if (lastChildren.last) "+- " else ":- ")
     }
 
     builder.append(prefix)
-    val headline = if (verbose) verboseString else simpleString
-    builder.append(headline)
+    builder.append(if (verbose) verboseString else simpleString)
     builder.append("\n")
 
     if (innerChildren.nonEmpty) {
@@ -557,9 +562,10 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
     }
 
     if (children.nonEmpty) {
-      children.init.foreach(
-        _.generateTreeString(depth + 1, lastChildren :+ false, builder, verbose, prefix))
-      children.last.generateTreeString(depth + 1, lastChildren :+ true, builder, verbose, prefix)
+      children.init.foreach(_.generateTreeString(
+        depth + 1, lastChildren :+ false, builder, verbose, prefix))
+      children.last.generateTreeString(
+        depth + 1, lastChildren :+ true, builder, verbose, prefix)
     }
 
     builder

http://git-wip-us.apache.org/repos/asf/spark/blob/ffdd1fcd/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
index 56bd5c1..03cc046 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
@@ -24,7 +24,6 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.QueryPlan
 import org.apache.spark.sql.catalyst.plans.logical
 import org.apache.spark.sql.catalyst.plans.logical.Statistics
 import org.apache.spark.sql.execution.SparkPlan
@@ -64,7 +63,7 @@ case class InMemoryRelation(
     val batchStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator)
   extends logical.LeafNode with MultiInstanceRelation {
 
-  override protected def innerChildren: Seq[QueryPlan[_]] = Seq(child)
+  override protected def innerChildren: Seq[SparkPlan] = Seq(child)
 
   override def producedAttributes: AttributeSet = outputSet
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ffdd1fcd/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
index 5a4b1cf..2ef8b18 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
@@ -54,6 +54,24 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
     t.createOrReplaceTempView("t")
   }
 
+  test("SPARK-18854 numberedTreeString for subquery") {
+    val df = sql("select * from range(10) where id not in " +
+      "(select id from range(2) union all select id from range(2))")
+
+    // The depth first traversal of the plan tree
+    val dfs = Seq("Project", "Filter", "Union", "Project", "Range", "Project", "Range", "Range")
+    val numbered = df.queryExecution.analyzed.numberedTreeString.split("\n")
+
+    // There should be 8 plan nodes in total
+    assert(numbered.size == dfs.size)
+
+    for (i <- dfs.indices) {
+      val node = df.queryExecution.analyzed(i)
+      assert(node.nodeName == dfs(i))
+      assert(numbered(i).contains(node.nodeName))
+    }
+  }
+
   test("rdd deserialization does not crash [SPARK-15791]") {
     sql("select (select 1 as b) as b").rdd.count()
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org