You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tez.apache.org by bi...@apache.org on 2013/10/08 03:29:05 UTC

git commit: TEZ-534. Add an InputFormat that combines original splits into groups (bikas)

Updated Branches:
  refs/heads/master eb92543b3 -> 4dd5e195d


TEZ-534. Add an InputFormat that combines original splits into groups (bikas)


Project: http://git-wip-us.apache.org/repos/asf/incubator-tez/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-tez/commit/4dd5e195
Tree: http://git-wip-us.apache.org/repos/asf/incubator-tez/tree/4dd5e195
Diff: http://git-wip-us.apache.org/repos/asf/incubator-tez/diff/4dd5e195

Branch: refs/heads/master
Commit: 4dd5e195d074d36edfd414fac4f1554031257847
Parents: eb92543
Author: Bikas Saha <bi...@apache.org>
Authored: Mon Oct 7 18:22:09 2013 -0700
Committer: Bikas Saha <bi...@apache.org>
Committed: Mon Oct 7 18:22:09 2013 -0700

----------------------------------------------------------------------
 .../hadoop/mapred/split/TezGroupedSplit.java    | 137 +++++++
 .../split/TezGroupedSplitsInputFormat.java      | 380 +++++++++++++++++++
 .../hadoop/mapreduce/split/TezGroupedSplit.java | 146 +++++++
 .../split/TezGroupedSplitsInputFormat.java      | 341 +++++++++++++++++
 .../hadoop/mapred/split/TestGroupedSplits.java  | 278 ++++++++++++++
 5 files changed, 1282 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/4dd5e195/tez-mapreduce/src/main/java/org/apache/hadoop/mapred/split/TezGroupedSplit.java
----------------------------------------------------------------------
diff --git a/tez-mapreduce/src/main/java/org/apache/hadoop/mapred/split/TezGroupedSplit.java b/tez-mapreduce/src/main/java/org/apache/hadoop/mapred/split/TezGroupedSplit.java
new file mode 100644
index 0000000..0bb383d
--- /dev/null
+++ b/tez-mapreduce/src/main/java/org/apache/hadoop/mapred/split/TezGroupedSplit.java
@@ -0,0 +1,137 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.mapred.split;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapred.InputSplit;
+import org.apache.tez.dag.api.TezUncheckedException;
+
+public class TezGroupedSplit implements InputSplit {
+
+  List<InputSplit> wrappedSplits = null;
+  String wrappedInputFormatName = null;
+  String[] locations = null;
+  long length = 0;
+  
+  public TezGroupedSplit() {
+    
+  }
+  
+  public TezGroupedSplit(int numSplits, String wrappedInputFormatName,
+      String[] locations) {
+    this.wrappedSplits = new ArrayList<InputSplit>(numSplits);
+    this.wrappedInputFormatName = wrappedInputFormatName;
+    this.locations = locations;
+  }
+  
+  public void addSplit(InputSplit split) {
+    wrappedSplits.add(split);
+    try {
+      length += split.getLength();
+    } catch (Exception e) {
+      throw new TezUncheckedException(e);
+    }
+  }
+  
+  @Override
+  public void write(DataOutput out) throws IOException {
+    if (wrappedSplits == null) {
+      throw new TezUncheckedException("Wrapped splits cannot be empty");
+    }
+
+    Text.writeString(out, wrappedInputFormatName);
+    Text.writeString(out, wrappedSplits.get(0).getClass().getCanonicalName());
+    out.writeInt(wrappedSplits.size());
+    for(InputSplit split : wrappedSplits) {
+      writeWrappedSplit(split, out);
+    }
+    out.writeLong(length);
+    
+    if (locations == null) {
+      out.writeInt(0);
+    } else {
+      out.writeInt(locations.length);
+      for (String location : locations) {
+        Text.writeString(out, location);
+      }
+    }
+  }
+
+  @SuppressWarnings("unchecked")
+  @Override
+  public void readFields(DataInput in) throws IOException {
+    wrappedInputFormatName = Text.readString(in);
+    String inputSplitClassName = Text.readString(in);
+    Class<? extends InputSplit> clazz = 
+        (Class<? extends InputSplit>) 
+        TezGroupedSplitsInputFormat.getClassFromName(inputSplitClassName);
+    
+    int numSplits = in.readInt();
+    
+    wrappedSplits = new ArrayList<InputSplit>(numSplits);
+    for (int i=0; i<numSplits; ++i) {
+      addSplit(readWrappedSplit(in, clazz));
+    }
+    
+    long recordedLength = in.readLong();
+    if(recordedLength != length) {
+      throw new TezUncheckedException("Expected length: " + recordedLength
+          + " actual length: " + length);
+    }
+    int numLocs = in.readInt();
+    if (numLocs > 0) {
+      locations = new String[numLocs];
+      for (int i=0; i<numLocs; ++i) {
+        locations[i] = Text.readString(in);
+      }
+    }
+  }
+  
+  void writeWrappedSplit(InputSplit split, DataOutput out) throws IOException {
+    split.write(out);
+  }
+  
+  InputSplit readWrappedSplit(DataInput in, Class<? extends InputSplit> clazz) 
+      throws IOException {
+    InputSplit split;
+    try {
+      split = clazz.newInstance();
+    } catch (Exception e) {
+      throw new TezUncheckedException(e);
+    }
+    split.readFields(in);
+    return split;
+  }
+  
+  @Override
+  public long getLength() throws IOException {
+    return length;
+  }
+
+  @Override
+  public String[] getLocations() throws IOException {
+    return locations;
+  }
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/4dd5e195/tez-mapreduce/src/main/java/org/apache/hadoop/mapred/split/TezGroupedSplitsInputFormat.java
----------------------------------------------------------------------
diff --git a/tez-mapreduce/src/main/java/org/apache/hadoop/mapred/split/TezGroupedSplitsInputFormat.java b/tez-mapreduce/src/main/java/org/apache/hadoop/mapred/split/TezGroupedSplitsInputFormat.java
new file mode 100644
index 0000000..6edd740
--- /dev/null
+++ b/tez-mapreduce/src/main/java/org/apache/hadoop/mapred/split/TezGroupedSplitsInputFormat.java
@@ -0,0 +1,380 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.mapred.split;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.mapred.InputFormat;
+import org.apache.hadoop.mapred.InputSplit;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.hadoop.mapred.RecordReader;
+import org.apache.hadoop.mapred.Reporter;
+import org.apache.tez.dag.api.TezUncheckedException;
+
+import com.google.common.base.Preconditions;
+
+public class TezGroupedSplitsInputFormat<K, V> implements InputFormat<K, V> {
+  private static final Log LOG = LogFactory.getLog(TezGroupedSplitsInputFormat.class);
+
+  InputFormat<K, V> wrappedInputFormat;
+  int desiredNumSplits = 0;
+  
+  public TezGroupedSplitsInputFormat() {
+    
+  }
+  
+  public void setInputFormat(InputFormat<K, V> wrappedInputFormat) {
+    this.wrappedInputFormat = wrappedInputFormat;
+    if (LOG.isDebugEnabled()) {
+      LOG.debug("wrappedInputFormat: " + wrappedInputFormat.getClass().getName());
+    }
+  }
+  
+  public void setDesiredNumberOfSPlits(int num) {
+    Preconditions.checkArgument(num > 0);
+    this.desiredNumSplits = num;
+    if (LOG.isDebugEnabled()) {
+      LOG.debug("desiredNumSplits: " + desiredNumSplits);
+    }
+  }
+  
+  class SplitHolder {
+    InputSplit split;
+    boolean isProcessed = false;
+    SplitHolder(InputSplit split) {
+      this.split = split;
+    }
+  }
+  
+  class LocationHolder {
+    List<SplitHolder> splits;
+    int headIndex = 0;
+    LocationHolder(int capacity) {
+      splits = new ArrayList<SplitHolder>(capacity);
+    }
+    boolean isEmpty() {
+      return (headIndex == splits.size());
+    }
+    SplitHolder getUnprocessedHeadSplit() {
+      while (!isEmpty()) {
+        SplitHolder holder = splits.get(headIndex);
+        if (!holder.isProcessed) {
+          return holder;
+        }
+        incrementHeadIndex();
+      }
+      return null;
+    }
+    void incrementHeadIndex() {
+      headIndex++;
+    }
+  }
+  
+  @Override
+  public InputSplit[] getSplits(JobConf job, int numSplits) throws IOException {
+    if (desiredNumSplits > 0) {
+      // get the desired num splits directly if possible
+      numSplits = desiredNumSplits;
+    }
+    InputSplit[] originalSplits = wrappedInputFormat.getSplits(job, numSplits);
+    String wrappedInputFormatName = wrappedInputFormat.getClass().getCanonicalName();
+    if (desiredNumSplits == 0 ||
+        originalSplits.length == 0 ||
+        desiredNumSplits >= originalSplits.length) {
+      // nothing set. so return all the splits as is
+      if (LOG.isDebugEnabled()) {
+        LOG.debug("Using original number of splits: " + originalSplits.length);
+      }
+      InputSplit[] groupedSplits = new TezGroupedSplit[originalSplits.length];
+      int i=0;
+      for (InputSplit split : originalSplits) {
+        TezGroupedSplit newSplit = 
+            new TezGroupedSplit(1, wrappedInputFormatName, split.getLocations());
+        newSplit.addSplit(split);
+        groupedSplits[i++] = newSplit;
+      }
+      return groupedSplits;
+    }
+    
+    String[] emptyLocations = {"EmptyLocation"};
+    List<InputSplit> groupedSplitsList = new ArrayList<InputSplit>(desiredNumSplits);
+    
+    // sort the splits by length
+    Arrays.sort(originalSplits, new Comparator<InputSplit>() {
+      @Override
+      public int compare(InputSplit o1, InputSplit o2) {
+        try {
+          if (o1.getLength() < o2.getLength()) {
+            return -1;
+          } else if (o1.getLength() > o2.getLength()) {
+            return 1;
+          }
+        } catch (Exception e) {
+          throw new TezUncheckedException(e);
+        }
+        return 0;
+      }
+    });
+    
+    long totalLength = 0;
+    Map<String, LocationHolder> distinctLocations = new HashMap<String, LocationHolder>();
+    // go through splits in sorted order and add them to locations
+    for (InputSplit split : originalSplits) {
+      totalLength += split.getLength();
+      String[] locations = split.getLocations();
+      if (locations == null || locations.length == 0) {
+        locations = emptyLocations;
+      }
+      for (String location : locations ) {
+        distinctLocations.put(location, null);
+      }
+    }
+    
+    long lengthPerSplit = totalLength/desiredNumSplits;
+    int numLocations = distinctLocations.size();
+    int numSplitsPerLocation = originalSplits.length/numLocations;
+    int numSplitsInGroup = originalSplits.length/desiredNumSplits;
+    for (String location : distinctLocations.keySet()) {
+      distinctLocations.put(location, new LocationHolder(numSplitsPerLocation));
+    }
+    
+    for (InputSplit split : originalSplits) {
+      SplitHolder splitHolder = new SplitHolder(split);
+      String[] locations = split.getLocations();
+      if (locations == null || locations.length == 0) {
+        locations = emptyLocations;
+      }
+      for (String location : locations ) {
+        LocationHolder holder = distinctLocations.get(location);
+        holder.splits.add(splitHolder); // added smallest to largest
+      }
+    }
+    
+    if (LOG.isDebugEnabled()) {
+      LOG.debug("Desired lengthPerSplit: " + lengthPerSplit +
+          " numLocations: " + numLocations +
+          " numSplitsPerLocation: " + numSplitsPerLocation +
+          " numSplitsInGroup: " + numSplitsInGroup + 
+          " totalLength: " + totalLength);
+    }
+    
+    // go through locations and group splits
+    int splitsProcessed = 0;
+    List<SplitHolder> group = new ArrayList<SplitHolder>(numSplitsInGroup+1);
+    boolean allowSmallGroups = false;
+    int iterations = 0;
+    while (splitsProcessed < originalSplits.length) {
+      group.clear();
+      iterations++;
+      int numFullGroupsCreated = 0;
+      for (Map.Entry<String, LocationHolder> entry : distinctLocations.entrySet()) {
+        String location = entry.getKey();
+        LocationHolder holder = entry.getValue();
+        SplitHolder splitHolder = holder.getUnprocessedHeadSplit();
+        if (splitHolder == null) {
+          // all splits on node processed
+          continue;
+        }
+        int oldHeadIndex = holder.headIndex;
+        long groupLength = 0;
+        do {
+          group.add(splitHolder);
+          groupLength += splitHolder.split.getLength();
+          holder.incrementHeadIndex();
+          splitHolder = holder.getUnprocessedHeadSplit();
+        } while(splitHolder != null && 
+            groupLength + splitHolder.split.getLength() <= lengthPerSplit);
+
+        if (holder.isEmpty() && groupLength < lengthPerSplit/2 && !allowSmallGroups) {
+          // group too small, reset it
+          holder.headIndex = oldHeadIndex;
+          continue;
+        }
+        
+        numFullGroupsCreated++;
+
+        // One split group created
+        String[] groupLocation = {location};
+        if (location == emptyLocations[0]) {
+          groupLocation = null;
+        }
+        TezGroupedSplit groupedSplit = 
+            new TezGroupedSplit(group.size(), wrappedInputFormatName, groupLocation);
+        for (SplitHolder groupedSplitHolder : group) {
+          groupedSplit.addSplit(groupedSplitHolder.split);
+          groupedSplitHolder.isProcessed = true;
+          splitsProcessed++;
+        }
+        if (LOG.isDebugEnabled()) {
+          LOG.debug("Grouped " + group.size() + " split at: " + groupLocation);
+        }
+        groupedSplitsList.add(groupedSplit);
+      }
+      
+      if (!allowSmallGroups && numFullGroupsCreated < numLocations/4) {
+        // a few nodes have a lot of data or data is thinly spread across nodes
+        // so allow small groups now
+        if (LOG.isDebugEnabled()) {
+          LOG.debug("Allowing small groups");
+        }
+        allowSmallGroups = true;
+      }
+      if (LOG.isDebugEnabled()) {
+        LOG.debug("Iteration: " + iterations +
+            " splitsProcessed: " + splitsProcessed + 
+            " numFullGroupsInRound: " + numFullGroupsCreated +
+            " totalGroups: " + groupedSplitsList.size());
+      }
+    }
+    InputSplit[] groupedSplits = new InputSplit[groupedSplitsList.size()];
+    groupedSplitsList.toArray(groupedSplits);
+    LOG.info("Number of splits created: " + groupedSplitsList.size());
+    return groupedSplits;
+  }
+  
+  @Override
+  public RecordReader<K, V> getRecordReader(InputSplit split, JobConf job,
+      Reporter reporter) throws IOException {
+    TezGroupedSplit groupedSplit = (TezGroupedSplit) split;
+    initInputFormatFromSplit(groupedSplit);
+    return new TezGroupedSplitsRecordReader(groupedSplit, job, reporter);
+  }
+  
+  @SuppressWarnings({ "unchecked", "rawtypes" })
+  void initInputFormatFromSplit(TezGroupedSplit split) {
+    if (wrappedInputFormat == null) {
+      Class<? extends InputFormat> clazz = (Class<? extends InputFormat>) 
+          getClassFromName(split.wrappedInputFormatName);
+      try {
+        wrappedInputFormat = clazz.newInstance();
+      } catch (Exception e) {
+        throw new TezUncheckedException(e);
+      }
+    }
+  }
+  
+  static Class<?> getClassFromName(String name) {
+    try {
+      return Class.forName(name);
+    } catch (ClassNotFoundException e1) {
+      throw new TezUncheckedException(e1);
+    }
+  }
+  
+  public class TezGroupedSplitsRecordReader implements RecordReader<K, V> {
+
+    TezGroupedSplit groupedSplit;
+    JobConf job;
+    Reporter reporter;
+    int idx = 0;
+    long progress;
+    RecordReader<K, V> curReader;
+    
+    public TezGroupedSplitsRecordReader(TezGroupedSplit split, JobConf job,
+        Reporter reporter) throws IOException {
+      this.groupedSplit = split;
+      this.job = job;
+      this.reporter = reporter;
+      initNextRecordReader();
+    }
+    
+    @Override
+    public boolean next(K key, V value) throws IOException {
+
+      while ((curReader == null) || !curReader.next(key, value)) {
+        if (!initNextRecordReader()) {
+          return false;
+        }
+      }
+      return true;
+    }
+
+    @Override
+    public K createKey() {
+      return curReader.createKey();
+    }
+    
+    @Override
+    public V createValue() {
+      return curReader.createValue();
+    }
+    
+    @Override
+    public float getProgress() throws IOException {
+      return Math.min(1.0f,  getPos()/(float)(groupedSplit.getLength()));
+    }
+    
+    @Override
+    public void close() throws IOException {
+      if (curReader != null) {
+        curReader.close();
+        curReader = null;
+      }
+    }
+    
+    protected boolean initNextRecordReader() throws IOException {
+      if (curReader != null) {
+        curReader.close();
+        curReader = null;
+        if (idx > 0) {
+          progress += groupedSplit.wrappedSplits.get(idx-1).getLength();
+        }
+      }
+
+      // if all chunks have been processed, nothing more to do.
+      if (idx == groupedSplit.wrappedSplits.size()) {
+        return false;
+      }
+
+      if (LOG.isDebugEnabled()) {
+        LOG.debug("Init record reader for index " + idx + " of " + 
+                  groupedSplit.wrappedSplits.size());
+      }
+
+      // get a record reader for the idx-th chunk
+      try {
+        curReader = wrappedInputFormat.getRecordReader(
+            groupedSplit.wrappedSplits.get(idx), job, reporter);
+      } catch (Exception e) {
+        throw new RuntimeException (e);
+      }
+      idx++;
+      return true;
+    }
+
+    @Override
+    public long getPos() throws IOException {
+      long subprogress = 0;    // bytes processed in current split
+      if (null != curReader) {
+        // idx is always one past the current subsplit's true index.
+        subprogress = curReader.getPos();
+      }
+      return (progress + subprogress);
+    }
+  }  
+
+}

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/4dd5e195/tez-mapreduce/src/main/java/org/apache/hadoop/mapreduce/split/TezGroupedSplit.java
----------------------------------------------------------------------
diff --git a/tez-mapreduce/src/main/java/org/apache/hadoop/mapreduce/split/TezGroupedSplit.java b/tez-mapreduce/src/main/java/org/apache/hadoop/mapreduce/split/TezGroupedSplit.java
new file mode 100644
index 0000000..afc3108
--- /dev/null
+++ b/tez-mapreduce/src/main/java/org/apache/hadoop/mapreduce/split/TezGroupedSplit.java
@@ -0,0 +1,146 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.mapreduce.split;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.mapreduce.InputSplit;
+import org.apache.tez.dag.api.TezUncheckedException;
+
+public class TezGroupedSplit extends InputSplit implements Writable {
+
+  List<InputSplit> wrappedSplits = null;
+  String wrappedInputFormatName = null;
+  String[] locations = null;
+  long length = 0;
+  
+  public TezGroupedSplit() {
+    
+  }
+  
+  public TezGroupedSplit(int numSplits, String wrappedInputFormatName,
+      String[] locations) {
+    this.wrappedSplits = new ArrayList<InputSplit>(numSplits);
+    this.wrappedInputFormatName = wrappedInputFormatName;
+    this.locations = locations;
+  }
+  
+  public void addSplit(InputSplit split) {
+    wrappedSplits.add(split);
+    try {
+      length += split.getLength();
+    } catch (Exception e) {
+      throw new TezUncheckedException(e);
+    }
+  }
+  
+  @Override
+  public void write(DataOutput out) throws IOException {
+    if (wrappedSplits == null) {
+      throw new TezUncheckedException("Wrapped splits cannot be empty");
+    }
+
+    Text.writeString(out, wrappedInputFormatName);
+    Text.writeString(out, wrappedSplits.get(0).getClass().getCanonicalName());
+    out.writeInt(wrappedSplits.size());
+    for(InputSplit split : wrappedSplits) {
+      writeWrappedSplit(split, out);
+    }
+    out.writeLong(length);
+    
+    if (locations == null) {
+      out.writeInt(0);
+    } else {
+      out.writeInt(locations.length);
+      for (String location : locations) {
+        Text.writeString(out, location);
+      }
+    }
+  }
+
+  @SuppressWarnings("unchecked")
+  @Override
+  public void readFields(DataInput in) throws IOException {
+    wrappedInputFormatName = Text.readString(in);
+    String inputSplitClassName = Text.readString(in);
+    Class<? extends InputSplit> clazz = 
+        (Class<? extends InputSplit>) 
+        TezGroupedSplitsInputFormat.getClassFromName(inputSplitClassName);
+    
+    int numSplits = in.readInt();
+    
+    wrappedSplits = new ArrayList<InputSplit>(numSplits);
+    for (int i=0; i<numSplits; ++i) {
+      addSplit(readWrappedSplit(in, clazz));
+    }
+    
+    long recordedLength = in.readLong();
+    if(recordedLength != length) {
+      throw new TezUncheckedException("Expected length: " + recordedLength
+          + " actual length: " + length);
+    }
+    int numLocs = in.readInt();
+    if (numLocs > 0) {
+      locations = new String[numLocs];
+      for (int i=0; i<numLocs; ++i) {
+        locations[i] = Text.readString(in);
+      }
+    }
+  }
+  
+  void writeWrappedSplit(InputSplit split, DataOutput out) throws IOException {
+    if (split instanceof Writable) {
+      ((Writable) split).write(out);
+    } else {
+      throw new TezUncheckedException(
+          split.getClass().getName() + " is not a Writable");
+    }
+  }
+  
+  InputSplit readWrappedSplit(DataInput in, Class<? extends InputSplit> clazz) {
+    try {
+      InputSplit split = clazz.newInstance();
+      if (split instanceof Writable) {
+        ((Writable) split).readFields(in);
+        return split;
+      } else {
+        throw new TezUncheckedException(
+            split.getClass().getName() + " is not a Writable");          
+      }
+    } catch (Exception e) {
+      throw new TezUncheckedException(e);
+    }
+  }
+  
+  @Override
+  public long getLength() throws IOException, InterruptedException {
+    return length;
+  }
+
+  @Override
+  public String[] getLocations() throws IOException, InterruptedException {
+    return locations;
+  }
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/4dd5e195/tez-mapreduce/src/main/java/org/apache/hadoop/mapreduce/split/TezGroupedSplitsInputFormat.java
----------------------------------------------------------------------
diff --git a/tez-mapreduce/src/main/java/org/apache/hadoop/mapreduce/split/TezGroupedSplitsInputFormat.java b/tez-mapreduce/src/main/java/org/apache/hadoop/mapreduce/split/TezGroupedSplitsInputFormat.java
new file mode 100644
index 0000000..4fd44d7
--- /dev/null
+++ b/tez-mapreduce/src/main/java/org/apache/hadoop/mapreduce/split/TezGroupedSplitsInputFormat.java
@@ -0,0 +1,341 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.mapreduce.split;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.hadoop.mapreduce.InputFormat;
+import org.apache.hadoop.mapreduce.InputSplit;
+import org.apache.hadoop.mapreduce.JobContext;
+import org.apache.hadoop.mapreduce.RecordReader;
+import org.apache.hadoop.mapreduce.TaskAttemptContext;
+import org.apache.tez.dag.api.TezUncheckedException;
+
+public class TezGroupedSplitsInputFormat<K, V> extends InputFormat<K, V> {
+
+  InputFormat<K, V> wrappedInputFormat;
+  int desiredNumSplits = 0;
+  List<InputSplit> groupedSplits = null;
+  
+  public TezGroupedSplitsInputFormat() {
+    
+  }
+  
+  public void setInputFormat(InputFormat<K, V> wrappedInputFormat) {
+    this.wrappedInputFormat = wrappedInputFormat;
+  }
+  
+  public void setDesiredNumberOfSPlits(int num) {
+    this.desiredNumSplits = num;
+  }
+  
+  class SplitHolder {
+    InputSplit split;
+    boolean isProcessed = false;
+    SplitHolder(InputSplit split) {
+      this.split = split;
+    }
+  }
+  
+  class LocationHolder {
+    List<SplitHolder> splits;
+    int headIndex = 0;
+    LocationHolder(int capacity) {
+      splits = new ArrayList<SplitHolder>(capacity);
+    }
+    boolean isEmpty() {
+      return (headIndex == splits.size());
+    }
+    SplitHolder getUnprocessedHeadSplit() {
+      while (!isEmpty()) {
+        SplitHolder holder = splits.get(headIndex);
+        if (!holder.isProcessed) {
+          return holder;
+        }
+        incrementHeadIndex();
+      }
+      return null;
+    }
+    void incrementHeadIndex() {
+      headIndex++;
+    }
+  }
+  
+  @Override
+  public List<InputSplit> getSplits(JobContext context) throws IOException,
+      InterruptedException {
+    List<InputSplit> originalSplits = wrappedInputFormat.getSplits(context);
+    String wrappedInputFormatName = wrappedInputFormat.getClass().getCanonicalName();
+    if (desiredNumSplits == 0 ||
+        originalSplits.size() == 0 ||
+        desiredNumSplits >= originalSplits.size()) {
+      // nothing set. so return all the splits as is
+      groupedSplits = new ArrayList<InputSplit>(originalSplits.size());
+      for (InputSplit split : originalSplits) {
+        TezGroupedSplit newSplit = 
+            new TezGroupedSplit(1, wrappedInputFormatName, split.getLocations());
+        newSplit.addSplit(split);
+        groupedSplits.add(newSplit);
+      }
+      return groupedSplits;
+    }
+    
+    String[] emptyLocations = {"EmptyLocation"};
+    groupedSplits = new ArrayList<InputSplit>(desiredNumSplits);
+    
+    // sort the splits by length
+    Collections.sort(originalSplits, new Comparator<InputSplit>() {
+      @Override
+      public int compare(InputSplit o1, InputSplit o2) {
+        try {
+          if (o1.getLength() < o2.getLength()) {
+            return -1;
+          } else if (o1.getLength() > o2.getLength()) {
+            return 1;
+          }
+        } catch (Exception e) {
+          throw new TezUncheckedException(e);
+        }
+        return 0;
+      }
+    });
+    
+    long totalLength = 0;
+    Map<String, LocationHolder> distinctLocations = new HashMap<String, LocationHolder>();
+    // go through splits in sorted order and add them to locations
+    for (InputSplit split : originalSplits) {
+      totalLength += split.getLength();
+      String[] locations = split.getLocations();
+      if (locations == null || locations.length == 0) {
+        locations = emptyLocations;
+      }
+      for (String location : locations ) {
+        distinctLocations.put(location, null);
+      }
+    }
+    
+    long lengthPerSplit = totalLength/desiredNumSplits;
+    int numLocations = distinctLocations.size();
+    int numSplitsPerLocation = originalSplits.size()/numLocations;
+    int numSplitsInGroup = originalSplits.size()/desiredNumSplits;
+    for (String location : distinctLocations.keySet()) {
+      distinctLocations.put(location, new LocationHolder(numSplitsPerLocation));
+    }
+    
+    for (InputSplit split : originalSplits) {
+      SplitHolder splitHolder = new SplitHolder(split);
+      String[] locations = split.getLocations();
+      if (locations == null || locations.length == 0) {
+        locations = emptyLocations;
+      }
+      for (String location : locations ) {
+        LocationHolder holder = distinctLocations.get(location);
+        holder.splits.add(splitHolder); // added smallest to largest
+      }
+    }
+    
+    // go through locations and group splits
+    int splitsProcessed = 0;
+    List<SplitHolder> group = new ArrayList<SplitHolder>(numSplitsInGroup);
+    boolean allowSmallGroups = false;
+    while (splitsProcessed < originalSplits.size()) {
+      group.clear();
+      int numFullGroupsCreated = 0;
+      for (Map.Entry<String, LocationHolder> entry : distinctLocations.entrySet()) {
+        String location = entry.getKey();
+        LocationHolder holder = entry.getValue();
+        SplitHolder splitHolder = holder.getUnprocessedHeadSplit();
+        if (splitHolder == null) {
+          // all splits on node processed
+          continue;
+        }
+        int oldHeadIndex = holder.headIndex;
+        long groupLength = 0;
+        do {
+          group.add(splitHolder);
+          groupLength += splitHolder.split.getLength();
+          holder.incrementHeadIndex();
+          splitHolder = holder.getUnprocessedHeadSplit();
+        } while(splitHolder != null && 
+            groupLength + splitHolder.split.getLength() <= lengthPerSplit);
+
+        if (holder.isEmpty() && groupLength < lengthPerSplit/2 && !allowSmallGroups) {
+          // group too small, reset it
+          holder.headIndex = oldHeadIndex;
+          continue;
+        }
+        
+        numFullGroupsCreated++;
+
+        // One split group created
+        String[] groupLocation = {location};
+        if (location == emptyLocations[0]) {
+          groupLocation = null;
+        }
+        TezGroupedSplit groupedSplit = 
+            new TezGroupedSplit(group.size(), wrappedInputFormatName, groupLocation);
+        for (SplitHolder groupedSplitHolder : group) {
+          groupedSplit.addSplit(groupedSplitHolder.split);
+          groupedSplitHolder.isProcessed = true;
+          splitsProcessed++;
+        }
+        groupedSplits.add(groupedSplit);
+      }
+      
+      if (!allowSmallGroups && numFullGroupsCreated < numLocations/4) {
+        // a few nodes have a lot of data or data is thinly spread across nodes
+        // so allow small groups now
+        allowSmallGroups = true;
+      }
+    }
+    
+    return groupedSplits;
+  }
+
+  @Override
+  public RecordReader<K, V> createRecordReader(InputSplit split,
+      TaskAttemptContext context) throws IOException, InterruptedException {
+    TezGroupedSplit groupedSplit = (TezGroupedSplit) split;
+    initInputFormatFromSplit(groupedSplit);
+    return new TezGroupedSplitsRecordReader(groupedSplit, context);
+  }
+  
+  @SuppressWarnings({ "rawtypes", "unchecked" })
+  void initInputFormatFromSplit(TezGroupedSplit split) {
+    if (wrappedInputFormat == null) {
+      Class<? extends InputFormat> clazz = (Class<? extends InputFormat>) 
+          getClassFromName(split.wrappedInputFormatName);
+      try {
+        wrappedInputFormat = clazz.newInstance();
+      } catch (Exception e) {
+        throw new TezUncheckedException(e);
+      }
+    }
+  }
+  
+  static Class<?> getClassFromName(String name) {
+    try {
+      return Class.forName(name);
+    } catch (ClassNotFoundException e1) {
+      throw new TezUncheckedException(e1);
+    }
+  }
+  
+  public class TezGroupedSplitsRecordReader  extends RecordReader<K, V> {
+
+    TezGroupedSplit groupedSplit;
+    TaskAttemptContext context;
+    int idx = 0;
+    long progress;
+    RecordReader<K, V> curReader;
+    
+    public TezGroupedSplitsRecordReader(TezGroupedSplit split,
+        TaskAttemptContext context) throws IOException {
+      this.groupedSplit = split;
+      this.context = context;
+    }
+    
+    public void initialize(InputSplit split,
+        TaskAttemptContext context) throws IOException, InterruptedException {
+      if (this.groupedSplit != split) {
+        throw new TezUncheckedException("Splits dont match");
+      }
+      if (this.context != context) {
+        throw new TezUncheckedException("Contexts dont match");
+      }
+      initNextRecordReader();
+    }
+    
+    public boolean nextKeyValue() throws IOException, InterruptedException {
+      while ((curReader == null) || !curReader.nextKeyValue()) {
+        // false return finishes. true return loops back for nextKeyValue()
+        if (!initNextRecordReader()) {
+          return false;
+        }
+      }
+      return true;
+    }
+
+    public K getCurrentKey() throws IOException, InterruptedException {
+      return curReader.getCurrentKey();
+    }
+    
+    public V getCurrentValue() throws IOException, InterruptedException {
+      return curReader.getCurrentValue();
+    }
+    
+    public void close() throws IOException {
+      if (curReader != null) {
+        curReader.close();
+        curReader = null;
+      }
+    }
+    
+    protected boolean initNextRecordReader() throws IOException {
+      if (curReader != null) {
+        curReader.close();
+        curReader = null;
+        if (idx > 0) {
+          try {
+            progress += groupedSplit.wrappedSplits.get(idx-1).getLength();
+          } catch (InterruptedException e) {
+            throw new TezUncheckedException(e);
+          }
+        }
+      }
+
+      // if all chunks have been processed, nothing more to do.
+      if (idx == groupedSplit.wrappedSplits.size()) {
+        return false;
+      }
+
+      // get a record reader for the idx-th chunk
+      try {
+        curReader = wrappedInputFormat.createRecordReader(
+            groupedSplit.wrappedSplits.get(idx), context);
+
+        curReader.initialize(groupedSplit, context);
+      } catch (Exception e) {
+        throw new RuntimeException (e);
+      }
+      idx++;
+      return true;
+    }
+    
+    /**
+     * return progress based on the amount of data processed so far.
+     */
+    public float getProgress() throws IOException, InterruptedException {
+      long subprogress = 0;    // bytes processed in current split
+      if (null != curReader) {
+        // idx is always one past the current subsplit's true index.
+        subprogress = (long) (curReader.getProgress() * groupedSplit.wrappedSplits
+            .get(idx - 1).getLength());
+      }
+      return Math.min(1.0f,  (progress + subprogress)/(float)(groupedSplit.getLength()));
+    }
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/4dd5e195/tez-mapreduce/src/test/java/org/apache/hadoop/mapred/split/TestGroupedSplits.java
----------------------------------------------------------------------
diff --git a/tez-mapreduce/src/test/java/org/apache/hadoop/mapred/split/TestGroupedSplits.java b/tez-mapreduce/src/test/java/org/apache/hadoop/mapred/split/TestGroupedSplits.java
new file mode 100644
index 0000000..f3e549f
--- /dev/null
+++ b/tez-mapreduce/src/test/java/org/apache/hadoop/mapred/split/TestGroupedSplits.java
@@ -0,0 +1,278 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.mapred.split;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.io.OutputStreamWriter;
+import java.io.Writer;
+import java.util.ArrayList;
+import java.util.BitSet;
+import java.util.List;
+import java.util.Random;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.compress.CompressionCodec;
+import org.apache.hadoop.io.compress.GzipCodec;
+import org.apache.hadoop.mapred.FileInputFormat;
+import org.apache.hadoop.mapred.InputFormat;
+import org.apache.hadoop.mapred.InputSplit;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.hadoop.mapred.RecordReader;
+import org.apache.hadoop.mapred.Reporter;
+import org.apache.hadoop.mapred.TextInputFormat;
+import org.apache.hadoop.util.ReflectionUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class TestGroupedSplits {
+  private static final Log LOG =
+    LogFactory.getLog(TestGroupedSplits.class);
+
+  private static JobConf defaultConf = new JobConf();
+  private static FileSystem localFs = null;
+
+  static {
+    try {
+      defaultConf.set("fs.defaultFS", "file:///");
+      localFs = FileSystem.getLocal(defaultConf);
+    } catch (IOException e) {
+      throw new RuntimeException("init failure", e);
+    }
+  }
+
+  @SuppressWarnings("deprecation")
+  private static Path workDir =
+    new Path(new Path(System.getProperty("test.build.data", "/tmp")),
+             "TestCombineTextInputFormat").makeQualified(localFs);
+
+  // A reporter that does nothing
+  private static final Reporter voidReporter = Reporter.NULL;
+
+  @Test(timeout=10000)
+  public void testFormat() throws Exception {
+    JobConf job = new JobConf(defaultConf);
+
+    Random random = new Random();
+    long seed = random.nextLong();
+    LOG.info("seed = "+seed);
+    random.setSeed(seed);
+
+    localFs.delete(workDir, true);
+    FileInputFormat.setInputPaths(job, workDir);
+
+    final int length = 10000;
+    final int numFiles = 10;
+
+    createFiles(length, numFiles, random);
+
+    // create a combined split for the files
+    TextInputFormat wrappedFormat = new TextInputFormat();
+    wrappedFormat.configure(job);
+    TezGroupedSplitsInputFormat<LongWritable , Text> format = 
+        new TezGroupedSplitsInputFormat<LongWritable, Text>();
+    format.setDesiredNumberOfSPlits(1);
+    format.setInputFormat(wrappedFormat);
+    LongWritable key = new LongWritable();
+    Text value = new Text();
+    for (int i = 0; i < 3; i++) {
+      int numSplits = random.nextInt(length/20)+1;
+      LOG.info("splitting: requesting = " + numSplits);
+      InputSplit[] splits = format.getSplits(job, numSplits);
+      LOG.info("splitting: got =        " + splits.length);
+
+      // we should have a single split as the length is comfortably smaller than
+      // the block size
+      Assert.assertEquals("We got more than one splits!", 1, splits.length);
+      InputSplit split = splits[0];
+      Assert.assertEquals("It should be TezGroupedSplit",
+        TezGroupedSplit.class, split.getClass());
+
+      // check the split
+      BitSet bits = new BitSet(length);
+      LOG.debug("split= " + split);
+      RecordReader<LongWritable, Text> reader =
+        format.getRecordReader(split, job, voidReporter);
+      try {
+        int count = 0;
+        while (reader.next(key, value)) {
+          int v = Integer.parseInt(value.toString());
+          LOG.debug("read " + v);
+          if (bits.get(v)) {
+            LOG.warn("conflict with " + v +
+                     " at position "+reader.getPos());
+          }
+          Assert.assertFalse("Key in multiple partitions.", bits.get(v));
+          bits.set(v);
+          count++;
+        }
+        LOG.info("splits="+split+" count=" + count);
+      } finally {
+        reader.close();
+      }
+      Assert.assertEquals("Some keys in no partition.", length, bits.cardinality());
+    }
+  }
+
+  private static class Range {
+    private final int start;
+    private final int end;
+
+    Range(int start, int end) {
+      this.start = start;
+      this.end = end;
+    }
+
+    @Override
+    public String toString() {
+      return "(" + start + ", " + end + ")";
+    }
+  }
+
+  private static Range[] createRanges(int length, int numFiles, Random random) {
+    // generate a number of files with various lengths
+    Range[] ranges = new Range[numFiles];
+    for (int i = 0; i < numFiles; i++) {
+      int start = i == 0 ? 0 : ranges[i-1].end;
+      int end = i == numFiles - 1 ?
+        length :
+        (length/numFiles)*(2*i + 1)/2 + random.nextInt(length/numFiles) + 1;
+      ranges[i] = new Range(start, end);
+    }
+    return ranges;
+  }
+
+  private static void createFiles(int length, int numFiles, Random random)
+    throws IOException {
+    Range[] ranges = createRanges(length, numFiles, random);
+
+    for (int i = 0; i < numFiles; i++) {
+      Path file = new Path(workDir, "test_" + i + ".txt");
+      Writer writer = new OutputStreamWriter(localFs.create(file));
+      Range range = ranges[i];
+      try {
+        for (int j = range.start; j < range.end; j++) {
+          writer.write(Integer.toString(j));
+          writer.write("\n");
+        }
+      } finally {
+        writer.close();
+      }
+    }
+  }
+
+  private static void writeFile(FileSystem fs, Path name,
+                                CompressionCodec codec,
+                                String contents) throws IOException {
+    OutputStream stm;
+    if (codec == null) {
+      stm = fs.create(name);
+    } else {
+      stm = codec.createOutputStream(fs.create(name));
+    }
+    stm.write(contents.getBytes());
+    stm.close();
+  }
+
+  private static List<Text> readSplit(InputFormat<LongWritable,Text> format,
+                                      InputSplit split,
+                                      JobConf job) throws IOException {
+    List<Text> result = new ArrayList<Text>();
+    RecordReader<LongWritable, Text> reader =
+      format.getRecordReader(split, job, voidReporter);
+    LongWritable key = reader.createKey();
+    Text value = reader.createValue();
+    while (reader.next(key, value)) {
+      result.add(value);
+      value = reader.createValue();
+    }
+    reader.close();
+    return result;
+  }
+
+  /**
+   * Test using the gzip codec for reading
+   */
+  @Test(timeout=10000)
+  public void testGzip() throws IOException {
+    JobConf job = new JobConf(defaultConf);
+    CompressionCodec gzip = new GzipCodec();
+    ReflectionUtils.setConf(gzip, job);
+    localFs.delete(workDir, true);
+    writeFile(localFs, new Path(workDir, "part1.txt.gz"), gzip,
+              "the quick\nbrown\nfox jumped\nover\n the lazy\n dog\n");
+    writeFile(localFs, new Path(workDir, "part2.txt.gz"), gzip,
+              "is\ngzip\n");
+    writeFile(localFs, new Path(workDir, "part3.txt.gz"), gzip,
+        "one\nmore\nsplit\n");
+    FileInputFormat.setInputPaths(job, workDir);
+    TextInputFormat wrappedFormat = new TextInputFormat();
+    wrappedFormat.configure(job);
+    TezGroupedSplitsInputFormat<LongWritable , Text> format = 
+        new TezGroupedSplitsInputFormat<LongWritable, Text>();
+    format.setInputFormat(wrappedFormat);
+    
+    // TextInputFormat will produce 3 splits
+    for (int j=1; j<=3; ++j) {
+      format.setDesiredNumberOfSPlits(j);
+      InputSplit[] splits = format.getSplits(job, 100);
+      Assert.assertEquals("compressed splits == " + j, j, splits.length);
+      // j==3 cases exercises the code where desired == actual
+      // and does not do grouping
+      List<Text> results = new ArrayList<Text>();
+      for (int i=0; i<splits.length; ++i) { 
+        List<Text> read = readSplit(format, splits[i], job);
+        results.addAll(read);
+      }
+      Assert.assertEquals("splits length", 11, results.size());
+  
+      final String[] firstList =
+        {"the quick", "brown", "fox jumped", "over", " the lazy", " dog"};
+      final String[] secondList = {"is", "gzip"};
+      final String[] thirdList = {"one", "more", "split"};
+      String first = results.get(0).toString();
+      int start = 0;
+      switch (first.charAt(0)) {
+      case 't':
+        start = testResults(results, firstList, start);
+        break;
+      case 'i':
+        start = testResults(results, secondList, start);
+        break;
+      case 'o':
+        start = testResults(results, thirdList, start);
+        break;
+      default:
+        Assert.fail("unexpected first token - " + first);
+      }
+    }
+  }
+
+  private static int testResults(List<Text> results, String[] first, int start) {
+    for (int i = 0; i < first.length; i++) {
+      Assert.assertEquals("splits["+i+"]", first[i], results.get(start+i).toString());
+    }
+    return first.length+start;
+  }
+}