You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yh...@apache.org on 2016/02/29 21:59:54 UTC

spark git commit: [SPARK-13123][SQL] Implement whole state codegen for sort

Repository: spark
Updated Branches:
  refs/heads/master 644dbb641 -> 4bd697da0


[SPARK-13123][SQL] Implement whole state codegen for sort

## What changes were proposed in this pull request?
This PR adds support for implementing whole state codegen for sort. Builds heaving on nongli 's PR: https://github.com/apache/spark/pull/11008 (which actually implements the feature), and adds the following changes on top:

- [x]  Generated code updates peak execution memory metrics
- [x]  Unit tests in `WholeStageCodegenSuite` and `SQLMetricsSuite`

## How was this patch tested?

New unit tests in `WholeStageCodegenSuite` and `SQLMetricsSuite`. Further, all existing sort tests should pass.

Author: Sameer Agarwal <sa...@databricks.com>
Author: Nong Li <no...@databricks.com>

Closes #11359 from sameeragarwal/sort-codegen.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4bd697da
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4bd697da
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4bd697da

Branch: refs/heads/master
Commit: 4bd697da03079c26fd4409dc128dbff28c737701
Parents: 644dbb6
Author: Sameer Agarwal <sa...@databricks.com>
Authored: Mon Feb 29 12:59:46 2016 -0800
Committer: Yin Huai <yh...@databricks.com>
Committed: Mon Feb 29 12:59:46 2016 -0800

----------------------------------------------------------------------
 .../sql/execution/UnsafeExternalRowSorter.java  |   9 +-
 .../org/apache/spark/sql/execution/Sort.scala   | 124 +++++++++++++++----
 .../spark/sql/execution/WholeStageCodegen.scala |   8 +-
 .../sql/execution/WholeStageCodegenSuite.scala  |   9 ++
 .../sql/execution/metric/SQLMetricsSuite.scala  |   7 ++
 5 files changed, 122 insertions(+), 35 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/4bd697da/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
index 27ae62f..0ad0f49 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
@@ -36,7 +36,7 @@ import org.apache.spark.util.collection.unsafe.sort.RecordComparator;
 import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter;
 import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterIterator;
 
-final class UnsafeExternalRowSorter {
+public final class UnsafeExternalRowSorter {
 
   /**
    * If positive, forces records to be spilled to disk at the given frequency (measured in numbers
@@ -84,8 +84,7 @@ final class UnsafeExternalRowSorter {
     testSpillFrequency = frequency;
   }
 
-  @VisibleForTesting
-  void insertRow(UnsafeRow row) throws IOException {
+  public void insertRow(UnsafeRow row) throws IOException {
     final long prefix = prefixComputer.computePrefix(row);
     sorter.insertRecord(
       row.getBaseObject(),
@@ -110,8 +109,7 @@ final class UnsafeExternalRowSorter {
     sorter.cleanupResources();
   }
 
-  @VisibleForTesting
-  Iterator<UnsafeRow> sort() throws IOException {
+  public Iterator<UnsafeRow> sort() throws IOException {
     try {
       final UnsafeSorterIterator sortedIterator = sorter.getSortedIterator();
       if (!sortedIterator.hasNext()) {
@@ -160,7 +158,6 @@ final class UnsafeExternalRowSorter {
     }
   }
 
-
   public Iterator<UnsafeRow> sort(Iterator<UnsafeRow> inputIterator) throws IOException {
     while (inputIterator.hasNext()) {
       insertRow(inputIterator.next());

http://git-wip-us.apache.org/repos/asf/spark/blob/4bd697da/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala
index 75cb6d1..2ea889e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala
@@ -17,10 +17,12 @@
 
 package org.apache.spark.sql.execution
 
-import org.apache.spark.{InternalAccumulator, SparkEnv, TaskContext}
+import org.apache.spark.{SparkEnv, TaskContext}
+import org.apache.spark.executor.TaskMetrics
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection}
 import org.apache.spark.sql.catalyst.plans.physical.{Distribution, OrderedDistribution, UnspecifiedDistribution}
 import org.apache.spark.sql.execution.metric.SQLMetrics
 
@@ -37,7 +39,7 @@ case class Sort(
     global: Boolean,
     child: SparkPlan,
     testSpillFrequency: Int = 0)
-  extends UnaryNode {
+  extends UnaryNode with CodegenSupport {
 
   override def output: Seq[Attribute] = child.output
 
@@ -50,34 +52,36 @@ case class Sort(
     "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
     "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"))
 
-  protected override def doExecute(): RDD[InternalRow] = {
-    val schema = child.schema
-    val childOutput = child.output
+  def createSorter(): UnsafeExternalRowSorter = {
+    val ordering = newOrdering(sortOrder, output)
+
+    // The comparator for comparing prefix
+    val boundSortExpression = BindReferences.bindReference(sortOrder.head, output)
+    val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression)
+
+    // The generator for prefix
+    val prefixProjection = UnsafeProjection.create(Seq(SortPrefix(boundSortExpression)))
+    val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer {
+      override def computePrefix(row: InternalRow): Long = {
+        prefixProjection.apply(row).getLong(0)
+      }
+    }
 
+    val pageSize = SparkEnv.get.memoryManager.pageSizeBytes
+    val sorter = new UnsafeExternalRowSorter(
+      schema, ordering, prefixComparator, prefixComputer, pageSize)
+    if (testSpillFrequency > 0) {
+      sorter.setTestSpillFrequency(testSpillFrequency)
+    }
+    sorter
+  }
+
+  protected override def doExecute(): RDD[InternalRow] = {
     val dataSize = longMetric("dataSize")
     val spillSize = longMetric("spillSize")
 
     child.execute().mapPartitionsInternal { iter =>
-      val ordering = newOrdering(sortOrder, childOutput)
-
-      // The comparator for comparing prefix
-      val boundSortExpression = BindReferences.bindReference(sortOrder.head, childOutput)
-      val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression)
-
-      // The generator for prefix
-      val prefixProjection = UnsafeProjection.create(Seq(SortPrefix(boundSortExpression)))
-      val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer {
-        override def computePrefix(row: InternalRow): Long = {
-          prefixProjection.apply(row).getLong(0)
-        }
-      }
-
-      val pageSize = SparkEnv.get.memoryManager.pageSizeBytes
-      val sorter = new UnsafeExternalRowSorter(
-        schema, ordering, prefixComparator, prefixComputer, pageSize)
-      if (testSpillFrequency > 0) {
-        sorter.setTestSpillFrequency(testSpillFrequency)
-      }
+      val sorter = createSorter()
 
       val metrics = TaskContext.get().taskMetrics()
       // Remember spill data size of this task before execute this operator so that we can
@@ -93,4 +97,74 @@ case class Sort(
       sortedIterator
     }
   }
+
+  override def upstreams(): Seq[RDD[InternalRow]] = {
+    child.asInstanceOf[CodegenSupport].upstreams()
+  }
+
+  // Name of sorter variable used in codegen.
+  private var sorterVariable: String = _
+
+  override protected def doProduce(ctx: CodegenContext): String = {
+    val needToSort = ctx.freshName("needToSort")
+    ctx.addMutableState("boolean", needToSort, s"$needToSort = true;")
+
+
+    // Initialize the class member variables. This includes the instance of the Sorter and
+    // the iterator to return sorted rows.
+    val thisPlan = ctx.addReferenceObj("plan", this)
+    sorterVariable = ctx.freshName("sorter")
+    ctx.addMutableState(classOf[UnsafeExternalRowSorter].getName, sorterVariable,
+      s"$sorterVariable = $thisPlan.createSorter();")
+    val metrics = ctx.freshName("metrics")
+    ctx.addMutableState(classOf[TaskMetrics].getName, metrics,
+      s"$metrics = org.apache.spark.TaskContext.get().taskMetrics();")
+    val sortedIterator = ctx.freshName("sortedIter")
+    ctx.addMutableState("scala.collection.Iterator<UnsafeRow>", sortedIterator, "")
+
+    val addToSorter = ctx.freshName("addToSorter")
+    ctx.addNewFunction(addToSorter,
+      s"""
+        | private void $addToSorter() throws java.io.IOException {
+        |   ${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
+        | }
+      """.stripMargin.trim)
+
+    val outputRow = ctx.freshName("outputRow")
+    val dataSize = metricTerm(ctx, "dataSize")
+    val spillSize = metricTerm(ctx, "spillSize")
+    val spillSizeBefore = ctx.freshName("spillSizeBefore")
+    s"""
+       | if ($needToSort) {
+       |   $addToSorter();
+       |   Long $spillSizeBefore = $metrics.memoryBytesSpilled();
+       |   $sortedIterator = $sorterVariable.sort();
+       |   $dataSize.add($sorterVariable.getPeakMemoryUsage());
+       |   $spillSize.add($metrics.memoryBytesSpilled() - $spillSizeBefore);
+       |   $metrics.incPeakExecutionMemory($sorterVariable.getPeakMemoryUsage());
+       |   $needToSort = false;
+       | }
+       |
+       | while ($sortedIterator.hasNext()) {
+       |   UnsafeRow $outputRow = (UnsafeRow)$sortedIterator.next();
+       |   ${consume(ctx, null, outputRow)}
+       |   if (shouldStop()) return;
+       | }
+     """.stripMargin.trim
+  }
+
+  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+    val colExprs = child.output.zipWithIndex.map { case (attr, i) =>
+      BoundReference(i, attr.dataType, attr.nullable)
+    }
+
+    ctx.currentVars = input
+    val code = GenerateUnsafeProjection.createCode(ctx, colExprs)
+
+    s"""
+       | // Convert the input attributes to an UnsafeRow and add it to the sorter
+       | ${code.code}
+       | $sorterVariable.insertRow(${code.value});
+     """.stripMargin.trim
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/4bd697da/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index afaddcf..cb68ca6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -287,7 +287,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
           ${code.trim}
         }
       }
-      """
+      """.trim
 
     // try to compile, helpful for debug
     val cleanedSource = CodeFormatter.stripExtraNewLines(source)
@@ -338,7 +338,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
       // There is an UnsafeRow already
       s"""
          |append($row.copy());
-       """.stripMargin
+       """.stripMargin.trim
     } else {
       assert(input != null)
       if (input.nonEmpty) {
@@ -351,12 +351,12 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
         s"""
            |${code.code.trim}
            |append(${code.value}.copy());
-         """.stripMargin
+         """.stripMargin.trim
       } else {
         // There is no columns
         s"""
            |append(unsafeRow);
-         """.stripMargin
+         """.stripMargin.trim
       }
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/4bd697da/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index 9350205..de371d8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -69,4 +69,13 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
         p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[BroadcastHashJoin]).isDefined)
     assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, "2")))
   }
+
+  test("Sort should be included in WholeStageCodegen") {
+    val df = sqlContext.range(3, 0, -1).sort(col("id"))
+    val plan = df.queryExecution.executedPlan
+    assert(plan.find(p =>
+      p.isInstanceOf[WholeStageCodegen] &&
+        p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[Sort]).isDefined)
+    assert(df.collect() === Array(Row(1), Row(2), Row(3)))
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/4bd697da/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index c49f243..5b4f6f1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -154,6 +154,13 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
     )
   }
 
+  test("Sort metrics") {
+    // Assume the execution plan is
+    // WholeStageCodegen(nodeId = 0, Range(nodeId = 2) -> Sort(nodeId = 1))
+    val df = sqlContext.range(10).sort('id)
+    testSparkPlanMetrics(df, 2, Map.empty)
+  }
+
   test("SortMergeJoin metrics") {
     // Because SortMergeJoin may skip different rows if the number of partitions is different, this
     // test should use the deterministic number of partitions.


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org