You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tez.apache.org by ss...@apache.org on 2013/09/25 09:31:39 UTC
[32/50] [abbrv] Rename tez-engine-api to tez-runtime-api and
tez-engine is split into 2: - tez-engine-library for user-visible
Input/Output/Processor implementations - tez-engine-internals for framework
internals
http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/b212ca1d/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/Shuffle.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/Shuffle.java b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/Shuffle.java
new file mode 100644
index 0000000..8689d11
--- /dev/null
+++ b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/Shuffle.java
@@ -0,0 +1,278 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.tez.runtime.library.common.shuffle.impl;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.FutureTask;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import javax.crypto.SecretKey;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.classification.InterfaceAudience;
+import org.apache.hadoop.classification.InterfaceStability;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.LocalDirAllocator;
+import org.apache.hadoop.security.UserGroupInformation;
+import org.apache.tez.common.TezJobConfig;
+import org.apache.tez.common.counters.TaskCounter;
+import org.apache.tez.common.counters.TezCounter;
+import org.apache.tez.dag.api.TezUncheckedException;
+import org.apache.tez.runtime.api.Event;
+import org.apache.tez.runtime.api.TezInputContext;
+import org.apache.tez.runtime.library.common.TezRuntimeUtils;
+import org.apache.tez.runtime.library.common.combine.Combiner;
+import org.apache.tez.runtime.library.common.shuffle.server.ShuffleHandler;
+import org.apache.tez.runtime.library.common.sort.impl.TezRawKeyValueIterator;
+import org.apache.tez.runtime.library.shuffle.common.ShuffleUtils;
+
+import com.google.common.base.Preconditions;
+
+@InterfaceAudience.Private
+@InterfaceStability.Unstable
+public class Shuffle implements ExceptionReporter {
+
+ private static final Log LOG = LogFactory.getLog(Shuffle.class);
+ private static final int PROGRESS_FREQUENCY = 2000;
+
+ private final Configuration conf;
+ private final TezInputContext inputContext;
+ private final ShuffleClientMetrics metrics;
+
+ private final ShuffleInputEventHandler eventHandler;
+ private final ShuffleScheduler scheduler;
+ private final MergeManager merger;
+ private Throwable throwable = null;
+ private String throwingThreadName = null;
+ private final int numInputs;
+ private final AtomicInteger reduceStartId;
+ private final SecretKey jobTokenSecret;
+ private AtomicInteger reduceRange = new AtomicInteger(
+ TezJobConfig.TEZ_RUNTIME_SHUFFLE_PARTITION_RANGE_DEFAULT);
+
+ private FutureTask<TezRawKeyValueIterator> runShuffleFuture;
+
+ public Shuffle(TezInputContext inputContext, Configuration conf, int numInputs) throws IOException {
+ this.inputContext = inputContext;
+ this.conf = conf;
+ this.metrics = new ShuffleClientMetrics(inputContext.getDAGName(),
+ inputContext.getTaskVertexName(), inputContext.getTaskIndex(),
+ this.conf, UserGroupInformation.getCurrentUser().getShortUserName());
+
+ this.numInputs = numInputs;
+
+ this.jobTokenSecret = ShuffleUtils
+ .getJobTokenSecretFromTokenBytes(inputContext
+ .getServiceConsumerMetaData(ShuffleHandler.MAPREDUCE_SHUFFLE_SERVICEID));
+
+ Combiner combiner = TezRuntimeUtils.instantiateCombiner(conf, inputContext);
+
+ FileSystem localFS = FileSystem.getLocal(this.conf);
+ LocalDirAllocator localDirAllocator =
+ new LocalDirAllocator(TezJobConfig.LOCAL_DIRS);
+
+ // TODO TEZ Get rid of Map / Reduce references.
+ TezCounter shuffledMapsCounter =
+ inputContext.getCounters().findCounter(TaskCounter.SHUFFLED_MAPS);
+ TezCounter reduceShuffleBytes =
+ inputContext.getCounters().findCounter(TaskCounter.REDUCE_SHUFFLE_BYTES);
+ TezCounter failedShuffleCounter =
+ inputContext.getCounters().findCounter(TaskCounter.FAILED_SHUFFLE);
+ TezCounter spilledRecordsCounter =
+ inputContext.getCounters().findCounter(TaskCounter.SPILLED_RECORDS);
+ TezCounter reduceCombineInputCounter =
+ inputContext.getCounters().findCounter(TaskCounter.COMBINE_INPUT_RECORDS);
+ TezCounter mergedMapOutputsCounter =
+ inputContext.getCounters().findCounter(TaskCounter.MERGED_MAP_OUTPUTS);
+
+ reduceStartId = new AtomicInteger(inputContext.getTaskIndex());
+ LOG.info("Shuffle assigned reduce start id: " + reduceStartId.get()
+ + " with default reduce range: " + reduceRange.get());
+
+ scheduler = new ShuffleScheduler(
+ this.inputContext,
+ this.conf,
+ this.numInputs,
+ this,
+ shuffledMapsCounter,
+ reduceShuffleBytes,
+ failedShuffleCounter);
+ eventHandler= new ShuffleInputEventHandler(
+ inputContext,
+ this,
+ scheduler);
+ merger = new MergeManager(
+ this.conf,
+ localFS,
+ localDirAllocator,
+ inputContext,
+ combiner,
+ spilledRecordsCounter,
+ reduceCombineInputCounter,
+ mergedMapOutputsCounter,
+ this);
+ }
+
+ public void handleEvents(List<Event> events) {
+ eventHandler.handleEvents(events);
+ }
+
+ /**
+ * Indicates whether the Shuffle and Merge processing is complete.
+ * @return false if not complete, true if complete or if an error occurred.
+ */
+ public boolean isInputReady() {
+ if (runShuffleFuture == null) {
+ return false;
+ }
+ return runShuffleFuture.isDone();
+ //return scheduler.isDone() && merger.isMergeComplete();
+ }
+
+ /**
+ * Waits for the Shuffle and Merge to complete, and returns an iterator over the input.
+ * @return an iterator over the fetched input.
+ * @throws IOException
+ * @throws InterruptedException
+ */
+ public TezRawKeyValueIterator waitForInput() throws IOException, InterruptedException {
+ Preconditions.checkState(runShuffleFuture != null,
+ "waitForInput can only be called after run");
+ TezRawKeyValueIterator kvIter;
+ try {
+ kvIter = runShuffleFuture.get();
+ } catch (ExecutionException e) {
+ Throwable cause = e.getCause();
+ if (cause instanceof IOException) {
+ throw (IOException) cause;
+ } else if (cause instanceof InterruptedException) {
+ throw (InterruptedException) cause;
+ } else {
+ throw new TezUncheckedException(
+ "Unexpected exception type while running Shuffle and Merge", cause);
+ }
+ }
+ return kvIter;
+ }
+
+ public void run() {
+ RunShuffleCallable runShuffle = new RunShuffleCallable();
+ runShuffleFuture = new FutureTask<TezRawKeyValueIterator>(runShuffle);
+ new Thread(runShuffleFuture, "ShuffleMergeRunner").start();
+ }
+
+ private class RunShuffleCallable implements Callable<TezRawKeyValueIterator> {
+ @Override
+ public TezRawKeyValueIterator call() throws IOException, InterruptedException {
+ // TODO NEWTEZ Limit # fetchers to number of inputs
+ final int numFetchers =
+ conf.getInt(
+ TezJobConfig.TEZ_RUNTIME_SHUFFLE_PARALLEL_COPIES,
+ TezJobConfig.DEFAULT_TEZ_RUNTIME_SHUFFLE_PARALLEL_COPIES);
+ Fetcher[] fetchers = new Fetcher[numFetchers];
+ for (int i = 0; i < numFetchers; ++i) {
+ fetchers[i] = new Fetcher(conf, scheduler, merger, metrics, Shuffle.this, jobTokenSecret, inputContext);
+ fetchers[i].start();
+ }
+
+ while (!scheduler.waitUntilDone(PROGRESS_FREQUENCY)) {
+ synchronized (this) {
+ if (throwable != null) {
+ throw new ShuffleError("error in shuffle in " + throwingThreadName,
+ throwable);
+ }
+ }
+ }
+
+ // Stop the map-output fetcher threads
+ for (Fetcher fetcher : fetchers) {
+ fetcher.shutDown();
+ }
+ fetchers = null;
+
+ // stop the scheduler
+ scheduler.close();
+
+
+ // Finish the on-going merges...
+ TezRawKeyValueIterator kvIter = null;
+ try {
+ kvIter = merger.close();
+ } catch (Throwable e) {
+ throw new ShuffleError("Error while doing final merge " , e);
+ }
+
+ // Sanity check
+ synchronized (Shuffle.this) {
+ if (throwable != null) {
+ throw new ShuffleError("error in shuffle in " + throwingThreadName,
+ throwable);
+ }
+ }
+ return kvIter;
+ }
+ }
+
+ public int getReduceStartId() {
+ return reduceStartId.get();
+ }
+
+ public int getReduceRange() {
+ return reduceRange.get();
+ }
+
+ public synchronized void reportException(Throwable t) {
+ if (throwable == null) {
+ throwable = t;
+ throwingThreadName = Thread.currentThread().getName();
+ // Notify the scheduler so that the reporting thread finds the
+ // exception immediately.
+ synchronized (scheduler) {
+ scheduler.notifyAll();
+ }
+ }
+ }
+
+ public static class ShuffleError extends IOException {
+ private static final long serialVersionUID = 5753909320586607881L;
+
+ ShuffleError(String msg, Throwable t) {
+ super(msg, t);
+ }
+ }
+
+ public void setPartitionRange(int range) {
+ if (range == reduceRange.get()) {
+ return;
+ }
+ if (reduceRange.compareAndSet(
+ TezJobConfig.TEZ_RUNTIME_SHUFFLE_PARTITION_RANGE_DEFAULT, range)) {
+ LOG.info("Reduce range set to: " + range);
+ } else {
+ TezUncheckedException e =
+ new TezUncheckedException("Reduce range can be set only once.");
+ reportException(e);
+ throw e;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/b212ca1d/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/ShuffleClientMetrics.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/ShuffleClientMetrics.java b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/ShuffleClientMetrics.java
new file mode 100644
index 0000000..70de31f
--- /dev/null
+++ b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/ShuffleClientMetrics.java
@@ -0,0 +1,91 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.tez.runtime.library.common.shuffle.impl;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.metrics.MetricsContext;
+import org.apache.hadoop.metrics.MetricsRecord;
+import org.apache.hadoop.metrics.MetricsUtil;
+import org.apache.hadoop.metrics.Updater;
+import org.apache.tez.common.TezJobConfig;
+import org.apache.tez.runtime.library.common.Constants;
+import org.apache.tez.runtime.library.common.TezRuntimeUtils;
+
+class ShuffleClientMetrics implements Updater {
+
+ private MetricsRecord shuffleMetrics = null;
+ private int numFailedFetches = 0;
+ private int numSuccessFetches = 0;
+ private long numBytes = 0;
+ private int numThreadsBusy = 0;
+ private final int numCopiers;
+
+ ShuffleClientMetrics(String dagName, String vertexName, int taskIndex, Configuration conf,
+ String user) {
+ this.numCopiers =
+ conf.getInt(
+ TezJobConfig.TEZ_RUNTIME_SHUFFLE_PARALLEL_COPIES,
+ TezJobConfig.DEFAULT_TEZ_RUNTIME_SHUFFLE_PARALLEL_COPIES);
+
+ MetricsContext metricsContext = MetricsUtil.getContext(Constants.TEZ);
+ this.shuffleMetrics =
+ MetricsUtil.createRecord(metricsContext, "shuffleInput");
+ this.shuffleMetrics.setTag("user", user);
+ this.shuffleMetrics.setTag("dagName", dagName);
+ this.shuffleMetrics.setTag("taskId", TezRuntimeUtils.getTaskIdentifier(vertexName, taskIndex));
+ this.shuffleMetrics.setTag("sessionId",
+ conf.get(
+ TezJobConfig.TEZ_RUNTIME_METRICS_SESSION_ID,
+ TezJobConfig.DEFAULT_TEZ_RUNTIME_METRICS_SESSION_ID));
+ metricsContext.registerUpdater(this);
+ }
+ public synchronized void inputBytes(long numBytes) {
+ this.numBytes += numBytes;
+ }
+ public synchronized void failedFetch() {
+ ++numFailedFetches;
+ }
+ public synchronized void successFetch() {
+ ++numSuccessFetches;
+ }
+ public synchronized void threadBusy() {
+ ++numThreadsBusy;
+ }
+ public synchronized void threadFree() {
+ --numThreadsBusy;
+ }
+ public void doUpdates(MetricsContext unused) {
+ synchronized (this) {
+ shuffleMetrics.incrMetric("shuffle_input_bytes", numBytes);
+ shuffleMetrics.incrMetric("shuffle_failed_fetches",
+ numFailedFetches);
+ shuffleMetrics.incrMetric("shuffle_success_fetches",
+ numSuccessFetches);
+ if (numCopiers != 0) {
+ shuffleMetrics.setMetric("shuffle_fetchers_busy_percent",
+ 100*((float)numThreadsBusy/numCopiers));
+ } else {
+ shuffleMetrics.setMetric("shuffle_fetchers_busy_percent", 0);
+ }
+ numBytes = 0;
+ numSuccessFetches = 0;
+ numFailedFetches = 0;
+ }
+ shuffleMetrics.update();
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/b212ca1d/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/ShuffleHeader.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/ShuffleHeader.java b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/ShuffleHeader.java
new file mode 100644
index 0000000..327473e
--- /dev/null
+++ b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/ShuffleHeader.java
@@ -0,0 +1,94 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.tez.runtime.library.common.shuffle.impl;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+import org.apache.hadoop.classification.InterfaceAudience;
+import org.apache.hadoop.classification.InterfaceStability;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.io.WritableUtils;
+
+/**
+ * Shuffle Header information that is sent by the TaskTracker and
+ * deciphered by the Fetcher thread of Reduce task
+ *
+ */
+@InterfaceAudience.Private
+@InterfaceStability.Stable
+public class ShuffleHeader implements Writable {
+
+ /** Header info of the shuffle http request/response */
+ public static final String HTTP_HEADER_NAME = "name";
+ public static final String DEFAULT_HTTP_HEADER_NAME = "mapreduce";
+ public static final String HTTP_HEADER_VERSION = "version";
+ public static final String DEFAULT_HTTP_HEADER_VERSION = "1.0.0";
+
+ /**
+ * The longest possible length of task attempt id that we will accept.
+ */
+ private static final int MAX_ID_LENGTH = 1000;
+
+ String mapId;
+ long uncompressedLength;
+ long compressedLength;
+ int forReduce;
+
+ public ShuffleHeader() { }
+
+ public ShuffleHeader(String mapId, long compressedLength,
+ long uncompressedLength, int forReduce) {
+ this.mapId = mapId;
+ this.compressedLength = compressedLength;
+ this.uncompressedLength = uncompressedLength;
+ this.forReduce = forReduce;
+ }
+
+ public String getMapId() {
+ return this.mapId;
+ }
+
+ public int getPartition() {
+ return this.forReduce;
+ }
+
+ public long getUncompressedLength() {
+ return uncompressedLength;
+ }
+
+ public long getCompressedLength() {
+ return compressedLength;
+ }
+
+ public void readFields(DataInput in) throws IOException {
+ mapId = WritableUtils.readStringSafely(in, MAX_ID_LENGTH);
+ compressedLength = WritableUtils.readVLong(in);
+ uncompressedLength = WritableUtils.readVLong(in);
+ forReduce = WritableUtils.readVInt(in);
+ }
+
+ public void write(DataOutput out) throws IOException {
+ Text.writeString(out, mapId);
+ WritableUtils.writeVLong(out, compressedLength);
+ WritableUtils.writeVLong(out, uncompressedLength);
+ WritableUtils.writeVInt(out, forReduce);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/b212ca1d/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/ShuffleInputEventHandler.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/ShuffleInputEventHandler.java b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/ShuffleInputEventHandler.java
new file mode 100644
index 0000000..8b323b5
--- /dev/null
+++ b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/ShuffleInputEventHandler.java
@@ -0,0 +1,134 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.common.shuffle.impl;
+
+import java.net.URI;
+import java.util.List;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.tez.dag.api.TezUncheckedException;
+import org.apache.tez.runtime.api.Event;
+import org.apache.tez.runtime.api.TezInputContext;
+import org.apache.tez.runtime.api.events.DataMovementEvent;
+import org.apache.tez.runtime.api.events.InputFailedEvent;
+import org.apache.tez.runtime.api.events.InputInformationEvent;
+import org.apache.tez.runtime.library.common.InputAttemptIdentifier;
+import org.apache.tez.runtime.library.shuffle.impl.ShuffleUserPayloads.DataMovementEventPayloadProto;
+import org.apache.tez.runtime.library.shuffle.impl.ShuffleUserPayloads.InputInformationEventPayloadProto;
+
+import com.google.common.base.Preconditions;
+import com.google.protobuf.InvalidProtocolBufferException;
+
+public class ShuffleInputEventHandler {
+
+ private static final Log LOG = LogFactory.getLog(ShuffleInputEventHandler.class);
+
+ private final ShuffleScheduler scheduler;
+ private final TezInputContext inputContext;
+ private final Shuffle shuffle;
+
+ private int maxMapRuntime = 0;
+ private boolean shuffleRangeSet = false;
+
+ public ShuffleInputEventHandler(TezInputContext inputContext,
+ Shuffle shuffle, ShuffleScheduler scheduler) {
+ this.inputContext = inputContext;
+ this.shuffle = shuffle;
+ this.scheduler = scheduler;
+ }
+
+ public void handleEvents(List<Event> events) {
+ for (Event event : events) {
+ handleEvent(event);
+ }
+ }
+
+
+ private void handleEvent(Event event) {
+ if (event instanceof InputInformationEvent) {
+ processInputInformationEvent((InputInformationEvent) event);
+ }
+ else if (event instanceof DataMovementEvent) {
+ processDataMovementEvent((DataMovementEvent) event);
+ } else if (event instanceof InputFailedEvent) {
+ processTaskFailedEvent((InputFailedEvent) event);
+ }
+ }
+
+ private void processInputInformationEvent(InputInformationEvent iiEvent) {
+ InputInformationEventPayloadProto inputInfoPayload;
+ try {
+ inputInfoPayload = InputInformationEventPayloadProto.parseFrom(iiEvent.getUserPayload());
+ } catch (InvalidProtocolBufferException e) {
+ throw new TezUncheckedException("Unable to parse InputInformationEvent payload", e);
+ }
+ int partitionRange = inputInfoPayload.getPartitionRange();
+ shuffle.setPartitionRange(partitionRange);
+ this.shuffleRangeSet = true;
+ }
+
+ private void processDataMovementEvent(DataMovementEvent dmEvent) {
+ // FIXME TODO NEWTEZ
+ // Preconditions.checkState(shuffleRangeSet == true, "Shuffle Range must be set before a DataMovementEvent is processed");
+ DataMovementEventPayloadProto shufflePayload;
+ try {
+ shufflePayload = DataMovementEventPayloadProto.parseFrom(dmEvent.getUserPayload());
+ } catch (InvalidProtocolBufferException e) {
+ throw new TezUncheckedException("Unable to parse DataMovementEvent payload", e);
+ }
+ int partitionId = dmEvent.getSourceIndex();
+ URI baseUri = getBaseURI(shufflePayload.getHost(), shufflePayload.getPort(), partitionId);
+
+ InputAttemptIdentifier srcAttemptIdentifier = new InputAttemptIdentifier(dmEvent.getTargetIndex(), dmEvent.getVersion(), shufflePayload.getPathComponent());
+ scheduler.addKnownMapOutput(shufflePayload.getHost(), partitionId, baseUri.toString(), srcAttemptIdentifier);
+
+ // TODO NEWTEZ See if this duration hack can be removed.
+ int duration = shufflePayload.getRunDuration();
+ if (duration > maxMapRuntime) {
+ maxMapRuntime = duration;
+ scheduler.informMaxMapRunTime(maxMapRuntime);
+ }
+ }
+
+ private void processTaskFailedEvent(InputFailedEvent ifEvent) {
+ InputAttemptIdentifier taIdentifier = new InputAttemptIdentifier(ifEvent.getSourceIndex(), ifEvent.getVersion());
+ scheduler.obsoleteMapOutput(taIdentifier);
+ LOG.info("Obsoleting output of src-task: " + taIdentifier);
+ }
+
+ // TODO NEWTEZ Handle encrypted shuffle
+ private URI getBaseURI(String host, int port, int partitionId) {
+ StringBuilder sb = new StringBuilder("http://");
+ sb.append(host);
+ sb.append(":");
+ sb.append(String.valueOf(port));
+ sb.append("/");
+
+ sb.append("mapOutput?job=");
+ // Required to use the existing ShuffleHandler
+ sb.append(inputContext.getApplicationId().toString().replace("application", "job"));
+
+ sb.append("&reduce=");
+ sb.append(partitionId);
+ sb.append("&map=");
+ URI u = URI.create(sb.toString());
+ return u;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/b212ca1d/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/ShuffleScheduler.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/ShuffleScheduler.java b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/ShuffleScheduler.java
new file mode 100644
index 0000000..a682a09
--- /dev/null
+++ b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/impl/ShuffleScheduler.java
@@ -0,0 +1,521 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.tez.runtime.library.common.shuffle.impl;
+
+import java.io.IOException;
+import java.text.DecimalFormat;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.DelayQueue;
+import java.util.concurrent.Delayed;
+import java.util.concurrent.TimeUnit;
+
+import org.apache.commons.lang.mutable.MutableInt;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.tez.common.TezJobConfig;
+import org.apache.tez.common.counters.TezCounter;
+import org.apache.tez.runtime.api.Event;
+import org.apache.tez.runtime.api.TezInputContext;
+import org.apache.tez.runtime.api.events.InputReadErrorEvent;
+import org.apache.tez.runtime.library.common.InputAttemptIdentifier;
+import org.apache.tez.runtime.library.common.TezRuntimeUtils;
+
+import com.google.common.collect.Lists;
+
+class ShuffleScheduler {
+ static ThreadLocal<Long> shuffleStart = new ThreadLocal<Long>() {
+ protected Long initialValue() {
+ return 0L;
+ }
+ };
+
+ private static final Log LOG = LogFactory.getLog(ShuffleScheduler.class);
+ private static final int MAX_MAPS_AT_ONCE = 20;
+ private static final long INITIAL_PENALTY = 10000;
+ private static final float PENALTY_GROWTH_RATE = 1.3f;
+
+ // TODO NEWTEZ May need to be a string if attempting to fetch from multiple inputs.
+ private final Map<Integer, MutableInt> finishedMaps;
+ private final int numInputs;
+ private int remainingMaps;
+ private Map<InputAttemptIdentifier, MapHost> mapLocations = new HashMap<InputAttemptIdentifier, MapHost>();
+ //TODO NEWTEZ Clean this and other maps at some point
+ private ConcurrentMap<String, InputAttemptIdentifier> pathToIdentifierMap = new ConcurrentHashMap<String, InputAttemptIdentifier>();
+ private Set<MapHost> pendingHosts = new HashSet<MapHost>();
+ private Set<InputAttemptIdentifier> obsoleteMaps = new HashSet<InputAttemptIdentifier>();
+
+ private final Random random = new Random(System.currentTimeMillis());
+ private final DelayQueue<Penalty> penalties = new DelayQueue<Penalty>();
+ private final Referee referee = new Referee();
+ private final Map<InputAttemptIdentifier, IntWritable> failureCounts =
+ new HashMap<InputAttemptIdentifier,IntWritable>();
+ private final Map<String,IntWritable> hostFailures =
+ new HashMap<String,IntWritable>();
+ private final TezInputContext inputContext;
+ private final Shuffle shuffle;
+ private final int abortFailureLimit;
+ private final TezCounter shuffledMapsCounter;
+ private final TezCounter reduceShuffleBytes;
+ private final TezCounter failedShuffleCounter;
+
+ private final long startTime;
+ private long lastProgressTime;
+
+ private int maxMapRuntime = 0;
+ private int maxFailedUniqueFetches = 5;
+ private int maxFetchFailuresBeforeReporting;
+
+ private long totalBytesShuffledTillNow = 0;
+ private DecimalFormat mbpsFormat = new DecimalFormat("0.00");
+
+ private boolean reportReadErrorImmediately = true;
+
+ public ShuffleScheduler(TezInputContext inputContext,
+ Configuration conf,
+ int tasksInDegree,
+ Shuffle shuffle,
+ TezCounter shuffledMapsCounter,
+ TezCounter reduceShuffleBytes,
+ TezCounter failedShuffleCounter) {
+ this.inputContext = inputContext;
+ this.numInputs = tasksInDegree;
+ abortFailureLimit = Math.max(30, tasksInDegree / 10);
+ remainingMaps = tasksInDegree;
+ //TODO NEWTEZ May need to be a string or a more usable construct if attempting to fetch from multiple inputs. Define a taskId / taskAttemptId pair
+ finishedMaps = new HashMap<Integer, MutableInt>(remainingMaps);
+ this.shuffle = shuffle;
+ this.shuffledMapsCounter = shuffledMapsCounter;
+ this.reduceShuffleBytes = reduceShuffleBytes;
+ this.failedShuffleCounter = failedShuffleCounter;
+ this.startTime = System.currentTimeMillis();
+ this.lastProgressTime = startTime;
+ referee.start();
+ this.maxFailedUniqueFetches = Math.min(tasksInDegree,
+ this.maxFailedUniqueFetches);
+ this.maxFetchFailuresBeforeReporting =
+ conf.getInt(
+ TezJobConfig.TEZ_RUNTIME_SHUFFLE_FETCH_FAILURES,
+ TezJobConfig.DEFAULT_TEZ_RUNTIME_SHUFFLE_FETCH_FAILURES_LIMIT);
+ this.reportReadErrorImmediately =
+ conf.getBoolean(
+ TezJobConfig.TEZ_RUNTIME_SHUFFLE_NOTIFY_READERROR,
+ TezJobConfig.DEFAULT_TEZ_RUNTIME_SHUFFLE_NOTIFY_READERROR);
+ }
+
+ public synchronized void copySucceeded(InputAttemptIdentifier srcAttemptIdentifier,
+ MapHost host,
+ long bytes,
+ long milis,
+ MapOutput output
+ ) throws IOException {
+ String taskIdentifier = TezRuntimeUtils.getTaskAttemptIdentifier(srcAttemptIdentifier.getInputIdentifier().getSrcTaskIndex(), srcAttemptIdentifier.getAttemptNumber());
+ failureCounts.remove(taskIdentifier);
+ hostFailures.remove(host.getHostName());
+
+ if (!isFinishedTaskTrue(srcAttemptIdentifier.getInputIdentifier().getSrcTaskIndex())) {
+ output.commit();
+ if(incrementTaskCopyAndCheckCompletion(srcAttemptIdentifier.getInputIdentifier().getSrcTaskIndex())) {
+ shuffledMapsCounter.increment(1);
+ if (--remainingMaps == 0) {
+ notifyAll();
+ }
+ }
+
+ // update the status
+ lastProgressTime = System.currentTimeMillis();
+ totalBytesShuffledTillNow += bytes;
+ logProgress();
+ reduceShuffleBytes.increment(bytes);
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("src task: "
+ + TezRuntimeUtils.getTaskAttemptIdentifier(
+ inputContext.getSourceVertexName(), srcAttemptIdentifier.getInputIdentifier().getSrcTaskIndex(),
+ srcAttemptIdentifier.getAttemptNumber()) + " done");
+ }
+ }
+ // TODO NEWTEZ Should this be releasing the output, if not committed ? Possible memory leak in case of speculation.
+ }
+
+ private void logProgress() {
+ float mbs = (float) totalBytesShuffledTillNow / (1024 * 1024);
+ int mapsDone = numInputs - remainingMaps;
+ long secsSinceStart = (System.currentTimeMillis() - startTime) / 1000 + 1;
+
+ float transferRate = mbs / secsSinceStart;
+ LOG.info("copy(" + mapsDone + " of " + numInputs + " at "
+ + mbpsFormat.format(transferRate) + " MB/s)");
+ }
+
+ public synchronized void copyFailed(InputAttemptIdentifier srcAttempt,
+ MapHost host,
+ boolean readError) {
+ host.penalize();
+ int failures = 1;
+ if (failureCounts.containsKey(srcAttempt)) {
+ IntWritable x = failureCounts.get(srcAttempt);
+ x.set(x.get() + 1);
+ failures = x.get();
+ } else {
+ failureCounts.put(srcAttempt, new IntWritable(1));
+ }
+ String hostname = host.getHostName();
+ if (hostFailures.containsKey(hostname)) {
+ IntWritable x = hostFailures.get(hostname);
+ x.set(x.get() + 1);
+ } else {
+ hostFailures.put(hostname, new IntWritable(1));
+ }
+ if (failures >= abortFailureLimit) {
+ try {
+ throw new IOException(failures
+ + " failures downloading "
+ + TezRuntimeUtils.getTaskAttemptIdentifier(
+ inputContext.getSourceVertexName(), srcAttempt.getInputIdentifier().getSrcTaskIndex(),
+ srcAttempt.getAttemptNumber()));
+ } catch (IOException ie) {
+ shuffle.reportException(ie);
+ }
+ }
+
+ checkAndInformJobTracker(failures, srcAttempt, readError);
+
+ checkReducerHealth();
+
+ long delay = (long) (INITIAL_PENALTY *
+ Math.pow(PENALTY_GROWTH_RATE, failures));
+
+ penalties.add(new Penalty(host, delay));
+
+ failedShuffleCounter.increment(1);
+ }
+
+ // Notify the JobTracker
+ // after every read error, if 'reportReadErrorImmediately' is true or
+ // after every 'maxFetchFailuresBeforeReporting' failures
+ private void checkAndInformJobTracker(
+ int failures, InputAttemptIdentifier srcAttempt, boolean readError) {
+ if ((reportReadErrorImmediately && readError)
+ || ((failures % maxFetchFailuresBeforeReporting) == 0)) {
+ LOG.info("Reporting fetch failure for "
+ + TezRuntimeUtils.getTaskAttemptIdentifier(
+ inputContext.getSourceVertexName(), srcAttempt.getInputIdentifier().getSrcTaskIndex(),
+ srcAttempt.getAttemptNumber()) + " to jobtracker.");
+
+ List<Event> failedEvents = Lists.newArrayListWithCapacity(1);
+ failedEvents.add(new InputReadErrorEvent("Fetch failure for "
+ + TezRuntimeUtils.getTaskAttemptIdentifier(
+ inputContext.getSourceVertexName(), srcAttempt.getInputIdentifier().getSrcTaskIndex(),
+ srcAttempt.getAttemptNumber()) + " to jobtracker.", srcAttempt.getInputIdentifier()
+ .getSrcTaskIndex(), srcAttempt.getAttemptNumber()));
+
+ inputContext.sendEvents(failedEvents);
+ //status.addFailedDependency(mapId);
+ }
+ }
+
+ private void checkReducerHealth() {
+ final float MAX_ALLOWED_FAILED_FETCH_ATTEMPT_PERCENT = 0.5f;
+ final float MIN_REQUIRED_PROGRESS_PERCENT = 0.5f;
+ final float MAX_ALLOWED_STALL_TIME_PERCENT = 0.5f;
+
+ long totalFailures = failedShuffleCounter.getValue();
+ int doneMaps = numInputs - remainingMaps;
+
+ boolean reducerHealthy =
+ (((float)totalFailures / (totalFailures + doneMaps))
+ < MAX_ALLOWED_FAILED_FETCH_ATTEMPT_PERCENT);
+
+ // check if the reducer has progressed enough
+ boolean reducerProgressedEnough =
+ (((float)doneMaps / numInputs)
+ >= MIN_REQUIRED_PROGRESS_PERCENT);
+
+ // check if the reducer is stalled for a long time
+ // duration for which the reducer is stalled
+ int stallDuration =
+ (int)(System.currentTimeMillis() - lastProgressTime);
+
+ // duration for which the reducer ran with progress
+ int shuffleProgressDuration =
+ (int)(lastProgressTime - startTime);
+
+ // min time the reducer should run without getting killed
+ int minShuffleRunDuration =
+ (shuffleProgressDuration > maxMapRuntime)
+ ? shuffleProgressDuration
+ : maxMapRuntime;
+
+ boolean reducerStalled =
+ (((float)stallDuration / minShuffleRunDuration)
+ >= MAX_ALLOWED_STALL_TIME_PERCENT);
+
+ // kill if not healthy and has insufficient progress
+ if ((failureCounts.size() >= maxFailedUniqueFetches ||
+ failureCounts.size() == (numInputs - doneMaps))
+ && !reducerHealthy
+ && (!reducerProgressedEnough || reducerStalled)) {
+ LOG.fatal("Shuffle failed with too many fetch failures " +
+ "and insufficient progress!");
+ String errorMsg = "Exceeded MAX_FAILED_UNIQUE_FETCHES; bailing-out.";
+ shuffle.reportException(new IOException(errorMsg));
+ }
+
+ }
+
+ public synchronized void tipFailed(int srcTaskIndex) {
+ if (!isFinishedTaskTrue(srcTaskIndex)) {
+ setFinishedTaskTrue(srcTaskIndex);
+ if (--remainingMaps == 0) {
+ notifyAll();
+ }
+ logProgress();
+ }
+ }
+
+ public synchronized void addKnownMapOutput(String hostName,
+ int partitionId,
+ String hostUrl,
+ InputAttemptIdentifier srcAttempt) {
+ String identifier = MapHost.createIdentifier(hostName, partitionId);
+ MapHost host = mapLocations.get(identifier);
+ if (host == null) {
+ host = new MapHost(partitionId, hostName, hostUrl);
+ assert identifier.equals(host.getIdentifier());
+ mapLocations.put(srcAttempt, host);
+ }
+ host.addKnownMap(srcAttempt);
+ pathToIdentifierMap.put(srcAttempt.getPathComponent(), srcAttempt);
+
+ // Mark the host as pending
+ if (host.getState() == MapHost.State.PENDING) {
+ pendingHosts.add(host);
+ notifyAll();
+ }
+ }
+
+ public synchronized void obsoleteMapOutput(InputAttemptIdentifier srcAttempt) {
+ // The incoming srcAttempt does not contain a path component.
+ obsoleteMaps.add(srcAttempt);
+ }
+
+ public synchronized void putBackKnownMapOutput(MapHost host,
+ InputAttemptIdentifier srcAttempt) {
+ host.addKnownMap(srcAttempt);
+ }
+
+ public synchronized MapHost getHost() throws InterruptedException {
+ while(pendingHosts.isEmpty()) {
+ wait();
+ }
+
+ MapHost host = null;
+ Iterator<MapHost> iter = pendingHosts.iterator();
+ int numToPick = random.nextInt(pendingHosts.size());
+ for (int i=0; i <= numToPick; ++i) {
+ host = iter.next();
+ }
+
+ pendingHosts.remove(host);
+ host.markBusy();
+
+ LOG.info("Assigning " + host + " with " + host.getNumKnownMapOutputs() +
+ " to " + Thread.currentThread().getName());
+ shuffleStart.set(System.currentTimeMillis());
+
+ return host;
+ }
+
+ public InputAttemptIdentifier getIdentifierForPathComponent(String pathComponent) {
+ return pathToIdentifierMap.get(pathComponent);
+ }
+
+ public synchronized List<InputAttemptIdentifier> getMapsForHost(MapHost host) {
+ List<InputAttemptIdentifier> list = host.getAndClearKnownMaps();
+ Iterator<InputAttemptIdentifier> itr = list.iterator();
+ List<InputAttemptIdentifier> result = new ArrayList<InputAttemptIdentifier>();
+ int includedMaps = 0;
+ int totalSize = list.size();
+ // find the maps that we still need, up to the limit
+ while (itr.hasNext()) {
+ InputAttemptIdentifier id = itr.next();
+ if (!obsoleteMaps.contains(id) && !isFinishedTaskTrue(id.getInputIdentifier().getSrcTaskIndex())) {
+ result.add(id);
+ if (++includedMaps >= MAX_MAPS_AT_ONCE) {
+ break;
+ }
+ }
+ }
+ // put back the maps left after the limit
+ while (itr.hasNext()) {
+ InputAttemptIdentifier id = itr.next();
+ if (!obsoleteMaps.contains(id) && !isFinishedTaskTrue(id.getInputIdentifier().getSrcTaskIndex())) {
+ host.addKnownMap(id);
+ }
+ }
+ LOG.info("assigned " + includedMaps + " of " + totalSize + " to " +
+ host + " to " + Thread.currentThread().getName());
+ return result;
+ }
+
+ public synchronized void freeHost(MapHost host) {
+ if (host.getState() != MapHost.State.PENALIZED) {
+ if (host.markAvailable() == MapHost.State.PENDING) {
+ pendingHosts.add(host);
+ notifyAll();
+ }
+ }
+ LOG.info(host + " freed by " + Thread.currentThread().getName() + " in " +
+ (System.currentTimeMillis()-shuffleStart.get()) + "s");
+ }
+
+ public synchronized void resetKnownMaps() {
+ mapLocations.clear();
+ obsoleteMaps.clear();
+ pendingHosts.clear();
+ pathToIdentifierMap.clear();
+ }
+
+ /**
+ * Utility method to check if the Shuffle data fetch is complete.
+ * @return
+ */
+ public synchronized boolean isDone() {
+ return remainingMaps == 0;
+ }
+
+ /**
+ * Wait until the shuffle finishes or until the timeout.
+ * @param millis maximum wait time
+ * @return true if the shuffle is done
+ * @throws InterruptedException
+ */
+ public synchronized boolean waitUntilDone(int millis
+ ) throws InterruptedException {
+ if (remainingMaps > 0) {
+ wait(millis);
+ return remainingMaps == 0;
+ }
+ return true;
+ }
+
+ /**
+ * A structure that records the penalty for a host.
+ */
+ private static class Penalty implements Delayed {
+ MapHost host;
+ private long endTime;
+
+ Penalty(MapHost host, long delay) {
+ this.host = host;
+ this.endTime = System.currentTimeMillis() + delay;
+ }
+
+ public long getDelay(TimeUnit unit) {
+ long remainingTime = endTime - System.currentTimeMillis();
+ return unit.convert(remainingTime, TimeUnit.MILLISECONDS);
+ }
+
+ public int compareTo(Delayed o) {
+ long other = ((Penalty) o).endTime;
+ return endTime == other ? 0 : (endTime < other ? -1 : 1);
+ }
+
+ }
+
+ /**
+ * A thread that takes hosts off of the penalty list when the timer expires.
+ */
+ private class Referee extends Thread {
+ public Referee() {
+ setName("ShufflePenaltyReferee");
+ setDaemon(true);
+ }
+
+ public void run() {
+ try {
+ while (true) {
+ // take the first host that has an expired penalty
+ MapHost host = penalties.take().host;
+ synchronized (ShuffleScheduler.this) {
+ if (host.markAvailable() == MapHost.State.PENDING) {
+ pendingHosts.add(host);
+ ShuffleScheduler.this.notifyAll();
+ }
+ }
+ }
+ } catch (InterruptedException ie) {
+ return;
+ } catch (Throwable t) {
+ shuffle.reportException(t);
+ }
+ }
+ }
+
+ public void close() throws InterruptedException {
+ referee.interrupt();
+ referee.join();
+ }
+
+ public synchronized void informMaxMapRunTime(int duration) {
+ if (duration > maxMapRuntime) {
+ maxMapRuntime = duration;
+ }
+ }
+
+ void setFinishedTaskTrue(int srcTaskIndex) {
+ synchronized(finishedMaps) {
+ finishedMaps.put(srcTaskIndex, new MutableInt(shuffle.getReduceRange()));
+ }
+ }
+
+ boolean incrementTaskCopyAndCheckCompletion(int srcTaskIndex) {
+ synchronized(finishedMaps) {
+ MutableInt result = finishedMaps.get(srcTaskIndex);
+ if(result == null) {
+ result = new MutableInt(0);
+ finishedMaps.put(srcTaskIndex, result);
+ }
+ result.increment();
+ return isFinishedTaskTrue(srcTaskIndex);
+ }
+ }
+
+ boolean isFinishedTaskTrue(int srcTaskIndex) {
+ synchronized (finishedMaps) {
+ MutableInt result = finishedMaps.get(srcTaskIndex);
+ if(result == null) {
+ return false;
+ }
+ if (result.intValue() == shuffle.getReduceRange()) {
+ return true;
+ }
+ return false;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/b212ca1d/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/server/ShuffleHandler.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/server/ShuffleHandler.java b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/server/ShuffleHandler.java
new file mode 100644
index 0000000..9a206c6
--- /dev/null
+++ b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/shuffle/server/ShuffleHandler.java
@@ -0,0 +1,572 @@
+/**
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements. See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership. The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.tez.runtime.library.common.shuffle.server;
+
+import static org.jboss.netty.buffer.ChannelBuffers.wrappedBuffer;
+import static org.jboss.netty.handler.codec.http.HttpHeaders.Names.CONTENT_TYPE;
+import static org.jboss.netty.handler.codec.http.HttpMethod.GET;
+import static org.jboss.netty.handler.codec.http.HttpResponseStatus.BAD_REQUEST;
+import static org.jboss.netty.handler.codec.http.HttpResponseStatus.FORBIDDEN;
+import static org.jboss.netty.handler.codec.http.HttpResponseStatus.INTERNAL_SERVER_ERROR;
+import static org.jboss.netty.handler.codec.http.HttpResponseStatus.METHOD_NOT_ALLOWED;
+import static org.jboss.netty.handler.codec.http.HttpResponseStatus.NOT_FOUND;
+import static org.jboss.netty.handler.codec.http.HttpResponseStatus.OK;
+import static org.jboss.netty.handler.codec.http.HttpResponseStatus.UNAUTHORIZED;
+import static org.jboss.netty.handler.codec.http.HttpVersion.HTTP_1_1;
+
+import java.io.IOException;
+import java.net.InetSocketAddress;
+import java.net.URL;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ThreadFactory;
+import java.util.concurrent.TimeUnit;
+
+import javax.crypto.SecretKey;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.DataInputByteBuffer;
+import org.apache.hadoop.io.DataOutputBuffer;
+import org.apache.hadoop.metrics2.MetricsSystem;
+import org.apache.hadoop.metrics2.annotation.Metric;
+import org.apache.hadoop.metrics2.annotation.Metrics;
+import org.apache.hadoop.metrics2.lib.DefaultMetricsSystem;
+import org.apache.hadoop.metrics2.lib.MutableCounterInt;
+import org.apache.hadoop.metrics2.lib.MutableCounterLong;
+import org.apache.hadoop.metrics2.lib.MutableGaugeInt;
+import org.apache.hadoop.security.ssl.SSLFactory;
+import org.apache.hadoop.security.token.Token;
+import org.apache.hadoop.yarn.api.records.ApplicationId;
+import org.apache.hadoop.yarn.server.api.ApplicationInitializationContext;
+import org.apache.hadoop.yarn.server.api.ApplicationTerminationContext;
+import org.apache.hadoop.yarn.server.api.AuxiliaryService;
+import org.apache.tez.common.TezJobConfig;
+import org.apache.tez.runtime.api.TezOutputContext;
+import org.apache.tez.runtime.library.common.security.JobTokenIdentifier;
+import org.apache.tez.runtime.library.common.security.JobTokenSecretManager;
+import org.apache.tez.runtime.library.common.security.SecureShuffleUtils;
+import org.apache.tez.runtime.library.common.shuffle.impl.ShuffleHeader;
+import org.apache.tez.runtime.library.common.sort.impl.ExternalSorter;
+import org.apache.tez.runtime.library.shuffle.common.ShuffleUtils;
+import org.jboss.netty.bootstrap.ServerBootstrap;
+import org.jboss.netty.buffer.ChannelBuffers;
+import org.jboss.netty.channel.Channel;
+import org.jboss.netty.channel.ChannelFactory;
+import org.jboss.netty.channel.ChannelFuture;
+import org.jboss.netty.channel.ChannelFutureListener;
+import org.jboss.netty.channel.ChannelHandlerContext;
+import org.jboss.netty.channel.ChannelPipeline;
+import org.jboss.netty.channel.ChannelPipelineFactory;
+import org.jboss.netty.channel.Channels;
+import org.jboss.netty.channel.ExceptionEvent;
+import org.jboss.netty.channel.MessageEvent;
+import org.jboss.netty.channel.SimpleChannelUpstreamHandler;
+import org.jboss.netty.channel.group.ChannelGroup;
+import org.jboss.netty.channel.group.DefaultChannelGroup;
+import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory;
+import org.jboss.netty.handler.codec.frame.TooLongFrameException;
+import org.jboss.netty.handler.codec.http.DefaultHttpResponse;
+import org.jboss.netty.handler.codec.http.HttpChunkAggregator;
+import org.jboss.netty.handler.codec.http.HttpRequest;
+import org.jboss.netty.handler.codec.http.HttpRequestDecoder;
+import org.jboss.netty.handler.codec.http.HttpResponse;
+import org.jboss.netty.handler.codec.http.HttpResponseEncoder;
+import org.jboss.netty.handler.codec.http.HttpResponseStatus;
+import org.jboss.netty.handler.codec.http.QueryStringDecoder;
+import org.jboss.netty.handler.ssl.SslHandler;
+import org.jboss.netty.handler.stream.ChunkedStream;
+import org.jboss.netty.handler.stream.ChunkedWriteHandler;
+import org.jboss.netty.util.CharsetUtil;
+
+import com.google.common.util.concurrent.ThreadFactoryBuilder;
+
+public class ShuffleHandler extends AuxiliaryService {
+
+ private static final Log LOG = LogFactory.getLog(ShuffleHandler.class);
+
+ public static final String SHUFFLE_MANAGE_OS_CACHE = "mapreduce.shuffle.manage.os.cache";
+ public static final boolean DEFAULT_SHUFFLE_MANAGE_OS_CACHE = true;
+
+ public static final String SHUFFLE_READAHEAD_BYTES = "mapreduce.shuffle.readahead.bytes";
+ public static final int DEFAULT_SHUFFLE_READAHEAD_BYTES = 4 * 1024 * 1024;
+
+ private int port;
+ private ChannelFactory selector;
+ private final ChannelGroup accepted = new DefaultChannelGroup();
+ private HttpPipelineFactory pipelineFact;
+ private int sslFileBufferSize;
+
+ public static final String MAPREDUCE_SHUFFLE_SERVICEID =
+ "mapreduce.shuffle";
+
+ private static final Map<String,String> userRsrc =
+ new ConcurrentHashMap<String,String>();
+ private static final JobTokenSecretManager secretManager =
+ new JobTokenSecretManager();
+ private SecretKey tokenSecret;
+
+ public static final String SHUFFLE_PORT_CONFIG_KEY = "mapreduce.shuffle.port";
+ public static final int DEFAULT_SHUFFLE_PORT = 8080;
+
+ public static final String SUFFLE_SSL_FILE_BUFFER_SIZE_KEY =
+ "mapreduce.shuffle.ssl.file.buffer.size";
+
+ public static final int DEFAULT_SUFFLE_SSL_FILE_BUFFER_SIZE = 60 * 1024;
+
+ private ExternalSorter sorter;
+
+ @Metrics(about="Shuffle output metrics", context="mapred")
+ static class ShuffleMetrics implements ChannelFutureListener {
+ @Metric("Shuffle output in bytes")
+ MutableCounterLong shuffleOutputBytes;
+ @Metric("# of failed shuffle outputs")
+ MutableCounterInt shuffleOutputsFailed;
+ @Metric("# of succeeeded shuffle outputs")
+ MutableCounterInt shuffleOutputsOK;
+ @Metric("# of current shuffle connections")
+ MutableGaugeInt shuffleConnections;
+
+ @Override
+ public void operationComplete(ChannelFuture future) throws Exception {
+ if (future.isSuccess()) {
+ shuffleOutputsOK.incr();
+ } else {
+ shuffleOutputsFailed.incr();
+ }
+ shuffleConnections.decr();
+ }
+ }
+
+ final ShuffleMetrics metrics;
+
+ ShuffleHandler(MetricsSystem ms) {
+ super("httpshuffle");
+ metrics = ms.register(new ShuffleMetrics());
+ }
+
+ public ShuffleHandler(ExternalSorter sorter) {
+ this(DefaultMetricsSystem.instance());
+ this.sorter = sorter;
+ }
+
+ /**
+ * Serialize the shuffle port into a ByteBuffer for use later on.
+ * @param port the port to be sent to the ApplciationMaster
+ * @return the serialized form of the port.
+ */
+ public static ByteBuffer serializeMetaData(int port) throws IOException {
+ //TODO these bytes should be versioned
+ DataOutputBuffer port_dob = new DataOutputBuffer();
+ port_dob.writeInt(port);
+ return ByteBuffer.wrap(port_dob.getData(), 0, port_dob.getLength());
+ }
+
+ /**
+ * A helper function to deserialize the metadata returned by ShuffleHandler.
+ * @param meta the metadata returned by the ShuffleHandler
+ * @return the port the Shuffle Handler is listening on to serve shuffle data.
+ */
+ public static int deserializeMetaData(ByteBuffer meta) throws IOException {
+ //TODO this should be returning a class not just an int
+ DataInputByteBuffer in = new DataInputByteBuffer();
+ in.reset(meta);
+ int port = in.readInt();
+ return port;
+ }
+
+ /**
+ * A helper function to serialize the JobTokenIdentifier to be sent to the
+ * ShuffleHandler as ServiceData.
+ * @param jobToken the job token to be used for authentication of
+ * shuffle data requests.
+ * @return the serialized version of the jobToken.
+ */
+ public static ByteBuffer serializeServiceData(Token<JobTokenIdentifier> jobToken) throws IOException {
+ //TODO these bytes should be versioned
+ DataOutputBuffer jobToken_dob = new DataOutputBuffer();
+ jobToken.write(jobToken_dob);
+ return ByteBuffer.wrap(jobToken_dob.getData(), 0, jobToken_dob.getLength());
+ }
+
+ static Token<JobTokenIdentifier> deserializeServiceData(ByteBuffer secret) throws IOException {
+ DataInputByteBuffer in = new DataInputByteBuffer();
+ in.reset(secret);
+ Token<JobTokenIdentifier> jt = new Token<JobTokenIdentifier>();
+ jt.readFields(in);
+ return jt;
+ }
+
+
+ @Override
+ public void initializeApplication(
+ ApplicationInitializationContext initAppContext) {
+ // TODO these bytes should be versioned
+ try {
+ String user = initAppContext.getUser();
+ ApplicationId appId = initAppContext.getApplicationId();
+ ByteBuffer secret = initAppContext.getApplicationDataForService();
+ Token<JobTokenIdentifier> jt = deserializeServiceData(secret);
+ // TODO: Once SHuffle is out of NM, this can use MR APIs
+ userRsrc.put(appId.toString(), user);
+ LOG.info("Added token for " + appId.toString());
+ secretManager.addTokenForJob(appId.toString(), jt);
+ } catch (IOException e) {
+ LOG.error("Error during initApp", e);
+ // TODO add API to AuxiliaryServices to report failures
+ }
+ }
+
+ @Override
+ public void stopApplication(ApplicationTerminationContext context) {
+ ApplicationId appId = context.getApplicationId();
+ secretManager.removeTokenForJob(appId.toString());
+ userRsrc.remove(appId.toString());
+ }
+
+ public void initialize(TezOutputContext outputContext, Configuration conf) throws IOException {
+ this.init(new Configuration(conf));
+ tokenSecret = ShuffleUtils.getJobTokenSecretFromTokenBytes(outputContext.getServiceConsumerMetaData(MAPREDUCE_SHUFFLE_SERVICEID));
+ }
+
+ @Override
+ public synchronized void serviceInit(Configuration conf) {
+ ThreadFactory bossFactory = new ThreadFactoryBuilder()
+ .setNameFormat("ShuffleHandler Netty Boss #%d")
+ .build();
+ ThreadFactory workerFactory = new ThreadFactoryBuilder()
+ .setNameFormat("ShuffleHandler Netty Worker #%d")
+ .build();
+
+ selector = new NioServerSocketChannelFactory(
+ Executors.newCachedThreadPool(bossFactory),
+ Executors.newCachedThreadPool(workerFactory));
+ }
+
+ // TODO change AbstractService to throw InterruptedException
+ @Override
+ public synchronized void serviceStart() {
+ Configuration conf = getConfig();
+ ServerBootstrap bootstrap = new ServerBootstrap(selector);
+ try {
+ pipelineFact = new HttpPipelineFactory(conf);
+ } catch (Exception ex) {
+ throw new RuntimeException(ex);
+ }
+ bootstrap.setPipelineFactory(pipelineFact);
+ // Let OS pick the port
+ Channel ch = bootstrap.bind(new InetSocketAddress(0));
+ accepted.add(ch);
+ port = ((InetSocketAddress)ch.getLocalAddress()).getPort();
+ conf.set(SHUFFLE_PORT_CONFIG_KEY, Integer.toString(port));
+ pipelineFact.SHUFFLE.setPort(port);
+ LOG.info(getName() + " listening on port " + port);
+
+ sslFileBufferSize = conf.getInt(SUFFLE_SSL_FILE_BUFFER_SIZE_KEY,
+ DEFAULT_SUFFLE_SSL_FILE_BUFFER_SIZE);
+ }
+
+ @Override
+ public synchronized void serviceStop() {
+ accepted.close().awaitUninterruptibly(10, TimeUnit.SECONDS);
+ ServerBootstrap bootstrap = new ServerBootstrap(selector);
+ bootstrap.releaseExternalResources();
+ pipelineFact.destroy();
+ }
+
+ @Override
+ public synchronized ByteBuffer getMetaData() {
+ try {
+ return serializeMetaData(port);
+ } catch (IOException e) {
+ LOG.error("Error during getMeta", e);
+ // TODO add API to AuxiliaryServices to report failures
+ return null;
+ }
+ }
+
+ class HttpPipelineFactory implements ChannelPipelineFactory {
+
+ final Shuffle SHUFFLE;
+ private SSLFactory sslFactory;
+
+ public HttpPipelineFactory(Configuration conf) throws Exception {
+ SHUFFLE = new Shuffle(conf);
+ if (conf.getBoolean(TezJobConfig.TEZ_RUNTIME_SHUFFLE_ENABLE_SSL,
+ TezJobConfig.DEFAULT_TEZ_RUNTIME_SHUFFLE_ENABLE_SSL)) {
+ sslFactory = new SSLFactory(SSLFactory.Mode.SERVER, conf);
+ sslFactory.init();
+ }
+ }
+
+ public void destroy() {
+ if (sslFactory != null) {
+ sslFactory.destroy();
+ }
+ }
+
+ @Override
+ public ChannelPipeline getPipeline() throws Exception {
+ ChannelPipeline pipeline = Channels.pipeline();
+ if (sslFactory != null) {
+ pipeline.addLast("ssl", new SslHandler(sslFactory.createSSLEngine()));
+ }
+ pipeline.addLast("decoder", new HttpRequestDecoder());
+ pipeline.addLast("aggregator", new HttpChunkAggregator(1 << 16));
+ pipeline.addLast("encoder", new HttpResponseEncoder());
+ pipeline.addLast("chunking", new ChunkedWriteHandler());
+ pipeline.addLast("shuffle", SHUFFLE);
+ return pipeline;
+ // TODO factor security manager into pipeline
+ // TODO factor out encode/decode to permit binary shuffle
+ // TODO factor out decode of index to permit alt. models
+ }
+
+ }
+
+ class Shuffle extends SimpleChannelUpstreamHandler {
+
+ private final Configuration conf;
+ private int port;
+
+ public Shuffle(Configuration conf) {
+ this.conf = conf;
+ this.port = conf.getInt(SHUFFLE_PORT_CONFIG_KEY, DEFAULT_SHUFFLE_PORT);
+ }
+
+ public void setPort(int port) {
+ this.port = port;
+ }
+
+ private List<String> splitMaps(List<String> mapq) {
+ if (null == mapq) {
+ return null;
+ }
+ final List<String> ret = new ArrayList<String>();
+ for (String s : mapq) {
+ Collections.addAll(ret, s.split(","));
+ }
+ return ret;
+ }
+
+ @Override
+ public void messageReceived(ChannelHandlerContext ctx, MessageEvent evt)
+ throws Exception {
+ HttpRequest request = (HttpRequest) evt.getMessage();
+ if (request.getMethod() != GET) {
+ sendError(ctx, METHOD_NOT_ALLOWED);
+ return;
+ }
+ // Check whether the shuffle version is compatible
+ if (!ShuffleHeader.DEFAULT_HTTP_HEADER_NAME.equals(
+ request.getHeader(ShuffleHeader.HTTP_HEADER_NAME))
+ || !ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION.equals(
+ request.getHeader(ShuffleHeader.HTTP_HEADER_VERSION))) {
+ sendError(ctx, "Incompatible shuffle request version", BAD_REQUEST);
+ }
+ final Map<String,List<String>> q =
+ new QueryStringDecoder(request.getUri()).getParameters();
+ final List<String> mapIds = splitMaps(q.get("map"));
+ final List<String> reduceQ = q.get("reduce");
+ final List<String> jobQ = q.get("job");
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("RECV: " + request.getUri() +
+ "\n mapId: " + mapIds +
+ "\n reduceId: " + reduceQ +
+ "\n jobId: " + jobQ);
+ }
+
+ if (mapIds == null || reduceQ == null || jobQ == null) {
+ sendError(ctx, "Required param job, map and reduce", BAD_REQUEST);
+ return;
+ }
+ if (reduceQ.size() != 1 || jobQ.size() != 1) {
+ sendError(ctx, "Too many job/reduce parameters", BAD_REQUEST);
+ return;
+ }
+ int reduceId;
+ String jobId;
+ try {
+ reduceId = Integer.parseInt(reduceQ.get(0));
+ jobId = jobQ.get(0);
+ } catch (NumberFormatException e) {
+ sendError(ctx, "Bad reduce parameter", BAD_REQUEST);
+ return;
+ } catch (IllegalArgumentException e) {
+ sendError(ctx, "Bad job parameter", BAD_REQUEST);
+ return;
+ }
+
+ final String reqUri = request.getUri();
+ if (null == reqUri) {
+ // TODO? add upstream?
+ sendError(ctx, FORBIDDEN);
+ return;
+ }
+ HttpResponse response = new DefaultHttpResponse(HTTP_1_1, OK);
+ try {
+ verifyRequest(jobId, ctx, request, response,
+ new URL("http", "", this.port, reqUri));
+ } catch (IOException e) {
+ LOG.warn("Shuffle failure ", e);
+ sendError(ctx, e.getMessage(), UNAUTHORIZED);
+ return;
+ }
+
+ Channel ch = evt.getChannel();
+ ch.write(response);
+ // TODO refactor the following into the pipeline
+ ChannelFuture lastMap = null;
+ for (String mapId : mapIds) {
+ try {
+ // TODO: Error handling - validate mapId via TezTaskAttemptId.forName
+
+ // TODO NEWTEZ Fix this. TaskAttemptId is no longer valid. mapId validation will not work anymore.
+// if (!mapId.equals(sorter.getTaskAttemptId().toString())) {
+// String errorMessage =
+// "Illegal shuffle request mapId: " + mapId
+// + " while actual mapId is " + sorter.getTaskAttemptId();
+// LOG.warn(errorMessage);
+// sendError(ctx, errorMessage, BAD_REQUEST);
+// return;
+// }
+
+ lastMap =
+ sendMapOutput(ctx, ch, userRsrc.get(jobId), jobId, mapId, reduceId);
+ if (null == lastMap) {
+ sendError(ctx, NOT_FOUND);
+ return;
+ }
+ } catch (IOException e) {
+ LOG.error("Shuffle error ", e);
+ sendError(ctx, e.getMessage(), INTERNAL_SERVER_ERROR);
+ return;
+ }
+ }
+ lastMap.addListener(metrics);
+ lastMap.addListener(ChannelFutureListener.CLOSE);
+ }
+
+ private void verifyRequest(String appid, ChannelHandlerContext ctx,
+ HttpRequest request, HttpResponse response, URL requestUri)
+ throws IOException {
+ if (null == tokenSecret) {
+ LOG.info("Request for unknown token " + appid);
+ throw new IOException("could not find jobid");
+ }
+ // string to encrypt
+ String enc_str = SecureShuffleUtils.buildMsgFrom(requestUri);
+ // hash from the fetcher
+ String urlHashStr =
+ request.getHeader(SecureShuffleUtils.HTTP_HEADER_URL_HASH);
+ if (urlHashStr == null) {
+ LOG.info("Missing header hash for " + appid);
+ throw new IOException("fetcher cannot be authenticated");
+ }
+ if (LOG.isDebugEnabled()) {
+ int len = urlHashStr.length();
+ LOG.debug("verifying request. enc_str=" + enc_str + "; hash=..." +
+ urlHashStr.substring(len-len/2, len-1));
+ }
+ // verify - throws exception
+ SecureShuffleUtils.verifyReply(urlHashStr, enc_str, tokenSecret);
+ // verification passed - encode the reply
+ String reply =
+ SecureShuffleUtils.generateHash(urlHashStr.getBytes(), tokenSecret);
+ response.setHeader(SecureShuffleUtils.HTTP_HEADER_REPLY_URL_HASH, reply);
+ addVersionToHeader(response);
+ if (LOG.isDebugEnabled()) {
+ int len = reply.length();
+ LOG.debug("Fetcher request verfied. enc_str=" + enc_str + ";reply=" +
+ reply.substring(len-len/2, len-1));
+ }
+ }
+
+ protected ChannelFuture sendMapOutput(ChannelHandlerContext ctx, Channel ch,
+ String user, String jobId, String mapId, int reduce)
+ throws IOException {
+ final ShuffleHeader header = sorter.getShuffleHeader(reduce);
+ final DataOutputBuffer dob = new DataOutputBuffer();
+ header.write(dob);
+ ch.write(wrappedBuffer(dob.getData(), 0, dob.getLength()));
+
+ ChannelFuture writeFuture =
+ ch.write(
+ new ChunkedStream(
+ sorter.getSortedStream(reduce), sslFileBufferSize
+ )
+ );
+ metrics.shuffleConnections.incr();
+ metrics.shuffleOutputBytes.incr(header.getCompressedLength()); // optimistic
+ return writeFuture;
+ }
+
+ private void sendError(ChannelHandlerContext ctx,
+ HttpResponseStatus status) {
+ sendError(ctx, "", status);
+ }
+
+ private void sendError(ChannelHandlerContext ctx, String message,
+ HttpResponseStatus status) {
+ HttpResponse response = new DefaultHttpResponse(HTTP_1_1, status);
+ response.setHeader(CONTENT_TYPE, "text/plain; charset=UTF-8");
+ addVersionToHeader(response);
+ response.setContent(
+ ChannelBuffers.copiedBuffer(message, CharsetUtil.UTF_8));
+ // Close the connection as soon as the error message is sent.
+ ctx.getChannel().write(response).addListener(ChannelFutureListener.CLOSE);
+ }
+
+ private void addVersionToHeader(HttpResponse response) {
+ // Put shuffle version into http header
+ response.setHeader(ShuffleHeader.HTTP_HEADER_NAME,
+ ShuffleHeader.DEFAULT_HTTP_HEADER_NAME);
+ response.setHeader(ShuffleHeader.HTTP_HEADER_VERSION,
+ ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION);
+ }
+
+ @Override
+ public void exceptionCaught(ChannelHandlerContext ctx, ExceptionEvent e)
+ throws Exception {
+ Channel ch = e.getChannel();
+ Throwable cause = e.getCause();
+ if (cause instanceof TooLongFrameException) {
+ sendError(ctx, BAD_REQUEST);
+ return;
+ }
+
+ LOG.error("Shuffle error: ", cause);
+ if (ch.isConnected()) {
+ LOG.error("Shuffle error " + e);
+ sendError(ctx, INTERNAL_SERVER_ERROR);
+ }
+ }
+
+ }
+
+ public int getPort() {
+ return port;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/b212ca1d/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/sort/impl/ExternalSorter.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/sort/impl/ExternalSorter.java b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/sort/impl/ExternalSorter.java
new file mode 100644
index 0000000..c362d98
--- /dev/null
+++ b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/common/sort/impl/ExternalSorter.java
@@ -0,0 +1,194 @@
+/**
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements. See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership. The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.tez.runtime.library.common.sort.impl;
+
+import java.io.File;
+import java.io.IOException;
+import java.io.InputStream;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.classification.InterfaceAudience.Private;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.LocalFileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.fs.RawLocalFileSystem;
+import org.apache.hadoop.io.RawComparator;
+import org.apache.hadoop.io.compress.CompressionCodec;
+import org.apache.hadoop.io.compress.DefaultCodec;
+import org.apache.hadoop.io.serializer.SerializationFactory;
+import org.apache.hadoop.io.serializer.Serializer;
+import org.apache.hadoop.util.IndexedSorter;
+import org.apache.hadoop.util.Progressable;
+import org.apache.hadoop.util.QuickSort;
+import org.apache.hadoop.util.ReflectionUtils;
+import org.apache.tez.common.TezJobConfig;
+import org.apache.tez.common.counters.TaskCounter;
+import org.apache.tez.common.counters.TezCounter;
+import org.apache.tez.runtime.api.TezOutputContext;
+import org.apache.tez.runtime.library.api.Partitioner;
+import org.apache.tez.runtime.library.common.ConfigUtils;
+import org.apache.tez.runtime.library.common.TezRuntimeUtils;
+import org.apache.tez.runtime.library.common.combine.Combiner;
+import org.apache.tez.runtime.library.common.shuffle.impl.ShuffleHeader;
+import org.apache.tez.runtime.library.common.sort.impl.IFile.Writer;
+import org.apache.tez.runtime.library.common.task.local.output.TezTaskOutput;
+import org.apache.tez.runtime.library.hadoop.compat.NullProgressable;
+
+@SuppressWarnings({"unchecked", "rawtypes"})
+public abstract class ExternalSorter {
+
+ private static final Log LOG = LogFactory.getLog(ExternalSorter.class);
+
+ public abstract void close() throws IOException;
+
+ public abstract void flush() throws IOException;
+
+ public abstract void write(Object key, Object value) throws IOException;
+
+ protected Progressable nullProgressable = new NullProgressable();
+ protected TezOutputContext outputContext;
+ protected Combiner combiner;
+ protected Partitioner partitioner;
+ protected Configuration conf;
+ protected FileSystem rfs;
+ protected TezTaskOutput mapOutputFile;
+ protected int partitions;
+ protected Class keyClass;
+ protected Class valClass;
+ protected RawComparator comparator;
+ protected SerializationFactory serializationFactory;
+ protected Serializer keySerializer;
+ protected Serializer valSerializer;
+
+ protected IndexedSorter sorter;
+
+ // Compression for map-outputs
+ protected CompressionCodec codec;
+
+ // Counters
+ // TODO TEZ Rename all counter variables [Mapping of counter to MR for compatibility in the MR layer]
+ protected TezCounter mapOutputByteCounter;
+ protected TezCounter mapOutputRecordCounter;
+ protected TezCounter fileOutputByteCounter;
+ protected TezCounter spilledRecordsCounter;
+
+ public void initialize(TezOutputContext outputContext, Configuration conf, int numOutputs) throws IOException {
+ this.outputContext = outputContext;
+ this.conf = conf;
+ this.partitions = numOutputs;
+
+ rfs = ((LocalFileSystem)FileSystem.getLocal(this.conf)).getRaw();
+
+ // sorter
+ sorter = ReflectionUtils.newInstance(this.conf.getClass(
+ TezJobConfig.TEZ_RUNTIME_INTERNAL_SORTER_CLASS, QuickSort.class,
+ IndexedSorter.class), this.conf);
+
+ comparator = ConfigUtils.getIntermediateOutputKeyComparator(this.conf);
+
+ // k/v serialization
+ keyClass = ConfigUtils.getIntermediateOutputKeyClass(this.conf);
+ valClass = ConfigUtils.getIntermediateOutputValueClass(this.conf);
+ serializationFactory = new SerializationFactory(this.conf);
+ keySerializer = serializationFactory.getSerializer(keyClass);
+ valSerializer = serializationFactory.getSerializer(valClass);
+
+ // counters
+ mapOutputByteCounter =
+ outputContext.getCounters().findCounter(TaskCounter.MAP_OUTPUT_BYTES);
+ mapOutputRecordCounter =
+ outputContext.getCounters().findCounter(TaskCounter.MAP_OUTPUT_RECORDS);
+ fileOutputByteCounter =
+ outputContext.getCounters().findCounter(TaskCounter.MAP_OUTPUT_MATERIALIZED_BYTES);
+ spilledRecordsCounter =
+ outputContext.getCounters().findCounter(TaskCounter.SPILLED_RECORDS);
+ // compression
+ if (ConfigUtils.shouldCompressIntermediateOutput(this.conf)) {
+ Class<? extends CompressionCodec> codecClass =
+ ConfigUtils.getIntermediateOutputCompressorClass(this.conf, DefaultCodec.class);
+ codec = ReflectionUtils.newInstance(codecClass, this.conf);
+ } else {
+ codec = null;
+ }
+
+ // Task outputs
+ mapOutputFile = TezRuntimeUtils.instantiateTaskOutputManager(conf, outputContext);
+
+ LOG.info("Instantiating Partitioner: [" + conf.get(TezJobConfig.TEZ_RUNTIME_PARTITIONER_CLASS) + "]");
+ this.conf.setInt(TezJobConfig.TEZ_RUNTIME_NUM_EXPECTED_PARTITIONS, this.partitions);
+ this.partitioner = TezRuntimeUtils.instantiatePartitioner(this.conf);
+ this.combiner = TezRuntimeUtils.instantiateCombiner(this.conf, outputContext);
+ }
+
+ /**
+ * Exception indicating that the allocated sort buffer is insufficient to hold
+ * the current record.
+ */
+ @SuppressWarnings("serial")
+ public static class MapBufferTooSmallException extends IOException {
+ public MapBufferTooSmallException(String s) {
+ super(s);
+ }
+ }
+
+ @Private
+ public TezTaskOutput getMapOutput() {
+ return mapOutputFile;
+ }
+
+ protected void runCombineProcessor(TezRawKeyValueIterator kvIter,
+ Writer writer) throws IOException {
+ try {
+ combiner.combine(kvIter, writer);
+ } catch (InterruptedException e) {
+ throw new IOException(e);
+ }
+ }
+
+ /**
+ * Rename srcPath to dstPath on the same volume. This is the same as
+ * RawLocalFileSystem's rename method, except that it will not fall back to a
+ * copy, and it will create the target directory if it doesn't exist.
+ */
+ protected void sameVolRename(Path srcPath, Path dstPath) throws IOException {
+ RawLocalFileSystem rfs = (RawLocalFileSystem) this.rfs;
+ File src = rfs.pathToFile(srcPath);
+ File dst = rfs.pathToFile(dstPath);
+ if (!dst.getParentFile().exists()) {
+ if (!dst.getParentFile().mkdirs()) {
+ throw new IOException("Unable to rename " + src + " to " + dst
+ + ": couldn't create parent directory");
+ }
+ }
+
+ if (!src.renameTo(dst)) {
+ throw new IOException("Unable to rename " + src + " to " + dst);
+ }
+ }
+
+ public InputStream getSortedStream(int partition) {
+ throw new UnsupportedOperationException("getSortedStream isn't supported!");
+ }
+
+ public ShuffleHeader getShuffleHeader(int reduce) {
+ throw new UnsupportedOperationException("getShuffleHeader isn't supported!");
+ }
+}