You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2017/07/24 17:18:31 UTC

spark git commit: [SPARK-17528][SQL][FOLLOWUP] remove unnecessary data copy in object hash aggregate

Repository: spark
Updated Branches:
  refs/heads/master 481f07929 -> 86664338f


[SPARK-17528][SQL][FOLLOWUP] remove unnecessary data copy in object hash aggregate

## What changes were proposed in this pull request?

In #18483 , we fixed the data copy bug when saving into `InternalRow`, and removed all workarounds for this bug in the aggregate code path. However, the object hash aggregate was missed, this PR fixes it.

This patch is also a requirement for #17419 , which shows that DataFrame version is slower than RDD version because of this issue.

## How was this patch tested?

existing tests

Author: Wenchen Fan <we...@databricks.com>

Closes #18712 from cloud-fan/minor.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/86664338
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/86664338
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/86664338

Branch: refs/heads/master
Commit: 86664338f25f58b2f59db93b68cd57de671a4c0b
Parents: 481f079
Author: Wenchen Fan <we...@databricks.com>
Authored: Mon Jul 24 10:18:28 2017 -0700
Committer: Cheng Lian <li...@databricks.com>
Committed: Mon Jul 24 10:18:28 2017 -0700

----------------------------------------------------------------------
 .../aggregate/ObjectAggregationIterator.scala   | 20 ++++----------------
 1 file changed, 4 insertions(+), 16 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/86664338/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala
index 6e47f9d..eef2c4e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala
@@ -70,10 +70,6 @@ class ObjectAggregationIterator(
     generateProcessRow(newExpressions, newFunctions, newInputAttributes)
   }
 
-  // A safe projection used to do deep clone of input rows to prevent false sharing.
-  private[this] val safeProjection: Projection =
-    FromUnsafeProjection(outputAttributes.map(_.dataType))
-
   /**
    * Start processing input rows.
    */
@@ -151,12 +147,11 @@ class ObjectAggregationIterator(
       val groupingKey = groupingProjection.apply(null)
       val buffer: InternalRow = getAggregationBufferByKey(hashMap, groupingKey)
       while (inputRows.hasNext) {
-        val newInput = safeProjection(inputRows.next())
-        processRow(buffer, newInput)
+        processRow(buffer, inputRows.next())
       }
     } else {
       while (inputRows.hasNext && !sortBased) {
-        val newInput = safeProjection(inputRows.next())
+        val newInput = inputRows.next()
         val groupingKey = groupingProjection.apply(newInput)
         val buffer: InternalRow = getAggregationBufferByKey(hashMap, groupingKey)
         processRow(buffer, newInput)
@@ -266,9 +261,7 @@ class SortBasedAggregator(
           // Firstly, update the aggregation buffer with input rows.
           while (hasNextInput &&
             groupingKeyOrdering.compare(inputIterator.getKey, groupingKey) == 0) {
-            // Since `inputIterator.getValue` is an `UnsafeRow` whose underlying buffer will be
-            // overwritten when `inputIterator` steps forward, we need to do a deep copy here.
-            processRow(result.aggregationBuffer, inputIterator.getValue.copy())
+            processRow(result.aggregationBuffer, inputIterator.getValue)
             hasNextInput = inputIterator.next()
           }
 
@@ -277,12 +270,7 @@ class SortBasedAggregator(
           // be called after calling processRow.
           while (hasNextAggBuffer &&
             groupingKeyOrdering.compare(initialAggBufferIterator.getKey, groupingKey) == 0) {
-            mergeAggregationBuffers(
-              result.aggregationBuffer,
-              // Since `inputIterator.getValue` is an `UnsafeRow` whose underlying buffer will be
-              // overwritten when `inputIterator` steps forward, we need to do a deep copy here.
-              initialAggBufferIterator.getValue.copy()
-            )
+            mergeAggregationBuffers(result.aggregationBuffer, initialAggBufferIterator.getValue)
             hasNextAggBuffer = initialAggBufferIterator.next()
           }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org