You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@uniffle.apache.org by ro...@apache.org on 2022/08/22 09:04:33 UTC
[incubator-uniffle] branch master updated: [BUGFIX] Fix resource leak when shuffle read (#174)
This is an automated email from the ASF dual-hosted git repository.
roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new df4163fd [BUGFIX] Fix resource leak when shuffle read (#174)
df4163fd is described below
commit df4163fd4e0970d420678bcb77ebe5bc4f79a9f6
Author: Chen Zhang <67...@users.noreply.github.com>
AuthorDate: Mon Aug 22 17:04:28 2022 +0800
[BUGFIX] Fix resource leak when shuffle read (#174)
### What changes were proposed in this pull request?
Use `org.apache.spark.TaskContext#addTaskCompletionListener` to clean up resources used by `RssShuffleDataIterator`. This avoids possible resource leaks.
### Why are the changes needed?
Before this PR, `RssShuffleDataIterator` would only clean up used resources after all records read.
When the `Spark Task` fails or cancels, or runs some special logic such as `LocalLimit`, the resource will not be cleaned up. This creates potential resource leaks.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Added a UT case
Co-authored-by: zhangchen351 <zh...@jd.com>
---
.../spark/shuffle/reader/RssShuffleDataIterator.java | 14 +++++++++++---
.../shuffle/reader/RssShuffleDataIteratorTest.java | 15 +++++++++++++++
.../apache/spark/shuffle/reader/RssShuffleReader.java | 18 ++++++++++++++----
.../apache/spark/shuffle/reader/RssShuffleReader.java | 17 +++++++++++++----
4 files changed, 53 insertions(+), 11 deletions(-)
diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssShuffleDataIterator.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssShuffleDataIterator.java
index bd8184c5..23e03641 100644
--- a/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssShuffleDataIterator.java
+++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssShuffleDataIterator.java
@@ -34,6 +34,7 @@ import scala.Product2;
import scala.Tuple2;
import scala.collection.AbstractIterator;
import scala.collection.Iterator;
+import scala.runtime.BoxedUnit;
import org.apache.uniffle.client.api.ShuffleReadClient;
import org.apache.uniffle.client.response.CompressedShuffleBlock;
@@ -130,9 +131,7 @@ public class RssShuffleDataIterator<K, C> extends AbstractIterator<Product2<K, C
readTime += fetchDuration;
serializeTime += serializationDuration;
} else {
- // finish reading records, close related reader and check data consistent
- clearDeserializationStream();
- shuffleReadClient.close();
+ // finish reading records, check data consistent
shuffleReadClient.checkProcessedBlockIds();
shuffleReadClient.logStatics();
LOG.info("Fetch " + shuffleReadMetrics.remoteBytesRead() + " bytes cost " + readTime + " ms and "
@@ -150,6 +149,15 @@ public class RssShuffleDataIterator<K, C> extends AbstractIterator<Product2<K, C
return (Product2<K, C>) recordsIterator.next();
}
+ public BoxedUnit cleanup() {
+ clearDeserializationStream();
+ if (shuffleReadClient != null) {
+ shuffleReadClient.close();
+ }
+ shuffleReadClient = null;
+ return BoxedUnit.UNIT;
+ }
+
@VisibleForTesting
protected ShuffleReadMetrics getShuffleReadMetrics() {
return shuffleReadMetrics;
diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/RssShuffleDataIteratorTest.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/RssShuffleDataIteratorTest.java
index 50c05033..12f5a318 100644
--- a/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/RssShuffleDataIteratorTest.java
+++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/RssShuffleDataIteratorTest.java
@@ -24,6 +24,7 @@ import static org.mockito.ArgumentMatchers.any;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
+import org.apache.uniffle.client.api.ShuffleReadClient;
import org.apache.uniffle.client.impl.ShuffleReadClientImpl;
import org.apache.uniffle.client.util.ClientUtils;
import org.apache.uniffle.client.util.DefaultIdHelper;
@@ -46,6 +47,11 @@ import org.mockito.MockedStatic;
import org.mockito.Mockito;
import org.roaringbitmap.longlong.Roaring64NavigableMap;
+import static org.mockito.Mockito.doNothing;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+
public class RssShuffleDataIteratorTest extends AbstractRssReaderTest {
private static final Serializer KRYO_SERIALIZER = new KryoSerializer(new SparkConf(false));
@@ -235,4 +241,13 @@ public class RssShuffleDataIteratorTest extends AbstractRssReaderTest {
}
}
+ @Test
+ public void cleanup() throws Exception {
+ ShuffleReadClient mockClient = mock(ShuffleReadClient.class);
+ doNothing().when(mockClient).close();
+ RssShuffleDataIterator dataIterator = new RssShuffleDataIterator(KRYO_SERIALIZER, mockClient, new ShuffleReadMetrics());
+ dataIterator.cleanup();
+ verify(mockClient, times(1)).close();
+ }
+
}
diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
index 3130521c..ef97bea3 100644
--- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
+++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
@@ -26,6 +26,7 @@ import org.apache.spark.TaskContext;
import org.apache.spark.serializer.Serializer;
import org.apache.spark.shuffle.RssShuffleHandle;
import org.apache.spark.shuffle.ShuffleReader;
+import org.apache.spark.util.CompletionIterator;
import org.apache.spark.util.CompletionIterator$;
import org.apache.spark.util.TaskCompletionListener;
import org.apache.spark.util.collection.ExternalSorter;
@@ -113,6 +114,16 @@ public class RssShuffleReader<K, C> implements ShuffleReader<K, C> {
RssShuffleDataIterator rssShuffleDataIterator = new RssShuffleDataIterator<K, C>(
shuffleDependency.serializer(), shuffleReadClient,
context.taskMetrics().shuffleReadMetrics());
+ CompletionIterator completionIterator =
+ CompletionIterator$.MODULE$.apply(rssShuffleDataIterator, new AbstractFunction0<BoxedUnit>() {
+ @Override
+ public BoxedUnit apply() {
+ return rssShuffleDataIterator.cleanup();
+ }
+ });
+ context.addTaskCompletionListener(context -> {
+ completionIterator.completion();
+ });
Iterator<Product2<K, C>> resultIter = null;
Iterator<Product2<K, C>> aggregatedIter = null;
@@ -120,16 +131,15 @@ public class RssShuffleReader<K, C> implements ShuffleReader<K, C> {
if (shuffleDependency.aggregator().isDefined()) {
if (shuffleDependency.mapSideCombine()) {
// We are reading values that are already combined
- aggregatedIter = shuffleDependency.aggregator().get().combineCombinersByKey(
- rssShuffleDataIterator, context);
+ aggregatedIter = shuffleDependency.aggregator().get().combineCombinersByKey(completionIterator, context);
} else {
// We don't know the value type, but also don't care -- the dependency *should*
// have made sure its compatible w/ this aggregator, which will convert the value
// type to the combined type C
- aggregatedIter = shuffleDependency.aggregator().get().combineValuesByKey(rssShuffleDataIterator, context);
+ aggregatedIter = shuffleDependency.aggregator().get().combineValuesByKey(completionIterator, context);
}
} else {
- aggregatedIter = rssShuffleDataIterator;
+ aggregatedIter = completionIterator;
}
if (shuffleDependency.keyOrdering().isDefined()) {
diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
index 6d6025aa..a565cfe4 100644
--- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
+++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
@@ -30,6 +30,7 @@ import org.apache.spark.executor.ShuffleReadMetrics;
import org.apache.spark.serializer.Serializer;
import org.apache.spark.shuffle.RssShuffleHandle;
import org.apache.spark.shuffle.ShuffleReader;
+import org.apache.spark.util.CompletionIterator;
import org.apache.spark.util.CompletionIterator$;
import org.apache.spark.util.collection.ExternalSorter;
import org.roaringbitmap.longlong.Roaring64NavigableMap;
@@ -183,11 +184,11 @@ public class RssShuffleReader<K, C> implements ShuffleReader<K, C> {
}
class MultiPartitionIterator<K, C> extends AbstractIterator<Product2<K, C>> {
- java.util.Iterator<RssShuffleDataIterator> iterator;
- RssShuffleDataIterator dataIterator;
+ java.util.Iterator<CompletionIterator<Product2<K, C>, RssShuffleDataIterator<K, C>>> iterator;
+ CompletionIterator<Product2<K, C>, RssShuffleDataIterator<K, C>> dataIterator;
MultiPartitionIterator() {
- List<RssShuffleDataIterator> iterators = Lists.newArrayList();
+ List<CompletionIterator<Product2<K, C>, RssShuffleDataIterator<K, C>>> iterators = Lists.newArrayList();
for (int partition = startPartition; partition < endPartition; partition++) {
if (partitionToExpectBlocks.get(partition).isEmpty()) {
LOG.info("{} partition is empty partition", partition);
@@ -201,13 +202,21 @@ public class RssShuffleReader<K, C> implements ShuffleReader<K, C> {
RssShuffleDataIterator iterator = new RssShuffleDataIterator<K, C>(
shuffleDependency.serializer(), shuffleReadClient,
readMetrics);
- iterators.add(iterator);
+ CompletionIterator<Product2<K, C>, RssShuffleDataIterator<K, C>> completionIterator =
+ CompletionIterator$.MODULE$.apply(iterator, () -> iterator.cleanup());
+ iterators.add(completionIterator);
}
iterator = iterators.iterator();
if (iterator.hasNext()) {
dataIterator = iterator.next();
iterator.remove();
}
+ context.addTaskCompletionListener((taskContext) -> {
+ if (dataIterator != null) {
+ dataIterator.completion();
+ }
+ iterator.forEachRemaining(ci -> ci.completion());
+ });
}
@Override