You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@uniffle.apache.org by xi...@apache.org on 2023/06/16 16:46:39 UTC
[incubator-uniffle] branch master updated: [#940] feat: Support columnar shuffle with gluten (#950)
This is an automated email from the ASF dual-hosted git repository.
xianjingfeng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new bb7f9c33 [#940] feat: Support columnar shuffle with gluten (#950)
bb7f9c33 is described below
commit bb7f9c33d4b905ca6ff0a1ec4bdbda341e3d11d0
Author: summaryzb <su...@gmail.com>
AuthorDate: Sat Jun 17 00:46:33 2023 +0800
[#940] feat: Support columnar shuffle with gluten (#950)
### What changes were proposed in this pull request?
support read and write serialized columnar data
### Why are the changes needed?
Fix: #940
### Does this PR introduce any user-facing change?
No.
### How was this patch tested?
UnitTest, it's covered by the existing test
---
.../spark/shuffle/writer/WriteBufferManager.java | 60 ++++++++++++++--------
.../shuffle/writer/WriteBufferManagerTest.java | 20 ++++++++
.../apache/spark/shuffle/RssShuffleManager.java | 2 +-
.../spark/shuffle/reader/RssShuffleReader.java | 9 +++-
.../spark/shuffle/reader/RssShuffleReaderTest.java | 1 +
5 files changed, 67 insertions(+), 25 deletions(-)
diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
index 326ccfba..176434bf 100644
--- a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
+++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
@@ -128,7 +128,6 @@ public class WriteBufferManager extends MemoryConsumer {
super(taskMemoryManager, taskMemoryManager.pageSizeBytes(), MemoryMode.ON_HEAP);
this.bufferSize = bufferManagerOptions.getBufferSize();
this.spillSize = bufferManagerOptions.getBufferSpillThreshold();
- this.instance = serializer.newInstance();
this.buffers = Maps.newHashMap();
this.shuffleId = shuffleId;
this.taskId = taskId;
@@ -141,7 +140,11 @@ public class WriteBufferManager extends MemoryConsumer {
this.requireMemoryInterval = bufferManagerOptions.getRequireMemoryInterval();
this.requireMemoryRetryMax = bufferManagerOptions.getRequireMemoryRetryMax();
this.arrayOutputStream = new WrappedByteArrayOutputStream(serializerBufferSize);
- this.serializeStream = instance.serializeStream(arrayOutputStream);
+ // in columnar shuffle, the serializer here is never used
+ if (serializer != null) {
+ this.instance = serializer.newInstance();
+ this.serializeStream = instance.serializeStream(arrayOutputStream);
+ }
boolean compress = rssConf.getBoolean(RssSparkConfig.SPARK_SHUFFLE_COMPRESS_KEY
.substring(RssSparkConfig.SPARK_RSS_CONFIG_PREFIX.length()),
RssSparkConfig.SPARK_SHUFFLE_COMPRESS_DEFAULT);
@@ -151,26 +154,16 @@ public class WriteBufferManager extends MemoryConsumer {
this.memorySpillTimeoutSec = rssConf.get(RssSparkConfig.RSS_MEMORY_SPILL_TIMEOUT);
}
- public List<ShuffleBlockInfo> addRecord(int partitionId, Object key, Object value) {
- final long start = System.currentTimeMillis();
- arrayOutputStream.reset();
- if (key != null) {
- serializeStream.writeKey(key, ClassTag$.MODULE$.apply(key.getClass()));
- } else {
- serializeStream.writeKey(null, ManifestFactory$.MODULE$.Null());
- }
- if (value != null) {
- serializeStream.writeValue(value, ClassTag$.MODULE$.apply(value.getClass()));
- } else {
- serializeStream.writeValue(null, ManifestFactory$.MODULE$.Null());
- }
- serializeStream.flush();
- serializeTime += System.currentTimeMillis() - start;
- byte[] serializedData = arrayOutputStream.getBuf();
- int serializedDataLength = arrayOutputStream.size();
- if (serializedDataLength == 0) {
- return null;
- }
+ /**
+ * add serialized columnar data directly when integrate with gluten
+ */
+ public List<ShuffleBlockInfo> addPartitionData(int partitionId, byte[] serializedData) {
+ return addPartitionData(
+ partitionId, serializedData, serializedData.length, System.currentTimeMillis());
+ }
+
+ public List<ShuffleBlockInfo> addPartitionData(
+ int partitionId, byte[] serializedData, int serializedDataLength, long start) {
List<ShuffleBlockInfo> result = Lists.newArrayList();
if (buffers.containsKey(partitionId)) {
WriterBuffer wb = buffers.get(partitionId);
@@ -202,6 +195,29 @@ public class WriteBufferManager extends MemoryConsumer {
return result;
}
+ public List<ShuffleBlockInfo> addRecord(int partitionId, Object key, Object value) {
+ final long start = System.currentTimeMillis();
+ arrayOutputStream.reset();
+ if (key != null) {
+ serializeStream.writeKey(key, ClassTag$.MODULE$.apply(key.getClass()));
+ } else {
+ serializeStream.writeKey(null, ManifestFactory$.MODULE$.Null());
+ }
+ if (value != null) {
+ serializeStream.writeValue(value, ClassTag$.MODULE$.apply(value.getClass()));
+ } else {
+ serializeStream.writeValue(null, ManifestFactory$.MODULE$.Null());
+ }
+ serializeStream.flush();
+ serializeTime += System.currentTimeMillis() - start;
+ byte[] serializedData = arrayOutputStream.getBuf();
+ int serializedDataLength = arrayOutputStream.size();
+ if (serializedDataLength == 0) {
+ return null;
+ }
+ return addPartitionData(partitionId, serializedData, serializedDataLength, start);
+ }
+
// transform all [partition, records] to [partition, ShuffleBlockInfo] and clear cache
public synchronized List<ShuffleBlockInfo> clear() {
List<ShuffleBlockInfo> result = Lists.newArrayList();
diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
index 44cd9124..9535f15d 100644
--- a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
+++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
@@ -177,6 +177,26 @@ public class WriteBufferManagerTest {
assertEquals(1, wbm.getBuffers().size());
}
+ @Test
+ public void addPartitionDataTest() {
+ SparkConf conf = getConf();
+ TaskMemoryManager mockTaskMemoryManager = mock(TaskMemoryManager.class);
+ BufferManagerOptions bufferOptions = new BufferManagerOptions(conf);
+ WriteBufferManager wbm = new WriteBufferManager(
+ 0, 0, bufferOptions, null,
+ Maps.newHashMap(), mockTaskMemoryManager, new ShuffleWriteMetrics(), RssSparkConfig.toRssConf(conf));
+ WriteBufferManager spyManager = spy(wbm);
+ doReturn(512L).when(spyManager).acquireMemory(anyLong());
+
+ List<ShuffleBlockInfo> shuffleBlockInfos = spyManager.addPartitionData(0, new byte[64]);
+ assertEquals(1, spyManager.getBuffers().size());
+ assertEquals(0, shuffleBlockInfos.size());
+ shuffleBlockInfos = spyManager.addPartitionData(0, new byte[64]);
+ assertEquals(0, spyManager.getBuffers().size());
+ assertEquals(1, shuffleBlockInfos.size());
+ assertEquals(128, shuffleBlockInfos.get(0).getUncompressLength());
+ }
+
@Test
public void createBlockIdTest() {
SparkConf conf = getConf();
diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index d0892ce6..aed994a8 100644
--- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -474,7 +474,7 @@ public class RssShuffleManager extends RssShuffleManagerBase {
if (!(handle instanceof RssShuffleHandle)) {
throw new RssException("Unexpected ShuffleHandle:" + handle.getClass().getName());
}
- RssShuffleHandle<K, C, ?> rssShuffleHandle = (RssShuffleHandle<K, C, ?>) handle;
+ RssShuffleHandle<K, ?, C> rssShuffleHandle = (RssShuffleHandle<K, ?, C>) handle;
final int partitionNum = rssShuffleHandle.getDependency().partitioner().numPartitions();
int shuffleId = rssShuffleHandle.getShuffleId();
Map<Integer, List<ShuffleServerInfo>> allPartitionToServers = rssShuffleHandle.getPartitionToServers();
diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
index b9c8e36f..6cd4a13d 100644
--- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
+++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
@@ -65,7 +65,8 @@ public class RssShuffleReader<K, C> implements ShuffleReader<K, C> {
private int startPartition;
private int endPartition;
private TaskContext context;
- private ShuffleDependency<K, C, ?> shuffleDependency;
+ private ShuffleDependency<K, ?, C> shuffleDependency;
+ private int numMaps;
private Serializer serializer;
private String taskId;
private String basePath;
@@ -85,7 +86,7 @@ public class RssShuffleReader<K, C> implements ShuffleReader<K, C> {
int mapStartIndex,
int mapEndIndex,
TaskContext context,
- RssShuffleHandle<K, C, ?> rssShuffleHandle,
+ RssShuffleHandle<K, ?, C> rssShuffleHandle,
String basePath,
Configuration hadoopConf,
int partitionNum,
@@ -100,6 +101,7 @@ public class RssShuffleReader<K, C> implements ShuffleReader<K, C> {
this.mapStartIndex = mapStartIndex;
this.mapEndIndex = mapEndIndex;
this.context = context;
+ this.numMaps = rssShuffleHandle.getNumMaps();
this.shuffleDependency = rssShuffleHandle.getDependency();
this.shuffleId = shuffleDependency.shuffleId();
this.serializer = rssShuffleHandle.getDependency().serializer();
@@ -205,6 +207,9 @@ public class RssShuffleReader<K, C> implements ShuffleReader<K, C> {
MultiPartitionIterator() {
List<CompletionIterator<Product2<K, C>, RssShuffleDataIterator<K, C>>> iterators = Lists.newArrayList();
+ if (numMaps <= 0) {
+ return;
+ }
for (int partition = startPartition; partition < endPartition; partition++) {
if (partitionToExpectBlocks.get(partition).isEmpty()) {
LOG.info("{} partition is empty partition", partition);
diff --git a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
index 1bbd5def..268c75a2 100644
--- a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
+++ b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
@@ -73,6 +73,7 @@ public class RssShuffleReaderTest extends AbstractRssReaderTest {
when(handleMock.getAppId()).thenReturn("appId");
when(handleMock.getDependency()).thenReturn(dependencyMock);
when(handleMock.getShuffleId()).thenReturn(1);
+ when(handleMock.getNumMaps()).thenReturn(1);
Map<Integer, List<ShuffleServerInfo>> partitionToServers = new HashMap<>();
partitionToServers.put(0, Lists.newArrayList(ssi));
partitionToServers.put(1, Lists.newArrayList(ssi));