You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tez.apache.org by ss...@apache.org on 2013/09/25 09:31:40 UTC

[33/50] [abbrv] Rename tez-engine-api to tez-runtime-api and tez-engine is split into 2: - tez-engine-library for user-visible Input/Output/Processor implementations - tez-engine-internals for framework internals

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/b212ca1d/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/Fetcher.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/Fetcher.java b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/Fetcher.java
new file mode 100644
index 0000000..f5d1802
--- /dev/null
+++ b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/Fetcher.java
@@ -0,0 +1,624 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.tez.runtime.library.common.shuffle.impl;
+
+import java.io.DataInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.net.HttpURLConnection;
+import java.net.MalformedURLException;
+import java.net.URL;
+import java.net.URLConnection;
+import java.security.GeneralSecurityException;
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import javax.crypto.SecretKey;
+import javax.net.ssl.HttpsURLConnection;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.IOUtils;
+import org.apache.hadoop.io.compress.CodecPool;
+import org.apache.hadoop.io.compress.CompressionCodec;
+import org.apache.hadoop.io.compress.Decompressor;
+import org.apache.hadoop.io.compress.DefaultCodec;
+import org.apache.hadoop.security.ssl.SSLFactory;
+import org.apache.hadoop.util.ReflectionUtils;
+import org.apache.tez.common.TezJobConfig;
+import org.apache.tez.common.counters.TezCounter;
+import org.apache.tez.runtime.api.TezInputContext;
+import org.apache.tez.runtime.library.common.ConfigUtils;
+import org.apache.tez.runtime.library.common.InputAttemptIdentifier;
+import org.apache.tez.runtime.library.common.security.SecureShuffleUtils;
+import org.apache.tez.runtime.library.common.shuffle.impl.MapOutput.Type;
+import org.apache.tez.runtime.library.common.sort.impl.IFileInputStream;
+
+import com.google.common.annotations.VisibleForTesting;
+
+class Fetcher extends Thread {
+  
+  private static final Log LOG = LogFactory.getLog(Fetcher.class);
+  
+  /** Basic/unit connection timeout (in milliseconds) */
+  private final static int UNIT_CONNECT_TIMEOUT = 60 * 1000;
+
+  private static enum ShuffleErrors{IO_ERROR, WRONG_LENGTH, BAD_ID, WRONG_MAP,
+                                    CONNECTION, WRONG_REDUCE}
+  
+  private final static String SHUFFLE_ERR_GRP_NAME = "Shuffle Errors";
+  private final TezCounter connectionErrs;
+  private final TezCounter ioErrs;
+  private final TezCounter wrongLengthErrs;
+  private final TezCounter badIdErrs;
+  private final TezCounter wrongMapErrs;
+  private final TezCounter wrongReduceErrs;
+  private final MergeManager merger;
+  private final ShuffleScheduler scheduler;
+  private final ShuffleClientMetrics metrics;
+  private final Shuffle shuffle;
+  private final int id;
+  private static int nextId = 0;
+  
+  private final int connectionTimeout;
+  private final int readTimeout;
+  
+  // Decompression of map-outputs
+  private final CompressionCodec codec;
+  private final Decompressor decompressor;
+  private final SecretKey jobTokenSecret;
+
+  private volatile boolean stopped = false;
+
+  private Configuration job;
+
+  private static boolean sslShuffle;
+  private static SSLFactory sslFactory;
+
+  public Fetcher(Configuration job, 
+      ShuffleScheduler scheduler, MergeManager merger,
+      ShuffleClientMetrics metrics,
+      Shuffle shuffle, SecretKey jobTokenSecret, TezInputContext inputContext) throws IOException {
+    this.job = job;
+    this.scheduler = scheduler;
+    this.merger = merger;
+    this.metrics = metrics;
+    this.shuffle = shuffle;
+    this.id = ++nextId;
+    this.jobTokenSecret = jobTokenSecret;
+    ioErrs = inputContext.getCounters().findCounter(SHUFFLE_ERR_GRP_NAME,
+        ShuffleErrors.IO_ERROR.toString());
+    wrongLengthErrs = inputContext.getCounters().findCounter(SHUFFLE_ERR_GRP_NAME,
+        ShuffleErrors.WRONG_LENGTH.toString());
+    badIdErrs = inputContext.getCounters().findCounter(SHUFFLE_ERR_GRP_NAME,
+        ShuffleErrors.BAD_ID.toString());
+    wrongMapErrs = inputContext.getCounters().findCounter(SHUFFLE_ERR_GRP_NAME,
+        ShuffleErrors.WRONG_MAP.toString());
+    connectionErrs = inputContext.getCounters().findCounter(SHUFFLE_ERR_GRP_NAME,
+        ShuffleErrors.CONNECTION.toString());
+    wrongReduceErrs = inputContext.getCounters().findCounter(SHUFFLE_ERR_GRP_NAME,
+        ShuffleErrors.WRONG_REDUCE.toString());
+
+    if (ConfigUtils.isIntermediateInputCompressed(job)) {
+      Class<? extends CompressionCodec> codecClass =
+          ConfigUtils.getIntermediateInputCompressorClass(job, DefaultCodec.class);
+      codec = ReflectionUtils.newInstance(codecClass, job);
+      decompressor = CodecPool.getDecompressor(codec);
+    } else {
+      codec = null;
+      decompressor = null;
+    }
+
+    this.connectionTimeout = 
+        job.getInt(TezJobConfig.TEZ_RUNTIME_SHUFFLE_CONNECT_TIMEOUT,
+            TezJobConfig.DEFAULT_TEZ_RUNTIME_SHUFFLE_STALLED_COPY_TIMEOUT);
+    this.readTimeout = 
+        job.getInt(TezJobConfig.TEZ_RUNTIME_SHUFFLE_READ_TIMEOUT, 
+            TezJobConfig.DEFAULT_TEZ_RUNTIME_SHUFFLE_READ_TIMEOUT);
+
+    setName("fetcher#" + id);
+    setDaemon(true);
+
+    synchronized (Fetcher.class) {
+      sslShuffle = job.getBoolean(TezJobConfig.TEZ_RUNTIME_SHUFFLE_ENABLE_SSL,
+          TezJobConfig.DEFAULT_TEZ_RUNTIME_SHUFFLE_ENABLE_SSL);
+      if (sslShuffle && sslFactory == null) {
+        sslFactory = new SSLFactory(SSLFactory.Mode.CLIENT, job);
+        try {
+          sslFactory.init();
+        } catch (Exception ex) {
+          sslFactory.destroy();
+          throw new RuntimeException(ex);
+        }
+      }
+    }
+  }
+  
+  public void run() {
+    try {
+      while (!stopped && !Thread.currentThread().isInterrupted()) {
+        MapHost host = null;
+        try {
+          // If merge is on, block
+          merger.waitForInMemoryMerge();
+
+          // Get a host to shuffle from
+          host = scheduler.getHost();
+          metrics.threadBusy();
+
+          // Shuffle
+          copyFromHost(host);
+        } finally {
+          if (host != null) {
+            scheduler.freeHost(host);
+            metrics.threadFree();            
+          }
+        }
+      }
+    } catch (InterruptedException ie) {
+      return;
+    } catch (Throwable t) {
+      shuffle.reportException(t);
+    }
+  }
+
+  public void shutDown() throws InterruptedException {
+    this.stopped = true;
+    interrupt();
+    try {
+      join(5000);
+    } catch (InterruptedException ie) {
+      LOG.warn("Got interrupt while joining " + getName(), ie);
+    }
+    if (sslFactory != null) {
+      sslFactory.destroy();
+    }
+  }
+
+  @VisibleForTesting
+  protected HttpURLConnection openConnection(URL url) throws IOException {
+    HttpURLConnection conn = (HttpURLConnection) url.openConnection();
+    if (sslShuffle) {
+      HttpsURLConnection httpsConn = (HttpsURLConnection) conn;
+      try {
+        httpsConn.setSSLSocketFactory(sslFactory.createSSLSocketFactory());
+      } catch (GeneralSecurityException ex) {
+        throw new IOException(ex);
+      }
+      httpsConn.setHostnameVerifier(sslFactory.getHostnameVerifier());
+    }
+    return conn;
+  }
+  
+  /**
+   * The crux of the matter...
+   * 
+   * @param host {@link MapHost} from which we need to  
+   *              shuffle available map-outputs.
+   */
+  @VisibleForTesting
+  protected void copyFromHost(MapHost host) throws IOException {
+    // Get completed maps on 'host'
+    List<InputAttemptIdentifier> srcAttempts = scheduler.getMapsForHost(host);
+    
+    // Sanity check to catch hosts with only 'OBSOLETE' maps, 
+    // especially at the tail of large jobs
+    if (srcAttempts.size() == 0) {
+      return;
+    }
+    
+    if(LOG.isDebugEnabled()) {
+      LOG.debug("Fetcher " + id + " going to fetch from " + host + " for: "
+        + srcAttempts);
+    }
+    
+    // List of maps to be fetched yet
+    Set<InputAttemptIdentifier> remaining = new HashSet<InputAttemptIdentifier>(srcAttempts);
+    
+    // Construct the url and connect
+    DataInputStream input;
+    boolean connectSucceeded = false;
+    
+    try {
+      URL url = getMapOutputURL(host, srcAttempts);
+      HttpURLConnection connection = openConnection(url);
+      
+      // generate hash of the url
+      String msgToEncode = SecureShuffleUtils.buildMsgFrom(url);
+      String encHash = SecureShuffleUtils.hashFromString(msgToEncode, jobTokenSecret);
+      
+      // put url hash into http header
+      connection.addRequestProperty(
+          SecureShuffleUtils.HTTP_HEADER_URL_HASH, encHash);
+      // set the read timeout
+      connection.setReadTimeout(readTimeout);
+      // put shuffle version into http header
+      connection.addRequestProperty(ShuffleHeader.HTTP_HEADER_NAME,
+          ShuffleHeader.DEFAULT_HTTP_HEADER_NAME);
+      connection.addRequestProperty(ShuffleHeader.HTTP_HEADER_VERSION,
+          ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION);
+      connect(connection, connectionTimeout);
+      connectSucceeded = true;
+      input = new DataInputStream(connection.getInputStream());
+
+      // Validate response code
+      int rc = connection.getResponseCode();
+      if (rc != HttpURLConnection.HTTP_OK) {
+        throw new IOException(
+            "Got invalid response code " + rc + " from " + url +
+            ": " + connection.getResponseMessage());
+      }
+      // get the shuffle version
+      if (!ShuffleHeader.DEFAULT_HTTP_HEADER_NAME.equals(
+          connection.getHeaderField(ShuffleHeader.HTTP_HEADER_NAME))
+          || !ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION.equals(
+              connection.getHeaderField(ShuffleHeader.HTTP_HEADER_VERSION))) {
+        throw new IOException("Incompatible shuffle response version");
+      }
+      // get the replyHash which is HMac of the encHash we sent to the server
+      String replyHash = connection.getHeaderField(SecureShuffleUtils.HTTP_HEADER_REPLY_URL_HASH);
+      if(replyHash==null) {
+        throw new IOException("security validation of TT Map output failed");
+      }
+      LOG.debug("url="+msgToEncode+";encHash="+encHash+";replyHash="+replyHash);
+      // verify that replyHash is HMac of encHash
+      SecureShuffleUtils.verifyReply(replyHash, encHash, jobTokenSecret);
+      LOG.info("for url="+msgToEncode+" sent hash and receievd reply");
+    } catch (IOException ie) {
+      ioErrs.increment(1);
+      LOG.warn("Failed to connect to " + host + " with " + remaining.size() + 
+               " map outputs", ie);
+
+      // If connect did not succeed, just mark all the maps as failed,
+      // indirectly penalizing the host
+      if (!connectSucceeded) {
+        for(InputAttemptIdentifier left: remaining) {
+          scheduler.copyFailed(left, host, connectSucceeded);
+        }
+      } else {
+        // If we got a read error at this stage, it implies there was a problem
+        // with the first map, typically lost map. So, penalize only that map
+        // and add the rest
+        InputAttemptIdentifier firstMap = srcAttempts.get(0);
+        scheduler.copyFailed(firstMap, host, connectSucceeded);
+      }
+      
+      // Add back all the remaining maps, WITHOUT marking them as failed
+      for(InputAttemptIdentifier left: remaining) {
+        // TODO Should the first one be skipped ?
+        scheduler.putBackKnownMapOutput(host, left);
+      }
+      
+      return;
+    }
+    
+    try {
+      // Loop through available map-outputs and fetch them
+      // On any error, faildTasks is not null and we exit
+      // after putting back the remaining maps to the 
+      // yet_to_be_fetched list and marking the failed tasks.
+      InputAttemptIdentifier[] failedTasks = null;
+      while (!remaining.isEmpty() && failedTasks == null) {
+        failedTasks = copyMapOutput(host, input, remaining);
+      }
+      
+      if(failedTasks != null && failedTasks.length > 0) {
+        LOG.warn("copyMapOutput failed for tasks "+Arrays.toString(failedTasks));
+        for(InputAttemptIdentifier left: failedTasks) {
+          scheduler.copyFailed(left, host, true);
+        }
+      }
+      
+      IOUtils.cleanup(LOG, input);
+      
+      // Sanity check
+      if (failedTasks == null && !remaining.isEmpty()) {
+        throw new IOException("server didn't return all expected map outputs: "
+            + remaining.size() + " left.");
+      }
+    } finally {
+      for (InputAttemptIdentifier left : remaining) {
+        scheduler.putBackKnownMapOutput(host, left);
+      }
+    }
+  }
+  
+  private static InputAttemptIdentifier[] EMPTY_ATTEMPT_ID_ARRAY = new InputAttemptIdentifier[0];
+  
+  private InputAttemptIdentifier[] copyMapOutput(MapHost host,
+                                DataInputStream input,
+                                Set<InputAttemptIdentifier> remaining) {
+    MapOutput mapOutput = null;
+    InputAttemptIdentifier srcAttemptId = null;
+    long decompressedLength = -1;
+    long compressedLength = -1;
+    
+    try {
+      long startTime = System.currentTimeMillis();
+      int forReduce = -1;
+      //Read the shuffle header
+      try {
+        ShuffleHeader header = new ShuffleHeader();
+        header.readFields(input);
+        String pathComponent = header.mapId;
+        srcAttemptId = scheduler.getIdentifierForPathComponent(pathComponent);
+        compressedLength = header.compressedLength;
+        decompressedLength = header.uncompressedLength;
+        forReduce = header.forReduce;
+      } catch (IllegalArgumentException e) {
+        badIdErrs.increment(1);
+        LOG.warn("Invalid map id ", e);
+        //Don't know which one was bad, so consider all of them as bad
+        return remaining.toArray(new InputAttemptIdentifier[remaining.size()]);
+      }
+
+ 
+      // Do some basic sanity verification
+      if (!verifySanity(compressedLength, decompressedLength, forReduce,
+          remaining, srcAttemptId)) {
+        return new InputAttemptIdentifier[] {srcAttemptId};
+      }
+      
+      if(LOG.isDebugEnabled()) {
+        LOG.debug("header: " + srcAttemptId + ", len: " + compressedLength + 
+            ", decomp len: " + decompressedLength);
+      }
+      
+      // Get the location for the map output - either in-memory or on-disk
+      mapOutput = merger.reserve(srcAttemptId, decompressedLength, id);
+      
+      // Check if we can shuffle *now* ...
+      if (mapOutput.getType() == Type.WAIT) {
+        LOG.info("fetcher#" + id + " - MergerManager returned Status.WAIT ...");
+        //Not an error but wait to process data.
+        return EMPTY_ATTEMPT_ID_ARRAY;
+      } 
+      
+      // Go!
+      LOG.info("fetcher#" + id + " about to shuffle output of map " + 
+               mapOutput.getAttemptIdentifier() + " decomp: " +
+               decompressedLength + " len: " + compressedLength + " to " +
+               mapOutput.getType());
+      if (mapOutput.getType() == Type.MEMORY) {
+        shuffleToMemory(host, mapOutput, input, 
+                        (int) decompressedLength, (int) compressedLength);
+      } else {
+        shuffleToDisk(host, mapOutput, input, compressedLength);
+      }
+      
+      // Inform the shuffle scheduler
+      long endTime = System.currentTimeMillis();
+      scheduler.copySucceeded(srcAttemptId, host, compressedLength, 
+                              endTime - startTime, mapOutput);
+      // Note successful shuffle
+      remaining.remove(srcAttemptId);
+      metrics.successFetch();
+      return null;
+    } catch (IOException ioe) {
+      ioErrs.increment(1);
+      if (srcAttemptId == null || mapOutput == null) {
+        LOG.info("fetcher#" + id + " failed to read map header" + 
+                 srcAttemptId + " decomp: " + 
+                 decompressedLength + ", " + compressedLength, ioe);
+        if(srcAttemptId == null) {
+          return remaining.toArray(new InputAttemptIdentifier[remaining.size()]);
+        } else {
+          return new InputAttemptIdentifier[] {srcAttemptId};
+        }
+      }
+      
+      LOG.warn("Failed to shuffle output of " + srcAttemptId + 
+               " from " + host.getHostName(), ioe); 
+
+      // Inform the shuffle-scheduler
+      mapOutput.abort();
+      metrics.failedFetch();
+      return new InputAttemptIdentifier[] {srcAttemptId};
+    }
+
+  }
+  
+  /**
+   * Do some basic verification on the input received -- Being defensive
+   * @param compressedLength
+   * @param decompressedLength
+   * @param forReduce
+   * @param remaining
+   * @param mapId
+   * @return true/false, based on if the verification succeeded or not
+   */
+  private boolean verifySanity(long compressedLength, long decompressedLength,
+      int forReduce, Set<InputAttemptIdentifier> remaining, InputAttemptIdentifier srcAttemptId) {
+    if (compressedLength < 0 || decompressedLength < 0) {
+      wrongLengthErrs.increment(1);
+      LOG.warn(getName() + " invalid lengths in map output header: id: " +
+          srcAttemptId + " len: " + compressedLength + ", decomp len: " + 
+               decompressedLength);
+      return false;
+    }
+    
+    int reduceStartId = shuffle.getReduceStartId();
+    int reduceRange = shuffle.getReduceRange();
+    if (forReduce < reduceStartId || forReduce >= reduceStartId+reduceRange) {
+      wrongReduceErrs.increment(1);
+      LOG.warn(getName() + " data for the wrong reduce map: " +
+               srcAttemptId + " len: " + compressedLength + " decomp len: " +
+               decompressedLength + " for reduce " + forReduce);
+      return false;
+    }
+
+    // Sanity check
+    if (!remaining.contains(srcAttemptId)) {
+      wrongMapErrs.increment(1);
+      LOG.warn("Invalid map-output! Received output for " + srcAttemptId);
+      return false;
+    }
+    
+    return true;
+  }
+
+  /**
+   * Create the map-output-url. This will contain all the map ids
+   * separated by commas
+   * @param host
+   * @param maps
+   * @return
+   * @throws MalformedURLException
+   */
+  private URL getMapOutputURL(MapHost host, List<InputAttemptIdentifier> srcAttempts
+                              )  throws MalformedURLException {
+    // Get the base url
+    StringBuffer url = new StringBuffer(host.getBaseUrl());
+    
+    boolean first = true;
+    for (InputAttemptIdentifier mapId : srcAttempts) {
+      if (!first) {
+        url.append(",");
+      }
+      url.append(mapId.getPathComponent());
+      first = false;
+    }
+   
+    if (LOG.isDebugEnabled()) {
+      LOG.debug("MapOutput URL for " + host + " -> " + url.toString());
+    }
+    return new URL(url.toString());
+  }
+  
+  /** 
+   * The connection establishment is attempted multiple times and is given up 
+   * only on the last failure. Instead of connecting with a timeout of 
+   * X, we try connecting with a timeout of x < X but multiple times. 
+   */
+  private void connect(URLConnection connection, int connectionTimeout)
+  throws IOException {
+    int unit = 0;
+    if (connectionTimeout < 0) {
+      throw new IOException("Invalid timeout "
+                            + "[timeout = " + connectionTimeout + " ms]");
+    } else if (connectionTimeout > 0) {
+      unit = Math.min(UNIT_CONNECT_TIMEOUT, connectionTimeout);
+    }
+    // set the connect timeout to the unit-connect-timeout
+    connection.setConnectTimeout(unit);
+    while (true) {
+      try {
+        connection.connect();
+        break;
+      } catch (IOException ioe) {
+        // update the total remaining connect-timeout
+        connectionTimeout -= unit;
+
+        // throw an exception if we have waited for timeout amount of time
+        // note that the updated value if timeout is used here
+        if (connectionTimeout == 0) {
+          throw ioe;
+        }
+
+        // reset the connect timeout for the last try
+        if (connectionTimeout < unit) {
+          unit = connectionTimeout;
+          // reset the connect time out for the final connect
+          connection.setConnectTimeout(unit);
+        }
+      }
+    }
+  }
+
+  private void shuffleToMemory(MapHost host, MapOutput mapOutput, 
+                               InputStream input, 
+                               int decompressedLength, 
+                               int compressedLength) throws IOException {    
+    IFileInputStream checksumIn = 
+      new IFileInputStream(input, compressedLength, job);
+
+    input = checksumIn;       
+  
+    // Are map-outputs compressed?
+    if (codec != null) {
+      decompressor.reset();
+      input = codec.createInputStream(input, decompressor);
+    }
+  
+    // Copy map-output into an in-memory buffer
+    byte[] shuffleData = mapOutput.getMemory();
+    
+    try {
+      IOUtils.readFully(input, shuffleData, 0, shuffleData.length);
+      metrics.inputBytes(shuffleData.length);
+      LOG.info("Read " + shuffleData.length + " bytes from map-output for " +
+               mapOutput.getAttemptIdentifier());
+    } catch (IOException ioe) {      
+      // Close the streams
+      IOUtils.cleanup(LOG, input);
+
+      // Re-throw
+      throw ioe;
+    }
+
+  }
+  
+  private void shuffleToDisk(MapHost host, MapOutput mapOutput, 
+                             InputStream input, 
+                             long compressedLength) 
+  throws IOException {
+    // Copy data to local-disk
+    OutputStream output = mapOutput.getDisk();
+    long bytesLeft = compressedLength;
+    try {
+      final int BYTES_TO_READ = 64 * 1024;
+      byte[] buf = new byte[BYTES_TO_READ];
+      while (bytesLeft > 0) {
+        int n = input.read(buf, 0, (int) Math.min(bytesLeft, BYTES_TO_READ));
+        if (n < 0) {
+          throw new IOException("read past end of stream reading " + 
+                                mapOutput.getAttemptIdentifier());
+        }
+        output.write(buf, 0, n);
+        bytesLeft -= n;
+        metrics.inputBytes(n);
+      }
+
+      LOG.info("Read " + (compressedLength - bytesLeft) + 
+               " bytes from map-output for " +
+               mapOutput.getAttemptIdentifier());
+
+      output.close();
+    } catch (IOException ioe) {
+      // Close the streams
+      IOUtils.cleanup(LOG, input, output);
+
+      // Re-throw
+      throw ioe;
+    }
+
+    // Sanity check
+    if (bytesLeft != 0) {
+      throw new IOException("Incomplete map output received for " +
+                            mapOutput.getAttemptIdentifier() + " from " +
+                            host.getHostName() + " (" + 
+                            bytesLeft + " bytes missing of " + 
+                            compressedLength + ")"
+      );
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/b212ca1d/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/InMemoryReader.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/InMemoryReader.java b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/InMemoryReader.java
new file mode 100644
index 0000000..ae95268
--- /dev/null
+++ b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/InMemoryReader.java
@@ -0,0 +1,156 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.common.shuffle.impl;
+
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.IOException;
+
+import org.apache.hadoop.classification.InterfaceAudience;
+import org.apache.hadoop.classification.InterfaceStability;
+import org.apache.hadoop.io.DataInputBuffer;
+import org.apache.tez.runtime.library.common.InputAttemptIdentifier;
+import org.apache.tez.runtime.library.common.sort.impl.IFile;
+import org.apache.tez.runtime.library.common.sort.impl.IFile.Reader;
+
+/**
+ * <code>IFile.InMemoryReader</code> to read map-outputs present in-memory.
+ */
+@InterfaceAudience.Private
+@InterfaceStability.Unstable
+public class InMemoryReader extends Reader {
+  private final InputAttemptIdentifier taskAttemptId;
+  private final MergeManager merger;
+  DataInputBuffer memDataIn = new DataInputBuffer();
+  private int start;
+  private int length;
+  private int prevKeyPos;
+
+  public InMemoryReader(MergeManager merger, InputAttemptIdentifier taskAttemptId,
+                        byte[] data, int start, int length)
+  throws IOException {
+    super(null, null, length - start, null, null);
+    this.merger = merger;
+    this.taskAttemptId = taskAttemptId;
+
+    buffer = data;
+    bufferSize = (int)fileLength;
+    memDataIn.reset(buffer, start, length);
+    this.start = start;
+    this.length = length;
+  }
+
+  @Override
+  public void reset(int offset) {
+    memDataIn.reset(buffer, start + offset, length);
+    bytesRead = offset;
+    eof = false;
+  }
+
+  @Override
+  public long getPosition() throws IOException {
+    // InMemoryReader does not initialize streams like Reader, so in.getPos()
+    // would not work. Instead, return the number of uncompressed bytes read,
+    // which will be correct since in-memory data is not compressed.
+    return bytesRead;
+  }
+  
+  @Override
+  public long getLength() { 
+    return fileLength;
+  }
+  
+  private void dumpOnError() {
+    File dumpFile = new File("../output/" + taskAttemptId + ".dump");
+    System.err.println("Dumping corrupt map-output of " + taskAttemptId + 
+                       " to " + dumpFile.getAbsolutePath());
+    try {
+      FileOutputStream fos = new FileOutputStream(dumpFile);
+      fos.write(buffer, 0, bufferSize);
+      fos.close();
+    } catch (IOException ioe) {
+      System.err.println("Failed to dump map-output of " + taskAttemptId);
+    }
+  }
+  
+  public KeyState readRawKey(DataInputBuffer key) throws IOException {
+    try {
+      if (!positionToNextRecord(memDataIn)) {
+        return KeyState.NO_KEY;
+      }
+      // Setup the key
+      int pos = memDataIn.getPosition();
+      byte[] data = memDataIn.getData();      
+      if(currentKeyLength == IFile.RLE_MARKER) {
+        key.reset(data, prevKeyPos, prevKeyLength);
+        currentKeyLength = prevKeyLength;
+        return KeyState.SAME_KEY;
+      }      
+      key.reset(data, pos, currentKeyLength);
+      prevKeyPos = pos;
+      // Position for the next value
+      long skipped = memDataIn.skip(currentKeyLength);
+      if (skipped != currentKeyLength) {
+        throw new IOException("Rec# " + recNo + 
+            ": Failed to skip past key of length: " + 
+            currentKeyLength);
+      }
+
+      // Record the byte
+      bytesRead += currentKeyLength;
+      return KeyState.NEW_KEY;
+    } catch (IOException ioe) {
+      dumpOnError();
+      throw ioe;
+    }
+  }
+  
+  public void nextRawValue(DataInputBuffer value) throws IOException {
+    try {
+      int pos = memDataIn.getPosition();
+      byte[] data = memDataIn.getData();
+      value.reset(data, pos, currentValueLength);
+
+      // Position for the next record
+      long skipped = memDataIn.skip(currentValueLength);
+      if (skipped != currentValueLength) {
+        throw new IOException("Rec# " + recNo + 
+            ": Failed to skip past value of length: " + 
+            currentValueLength);
+      }
+      // Record the byte
+      bytesRead += currentValueLength;
+
+      ++recNo;
+    } catch (IOException ioe) {
+      dumpOnError();
+      throw ioe;
+    }
+  }
+    
+  public void close() {
+    // Release
+    dataIn = null;
+    buffer = null;
+      // Inform the MergeManager
+    if (merger != null) {
+      merger.unreserve(bufferSize);
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/b212ca1d/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/InMemoryWriter.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/InMemoryWriter.java b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/InMemoryWriter.java
new file mode 100644
index 0000000..f81b28e
--- /dev/null
+++ b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/InMemoryWriter.java
@@ -0,0 +1,100 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.tez.runtime.library.common.shuffle.impl;
+
+import java.io.DataOutputStream;
+import java.io.IOException;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.classification.InterfaceAudience;
+import org.apache.hadoop.classification.InterfaceStability;
+import org.apache.hadoop.io.BoundedByteArrayOutputStream;
+import org.apache.hadoop.io.DataInputBuffer;
+import org.apache.hadoop.io.WritableUtils;
+import org.apache.tez.runtime.library.common.sort.impl.IFile;
+import org.apache.tez.runtime.library.common.sort.impl.IFileOutputStream;
+import org.apache.tez.runtime.library.common.sort.impl.IFile.Writer;
+
+@InterfaceAudience.Private
+@InterfaceStability.Unstable
+public class InMemoryWriter extends Writer {
+  private static final Log LOG = LogFactory.getLog(InMemoryWriter.class);
+
+  private DataOutputStream out;
+
+  public InMemoryWriter(BoundedByteArrayOutputStream arrayStream) {
+    super(null);
+    this.out =
+      new DataOutputStream(new IFileOutputStream(arrayStream));
+  }
+
+  public void append(Object key, Object value) throws IOException {
+    throw new UnsupportedOperationException
+    ("InMemoryWriter.append(K key, V value");
+  }
+
+  public void append(DataInputBuffer key, DataInputBuffer value)
+  throws IOException {
+    int keyLength = key.getLength() - key.getPosition();
+    if (keyLength < 0) {
+      throw new IOException("Negative key-length not allowed: " + keyLength +
+                            " for " + key);
+    }
+
+    boolean sameKey = (key == IFile.REPEAT_KEY);
+
+    int valueLength = value.getLength() - value.getPosition();
+    if (valueLength < 0) {
+      throw new IOException("Negative value-length not allowed: " +
+                            valueLength + " for " + value);
+    }
+
+    if(sameKey) {
+      WritableUtils.writeVInt(out, IFile.RLE_MARKER);
+      WritableUtils.writeVInt(out, valueLength);
+      out.write(value.getData(), value.getPosition(), valueLength);
+    } else {
+      if (LOG.isDebugEnabled()) {
+        LOG.debug("InMemWriter.append" +
+            " key.data=" + key.getData() +
+            " key.pos=" + key.getPosition() +
+            " key.len=" +key.getLength() +
+            " val.data=" + value.getData() +
+            " val.pos=" + value.getPosition() +
+            " val.len=" + value.getLength());
+      }
+      WritableUtils.writeVInt(out, keyLength);
+      WritableUtils.writeVInt(out, valueLength);
+      out.write(key.getData(), key.getPosition(), keyLength);
+      out.write(value.getData(), value.getPosition(), valueLength);
+    }
+
+  }
+
+  public void close() throws IOException {
+    // Write EOF_MARKER for key/value length
+    WritableUtils.writeVInt(out, IFile.EOF_MARKER);
+    WritableUtils.writeVInt(out, IFile.EOF_MARKER);
+
+    // Close the stream
+    out.close();
+    out = null;
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/b212ca1d/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/MapHost.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/MapHost.java b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/MapHost.java
new file mode 100644
index 0000000..b8be657
--- /dev/null
+++ b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/MapHost.java
@@ -0,0 +1,124 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.tez.runtime.library.common.shuffle.impl;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.hadoop.classification.InterfaceAudience.Private;
+import org.apache.tez.runtime.library.common.InputAttemptIdentifier;
+
+@Private
+class MapHost {
+  
+  public static enum State {
+    IDLE,               // No map outputs available
+    BUSY,               // Map outputs are being fetched
+    PENDING,            // Known map outputs which need to be fetched
+    PENALIZED           // Host penalized due to shuffle failures
+  }
+  
+  private State state = State.IDLE;
+  private final String hostName;
+  private final int partitionId;
+  private final String baseUrl;
+  private final String identifier;
+  // Tracks attempt IDs
+  private List<InputAttemptIdentifier> maps = new ArrayList<InputAttemptIdentifier>();
+  
+  public MapHost(int partitionId, String hostName, String baseUrl) {
+    this.partitionId = partitionId;
+    this.hostName = hostName;
+    this.baseUrl = baseUrl;
+    this.identifier = createIdentifier(hostName, partitionId);
+  }
+  
+  public static String createIdentifier(String hostName, int partitionId) {
+    return hostName + ":" + Integer.toString(partitionId);
+  }
+  
+  public String getIdentifier() {
+    return identifier;
+  }
+  
+  public int getPartitionId() {
+    return partitionId;
+  }
+
+  public State getState() {
+    return state;
+  }
+
+  public String getHostName() {
+    return hostName;
+  }
+
+  public String getBaseUrl() {
+    return baseUrl;
+  }
+
+  public synchronized void addKnownMap(InputAttemptIdentifier srcAttempt) {
+    maps.add(srcAttempt);
+    if (state == State.IDLE) {
+      state = State.PENDING;
+    }
+  }
+
+  public synchronized List<InputAttemptIdentifier> getAndClearKnownMaps() {
+    List<InputAttemptIdentifier> currentKnownMaps = maps;
+    maps = new ArrayList<InputAttemptIdentifier>();
+    return currentKnownMaps;
+  }
+  
+  public synchronized void markBusy() {
+    state = State.BUSY;
+  }
+  
+  public synchronized void markPenalized() {
+    state = State.PENALIZED;
+  }
+  
+  public synchronized int getNumKnownMapOutputs() {
+    return maps.size();
+  }
+
+  /**
+   * Called when the node is done with its penalty or done copying.
+   * @return the host's new state
+   */
+  public synchronized State markAvailable() {
+    if (maps.isEmpty()) {
+      state = State.IDLE;
+    } else {
+      state = State.PENDING;
+    }
+    return state;
+  }
+  
+  @Override
+  public String toString() {
+    return hostName;
+  }
+  
+  /**
+   * Mark the host as penalized
+   */
+  public synchronized void penalize() {
+    state = State.PENALIZED;
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/b212ca1d/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/MapOutput.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/MapOutput.java b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/MapOutput.java
new file mode 100644
index 0000000..9f673a0
--- /dev/null
+++ b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/MapOutput.java
@@ -0,0 +1,227 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.tez.runtime.library.common.shuffle.impl;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.util.Comparator;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.LocalDirAllocator;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.BoundedByteArrayOutputStream;
+import org.apache.tez.runtime.library.common.InputAttemptIdentifier;
+import org.apache.tez.runtime.library.common.task.local.output.TezTaskOutputFiles;
+
+
+class MapOutput {
+  private static final Log LOG = LogFactory.getLog(MapOutput.class);
+  private static AtomicInteger ID = new AtomicInteger(0);
+  
+  public static enum Type {
+    WAIT,
+    MEMORY,
+    DISK
+  }
+  
+  private InputAttemptIdentifier attemptIdentifier;
+  private final int id;
+  
+  private final MergeManager merger;
+  
+  private final long size;
+  
+  private final byte[] memory;
+  private BoundedByteArrayOutputStream byteStream;
+  
+  private final FileSystem localFS;
+  private final Path tmpOutputPath;
+  private final Path outputPath;
+  private final OutputStream disk; 
+  
+  private final Type type;
+  
+  private final boolean primaryMapOutput;
+  
+  MapOutput(InputAttemptIdentifier attemptIdentifier, MergeManager merger, long size, 
+            Configuration conf, LocalDirAllocator localDirAllocator,
+            int fetcher, boolean primaryMapOutput, 
+            TezTaskOutputFiles mapOutputFile)
+         throws IOException {
+    this.id = ID.incrementAndGet();
+    this.attemptIdentifier = attemptIdentifier;
+    this.merger = merger;
+
+    type = Type.DISK;
+
+    memory = null;
+    byteStream = null;
+
+    this.size = size;
+    
+    this.localFS = FileSystem.getLocal(conf);
+    outputPath =
+      mapOutputFile.getInputFileForWrite(this.attemptIdentifier.getInputIdentifier().getSrcTaskIndex(), size);
+    tmpOutputPath = outputPath.suffix(String.valueOf(fetcher));
+
+    disk = localFS.create(tmpOutputPath);
+    
+    this.primaryMapOutput = primaryMapOutput;
+  }
+  
+  MapOutput(InputAttemptIdentifier attemptIdentifier, MergeManager merger, int size, 
+            boolean primaryMapOutput) {
+    this.id = ID.incrementAndGet();
+    this.attemptIdentifier = attemptIdentifier;
+    this.merger = merger;
+
+    type = Type.MEMORY;
+    byteStream = new BoundedByteArrayOutputStream(size);
+    memory = byteStream.getBuffer();
+
+    this.size = size;
+    
+    localFS = null;
+    disk = null;
+    outputPath = null;
+    tmpOutputPath = null;
+    
+    this.primaryMapOutput = primaryMapOutput;
+  }
+
+  public MapOutput(InputAttemptIdentifier attemptIdentifier) {
+    this.id = ID.incrementAndGet();
+    this.attemptIdentifier = attemptIdentifier;
+
+    type = Type.WAIT;
+    merger = null;
+    memory = null;
+    byteStream = null;
+    
+    size = -1;
+    
+    localFS = null;
+    disk = null;
+    outputPath = null;
+    tmpOutputPath = null;
+
+    this.primaryMapOutput = false;
+}
+  
+  public boolean isPrimaryMapOutput() {
+    return primaryMapOutput;
+  }
+
+  @Override
+  public boolean equals(Object obj) {
+    if (obj instanceof MapOutput) {
+      return id == ((MapOutput)obj).id;
+    }
+    return false;
+  }
+
+  @Override
+  public int hashCode() {
+    return id;
+  }
+
+  public Path getOutputPath() {
+    return outputPath;
+  }
+
+  public byte[] getMemory() {
+    return memory;
+  }
+
+  public BoundedByteArrayOutputStream getArrayStream() {
+    return byteStream;
+  }
+  
+  public OutputStream getDisk() {
+    return disk;
+  }
+
+  public InputAttemptIdentifier getAttemptIdentifier() {
+    return this.attemptIdentifier;
+  }
+
+  public Type getType() {
+    return type;
+  }
+
+  public long getSize() {
+    return size;
+  }
+
+  public void commit() throws IOException {
+    if (type == Type.MEMORY) {
+      merger.closeInMemoryFile(this);
+    } else if (type == Type.DISK) {
+      localFS.rename(tmpOutputPath, outputPath);
+      merger.closeOnDiskFile(outputPath);
+    } else {
+      throw new IOException("Cannot commit MapOutput of type WAIT!");
+    }
+  }
+  
+  public void abort() {
+    if (type == Type.MEMORY) {
+      merger.unreserve(memory.length);
+    } else if (type == Type.DISK) {
+      try {
+        localFS.delete(tmpOutputPath, false);
+      } catch (IOException ie) {
+        LOG.info("failure to clean up " + tmpOutputPath, ie);
+      }
+    } else {
+      throw new IllegalArgumentException
+                   ("Cannot commit MapOutput with of type WAIT!");
+    }
+  }
+  
+  public String toString() {
+    return "MapOutput( AttemptIdentifier: " + attemptIdentifier + ", Type: " + type + ")";
+  }
+  
+  public static class MapOutputComparator 
+  implements Comparator<MapOutput> {
+    public int compare(MapOutput o1, MapOutput o2) {
+      if (o1.id == o2.id) { 
+        return 0;
+      }
+      
+      if (o1.size < o2.size) {
+        return -1;
+      } else if (o1.size > o2.size) {
+        return 1;
+      }
+      
+      if (o1.id < o2.id) {
+        return -1;
+      } else {
+        return 1;
+      
+      }
+    }
+  }
+  
+}

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/b212ca1d/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/MergeManager.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/MergeManager.java b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/MergeManager.java
new file mode 100644
index 0000000..0abe530
--- /dev/null
+++ b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/MergeManager.java
@@ -0,0 +1,782 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.tez.runtime.library.common.shuffle.impl;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.List;
+import java.util.Set;
+import java.util.TreeSet;
+
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.classification.InterfaceAudience;
+import org.apache.hadoop.classification.InterfaceStability;
+import org.apache.hadoop.classification.InterfaceAudience.Private;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.ChecksumFileSystem;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.LocalDirAllocator;
+import org.apache.hadoop.fs.LocalFileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.DataInputBuffer;
+import org.apache.hadoop.io.RawComparator;
+import org.apache.hadoop.io.compress.CompressionCodec;
+import org.apache.hadoop.io.compress.DefaultCodec;
+import org.apache.hadoop.util.Progressable;
+import org.apache.hadoop.util.ReflectionUtils;
+import org.apache.tez.common.TezJobConfig;
+import org.apache.tez.common.counters.TezCounter;
+import org.apache.tez.runtime.api.TezInputContext;
+import org.apache.tez.runtime.library.common.ConfigUtils;
+import org.apache.tez.runtime.library.common.Constants;
+import org.apache.tez.runtime.library.common.InputAttemptIdentifier;
+import org.apache.tez.runtime.library.common.combine.Combiner;
+import org.apache.tez.runtime.library.common.sort.impl.IFile;
+import org.apache.tez.runtime.library.common.sort.impl.TezMerger;
+import org.apache.tez.runtime.library.common.sort.impl.TezRawKeyValueIterator;
+import org.apache.tez.runtime.library.common.sort.impl.IFile.Writer;
+import org.apache.tez.runtime.library.common.sort.impl.TezMerger.Segment;
+import org.apache.tez.runtime.library.common.task.local.output.TezTaskOutputFiles;
+import org.apache.tez.runtime.library.hadoop.compat.NullProgressable;
+
+@InterfaceAudience.Private
+@InterfaceStability.Unstable
+@SuppressWarnings(value={"rawtypes"})
+public class MergeManager {
+  
+  private static final Log LOG = LogFactory.getLog(MergeManager.class);
+
+  private final Configuration conf;
+  private final FileSystem localFS;
+  private final FileSystem rfs;
+  private final LocalDirAllocator localDirAllocator;
+  
+  private final  TezTaskOutputFiles mapOutputFile;
+  private final Progressable nullProgressable = new NullProgressable();
+  private final Combiner combiner;  
+  
+  Set<MapOutput> inMemoryMergedMapOutputs = 
+    new TreeSet<MapOutput>(new MapOutput.MapOutputComparator());
+  private final IntermediateMemoryToMemoryMerger memToMemMerger;
+
+  Set<MapOutput> inMemoryMapOutputs = 
+    new TreeSet<MapOutput>(new MapOutput.MapOutputComparator());
+  private final InMemoryMerger inMemoryMerger;
+  
+  Set<Path> onDiskMapOutputs = new TreeSet<Path>();
+  private final OnDiskMerger onDiskMerger;
+  
+  private final long memoryLimit;
+  private long usedMemory;
+  private long commitMemory;
+  private final long maxSingleShuffleLimit;
+  
+  private final int memToMemMergeOutputsThreshold; 
+  private final long mergeThreshold;
+  
+  private final int ioSortFactor;
+
+  private final ExceptionReporter exceptionReporter;
+  
+  private final TezInputContext inputContext;
+
+  private final TezCounter spilledRecordsCounter;
+
+  private final TezCounter reduceCombineInputCounter;
+
+  private final TezCounter mergedMapOutputsCounter;
+  
+  private final CompressionCodec codec;
+  
+  private volatile boolean finalMergeComplete = false;
+
+  public MergeManager(Configuration conf, 
+                      FileSystem localFS,
+                      LocalDirAllocator localDirAllocator,  
+                      TezInputContext inputContext,
+                      Combiner combiner,
+                      TezCounter spilledRecordsCounter,
+                      TezCounter reduceCombineInputCounter,
+                      TezCounter mergedMapOutputsCounter,
+                      ExceptionReporter exceptionReporter) {
+    this.inputContext = inputContext;
+    this.conf = conf;
+    this.localDirAllocator = localDirAllocator;
+    this.exceptionReporter = exceptionReporter;
+    
+    this.combiner = combiner;
+
+    this.reduceCombineInputCounter = reduceCombineInputCounter;
+    this.spilledRecordsCounter = spilledRecordsCounter;
+    this.mergedMapOutputsCounter = mergedMapOutputsCounter;
+    this.mapOutputFile = new TezTaskOutputFiles(conf, inputContext.getUniqueIdentifier());
+    
+    this.localFS = localFS;
+    this.rfs = ((LocalFileSystem)localFS).getRaw();
+
+    if (ConfigUtils.isIntermediateInputCompressed(conf)) {
+      Class<? extends CompressionCodec> codecClass =
+          ConfigUtils.getIntermediateInputCompressorClass(conf, DefaultCodec.class);
+      codec = ReflectionUtils.newInstance(codecClass, conf);
+    } else {
+      codec = null;
+    }
+
+    final float maxInMemCopyUse =
+      conf.getFloat(
+          TezJobConfig.TEZ_RUNTIME_SHUFFLE_INPUT_BUFFER_PERCENT, 
+          TezJobConfig.DEFAULT_TEZ_RUNTIME_SHUFFLE_INPUT_BUFFER_PERCENT);
+    if (maxInMemCopyUse > 1.0 || maxInMemCopyUse < 0.0) {
+      throw new IllegalArgumentException("Invalid value for " +
+          TezJobConfig.TEZ_RUNTIME_SHUFFLE_INPUT_BUFFER_PERCENT + ": " +
+          maxInMemCopyUse);
+    }
+
+    // Allow unit tests to fix Runtime memory
+    this.memoryLimit = 
+      (long)(conf.getLong(Constants.TEZ_RUNTIME_TASK_MEMORY,
+          Math.min(Runtime.getRuntime().maxMemory(), Integer.MAX_VALUE))
+        * maxInMemCopyUse);
+ 
+    this.ioSortFactor = 
+        conf.getInt(
+            TezJobConfig.TEZ_RUNTIME_IO_SORT_FACTOR, 
+            TezJobConfig.DEFAULT_TEZ_RUNTIME_IO_SORT_FACTOR);
+
+    final float singleShuffleMemoryLimitPercent =
+        conf.getFloat(
+            TezJobConfig.TEZ_RUNTIME_SHUFFLE_MEMORY_LIMIT_PERCENT,
+            TezJobConfig.DEFAULT_TEZ_RUNTIME_SHUFFLE_MEMORY_LIMIT_PERCENT);
+    if (singleShuffleMemoryLimitPercent <= 0.0f
+        || singleShuffleMemoryLimitPercent > 1.0f) {
+      throw new IllegalArgumentException("Invalid value for "
+          + TezJobConfig.TEZ_RUNTIME_SHUFFLE_MEMORY_LIMIT_PERCENT + ": "
+          + singleShuffleMemoryLimitPercent);
+    }
+
+    this.maxSingleShuffleLimit = 
+      (long)(memoryLimit * singleShuffleMemoryLimitPercent);
+    this.memToMemMergeOutputsThreshold = 
+            conf.getInt(
+                TezJobConfig.TEZ_RUNTIME_SHUFFLE_MEMTOMEM_SEGMENTS, 
+                ioSortFactor);
+    this.mergeThreshold = 
+        (long)(this.memoryLimit * 
+               conf.getFloat(
+                   TezJobConfig.TEZ_RUNTIME_SHUFFLE_MERGE_PERCENT, 
+                   TezJobConfig.DEFAULT_TEZ_RUNTIME_SHUFFLE_MERGE_PERCENT));
+    LOG.info("MergerManager: memoryLimit=" + memoryLimit + ", " +
+             "maxSingleShuffleLimit=" + maxSingleShuffleLimit + ", " +
+             "mergeThreshold=" + mergeThreshold + ", " + 
+             "ioSortFactor=" + ioSortFactor + ", " +
+             "memToMemMergeOutputsThreshold=" + memToMemMergeOutputsThreshold);
+
+    if (this.maxSingleShuffleLimit >= this.mergeThreshold) {
+      throw new RuntimeException("Invlaid configuration: "
+          + "maxSingleShuffleLimit should be less than mergeThreshold"
+          + "maxSingleShuffleLimit: " + this.maxSingleShuffleLimit
+          + "mergeThreshold: " + this.mergeThreshold);
+    }
+
+    boolean allowMemToMemMerge = 
+      conf.getBoolean(
+          TezJobConfig.TEZ_RUNTIME_SHUFFLE_ENABLE_MEMTOMEM, 
+          TezJobConfig.DEFAULT_TEZ_RUNTIME_SHUFFLE_ENABLE_MEMTOMEM);
+    if (allowMemToMemMerge) {
+      this.memToMemMerger = 
+        new IntermediateMemoryToMemoryMerger(this,
+                                             memToMemMergeOutputsThreshold);
+      this.memToMemMerger.start();
+    } else {
+      this.memToMemMerger = null;
+    }
+    
+    this.inMemoryMerger = new InMemoryMerger(this);
+    this.inMemoryMerger.start();
+    
+    this.onDiskMerger = new OnDiskMerger(this);
+    this.onDiskMerger.start();
+  }
+
+  public void waitForInMemoryMerge() throws InterruptedException {
+    inMemoryMerger.waitForMerge();
+  }
+  
+  private boolean canShuffleToMemory(long requestedSize) {
+    return (requestedSize < maxSingleShuffleLimit); 
+  }
+
+  final private MapOutput stallShuffle = new MapOutput(null);
+
+  public synchronized MapOutput reserve(InputAttemptIdentifier srcAttemptIdentifier, 
+                                             long requestedSize,
+                                             int fetcher
+                                             ) throws IOException {
+    if (!canShuffleToMemory(requestedSize)) {
+      LOG.info(srcAttemptIdentifier + ": Shuffling to disk since " + requestedSize + 
+               " is greater than maxSingleShuffleLimit (" + 
+               maxSingleShuffleLimit + ")");
+      return new MapOutput(srcAttemptIdentifier, this, requestedSize, conf, 
+                                localDirAllocator, fetcher, true,
+                                mapOutputFile);
+    }
+    
+    // Stall shuffle if we are above the memory limit
+
+    // It is possible that all threads could just be stalling and not make
+    // progress at all. This could happen when:
+    //
+    // requested size is causing the used memory to go above limit &&
+    // requested size < singleShuffleLimit &&
+    // current used size < mergeThreshold (merge will not get triggered)
+    //
+    // To avoid this from happening, we allow exactly one thread to go past
+    // the memory limit. We check (usedMemory > memoryLimit) and not
+    // (usedMemory + requestedSize > memoryLimit). When this thread is done
+    // fetching, this will automatically trigger a merge thereby unlocking
+    // all the stalled threads
+    
+    if (usedMemory > memoryLimit) {
+      LOG.debug(srcAttemptIdentifier + ": Stalling shuffle since usedMemory (" + usedMemory
+          + ") is greater than memoryLimit (" + memoryLimit + ")." + 
+          " CommitMemory is (" + commitMemory + ")"); 
+      return stallShuffle;
+    }
+    
+    // Allow the in-memory shuffle to progress
+    LOG.debug(srcAttemptIdentifier + ": Proceeding with shuffle since usedMemory ("
+        + usedMemory + ") is lesser than memoryLimit (" + memoryLimit + ")."
+        + "CommitMemory is (" + commitMemory + ")"); 
+    return unconditionalReserve(srcAttemptIdentifier, requestedSize, true);
+  }
+  
+  /**
+   * Unconditional Reserve is used by the Memory-to-Memory thread
+   * @return
+   */
+  private synchronized MapOutput unconditionalReserve(
+      InputAttemptIdentifier srcAttemptIdentifier, long requestedSize, boolean primaryMapOutput) {
+    usedMemory += requestedSize;
+    return new MapOutput(srcAttemptIdentifier, this, (int)requestedSize, 
+        primaryMapOutput);
+  }
+  
+  synchronized void unreserve(long size) {
+    commitMemory -= size;
+    usedMemory -= size;
+  }
+
+  public synchronized void closeInMemoryFile(MapOutput mapOutput) { 
+    inMemoryMapOutputs.add(mapOutput);
+    LOG.info("closeInMemoryFile -> map-output of size: " + mapOutput.getSize()
+        + ", inMemoryMapOutputs.size() -> " + inMemoryMapOutputs.size()
+        + ", commitMemory -> " + commitMemory + ", usedMemory ->" + usedMemory);
+
+    commitMemory+= mapOutput.getSize();
+
+    synchronized (inMemoryMerger) {
+      // Can hang if mergeThreshold is really low.
+      if (!inMemoryMerger.isInProgress() && commitMemory >= mergeThreshold) {
+        LOG.info("Starting inMemoryMerger's merge since commitMemory=" +
+            commitMemory + " > mergeThreshold=" + mergeThreshold + 
+            ". Current usedMemory=" + usedMemory);
+        inMemoryMapOutputs.addAll(inMemoryMergedMapOutputs);
+        inMemoryMergedMapOutputs.clear();
+        inMemoryMerger.startMerge(inMemoryMapOutputs);
+      } 
+    }
+    
+    if (memToMemMerger != null) {
+      synchronized (memToMemMerger) {
+        if (!memToMemMerger.isInProgress() && 
+            inMemoryMapOutputs.size() >= memToMemMergeOutputsThreshold) {
+          memToMemMerger.startMerge(inMemoryMapOutputs);
+        }
+      }
+    }
+  }
+  
+  
+  public synchronized void closeInMemoryMergedFile(MapOutput mapOutput) {
+    inMemoryMergedMapOutputs.add(mapOutput);
+    LOG.info("closeInMemoryMergedFile -> size: " + mapOutput.getSize() + 
+             ", inMemoryMergedMapOutputs.size() -> " + 
+             inMemoryMergedMapOutputs.size());
+  }
+  
+  public synchronized void closeOnDiskFile(Path file) {
+    onDiskMapOutputs.add(file);
+    
+    synchronized (onDiskMerger) {
+      if (!onDiskMerger.isInProgress() && 
+          onDiskMapOutputs.size() >= (2 * ioSortFactor - 1)) {
+        onDiskMerger.startMerge(onDiskMapOutputs);
+      }
+    }
+  }
+
+  /**
+   * Should <b>only</b> be used after the Shuffle phaze is complete, otherwise can
+   * return an invalid state since a merge may not be in progress dur to
+   * inadequate inputs
+   * 
+   * @return true if the merge process is complete, otherwise false
+   */
+  @Private
+  public boolean isMergeComplete() {
+    return finalMergeComplete;
+  }
+  
+  public TezRawKeyValueIterator close() throws Throwable {
+    // Wait for on-going merges to complete
+    if (memToMemMerger != null) { 
+      memToMemMerger.close();
+    }
+    inMemoryMerger.close();
+    onDiskMerger.close();
+    
+    List<MapOutput> memory = 
+      new ArrayList<MapOutput>(inMemoryMergedMapOutputs);
+    memory.addAll(inMemoryMapOutputs);
+    List<Path> disk = new ArrayList<Path>(onDiskMapOutputs);
+    TezRawKeyValueIterator kvIter = finalMerge(conf, rfs, memory, disk);
+    this.finalMergeComplete = true;
+    return kvIter;
+  }
+   
+  void runCombineProcessor(TezRawKeyValueIterator kvIter, Writer writer)
+      throws IOException, InterruptedException {
+    combiner.combine(kvIter, writer);
+  }
+
+  private class IntermediateMemoryToMemoryMerger 
+  extends MergeThread<MapOutput> {
+    
+    public IntermediateMemoryToMemoryMerger(MergeManager manager, 
+                                            int mergeFactor) {
+      super(manager, mergeFactor, exceptionReporter);
+      setName("InMemoryMerger - Thread to do in-memory merge of in-memory " +
+      		    "shuffled map-outputs");
+      setDaemon(true);
+    }
+
+    @Override
+    public void merge(List<MapOutput> inputs) throws IOException {
+      if (inputs == null || inputs.size() == 0) {
+        return;
+      }
+
+      InputAttemptIdentifier dummyMapId = inputs.get(0).getAttemptIdentifier(); 
+      List<Segment> inMemorySegments = new ArrayList<Segment>();
+      long mergeOutputSize = 
+        createInMemorySegments(inputs, inMemorySegments, 0);
+      int noInMemorySegments = inMemorySegments.size();
+      
+      MapOutput mergedMapOutputs = 
+        unconditionalReserve(dummyMapId, mergeOutputSize, false);
+      
+      Writer writer = 
+        new InMemoryWriter(mergedMapOutputs.getArrayStream());
+      
+      LOG.info("Initiating Memory-to-Memory merge with " + noInMemorySegments +
+               " segments of total-size: " + mergeOutputSize);
+
+      TezRawKeyValueIterator rIter = 
+        TezMerger.merge(conf, rfs,
+                       ConfigUtils.getIntermediateInputKeyClass(conf),
+                       ConfigUtils.getIntermediateInputValueClass(conf),
+                       inMemorySegments, inMemorySegments.size(),
+                       new Path(inputContext.getUniqueIdentifier()),
+                       (RawComparator)ConfigUtils.getIntermediateInputKeyComparator(conf),
+                       nullProgressable, null, null, null);
+      TezMerger.writeFile(rIter, writer, nullProgressable, conf);
+      writer.close();
+
+      LOG.info(inputContext.getUniqueIdentifier() +  
+               " Memory-to-Memory merge of the " + noInMemorySegments +
+               " files in-memory complete.");
+
+      // Note the output of the merge
+      closeInMemoryMergedFile(mergedMapOutputs);
+    }
+  }
+  
+  private class InMemoryMerger extends MergeThread<MapOutput> {
+    
+    public InMemoryMerger(MergeManager manager) {
+      super(manager, Integer.MAX_VALUE, exceptionReporter);
+      setName
+      ("InMemoryMerger - Thread to merge in-memory shuffled map-outputs");
+      setDaemon(true);
+    }
+    
+    @Override
+    public void merge(List<MapOutput> inputs) throws IOException, InterruptedException {
+      if (inputs == null || inputs.size() == 0) {
+        return;
+      }
+      
+      //name this output file same as the name of the first file that is 
+      //there in the current list of inmem files (this is guaranteed to
+      //be absent on the disk currently. So we don't overwrite a prev. 
+      //created spill). Also we need to create the output file now since
+      //it is not guaranteed that this file will be present after merge
+      //is called (we delete empty files as soon as we see them
+      //in the merge method)
+
+      //figure out the mapId 
+      InputAttemptIdentifier srcTaskIdentifier = inputs.get(0).getAttemptIdentifier();
+
+      List<Segment> inMemorySegments = new ArrayList<Segment>();
+      long mergeOutputSize = 
+        createInMemorySegments(inputs, inMemorySegments,0);
+      int noInMemorySegments = inMemorySegments.size();
+
+      Path outputPath = mapOutputFile.getInputFileForWrite(
+          srcTaskIdentifier.getInputIdentifier().getSrcTaskIndex(),
+          mergeOutputSize).suffix(Constants.MERGED_OUTPUT_PREFIX);
+
+      Writer writer = null;
+      try {
+        writer =
+            new Writer(conf, rfs, outputPath,
+                (Class)ConfigUtils.getIntermediateInputKeyClass(conf),
+                (Class)ConfigUtils.getIntermediateInputValueClass(conf),
+                codec, null);
+
+        TezRawKeyValueIterator rIter = null;
+        LOG.info("Initiating in-memory merge with " + noInMemorySegments + 
+            " segments...");
+
+        rIter = TezMerger.merge(conf, rfs,
+            (Class)ConfigUtils.getIntermediateInputKeyClass(conf),
+            (Class)ConfigUtils.getIntermediateInputValueClass(conf),
+            inMemorySegments, inMemorySegments.size(),
+            new Path(inputContext.getUniqueIdentifier()),
+            (RawComparator)ConfigUtils.getIntermediateInputKeyComparator(conf),
+            nullProgressable, spilledRecordsCounter, null, null);
+
+        if (null == combiner) {
+          TezMerger.writeFile(rIter, writer, nullProgressable, conf);
+        } else {
+          runCombineProcessor(rIter, writer);
+        }
+        writer.close();
+        writer = null;
+
+        LOG.info(inputContext.getUniqueIdentifier() +  
+            " Merge of the " + noInMemorySegments +
+            " files in-memory complete." +
+            " Local file is " + outputPath + " of size " + 
+            localFS.getFileStatus(outputPath).getLen());
+      } catch (IOException e) { 
+        //make sure that we delete the ondisk file that we created 
+        //earlier when we invoked cloneFileAttributes
+        localFS.delete(outputPath, true);
+        throw e;
+      } finally {
+        if (writer != null) {
+          writer.close();
+        }
+      }
+
+      // Note the output of the merge
+      closeOnDiskFile(outputPath);
+    }
+
+  }
+  
+  private class OnDiskMerger extends MergeThread<Path> {
+    
+    public OnDiskMerger(MergeManager manager) {
+      super(manager, Integer.MAX_VALUE, exceptionReporter);
+      setName("OnDiskMerger - Thread to merge on-disk map-outputs");
+      setDaemon(true);
+    }
+    
+    @Override
+    public void merge(List<Path> inputs) throws IOException {
+      // sanity check
+      if (inputs == null || inputs.isEmpty()) {
+        LOG.info("No ondisk files to merge...");
+        return;
+      }
+      
+      long approxOutputSize = 0;
+      int bytesPerSum = 
+        conf.getInt("io.bytes.per.checksum", 512);
+      
+      LOG.info("OnDiskMerger: We have  " + inputs.size() + 
+               " map outputs on disk. Triggering merge...");
+      
+      // 1. Prepare the list of files to be merged. 
+      for (Path file : inputs) {
+        approxOutputSize += localFS.getFileStatus(file).getLen();
+      }
+
+      // add the checksum length
+      approxOutputSize += 
+        ChecksumFileSystem.getChecksumLength(approxOutputSize, bytesPerSum);
+
+      // 2. Start the on-disk merge process
+      Path outputPath = 
+        localDirAllocator.getLocalPathForWrite(inputs.get(0).toString(), 
+            approxOutputSize, conf).suffix(Constants.MERGED_OUTPUT_PREFIX);
+      Writer writer = 
+        new Writer(conf, rfs, outputPath, 
+                        (Class)ConfigUtils.getIntermediateInputKeyClass(conf), 
+                        (Class)ConfigUtils.getIntermediateInputValueClass(conf),
+                        codec, null);
+      TezRawKeyValueIterator iter  = null;
+      Path tmpDir = new Path(inputContext.getUniqueIdentifier());
+      try {
+        iter = TezMerger.merge(conf, rfs,
+                            (Class)ConfigUtils.getIntermediateInputKeyClass(conf), 
+                            (Class)ConfigUtils.getIntermediateInputValueClass(conf),
+                            codec, inputs.toArray(new Path[inputs.size()]), 
+                            true, ioSortFactor, tmpDir, 
+                            (RawComparator)ConfigUtils.getIntermediateInputKeyComparator(conf), 
+                            nullProgressable, spilledRecordsCounter, null, 
+                            mergedMapOutputsCounter, null);
+
+        TezMerger.writeFile(iter, writer, nullProgressable, conf);
+        writer.close();
+      } catch (IOException e) {
+        localFS.delete(outputPath, true);
+        throw e;
+      }
+
+      closeOnDiskFile(outputPath);
+
+      LOG.info(inputContext.getUniqueIdentifier() +
+          " Finished merging " + inputs.size() + 
+          " map output files on disk of total-size " + 
+          approxOutputSize + "." + 
+          " Local output file is " + outputPath + " of size " +
+          localFS.getFileStatus(outputPath).getLen());
+    }
+  }
+  
+  private long createInMemorySegments(List<MapOutput> inMemoryMapOutputs,
+                                      List<Segment> inMemorySegments, 
+                                      long leaveBytes
+                                      ) throws IOException {
+    long totalSize = 0L;
+    // We could use fullSize could come from the RamManager, but files can be
+    // closed but not yet present in inMemoryMapOutputs
+    long fullSize = 0L;
+    for (MapOutput mo : inMemoryMapOutputs) {
+      fullSize += mo.getMemory().length;
+    }
+    while(fullSize > leaveBytes) {
+      MapOutput mo = inMemoryMapOutputs.remove(0);
+      byte[] data = mo.getMemory();
+      long size = data.length;
+      totalSize += size;
+      fullSize -= size;
+      IFile.Reader reader = new InMemoryReader(MergeManager.this, 
+                                                   mo.getAttemptIdentifier(),
+                                                   data, 0, (int)size);
+      inMemorySegments.add(new Segment(reader, true, 
+                                            (mo.isPrimaryMapOutput() ? 
+                                            mergedMapOutputsCounter : null)));
+    }
+    return totalSize;
+  }
+
+  class RawKVIteratorReader extends IFile.Reader {
+
+    private final TezRawKeyValueIterator kvIter;
+
+    public RawKVIteratorReader(TezRawKeyValueIterator kvIter, long size)
+        throws IOException {
+      super(null, null, size, null, spilledRecordsCounter);
+      this.kvIter = kvIter;
+    }
+    public boolean nextRawKey(DataInputBuffer key) throws IOException {
+      if (kvIter.next()) {
+        final DataInputBuffer kb = kvIter.getKey();
+        final int kp = kb.getPosition();
+        final int klen = kb.getLength() - kp;
+        key.reset(kb.getData(), kp, klen);
+        bytesRead += klen;
+        return true;
+      }
+      return false;
+    }
+    public void nextRawValue(DataInputBuffer value) throws IOException {
+      final DataInputBuffer vb = kvIter.getValue();
+      final int vp = vb.getPosition();
+      final int vlen = vb.getLength() - vp;
+      value.reset(vb.getData(), vp, vlen);
+      bytesRead += vlen;
+    }
+    public long getPosition() throws IOException {
+      return bytesRead;
+    }
+
+    public void close() throws IOException {
+      kvIter.close();
+    }
+  }
+
+  private TezRawKeyValueIterator finalMerge(Configuration job, FileSystem fs,
+                                       List<MapOutput> inMemoryMapOutputs,
+                                       List<Path> onDiskMapOutputs
+                                       ) throws IOException {
+    LOG.info("finalMerge called with " + 
+             inMemoryMapOutputs.size() + " in-memory map-outputs and " + 
+             onDiskMapOutputs.size() + " on-disk map-outputs");
+    
+    final float maxRedPer =
+      job.getFloat(
+          TezJobConfig.TEZ_RUNTIME_INPUT_BUFFER_PERCENT,
+          TezJobConfig.DEFAULT_TEZ_RUNTIME_INPUT_BUFFER_PERCENT);
+    if (maxRedPer > 1.0 || maxRedPer < 0.0) {
+      throw new IOException(TezJobConfig.TEZ_RUNTIME_INPUT_BUFFER_PERCENT +
+                            maxRedPer);
+    }
+    int maxInMemReduce = (int)Math.min(
+        Runtime.getRuntime().maxMemory() * maxRedPer, Integer.MAX_VALUE);
+    
+
+    // merge config params
+    Class keyClass = (Class)ConfigUtils.getIntermediateInputKeyClass(job);
+    Class valueClass = (Class)ConfigUtils.getIntermediateInputValueClass(job);
+    final Path tmpDir = new Path(inputContext.getUniqueIdentifier());
+    final RawComparator comparator =
+      (RawComparator)ConfigUtils.getIntermediateInputKeyComparator(job);
+
+    // segments required to vacate memory
+    List<Segment> memDiskSegments = new ArrayList<Segment>();
+    long inMemToDiskBytes = 0;
+    boolean mergePhaseFinished = false;
+    if (inMemoryMapOutputs.size() > 0) {
+      int srcTaskId = inMemoryMapOutputs.get(0).getAttemptIdentifier().getInputIdentifier().getSrcTaskIndex();
+      inMemToDiskBytes = createInMemorySegments(inMemoryMapOutputs, 
+                                                memDiskSegments,
+                                                maxInMemReduce);
+      final int numMemDiskSegments = memDiskSegments.size();
+      if (numMemDiskSegments > 0 &&
+            ioSortFactor > onDiskMapOutputs.size()) {
+        
+        // If we reach here, it implies that we have less than io.sort.factor
+        // disk segments and this will be incremented by 1 (result of the 
+        // memory segments merge). Since this total would still be 
+        // <= io.sort.factor, we will not do any more intermediate merges,
+        // the merge of all these disk segments would be directly fed to the
+        // reduce method
+        
+        mergePhaseFinished = true;
+        // must spill to disk, but can't retain in-mem for intermediate merge
+        final Path outputPath = 
+          mapOutputFile.getInputFileForWrite(srcTaskId,
+                                             inMemToDiskBytes).suffix(
+                                                 Constants.MERGED_OUTPUT_PREFIX);
+        final TezRawKeyValueIterator rIter = TezMerger.merge(job, fs,
+            keyClass, valueClass, memDiskSegments, numMemDiskSegments,
+            tmpDir, comparator, nullProgressable, spilledRecordsCounter, null, null);
+        final Writer writer = new Writer(job, fs, outputPath,
+            keyClass, valueClass, codec, null);
+        try {
+          TezMerger.writeFile(rIter, writer, nullProgressable, job);
+          // add to list of final disk outputs.
+          onDiskMapOutputs.add(outputPath);
+        } catch (IOException e) {
+          if (null != outputPath) {
+            try {
+              fs.delete(outputPath, true);
+            } catch (IOException ie) {
+              // NOTHING
+            }
+          }
+          throw e;
+        } finally {
+          if (null != writer) {
+            writer.close();
+          }
+        }
+        LOG.info("Merged " + numMemDiskSegments + " segments, " +
+                 inMemToDiskBytes + " bytes to disk to satisfy " +
+                 "reduce memory limit");
+        inMemToDiskBytes = 0;
+        memDiskSegments.clear();
+      } else if (inMemToDiskBytes != 0) {
+        LOG.info("Keeping " + numMemDiskSegments + " segments, " +
+                 inMemToDiskBytes + " bytes in memory for " +
+                 "intermediate, on-disk merge");
+      }
+    }
+
+    // segments on disk
+    List<Segment> diskSegments = new ArrayList<Segment>();
+    long onDiskBytes = inMemToDiskBytes;
+    Path[] onDisk = onDiskMapOutputs.toArray(new Path[onDiskMapOutputs.size()]);
+    for (Path file : onDisk) {
+      onDiskBytes += fs.getFileStatus(file).getLen();
+      LOG.debug("Disk file: " + file + " Length is " + 
+          fs.getFileStatus(file).getLen());
+      diskSegments.add(new Segment(job, fs, file, codec, false,
+                                         (file.toString().endsWith(
+                                             Constants.MERGED_OUTPUT_PREFIX) ?
+                                          null : mergedMapOutputsCounter)
+                                        ));
+    }
+    LOG.info("Merging " + onDisk.length + " files, " +
+             onDiskBytes + " bytes from disk");
+    Collections.sort(diskSegments, new Comparator<Segment>() {
+      public int compare(Segment o1, Segment o2) {
+        if (o1.getLength() == o2.getLength()) {
+          return 0;
+        }
+        return o1.getLength() < o2.getLength() ? -1 : 1;
+      }
+    });
+
+    // build final list of segments from merged backed by disk + in-mem
+    List<Segment> finalSegments = new ArrayList<Segment>();
+    long inMemBytes = createInMemorySegments(inMemoryMapOutputs, 
+                                             finalSegments, 0);
+    LOG.info("Merging " + finalSegments.size() + " segments, " +
+             inMemBytes + " bytes from memory into reduce");
+    if (0 != onDiskBytes) {
+      final int numInMemSegments = memDiskSegments.size();
+      diskSegments.addAll(0, memDiskSegments);
+      memDiskSegments.clear();
+      TezRawKeyValueIterator diskMerge = TezMerger.merge(
+          job, fs, keyClass, valueClass, diskSegments,
+          ioSortFactor, numInMemSegments, tmpDir, comparator,
+          nullProgressable, false, spilledRecordsCounter, null, null);
+      diskSegments.clear();
+      if (0 == finalSegments.size()) {
+        return diskMerge;
+      }
+      finalSegments.add(new Segment(
+            new RawKVIteratorReader(diskMerge, onDiskBytes), true));
+    }
+    return TezMerger.merge(job, fs, keyClass, valueClass,
+                 finalSegments, finalSegments.size(), tmpDir,
+                 comparator, nullProgressable, spilledRecordsCounter, null,
+                 null);
+  
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/b212ca1d/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/MergeThread.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/MergeThread.java b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/MergeThread.java
new file mode 100644
index 0000000..d8a7722
--- /dev/null
+++ b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/MergeThread.java
@@ -0,0 +1,108 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.tez.runtime.library.common.shuffle.impl;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Set;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+
+abstract class MergeThread<T> extends Thread {
+  
+  private static final Log LOG = LogFactory.getLog(MergeThread.class);
+
+  private volatile boolean inProgress = false;
+  private List<T> inputs = new ArrayList<T>();
+  protected final MergeManager manager;
+  private final ExceptionReporter reporter;
+  private boolean closed = false;
+  private final int mergeFactor;
+  
+  public MergeThread(MergeManager manager, int mergeFactor,
+                     ExceptionReporter reporter) {
+    this.manager = manager;
+    this.mergeFactor = mergeFactor;
+    this.reporter = reporter;
+  }
+  
+  public synchronized void close() throws InterruptedException {
+    closed = true;
+    waitForMerge();
+    interrupt();
+  }
+
+  public synchronized boolean isInProgress() {
+    return inProgress;
+  }
+  
+  public synchronized void startMerge(Set<T> inputs) {
+    if (!closed) {
+      inProgress = true;
+      this.inputs = new ArrayList<T>();
+      Iterator<T> iter=inputs.iterator();
+      for (int ctr = 0; iter.hasNext() && ctr < mergeFactor; ++ctr) {
+        this.inputs.add(iter.next());
+        iter.remove();
+      }
+      LOG.info(getName() + ": Starting merge with " + this.inputs.size() + 
+               " segments, while ignoring " + inputs.size() + " segments");
+      notifyAll();
+    }
+  }
+
+  public synchronized void waitForMerge() throws InterruptedException {
+    while (inProgress) {
+      wait();
+    }
+  }
+
+  public void run() {
+    while (true) {
+      try {
+        // Wait for notification to start the merge...
+        synchronized (this) {
+          while (!inProgress) {
+            wait();
+          }
+        }
+
+        // Merge
+        merge(inputs);
+      } catch (InterruptedException ie) {
+        return;
+      } catch(Throwable t) {
+        reporter.reportException(t);
+        return;
+      } finally {
+        synchronized (this) {
+          // Clear inputs
+          inputs = null;
+          inProgress = false;        
+          notifyAll();
+        }
+      }
+    }
+  }
+
+  public abstract void merge(List<T> inputs) 
+      throws IOException, InterruptedException;
+}