You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hive.apache.org by xu...@apache.org on 2015/01/30 04:27:28 UTC
svn commit: r1655926 - in /hive/branches/spark:
common/src/java/org/apache/hadoop/hive/conf/
spark-client/src/main/java/org/apache/hive/spark/client/
spark-client/src/main/java/org/apache/hive/spark/client/rpc/
spark-client/src/test/java/org/apache/hiv...
Author: xuefu
Date: Fri Jan 30 03:27:28 2015
New Revision: 1655926
URL: http://svn.apache.org/r1655926
Log:
HIVE-9487: Make Remote Spark Context secure [Spark Branch] (Marcelo via Xuefu)
Added:
hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/SaslHandler.java
Modified:
hive/branches/spark/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java
hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/RemoteDriver.java
hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientFactory.java
hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientImpl.java
hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/KryoMessageCodec.java
hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/README.md
hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/Rpc.java
hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/RpcConfiguration.java
hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/RpcServer.java
hive/branches/spark/spark-client/src/test/java/org/apache/hive/spark/client/rpc/TestKryoMessageCodec.java
hive/branches/spark/spark-client/src/test/java/org/apache/hive/spark/client/rpc/TestRpc.java
Modified: hive/branches/spark/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java
URL: http://svn.apache.org/viewvc/hive/branches/spark/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java?rev=1655926&r1=1655925&r2=1655926&view=diff
==============================================================================
--- hive/branches/spark/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java (original)
+++ hive/branches/spark/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java Fri Jan 30 03:27:28 2015
@@ -2018,7 +2018,9 @@ public class HiveConf extends Configurat
SPARK_RPC_MAX_MESSAGE_SIZE("hive.spark.client.rpc.max.size", 50 * 1024 * 1024,
"Maximum message size in bytes for communication between Hive client and remote Spark driver. Default is 50MB."),
SPARK_RPC_CHANNEL_LOG_LEVEL("hive.spark.client.channel.log.level", null,
- "Channel logging level for remote Spark driver. One of {DEBUG, ERROR, INFO, TRACE, WARN}.");
+ "Channel logging level for remote Spark driver. One of {DEBUG, ERROR, INFO, TRACE, WARN}."),
+ SPARK_RPC_SASL_MECHANISM("hive.spark.client.rpc.sasl.mechanisms", "DIGEST-MD5",
+ "Name of the SASL mechanism to use for authentication.");
public final String varname;
private final String defaultExpr;
Modified: hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/RemoteDriver.java
URL: http://svn.apache.org/viewvc/hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/RemoteDriver.java?rev=1655926&r1=1655925&r2=1655926&view=diff
==============================================================================
--- hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/RemoteDriver.java (original)
+++ hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/RemoteDriver.java Fri Jan 30 03:27:28 2015
@@ -106,6 +106,8 @@ public class RemoteDriver {
serverAddress = getArg(args, idx);
} else if (key.equals("--remote-port")) {
serverPort = Integer.parseInt(getArg(args, idx));
+ } else if (key.equals("--client-id")) {
+ conf.set(SparkClientFactory.CONF_CLIENT_ID, getArg(args, idx));
} else if (key.equals("--secret")) {
conf.set(SparkClientFactory.CONF_KEY_SECRET, getArg(args, idx));
} else if (key.equals("--conf")) {
@@ -127,6 +129,8 @@ public class RemoteDriver {
LOG.debug("Remote Driver configured with: " + e._1() + "=" + e._2());
}
+ String clientId = mapConf.get(SparkClientFactory.CONF_CLIENT_ID);
+ Preconditions.checkArgument(clientId != null, "No client ID provided.");
String secret = mapConf.get(SparkClientFactory.CONF_KEY_SECRET);
Preconditions.checkArgument(secret != null, "No secret provided.");
@@ -140,8 +144,8 @@ public class RemoteDriver {
this.protocol = new DriverProtocol();
// The RPC library takes care of timing out this.
- this.clientRpc = Rpc.createClient(mapConf, egroup, serverAddress, serverPort, secret, protocol)
- .get();
+ this.clientRpc = Rpc.createClient(mapConf, egroup, serverAddress, serverPort,
+ clientId, secret, protocol).get();
this.running = true;
this.clientRpc.addListener(new Rpc.Listener() {
Modified: hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientFactory.java
URL: http://svn.apache.org/viewvc/hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientFactory.java?rev=1655926&r1=1655925&r2=1655926&view=diff
==============================================================================
--- hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientFactory.java (original)
+++ hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientFactory.java Fri Jan 30 03:27:28 2015
@@ -37,6 +37,9 @@ public final class SparkClientFactory {
/** Used to run the driver in-process, mostly for testing. */
static final String CONF_KEY_IN_PROCESS = "spark.client.do_not_use.run_driver_in_process";
+ /** Used by client and driver to share a client ID for establishing an RPC session. */
+ static final String CONF_CLIENT_ID = "spark.client.authentication.client_id";
+
/** Used by client and driver to share a secret for establishing an RPC session. */
static final String CONF_KEY_SECRET = "spark.client.authentication.secret";
Modified: hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientImpl.java
URL: http://svn.apache.org/viewvc/hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientImpl.java?rev=1655926&r1=1655925&r2=1655926&view=diff
==============================================================================
--- hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientImpl.java (original)
+++ hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientImpl.java Fri Jan 30 03:27:28 2015
@@ -84,13 +84,14 @@ class SparkClientImpl implements SparkCl
this.childIdGenerator = new AtomicInteger();
this.jobs = Maps.newConcurrentMap();
+ String clientId = UUID.randomUUID().toString();
String secret = rpcServer.createSecret();
- this.driverThread = startDriver(rpcServer, secret);
+ this.driverThread = startDriver(rpcServer, clientId, secret);
this.protocol = new ClientProtocol();
try {
// The RPC server will take care of timeouts here.
- this.driverRpc = rpcServer.registerClient(secret, protocol).get();
+ this.driverRpc = rpcServer.registerClient(clientId, secret, protocol).get();
} catch (Exception e) {
LOG.warn("Error while waiting for client to connect.", e);
driverThread.interrupt();
@@ -174,7 +175,8 @@ class SparkClientImpl implements SparkCl
protocol.cancel(jobId);
}
- private Thread startDriver(RpcServer rpcServer, final String secret) throws IOException {
+ private Thread startDriver(RpcServer rpcServer, final String clientId, final String secret)
+ throws IOException {
Runnable runnable;
final String serverAddress = rpcServer.getAddress();
final String serverPort = String.valueOf(rpcServer.getPort());
@@ -190,6 +192,8 @@ class SparkClientImpl implements SparkCl
args.add(serverAddress);
args.add("--remote-port");
args.add(serverPort);
+ args.add("--client-id");
+ args.add(clientId);
args.add("--secret");
args.add(secret);
@@ -243,6 +247,7 @@ class SparkClientImpl implements SparkCl
for (Map.Entry<String, String> e : conf.entrySet()) {
allProps.put(e.getKey(), conf.get(e.getKey()));
}
+ allProps.put(SparkClientFactory.CONF_CLIENT_ID, clientId);
allProps.put(SparkClientFactory.CONF_KEY_SECRET, secret);
allProps.put(DRIVER_OPTS_KEY, driverJavaOpts);
allProps.put(EXECUTOR_OPTS_KEY, executorJavaOpts);
Modified: hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/KryoMessageCodec.java
URL: http://svn.apache.org/viewvc/hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/KryoMessageCodec.java?rev=1655926&r1=1655925&r2=1655926&view=diff
==============================================================================
--- hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/KryoMessageCodec.java (original)
+++ hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/KryoMessageCodec.java Fri Jan 30 03:27:28 2015
@@ -18,6 +18,7 @@
package org.apache.hive.spark.client.rpc;
import java.io.ByteArrayOutputStream;
+import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.List;
@@ -63,9 +64,12 @@ class KryoMessageCodec extends ByteToMes
}
};
+ private volatile EncryptionHandler encryptionHandler;
+
public KryoMessageCodec(int maxMessageSize, Class<?>... messages) {
this.maxMessageSize = maxMessageSize;
this.messages = Arrays.asList(messages);
+ this.encryptionHandler = null;
}
@Override
@@ -86,7 +90,7 @@ class KryoMessageCodec extends ByteToMes
}
try {
- ByteBuffer nioBuffer = in.nioBuffer(in.readerIndex(), msgSize);
+ ByteBuffer nioBuffer = maybeDecrypt(in.nioBuffer(in.readerIndex(), msgSize));
Input kryoIn = new Input(new ByteBufferInputStream(nioBuffer));
Object msg = kryos.get().readClassAndObject(kryoIn);
@@ -106,7 +110,7 @@ class KryoMessageCodec extends ByteToMes
kryos.get().writeClassAndObject(kryoOut, msg);
kryoOut.flush();
- byte[] msgData = bytes.toByteArray();
+ byte[] msgData = maybeEncrypt(bytes.toByteArray());
LOG.debug("Encoded message of type {} ({} bytes)", msg.getClass().getName(), msgData.length);
checkSize(msgData.length);
@@ -115,10 +119,56 @@ class KryoMessageCodec extends ByteToMes
buf.writeBytes(msgData);
}
+ @Override
+ public void channelInactive(ChannelHandlerContext ctx) throws Exception {
+ if (encryptionHandler != null) {
+ encryptionHandler.dispose();
+ }
+ super.channelInactive(ctx);
+ }
+
private void checkSize(int msgSize) {
Preconditions.checkArgument(msgSize > 0, "Message size (%s bytes) must be positive.", msgSize);
Preconditions.checkArgument(maxMessageSize <= 0 || msgSize <= maxMessageSize,
"Message (%s bytes) exceeds maximum allowed size (%s bytes).", msgSize, maxMessageSize);
}
+ private byte[] maybeEncrypt(byte[] data) throws Exception {
+ return (encryptionHandler != null) ? encryptionHandler.wrap(data, 0, data.length) : data;
+ }
+
+ private ByteBuffer maybeDecrypt(ByteBuffer data) throws Exception {
+ if (encryptionHandler != null) {
+ byte[] encrypted;
+ int len = data.limit() - data.position();
+ int offset;
+ if (data.hasArray()) {
+ encrypted = data.array();
+ offset = data.position() + data.arrayOffset();
+ data.position(data.limit());
+ } else {
+ encrypted = new byte[len];
+ offset = 0;
+ data.get(encrypted);
+ }
+ return ByteBuffer.wrap(encryptionHandler.unwrap(encrypted, offset, len));
+ } else {
+ return data;
+ }
+ }
+
+ void setEncryptionHandler(EncryptionHandler handler) {
+ this.encryptionHandler = handler;
+ }
+
+ interface EncryptionHandler {
+
+ byte[] wrap(byte[] data, int offset, int len) throws IOException;
+
+ byte[] unwrap(byte[] data, int offset, int len) throws IOException;
+
+ void dispose() throws IOException;
+
+ }
+
}
Modified: hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/README.md
URL: http://svn.apache.org/viewvc/hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/README.md?rev=1655926&r1=1655925&r2=1655926&view=diff
==============================================================================
--- hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/README.md (original)
+++ hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/README.md Fri Jan 30 03:27:28 2015
@@ -6,7 +6,7 @@ Basic flow of events:
- Client side creates an RPC server
- Client side spawns RemoteDriver, which manages the SparkContext, and provides a secret
- Client side sets up a timer to wait for RemoteDriver to connect back
-- RemoteDriver connects back to client, sends Hello message with secret
+- RemoteDriver connects back to client, SASL handshake ensues
- Connection is established and now there's a session between the client and the driver.
Features of the RPC layer:
Modified: hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/Rpc.java
URL: http://svn.apache.org/viewvc/hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/Rpc.java?rev=1655926&r1=1655925&r2=1655926&view=diff
==============================================================================
--- hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/Rpc.java (original)
+++ hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/Rpc.java Fri Jan 30 03:27:28 2015
@@ -17,6 +17,29 @@
package org.apache.hive.spark.client.rpc;
+import java.io.Closeable;
+import java.io.IOException;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Map;
+import java.util.concurrent.TimeoutException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicLong;
+import java.util.concurrent.atomic.AtomicReference;
+import javax.security.auth.callback.Callback;
+import javax.security.auth.callback.CallbackHandler;
+import javax.security.auth.callback.NameCallback;
+import javax.security.auth.callback.PasswordCallback;
+import javax.security.sasl.RealmCallback;
+import javax.security.sasl.Sasl;
+import javax.security.sasl.SaslClient;
+import javax.security.sasl.SaslException;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
+import com.google.common.base.Throwables;
+import com.google.common.collect.Lists;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
@@ -28,32 +51,18 @@ import io.netty.channel.embedded.Embedde
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
-import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
+import io.netty.handler.logging.LogLevel;
import io.netty.util.concurrent.EventExecutorGroup;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;
import io.netty.util.concurrent.ImmediateEventExecutor;
import io.netty.util.concurrent.Promise;
import io.netty.util.concurrent.ScheduledFuture;
-
-import java.io.Closeable;
-import java.util.Collection;
-import java.util.Map;
-import java.util.concurrent.TimeUnit;
-import java.util.concurrent.TimeoutException;
-import java.util.concurrent.atomic.AtomicBoolean;
-import java.util.concurrent.atomic.AtomicLong;
-import java.util.concurrent.atomic.AtomicReference;
-
-import org.apache.hadoop.hive.common.classification.InterfaceAudience;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import com.google.common.annotations.VisibleForTesting;
-import com.google.common.base.Preconditions;
-import com.google.common.base.Throwables;
-import com.google.common.collect.Lists;
+import org.apache.hadoop.hive.common.classification.InterfaceAudience;
/**
* Encapsulates the RPC functionality. Provides higher-level methods to talk to the remote
@@ -62,9 +71,13 @@ import com.google.common.collect.Lists;
@InterfaceAudience.Private
public class Rpc implements Closeable {
- private static final String DISPATCHER_HANDLER_NAME = "dispatcher";
private static final Logger LOG = LoggerFactory.getLogger(Rpc.class);
+ static final String SASL_REALM = "rsc";
+ static final String SASL_USER = "rsc";
+ static final String SASL_PROTOCOL = "rsc";
+ static final String SASL_AUTH_CONF = "auth-conf";
+
/**
* Creates an RPC client for a server running on the given remote host and port.
*
@@ -72,7 +85,8 @@ public class Rpc implements Closeable {
* @param eloop Event loop for managing the connection.
* @param host Host name or IP address to connect to.
* @param port Port where server is listening.
- * @param secret Secret for identifying the client with the server.
+ * @param clientId The client ID that identifies the connection.
+ * @param secret Secret for authenticating the client with the server.
* @param dispatcher Dispatcher used to handle RPC calls.
* @return A future that can be used to monitor the creation of the RPC object.
*/
@@ -81,6 +95,7 @@ public class Rpc implements Closeable {
final NioEventLoopGroup eloop,
String host,
int port,
+ final String clientId,
final String secret,
final RpcDispatcher dispatcher) throws Exception {
final RpcConfiguration rpcConf = new RpcConfiguration(config);
@@ -107,28 +122,17 @@ public class Rpc implements Closeable {
final ScheduledFuture<?> timeoutFuture = eloop.schedule(timeoutTask,
rpcConf.getServerConnectTimeoutMs(), TimeUnit.MILLISECONDS);
- // The channel listener instantiates the Rpc instance when the connection is established, and
- // sends the "Hello" message to complete the handshake.
+ // The channel listener instantiates the Rpc instance when the connection is established,
+ // and initiates the SASL handshake.
cf.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture cf) throws Exception {
if (cf.isSuccess()) {
- rpc.set(createRpc(rpcConf, (SocketChannel) cf.channel(), dispatcher, eloop));
- // The RPC listener waits for confirmation from the server that the "Hello" was good.
- // Once it's finished, the Rpc object is provided to the caller by completing the
- // promise.
- Future<Void> hello = rpc.get().call(new Rpc.Hello(secret));
- hello.addListener(new GenericFutureListener<Future<Void>>() {
- @Override
- public void operationComplete(Future<Void> p) {
- timeoutFuture.cancel(true);
- if (p.isSuccess()) {
- promise.setSuccess(rpc.get());
- } else {
- promise.setFailure(p.cause());
- }
- }
- });
+ SaslClientHandler saslHandler = new SaslClientHandler(rpcConf, clientId, promise,
+ timeoutFuture, secret, dispatcher);
+ Rpc rpc = createRpc(rpcConf, saslHandler, (SocketChannel) cf.channel(), eloop);
+ saslHandler.rpc = rpc;
+ saslHandler.sendHello(cf.channel());
} else {
promise.setFailure(cf.cause());
}
@@ -148,16 +152,16 @@ public class Rpc implements Closeable {
return promise;
}
- /**
- * Creates an RPC handler for a connected socket channel.
- *
- * @param config RpcConfiguration object.
- * @param client Socket channel connected to the RPC remote end.
- * @param dispatcher Dispatcher used to handle RPC calls.
- * @param egroup Event executor for handling futures.
- */
- static Rpc createRpc(RpcConfiguration config, SocketChannel client,
- RpcDispatcher dispatcher, EventExecutorGroup egroup) {
+ static Rpc createServer(SaslHandler saslHandler, RpcConfiguration config, SocketChannel channel,
+ EventExecutorGroup egroup) throws IOException {
+ return createRpc(config, saslHandler, channel, egroup);
+ }
+
+ private static Rpc createRpc(RpcConfiguration config,
+ SaslHandler saslHandler,
+ SocketChannel client,
+ EventExecutorGroup egroup)
+ throws IOException {
LogLevel logLevel = LogLevel.TRACE;
if (config.getRpcChannelLogLevel() != null) {
try {
@@ -187,15 +191,16 @@ public class Rpc implements Closeable {
}
if (logEnabled) {
- client.pipeline()
- .addLast("logger", new LoggingHandler(Rpc.class, logLevel));
+ client.pipeline().addLast("logger", new LoggingHandler(Rpc.class, logLevel));
}
+ KryoMessageCodec kryo = new KryoMessageCodec(config.getMaxMessageSize(),
+ MessageHeader.class, NullMessage.class, SaslMessage.class);
+ saslHandler.setKryoMessageCodec(kryo);
client.pipeline()
- .addLast("codec", new KryoMessageCodec(config.getMaxMessageSize(),
- MessageHeader.class, NullMessage.class))
- .addLast(DISPATCHER_HANDLER_NAME, dispatcher);
- return new Rpc(client, dispatcher, egroup);
+ .addLast("codec", kryo)
+ .addLast("sasl", saslHandler);
+ return new Rpc(config, client, egroup);
}
@VisibleForTesting
@@ -204,25 +209,28 @@ public class Rpc implements Closeable {
new LoggingHandler(Rpc.class),
new KryoMessageCodec(0, MessageHeader.class, NullMessage.class),
dispatcher);
- return new Rpc(c, dispatcher, ImmediateEventExecutor.INSTANCE);
+ Rpc rpc = new Rpc(new RpcConfiguration(Collections.<String, String>emptyMap()),
+ c, ImmediateEventExecutor.INSTANCE);
+ rpc.dispatcher = dispatcher;
+ return rpc;
}
+ private final RpcConfiguration config;
private final AtomicBoolean rpcClosed;
private final AtomicLong rpcId;
- private final AtomicReference<RpcDispatcher> dispatcher;
private final Channel channel;
private final Collection<Listener> listeners;
private final EventExecutorGroup egroup;
private final Object channelLock;
+ private volatile RpcDispatcher dispatcher;
- @SuppressWarnings({ "rawtypes", "unchecked" })
- private Rpc(Channel channel, RpcDispatcher dispatcher, EventExecutorGroup egroup) {
+ private Rpc(RpcConfiguration config, Channel channel, EventExecutorGroup egroup) {
Preconditions.checkArgument(channel != null);
- Preconditions.checkArgument(dispatcher != null);
Preconditions.checkArgument(egroup != null);
+ this.config = config;
this.channel = channel;
this.channelLock = new Object();
- this.dispatcher = new AtomicReference(dispatcher);
+ this.dispatcher = null;
this.egroup = egroup;
this.listeners = Lists.newLinkedList();
this.rpcClosed = new AtomicBoolean();
@@ -271,13 +279,13 @@ public class Rpc implements Closeable {
if (!cf.isSuccess() && !promise.isDone()) {
LOG.warn("Failed to send RPC, closing connection.", cf.cause());
promise.setFailure(cf.cause());
- dispatcher.get().discardRpc(id);
+ dispatcher.discardRpc(id);
close();
}
}
};
- dispatcher.get().registerRpc(id, promise, msg.getClass().getName());
+ dispatcher.registerRpc(id, promise, msg.getClass().getName());
synchronized (channelLock) {
channel.write(new MessageHeader(id, Rpc.MessageType.CALL)).addListener(listener);
channel.writeAndFlush(msg).addListener(listener);
@@ -300,14 +308,11 @@ public class Rpc implements Closeable {
return channel;
}
- /**
- * This is only used by RpcServer after the handshake is successful. It shouldn't be called in
- * any other situation. It particularly will not work for embedded channels used for testing.
- */
- void replaceDispatcher(RpcDispatcher newDispatcher) {
- channel.pipeline().remove(DISPATCHER_HANDLER_NAME);
- channel.pipeline().addLast(DISPATCHER_HANDLER_NAME, newDispatcher);
- dispatcher.set(newDispatcher);
+ void setDispatcher(RpcDispatcher dispatcher) {
+ Preconditions.checkNotNull(dispatcher);
+ Preconditions.checkState(this.dispatcher == null);
+ this.dispatcher = dispatcher;
+ channel.pipeline().addLast("dispatcher", dispatcher);
}
@Override
@@ -359,20 +364,129 @@ public class Rpc implements Closeable {
}
- static class Hello {
- final String secret;
+ static class NullMessage {
+
+ }
+
+ static class SaslMessage {
+ final String clientId;
+ final byte[] payload;
- Hello() {
- this(null);
+ SaslMessage() {
+ this(null, null);
}
- Hello(String secret) {
- this.secret = secret;
+ SaslMessage(byte[] payload) {
+ this(null, payload);
+ }
+
+ SaslMessage(String clientId, byte[] payload) {
+ this.clientId = clientId;
+ this.payload = payload;
}
}
- static class NullMessage {
+ private static class SaslClientHandler extends SaslHandler implements CallbackHandler {
+
+ private final SaslClient client;
+ private final String clientId;
+ private final String secret;
+ private final RpcDispatcher dispatcher;
+ private Promise<Rpc> promise;
+ private ScheduledFuture<?> timeout;
+
+ // Can't be set in constructor due to circular dependency.
+ private Rpc rpc;
+
+ SaslClientHandler(
+ RpcConfiguration config,
+ String clientId,
+ Promise<Rpc> promise,
+ ScheduledFuture<?> timeout,
+ String secret,
+ RpcDispatcher dispatcher)
+ throws IOException {
+ super(config);
+ this.clientId = clientId;
+ this.promise = promise;
+ this.timeout = timeout;
+ this.secret = secret;
+ this.dispatcher = dispatcher;
+ this.client = Sasl.createSaslClient(new String[] { config.getSaslMechanism() },
+ null, SASL_PROTOCOL, SASL_REALM, config.getSaslOptions(), this);
+ }
+
+ @Override
+ protected boolean isComplete() {
+ return client.isComplete();
+ }
+
+ @Override
+ protected String getNegotiatedProperty(String name) {
+ return (String) client.getNegotiatedProperty(name);
+ }
+
+ @Override
+ protected SaslMessage update(SaslMessage challenge) throws IOException {
+ byte[] response = client.evaluateChallenge(challenge.payload);
+ return response != null ? new SaslMessage(response) : null;
+ }
+
+ @Override
+ public byte[] wrap(byte[] data, int offset, int len) throws IOException {
+ return client.wrap(data, offset, len);
+ }
+
+ @Override
+ public byte[] unwrap(byte[] data, int offset, int len) throws IOException {
+ return client.unwrap(data, offset, len);
+ }
+
+ @Override
+ public void dispose() throws IOException {
+ if (!client.isComplete()) {
+ onError(new SaslException("Client closed before SASL negotiation finished."));
+ }
+ client.dispose();
+ }
+
+ @Override
+ protected void onComplete() throws Exception {
+ timeout.cancel(true);
+ rpc.setDispatcher(dispatcher);
+ promise.setSuccess(rpc);
+ timeout = null;
+ promise = null;
+ }
+
+ @Override
+ protected void onError(Throwable error) {
+ timeout.cancel(true);
+ if (!promise.isDone()) {
+ promise.setFailure(error);
+ }
+ }
+
+ @Override
+ public void handle(Callback[] callbacks) {
+ for (Callback cb : callbacks) {
+ if (cb instanceof NameCallback) {
+ ((NameCallback)cb).setName(clientId);
+ } else if (cb instanceof PasswordCallback) {
+ ((PasswordCallback)cb).setPassword(secret.toCharArray());
+ } else if (cb instanceof RealmCallback) {
+ RealmCallback rb = (RealmCallback) cb;
+ rb.setText(rb.getDefaultText());
+ }
+ }
+ }
+
+ void sendHello(Channel c) throws Exception {
+ byte[] hello = client.hasInitialResponse() ?
+ client.evaluateChallenge(new byte[0]) : new byte[0];
+ c.writeAndFlush(new SaslMessage(clientId, hello));
+ }
}
Modified: hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/RpcConfiguration.java
URL: http://svn.apache.org/viewvc/hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/RpcConfiguration.java?rev=1655926&r1=1655925&r2=1655926&view=diff
==============================================================================
--- hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/RpcConfiguration.java (original)
+++ hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/RpcConfiguration.java Fri Jan 30 03:27:28 2015
@@ -21,18 +21,19 @@ import java.io.IOException;
import java.net.Inet4Address;
import java.net.InetAddress;
import java.net.NetworkInterface;
-import java.util.Arrays;
import java.util.Enumeration;
-import java.util.List;
+import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.TimeUnit;
+import javax.security.sasl.Sasl;
+import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
-import org.apache.hadoop.hive.conf.HiveConf;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.hadoop.hive.common.classification.InterfaceAudience;
+import org.apache.hadoop.hive.conf.HiveConf;
/**
* Definitions of configuration keys and default values for the RPC layer.
@@ -57,6 +58,9 @@ public final class RpcConfiguration {
public static final String SERVER_LISTEN_ADDRESS_KEY = "hive.spark.client.server.address";
+ /** Prefix for other SASL options. */
+ public static final String RPC_SASL_OPT_PREFIX = "hive.spark.client.rpc.sasl.";
+
private final Map<String, String> config;
private static final HiveConf DEFAULT_CONF = new HiveConf();
@@ -96,8 +100,7 @@ public final class RpcConfiguration {
InetAddress address = InetAddress.getLocalHost();
if (address.isLoopbackAddress()) {
// Address resolves to something like 127.0.1.1, which happens on Debian;
- // try to find
- // a better address using the local network interfaces
+ // try to find a better address using the local network interfaces
Enumeration<NetworkInterface> ifaces = NetworkInterface.getNetworkInterfaces();
while (ifaces.hasMoreElements()) {
NetworkInterface ni = ifaces.nextElement();
@@ -132,7 +135,6 @@ public final class RpcConfiguration {
return value != null ? Integer.parseInt(value) : HiveConf.ConfVars.SPARK_RPC_MAX_THREADS.defaultIntVal;
}
-
/**
* Utility method for a given RpcConfiguration key, to convert value to millisecond if it is a time value,
* and return as string in either case.
@@ -148,4 +150,41 @@ public final class RpcConfiguration {
return conf.get(key);
}
}
+
+ String getSaslMechanism() {
+ String value = config.get(HiveConf.ConfVars.SPARK_RPC_SASL_MECHANISM.varname);
+ return value != null ? value : HiveConf.ConfVars. SPARK_RPC_SASL_MECHANISM.defaultStrVal;
+ }
+
+ /**
+ * SASL options are namespaced under "hive.spark.client.rpc.sasl.*"; each option is the
+ * lower-case version of the constant in the "javax.security.sasl.Sasl" class (e.g. "strength"
+ * for cipher strength).
+ */
+ Map<String, String> getSaslOptions() {
+ Map<String, String> opts = new HashMap<String, String>();
+ Map<String, String> saslOpts = ImmutableMap.<String, String>builder()
+ .put(Sasl.CREDENTIALS, "credentials")
+ .put(Sasl.MAX_BUFFER, "max_buffer")
+ .put(Sasl.POLICY_FORWARD_SECRECY, "policy_forward_secrecy")
+ .put(Sasl.POLICY_NOACTIVE, "policy_noactive")
+ .put(Sasl.POLICY_NOANONYMOUS, "policy_noanonymous")
+ .put(Sasl.POLICY_NODICTIONARY, "policy_nodictionary")
+ .put(Sasl.POLICY_NOPLAINTEXT, "policy_noplaintext")
+ .put(Sasl.POLICY_PASS_CREDENTIALS, "policy_pass_credentials")
+ .put(Sasl.QOP, "qop")
+ .put(Sasl.RAW_SEND_SIZE, "raw_send_size")
+ .put(Sasl.REUSE, "reuse")
+ .put(Sasl.SERVER_AUTH, "server_auth")
+ .put(Sasl.STRENGTH, "strength")
+ .build();
+ for (Map.Entry<String, String> e : saslOpts.entrySet()) {
+ String value = config.get(RPC_SASL_OPT_PREFIX + e.getValue());
+ if (value != null) {
+ opts.put(e.getKey(), value);
+ }
+ }
+ return opts;
+ }
+
}
Modified: hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/RpcServer.java
URL: http://svn.apache.org/viewvc/hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/RpcServer.java?rev=1655926&r1=1655925&r2=1655926&view=diff
==============================================================================
--- hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/RpcServer.java (original)
+++ hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/RpcServer.java Fri Jan 30 03:27:28 2015
@@ -17,9 +17,29 @@
package org.apache.hive.spark.client.rpc;
+import java.io.Closeable;
+import java.io.IOException;
+import java.net.InetSocketAddress;
+import java.security.SecureRandom;
+import java.util.Map;
+import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.TimeoutException;
+import java.util.concurrent.TimeUnit;
+import javax.security.auth.callback.Callback;
+import javax.security.auth.callback.CallbackHandler;
+import javax.security.auth.callback.NameCallback;
+import javax.security.auth.callback.PasswordCallback;
+import javax.security.sasl.AuthorizeCallback;
+import javax.security.sasl.RealmCallback;
+import javax.security.sasl.Sasl;
+import javax.security.sasl.SaslException;
+import javax.security.sasl.SaslServer;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Maps;
+import com.google.common.util.concurrent.ThreadFactoryBuilder;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel;
-import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
@@ -30,23 +50,10 @@ import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;
import io.netty.util.concurrent.Promise;
import io.netty.util.concurrent.ScheduledFuture;
-
-import java.io.Closeable;
-import java.io.IOException;
-import java.net.InetSocketAddress;
-import java.security.SecureRandom;
-import java.util.Collection;
-import java.util.Iterator;
-import java.util.Map;
-import java.util.concurrent.ConcurrentLinkedQueue;
-import java.util.concurrent.TimeUnit;
-import java.util.concurrent.TimeoutException;
-
-import org.apache.hadoop.hive.common.classification.InterfaceAudience;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import com.google.common.util.concurrent.ThreadFactoryBuilder;
+import org.apache.hadoop.hive.common.classification.InterfaceAudience;
/**
* An RPC server. The server matches remote clients based on a secret that is generated on
@@ -63,11 +70,11 @@ public class RpcServer implements Closea
private final Channel channel;
private final EventLoopGroup group;
private final int port;
- private final Collection<ClientInfo> pendingClients;
+ private final ConcurrentMap<String, ClientInfo> pendingClients;
private final RpcConfiguration config;
- public RpcServer(Map<String, String> config) throws IOException, InterruptedException {
- this.config = new RpcConfiguration(config);
+ public RpcServer(Map<String, String> mapConf) throws IOException, InterruptedException {
+ this.config = new RpcConfiguration(mapConf);
this.group = new NioEventLoopGroup(
this.config.getRpcThreadCount(),
new ThreadFactoryBuilder()
@@ -80,9 +87,9 @@ public class RpcServer implements Closea
.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
public void initChannel(SocketChannel ch) throws Exception {
- HelloDispatcher dispatcher = new HelloDispatcher();
- final Rpc newRpc = Rpc.createRpc(RpcServer.this.config, ch, dispatcher, group);
- dispatcher.rpc = newRpc;
+ SaslServerHandler saslHandler = new SaslServerHandler(config);
+ final Rpc newRpc = Rpc.createServer(saslHandler, config, ch, group);
+ saslHandler.rpc = newRpc;
Runnable cancelTask = new Runnable() {
@Override
@@ -91,7 +98,7 @@ public class RpcServer implements Closea
newRpc.close();
}
};
- dispatcher.cancelTask = group.schedule(cancelTask,
+ saslHandler.cancelTask = group.schedule(cancelTask,
RpcServer.this.config.getServerConnectTimeoutMs(),
TimeUnit.MILLISECONDS);
@@ -104,19 +111,21 @@ public class RpcServer implements Closea
.sync()
.channel();
this.port = ((InetSocketAddress) channel.localAddress()).getPort();
- this.pendingClients = new ConcurrentLinkedQueue<ClientInfo>();
+ this.pendingClients = Maps.newConcurrentMap();
this.address = this.config.getServerAddress();
}
/**
* Tells the RPC server to expect a connection from a new client.
*
+ * @param clientId An identifier for the client. Must be unique.
* @param secret The secret the client will send to the server to identify itself.
* @param serverDispatcher The dispatcher to use when setting up the RPC instance.
* @return A future that can be used to wait for the client connection, which also provides the
* secret needed for the client to connect.
*/
- public Future<Rpc> registerClient(String secret, RpcDispatcher serverDispatcher) {
+ public Future<Rpc> registerClient(final String clientId, String secret,
+ RpcDispatcher serverDispatcher) {
final Promise<Rpc> promise = group.next().newPromise();
Runnable timeout = new Runnable() {
@@ -128,15 +137,18 @@ public class RpcServer implements Closea
ScheduledFuture<?> timeoutFuture = group.schedule(timeout,
config.getServerConnectTimeoutMs(),
TimeUnit.MILLISECONDS);
- final ClientInfo client = new ClientInfo(promise, secret, serverDispatcher, timeoutFuture);
- pendingClients.add(client);
-
+ final ClientInfo client = new ClientInfo(clientId, promise, secret, serverDispatcher,
+ timeoutFuture);
+ if (pendingClients.putIfAbsent(clientId, client) != null) {
+ throw new IllegalStateException(
+ String.format("Client '%s' already registered.", clientId));
+ }
promise.addListener(new GenericFutureListener<Promise<Rpc>>() {
@Override
public void operationComplete(Promise<Rpc> p) {
if (p.isCancelled()) {
- pendingClients.remove(client);
+ pendingClients.remove(clientId);
}
}
});
@@ -173,50 +185,121 @@ public class RpcServer implements Closea
public void close() {
try {
channel.close();
- for (Iterator<ClientInfo> clients = pendingClients.iterator(); clients.hasNext();) {
- ClientInfo client = clients.next();
- clients.remove();
+ for (ClientInfo client : pendingClients.values()) {
client.promise.cancel(true);
}
+ pendingClients.clear();
} finally {
group.shutdownGracefully();
}
}
- private class HelloDispatcher extends RpcDispatcher {
+ private class SaslServerHandler extends SaslHandler implements CallbackHandler {
+ private final SaslServer server;
private Rpc rpc;
private ScheduledFuture<?> cancelTask;
+ private String clientId;
+ private ClientInfo client;
- protected void handle(ChannelHandlerContext ctx, Rpc.Hello msg) {
+ SaslServerHandler(RpcConfiguration config) throws IOException {
+ super(config);
+ this.server = Sasl.createSaslServer(config.getSaslMechanism(), Rpc.SASL_PROTOCOL,
+ Rpc.SASL_REALM, config.getSaslOptions(), this);
+ }
+
+ @Override
+ protected boolean isComplete() {
+ return server.isComplete();
+ }
+
+ @Override
+ protected String getNegotiatedProperty(String name) {
+ return (String) server.getNegotiatedProperty(name);
+ }
+
+ @Override
+ protected Rpc.SaslMessage update(Rpc.SaslMessage challenge) throws IOException {
+ if (clientId == null) {
+ Preconditions.checkArgument(challenge.clientId != null,
+ "Missing client ID in SASL handshake.");
+ clientId = challenge.clientId;
+ client = pendingClients.get(clientId);
+ Preconditions.checkArgument(client != null,
+ "Unexpected client ID '%s' in SASL handshake.", clientId);
+ }
+
+ return new Rpc.SaslMessage(server.evaluateResponse(challenge.payload));
+ }
+
+ @Override
+ public byte[] wrap(byte[] data, int offset, int len) throws IOException {
+ return server.wrap(data, offset, len);
+ }
+
+ @Override
+ public byte[] unwrap(byte[] data, int offset, int len) throws IOException {
+ return server.unwrap(data, offset, len);
+ }
+
+ @Override
+ public void dispose() throws IOException {
+ if (!server.isComplete()) {
+ onError(new SaslException("Server closed before SASL negotiation finished."));
+ }
+ server.dispose();
+ }
+
+ @Override
+ protected void onComplete() throws Exception {
cancelTask.cancel(true);
+ client.timeoutFuture.cancel(true);
+ rpc.setDispatcher(client.dispatcher);
+ client.promise.setSuccess(rpc);
+ pendingClients.remove(client.id);
+ }
- for (Iterator<ClientInfo> clients = pendingClients.iterator(); clients.hasNext();) {
- ClientInfo client = clients.next();
- if (client.secret.equals(msg.secret)) {
- rpc.replaceDispatcher(client.dispatcher);
- client.timeoutFuture.cancel(true);
- client.promise.setSuccess(rpc);
- return;
+ @Override
+ protected void onError(Throwable error) {
+ cancelTask.cancel(true);
+ if (client != null) {
+ client.timeoutFuture.cancel(true);
+ if (!client.promise.isDone()) {
+ client.promise.setFailure(error);
}
}
+ }
- LOG.debug("Closing channel because secret '{}' does not match any pending client.",
- msg.secret);
- ctx.close();
+ @Override
+ public void handle(Callback[] callbacks) {
+ Preconditions.checkState(client != null, "Handshake not initialized yet.");
+ for (Callback cb : callbacks) {
+ if (cb instanceof NameCallback) {
+ ((NameCallback)cb).setName(clientId);
+ } else if (cb instanceof PasswordCallback) {
+ ((PasswordCallback)cb).setPassword(client.secret.toCharArray());
+ } else if (cb instanceof AuthorizeCallback) {
+ ((AuthorizeCallback) cb).setAuthorized(true);
+ } else if (cb instanceof RealmCallback) {
+ RealmCallback rb = (RealmCallback) cb;
+ rb.setText(rb.getDefaultText());
+ }
+ }
}
}
private static class ClientInfo {
+ final String id;
final Promise<Rpc> promise;
final String secret;
final RpcDispatcher dispatcher;
final ScheduledFuture<?> timeoutFuture;
- private ClientInfo(Promise<Rpc> promise, String secret, RpcDispatcher dispatcher,
+ private ClientInfo(String id, Promise<Rpc> promise, String secret, RpcDispatcher dispatcher,
ScheduledFuture<?> timeoutFuture) {
+ this.id = id;
this.promise = promise;
this.secret = secret;
this.dispatcher = dispatcher;
Added: hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/SaslHandler.java
URL: http://svn.apache.org/viewvc/hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/SaslHandler.java?rev=1655926&view=auto
==============================================================================
--- hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/SaslHandler.java (added)
+++ hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/SaslHandler.java Fri Jan 30 03:27:28 2015
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hive.spark.client.rpc;
+
+import java.io.IOException;
+import javax.security.sasl.Sasl;
+import javax.security.sasl.SaslException;
+
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.SimpleChannelInboundHandler;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Abstract SASL handler. Abstracts the auth protocol handling and encryption, if it's enabled.
+ * Needs subclasses to provide access to the actual underlying SASL implementation (client or
+ * server).
+ */
+abstract class SaslHandler extends SimpleChannelInboundHandler<Rpc.SaslMessage>
+ implements KryoMessageCodec.EncryptionHandler {
+
+ // LOG is not static to make debugging easier (being able to identify which sub-class
+ // generated the log message).
+ private final Logger LOG;
+ private final boolean requiresEncryption;
+ private KryoMessageCodec kryo;
+ private boolean hasAuthResponse = false;
+
+ protected SaslHandler(RpcConfiguration config) {
+ this.requiresEncryption = Rpc.SASL_AUTH_CONF.equals(config.getSaslOptions().get(Sasl.QOP));
+ this.LOG = LoggerFactory.getLogger(getClass());
+ }
+
+ // Use a separate method to make it easier to create a SaslHandler without having to
+ // plumb the KryoMessageCodec instance through the constructors.
+ void setKryoMessageCodec(KryoMessageCodec kryo) {
+ this.kryo = kryo;
+ }
+
+ @Override
+ protected final void channelRead0(ChannelHandlerContext ctx, Rpc.SaslMessage msg)
+ throws Exception {
+ LOG.debug("Handling SASL challenge message...");
+ Rpc.SaslMessage response = update(msg);
+ if (response != null) {
+ LOG.debug("Sending SASL challenge response...");
+ hasAuthResponse = true;
+ ctx.channel().writeAndFlush(response).sync();
+ }
+
+ if (!isComplete()) {
+ return;
+ }
+
+ // If negotiation is complete, remove this handler from the pipeline, and register it with
+ // the Kryo instance to handle encryption if needed.
+ ctx.channel().pipeline().remove(this);
+ String qop = getNegotiatedProperty(Sasl.QOP);
+ LOG.debug("SASL negotiation finished with QOP {}.", qop);
+ if (Rpc.SASL_AUTH_CONF.equals(qop)) {
+ LOG.info("SASL confidentiality enabled.");
+ kryo.setEncryptionHandler(this);
+ } else {
+ if (requiresEncryption) {
+ throw new SaslException("Encryption required, but SASL negotiation did not set it up.");
+ }
+ dispose();
+ }
+
+ onComplete();
+ }
+
+ @Override
+ public void channelInactive(ChannelHandlerContext ctx) throws Exception {
+ dispose();
+ super.channelInactive(ctx);
+ }
+
+ @Override
+ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
+ if (!isComplete()) {
+ LOG.info("Exception in SASL negotiation.", cause);
+ onError(cause);
+ ctx.close();
+ }
+ ctx.fireExceptionCaught(cause);
+ }
+
+ protected abstract boolean isComplete();
+
+ protected abstract String getNegotiatedProperty(String name);
+
+ protected abstract Rpc.SaslMessage update(Rpc.SaslMessage challenge) throws IOException;
+
+ protected abstract void onComplete() throws Exception;
+
+ protected abstract void onError(Throwable t);
+
+}
Modified: hive/branches/spark/spark-client/src/test/java/org/apache/hive/spark/client/rpc/TestKryoMessageCodec.java
URL: http://svn.apache.org/viewvc/hive/branches/spark/spark-client/src/test/java/org/apache/hive/spark/client/rpc/TestKryoMessageCodec.java?rev=1655926&r1=1655925&r2=1655926&view=diff
==============================================================================
--- hive/branches/spark/spark-client/src/test/java/org/apache/hive/spark/client/rpc/TestKryoMessageCodec.java (original)
+++ hive/branches/spark/spark-client/src/test/java/org/apache/hive/spark/client/rpc/TestKryoMessageCodec.java Fri Jan 30 03:27:28 2015
@@ -17,37 +17,27 @@
package org.apache.hive.spark.client.rpc;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertTrue;
-import static org.junit.Assert.fail;
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+import com.google.common.collect.Lists;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.UnpooledByteBufAllocator;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.logging.LoggingHandler;
-
-import java.util.List;
-
import org.junit.Test;
-
-import com.google.common.collect.Lists;
+import static org.junit.Assert.*;
public class TestKryoMessageCodec {
+ private static final String MESSAGE = "Hello World!";
+
@Test
public void testKryoCodec() throws Exception {
- ByteBuf buf = newBuffer();
- Object message = "Hello World!";
-
- KryoMessageCodec codec = new KryoMessageCodec(0);
- codec.encode(null, message, buf);
-
- List<Object> objects = Lists.newArrayList();
- codec.decode(null, buf, objects);
-
+ List<Object> objects = encodeAndDecode(MESSAGE, null);
assertEquals(1, objects.size());
- assertEquals(message, objects.get(0));
+ assertEquals(MESSAGE, objects.get(0));
}
@Test
@@ -76,16 +66,15 @@ public class TestKryoMessageCodec {
@Test
public void testEmbeddedChannel() throws Exception {
- Object message = "Hello World!";
EmbeddedChannel c = new EmbeddedChannel(
new LoggingHandler(getClass()),
new KryoMessageCodec(0));
- c.writeAndFlush(message);
+ c.writeAndFlush(MESSAGE);
assertEquals(1, c.outboundMessages().size());
- assertFalse(message.getClass().equals(c.outboundMessages().peek().getClass()));
+ assertFalse(MESSAGE.getClass().equals(c.outboundMessages().peek().getClass()));
c.writeInbound(c.readOutbound());
assertEquals(1, c.inboundMessages().size());
- assertEquals(message, c.readInbound());
+ assertEquals(MESSAGE, c.readInbound());
c.close();
}
@@ -143,6 +132,53 @@ public class TestKryoMessageCodec {
}
}
+ @Test
+ public void testEncryptionOnly() throws Exception {
+ List<Object> objects = Collections.<Object>emptyList();
+ try {
+ objects = encodeAndDecode(MESSAGE, new TestEncryptionHandler(true, false));
+ } catch (Exception e) {
+ // Pass.
+ }
+ // Do this check in case the ciphertext actually makes sense in some way.
+ for (Object msg : objects) {
+ assertFalse(MESSAGE.equals(objects.get(0)));
+ }
+ }
+
+ @Test
+ public void testDecryptionOnly() throws Exception {
+ List<Object> objects = Collections.<Object>emptyList();
+ try {
+ objects = encodeAndDecode(MESSAGE, new TestEncryptionHandler(false, true));
+ } catch (Exception e) {
+ // Pass.
+ }
+ // Do this check in case the decrypted plaintext actually makes sense in some way.
+ for (Object msg : objects) {
+ assertFalse(MESSAGE.equals(objects.get(0)));
+ }
+ }
+
+ @Test
+ public void testEncryptDecrypt() throws Exception {
+ List<Object> objects = encodeAndDecode(MESSAGE, new TestEncryptionHandler(true, true));
+ assertEquals(1, objects.size());
+ assertEquals(MESSAGE, objects.get(0));
+ }
+
+ private List<Object> encodeAndDecode(Object message, KryoMessageCodec.EncryptionHandler eh)
+ throws Exception {
+ ByteBuf buf = newBuffer();
+ KryoMessageCodec codec = new KryoMessageCodec(0);
+ codec.setEncryptionHandler(eh);
+ codec.encode(null, message, buf);
+
+ List<Object> objects = Lists.newArrayList();
+ codec.decode(null, buf, objects);
+ return objects;
+ }
+
private ByteBuf newBuffer() {
return UnpooledByteBufAllocator.DEFAULT.buffer(1024);
}
@@ -159,4 +195,38 @@ public class TestKryoMessageCodec {
}
}
+ private static class TestEncryptionHandler implements KryoMessageCodec.EncryptionHandler {
+
+ private static final byte KEY = 0x42;
+
+ private final boolean encrypt;
+ private final boolean decrypt;
+
+ TestEncryptionHandler(boolean encrypt, boolean decrypt) {
+ this.encrypt = encrypt;
+ this.decrypt = decrypt;
+ }
+
+ public byte[] wrap(byte[] data, int offset, int len) throws IOException {
+ return encrypt ? transform(data, offset, len) : data;
+ }
+
+ public byte[] unwrap(byte[] data, int offset, int len) throws IOException {
+ return decrypt ? transform(data, offset, len) : data;
+ }
+
+ public void dispose() throws IOException {
+
+ }
+
+ private byte[] transform(byte[] data, int offset, int len) {
+ byte[] dest = new byte[len];
+ for (int i = 0; i < len; i++) {
+ dest[i] = (byte) (data[offset + i] ^ KEY);
+ }
+ return dest;
+ }
+
+ }
+
}
Modified: hive/branches/spark/spark-client/src/test/java/org/apache/hive/spark/client/rpc/TestRpc.java
URL: http://svn.apache.org/viewvc/hive/branches/spark/spark-client/src/test/java/org/apache/hive/spark/client/rpc/TestRpc.java?rev=1655926&r1=1655925&r2=1655926&view=diff
==============================================================================
--- hive/branches/spark/spark-client/src/test/java/org/apache/hive/spark/client/rpc/TestRpc.java (original)
+++ hive/branches/spark/spark-client/src/test/java/org/apache/hive/spark/client/rpc/TestRpc.java Fri Jan 30 03:27:28 2015
@@ -24,6 +24,7 @@ import java.util.concurrent.Cancellation
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
+import javax.security.sasl.SaslException;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
@@ -120,20 +121,21 @@ public class TestRpc {
public void testBadHello() throws Exception {
RpcServer server = autoClose(new RpcServer(emptyConfig));
- Future<Rpc> serverRpcFuture = server.registerClient("newClient", new TestDispatcher());
+ Future<Rpc> serverRpcFuture = server.registerClient("client", "newClient",
+ new TestDispatcher());
NioEventLoopGroup eloop = new NioEventLoopGroup();
Future<Rpc> clientRpcFuture = Rpc.createClient(emptyConfig, eloop,
- "localhost", server.getPort(), "wrongClient", new TestDispatcher());
+ "localhost", server.getPort(), "client", "wrongClient", new TestDispatcher());
try {
autoClose(clientRpcFuture.get(10, TimeUnit.SECONDS));
fail("Should have failed to create client with wrong secret.");
} catch (ExecutionException ee) {
- // On failure, the server will close the channel. This will cause the client's promise
- // to be cancelled.
+ // On failure, the SASL handler will throw an exception indicating that the SASL
+ // negotiation failed.
assertTrue("Unexpected exception: " + ee.getCause(),
- ee.getCause() instanceof CancellationException);
+ ee.getCause() instanceof SaslException);
}
serverRpcFuture.cancel(true);
@@ -172,6 +174,22 @@ public class TestRpc {
}
}
+ @Test
+ public void testEncryption() throws Exception {
+ Map<String, String> eConf = ImmutableMap.<String,String>builder()
+ .putAll(emptyConfig)
+ .put(RpcConfiguration.RPC_SASL_OPT_PREFIX + "qop", Rpc.SASL_AUTH_CONF)
+ .build();
+ RpcServer server = autoClose(new RpcServer(eConf));
+ Rpc[] rpcs = createRpcConnection(server, eConf);
+ Rpc client = rpcs[1];
+
+ TestMessage outbound = new TestMessage("Hello World!");
+ Future<TestMessage> call = client.call(outbound, TestMessage.class);
+ TestMessage reply = call.get(10, TimeUnit.SECONDS);
+ assertEquals(outbound.message, reply.message);
+ }
+
private void transfer(Rpc serverRpc, Rpc clientRpc) {
EmbeddedChannel client = (EmbeddedChannel) clientRpc.getChannel();
EmbeddedChannel server = (EmbeddedChannel) serverRpc.getChannel();
@@ -199,11 +217,16 @@ public class TestRpc {
* @return two-tuple (server rpc, client rpc)
*/
private Rpc[] createRpcConnection(RpcServer server) throws Exception {
+ return createRpcConnection(server, emptyConfig);
+ }
+
+ private Rpc[] createRpcConnection(RpcServer server, Map<String, String> clientConf)
+ throws Exception {
String secret = server.createSecret();
- Future<Rpc> serverRpcFuture = server.registerClient(secret, new TestDispatcher());
+ Future<Rpc> serverRpcFuture = server.registerClient("client", secret, new TestDispatcher());
NioEventLoopGroup eloop = new NioEventLoopGroup();
- Future<Rpc> clientRpcFuture = Rpc.createClient(emptyConfig, eloop,
- "localhost", server.getPort(), secret, new TestDispatcher());
+ Future<Rpc> clientRpcFuture = Rpc.createClient(clientConf, eloop,
+ "localhost", server.getPort(), "client", secret, new TestDispatcher());
Rpc serverRpc = autoClose(serverRpcFuture.get(10, TimeUnit.SECONDS));
Rpc clientRpc = autoClose(clientRpcFuture.get(10, TimeUnit.SECONDS));