You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/02/23 01:26:59 UTC

spark git commit: [SPARK-19658][SQL] Set NumPartitions of RepartitionByExpression In Parser

Repository: spark
Updated Branches:
  refs/heads/master 4661d30b9 -> dc005ed53


[SPARK-19658][SQL] Set NumPartitions of RepartitionByExpression In Parser

### What changes were proposed in this pull request?

Currently, if `NumPartitions` is not set in RepartitionByExpression, we will set it using `spark.sql.shuffle.partitions` during Planner. However, this is not following the general resolution process. This PR is to set it in `Parser` and then `Optimizer` can use the value for plan optimization.

### How was this patch tested?

Added a test case.

Author: Xiao Li <ga...@gmail.com>

Closes #16988 from gatorsmile/resolveRepartition.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/dc005ed5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/dc005ed5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/dc005ed5

Branch: refs/heads/master
Commit: dc005ed53c87216efff50268009217ba26e34a10
Parents: 4661d30
Author: Xiao Li <ga...@gmail.com>
Authored: Wed Feb 22 17:26:56 2017 -0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Wed Feb 22 17:26:56 2017 -0800

----------------------------------------------------------------------
 .../apache/spark/sql/catalyst/dsl/package.scala |  4 +--
 .../sql/catalyst/optimizer/Optimizer.scala      |  2 +-
 .../spark/sql/catalyst/parser/AstBuilder.scala  | 16 ++++++++--
 .../plans/logical/basicLogicalOperators.scala   |  9 ++----
 .../sql/catalyst/analysis/AnalysisSuite.scala   | 10 +++---
 .../sql/catalyst/parser/PlanParserSuite.scala   |  5 +--
 .../scala/org/apache/spark/sql/Dataset.scala    |  5 +--
 .../spark/sql/execution/SparkSqlParser.scala    | 17 +++++++++--
 .../spark/sql/execution/SparkStrategies.scala   |  6 ++--
 .../sql/execution/SparkSqlParserSuite.scala     | 32 ++++++++++++++++++--
 10 files changed, 74 insertions(+), 32 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/dc005ed5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 3c53132..c062e4e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -373,8 +373,8 @@ package object dsl {
       def repartition(num: Integer): LogicalPlan =
         Repartition(num, shuffle = true, logicalPlan)
 
-      def distribute(exprs: Expression*)(n: Int = -1): LogicalPlan =
-        RepartitionByExpression(exprs, logicalPlan, numPartitions = if (n < 0) None else Some(n))
+      def distribute(exprs: Expression*)(n: Int): LogicalPlan =
+        RepartitionByExpression(exprs, logicalPlan, numPartitions = n)
 
       def analyze: LogicalPlan =
         EliminateSubqueryAliases(analysis.SimpleAnalyzer.execute(logicalPlan))

http://git-wip-us.apache.org/repos/asf/spark/blob/dc005ed5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 0c13e3e..af846a0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -578,7 +578,7 @@ object CollapseRepartition extends Rule[LogicalPlan] {
       RepartitionByExpression(exprs, child, numPartitions)
     // Case 3
     case Repartition(numPartitions, _, r: RepartitionByExpression) =>
-      r.copy(numPartitions = Some(numPartitions))
+      r.copy(numPartitions = numPartitions)
     // Case 3
     case RepartitionByExpression(exprs, Repartition(_, _, child), numPartitions) =>
       RepartitionByExpression(exprs, child, numPartitions)

http://git-wip-us.apache.org/repos/asf/spark/blob/dc005ed5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 08a6dd1..926a37b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -242,20 +242,20 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
       Sort(sort.asScala.map(visitSortItem), global = false, query)
     } else if (order.isEmpty && sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) {
       // DISTRIBUTE BY ...
-      RepartitionByExpression(expressionList(distributeBy), query)
+      withRepartitionByExpression(ctx, expressionList(distributeBy), query)
     } else if (order.isEmpty && !sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) {
       // SORT BY ... DISTRIBUTE BY ...
       Sort(
         sort.asScala.map(visitSortItem),
         global = false,
-        RepartitionByExpression(expressionList(distributeBy), query))
+        withRepartitionByExpression(ctx, expressionList(distributeBy), query))
     } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && !clusterBy.isEmpty) {
       // CLUSTER BY ...
       val expressions = expressionList(clusterBy)
       Sort(
         expressions.map(SortOrder(_, Ascending)),
         global = false,
-        RepartitionByExpression(expressions, query))
+        withRepartitionByExpression(ctx, expressions, query))
     } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) {
       // [EMPTY]
       query
@@ -274,6 +274,16 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
   }
 
   /**
+   * Create a clause for DISTRIBUTE BY.
+   */
+  protected def withRepartitionByExpression(
+      ctx: QueryOrganizationContext,
+      expressions: Seq[Expression],
+      query: LogicalPlan): LogicalPlan = {
+    throw new ParseException("DISTRIBUTE BY is not supported", ctx)
+  }
+
+  /**
    * Create a logical plan using a query specification.
    */
   override def visitQuerySpecification(

http://git-wip-us.apache.org/repos/asf/spark/blob/dc005ed5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index af57632..d17d12c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -844,18 +844,13 @@ case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan)
  * information about the number of partitions during execution. Used when a specific ordering or
  * distribution is expected by the consumer of the query result. Use [[Repartition]] for RDD-like
  * `coalesce` and `repartition`.
- * If `numPartitions` is not specified, the number of partitions will be the number set by
- * `spark.sql.shuffle.partitions`.
  */
 case class RepartitionByExpression(
     partitionExpressions: Seq[Expression],
     child: LogicalPlan,
-    numPartitions: Option[Int] = None) extends UnaryNode {
+    numPartitions: Int) extends UnaryNode {
 
-  numPartitions match {
-    case Some(n) => require(n > 0, s"Number of partitions ($n) must be positive.")
-    case None => // Ok
-  }
+  require(numPartitions > 0, s"Number of partitions ($numPartitions) must be positive.")
 
   override def maxRows: Option[Long] = child.maxRows
   override def output: Seq[Attribute] = child.output

http://git-wip-us.apache.org/repos/asf/spark/blob/dc005ed5/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 786e0f4..01737e0 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -21,11 +21,12 @@ import java.util.TimeZone
 
 import org.scalatest.ShouldMatchers
 
-import org.apache.spark.sql.catalyst.{SimpleCatalystConf, TableIdentifier}
+import org.apache.spark.sql.catalyst.TableIdentifier
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.{Cross, Inner}
+import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
+import org.apache.spark.sql.catalyst.plans.Cross
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.types._
 
@@ -192,12 +193,13 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers {
   }
 
   test("pull out nondeterministic expressions from RepartitionByExpression") {
-    val plan = RepartitionByExpression(Seq(Rand(33)), testRelation)
+    val plan = RepartitionByExpression(Seq(Rand(33)), testRelation, numPartitions = 10)
     val projected = Alias(Rand(33), "_nondeterministic")()
     val expected =
       Project(testRelation.output,
         RepartitionByExpression(Seq(projected.toAttribute),
-          Project(testRelation.output :+ projected, testRelation)))
+          Project(testRelation.output :+ projected, testRelation),
+          numPartitions = 10))
     checkAnalysis(plan, expected)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/dc005ed5/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
index 2c14252..67d5d22 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
@@ -152,10 +152,7 @@ class PlanParserSuite extends PlanTest {
     val orderSortDistrClusterClauses = Seq(
       ("", basePlan),
       (" order by a, b desc", basePlan.orderBy('a.asc, 'b.desc)),
-      (" sort by a, b desc", basePlan.sortBy('a.asc, 'b.desc)),
-      (" distribute by a, b", basePlan.distribute('a, 'b)()),
-      (" distribute by a sort by b", basePlan.distribute('a)().sortBy('b.asc)),
-      (" cluster by a, b", basePlan.distribute('a, 'b)().sortBy('a.asc, 'b.asc))
+      (" sort by a, b desc", basePlan.sortBy('a.asc, 'b.desc))
     )
 
     orderSortDistrClusterClauses.foreach {

http://git-wip-us.apache.org/repos/asf/spark/blob/dc005ed5/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 38a24cc..1ebc53d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -2410,7 +2410,7 @@ class Dataset[T] private[sql](
    */
   @scala.annotation.varargs
   def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = withTypedPlan {
-    RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, Some(numPartitions))
+    RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, numPartitions)
   }
 
   /**
@@ -2425,7 +2425,8 @@ class Dataset[T] private[sql](
    */
   @scala.annotation.varargs
   def repartition(partitionExprs: Column*): Dataset[T] = withTypedPlan {
-    RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, numPartitions = None)
+    RepartitionByExpression(
+      partitionExprs.map(_.expr), logicalPlan, sparkSession.sessionState.conf.numShufflePartitions)
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/dc005ed5/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
index d508002..1340aeb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
@@ -22,16 +22,17 @@ import scala.collection.JavaConverters._
 import org.antlr.v4.runtime.{ParserRuleContext, Token}
 import org.antlr.v4.runtime.tree.TerminalNode
 
-import org.apache.spark.sql.{AnalysisException, SaveMode}
+import org.apache.spark.sql.SaveMode
 import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
 import org.apache.spark.sql.catalyst.catalog._
+import org.apache.spark.sql.catalyst.expressions.Expression
 import org.apache.spark.sql.catalyst.parser._
 import org.apache.spark.sql.catalyst.parser.SqlBaseParser._
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation, ScriptInputOutputSchema}
+import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.execution.command._
 import org.apache.spark.sql.execution.datasources.{CreateTable, _}
 import org.apache.spark.sql.internal.{HiveSerDe, SQLConf, VariableSubstitution}
-import org.apache.spark.sql.types.{StructField, StructType}
+import org.apache.spark.sql.types.StructType
 
 /**
  * Concrete parser for Spark SQL statements.
@@ -1441,4 +1442,14 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder {
       reader, writer,
       schemaLess)
   }
+
+  /**
+   * Create a clause for DISTRIBUTE BY.
+   */
+  override protected def withRepartitionByExpression(
+      ctx: QueryOrganizationContext,
+      expressions: Seq[Expression],
+      query: LogicalPlan): LogicalPlan = {
+    RepartitionByExpression(expressions, query, conf.numShufflePartitions)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/dc005ed5/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 557181e..0e3d559 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -332,8 +332,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
 
   // Can we automate these 'pass through' operations?
   object BasicOperators extends Strategy {
-    def numPartitions: Int = self.numPartitions
-
     def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
       case r: RunnableCommand => ExecutedCommandExec(r) :: Nil
 
@@ -414,9 +412,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
         execution.RDDScanExec(Nil, singleRowRdd, "OneRowRelation") :: Nil
       case r: logical.Range =>
         execution.RangeExec(r) :: Nil
-      case logical.RepartitionByExpression(expressions, child, nPartitions) =>
+      case logical.RepartitionByExpression(expressions, child, numPartitions) =>
         exchange.ShuffleExchange(HashPartitioning(
-          expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil
+          expressions, numPartitions), planLater(child)) :: Nil
       case ExternalRDD(outputObjAttr, rdd) => ExternalRDDScanExec(outputObjAttr, rdd) :: Nil
       case r: LogicalRDD =>
         RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering) :: Nil

http://git-wip-us.apache.org/repos/asf/spark/blob/dc005ed5/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala
index 15e490f..bb6c486 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala
@@ -19,10 +19,12 @@ package org.apache.spark.sql.execution
 
 import org.apache.spark.sql.SaveMode
 import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar}
 import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType}
+import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder}
 import org.apache.spark.sql.catalyst.parser.ParseException
 import org.apache.spark.sql.catalyst.plans.PlanTest
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, RepartitionByExpression, Sort}
 import org.apache.spark.sql.execution.command._
 import org.apache.spark.sql.execution.datasources.CreateTable
 import org.apache.spark.sql.internal.{HiveSerDe, SQLConf}
@@ -36,7 +38,8 @@ import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType
  */
 class SparkSqlParserSuite extends PlanTest {
 
-  private lazy val parser = new SparkSqlParser(new SQLConf)
+  val newConf = new SQLConf
+  private lazy val parser = new SparkSqlParser(newConf)
 
   /**
    * Normalizes plans:
@@ -251,4 +254,29 @@ class SparkSqlParserSuite extends PlanTest {
     assertEqual("ANALYZE TABLE t COMPUTE STATISTICS FOR COLUMNS key, value",
       AnalyzeColumnCommand(TableIdentifier("t"), Seq("key", "value")))
   }
+
+  test("query organization") {
+    // Test all valid combinations of order by/sort by/distribute by/cluster by/limit/windows
+    val baseSql = "select * from t"
+    val basePlan =
+      Project(Seq(UnresolvedStar(None)), UnresolvedRelation(TableIdentifier("t")))
+
+    assertEqual(s"$baseSql distribute by a, b",
+      RepartitionByExpression(UnresolvedAttribute("a") :: UnresolvedAttribute("b") :: Nil,
+        basePlan,
+        numPartitions = newConf.numShufflePartitions))
+    assertEqual(s"$baseSql distribute by a sort by b",
+      Sort(SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil,
+        global = false,
+        RepartitionByExpression(UnresolvedAttribute("a") :: Nil,
+          basePlan,
+          numPartitions = newConf.numShufflePartitions)))
+    assertEqual(s"$baseSql cluster by a, b",
+      Sort(SortOrder(UnresolvedAttribute("a"), Ascending) ::
+          SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil,
+        global = false,
+        RepartitionByExpression(UnresolvedAttribute("a") :: UnresolvedAttribute("b") :: Nil,
+          basePlan,
+          numPartitions = newConf.numShufflePartitions)))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org