You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@uniffle.apache.org by xi...@apache.org on 2023/02/03 12:14:58 UTC

[incubator-uniffle] branch master updated: [ISSUE-476][FEATURE] Respect `spark.shuffle.compress` configuration in Uniffle (#495)

This is an automated email from the ASF dual-hosted git repository.

xianjin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new 0bb1e8c3 [ISSUE-476][FEATURE] Respect `spark.shuffle.compress` configuration in Uniffle (#495)
0bb1e8c3 is described below

commit 0bb1e8c31de32f04e58380f0d9dee89117fb98b8
Author: jiafu zhang <ji...@intel.com>
AuthorDate: Fri Feb 3 20:14:52 2023 +0800

    [ISSUE-476][FEATURE] Respect `spark.shuffle.compress` configuration in Uniffle (#495)
    
    ### What changes were proposed in this pull request?
    In Spark, there is a configuration, "spark.shuffle.compress", to control if we should compress shuffled data. It defaults to true. But Uniffle always compresses shuffled data without respecting the configuration. Uniffle should respect it to be more align with vanilla Spark.
    
    ### Why are the changes needed?
    We should let user decide if compress can be applied since some workloads are not compress friendly which otherwise can avoid useless CPU cycles.
    
    ### Does this PR introduce _any_ user-facing change?
    NO
    
    ### How was this patch tested?
    tested with "spark.shuffle.compress" setting to true and false. They both worked as expected.
---
 .../org/apache/spark/shuffle/RssSparkConfig.java   |  5 ++
 .../shuffle/reader/RssShuffleDataIterator.java     | 65 +++++++++++--------
 .../spark/shuffle/writer/WriteBufferManager.java   | 15 +++--
 .../shuffle/reader/AbstractRssReaderTest.java      | 25 +++++++-
 .../shuffle/reader/RssShuffleDataIteratorTest.java | 73 +++++++++++++++-------
 .../shuffle/writer/WriteBufferManagerTest.java     | 25 +++++++-
 6 files changed, 152 insertions(+), 56 deletions(-)

diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
index 9da401d2..c8e3478c 100644
--- a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
+++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
@@ -265,6 +265,11 @@ public class RssSparkConfig {
                    + " spark.rss.estimate.server.assignment.enabled"))
       .createWithDefault(RssClientConfig.RSS_ESTIMATE_TASK_CONCURRENCY_PER_SERVER_DEFAULT_VALUE);
 
+  // spark2 doesn't have this key defined
+  public static final String SPARK_SHUFFLE_COMPRESS_KEY = "spark.shuffle.compress";
+
+  public static final boolean SPARK_SHUFFLE_COMPRESS_DEFAULT = true;
+
   public static final Set<String> RSS_MANDATORY_CLUSTER_CONF =
       ImmutableSet.of(RSS_STORAGE_TYPE.key(), RSS_REMOTE_STORAGE_PATH.key());
 
diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssShuffleDataIterator.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssShuffleDataIterator.java
index af2be209..4c26a5c3 100644
--- a/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssShuffleDataIterator.java
+++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssShuffleDataIterator.java
@@ -27,6 +27,7 @@ import org.apache.spark.executor.ShuffleReadMetrics;
 import org.apache.spark.serializer.DeserializationStream;
 import org.apache.spark.serializer.Serializer;
 import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.shuffle.RssSparkConfig;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import scala.Product2;
@@ -53,7 +54,7 @@ public class RssShuffleDataIterator<K, C> extends AbstractIterator<Product2<K, C
   private long decompressTime = 0;
   private DeserializationStream deserializationStream = null;
   private ByteBufInputStream byteBufInputStream = null;
-  private long compressedBytesLength = 0;
+  private long totalRawBytesLength = 0;
   private long unCompressedBytesLength = 0;
   private ByteBuffer uncompressedData;
   private Codec codec;
@@ -66,12 +67,15 @@ public class RssShuffleDataIterator<K, C> extends AbstractIterator<Product2<K, C
     this.serializerInstance = serializer.newInstance();
     this.shuffleReadClient = shuffleReadClient;
     this.shuffleReadMetrics = shuffleReadMetrics;
-    this.codec = Codec.newInstance(rssConf);
+    boolean compress = rssConf.getBoolean(RssSparkConfig.SPARK_SHUFFLE_COMPRESS_KEY
+            .substring(RssSparkConfig.SPARK_RSS_CONFIG_PREFIX.length()),
+        RssSparkConfig.SPARK_SHUFFLE_COMPRESS_DEFAULT);
+    this.codec = compress ? Codec.newInstance(rssConf) : null;
   }
 
   public Iterator<Tuple2<Object, Object>> createKVIterator(ByteBuffer data, int size) {
     clearDeserializationStream();
-    byteBufInputStream = new ByteBufInputStream(Unpooled.wrappedBuffer(data.array(), 0, size), true);
+    byteBufInputStream = new ByteBufInputStream(Unpooled.wrappedBuffer(data.array(), data.position(), size), true);
     deserializationStream = serializerInstance.deserializeStream(byteBufInputStream);
     return deserializationStream.asKeyValueIterator();
   }
@@ -96,31 +100,15 @@ public class RssShuffleDataIterator<K, C> extends AbstractIterator<Product2<K, C
     if (recordsIterator == null || !recordsIterator.hasNext()) {
       // read next segment
       long startFetch = System.currentTimeMillis();
-      CompressedShuffleBlock compressedBlock = shuffleReadClient.readShuffleBlockData();
+      // depends on spark.shuffle.compress, shuffled block may not be compressed
+      CompressedShuffleBlock rawBlock = shuffleReadClient.readShuffleBlockData();
       // If ShuffleServer delete
 
-      ByteBuffer compressedData = null;
-      if (compressedBlock != null) {
-        compressedData = compressedBlock.getByteBuffer();
-      }
+      ByteBuffer rawData = rawBlock != null ? rawBlock.getByteBuffer() : null;
       long fetchDuration = System.currentTimeMillis() - startFetch;
       shuffleReadMetrics.incFetchWaitTime(fetchDuration);
-      if (compressedData != null) {
-        long compressedDataLength = compressedData.limit() - compressedData.position();
-        compressedBytesLength += compressedDataLength;
-        shuffleReadMetrics.incRemoteBytesRead(compressedDataLength);
-
-        int uncompressedLen = compressedBlock.getUncompressLength();
-        if (uncompressedData == null || uncompressedData.capacity() < uncompressedLen) {
-          // todo: support off-heap bytebuffer
-          uncompressedData = ByteBuffer.allocate(uncompressedLen);
-        }
-        uncompressedData.clear();
-        long startDecompress = System.currentTimeMillis();
-        codec.decompress(compressedData, uncompressedLen, uncompressedData, 0);
-        unCompressedBytesLength += compressedBlock.getUncompressLength();
-        long decompressDuration = System.currentTimeMillis() - startDecompress;
-        decompressTime += decompressDuration;
+      if (rawData != null) {
+        int uncompressedLen = uncompress(rawBlock, rawData);
         // create new iterator for shuffle data
         long startSerialization = System.currentTimeMillis();
         recordsIterator = createKVIterator(uncompressedData, uncompressedLen);
@@ -131,15 +119,40 @@ public class RssShuffleDataIterator<K, C> extends AbstractIterator<Product2<K, C
         // finish reading records, check data consistent
         shuffleReadClient.checkProcessedBlockIds();
         shuffleReadClient.logStatics();
-        LOG.info("Fetch " + compressedBytesLength + " bytes cost " + readTime + " ms and "
-            + serializeTime + " ms to serialize, " + decompressTime + " ms to decompress with unCompressionLength["
+        String decInfo = codec == null ? "." : (", " + decompressTime
+            + " ms to decompress with unCompressionLength["
             + unCompressedBytesLength + "]");
+        LOG.info("Fetch {} bytes cost {} ms and {} ms to serialize{}",
+            totalRawBytesLength, readTime, serializeTime, decInfo);
         return false;
       }
     }
     return recordsIterator.hasNext();
   }
 
+  private int uncompress(CompressedShuffleBlock rawBlock, ByteBuffer rawData) {
+    long rawDataLength = rawData.limit() - rawData.position();
+    totalRawBytesLength += rawDataLength;
+    shuffleReadMetrics.incRemoteBytesRead(rawDataLength);
+
+    int uncompressedLen = rawBlock.getUncompressLength();
+    if (codec != null) {
+      if (uncompressedData == null || uncompressedData.capacity() < uncompressedLen) {
+        // todo: support off-heap bytebuffer
+        uncompressedData = ByteBuffer.allocate(uncompressedLen);
+      }
+      uncompressedData.clear();
+      long startDecompress = System.currentTimeMillis();
+      codec.decompress(rawData, uncompressedLen, uncompressedData, 0);
+      unCompressedBytesLength += uncompressedLen;
+      long decompressDuration = System.currentTimeMillis() - startDecompress;
+      decompressTime += decompressDuration;
+    } else {
+      uncompressedData = rawData;
+    }
+    return uncompressedLen;
+  }
+
   @Override
   public Product2<K, C> next() {
     shuffleReadMetrics.incRecordsRead(1L);
diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
index fcce0597..7b10fb83 100644
--- a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
+++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
@@ -32,6 +32,7 @@ import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.serializer.SerializationStream;
 import org.apache.spark.serializer.Serializer;
 import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.shuffle.RssSparkConfig;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import scala.reflect.ClassTag$;
@@ -106,7 +107,10 @@ public class WriteBufferManager extends MemoryConsumer {
     this.requireMemoryRetryMax = bufferManagerOptions.getRequireMemoryRetryMax();
     this.arrayOutputStream = new WrappedByteArrayOutputStream(serializerBufferSize);
     this.serializeStream = instance.serializeStream(arrayOutputStream);
-    this.codec = Codec.newInstance(rssConf);
+    boolean compress = rssConf.getBoolean(RssSparkConfig.SPARK_SHUFFLE_COMPRESS_KEY
+            .substring(RssSparkConfig.SPARK_RSS_CONFIG_PREFIX.length()),
+        RssSparkConfig.SPARK_SHUFFLE_COMPRESS_DEFAULT);
+    this.codec = compress ? Codec.newInstance(rssConf) : null;
   }
 
   public List<ShuffleBlockInfo> addRecord(int partitionId, Object key, Object value) {
@@ -178,10 +182,13 @@ public class WriteBufferManager extends MemoryConsumer {
   protected ShuffleBlockInfo createShuffleBlock(int partitionId, WriterBuffer wb) {
     byte[] data = wb.getData();
     final int uncompressLength = data.length;
-    long start = System.currentTimeMillis();
-    final byte[] compressed = codec.compress(data);
+    byte[] compressed = data;
+    if (codec != null) {
+      long start = System.currentTimeMillis();
+      compressed = codec.compress(data);
+      compressTime += System.currentTimeMillis() - start;
+    }
     final long crc32 = ChecksumUtils.getCrc32(compressed);
-    compressTime += System.currentTimeMillis() - start;
     final long blockId = ClientUtils.getBlockId(partitionId, taskAttemptId, getNextSeqNo(partitionId));
     uncompressedDataLen += data.length;
     shuffleWriteMetrics.incBytesWritten(compressed.length);
diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/AbstractRssReaderTest.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/AbstractRssReaderTest.java
index fd290cb4..b1e33ddf 100644
--- a/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/AbstractRssReaderTest.java
+++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/AbstractRssReaderTest.java
@@ -71,6 +71,20 @@ public abstract class AbstractRssReaderTest extends HdfsTestBase {
       String keyPrefix,
       Serializer serializer,
       int partitionID) throws Exception {
+    writeTestData(handler, blockNum, recordNum, expectedData, blockIdBitmap, keyPrefix, serializer,
+        partitionID, true);
+  }
+
+  protected void writeTestData(
+      ShuffleWriteHandler handler,
+      int blockNum,
+      int recordNum,
+      Map<String, String> expectedData,
+      Roaring64NavigableMap blockIdBitmap,
+      String keyPrefix,
+      Serializer serializer,
+      int partitionID,
+      boolean compress) throws Exception {
     List<ShufflePartitionedBlock> blocks = Lists.newArrayList();
     SerializerInstance serializerInstance = serializer.newInstance();
     for (int i = 0; i < blockNum; i++) {
@@ -84,14 +98,21 @@ public abstract class AbstractRssReaderTest extends HdfsTestBase {
       }
       long blockId = ClientUtils.getBlockId(partitionID, 0, atomicInteger.getAndIncrement());
       blockIdBitmap.add(blockId);
-      blocks.add(createShuffleBlock(output.toBytes(), blockId));
+      blocks.add(createShuffleBlock(output.toBytes(), blockId, compress));
       serializeStream.close();
     }
     handler.write(blocks);
   }
 
   protected ShufflePartitionedBlock createShuffleBlock(byte[] data, long blockId) {
-    byte[] compressData = Codec.newInstance(new RssConf()).compress(data);
+    return createShuffleBlock(data, blockId, true);
+  }
+
+  protected ShufflePartitionedBlock createShuffleBlock(byte[] data, long blockId, boolean compress) {
+    byte[] compressData = data;
+    if (compress) {
+      compressData = Codec.newInstance(new RssConf()).compress(data);
+    }
     long crc = ChecksumUtils.getCrc32(compressData);
     return new ShufflePartitionedBlock(compressData.length, data.length, crc, blockId, 0,
         compressData);
diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/RssShuffleDataIteratorTest.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/RssShuffleDataIteratorTest.java
index 9983d9a0..b1be1ec3 100644
--- a/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/RssShuffleDataIteratorTest.java
+++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/RssShuffleDataIteratorTest.java
@@ -23,6 +23,7 @@ import java.util.Map;
 
 import com.google.common.collect.Lists;
 import com.google.common.collect.Maps;
+import org.apache.commons.lang.reflect.FieldUtils;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.FileUtil;
 import org.apache.hadoop.fs.Path;
@@ -30,6 +31,8 @@ import org.apache.spark.SparkConf;
 import org.apache.spark.executor.ShuffleReadMetrics;
 import org.apache.spark.serializer.KryoSerializer;
 import org.apache.spark.serializer.Serializer;
+import org.apache.spark.shuffle.RssSparkConfig;
+import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 import org.mockito.MockedStatic;
 import org.mockito.Mockito;
@@ -98,36 +101,30 @@ public class RssShuffleDataIteratorTest extends AbstractRssReaderTest {
 
   private RssShuffleDataIterator getDataIterator(String basePath, Roaring64NavigableMap blockIdBitmap,
       Roaring64NavigableMap taskIdBitmap, List<ShuffleServerInfo> serverInfos) {
+    return getDataIterator(basePath, blockIdBitmap, taskIdBitmap, serverInfos, true);
+  }
+
+  private RssShuffleDataIterator getDataIterator(String basePath, Roaring64NavigableMap blockIdBitmap,
+       Roaring64NavigableMap taskIdBitmap, List<ShuffleServerInfo> serverInfos, boolean compress) {
     ShuffleReadClientImpl readClient = new ShuffleReadClientImpl(
         StorageType.HDFS.name(), "appId", 0, 1, 100, 2,
         10, 10000, basePath, blockIdBitmap, taskIdBitmap, Lists.newArrayList(serverInfos),
         new Configuration(), new DefaultIdHelper());
+    RssConf rc;
+    if (!compress) {
+      SparkConf sc = new SparkConf();
+      sc.set(RssSparkConfig.SPARK_SHUFFLE_COMPRESS_KEY, String.valueOf(false));
+      rc = RssSparkConfig.toRssConf(sc);
+    } else {
+      rc = new RssConf();
+    }
     return new RssShuffleDataIterator(KRYO_SERIALIZER, readClient,
-        new ShuffleReadMetrics(), new RssConf());
+        new ShuffleReadMetrics(), rc);
   }
 
   @Test
   public void readTest2() throws Exception {
-    String basePath = HDFS_URI + "readTest2";
-    HdfsShuffleWriteHandler writeHandler1 =
-        new HdfsShuffleWriteHandler("appId", 0, 0, 1, basePath, ssi1.getId(), conf);
-    HdfsShuffleWriteHandler writeHandler2 =
-        new HdfsShuffleWriteHandler("appId", 0, 0, 1, basePath, ssi2.getId(), conf);
-
-    Map<String, String> expectedData = Maps.newHashMap();
-    Roaring64NavigableMap blockIdBitmap = Roaring64NavigableMap.bitmapOf();
-    Roaring64NavigableMap taskIdBitmap = Roaring64NavigableMap.bitmapOf(0);
-    writeTestData(writeHandler1, 2, 5, expectedData,
-        blockIdBitmap, "key1", KRYO_SERIALIZER, 0);
-    writeTestData(writeHandler2, 2, 5, expectedData,
-        blockIdBitmap, "key2", KRYO_SERIALIZER, 0);
-
-    RssShuffleDataIterator rssShuffleDataIterator = getDataIterator(basePath, blockIdBitmap,
-        taskIdBitmap, Lists.newArrayList(ssi1, ssi2));
-
-    validateResult(rssShuffleDataIterator, expectedData, 20);
-    assertEquals(20, rssShuffleDataIterator.getShuffleReadMetrics().recordsRead());
-    assertTrue(rssShuffleDataIterator.getShuffleReadMetrics().fetchWaitTime() > 0);
+    readTestCompressOrNot("readTest2", true);
   }
 
   @Test
@@ -265,6 +262,40 @@ public class RssShuffleDataIteratorTest extends AbstractRssReaderTest {
     }
   }
 
+  @Test
+  public void readTestUncompressedShuffle() throws Exception {
+    readTestCompressOrNot("readTestUncompressedShuffle", false);
+  }
+
+  private void readTestCompressOrNot(String path, boolean compress) throws Exception {
+    String basePath = HDFS_URI + path;
+    HdfsShuffleWriteHandler writeHandler1 =
+        new HdfsShuffleWriteHandler("appId", 0, 0, 1, basePath, ssi1.getId(), conf);
+    HdfsShuffleWriteHandler writeHandler2 =
+        new HdfsShuffleWriteHandler("appId", 0, 0, 1, basePath, ssi2.getId(), conf);
+
+    Map<String, String> expectedData = Maps.newHashMap();
+    Roaring64NavigableMap blockIdBitmap = Roaring64NavigableMap.bitmapOf();
+    Roaring64NavigableMap taskIdBitmap = Roaring64NavigableMap.bitmapOf(0);
+    writeTestData(writeHandler1, 2, 5, expectedData,
+        blockIdBitmap, "key1", KRYO_SERIALIZER, 0, compress);
+    writeTestData(writeHandler2, 2, 5, expectedData,
+        blockIdBitmap, "key2", KRYO_SERIALIZER, 0, compress);
+
+    RssShuffleDataIterator rssShuffleDataIterator = getDataIterator(basePath, blockIdBitmap,
+        taskIdBitmap, Lists.newArrayList(ssi1, ssi2), compress);
+    Object codec = FieldUtils.readField(rssShuffleDataIterator, "codec", true);
+    if (compress) {
+      Assertions.assertNotNull(codec);
+    } else {
+      Assertions.assertNull(codec);
+    }
+
+    validateResult(rssShuffleDataIterator, expectedData, 20);
+    assertEquals(20, rssShuffleDataIterator.getShuffleReadMetrics().recordsRead());
+    assertTrue(rssShuffleDataIterator.getShuffleReadMetrics().fetchWaitTime() > 0);
+  }
+
   @Test
   public void cleanup() throws Exception {
     ShuffleReadClient mockClient = mock(ShuffleReadClient.class);
diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
index 4c2be50a..4f57e265 100644
--- a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
+++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
@@ -20,16 +20,17 @@ package org.apache.spark.shuffle.writer;
 import java.util.List;
 
 import com.google.common.collect.Maps;
+import org.apache.commons.lang.reflect.FieldUtils;
 import org.apache.spark.SparkConf;
 import org.apache.spark.executor.ShuffleWriteMetrics;
 import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.serializer.KryoSerializer;
 import org.apache.spark.serializer.Serializer;
 import org.apache.spark.shuffle.RssSparkConfig;
+import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 
 import org.apache.uniffle.common.ShuffleBlockInfo;
-import org.apache.uniffle.common.config.RssConf;
 
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertTrue;
@@ -48,7 +49,7 @@ public class WriteBufferManagerTest {
     BufferManagerOptions bufferOptions = new BufferManagerOptions(conf);
     WriteBufferManager wbm = new WriteBufferManager(
         0, 0, bufferOptions, kryoSerializer,
-        Maps.newHashMap(), mockTaskMemoryManager, new ShuffleWriteMetrics(), new RssConf());
+        Maps.newHashMap(), mockTaskMemoryManager, new ShuffleWriteMetrics(), RssSparkConfig.toRssConf(conf));
     WriteBufferManager spyManager = spy(wbm);
     doReturn(512L).when(spyManager).acquireMemory(anyLong());
     return spyManager;
@@ -65,9 +66,27 @@ public class WriteBufferManagerTest {
   }
 
   @Test
-  public void addRecordTest() {
+  public void addRecordCompressedTest() throws Exception {
+    addRecord(true);
+  }
+
+  @Test
+  public void addRecordUnCompressedTest() throws Exception {
+    addRecord(false);
+  }
+
+  private void addRecord(boolean compress) throws IllegalAccessException {
     SparkConf conf = getConf();
+    if (!compress) {
+      conf.set(RssSparkConfig.SPARK_SHUFFLE_COMPRESS_KEY, String.valueOf(false));
+    }
     WriteBufferManager wbm = createManager(conf);
+    Object codec = FieldUtils.readField(wbm, "codec", true);
+    if (compress) {
+      Assertions.assertNotNull(codec);
+    } else {
+      Assertions.assertNull(codec);
+    }
     wbm.setShuffleWriteMetrics(new ShuffleWriteMetrics());
     String testKey = "Key";
     String testValue = "Value";