You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@uniffle.apache.org by ro...@apache.org on 2023/06/12 09:33:50 UTC

[incubator-uniffle] branch master updated: [#854][FOLLOWUP] feat(tez): Add RssTezFetcherTask to fetch data from worker for OrderedInput (#935)

This is an automated email from the ASF dual-hosted git repository.

roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new 64ed9351 [#854][FOLLOWUP] feat(tez): Add RssTezFetcherTask to fetch data from worker for OrderedInput (#935)
64ed9351 is described below

commit 64ed9351b283878e7c80f038b440905e040936f6
Author: Qing <11...@qq.com>
AuthorDate: Mon Jun 12 17:33:44 2023 +0800

    [#854][FOLLOWUP] feat(tez): Add RssTezFetcherTask to fetch data from worker for OrderedInput (#935)
    
    ### What changes were proposed in this pull request?
    Add RssTezFetcherTask to fetch data from worker for OrderedInput
    
    ### Why are the changes needed?
    Fix: https://github.com/apache/incubator-uniffle/issues/854
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    
    unit test
---
 .../orderedgrouped/RssTezShuffleDataFetcher.java   | 262 +++++++++++++++
 .../RssTezShuffleDataFetcherTest.java              | 371 +++++++++++++++++++++
 2 files changed, 633 insertions(+)

diff --git a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssTezShuffleDataFetcher.java b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssTezShuffleDataFetcher.java
new file mode 100644
index 00000000..97378270
--- /dev/null
+++ b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssTezShuffleDataFetcher.java
@@ -0,0 +1,262 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.common.shuffle.orderedgrouped;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import com.google.common.annotations.VisibleForTesting;
+import org.apache.tez.common.CallableWithNdc;
+import org.apache.tez.common.counters.TezCounter;
+import org.apache.tez.common.counters.TezCounters;
+import org.apache.tez.runtime.library.common.InputAttemptIdentifier;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.client.api.ShuffleReadClient;
+import org.apache.uniffle.client.response.CompressedShuffleBlock;
+import org.apache.uniffle.common.compression.Codec;
+import org.apache.uniffle.common.config.RssConf;
+import org.apache.uniffle.common.exception.RssException;
+
+public class RssTezShuffleDataFetcher extends CallableWithNdc<Void> {
+  private static final Logger LOG = LoggerFactory.getLogger(RssTezShuffleDataFetcher.class);
+
+  private enum ShuffleErrors {
+    IO_ERROR, WRONG_LENGTH, BAD_ID, WRONG_MAP,
+    CONNECTION, WRONG_REDUCE
+  }
+
+  private static final String SHUFFLE_ERR_GRP_NAME = "Shuffle Errors";
+
+  private final TezCounter ioErrs;
+  private final MergeManager merger;
+  private final long totalBlockCount;
+
+  private long copyBlockCount = 0;
+  private volatile boolean stopped = false;
+
+  private final ShuffleReadClient shuffleReadClient;
+  private long readTime = 0;
+  private long decompressTime = 0;
+  private long serializeTime = 0;
+  private long waitTime = 0;
+  private long copyTime = 0;  // the sum of readTime + decompressTime + serializeTime + waitTime
+  private long unCompressionLength = 0;
+  private final InputAttemptIdentifier inputAttemptIdentifier;
+  private int uniqueMapId = 0;
+
+  private boolean hasPendingData = false;
+  private long startWait;
+  private int waitCount = 0;
+  private byte[] uncompressedData = null;
+  private final Codec rssCodec;
+  private Integer partitionId;
+  private final ExceptionReporter exceptionReporter;
+
+  private final AtomicInteger issuedCnt = new AtomicInteger(0);
+
+  public RssTezShuffleDataFetcher(InputAttemptIdentifier inputAttemptIdentifier,
+        Integer partitionId,
+        MergeManager merger,
+        TezCounters tezCounters,
+        ShuffleReadClient shuffleReadClient,
+        long totalBlockCount,
+        RssConf rssConf,
+        ExceptionReporter exceptionReporter) {
+    this.merger = merger;
+    this.partitionId = partitionId;
+    this.inputAttemptIdentifier = inputAttemptIdentifier;
+    this.exceptionReporter = exceptionReporter;
+    ioErrs = tezCounters.findCounter(SHUFFLE_ERR_GRP_NAME, RssTezShuffleDataFetcher.ShuffleErrors.IO_ERROR.toString());
+    this.shuffleReadClient = shuffleReadClient;
+    this.totalBlockCount = totalBlockCount;
+
+    this.rssCodec = Codec.newInstance(rssConf);
+    LOG.info("RssTezShuffleDataFetcher, partitionId:{}, inputAttemptIdentifier:{}.",
+        this.partitionId, this.inputAttemptIdentifier);
+  }
+
+  @Override
+  public Void callInternal() {
+    try {
+      fetchAllRssBlocks();
+    } catch (InterruptedException ie) {
+      //might not be respected when fetcher is in progress / server is busy.  TEZ-711
+      //Set the status back
+      LOG.warn(ie.getMessage(), ie);
+      Thread.currentThread().interrupt();
+      return null;
+    } catch (Throwable t) {
+      LOG.warn(t.getMessage(), t);
+      exceptionReporter.reportException(t);
+      // Shuffle knows how to deal with failures post shutdown via the onFailure hook
+    }
+    return null;
+  }
+
+  public void fetchAllRssBlocks() throws IOException, InterruptedException {
+    while (!stopped) {
+      try {
+        // If merge is on, block
+        merger.waitForInMemoryMerge();
+        // Do shuffle
+        copyFromRssServer();
+      } catch (Exception e) {
+        LOG.warn(e.getMessage(), e);
+        throw e;
+      }
+    }
+  }
+
+  @VisibleForTesting
+  public void copyFromRssServer() throws IOException {
+    CompressedShuffleBlock compressedBlock = null;
+    ByteBuffer compressedData = null;
+    // fetch a block
+    if (!hasPendingData) {
+      final long startFetch = System.currentTimeMillis();
+      compressedBlock = shuffleReadClient.readShuffleBlockData();
+      if (compressedBlock != null) {
+        compressedData = compressedBlock.getByteBuffer();
+      }
+      long fetchDuration = System.currentTimeMillis() - startFetch;
+      readTime += fetchDuration;
+    }
+
+    // uncompress the block
+    if (!hasPendingData && compressedData != null) {
+      final long startDecompress = System.currentTimeMillis();
+      int uncompressedLen = compressedBlock.getUncompressLength();
+      ByteBuffer decompressedBuffer = ByteBuffer.allocate(uncompressedLen);
+      rssCodec.decompress(compressedData, uncompressedLen, decompressedBuffer, 0);
+      uncompressedData = decompressedBuffer.array();
+      unCompressionLength += compressedBlock.getUncompressLength();
+      long decompressDuration = System.currentTimeMillis() - startDecompress;
+      decompressTime += decompressDuration;
+    }
+
+    if (uncompressedData != null) {
+      // start to merge
+      final long startSerialization = System.currentTimeMillis();
+      if (issueMapOutputMerge()) {
+        long serializationDuration = System.currentTimeMillis() - startSerialization;
+        serializeTime += serializationDuration;
+        // if reserve successes, reset status for next fetch
+        if (hasPendingData) {
+          waitTime += System.currentTimeMillis() - startWait;
+        }
+        hasPendingData = false;
+        uncompressedData = null;
+      } else {
+        // if reserve fail, return and wait
+        startWait = System.currentTimeMillis();
+        return;
+      }
+
+      // update some status
+      copyBlockCount++;
+      copyTime = readTime + decompressTime + serializeTime + waitTime;
+      updateStatus();
+    } else {
+      // finish reading data, close related reader and check data consistent
+      shuffleReadClient.close();
+      shuffleReadClient.checkProcessedBlockIds();
+      shuffleReadClient.logStatics();
+      LOG.info("Reduce task " + inputAttemptIdentifier + " read block cnt: " + copyBlockCount
+              + " cost " + readTime + " ms to fetch and "
+              + decompressTime + " ms to decompress with unCompressionLength["
+              + unCompressionLength + "] and " + serializeTime + " ms to serialize and "
+              + waitTime + " ms to wait resource" + ", copy time:" + copyTime);
+      stopFetch();
+    }
+  }
+
+  public Integer getPartitionId() {
+    return partitionId;
+  }
+
+  public void setPartitionId(Integer partitionId) {
+    this.partitionId = partitionId;
+  }
+
+  private boolean issueMapOutputMerge() throws IOException {
+    // Allocate a MapOutput (either in-memory or on-disk) to put uncompressed block
+    // In Rss, a MapOutput is sent as multiple blocks, so the reducer needs to
+    // treat each "block" as a faked "mapout".
+    // To avoid name conflicts, we use getNextUniqueTaskAttemptID instead.
+    // It will generate a unique TaskAttemptID(increased_seq++, 0).
+    InputAttemptIdentifier uniqueInputAttemptIdentifier = getNextUniqueInputAttemptIdentifier();
+    MapOutput mapOutput = null;
+    try {
+      issuedCnt.incrementAndGet();
+      LOG.info("IssueMapOutputMerge, uncompressedData length:{}, issueCnt:{}, totalBlockCount:{}",
+          uncompressedData.length, issuedCnt.get(), totalBlockCount);
+      mapOutput = merger.reserve(uniqueInputAttemptIdentifier, uncompressedData.length, 0, 1);
+    } catch (IOException ioe) {
+      // kill this reduce attempt
+      ioErrs.increment(1);
+      throw ioe;
+    }
+    // Check if we can shuffle *now* ...
+    if (mapOutput == null) {
+      LOG.info("RssMRFetcher" + " - MergeManager returned status WAIT ...");
+      // Not an error but wait to process data.
+      // Use a retry flag to avoid re-fetch and re-uncompress.
+      hasPendingData = true;
+      waitCount++;
+      return false;
+    }
+
+    // write data to mapOutput
+    try {
+      RssTezBypassWriter.write(mapOutput, uncompressedData);
+      // let the merger knows this block is ready for merging
+      mapOutput.commit();
+    } catch (Throwable t) {
+      ioErrs.increment(1);
+      mapOutput.abort();
+      throw new RssException("Reduce: " + inputAttemptIdentifier + " cannot write block to "
+              + mapOutput.getClass().getSimpleName() + " due to: " + t.getClass().getName());
+    }
+    return true;
+  }
+
+  private InputAttemptIdentifier getNextUniqueInputAttemptIdentifier() {
+    return new InputAttemptIdentifier(uniqueMapId++, 0);
+  }
+
+  private void updateStatus() {
+  }
+
+  @VisibleForTesting
+  public int getRetryCount() {
+    return waitCount;
+  }
+
+  private void stopFetch() {
+    LOG.info("RssTezShuffleDataFetcher stop fetch");
+    stopped = true;
+  }
+
+  public void shutDown() {
+    stopFetch();
+    LOG.info("RssTezShuffleDataFetcher shutdown");
+  }
+}
diff --git a/client-tez/src/test/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssTezShuffleDataFetcherTest.java b/client-tez/src/test/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssTezShuffleDataFetcherTest.java
new file mode 100644
index 00000000..aa7d994b
--- /dev/null
+++ b/client-tez/src/test/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssTezShuffleDataFetcherTest.java
@@ -0,0 +1,371 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.common.shuffle.orderedgrouped;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.TreeMap;
+
+import com.google.common.collect.ImmutableListMultimap;
+import com.google.common.collect.LinkedListMultimap;
+import com.google.common.collect.ListMultimap;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.LocalDirAllocator;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.BoundedByteArrayOutputStream;
+import org.apache.hadoop.io.BytesWritable;
+import org.apache.hadoop.io.DataInputBuffer;
+import org.apache.hadoop.io.DataOutputBuffer;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.RawComparator;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.io.serializer.SerializationFactory;
+import org.apache.hadoop.io.serializer.Serializer;
+import org.apache.tez.common.RssTezConfig;
+import org.apache.tez.common.TezRuntimeFrameworkConfigs;
+import org.apache.tez.common.counters.GenericCounter;
+import org.apache.tez.common.counters.TezCounter;
+import org.apache.tez.common.counters.TezCounters;
+import org.apache.tez.dag.api.UserPayload;
+import org.apache.tez.runtime.api.InputContext;
+import org.apache.tez.runtime.library.api.TezRuntimeConfiguration;
+import org.apache.tez.runtime.library.common.InputAttemptIdentifier;
+import org.apache.tez.runtime.library.common.TezRuntimeUtils;
+import org.apache.tez.runtime.library.common.ValuesIterator;
+import org.apache.tez.runtime.library.common.combine.Combiner;
+import org.apache.tez.runtime.library.common.comparator.TezBytesComparator;
+import org.apache.tez.runtime.library.common.sort.impl.TezRawKeyValueIterator;
+import org.junit.jupiter.api.Test;
+import org.mockito.internal.util.collections.Sets;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.client.api.ShuffleReadClient;
+import org.apache.uniffle.client.response.CompressedShuffleBlock;
+import org.apache.uniffle.common.compression.Codec;
+import org.apache.uniffle.common.compression.Lz4Codec;
+
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.junit.jupiter.api.Assertions.fail;
+import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.mock;
+
+public class RssTezShuffleDataFetcherTest {
+  private static final Logger LOG = LoggerFactory.getLogger(RssTezShuffleDataFetcherTest.class);
+
+  enum TestWithComparator {
+    LONG, INT, BYTES, TEZ_BYTES, TEXT
+  }
+
+  Configuration conf;
+  FileSystem fs;
+
+  final Class keyClass;
+  final Class valClass;
+  final RawComparator comparator;
+  final boolean expectedTestResult;
+
+  int mergeFactor;
+  //For storing original data
+  final ListMultimap<Writable, Writable> originalData;
+
+  TezRawKeyValueIterator rawKeyValueIterator;
+
+  Path baseDir;
+  Path tmpDir;
+  static List<byte[]> bytesData = new ArrayList<>();
+  static Codec codec = new Lz4Codec();
+
+  public RssTezShuffleDataFetcherTest() throws IOException, ClassNotFoundException {
+    this.keyClass = Class.forName("org.apache.hadoop.io.Text");
+    this.valClass = Class.forName("org.apache.hadoop.io.Text");
+    this.comparator = getComparator(TestWithComparator.TEXT);
+    this.expectedTestResult = true;
+    originalData = LinkedListMultimap.create();
+    setupConf();
+  }
+
+  private void setupConf() throws IOException, ClassNotFoundException {
+    mergeFactor = 2;
+    conf = new Configuration();
+    conf.setInt(TezRuntimeConfiguration.TEZ_RUNTIME_IO_SORT_FACTOR, mergeFactor);
+    conf.setClass(TezRuntimeConfiguration.TEZ_RUNTIME_KEY_COMPARATOR_CLASS,
+        Class.forName("org.apache.tez.runtime.library.common.comparator.TezBytesComparator"),
+        Class.forName("org.apache.hadoop.io.WritableComparator"));
+    baseDir = new Path(".", this.getClass().getName());
+    String localDirs = baseDir.toString();
+    conf.setStrings(TezRuntimeFrameworkConfigs.LOCAL_DIRS, localDirs);
+    fs = FileSystem.getLocal(conf);
+  }
+
+  @Test
+  public void testIteratorWithInMemoryReader() throws Throwable {
+    fs.mkdirs(baseDir);
+    tmpDir = new Path(baseDir, "tmp");
+
+    ValuesIterator iterator = createIterator();
+    verifyIteratorData(iterator);
+
+    fs.delete(baseDir, true);
+    originalData.clear();
+  }
+
+  private void getNextFromFinishedIterator(ValuesIterator iterator) {
+    try {
+      boolean hasNext = iterator.moveToNext();
+      fail();
+    } catch (IOException e) {
+      assertTrue(e.getMessage().contains("Please check if you are invoking moveToNext()"));
+    }
+  }
+
+  /**
+   * Tests whether data in valuesIterator matches with sorted input data set.
+   *
+   * Returns a list of value counts for each key.
+   *
+   * @param valuesIterator
+   * @return List
+   * @throws IOException
+   */
+  private void verifyIteratorData(ValuesIterator valuesIterator) throws IOException {
+    boolean result = true;
+
+    // sort original data based on comparator
+    ListMultimap<Writable, Writable> sortedMap =
+        new ImmutableListMultimap.Builder<Writable, Writable>()
+            .orderKeysBy(this.comparator).putAll(originalData).build();
+
+    Set<Map.Entry<Writable, Writable>> oriKeySet = Sets.newSet();
+    oriKeySet.addAll(sortedMap.entries());
+
+    //Iterate through sorted data and valuesIterator for verification
+    for (Map.Entry<Writable, Writable> entry : oriKeySet) {
+      assertTrue(valuesIterator.moveToNext());
+      Writable oriKey = entry.getKey();
+      //Verify if the key and the original key are same
+      if (!oriKey.equals((Writable) valuesIterator.getKey())) {
+        result = false;
+        break;
+      }
+
+      int valueCount = 0;
+      //Verify values
+      Iterator<Writable> vItr = valuesIterator.getValues().iterator();
+      for (Writable val : sortedMap.get(oriKey)) {
+        assertTrue(vItr.hasNext());
+        //Verify if the values are same
+        if (!val.equals((Writable) vItr.next())) {
+          result = false;
+          break;
+        }
+        valueCount++;
+      }
+      assertTrue(valueCount > 0);
+    }
+    assertTrue(result);
+    assertFalse(valuesIterator.moveToNext());
+    getNextFromFinishedIterator(valuesIterator);
+  }
+
+  /**
+   * Create sample data (in memory / disk based), merge them and return ValuesIterator
+   *
+   * @return ValuesIterator
+   * @throws IOException
+   */
+  @SuppressWarnings("unchecked")
+  private ValuesIterator createIterator() throws Throwable {
+    createInMemStreams();
+
+    ShuffleReadClient shuffleReadClient = new MockedShuffleReadClient(bytesData);
+
+    FileSystem localFS = FileSystem.getLocal(this.conf);
+    LocalDirAllocator localDirAllocator = new LocalDirAllocator(TezRuntimeFrameworkConfigs.LOCAL_DIRS);
+
+    InputContext inputContext = createTezInputContext();
+
+    Combiner combiner = TezRuntimeUtils.instantiateCombiner(conf, inputContext);
+
+    MergeManager mergeManager = new MergeManager(
+        this.conf,
+        localFS,
+        localDirAllocator,
+        inputContext,
+        combiner,
+        null,
+        null,
+        null,
+        null,
+        1024 * 1024 * 256,
+        null,
+        false,
+        0);
+
+
+    RssTezShuffleDataFetcher fetcher = new RssTezShuffleDataFetcher(
+        new InputAttemptIdentifier(1, 0),
+        9,
+        mergeManager, new TezCounters(), shuffleReadClient, 3,
+        RssTezConfig.toRssConf(conf), null);
+
+    fetcher.fetchAllRssBlocks();
+
+    rawKeyValueIterator = mergeManager.close(true);
+
+    return new ValuesIterator(rawKeyValueIterator, comparator,
+        keyClass, valClass, conf, (TezCounter) new GenericCounter("inputKeyCounter", "y3"),
+        (TezCounter) new GenericCounter("inputValueCounter", "y4"));
+  }
+
+  private RawComparator getComparator(TestWithComparator comparator) {
+    switch (comparator) {
+      case LONG:
+        return new LongWritable.Comparator();
+      case INT:
+        return new IntWritable.Comparator();
+      case BYTES:
+        return new BytesWritable.Comparator();
+      case TEZ_BYTES:
+        return new TezBytesComparator();
+      case TEXT:
+        return new Text.Comparator();
+      default:
+        return null;
+    }
+  }
+
+  /**
+   * create byte array test data
+   *
+   * @return
+   * @throws IOException
+   */
+  public void createInMemStreams() throws IOException {
+    int numberOfStreams = 5;
+    LOG.info("No of streams : " + numberOfStreams);
+
+    SerializationFactory serializationFactory = new SerializationFactory(conf);
+    Serializer keySerializer = serializationFactory.getSerializer(keyClass);
+    Serializer valueSerializer = serializationFactory.getSerializer(valClass);
+
+    LocalDirAllocator localDirAllocator = new LocalDirAllocator(TezRuntimeFrameworkConfigs.LOCAL_DIRS);
+    InputContext context = createTezInputContext();
+    MergeManager mergeManager = new MergeManager(conf, fs, localDirAllocator,
+          context, null, null, null, null,
+        null, 1024 * 1024 * 10, null, false, -1);
+
+    DataOutputBuffer keyBuf = new DataOutputBuffer();
+    DataOutputBuffer valBuf = new DataOutputBuffer();
+    DataInputBuffer keyIn = new DataInputBuffer();
+    DataInputBuffer valIn = new DataInputBuffer();
+    keySerializer.open(keyBuf);
+    valueSerializer.open(valBuf);
+
+    for (int i = 0; i < numberOfStreams; i++) {
+      BoundedByteArrayOutputStream bout = new BoundedByteArrayOutputStream(1024 * 1024 * 10);
+      InMemoryWriter writer = new InMemoryWriter(bout);
+      Map<Writable, Writable> data = createData();
+      //write data
+      for (Map.Entry<Writable, Writable> entry : data.entrySet()) {
+        keySerializer.serialize(entry.getKey());
+        valueSerializer.serialize(entry.getValue());
+        keyIn.reset(keyBuf.getData(), 0, keyBuf.getLength());
+        valIn.reset(valBuf.getData(), 0, valBuf.getLength());
+        writer.append(keyIn, valIn);
+
+        originalData.put(entry.getKey(), entry.getValue());
+        keyBuf.reset();
+        valBuf.reset();
+        keyIn.reset();
+        valIn.reset();
+      }
+      data.clear();
+      writer.close();
+      bytesData.add(bout.getBuffer());
+    }
+  }
+
+  private InputContext createTezInputContext() {
+    TezCounters counters = new TezCounters();
+    InputContext inputContext = mock(InputContext.class);
+    doReturn(1024 * 1024 * 100L).when(inputContext).getTotalMemoryAvailableToTask();
+    doReturn(counters).when(inputContext).getCounters();
+    doReturn(1).when(inputContext).getInputIndex();
+    doReturn("srcVertex").when(inputContext).getSourceVertexName();
+    doReturn(1).when(inputContext).getTaskVertexIndex();
+    doReturn(UserPayload.create(ByteBuffer.wrap(new byte[1024]))).when(inputContext).getUserPayload();
+    doReturn("test_input").when(inputContext).getUniqueIdentifier();
+    return inputContext;
+  }
+
+  private Map<Writable, Writable> createData() {
+    Map<Writable, Writable> map = new TreeMap<Writable, Writable>(comparator);
+    for (int j = 0; j < 10; j++) {
+      Writable key = new Text(String.valueOf(j));
+      Writable value = new Text(String.valueOf(j));
+      map.put(key, value);
+    }
+    return map;
+  }
+
+
+  static class MockedShuffleReadClient implements ShuffleReadClient {
+    List<CompressedShuffleBlock> blocks;
+    int index = 0;
+
+    MockedShuffleReadClient(List<byte[]> data) {
+      this.blocks = new LinkedList<>();
+      data.forEach(bytes -> {
+        byte[] compressed = codec.compress(bytes);
+        blocks.add(new CompressedShuffleBlock(ByteBuffer.wrap(compressed), bytes.length));
+      });
+    }
+
+    @Override
+    public CompressedShuffleBlock readShuffleBlockData() {
+      if (index < blocks.size()) {
+        return blocks.get(index++);
+      } else {
+        return null;
+      }
+    }
+
+    @Override
+    public void checkProcessedBlockIds() {
+    }
+
+    @Override
+    public void close() {
+    }
+
+    @Override
+    public void logStatics() {
+    }
+  }
+}