You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2016/11/29 01:01:08 UTC

spark git commit: [SPARK-18403][SQL] Fix unsafe data false sharing issue in ObjectHashAggregateExec

Repository: spark
Updated Branches:
  refs/heads/master 05f7c6ffa -> 2e809903d


[SPARK-18403][SQL] Fix unsafe data false sharing issue in ObjectHashAggregateExec

## What changes were proposed in this pull request?

This PR fixes a random OOM issue occurred while running `ObjectHashAggregateSuite`.

This issue can be steadily reproduced under the following conditions:

1. The aggregation must be evaluated using `ObjectHashAggregateExec`;
2. There must be an input column whose data type involves `ArrayType` (an input column of `MapType` may even cause SIGSEGV);
3. Sort-based aggregation fallback must be triggered during evaluation.

The root cause is that while falling back to sort-based aggregation, we must sort and feed already evaluated partial aggregation buffers living in the hash map to the sort-based aggregator using an external sorter. However, the underlying mutable byte buffer of `UnsafeRow`s produced by the iterator of the external sorter is reused and may get overwritten when the iterator steps forward. After the last entry is consumed, the byte buffer points to a block of uninitialized memory filled by `5a`. Therefore, while reading an `UnsafeArrayData` out of the `UnsafeRow`, `5a5a5a5a` is treated as array size and triggers a memory allocation for a ridiculously large array and immediately blows up the JVM with an OOM.

To fix this issue, we only need to add `.copy()` accordingly.

## How was this patch tested?

New regression test case added in `ObjectHashAggregateSuite`.

Author: Cheng Lian <li...@databricks.com>

Closes #15976 from liancheng/investigate-oom.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2e809903
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2e809903
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2e809903

Branch: refs/heads/master
Commit: 2e809903d459b5b5aa6fd882b5c4a0c915af4d43
Parents: 05f7c6f
Author: Cheng Lian <li...@databricks.com>
Authored: Tue Nov 29 09:01:03 2016 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Tue Nov 29 09:01:03 2016 +0800

----------------------------------------------------------------------
 .../aggregate/ObjectAggregationIterator.scala   |  11 +-
 .../execution/ObjectHashAggregateSuite.scala    | 164 +++++++++++--------
 2 files changed, 101 insertions(+), 74 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2e809903/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala
index 3c7b9ee..3a7fcf1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala
@@ -262,7 +262,9 @@ class SortBasedAggregator(
           // Firstly, update the aggregation buffer with input rows.
           while (hasNextInput &&
             groupingKeyOrdering.compare(inputIterator.getKey, groupingKey) == 0) {
-            processRow(result.aggregationBuffer, inputIterator.getValue)
+            // Since `inputIterator.getValue` is an `UnsafeRow` whose underlying buffer will be
+            // overwritten when `inputIterator` steps forward, we need to do a deep copy here.
+            processRow(result.aggregationBuffer, inputIterator.getValue.copy())
             hasNextInput = inputIterator.next()
           }
 
@@ -271,7 +273,12 @@ class SortBasedAggregator(
           // be called after calling processRow.
           while (hasNextAggBuffer &&
             groupingKeyOrdering.compare(initialAggBufferIterator.getKey, groupingKey) == 0) {
-            mergeAggregationBuffers(result.aggregationBuffer, initialAggBufferIterator.getValue)
+            mergeAggregationBuffers(
+              result.aggregationBuffer,
+              // Since `inputIterator.getValue` is an `UnsafeRow` whose underlying buffer will be
+              // overwritten when `inputIterator` steps forward, we need to do a deep copy here.
+              initialAggBufferIterator.getValue.copy()
+            )
             hasNextAggBuffer = initialAggBufferIterator.next()
           }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/2e809903/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala
index b7f91d8..9a8d449 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala
@@ -205,23 +205,19 @@ class ObjectHashAggregateSuite
     // A TypedImperativeAggregate function
     val typed = percentile_approx($"c0", 0.5)
 
-    // A Hive UDAF without partial aggregation support
-    val withoutPartial = function("hive_max", $"c1")
-
     // A Spark SQL native aggregate function with partial aggregation support that can be executed
     // by the Tungsten `HashAggregateExec`
-    val withPartialUnsafe = max($"c2")
+    val withPartialUnsafe = max($"c1")
 
     // A Spark SQL native aggregate function with partial aggregation support that can only be
     // executed by the Tungsten `HashAggregateExec`
-    val withPartialSafe = max($"c3")
+    val withPartialSafe = max($"c2")
 
     // A Spark SQL native distinct aggregate function
-    val withDistinct = countDistinct($"c4")
+    val withDistinct = countDistinct($"c3")
 
     val allAggs = Seq(
       "typed" -> typed,
-      "without partial" -> withoutPartial,
       "with partial + unsafe" -> withPartialUnsafe,
       "with partial + safe" -> withPartialSafe,
       "with distinct" -> withDistinct
@@ -276,10 +272,9 @@ class ObjectHashAggregateSuite
     // Generates a random schema for the randomized data generator
     val schema = new StructType()
       .add("c0", numericTypes(random.nextInt(numericTypes.length)), nullable = true)
-      .add("c1", orderedTypes(random.nextInt(orderedTypes.length)), nullable = true)
-      .add("c2", fixedLengthTypes(random.nextInt(fixedLengthTypes.length)), nullable = true)
-      .add("c3", varLenOrderedTypes(random.nextInt(varLenOrderedTypes.length)), nullable = true)
-      .add("c4", allTypes(random.nextInt(allTypes.length)), nullable = true)
+      .add("c1", fixedLengthTypes(random.nextInt(fixedLengthTypes.length)), nullable = true)
+      .add("c2", varLenOrderedTypes(random.nextInt(varLenOrderedTypes.length)), nullable = true)
+      .add("c3", allTypes(random.nextInt(allTypes.length)), nullable = true)
 
     logInfo(
       s"""Using the following random schema to generate all the randomized aggregation tests:
@@ -325,70 +320,67 @@ class ObjectHashAggregateSuite
 
             // Currently Spark SQL doesn't support evaluating distinct aggregate function together
             // with aggregate functions without partial aggregation support.
-            if (!(aggs.contains(withoutPartial) && aggs.contains(withDistinct))) {
-              // TODO Re-enables them after fixing SPARK-18403
-              ignore(
-                s"randomized aggregation test - " +
-                  s"${names.mkString("[", ", ", "]")} - " +
-                  s"${if (withGroupingKeys) "with" else "without"} grouping keys - " +
-                  s"with ${if (emptyInput) "empty" else "non-empty"} input"
-              ) {
-                var expected: Seq[Row] = null
-                var actual1: Seq[Row] = null
-                var actual2: Seq[Row] = null
-
-                // Disables `ObjectHashAggregateExec` to obtain a standard answer
-                withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "false") {
-                  val aggDf = doAggregation(df)
-
-                  if (aggs.intersect(Seq(withoutPartial, withPartialSafe, typed)).nonEmpty) {
-                    assert(containsSortAggregateExec(aggDf))
-                    assert(!containsObjectHashAggregateExec(aggDf))
-                    assert(!containsHashAggregateExec(aggDf))
-                  } else {
-                    assert(!containsSortAggregateExec(aggDf))
-                    assert(!containsObjectHashAggregateExec(aggDf))
-                    assert(containsHashAggregateExec(aggDf))
-                  }
-
-                  expected = aggDf.collect().toSeq
+            test(
+              s"randomized aggregation test - " +
+                s"${names.mkString("[", ", ", "]")} - " +
+                s"${if (withGroupingKeys) "with" else "without"} grouping keys - " +
+                s"with ${if (emptyInput) "empty" else "non-empty"} input"
+            ) {
+              var expected: Seq[Row] = null
+              var actual1: Seq[Row] = null
+              var actual2: Seq[Row] = null
+
+              // Disables `ObjectHashAggregateExec` to obtain a standard answer
+              withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "false") {
+                val aggDf = doAggregation(df)
+
+                if (aggs.intersect(Seq(withPartialSafe, typed)).nonEmpty) {
+                  assert(containsSortAggregateExec(aggDf))
+                  assert(!containsObjectHashAggregateExec(aggDf))
+                  assert(!containsHashAggregateExec(aggDf))
+                } else {
+                  assert(!containsSortAggregateExec(aggDf))
+                  assert(!containsObjectHashAggregateExec(aggDf))
+                  assert(containsHashAggregateExec(aggDf))
                 }
 
-                // Enables `ObjectHashAggregateExec`
-                withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "true") {
-                  val aggDf = doAggregation(df)
-
-                  if (aggs.contains(typed) && !aggs.contains(withoutPartial)) {
-                    assert(!containsSortAggregateExec(aggDf))
-                    assert(containsObjectHashAggregateExec(aggDf))
-                    assert(!containsHashAggregateExec(aggDf))
-                  } else if (aggs.intersect(Seq(withoutPartial, withPartialSafe)).nonEmpty) {
-                    assert(containsSortAggregateExec(aggDf))
-                    assert(!containsObjectHashAggregateExec(aggDf))
-                    assert(!containsHashAggregateExec(aggDf))
-                  } else {
-                    assert(!containsSortAggregateExec(aggDf))
-                    assert(!containsObjectHashAggregateExec(aggDf))
-                    assert(containsHashAggregateExec(aggDf))
-                  }
-
-                  // Disables sort-based aggregation fallback (we only generate 50 rows, so 100 is
-                  // big enough) to obtain a result to be checked.
-                  withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "100") {
-                    actual1 = aggDf.collect().toSeq
-                  }
-
-                  // Enables sort-based aggregation fallback to obtain another result to be checked.
-                  withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "3") {
-                    // Here we are not reusing `aggDf` because the physical plan in `aggDf` is
-                    // cached and won't be re-planned using the new fallback threshold.
-                    actual2 = doAggregation(df).collect().toSeq
-                  }
+                expected = aggDf.collect().toSeq
+              }
+
+              // Enables `ObjectHashAggregateExec`
+              withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "true") {
+                val aggDf = doAggregation(df)
+
+                if (aggs.contains(typed)) {
+                  assert(!containsSortAggregateExec(aggDf))
+                  assert(containsObjectHashAggregateExec(aggDf))
+                  assert(!containsHashAggregateExec(aggDf))
+                } else if (aggs.contains(withPartialSafe)) {
+                  assert(containsSortAggregateExec(aggDf))
+                  assert(!containsObjectHashAggregateExec(aggDf))
+                  assert(!containsHashAggregateExec(aggDf))
+                } else {
+                  assert(!containsSortAggregateExec(aggDf))
+                  assert(!containsObjectHashAggregateExec(aggDf))
+                  assert(containsHashAggregateExec(aggDf))
                 }
 
-                doubleSafeCheckRows(actual1, expected, 1e-4)
-                doubleSafeCheckRows(actual2, expected, 1e-4)
+                // Disables sort-based aggregation fallback (we only generate 50 rows, so 100 is
+                // big enough) to obtain a result to be checked.
+                withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "100") {
+                  actual1 = aggDf.collect().toSeq
+                }
+
+                // Enables sort-based aggregation fallback to obtain another result to be checked.
+                withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "3") {
+                  // Here we are not reusing `aggDf` because the physical plan in `aggDf` is
+                  // cached and won't be re-planned using the new fallback threshold.
+                  actual2 = doAggregation(df).collect().toSeq
+                }
               }
+
+              doubleSafeCheckRows(actual1, expected, 1e-4)
+              doubleSafeCheckRows(actual2, expected, 1e-4)
             }
           }
         }
@@ -425,7 +417,35 @@ class ObjectHashAggregateSuite
     }
   }
 
-  private def function(name: String, args: Column*): Column = {
-    Column(UnresolvedFunction(FunctionIdentifier(name), args.map(_.expr), isDistinct = false))
+  test("SPARK-18403 Fix unsafe data false sharing issue in ObjectHashAggregateExec") {
+    // SPARK-18403: An unsafe data false sharing issue may trigger OOM / SIGSEGV when evaluating
+    // certain aggregate functions. To reproduce this issue, the following conditions must be
+    // met:
+    //
+    //  1. The aggregation must be evaluated using `ObjectHashAggregateExec`;
+    //  2. There must be an input column whose data type involves `ArrayType` or `MapType`;
+    //  3. Sort-based aggregation fallback must be triggered during evaluation.
+    withSQLConf(
+      SQLConf.USE_OBJECT_HASH_AGG.key -> "true",
+      SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "1"
+    ) {
+      checkAnswer(
+        Seq
+          .fill(2)(Tuple1(Array.empty[Int]))
+          .toDF("c0")
+          .groupBy(lit(1))
+          .agg(typed_count($"c0"), max($"c0")),
+        Row(1, 2, Array.empty[Int])
+      )
+
+      checkAnswer(
+        Seq
+          .fill(2)(Tuple1(Map.empty[Int, Int]))
+          .toDF("c0")
+          .groupBy(lit(1))
+          .agg(typed_count($"c0"), first($"c0")),
+        Row(1, 2, Map.empty[Int, Int])
+      )
+    }
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org