You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2020/09/08 03:32:24 UTC

[spark] branch branch-2.4 updated: [SPARK-31511][SQL][2.4] Make BytesToBytesMap iterators thread-safe

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-2.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-2.4 by this push:
     new 277ccba  [SPARK-31511][SQL][2.4] Make BytesToBytesMap iterators thread-safe
277ccba is described below

commit 277ccba382ac19debeaf15cd648cf6ab69603012
Author: sychen <sy...@ctrip.com>
AuthorDate: Tue Sep 8 03:23:59 2020 +0000

    [SPARK-31511][SQL][2.4] Make BytesToBytesMap iterators thread-safe
    
    ### What changes were proposed in this pull request?
    This PR increases the thread safety of the `BytesToBytesMap`:
    - It makes the `iterator()` and `destructiveIterator()` methods used their own `Location` object. This used to be shared, and this was causing issues when the map was being iterated over in two threads by two different iterators.
    - Removes the `safeIterator()` function. This is not needed anymore.
    - Improves the documentation of a couple of methods w.r.t. thread-safety.
    
    ### Why are the changes needed?
    It is unexpected an iterator shares the object it is returning with all other iterators. This is a violation of the iterator contract, and it causes issues with iterators over a map that are consumed in different threads.
    
    ### Does this PR introduce any user-facing change?
    No
    
    ### How was this patch tested?
    add ut
    
    Closes #29605 from cxzl25/SPARK-31511.
    
    Lead-authored-by: sychen <sy...@ctrip.com>
    Co-authored-by: herman <he...@databricks.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../apache/spark/unsafe/map/BytesToBytesMap.java   | 18 +++++-----
 .../sql/execution/joins/HashedRelationSuite.scala  | 39 ++++++++++++++++++++++
 2 files changed, 49 insertions(+), 8 deletions(-)

diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
index 2b23fbbf..5ab52cc 100644
--- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
+++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -421,11 +421,11 @@ public final class BytesToBytesMap extends MemoryConsumer {
    *
    * For efficiency, all calls to `next()` will return the same {@link Location} object.
    *
-   * If any other lookups or operations are performed on this map while iterating over it, including
-   * `lookup()`, the behavior of the returned iterator is undefined.
+   * The returned iterator is thread-safe. However if the map is modified while iterating over it,
+   * the behavior of the returned iterator is undefined.
    */
   public MapIterator iterator() {
-    return new MapIterator(numValues, loc, false);
+    return new MapIterator(numValues, new Location(), false);
   }
 
   /**
@@ -435,19 +435,20 @@ public final class BytesToBytesMap extends MemoryConsumer {
    *
    * For efficiency, all calls to `next()` will return the same {@link Location} object.
    *
-   * If any other lookups or operations are performed on this map while iterating over it, including
-   * `lookup()`, the behavior of the returned iterator is undefined.
+   * The returned iterator is thread-safe. However if the map is modified while iterating over it,
+   * the behavior of the returned iterator is undefined.
    */
   public MapIterator destructiveIterator() {
     updatePeakMemoryUsed();
-    return new MapIterator(numValues, loc, true);
+    return new MapIterator(numValues, new Location(), true);
   }
 
   /**
    * Looks up a key, and return a {@link Location} handle that can be used to test existence
    * and read/write values.
    *
-   * This function always return the same {@link Location} instance to avoid object allocation.
+   * This function always returns the same {@link Location} instance to avoid object allocation.
+   * This function is not thread-safe.
    */
   public Location lookup(Object keyBase, long keyOffset, int keyLength) {
     safeLookup(keyBase, keyOffset, keyLength, loc,
@@ -459,7 +460,8 @@ public final class BytesToBytesMap extends MemoryConsumer {
    * Looks up a key, and return a {@link Location} handle that can be used to test existence
    * and read/write values.
    *
-   * This function always return the same {@link Location} instance to avoid object allocation.
+   * This function always returns the same {@link Location} instance to avoid object allocation.
+   * This function is not thread-safe.
    */
   public Location lookup(Object keyBase, long keyOffset, int keyLength, int hash) {
     safeLookup(keyBase, keyOffset, keyLength, loc, hash);
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
index d9b34dc..1bdd6fe 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
@@ -341,6 +341,45 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
     assert(java.util.Arrays.equals(os.toByteArray, os2.toByteArray))
   }
 
+  test("SPARK-31511: Make BytesToBytesMap iterators thread-safe") {
+    val ser = sparkContext.env.serializer.newInstance()
+    val key = Seq(BoundReference(0, LongType, false))
+
+    val unsafeProj = UnsafeProjection.create(
+      Seq(BoundReference(0, LongType, false), BoundReference(1, IntegerType, true)))
+    val rows = (0 until 10000).map(i => unsafeProj(InternalRow(Int.int2long(i), i + 1)).copy())
+    val unsafeHashed = UnsafeHashedRelation(rows.iterator, key, 1, mm)
+
+    val os = new ByteArrayOutputStream()
+    val thread1 = new Thread {
+      override def run(): Unit = {
+        val out = new ObjectOutputStream(os)
+        unsafeHashed.asInstanceOf[UnsafeHashedRelation].writeExternal(out)
+        out.flush()
+      }
+    }
+
+    val thread2 = new Thread {
+      override def run(): Unit = {
+        val threadOut = new ObjectOutputStream(new ByteArrayOutputStream())
+        unsafeHashed.asInstanceOf[UnsafeHashedRelation].writeExternal(threadOut)
+        threadOut.flush()
+      }
+    }
+
+    thread1.start()
+    thread2.start()
+    thread1.join()
+    thread2.join()
+
+    val unsafeHashed2 = ser.deserialize[UnsafeHashedRelation](ser.serialize(unsafeHashed))
+    val os2 = new ByteArrayOutputStream()
+    val out2 = new ObjectOutputStream(os2)
+    unsafeHashed2.writeExternal(out2)
+    out2.flush()
+    assert(java.util.Arrays.equals(os.toByteArray, os2.toByteArray))
+  }
+
   // This test require 4G heap to run, should run it manually
   ignore("build HashedRelation that is larger than 1G") {
     val unsafeProj = UnsafeProjection.create(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org