You are viewing a plain text version of this content. The canonical link for it is here.
Posted to common-commits@hadoop.apache.org by ha...@apache.org on 2010/09/02 02:35:30 UTC
svn commit: r991780 - in /hadoop/common/trunk: ./
src/java/org/apache/hadoop/ipc/ src/test/core/org/apache/hadoop/ipc/
Author: hairong
Date: Thu Sep 2 00:35:30 2010
New Revision: 991780
URL: http://svn.apache.org/viewvc?rev=991780&view=rev
Log:
HADOOP-6907. Rpc client doesn't use the per-connection conf to figure out server's Kerberos principal. Contributed by Kan Zhang.
Modified:
hadoop/common/trunk/CHANGES.txt
hadoop/common/trunk/src/java/org/apache/hadoop/ipc/Client.java
hadoop/common/trunk/src/java/org/apache/hadoop/ipc/WritableRpcEngine.java
hadoop/common/trunk/src/test/core/org/apache/hadoop/ipc/TestIPC.java
hadoop/common/trunk/src/test/core/org/apache/hadoop/ipc/TestSaslRPC.java
Modified: hadoop/common/trunk/CHANGES.txt
URL: http://svn.apache.org/viewvc/hadoop/common/trunk/CHANGES.txt?rev=991780&r1=991779&r2=991780&view=diff
==============================================================================
--- hadoop/common/trunk/CHANGES.txt (original)
+++ hadoop/common/trunk/CHANGES.txt Thu Sep 2 00:35:30 2010
@@ -226,6 +226,9 @@ Trunk (unreleased changes)
HADOOP-6913. Circular initialization between UserGroupInformation and
KerberosName (Kan Zhang via boryas)
+ HADOOP-6907. Rpc client doesn't use the per-connection conf to figure
+ out server's Kerberos principal (Kan Zhang via hairong)
+
Release 0.21.0 - Unreleased
INCOMPATIBLE CHANGES
Modified: hadoop/common/trunk/src/java/org/apache/hadoop/ipc/Client.java
URL: http://svn.apache.org/viewvc/hadoop/common/trunk/src/java/org/apache/hadoop/ipc/Client.java?rev=991780&r1=991779&r2=991780&view=diff
==============================================================================
--- hadoop/common/trunk/src/java/org/apache/hadoop/ipc/Client.java (original)
+++ hadoop/common/trunk/src/java/org/apache/hadoop/ipc/Client.java Thu Sep 2 00:35:30 2010
@@ -37,6 +37,7 @@ import java.security.PrivilegedException
import java.util.Hashtable;
import java.util.Iterator;
import java.util.Random;
+import java.util.Set;
import java.util.Map.Entry;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
@@ -45,6 +46,8 @@ import javax.net.SocketFactory;
import org.apache.commons.logging.*;
+import org.apache.hadoop.classification.InterfaceAudience;
+import org.apache.hadoop.classification.InterfaceStability;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.IOUtils;
import org.apache.hadoop.io.Text;
@@ -80,12 +83,6 @@ public class Client {
private int counter; // counter for call ids
private AtomicBoolean running = new AtomicBoolean(true); // if client runs
final private Configuration conf;
- final private int maxIdleTime; //connections will be culled if it was idle for
- //maxIdleTime msecs
- final private int maxRetries; //the max. no. of retries for socket connections
- private boolean tcpNoDelay; // if T then disable Nagle's Algorithm
- private int pingInterval; // how often sends ping to the server in msecs
- final private boolean doPing; //do we need to send ping message
private SocketFactory socketFactory; // how to create sockets
private int refCount = 1;
@@ -220,6 +217,12 @@ public class Client {
private DataInputStream in;
private DataOutputStream out;
private int rpcTimeout;
+ private int maxIdleTime; //connections will be culled if it was idle for
+ //maxIdleTime msecs
+ private int maxRetries; //the max. no. of retries for socket connections
+ private boolean tcpNoDelay; // if T then disable Nagle's Algorithm
+ private boolean doPing; //do we need to send ping message
+ private int pingInterval; // how often sends ping to the server in msecs
// currently active calls
private Hashtable<Integer, Call> calls = new Hashtable<Integer, Call>();
@@ -235,6 +238,15 @@ public class Client {
remoteId.getAddress().getHostName());
}
this.rpcTimeout = remoteId.getRpcTimeout();
+ this.maxIdleTime = remoteId.getMaxIdleTime();
+ this.maxRetries = remoteId.getMaxRetries();
+ this.tcpNoDelay = remoteId.getTcpNoDelay();
+ this.doPing = remoteId.getDoPing();
+ this.pingInterval = remoteId.getPingInterval();
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("The ping interval is" + this.pingInterval + "ms.");
+ }
+
UserGroupInformation ticket = remoteId.getTicket();
Class<?> protocol = remoteId.getProtocol();
this.useSasl = UserGroupInformation.isSecurityEnabled();
@@ -256,15 +268,9 @@ public class Client {
}
KerberosInfo krbInfo = protocol.getAnnotation(KerberosInfo.class);
if (krbInfo != null) {
- String serverKey = krbInfo.serverPrincipal();
- if (serverKey == null) {
- throw new IOException(
- "Can't obtain server Kerberos config key from KerberosInfo");
- }
- serverPrincipal = SecurityUtil.getServerPrincipal(
- conf.get(serverKey), server.getAddress().getCanonicalHostName());
+ serverPrincipal = remoteId.getServerPrincipal();
if (LOG.isDebugEnabled()) {
- LOG.debug("RPC Server Kerberos principal name for protocol="
+ LOG.debug("RPC Server's Kerberos principal name for protocol="
+ protocol.getCanonicalName() + " is " + serverPrincipal);
}
}
@@ -882,15 +888,6 @@ public class Client {
public Client(Class<? extends Writable> valueClass, Configuration conf,
SocketFactory factory) {
this.valueClass = valueClass;
- this.maxIdleTime =
- conf.getInt("ipc.client.connection.maxidletime", 10000); //10s
- this.maxRetries = conf.getInt("ipc.client.connect.max.retries", 10);
- this.tcpNoDelay = conf.getBoolean("ipc.client.tcpnodelay", false);
- this.doPing = conf.getBoolean("ipc.client.ping", true);
- this.pingInterval = getPingInterval(conf);
- if (LOG.isDebugEnabled()) {
- LOG.debug("The ping interval is" + this.pingInterval + "ms.");
- }
this.conf = conf;
this.socketFactory = factory;
}
@@ -942,7 +939,7 @@ public class Client {
/** Make a call, passing <code>param</code>, to the IPC server running at
* <code>address</code>, returning the value. Throws exceptions if there are
* network problems or if the remote code threw an exception.
- * @deprecated Use {@link #call(Writable, InetSocketAddress, Class, UserGroupInformation, int)} instead
+ * @deprecated Use {@link #call(Writable, ConnectionId)} instead
*/
@Deprecated
public Writable call(Writable param, InetSocketAddress address)
@@ -955,27 +952,60 @@ public class Client {
* the value.
* Throws exceptions if there are network problems or if the remote code
* threw an exception.
- * @deprecated Use {@link #call(Writable, InetSocketAddress, Class, UserGroupInformation, int)} instead
+ * @deprecated Use {@link #call(Writable, ConnectionId)} instead
*/
@Deprecated
public Writable call(Writable param, InetSocketAddress addr,
UserGroupInformation ticket)
throws InterruptedException, IOException {
- return call(param, addr, null, ticket, 0);
+ ConnectionId remoteId = ConnectionId.getConnectionId(addr, null, ticket, 0,
+ conf);
+ return call(param, remoteId);
}
/** Make a call, passing <code>param</code>, to the IPC server running at
* <code>address</code> which is servicing the <code>protocol</code> protocol,
- * with the <code>ticket</code> credentials, returning the value.
+ * with the <code>ticket</code> credentials and <code>rpcTimeout</code> as
+ * timeout, returning the value.
* Throws exceptions if there are network problems or if the remote code
- * threw an exception. */
+ * threw an exception.
+ * @deprecated Use {@link #call(Writable, ConnectionId)} instead
+ */
+ @Deprecated
public Writable call(Writable param, InetSocketAddress addr,
Class<?> protocol, UserGroupInformation ticket,
int rpcTimeout)
throws InterruptedException, IOException {
+ ConnectionId remoteId = ConnectionId.getConnectionId(addr, protocol,
+ ticket, rpcTimeout, conf);
+ return call(param, remoteId);
+ }
+
+ /**
+ * Make a call, passing <code>param</code>, to the IPC server running at
+ * <code>address</code> which is servicing the <code>protocol</code> protocol,
+ * with the <code>ticket</code> credentials, <code>rpcTimeout</code> as
+ * timeout and <code>conf</code> as conf for this connection, returning the
+ * value. Throws exceptions if there are network problems or if the remote
+ * code threw an exception.
+ */
+ public Writable call(Writable param, InetSocketAddress addr,
+ Class<?> protocol, UserGroupInformation ticket,
+ int rpcTimeout, Configuration conf)
+ throws InterruptedException, IOException {
+ ConnectionId remoteId = ConnectionId.getConnectionId(addr, protocol,
+ ticket, rpcTimeout, conf);
+ return call(param, remoteId);
+ }
+
+ /** Make a call, passing <code>param</code>, to the IPC server defined by
+ * <code>remoteId</code>, returning the value.
+ * Throws exceptions if there are network problems or if the remote code
+ * threw an exception. */
+ public Writable call(Writable param, ConnectionId remoteId)
+ throws InterruptedException, IOException {
Call call = new Call(param);
- Connection connection = getConnection(
- addr, protocol, ticket, rpcTimeout, call);
+ Connection connection = getConnection(remoteId, call);
connection.sendParam(call); // send the parameter
boolean interrupted = false;
synchronized (call) {
@@ -998,7 +1028,7 @@ public class Client {
call.error.fillInStackTrace();
throw call.error;
} else { // local exception
- throw wrapException(addr, call.error);
+ throw wrapException(remoteId.getAddress(), call.error);
}
} else {
return call.value;
@@ -1038,25 +1068,34 @@ public class Client {
}
/**
- * Makes a set of calls in parallel. Each parameter is sent to the
- * corresponding address. When all values are available, or have timed out
- * or errored, the collected results are returned in an array. The array
- * contains nulls for calls that timed out or errored.
- * @deprecated Use {@link #call(Writable[], InetSocketAddress[], Class, UserGroupInformation)} instead
+ * @deprecated Use {@link #call(Writable[], InetSocketAddress[],
+ * Class, UserGroupInformation, Configuration)} instead
*/
@Deprecated
public Writable[] call(Writable[] params, InetSocketAddress[] addresses)
throws IOException, InterruptedException {
- return call(params, addresses, null, null);
+ return call(params, addresses, null, null, conf);
+ }
+
+ /**
+ * @deprecated Use {@link #call(Writable[], InetSocketAddress[],
+ * Class, UserGroupInformation, Configuration)} instead
+ */
+ @Deprecated
+ public Writable[] call(Writable[] params, InetSocketAddress[] addresses,
+ Class<?> protocol, UserGroupInformation ticket)
+ throws IOException, InterruptedException {
+ return call(params, addresses, protocol, ticket, conf);
}
+
/** Makes a set of calls in parallel. Each parameter is sent to the
* corresponding address. When all values are available, or have timed out
* or errored, the collected results are returned in an array. The array
* contains nulls for calls that timed out or errored. */
- public Writable[] call(Writable[] params, InetSocketAddress[] addresses,
- Class<?> protocol, UserGroupInformation ticket)
- throws IOException, InterruptedException {
+ public Writable[] call(Writable[] params, InetSocketAddress[] addresses,
+ Class<?> protocol, UserGroupInformation ticket, Configuration conf)
+ throws IOException, InterruptedException {
if (addresses.length == 0) return new Writable[0];
ParallelResults results = new ParallelResults(params.length);
@@ -1064,8 +1103,9 @@ public class Client {
for (int i = 0; i < params.length; i++) {
ParallelCall call = new ParallelCall(params[i], results, i);
try {
- Connection connection =
- getConnection(addresses[i], protocol, ticket, 0, call);
+ ConnectionId remoteId = ConnectionId.getConnectionId(addresses[i],
+ protocol, ticket, 0, conf);
+ Connection connection = getConnection(remoteId, call);
connection.sendParam(call); // send each parameter
} catch (IOException e) {
// log errors
@@ -1084,12 +1124,18 @@ public class Client {
}
}
+ // for unit testing only
+ @InterfaceAudience.Private
+ @InterfaceStability.Unstable
+ Set<ConnectionId> getConnectionIds() {
+ synchronized (connections) {
+ return connections.keySet();
+ }
+ }
+
/** Get a connection from the pool, or create a new one and add it to the
- * pool. Connections to a given host/port are reused. */
- private Connection getConnection(InetSocketAddress addr,
- Class<?> protocol,
- UserGroupInformation ticket,
- int rpcTimeout,
+ * pool. Connections to a given ConnectionId are reused. */
+ private Connection getConnection(ConnectionId remoteId,
Call call)
throws IOException, InterruptedException {
if (!running.get()) {
@@ -1101,8 +1147,6 @@ public class Client {
* connectionsId object and with set() method. We need to manage the
* refs for keys in HashMap properly. For now its ok.
*/
- ConnectionId remoteId = new ConnectionId(
- addr, protocol, ticket, rpcTimeout);
do {
synchronized (connections) {
connection = connections.get(remoteId);
@@ -1120,24 +1164,40 @@ public class Client {
connection.setupIOstreams();
return connection;
}
-
+
/**
* This class holds the address and the user ticket. The client connections
* to servers are uniquely identified by <remoteAddress, protocol, ticket>
*/
- private static class ConnectionId {
+ static class ConnectionId {
InetSocketAddress address;
UserGroupInformation ticket;
Class<?> protocol;
private static final int PRIME = 16777619;
private int rpcTimeout;
+ private String serverPrincipal;
+ private int maxIdleTime; //connections will be culled if it was idle for
+ //maxIdleTime msecs
+ private int maxRetries; //the max. no. of retries for socket connections
+ private boolean tcpNoDelay; // if T then disable Nagle's Algorithm
+ private boolean doPing; //do we need to send ping message
+ private int pingInterval; // how often sends ping to the server in msecs
ConnectionId(InetSocketAddress address, Class<?> protocol,
- UserGroupInformation ticket, int rpcTimeout) {
+ UserGroupInformation ticket, int rpcTimeout,
+ String serverPrincipal, int maxIdleTime,
+ int maxRetries, boolean tcpNoDelay,
+ boolean doPing, int pingInterval) {
this.protocol = protocol;
this.address = address;
this.ticket = ticket;
this.rpcTimeout = rpcTimeout;
+ this.serverPrincipal = serverPrincipal;
+ this.maxIdleTime = maxIdleTime;
+ this.maxRetries = maxRetries;
+ this.tcpNoDelay = tcpNoDelay;
+ this.doPing = doPing;
+ this.pingInterval = pingInterval;
}
InetSocketAddress getAddress() {
@@ -1156,25 +1216,102 @@ public class Client {
return rpcTimeout;
}
+ String getServerPrincipal() {
+ return serverPrincipal;
+ }
+
+ int getMaxIdleTime() {
+ return maxIdleTime;
+ }
+
+ int getMaxRetries() {
+ return maxRetries;
+ }
+
+ boolean getTcpNoDelay() {
+ return tcpNoDelay;
+ }
+
+ boolean getDoPing() {
+ return doPing;
+ }
+
+ int getPingInterval() {
+ return pingInterval;
+ }
+
+ static ConnectionId getConnectionId(InetSocketAddress addr,
+ Class<?> protocol, UserGroupInformation ticket, int rpcTimeout,
+ Configuration conf) throws IOException {
+ String remotePrincipal = getRemotePrincipal(conf, addr, protocol);
+ return new ConnectionId(addr, protocol, ticket,
+ rpcTimeout, remotePrincipal,
+ conf.getInt("ipc.client.connection.maxidletime", 10000), // 10s
+ conf.getInt("ipc.client.connect.max.retries", 10),
+ conf.getBoolean("ipc.client.tcpnodelay", false),
+ conf.getBoolean("ipc.client.ping", true),
+ Client.getPingInterval(conf));
+ }
+
+ private static String getRemotePrincipal(Configuration conf,
+ InetSocketAddress address, Class<?> protocol) throws IOException {
+ if (protocol == null) {
+ return null;
+ }
+ KerberosInfo krbInfo = protocol.getAnnotation(KerberosInfo.class);
+ if (krbInfo != null) {
+ String serverKey = krbInfo.serverPrincipal();
+ if (serverKey == null) {
+ throw new IOException(
+ "Can't obtain server Kerberos config key from protocol="
+ + protocol.getCanonicalName());
+ }
+ return SecurityUtil.getServerPrincipal(conf.get(serverKey), address
+ .getAddress().getCanonicalHostName());
+ }
+ return null;
+ }
+
+ static boolean isEqual(Object a, Object b) {
+ return a == null ? b == null : a.equals(b);
+ }
+
@Override
public boolean equals(Object obj) {
- if (obj instanceof ConnectionId) {
- ConnectionId id = (ConnectionId) obj;
- return address.equals(id.address) && protocol == id.protocol &&
- ((ticket != null && ticket.equals(id.ticket)) ||
- (ticket == id.ticket)) && rpcTimeout == id.rpcTimeout;
- }
- return false;
+ if (obj == this) {
+ return true;
+ }
+ if (obj instanceof ConnectionId) {
+ ConnectionId that = (ConnectionId) obj;
+ return isEqual(this.address, that.address)
+ && this.doPing == that.doPing
+ && this.maxIdleTime == that.maxIdleTime
+ && this.maxRetries == that.maxRetries
+ && this.pingInterval == that.pingInterval
+ && isEqual(this.protocol, that.protocol)
+ && this.rpcTimeout == that.rpcTimeout
+ && isEqual(this.serverPrincipal, that.serverPrincipal)
+ && this.tcpNoDelay == that.tcpNoDelay
+ && isEqual(this.ticket, that.ticket);
+ }
+ return false;
}
- @Override // simply use the default Object#hashcode() ?
+ @Override
public int hashCode() {
- return (address.hashCode() + PRIME * (
- PRIME * (
- PRIME * System.identityHashCode(protocol) ^
- System.identityHashCode(ticket)
- ) ^ System.identityHashCode(rpcTimeout)
- ));
+ int result = 1;
+ result = PRIME * result + ((address == null) ? 0 : address.hashCode());
+ result = PRIME * result + (doPing ? 1231 : 1237);
+ result = PRIME * result + maxIdleTime;
+ result = PRIME * result + maxRetries;
+ result = PRIME * result + pingInterval;
+ result = PRIME * result + ((protocol == null) ? 0 : protocol.hashCode());
+ result = PRIME * result + rpcTimeout;
+ result = PRIME * result
+ + ((serverPrincipal == null) ? 0 : serverPrincipal.hashCode());
+ result = PRIME * result + (tcpNoDelay ? 1231 : 1237);
+ result = PRIME * result + ((ticket == null) ? 0 : ticket.hashCode());
+ return result;
}
}
}
Modified: hadoop/common/trunk/src/java/org/apache/hadoop/ipc/WritableRpcEngine.java
URL: http://svn.apache.org/viewvc/hadoop/common/trunk/src/java/org/apache/hadoop/ipc/WritableRpcEngine.java?rev=991780&r1=991779&r2=991780&view=diff
==============================================================================
--- hadoop/common/trunk/src/java/org/apache/hadoop/ipc/WritableRpcEngine.java (original)
+++ hadoop/common/trunk/src/java/org/apache/hadoop/ipc/WritableRpcEngine.java Thu Sep 2 00:35:30 2010
@@ -38,6 +38,8 @@ import org.apache.hadoop.security.UserGr
import org.apache.hadoop.security.authorize.ServiceAuthorizationManager;
import org.apache.hadoop.security.token.SecretManager;
import org.apache.hadoop.security.token.TokenIdentifier;
+import org.apache.hadoop.classification.InterfaceAudience;
+import org.apache.hadoop.classification.InterfaceStability;
import org.apache.hadoop.conf.*;
import org.apache.hadoop.metrics.util.MetricsTimeVaryingRate;
@@ -172,21 +174,16 @@ class WritableRpcEngine implements RpcEn
private static ClientCache CLIENTS=new ClientCache();
private static class Invoker implements InvocationHandler {
- private Class<?> protocol;
- private InetSocketAddress address;
- private UserGroupInformation ticket;
- private int rpcTimeout;
+ private Client.ConnectionId remoteId;
private Client client;
private boolean isClosed = false;
public Invoker(Class<?> protocol,
InetSocketAddress address, UserGroupInformation ticket,
Configuration conf, SocketFactory factory,
- int rpcTimeout) {
- this.protocol = protocol;
- this.address = address;
- this.ticket = ticket;
- this.rpcTimeout = rpcTimeout;
+ int rpcTimeout) throws IOException {
+ this.remoteId = Client.ConnectionId.getConnectionId(address, protocol,
+ ticket, rpcTimeout, conf);
this.client = CLIENTS.getClient(conf, factory);
}
@@ -198,8 +195,7 @@ class WritableRpcEngine implements RpcEn
}
ObjectWritable value = (ObjectWritable)
- client.call(new Invocation(method, args), address,
- protocol, ticket, rpcTimeout);
+ client.call(new Invocation(method, args), remoteId);
if (LOG.isDebugEnabled()) {
long callTime = System.currentTimeMillis() - startTime;
LOG.debug("Call: " + method.getName() + " " + callTime);
@@ -216,6 +212,13 @@ class WritableRpcEngine implements RpcEn
}
}
+ // for unit testing only
+ @InterfaceAudience.Private
+ @InterfaceStability.Unstable
+ static Client getClient(Configuration conf) {
+ return CLIENTS.getClient(conf);
+ }
+
/** Construct a client-side proxy object that implements the named protocol,
* talking to a server at the named address. */
public Object getProxy(Class<?> protocol, long clientVersion,
@@ -259,7 +262,7 @@ class WritableRpcEngine implements RpcEn
Client client = CLIENTS.getClient(conf);
try {
Writable[] wrappedValues =
- client.call(invocations, addrs, method.getDeclaringClass(), ticket);
+ client.call(invocations, addrs, method.getDeclaringClass(), ticket, conf);
if (method.getReturnType() == Void.TYPE) {
return null;
Modified: hadoop/common/trunk/src/test/core/org/apache/hadoop/ipc/TestIPC.java
URL: http://svn.apache.org/viewvc/hadoop/common/trunk/src/test/core/org/apache/hadoop/ipc/TestIPC.java?rev=991780&r1=991779&r2=991780&view=diff
==============================================================================
--- hadoop/common/trunk/src/test/core/org/apache/hadoop/ipc/TestIPC.java (original)
+++ hadoop/common/trunk/src/test/core/org/apache/hadoop/ipc/TestIPC.java Thu Sep 2 00:35:30 2010
@@ -94,7 +94,7 @@ public class TestIPC extends TestCase {
try {
LongWritable param = new LongWritable(RANDOM.nextLong());
LongWritable value =
- (LongWritable)client.call(param, server, null, null, 0);
+ (LongWritable)client.call(param, server, null, null, 0, conf);
if (!param.equals(value)) {
LOG.fatal("Call failed!");
failed = true;
@@ -127,7 +127,7 @@ public class TestIPC extends TestCase {
Writable[] params = new Writable[addresses.length];
for (int j = 0; j < addresses.length; j++)
params[j] = new LongWritable(RANDOM.nextLong());
- Writable[] values = client.call(params, addresses, null, null);
+ Writable[] values = client.call(params, addresses, null, null, conf);
for (int j = 0; j < addresses.length; j++) {
if (!params[j].equals(values[j])) {
LOG.fatal("Call failed!");
@@ -223,7 +223,7 @@ public class TestIPC extends TestCase {
InetSocketAddress address = new InetSocketAddress("127.0.0.1", 10);
try {
client.call(new LongWritable(RANDOM.nextLong()),
- address, null, null, 0);
+ address, null, null, 0, conf);
fail("Expected an exception to have been thrown");
} catch (IOException e) {
String message = e.getMessage();
@@ -280,7 +280,7 @@ public class TestIPC extends TestCase {
Client client = new Client(LongErrorWritable.class, conf);
try {
client.call(new LongErrorWritable(RANDOM.nextLong()),
- addr, null, null, 0);
+ addr, null, null, 0, conf);
fail("Expected an exception to have been thrown");
} catch (IOException e) {
// check error
@@ -300,7 +300,7 @@ public class TestIPC extends TestCase {
Client client = new Client(LongRTEWritable.class, conf);
try {
client.call(new LongRTEWritable(RANDOM.nextLong()),
- addr, null, null, 0);
+ addr, null, null, 0, conf);
fail("Expected an exception to have been thrown");
} catch (IOException e) {
// check error
@@ -326,7 +326,7 @@ public class TestIPC extends TestCase {
InetSocketAddress address = new InetSocketAddress("127.0.0.1", 10);
try {
client.call(new LongWritable(RANDOM.nextLong()),
- address, null, null, 0);
+ address, null, null, 0, conf);
fail("Expected an exception to have been thrown");
} catch (IOException e) {
assertTrue(e.getMessage().contains("Injected fault"));
@@ -344,14 +344,14 @@ public class TestIPC extends TestCase {
// set timeout to be less than MIN_SLEEP_TIME
try {
client.call(new LongWritable(RANDOM.nextLong()),
- addr, null, null, MIN_SLEEP_TIME/2);
+ addr, null, null, MIN_SLEEP_TIME/2, conf);
fail("Expected an exception to have been thrown");
} catch (SocketTimeoutException e) {
LOG.info("Get a SocketTimeoutException ", e);
}
// set timeout to be bigger than 3*ping interval
client.call(new LongWritable(RANDOM.nextLong()),
- addr, null, null, 3*PING_INTERVAL+MIN_SLEEP_TIME);
+ addr, null, null, 3*PING_INTERVAL+MIN_SLEEP_TIME, conf);
}
public static void main(String[] args) throws Exception {
Modified: hadoop/common/trunk/src/test/core/org/apache/hadoop/ipc/TestSaslRPC.java
URL: http://svn.apache.org/viewvc/hadoop/common/trunk/src/test/core/org/apache/hadoop/ipc/TestSaslRPC.java?rev=991780&r1=991779&r2=991780&view=diff
==============================================================================
--- hadoop/common/trunk/src/test/core/org/apache/hadoop/ipc/TestSaslRPC.java (original)
+++ hadoop/common/trunk/src/test/core/org/apache/hadoop/ipc/TestSaslRPC.java Thu Sep 2 00:35:30 2010
@@ -27,6 +27,7 @@ import java.io.IOException;
import java.net.InetSocketAddress;
import java.security.PrivilegedExceptionAction;
import java.util.Collection;
+import java.util.Set;
import javax.security.sasl.Sasl;
@@ -37,6 +38,7 @@ import org.apache.commons.logging.impl.L
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.Text;
+import org.apache.hadoop.ipc.Client.ConnectionId;
import org.apache.hadoop.net.NetUtils;
import org.apache.hadoop.security.KerberosInfo;
import org.apache.hadoop.security.token.SecretManager;
@@ -66,6 +68,9 @@ public class TestSaslRPC {
static final String ERROR_MESSAGE = "Token is invalid";
static final String SERVER_PRINCIPAL_KEY = "test.ipc.server.principal";
static final String SERVER_KEYTAB_KEY = "test.ipc.server.keytab";
+ static final String SERVER_PRINCIPAL_1 = "p1/foo@BAR";
+ static final String SERVER_PRINCIPAL_2 = "p2/foo@BAR";
+
private static Configuration conf;
static {
conf = new Configuration();
@@ -249,6 +254,63 @@ public class TestSaslRPC {
}
}
+ @Test
+ public void testPerConnectionConf() throws Exception {
+ TestTokenSecretManager sm = new TestTokenSecretManager();
+ final Server server = RPC.getServer(TestSaslProtocol.class,
+ new TestSaslImpl(), ADDRESS, 0, 5, true, conf, sm);
+ server.start();
+ final UserGroupInformation current = UserGroupInformation.getCurrentUser();
+ final InetSocketAddress addr = NetUtils.getConnectAddress(server);
+ TestTokenIdentifier tokenId = new TestTokenIdentifier(new Text(current
+ .getUserName()));
+ Token<TestTokenIdentifier> token = new Token<TestTokenIdentifier>(tokenId,
+ sm);
+ Text host = new Text(addr.getAddress().getHostAddress() + ":"
+ + addr.getPort());
+ token.setService(host);
+ LOG.info("Service IP address for token is " + host);
+ current.addToken(token);
+
+ Configuration newConf = new Configuration(conf);
+ newConf.set("hadoop.rpc.socket.factory.class.default", "");
+ newConf.set(SERVER_PRINCIPAL_KEY, SERVER_PRINCIPAL_1);
+
+ TestSaslProtocol proxy1 = null;
+ TestSaslProtocol proxy2 = null;
+ TestSaslProtocol proxy3 = null;
+ try {
+ proxy1 = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class,
+ TestSaslProtocol.versionID, addr, newConf);
+ Client client = WritableRpcEngine.getClient(conf);
+ Set<ConnectionId> conns = client.getConnectionIds();
+ assertEquals("number of connections in cache is wrong", 1, conns.size());
+ // same conf, connection should be re-used
+ proxy2 = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class,
+ TestSaslProtocol.versionID, addr, newConf);
+ assertEquals("number of connections in cache is wrong", 1, conns.size());
+ // different conf, new connection should be set up
+ newConf.set(SERVER_PRINCIPAL_KEY, SERVER_PRINCIPAL_2);
+ proxy3 = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class,
+ TestSaslProtocol.versionID, addr, newConf);
+ ConnectionId[] connsArray = conns.toArray(new ConnectionId[0]);
+ assertEquals("number of connections in cache is wrong", 2,
+ connsArray.length);
+ String p1 = connsArray[0].getServerPrincipal();
+ String p2 = connsArray[1].getServerPrincipal();
+ assertFalse("should have different principals", p1.equals(p2));
+ assertTrue("principal not as expected", p1.equals(SERVER_PRINCIPAL_1)
+ || p1.equals(SERVER_PRINCIPAL_2));
+ assertTrue("principal not as expected", p2.equals(SERVER_PRINCIPAL_1)
+ || p2.equals(SERVER_PRINCIPAL_2));
+ } finally {
+ server.stop();
+ RPC.stopProxy(proxy1);
+ RPC.stopProxy(proxy2);
+ RPC.stopProxy(proxy3);
+ }
+ }
+
static void testKerberosRpc(String principal, String keytab) throws Exception {
final Configuration newConf = new Configuration(conf);
newConf.set(SERVER_PRINCIPAL_KEY, principal);