You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@uniffle.apache.org by ro...@apache.org on 2023/06/12 09:33:50 UTC
[incubator-uniffle] branch master updated: [#854][FOLLOWUP] feat(tez): Add RssTezFetcherTask to fetch data from worker for OrderedInput (#935)
This is an automated email from the ASF dual-hosted git repository.
roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new 64ed9351 [#854][FOLLOWUP] feat(tez): Add RssTezFetcherTask to fetch data from worker for OrderedInput (#935)
64ed9351 is described below
commit 64ed9351b283878e7c80f038b440905e040936f6
Author: Qing <11...@qq.com>
AuthorDate: Mon Jun 12 17:33:44 2023 +0800
[#854][FOLLOWUP] feat(tez): Add RssTezFetcherTask to fetch data from worker for OrderedInput (#935)
### What changes were proposed in this pull request?
Add RssTezFetcherTask to fetch data from worker for OrderedInput
### Why are the changes needed?
Fix: https://github.com/apache/incubator-uniffle/issues/854
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
unit test
---
.../orderedgrouped/RssTezShuffleDataFetcher.java | 262 +++++++++++++++
.../RssTezShuffleDataFetcherTest.java | 371 +++++++++++++++++++++
2 files changed, 633 insertions(+)
diff --git a/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssTezShuffleDataFetcher.java b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssTezShuffleDataFetcher.java
new file mode 100644
index 00000000..97378270
--- /dev/null
+++ b/client-tez/src/main/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssTezShuffleDataFetcher.java
@@ -0,0 +1,262 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.common.shuffle.orderedgrouped;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import com.google.common.annotations.VisibleForTesting;
+import org.apache.tez.common.CallableWithNdc;
+import org.apache.tez.common.counters.TezCounter;
+import org.apache.tez.common.counters.TezCounters;
+import org.apache.tez.runtime.library.common.InputAttemptIdentifier;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.client.api.ShuffleReadClient;
+import org.apache.uniffle.client.response.CompressedShuffleBlock;
+import org.apache.uniffle.common.compression.Codec;
+import org.apache.uniffle.common.config.RssConf;
+import org.apache.uniffle.common.exception.RssException;
+
+public class RssTezShuffleDataFetcher extends CallableWithNdc<Void> {
+ private static final Logger LOG = LoggerFactory.getLogger(RssTezShuffleDataFetcher.class);
+
+ private enum ShuffleErrors {
+ IO_ERROR, WRONG_LENGTH, BAD_ID, WRONG_MAP,
+ CONNECTION, WRONG_REDUCE
+ }
+
+ private static final String SHUFFLE_ERR_GRP_NAME = "Shuffle Errors";
+
+ private final TezCounter ioErrs;
+ private final MergeManager merger;
+ private final long totalBlockCount;
+
+ private long copyBlockCount = 0;
+ private volatile boolean stopped = false;
+
+ private final ShuffleReadClient shuffleReadClient;
+ private long readTime = 0;
+ private long decompressTime = 0;
+ private long serializeTime = 0;
+ private long waitTime = 0;
+ private long copyTime = 0; // the sum of readTime + decompressTime + serializeTime + waitTime
+ private long unCompressionLength = 0;
+ private final InputAttemptIdentifier inputAttemptIdentifier;
+ private int uniqueMapId = 0;
+
+ private boolean hasPendingData = false;
+ private long startWait;
+ private int waitCount = 0;
+ private byte[] uncompressedData = null;
+ private final Codec rssCodec;
+ private Integer partitionId;
+ private final ExceptionReporter exceptionReporter;
+
+ private final AtomicInteger issuedCnt = new AtomicInteger(0);
+
+ public RssTezShuffleDataFetcher(InputAttemptIdentifier inputAttemptIdentifier,
+ Integer partitionId,
+ MergeManager merger,
+ TezCounters tezCounters,
+ ShuffleReadClient shuffleReadClient,
+ long totalBlockCount,
+ RssConf rssConf,
+ ExceptionReporter exceptionReporter) {
+ this.merger = merger;
+ this.partitionId = partitionId;
+ this.inputAttemptIdentifier = inputAttemptIdentifier;
+ this.exceptionReporter = exceptionReporter;
+ ioErrs = tezCounters.findCounter(SHUFFLE_ERR_GRP_NAME, RssTezShuffleDataFetcher.ShuffleErrors.IO_ERROR.toString());
+ this.shuffleReadClient = shuffleReadClient;
+ this.totalBlockCount = totalBlockCount;
+
+ this.rssCodec = Codec.newInstance(rssConf);
+ LOG.info("RssTezShuffleDataFetcher, partitionId:{}, inputAttemptIdentifier:{}.",
+ this.partitionId, this.inputAttemptIdentifier);
+ }
+
+ @Override
+ public Void callInternal() {
+ try {
+ fetchAllRssBlocks();
+ } catch (InterruptedException ie) {
+ //might not be respected when fetcher is in progress / server is busy. TEZ-711
+ //Set the status back
+ LOG.warn(ie.getMessage(), ie);
+ Thread.currentThread().interrupt();
+ return null;
+ } catch (Throwable t) {
+ LOG.warn(t.getMessage(), t);
+ exceptionReporter.reportException(t);
+ // Shuffle knows how to deal with failures post shutdown via the onFailure hook
+ }
+ return null;
+ }
+
+ public void fetchAllRssBlocks() throws IOException, InterruptedException {
+ while (!stopped) {
+ try {
+ // If merge is on, block
+ merger.waitForInMemoryMerge();
+ // Do shuffle
+ copyFromRssServer();
+ } catch (Exception e) {
+ LOG.warn(e.getMessage(), e);
+ throw e;
+ }
+ }
+ }
+
+ @VisibleForTesting
+ public void copyFromRssServer() throws IOException {
+ CompressedShuffleBlock compressedBlock = null;
+ ByteBuffer compressedData = null;
+ // fetch a block
+ if (!hasPendingData) {
+ final long startFetch = System.currentTimeMillis();
+ compressedBlock = shuffleReadClient.readShuffleBlockData();
+ if (compressedBlock != null) {
+ compressedData = compressedBlock.getByteBuffer();
+ }
+ long fetchDuration = System.currentTimeMillis() - startFetch;
+ readTime += fetchDuration;
+ }
+
+ // uncompress the block
+ if (!hasPendingData && compressedData != null) {
+ final long startDecompress = System.currentTimeMillis();
+ int uncompressedLen = compressedBlock.getUncompressLength();
+ ByteBuffer decompressedBuffer = ByteBuffer.allocate(uncompressedLen);
+ rssCodec.decompress(compressedData, uncompressedLen, decompressedBuffer, 0);
+ uncompressedData = decompressedBuffer.array();
+ unCompressionLength += compressedBlock.getUncompressLength();
+ long decompressDuration = System.currentTimeMillis() - startDecompress;
+ decompressTime += decompressDuration;
+ }
+
+ if (uncompressedData != null) {
+ // start to merge
+ final long startSerialization = System.currentTimeMillis();
+ if (issueMapOutputMerge()) {
+ long serializationDuration = System.currentTimeMillis() - startSerialization;
+ serializeTime += serializationDuration;
+ // if reserve successes, reset status for next fetch
+ if (hasPendingData) {
+ waitTime += System.currentTimeMillis() - startWait;
+ }
+ hasPendingData = false;
+ uncompressedData = null;
+ } else {
+ // if reserve fail, return and wait
+ startWait = System.currentTimeMillis();
+ return;
+ }
+
+ // update some status
+ copyBlockCount++;
+ copyTime = readTime + decompressTime + serializeTime + waitTime;
+ updateStatus();
+ } else {
+ // finish reading data, close related reader and check data consistent
+ shuffleReadClient.close();
+ shuffleReadClient.checkProcessedBlockIds();
+ shuffleReadClient.logStatics();
+ LOG.info("Reduce task " + inputAttemptIdentifier + " read block cnt: " + copyBlockCount
+ + " cost " + readTime + " ms to fetch and "
+ + decompressTime + " ms to decompress with unCompressionLength["
+ + unCompressionLength + "] and " + serializeTime + " ms to serialize and "
+ + waitTime + " ms to wait resource" + ", copy time:" + copyTime);
+ stopFetch();
+ }
+ }
+
+ public Integer getPartitionId() {
+ return partitionId;
+ }
+
+ public void setPartitionId(Integer partitionId) {
+ this.partitionId = partitionId;
+ }
+
+ private boolean issueMapOutputMerge() throws IOException {
+ // Allocate a MapOutput (either in-memory or on-disk) to put uncompressed block
+ // In Rss, a MapOutput is sent as multiple blocks, so the reducer needs to
+ // treat each "block" as a faked "mapout".
+ // To avoid name conflicts, we use getNextUniqueTaskAttemptID instead.
+ // It will generate a unique TaskAttemptID(increased_seq++, 0).
+ InputAttemptIdentifier uniqueInputAttemptIdentifier = getNextUniqueInputAttemptIdentifier();
+ MapOutput mapOutput = null;
+ try {
+ issuedCnt.incrementAndGet();
+ LOG.info("IssueMapOutputMerge, uncompressedData length:{}, issueCnt:{}, totalBlockCount:{}",
+ uncompressedData.length, issuedCnt.get(), totalBlockCount);
+ mapOutput = merger.reserve(uniqueInputAttemptIdentifier, uncompressedData.length, 0, 1);
+ } catch (IOException ioe) {
+ // kill this reduce attempt
+ ioErrs.increment(1);
+ throw ioe;
+ }
+ // Check if we can shuffle *now* ...
+ if (mapOutput == null) {
+ LOG.info("RssMRFetcher" + " - MergeManager returned status WAIT ...");
+ // Not an error but wait to process data.
+ // Use a retry flag to avoid re-fetch and re-uncompress.
+ hasPendingData = true;
+ waitCount++;
+ return false;
+ }
+
+ // write data to mapOutput
+ try {
+ RssTezBypassWriter.write(mapOutput, uncompressedData);
+ // let the merger knows this block is ready for merging
+ mapOutput.commit();
+ } catch (Throwable t) {
+ ioErrs.increment(1);
+ mapOutput.abort();
+ throw new RssException("Reduce: " + inputAttemptIdentifier + " cannot write block to "
+ + mapOutput.getClass().getSimpleName() + " due to: " + t.getClass().getName());
+ }
+ return true;
+ }
+
+ private InputAttemptIdentifier getNextUniqueInputAttemptIdentifier() {
+ return new InputAttemptIdentifier(uniqueMapId++, 0);
+ }
+
+ private void updateStatus() {
+ }
+
+ @VisibleForTesting
+ public int getRetryCount() {
+ return waitCount;
+ }
+
+ private void stopFetch() {
+ LOG.info("RssTezShuffleDataFetcher stop fetch");
+ stopped = true;
+ }
+
+ public void shutDown() {
+ stopFetch();
+ LOG.info("RssTezShuffleDataFetcher shutdown");
+ }
+}
diff --git a/client-tez/src/test/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssTezShuffleDataFetcherTest.java b/client-tez/src/test/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssTezShuffleDataFetcherTest.java
new file mode 100644
index 00000000..aa7d994b
--- /dev/null
+++ b/client-tez/src/test/java/org/apache/tez/runtime/library/common/shuffle/orderedgrouped/RssTezShuffleDataFetcherTest.java
@@ -0,0 +1,371 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.common.shuffle.orderedgrouped;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.TreeMap;
+
+import com.google.common.collect.ImmutableListMultimap;
+import com.google.common.collect.LinkedListMultimap;
+import com.google.common.collect.ListMultimap;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.LocalDirAllocator;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.BoundedByteArrayOutputStream;
+import org.apache.hadoop.io.BytesWritable;
+import org.apache.hadoop.io.DataInputBuffer;
+import org.apache.hadoop.io.DataOutputBuffer;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.RawComparator;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.io.serializer.SerializationFactory;
+import org.apache.hadoop.io.serializer.Serializer;
+import org.apache.tez.common.RssTezConfig;
+import org.apache.tez.common.TezRuntimeFrameworkConfigs;
+import org.apache.tez.common.counters.GenericCounter;
+import org.apache.tez.common.counters.TezCounter;
+import org.apache.tez.common.counters.TezCounters;
+import org.apache.tez.dag.api.UserPayload;
+import org.apache.tez.runtime.api.InputContext;
+import org.apache.tez.runtime.library.api.TezRuntimeConfiguration;
+import org.apache.tez.runtime.library.common.InputAttemptIdentifier;
+import org.apache.tez.runtime.library.common.TezRuntimeUtils;
+import org.apache.tez.runtime.library.common.ValuesIterator;
+import org.apache.tez.runtime.library.common.combine.Combiner;
+import org.apache.tez.runtime.library.common.comparator.TezBytesComparator;
+import org.apache.tez.runtime.library.common.sort.impl.TezRawKeyValueIterator;
+import org.junit.jupiter.api.Test;
+import org.mockito.internal.util.collections.Sets;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.client.api.ShuffleReadClient;
+import org.apache.uniffle.client.response.CompressedShuffleBlock;
+import org.apache.uniffle.common.compression.Codec;
+import org.apache.uniffle.common.compression.Lz4Codec;
+
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.junit.jupiter.api.Assertions.fail;
+import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.mock;
+
+public class RssTezShuffleDataFetcherTest {
+ private static final Logger LOG = LoggerFactory.getLogger(RssTezShuffleDataFetcherTest.class);
+
+ enum TestWithComparator {
+ LONG, INT, BYTES, TEZ_BYTES, TEXT
+ }
+
+ Configuration conf;
+ FileSystem fs;
+
+ final Class keyClass;
+ final Class valClass;
+ final RawComparator comparator;
+ final boolean expectedTestResult;
+
+ int mergeFactor;
+ //For storing original data
+ final ListMultimap<Writable, Writable> originalData;
+
+ TezRawKeyValueIterator rawKeyValueIterator;
+
+ Path baseDir;
+ Path tmpDir;
+ static List<byte[]> bytesData = new ArrayList<>();
+ static Codec codec = new Lz4Codec();
+
+ public RssTezShuffleDataFetcherTest() throws IOException, ClassNotFoundException {
+ this.keyClass = Class.forName("org.apache.hadoop.io.Text");
+ this.valClass = Class.forName("org.apache.hadoop.io.Text");
+ this.comparator = getComparator(TestWithComparator.TEXT);
+ this.expectedTestResult = true;
+ originalData = LinkedListMultimap.create();
+ setupConf();
+ }
+
+ private void setupConf() throws IOException, ClassNotFoundException {
+ mergeFactor = 2;
+ conf = new Configuration();
+ conf.setInt(TezRuntimeConfiguration.TEZ_RUNTIME_IO_SORT_FACTOR, mergeFactor);
+ conf.setClass(TezRuntimeConfiguration.TEZ_RUNTIME_KEY_COMPARATOR_CLASS,
+ Class.forName("org.apache.tez.runtime.library.common.comparator.TezBytesComparator"),
+ Class.forName("org.apache.hadoop.io.WritableComparator"));
+ baseDir = new Path(".", this.getClass().getName());
+ String localDirs = baseDir.toString();
+ conf.setStrings(TezRuntimeFrameworkConfigs.LOCAL_DIRS, localDirs);
+ fs = FileSystem.getLocal(conf);
+ }
+
+ @Test
+ public void testIteratorWithInMemoryReader() throws Throwable {
+ fs.mkdirs(baseDir);
+ tmpDir = new Path(baseDir, "tmp");
+
+ ValuesIterator iterator = createIterator();
+ verifyIteratorData(iterator);
+
+ fs.delete(baseDir, true);
+ originalData.clear();
+ }
+
+ private void getNextFromFinishedIterator(ValuesIterator iterator) {
+ try {
+ boolean hasNext = iterator.moveToNext();
+ fail();
+ } catch (IOException e) {
+ assertTrue(e.getMessage().contains("Please check if you are invoking moveToNext()"));
+ }
+ }
+
+ /**
+ * Tests whether data in valuesIterator matches with sorted input data set.
+ *
+ * Returns a list of value counts for each key.
+ *
+ * @param valuesIterator
+ * @return List
+ * @throws IOException
+ */
+ private void verifyIteratorData(ValuesIterator valuesIterator) throws IOException {
+ boolean result = true;
+
+ // sort original data based on comparator
+ ListMultimap<Writable, Writable> sortedMap =
+ new ImmutableListMultimap.Builder<Writable, Writable>()
+ .orderKeysBy(this.comparator).putAll(originalData).build();
+
+ Set<Map.Entry<Writable, Writable>> oriKeySet = Sets.newSet();
+ oriKeySet.addAll(sortedMap.entries());
+
+ //Iterate through sorted data and valuesIterator for verification
+ for (Map.Entry<Writable, Writable> entry : oriKeySet) {
+ assertTrue(valuesIterator.moveToNext());
+ Writable oriKey = entry.getKey();
+ //Verify if the key and the original key are same
+ if (!oriKey.equals((Writable) valuesIterator.getKey())) {
+ result = false;
+ break;
+ }
+
+ int valueCount = 0;
+ //Verify values
+ Iterator<Writable> vItr = valuesIterator.getValues().iterator();
+ for (Writable val : sortedMap.get(oriKey)) {
+ assertTrue(vItr.hasNext());
+ //Verify if the values are same
+ if (!val.equals((Writable) vItr.next())) {
+ result = false;
+ break;
+ }
+ valueCount++;
+ }
+ assertTrue(valueCount > 0);
+ }
+ assertTrue(result);
+ assertFalse(valuesIterator.moveToNext());
+ getNextFromFinishedIterator(valuesIterator);
+ }
+
+ /**
+ * Create sample data (in memory / disk based), merge them and return ValuesIterator
+ *
+ * @return ValuesIterator
+ * @throws IOException
+ */
+ @SuppressWarnings("unchecked")
+ private ValuesIterator createIterator() throws Throwable {
+ createInMemStreams();
+
+ ShuffleReadClient shuffleReadClient = new MockedShuffleReadClient(bytesData);
+
+ FileSystem localFS = FileSystem.getLocal(this.conf);
+ LocalDirAllocator localDirAllocator = new LocalDirAllocator(TezRuntimeFrameworkConfigs.LOCAL_DIRS);
+
+ InputContext inputContext = createTezInputContext();
+
+ Combiner combiner = TezRuntimeUtils.instantiateCombiner(conf, inputContext);
+
+ MergeManager mergeManager = new MergeManager(
+ this.conf,
+ localFS,
+ localDirAllocator,
+ inputContext,
+ combiner,
+ null,
+ null,
+ null,
+ null,
+ 1024 * 1024 * 256,
+ null,
+ false,
+ 0);
+
+
+ RssTezShuffleDataFetcher fetcher = new RssTezShuffleDataFetcher(
+ new InputAttemptIdentifier(1, 0),
+ 9,
+ mergeManager, new TezCounters(), shuffleReadClient, 3,
+ RssTezConfig.toRssConf(conf), null);
+
+ fetcher.fetchAllRssBlocks();
+
+ rawKeyValueIterator = mergeManager.close(true);
+
+ return new ValuesIterator(rawKeyValueIterator, comparator,
+ keyClass, valClass, conf, (TezCounter) new GenericCounter("inputKeyCounter", "y3"),
+ (TezCounter) new GenericCounter("inputValueCounter", "y4"));
+ }
+
+ private RawComparator getComparator(TestWithComparator comparator) {
+ switch (comparator) {
+ case LONG:
+ return new LongWritable.Comparator();
+ case INT:
+ return new IntWritable.Comparator();
+ case BYTES:
+ return new BytesWritable.Comparator();
+ case TEZ_BYTES:
+ return new TezBytesComparator();
+ case TEXT:
+ return new Text.Comparator();
+ default:
+ return null;
+ }
+ }
+
+ /**
+ * create byte array test data
+ *
+ * @return
+ * @throws IOException
+ */
+ public void createInMemStreams() throws IOException {
+ int numberOfStreams = 5;
+ LOG.info("No of streams : " + numberOfStreams);
+
+ SerializationFactory serializationFactory = new SerializationFactory(conf);
+ Serializer keySerializer = serializationFactory.getSerializer(keyClass);
+ Serializer valueSerializer = serializationFactory.getSerializer(valClass);
+
+ LocalDirAllocator localDirAllocator = new LocalDirAllocator(TezRuntimeFrameworkConfigs.LOCAL_DIRS);
+ InputContext context = createTezInputContext();
+ MergeManager mergeManager = new MergeManager(conf, fs, localDirAllocator,
+ context, null, null, null, null,
+ null, 1024 * 1024 * 10, null, false, -1);
+
+ DataOutputBuffer keyBuf = new DataOutputBuffer();
+ DataOutputBuffer valBuf = new DataOutputBuffer();
+ DataInputBuffer keyIn = new DataInputBuffer();
+ DataInputBuffer valIn = new DataInputBuffer();
+ keySerializer.open(keyBuf);
+ valueSerializer.open(valBuf);
+
+ for (int i = 0; i < numberOfStreams; i++) {
+ BoundedByteArrayOutputStream bout = new BoundedByteArrayOutputStream(1024 * 1024 * 10);
+ InMemoryWriter writer = new InMemoryWriter(bout);
+ Map<Writable, Writable> data = createData();
+ //write data
+ for (Map.Entry<Writable, Writable> entry : data.entrySet()) {
+ keySerializer.serialize(entry.getKey());
+ valueSerializer.serialize(entry.getValue());
+ keyIn.reset(keyBuf.getData(), 0, keyBuf.getLength());
+ valIn.reset(valBuf.getData(), 0, valBuf.getLength());
+ writer.append(keyIn, valIn);
+
+ originalData.put(entry.getKey(), entry.getValue());
+ keyBuf.reset();
+ valBuf.reset();
+ keyIn.reset();
+ valIn.reset();
+ }
+ data.clear();
+ writer.close();
+ bytesData.add(bout.getBuffer());
+ }
+ }
+
+ private InputContext createTezInputContext() {
+ TezCounters counters = new TezCounters();
+ InputContext inputContext = mock(InputContext.class);
+ doReturn(1024 * 1024 * 100L).when(inputContext).getTotalMemoryAvailableToTask();
+ doReturn(counters).when(inputContext).getCounters();
+ doReturn(1).when(inputContext).getInputIndex();
+ doReturn("srcVertex").when(inputContext).getSourceVertexName();
+ doReturn(1).when(inputContext).getTaskVertexIndex();
+ doReturn(UserPayload.create(ByteBuffer.wrap(new byte[1024]))).when(inputContext).getUserPayload();
+ doReturn("test_input").when(inputContext).getUniqueIdentifier();
+ return inputContext;
+ }
+
+ private Map<Writable, Writable> createData() {
+ Map<Writable, Writable> map = new TreeMap<Writable, Writable>(comparator);
+ for (int j = 0; j < 10; j++) {
+ Writable key = new Text(String.valueOf(j));
+ Writable value = new Text(String.valueOf(j));
+ map.put(key, value);
+ }
+ return map;
+ }
+
+
+ static class MockedShuffleReadClient implements ShuffleReadClient {
+ List<CompressedShuffleBlock> blocks;
+ int index = 0;
+
+ MockedShuffleReadClient(List<byte[]> data) {
+ this.blocks = new LinkedList<>();
+ data.forEach(bytes -> {
+ byte[] compressed = codec.compress(bytes);
+ blocks.add(new CompressedShuffleBlock(ByteBuffer.wrap(compressed), bytes.length));
+ });
+ }
+
+ @Override
+ public CompressedShuffleBlock readShuffleBlockData() {
+ if (index < blocks.size()) {
+ return blocks.get(index++);
+ } else {
+ return null;
+ }
+ }
+
+ @Override
+ public void checkProcessedBlockIds() {
+ }
+
+ @Override
+ public void close() {
+ }
+
+ @Override
+ public void logStatics() {
+ }
+ }
+}