You are viewing a plain text version of this content. The canonical link for it is here.
Posted to common-commits@hadoop.apache.org by ha...@apache.org on 2010/09/02 02:35:30 UTC

svn commit: r991780 - in /hadoop/common/trunk: ./ src/java/org/apache/hadoop/ipc/ src/test/core/org/apache/hadoop/ipc/

Author: hairong
Date: Thu Sep  2 00:35:30 2010
New Revision: 991780

URL: http://svn.apache.org/viewvc?rev=991780&view=rev
Log:
HADOOP-6907. Rpc client doesn't use the per-connection conf to figure out server's Kerberos principal. Contributed by Kan Zhang.

Modified:
    hadoop/common/trunk/CHANGES.txt
    hadoop/common/trunk/src/java/org/apache/hadoop/ipc/Client.java
    hadoop/common/trunk/src/java/org/apache/hadoop/ipc/WritableRpcEngine.java
    hadoop/common/trunk/src/test/core/org/apache/hadoop/ipc/TestIPC.java
    hadoop/common/trunk/src/test/core/org/apache/hadoop/ipc/TestSaslRPC.java

Modified: hadoop/common/trunk/CHANGES.txt
URL: http://svn.apache.org/viewvc/hadoop/common/trunk/CHANGES.txt?rev=991780&r1=991779&r2=991780&view=diff
==============================================================================
--- hadoop/common/trunk/CHANGES.txt (original)
+++ hadoop/common/trunk/CHANGES.txt Thu Sep  2 00:35:30 2010
@@ -226,6 +226,9 @@ Trunk (unreleased changes)
     HADOOP-6913. Circular initialization between UserGroupInformation and 
     KerberosName (Kan Zhang via boryas)
 
+    HADOOP-6907. Rpc client doesn't use the per-connection conf to figure
+    out server's Kerberos principal (Kan Zhang via hairong)
+
 Release 0.21.0 - Unreleased
 
   INCOMPATIBLE CHANGES

Modified: hadoop/common/trunk/src/java/org/apache/hadoop/ipc/Client.java
URL: http://svn.apache.org/viewvc/hadoop/common/trunk/src/java/org/apache/hadoop/ipc/Client.java?rev=991780&r1=991779&r2=991780&view=diff
==============================================================================
--- hadoop/common/trunk/src/java/org/apache/hadoop/ipc/Client.java (original)
+++ hadoop/common/trunk/src/java/org/apache/hadoop/ipc/Client.java Thu Sep  2 00:35:30 2010
@@ -37,6 +37,7 @@ import java.security.PrivilegedException
 import java.util.Hashtable;
 import java.util.Iterator;
 import java.util.Random;
+import java.util.Set;
 import java.util.Map.Entry;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicLong;
@@ -45,6 +46,8 @@ import javax.net.SocketFactory;
 
 import org.apache.commons.logging.*;
 
+import org.apache.hadoop.classification.InterfaceAudience;
+import org.apache.hadoop.classification.InterfaceStability;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.io.IOUtils;
 import org.apache.hadoop.io.Text;
@@ -80,12 +83,6 @@ public class Client {
   private int counter;                            // counter for call ids
   private AtomicBoolean running = new AtomicBoolean(true); // if client runs
   final private Configuration conf;
-  final private int maxIdleTime; //connections will be culled if it was idle for 
-                           //maxIdleTime msecs
-  final private int maxRetries; //the max. no. of retries for socket connections
-  private boolean tcpNoDelay; // if T then disable Nagle's Algorithm
-  private int pingInterval; // how often sends ping to the server in msecs
-  final private boolean doPing; //do we need to send ping message
 
   private SocketFactory socketFactory;           // how to create sockets
   private int refCount = 1;
@@ -220,6 +217,12 @@ public class Client {
     private DataInputStream in;
     private DataOutputStream out;
     private int rpcTimeout;
+    private int maxIdleTime; //connections will be culled if it was idle for 
+    //maxIdleTime msecs
+    private int maxRetries; //the max. no. of retries for socket connections
+    private boolean tcpNoDelay; // if T then disable Nagle's Algorithm
+    private boolean doPing; //do we need to send ping message
+    private int pingInterval; // how often sends ping to the server in msecs
     
     // currently active calls
     private Hashtable<Integer, Call> calls = new Hashtable<Integer, Call>();
@@ -235,6 +238,15 @@ public class Client {
                                        remoteId.getAddress().getHostName());
       }
       this.rpcTimeout = remoteId.getRpcTimeout();
+      this.maxIdleTime = remoteId.getMaxIdleTime();
+      this.maxRetries = remoteId.getMaxRetries();
+      this.tcpNoDelay = remoteId.getTcpNoDelay();
+      this.doPing = remoteId.getDoPing();
+      this.pingInterval = remoteId.getPingInterval();
+      if (LOG.isDebugEnabled()) {
+        LOG.debug("The ping interval is" + this.pingInterval + "ms.");
+      }
+
       UserGroupInformation ticket = remoteId.getTicket();
       Class<?> protocol = remoteId.getProtocol();
       this.useSasl = UserGroupInformation.isSecurityEnabled();
@@ -256,15 +268,9 @@ public class Client {
         }
         KerberosInfo krbInfo = protocol.getAnnotation(KerberosInfo.class);
         if (krbInfo != null) {
-          String serverKey = krbInfo.serverPrincipal();
-          if (serverKey == null) {
-            throw new IOException(
-                "Can't obtain server Kerberos config key from KerberosInfo");
-          }
-          serverPrincipal = SecurityUtil.getServerPrincipal(
-              conf.get(serverKey), server.getAddress().getCanonicalHostName());
+          serverPrincipal = remoteId.getServerPrincipal();
           if (LOG.isDebugEnabled()) {
-            LOG.debug("RPC Server Kerberos principal name for protocol="
+            LOG.debug("RPC Server's Kerberos principal name for protocol="
                 + protocol.getCanonicalName() + " is " + serverPrincipal);
           }
         }
@@ -882,15 +888,6 @@ public class Client {
   public Client(Class<? extends Writable> valueClass, Configuration conf, 
       SocketFactory factory) {
     this.valueClass = valueClass;
-    this.maxIdleTime = 
-      conf.getInt("ipc.client.connection.maxidletime", 10000); //10s
-    this.maxRetries = conf.getInt("ipc.client.connect.max.retries", 10);
-    this.tcpNoDelay = conf.getBoolean("ipc.client.tcpnodelay", false);
-    this.doPing = conf.getBoolean("ipc.client.ping", true);
-    this.pingInterval = getPingInterval(conf);
-    if (LOG.isDebugEnabled()) {
-      LOG.debug("The ping interval is" + this.pingInterval + "ms.");
-    }
     this.conf = conf;
     this.socketFactory = factory;
   }
@@ -942,7 +939,7 @@ public class Client {
   /** Make a call, passing <code>param</code>, to the IPC server running at
    * <code>address</code>, returning the value.  Throws exceptions if there are
    * network problems or if the remote code threw an exception.
-   * @deprecated Use {@link #call(Writable, InetSocketAddress, Class, UserGroupInformation, int)} instead 
+   * @deprecated Use {@link #call(Writable, ConnectionId)} instead 
    */
   @Deprecated
   public Writable call(Writable param, InetSocketAddress address)
@@ -955,27 +952,60 @@ public class Client {
    * the value.  
    * Throws exceptions if there are network problems or if the remote code 
    * threw an exception.
-   * @deprecated Use {@link #call(Writable, InetSocketAddress, Class, UserGroupInformation, int)} instead 
+   * @deprecated Use {@link #call(Writable, ConnectionId)} instead 
    */
   @Deprecated
   public Writable call(Writable param, InetSocketAddress addr, 
       UserGroupInformation ticket)  
       throws InterruptedException, IOException {
-    return call(param, addr, null, ticket, 0);
+    ConnectionId remoteId = ConnectionId.getConnectionId(addr, null, ticket, 0,
+        conf);
+    return call(param, remoteId);
   }
   
   /** Make a call, passing <code>param</code>, to the IPC server running at
    * <code>address</code> which is servicing the <code>protocol</code> protocol, 
-   * with the <code>ticket</code> credentials, returning the value.  
+   * with the <code>ticket</code> credentials and <code>rpcTimeout</code> as 
+   * timeout, returning the value.  
    * Throws exceptions if there are network problems or if the remote code 
-   * threw an exception. */
+   * threw an exception. 
+   * @deprecated Use {@link #call(Writable, ConnectionId)} instead 
+   */
+  @Deprecated
   public Writable call(Writable param, InetSocketAddress addr, 
                        Class<?> protocol, UserGroupInformation ticket,
                        int rpcTimeout)  
                        throws InterruptedException, IOException {
+    ConnectionId remoteId = ConnectionId.getConnectionId(addr, protocol,
+        ticket, rpcTimeout, conf);
+    return call(param, remoteId);
+  }
+
+  /**
+   * Make a call, passing <code>param</code>, to the IPC server running at
+   * <code>address</code> which is servicing the <code>protocol</code> protocol,
+   * with the <code>ticket</code> credentials, <code>rpcTimeout</code> as
+   * timeout and <code>conf</code> as conf for this connection, returning the
+   * value. Throws exceptions if there are network problems or if the remote
+   * code threw an exception.
+   */
+  public Writable call(Writable param, InetSocketAddress addr, 
+                       Class<?> protocol, UserGroupInformation ticket,
+                       int rpcTimeout, Configuration conf)  
+                       throws InterruptedException, IOException {
+    ConnectionId remoteId = ConnectionId.getConnectionId(addr, protocol,
+        ticket, rpcTimeout, conf);
+    return call(param, remoteId);
+  }
+  
+  /** Make a call, passing <code>param</code>, to the IPC server defined by
+   * <code>remoteId</code>, returning the value.  
+   * Throws exceptions if there are network problems or if the remote code 
+   * threw an exception. */
+  public Writable call(Writable param, ConnectionId remoteId)  
+      throws InterruptedException, IOException {
     Call call = new Call(param);
-    Connection connection = getConnection(
-        addr, protocol, ticket, rpcTimeout, call);
+    Connection connection = getConnection(remoteId, call);
     connection.sendParam(call);                 // send the parameter
     boolean interrupted = false;
     synchronized (call) {
@@ -998,7 +1028,7 @@ public class Client {
           call.error.fillInStackTrace();
           throw call.error;
         } else { // local exception
-          throw wrapException(addr, call.error);
+          throw wrapException(remoteId.getAddress(), call.error);
         }
       } else {
         return call.value;
@@ -1038,25 +1068,34 @@ public class Client {
   }
 
   /** 
-   * Makes a set of calls in parallel.  Each parameter is sent to the
-   * corresponding address.  When all values are available, or have timed out
-   * or errored, the collected results are returned in an array.  The array
-   * contains nulls for calls that timed out or errored.
-   * @deprecated Use {@link #call(Writable[], InetSocketAddress[], Class, UserGroupInformation)} instead 
+   * @deprecated Use {@link #call(Writable[], InetSocketAddress[], 
+   * Class, UserGroupInformation, Configuration)} instead 
    */
   @Deprecated
   public Writable[] call(Writable[] params, InetSocketAddress[] addresses)
     throws IOException, InterruptedException {
-    return call(params, addresses, null, null);
+    return call(params, addresses, null, null, conf);
+  }
+  
+  /**  
+   * @deprecated Use {@link #call(Writable[], InetSocketAddress[], 
+   * Class, UserGroupInformation, Configuration)} instead 
+   */
+  @Deprecated
+  public Writable[] call(Writable[] params, InetSocketAddress[] addresses, 
+                         Class<?> protocol, UserGroupInformation ticket)
+    throws IOException, InterruptedException {
+    return call(params, addresses, protocol, ticket, conf);
   }
   
+
   /** Makes a set of calls in parallel.  Each parameter is sent to the
    * corresponding address.  When all values are available, or have timed out
    * or errored, the collected results are returned in an array.  The array
    * contains nulls for calls that timed out or errored.  */
-  public Writable[] call(Writable[] params, InetSocketAddress[] addresses, 
-                         Class<?> protocol, UserGroupInformation ticket)
-    throws IOException, InterruptedException {
+  public Writable[] call(Writable[] params, InetSocketAddress[] addresses,
+      Class<?> protocol, UserGroupInformation ticket, Configuration conf)
+      throws IOException, InterruptedException {
     if (addresses.length == 0) return new Writable[0];
 
     ParallelResults results = new ParallelResults(params.length);
@@ -1064,8 +1103,9 @@ public class Client {
       for (int i = 0; i < params.length; i++) {
         ParallelCall call = new ParallelCall(params[i], results, i);
         try {
-          Connection connection = 
-            getConnection(addresses[i], protocol, ticket, 0, call);
+          ConnectionId remoteId = ConnectionId.getConnectionId(addresses[i],
+              protocol, ticket, 0, conf);
+          Connection connection = getConnection(remoteId, call);
           connection.sendParam(call);             // send each parameter
         } catch (IOException e) {
           // log errors
@@ -1084,12 +1124,18 @@ public class Client {
     }
   }
 
+  // for unit testing only
+  @InterfaceAudience.Private
+  @InterfaceStability.Unstable
+  Set<ConnectionId> getConnectionIds() {
+    synchronized (connections) {
+      return connections.keySet();
+    }
+  }
+  
   /** Get a connection from the pool, or create a new one and add it to the
-   * pool.  Connections to a given host/port are reused. */
-  private Connection getConnection(InetSocketAddress addr,
-                                   Class<?> protocol,
-                                   UserGroupInformation ticket,
-                                   int rpcTimeout,
+   * pool.  Connections to a given ConnectionId are reused. */
+  private Connection getConnection(ConnectionId remoteId,
                                    Call call)
                                    throws IOException, InterruptedException {
     if (!running.get()) {
@@ -1101,8 +1147,6 @@ public class Client {
      * connectionsId object and with set() method. We need to manage the
      * refs for keys in HashMap properly. For now its ok.
      */
-    ConnectionId remoteId = new ConnectionId(
-        addr, protocol, ticket, rpcTimeout);
     do {
       synchronized (connections) {
         connection = connections.get(remoteId);
@@ -1120,24 +1164,40 @@ public class Client {
     connection.setupIOstreams();
     return connection;
   }
-
+  
   /**
    * This class holds the address and the user ticket. The client connections
    * to servers are uniquely identified by <remoteAddress, protocol, ticket>
    */
-  private static class ConnectionId {
+  static class ConnectionId {
     InetSocketAddress address;
     UserGroupInformation ticket;
     Class<?> protocol;
     private static final int PRIME = 16777619;
     private int rpcTimeout;
+    private String serverPrincipal;
+    private int maxIdleTime; //connections will be culled if it was idle for 
+    //maxIdleTime msecs
+    private int maxRetries; //the max. no. of retries for socket connections
+    private boolean tcpNoDelay; // if T then disable Nagle's Algorithm
+    private boolean doPing; //do we need to send ping message
+    private int pingInterval; // how often sends ping to the server in msecs
     
     ConnectionId(InetSocketAddress address, Class<?> protocol, 
-                 UserGroupInformation ticket, int rpcTimeout) {
+                 UserGroupInformation ticket, int rpcTimeout,
+                 String serverPrincipal, int maxIdleTime, 
+                 int maxRetries, boolean tcpNoDelay,
+                 boolean doPing, int pingInterval) {
       this.protocol = protocol;
       this.address = address;
       this.ticket = ticket;
       this.rpcTimeout = rpcTimeout;
+      this.serverPrincipal = serverPrincipal;
+      this.maxIdleTime = maxIdleTime;
+      this.maxRetries = maxRetries;
+      this.tcpNoDelay = tcpNoDelay;
+      this.doPing = doPing;
+      this.pingInterval = pingInterval;
     }
     
     InetSocketAddress getAddress() {
@@ -1156,25 +1216,102 @@ public class Client {
       return rpcTimeout;
     }
     
+    String getServerPrincipal() {
+      return serverPrincipal;
+    }
+    
+    int getMaxIdleTime() {
+      return maxIdleTime;
+    }
+    
+    int getMaxRetries() {
+      return maxRetries;
+    }
+    
+    boolean getTcpNoDelay() {
+      return tcpNoDelay;
+    }
+    
+    boolean getDoPing() {
+      return doPing;
+    }
+    
+    int getPingInterval() {
+      return pingInterval;
+    }
+    
+    static ConnectionId getConnectionId(InetSocketAddress addr,
+        Class<?> protocol, UserGroupInformation ticket, int rpcTimeout,
+        Configuration conf) throws IOException {
+      String remotePrincipal = getRemotePrincipal(conf, addr, protocol);
+      return new ConnectionId(addr, protocol, ticket,
+          rpcTimeout, remotePrincipal,
+          conf.getInt("ipc.client.connection.maxidletime", 10000), // 10s
+          conf.getInt("ipc.client.connect.max.retries", 10),
+          conf.getBoolean("ipc.client.tcpnodelay", false),
+          conf.getBoolean("ipc.client.ping", true),
+          Client.getPingInterval(conf));
+    }
+    
+    private static String getRemotePrincipal(Configuration conf,
+        InetSocketAddress address, Class<?> protocol) throws IOException {
+      if (protocol == null) {
+        return null;
+      }
+      KerberosInfo krbInfo = protocol.getAnnotation(KerberosInfo.class);
+      if (krbInfo != null) {
+        String serverKey = krbInfo.serverPrincipal();
+        if (serverKey == null) {
+          throw new IOException(
+              "Can't obtain server Kerberos config key from protocol="
+                  + protocol.getCanonicalName());
+        }
+        return SecurityUtil.getServerPrincipal(conf.get(serverKey), address
+            .getAddress().getCanonicalHostName());
+      }
+      return null;
+    }
+    
+    static boolean isEqual(Object a, Object b) {
+      return a == null ? b == null : a.equals(b);
+    }
+
     @Override
     public boolean equals(Object obj) {
-     if (obj instanceof ConnectionId) {
-       ConnectionId id = (ConnectionId) obj;
-       return address.equals(id.address) && protocol == id.protocol && 
-              ((ticket != null && ticket.equals(id.ticket)) ||
-               (ticket == id.ticket)) && rpcTimeout == id.rpcTimeout;
-     }
-     return false;
+      if (obj == this) {
+        return true;
+      }
+      if (obj instanceof ConnectionId) {
+        ConnectionId that = (ConnectionId) obj;
+        return isEqual(this.address, that.address)
+            && this.doPing == that.doPing
+            && this.maxIdleTime == that.maxIdleTime
+            && this.maxRetries == that.maxRetries
+            && this.pingInterval == that.pingInterval
+            && isEqual(this.protocol, that.protocol)
+            && this.rpcTimeout == that.rpcTimeout
+            && isEqual(this.serverPrincipal, that.serverPrincipal)
+            && this.tcpNoDelay == that.tcpNoDelay
+            && isEqual(this.ticket, that.ticket);
+      }
+      return false;
     }
     
-    @Override  // simply use the default Object#hashcode() ?
+    @Override
     public int hashCode() {
-      return (address.hashCode() + PRIME * (
-                PRIME * (
-                  PRIME * System.identityHashCode(protocol) ^
-                  System.identityHashCode(ticket)
-                ) ^ System.identityHashCode(rpcTimeout)
-              ));
+      int result = 1;
+      result = PRIME * result + ((address == null) ? 0 : address.hashCode());
+      result = PRIME * result + (doPing ? 1231 : 1237);
+      result = PRIME * result + maxIdleTime;
+      result = PRIME * result + maxRetries;
+      result = PRIME * result + pingInterval;
+      result = PRIME * result + ((protocol == null) ? 0 : protocol.hashCode());
+      result = PRIME * result + rpcTimeout;
+      result = PRIME * result
+          + ((serverPrincipal == null) ? 0 : serverPrincipal.hashCode());
+      result = PRIME * result + (tcpNoDelay ? 1231 : 1237);
+      result = PRIME * result + ((ticket == null) ? 0 : ticket.hashCode());
+      return result;
     }
   }  
 }

Modified: hadoop/common/trunk/src/java/org/apache/hadoop/ipc/WritableRpcEngine.java
URL: http://svn.apache.org/viewvc/hadoop/common/trunk/src/java/org/apache/hadoop/ipc/WritableRpcEngine.java?rev=991780&r1=991779&r2=991780&view=diff
==============================================================================
--- hadoop/common/trunk/src/java/org/apache/hadoop/ipc/WritableRpcEngine.java (original)
+++ hadoop/common/trunk/src/java/org/apache/hadoop/ipc/WritableRpcEngine.java Thu Sep  2 00:35:30 2010
@@ -38,6 +38,8 @@ import org.apache.hadoop.security.UserGr
 import org.apache.hadoop.security.authorize.ServiceAuthorizationManager;
 import org.apache.hadoop.security.token.SecretManager;
 import org.apache.hadoop.security.token.TokenIdentifier;
+import org.apache.hadoop.classification.InterfaceAudience;
+import org.apache.hadoop.classification.InterfaceStability;
 import org.apache.hadoop.conf.*;
 import org.apache.hadoop.metrics.util.MetricsTimeVaryingRate;
 
@@ -172,21 +174,16 @@ class WritableRpcEngine implements RpcEn
   private static ClientCache CLIENTS=new ClientCache();
   
   private static class Invoker implements InvocationHandler {
-    private Class<?> protocol;
-    private InetSocketAddress address;
-    private UserGroupInformation ticket;
-    private int rpcTimeout;
+    private Client.ConnectionId remoteId;
     private Client client;
     private boolean isClosed = false;
 
     public Invoker(Class<?> protocol,
                    InetSocketAddress address, UserGroupInformation ticket,
                    Configuration conf, SocketFactory factory,
-                   int rpcTimeout) {
-      this.protocol = protocol;
-      this.address = address;
-      this.ticket = ticket;
-      this.rpcTimeout = rpcTimeout;
+                   int rpcTimeout) throws IOException {
+      this.remoteId = Client.ConnectionId.getConnectionId(address, protocol,
+          ticket, rpcTimeout, conf);
       this.client = CLIENTS.getClient(conf, factory);
     }
 
@@ -198,8 +195,7 @@ class WritableRpcEngine implements RpcEn
       }
 
       ObjectWritable value = (ObjectWritable)
-        client.call(new Invocation(method, args), address, 
-                    protocol, ticket, rpcTimeout);
+        client.call(new Invocation(method, args), remoteId);
       if (LOG.isDebugEnabled()) {
         long callTime = System.currentTimeMillis() - startTime;
         LOG.debug("Call: " + method.getName() + " " + callTime);
@@ -216,6 +212,13 @@ class WritableRpcEngine implements RpcEn
     }
   }
   
+  // for unit testing only
+  @InterfaceAudience.Private
+  @InterfaceStability.Unstable
+  static Client getClient(Configuration conf) {
+    return CLIENTS.getClient(conf);
+  }
+  
   /** Construct a client-side proxy object that implements the named protocol,
    * talking to a server at the named address. */
   public Object getProxy(Class<?> protocol, long clientVersion,
@@ -259,7 +262,7 @@ class WritableRpcEngine implements RpcEn
     Client client = CLIENTS.getClient(conf);
     try {
     Writable[] wrappedValues = 
-      client.call(invocations, addrs, method.getDeclaringClass(), ticket);
+      client.call(invocations, addrs, method.getDeclaringClass(), ticket, conf);
     
     if (method.getReturnType() == Void.TYPE) {
       return null;

Modified: hadoop/common/trunk/src/test/core/org/apache/hadoop/ipc/TestIPC.java
URL: http://svn.apache.org/viewvc/hadoop/common/trunk/src/test/core/org/apache/hadoop/ipc/TestIPC.java?rev=991780&r1=991779&r2=991780&view=diff
==============================================================================
--- hadoop/common/trunk/src/test/core/org/apache/hadoop/ipc/TestIPC.java (original)
+++ hadoop/common/trunk/src/test/core/org/apache/hadoop/ipc/TestIPC.java Thu Sep  2 00:35:30 2010
@@ -94,7 +94,7 @@ public class TestIPC extends TestCase {
         try {
           LongWritable param = new LongWritable(RANDOM.nextLong());
           LongWritable value =
-            (LongWritable)client.call(param, server, null, null, 0);
+            (LongWritable)client.call(param, server, null, null, 0, conf);
           if (!param.equals(value)) {
             LOG.fatal("Call failed!");
             failed = true;
@@ -127,7 +127,7 @@ public class TestIPC extends TestCase {
           Writable[] params = new Writable[addresses.length];
           for (int j = 0; j < addresses.length; j++)
             params[j] = new LongWritable(RANDOM.nextLong());
-          Writable[] values = client.call(params, addresses, null, null);
+          Writable[] values = client.call(params, addresses, null, null, conf);
           for (int j = 0; j < addresses.length; j++) {
             if (!params[j].equals(values[j])) {
               LOG.fatal("Call failed!");
@@ -223,7 +223,7 @@ public class TestIPC extends TestCase {
     InetSocketAddress address = new InetSocketAddress("127.0.0.1", 10);
     try {
       client.call(new LongWritable(RANDOM.nextLong()),
-              address, null, null, 0);
+              address, null, null, 0, conf);
       fail("Expected an exception to have been thrown");
     } catch (IOException e) {
       String message = e.getMessage();
@@ -280,7 +280,7 @@ public class TestIPC extends TestCase {
     Client client = new Client(LongErrorWritable.class, conf);
     try {
       client.call(new LongErrorWritable(RANDOM.nextLong()),
-              addr, null, null, 0);
+              addr, null, null, 0, conf);
       fail("Expected an exception to have been thrown");
     } catch (IOException e) {
       // check error
@@ -300,7 +300,7 @@ public class TestIPC extends TestCase {
     Client client = new Client(LongRTEWritable.class, conf);
     try {
       client.call(new LongRTEWritable(RANDOM.nextLong()),
-              addr, null, null, 0);
+              addr, null, null, 0, conf);
       fail("Expected an exception to have been thrown");
     } catch (IOException e) {
       // check error
@@ -326,7 +326,7 @@ public class TestIPC extends TestCase {
     InetSocketAddress address = new InetSocketAddress("127.0.0.1", 10);
     try {
       client.call(new LongWritable(RANDOM.nextLong()),
-              address, null, null, 0);
+              address, null, null, 0, conf);
       fail("Expected an exception to have been thrown");
     } catch (IOException e) {
       assertTrue(e.getMessage().contains("Injected fault"));
@@ -344,14 +344,14 @@ public class TestIPC extends TestCase {
     // set timeout to be less than MIN_SLEEP_TIME
     try {
       client.call(new LongWritable(RANDOM.nextLong()),
-              addr, null, null, MIN_SLEEP_TIME/2);
+              addr, null, null, MIN_SLEEP_TIME/2, conf);
       fail("Expected an exception to have been thrown");
     } catch (SocketTimeoutException e) {
       LOG.info("Get a SocketTimeoutException ", e);
     }
     // set timeout to be bigger than 3*ping interval
     client.call(new LongWritable(RANDOM.nextLong()),
-        addr, null, null, 3*PING_INTERVAL+MIN_SLEEP_TIME);
+        addr, null, null, 3*PING_INTERVAL+MIN_SLEEP_TIME, conf);
   }
   
   public static void main(String[] args) throws Exception {

Modified: hadoop/common/trunk/src/test/core/org/apache/hadoop/ipc/TestSaslRPC.java
URL: http://svn.apache.org/viewvc/hadoop/common/trunk/src/test/core/org/apache/hadoop/ipc/TestSaslRPC.java?rev=991780&r1=991779&r2=991780&view=diff
==============================================================================
--- hadoop/common/trunk/src/test/core/org/apache/hadoop/ipc/TestSaslRPC.java (original)
+++ hadoop/common/trunk/src/test/core/org/apache/hadoop/ipc/TestSaslRPC.java Thu Sep  2 00:35:30 2010
@@ -27,6 +27,7 @@ import java.io.IOException;
 import java.net.InetSocketAddress;
 import java.security.PrivilegedExceptionAction;
 import java.util.Collection;
+import java.util.Set;
 
 import javax.security.sasl.Sasl;
 
@@ -37,6 +38,7 @@ import org.apache.commons.logging.impl.L
 
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.io.Text;
+import org.apache.hadoop.ipc.Client.ConnectionId;
 import org.apache.hadoop.net.NetUtils;
 import org.apache.hadoop.security.KerberosInfo;
 import org.apache.hadoop.security.token.SecretManager;
@@ -66,6 +68,9 @@ public class TestSaslRPC {
   static final String ERROR_MESSAGE = "Token is invalid";
   static final String SERVER_PRINCIPAL_KEY = "test.ipc.server.principal";
   static final String SERVER_KEYTAB_KEY = "test.ipc.server.keytab";
+  static final String SERVER_PRINCIPAL_1 = "p1/foo@BAR";
+  static final String SERVER_PRINCIPAL_2 = "p2/foo@BAR";
+  
   private static Configuration conf;
   static {
     conf = new Configuration();
@@ -249,6 +254,63 @@ public class TestSaslRPC {
     }
   }
   
+  @Test
+  public void testPerConnectionConf() throws Exception {
+    TestTokenSecretManager sm = new TestTokenSecretManager();
+    final Server server = RPC.getServer(TestSaslProtocol.class,
+        new TestSaslImpl(), ADDRESS, 0, 5, true, conf, sm);
+    server.start();
+    final UserGroupInformation current = UserGroupInformation.getCurrentUser();
+    final InetSocketAddress addr = NetUtils.getConnectAddress(server);
+    TestTokenIdentifier tokenId = new TestTokenIdentifier(new Text(current
+        .getUserName()));
+    Token<TestTokenIdentifier> token = new Token<TestTokenIdentifier>(tokenId,
+        sm);
+    Text host = new Text(addr.getAddress().getHostAddress() + ":"
+        + addr.getPort());
+    token.setService(host);
+    LOG.info("Service IP address for token is " + host);
+    current.addToken(token);
+
+    Configuration newConf = new Configuration(conf);
+    newConf.set("hadoop.rpc.socket.factory.class.default", "");
+    newConf.set(SERVER_PRINCIPAL_KEY, SERVER_PRINCIPAL_1);
+
+    TestSaslProtocol proxy1 = null;
+    TestSaslProtocol proxy2 = null;
+    TestSaslProtocol proxy3 = null;
+    try {
+      proxy1 = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class,
+          TestSaslProtocol.versionID, addr, newConf);
+      Client client = WritableRpcEngine.getClient(conf);
+      Set<ConnectionId> conns = client.getConnectionIds();
+      assertEquals("number of connections in cache is wrong", 1, conns.size());
+      // same conf, connection should be re-used
+      proxy2 = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class,
+          TestSaslProtocol.versionID, addr, newConf);
+      assertEquals("number of connections in cache is wrong", 1, conns.size());
+      // different conf, new connection should be set up
+      newConf.set(SERVER_PRINCIPAL_KEY, SERVER_PRINCIPAL_2);
+      proxy3 = (TestSaslProtocol) RPC.getProxy(TestSaslProtocol.class,
+          TestSaslProtocol.versionID, addr, newConf);
+      ConnectionId[] connsArray = conns.toArray(new ConnectionId[0]);
+      assertEquals("number of connections in cache is wrong", 2,
+          connsArray.length);
+      String p1 = connsArray[0].getServerPrincipal();
+      String p2 = connsArray[1].getServerPrincipal();
+      assertFalse("should have different principals", p1.equals(p2));
+      assertTrue("principal not as expected", p1.equals(SERVER_PRINCIPAL_1)
+          || p1.equals(SERVER_PRINCIPAL_2));
+      assertTrue("principal not as expected", p2.equals(SERVER_PRINCIPAL_1)
+          || p2.equals(SERVER_PRINCIPAL_2));
+    } finally {
+      server.stop();
+      RPC.stopProxy(proxy1);
+      RPC.stopProxy(proxy2);
+      RPC.stopProxy(proxy3);
+    }
+  }
+  
   static void testKerberosRpc(String principal, String keytab) throws Exception {
     final Configuration newConf = new Configuration(conf);
     newConf.set(SERVER_PRINCIPAL_KEY, principal);