You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tez.apache.org by bi...@apache.org on 2013/10/08 03:29:05 UTC
git commit: TEZ-534. Add an InputFormat that combines original splits
into groups (bikas)
Updated Branches:
refs/heads/master eb92543b3 -> 4dd5e195d
TEZ-534. Add an InputFormat that combines original splits into groups (bikas)
Project: http://git-wip-us.apache.org/repos/asf/incubator-tez/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-tez/commit/4dd5e195
Tree: http://git-wip-us.apache.org/repos/asf/incubator-tez/tree/4dd5e195
Diff: http://git-wip-us.apache.org/repos/asf/incubator-tez/diff/4dd5e195
Branch: refs/heads/master
Commit: 4dd5e195d074d36edfd414fac4f1554031257847
Parents: eb92543
Author: Bikas Saha <bi...@apache.org>
Authored: Mon Oct 7 18:22:09 2013 -0700
Committer: Bikas Saha <bi...@apache.org>
Committed: Mon Oct 7 18:22:09 2013 -0700
----------------------------------------------------------------------
.../hadoop/mapred/split/TezGroupedSplit.java | 137 +++++++
.../split/TezGroupedSplitsInputFormat.java | 380 +++++++++++++++++++
.../hadoop/mapreduce/split/TezGroupedSplit.java | 146 +++++++
.../split/TezGroupedSplitsInputFormat.java | 341 +++++++++++++++++
.../hadoop/mapred/split/TestGroupedSplits.java | 278 ++++++++++++++
5 files changed, 1282 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/4dd5e195/tez-mapreduce/src/main/java/org/apache/hadoop/mapred/split/TezGroupedSplit.java
----------------------------------------------------------------------
diff --git a/tez-mapreduce/src/main/java/org/apache/hadoop/mapred/split/TezGroupedSplit.java b/tez-mapreduce/src/main/java/org/apache/hadoop/mapred/split/TezGroupedSplit.java
new file mode 100644
index 0000000..0bb383d
--- /dev/null
+++ b/tez-mapreduce/src/main/java/org/apache/hadoop/mapred/split/TezGroupedSplit.java
@@ -0,0 +1,137 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.mapred.split;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapred.InputSplit;
+import org.apache.tez.dag.api.TezUncheckedException;
+
+public class TezGroupedSplit implements InputSplit {
+
+ List<InputSplit> wrappedSplits = null;
+ String wrappedInputFormatName = null;
+ String[] locations = null;
+ long length = 0;
+
+ public TezGroupedSplit() {
+
+ }
+
+ public TezGroupedSplit(int numSplits, String wrappedInputFormatName,
+ String[] locations) {
+ this.wrappedSplits = new ArrayList<InputSplit>(numSplits);
+ this.wrappedInputFormatName = wrappedInputFormatName;
+ this.locations = locations;
+ }
+
+ public void addSplit(InputSplit split) {
+ wrappedSplits.add(split);
+ try {
+ length += split.getLength();
+ } catch (Exception e) {
+ throw new TezUncheckedException(e);
+ }
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ if (wrappedSplits == null) {
+ throw new TezUncheckedException("Wrapped splits cannot be empty");
+ }
+
+ Text.writeString(out, wrappedInputFormatName);
+ Text.writeString(out, wrappedSplits.get(0).getClass().getCanonicalName());
+ out.writeInt(wrappedSplits.size());
+ for(InputSplit split : wrappedSplits) {
+ writeWrappedSplit(split, out);
+ }
+ out.writeLong(length);
+
+ if (locations == null) {
+ out.writeInt(0);
+ } else {
+ out.writeInt(locations.length);
+ for (String location : locations) {
+ Text.writeString(out, location);
+ }
+ }
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ wrappedInputFormatName = Text.readString(in);
+ String inputSplitClassName = Text.readString(in);
+ Class<? extends InputSplit> clazz =
+ (Class<? extends InputSplit>)
+ TezGroupedSplitsInputFormat.getClassFromName(inputSplitClassName);
+
+ int numSplits = in.readInt();
+
+ wrappedSplits = new ArrayList<InputSplit>(numSplits);
+ for (int i=0; i<numSplits; ++i) {
+ addSplit(readWrappedSplit(in, clazz));
+ }
+
+ long recordedLength = in.readLong();
+ if(recordedLength != length) {
+ throw new TezUncheckedException("Expected length: " + recordedLength
+ + " actual length: " + length);
+ }
+ int numLocs = in.readInt();
+ if (numLocs > 0) {
+ locations = new String[numLocs];
+ for (int i=0; i<numLocs; ++i) {
+ locations[i] = Text.readString(in);
+ }
+ }
+ }
+
+ void writeWrappedSplit(InputSplit split, DataOutput out) throws IOException {
+ split.write(out);
+ }
+
+ InputSplit readWrappedSplit(DataInput in, Class<? extends InputSplit> clazz)
+ throws IOException {
+ InputSplit split;
+ try {
+ split = clazz.newInstance();
+ } catch (Exception e) {
+ throw new TezUncheckedException(e);
+ }
+ split.readFields(in);
+ return split;
+ }
+
+ @Override
+ public long getLength() throws IOException {
+ return length;
+ }
+
+ @Override
+ public String[] getLocations() throws IOException {
+ return locations;
+ }
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/4dd5e195/tez-mapreduce/src/main/java/org/apache/hadoop/mapred/split/TezGroupedSplitsInputFormat.java
----------------------------------------------------------------------
diff --git a/tez-mapreduce/src/main/java/org/apache/hadoop/mapred/split/TezGroupedSplitsInputFormat.java b/tez-mapreduce/src/main/java/org/apache/hadoop/mapred/split/TezGroupedSplitsInputFormat.java
new file mode 100644
index 0000000..6edd740
--- /dev/null
+++ b/tez-mapreduce/src/main/java/org/apache/hadoop/mapred/split/TezGroupedSplitsInputFormat.java
@@ -0,0 +1,380 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.mapred.split;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.mapred.InputFormat;
+import org.apache.hadoop.mapred.InputSplit;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.hadoop.mapred.RecordReader;
+import org.apache.hadoop.mapred.Reporter;
+import org.apache.tez.dag.api.TezUncheckedException;
+
+import com.google.common.base.Preconditions;
+
+public class TezGroupedSplitsInputFormat<K, V> implements InputFormat<K, V> {
+ private static final Log LOG = LogFactory.getLog(TezGroupedSplitsInputFormat.class);
+
+ InputFormat<K, V> wrappedInputFormat;
+ int desiredNumSplits = 0;
+
+ public TezGroupedSplitsInputFormat() {
+
+ }
+
+ public void setInputFormat(InputFormat<K, V> wrappedInputFormat) {
+ this.wrappedInputFormat = wrappedInputFormat;
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("wrappedInputFormat: " + wrappedInputFormat.getClass().getName());
+ }
+ }
+
+ public void setDesiredNumberOfSPlits(int num) {
+ Preconditions.checkArgument(num > 0);
+ this.desiredNumSplits = num;
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("desiredNumSplits: " + desiredNumSplits);
+ }
+ }
+
+ class SplitHolder {
+ InputSplit split;
+ boolean isProcessed = false;
+ SplitHolder(InputSplit split) {
+ this.split = split;
+ }
+ }
+
+ class LocationHolder {
+ List<SplitHolder> splits;
+ int headIndex = 0;
+ LocationHolder(int capacity) {
+ splits = new ArrayList<SplitHolder>(capacity);
+ }
+ boolean isEmpty() {
+ return (headIndex == splits.size());
+ }
+ SplitHolder getUnprocessedHeadSplit() {
+ while (!isEmpty()) {
+ SplitHolder holder = splits.get(headIndex);
+ if (!holder.isProcessed) {
+ return holder;
+ }
+ incrementHeadIndex();
+ }
+ return null;
+ }
+ void incrementHeadIndex() {
+ headIndex++;
+ }
+ }
+
+ @Override
+ public InputSplit[] getSplits(JobConf job, int numSplits) throws IOException {
+ if (desiredNumSplits > 0) {
+ // get the desired num splits directly if possible
+ numSplits = desiredNumSplits;
+ }
+ InputSplit[] originalSplits = wrappedInputFormat.getSplits(job, numSplits);
+ String wrappedInputFormatName = wrappedInputFormat.getClass().getCanonicalName();
+ if (desiredNumSplits == 0 ||
+ originalSplits.length == 0 ||
+ desiredNumSplits >= originalSplits.length) {
+ // nothing set. so return all the splits as is
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Using original number of splits: " + originalSplits.length);
+ }
+ InputSplit[] groupedSplits = new TezGroupedSplit[originalSplits.length];
+ int i=0;
+ for (InputSplit split : originalSplits) {
+ TezGroupedSplit newSplit =
+ new TezGroupedSplit(1, wrappedInputFormatName, split.getLocations());
+ newSplit.addSplit(split);
+ groupedSplits[i++] = newSplit;
+ }
+ return groupedSplits;
+ }
+
+ String[] emptyLocations = {"EmptyLocation"};
+ List<InputSplit> groupedSplitsList = new ArrayList<InputSplit>(desiredNumSplits);
+
+ // sort the splits by length
+ Arrays.sort(originalSplits, new Comparator<InputSplit>() {
+ @Override
+ public int compare(InputSplit o1, InputSplit o2) {
+ try {
+ if (o1.getLength() < o2.getLength()) {
+ return -1;
+ } else if (o1.getLength() > o2.getLength()) {
+ return 1;
+ }
+ } catch (Exception e) {
+ throw new TezUncheckedException(e);
+ }
+ return 0;
+ }
+ });
+
+ long totalLength = 0;
+ Map<String, LocationHolder> distinctLocations = new HashMap<String, LocationHolder>();
+ // go through splits in sorted order and add them to locations
+ for (InputSplit split : originalSplits) {
+ totalLength += split.getLength();
+ String[] locations = split.getLocations();
+ if (locations == null || locations.length == 0) {
+ locations = emptyLocations;
+ }
+ for (String location : locations ) {
+ distinctLocations.put(location, null);
+ }
+ }
+
+ long lengthPerSplit = totalLength/desiredNumSplits;
+ int numLocations = distinctLocations.size();
+ int numSplitsPerLocation = originalSplits.length/numLocations;
+ int numSplitsInGroup = originalSplits.length/desiredNumSplits;
+ for (String location : distinctLocations.keySet()) {
+ distinctLocations.put(location, new LocationHolder(numSplitsPerLocation));
+ }
+
+ for (InputSplit split : originalSplits) {
+ SplitHolder splitHolder = new SplitHolder(split);
+ String[] locations = split.getLocations();
+ if (locations == null || locations.length == 0) {
+ locations = emptyLocations;
+ }
+ for (String location : locations ) {
+ LocationHolder holder = distinctLocations.get(location);
+ holder.splits.add(splitHolder); // added smallest to largest
+ }
+ }
+
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Desired lengthPerSplit: " + lengthPerSplit +
+ " numLocations: " + numLocations +
+ " numSplitsPerLocation: " + numSplitsPerLocation +
+ " numSplitsInGroup: " + numSplitsInGroup +
+ " totalLength: " + totalLength);
+ }
+
+ // go through locations and group splits
+ int splitsProcessed = 0;
+ List<SplitHolder> group = new ArrayList<SplitHolder>(numSplitsInGroup+1);
+ boolean allowSmallGroups = false;
+ int iterations = 0;
+ while (splitsProcessed < originalSplits.length) {
+ group.clear();
+ iterations++;
+ int numFullGroupsCreated = 0;
+ for (Map.Entry<String, LocationHolder> entry : distinctLocations.entrySet()) {
+ String location = entry.getKey();
+ LocationHolder holder = entry.getValue();
+ SplitHolder splitHolder = holder.getUnprocessedHeadSplit();
+ if (splitHolder == null) {
+ // all splits on node processed
+ continue;
+ }
+ int oldHeadIndex = holder.headIndex;
+ long groupLength = 0;
+ do {
+ group.add(splitHolder);
+ groupLength += splitHolder.split.getLength();
+ holder.incrementHeadIndex();
+ splitHolder = holder.getUnprocessedHeadSplit();
+ } while(splitHolder != null &&
+ groupLength + splitHolder.split.getLength() <= lengthPerSplit);
+
+ if (holder.isEmpty() && groupLength < lengthPerSplit/2 && !allowSmallGroups) {
+ // group too small, reset it
+ holder.headIndex = oldHeadIndex;
+ continue;
+ }
+
+ numFullGroupsCreated++;
+
+ // One split group created
+ String[] groupLocation = {location};
+ if (location == emptyLocations[0]) {
+ groupLocation = null;
+ }
+ TezGroupedSplit groupedSplit =
+ new TezGroupedSplit(group.size(), wrappedInputFormatName, groupLocation);
+ for (SplitHolder groupedSplitHolder : group) {
+ groupedSplit.addSplit(groupedSplitHolder.split);
+ groupedSplitHolder.isProcessed = true;
+ splitsProcessed++;
+ }
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Grouped " + group.size() + " split at: " + groupLocation);
+ }
+ groupedSplitsList.add(groupedSplit);
+ }
+
+ if (!allowSmallGroups && numFullGroupsCreated < numLocations/4) {
+ // a few nodes have a lot of data or data is thinly spread across nodes
+ // so allow small groups now
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Allowing small groups");
+ }
+ allowSmallGroups = true;
+ }
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Iteration: " + iterations +
+ " splitsProcessed: " + splitsProcessed +
+ " numFullGroupsInRound: " + numFullGroupsCreated +
+ " totalGroups: " + groupedSplitsList.size());
+ }
+ }
+ InputSplit[] groupedSplits = new InputSplit[groupedSplitsList.size()];
+ groupedSplitsList.toArray(groupedSplits);
+ LOG.info("Number of splits created: " + groupedSplitsList.size());
+ return groupedSplits;
+ }
+
+ @Override
+ public RecordReader<K, V> getRecordReader(InputSplit split, JobConf job,
+ Reporter reporter) throws IOException {
+ TezGroupedSplit groupedSplit = (TezGroupedSplit) split;
+ initInputFormatFromSplit(groupedSplit);
+ return new TezGroupedSplitsRecordReader(groupedSplit, job, reporter);
+ }
+
+ @SuppressWarnings({ "unchecked", "rawtypes" })
+ void initInputFormatFromSplit(TezGroupedSplit split) {
+ if (wrappedInputFormat == null) {
+ Class<? extends InputFormat> clazz = (Class<? extends InputFormat>)
+ getClassFromName(split.wrappedInputFormatName);
+ try {
+ wrappedInputFormat = clazz.newInstance();
+ } catch (Exception e) {
+ throw new TezUncheckedException(e);
+ }
+ }
+ }
+
+ static Class<?> getClassFromName(String name) {
+ try {
+ return Class.forName(name);
+ } catch (ClassNotFoundException e1) {
+ throw new TezUncheckedException(e1);
+ }
+ }
+
+ public class TezGroupedSplitsRecordReader implements RecordReader<K, V> {
+
+ TezGroupedSplit groupedSplit;
+ JobConf job;
+ Reporter reporter;
+ int idx = 0;
+ long progress;
+ RecordReader<K, V> curReader;
+
+ public TezGroupedSplitsRecordReader(TezGroupedSplit split, JobConf job,
+ Reporter reporter) throws IOException {
+ this.groupedSplit = split;
+ this.job = job;
+ this.reporter = reporter;
+ initNextRecordReader();
+ }
+
+ @Override
+ public boolean next(K key, V value) throws IOException {
+
+ while ((curReader == null) || !curReader.next(key, value)) {
+ if (!initNextRecordReader()) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ @Override
+ public K createKey() {
+ return curReader.createKey();
+ }
+
+ @Override
+ public V createValue() {
+ return curReader.createValue();
+ }
+
+ @Override
+ public float getProgress() throws IOException {
+ return Math.min(1.0f, getPos()/(float)(groupedSplit.getLength()));
+ }
+
+ @Override
+ public void close() throws IOException {
+ if (curReader != null) {
+ curReader.close();
+ curReader = null;
+ }
+ }
+
+ protected boolean initNextRecordReader() throws IOException {
+ if (curReader != null) {
+ curReader.close();
+ curReader = null;
+ if (idx > 0) {
+ progress += groupedSplit.wrappedSplits.get(idx-1).getLength();
+ }
+ }
+
+ // if all chunks have been processed, nothing more to do.
+ if (idx == groupedSplit.wrappedSplits.size()) {
+ return false;
+ }
+
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Init record reader for index " + idx + " of " +
+ groupedSplit.wrappedSplits.size());
+ }
+
+ // get a record reader for the idx-th chunk
+ try {
+ curReader = wrappedInputFormat.getRecordReader(
+ groupedSplit.wrappedSplits.get(idx), job, reporter);
+ } catch (Exception e) {
+ throw new RuntimeException (e);
+ }
+ idx++;
+ return true;
+ }
+
+ @Override
+ public long getPos() throws IOException {
+ long subprogress = 0; // bytes processed in current split
+ if (null != curReader) {
+ // idx is always one past the current subsplit's true index.
+ subprogress = curReader.getPos();
+ }
+ return (progress + subprogress);
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/4dd5e195/tez-mapreduce/src/main/java/org/apache/hadoop/mapreduce/split/TezGroupedSplit.java
----------------------------------------------------------------------
diff --git a/tez-mapreduce/src/main/java/org/apache/hadoop/mapreduce/split/TezGroupedSplit.java b/tez-mapreduce/src/main/java/org/apache/hadoop/mapreduce/split/TezGroupedSplit.java
new file mode 100644
index 0000000..afc3108
--- /dev/null
+++ b/tez-mapreduce/src/main/java/org/apache/hadoop/mapreduce/split/TezGroupedSplit.java
@@ -0,0 +1,146 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.mapreduce.split;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.mapreduce.InputSplit;
+import org.apache.tez.dag.api.TezUncheckedException;
+
+public class TezGroupedSplit extends InputSplit implements Writable {
+
+ List<InputSplit> wrappedSplits = null;
+ String wrappedInputFormatName = null;
+ String[] locations = null;
+ long length = 0;
+
+ public TezGroupedSplit() {
+
+ }
+
+ public TezGroupedSplit(int numSplits, String wrappedInputFormatName,
+ String[] locations) {
+ this.wrappedSplits = new ArrayList<InputSplit>(numSplits);
+ this.wrappedInputFormatName = wrappedInputFormatName;
+ this.locations = locations;
+ }
+
+ public void addSplit(InputSplit split) {
+ wrappedSplits.add(split);
+ try {
+ length += split.getLength();
+ } catch (Exception e) {
+ throw new TezUncheckedException(e);
+ }
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ if (wrappedSplits == null) {
+ throw new TezUncheckedException("Wrapped splits cannot be empty");
+ }
+
+ Text.writeString(out, wrappedInputFormatName);
+ Text.writeString(out, wrappedSplits.get(0).getClass().getCanonicalName());
+ out.writeInt(wrappedSplits.size());
+ for(InputSplit split : wrappedSplits) {
+ writeWrappedSplit(split, out);
+ }
+ out.writeLong(length);
+
+ if (locations == null) {
+ out.writeInt(0);
+ } else {
+ out.writeInt(locations.length);
+ for (String location : locations) {
+ Text.writeString(out, location);
+ }
+ }
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ wrappedInputFormatName = Text.readString(in);
+ String inputSplitClassName = Text.readString(in);
+ Class<? extends InputSplit> clazz =
+ (Class<? extends InputSplit>)
+ TezGroupedSplitsInputFormat.getClassFromName(inputSplitClassName);
+
+ int numSplits = in.readInt();
+
+ wrappedSplits = new ArrayList<InputSplit>(numSplits);
+ for (int i=0; i<numSplits; ++i) {
+ addSplit(readWrappedSplit(in, clazz));
+ }
+
+ long recordedLength = in.readLong();
+ if(recordedLength != length) {
+ throw new TezUncheckedException("Expected length: " + recordedLength
+ + " actual length: " + length);
+ }
+ int numLocs = in.readInt();
+ if (numLocs > 0) {
+ locations = new String[numLocs];
+ for (int i=0; i<numLocs; ++i) {
+ locations[i] = Text.readString(in);
+ }
+ }
+ }
+
+ void writeWrappedSplit(InputSplit split, DataOutput out) throws IOException {
+ if (split instanceof Writable) {
+ ((Writable) split).write(out);
+ } else {
+ throw new TezUncheckedException(
+ split.getClass().getName() + " is not a Writable");
+ }
+ }
+
+ InputSplit readWrappedSplit(DataInput in, Class<? extends InputSplit> clazz) {
+ try {
+ InputSplit split = clazz.newInstance();
+ if (split instanceof Writable) {
+ ((Writable) split).readFields(in);
+ return split;
+ } else {
+ throw new TezUncheckedException(
+ split.getClass().getName() + " is not a Writable");
+ }
+ } catch (Exception e) {
+ throw new TezUncheckedException(e);
+ }
+ }
+
+ @Override
+ public long getLength() throws IOException, InterruptedException {
+ return length;
+ }
+
+ @Override
+ public String[] getLocations() throws IOException, InterruptedException {
+ return locations;
+ }
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/4dd5e195/tez-mapreduce/src/main/java/org/apache/hadoop/mapreduce/split/TezGroupedSplitsInputFormat.java
----------------------------------------------------------------------
diff --git a/tez-mapreduce/src/main/java/org/apache/hadoop/mapreduce/split/TezGroupedSplitsInputFormat.java b/tez-mapreduce/src/main/java/org/apache/hadoop/mapreduce/split/TezGroupedSplitsInputFormat.java
new file mode 100644
index 0000000..4fd44d7
--- /dev/null
+++ b/tez-mapreduce/src/main/java/org/apache/hadoop/mapreduce/split/TezGroupedSplitsInputFormat.java
@@ -0,0 +1,341 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.mapreduce.split;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.hadoop.mapreduce.InputFormat;
+import org.apache.hadoop.mapreduce.InputSplit;
+import org.apache.hadoop.mapreduce.JobContext;
+import org.apache.hadoop.mapreduce.RecordReader;
+import org.apache.hadoop.mapreduce.TaskAttemptContext;
+import org.apache.tez.dag.api.TezUncheckedException;
+
+public class TezGroupedSplitsInputFormat<K, V> extends InputFormat<K, V> {
+
+ InputFormat<K, V> wrappedInputFormat;
+ int desiredNumSplits = 0;
+ List<InputSplit> groupedSplits = null;
+
+ public TezGroupedSplitsInputFormat() {
+
+ }
+
+ public void setInputFormat(InputFormat<K, V> wrappedInputFormat) {
+ this.wrappedInputFormat = wrappedInputFormat;
+ }
+
+ public void setDesiredNumberOfSPlits(int num) {
+ this.desiredNumSplits = num;
+ }
+
+ class SplitHolder {
+ InputSplit split;
+ boolean isProcessed = false;
+ SplitHolder(InputSplit split) {
+ this.split = split;
+ }
+ }
+
+ class LocationHolder {
+ List<SplitHolder> splits;
+ int headIndex = 0;
+ LocationHolder(int capacity) {
+ splits = new ArrayList<SplitHolder>(capacity);
+ }
+ boolean isEmpty() {
+ return (headIndex == splits.size());
+ }
+ SplitHolder getUnprocessedHeadSplit() {
+ while (!isEmpty()) {
+ SplitHolder holder = splits.get(headIndex);
+ if (!holder.isProcessed) {
+ return holder;
+ }
+ incrementHeadIndex();
+ }
+ return null;
+ }
+ void incrementHeadIndex() {
+ headIndex++;
+ }
+ }
+
+ @Override
+ public List<InputSplit> getSplits(JobContext context) throws IOException,
+ InterruptedException {
+ List<InputSplit> originalSplits = wrappedInputFormat.getSplits(context);
+ String wrappedInputFormatName = wrappedInputFormat.getClass().getCanonicalName();
+ if (desiredNumSplits == 0 ||
+ originalSplits.size() == 0 ||
+ desiredNumSplits >= originalSplits.size()) {
+ // nothing set. so return all the splits as is
+ groupedSplits = new ArrayList<InputSplit>(originalSplits.size());
+ for (InputSplit split : originalSplits) {
+ TezGroupedSplit newSplit =
+ new TezGroupedSplit(1, wrappedInputFormatName, split.getLocations());
+ newSplit.addSplit(split);
+ groupedSplits.add(newSplit);
+ }
+ return groupedSplits;
+ }
+
+ String[] emptyLocations = {"EmptyLocation"};
+ groupedSplits = new ArrayList<InputSplit>(desiredNumSplits);
+
+ // sort the splits by length
+ Collections.sort(originalSplits, new Comparator<InputSplit>() {
+ @Override
+ public int compare(InputSplit o1, InputSplit o2) {
+ try {
+ if (o1.getLength() < o2.getLength()) {
+ return -1;
+ } else if (o1.getLength() > o2.getLength()) {
+ return 1;
+ }
+ } catch (Exception e) {
+ throw new TezUncheckedException(e);
+ }
+ return 0;
+ }
+ });
+
+ long totalLength = 0;
+ Map<String, LocationHolder> distinctLocations = new HashMap<String, LocationHolder>();
+ // go through splits in sorted order and add them to locations
+ for (InputSplit split : originalSplits) {
+ totalLength += split.getLength();
+ String[] locations = split.getLocations();
+ if (locations == null || locations.length == 0) {
+ locations = emptyLocations;
+ }
+ for (String location : locations ) {
+ distinctLocations.put(location, null);
+ }
+ }
+
+ long lengthPerSplit = totalLength/desiredNumSplits;
+ int numLocations = distinctLocations.size();
+ int numSplitsPerLocation = originalSplits.size()/numLocations;
+ int numSplitsInGroup = originalSplits.size()/desiredNumSplits;
+ for (String location : distinctLocations.keySet()) {
+ distinctLocations.put(location, new LocationHolder(numSplitsPerLocation));
+ }
+
+ for (InputSplit split : originalSplits) {
+ SplitHolder splitHolder = new SplitHolder(split);
+ String[] locations = split.getLocations();
+ if (locations == null || locations.length == 0) {
+ locations = emptyLocations;
+ }
+ for (String location : locations ) {
+ LocationHolder holder = distinctLocations.get(location);
+ holder.splits.add(splitHolder); // added smallest to largest
+ }
+ }
+
+ // go through locations and group splits
+ int splitsProcessed = 0;
+ List<SplitHolder> group = new ArrayList<SplitHolder>(numSplitsInGroup);
+ boolean allowSmallGroups = false;
+ while (splitsProcessed < originalSplits.size()) {
+ group.clear();
+ int numFullGroupsCreated = 0;
+ for (Map.Entry<String, LocationHolder> entry : distinctLocations.entrySet()) {
+ String location = entry.getKey();
+ LocationHolder holder = entry.getValue();
+ SplitHolder splitHolder = holder.getUnprocessedHeadSplit();
+ if (splitHolder == null) {
+ // all splits on node processed
+ continue;
+ }
+ int oldHeadIndex = holder.headIndex;
+ long groupLength = 0;
+ do {
+ group.add(splitHolder);
+ groupLength += splitHolder.split.getLength();
+ holder.incrementHeadIndex();
+ splitHolder = holder.getUnprocessedHeadSplit();
+ } while(splitHolder != null &&
+ groupLength + splitHolder.split.getLength() <= lengthPerSplit);
+
+ if (holder.isEmpty() && groupLength < lengthPerSplit/2 && !allowSmallGroups) {
+ // group too small, reset it
+ holder.headIndex = oldHeadIndex;
+ continue;
+ }
+
+ numFullGroupsCreated++;
+
+ // One split group created
+ String[] groupLocation = {location};
+ if (location == emptyLocations[0]) {
+ groupLocation = null;
+ }
+ TezGroupedSplit groupedSplit =
+ new TezGroupedSplit(group.size(), wrappedInputFormatName, groupLocation);
+ for (SplitHolder groupedSplitHolder : group) {
+ groupedSplit.addSplit(groupedSplitHolder.split);
+ groupedSplitHolder.isProcessed = true;
+ splitsProcessed++;
+ }
+ groupedSplits.add(groupedSplit);
+ }
+
+ if (!allowSmallGroups && numFullGroupsCreated < numLocations/4) {
+ // a few nodes have a lot of data or data is thinly spread across nodes
+ // so allow small groups now
+ allowSmallGroups = true;
+ }
+ }
+
+ return groupedSplits;
+ }
+
+ @Override
+ public RecordReader<K, V> createRecordReader(InputSplit split,
+ TaskAttemptContext context) throws IOException, InterruptedException {
+ TezGroupedSplit groupedSplit = (TezGroupedSplit) split;
+ initInputFormatFromSplit(groupedSplit);
+ return new TezGroupedSplitsRecordReader(groupedSplit, context);
+ }
+
+ @SuppressWarnings({ "rawtypes", "unchecked" })
+ void initInputFormatFromSplit(TezGroupedSplit split) {
+ if (wrappedInputFormat == null) {
+ Class<? extends InputFormat> clazz = (Class<? extends InputFormat>)
+ getClassFromName(split.wrappedInputFormatName);
+ try {
+ wrappedInputFormat = clazz.newInstance();
+ } catch (Exception e) {
+ throw new TezUncheckedException(e);
+ }
+ }
+ }
+
+ static Class<?> getClassFromName(String name) {
+ try {
+ return Class.forName(name);
+ } catch (ClassNotFoundException e1) {
+ throw new TezUncheckedException(e1);
+ }
+ }
+
+ public class TezGroupedSplitsRecordReader extends RecordReader<K, V> {
+
+ TezGroupedSplit groupedSplit;
+ TaskAttemptContext context;
+ int idx = 0;
+ long progress;
+ RecordReader<K, V> curReader;
+
+ public TezGroupedSplitsRecordReader(TezGroupedSplit split,
+ TaskAttemptContext context) throws IOException {
+ this.groupedSplit = split;
+ this.context = context;
+ }
+
+ public void initialize(InputSplit split,
+ TaskAttemptContext context) throws IOException, InterruptedException {
+ if (this.groupedSplit != split) {
+ throw new TezUncheckedException("Splits dont match");
+ }
+ if (this.context != context) {
+ throw new TezUncheckedException("Contexts dont match");
+ }
+ initNextRecordReader();
+ }
+
+ public boolean nextKeyValue() throws IOException, InterruptedException {
+ while ((curReader == null) || !curReader.nextKeyValue()) {
+ // false return finishes. true return loops back for nextKeyValue()
+ if (!initNextRecordReader()) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ public K getCurrentKey() throws IOException, InterruptedException {
+ return curReader.getCurrentKey();
+ }
+
+ public V getCurrentValue() throws IOException, InterruptedException {
+ return curReader.getCurrentValue();
+ }
+
+ public void close() throws IOException {
+ if (curReader != null) {
+ curReader.close();
+ curReader = null;
+ }
+ }
+
+ protected boolean initNextRecordReader() throws IOException {
+ if (curReader != null) {
+ curReader.close();
+ curReader = null;
+ if (idx > 0) {
+ try {
+ progress += groupedSplit.wrappedSplits.get(idx-1).getLength();
+ } catch (InterruptedException e) {
+ throw new TezUncheckedException(e);
+ }
+ }
+ }
+
+ // if all chunks have been processed, nothing more to do.
+ if (idx == groupedSplit.wrappedSplits.size()) {
+ return false;
+ }
+
+ // get a record reader for the idx-th chunk
+ try {
+ curReader = wrappedInputFormat.createRecordReader(
+ groupedSplit.wrappedSplits.get(idx), context);
+
+ curReader.initialize(groupedSplit, context);
+ } catch (Exception e) {
+ throw new RuntimeException (e);
+ }
+ idx++;
+ return true;
+ }
+
+ /**
+ * return progress based on the amount of data processed so far.
+ */
+ public float getProgress() throws IOException, InterruptedException {
+ long subprogress = 0; // bytes processed in current split
+ if (null != curReader) {
+ // idx is always one past the current subsplit's true index.
+ subprogress = (long) (curReader.getProgress() * groupedSplit.wrappedSplits
+ .get(idx - 1).getLength());
+ }
+ return Math.min(1.0f, (progress + subprogress)/(float)(groupedSplit.getLength()));
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/4dd5e195/tez-mapreduce/src/test/java/org/apache/hadoop/mapred/split/TestGroupedSplits.java
----------------------------------------------------------------------
diff --git a/tez-mapreduce/src/test/java/org/apache/hadoop/mapred/split/TestGroupedSplits.java b/tez-mapreduce/src/test/java/org/apache/hadoop/mapred/split/TestGroupedSplits.java
new file mode 100644
index 0000000..f3e549f
--- /dev/null
+++ b/tez-mapreduce/src/test/java/org/apache/hadoop/mapred/split/TestGroupedSplits.java
@@ -0,0 +1,278 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.mapred.split;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.io.OutputStreamWriter;
+import java.io.Writer;
+import java.util.ArrayList;
+import java.util.BitSet;
+import java.util.List;
+import java.util.Random;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.compress.CompressionCodec;
+import org.apache.hadoop.io.compress.GzipCodec;
+import org.apache.hadoop.mapred.FileInputFormat;
+import org.apache.hadoop.mapred.InputFormat;
+import org.apache.hadoop.mapred.InputSplit;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.hadoop.mapred.RecordReader;
+import org.apache.hadoop.mapred.Reporter;
+import org.apache.hadoop.mapred.TextInputFormat;
+import org.apache.hadoop.util.ReflectionUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class TestGroupedSplits {
+ private static final Log LOG =
+ LogFactory.getLog(TestGroupedSplits.class);
+
+ private static JobConf defaultConf = new JobConf();
+ private static FileSystem localFs = null;
+
+ static {
+ try {
+ defaultConf.set("fs.defaultFS", "file:///");
+ localFs = FileSystem.getLocal(defaultConf);
+ } catch (IOException e) {
+ throw new RuntimeException("init failure", e);
+ }
+ }
+
+ @SuppressWarnings("deprecation")
+ private static Path workDir =
+ new Path(new Path(System.getProperty("test.build.data", "/tmp")),
+ "TestCombineTextInputFormat").makeQualified(localFs);
+
+ // A reporter that does nothing
+ private static final Reporter voidReporter = Reporter.NULL;
+
+ @Test(timeout=10000)
+ public void testFormat() throws Exception {
+ JobConf job = new JobConf(defaultConf);
+
+ Random random = new Random();
+ long seed = random.nextLong();
+ LOG.info("seed = "+seed);
+ random.setSeed(seed);
+
+ localFs.delete(workDir, true);
+ FileInputFormat.setInputPaths(job, workDir);
+
+ final int length = 10000;
+ final int numFiles = 10;
+
+ createFiles(length, numFiles, random);
+
+ // create a combined split for the files
+ TextInputFormat wrappedFormat = new TextInputFormat();
+ wrappedFormat.configure(job);
+ TezGroupedSplitsInputFormat<LongWritable , Text> format =
+ new TezGroupedSplitsInputFormat<LongWritable, Text>();
+ format.setDesiredNumberOfSPlits(1);
+ format.setInputFormat(wrappedFormat);
+ LongWritable key = new LongWritable();
+ Text value = new Text();
+ for (int i = 0; i < 3; i++) {
+ int numSplits = random.nextInt(length/20)+1;
+ LOG.info("splitting: requesting = " + numSplits);
+ InputSplit[] splits = format.getSplits(job, numSplits);
+ LOG.info("splitting: got = " + splits.length);
+
+ // we should have a single split as the length is comfortably smaller than
+ // the block size
+ Assert.assertEquals("We got more than one splits!", 1, splits.length);
+ InputSplit split = splits[0];
+ Assert.assertEquals("It should be TezGroupedSplit",
+ TezGroupedSplit.class, split.getClass());
+
+ // check the split
+ BitSet bits = new BitSet(length);
+ LOG.debug("split= " + split);
+ RecordReader<LongWritable, Text> reader =
+ format.getRecordReader(split, job, voidReporter);
+ try {
+ int count = 0;
+ while (reader.next(key, value)) {
+ int v = Integer.parseInt(value.toString());
+ LOG.debug("read " + v);
+ if (bits.get(v)) {
+ LOG.warn("conflict with " + v +
+ " at position "+reader.getPos());
+ }
+ Assert.assertFalse("Key in multiple partitions.", bits.get(v));
+ bits.set(v);
+ count++;
+ }
+ LOG.info("splits="+split+" count=" + count);
+ } finally {
+ reader.close();
+ }
+ Assert.assertEquals("Some keys in no partition.", length, bits.cardinality());
+ }
+ }
+
+ private static class Range {
+ private final int start;
+ private final int end;
+
+ Range(int start, int end) {
+ this.start = start;
+ this.end = end;
+ }
+
+ @Override
+ public String toString() {
+ return "(" + start + ", " + end + ")";
+ }
+ }
+
+ private static Range[] createRanges(int length, int numFiles, Random random) {
+ // generate a number of files with various lengths
+ Range[] ranges = new Range[numFiles];
+ for (int i = 0; i < numFiles; i++) {
+ int start = i == 0 ? 0 : ranges[i-1].end;
+ int end = i == numFiles - 1 ?
+ length :
+ (length/numFiles)*(2*i + 1)/2 + random.nextInt(length/numFiles) + 1;
+ ranges[i] = new Range(start, end);
+ }
+ return ranges;
+ }
+
+ private static void createFiles(int length, int numFiles, Random random)
+ throws IOException {
+ Range[] ranges = createRanges(length, numFiles, random);
+
+ for (int i = 0; i < numFiles; i++) {
+ Path file = new Path(workDir, "test_" + i + ".txt");
+ Writer writer = new OutputStreamWriter(localFs.create(file));
+ Range range = ranges[i];
+ try {
+ for (int j = range.start; j < range.end; j++) {
+ writer.write(Integer.toString(j));
+ writer.write("\n");
+ }
+ } finally {
+ writer.close();
+ }
+ }
+ }
+
+ private static void writeFile(FileSystem fs, Path name,
+ CompressionCodec codec,
+ String contents) throws IOException {
+ OutputStream stm;
+ if (codec == null) {
+ stm = fs.create(name);
+ } else {
+ stm = codec.createOutputStream(fs.create(name));
+ }
+ stm.write(contents.getBytes());
+ stm.close();
+ }
+
+ private static List<Text> readSplit(InputFormat<LongWritable,Text> format,
+ InputSplit split,
+ JobConf job) throws IOException {
+ List<Text> result = new ArrayList<Text>();
+ RecordReader<LongWritable, Text> reader =
+ format.getRecordReader(split, job, voidReporter);
+ LongWritable key = reader.createKey();
+ Text value = reader.createValue();
+ while (reader.next(key, value)) {
+ result.add(value);
+ value = reader.createValue();
+ }
+ reader.close();
+ return result;
+ }
+
+ /**
+ * Test using the gzip codec for reading
+ */
+ @Test(timeout=10000)
+ public void testGzip() throws IOException {
+ JobConf job = new JobConf(defaultConf);
+ CompressionCodec gzip = new GzipCodec();
+ ReflectionUtils.setConf(gzip, job);
+ localFs.delete(workDir, true);
+ writeFile(localFs, new Path(workDir, "part1.txt.gz"), gzip,
+ "the quick\nbrown\nfox jumped\nover\n the lazy\n dog\n");
+ writeFile(localFs, new Path(workDir, "part2.txt.gz"), gzip,
+ "is\ngzip\n");
+ writeFile(localFs, new Path(workDir, "part3.txt.gz"), gzip,
+ "one\nmore\nsplit\n");
+ FileInputFormat.setInputPaths(job, workDir);
+ TextInputFormat wrappedFormat = new TextInputFormat();
+ wrappedFormat.configure(job);
+ TezGroupedSplitsInputFormat<LongWritable , Text> format =
+ new TezGroupedSplitsInputFormat<LongWritable, Text>();
+ format.setInputFormat(wrappedFormat);
+
+ // TextInputFormat will produce 3 splits
+ for (int j=1; j<=3; ++j) {
+ format.setDesiredNumberOfSPlits(j);
+ InputSplit[] splits = format.getSplits(job, 100);
+ Assert.assertEquals("compressed splits == " + j, j, splits.length);
+ // j==3 cases exercises the code where desired == actual
+ // and does not do grouping
+ List<Text> results = new ArrayList<Text>();
+ for (int i=0; i<splits.length; ++i) {
+ List<Text> read = readSplit(format, splits[i], job);
+ results.addAll(read);
+ }
+ Assert.assertEquals("splits length", 11, results.size());
+
+ final String[] firstList =
+ {"the quick", "brown", "fox jumped", "over", " the lazy", " dog"};
+ final String[] secondList = {"is", "gzip"};
+ final String[] thirdList = {"one", "more", "split"};
+ String first = results.get(0).toString();
+ int start = 0;
+ switch (first.charAt(0)) {
+ case 't':
+ start = testResults(results, firstList, start);
+ break;
+ case 'i':
+ start = testResults(results, secondList, start);
+ break;
+ case 'o':
+ start = testResults(results, thirdList, start);
+ break;
+ default:
+ Assert.fail("unexpected first token - " + first);
+ }
+ }
+ }
+
+ private static int testResults(List<Text> results, String[] first, int start) {
+ for (int i = 0; i < first.length; i++) {
+ Assert.assertEquals("splits["+i+"]", first[i], results.get(start+i).toString());
+ }
+ return first.length+start;
+ }
+}