You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hive.apache.org by xu...@apache.org on 2015/01/30 04:27:28 UTC

svn commit: r1655926 - in /hive/branches/spark: common/src/java/org/apache/hadoop/hive/conf/ spark-client/src/main/java/org/apache/hive/spark/client/ spark-client/src/main/java/org/apache/hive/spark/client/rpc/ spark-client/src/test/java/org/apache/hiv...

Author: xuefu
Date: Fri Jan 30 03:27:28 2015
New Revision: 1655926

URL: http://svn.apache.org/r1655926
Log:
HIVE-9487: Make Remote Spark Context secure [Spark Branch] (Marcelo via Xuefu)

Added:
    hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/SaslHandler.java
Modified:
    hive/branches/spark/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java
    hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/RemoteDriver.java
    hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientFactory.java
    hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientImpl.java
    hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/KryoMessageCodec.java
    hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/README.md
    hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/Rpc.java
    hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/RpcConfiguration.java
    hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/RpcServer.java
    hive/branches/spark/spark-client/src/test/java/org/apache/hive/spark/client/rpc/TestKryoMessageCodec.java
    hive/branches/spark/spark-client/src/test/java/org/apache/hive/spark/client/rpc/TestRpc.java

Modified: hive/branches/spark/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java
URL: http://svn.apache.org/viewvc/hive/branches/spark/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java?rev=1655926&r1=1655925&r2=1655926&view=diff
==============================================================================
--- hive/branches/spark/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java (original)
+++ hive/branches/spark/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java Fri Jan 30 03:27:28 2015
@@ -2018,7 +2018,9 @@ public class HiveConf extends Configurat
     SPARK_RPC_MAX_MESSAGE_SIZE("hive.spark.client.rpc.max.size", 50 * 1024 * 1024,
       "Maximum message size in bytes for communication between Hive client and remote Spark driver. Default is 50MB."),
     SPARK_RPC_CHANNEL_LOG_LEVEL("hive.spark.client.channel.log.level", null,
-      "Channel logging level for remote Spark driver.  One of {DEBUG, ERROR, INFO, TRACE, WARN}.");
+      "Channel logging level for remote Spark driver.  One of {DEBUG, ERROR, INFO, TRACE, WARN}."),
+    SPARK_RPC_SASL_MECHANISM("hive.spark.client.rpc.sasl.mechanisms", "DIGEST-MD5",
+      "Name of the SASL mechanism to use for authentication.");
 
     public final String varname;
     private final String defaultExpr;

Modified: hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/RemoteDriver.java
URL: http://svn.apache.org/viewvc/hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/RemoteDriver.java?rev=1655926&r1=1655925&r2=1655926&view=diff
==============================================================================
--- hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/RemoteDriver.java (original)
+++ hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/RemoteDriver.java Fri Jan 30 03:27:28 2015
@@ -106,6 +106,8 @@ public class RemoteDriver {
         serverAddress = getArg(args, idx);
       } else if (key.equals("--remote-port")) {
         serverPort = Integer.parseInt(getArg(args, idx));
+      } else if (key.equals("--client-id")) {
+        conf.set(SparkClientFactory.CONF_CLIENT_ID, getArg(args, idx));
       } else if (key.equals("--secret")) {
         conf.set(SparkClientFactory.CONF_KEY_SECRET, getArg(args, idx));
       } else if (key.equals("--conf")) {
@@ -127,6 +129,8 @@ public class RemoteDriver {
       LOG.debug("Remote Driver configured with: " + e._1() + "=" + e._2());
     }
 
+    String clientId = mapConf.get(SparkClientFactory.CONF_CLIENT_ID);
+    Preconditions.checkArgument(clientId != null, "No client ID provided.");
     String secret = mapConf.get(SparkClientFactory.CONF_KEY_SECRET);
     Preconditions.checkArgument(secret != null, "No secret provided.");
 
@@ -140,8 +144,8 @@ public class RemoteDriver {
     this.protocol = new DriverProtocol();
 
     // The RPC library takes care of timing out this.
-    this.clientRpc = Rpc.createClient(mapConf, egroup, serverAddress, serverPort, secret, protocol)
-      .get();
+    this.clientRpc = Rpc.createClient(mapConf, egroup, serverAddress, serverPort,
+      clientId, secret, protocol).get();
     this.running = true;
 
     this.clientRpc.addListener(new Rpc.Listener() {

Modified: hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientFactory.java
URL: http://svn.apache.org/viewvc/hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientFactory.java?rev=1655926&r1=1655925&r2=1655926&view=diff
==============================================================================
--- hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientFactory.java (original)
+++ hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientFactory.java Fri Jan 30 03:27:28 2015
@@ -37,6 +37,9 @@ public final class SparkClientFactory {
   /** Used to run the driver in-process, mostly for testing. */
   static final String CONF_KEY_IN_PROCESS = "spark.client.do_not_use.run_driver_in_process";
 
+  /** Used by client and driver to share a client ID for establishing an RPC session. */
+  static final String CONF_CLIENT_ID = "spark.client.authentication.client_id";
+
   /** Used by client and driver to share a secret for establishing an RPC session. */
   static final String CONF_KEY_SECRET = "spark.client.authentication.secret";
 

Modified: hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientImpl.java
URL: http://svn.apache.org/viewvc/hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientImpl.java?rev=1655926&r1=1655925&r2=1655926&view=diff
==============================================================================
--- hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientImpl.java (original)
+++ hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientImpl.java Fri Jan 30 03:27:28 2015
@@ -84,13 +84,14 @@ class SparkClientImpl implements SparkCl
     this.childIdGenerator = new AtomicInteger();
     this.jobs = Maps.newConcurrentMap();
 
+    String clientId = UUID.randomUUID().toString();
     String secret = rpcServer.createSecret();
-    this.driverThread = startDriver(rpcServer, secret);
+    this.driverThread = startDriver(rpcServer, clientId, secret);
     this.protocol = new ClientProtocol();
 
     try {
       // The RPC server will take care of timeouts here.
-      this.driverRpc = rpcServer.registerClient(secret, protocol).get();
+      this.driverRpc = rpcServer.registerClient(clientId, secret, protocol).get();
     } catch (Exception e) {
       LOG.warn("Error while waiting for client to connect.", e);
       driverThread.interrupt();
@@ -174,7 +175,8 @@ class SparkClientImpl implements SparkCl
     protocol.cancel(jobId);
   }
 
-  private Thread startDriver(RpcServer rpcServer, final String secret) throws IOException {
+  private Thread startDriver(RpcServer rpcServer, final String clientId, final String secret)
+      throws IOException {
     Runnable runnable;
     final String serverAddress = rpcServer.getAddress();
     final String serverPort = String.valueOf(rpcServer.getPort());
@@ -190,6 +192,8 @@ class SparkClientImpl implements SparkCl
           args.add(serverAddress);
           args.add("--remote-port");
           args.add(serverPort);
+          args.add("--client-id");
+          args.add(clientId);
           args.add("--secret");
           args.add(secret);
 
@@ -243,6 +247,7 @@ class SparkClientImpl implements SparkCl
       for (Map.Entry<String, String> e : conf.entrySet()) {
         allProps.put(e.getKey(), conf.get(e.getKey()));
       }
+      allProps.put(SparkClientFactory.CONF_CLIENT_ID, clientId);
       allProps.put(SparkClientFactory.CONF_KEY_SECRET, secret);
       allProps.put(DRIVER_OPTS_KEY, driverJavaOpts);
       allProps.put(EXECUTOR_OPTS_KEY, executorJavaOpts);

Modified: hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/KryoMessageCodec.java
URL: http://svn.apache.org/viewvc/hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/KryoMessageCodec.java?rev=1655926&r1=1655925&r2=1655926&view=diff
==============================================================================
--- hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/KryoMessageCodec.java (original)
+++ hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/KryoMessageCodec.java Fri Jan 30 03:27:28 2015
@@ -18,6 +18,7 @@
 package org.apache.hive.spark.client.rpc;
 
 import java.io.ByteArrayOutputStream;
+import java.io.IOException;
 import java.nio.ByteBuffer;
 import java.util.Arrays;
 import java.util.List;
@@ -63,9 +64,12 @@ class KryoMessageCodec extends ByteToMes
     }
   };
 
+  private volatile EncryptionHandler encryptionHandler;
+
   public KryoMessageCodec(int maxMessageSize, Class<?>... messages) {
     this.maxMessageSize = maxMessageSize;
     this.messages = Arrays.asList(messages);
+    this.encryptionHandler = null;
   }
 
   @Override
@@ -86,7 +90,7 @@ class KryoMessageCodec extends ByteToMes
     }
 
     try {
-      ByteBuffer nioBuffer = in.nioBuffer(in.readerIndex(), msgSize);
+      ByteBuffer nioBuffer = maybeDecrypt(in.nioBuffer(in.readerIndex(), msgSize));
       Input kryoIn = new Input(new ByteBufferInputStream(nioBuffer));
 
       Object msg = kryos.get().readClassAndObject(kryoIn);
@@ -106,7 +110,7 @@ class KryoMessageCodec extends ByteToMes
     kryos.get().writeClassAndObject(kryoOut, msg);
     kryoOut.flush();
 
-    byte[] msgData = bytes.toByteArray();
+    byte[] msgData = maybeEncrypt(bytes.toByteArray());
     LOG.debug("Encoded message of type {} ({} bytes)", msg.getClass().getName(), msgData.length);
     checkSize(msgData.length);
 
@@ -115,10 +119,56 @@ class KryoMessageCodec extends ByteToMes
     buf.writeBytes(msgData);
   }
 
+  @Override
+  public void channelInactive(ChannelHandlerContext ctx) throws Exception {
+    if (encryptionHandler != null) {
+      encryptionHandler.dispose();
+    }
+    super.channelInactive(ctx);
+  }
+
   private void checkSize(int msgSize) {
     Preconditions.checkArgument(msgSize > 0, "Message size (%s bytes) must be positive.", msgSize);
     Preconditions.checkArgument(maxMessageSize <= 0 || msgSize <= maxMessageSize,
         "Message (%s bytes) exceeds maximum allowed size (%s bytes).", msgSize, maxMessageSize);
   }
 
+  private byte[] maybeEncrypt(byte[] data) throws Exception {
+    return (encryptionHandler != null) ? encryptionHandler.wrap(data, 0, data.length) : data;
+  }
+
+  private ByteBuffer maybeDecrypt(ByteBuffer data) throws Exception {
+    if (encryptionHandler != null) {
+      byte[] encrypted;
+      int len = data.limit() - data.position();
+      int offset;
+      if (data.hasArray()) {
+        encrypted = data.array();
+        offset = data.position() + data.arrayOffset();
+        data.position(data.limit());
+      } else {
+        encrypted = new byte[len];
+        offset = 0;
+        data.get(encrypted);
+      }
+      return ByteBuffer.wrap(encryptionHandler.unwrap(encrypted, offset, len));
+    } else {
+      return data;
+    }
+  }
+
+  void setEncryptionHandler(EncryptionHandler handler) {
+    this.encryptionHandler = handler;
+  }
+
+  interface EncryptionHandler {
+
+    byte[] wrap(byte[] data, int offset, int len) throws IOException;
+
+    byte[] unwrap(byte[] data, int offset, int len) throws IOException;
+
+    void dispose() throws IOException;
+
+  }
+
 }

Modified: hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/README.md
URL: http://svn.apache.org/viewvc/hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/README.md?rev=1655926&r1=1655925&r2=1655926&view=diff
==============================================================================
--- hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/README.md (original)
+++ hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/README.md Fri Jan 30 03:27:28 2015
@@ -6,7 +6,7 @@ Basic flow of events:
 - Client side creates an RPC server
 - Client side spawns RemoteDriver, which manages the SparkContext, and provides a secret
 - Client side sets up a timer to wait for RemoteDriver to connect back
-- RemoteDriver connects back to client, sends Hello message with secret
+- RemoteDriver connects back to client, SASL handshake ensues
 - Connection is established and now there's a session between the client and the driver.
 
 Features of the RPC layer:

Modified: hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/Rpc.java
URL: http://svn.apache.org/viewvc/hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/Rpc.java?rev=1655926&r1=1655925&r2=1655926&view=diff
==============================================================================
--- hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/Rpc.java (original)
+++ hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/Rpc.java Fri Jan 30 03:27:28 2015
@@ -17,6 +17,29 @@
 
 package org.apache.hive.spark.client.rpc;
 
+import java.io.Closeable;
+import java.io.IOException;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Map;
+import java.util.concurrent.TimeoutException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicLong;
+import java.util.concurrent.atomic.AtomicReference;
+import javax.security.auth.callback.Callback;
+import javax.security.auth.callback.CallbackHandler;
+import javax.security.auth.callback.NameCallback;
+import javax.security.auth.callback.PasswordCallback;
+import javax.security.sasl.RealmCallback;
+import javax.security.sasl.Sasl;
+import javax.security.sasl.SaslClient;
+import javax.security.sasl.SaslException;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
+import com.google.common.base.Throwables;
+import com.google.common.collect.Lists;
 import io.netty.bootstrap.Bootstrap;
 import io.netty.channel.Channel;
 import io.netty.channel.ChannelFuture;
@@ -28,32 +51,18 @@ import io.netty.channel.embedded.Embedde
 import io.netty.channel.nio.NioEventLoopGroup;
 import io.netty.channel.socket.SocketChannel;
 import io.netty.channel.socket.nio.NioSocketChannel;
-import io.netty.handler.logging.LogLevel;
 import io.netty.handler.logging.LoggingHandler;
+import io.netty.handler.logging.LogLevel;
 import io.netty.util.concurrent.EventExecutorGroup;
 import io.netty.util.concurrent.Future;
 import io.netty.util.concurrent.GenericFutureListener;
 import io.netty.util.concurrent.ImmediateEventExecutor;
 import io.netty.util.concurrent.Promise;
 import io.netty.util.concurrent.ScheduledFuture;
-
-import java.io.Closeable;
-import java.util.Collection;
-import java.util.Map;
-import java.util.concurrent.TimeUnit;
-import java.util.concurrent.TimeoutException;
-import java.util.concurrent.atomic.AtomicBoolean;
-import java.util.concurrent.atomic.AtomicLong;
-import java.util.concurrent.atomic.AtomicReference;
-
-import org.apache.hadoop.hive.common.classification.InterfaceAudience;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import com.google.common.annotations.VisibleForTesting;
-import com.google.common.base.Preconditions;
-import com.google.common.base.Throwables;
-import com.google.common.collect.Lists;
+import org.apache.hadoop.hive.common.classification.InterfaceAudience;
 
 /**
  * Encapsulates the RPC functionality. Provides higher-level methods to talk to the remote
@@ -62,9 +71,13 @@ import com.google.common.collect.Lists;
 @InterfaceAudience.Private
 public class Rpc implements Closeable {
 
-  private static final String DISPATCHER_HANDLER_NAME = "dispatcher";
   private static final Logger LOG = LoggerFactory.getLogger(Rpc.class);
 
+  static final String SASL_REALM = "rsc";
+  static final String SASL_USER = "rsc";
+  static final String SASL_PROTOCOL = "rsc";
+  static final String SASL_AUTH_CONF = "auth-conf";
+
   /**
    * Creates an RPC client for a server running on the given remote host and port.
    *
@@ -72,7 +85,8 @@ public class Rpc implements Closeable {
    * @param eloop Event loop for managing the connection.
    * @param host Host name or IP address to connect to.
    * @param port Port where server is listening.
-   * @param secret Secret for identifying the client with the server.
+   * @param clientId The client ID that identifies the connection.
+   * @param secret Secret for authenticating the client with the server.
    * @param dispatcher Dispatcher used to handle RPC calls.
    * @return A future that can be used to monitor the creation of the RPC object.
    */
@@ -81,6 +95,7 @@ public class Rpc implements Closeable {
       final NioEventLoopGroup eloop,
       String host,
       int port,
+      final String clientId,
       final String secret,
       final RpcDispatcher dispatcher) throws Exception {
     final RpcConfiguration rpcConf = new RpcConfiguration(config);
@@ -107,28 +122,17 @@ public class Rpc implements Closeable {
     final ScheduledFuture<?> timeoutFuture = eloop.schedule(timeoutTask,
         rpcConf.getServerConnectTimeoutMs(), TimeUnit.MILLISECONDS);
 
-    // The channel listener instantiates the Rpc instance when the connection is established, and
-    // sends the "Hello" message to complete the handshake.
+    // The channel listener instantiates the Rpc instance when the connection is established,
+    // and initiates the SASL handshake.
     cf.addListener(new ChannelFutureListener() {
       @Override
       public void operationComplete(ChannelFuture cf) throws Exception {
         if (cf.isSuccess()) {
-          rpc.set(createRpc(rpcConf, (SocketChannel) cf.channel(), dispatcher, eloop));
-          // The RPC listener waits for confirmation from the server that the "Hello" was good.
-          // Once it's finished, the Rpc object is provided to the caller by completing the
-          // promise.
-          Future<Void> hello = rpc.get().call(new Rpc.Hello(secret));
-          hello.addListener(new GenericFutureListener<Future<Void>>() {
-            @Override
-            public void operationComplete(Future<Void> p) {
-              timeoutFuture.cancel(true);
-              if (p.isSuccess()) {
-                promise.setSuccess(rpc.get());
-              } else {
-                promise.setFailure(p.cause());
-              }
-            }
-          });
+          SaslClientHandler saslHandler = new SaslClientHandler(rpcConf, clientId, promise,
+            timeoutFuture, secret, dispatcher);
+          Rpc rpc = createRpc(rpcConf, saslHandler, (SocketChannel) cf.channel(), eloop);
+          saslHandler.rpc = rpc;
+          saslHandler.sendHello(cf.channel());
         } else {
           promise.setFailure(cf.cause());
         }
@@ -148,16 +152,16 @@ public class Rpc implements Closeable {
     return promise;
   }
 
-  /**
-   * Creates an RPC handler for a connected socket channel.
-   *
-   * @param config RpcConfiguration object.
-   * @param client Socket channel connected to the RPC remote end.
-   * @param dispatcher Dispatcher used to handle RPC calls.
-   * @param egroup Event executor for handling futures.
-   */
-  static Rpc createRpc(RpcConfiguration config, SocketChannel client,
-      RpcDispatcher dispatcher, EventExecutorGroup egroup) {
+  static Rpc createServer(SaslHandler saslHandler, RpcConfiguration config, SocketChannel channel,
+      EventExecutorGroup egroup) throws IOException {
+    return createRpc(config, saslHandler, channel, egroup);
+  }
+
+  private static Rpc createRpc(RpcConfiguration config,
+      SaslHandler saslHandler,
+      SocketChannel client,
+      EventExecutorGroup egroup)
+      throws IOException {
     LogLevel logLevel = LogLevel.TRACE;
     if (config.getRpcChannelLogLevel() != null) {
       try {
@@ -187,15 +191,16 @@ public class Rpc implements Closeable {
     }
 
     if (logEnabled) {
-      client.pipeline()
-          .addLast("logger", new LoggingHandler(Rpc.class, logLevel));
+      client.pipeline().addLast("logger", new LoggingHandler(Rpc.class, logLevel));
     }
 
+    KryoMessageCodec kryo = new KryoMessageCodec(config.getMaxMessageSize(),
+        MessageHeader.class, NullMessage.class, SaslMessage.class);
+    saslHandler.setKryoMessageCodec(kryo);
     client.pipeline()
-        .addLast("codec", new KryoMessageCodec(config.getMaxMessageSize(),
-            MessageHeader.class, NullMessage.class))
-        .addLast(DISPATCHER_HANDLER_NAME, dispatcher);
-    return new Rpc(client, dispatcher, egroup);
+        .addLast("codec", kryo)
+        .addLast("sasl", saslHandler);
+    return new Rpc(config, client, egroup);
   }
 
   @VisibleForTesting
@@ -204,25 +209,28 @@ public class Rpc implements Closeable {
         new LoggingHandler(Rpc.class),
         new KryoMessageCodec(0, MessageHeader.class, NullMessage.class),
         dispatcher);
-    return new Rpc(c, dispatcher, ImmediateEventExecutor.INSTANCE);
+    Rpc rpc = new Rpc(new RpcConfiguration(Collections.<String, String>emptyMap()),
+      c, ImmediateEventExecutor.INSTANCE);
+    rpc.dispatcher = dispatcher;
+    return rpc;
   }
 
+  private final RpcConfiguration config;
   private final AtomicBoolean rpcClosed;
   private final AtomicLong rpcId;
-  private final AtomicReference<RpcDispatcher> dispatcher;
   private final Channel channel;
   private final Collection<Listener> listeners;
   private final EventExecutorGroup egroup;
   private final Object channelLock;
+  private volatile RpcDispatcher dispatcher;
 
-  @SuppressWarnings({ "rawtypes", "unchecked" })
-  private Rpc(Channel channel, RpcDispatcher dispatcher, EventExecutorGroup egroup) {
+  private Rpc(RpcConfiguration config, Channel channel, EventExecutorGroup egroup) {
     Preconditions.checkArgument(channel != null);
-    Preconditions.checkArgument(dispatcher != null);
     Preconditions.checkArgument(egroup != null);
+    this.config = config;
     this.channel = channel;
     this.channelLock = new Object();
-    this.dispatcher = new AtomicReference(dispatcher);
+    this.dispatcher = null;
     this.egroup = egroup;
     this.listeners = Lists.newLinkedList();
     this.rpcClosed = new AtomicBoolean();
@@ -271,13 +279,13 @@ public class Rpc implements Closeable {
             if (!cf.isSuccess() && !promise.isDone()) {
               LOG.warn("Failed to send RPC, closing connection.", cf.cause());
               promise.setFailure(cf.cause());
-              dispatcher.get().discardRpc(id);
+              dispatcher.discardRpc(id);
               close();
             }
           }
       };
 
-      dispatcher.get().registerRpc(id, promise, msg.getClass().getName());
+      dispatcher.registerRpc(id, promise, msg.getClass().getName());
       synchronized (channelLock) {
         channel.write(new MessageHeader(id, Rpc.MessageType.CALL)).addListener(listener);
         channel.writeAndFlush(msg).addListener(listener);
@@ -300,14 +308,11 @@ public class Rpc implements Closeable {
     return channel;
   }
 
-  /**
-   * This is only used by RpcServer after the handshake is successful. It shouldn't be called in
-   * any other situation. It particularly will not work for embedded channels used for testing.
-   */
-  void replaceDispatcher(RpcDispatcher newDispatcher) {
-    channel.pipeline().remove(DISPATCHER_HANDLER_NAME);
-    channel.pipeline().addLast(DISPATCHER_HANDLER_NAME, newDispatcher);
-    dispatcher.set(newDispatcher);
+  void setDispatcher(RpcDispatcher dispatcher) {
+    Preconditions.checkNotNull(dispatcher);
+    Preconditions.checkState(this.dispatcher == null);
+    this.dispatcher = dispatcher;
+    channel.pipeline().addLast("dispatcher", dispatcher);
   }
 
   @Override
@@ -359,20 +364,129 @@ public class Rpc implements Closeable {
 
   }
 
-  static class Hello {
-    final String secret;
+  static class NullMessage {
+
+  }
+
+  static class SaslMessage {
+    final String clientId;
+    final byte[] payload;
 
-    Hello() {
-      this(null);
+    SaslMessage() {
+      this(null, null);
     }
 
-    Hello(String secret) {
-      this.secret = secret;
+    SaslMessage(byte[] payload) {
+      this(null, payload);
+    }
+
+    SaslMessage(String clientId, byte[] payload) {
+      this.clientId = clientId;
+      this.payload = payload;
     }
 
   }
 
-  static class NullMessage {
+  private static class SaslClientHandler extends SaslHandler implements CallbackHandler {
+
+    private final SaslClient client;
+    private final String clientId;
+    private final String secret;
+    private final RpcDispatcher dispatcher;
+    private Promise<Rpc> promise;
+    private ScheduledFuture<?> timeout;
+
+    // Can't be set in constructor due to circular dependency.
+    private Rpc rpc;
+
+    SaslClientHandler(
+        RpcConfiguration config,
+        String clientId,
+        Promise<Rpc> promise,
+        ScheduledFuture<?> timeout,
+        String secret,
+        RpcDispatcher dispatcher)
+        throws IOException {
+      super(config);
+      this.clientId = clientId;
+      this.promise = promise;
+      this.timeout = timeout;
+      this.secret = secret;
+      this.dispatcher = dispatcher;
+      this.client = Sasl.createSaslClient(new String[] { config.getSaslMechanism() },
+        null, SASL_PROTOCOL, SASL_REALM, config.getSaslOptions(), this);
+    }
+
+    @Override
+    protected boolean isComplete() {
+      return client.isComplete();
+    }
+
+    @Override
+    protected String getNegotiatedProperty(String name) {
+      return (String) client.getNegotiatedProperty(name);
+    }
+
+    @Override
+    protected SaslMessage update(SaslMessage challenge) throws IOException {
+      byte[] response = client.evaluateChallenge(challenge.payload);
+      return response != null ? new SaslMessage(response) : null;
+    }
+
+    @Override
+    public byte[] wrap(byte[] data, int offset, int len) throws IOException {
+      return client.wrap(data, offset, len);
+    }
+
+    @Override
+    public byte[] unwrap(byte[] data, int offset, int len) throws IOException {
+      return client.unwrap(data, offset, len);
+    }
+
+    @Override
+    public void dispose() throws IOException {
+      if (!client.isComplete()) {
+        onError(new SaslException("Client closed before SASL negotiation finished."));
+      }
+      client.dispose();
+    }
+
+    @Override
+    protected void onComplete() throws Exception {
+      timeout.cancel(true);
+      rpc.setDispatcher(dispatcher);
+      promise.setSuccess(rpc);
+      timeout = null;
+      promise = null;
+    }
+
+    @Override
+    protected void onError(Throwable error) {
+      timeout.cancel(true);
+      if (!promise.isDone()) {
+        promise.setFailure(error);
+      }
+    }
+
+    @Override
+    public void handle(Callback[] callbacks) {
+      for (Callback cb : callbacks) {
+        if (cb instanceof NameCallback) {
+          ((NameCallback)cb).setName(clientId);
+        } else if (cb instanceof PasswordCallback) {
+          ((PasswordCallback)cb).setPassword(secret.toCharArray());
+        } else if (cb instanceof RealmCallback) {
+          RealmCallback rb = (RealmCallback) cb;
+          rb.setText(rb.getDefaultText());
+        }
+      }
+    }
+
+    void sendHello(Channel c) throws Exception {
+      byte[] hello = client.hasInitialResponse() ?
+        client.evaluateChallenge(new byte[0]) : new byte[0];
+      c.writeAndFlush(new SaslMessage(clientId, hello));
+    }
 
   }
 

Modified: hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/RpcConfiguration.java
URL: http://svn.apache.org/viewvc/hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/RpcConfiguration.java?rev=1655926&r1=1655925&r2=1655926&view=diff
==============================================================================
--- hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/RpcConfiguration.java (original)
+++ hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/RpcConfiguration.java Fri Jan 30 03:27:28 2015
@@ -21,18 +21,19 @@ import java.io.IOException;
 import java.net.Inet4Address;
 import java.net.InetAddress;
 import java.net.NetworkInterface;
-import java.util.Arrays;
 import java.util.Enumeration;
-import java.util.List;
+import java.util.HashMap;
 import java.util.Map;
 import java.util.concurrent.TimeUnit;
+import javax.security.sasl.Sasl;
 
+import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.ImmutableSet;
-import org.apache.hadoop.hive.conf.HiveConf;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import org.apache.hadoop.hive.common.classification.InterfaceAudience;
+import org.apache.hadoop.hive.conf.HiveConf;
 
 /**
  * Definitions of configuration keys and default values for the RPC layer.
@@ -57,6 +58,9 @@ public final class RpcConfiguration {
 
   public static final String SERVER_LISTEN_ADDRESS_KEY = "hive.spark.client.server.address";
 
+  /** Prefix for other SASL options. */
+  public static final String RPC_SASL_OPT_PREFIX = "hive.spark.client.rpc.sasl.";
+
   private final Map<String, String> config;
 
   private static final HiveConf DEFAULT_CONF = new HiveConf();
@@ -96,8 +100,7 @@ public final class RpcConfiguration {
     InetAddress address = InetAddress.getLocalHost();
     if (address.isLoopbackAddress()) {
       // Address resolves to something like 127.0.1.1, which happens on Debian;
-      // try to find
-      // a better address using the local network interfaces
+      // try to find a better address using the local network interfaces
       Enumeration<NetworkInterface> ifaces = NetworkInterface.getNetworkInterfaces();
       while (ifaces.hasMoreElements()) {
         NetworkInterface ni = ifaces.nextElement();
@@ -132,7 +135,6 @@ public final class RpcConfiguration {
     return value != null ? Integer.parseInt(value) : HiveConf.ConfVars.SPARK_RPC_MAX_THREADS.defaultIntVal;
   }
 
-
   /**
    * Utility method for a given RpcConfiguration key, to convert value to millisecond if it is a time value,
    * and return as string in either case.
@@ -148,4 +150,41 @@ public final class RpcConfiguration {
       return conf.get(key);
     }
   }
+
+  String getSaslMechanism() {
+    String value = config.get(HiveConf.ConfVars.SPARK_RPC_SASL_MECHANISM.varname);
+    return value != null ? value : HiveConf.ConfVars. SPARK_RPC_SASL_MECHANISM.defaultStrVal;
+  }
+
+  /**
+   * SASL options are namespaced under "hive.spark.client.rpc.sasl.*"; each option is the
+   * lower-case version of the constant in the "javax.security.sasl.Sasl" class (e.g. "strength"
+   * for cipher strength).
+   */
+  Map<String, String> getSaslOptions() {
+    Map<String, String> opts = new HashMap<String, String>();
+    Map<String, String> saslOpts = ImmutableMap.<String, String>builder()
+      .put(Sasl.CREDENTIALS, "credentials")
+      .put(Sasl.MAX_BUFFER, "max_buffer")
+      .put(Sasl.POLICY_FORWARD_SECRECY, "policy_forward_secrecy")
+      .put(Sasl.POLICY_NOACTIVE, "policy_noactive")
+      .put(Sasl.POLICY_NOANONYMOUS, "policy_noanonymous")
+      .put(Sasl.POLICY_NODICTIONARY, "policy_nodictionary")
+      .put(Sasl.POLICY_NOPLAINTEXT, "policy_noplaintext")
+      .put(Sasl.POLICY_PASS_CREDENTIALS, "policy_pass_credentials")
+      .put(Sasl.QOP, "qop")
+      .put(Sasl.RAW_SEND_SIZE, "raw_send_size")
+      .put(Sasl.REUSE, "reuse")
+      .put(Sasl.SERVER_AUTH, "server_auth")
+      .put(Sasl.STRENGTH, "strength")
+      .build();
+    for (Map.Entry<String, String> e : saslOpts.entrySet()) {
+      String value = config.get(RPC_SASL_OPT_PREFIX + e.getValue());
+      if (value != null) {
+        opts.put(e.getKey(), value);
+      }
+    }
+    return opts;
+  }
+
 }

Modified: hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/RpcServer.java
URL: http://svn.apache.org/viewvc/hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/RpcServer.java?rev=1655926&r1=1655925&r2=1655926&view=diff
==============================================================================
--- hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/RpcServer.java (original)
+++ hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/RpcServer.java Fri Jan 30 03:27:28 2015
@@ -17,9 +17,29 @@
 
 package org.apache.hive.spark.client.rpc;
 
+import java.io.Closeable;
+import java.io.IOException;
+import java.net.InetSocketAddress;
+import java.security.SecureRandom;
+import java.util.Map;
+import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.TimeoutException;
+import java.util.concurrent.TimeUnit;
+import javax.security.auth.callback.Callback;
+import javax.security.auth.callback.CallbackHandler;
+import javax.security.auth.callback.NameCallback;
+import javax.security.auth.callback.PasswordCallback;
+import javax.security.sasl.AuthorizeCallback;
+import javax.security.sasl.RealmCallback;
+import javax.security.sasl.Sasl;
+import javax.security.sasl.SaslException;
+import javax.security.sasl.SaslServer;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Maps;
+import com.google.common.util.concurrent.ThreadFactoryBuilder;
 import io.netty.bootstrap.ServerBootstrap;
 import io.netty.channel.Channel;
-import io.netty.channel.ChannelHandlerContext;
 import io.netty.channel.ChannelInitializer;
 import io.netty.channel.ChannelOption;
 import io.netty.channel.EventLoopGroup;
@@ -30,23 +50,10 @@ import io.netty.util.concurrent.Future;
 import io.netty.util.concurrent.GenericFutureListener;
 import io.netty.util.concurrent.Promise;
 import io.netty.util.concurrent.ScheduledFuture;
-
-import java.io.Closeable;
-import java.io.IOException;
-import java.net.InetSocketAddress;
-import java.security.SecureRandom;
-import java.util.Collection;
-import java.util.Iterator;
-import java.util.Map;
-import java.util.concurrent.ConcurrentLinkedQueue;
-import java.util.concurrent.TimeUnit;
-import java.util.concurrent.TimeoutException;
-
-import org.apache.hadoop.hive.common.classification.InterfaceAudience;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import com.google.common.util.concurrent.ThreadFactoryBuilder;
+import org.apache.hadoop.hive.common.classification.InterfaceAudience;
 
 /**
  * An RPC server. The server matches remote clients based on a secret that is generated on
@@ -63,11 +70,11 @@ public class RpcServer implements Closea
   private final Channel channel;
   private final EventLoopGroup group;
   private final int port;
-  private final Collection<ClientInfo> pendingClients;
+  private final ConcurrentMap<String, ClientInfo> pendingClients;
   private final RpcConfiguration config;
 
-  public RpcServer(Map<String, String> config) throws IOException, InterruptedException {
-    this.config = new RpcConfiguration(config);
+  public RpcServer(Map<String, String> mapConf) throws IOException, InterruptedException {
+    this.config = new RpcConfiguration(mapConf);
     this.group = new NioEventLoopGroup(
         this.config.getRpcThreadCount(),
         new ThreadFactoryBuilder()
@@ -80,9 +87,9 @@ public class RpcServer implements Closea
       .childHandler(new ChannelInitializer<SocketChannel>() {
           @Override
           public void initChannel(SocketChannel ch) throws Exception {
-            HelloDispatcher dispatcher = new HelloDispatcher();
-            final Rpc newRpc = Rpc.createRpc(RpcServer.this.config, ch, dispatcher, group);
-            dispatcher.rpc = newRpc;
+            SaslServerHandler saslHandler = new SaslServerHandler(config);
+            final Rpc newRpc = Rpc.createServer(saslHandler, config, ch, group);
+            saslHandler.rpc = newRpc;
 
             Runnable cancelTask = new Runnable() {
                 @Override
@@ -91,7 +98,7 @@ public class RpcServer implements Closea
                   newRpc.close();
                 }
             };
-            dispatcher.cancelTask = group.schedule(cancelTask,
+            saslHandler.cancelTask = group.schedule(cancelTask,
                 RpcServer.this.config.getServerConnectTimeoutMs(),
                 TimeUnit.MILLISECONDS);
 
@@ -104,19 +111,21 @@ public class RpcServer implements Closea
       .sync()
       .channel();
     this.port = ((InetSocketAddress) channel.localAddress()).getPort();
-    this.pendingClients = new ConcurrentLinkedQueue<ClientInfo>();
+    this.pendingClients = Maps.newConcurrentMap();
     this.address = this.config.getServerAddress();
   }
 
   /**
    * Tells the RPC server to expect a connection from a new client.
    *
+   * @param clientId An identifier for the client. Must be unique.
    * @param secret The secret the client will send to the server to identify itself.
    * @param serverDispatcher The dispatcher to use when setting up the RPC instance.
    * @return A future that can be used to wait for the client connection, which also provides the
    *         secret needed for the client to connect.
    */
-  public Future<Rpc> registerClient(String secret, RpcDispatcher serverDispatcher) {
+  public Future<Rpc> registerClient(final String clientId, String secret,
+      RpcDispatcher serverDispatcher) {
     final Promise<Rpc> promise = group.next().newPromise();
 
     Runnable timeout = new Runnable() {
@@ -128,15 +137,18 @@ public class RpcServer implements Closea
     ScheduledFuture<?> timeoutFuture = group.schedule(timeout,
         config.getServerConnectTimeoutMs(),
         TimeUnit.MILLISECONDS);
-    final ClientInfo client = new ClientInfo(promise, secret, serverDispatcher, timeoutFuture);
-    pendingClients.add(client);
-
+    final ClientInfo client = new ClientInfo(clientId, promise, secret, serverDispatcher,
+        timeoutFuture);
+    if (pendingClients.putIfAbsent(clientId, client) != null) {
+      throw new IllegalStateException(
+          String.format("Client '%s' already registered.", clientId));
+    }
 
     promise.addListener(new GenericFutureListener<Promise<Rpc>>() {
       @Override
       public void operationComplete(Promise<Rpc> p) {
         if (p.isCancelled()) {
-          pendingClients.remove(client);
+          pendingClients.remove(clientId);
         }
       }
     });
@@ -173,50 +185,121 @@ public class RpcServer implements Closea
   public void close() {
     try {
       channel.close();
-      for (Iterator<ClientInfo> clients = pendingClients.iterator(); clients.hasNext();) {
-        ClientInfo client = clients.next();
-        clients.remove();
+      for (ClientInfo client : pendingClients.values()) {
         client.promise.cancel(true);
       }
+      pendingClients.clear();
     } finally {
       group.shutdownGracefully();
     }
   }
 
-  private class HelloDispatcher extends RpcDispatcher {
+  private class SaslServerHandler extends SaslHandler implements CallbackHandler {
 
+    private final SaslServer server;
     private Rpc rpc;
     private ScheduledFuture<?> cancelTask;
+    private String clientId;
+    private ClientInfo client;
 
-    protected void handle(ChannelHandlerContext ctx, Rpc.Hello msg) {
+    SaslServerHandler(RpcConfiguration config) throws IOException {
+      super(config);
+      this.server = Sasl.createSaslServer(config.getSaslMechanism(), Rpc.SASL_PROTOCOL,
+        Rpc.SASL_REALM, config.getSaslOptions(), this);
+    }
+
+    @Override
+    protected boolean isComplete() {
+      return server.isComplete();
+    }
+
+    @Override
+    protected String getNegotiatedProperty(String name) {
+      return (String) server.getNegotiatedProperty(name);
+    }
+
+    @Override
+    protected Rpc.SaslMessage update(Rpc.SaslMessage challenge) throws IOException {
+      if (clientId == null) {
+        Preconditions.checkArgument(challenge.clientId != null,
+          "Missing client ID in SASL handshake.");
+        clientId = challenge.clientId;
+        client = pendingClients.get(clientId);
+        Preconditions.checkArgument(client != null,
+          "Unexpected client ID '%s' in SASL handshake.", clientId);
+      }
+
+      return new Rpc.SaslMessage(server.evaluateResponse(challenge.payload));
+    }
+
+    @Override
+    public byte[] wrap(byte[] data, int offset, int len) throws IOException {
+      return server.wrap(data, offset, len);
+    }
+
+    @Override
+    public byte[] unwrap(byte[] data, int offset, int len) throws IOException {
+      return server.unwrap(data, offset, len);
+    }
+
+    @Override
+    public void dispose() throws IOException {
+      if (!server.isComplete()) {
+        onError(new SaslException("Server closed before SASL negotiation finished."));
+      }
+      server.dispose();
+    }
+
+    @Override
+    protected void onComplete() throws Exception {
       cancelTask.cancel(true);
+      client.timeoutFuture.cancel(true);
+      rpc.setDispatcher(client.dispatcher);
+      client.promise.setSuccess(rpc);
+      pendingClients.remove(client.id);
+    }
 
-      for (Iterator<ClientInfo> clients = pendingClients.iterator(); clients.hasNext();) {
-        ClientInfo client = clients.next();
-        if (client.secret.equals(msg.secret)) {
-          rpc.replaceDispatcher(client.dispatcher);
-          client.timeoutFuture.cancel(true);
-          client.promise.setSuccess(rpc);
-          return;
+    @Override
+    protected void onError(Throwable error) {
+      cancelTask.cancel(true);
+      if (client != null) {
+        client.timeoutFuture.cancel(true);
+        if (!client.promise.isDone()) {
+          client.promise.setFailure(error);
         }
       }
+    }
 
-      LOG.debug("Closing channel because secret '{}' does not match any pending client.",
-          msg.secret);
-      ctx.close();
+    @Override
+    public void handle(Callback[] callbacks) {
+      Preconditions.checkState(client != null, "Handshake not initialized yet.");
+      for (Callback cb : callbacks) {
+        if (cb instanceof NameCallback) {
+          ((NameCallback)cb).setName(clientId);
+        } else if (cb instanceof PasswordCallback) {
+          ((PasswordCallback)cb).setPassword(client.secret.toCharArray());
+        } else if (cb instanceof AuthorizeCallback) {
+          ((AuthorizeCallback) cb).setAuthorized(true);
+        } else if (cb instanceof RealmCallback) {
+          RealmCallback rb = (RealmCallback) cb;
+          rb.setText(rb.getDefaultText());
+        }
+      }
     }
 
   }
 
   private static class ClientInfo {
 
+    final String id;
     final Promise<Rpc> promise;
     final String secret;
     final RpcDispatcher dispatcher;
     final ScheduledFuture<?> timeoutFuture;
 
-    private ClientInfo(Promise<Rpc> promise, String secret, RpcDispatcher dispatcher,
+    private ClientInfo(String id, Promise<Rpc> promise, String secret, RpcDispatcher dispatcher,
         ScheduledFuture<?> timeoutFuture) {
+      this.id = id;
       this.promise = promise;
       this.secret = secret;
       this.dispatcher = dispatcher;

Added: hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/SaslHandler.java
URL: http://svn.apache.org/viewvc/hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/SaslHandler.java?rev=1655926&view=auto
==============================================================================
--- hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/SaslHandler.java (added)
+++ hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/SaslHandler.java Fri Jan 30 03:27:28 2015
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hive.spark.client.rpc;
+
+import java.io.IOException;
+import javax.security.sasl.Sasl;
+import javax.security.sasl.SaslException;
+
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.SimpleChannelInboundHandler;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Abstract SASL handler. Abstracts the auth protocol handling and encryption, if it's enabled.
+ * Needs subclasses to provide access to the actual underlying SASL implementation (client or
+ * server).
+ */
+abstract class SaslHandler extends SimpleChannelInboundHandler<Rpc.SaslMessage>
+    implements KryoMessageCodec.EncryptionHandler {
+
+  // LOG is not static to make debugging easier (being able to identify which sub-class
+  // generated the log message).
+  private final Logger LOG;
+  private final boolean requiresEncryption;
+  private KryoMessageCodec kryo;
+  private boolean hasAuthResponse = false;
+
+  protected SaslHandler(RpcConfiguration config) {
+    this.requiresEncryption = Rpc.SASL_AUTH_CONF.equals(config.getSaslOptions().get(Sasl.QOP));
+    this.LOG = LoggerFactory.getLogger(getClass());
+  }
+
+  // Use a separate method to make it easier to create a SaslHandler without having to
+  // plumb the KryoMessageCodec instance through the constructors.
+  void setKryoMessageCodec(KryoMessageCodec kryo) {
+    this.kryo = kryo;
+  }
+
+  @Override
+  protected final void channelRead0(ChannelHandlerContext ctx, Rpc.SaslMessage msg)
+      throws Exception {
+    LOG.debug("Handling SASL challenge message...");
+    Rpc.SaslMessage response = update(msg);
+    if (response != null) {
+      LOG.debug("Sending SASL challenge response...");
+      hasAuthResponse = true;
+      ctx.channel().writeAndFlush(response).sync();
+    }
+
+    if (!isComplete()) {
+      return;
+    }
+
+    // If negotiation is complete, remove this handler from the pipeline, and register it with
+    // the Kryo instance to handle encryption if needed.
+    ctx.channel().pipeline().remove(this);
+    String qop = getNegotiatedProperty(Sasl.QOP);
+    LOG.debug("SASL negotiation finished with QOP {}.", qop);
+    if (Rpc.SASL_AUTH_CONF.equals(qop)) {
+      LOG.info("SASL confidentiality enabled.");
+      kryo.setEncryptionHandler(this);
+    } else {
+      if (requiresEncryption) {
+        throw new SaslException("Encryption required, but SASL negotiation did not set it up.");
+      }
+      dispose();
+    }
+
+    onComplete();
+  }
+
+  @Override
+  public void channelInactive(ChannelHandlerContext ctx) throws Exception {
+    dispose();
+    super.channelInactive(ctx);
+  }
+
+  @Override
+  public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
+    if (!isComplete()) {
+      LOG.info("Exception in SASL negotiation.", cause);
+      onError(cause);
+      ctx.close();
+    }
+    ctx.fireExceptionCaught(cause);
+  }
+
+  protected abstract boolean isComplete();
+
+  protected abstract String getNegotiatedProperty(String name);
+
+  protected abstract Rpc.SaslMessage update(Rpc.SaslMessage challenge) throws IOException;
+
+  protected abstract void onComplete() throws Exception;
+
+  protected abstract void onError(Throwable t);
+
+}

Modified: hive/branches/spark/spark-client/src/test/java/org/apache/hive/spark/client/rpc/TestKryoMessageCodec.java
URL: http://svn.apache.org/viewvc/hive/branches/spark/spark-client/src/test/java/org/apache/hive/spark/client/rpc/TestKryoMessageCodec.java?rev=1655926&r1=1655925&r2=1655926&view=diff
==============================================================================
--- hive/branches/spark/spark-client/src/test/java/org/apache/hive/spark/client/rpc/TestKryoMessageCodec.java (original)
+++ hive/branches/spark/spark-client/src/test/java/org/apache/hive/spark/client/rpc/TestKryoMessageCodec.java Fri Jan 30 03:27:28 2015
@@ -17,37 +17,27 @@
 
 package org.apache.hive.spark.client.rpc;
 
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertTrue;
-import static org.junit.Assert.fail;
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
 
+import com.google.common.collect.Lists;
 import io.netty.buffer.ByteBuf;
 import io.netty.buffer.UnpooledByteBufAllocator;
 import io.netty.channel.embedded.EmbeddedChannel;
 import io.netty.handler.logging.LoggingHandler;
-
-import java.util.List;
-
 import org.junit.Test;
-
-import com.google.common.collect.Lists;
+import static org.junit.Assert.*;
 
 public class TestKryoMessageCodec {
 
+  private static final String MESSAGE = "Hello World!";
+
   @Test
   public void testKryoCodec() throws Exception {
-    ByteBuf buf = newBuffer();
-    Object message = "Hello World!";
-
-    KryoMessageCodec codec = new KryoMessageCodec(0);
-    codec.encode(null, message, buf);
-
-    List<Object> objects = Lists.newArrayList();
-    codec.decode(null, buf, objects);
-
+    List<Object> objects = encodeAndDecode(MESSAGE, null);
     assertEquals(1, objects.size());
-    assertEquals(message, objects.get(0));
+    assertEquals(MESSAGE, objects.get(0));
   }
 
   @Test
@@ -76,16 +66,15 @@ public class TestKryoMessageCodec {
 
   @Test
   public void testEmbeddedChannel() throws Exception {
-    Object message = "Hello World!";
     EmbeddedChannel c = new EmbeddedChannel(
       new LoggingHandler(getClass()),
       new KryoMessageCodec(0));
-    c.writeAndFlush(message);
+    c.writeAndFlush(MESSAGE);
     assertEquals(1, c.outboundMessages().size());
-    assertFalse(message.getClass().equals(c.outboundMessages().peek().getClass()));
+    assertFalse(MESSAGE.getClass().equals(c.outboundMessages().peek().getClass()));
     c.writeInbound(c.readOutbound());
     assertEquals(1, c.inboundMessages().size());
-    assertEquals(message, c.readInbound());
+    assertEquals(MESSAGE, c.readInbound());
     c.close();
   }
 
@@ -143,6 +132,53 @@ public class TestKryoMessageCodec {
     }
   }
 
+  @Test
+  public void testEncryptionOnly() throws Exception {
+    List<Object> objects = Collections.<Object>emptyList();
+    try {
+      objects = encodeAndDecode(MESSAGE, new TestEncryptionHandler(true, false));
+    } catch (Exception e) {
+      // Pass.
+    }
+    // Do this check in case the ciphertext actually makes sense in some way.
+    for (Object msg : objects) {
+      assertFalse(MESSAGE.equals(objects.get(0)));
+    }
+  }
+
+  @Test
+  public void testDecryptionOnly() throws Exception {
+    List<Object> objects = Collections.<Object>emptyList();
+    try {
+      objects = encodeAndDecode(MESSAGE, new TestEncryptionHandler(false, true));
+    } catch (Exception e) {
+      // Pass.
+    }
+    // Do this check in case the decrypted plaintext actually makes sense in some way.
+    for (Object msg : objects) {
+      assertFalse(MESSAGE.equals(objects.get(0)));
+    }
+  }
+
+  @Test
+  public void testEncryptDecrypt() throws Exception {
+    List<Object> objects = encodeAndDecode(MESSAGE, new TestEncryptionHandler(true, true));
+    assertEquals(1, objects.size());
+    assertEquals(MESSAGE, objects.get(0));
+  }
+
+  private List<Object> encodeAndDecode(Object message, KryoMessageCodec.EncryptionHandler eh)
+      throws Exception {
+    ByteBuf buf = newBuffer();
+    KryoMessageCodec codec = new KryoMessageCodec(0);
+    codec.setEncryptionHandler(eh);
+    codec.encode(null, message, buf);
+
+    List<Object> objects = Lists.newArrayList();
+    codec.decode(null, buf, objects);
+    return objects;
+  }
+
   private ByteBuf newBuffer() {
     return UnpooledByteBufAllocator.DEFAULT.buffer(1024);
   }
@@ -159,4 +195,38 @@ public class TestKryoMessageCodec {
     }
   }
 
+  private static class TestEncryptionHandler implements KryoMessageCodec.EncryptionHandler {
+
+    private static final byte KEY = 0x42;
+
+    private final boolean encrypt;
+    private final boolean decrypt;
+
+    TestEncryptionHandler(boolean encrypt, boolean decrypt) {
+      this.encrypt = encrypt;
+      this.decrypt = decrypt;
+    }
+
+    public byte[] wrap(byte[] data, int offset, int len) throws IOException {
+      return encrypt ? transform(data, offset, len) : data;
+    }
+
+    public byte[] unwrap(byte[] data, int offset, int len) throws IOException {
+      return decrypt ? transform(data, offset, len) : data;
+    }
+
+    public void dispose() throws IOException {
+
+    }
+
+    private byte[] transform(byte[] data, int offset, int len) {
+      byte[] dest = new byte[len];
+      for (int i = 0; i < len; i++) {
+        dest[i] = (byte) (data[offset + i] ^ KEY);
+      }
+      return dest;
+    }
+
+  }
+
 }

Modified: hive/branches/spark/spark-client/src/test/java/org/apache/hive/spark/client/rpc/TestRpc.java
URL: http://svn.apache.org/viewvc/hive/branches/spark/spark-client/src/test/java/org/apache/hive/spark/client/rpc/TestRpc.java?rev=1655926&r1=1655925&r2=1655926&view=diff
==============================================================================
--- hive/branches/spark/spark-client/src/test/java/org/apache/hive/spark/client/rpc/TestRpc.java (original)
+++ hive/branches/spark/spark-client/src/test/java/org/apache/hive/spark/client/rpc/TestRpc.java Fri Jan 30 03:27:28 2015
@@ -24,6 +24,7 @@ import java.util.concurrent.Cancellation
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicInteger;
+import javax.security.sasl.SaslException;
 
 import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.Lists;
@@ -120,20 +121,21 @@ public class TestRpc {
   public void testBadHello() throws Exception {
     RpcServer server = autoClose(new RpcServer(emptyConfig));
 
-    Future<Rpc> serverRpcFuture = server.registerClient("newClient", new TestDispatcher());
+    Future<Rpc> serverRpcFuture = server.registerClient("client", "newClient",
+        new TestDispatcher());
     NioEventLoopGroup eloop = new NioEventLoopGroup();
 
     Future<Rpc> clientRpcFuture = Rpc.createClient(emptyConfig, eloop,
-        "localhost", server.getPort(), "wrongClient", new TestDispatcher());
+        "localhost", server.getPort(), "client", "wrongClient", new TestDispatcher());
 
     try {
       autoClose(clientRpcFuture.get(10, TimeUnit.SECONDS));
       fail("Should have failed to create client with wrong secret.");
     } catch (ExecutionException ee) {
-      // On failure, the server will close the channel. This will cause the client's promise
-      // to be cancelled.
+      // On failure, the SASL handler will throw an exception indicating that the SASL
+      // negotiation failed.
       assertTrue("Unexpected exception: " + ee.getCause(),
-        ee.getCause() instanceof CancellationException);
+        ee.getCause() instanceof SaslException);
     }
 
     serverRpcFuture.cancel(true);
@@ -172,6 +174,22 @@ public class TestRpc {
     }
   }
 
+  @Test
+  public void testEncryption() throws Exception {
+    Map<String, String> eConf = ImmutableMap.<String,String>builder()
+      .putAll(emptyConfig)
+      .put(RpcConfiguration.RPC_SASL_OPT_PREFIX + "qop", Rpc.SASL_AUTH_CONF)
+      .build();
+    RpcServer server = autoClose(new RpcServer(eConf));
+    Rpc[] rpcs = createRpcConnection(server, eConf);
+    Rpc client = rpcs[1];
+
+    TestMessage outbound = new TestMessage("Hello World!");
+    Future<TestMessage> call = client.call(outbound, TestMessage.class);
+    TestMessage reply = call.get(10, TimeUnit.SECONDS);
+    assertEquals(outbound.message, reply.message);
+  }
+
   private void transfer(Rpc serverRpc, Rpc clientRpc) {
     EmbeddedChannel client = (EmbeddedChannel) clientRpc.getChannel();
     EmbeddedChannel server = (EmbeddedChannel) serverRpc.getChannel();
@@ -199,11 +217,16 @@ public class TestRpc {
    * @return two-tuple (server rpc, client rpc)
    */
   private Rpc[] createRpcConnection(RpcServer server) throws Exception {
+    return createRpcConnection(server, emptyConfig);
+  }
+
+  private Rpc[] createRpcConnection(RpcServer server, Map<String, String> clientConf)
+      throws Exception {
     String secret = server.createSecret();
-    Future<Rpc> serverRpcFuture = server.registerClient(secret, new TestDispatcher());
+    Future<Rpc> serverRpcFuture = server.registerClient("client", secret, new TestDispatcher());
     NioEventLoopGroup eloop = new NioEventLoopGroup();
-    Future<Rpc> clientRpcFuture = Rpc.createClient(emptyConfig, eloop,
-        "localhost", server.getPort(), secret, new TestDispatcher());
+    Future<Rpc> clientRpcFuture = Rpc.createClient(clientConf, eloop,
+        "localhost", server.getPort(), "client", secret, new TestDispatcher());
 
     Rpc serverRpc = autoClose(serverRpcFuture.get(10, TimeUnit.SECONDS));
     Rpc clientRpc = autoClose(clientRpcFuture.get(10, TimeUnit.SECONDS));