You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@uniffle.apache.org by xi...@apache.org on 2023/06/16 16:46:39 UTC

[incubator-uniffle] branch master updated: [#940] feat: Support columnar shuffle with gluten (#950)

This is an automated email from the ASF dual-hosted git repository.

xianjingfeng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new bb7f9c33 [#940] feat: Support columnar shuffle with gluten (#950)
bb7f9c33 is described below

commit bb7f9c33d4b905ca6ff0a1ec4bdbda341e3d11d0
Author: summaryzb <su...@gmail.com>
AuthorDate: Sat Jun 17 00:46:33 2023 +0800

    [#940] feat: Support columnar shuffle with gluten (#950)
    
    ### What changes were proposed in this pull request?
    support read and write serialized columnar data
    
    ### Why are the changes needed?
    Fix: #940
    
    ### Does this PR introduce any user-facing change?
    No.
    
    ### How was this patch tested?
    UnitTest, it's covered by the existing test
---
 .../spark/shuffle/writer/WriteBufferManager.java   | 60 ++++++++++++++--------
 .../shuffle/writer/WriteBufferManagerTest.java     | 20 ++++++++
 .../apache/spark/shuffle/RssShuffleManager.java    |  2 +-
 .../spark/shuffle/reader/RssShuffleReader.java     |  9 +++-
 .../spark/shuffle/reader/RssShuffleReaderTest.java |  1 +
 5 files changed, 67 insertions(+), 25 deletions(-)

diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
index 326ccfba..176434bf 100644
--- a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
+++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
@@ -128,7 +128,6 @@ public class WriteBufferManager extends MemoryConsumer {
     super(taskMemoryManager, taskMemoryManager.pageSizeBytes(), MemoryMode.ON_HEAP);
     this.bufferSize = bufferManagerOptions.getBufferSize();
     this.spillSize = bufferManagerOptions.getBufferSpillThreshold();
-    this.instance = serializer.newInstance();
     this.buffers = Maps.newHashMap();
     this.shuffleId = shuffleId;
     this.taskId = taskId;
@@ -141,7 +140,11 @@ public class WriteBufferManager extends MemoryConsumer {
     this.requireMemoryInterval = bufferManagerOptions.getRequireMemoryInterval();
     this.requireMemoryRetryMax = bufferManagerOptions.getRequireMemoryRetryMax();
     this.arrayOutputStream = new WrappedByteArrayOutputStream(serializerBufferSize);
-    this.serializeStream = instance.serializeStream(arrayOutputStream);
+    // in columnar shuffle, the serializer here is never used
+    if (serializer != null) {
+      this.instance = serializer.newInstance();
+      this.serializeStream = instance.serializeStream(arrayOutputStream);
+    }
     boolean compress = rssConf.getBoolean(RssSparkConfig.SPARK_SHUFFLE_COMPRESS_KEY
             .substring(RssSparkConfig.SPARK_RSS_CONFIG_PREFIX.length()),
         RssSparkConfig.SPARK_SHUFFLE_COMPRESS_DEFAULT);
@@ -151,26 +154,16 @@ public class WriteBufferManager extends MemoryConsumer {
     this.memorySpillTimeoutSec = rssConf.get(RssSparkConfig.RSS_MEMORY_SPILL_TIMEOUT);
   }
 
-  public List<ShuffleBlockInfo> addRecord(int partitionId, Object key, Object value) {
-    final long start = System.currentTimeMillis();
-    arrayOutputStream.reset();
-    if (key != null) {
-      serializeStream.writeKey(key, ClassTag$.MODULE$.apply(key.getClass()));
-    } else {
-      serializeStream.writeKey(null, ManifestFactory$.MODULE$.Null());
-    }
-    if (value != null) {
-      serializeStream.writeValue(value, ClassTag$.MODULE$.apply(value.getClass()));
-    } else {
-      serializeStream.writeValue(null, ManifestFactory$.MODULE$.Null());
-    }
-    serializeStream.flush();
-    serializeTime += System.currentTimeMillis() - start;
-    byte[] serializedData = arrayOutputStream.getBuf();
-    int serializedDataLength = arrayOutputStream.size();
-    if (serializedDataLength == 0) {
-      return null;
-    }
+  /**
+   * add serialized columnar data directly when integrate with gluten
+   */
+  public List<ShuffleBlockInfo> addPartitionData(int partitionId, byte[] serializedData) {
+    return addPartitionData(
+            partitionId, serializedData, serializedData.length, System.currentTimeMillis());
+  }
+
+  public List<ShuffleBlockInfo> addPartitionData(
+          int partitionId, byte[] serializedData, int serializedDataLength, long start) {
     List<ShuffleBlockInfo> result = Lists.newArrayList();
     if (buffers.containsKey(partitionId)) {
       WriterBuffer wb = buffers.get(partitionId);
@@ -202,6 +195,29 @@ public class WriteBufferManager extends MemoryConsumer {
     return result;
   }
 
+  public List<ShuffleBlockInfo> addRecord(int partitionId, Object key, Object value) {
+    final long start = System.currentTimeMillis();
+    arrayOutputStream.reset();
+    if (key != null) {
+      serializeStream.writeKey(key, ClassTag$.MODULE$.apply(key.getClass()));
+    } else {
+      serializeStream.writeKey(null, ManifestFactory$.MODULE$.Null());
+    }
+    if (value != null) {
+      serializeStream.writeValue(value, ClassTag$.MODULE$.apply(value.getClass()));
+    } else {
+      serializeStream.writeValue(null, ManifestFactory$.MODULE$.Null());
+    }
+    serializeStream.flush();
+    serializeTime += System.currentTimeMillis() - start;
+    byte[] serializedData = arrayOutputStream.getBuf();
+    int serializedDataLength = arrayOutputStream.size();
+    if (serializedDataLength == 0) {
+      return null;
+    }
+    return addPartitionData(partitionId, serializedData, serializedDataLength, start);
+  }
+
   // transform all [partition, records] to [partition, ShuffleBlockInfo] and clear cache
   public synchronized List<ShuffleBlockInfo> clear() {
     List<ShuffleBlockInfo> result = Lists.newArrayList();
diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
index 44cd9124..9535f15d 100644
--- a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
+++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
@@ -177,6 +177,26 @@ public class WriteBufferManagerTest {
     assertEquals(1, wbm.getBuffers().size());
   }
 
+  @Test
+  public void addPartitionDataTest() {
+    SparkConf conf = getConf();
+    TaskMemoryManager mockTaskMemoryManager = mock(TaskMemoryManager.class);
+    BufferManagerOptions bufferOptions = new BufferManagerOptions(conf);
+    WriteBufferManager wbm = new WriteBufferManager(
+            0, 0, bufferOptions, null,
+            Maps.newHashMap(), mockTaskMemoryManager, new ShuffleWriteMetrics(), RssSparkConfig.toRssConf(conf));
+    WriteBufferManager spyManager = spy(wbm);
+    doReturn(512L).when(spyManager).acquireMemory(anyLong());
+
+    List<ShuffleBlockInfo> shuffleBlockInfos = spyManager.addPartitionData(0, new byte[64]);
+    assertEquals(1, spyManager.getBuffers().size());
+    assertEquals(0, shuffleBlockInfos.size());
+    shuffleBlockInfos = spyManager.addPartitionData(0, new byte[64]);
+    assertEquals(0, spyManager.getBuffers().size());
+    assertEquals(1, shuffleBlockInfos.size());
+    assertEquals(128, shuffleBlockInfos.get(0).getUncompressLength());
+  }
+
   @Test
   public void createBlockIdTest() {
     SparkConf conf = getConf();
diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index d0892ce6..aed994a8 100644
--- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -474,7 +474,7 @@ public class RssShuffleManager extends RssShuffleManagerBase {
     if (!(handle instanceof RssShuffleHandle)) {
       throw new RssException("Unexpected ShuffleHandle:" + handle.getClass().getName());
     }
-    RssShuffleHandle<K, C, ?> rssShuffleHandle = (RssShuffleHandle<K, C, ?>) handle;
+    RssShuffleHandle<K, ?, C> rssShuffleHandle = (RssShuffleHandle<K, ?, C>) handle;
     final int partitionNum = rssShuffleHandle.getDependency().partitioner().numPartitions();
     int shuffleId = rssShuffleHandle.getShuffleId();
     Map<Integer, List<ShuffleServerInfo>> allPartitionToServers = rssShuffleHandle.getPartitionToServers();
diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
index b9c8e36f..6cd4a13d 100644
--- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
+++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
@@ -65,7 +65,8 @@ public class RssShuffleReader<K, C> implements ShuffleReader<K, C> {
   private int startPartition;
   private int endPartition;
   private TaskContext context;
-  private ShuffleDependency<K, C, ?> shuffleDependency;
+  private ShuffleDependency<K, ?, C> shuffleDependency;
+  private int numMaps;
   private Serializer serializer;
   private String taskId;
   private String basePath;
@@ -85,7 +86,7 @@ public class RssShuffleReader<K, C> implements ShuffleReader<K, C> {
       int mapStartIndex,
       int mapEndIndex,
       TaskContext context,
-      RssShuffleHandle<K, C, ?> rssShuffleHandle,
+      RssShuffleHandle<K, ?, C> rssShuffleHandle,
       String basePath,
       Configuration hadoopConf,
       int partitionNum,
@@ -100,6 +101,7 @@ public class RssShuffleReader<K, C> implements ShuffleReader<K, C> {
     this.mapStartIndex = mapStartIndex;
     this.mapEndIndex = mapEndIndex;
     this.context = context;
+    this.numMaps = rssShuffleHandle.getNumMaps();
     this.shuffleDependency = rssShuffleHandle.getDependency();
     this.shuffleId = shuffleDependency.shuffleId();
     this.serializer = rssShuffleHandle.getDependency().serializer();
@@ -205,6 +207,9 @@ public class RssShuffleReader<K, C> implements ShuffleReader<K, C> {
 
     MultiPartitionIterator() {
       List<CompletionIterator<Product2<K, C>, RssShuffleDataIterator<K, C>>> iterators = Lists.newArrayList();
+      if (numMaps <= 0) {
+        return;
+      }
       for (int partition = startPartition; partition < endPartition; partition++) {
         if (partitionToExpectBlocks.get(partition).isEmpty()) {
           LOG.info("{} partition is empty partition", partition);
diff --git a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
index 1bbd5def..268c75a2 100644
--- a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
+++ b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
@@ -73,6 +73,7 @@ public class RssShuffleReaderTest extends AbstractRssReaderTest {
     when(handleMock.getAppId()).thenReturn("appId");
     when(handleMock.getDependency()).thenReturn(dependencyMock);
     when(handleMock.getShuffleId()).thenReturn(1);
+    when(handleMock.getNumMaps()).thenReturn(1);
     Map<Integer, List<ShuffleServerInfo>> partitionToServers = new HashMap<>();
     partitionToServers.put(0, Lists.newArrayList(ssi));
     partitionToServers.put(1, Lists.newArrayList(ssi));