You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jo...@apache.org on 2015/10/12 03:11:18 UTC

spark git commit: [SPARK-11053] Remove use of KVIterator in SortBasedAggregationIterator

Repository: spark
Updated Branches:
  refs/heads/master a16396df7 -> 595012ea8


[SPARK-11053] Remove use of KVIterator in SortBasedAggregationIterator

SortBasedAggregationIterator uses a KVIterator interface in order to process input rows as key-value pairs, but this use of KVIterator is unnecessary, slightly complicates the code, and might hurt performance. This patch refactors this code to remove the use of this extra layer of iterator wrapping and simplifies other parts of the code in the process.

Author: Josh Rosen <jo...@databricks.com>

Closes #9066 from JoshRosen/sort-iterator-cleanup.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/595012ea
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/595012ea
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/595012ea

Branch: refs/heads/master
Commit: 595012ea8b9c6afcc2fc024d5a5e198df765bd75
Parents: a16396d
Author: Josh Rosen <jo...@databricks.com>
Authored: Sun Oct 11 18:11:08 2015 -0700
Committer: Josh Rosen <jo...@databricks.com>
Committed: Sun Oct 11 18:11:08 2015 -0700

----------------------------------------------------------------------
 .../aggregate/AggregationIterator.scala         | 83 ------------------
 .../aggregate/SortBasedAggregate.scala          | 20 +++--
 .../SortBasedAggregationIterator.scala          | 89 +++++---------------
 3 files changed, 33 insertions(+), 159 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/595012ea/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
index 5f7341e..8e0fbd1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
@@ -21,7 +21,6 @@ import org.apache.spark.Logging
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.unsafe.KVIterator
 
 import scala.collection.mutable.ArrayBuffer
 
@@ -412,85 +411,3 @@ abstract class AggregationIterator(
    */
   protected def newBuffer: MutableRow
 }
-
-object AggregationIterator {
-  def kvIterator(
-    groupingExpressions: Seq[NamedExpression],
-    newProjection: (Seq[Expression], Seq[Attribute]) => Projection,
-    inputAttributes: Seq[Attribute],
-    inputIter: Iterator[InternalRow]): KVIterator[InternalRow, InternalRow] = {
-    new KVIterator[InternalRow, InternalRow] {
-      private[this] val groupingKeyGenerator = newProjection(groupingExpressions, inputAttributes)
-
-      private[this] var groupingKey: InternalRow = _
-
-      private[this] var value: InternalRow = _
-
-      override def next(): Boolean = {
-        if (inputIter.hasNext) {
-          // Read the next input row.
-          val inputRow = inputIter.next()
-          // Get groupingKey based on groupingExpressions.
-          groupingKey = groupingKeyGenerator(inputRow)
-          // The value is the inputRow.
-          value = inputRow
-          true
-        } else {
-          false
-        }
-      }
-
-      override def getKey(): InternalRow = {
-        groupingKey
-      }
-
-      override def getValue(): InternalRow = {
-        value
-      }
-
-      override def close(): Unit = {
-        // Do nothing
-      }
-    }
-  }
-
-  def unsafeKVIterator(
-      groupingExpressions: Seq[NamedExpression],
-      inputAttributes: Seq[Attribute],
-      inputIter: Iterator[InternalRow]): KVIterator[UnsafeRow, InternalRow] = {
-    new KVIterator[UnsafeRow, InternalRow] {
-      private[this] val groupingKeyGenerator =
-        UnsafeProjection.create(groupingExpressions, inputAttributes)
-
-      private[this] var groupingKey: UnsafeRow = _
-
-      private[this] var value: InternalRow = _
-
-      override def next(): Boolean = {
-        if (inputIter.hasNext) {
-          // Read the next input row.
-          val inputRow = inputIter.next()
-          // Get groupingKey based on groupingExpressions.
-          groupingKey = groupingKeyGenerator.apply(inputRow)
-          // The value is the inputRow.
-          value = inputRow
-          true
-        } else {
-          false
-        }
-      }
-
-      override def getKey(): UnsafeRow = {
-        groupingKey
-      }
-
-      override def getValue(): InternalRow = {
-        value
-      }
-
-      override def close(): Unit = {
-        // Do nothing
-      }
-    }
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/595012ea/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
index f4c14a9..4d37106 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
@@ -23,9 +23,8 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution}
-import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan, UnaryNode}
+import org.apache.spark.sql.execution.{SparkPlan, UnaryNode}
 import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.sql.types.StructType
 
 case class SortBasedAggregate(
     requiredChildDistributionExpressions: Option[Seq[Expression]],
@@ -79,18 +78,23 @@ case class SortBasedAggregate(
         // so return an empty iterator.
         Iterator[InternalRow]()
       } else {
-        val outputIter = SortBasedAggregationIterator.createFromInputIterator(
-          groupingExpressions,
+        val groupingKeyProjection = if (UnsafeProjection.canSupport(groupingExpressions)) {
+          UnsafeProjection.create(groupingExpressions, child.output)
+        } else {
+          newMutableProjection(groupingExpressions, child.output)()
+        }
+        val outputIter = new SortBasedAggregationIterator(
+          groupingKeyProjection,
+          groupingExpressions.map(_.toAttribute),
+          child.output,
+          iter,
           nonCompleteAggregateExpressions,
           nonCompleteAggregateAttributes,
           completeAggregateExpressions,
           completeAggregateAttributes,
           initialInputBufferOffset,
           resultExpressions,
-          newMutableProjection _,
-          newProjection _,
-          child.output,
-          iter,
+          newMutableProjection,
           outputsUnsafeRows,
           numInputRows,
           numOutputRows)

http://git-wip-us.apache.org/repos/asf/spark/blob/595012ea/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
index a9e5d17..64c6730 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
@@ -21,16 +21,16 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2, AggregateFunction2}
 import org.apache.spark.sql.execution.metric.LongSQLMetric
-import org.apache.spark.unsafe.KVIterator
 
 /**
  * An iterator used to evaluate [[AggregateFunction2]]. It assumes the input rows have been
  * sorted by values of [[groupingKeyAttributes]].
  */
 class SortBasedAggregationIterator(
+    groupingKeyProjection: InternalRow => InternalRow,
     groupingKeyAttributes: Seq[Attribute],
     valueAttributes: Seq[Attribute],
-    inputKVIterator: KVIterator[InternalRow, InternalRow],
+    inputIterator: Iterator[InternalRow],
     nonCompleteAggregateExpressions: Seq[AggregateExpression2],
     nonCompleteAggregateAttributes: Seq[Attribute],
     completeAggregateExpressions: Seq[AggregateExpression2],
@@ -90,6 +90,22 @@ class SortBasedAggregationIterator(
   // The aggregation buffer used by the sort-based aggregation.
   private[this] val sortBasedAggregationBuffer: MutableRow = newBuffer
 
+  protected def initialize(): Unit = {
+    if (inputIterator.hasNext) {
+      initializeBuffer(sortBasedAggregationBuffer)
+      val inputRow = inputIterator.next()
+      nextGroupingKey = groupingKeyProjection(inputRow).copy()
+      firstRowInNextGroup = inputRow.copy()
+      numInputRows += 1
+      sortedInputHasNewGroup = true
+    } else {
+      // This inputIter is empty.
+      sortedInputHasNewGroup = false
+    }
+  }
+
+  initialize()
+
   /** Processes rows in the current group. It will stop when it find a new group. */
   protected def processCurrentSortedGroup(): Unit = {
     currentGroupingKey = nextGroupingKey
@@ -101,18 +117,15 @@ class SortBasedAggregationIterator(
 
     // The search will stop when we see the next group or there is no
     // input row left in the iter.
-    var hasNext = inputKVIterator.next()
-    while (!findNextPartition && hasNext) {
+    while (!findNextPartition && inputIterator.hasNext) {
       // Get the grouping key.
-      val groupingKey = inputKVIterator.getKey
-      val currentRow = inputKVIterator.getValue
+      val currentRow = inputIterator.next()
+      val groupingKey = groupingKeyProjection(currentRow)
       numInputRows += 1
 
       // Check if the current row belongs the current input row.
       if (currentGroupingKey == groupingKey) {
         processRow(sortBasedAggregationBuffer, currentRow)
-
-        hasNext = inputKVIterator.next()
       } else {
         // We find a new group.
         findNextPartition = true
@@ -149,68 +162,8 @@ class SortBasedAggregationIterator(
     }
   }
 
-  protected def initialize(): Unit = {
-    if (inputKVIterator.next()) {
-      initializeBuffer(sortBasedAggregationBuffer)
-
-      nextGroupingKey = inputKVIterator.getKey().copy()
-      firstRowInNextGroup = inputKVIterator.getValue().copy()
-      numInputRows += 1
-      sortedInputHasNewGroup = true
-    } else {
-      // This inputIter is empty.
-      sortedInputHasNewGroup = false
-    }
-  }
-
-  initialize()
-
   def outputForEmptyGroupingKeyWithoutInput(): InternalRow = {
     initializeBuffer(sortBasedAggregationBuffer)
     generateOutput(new GenericInternalRow(0), sortBasedAggregationBuffer)
   }
 }
-
-object SortBasedAggregationIterator {
-  // scalastyle:off
-  def createFromInputIterator(
-      groupingExprs: Seq[NamedExpression],
-      nonCompleteAggregateExpressions: Seq[AggregateExpression2],
-      nonCompleteAggregateAttributes: Seq[Attribute],
-      completeAggregateExpressions: Seq[AggregateExpression2],
-      completeAggregateAttributes: Seq[Attribute],
-      initialInputBufferOffset: Int,
-      resultExpressions: Seq[NamedExpression],
-      newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
-      newProjection: (Seq[Expression], Seq[Attribute]) => Projection,
-      inputAttributes: Seq[Attribute],
-      inputIter: Iterator[InternalRow],
-      outputsUnsafeRows: Boolean,
-      numInputRows: LongSQLMetric,
-      numOutputRows: LongSQLMetric): SortBasedAggregationIterator = {
-    val kvIterator = if (UnsafeProjection.canSupport(groupingExprs)) {
-      AggregationIterator.unsafeKVIterator(
-        groupingExprs,
-        inputAttributes,
-        inputIter).asInstanceOf[KVIterator[InternalRow, InternalRow]]
-    } else {
-      AggregationIterator.kvIterator(groupingExprs, newProjection, inputAttributes, inputIter)
-    }
-
-    new SortBasedAggregationIterator(
-      groupingExprs.map(_.toAttribute),
-      inputAttributes,
-      kvIterator,
-      nonCompleteAggregateExpressions,
-      nonCompleteAggregateAttributes,
-      completeAggregateExpressions,
-      completeAggregateAttributes,
-      initialInputBufferOffset,
-      resultExpressions,
-      newMutableProjection,
-      outputsUnsafeRows,
-      numInputRows,
-      numOutputRows)
-  }
-  // scalastyle:on
-}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org