You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@uniffle.apache.org by ro...@apache.org on 2022/08/22 09:04:33 UTC

[incubator-uniffle] branch master updated: [BUGFIX] Fix resource leak when shuffle read (#174)

This is an automated email from the ASF dual-hosted git repository.

roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new df4163fd [BUGFIX] Fix resource leak when shuffle read (#174)
df4163fd is described below

commit df4163fd4e0970d420678bcb77ebe5bc4f79a9f6
Author: Chen Zhang <67...@users.noreply.github.com>
AuthorDate: Mon Aug 22 17:04:28 2022 +0800

    [BUGFIX] Fix resource leak when shuffle read (#174)
    
    ### What changes were proposed in this pull request?
    Use `org.apache.spark.TaskContext#addTaskCompletionListener` to clean up resources used by `RssShuffleDataIterator`. This avoids possible resource leaks.
    
    ### Why are the changes needed?
    Before this PR, `RssShuffleDataIterator` would only clean up used resources after all records read.
    
    When the `Spark Task` fails or cancels, or runs some special logic such as `LocalLimit`, the resource will not be cleaned up. This creates potential resource leaks.
    
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Added a UT case
    
    Co-authored-by: zhangchen351 <zh...@jd.com>
---
 .../spark/shuffle/reader/RssShuffleDataIterator.java   | 14 +++++++++++---
 .../shuffle/reader/RssShuffleDataIteratorTest.java     | 15 +++++++++++++++
 .../apache/spark/shuffle/reader/RssShuffleReader.java  | 18 ++++++++++++++----
 .../apache/spark/shuffle/reader/RssShuffleReader.java  | 17 +++++++++++++----
 4 files changed, 53 insertions(+), 11 deletions(-)

diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssShuffleDataIterator.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssShuffleDataIterator.java
index bd8184c5..23e03641 100644
--- a/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssShuffleDataIterator.java
+++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssShuffleDataIterator.java
@@ -34,6 +34,7 @@ import scala.Product2;
 import scala.Tuple2;
 import scala.collection.AbstractIterator;
 import scala.collection.Iterator;
+import scala.runtime.BoxedUnit;
 
 import org.apache.uniffle.client.api.ShuffleReadClient;
 import org.apache.uniffle.client.response.CompressedShuffleBlock;
@@ -130,9 +131,7 @@ public class RssShuffleDataIterator<K, C> extends AbstractIterator<Product2<K, C
         readTime += fetchDuration;
         serializeTime += serializationDuration;
       } else {
-        // finish reading records, close related reader and check data consistent
-        clearDeserializationStream();
-        shuffleReadClient.close();
+        // finish reading records, check data consistent
         shuffleReadClient.checkProcessedBlockIds();
         shuffleReadClient.logStatics();
         LOG.info("Fetch " + shuffleReadMetrics.remoteBytesRead() + " bytes cost " + readTime + " ms and "
@@ -150,6 +149,15 @@ public class RssShuffleDataIterator<K, C> extends AbstractIterator<Product2<K, C
     return (Product2<K, C>) recordsIterator.next();
   }
 
+  public BoxedUnit cleanup() {
+    clearDeserializationStream();
+    if (shuffleReadClient != null) {
+      shuffleReadClient.close();
+    }
+    shuffleReadClient = null;
+    return BoxedUnit.UNIT;
+  }
+
   @VisibleForTesting
   protected ShuffleReadMetrics getShuffleReadMetrics() {
     return shuffleReadMetrics;
diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/RssShuffleDataIteratorTest.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/RssShuffleDataIteratorTest.java
index 50c05033..12f5a318 100644
--- a/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/RssShuffleDataIteratorTest.java
+++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/RssShuffleDataIteratorTest.java
@@ -24,6 +24,7 @@ import static org.mockito.ArgumentMatchers.any;
 
 import com.google.common.collect.Lists;
 import com.google.common.collect.Maps;
+import org.apache.uniffle.client.api.ShuffleReadClient;
 import org.apache.uniffle.client.impl.ShuffleReadClientImpl;
 import org.apache.uniffle.client.util.ClientUtils;
 import org.apache.uniffle.client.util.DefaultIdHelper;
@@ -46,6 +47,11 @@ import org.mockito.MockedStatic;
 import org.mockito.Mockito;
 import org.roaringbitmap.longlong.Roaring64NavigableMap;
 
+import static org.mockito.Mockito.doNothing;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+
 public class RssShuffleDataIteratorTest extends AbstractRssReaderTest {
 
   private static final Serializer KRYO_SERIALIZER = new KryoSerializer(new SparkConf(false));
@@ -235,4 +241,13 @@ public class RssShuffleDataIteratorTest extends AbstractRssReaderTest {
     }
   }
 
+  @Test
+  public void cleanup() throws Exception {
+    ShuffleReadClient mockClient = mock(ShuffleReadClient.class);
+    doNothing().when(mockClient).close();
+    RssShuffleDataIterator dataIterator = new RssShuffleDataIterator(KRYO_SERIALIZER, mockClient, new ShuffleReadMetrics());
+    dataIterator.cleanup();
+    verify(mockClient, times(1)).close();
+  }
+
 }
diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
index 3130521c..ef97bea3 100644
--- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
+++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
@@ -26,6 +26,7 @@ import org.apache.spark.TaskContext;
 import org.apache.spark.serializer.Serializer;
 import org.apache.spark.shuffle.RssShuffleHandle;
 import org.apache.spark.shuffle.ShuffleReader;
+import org.apache.spark.util.CompletionIterator;
 import org.apache.spark.util.CompletionIterator$;
 import org.apache.spark.util.TaskCompletionListener;
 import org.apache.spark.util.collection.ExternalSorter;
@@ -113,6 +114,16 @@ public class RssShuffleReader<K, C> implements ShuffleReader<K, C> {
     RssShuffleDataIterator rssShuffleDataIterator = new RssShuffleDataIterator<K, C>(
         shuffleDependency.serializer(), shuffleReadClient,
         context.taskMetrics().shuffleReadMetrics());
+    CompletionIterator completionIterator =
+        CompletionIterator$.MODULE$.apply(rssShuffleDataIterator, new AbstractFunction0<BoxedUnit>() {
+          @Override
+          public BoxedUnit apply() {
+            return rssShuffleDataIterator.cleanup();
+          }
+        });
+    context.addTaskCompletionListener(context -> {
+      completionIterator.completion();
+    });
 
     Iterator<Product2<K, C>> resultIter = null;
     Iterator<Product2<K, C>> aggregatedIter = null;
@@ -120,16 +131,15 @@ public class RssShuffleReader<K, C> implements ShuffleReader<K, C> {
     if (shuffleDependency.aggregator().isDefined()) {
       if (shuffleDependency.mapSideCombine()) {
         // We are reading values that are already combined
-        aggregatedIter = shuffleDependency.aggregator().get().combineCombinersByKey(
-            rssShuffleDataIterator, context);
+        aggregatedIter = shuffleDependency.aggregator().get().combineCombinersByKey(completionIterator, context);
       } else {
         // We don't know the value type, but also don't care -- the dependency *should*
         // have made sure its compatible w/ this aggregator, which will convert the value
         // type to the combined type C
-        aggregatedIter = shuffleDependency.aggregator().get().combineValuesByKey(rssShuffleDataIterator, context);
+        aggregatedIter = shuffleDependency.aggregator().get().combineValuesByKey(completionIterator, context);
       }
     } else {
-      aggregatedIter = rssShuffleDataIterator;
+      aggregatedIter = completionIterator;
     }
 
     if (shuffleDependency.keyOrdering().isDefined()) {
diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
index 6d6025aa..a565cfe4 100644
--- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
+++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
@@ -30,6 +30,7 @@ import org.apache.spark.executor.ShuffleReadMetrics;
 import org.apache.spark.serializer.Serializer;
 import org.apache.spark.shuffle.RssShuffleHandle;
 import org.apache.spark.shuffle.ShuffleReader;
+import org.apache.spark.util.CompletionIterator;
 import org.apache.spark.util.CompletionIterator$;
 import org.apache.spark.util.collection.ExternalSorter;
 import org.roaringbitmap.longlong.Roaring64NavigableMap;
@@ -183,11 +184,11 @@ public class RssShuffleReader<K, C> implements ShuffleReader<K, C> {
   }
 
   class MultiPartitionIterator<K, C> extends AbstractIterator<Product2<K, C>> {
-    java.util.Iterator<RssShuffleDataIterator> iterator;
-    RssShuffleDataIterator dataIterator;
+    java.util.Iterator<CompletionIterator<Product2<K, C>, RssShuffleDataIterator<K, C>>> iterator;
+    CompletionIterator<Product2<K, C>, RssShuffleDataIterator<K, C>>  dataIterator;
 
     MultiPartitionIterator() {
-      List<RssShuffleDataIterator> iterators = Lists.newArrayList();
+      List<CompletionIterator<Product2<K, C>, RssShuffleDataIterator<K, C>>> iterators = Lists.newArrayList();
       for (int partition = startPartition; partition < endPartition; partition++) {
         if (partitionToExpectBlocks.get(partition).isEmpty()) {
           LOG.info("{} partition is empty partition", partition);
@@ -201,13 +202,21 @@ public class RssShuffleReader<K, C> implements ShuffleReader<K, C> {
         RssShuffleDataIterator iterator = new RssShuffleDataIterator<K, C>(
             shuffleDependency.serializer(), shuffleReadClient,
             readMetrics);
-        iterators.add(iterator);
+        CompletionIterator<Product2<K, C>, RssShuffleDataIterator<K, C>> completionIterator =
+            CompletionIterator$.MODULE$.apply(iterator, () -> iterator.cleanup());
+        iterators.add(completionIterator);
       }
       iterator = iterators.iterator();
       if (iterator.hasNext()) {
         dataIterator = iterator.next();
         iterator.remove();
       }
+      context.addTaskCompletionListener((taskContext) -> {
+        if (dataIterator != null) {
+          dataIterator.completion();
+        }
+        iterator.forEachRemaining(ci -> ci.completion());
+      });
     }
 
     @Override