You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tez.apache.org by je...@apache.org on 2016/07/25 15:31:36 UTC
[1/2] tez git commit: TEZ-3355. Tez Custom Shuffle Handler POC
(jeagles)
Repository: tez
Updated Branches:
refs/heads/TEZ-3334 [created] 077dd88e0
http://git-wip-us.apache.org/repos/asf/tez/blob/077dd88e/tez-plugins/tez-aux-services/src/test/java/org/apache/tez/auxservices/TestShuffleHandler.java
----------------------------------------------------------------------
diff --git a/tez-plugins/tez-aux-services/src/test/java/org/apache/tez/auxservices/TestShuffleHandler.java b/tez-plugins/tez-aux-services/src/test/java/org/apache/tez/auxservices/TestShuffleHandler.java
new file mode 100644
index 0000000..ffab7dd
--- /dev/null
+++ b/tez-plugins/tez-aux-services/src/test/java/org/apache/tez/auxservices/TestShuffleHandler.java
@@ -0,0 +1,1127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.tez.auxservices;
+
+//import static org.apache.hadoop.test.MetricsAsserts.assertCounter;
+//import static org.apache.hadoop.test.MetricsAsserts.assertGauge;
+//import static org.apache.hadoop.test.MetricsAsserts.getMetrics;
+import static org.junit.Assert.assertTrue;
+import static org.jboss.netty.buffer.ChannelBuffers.wrappedBuffer;
+import static org.jboss.netty.handler.codec.http.HttpResponseStatus.OK;
+import static org.jboss.netty.handler.codec.http.HttpVersion.HTTP_1_1;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assume.assumeTrue;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+import java.io.DataInputStream;
+import java.io.EOFException;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.net.HttpURLConnection;
+import java.net.SocketException;
+import java.net.URL;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.zip.CheckedOutputStream;
+import java.util.zip.Checksum;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.CommonConfigurationKeysPublic;
+import org.apache.hadoop.fs.FSDataOutputStream;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.FileUtil;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.DataOutputBuffer;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.nativeio.NativeIO;
+import org.apache.hadoop.mapred.JobID;
+import org.apache.hadoop.mapred.MapTask;
+import org.apache.hadoop.mapreduce.TypeConverter;
+import org.apache.tez.runtime.library.common.security.SecureShuffleUtils;
+import org.apache.tez.common.security.JobTokenIdentifier;
+import org.apache.tez.common.security.JobTokenSecretManager;
+import org.apache.tez.runtime.library.common.shuffle.orderedgrouped.ShuffleHeader;
+import org.apache.hadoop.metrics2.MetricsRecordBuilder;
+import org.apache.hadoop.metrics2.MetricsSource;
+import org.apache.hadoop.metrics2.MetricsSystem;
+import org.apache.hadoop.metrics2.impl.MetricsSystemImpl;
+import org.apache.hadoop.security.UserGroupInformation;
+import org.apache.hadoop.security.token.Token;
+import org.apache.hadoop.service.ServiceStateException;
+import org.apache.hadoop.util.PureJavaCrc32;
+import org.apache.hadoop.util.StringUtils;
+import org.apache.hadoop.yarn.api.records.ApplicationId;
+import org.apache.hadoop.yarn.conf.YarnConfiguration;
+import org.apache.hadoop.yarn.server.api.ApplicationInitializationContext;
+import org.apache.hadoop.yarn.server.api.ApplicationTerminationContext;
+import org.apache.hadoop.yarn.server.records.Version;
+import org.jboss.netty.channel.Channel;
+import org.jboss.netty.channel.ChannelFuture;
+import org.jboss.netty.channel.ChannelHandlerContext;
+import org.jboss.netty.channel.socket.SocketChannel;
+import org.jboss.netty.channel.MessageEvent;
+import org.jboss.netty.channel.AbstractChannel;
+import org.jboss.netty.handler.codec.http.DefaultHttpResponse;
+import org.jboss.netty.handler.codec.http.HttpRequest;
+import org.jboss.netty.handler.codec.http.HttpResponse;
+import org.jboss.netty.handler.codec.http.HttpResponseStatus;
+import org.jboss.netty.handler.codec.http.HttpMethod;
+import org.junit.Assert;
+import org.junit.Test;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+import org.mockito.Mockito;
+import org.mortbay.jetty.HttpHeaders;
+
+public class TestShuffleHandler {
+ static final long MiB = 1024 * 1024;
+ private static final Log LOG = LogFactory.getLog(TestShuffleHandler.class);
+
+ class MockShuffleHandler extends org.apache.tez.auxservices.ShuffleHandler {
+ @Override
+ protected Shuffle getShuffle(final Configuration conf) {
+ return new Shuffle(conf) {
+ @Override
+ protected void verifyRequest(String appid, ChannelHandlerContext ctx,
+ HttpRequest request, HttpResponse response, URL requestUri)
+ throws IOException {
+ }
+ @Override
+ protected MapOutputInfo getMapOutputInfo(String mapId, int reduce,
+ String jobId, String user) throws IOException {
+ // Do nothing.
+ return null;
+ }
+ @Override
+ protected void populateHeaders(List<String> mapIds, String jobId,
+ String user, int reduce, HttpRequest request,
+ HttpResponse response, boolean keepAliveParam,
+ Map<String, MapOutputInfo> infoMap) throws IOException {
+ // Do nothing.
+ }
+ @Override
+ protected ChannelFuture sendMapOutput(ChannelHandlerContext ctx,
+ Channel ch, String user, String mapId, int reduce,
+ MapOutputInfo info) throws IOException {
+
+ ShuffleHeader header =
+ new ShuffleHeader("attempt_12345_1_m_1_0", 5678, 5678, 1);
+ DataOutputBuffer dob = new DataOutputBuffer();
+ header.write(dob);
+ ch.write(wrappedBuffer(dob.getData(), 0, dob.getLength()));
+ dob = new DataOutputBuffer();
+ for (int i = 0; i < 100; ++i) {
+ header.write(dob);
+ }
+ return ch.write(wrappedBuffer(dob.getData(), 0, dob.getLength()));
+ }
+ };
+ }
+ }
+
+ private static class MockShuffleHandler2 extends org.apache.tez.auxservices.ShuffleHandler {
+ boolean socketKeepAlive = false;
+
+ @Override
+ protected Shuffle getShuffle(final Configuration conf) {
+ return new Shuffle(conf) {
+ @Override
+ protected void verifyRequest(String appid, ChannelHandlerContext ctx,
+ HttpRequest request, HttpResponse response, URL requestUri)
+ throws IOException {
+ SocketChannel channel = (SocketChannel)(ctx.getChannel());
+ socketKeepAlive = channel.getConfig().isKeepAlive();
+ }
+ };
+ }
+
+ protected boolean isSocketKeepAlive() {
+ return socketKeepAlive;
+ }
+ }
+
+ /**
+ * Test the validation of ShuffleHandler's meta-data's serialization and
+ * de-serialization.
+ *
+ * @throws Exception exception
+ */
+ @Test (timeout = 10000)
+ public void testSerializeMeta() throws Exception {
+ assertEquals(1, ShuffleHandler.deserializeMetaData(
+ ShuffleHandler.serializeMetaData(1)));
+ assertEquals(-1, ShuffleHandler.deserializeMetaData(
+ ShuffleHandler.serializeMetaData(-1)));
+ assertEquals(8080, ShuffleHandler.deserializeMetaData(
+ ShuffleHandler.serializeMetaData(8080)));
+ }
+
+ /**
+ * Validate shuffle connection and input/output metrics.
+ *
+ * @throws Exception exception
+ */
+ @Test (timeout = 10000)
+ public void testShuffleMetrics() throws Exception {
+ MetricsSystem ms = new MetricsSystemImpl();
+ ShuffleHandler sh = new ShuffleHandler(ms);
+ ChannelFuture cf = mock(ChannelFuture.class);
+ when(cf.isSuccess()).thenReturn(true, false);
+
+ sh.metrics.shuffleConnections.incr();
+ sh.metrics.shuffleOutputBytes.incr(1*MiB);
+ sh.metrics.shuffleConnections.incr();
+ sh.metrics.shuffleOutputBytes.incr(2*MiB);
+
+ checkShuffleMetrics(ms, 3*MiB, 0 , 0, 2);
+
+ sh.metrics.operationComplete(cf);
+ sh.metrics.operationComplete(cf);
+
+ checkShuffleMetrics(ms, 3*MiB, 1, 1, 0);
+ }
+
+ static void checkShuffleMetrics(MetricsSystem ms, long bytes, int failed,
+ int succeeded, int connections) {
+ /* TODO
+ MetricsSource source = ms.getSource("ShuffleMetrics");
+ MetricsRecordBuilder rb = getMetrics(source);
+ assertCounter("ShuffleOutputBytes", bytes, rb);
+ assertCounter("ShuffleOutputsFailed", failed, rb);
+ assertCounter("ShuffleOutputsOK", succeeded, rb);
+ assertGauge("ShuffleConnections", connections, rb);
+ */
+ }
+
+ /**
+ * Verify client prematurely closing a connection.
+ *
+ * @throws Exception exception.
+ */
+ @Test (timeout = 10000)
+ public void testClientClosesConnection() throws Exception {
+ final ArrayList<Throwable> failures = new ArrayList<Throwable>(1);
+ Configuration conf = new Configuration();
+ conf.setInt(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY, 0);
+ ShuffleHandler shuffleHandler = new ShuffleHandler() {
+ @Override
+ protected Shuffle getShuffle(Configuration conf) {
+ // replace the shuffle handler with one stubbed for testing
+ return new Shuffle(conf) {
+ @Override
+ protected MapOutputInfo getMapOutputInfo(String mapId, int reduce,
+ String jobId, String user) throws IOException {
+ return null;
+ }
+ @Override
+ protected void populateHeaders(List<String> mapIds, String jobId,
+ String user, int reduce, HttpRequest request,
+ HttpResponse response, boolean keepAliveParam,
+ Map<String, MapOutputInfo> infoMap) throws IOException {
+ // Only set response headers and skip everything else
+ // send some dummy value for content-length
+ super.setResponseHeaders(response, keepAliveParam, 100);
+ }
+ @Override
+ protected void verifyRequest(String appid, ChannelHandlerContext ctx,
+ HttpRequest request, HttpResponse response, URL requestUri)
+ throws IOException {
+ }
+ @Override
+ protected ChannelFuture sendMapOutput(ChannelHandlerContext ctx,
+ Channel ch, String user, String mapId, int reduce,
+ MapOutputInfo info)
+ throws IOException {
+ // send a shuffle header and a lot of data down the channel
+ // to trigger a broken pipe
+ ShuffleHeader header =
+ new ShuffleHeader("attempt_12345_1_m_1_0", 5678, 5678, 1);
+ DataOutputBuffer dob = new DataOutputBuffer();
+ header.write(dob);
+ ch.write(wrappedBuffer(dob.getData(), 0, dob.getLength()));
+ dob = new DataOutputBuffer();
+ for (int i = 0; i < 100000; ++i) {
+ header.write(dob);
+ }
+ return ch.write(wrappedBuffer(dob.getData(), 0, dob.getLength()));
+ }
+ @Override
+ protected void sendError(ChannelHandlerContext ctx,
+ HttpResponseStatus status) {
+ if (failures.size() == 0) {
+ failures.add(new Error());
+ ctx.getChannel().close();
+ }
+ }
+ @Override
+ protected void sendError(ChannelHandlerContext ctx, String message,
+ HttpResponseStatus status) {
+ if (failures.size() == 0) {
+ failures.add(new Error());
+ ctx.getChannel().close();
+ }
+ }
+ };
+ }
+ };
+ shuffleHandler.init(conf);
+ shuffleHandler.start();
+
+ // simulate a reducer that closes early by reading a single shuffle header
+ // then closing the connection
+ URL url = new URL("http://127.0.0.1:"
+ + shuffleHandler.getConfig().get(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY)
+ + "/mapOutput?job=job_12345_1&reduce=1&map=attempt_12345_1_m_1_0");
+ HttpURLConnection conn = (HttpURLConnection)url.openConnection();
+ conn.setRequestProperty(ShuffleHeader.HTTP_HEADER_NAME,
+ ShuffleHeader.DEFAULT_HTTP_HEADER_NAME);
+ conn.setRequestProperty(ShuffleHeader.HTTP_HEADER_VERSION,
+ ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION);
+ conn.connect();
+ DataInputStream input = new DataInputStream(conn.getInputStream());
+ Assert.assertEquals(HttpURLConnection.HTTP_OK, conn.getResponseCode());
+ Assert.assertEquals("close", conn.getHeaderField(HttpHeaders.CONNECTION));
+ ShuffleHeader header = new ShuffleHeader();
+ header.readFields(input);
+ input.close();
+
+ shuffleHandler.stop();
+ Assert.assertTrue("sendError called when client closed connection",
+ failures.size() == 0);
+ }
+
+ @Test(timeout = 10000)
+ public void testKeepAlive() throws Exception {
+ final ArrayList<Throwable> failures = new ArrayList<Throwable>(1);
+ Configuration conf = new Configuration();
+ conf.setInt(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY, 0);
+ conf.setBoolean(ShuffleHandler.SHUFFLE_CONNECTION_KEEP_ALIVE_ENABLED, true);
+ // try setting to -ve keep alive timeout.
+ conf.setInt(ShuffleHandler.SHUFFLE_CONNECTION_KEEP_ALIVE_TIME_OUT, -100);
+ ShuffleHandler shuffleHandler = new ShuffleHandler() {
+ @Override
+ protected Shuffle getShuffle(final Configuration conf) {
+ // replace the shuffle handler with one stubbed for testing
+ return new Shuffle(conf) {
+ @Override
+ protected MapOutputInfo getMapOutputInfo(String mapId, int reduce,
+ String jobId, String user) throws IOException {
+ return null;
+ }
+ @Override
+ protected void verifyRequest(String appid, ChannelHandlerContext ctx,
+ HttpRequest request, HttpResponse response, URL requestUri)
+ throws IOException {
+ }
+
+ @Override
+ protected void populateHeaders(List<String> mapIds, String jobId,
+ String user, int reduce, HttpRequest request,
+ HttpResponse response, boolean keepAliveParam,
+ Map<String, MapOutputInfo> infoMap) throws IOException {
+ // Send some dummy data (populate content length details)
+ ShuffleHeader header =
+ new ShuffleHeader("attempt_12345_1_m_1_0", 5678, 5678, 1);
+ DataOutputBuffer dob = new DataOutputBuffer();
+ header.write(dob);
+ dob = new DataOutputBuffer();
+ for (int i = 0; i < 100000; ++i) {
+ header.write(dob);
+ }
+
+ long contentLength = dob.getLength();
+ // for testing purpose;
+ // disable connectinKeepAliveEnabled if keepAliveParam is available
+ if (keepAliveParam) {
+ connectionKeepAliveEnabled = false;
+ }
+
+ super.setResponseHeaders(response, keepAliveParam, contentLength);
+ }
+
+ @Override
+ protected ChannelFuture sendMapOutput(ChannelHandlerContext ctx,
+ Channel ch, String user, String mapId, int reduce,
+ MapOutputInfo info) throws IOException {
+ HttpResponse response = new DefaultHttpResponse(HTTP_1_1, OK);
+
+ // send a shuffle header and a lot of data down the channel
+ // to trigger a broken pipe
+ ShuffleHeader header =
+ new ShuffleHeader("attempt_12345_1_m_1_0", 5678, 5678, 1);
+ DataOutputBuffer dob = new DataOutputBuffer();
+ header.write(dob);
+ ch.write(wrappedBuffer(dob.getData(), 0, dob.getLength()));
+ dob = new DataOutputBuffer();
+ for (int i = 0; i < 100000; ++i) {
+ header.write(dob);
+ }
+ return ch.write(wrappedBuffer(dob.getData(), 0, dob.getLength()));
+ }
+
+ @Override
+ protected void sendError(ChannelHandlerContext ctx,
+ HttpResponseStatus status) {
+ if (failures.size() == 0) {
+ failures.add(new Error());
+ ctx.getChannel().close();
+ }
+ }
+
+ @Override
+ protected void sendError(ChannelHandlerContext ctx, String message,
+ HttpResponseStatus status) {
+ if (failures.size() == 0) {
+ failures.add(new Error());
+ ctx.getChannel().close();
+ }
+ }
+ };
+ }
+ };
+ shuffleHandler.init(conf);
+ shuffleHandler.start();
+
+ String shuffleBaseURL = "http://127.0.0.1:"
+ + shuffleHandler.getConfig().get(
+ ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY);
+ URL url =
+ new URL(shuffleBaseURL + "/mapOutput?job=job_12345_1&reduce=1&"
+ + "map=attempt_12345_1_m_1_0");
+ HttpURLConnection conn = (HttpURLConnection) url.openConnection();
+ conn.setRequestProperty(ShuffleHeader.HTTP_HEADER_NAME,
+ ShuffleHeader.DEFAULT_HTTP_HEADER_NAME);
+ conn.setRequestProperty(ShuffleHeader.HTTP_HEADER_VERSION,
+ ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION);
+ conn.connect();
+ DataInputStream input = new DataInputStream(conn.getInputStream());
+ Assert.assertEquals(HttpHeaders.KEEP_ALIVE,
+ conn.getHeaderField(HttpHeaders.CONNECTION));
+ Assert.assertEquals("timeout=1",
+ conn.getHeaderField(HttpHeaders.KEEP_ALIVE));
+ Assert.assertEquals(HttpURLConnection.HTTP_OK, conn.getResponseCode());
+ ShuffleHeader header = new ShuffleHeader();
+ header.readFields(input);
+ input.close();
+
+ // For keepAlive via URL
+ url =
+ new URL(shuffleBaseURL + "/mapOutput?job=job_12345_1&reduce=1&"
+ + "map=attempt_12345_1_m_1_0&keepAlive=true");
+ conn = (HttpURLConnection) url.openConnection();
+ conn.setRequestProperty(ShuffleHeader.HTTP_HEADER_NAME,
+ ShuffleHeader.DEFAULT_HTTP_HEADER_NAME);
+ conn.setRequestProperty(ShuffleHeader.HTTP_HEADER_VERSION,
+ ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION);
+ conn.connect();
+ input = new DataInputStream(conn.getInputStream());
+ Assert.assertEquals(HttpHeaders.KEEP_ALIVE,
+ conn.getHeaderField(HttpHeaders.CONNECTION));
+ Assert.assertEquals("timeout=1",
+ conn.getHeaderField(HttpHeaders.KEEP_ALIVE));
+ Assert.assertEquals(HttpURLConnection.HTTP_OK, conn.getResponseCode());
+ header = new ShuffleHeader();
+ header.readFields(input);
+ input.close();
+ }
+
+ @Test
+ public void testSocketKeepAlive() throws Exception {
+ Configuration conf = new Configuration();
+ conf.setInt(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY, 0);
+ conf.setBoolean(ShuffleHandler.SHUFFLE_CONNECTION_KEEP_ALIVE_ENABLED, true);
+ // try setting to -ve keep alive timeout.
+ conf.setInt(ShuffleHandler.SHUFFLE_CONNECTION_KEEP_ALIVE_TIME_OUT, -100);
+ HttpURLConnection conn = null;
+ MockShuffleHandler2 shuffleHandler = new MockShuffleHandler2();
+ try {
+ shuffleHandler.init(conf);
+ shuffleHandler.start();
+
+ String shuffleBaseURL = "http://127.0.0.1:"
+ + shuffleHandler.getConfig().get(
+ ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY);
+ URL url =
+ new URL(shuffleBaseURL + "/mapOutput?job=job_12345_1&reduce=1&"
+ + "map=attempt_12345_1_m_1_0");
+ conn = (HttpURLConnection) url.openConnection();
+ conn.setRequestProperty(ShuffleHeader.HTTP_HEADER_NAME,
+ ShuffleHeader.DEFAULT_HTTP_HEADER_NAME);
+ conn.setRequestProperty(ShuffleHeader.HTTP_HEADER_VERSION,
+ ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION);
+ conn.connect();
+ conn.getInputStream();
+ Assert.assertTrue("socket should be set KEEP_ALIVE",
+ shuffleHandler.isSocketKeepAlive());
+ } finally {
+ if (conn != null) {
+ conn.disconnect();
+ }
+ shuffleHandler.stop();
+ }
+ }
+
+ /**
+ * simulate a reducer that sends an invalid shuffle-header - sometimes a wrong
+ * header_name and sometimes a wrong version
+ *
+ * @throws Exception exception
+ */
+ @Test (timeout = 10000)
+ public void testIncompatibleShuffleVersion() throws Exception {
+ final int failureNum = 3;
+ Configuration conf = new Configuration();
+ conf.setInt(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY, 0);
+ ShuffleHandler shuffleHandler = new ShuffleHandler();
+ shuffleHandler.init(conf);
+ shuffleHandler.start();
+
+ // simulate a reducer that closes early by reading a single shuffle header
+ // then closing the connection
+ URL url = new URL("http://127.0.0.1:"
+ + shuffleHandler.getConfig().get(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY)
+ + "/mapOutput?job=job_12345_1&reduce=1&map=attempt_12345_1_m_1_0");
+ for (int i = 0; i < failureNum; ++i) {
+ HttpURLConnection conn = (HttpURLConnection)url.openConnection();
+ conn.setRequestProperty(ShuffleHeader.HTTP_HEADER_NAME,
+ i == 0 ? "mapreduce" : "other");
+ conn.setRequestProperty(ShuffleHeader.HTTP_HEADER_VERSION,
+ i == 1 ? "1.0.0" : "1.0.1");
+ conn.connect();
+ Assert.assertEquals(
+ HttpURLConnection.HTTP_BAD_REQUEST, conn.getResponseCode());
+ }
+
+ shuffleHandler.stop();
+ shuffleHandler.close();
+ }
+
+ /**
+ * Validate the limit on number of shuffle connections.
+ *
+ * @throws Exception exception
+ */
+ @Test (timeout = 10000)
+ public void testMaxConnections() throws Exception {
+
+ Configuration conf = new Configuration();
+ conf.setInt(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY, 0);
+ conf.setInt(ShuffleHandler.MAX_SHUFFLE_CONNECTIONS, 3);
+ ShuffleHandler shuffleHandler = new ShuffleHandler() {
+ @Override
+ protected Shuffle getShuffle(Configuration conf) {
+ // replace the shuffle handler with one stubbed for testing
+ return new Shuffle(conf) {
+ @Override
+ protected MapOutputInfo getMapOutputInfo(String mapId, int reduce,
+ String jobId, String user) throws IOException {
+ // Do nothing.
+ return null;
+ }
+ @Override
+ protected void populateHeaders(List<String> mapIds, String jobId,
+ String user, int reduce, HttpRequest request,
+ HttpResponse response, boolean keepAliveParam,
+ Map<String, MapOutputInfo> infoMap) throws IOException {
+ // Do nothing.
+ }
+ @Override
+ protected void verifyRequest(String appid, ChannelHandlerContext ctx,
+ HttpRequest request, HttpResponse response, URL requestUri)
+ throws IOException {
+ // Do nothing.
+ }
+ @Override
+ protected ChannelFuture sendMapOutput(ChannelHandlerContext ctx,
+ Channel ch, String user, String mapId, int reduce,
+ MapOutputInfo info)
+ throws IOException {
+ // send a shuffle header and a lot of data down the channel
+ // to trigger a broken pipe
+ ShuffleHeader header =
+ new ShuffleHeader("dummy_header", 5678, 5678, 1);
+ DataOutputBuffer dob = new DataOutputBuffer();
+ header.write(dob);
+ ch.write(wrappedBuffer(dob.getData(), 0, dob.getLength()));
+ dob = new DataOutputBuffer();
+ for (int i=0; i<100000; ++i) {
+ header.write(dob);
+ }
+ return ch.write(wrappedBuffer(dob.getData(), 0, dob.getLength()));
+ }
+ };
+ }
+ };
+ shuffleHandler.init(conf);
+ shuffleHandler.start();
+
+ // setup connections
+ int connAttempts = 3;
+ HttpURLConnection conns[] = new HttpURLConnection[connAttempts];
+
+ for (int i = 0; i < connAttempts; i++) {
+ String URLstring = "http://127.0.0.1:"
+ + shuffleHandler.getConfig().get(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY)
+ + "/mapOutput?job=job_12345_1&reduce=1&map=attempt_12345_1_m_"
+ + i + "_0";
+ URL url = new URL(URLstring);
+ conns[i] = (HttpURLConnection)url.openConnection();
+ conns[i].setRequestProperty(ShuffleHeader.HTTP_HEADER_NAME,
+ ShuffleHeader.DEFAULT_HTTP_HEADER_NAME);
+ conns[i].setRequestProperty(ShuffleHeader.HTTP_HEADER_VERSION,
+ ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION);
+ }
+
+ // Try to open numerous connections
+ for (int i = 0; i < connAttempts; i++) {
+ conns[i].connect();
+ }
+
+ //Ensure first connections are okay
+ conns[0].getInputStream();
+ int rc = conns[0].getResponseCode();
+ Assert.assertEquals(HttpURLConnection.HTTP_OK, rc);
+
+ conns[1].getInputStream();
+ rc = conns[1].getResponseCode();
+ Assert.assertEquals(HttpURLConnection.HTTP_OK, rc);
+
+ // This connection should be closed because it to above the limit
+ try {
+ conns[2].getInputStream();
+ rc = conns[2].getResponseCode();
+ Assert.fail("Expected a SocketException");
+ } catch (SocketException se) {
+ LOG.info("Expected - connection should not be open");
+ } catch (Exception e) {
+ Assert.fail("Expected a SocketException");
+ }
+
+ shuffleHandler.stop();
+ }
+
+ /**
+ * Validate the ownership of the map-output files being pulled in. The
+ * local-file-system owner of the file should match the user component in the
+ *
+ * @throws Exception exception
+ */
+ @Test(timeout = 100000)
+ public void testMapFileAccess() throws IOException {
+ // This will run only in NativeIO is enabled as SecureIOUtils need it
+ assumeTrue(NativeIO.isAvailable());
+ Configuration conf = new Configuration();
+ conf.setInt(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY, 0);
+ conf.setInt(ShuffleHandler.MAX_SHUFFLE_CONNECTIONS, 3);
+ conf.set(CommonConfigurationKeysPublic.HADOOP_SECURITY_AUTHENTICATION,
+ "kerberos");
+ UserGroupInformation.setConfiguration(conf);
+ File absLogDir = new File("target",
+ TestShuffleHandler.class.getSimpleName() + "LocDir").getAbsoluteFile();
+ conf.set(YarnConfiguration.NM_LOCAL_DIRS, absLogDir.getAbsolutePath());
+ ApplicationId appId = ApplicationId.newInstance(12345, 1);
+ LOG.info(appId.toString());
+ String appAttemptId = "attempt_12345_1_m_1_0";
+ String user = "randomUser";
+ String reducerId = "0";
+ List<File> fileMap = new ArrayList<File>();
+ createShuffleHandlerFiles(absLogDir, user, appId.toString(), appAttemptId,
+ conf, fileMap);
+ ShuffleHandler shuffleHandler = new ShuffleHandler() {
+
+ @Override
+ protected Shuffle getShuffle(Configuration conf) {
+ // replace the shuffle handler with one stubbed for testing
+ return new Shuffle(conf) {
+
+ @Override
+ protected void verifyRequest(String appid, ChannelHandlerContext ctx,
+ HttpRequest request, HttpResponse response, URL requestUri)
+ throws IOException {
+ // Do nothing.
+ }
+
+ };
+ }
+ };
+ shuffleHandler.init(conf);
+ try {
+ shuffleHandler.start();
+ DataOutputBuffer outputBuffer = new DataOutputBuffer();
+ outputBuffer.reset();
+ Token<JobTokenIdentifier> jt =
+ new Token<JobTokenIdentifier>("identifier".getBytes(),
+ "password".getBytes(), new Text(user), new Text("shuffleService"));
+ jt.write(outputBuffer);
+ shuffleHandler
+ .initializeApplication(new ApplicationInitializationContext(user,
+ appId, ByteBuffer.wrap(outputBuffer.getData(), 0,
+ outputBuffer.getLength())));
+ URL url =
+ new URL(
+ "http://127.0.0.1:"
+ + shuffleHandler.getConfig().get(
+ ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY)
+ + "/mapOutput?job=job_12345_0001&reduce=" + reducerId
+ + "&map=attempt_12345_1_m_1_0");
+ HttpURLConnection conn = (HttpURLConnection) url.openConnection();
+ conn.setRequestProperty(ShuffleHeader.HTTP_HEADER_NAME,
+ ShuffleHeader.DEFAULT_HTTP_HEADER_NAME);
+ conn.setRequestProperty(ShuffleHeader.HTTP_HEADER_VERSION,
+ ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION);
+ conn.connect();
+ byte[] byteArr = new byte[10000];
+ try {
+ DataInputStream is = new DataInputStream(conn.getInputStream());
+ is.readFully(byteArr);
+ } catch (EOFException e) {
+ // ignore
+ }
+ // Retrieve file owner name
+ FileInputStream is = new FileInputStream(fileMap.get(0));
+ String owner = NativeIO.POSIX.getFstat(is.getFD()).getOwner();
+ is.close();
+
+ String message =
+ "Owner '" + owner + "' for path " + fileMap.get(0).getAbsolutePath()
+ + " did not match expected owner '" + user + "'";
+ Assert.assertTrue((new String(byteArr)).contains(message));
+ } finally {
+ shuffleHandler.stop();
+ FileUtil.fullyDelete(absLogDir);
+ }
+ }
+
+ private static void createShuffleHandlerFiles(File logDir, String user,
+ String appId, String appAttemptId, Configuration conf,
+ List<File> fileMap) throws IOException {
+ String attemptDir =
+ StringUtils.join(Path.SEPARATOR,
+ new String[] { logDir.getAbsolutePath(),
+ ShuffleHandler.USERCACHE, user,
+ ShuffleHandler.APPCACHE, appId, "output", appAttemptId });
+ File appAttemptDir = new File(attemptDir);
+ appAttemptDir.mkdirs();
+ System.out.println(appAttemptDir.getAbsolutePath());
+ File indexFile = new File(appAttemptDir, "file.out.index");
+ fileMap.add(indexFile);
+ createIndexFile(indexFile, conf);
+ File mapOutputFile = new File(appAttemptDir, "file.out");
+ fileMap.add(mapOutputFile);
+ createMapOutputFile(mapOutputFile, conf);
+ }
+
+ private static void
+ createMapOutputFile(File mapOutputFile, Configuration conf)
+ throws IOException {
+ FileOutputStream out = new FileOutputStream(mapOutputFile);
+ out.write("Creating new dummy map output file. Used only for testing"
+ .getBytes());
+ out.flush();
+ out.close();
+ }
+
+ private static void createIndexFile(File indexFile, Configuration conf)
+ throws IOException {
+ if (indexFile.exists()) {
+ System.out.println("Deleting existing file");
+ indexFile.delete();
+ }
+ indexFile.createNewFile();
+ FSDataOutputStream output = FileSystem.getLocal(conf).getRaw().append(
+ new Path(indexFile.getAbsolutePath()));
+ Checksum crc = new PureJavaCrc32();
+ crc.reset();
+ CheckedOutputStream chk = new CheckedOutputStream(output, crc);
+ String msg = "Writing new index file. This file will be used only " +
+ "for the testing.";
+ chk.write(Arrays.copyOf(msg.getBytes(),
+ MapTask.MAP_OUTPUT_INDEX_RECORD_LENGTH));
+ output.writeLong(chk.getChecksum().getValue());
+ output.close();
+ }
+
+ @Test
+ public void testRecovery() throws IOException {
+ final String user = "someuser";
+ final ApplicationId appId = ApplicationId.newInstance(12345, 1);
+ final JobID jobId = JobID.downgrade(TypeConverter.fromYarn(appId));
+ final File tmpDir = new File(System.getProperty("test.build.data",
+ System.getProperty("java.io.tmpdir")),
+ TestShuffleHandler.class.getName());
+ Configuration conf = new Configuration();
+ conf.setInt(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY, 0);
+ conf.setInt(ShuffleHandler.MAX_SHUFFLE_CONNECTIONS, 3);
+ ShuffleHandler shuffle = new ShuffleHandler();
+ // emulate aux services startup with recovery enabled
+ shuffle.setRecoveryPath(new Path(tmpDir.toString()));
+ tmpDir.mkdirs();
+ try {
+ shuffle.init(conf);
+ shuffle.start();
+
+ // setup a shuffle token for an application
+ DataOutputBuffer outputBuffer = new DataOutputBuffer();
+ outputBuffer.reset();
+ Token<JobTokenIdentifier> jt = new Token<JobTokenIdentifier>(
+ "identifier".getBytes(), "password".getBytes(), new Text(user),
+ new Text("shuffleService"));
+ jt.write(outputBuffer);
+ shuffle.initializeApplication(new ApplicationInitializationContext(user,
+ appId, ByteBuffer.wrap(outputBuffer.getData(), 0,
+ outputBuffer.getLength())));
+
+ // verify we are authorized to shuffle
+ int rc = getShuffleResponseCode(shuffle, jt);
+ Assert.assertEquals(HttpURLConnection.HTTP_OK, rc);
+
+ // emulate shuffle handler restart
+ shuffle.close();
+ shuffle = new ShuffleHandler();
+ shuffle.setRecoveryPath(new Path(tmpDir.toString()));
+ shuffle.init(conf);
+ shuffle.start();
+
+ // verify we are still authorized to shuffle to the old application
+ rc = getShuffleResponseCode(shuffle, jt);
+ Assert.assertEquals(HttpURLConnection.HTTP_OK, rc);
+
+ // shutdown app and verify access is lost
+ shuffle.stopApplication(new ApplicationTerminationContext(appId));
+ rc = getShuffleResponseCode(shuffle, jt);
+ Assert.assertEquals(HttpURLConnection.HTTP_UNAUTHORIZED, rc);
+
+ // emulate shuffle handler restart
+ shuffle.close();
+ shuffle = new ShuffleHandler();
+ shuffle.setRecoveryPath(new Path(tmpDir.toString()));
+ shuffle.init(conf);
+ shuffle.start();
+
+ // verify we still don't have access
+ rc = getShuffleResponseCode(shuffle, jt);
+ Assert.assertEquals(HttpURLConnection.HTTP_UNAUTHORIZED, rc);
+ } finally {
+ if (shuffle != null) {
+ shuffle.close();
+ }
+ FileUtil.fullyDelete(tmpDir);
+ }
+ }
+
+ @Test
+ public void testRecoveryFromOtherVersions() throws IOException {
+ final String user = "someuser";
+ final ApplicationId appId = ApplicationId.newInstance(12345, 1);
+ final File tmpDir = new File(System.getProperty("test.build.data",
+ System.getProperty("java.io.tmpdir")),
+ TestShuffleHandler.class.getName());
+ Configuration conf = new Configuration();
+ conf.setInt(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY, 0);
+ conf.setInt(ShuffleHandler.MAX_SHUFFLE_CONNECTIONS, 3);
+ ShuffleHandler shuffle = new ShuffleHandler();
+ // emulate aux services startup with recovery enabled
+ shuffle.setRecoveryPath(new Path(tmpDir.toString()));
+ tmpDir.mkdirs();
+ try {
+ shuffle.init(conf);
+ shuffle.start();
+
+ // setup a shuffle token for an application
+ DataOutputBuffer outputBuffer = new DataOutputBuffer();
+ outputBuffer.reset();
+ Token<JobTokenIdentifier> jt = new Token<JobTokenIdentifier>(
+ "identifier".getBytes(), "password".getBytes(), new Text(user),
+ new Text("shuffleService"));
+ jt.write(outputBuffer);
+ shuffle.initializeApplication(new ApplicationInitializationContext(user,
+ appId, ByteBuffer.wrap(outputBuffer.getData(), 0,
+ outputBuffer.getLength())));
+
+ // verify we are authorized to shuffle
+ int rc = getShuffleResponseCode(shuffle, jt);
+ Assert.assertEquals(HttpURLConnection.HTTP_OK, rc);
+
+ // emulate shuffle handler restart
+ shuffle.close();
+ shuffle = new ShuffleHandler();
+ shuffle.setRecoveryPath(new Path(tmpDir.toString()));
+ shuffle.init(conf);
+ shuffle.start();
+
+ // verify we are still authorized to shuffle to the old application
+ rc = getShuffleResponseCode(shuffle, jt);
+ Assert.assertEquals(HttpURLConnection.HTTP_OK, rc);
+ Version version = Version.newInstance(1, 0);
+ Assert.assertEquals(version, shuffle.getCurrentVersion());
+
+ // emulate shuffle handler restart with compatible version
+ Version version11 = Version.newInstance(1, 1);
+ // update version info before close shuffle
+ shuffle.storeVersion(version11);
+ Assert.assertEquals(version11, shuffle.loadVersion());
+ shuffle.close();
+ shuffle = new ShuffleHandler();
+ shuffle.setRecoveryPath(new Path(tmpDir.toString()));
+ shuffle.init(conf);
+ shuffle.start();
+ // shuffle version will be override by CURRENT_VERSION_INFO after restart
+ // successfully.
+ Assert.assertEquals(version, shuffle.loadVersion());
+ // verify we are still authorized to shuffle to the old application
+ rc = getShuffleResponseCode(shuffle, jt);
+ Assert.assertEquals(HttpURLConnection.HTTP_OK, rc);
+
+ // emulate shuffle handler restart with incompatible version
+ Version version21 = Version.newInstance(2, 1);
+ shuffle.storeVersion(version21);
+ Assert.assertEquals(version21, shuffle.loadVersion());
+ shuffle.close();
+ shuffle = new ShuffleHandler();
+ shuffle.setRecoveryPath(new Path(tmpDir.toString()));
+ shuffle.init(conf);
+
+ try {
+ shuffle.start();
+ Assert.fail("Incompatible version, should expect fail here.");
+ } catch (ServiceStateException e) {
+ Assert.assertTrue("Exception message mismatch",
+ e.getMessage().contains("Incompatible version for state DB schema:"));
+ }
+
+ } finally {
+ if (shuffle != null) {
+ shuffle.close();
+ }
+ FileUtil.fullyDelete(tmpDir);
+ }
+ }
+
+ private static int getShuffleResponseCode(ShuffleHandler shuffle,
+ Token<JobTokenIdentifier> jt) throws IOException {
+ URL url = new URL("http://127.0.0.1:"
+ + shuffle.getConfig().get(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY)
+ + "/mapOutput?job=job_12345_0001&reduce=0&map=attempt_12345_1_m_1_0");
+ HttpURLConnection conn = (HttpURLConnection) url.openConnection();
+ String encHash = SecureShuffleUtils.hashFromString(
+ SecureShuffleUtils.buildMsgFrom(url),
+ new JobTokenSecretManager(JobTokenSecretManager.createSecretKey(jt.getPassword())));
+ conn.addRequestProperty(
+ SecureShuffleUtils.HTTP_HEADER_URL_HASH, encHash);
+ conn.setRequestProperty(ShuffleHeader.HTTP_HEADER_NAME,
+ ShuffleHeader.DEFAULT_HTTP_HEADER_NAME);
+ conn.setRequestProperty(ShuffleHeader.HTTP_HEADER_VERSION,
+ ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION);
+ conn.connect();
+ int rc = conn.getResponseCode();
+ conn.disconnect();
+ return rc;
+ }
+
+ @Test(timeout = 100000)
+ public void testGetMapOutputInfo() throws Exception {
+ final ArrayList<Throwable> failures = new ArrayList<Throwable>(1);
+ Configuration conf = new Configuration();
+ conf.setInt(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY, 0);
+ conf.setInt(ShuffleHandler.MAX_SHUFFLE_CONNECTIONS, 3);
+ conf.set(CommonConfigurationKeysPublic.HADOOP_SECURITY_AUTHENTICATION,
+ "simple");
+ UserGroupInformation.setConfiguration(conf);
+ File absLogDir = new File("target", TestShuffleHandler.class.
+ getSimpleName() + "LocDir").getAbsoluteFile();
+ conf.set(YarnConfiguration.NM_LOCAL_DIRS, absLogDir.getAbsolutePath());
+ ApplicationId appId = ApplicationId.newInstance(12345, 1);
+ String appAttemptId = "attempt_12345_1_m_1_0";
+ String user = "randomUser";
+ String reducerId = "0";
+ List<File> fileMap = new ArrayList<File>();
+ createShuffleHandlerFiles(absLogDir, user, appId.toString(), appAttemptId,
+ conf, fileMap);
+ ShuffleHandler shuffleHandler = new ShuffleHandler() {
+ @Override
+ protected Shuffle getShuffle(Configuration conf) {
+ // replace the shuffle handler with one stubbed for testing
+ return new Shuffle(conf) {
+ @Override
+ protected void populateHeaders(List<String> mapIds,
+ String outputBaseStr, String user, int reduce,
+ HttpRequest request, HttpResponse response,
+ boolean keepAliveParam, Map<String, MapOutputInfo> infoMap)
+ throws IOException {
+ // Only set response headers and skip everything else
+ // send some dummy value for content-length
+ super.setResponseHeaders(response, keepAliveParam, 100);
+ }
+ @Override
+ protected void verifyRequest(String appid,
+ ChannelHandlerContext ctx, HttpRequest request,
+ HttpResponse response, URL requestUri) throws IOException {
+ // Do nothing.
+ }
+ @Override
+ protected void sendError(ChannelHandlerContext ctx, String message,
+ HttpResponseStatus status) {
+ if (failures.size() == 0) {
+ failures.add(new Error(message));
+ ctx.getChannel().close();
+ }
+ }
+ @Override
+ protected ChannelFuture sendMapOutput(ChannelHandlerContext ctx,
+ Channel ch, String user, String mapId, int reduce,
+ MapOutputInfo info) throws IOException {
+ // send a shuffle header
+ ShuffleHeader header =
+ new ShuffleHeader("attempt_12345_1_m_1_0", 5678, 5678, 1);
+ DataOutputBuffer dob = new DataOutputBuffer();
+ header.write(dob);
+ return ch.write(wrappedBuffer(dob.getData(), 0, dob.getLength()));
+ }
+ };
+ }
+ };
+ shuffleHandler.init(conf);
+ try {
+ shuffleHandler.start();
+ DataOutputBuffer outputBuffer = new DataOutputBuffer();
+ outputBuffer.reset();
+ Token<JobTokenIdentifier> jt =
+ new Token<JobTokenIdentifier>("identifier".getBytes(),
+ "password".getBytes(), new Text(user), new Text("shuffleService"));
+ jt.write(outputBuffer);
+ shuffleHandler
+ .initializeApplication(new ApplicationInitializationContext(user,
+ appId, ByteBuffer.wrap(outputBuffer.getData(), 0,
+ outputBuffer.getLength())));
+ URL url =
+ new URL(
+ "http://127.0.0.1:"
+ + shuffleHandler.getConfig().get(
+ ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY)
+ + "/mapOutput?job=job_12345_0001&reduce=" + reducerId
+ + "&map=attempt_12345_1_m_1_0");
+ HttpURLConnection conn = (HttpURLConnection) url.openConnection();
+ conn.setRequestProperty(ShuffleHeader.HTTP_HEADER_NAME,
+ ShuffleHeader.DEFAULT_HTTP_HEADER_NAME);
+ conn.setRequestProperty(ShuffleHeader.HTTP_HEADER_VERSION,
+ ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION);
+ conn.connect();
+ try {
+ DataInputStream is = new DataInputStream(conn.getInputStream());
+ ShuffleHeader header = new ShuffleHeader();
+ header.readFields(is);
+ is.close();
+ } catch (EOFException e) {
+ // ignore
+ }
+ Assert.assertEquals("sendError called due to shuffle error",
+ 0, failures.size());
+ } finally {
+ shuffleHandler.stop();
+ FileUtil.fullyDelete(absLogDir);
+ }
+ }
+
+ @Test(timeout = 4000)
+ public void testSendMapCount() throws Exception {
+ final List<ShuffleHandler.ReduceMapFileCount> listenerList =
+ new ArrayList<ShuffleHandler.ReduceMapFileCount>();
+
+ final ChannelHandlerContext mockCtx =
+ mock(ChannelHandlerContext.class);
+ final MessageEvent mockEvt = mock(MessageEvent.class);
+ final Channel mockCh = mock(AbstractChannel.class);
+
+ // Mock HttpRequest and ChannelFuture
+ final HttpRequest mockHttpRequest = createMockHttpRequest();
+ final ChannelFuture mockFuture = createMockChannelFuture(mockCh,
+ listenerList);
+
+ // Mock Netty Channel Context and Channel behavior
+ Mockito.doReturn(mockCh).when(mockCtx).getChannel();
+ when(mockCtx.getChannel()).thenReturn(mockCh);
+ Mockito.doReturn(mockFuture).when(mockCh).write(Mockito.any(Object.class));
+ when(mockCh.write(Object.class)).thenReturn(mockFuture);
+
+ //Mock MessageEvent behavior
+ Mockito.doReturn(mockCh).when(mockEvt).getChannel();
+ when(mockEvt.getChannel()).thenReturn(mockCh);
+ Mockito.doReturn(mockHttpRequest).when(mockEvt).getMessage();
+
+ final ShuffleHandler sh = new MockShuffleHandler();
+ Configuration conf = new Configuration();
+ sh.init(conf);
+ sh.start();
+ int maxOpenFiles =conf.getInt(ShuffleHandler.SHUFFLE_MAX_SESSION_OPEN_FILES,
+ ShuffleHandler.DEFAULT_SHUFFLE_MAX_SESSION_OPEN_FILES);
+ sh.getShuffle(conf).messageReceived(mockCtx, mockEvt);
+ assertTrue("Number of Open files should not exceed the configured " +
+ "value!-Not Expected",
+ listenerList.size() <= maxOpenFiles);
+ while(!listenerList.isEmpty()) {
+ listenerList.remove(0).operationComplete(mockFuture);
+ assertTrue("Number of Open files should not exceed the configured " +
+ "value!-Not Expected",
+ listenerList.size() <= maxOpenFiles);
+ }
+ sh.close();
+ }
+
+ public ChannelFuture createMockChannelFuture(Channel mockCh,
+ final List<ShuffleHandler.ReduceMapFileCount> listenerList) {
+ final ChannelFuture mockFuture = mock(ChannelFuture.class);
+ when(mockFuture.getChannel()).thenReturn(mockCh);
+ Mockito.doReturn(true).when(mockFuture).isSuccess();
+ Mockito.doAnswer(new Answer() {
+ @Override
+ public Object answer(InvocationOnMock invocation) throws Throwable {
+ //Add ReduceMapFileCount listener to a list
+ if (invocation.getArguments()[0].getClass() ==
+ ShuffleHandler.ReduceMapFileCount.class)
+ listenerList.add((ShuffleHandler.ReduceMapFileCount)
+ invocation.getArguments()[0]);
+ return null;
+ }
+ }).when(mockFuture).addListener(Mockito.any(
+ ShuffleHandler.ReduceMapFileCount.class));
+ return mockFuture;
+ }
+
+ public HttpRequest createMockHttpRequest() {
+ HttpRequest mockHttpRequest = mock(HttpRequest.class);
+ Mockito.doReturn(HttpMethod.GET).when(mockHttpRequest).getMethod();
+ Mockito.doAnswer(new Answer() {
+ @Override
+ public Object answer(InvocationOnMock invocation) throws Throwable {
+ String uri = "/mapOutput?job=job_12345_1&reduce=1";
+ for (int i = 0; i < 100; i++)
+ uri = uri.concat("&map=attempt_12345_1_m_" + i + "_0");
+ return uri;
+ }
+ }).when(mockHttpRequest).getUri();
+ return mockHttpRequest;
+ }
+}
http://git-wip-us.apache.org/repos/asf/tez/blob/077dd88e/tez-plugins/tez-history-parser/pom.xml
----------------------------------------------------------------------
diff --git a/tez-plugins/tez-history-parser/pom.xml b/tez-plugins/tez-history-parser/pom.xml
index e151b07..3a6185a 100644
--- a/tez-plugins/tez-history-parser/pom.xml
+++ b/tez-plugins/tez-history-parser/pom.xml
@@ -70,11 +70,6 @@
</dependency>
<dependency>
<groupId>org.apache.hadoop</groupId>
- <artifactId>hadoop-hdfs</artifactId>
- <scope>test</scope>
- </dependency>
- <dependency>
- <groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-yarn-api</artifactId>
</dependency>
<dependency>
http://git-wip-us.apache.org/repos/asf/tez/blob/077dd88e/tez-plugins/tez-yarn-timeline-history-with-acls/pom.xml
----------------------------------------------------------------------
diff --git a/tez-plugins/tez-yarn-timeline-history-with-acls/pom.xml b/tez-plugins/tez-yarn-timeline-history-with-acls/pom.xml
index 5dcfb99..c85adca 100644
--- a/tez-plugins/tez-yarn-timeline-history-with-acls/pom.xml
+++ b/tez-plugins/tez-yarn-timeline-history-with-acls/pom.xml
@@ -68,11 +68,6 @@
</dependency>
<dependency>
<groupId>org.apache.hadoop</groupId>
- <artifactId>hadoop-hdfs</artifactId>
- <scope>test</scope>
- </dependency>
- <dependency>
- <groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-yarn-api</artifactId>
</dependency>
<dependency>
[2/2] tez git commit: TEZ-3355. Tez Custom Shuffle Handler POC
(jeagles)
Posted by je...@apache.org.
TEZ-3355. Tez Custom Shuffle Handler POC (jeagles)
Project: http://git-wip-us.apache.org/repos/asf/tez/repo
Commit: http://git-wip-us.apache.org/repos/asf/tez/commit/077dd88e
Tree: http://git-wip-us.apache.org/repos/asf/tez/tree/077dd88e
Diff: http://git-wip-us.apache.org/repos/asf/tez/diff/077dd88e
Branch: refs/heads/TEZ-3334
Commit: 077dd88e03b6e055b2c0bd8b7cb1986c7775658d
Parents: 97fa44f
Author: Jonathan Eagles <je...@yahoo-inc.com>
Authored: Mon Jul 25 10:29:31 2016 -0500
Committer: Jonathan Eagles <je...@yahoo-inc.com>
Committed: Mon Jul 25 10:29:31 2016 -0500
----------------------------------------------------------------------
TEZ-3334-CHANGES.txt | 7 +
pom.xml | 25 +
tez-dist/src/main/assembly/tez-dist-minimal.xml | 3 +
tez-dist/src/main/assembly/tez-dist.xml | 3 +
tez-plugins/pom.xml | 2 +
.../tez-aux-services/findbugs-exclude.xml | 16 +
tez-plugins/tez-aux-services/pom.xml | 108 ++
.../org/apache/tez/auxservices/IndexCache.java | 195 +++
.../apache/tez/auxservices/ShuffleHandler.java | 1343 ++++++++++++++++++
.../tez/auxservices/TestShuffleHandler.java | 1127 +++++++++++++++
tez-plugins/tez-history-parser/pom.xml | 5 -
.../tez-yarn-timeline-history-with-acls/pom.xml | 5 -
12 files changed, 2829 insertions(+), 10 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/tez/blob/077dd88e/TEZ-3334-CHANGES.txt
----------------------------------------------------------------------
diff --git a/TEZ-3334-CHANGES.txt b/TEZ-3334-CHANGES.txt
new file mode 100644
index 0000000..f779000
--- /dev/null
+++ b/TEZ-3334-CHANGES.txt
@@ -0,0 +1,7 @@
+Apache Tez Change Log
+=====================
+
+INCOMPATIBLE CHANGES:
+
+ALL CHANGES:
+ TEZ-3355. Tez Custom Shuffle Handler POC
http://git-wip-us.apache.org/repos/asf/tez/blob/077dd88e/pom.xml
----------------------------------------------------------------------
diff --git a/pom.xml b/pom.xml
index 336f5cb..6e4fe40 100644
--- a/pom.xml
+++ b/pom.xml
@@ -580,6 +580,31 @@
</dependency>
<dependency>
<groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-yarn-server-common</artifactId>
+ <version>${hadoop.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-mapreduce-client-shuffle</artifactId>
+ <scope>provided</scope>
+ <version>${hadoop.version}</version>
+ <exclusions>
+ <exclusion>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-yarn-server-common</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-yarn-server-nodemanager</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-mapreduce-client-common</artifactId>
+ </exclusion>
+ </exclusions>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-mapreduce-client-jobclient</artifactId>
<scope>test</scope>
<type>test-jar</type>
http://git-wip-us.apache.org/repos/asf/tez/blob/077dd88e/tez-dist/src/main/assembly/tez-dist-minimal.xml
----------------------------------------------------------------------
diff --git a/tez-dist/src/main/assembly/tez-dist-minimal.xml b/tez-dist/src/main/assembly/tez-dist-minimal.xml
index 869e5b0..80633ff 100644
--- a/tez-dist/src/main/assembly/tez-dist-minimal.xml
+++ b/tez-dist/src/main/assembly/tez-dist-minimal.xml
@@ -22,6 +22,9 @@
<moduleSets>
<moduleSet>
<useAllReactorProjects>true</useAllReactorProjects>
+ <excludes>
+ <exclude>org.apache.tez:tez-aux-services</exclude>
+ </excludes>
<binaries>
<outputDirectory>/</outputDirectory>
<unpack>false</unpack>
http://git-wip-us.apache.org/repos/asf/tez/blob/077dd88e/tez-dist/src/main/assembly/tez-dist.xml
----------------------------------------------------------------------
diff --git a/tez-dist/src/main/assembly/tez-dist.xml b/tez-dist/src/main/assembly/tez-dist.xml
index a181546..b8834a8 100644
--- a/tez-dist/src/main/assembly/tez-dist.xml
+++ b/tez-dist/src/main/assembly/tez-dist.xml
@@ -22,6 +22,9 @@
<moduleSets>
<moduleSet>
<useAllReactorProjects>true</useAllReactorProjects>
+ <excludes>
+ <exclude>org.apache.tez:tez-aux-services</exclude>
+ </excludes>
<binaries>
<outputDirectory>/</outputDirectory>
<unpack>false</unpack>
http://git-wip-us.apache.org/repos/asf/tez/blob/077dd88e/tez-plugins/pom.xml
----------------------------------------------------------------------
diff --git a/tez-plugins/pom.xml b/tez-plugins/pom.xml
index 27707a8..ffe59b9 100644
--- a/tez-plugins/pom.xml
+++ b/tez-plugins/pom.xml
@@ -48,6 +48,7 @@
<module>tez-yarn-timeline-history</module>
<module>tez-yarn-timeline-history-with-acls</module>
<module>tez-history-parser</module>
+ <module>tez-aux-services</module>
</modules>
</profile>
<profile>
@@ -61,6 +62,7 @@
<module>tez-yarn-timeline-cache-plugin</module>
<module>tez-yarn-timeline-history-with-fs</module>
<module>tez-history-parser</module>
+ <module>tez-aux-services</module>
</modules>
</profile>
http://git-wip-us.apache.org/repos/asf/tez/blob/077dd88e/tez-plugins/tez-aux-services/findbugs-exclude.xml
----------------------------------------------------------------------
diff --git a/tez-plugins/tez-aux-services/findbugs-exclude.xml b/tez-plugins/tez-aux-services/findbugs-exclude.xml
new file mode 100644
index 0000000..5b11308
--- /dev/null
+++ b/tez-plugins/tez-aux-services/findbugs-exclude.xml
@@ -0,0 +1,16 @@
+<!--
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License. See accompanying LICENSE file.
+-->
+<FindBugsFilter>
+
+</FindBugsFilter>
http://git-wip-us.apache.org/repos/asf/tez/blob/077dd88e/tez-plugins/tez-aux-services/pom.xml
----------------------------------------------------------------------
diff --git a/tez-plugins/tez-aux-services/pom.xml b/tez-plugins/tez-aux-services/pom.xml
new file mode 100644
index 0000000..c30555b
--- /dev/null
+++ b/tez-plugins/tez-aux-services/pom.xml
@@ -0,0 +1,108 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ ~ Licensed under the Apache License, Version 2.0 (the "License");
+ ~ you may not use this file except in compliance with the License.
+ ~ You may obtain a copy of the License at
+ ~
+ ~ http://www.apache.org/licenses/LICENSE-2.0
+ ~
+ ~ Unless required by applicable law or agreed to in writing, software
+ ~ distributed under the License is distributed on an "AS IS" BASIS,
+ ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ ~ See the License for the specific language governing permissions and
+ ~ limitations under the License.
+ -->
+
+<project xmlns="http://maven.apache.org/POM/4.0.0"
+ xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
+ <modelVersion>4.0.0</modelVersion>
+ <parent>
+ <artifactId>tez-plugins</artifactId>
+ <groupId>org.apache.tez</groupId>
+ <version>0.9.0-SNAPSHOT</version>
+ </parent>
+
+ <artifactId>tez-aux-services</artifactId>
+
+ <dependencies>
+ <dependency>
+ <groupId>org.slf4j</groupId>
+ <artifactId>slf4j-log4j12</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>com.google.guava</groupId>
+ <artifactId>guava</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-common</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>junit</groupId>
+ <artifactId>junit</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.mockito</groupId>
+ <artifactId>mockito-all</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <!-- Required for the ShuffleHandler -->
+ <groupId>org.apache.tez</groupId>
+ <artifactId>tez-runtime-library</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-yarn-server-common</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-mapreduce-client-core</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-mapreduce-client-shuffle</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.tez</groupId>
+ <artifactId>tez-mapreduce</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.mortbay.jetty</groupId>
+ <artifactId>jetty</artifactId>
+ </dependency>
+ </dependencies>
+
+ <build>
+ <!--
+ Include all files in src/main/resources. By default, do not apply property
+ substitution (filtering=false), but do apply property substitution to
+ version-info.properties (filtering=true). This will substitute the
+ version information correctly, but prevent Maven from altering other files.
+ -->
+ <resources>
+ <resource>
+ <directory>${basedir}/src/main/resources</directory>
+ <excludes>
+ <exclude>tez-api-version-info.properties</exclude>
+ </excludes>
+ <filtering>false</filtering>
+ </resource>
+ <resource>
+ <directory>${basedir}/src/main/resources</directory>
+ <includes>
+ <include>tez-api-version-info.properties</include>
+ </includes>
+ <filtering>true</filtering>
+ </resource>
+ </resources>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.rat</groupId>
+ <artifactId>apache-rat-plugin</artifactId>
+ </plugin>
+ </plugins>
+ </build>
+
+</project>
http://git-wip-us.apache.org/repos/asf/tez/blob/077dd88e/tez-plugins/tez-aux-services/src/main/java/org/apache/tez/auxservices/IndexCache.java
----------------------------------------------------------------------
diff --git a/tez-plugins/tez-aux-services/src/main/java/org/apache/tez/auxservices/IndexCache.java b/tez-plugins/tez-aux-services/src/main/java/org/apache/tez/auxservices/IndexCache.java
new file mode 100644
index 0000000..532187e
--- /dev/null
+++ b/tez-plugins/tez-aux-services/src/main/java/org/apache/tez/auxservices/IndexCache.java
@@ -0,0 +1,195 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.tez.auxservices;
+
+import java.io.IOException;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.tez.runtime.library.common.Constants;
+import org.apache.tez.runtime.library.common.sort.impl.TezIndexRecord;
+import org.apache.tez.runtime.library.common.sort.impl.TezSpillRecord;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+class IndexCache {
+
+ private final Configuration conf;
+ private final int totalMemoryAllowed;
+ private AtomicInteger totalMemoryUsed = new AtomicInteger();
+ private static final Logger LOG = LoggerFactory.getLogger(IndexCache.class);
+
+ private final ConcurrentHashMap<String,IndexInformation> cache =
+ new ConcurrentHashMap<String,IndexInformation>();
+
+ private final LinkedBlockingQueue<String> queue =
+ new LinkedBlockingQueue<String>();
+
+ public IndexCache(Configuration conf) {
+ this.conf = conf;
+ totalMemoryAllowed = 10 * 1024 * 1024;
+ LOG.info("IndexCache created with max memory = " + totalMemoryAllowed);
+ }
+
+ /**
+ * This method gets the index information for the given mapId and reduce.
+ * It reads the index file into cache if it is not already present.
+ * @param mapId
+ * @param reduce
+ * @param fileName The file to read the index information from if it is not
+ * already present in the cache
+ * @param expectedIndexOwner The expected owner of the index file
+ * @return The Index Information
+ * @throws IOException
+ */
+ public TezIndexRecord getIndexInformation(String mapId, int reduce,
+ Path fileName, String expectedIndexOwner)
+ throws IOException {
+
+ IndexInformation info = cache.get(mapId);
+
+ if (info == null) {
+ info = readIndexFileToCache(fileName, mapId, expectedIndexOwner);
+ } else {
+ synchronized(info) {
+ while (isUnderConstruction(info)) {
+ try {
+ info.wait();
+ } catch (InterruptedException e) {
+ throw new IOException("Interrupted waiting for construction", e);
+ }
+ }
+ }
+ LOG.debug("IndexCache HIT: MapId " + mapId + " found");
+ }
+
+ if (info.mapSpillRecord.size() == 0 ||
+ info.mapSpillRecord.size() <= reduce) {
+ throw new IOException("Invalid request " +
+ " Map Id = " + mapId + " Reducer = " + reduce +
+ " Index Info Length = " + info.mapSpillRecord.size());
+ }
+ return info.mapSpillRecord.getIndex(reduce);
+ }
+
+ private boolean isUnderConstruction(IndexInformation info) {
+ synchronized(info) {
+ return (null == info.mapSpillRecord);
+ }
+ }
+
+ private IndexInformation readIndexFileToCache(Path indexFileName,
+ String mapId,
+ String expectedIndexOwner)
+ throws IOException {
+ IndexInformation info;
+ IndexInformation newInd = new IndexInformation();
+ if ((info = cache.putIfAbsent(mapId, newInd)) != null) {
+ synchronized(info) {
+ while (isUnderConstruction(info)) {
+ try {
+ info.wait();
+ } catch (InterruptedException e) {
+ throw new IOException("Interrupted waiting for construction", e);
+ }
+ }
+ }
+ LOG.debug("IndexCache HIT: MapId " + mapId + " found");
+ return info;
+ }
+ LOG.debug("IndexCache MISS: MapId " + mapId + " not found") ;
+ TezSpillRecord tmp = null;
+ try {
+ tmp = new TezSpillRecord(indexFileName, conf, expectedIndexOwner);
+ } catch (Throwable e) {
+ tmp = new TezSpillRecord(0);
+ cache.remove(mapId);
+ throw new IOException("Error Reading IndexFile", e);
+ } finally {
+ synchronized (newInd) {
+ newInd.mapSpillRecord = tmp;
+ newInd.notifyAll();
+ }
+ }
+ queue.add(mapId);
+
+ if (totalMemoryUsed.addAndGet(newInd.getSize()) > totalMemoryAllowed) {
+ freeIndexInformation();
+ }
+ return newInd;
+ }
+
+ /**
+ * This method removes the map from the cache if index information for this
+ * map is loaded(size>0), index information entry in cache will not be
+ * removed if it is in the loading phrase(size=0), this prevents corruption
+ * of totalMemoryUsed. It should be called when a map output on this tracker
+ * is discarded.
+ * @param mapId The taskID of this map.
+ */
+ public void removeMap(String mapId) {
+ IndexInformation info = cache.get(mapId);
+ if (info == null || ((info != null) && isUnderConstruction(info))) {
+ return;
+ }
+ info = cache.remove(mapId);
+ if (info != null) {
+ totalMemoryUsed.addAndGet(-info.getSize());
+ if (!queue.remove(mapId)) {
+ LOG.warn("Map ID" + mapId + " not found in queue!!");
+ }
+ } else {
+ LOG.info("Map ID " + mapId + " not found in cache");
+ }
+ }
+
+ /**
+ * This method checks if cache and totolMemoryUsed is consistent.
+ * It is only used for unit test.
+ * @return True if cache and totolMemoryUsed is consistent
+ */
+ boolean checkTotalMemoryUsed() {
+ int totalSize = 0;
+ for (IndexInformation info : cache.values()) {
+ totalSize += info.getSize();
+ }
+ return totalSize == totalMemoryUsed.get();
+ }
+
+ /**
+ * Bring memory usage below totalMemoryAllowed.
+ */
+ private synchronized void freeIndexInformation() {
+ while (totalMemoryUsed.get() > totalMemoryAllowed) {
+ String s = queue.remove();
+ IndexInformation info = cache.remove(s);
+ if (info != null) {
+ totalMemoryUsed.addAndGet(-info.getSize());
+ }
+ }
+ }
+
+ private static class IndexInformation {
+ TezSpillRecord mapSpillRecord;
+
+ int getSize() {
+ return mapSpillRecord == null
+ ? 0
+ : mapSpillRecord.size() * Constants.MAP_OUTPUT_INDEX_RECORD_LENGTH;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/tez/blob/077dd88e/tez-plugins/tez-aux-services/src/main/java/org/apache/tez/auxservices/ShuffleHandler.java
----------------------------------------------------------------------
diff --git a/tez-plugins/tez-aux-services/src/main/java/org/apache/tez/auxservices/ShuffleHandler.java b/tez-plugins/tez-aux-services/src/main/java/org/apache/tez/auxservices/ShuffleHandler.java
new file mode 100644
index 0000000..c8eb238
--- /dev/null
+++ b/tez-plugins/tez-aux-services/src/main/java/org/apache/tez/auxservices/ShuffleHandler.java
@@ -0,0 +1,1343 @@
+/**
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements. See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership. The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.tez.auxservices;
+
+import static org.fusesource.leveldbjni.JniDBFactory.asString;
+import static org.fusesource.leveldbjni.JniDBFactory.bytes;
+import static org.jboss.netty.buffer.ChannelBuffers.wrappedBuffer;
+import static org.jboss.netty.handler.codec.http.HttpHeaders.Names.CONTENT_TYPE;
+import static org.jboss.netty.handler.codec.http.HttpMethod.GET;
+import static org.jboss.netty.handler.codec.http.HttpResponseStatus.BAD_REQUEST;
+import static org.jboss.netty.handler.codec.http.HttpResponseStatus.FORBIDDEN;
+import static org.jboss.netty.handler.codec.http.HttpResponseStatus.INTERNAL_SERVER_ERROR;
+import static org.jboss.netty.handler.codec.http.HttpResponseStatus.METHOD_NOT_ALLOWED;
+import static org.jboss.netty.handler.codec.http.HttpResponseStatus.NOT_FOUND;
+import static org.jboss.netty.handler.codec.http.HttpResponseStatus.OK;
+import static org.jboss.netty.handler.codec.http.HttpResponseStatus.UNAUTHORIZED;
+import static org.jboss.netty.handler.codec.http.HttpVersion.HTTP_1_1;
+
+import java.io.File;
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import java.io.RandomAccessFile;
+import java.net.InetSocketAddress;
+import java.net.URL;
+import java.nio.ByteBuffer;
+import java.nio.channels.ClosedChannelException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ThreadFactory;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.regex.Pattern;
+
+import javax.crypto.SecretKey;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.LocalDirAllocator;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.DataInputByteBuffer;
+import org.apache.hadoop.io.DataOutputBuffer;
+import org.apache.hadoop.io.ReadaheadPool;
+import org.apache.hadoop.io.SecureIOUtils;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapred.FadvisedChunkedFile;
+import org.apache.hadoop.mapred.FadvisedFileRegion;
+import org.apache.hadoop.mapred.proto.ShuffleHandlerRecoveryProtos.JobShuffleInfoProto;
+import org.apache.hadoop.mapreduce.JobID;
+import org.apache.tez.mapreduce.hadoop.MRConfig;
+import org.apache.tez.common.security.JobTokenIdentifier;
+import org.apache.tez.common.security.JobTokenSecretManager;
+import org.apache.tez.runtime.library.common.security.SecureShuffleUtils;
+import org.apache.tez.runtime.library.common.shuffle.orderedgrouped.ShuffleHeader;
+import org.apache.hadoop.metrics2.MetricsSystem;
+import org.apache.hadoop.metrics2.annotation.Metric;
+import org.apache.hadoop.metrics2.annotation.Metrics;
+import org.apache.hadoop.metrics2.lib.DefaultMetricsSystem;
+import org.apache.hadoop.metrics2.lib.MutableCounterInt;
+import org.apache.hadoop.metrics2.lib.MutableCounterLong;
+import org.apache.hadoop.metrics2.lib.MutableGaugeInt;
+import org.apache.hadoop.security.proto.SecurityProtos.TokenProto;
+import org.apache.hadoop.security.ssl.SSLFactory;
+import org.apache.hadoop.security.token.Token;
+import org.apache.tez.runtime.library.common.sort.impl.TezIndexRecord;
+import org.apache.hadoop.util.Shell;
+import org.apache.hadoop.yarn.api.records.ApplicationId;
+import org.apache.hadoop.yarn.conf.YarnConfiguration;
+import org.apache.hadoop.yarn.proto.YarnServerCommonProtos.VersionProto;
+import org.apache.hadoop.yarn.server.api.ApplicationInitializationContext;
+import org.apache.hadoop.yarn.server.api.ApplicationTerminationContext;
+import org.apache.hadoop.yarn.server.api.AuxiliaryService;
+import org.apache.hadoop.yarn.server.records.Version;
+import org.apache.hadoop.yarn.server.records.impl.pb.VersionPBImpl;
+import org.apache.hadoop.yarn.server.utils.LeveldbIterator;
+import org.fusesource.leveldbjni.JniDBFactory;
+import org.fusesource.leveldbjni.internal.NativeDB;
+import org.iq80.leveldb.DB;
+import org.iq80.leveldb.DBException;
+import org.iq80.leveldb.Logger;
+import org.iq80.leveldb.Options;
+import org.jboss.netty.bootstrap.ServerBootstrap;
+import org.jboss.netty.buffer.ChannelBuffers;
+import org.jboss.netty.channel.Channel;
+import org.jboss.netty.channel.ChannelFactory;
+import org.jboss.netty.channel.ChannelFuture;
+import org.jboss.netty.channel.ChannelFutureListener;
+import org.jboss.netty.channel.ChannelHandlerContext;
+import org.jboss.netty.channel.ChannelPipeline;
+import org.jboss.netty.channel.ChannelPipelineFactory;
+import org.jboss.netty.channel.ChannelStateEvent;
+import org.jboss.netty.channel.Channels;
+import org.jboss.netty.channel.ExceptionEvent;
+import org.jboss.netty.channel.MessageEvent;
+import org.jboss.netty.channel.SimpleChannelUpstreamHandler;
+import org.jboss.netty.channel.group.ChannelGroup;
+import org.jboss.netty.channel.group.DefaultChannelGroup;
+import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory;
+import org.jboss.netty.handler.codec.frame.TooLongFrameException;
+import org.jboss.netty.handler.codec.http.DefaultHttpResponse;
+import org.jboss.netty.handler.codec.http.HttpChunkAggregator;
+import org.jboss.netty.handler.codec.http.HttpRequest;
+import org.jboss.netty.handler.codec.http.HttpRequestDecoder;
+import org.jboss.netty.handler.codec.http.HttpResponse;
+import org.jboss.netty.handler.codec.http.HttpResponseEncoder;
+import org.jboss.netty.handler.codec.http.HttpResponseStatus;
+import org.jboss.netty.handler.codec.http.QueryStringDecoder;
+import org.jboss.netty.handler.ssl.SslHandler;
+import org.jboss.netty.handler.stream.ChunkedWriteHandler;
+import org.jboss.netty.util.CharsetUtil;
+import org.mortbay.jetty.HttpHeaders;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Charsets;
+import com.google.common.cache.CacheBuilder;
+import com.google.common.cache.CacheLoader;
+import com.google.common.cache.LoadingCache;
+import com.google.common.cache.RemovalListener;
+import com.google.common.cache.RemovalNotification;
+import com.google.common.cache.Weigher;
+import com.google.common.util.concurrent.ThreadFactoryBuilder;
+import com.google.protobuf.ByteString;
+
+public class ShuffleHandler extends AuxiliaryService {
+
+ private static final Log LOG = LogFactory.getLog(ShuffleHandler.class);
+ private static final Log AUDITLOG =
+ LogFactory.getLog(ShuffleHandler.class.getName()+".audit");
+ public static final String SHUFFLE_MANAGE_OS_CACHE = "mapreduce.shuffle.manage.os.cache";
+ public static final boolean DEFAULT_SHUFFLE_MANAGE_OS_CACHE = true;
+
+ public static final String SHUFFLE_READAHEAD_BYTES = "mapreduce.shuffle.readahead.bytes";
+ public static final int DEFAULT_SHUFFLE_READAHEAD_BYTES = 4 * 1024 * 1024;
+ public static final String USERCACHE = "usercache";
+ public static final String APPCACHE = "appcache";
+
+ // pattern to identify errors related to the client closing the socket early
+ // idea borrowed from Netty SslHandler
+ private static final Pattern IGNORABLE_ERROR_MESSAGE = Pattern.compile(
+ "^.*(?:connection.*reset|connection.*closed|broken.*pipe).*$",
+ Pattern.CASE_INSENSITIVE);
+
+ private static final String STATE_DB_NAME = "mapreduce_shuffle_state";
+ private static final String STATE_DB_SCHEMA_VERSION_KEY = "shuffle-schema-version";
+ protected static final Version CURRENT_VERSION_INFO =
+ Version.newInstance(1, 0);
+
+ private static final String DATA_FILE_NAME = "file.out";
+ private static final String INDEX_FILE_NAME = "file.out.index";
+
+ private int port;
+ private ChannelFactory selector;
+ private final ChannelGroup accepted = new DefaultChannelGroup();
+ protected HttpPipelineFactory pipelineFact;
+ private int sslFileBufferSize;
+
+ /**
+ * Should the shuffle use posix_fadvise calls to manage the OS cache during
+ * sendfile
+ */
+ private boolean manageOsCache;
+ private int readaheadLength;
+ private int maxShuffleConnections;
+ private int shuffleBufferSize;
+ private boolean shuffleTransferToAllowed;
+ private int maxSessionOpenFiles;
+ private ReadaheadPool readaheadPool = ReadaheadPool.getInstance();
+
+ private Map<String,String> userRsrc;
+ private JobTokenSecretManager secretManager;
+
+ private DB stateDb = null;
+
+ public static final String MAPREDUCE_SHUFFLE_SERVICEID =
+ "mapreduce_shuffle";
+
+ public static final String SHUFFLE_PORT_CONFIG_KEY = "mapreduce.shuffle.port";
+ public static final int DEFAULT_SHUFFLE_PORT = 13562;
+
+ public static final String SHUFFLE_CONNECTION_KEEP_ALIVE_ENABLED =
+ "mapreduce.shuffle.connection-keep-alive.enable";
+ public static final boolean DEFAULT_SHUFFLE_CONNECTION_KEEP_ALIVE_ENABLED = false;
+
+ public static final String SHUFFLE_CONNECTION_KEEP_ALIVE_TIME_OUT =
+ "mapreduce.shuffle.connection-keep-alive.timeout";
+ public static final int DEFAULT_SHUFFLE_CONNECTION_KEEP_ALIVE_TIME_OUT = 5; //seconds
+
+ public static final String SHUFFLE_MAPOUTPUT_META_INFO_CACHE_SIZE =
+ "mapreduce.shuffle.mapoutput-info.meta.cache.size";
+ public static final int DEFAULT_SHUFFLE_MAPOUTPUT_META_INFO_CACHE_SIZE =
+ 1000;
+
+ public static final String CONNECTION_CLOSE = "close";
+
+ public static final String SUFFLE_SSL_FILE_BUFFER_SIZE_KEY =
+ "mapreduce.shuffle.ssl.file.buffer.size";
+
+ public static final int DEFAULT_SUFFLE_SSL_FILE_BUFFER_SIZE = 60 * 1024;
+
+ public static final String MAX_SHUFFLE_CONNECTIONS = "mapreduce.shuffle.max.connections";
+ public static final int DEFAULT_MAX_SHUFFLE_CONNECTIONS = 0; // 0 implies no limit
+
+ public static final String MAX_SHUFFLE_THREADS = "mapreduce.shuffle.max.threads";
+ // 0 implies Netty default of 2 * number of available processors
+ public static final int DEFAULT_MAX_SHUFFLE_THREADS = 0;
+
+ public static final String SHUFFLE_BUFFER_SIZE =
+ "mapreduce.shuffle.transfer.buffer.size";
+ public static final int DEFAULT_SHUFFLE_BUFFER_SIZE = 128 * 1024;
+
+ public static final String SHUFFLE_TRANSFERTO_ALLOWED =
+ "mapreduce.shuffle.transferTo.allowed";
+ public static final boolean DEFAULT_SHUFFLE_TRANSFERTO_ALLOWED = true;
+ public static final boolean WINDOWS_DEFAULT_SHUFFLE_TRANSFERTO_ALLOWED =
+ false;
+
+ /* the maximum number of files a single GET request can
+ open simultaneously during shuffle
+ */
+ public static final String SHUFFLE_MAX_SESSION_OPEN_FILES =
+ "mapreduce.shuffle.max.session-open-files";
+ public static final int DEFAULT_SHUFFLE_MAX_SESSION_OPEN_FILES = 3;
+
+ boolean connectionKeepAliveEnabled = false;
+ int connectionKeepAliveTimeOut;
+ int mapOutputMetaInfoCacheSize;
+
+ @Metrics(about="Shuffle output metrics", context="mapred")
+ static class ShuffleMetrics implements ChannelFutureListener {
+ @Metric("Shuffle output in bytes")
+ MutableCounterLong shuffleOutputBytes;
+ @Metric("# of failed shuffle outputs")
+ MutableCounterInt shuffleOutputsFailed;
+ @Metric("# of succeeeded shuffle outputs")
+ MutableCounterInt shuffleOutputsOK;
+ @Metric("# of current shuffle connections")
+ MutableGaugeInt shuffleConnections;
+
+ @Override
+ public void operationComplete(ChannelFuture future) throws Exception {
+ if (future.isSuccess()) {
+ shuffleOutputsOK.incr();
+ } else {
+ shuffleOutputsFailed.incr();
+ }
+ shuffleConnections.decr();
+ }
+ }
+
+ final ShuffleMetrics metrics;
+
+ class ReduceMapFileCount implements ChannelFutureListener {
+
+ private ReduceContext reduceContext;
+
+ public ReduceMapFileCount(ReduceContext rc) {
+ this.reduceContext = rc;
+ }
+
+ @Override
+ public void operationComplete(ChannelFuture future) throws Exception {
+ if (!future.isSuccess()) {
+ future.getChannel().close();
+ return;
+ }
+ int waitCount = this.reduceContext.getMapsToWait().decrementAndGet();
+ if (waitCount == 0) {
+ metrics.operationComplete(future);
+ future.getChannel().close();
+ } else {
+ pipelineFact.getSHUFFLE().sendMap(reduceContext);
+ }
+ }
+ }
+
+ /**
+ * Maintain parameters per messageReceived() Netty context.
+ * Allows sendMapOutput calls from operationComplete()
+ */
+ private static class ReduceContext {
+
+ private List<String> mapIds;
+ private AtomicInteger mapsToWait;
+ private AtomicInteger mapsToSend;
+ private int reduceId;
+ private ChannelHandlerContext ctx;
+ private String user;
+ private Map<String, Shuffle.MapOutputInfo> infoMap;
+ private String jobId;
+
+ public ReduceContext(List<String> mapIds, int rId,
+ ChannelHandlerContext context, String usr,
+ Map<String, Shuffle.MapOutputInfo> mapOutputInfoMap,
+ String jobId) {
+
+ this.mapIds = mapIds;
+ this.reduceId = rId;
+ /**
+ * Atomic count for tracking the no. of map outputs that are yet to
+ * complete. Multiple futureListeners' operationComplete() can decrement
+ * this value asynchronously. It is used to decide when the channel should
+ * be closed.
+ */
+ this.mapsToWait = new AtomicInteger(mapIds.size());
+ /**
+ * Atomic count for tracking the no. of map outputs that have been sent.
+ * Multiple sendMap() calls can increment this value
+ * asynchronously. Used to decide which mapId should be sent next.
+ */
+ this.mapsToSend = new AtomicInteger(0);
+ this.ctx = context;
+ this.user = usr;
+ this.infoMap = mapOutputInfoMap;
+ this.jobId = jobId;
+ }
+
+ public int getReduceId() {
+ return reduceId;
+ }
+
+ public ChannelHandlerContext getCtx() {
+ return ctx;
+ }
+
+ public String getUser() {
+ return user;
+ }
+
+ public Map<String, Shuffle.MapOutputInfo> getInfoMap() {
+ return infoMap;
+ }
+
+ public String getJobId() {
+ return jobId;
+ }
+
+ public List<String> getMapIds() {
+ return mapIds;
+ }
+
+ public AtomicInteger getMapsToSend() {
+ return mapsToSend;
+ }
+
+ public AtomicInteger getMapsToWait() {
+ return mapsToWait;
+ }
+ }
+
+ ShuffleHandler(MetricsSystem ms) {
+ super(MAPREDUCE_SHUFFLE_SERVICEID);
+ metrics = ms.register(new ShuffleMetrics());
+ }
+
+ public ShuffleHandler() {
+ this(DefaultMetricsSystem.instance());
+ }
+
+ /**
+ * Serialize the shuffle port into a ByteBuffer for use later on.
+ * @param port the port to be sent to the ApplciationMaster
+ * @return the serialized form of the port.
+ */
+ public static ByteBuffer serializeMetaData(int port) throws IOException {
+ //TODO these bytes should be versioned
+ DataOutputBuffer port_dob = new DataOutputBuffer();
+ port_dob.writeInt(port);
+ return ByteBuffer.wrap(port_dob.getData(), 0, port_dob.getLength());
+ }
+
+ /**
+ * A helper function to deserialize the metadata returned by ShuffleHandler.
+ * @param meta the metadata returned by the ShuffleHandler
+ * @return the port the Shuffle Handler is listening on to serve shuffle data.
+ */
+ public static int deserializeMetaData(ByteBuffer meta) throws IOException {
+ //TODO this should be returning a class not just an int
+ DataInputByteBuffer in = new DataInputByteBuffer();
+ in.reset(meta);
+ int port = in.readInt();
+ return port;
+ }
+
+ /**
+ * A helper function to serialize the JobTokenIdentifier to be sent to the
+ * ShuffleHandler as ServiceData.
+ * @param jobToken the job token to be used for authentication of
+ * shuffle data requests.
+ * @return the serialized version of the jobToken.
+ */
+ public static ByteBuffer serializeServiceData(Token<JobTokenIdentifier> jobToken) throws IOException {
+ //TODO these bytes should be versioned
+ DataOutputBuffer jobToken_dob = new DataOutputBuffer();
+ jobToken.write(jobToken_dob);
+ return ByteBuffer.wrap(jobToken_dob.getData(), 0, jobToken_dob.getLength());
+ }
+
+ static Token<JobTokenIdentifier> deserializeServiceData(ByteBuffer secret) throws IOException {
+ DataInputByteBuffer in = new DataInputByteBuffer();
+ in.reset(secret);
+ Token<JobTokenIdentifier> jt = new Token<JobTokenIdentifier>();
+ jt.readFields(in);
+ return jt;
+ }
+
+ @Override
+ public void initializeApplication(ApplicationInitializationContext context) {
+
+ String user = context.getUser();
+ ApplicationId appId = context.getApplicationId();
+ ByteBuffer secret = context.getApplicationDataForService();
+ // TODO these bytes should be versioned
+ try {
+ Token<JobTokenIdentifier> jt = deserializeServiceData(secret);
+ // TODO: Once SHuffle is out of NM, this can use MR APIs
+ JobID jobId = new JobID(Long.toString(appId.getClusterTimestamp()), appId.getId());
+ recordJobShuffleInfo(jobId, user, jt);
+ } catch (IOException e) {
+ LOG.error("Error during initApp", e);
+ // TODO add API to AuxiliaryServices to report failures
+ }
+ }
+
+ @Override
+ public void stopApplication(ApplicationTerminationContext context) {
+ ApplicationId appId = context.getApplicationId();
+ JobID jobId = new JobID(Long.toString(appId.getClusterTimestamp()), appId.getId());
+ try {
+ removeJobShuffleInfo(jobId);
+ } catch (IOException e) {
+ LOG.error("Error during stopApp", e);
+ // TODO add API to AuxiliaryServices to report failures
+ }
+ }
+
+ @Override
+ protected void serviceInit(Configuration conf) throws Exception {
+ manageOsCache = conf.getBoolean(SHUFFLE_MANAGE_OS_CACHE,
+ DEFAULT_SHUFFLE_MANAGE_OS_CACHE);
+
+ readaheadLength = conf.getInt(SHUFFLE_READAHEAD_BYTES,
+ DEFAULT_SHUFFLE_READAHEAD_BYTES);
+
+ maxShuffleConnections = conf.getInt(MAX_SHUFFLE_CONNECTIONS,
+ DEFAULT_MAX_SHUFFLE_CONNECTIONS);
+ int maxShuffleThreads = conf.getInt(MAX_SHUFFLE_THREADS,
+ DEFAULT_MAX_SHUFFLE_THREADS);
+ if (maxShuffleThreads == 0) {
+ maxShuffleThreads = 2 * Runtime.getRuntime().availableProcessors();
+ }
+
+ shuffleBufferSize = conf.getInt(SHUFFLE_BUFFER_SIZE,
+ DEFAULT_SHUFFLE_BUFFER_SIZE);
+
+ shuffleTransferToAllowed = conf.getBoolean(SHUFFLE_TRANSFERTO_ALLOWED,
+ (Shell.WINDOWS)?WINDOWS_DEFAULT_SHUFFLE_TRANSFERTO_ALLOWED:
+ DEFAULT_SHUFFLE_TRANSFERTO_ALLOWED);
+
+ maxSessionOpenFiles = conf.getInt(SHUFFLE_MAX_SESSION_OPEN_FILES,
+ DEFAULT_SHUFFLE_MAX_SESSION_OPEN_FILES);
+
+ ThreadFactory bossFactory = new ThreadFactoryBuilder()
+ .setNameFormat("ShuffleHandler Netty Boss #%d")
+ .build();
+ ThreadFactory workerFactory = new ThreadFactoryBuilder()
+ .setNameFormat("ShuffleHandler Netty Worker #%d")
+ .build();
+
+ selector = new NioServerSocketChannelFactory(
+ Executors.newCachedThreadPool(bossFactory),
+ Executors.newCachedThreadPool(workerFactory),
+ maxShuffleThreads);
+ super.serviceInit(new YarnConfiguration(conf));
+ }
+
+ // TODO change AbstractService to throw InterruptedException
+ @Override
+ protected void serviceStart() throws Exception {
+ Configuration conf = getConfig();
+ userRsrc = new ConcurrentHashMap<String,String>();
+ secretManager = new JobTokenSecretManager();
+ recoverState(conf);
+ ServerBootstrap bootstrap = new ServerBootstrap(selector);
+ try {
+ pipelineFact = new HttpPipelineFactory(conf);
+ } catch (Exception ex) {
+ throw new RuntimeException(ex);
+ }
+ bootstrap.setOption("child.keepAlive", true);
+ bootstrap.setPipelineFactory(pipelineFact);
+ port = conf.getInt(SHUFFLE_PORT_CONFIG_KEY, DEFAULT_SHUFFLE_PORT);
+ Channel ch = bootstrap.bind(new InetSocketAddress(port));
+ accepted.add(ch);
+ port = ((InetSocketAddress)ch.getLocalAddress()).getPort();
+ conf.set(SHUFFLE_PORT_CONFIG_KEY, Integer.toString(port));
+ pipelineFact.SHUFFLE.setPort(port);
+ LOG.info(getName() + " listening on port " + port);
+ super.serviceStart();
+
+ sslFileBufferSize = conf.getInt(SUFFLE_SSL_FILE_BUFFER_SIZE_KEY,
+ DEFAULT_SUFFLE_SSL_FILE_BUFFER_SIZE);
+ connectionKeepAliveEnabled =
+ conf.getBoolean(SHUFFLE_CONNECTION_KEEP_ALIVE_ENABLED,
+ DEFAULT_SHUFFLE_CONNECTION_KEEP_ALIVE_ENABLED);
+ connectionKeepAliveTimeOut =
+ Math.max(1, conf.getInt(SHUFFLE_CONNECTION_KEEP_ALIVE_TIME_OUT,
+ DEFAULT_SHUFFLE_CONNECTION_KEEP_ALIVE_TIME_OUT));
+ mapOutputMetaInfoCacheSize =
+ Math.max(1, conf.getInt(SHUFFLE_MAPOUTPUT_META_INFO_CACHE_SIZE,
+ DEFAULT_SHUFFLE_MAPOUTPUT_META_INFO_CACHE_SIZE));
+ }
+
+ @Override
+ protected void serviceStop() throws Exception {
+ accepted.close().awaitUninterruptibly(10, TimeUnit.SECONDS);
+ if (selector != null) {
+ ServerBootstrap bootstrap = new ServerBootstrap(selector);
+ bootstrap.releaseExternalResources();
+ }
+ if (pipelineFact != null) {
+ pipelineFact.destroy();
+ }
+ if (stateDb != null) {
+ stateDb.close();
+ }
+ super.serviceStop();
+ }
+
+ @Override
+ public synchronized ByteBuffer getMetaData() {
+ try {
+ return serializeMetaData(port);
+ } catch (IOException e) {
+ LOG.error("Error during getMeta", e);
+ // TODO add API to AuxiliaryServices to report failures
+ return null;
+ }
+ }
+
+ protected Shuffle getShuffle(Configuration conf) {
+ return new Shuffle(conf);
+ }
+
+ private void recoverState(Configuration conf) throws IOException {
+ Path recoveryRoot = getRecoveryPath();
+ if (recoveryRoot != null) {
+ startStore(recoveryRoot);
+ Pattern jobPattern = Pattern.compile(JobID.JOBID_REGEX);
+ LeveldbIterator iter = null;
+ try {
+ iter = new LeveldbIterator(stateDb);
+ iter.seek(bytes(JobID.JOB));
+ while (iter.hasNext()) {
+ Map.Entry<byte[],byte[]> entry = iter.next();
+ String key = asString(entry.getKey());
+ if (!jobPattern.matcher(key).matches()) {
+ break;
+ }
+ recoverJobShuffleInfo(key, entry.getValue());
+ }
+ } catch (DBException e) {
+ throw new IOException("Database error during recovery", e);
+ } finally {
+ if (iter != null) {
+ iter.close();
+ }
+ }
+ }
+ }
+
+ private void startStore(Path recoveryRoot) throws IOException {
+ Options options = new Options();
+ options.createIfMissing(false);
+ options.logger(new LevelDBLogger());
+ Path dbPath = new Path(recoveryRoot, STATE_DB_NAME);
+ LOG.info("Using state database at " + dbPath + " for recovery");
+ File dbfile = new File(dbPath.toString());
+ try {
+ stateDb = JniDBFactory.factory.open(dbfile, options);
+ } catch (NativeDB.DBException e) {
+ if (e.isNotFound() || e.getMessage().contains(" does not exist ")) {
+ LOG.info("Creating state database at " + dbfile);
+ options.createIfMissing(true);
+ try {
+ stateDb = JniDBFactory.factory.open(dbfile, options);
+ storeVersion();
+ } catch (DBException dbExc) {
+ throw new IOException("Unable to create state store", dbExc);
+ }
+ } else {
+ throw e;
+ }
+ }
+ checkVersion();
+ }
+
+ @VisibleForTesting
+ Version loadVersion() throws IOException {
+ byte[] data = stateDb.get(bytes(STATE_DB_SCHEMA_VERSION_KEY));
+ // if version is not stored previously, treat it as CURRENT_VERSION_INFO.
+ if (data == null || data.length == 0) {
+ return getCurrentVersion();
+ }
+ Version version =
+ new VersionPBImpl(VersionProto.parseFrom(data));
+ return version;
+ }
+
+ private void storeSchemaVersion(Version version) throws IOException {
+ String key = STATE_DB_SCHEMA_VERSION_KEY;
+ byte[] data =
+ ((VersionPBImpl) version).getProto().toByteArray();
+ try {
+ stateDb.put(bytes(key), data);
+ } catch (DBException e) {
+ throw new IOException(e.getMessage(), e);
+ }
+ }
+
+ private void storeVersion() throws IOException {
+ storeSchemaVersion(CURRENT_VERSION_INFO);
+ }
+
+ // Only used for test
+ @VisibleForTesting
+ void storeVersion(Version version) throws IOException {
+ storeSchemaVersion(version);
+ }
+
+ protected Version getCurrentVersion() {
+ return CURRENT_VERSION_INFO;
+ }
+
+ /**
+ * 1) Versioning scheme: major.minor. For e.g. 1.0, 1.1, 1.2...1.25, 2.0 etc.
+ * 2) Any incompatible change of DB schema is a major upgrade, and any
+ * compatible change of DB schema is a minor upgrade.
+ * 3) Within a minor upgrade, say 1.1 to 1.2:
+ * overwrite the version info and proceed as normal.
+ * 4) Within a major upgrade, say 1.2 to 2.0:
+ * throw exception and indicate user to use a separate upgrade tool to
+ * upgrade shuffle info or remove incompatible old state.
+ */
+ private void checkVersion() throws IOException {
+ Version loadedVersion = loadVersion();
+ LOG.info("Loaded state DB schema version info " + loadedVersion);
+ if (loadedVersion.equals(getCurrentVersion())) {
+ return;
+ }
+ if (loadedVersion.isCompatibleTo(getCurrentVersion())) {
+ LOG.info("Storing state DB schedma version info " + getCurrentVersion());
+ storeVersion();
+ } else {
+ throw new IOException(
+ "Incompatible version for state DB schema: expecting DB schema version "
+ + getCurrentVersion() + ", but loading version " + loadedVersion);
+ }
+ }
+
+ private void addJobToken(JobID jobId, String user,
+ Token<JobTokenIdentifier> jobToken) {
+ userRsrc.put(jobId.toString(), user);
+ secretManager.addTokenForJob(jobId.toString(), jobToken);
+ LOG.info("Added token for " + jobId.toString());
+ }
+
+ private void recoverJobShuffleInfo(String jobIdStr, byte[] data)
+ throws IOException {
+ JobID jobId;
+ try {
+ jobId = JobID.forName(jobIdStr);
+ } catch (IllegalArgumentException e) {
+ throw new IOException("Bad job ID " + jobIdStr + " in state store", e);
+ }
+
+ JobShuffleInfoProto proto = JobShuffleInfoProto.parseFrom(data);
+ String user = proto.getUser();
+ TokenProto tokenProto = proto.getJobToken();
+ Token<JobTokenIdentifier> jobToken = new Token<JobTokenIdentifier>(
+ tokenProto.getIdentifier().toByteArray(),
+ tokenProto.getPassword().toByteArray(),
+ new Text(tokenProto.getKind()), new Text(tokenProto.getService()));
+ addJobToken(jobId, user, jobToken);
+ }
+
+ private void recordJobShuffleInfo(JobID jobId, String user,
+ Token<JobTokenIdentifier> jobToken) throws IOException {
+ if (stateDb != null) {
+ TokenProto tokenProto = TokenProto.newBuilder()
+ .setIdentifier(ByteString.copyFrom(jobToken.getIdentifier()))
+ .setPassword(ByteString.copyFrom(jobToken.getPassword()))
+ .setKind(jobToken.getKind().toString())
+ .setService(jobToken.getService().toString())
+ .build();
+ JobShuffleInfoProto proto = JobShuffleInfoProto.newBuilder()
+ .setUser(user).setJobToken(tokenProto).build();
+ try {
+ stateDb.put(bytes(jobId.toString()), proto.toByteArray());
+ } catch (DBException e) {
+ throw new IOException("Error storing " + jobId, e);
+ }
+ }
+ addJobToken(jobId, user, jobToken);
+ }
+
+ private void removeJobShuffleInfo(JobID jobId) throws IOException {
+ String jobIdStr = jobId.toString();
+ secretManager.removeTokenForJob(jobIdStr);
+ userRsrc.remove(jobIdStr);
+ if (stateDb != null) {
+ try {
+ stateDb.delete(bytes(jobIdStr));
+ } catch (DBException e) {
+ throw new IOException("Unable to remove " + jobId
+ + " from state store", e);
+ }
+ }
+ }
+
+ private static class LevelDBLogger implements Logger {
+ private static final Log LOG = LogFactory.getLog(LevelDBLogger.class);
+
+ @Override
+ public void log(String message) {
+ LOG.info(message);
+ }
+ }
+
+ class HttpPipelineFactory implements ChannelPipelineFactory {
+
+ final Shuffle SHUFFLE;
+ private SSLFactory sslFactory;
+
+ public HttpPipelineFactory(Configuration conf) throws Exception {
+ SHUFFLE = getShuffle(conf);
+ if (conf.getBoolean(MRConfig.SHUFFLE_SSL_ENABLED_KEY,
+ MRConfig.SHUFFLE_SSL_ENABLED_DEFAULT)) {
+ LOG.info("Encrypted shuffle is enabled.");
+ sslFactory = new SSLFactory(SSLFactory.Mode.SERVER, conf);
+ sslFactory.init();
+ }
+ }
+
+ public Shuffle getSHUFFLE() {
+ return SHUFFLE;
+ }
+
+ public void destroy() {
+ if (sslFactory != null) {
+ sslFactory.destroy();
+ }
+ }
+
+ @Override
+ public ChannelPipeline getPipeline() throws Exception {
+ ChannelPipeline pipeline = Channels.pipeline();
+ if (sslFactory != null) {
+ pipeline.addLast("ssl", new SslHandler(sslFactory.createSSLEngine()));
+ }
+ pipeline.addLast("decoder", new HttpRequestDecoder());
+ pipeline.addLast("aggregator", new HttpChunkAggregator(1 << 16));
+ pipeline.addLast("encoder", new HttpResponseEncoder());
+ pipeline.addLast("chunking", new ChunkedWriteHandler());
+ pipeline.addLast("shuffle", SHUFFLE);
+ return pipeline;
+ // TODO factor security manager into pipeline
+ // TODO factor out encode/decode to permit binary shuffle
+ // TODO factor out decode of index to permit alt. models
+ }
+
+ }
+
+ class Shuffle extends SimpleChannelUpstreamHandler {
+
+ private static final int MAX_WEIGHT = 10 * 1024 * 1024;
+ private static final int EXPIRE_AFTER_ACCESS_MINUTES = 5;
+ private static final int ALLOWED_CONCURRENCY = 16;
+ private final Configuration conf;
+ private final IndexCache indexCache;
+ private final LocalDirAllocator lDirAlloc =
+ new LocalDirAllocator(YarnConfiguration.NM_LOCAL_DIRS);
+ private int port;
+ private final LoadingCache<AttemptPathIdentifier, AttemptPathInfo> pathCache =
+ CacheBuilder.newBuilder().expireAfterAccess(EXPIRE_AFTER_ACCESS_MINUTES,
+ TimeUnit.MINUTES).softValues().concurrencyLevel(ALLOWED_CONCURRENCY).
+ removalListener(
+ new RemovalListener<AttemptPathIdentifier, AttemptPathInfo>() {
+ @Override
+ public void onRemoval(RemovalNotification<AttemptPathIdentifier,
+ AttemptPathInfo> notification) {
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("PathCache Eviction: " + notification.getKey() +
+ ", Reason=" + notification.getCause());
+ }
+ }
+ }
+ ).maximumWeight(MAX_WEIGHT).weigher(
+ new Weigher<AttemptPathIdentifier, AttemptPathInfo>() {
+ @Override
+ public int weigh(AttemptPathIdentifier key,
+ AttemptPathInfo value) {
+ return key.jobId.length() + key.user.length() +
+ key.attemptId.length()+
+ value.indexPath.toString().length() +
+ value.dataPath.toString().length();
+ }
+ }
+ ).build(new CacheLoader<AttemptPathIdentifier, AttemptPathInfo>() {
+ @Override
+ public AttemptPathInfo load(AttemptPathIdentifier key) throws
+ Exception {
+ String base = getBaseLocation(key.jobId, key.user);
+ String attemptBase = base + key.attemptId;
+ Path indexFileName = lDirAlloc.getLocalPathToRead(
+ attemptBase + "/" + INDEX_FILE_NAME, conf);
+ Path mapOutputFileName = lDirAlloc.getLocalPathToRead(
+ attemptBase + "/" + DATA_FILE_NAME, conf);
+
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Loaded : " + key + " via loader");
+ }
+ return new AttemptPathInfo(indexFileName, mapOutputFileName);
+ }
+ });
+
+ public Shuffle(Configuration conf) {
+ this.conf = conf;
+ indexCache = new IndexCache(conf);
+ this.port = conf.getInt(SHUFFLE_PORT_CONFIG_KEY, DEFAULT_SHUFFLE_PORT);
+ }
+
+ public void setPort(int port) {
+ this.port = port;
+ }
+
+ private List<String> splitMaps(List<String> mapq) {
+ if (null == mapq) {
+ return null;
+ }
+ final List<String> ret = new ArrayList<String>();
+ for (String s : mapq) {
+ Collections.addAll(ret, s.split(","));
+ }
+ return ret;
+ }
+
+ @Override
+ public void channelOpen(ChannelHandlerContext ctx, ChannelStateEvent evt)
+ throws Exception {
+
+ if ((maxShuffleConnections > 0) && (accepted.size() >= maxShuffleConnections)) {
+ LOG.info(String.format("Current number of shuffle connections (%d) is " +
+ "greater than or equal to the max allowed shuffle connections (%d)",
+ accepted.size(), maxShuffleConnections));
+ evt.getChannel().close();
+ return;
+ }
+ accepted.add(evt.getChannel());
+ super.channelOpen(ctx, evt);
+ }
+
+ @Override
+ public void messageReceived(ChannelHandlerContext ctx, MessageEvent evt)
+ throws Exception {
+ HttpRequest request = (HttpRequest) evt.getMessage();
+ if (request.getMethod() != GET) {
+ sendError(ctx, METHOD_NOT_ALLOWED);
+ return;
+ }
+ // Check whether the shuffle version is compatible
+ if (!ShuffleHeader.DEFAULT_HTTP_HEADER_NAME.equals(
+ request.getHeader(ShuffleHeader.HTTP_HEADER_NAME))
+ || !ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION.equals(
+ request.getHeader(ShuffleHeader.HTTP_HEADER_VERSION))) {
+ sendError(ctx, "Incompatible shuffle request version", BAD_REQUEST);
+ }
+ final Map<String,List<String>> q =
+ new QueryStringDecoder(request.getUri()).getParameters();
+ final List<String> keepAliveList = q.get("keepAlive");
+ boolean keepAliveParam = false;
+ if (keepAliveList != null && keepAliveList.size() == 1) {
+ keepAliveParam = Boolean.valueOf(keepAliveList.get(0));
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("KeepAliveParam : " + keepAliveList
+ + " : " + keepAliveParam);
+ }
+ }
+ final List<String> mapIds = splitMaps(q.get("map"));
+ final List<String> reduceQ = q.get("reduce");
+ final List<String> jobQ = q.get("job");
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("RECV: " + request.getUri() +
+ "\n mapId: " + mapIds +
+ "\n reduceId: " + reduceQ +
+ "\n jobId: " + jobQ +
+ "\n keepAlive: " + keepAliveParam);
+ }
+
+ if (mapIds == null || reduceQ == null || jobQ == null) {
+ sendError(ctx, "Required param job, map and reduce", BAD_REQUEST);
+ return;
+ }
+ if (reduceQ.size() != 1 || jobQ.size() != 1) {
+ sendError(ctx, "Too many job/reduce parameters", BAD_REQUEST);
+ return;
+ }
+
+ // this audit log is disabled by default,
+ // to turn it on please enable this audit log
+ // on log4j.properties by uncommenting the setting
+ if (AUDITLOG.isDebugEnabled()) {
+ AUDITLOG.debug("shuffle for " + jobQ.get(0) +
+ " reducer " + reduceQ.get(0));
+ }
+ int reduceId;
+ String jobId;
+ try {
+ reduceId = Integer.parseInt(reduceQ.get(0));
+ jobId = jobQ.get(0);
+ } catch (NumberFormatException e) {
+ sendError(ctx, "Bad reduce parameter", BAD_REQUEST);
+ return;
+ } catch (IllegalArgumentException e) {
+ sendError(ctx, "Bad job parameter", BAD_REQUEST);
+ return;
+ }
+ final String reqUri = request.getUri();
+ if (null == reqUri) {
+ // TODO? add upstream?
+ sendError(ctx, FORBIDDEN);
+ return;
+ }
+ HttpResponse response = new DefaultHttpResponse(HTTP_1_1, OK);
+ try {
+ verifyRequest(jobId, ctx, request, response,
+ new URL("http", "", this.port, reqUri));
+ } catch (IOException e) {
+ LOG.warn("Shuffle failure ", e);
+ sendError(ctx, e.getMessage(), UNAUTHORIZED);
+ return;
+ }
+
+ Map<String, MapOutputInfo> mapOutputInfoMap =
+ new HashMap<String, MapOutputInfo>();
+ Channel ch = evt.getChannel();
+ String user = userRsrc.get(jobId);
+
+ try {
+ populateHeaders(mapIds, jobId, user, reduceId, request,
+ response, keepAliveParam, mapOutputInfoMap);
+ } catch(IOException e) {
+ ch.write(response);
+ LOG.error("Shuffle error in populating headers :", e);
+ String errorMessage = getErrorMessage(e);
+ sendError(ctx,errorMessage , INTERNAL_SERVER_ERROR);
+ return;
+ }
+ ch.write(response);
+ //Initialize one ReduceContext object per messageReceived call
+ ReduceContext reduceContext = new ReduceContext(mapIds, reduceId, ctx,
+ user, mapOutputInfoMap, jobId);
+ for (int i = 0; i < Math.min(maxSessionOpenFiles, mapIds.size()); i++) {
+ ChannelFuture nextMap = sendMap(reduceContext);
+ if(nextMap == null) {
+ return;
+ }
+ }
+ }
+
+ /**
+ * Calls sendMapOutput for the mapId pointed by ReduceContext.mapsToSend
+ * and increments it. This method is first called by messageReceived()
+ * maxSessionOpenFiles times and then on the completion of every
+ * sendMapOutput operation. This limits the number of open files on a node,
+ * which can get really large(exhausting file descriptors on the NM) if all
+ * sendMapOutputs are called in one go, as was done previous to this change.
+ * @param reduceContext used to call sendMapOutput with correct params.
+ * @return the ChannelFuture of the sendMapOutput, can be null.
+ */
+ public ChannelFuture sendMap(ReduceContext reduceContext)
+ throws Exception {
+
+ ChannelFuture nextMap = null;
+ if (reduceContext.getMapsToSend().get() <
+ reduceContext.getMapIds().size()) {
+ int nextIndex = reduceContext.getMapsToSend().getAndIncrement();
+ String mapId = reduceContext.getMapIds().get(nextIndex);
+
+ try {
+ MapOutputInfo info = reduceContext.getInfoMap().get(mapId);
+ if (info == null) {
+ info = getMapOutputInfo(mapId, reduceContext.getReduceId(),
+ reduceContext.getJobId(), reduceContext.getUser());
+ }
+ nextMap = sendMapOutput(
+ reduceContext.getCtx(),
+ reduceContext.getCtx().getChannel(),
+ reduceContext.getUser(), mapId,
+ reduceContext.getReduceId(), info);
+ if (null == nextMap) {
+ sendError(reduceContext.getCtx(), NOT_FOUND);
+ return null;
+ }
+ nextMap.addListener(new ReduceMapFileCount(reduceContext));
+ } catch (IOException e) {
+ LOG.error("Shuffle error :", e);
+ String errorMessage = getErrorMessage(e);
+ sendError(reduceContext.getCtx(), errorMessage,
+ INTERNAL_SERVER_ERROR);
+ return null;
+ }
+ }
+ return nextMap;
+ }
+
+ private String getErrorMessage(Throwable t) {
+ StringBuffer sb = new StringBuffer(t.getMessage());
+ while (t.getCause() != null) {
+ sb.append(t.getCause().getMessage());
+ t = t.getCause();
+ }
+ return sb.toString();
+ }
+
+ private String getBaseLocation(String jobId, String user) {
+ final JobID jobID = JobID.forName(jobId);
+ final ApplicationId appID =
+ ApplicationId.newInstance(Long.parseLong(jobID.getJtIdentifier()),
+ jobID.getId());
+ final String baseStr =
+ USERCACHE + "/" + user + "/"
+ + APPCACHE + "/"
+ + appID.toString() + "/output" + "/";
+ return baseStr;
+ }
+
+ protected MapOutputInfo getMapOutputInfo(String mapId, int reduce,
+ String jobId, String user) throws IOException {
+ AttemptPathInfo pathInfo;
+ try {
+ AttemptPathIdentifier identifier = new AttemptPathIdentifier(
+ jobId, user, mapId);
+ pathInfo = pathCache.get(identifier);
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Retrieved pathInfo for " + identifier +
+ " check for corresponding loaded messages to determine whether" +
+ " it was loaded or cached");
+ }
+ } catch (ExecutionException e) {
+ if (e.getCause() instanceof IOException) {
+ throw (IOException) e.getCause();
+ } else {
+ throw new RuntimeException(e.getCause());
+ }
+ }
+
+ TezIndexRecord info =
+ indexCache.getIndexInformation(mapId, reduce, pathInfo.indexPath, user);
+
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("getMapOutputInfo: jobId=" + jobId + ", mapId=" + mapId +
+ ",dataFile=" + pathInfo.dataPath + ", indexFile=" +
+ pathInfo.indexPath);
+ }
+
+ MapOutputInfo outputInfo = new MapOutputInfo(pathInfo.dataPath, info);
+ return outputInfo;
+ }
+
+ protected void populateHeaders(List<String> mapIds, String jobId,
+ String user, int reduce, HttpRequest request, HttpResponse response,
+ boolean keepAliveParam, Map<String, MapOutputInfo> mapOutputInfoMap)
+ throws IOException {
+
+ long contentLength = 0;
+ for (String mapId : mapIds) {
+ MapOutputInfo outputInfo = getMapOutputInfo(mapId, reduce, jobId, user);
+ if (mapOutputInfoMap.size() < mapOutputMetaInfoCacheSize) {
+ mapOutputInfoMap.put(mapId, outputInfo);
+ }
+
+ ShuffleHeader header =
+ new ShuffleHeader(mapId, outputInfo.indexRecord.getPartLength(),
+ outputInfo.indexRecord.getRawLength(), reduce);
+ DataOutputBuffer dob = new DataOutputBuffer();
+ header.write(dob);
+
+ contentLength += outputInfo.indexRecord.getPartLength();
+ contentLength += dob.getLength();
+ }
+
+ // Now set the response headers.
+ setResponseHeaders(response, keepAliveParam, contentLength);
+ }
+
+ protected void setResponseHeaders(HttpResponse response,
+ boolean keepAliveParam, long contentLength) {
+ if (!connectionKeepAliveEnabled && !keepAliveParam) {
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Setting connection close header...");
+ }
+ response.setHeader(HttpHeaders.CONNECTION, CONNECTION_CLOSE);
+ } else {
+ response.setHeader(HttpHeaders.CONTENT_LENGTH,
+ String.valueOf(contentLength));
+ response.setHeader(HttpHeaders.CONNECTION, HttpHeaders.KEEP_ALIVE);
+ response.setHeader(HttpHeaders.KEEP_ALIVE, "timeout="
+ + connectionKeepAliveTimeOut);
+ LOG.info("Content Length in shuffle : " + contentLength);
+ }
+ }
+
+ class MapOutputInfo {
+ final Path mapOutputFileName;
+ final TezIndexRecord indexRecord;
+
+ MapOutputInfo(Path mapOutputFileName, TezIndexRecord indexRecord) {
+ this.mapOutputFileName = mapOutputFileName;
+ this.indexRecord = indexRecord;
+ }
+ }
+
+ protected void verifyRequest(String appid, ChannelHandlerContext ctx,
+ HttpRequest request, HttpResponse response, URL requestUri)
+ throws IOException {
+ SecretKey tokenSecret = secretManager.retrieveTokenSecret(appid);
+ if (null == tokenSecret) {
+ LOG.info("Request for unknown token " + appid);
+ throw new IOException("could not find jobid");
+ }
+ // string to encrypt
+ String enc_str = SecureShuffleUtils.buildMsgFrom(requestUri);
+ // hash from the fetcher
+ String urlHashStr =
+ request.getHeader(SecureShuffleUtils.HTTP_HEADER_URL_HASH);
+ if (urlHashStr == null) {
+ LOG.info("Missing header hash for " + appid);
+ throw new IOException("fetcher cannot be authenticated");
+ }
+ if (LOG.isDebugEnabled()) {
+ int len = urlHashStr.length();
+ LOG.debug("verifying request. enc_str=" + enc_str + "; hash=..." +
+ urlHashStr.substring(len-len/2, len-1));
+ }
+ // verify - throws exception
+ SecureShuffleUtils.verifyReply(urlHashStr, enc_str, tokenSecret);
+ // verification passed - encode the reply
+ String reply =
+ SecureShuffleUtils.generateHash(urlHashStr.getBytes(Charsets.UTF_8),
+ tokenSecret);
+ response.setHeader(SecureShuffleUtils.HTTP_HEADER_REPLY_URL_HASH, reply);
+ // Put shuffle version into http header
+ response.setHeader(ShuffleHeader.HTTP_HEADER_NAME,
+ ShuffleHeader.DEFAULT_HTTP_HEADER_NAME);
+ response.setHeader(ShuffleHeader.HTTP_HEADER_VERSION,
+ ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION);
+ if (LOG.isDebugEnabled()) {
+ int len = reply.length();
+ LOG.debug("Fetcher request verfied. enc_str=" + enc_str + ";reply=" +
+ reply.substring(len-len/2, len-1));
+ }
+ }
+
+ protected ChannelFuture sendMapOutput(ChannelHandlerContext ctx, Channel ch,
+ String user, String mapId, int reduce, MapOutputInfo mapOutputInfo)
+ throws IOException {
+ final TezIndexRecord info = mapOutputInfo.indexRecord;
+ final ShuffleHeader header =
+ new ShuffleHeader(mapId, info.getPartLength(), info.getRawLength(), reduce);
+ final DataOutputBuffer dob = new DataOutputBuffer();
+ header.write(dob);
+ ch.write(wrappedBuffer(dob.getData(), 0, dob.getLength()));
+ final File spillfile =
+ new File(mapOutputInfo.mapOutputFileName.toString());
+ RandomAccessFile spill;
+ try {
+ spill = SecureIOUtils.openForRandomRead(spillfile, "r", user, null);
+ } catch (FileNotFoundException e) {
+ LOG.info(spillfile + " not found");
+ return null;
+ }
+ ChannelFuture writeFuture;
+ if (ch.getPipeline().get(SslHandler.class) == null) {
+ final FadvisedFileRegion partition = new FadvisedFileRegion(spill,
+ info.getStartOffset(), info.getPartLength(), manageOsCache, readaheadLength,
+ readaheadPool, spillfile.getAbsolutePath(),
+ shuffleBufferSize, shuffleTransferToAllowed);
+ writeFuture = ch.write(partition);
+ writeFuture.addListener(new ChannelFutureListener() {
+ // TODO error handling; distinguish IO/connection failures,
+ // attribute to appropriate spill output
+ @Override
+ public void operationComplete(ChannelFuture future) {
+ if (future.isSuccess()) {
+ partition.transferSuccessful();
+ }
+ partition.releaseExternalResources();
+ }
+ });
+ } else {
+ // HTTPS cannot be done with zero copy.
+ final FadvisedChunkedFile chunk = new FadvisedChunkedFile(spill,
+ info.getStartOffset(), info.getPartLength(), sslFileBufferSize,
+ manageOsCache, readaheadLength, readaheadPool,
+ spillfile.getAbsolutePath());
+ writeFuture = ch.write(chunk);
+ }
+ metrics.shuffleConnections.incr();
+ metrics.shuffleOutputBytes.incr(info.getPartLength()); // optimistic
+ return writeFuture;
+ }
+
+ protected void sendError(ChannelHandlerContext ctx,
+ HttpResponseStatus status) {
+ sendError(ctx, "", status);
+ }
+
+ protected void sendError(ChannelHandlerContext ctx, String message,
+ HttpResponseStatus status) {
+ HttpResponse response = new DefaultHttpResponse(HTTP_1_1, status);
+ response.setHeader(CONTENT_TYPE, "text/plain; charset=UTF-8");
+ // Put shuffle version into http header
+ response.setHeader(ShuffleHeader.HTTP_HEADER_NAME,
+ ShuffleHeader.DEFAULT_HTTP_HEADER_NAME);
+ response.setHeader(ShuffleHeader.HTTP_HEADER_VERSION,
+ ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION);
+ response.setContent(
+ ChannelBuffers.copiedBuffer(message, CharsetUtil.UTF_8));
+
+ // Close the connection as soon as the error message is sent.
+ ctx.getChannel().write(response).addListener(ChannelFutureListener.CLOSE);
+ }
+
+ @Override
+ public void exceptionCaught(ChannelHandlerContext ctx, ExceptionEvent e)
+ throws Exception {
+ Channel ch = e.getChannel();
+ Throwable cause = e.getCause();
+ if (cause instanceof TooLongFrameException) {
+ sendError(ctx, BAD_REQUEST);
+ return;
+ } else if (cause instanceof IOException) {
+ if (cause instanceof ClosedChannelException) {
+ LOG.debug("Ignoring closed channel error", cause);
+ return;
+ }
+ String message = String.valueOf(cause.getMessage());
+ if (IGNORABLE_ERROR_MESSAGE.matcher(message).matches()) {
+ LOG.debug("Ignoring client socket close", cause);
+ return;
+ }
+ }
+
+ LOG.error("Shuffle error: ", cause);
+ if (ch.isConnected()) {
+ LOG.error("Shuffle error " + e);
+ sendError(ctx, INTERNAL_SERVER_ERROR);
+ }
+ }
+ }
+
+ static class AttemptPathInfo {
+ // TODO Change this over to just store local dir indices, instead of the
+ // entire path. Far more efficient.
+ private final Path indexPath;
+ private final Path dataPath;
+
+ public AttemptPathInfo(Path indexPath, Path dataPath) {
+ this.indexPath = indexPath;
+ this.dataPath = dataPath;
+ }
+ }
+
+ static class AttemptPathIdentifier {
+ private final String jobId;
+ private final String user;
+ private final String attemptId;
+
+ public AttemptPathIdentifier(String jobId, String user, String attemptId) {
+ this.jobId = jobId;
+ this.user = user;
+ this.attemptId = attemptId;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+
+ AttemptPathIdentifier that = (AttemptPathIdentifier) o;
+
+ if (!attemptId.equals(that.attemptId)) {
+ return false;
+ }
+ if (!jobId.equals(that.jobId)) {
+ return false;
+ }
+
+ return true;
+ }
+
+ @Override
+ public int hashCode() {
+ int result = jobId.hashCode();
+ result = 31 * result + attemptId.hashCode();
+ return result;
+ }
+
+ @Override
+ public String toString() {
+ return "AttemptPathIdentifier{" +
+ "attemptId='" + attemptId + '\'' +
+ ", jobId='" + jobId + '\'' +
+ '}';
+ }
+ }
+}