You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tez.apache.org by mi...@apache.org on 2016/10/16 23:24:05 UTC

tez git commit: TEZ-3430. Make split sorting optional. (mingma)

Repository: tez
Updated Branches:
  refs/heads/master 43f7b5e3a -> 48208dc8c


TEZ-3430. Make split sorting optional. (mingma)


Project: http://git-wip-us.apache.org/repos/asf/tez/repo
Commit: http://git-wip-us.apache.org/repos/asf/tez/commit/48208dc8
Tree: http://git-wip-us.apache.org/repos/asf/tez/tree/48208dc8
Diff: http://git-wip-us.apache.org/repos/asf/tez/diff/48208dc8

Branch: refs/heads/master
Commit: 48208dc8cb835fa6bb2de9e82aecf85cae83e35b
Parents: 43f7b5e
Author: Ming Ma <mi...@twitter.com>
Authored: Sun Oct 16 16:23:52 2016 -0700
Committer: Ming Ma <mi...@twitter.com>
Committed: Sun Oct 16 16:23:52 2016 -0700

----------------------------------------------------------------------
 CHANGES.txt                                     |  1 +
 .../apache/tez/mapreduce/client/YARNRunner.java |  7 +-
 .../common/MRInputAMSplitGenerator.java         | 13 ++-
 .../tez/mapreduce/hadoop/MRInputHelpers.java    | 64 +++++++--------
 .../org/apache/tez/mapreduce/input/MRInput.java | 43 +++++-----
 .../src/main/proto/MRRuntimeProtos.proto        |  1 +
 .../org/apache/tez/mapreduce/TezTestUtils.java  | 81 +++++++++++++++++++
 .../common/TestMRInputSplitDistributor.java     | 84 +-------------------
 8 files changed, 152 insertions(+), 142 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/tez/blob/48208dc8/CHANGES.txt
----------------------------------------------------------------------
diff --git a/CHANGES.txt b/CHANGES.txt
index ef6b890..654e88a 100644
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -7,6 +7,7 @@ INCOMPATIBLE CHANGES
 
 ALL CHANGES:
 
+  TEZ-3430. Make split sorting optional.
   TEZ-3466. Tez classpath building to mimic mapreduce classpath building.
   TEZ-3453. Correct the downloaded ATS dag data location for analyzer.
   TEZ-3449. Fix Spelling typos.

http://git-wip-us.apache.org/repos/asf/tez/blob/48208dc8/tez-mapreduce/src/main/java/org/apache/tez/mapreduce/client/YARNRunner.java
----------------------------------------------------------------------
diff --git a/tez-mapreduce/src/main/java/org/apache/tez/mapreduce/client/YARNRunner.java b/tez-mapreduce/src/main/java/org/apache/tez/mapreduce/client/YARNRunner.java
index 820e2e4..00a68cd 100644
--- a/tez-mapreduce/src/main/java/org/apache/tez/mapreduce/client/YARNRunner.java
+++ b/tez-mapreduce/src/main/java/org/apache/tez/mapreduce/client/YARNRunner.java
@@ -818,9 +818,10 @@ public class YARNRunner implements ClientProtocol {
   private static class MRInputHelpersInternal extends MRInputHelpers {
 
     protected static UserPayload createMRInputPayload(Configuration conf,
-                                                 MRRuntimeProtos.MRSplitsProto mrSplitsProto) throws
-        IOException {
-      return MRInputHelpers.createMRInputPayload(conf, mrSplitsProto);
+        MRRuntimeProtos.MRSplitsProto mrSplitsProto) throws
+            IOException {
+      return MRInputHelpers.createMRInputPayload(conf, mrSplitsProto, false,
+          true);
     }
   }
 

http://git-wip-us.apache.org/repos/asf/tez/blob/48208dc8/tez-mapreduce/src/main/java/org/apache/tez/mapreduce/common/MRInputAMSplitGenerator.java
----------------------------------------------------------------------
diff --git a/tez-mapreduce/src/main/java/org/apache/tez/mapreduce/common/MRInputAMSplitGenerator.java b/tez-mapreduce/src/main/java/org/apache/tez/mapreduce/common/MRInputAMSplitGenerator.java
index c109664..dbfdcb3 100644
--- a/tez-mapreduce/src/main/java/org/apache/tez/mapreduce/common/MRInputAMSplitGenerator.java
+++ b/tez-mapreduce/src/main/java/org/apache/tez/mapreduce/common/MRInputAMSplitGenerator.java
@@ -107,9 +107,11 @@ public class MRInputAMSplitGenerator extends InputInitializer {
 
 
     boolean groupSplits = userPayloadProto.getGroupingEnabled();
+    boolean sortSplits = userPayloadProto.getSortSplitsEnabled();
     LOG.info("Input " + getContext().getInputName() + " asking for " + numTasks
-        + " tasks. Headroom: " + totalResource + " Task Resource: "
-        + taskResource + " waves: " + waves + ", groupingEnabled: " + groupSplits);
+        + " tasks. Headroom: " + totalResource + ". Task Resource: "
+        + taskResource + ". waves: " + waves + ". groupingEnabled: "
+        + groupSplits + ". SortSplitsEnabled: " + sortSplits);
 
     // Read all credentials into the credentials instance stored in JobConf.
     JobConf jobConf = new JobConf(conf);
@@ -117,11 +119,8 @@ public class MRInputAMSplitGenerator extends InputInitializer {
 
     InputSplitInfoMem inputSplitInfo = null;
 
-    if (groupSplits) {
-      inputSplitInfo = MRInputHelpers.generateInputSplitsToMem(jobConf, true, numTasks);
-    } else {
-      inputSplitInfo = MRInputHelpers.generateInputSplitsToMem(jobConf, false, 0);
-    }
+    inputSplitInfo = MRInputHelpers.generateInputSplitsToMem(jobConf,
+        groupSplits, sortSplits, groupSplits ? numTasks : 0);
     sw.stop();
     if (LOG.isDebugEnabled()) {
       LOG.debug("Time to create splits to mem: " + sw.now(TimeUnit.MILLISECONDS));

http://git-wip-us.apache.org/repos/asf/tez/blob/48208dc8/tez-mapreduce/src/main/java/org/apache/tez/mapreduce/hadoop/MRInputHelpers.java
----------------------------------------------------------------------
diff --git a/tez-mapreduce/src/main/java/org/apache/tez/mapreduce/hadoop/MRInputHelpers.java b/tez-mapreduce/src/main/java/org/apache/tez/mapreduce/hadoop/MRInputHelpers.java
index 9b88c4d..97e1677 100644
--- a/tez-mapreduce/src/main/java/org/apache/tez/mapreduce/hadoop/MRInputHelpers.java
+++ b/tez-mapreduce/src/main/java/org/apache/tez/mapreduce/hadoop/MRInputHelpers.java
@@ -111,7 +111,7 @@ public class MRInputHelpers {
 
       InputDescriptor inputDescriptor = InputDescriptor.create(useLegacyInput ? MRInputLegacy.class
           .getName() : MRInput.class.getName())
-          .setUserPayload(createMRInputPayload(conf, null));
+          .setUserPayload(createMRInputPayload(conf, null, false, true));
       Map<String, LocalResource> additionalLocalResources = new HashMap<String, LocalResource>();
       updateLocalResourcesForInputSplits(conf, inputSplitInfo,
           additionalLocalResources);
@@ -278,8 +278,8 @@ public class MRInputHelpers {
    * @throws InterruptedException
    */
   @InterfaceStability.Unstable
-  public static InputSplitInfoMem generateInputSplitsToMem(Configuration conf, boolean groupSplits,
-                                                           int targetTasks)
+  public static InputSplitInfoMem generateInputSplitsToMem(Configuration conf,
+      boolean groupSplits, boolean sortSplits, int targetTasks)
       throws IOException, ClassNotFoundException, InterruptedException {
 
     InputSplitInfoMem splitInfoMem = null;
@@ -290,7 +290,7 @@ public class MRInputHelpers {
       }
       Job job = Job.getInstance(conf);
       org.apache.hadoop.mapreduce.InputSplit[] splits =
-          generateNewSplits(job, groupSplits, targetTasks);
+          generateNewSplits(job, groupSplits, sortSplits, targetTasks);
       splitInfoMem = new InputSplitInfoMem(splits, createTaskLocationHintsFromSplits(splits),
           splits.length, job.getCredentials(), job.getConfiguration());
     } else {
@@ -298,7 +298,7 @@ public class MRInputHelpers {
         LOG.debug("Generating mapred api input splits");
       }
       org.apache.hadoop.mapred.InputSplit[] splits =
-          generateOldSplits(jobConf, groupSplits, targetTasks);
+          generateOldSplits(jobConf, groupSplits, sortSplits, targetTasks);
       splitInfoMem = new InputSplitInfoMem(splits, createTaskLocationHintsFromSplits(splits),
           splits.length, jobConf.getCredentials(), jobConf);
     }
@@ -379,8 +379,8 @@ public class MRInputHelpers {
 
   @SuppressWarnings({ "rawtypes", "unchecked" })
   private static org.apache.hadoop.mapreduce.InputSplit[] generateNewSplits(
-      JobContext jobContext, boolean groupSplits, int numTasks)
-      throws ClassNotFoundException, IOException,
+      JobContext jobContext, boolean groupSplits, boolean sortSplits,
+      int numTasks) throws ClassNotFoundException, IOException,
       InterruptedException {
     Configuration conf = jobContext.getConfiguration();
 
@@ -413,15 +413,18 @@ public class MRInputHelpers {
     org.apache.hadoop.mapreduce.InputSplit[] splits = (org.apache.hadoop.mapreduce.InputSplit[]) array
         .toArray(new org.apache.hadoop.mapreduce.InputSplit[array.size()]);
 
-    // sort the splits into order based on size, so that the biggest
-    // go first
-    Arrays.sort(splits, new InputSplitComparator());
+    if (sortSplits) {
+      // sort the splits into order based on size, so that the biggest
+      // go first
+      Arrays.sort(splits, new InputSplitComparator());
+    }
     return splits;
   }
 
   @SuppressWarnings({ "rawtypes", "unchecked" })
   private static org.apache.hadoop.mapred.InputSplit[] generateOldSplits(
-      JobConf jobConf, boolean groupSplits, int numTasks) throws IOException {
+      JobConf jobConf, boolean groupSplits, boolean sortSplits, int numTasks)
+      throws IOException {
 
     // This is the real InputFormat
     org.apache.hadoop.mapred.InputFormat inputFormat;
@@ -445,9 +448,11 @@ public class MRInputHelpers {
     }
     org.apache.hadoop.mapred.InputSplit[] splits = finalInputFormat
         .getSplits(jobConf, jobConf.getNumMapTasks());
-    // sort the splits into order based on size, so that the biggest
-    // go first
-    Arrays.sort(splits, new OldInputSplitComparator());
+    if (sortSplits) {
+      // sort the splits into order based on size, so that the biggest
+      // go first
+      Arrays.sort(splits, new OldInputSplitComparator());
+    }
     return splits;
   }
 
@@ -519,7 +524,7 @@ public class MRInputHelpers {
       ClassNotFoundException {
 
     org.apache.hadoop.mapreduce.InputSplit[] splits =
-        generateNewSplits(jobContext, false, 0);
+        generateNewSplits(jobContext, false, true, 0);
 
     Configuration conf = jobContext.getConfiguration();
 
@@ -556,7 +561,7 @@ public class MRInputHelpers {
                                                    Path inputSplitDir) throws IOException {
 
     org.apache.hadoop.mapred.InputSplit[] splits =
-        generateOldSplits(jobConf, false, 0);
+        generateOldSplits(jobConf, false, true, 0);
 
     JobSplitWriter.createSplitFiles(inputSplitDir, jobConf,
         inputSplitDir.getFileSystem(jobConf), splits);
@@ -664,8 +669,8 @@ public class MRInputHelpers {
   }
 
   /**
-   * Called to specify that grouping of input splits be performed by Tez
-   * The conf should have the input format class configuration
+   * When isGrouped is true, it specifies that grouping of input splits be
+   * performed by Tez The conf should have the input format class configuration
    * set to the TezGroupedSplitsInputFormat. The real input format class name
    * should be passed as an argument to this method.
    * <p/>
@@ -674,27 +679,20 @@ public class MRInputHelpers {
    * or {@link org.apache.hadoop.mapreduce.split.TezGroupedSplitsInputFormat}
    */
   @InterfaceAudience.Private
-  protected static UserPayload createMRInputPayloadWithGrouping(Configuration conf) throws IOException {
-    Preconditions
-        .checkArgument(conf != null, "Configuration must be specified");
-    return createMRInputPayload(TezUtils.createByteStringFromConf(conf),
-        null, true);
-  }
-
-  @InterfaceAudience.Private
   protected static UserPayload createMRInputPayload(Configuration conf,
-                                               MRRuntimeProtos.MRSplitsProto mrSplitsProto) throws
-      IOException {
+      MRRuntimeProtos.MRSplitsProto mrSplitsProto, boolean isGrouped,
+      boolean isSorted) throws
+          IOException {
     Preconditions
         .checkArgument(conf != null, "Configuration must be specified");
 
     return createMRInputPayload(TezUtils.createByteStringFromConf(conf),
-        mrSplitsProto, false);
+        mrSplitsProto, isGrouped, isSorted);
   }
 
   private static UserPayload createMRInputPayload(ByteString bytes,
-                                             MRRuntimeProtos.MRSplitsProto mrSplitsProto,
-                                             boolean isGrouped) throws IOException {
+    MRRuntimeProtos.MRSplitsProto mrSplitsProto,
+    boolean isGrouped, boolean isSorted) throws IOException {
     MRRuntimeProtos.MRInputUserPayloadProto.Builder userPayloadBuilder =
         MRRuntimeProtos.MRInputUserPayloadProto
             .newBuilder();
@@ -703,7 +701,9 @@ public class MRInputHelpers {
       userPayloadBuilder.setSplits(mrSplitsProto);
     }
     userPayloadBuilder.setGroupingEnabled(isGrouped);
-    return UserPayload.create(userPayloadBuilder.build().toByteString().asReadOnlyByteBuffer());
+    userPayloadBuilder.setSortSplitsEnabled(isSorted);
+    return UserPayload.create(userPayloadBuilder.build().
+        toByteString().asReadOnlyByteBuffer());
   }
 
 

http://git-wip-us.apache.org/repos/asf/tez/blob/48208dc8/tez-mapreduce/src/main/java/org/apache/tez/mapreduce/input/MRInput.java
----------------------------------------------------------------------
diff --git a/tez-mapreduce/src/main/java/org/apache/tez/mapreduce/input/MRInput.java b/tez-mapreduce/src/main/java/org/apache/tez/mapreduce/input/MRInput.java
index 1b0ffed..248a92a 100644
--- a/tez-mapreduce/src/main/java/org/apache/tez/mapreduce/input/MRInput.java
+++ b/tez-mapreduce/src/main/java/org/apache/tez/mapreduce/input/MRInput.java
@@ -113,6 +113,7 @@ public class MRInput extends MRInputBase {
     final boolean inputFormatProvided;
     boolean useNewApi;
     boolean groupSplitsInAM = true;
+    boolean sortSplitsInAM = true;
     boolean generateSplitsInAM = true;
     String inputClassName = MRInput.class.getName();
     boolean getCredentialsForSourceFilesystem = true;
@@ -191,7 +192,17 @@ public class MRInput extends MRInputBase {
       groupSplitsInAM = value;
       return this;
     }
-    
+
+    /**
+     * Set whether splits should be sorted (default true)
+     * @param value whether to sort splits in the AM or not
+     * @return {@link org.apache.tez.mapreduce.input.MRInput.MRInputConfigBuilder}
+     */
+    public MRInputConfigBuilder sortSplits(boolean value) {
+      sortSplitsInAM = value;
+      return this;
+    }
+
     /**
      * Set whether splits should be generated in the Tez App Master (default true)
      * @param value whether to generate splits in the AM or not
@@ -266,7 +277,7 @@ public class MRInput extends MRInputBase {
       InputSplitInfo inputSplitInfo;
       setupBasicConf(conf);
       try {
-        inputSplitInfo = MRInputHelpers.generateInputSplitsToMem(conf, false, 0);
+        inputSplitInfo = MRInputHelpers.generateInputSplitsToMem(conf, false, true, 0);
       } catch (Exception e) {
         throw new TezUncheckedException(e);
       }
@@ -298,12 +309,8 @@ public class MRInput extends MRInputBase {
 
       Collection<URI> uris = maybeGetURIsForCredentials();
 
-      UserPayload payload = null;
-      if (groupSplitsInAM) {
-        payload = MRInputHelpersInternal.createMRInputPayloadWithGrouping(conf);
-      } else {
-        payload = MRInputHelpersInternal.createMRInputPayload(conf, null);
-      }
+      UserPayload payload = MRInputHelpersInternal.createMRInputPayload(
+          conf, groupSplitsInAM, sortSplitsInAM);
 
       DataSourceDescriptor ds = DataSourceDescriptor
           .create(InputDescriptor.create(inputClassName).setUserPayload(payload),
@@ -326,12 +333,8 @@ public class MRInput extends MRInputBase {
       
       Collection<URI> uris = maybeGetURIsForCredentials();
 
-      UserPayload payload = null;
-      if (groupSplitsInAM) {
-        payload = MRInputHelpersInternal.createMRInputPayloadWithGrouping(conf);
-      } else {
-        payload = MRInputHelpersInternal.createMRInputPayload(conf, null);
-      }
+      UserPayload payload = MRInputHelpersInternal.createMRInputPayload(
+          conf, groupSplitsInAM, sortSplitsInAM);
 
       DataSourceDescriptor ds = DataSourceDescriptor.create(
           InputDescriptor.create(inputClassName).setUserPayload(payload),
@@ -703,15 +706,17 @@ public class MRInput extends MRInputBase {
 
   private static class MRInputHelpersInternal extends MRInputHelpers {
 
-    protected static UserPayload createMRInputPayloadWithGrouping(Configuration conf) throws
-        IOException {
-      return MRInputHelpers.createMRInputPayloadWithGrouping(conf);
+    protected static UserPayload createMRInputPayload(Configuration conf,
+        boolean isGrouped, boolean isSorted) throws IOException {
+      return MRInputHelpers.createMRInputPayload(conf, null, isGrouped,
+          isSorted);
     }
 
     protected static UserPayload createMRInputPayload(Configuration conf,
-                                                 MRRuntimeProtos.MRSplitsProto mrSplitsProto) throws
+        MRRuntimeProtos.MRSplitsProto mrSplitsProto) throws
         IOException {
-      return MRInputHelpers.createMRInputPayload(conf, mrSplitsProto);
+      return MRInputHelpers.createMRInputPayload(conf, mrSplitsProto, false,
+          true);
     }
   }
 

http://git-wip-us.apache.org/repos/asf/tez/blob/48208dc8/tez-mapreduce/src/main/proto/MRRuntimeProtos.proto
----------------------------------------------------------------------
diff --git a/tez-mapreduce/src/main/proto/MRRuntimeProtos.proto b/tez-mapreduce/src/main/proto/MRRuntimeProtos.proto
index 646eb3a..8cba5fe 100644
--- a/tez-mapreduce/src/main/proto/MRRuntimeProtos.proto
+++ b/tez-mapreduce/src/main/proto/MRRuntimeProtos.proto
@@ -38,4 +38,5 @@ message MRInputUserPayloadProto {
   optional bytes configuration_bytes = 1;
   optional MRSplitsProto splits = 2;
   optional bool grouping_enabled = 3;
+  optional bool sort_splits_enabled = 4 [default = true];
 }

http://git-wip-us.apache.org/repos/asf/tez/blob/48208dc8/tez-mapreduce/src/test/java/org/apache/tez/mapreduce/TezTestUtils.java
----------------------------------------------------------------------
diff --git a/tez-mapreduce/src/test/java/org/apache/tez/mapreduce/TezTestUtils.java b/tez-mapreduce/src/test/java/org/apache/tez/mapreduce/TezTestUtils.java
index 8a8b141..8912ad2 100644
--- a/tez-mapreduce/src/test/java/org/apache/tez/mapreduce/TezTestUtils.java
+++ b/tez-mapreduce/src/test/java/org/apache/tez/mapreduce/TezTestUtils.java
@@ -18,10 +18,17 @@
 package org.apache.tez.mapreduce;
 
 import org.apache.hadoop.yarn.api.records.ApplicationId;
+import org.apache.hadoop.yarn.api.records.Resource;
+import org.apache.tez.dag.api.UserPayload;
+import org.apache.tez.dag.api.event.VertexState;
 import org.apache.tez.dag.records.TezDAGID;
 import org.apache.tez.dag.records.TezTaskAttemptID;
 import org.apache.tez.dag.records.TezTaskID;
 import org.apache.tez.dag.records.TezVertexID;
+import org.apache.tez.runtime.api.InputInitializerContext;
+
+import java.io.IOException;
+import java.util.Set;
 
 public class TezTestUtils {
 
@@ -45,4 +52,78 @@ public class TezTestUtils {
             jobId), vertexId),
             taskId);
   }
+
+  public static class TezRootInputInitializerContextForTest implements
+      InputInitializerContext {
+
+    private final ApplicationId appId;
+    private final UserPayload payload;
+
+    public TezRootInputInitializerContextForTest(UserPayload payload) throws IOException {
+      appId = ApplicationId.newInstance(1000, 200);
+      this.payload = payload == null ? UserPayload.create(null) : payload;
+    }
+
+    @Override
+    public ApplicationId getApplicationId() {
+      return appId;
+    }
+
+    @Override
+    public String getDAGName() {
+      return "FakeDAG";
+    }
+
+    @Override
+    public String getInputName() {
+      return "MRInput";
+    }
+
+    @Override
+    public UserPayload getInputUserPayload() {
+      return payload;
+    }
+
+    @Override
+    public int getNumTasks() {
+      return 100;
+    }
+
+    @Override
+    public Resource getVertexTaskResource() {
+      return Resource.newInstance(1024, 1);
+    }
+
+    @Override
+    public Resource getTotalAvailableResource() {
+      return Resource.newInstance(10240, 10);
+    }
+
+    @Override
+    public int getNumClusterNodes() {
+      return 10;
+    }
+
+    @Override
+    public int getDAGAttemptNumber() {
+      return 1;
+    }
+
+    @Override
+    public int getVertexNumTasks(String vertexName) {
+      throw new UnsupportedOperationException("getVertexNumTasks not implemented in this mock");
+    }
+
+    @Override
+    public void registerForVertexStateUpdates(String vertexName, Set<VertexState> stateSet) {
+      throw new UnsupportedOperationException("getVertexNumTasks not implemented in this mock");
+    }
+
+    @Override
+    public UserPayload getUserPayload() {
+      throw new UnsupportedOperationException("getUserPayload not implemented in this mock");
+    }
+
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/tez/blob/48208dc8/tez-mapreduce/src/test/java/org/apache/tez/mapreduce/common/TestMRInputSplitDistributor.java
----------------------------------------------------------------------
diff --git a/tez-mapreduce/src/test/java/org/apache/tez/mapreduce/common/TestMRInputSplitDistributor.java b/tez-mapreduce/src/test/java/org/apache/tez/mapreduce/common/TestMRInputSplitDistributor.java
index cdf1ee4..3772cde 100644
--- a/tez-mapreduce/src/test/java/org/apache/tez/mapreduce/common/TestMRInputSplitDistributor.java
+++ b/tez-mapreduce/src/test/java/org/apache/tez/mapreduce/common/TestMRInputSplitDistributor.java
@@ -29,17 +29,12 @@ import static org.junit.Assert.assertTrue;
 import java.io.DataInput;
 import java.io.DataOutput;
 import java.io.IOException;
-import java.util.EnumSet;
 import java.util.List;
-import java.util.Set;
 
 import org.apache.hadoop.classification.InterfaceAudience.Private;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.mapred.InputSplit;
-import org.apache.hadoop.yarn.api.records.ApplicationId;
-import org.apache.hadoop.yarn.api.records.Resource;
-import org.apache.tez.dag.api.event.VertexState;
-import org.apache.tez.dag.api.event.VertexStateUpdate;
+import org.apache.tez.mapreduce.TezTestUtils;
 import org.apache.tez.mapreduce.hadoop.MRInputHelpers;
 import org.apache.tez.mapreduce.hadoop.MRJobConfig;
 import org.apache.tez.mapreduce.lib.MRInputUtils;
@@ -75,7 +70,7 @@ public class TestMRInputSplitDistributor {
     UserPayload userPayload =
         UserPayload.create(payloadProto.build().toByteString().asReadOnlyByteBuffer());
 
-    InputInitializerContext context = new TezRootInputInitializerContextForTest(userPayload);
+    InputInitializerContext context = new TezTestUtils.TezRootInputInitializerContextForTest(userPayload);
     MRInputSplitDistributor splitDist = new MRInputSplitDistributor(context);
 
     List<Event> events = splitDist.initialize();
@@ -124,7 +119,7 @@ public class TestMRInputSplitDistributor {
     UserPayload userPayload =
         UserPayload.create(payloadProto.build().toByteString().asReadOnlyByteBuffer());
 
-    InputInitializerContext context = new TezRootInputInitializerContextForTest(userPayload);
+    InputInitializerContext context = new TezTestUtils.TezRootInputInitializerContextForTest(userPayload);
     MRInputSplitDistributor splitDist = new MRInputSplitDistributor(context);
 
     List<Event> events = splitDist.initialize();
@@ -150,79 +145,6 @@ public class TestMRInputSplitDistributor {
     assertEquals(2, ((InputSplitForTest) diEvent2.getDeserializedUserPayload()).identifier);
   }
 
-  private static class TezRootInputInitializerContextForTest implements
-      InputInitializerContext {
-
-    private final ApplicationId appId;
-    private final UserPayload payload;
-
-    TezRootInputInitializerContextForTest(UserPayload payload) throws IOException {
-      appId = ApplicationId.newInstance(1000, 200);
-      this.payload = payload == null ? UserPayload.create(null) : payload;
-    }
-
-    @Override
-    public ApplicationId getApplicationId() {
-      return appId;
-    }
-
-    @Override
-    public String getDAGName() {
-      return "FakeDAG";
-    }
-
-    @Override
-    public String getInputName() {
-      return "MRInput";
-    }
-
-    @Override
-    public UserPayload getInputUserPayload() {
-      return payload;
-    }
-
-    @Override
-    public int getNumTasks() {
-      return 100;
-    }
-
-    @Override
-    public Resource getVertexTaskResource() {
-      return Resource.newInstance(1024, 1);
-    }
-
-    @Override
-    public Resource getTotalAvailableResource() {
-      return Resource.newInstance(10240, 10);
-    }
-
-    @Override
-    public int getNumClusterNodes() {
-      return 10;
-    }
-
-    @Override
-    public int getDAGAttemptNumber() {
-      return 1;
-    }
-
-    @Override
-    public int getVertexNumTasks(String vertexName) {
-      throw new UnsupportedOperationException("getVertexNumTasks not implemented in this mock");
-    }
-
-    @Override
-    public void registerForVertexStateUpdates(String vertexName, Set<VertexState> stateSet) {
-      throw new UnsupportedOperationException("getVertexNumTasks not implemented in this mock");
-    }
-
-    @Override
-    public UserPayload getUserPayload() {
-      throw new UnsupportedOperationException("getUserPayload not implemented in this mock");
-    }
-
-  }
-
   @Private
   private static class InputSplitForTest implements InputSplit {