You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kudu.apache.org by da...@apache.org on 2017/01/26 02:19:50 UTC

[2/2] kudu git commit: TLS-negotiation [6/n]: Refactor RPC negotiation

TLS-negotiation [6/n]: Refactor RPC negotiation

The KRPC negotiation process now encompasses more than just SASL
authentication. Feature flags are negotiated, and soon TLS encryption
will be negotiated as well. In anticipation of TLS negotiation, this
commit refactors the SaslClient and SaslServer classes to better reflect
their role.  Additionally, the Connection class now has less
responsibility and knowledge of negotiation. Instead, the existing logic
in negotiation.cc has been beefed up to serve as the glue between the
ClientNegotiation and ServerNegotiation classes (which have no knowledge
of Connection), and Connection (which now has no knowledge of
ClientNegotiation/ServerNegotiation). Hopefully this will lead to more
maintainable code in the long term. It is expected that this refactor
will make TLS-negotiation an easy add-on.

Change-Id: I567b62f3341a1e74342c30c76b63f2ca5d7990bd
Reviewed-on: http://gerrit.cloudera.org:8080/5760
Tested-by: Kudu Jenkins
Reviewed-by: Todd Lipcon <to...@apache.org>
Reviewed-by: Alexey Serbin <as...@cloudera.com>


Project: http://git-wip-us.apache.org/repos/asf/kudu/repo
Commit: http://git-wip-us.apache.org/repos/asf/kudu/commit/dc852535
Tree: http://git-wip-us.apache.org/repos/asf/kudu/tree/dc852535
Diff: http://git-wip-us.apache.org/repos/asf/kudu/diff/dc852535

Branch: refs/heads/master
Commit: dc8525358f48df4142e3347b389612704a9312d1
Parents: b9aa5dd
Author: Dan Burkert <da...@apache.org>
Authored: Fri Jan 20 15:17:44 2017 -0800
Committer: Dan Burkert <da...@apache.org>
Committed: Thu Jan 26 02:19:25 2017 +0000

----------------------------------------------------------------------
 src/kudu/rpc/client_negotiation.cc | 408 +++++++++++-------------
 src/kudu/rpc/client_negotiation.h  | 144 +++++----
 src/kudu/rpc/connection.cc         |  88 ++----
 src/kudu/rpc/connection.h          |  58 ++--
 src/kudu/rpc/messenger.cc          |   8 +-
 src/kudu/rpc/negotiation-test.cc   | 218 ++++++-------
 src/kudu/rpc/negotiation.cc        | 166 +++++-----
 src/kudu/rpc/negotiation.h         |   6 +-
 src/kudu/rpc/reactor.cc            |  87 +++---
 src/kudu/rpc/reactor.h             |  23 +-
 src/kudu/rpc/rpc-test.cc           |   4 +-
 src/kudu/rpc/sasl_common.cc        |  47 +--
 src/kudu/rpc/sasl_common.h         |   2 +-
 src/kudu/rpc/sasl_helper.cc        |  71 +----
 src/kudu/rpc/sasl_helper.h         |  65 ++--
 src/kudu/rpc/server_negotiation.cc | 535 ++++++++++++++++----------------
 src/kudu/rpc/server_negotiation.h  | 135 ++++----
 src/kudu/security/ssl_socket.cc    |   6 +-
 src/kudu/security/ssl_socket.h     |   3 +-
 19 files changed, 951 insertions(+), 1123 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/kudu/blob/dc852535/src/kudu/rpc/client_negotiation.cc
----------------------------------------------------------------------
diff --git a/src/kudu/rpc/client_negotiation.cc b/src/kudu/rpc/client_negotiation.cc
index 0eb96b8..a7c6c63 100644
--- a/src/kudu/rpc/client_negotiation.cc
+++ b/src/kudu/rpc/client_negotiation.cc
@@ -20,6 +20,7 @@
 #include <string.h>
 
 #include <map>
+#include <memory>
 #include <set>
 #include <string>
 
@@ -29,7 +30,6 @@
 #include "kudu/gutil/endian.h"
 #include "kudu/gutil/map-util.h"
 #include "kudu/gutil/stl_util.h"
-#include "kudu/gutil/stringprintf.h"
 #include "kudu/gutil/strings/join.h"
 #include "kudu/gutil/strings/substitute.h"
 #include "kudu/gutil/strings/util.h"
@@ -51,21 +51,28 @@ namespace rpc {
 using std::map;
 using std::set;
 using std::string;
-
-static int SaslClientGetoptCb(void* sasl_client, const char* plugin_name, const char* option,
-                       const char** result, unsigned* len) {
-  return static_cast<SaslClient*>(sasl_client)
-    ->GetOptionCb(plugin_name, option, result, len);
+using std::unique_ptr;
+
+static int ClientNegotiationGetoptCb(ClientNegotiation* client_negotiation,
+                                     const char* plugin_name,
+                                     const char* option,
+                                     const char** result,
+                                     unsigned* len) {
+  return client_negotiation->GetOptionCb(plugin_name, option, result, len);
 }
 
-static int SaslClientSimpleCb(void *sasl_client, int id,
-                       const char **result, unsigned *len) {
-  return static_cast<SaslClient*>(sasl_client)->SimpleCb(id, result, len);
+static int ClientNegotiationSimpleCb(ClientNegotiation* client_negotiation,
+                                     int id,
+                                     const char** result,
+                                     unsigned* len) {
+  return client_negotiation->SimpleCb(id, result, len);
 }
 
-static int SaslClientSecretCb(sasl_conn_t* conn, void *sasl_client, int id,
-                       sasl_secret_t** psecret) {
-  return static_cast<SaslClient*>(sasl_client)->SecretCb(conn, id, psecret);
+static int ClientNegotiationSecretCb(sasl_conn_t* conn,
+                                     ClientNegotiation* client_negotiation,
+                                     int id,
+                                     sasl_secret_t** psecret) {
+  return client_negotiation->SecretCb(conn, id, psecret);
 }
 
 // Return an appropriately-typed Status object based on an ErrorStatusPB returned
@@ -85,187 +92,166 @@ static Status StatusFromRpcError(const ErrorStatusPB& error) {
   }
 }
 
-SaslClient::SaslClient(string app_name, Socket* socket)
-    : app_name_(std::move(app_name)),
-      sock_(socket),
+ClientNegotiation::ClientNegotiation(unique_ptr<Socket> socket)
+    : socket_(std::move(socket)),
       helper_(SaslHelper::CLIENT),
-      client_state_(SaslNegotiationState::NEW),
       negotiated_mech_(SaslMechanism::INVALID),
       deadline_(MonoTime::Max()) {
   callbacks_.push_back(SaslBuildCallback(SASL_CB_GETOPT,
-      reinterpret_cast<int (*)()>(&SaslClientGetoptCb), this));
+      reinterpret_cast<int (*)()>(&ClientNegotiationGetoptCb), this));
   callbacks_.push_back(SaslBuildCallback(SASL_CB_AUTHNAME,
-      reinterpret_cast<int (*)()>(&SaslClientSimpleCb), this));
+      reinterpret_cast<int (*)()>(&ClientNegotiationSimpleCb), this));
   callbacks_.push_back(SaslBuildCallback(SASL_CB_PASS,
-      reinterpret_cast<int (*)()>(&SaslClientSecretCb), this));
+      reinterpret_cast<int (*)()>(&ClientNegotiationSecretCb), this));
   callbacks_.push_back(SaslBuildCallback(SASL_CB_LIST_END, nullptr, nullptr));
 }
 
-Status SaslClient::EnablePlain(const string& user, const string& pass) {
-  DCHECK_EQ(client_state_, SaslNegotiationState::NEW);
+Status ClientNegotiation::EnablePlain(const string& user, const string& pass) {
   RETURN_NOT_OK(helper_.EnablePlain());
   plain_auth_user_ = user;
   plain_pass_ = pass;
   return Status::OK();
 }
 
-Status SaslClient::EnableGSSAPI() {
-  DCHECK_EQ(client_state_, SaslNegotiationState::NEW);
+Status ClientNegotiation::EnableGSSAPI() {
   return helper_.EnableGSSAPI();
 }
 
-SaslMechanism::Type SaslClient::negotiated_mechanism() const {
-  DCHECK_EQ(client_state_, SaslNegotiationState::NEGOTIATED);
+SaslMechanism::Type ClientNegotiation::negotiated_mechanism() const {
   return negotiated_mech_;
 }
 
-void SaslClient::set_local_addr(const Sockaddr& addr) {
-  DCHECK_EQ(client_state_, SaslNegotiationState::NEW);
+void ClientNegotiation::set_local_addr(const Sockaddr& addr) {
   helper_.set_local_addr(addr);
 }
 
-void SaslClient::set_remote_addr(const Sockaddr& addr) {
-  DCHECK_EQ(client_state_, SaslNegotiationState::NEW);
+void ClientNegotiation::set_remote_addr(const Sockaddr& addr) {
   helper_.set_remote_addr(addr);
 }
 
-void SaslClient::set_server_fqdn(const string& domain_name) {
-  DCHECK_EQ(client_state_, SaslNegotiationState::NEW);
+void ClientNegotiation::set_server_fqdn(const string& domain_name) {
   helper_.set_server_fqdn(domain_name);
 }
 
-void SaslClient::set_deadline(const MonoTime& deadline) {
-  DCHECK_NE(client_state_, SaslNegotiationState::NEGOTIATED);
+void ClientNegotiation::set_deadline(const MonoTime& deadline) {
   deadline_ = deadline;
 }
 
-// calls sasl_client_init() and sasl_client_new()
-Status SaslClient::Init(const string& service_type) {
-  RETURN_NOT_OK(SaslInit(app_name_.c_str()));
-
-  // Ensure we are not called more than once.
-  if (client_state_ != SaslNegotiationState::NEW) {
-    return Status::IllegalState("Init() may only be called once per SaslClient object.");
-  }
-
-  // TODO(unknown): Support security flags.
-  unsigned secflags = 0;
-
-  sasl_conn_t* sasl_conn = nullptr;
-  Status s = WrapSaslCall(nullptr /* no conn */, [&]() {
-      return sasl_client_new(
-          service_type.c_str(),         // Registered name of the service using SASL. Required.
-          helper_.server_fqdn(),        // The fully qualified domain name of the remote server.
-          helper_.local_addr_string(),  // Local and remote IP address strings. (NULL disables
-          helper_.remote_addr_string(), //   mechanisms which require this info.)
-          &callbacks_[0],               // Connection-specific callbacks.
-          secflags,                     // Security flags.
-          &sasl_conn);
-    });
-  if (!s.ok()) {
-    return Status::RuntimeError("Unable to create new SASL client",
-                                s.message());
-  }
-  sasl_conn_.reset(sasl_conn);
-
-  client_state_ = SaslNegotiationState::INITIALIZED;
-  return Status::OK();
-}
-
-Status SaslClient::Negotiate() {
-  // After negotiation, we no longer need the SASL library object, so
-  // may as well free its memory since the connection may be long-lived.
-  // Additionally, this works around a SEGV seen at process shutdown time:
-  // if we still have SASL objects retained by Reactor when the process
-  // is exiting, the SASL libraries may start destructing global state
-  // and cause a crash when we sasl_dispose the connection.
-  auto cleanup = MakeScopedCleanup([&]() {
-      sasl_conn_.reset();
-    });
-  TRACE("Called SaslClient::Negotiate()");
-
-  // Ensure we called exactly once, and in the right order.
-  if (client_state_ == SaslNegotiationState::NEW) {
-    return Status::IllegalState("SaslClient: Init() must be called before calling Negotiate()");
-  }
-  if (client_state_ == SaslNegotiationState::NEGOTIATED) {
-    return Status::IllegalState("SaslClient: Negotiate() may only be called once per object.");
-  }
+Status ClientNegotiation::Negotiate() {
+  TRACE("Beginning negotiation");
 
   // Ensure we can use blocking calls on the socket during negotiation.
-  RETURN_NOT_OK(EnsureBlockingMode(sock_));
+  RETURN_NOT_OK(EnsureBlockingMode(socket_.get()));
 
-  // Start by asking the server for a list of available auth mechanisms.
-  RETURN_NOT_OK(SendNegotiateMessage());
+  // Step 1: send the connection header.
+  RETURN_NOT_OK(SendConnectionHeader());
 
   faststring recv_buf;
-  nego_ok_ = false;
-
-  // We set nego_ok_ = true when the SASL library returns SASL_OK to us.
-  // We set nego_response_expected_ = true each time we send a request to the server.
-  while (!nego_ok_ || nego_response_expected_) {
-    ResponseHeader header;
-    Slice param_buf;
-    RETURN_NOT_OK(ReceiveFramedMessageBlocking(sock_, &recv_buf, &header, &param_buf, deadline_));
-    nego_response_expected_ = false;
 
+  { // Step 2: send and receive the NEGOTIATE step messages.
+    RETURN_NOT_OK(SendNegotiate());
     NegotiatePB response;
-    RETURN_NOT_OK(ParseNegotiatePB(header, param_buf, &response));
+    RETURN_NOT_OK(RecvNegotiatePB(&response, &recv_buf));
+    RETURN_NOT_OK(HandleNegotiate(response));
+  }
 
+  // Step 3: SASL negotiation.
+  RETURN_NOT_OK(InitSaslClient());
+  RETURN_NOT_OK(SendSaslInitiate());
+  for (bool cont = true; cont; ) {
+    NegotiatePB response;
+    RETURN_NOT_OK(RecvNegotiatePB(&response, &recv_buf));
+    Status s;
     switch (response.step()) {
-      // NEGOTIATE: Server has sent us its list of supported SASL mechanisms.
-      case NegotiatePB::NEGOTIATE:
-        RETURN_NOT_OK(HandleNegotiateResponse(response));
-        break;
-
       // SASL_CHALLENGE: Server sent us a follow-up to an SASL_INITIATE or SASL_RESPONSE request.
       case NegotiatePB::SASL_CHALLENGE:
-        RETURN_NOT_OK(HandleChallengeResponse(response));
+        RETURN_NOT_OK(HandleSaslChallenge(response));
         break;
-
       // SASL_SUCCESS: Server has accepted our authentication request. Negotiation successful.
       case NegotiatePB::SASL_SUCCESS:
-        RETURN_NOT_OK(HandleSuccessResponse(response));
+        cont = false;
         break;
-
-      // Client sent us some unsupported SASL response.
       default:
-        LOG(ERROR) << "SASL Client: Received unsupported response from server";
-        return Status::InvalidArgument("RPC client doesn't support Negotiate step",
-                                       NegotiatePB::NegotiateStep_Name(response.step()));
+        return Status::NotAuthorized("expected SASL_CHALLENGE or SASL_SUCCESS step",
+                                     NegotiatePB::NegotiateStep_Name(response.step()));
     }
   }
 
-  TRACE("SASL Client: Successful negotiation");
-  client_state_ = SaslNegotiationState::NEGOTIATED;
+  // Step 4: Send connection context.
+  RETURN_NOT_OK(SendConnectionContext());
+
+  TRACE("Negotiation successful");
   return Status::OK();
 }
 
-Status SaslClient::SendNegotiatePB(const NegotiatePB& msg) {
-  DCHECK_NE(client_state_, SaslNegotiationState::NEW)
-      << "Must not send Negotiate messages before calling Init()";
-  DCHECK_NE(client_state_, SaslNegotiationState::NEGOTIATED)
-      << "Must not send Negotiate messages after negotiation succeeds";
-
-  // Create header with SASL-specific callId
+Status ClientNegotiation::SendNegotiatePB(const NegotiatePB& msg) {
   RequestHeader header;
   header.set_call_id(kNegotiateCallId);
-  return helper_.SendNegotiatePB(sock_, header, msg, deadline_);
+
+  DCHECK(socket_);
+  DCHECK(msg.IsInitialized()) << "message must be initialized";
+  DCHECK(msg.has_step()) << "message must have a step";
+
+  TRACE("Sending $0 NegotiatePB request", NegotiatePB::NegotiateStep_Name(msg.step()));
+  return SendFramedMessageBlocking(socket(), header, msg, deadline_);
 }
 
-Status SaslClient::ParseNegotiatePB(const ResponseHeader& header,
-                                    const Slice& param_buf,
-                                    NegotiatePB* response) {
-  RETURN_NOT_OK(helper_.SanityCheckNegotiateCallId(header.call_id()));
+Status ClientNegotiation::RecvNegotiatePB(NegotiatePB* msg, faststring* buffer) {
+  ResponseHeader header;
+  Slice param_buf;
+  RETURN_NOT_OK(ReceiveFramedMessageBlocking(socket(), buffer, &header, &param_buf, deadline_));
+  RETURN_NOT_OK(helper_.CheckNegotiateCallId(header.call_id()));
 
   if (header.is_error()) {
     return ParseError(param_buf);
   }
 
-  return helper_.ParseNegotiatePB(param_buf, response);
+  RETURN_NOT_OK(helper_.ParseNegotiatePB(param_buf, msg));
+  TRACE("Received $0 NegotiatePB response", NegotiatePB::NegotiateStep_Name(msg->step()));
+  return Status::OK();
+}
+
+Status ClientNegotiation::ParseError(const Slice& err_data) {
+  ErrorStatusPB error;
+  if (!error.ParseFromArray(err_data.data(), err_data.size())) {
+    return Status::IOError("invalid error response, missing fields",
+                           error.InitializationErrorString());
+  }
+  Status s = StatusFromRpcError(error);
+  TRACE("Received error response from server: $0", s.ToString());
+  return s;
+}
+
+Status ClientNegotiation::SendConnectionHeader() {
+  const uint8_t buflen = kMagicNumberLength + kHeaderFlagsLength;
+  uint8_t buf[buflen];
+  serialization::SerializeConnHeader(buf);
+  size_t nsent;
+  return socket()->BlockingWrite(buf, buflen, &nsent, deadline_);
 }
 
-Status SaslClient::SendNegotiateMessage() {
+Status ClientNegotiation::InitSaslClient() {
+  RETURN_NOT_OK(SaslInit());
+
+  // TODO(unknown): Support security flags.
+  unsigned secflags = 0;
+
+  sasl_conn_t* sasl_conn = nullptr;
+  RETURN_NOT_OK_PREPEND(WrapSaslCall(nullptr /* no conn */, [&]() {
+      return sasl_client_new(
+          kSaslProtoName,               // Registered name of the service using SASL. Required.
+          helper_.server_fqdn(),        // The fully qualified domain name of the remote server.
+          helper_.local_addr_string(),  // Local and remote IP address strings. (NULL disables
+          helper_.remote_addr_string(), //   mechanisms which require this info.)
+          &callbacks_[0],               // Connection-specific callbacks.
+          secflags,                     // Security flags.
+          &sasl_conn);
+    }), "Unable to create new SASL client");
+  sasl_conn_.reset(sasl_conn);
+  return Status::OK();
+}
+
+Status ClientNegotiation::SendNegotiate() {
   NegotiatePB msg;
   msg.set_step(NegotiatePB::NEGOTIATE);
 
@@ -274,59 +260,29 @@ Status SaslClient::SendNegotiateMessage() {
     msg.add_supported_features(feature);
   }
 
-  TRACE("SASL Client: Sending NEGOTIATE request to server.");
-  RETURN_NOT_OK(SendNegotiatePB(msg));
-  nego_response_expected_ = true;
-  return Status::OK();
-}
-
-Status SaslClient::SendInitiateMessage(const NegotiatePB_SaslAuth& auth,
-    const char* init_msg, unsigned init_msg_len) {
-  NegotiatePB msg;
-  msg.set_step(NegotiatePB::SASL_INITIATE);
-  msg.mutable_token()->assign(init_msg, init_msg_len);
-  msg.add_auths()->CopyFrom(auth);
-  TRACE("SASL Client: Sending SASL_INITIATE request to server.");
   RETURN_NOT_OK(SendNegotiatePB(msg));
-  nego_response_expected_ = true;
-  return Status::OK();
-}
-
-Status SaslClient::SendResponseMessage(const char* resp_msg, unsigned resp_msg_len) {
-  NegotiatePB reply;
-  reply.set_step(NegotiatePB::SASL_RESPONSE);
-  reply.mutable_token()->assign(resp_msg, resp_msg_len);
-  TRACE("SASL Client: Sending SASL_RESPONSE request to server.");
-  RETURN_NOT_OK(SendNegotiatePB(reply));
-  nego_response_expected_ = true;
   return Status::OK();
 }
 
-Status SaslClient::DoSaslStep(const string& in, const char** out, unsigned* out_len) {
-  TRACE("SASL Client: Calling sasl_client_step()");
-  Status s = WrapSaslCall(sasl_conn_.get(), [&]() {
-      return sasl_client_step(sasl_conn_.get(), in.c_str(), in.length(), nullptr, out, out_len);
-    });
-  if (s.ok()) {
-    nego_ok_ = true;
+Status ClientNegotiation::HandleNegotiate(const NegotiatePB& response) {
+  if (PREDICT_FALSE(response.step() != NegotiatePB::NEGOTIATE)) {
+    return Status::NotAuthorized("expected NEGOTIATE step",
+                                 NegotiatePB::NegotiateStep_Name(response.step()));
   }
-  return s;
-}
+  TRACE("Received NEGOTIATE response from server");
 
-Status SaslClient::HandleNegotiateResponse(const NegotiatePB& response) {
-  TRACE("SASL Client: Received NEGOTIATE response from server");
   // Fill in the set of features supported by the server.
   for (int flag : response.supported_features()) {
     // We only add the features that our local build knows about.
     RpcFeatureFlag feature_flag = RpcFeatureFlag_IsValid(flag) ?
                                   static_cast<RpcFeatureFlag>(flag) : UNKNOWN;
-    if (ContainsKey(kSupportedClientRpcFeatureFlags, feature_flag)) {
+    if (feature_flag != UNKNOWN) {
       server_features_.insert(feature_flag);
     }
   }
 
-  // Build a map of the mechanisms offered by the server.
-  const set<string>& local_mechs = helper_.LocalMechs();
+  // Build a map of the SASL mechanisms offered by the server.
+  const set<string>& enabled_mechs = helper_.EnabledMechs();
   set<string> server_mechs;
   map<string, NegotiatePB::SaslAuth> server_mech_map;
   for (const NegotiatePB::SaslAuth& auth : response.auths()) {
@@ -334,21 +290,26 @@ Status SaslClient::HandleNegotiateResponse(const NegotiatePB& response) {
     server_mech_map[mech] = auth;
     server_mechs.insert(mech);
   }
+
   // Determine which server mechs are also enabled by the client.
   // Cyrus SASL 2.1.25 and later supports doing this set intersection via
   // the 'client_mech_list' option, but that version is not available on
   // RHEL 6, so we have to do it manually.
-  set<string> matching_mechs = STLSetIntersection(local_mechs, server_mechs);
+  common_mechs_ = STLSetIntersection(enabled_mechs, server_mechs);
 
-  if (matching_mechs.empty() &&
+  if (common_mechs_.empty() &&
       ContainsKey(server_mechs, kSaslMechGSSAPI) &&
-      !ContainsKey(local_mechs, kSaslMechGSSAPI)) {
+      !ContainsKey(enabled_mechs, kSaslMechGSSAPI)) {
     return Status::NotAuthorized("server requires GSSAPI (Kerberos) authentication and "
                                  "client was missing the required SASL module");
   }
 
-  string matching_mechs_str = JoinElements(matching_mechs, " ");
-  TRACE("SASL Client: Matching mech list: $0", matching_mechs_str);
+  return Status::OK();
+}
+
+Status ClientNegotiation::SendSaslInitiate() {
+  string matching_mechs_str = JoinElements(common_mechs_, " ");
+  TRACE("Matching mech list: $0", matching_mechs_str);
 
   const char* init_msg = nullptr;
   unsigned init_msg_len = 0;
@@ -368,7 +329,7 @@ Status SaslClient::HandleNegotiateResponse(const NegotiatePB& response) {
    *  SASL_NOMECH   -- no mechanism meets requested properties
    *  SASL_INTERACT -- user interaction needed to fill in prompt_need list
    */
-  TRACE("SASL Client: Calling sasl_client_start()");
+  TRACE("Calling sasl_client_start()");
   Status s = WrapSaslCall(sasl_conn_.get(), [&]() {
       return sasl_client_start(
           sasl_conn_.get(),           // The SASL connection context created by init()
@@ -377,112 +338,101 @@ Status SaslClient::HandleNegotiateResponse(const NegotiatePB& response) {
           &init_msg,                  // Filled in on success.
           &init_msg_len,              // Filled in on success.
           &negotiated_mech);          // Filled in on success.
-    });
+  });
 
-  if (s.ok()) {
-    nego_ok_ = true;
-  } else if (!s.IsIncomplete()) {
+  if (PREDICT_FALSE(!s.IsIncomplete() && !s.ok())) {
     return s;
   }
 
   // The server matched one of our mechanisms.
-  NegotiatePB::SaslAuth* auth = FindOrNull(server_mech_map, negotiated_mech);
-  if (PREDICT_FALSE(auth == nullptr)) {
-    return Status::IllegalState("Unable to find auth in map, unexpected error", negotiated_mech);
-  }
   negotiated_mech_ = SaslMechanism::value_of(negotiated_mech);
 
-  RETURN_NOT_OK(SendInitiateMessage(*auth, init_msg, init_msg_len));
-  return Status::OK();
+  NegotiatePB msg;
+  msg.set_step(NegotiatePB::SASL_INITIATE);
+  msg.mutable_token()->assign(init_msg, init_msg_len);
+  msg.add_auths()->set_mechanism(negotiated_mech);
+  return SendNegotiatePB(msg);
 }
 
-Status SaslClient::HandleChallengeResponse(const NegotiatePB& response) {
-  TRACE("SASL Client: Received SASL_CHALLENGE response from server");
-  if (PREDICT_FALSE(nego_ok_)) {
-    LOG(DFATAL) << "Server sent SASL_CHALLENGE response after client library returned SASL_OK";
-  }
+Status ClientNegotiation::SendSaslResponse(const char* resp_msg, unsigned resp_msg_len) {
+  NegotiatePB reply;
+  reply.set_step(NegotiatePB::SASL_RESPONSE);
+  reply.mutable_token()->assign(resp_msg, resp_msg_len);
+  return SendNegotiatePB(reply);
+}
 
+Status ClientNegotiation::HandleSaslChallenge(const NegotiatePB& response) {
+  TRACE("Received SASL_CHALLENGE response from server");
   if (PREDICT_FALSE(!response.has_token())) {
-    return Status::InvalidArgument("No token in SASL_CHALLENGE response from server");
+    return Status::NotAuthorized("no token in SASL_CHALLENGE response from server");
   }
 
   const char* out = nullptr;
   unsigned out_len = 0;
   Status s = DoSaslStep(response.token(), &out, &out_len);
-  if (!s.ok() && !s.IsIncomplete()) {
+  if (PREDICT_FALSE(!s.IsIncomplete() && !s.ok())) {
     return s;
   }
-  RETURN_NOT_OK(SendResponseMessage(out, out_len));
-  return Status::OK();
+
+  return SendSaslResponse(out, out_len);
 }
 
-Status SaslClient::HandleSuccessResponse(const NegotiatePB& response) {
-  TRACE("SASL Client: Received SASL_SUCCESS response from server");
-  if (!nego_ok_) {
-    const char* out = nullptr;
-    unsigned out_len = 0;
-    Status s = DoSaslStep(response.token(), &out, &out_len);
-    if (s.IsIncomplete()) {
-      return Status::IllegalState("Server indicated successful authentication, but client "
-                                  "was not complete");
-    }
-    RETURN_NOT_OK(s);
-    if (out_len > 0) {
-      return Status::IllegalState("SASL client library generated spurious token after SASL_SUCCESS",
-          string(out, out_len));
-    }
-    CHECK(nego_ok_);
-  }
-  return Status::OK();
+Status ClientNegotiation::DoSaslStep(const string& in, const char** out, unsigned* out_len) {
+  TRACE("Calling sasl_client_step()");
+
+  return WrapSaslCall(sasl_conn_.get(), [&]() {
+      return sasl_client_step(sasl_conn_.get(), in.c_str(), in.length(), nullptr, out, out_len);
+  });
 }
 
-// Parse error status message from raw bytes of an ErrorStatusPB.
-Status SaslClient::ParseError(const Slice& err_data) {
-  ErrorStatusPB error;
-  if (!error.ParseFromArray(err_data.data(), err_data.size())) {
-    return Status::IOError("Invalid error response, missing fields",
-        error.InitializationErrorString());
-  }
-  Status s = StatusFromRpcError(error);
-  TRACE("SASL Client: Received error response from server: $0", s.ToString());
-  return s;
+Status ClientNegotiation::SendConnectionContext() {
+  TRACE("Sending connection context");
+  RequestHeader header;
+  header.set_call_id(kConnectionContextCallId);
+
+  ConnectionContextPB conn_context;
+  // This field is deprecated but used by servers <Kudu 1.1. Newer server versions ignore
+  // this and use the SASL-provided username instead.
+  conn_context.mutable_deprecated_user_info()->set_real_user(
+      plain_auth_user_.empty() ? "cpp-client" : plain_auth_user_);
+  return SendFramedMessageBlocking(socket(), header, conn_context, deadline_);
 }
 
-int SaslClient::GetOptionCb(const char* plugin_name, const char* option,
+int ClientNegotiation::GetOptionCb(const char* plugin_name, const char* option,
                             const char** result, unsigned* len) {
   return helper_.GetOptionCb(plugin_name, option, result, len);
 }
 
 // Used for PLAIN.
 // SASL callback for SASL_CB_USER, SASL_CB_AUTHNAME, SASL_CB_LANGUAGE
-int SaslClient::SimpleCb(int id, const char** result, unsigned* len) {
+int ClientNegotiation::SimpleCb(int id, const char** result, unsigned* len) {
   if (PREDICT_FALSE(!helper_.IsPlainEnabled())) {
-    LOG(DFATAL) << "SASL Client: Simple callback called, but PLAIN auth is not enabled";
+    LOG(DFATAL) << "Simple callback called, but PLAIN auth is not enabled";
     return SASL_FAIL;
   }
   if (PREDICT_FALSE(result == nullptr)) {
-    LOG(DFATAL) << "SASL Client: result outparam is NULL";
+    LOG(DFATAL) << "result outparam is NULL";
     return SASL_BADPARAM;
   }
   switch (id) {
     // TODO(unknown): Support impersonation?
     // For impersonation, USER is the impersonated user, AUTHNAME is the "sudoer".
     case SASL_CB_USER:
-      TRACE("SASL Client: callback for SASL_CB_USER");
+      TRACE("callback for SASL_CB_USER");
       *result = plain_auth_user_.c_str();
       if (len != nullptr) *len = plain_auth_user_.length();
       break;
     case SASL_CB_AUTHNAME:
-      TRACE("SASL Client: callback for SASL_CB_AUTHNAME");
+      TRACE("callback for SASL_CB_AUTHNAME");
       *result = plain_auth_user_.c_str();
       if (len != nullptr) *len = plain_auth_user_.length();
       break;
     case SASL_CB_LANGUAGE:
-      LOG(DFATAL) << "SASL Client: Unable to handle SASL callback type SASL_CB_LANGUAGE"
+      LOG(DFATAL) << "Unable to handle SASL callback type SASL_CB_LANGUAGE"
         << "(" << id << ")";
       return SASL_BADPARAM;
     default:
-      LOG(DFATAL) << "SASL Client: Unexpected SASL callback type: " << id;
+      LOG(DFATAL) << "Unexpected SASL callback type: " << id;
       return SASL_BADPARAM;
   }
 
@@ -491,9 +441,9 @@ int SaslClient::SimpleCb(int id, const char** result, unsigned* len) {
 
 // Used for PLAIN.
 // SASL callback for SASL_CB_PASS: User password.
-int SaslClient::SecretCb(sasl_conn_t* conn, int id, sasl_secret_t** psecret) {
+int ClientNegotiation::SecretCb(sasl_conn_t* conn, int id, sasl_secret_t** psecret) {
   if (PREDICT_FALSE(!helper_.IsPlainEnabled())) {
-    LOG(DFATAL) << "SASL Client: Plain secret callback called, but PLAIN auth is not enabled";
+    LOG(DFATAL) << "Plain secret callback called, but PLAIN auth is not enabled";
     return SASL_FAIL;
   }
   switch (id) {
@@ -511,7 +461,7 @@ int SaslClient::SecretCb(sasl_conn_t* conn, int id, sasl_secret_t** psecret) {
       break;
     }
     default:
-      LOG(DFATAL) << "SASL Client: Unexpected SASL callback type: " << id;
+      LOG(DFATAL) << "Unexpected SASL callback type: " << id;
       return SASL_BADPARAM;
   }
 

http://git-wip-us.apache.org/repos/asf/kudu/blob/dc852535/src/kudu/rpc/client_negotiation.h
----------------------------------------------------------------------
diff --git a/src/kudu/rpc/client_negotiation.h b/src/kudu/rpc/client_negotiation.h
index 2f01e56..6677d33 100644
--- a/src/kudu/rpc/client_negotiation.h
+++ b/src/kudu/rpc/client_negotiation.h
@@ -15,9 +15,9 @@
 // specific language governing permissions and limitations
 // under the License.
 
-#ifndef KUDU_RPC_SASL_CLIENT_H
-#define KUDU_RPC_SASL_CLIENT_H
+#pragma once
 
+#include <memory>
 #include <set>
 #include <string>
 #include <vector>
@@ -33,65 +33,77 @@
 #include "kudu/util/status.h"
 
 namespace kudu {
-namespace rpc {
 
-using std::string;
+namespace rpc {
 
 class NegotiatePB;
 class NegotiatePB_SaslAuth;
 class ResponseHeader;
 
-// Class for doing SASL negotiation with a SaslServer over a bidirectional socket.
+// Class for doing KRPC negotiation with a remote server over a bidirectional socket.
 // Operations on this class are NOT thread-safe.
-class SaslClient {
+class ClientNegotiation {
  public:
-  // Does not take ownership of the socket indicated by the fd.
-  SaslClient(string app_name, Socket* socket);
+  // Creates a new client negotiation instance, taking ownership of the
+  // provided socket. After completing the negotiation process by setting the
+  // desired options and calling Negotiate(), the socket can be retrieved with
+  // 'release_socket'.
+  explicit ClientNegotiation(std::unique_ptr<Socket> socket);
 
   // Enable PLAIN authentication.
-  // Must be called after Init().
-  Status EnablePlain(const string& user, const string& pass);
+  // Must be called before Negotiate().
+  Status EnablePlain(const std::string& user,
+                     const std::string& pass);
 
   // Enable GSSAPI authentication.
-  // Call after Init().
+  // Must be called before Negotiate().
   Status EnableGSSAPI();
 
   // Returns mechanism negotiated by this connection.
-  // Must be called after Negotiate().
+  // Must be called before Negotiate().
   SaslMechanism::Type negotiated_mechanism() const;
 
   // Returns the set of RPC system features supported by the remote server.
-  // Must be called after Negotiate().
-  const std::set<RpcFeatureFlag>& server_features() const {
+  // Must be called before Negotiate().
+  std::set<RpcFeatureFlag> server_features() const {
     return server_features_;
   }
 
+  // Returns the set of RPC system features supported by the remote server.
+  // Must be called after Negotiate().
+  // Subsequent calls to this method or server_features() will return an empty set.
+  std::set<RpcFeatureFlag> take_server_features() {
+    return std::move(server_features_);
+  }
+
   // Specify IP:port of local side of connection.
-  // Must be called before Init(). Required for some mechanisms.
+  // Must be called before Negotiate(). Required for some mechanisms.
   void set_local_addr(const Sockaddr& addr);
 
   // Specify IP:port of remote side of connection.
-  // Must be called before Init(). Required for some mechanisms.
+  // Must be called before Negotiate(). Required for some mechanisms.
   void set_remote_addr(const Sockaddr& addr);
 
   // Specify the fully-qualified domain name of the remote server.
-  // Must be called before Init(). Required for some mechanisms.
-  void set_server_fqdn(const string& domain_name);
+  // Must be called before Negotiate(). Required for some mechanisms.
+  void set_server_fqdn(const std::string& domain_name);
 
   // Set deadline for connection negotiation.
   void set_deadline(const MonoTime& deadline);
 
-  // Get deadline for connection negotiation.
-  const MonoTime& deadline() const { return deadline_; }
+  Socket* socket() { return socket_.get(); }
 
-  // Initialize a new SASL client. Must be called before Negotiate().
-  // Returns OK on success, otherwise RuntimeError.
-  Status Init(const string& service_type);
+  // Takes and returns the socket owned by this client negotiation. The caller
+  // will own the socket after this call, and the negotiation instance should no
+  // longer be used. Must be called after Negotiate(). Subsequent calls to this
+  // method or socket() will return a null pointer.
+  std::unique_ptr<Socket> release_socket() { return std::move(socket_); }
 
-  // Begin negotiation with the SASL server on the other side of the fd socket
-  // that this client was constructed with.
-  // Returns OK on success.
-  // Otherwise, it may return NotAuthorized, NotSupported, or another non-OK status.
+  // Negotiate with the remote server. Should only be called once per
+  // ClientNegotiation and socket instance, after all options have been set.
+  //
+  // Returns OK on success, otherwise may return NotAuthorized, NotSupported, or
+  // another non-OK status.
   Status Negotiate();
 
   // SASL callback for plugin options, supported mechanisms, etc.
@@ -106,23 +118,36 @@ class SaslClient {
   int SecretCb(sasl_conn_t* conn, int id, sasl_secret_t** psecret);
 
  private:
-  // Encode and send the specified negotiate message to the server.
-  Status SendNegotiatePB(const NegotiatePB& msg);
 
-  // Validate that header does not indicate an error, parse param_buf into response.
-  Status ParseNegotiatePB(const ResponseHeader& header,
-                          const Slice& param_buf,
-                          NegotiatePB* response);
+  // Encode and send the specified negotiate request message to the server.
+  Status SendNegotiatePB(const NegotiatePB& msg) WARN_UNUSED_RESULT;
+
+  // Receive a negotiate response message from the server, deserializing it into 'msg'.
+  // Validates that the response is not an error.
+  Status RecvNegotiatePB(NegotiatePB* msg, faststring* buffer) WARN_UNUSED_RESULT;
+
+  // Parse error status message from raw bytes of an ErrorStatusPB.
+  Status ParseError(const Slice& err_data) WARN_UNUSED_RESULT;
 
-  // Send an NEGOTIATE message to the server.
-  Status SendNegotiateMessage();
+  Status SendConnectionHeader() WARN_UNUSED_RESULT;
+
+  // Initialize the SASL client negotiation instance.
+  Status InitSaslClient() WARN_UNUSED_RESULT;
+
+  // Send a NEGOTIATE step message to the server.
+  Status SendNegotiate() WARN_UNUSED_RESULT;
+
+  // Handle NEGOTIATE step response from the server.
+  Status HandleNegotiate(const NegotiatePB& response) WARN_UNUSED_RESULT;
 
   // Send an SASL_INITIATE message to the server.
-  Status SendInitiateMessage(const NegotiatePB_SaslAuth& auth,
-                             const char* init_msg, unsigned init_msg_len);
+  Status SendSaslInitiate() WARN_UNUSED_RESULT;
 
-  // Send a RESPONSE message to the server.
-  Status SendResponseMessage(const char* resp_msg, unsigned resp_msg_len);
+  // Send a SASL_RESPONSE message to the server.
+  Status SendSaslResponse(const char* resp_msg, unsigned resp_msg_len) WARN_UNUSED_RESULT;
+
+  // Handle case when server sends SASL_CHALLENGE response.
+  Status HandleSaslChallenge(const NegotiatePB& response) WARN_UNUSED_RESULT;
 
   // Perform a client-side step of the SASL negotiation.
   // Input is what came from the server. Output is what we will send back to the server.
@@ -130,51 +155,36 @@ class SaslClient {
   //   Status::OK if sasl_client_step returns SASL_OK.
   //   Status::Incomplete if sasl_client_step returns SASL_CONTINUE
   // otherwise returns an appropriate error status.
-  Status DoSaslStep(const string& in, const char** out, unsigned* out_len);
-
-  // Handle case when server sends NEGOTIATE response.
-  Status HandleNegotiateResponse(const NegotiatePB& response);
+  Status DoSaslStep(const std::string& in, const char** out, unsigned* out_len) WARN_UNUSED_RESULT;
 
-  // Handle case when server sends CHALLENGE response.
-  Status HandleChallengeResponse(const NegotiatePB& response);
+  Status SendConnectionContext() WARN_UNUSED_RESULT;
 
-  // Handle case when server sends SUCCESS response.
-  Status HandleSuccessResponse(const NegotiatePB& response);
+  // The socket to the remote server.
+  std::unique_ptr<Socket> socket_;
 
-  // Parse error status message from raw bytes of an ErrorStatusPB.
-  Status ParseError(const Slice& err_data);
-
-  string app_name_;
-  Socket* sock_;
+  // SASL state.
   std::vector<sasl_callback_t> callbacks_;
-  // The SASL connection object. This is initialized in Init() and
-  // freed after Negotiate() completes (regardless whether it was successful).
   gscoped_ptr<sasl_conn_t, SaslDeleter> sasl_conn_;
   SaslHelper helper_;
 
-  string plain_auth_user_;
-  string plain_pass_;
+  // Authentication state.
+  std::string plain_auth_user_;
+  std::string plain_pass_;
   gscoped_ptr<sasl_secret_t, FreeDeleter> psecret_;
 
-  // The set of features supported by the server.
+  // The set of features supported by the server. Filled in during negotiation.
   std::set<RpcFeatureFlag> server_features_;
 
-  SaslNegotiationState::Type client_state_;
+  // The set of SASL mechanisms supported by the server and the client. Filled
+  // in during negotiation.
+  std::set<std::string> common_mechs_;
 
-  // The mechanism we negotiated with the server.
+  // The SASL mechanism used by the connection. Filled in during negotiation.
   SaslMechanism::Type negotiated_mech_;
 
-  // Intra-negotiation state.
-  bool nego_ok_;  // During negotiation: did we get a SASL_OK response from the SASL library?
-  bool nego_response_expected_;  // During negotiation: Are we waiting for a server response?
-
   // Negotiation timeout deadline.
   MonoTime deadline_;
-
-  DISALLOW_COPY_AND_ASSIGN(SaslClient);
 };
 
 } // namespace rpc
 } // namespace kudu
-
-#endif  // KUDU_RPC_SASL_CLIENT_H

http://git-wip-us.apache.org/repos/asf/kudu/blob/dc852535/src/kudu/rpc/connection.cc
----------------------------------------------------------------------
diff --git a/src/kudu/rpc/connection.cc b/src/kudu/rpc/connection.cc
index 2abf6ca..30d3720 100644
--- a/src/kudu/rpc/connection.cc
+++ b/src/kudu/rpc/connection.cc
@@ -17,16 +17,19 @@
 
 #include "kudu/rpc/connection.h"
 
+#include <stdint.h>
+
 #include <algorithm>
-#include <boost/intrusive/list.hpp>
-#include <gflags/gflags.h>
-#include <glog/logging.h>
 #include <iostream>
 #include <set>
-#include <stdint.h>
 #include <string>
+#include <unordered_set>
 #include <vector>
 
+#include <boost/intrusive/list.hpp>
+#include <gflags/gflags.h>
+#include <glog/logging.h>
+
 #include "kudu/gutil/map-util.h"
 #include "kudu/gutil/strings/human_readable.h"
 #include "kudu/gutil/strings/substitute.h"
@@ -38,12 +41,12 @@
 #include "kudu/rpc/rpc_header.pb.h"
 #include "kudu/rpc/rpc_introspection.pb.h"
 #include "kudu/rpc/transfer.h"
-#include "kudu/security/ssl_factory.h"
 #include "kudu/security/ssl_socket.h"
 #include "kudu/util/debug-util.h"
 #include "kudu/util/flag_tags.h"
 #include "kudu/util/logging.h"
 #include "kudu/util/net/sockaddr.h"
+#include "kudu/util/net/socket.h"
 #include "kudu/util/status.h"
 #include "kudu/util/trace.h"
 
@@ -51,28 +54,27 @@ using std::function;
 using std::includes;
 using std::set;
 using std::shared_ptr;
+using std::unique_ptr;
 using std::vector;
 using strings::Substitute;
 
-DECLARE_bool(server_require_kerberos);
-
 namespace kudu {
 namespace rpc {
 
 ///
 /// Connection
 ///
-Connection::Connection(ReactorThread *reactor_thread, Sockaddr remote,
-                       Socket* socket, Direction direction)
+Connection::Connection(ReactorThread *reactor_thread,
+                       Sockaddr remote,
+                       unique_ptr<Socket> socket,
+                       Direction direction)
     : reactor_thread_(reactor_thread),
-      remote_(std::move(remote)),
-      socket_(socket),
+      remote_(remote),
+      socket_(std::move(socket)),
       direction_(direction),
       last_activity_time_(MonoTime::Now()),
       is_epoll_registered_(false),
       next_call_id_(1),
-      sasl_client_(kSaslAppName, socket),
-      sasl_server_(kSaslAppName, socket),
       negotiation_complete_(false) {
 }
 
@@ -169,7 +171,9 @@ void Connection::Shutdown(const Status &status) {
   read_io_.stop();
   write_io_.stop();
   is_epoll_registered_ = false;
-  WARN_NOT_OK(socket_->Close(), "Error closing socket");
+  if (socket_) {
+    WARN_NOT_OK(socket_->Close(), "Error closing socket");
+  }
 }
 
 void Connection::QueueOutbound(gscoped_ptr<OutboundTransfer> transfer) {
@@ -578,8 +582,7 @@ void Connection::WriteHandler(ev::io &watcher, int revents) {
         // order to ensure that the negotiation has taken place, so that the flags
         // are available.
         const set<RpcFeatureFlag>& required_features = car->call->required_rpc_features();
-        const set<RpcFeatureFlag>& server_features = sasl_client_.server_features();
-        if (!includes(server_features.begin(), server_features.end(),
+        if (!includes(remote_features_.begin(), remote_features_.end(),
                       required_features.begin(), required_features.end())) {
           outbound_transfers_.pop_front();
           Status s = Status::NotSupported("server does not support the required RPC features");
@@ -626,58 +629,13 @@ std::string Connection::ToString() const {
     remote_.ToString());
 }
 
-Status Connection::InitSSLIfNecessary() {
-  if (!reactor_thread_->reactor()->messenger()->ssl_enabled()) return Status::OK();
-  SSLSocket* ssl_socket = down_cast<SSLSocket*>(socket_.get());
-  RETURN_NOT_OK(ssl_socket->DoHandshake());
-  return Status::OK();
-}
-
-Status Connection::InitSaslClient() {
-  // Note that remote_.host() is an IP address here: we've already lost
-  // whatever DNS name the client was attempting to use. Unless krb5
-  // is configured with 'rdns = false', it will automatically take care
-  // of reversing this address to its canonical hostname to determine
-  // the expected server principal.
-  sasl_client().set_server_fqdn(remote_.host());
-  Status gssapi_status = sasl_client().EnableGSSAPI();
-  if (!gssapi_status.ok()) {
-    // If we can't enable GSSAPI, it's likely the client is just missing the
-    // appropriate SASL plugin. We don't want to require it to be installed
-    // if the user doesn't care about connecting to servers using Kerberos
-    // authentication. So, we'll just VLOG this here. If we try to connect
-    // to a server which requires Kerberos, we'll get a negotiation error
-    // at that point.
-    if (VLOG_IS_ON(1)) {
-      KLOG_FIRST_N(INFO, 1) << "Couldn't enable GSSAPI (Kerberos) SASL plugin: "
-                            << gssapi_status.message().ToString()
-                            << ". This process will be unable to connect to "
-                            << "servers requiring Kerberos authentication.";
-    }
-  }
-  RETURN_NOT_OK(sasl_client().EnablePlain(user_credentials().real_user(), ""));
-  RETURN_NOT_OK(sasl_client().Init(kSaslProtoName));
-  return Status::OK();
-}
-
-Status Connection::InitSaslServer() {
-  if (FLAGS_server_require_kerberos) {
-    RETURN_NOT_OK(sasl_server().EnableGSSAPI());
-  } else {
-    RETURN_NOT_OK(sasl_server().EnablePlain());
-  }
-  RETURN_NOT_OK(sasl_server().Init(kSaslProtoName));
-  return Status::OK();
-}
-
 // Reactor task that transitions this Connection from connection negotiation to
 // regular RPC handling. Destroys Connection on negotiation error.
 class NegotiationCompletedTask : public ReactorTask {
  public:
-  NegotiationCompletedTask(Connection* conn,
-      const Status& negotiation_status)
+  NegotiationCompletedTask(Connection* conn, Status negotiation_status)
     : conn_(conn),
-      negotiation_status_(negotiation_status) {
+      negotiation_status_(std::move(negotiation_status)) {
   }
 
   virtual void Run(ReactorThread *rthread) OVERRIDE {
@@ -687,8 +645,8 @@ class NegotiationCompletedTask : public ReactorTask {
 
   virtual void Abort(const Status &status) OVERRIDE {
     DCHECK(conn_->reactor_thread()->reactor()->closing());
-    VLOG(1) << "Failed connection negotiation due to shut down reactor thread: " <<
-        status.ToString();
+    VLOG(1) << "Failed connection negotiation due to shut down reactor thread: "
+            << status.ToString();
     delete this;
   }
 

http://git-wip-us.apache.org/repos/asf/kudu/blob/dc852535/src/kudu/rpc/connection.h
----------------------------------------------------------------------
diff --git a/src/kudu/rpc/connection.h b/src/kudu/rpc/connection.h
index ef70483..ad5ffed 100644
--- a/src/kudu/rpc/connection.h
+++ b/src/kudu/rpc/connection.h
@@ -18,22 +18,22 @@
 #ifndef KUDU_RPC_CONNECTION_H
 #define KUDU_RPC_CONNECTION_H
 
-#include <boost/intrusive/list.hpp>
-#include <ev++.h>
-#include <memory>
 #include <stdint.h>
-#include <unordered_map>
 
 #include <limits>
+#include <memory>
+#include <set>
 #include <string>
+#include <unordered_map>
 #include <vector>
 
+#include <boost/intrusive/list.hpp>
+#include <ev++.h>
+
 #include "kudu/gutil/gscoped_ptr.h"
 #include "kudu/gutil/ref_counted.h"
-#include "kudu/rpc/client_negotiation.h"
 #include "kudu/rpc/inbound_call.h"
 #include "kudu/rpc/outbound_call.h"
-#include "kudu/rpc/server_negotiation.h"
 #include "kudu/rpc/transfer.h"
 #include "kudu/util/monotime.h"
 #include "kudu/util/net/sockaddr.h"
@@ -80,7 +80,9 @@ class Connection : public RefCountedThreadSafe<Connection> {
   // remote: the address of the remote end
   // socket: the socket to take ownership of.
   // direction: whether we are the client or server side
-  Connection(ReactorThread *reactor_thread, Sockaddr remote, Socket* socket,
+  Connection(ReactorThread *reactor_thread,
+             Sockaddr remote,
+             std::unique_ptr<Socket> socket,
              Direction direction);
 
   // Set underlying socket to non-blocking (or blocking) mode.
@@ -143,23 +145,8 @@ class Connection : public RefCountedThreadSafe<Connection> {
 
   Socket* socket() { return socket_.get(); }
 
-  // Return SASL client instance for this connection.
-  SaslClient &sasl_client() { return sasl_client_; }
-
-  // Return SASL server instance for this connection.
-  SaslServer &sasl_server() { return sasl_server_; }
-
-  // Initialize underlying SSLSocket if SSL is enabled.
-  Status InitSSLIfNecessary();
-
-  // Initialize SASL client before negotiation begins.
-  Status InitSaslClient();
-
-  // Initialize SASL server before negotiation begins.
-  Status InitSaslServer();
-
   // Go through the process of transferring control of the underlying socket back to the Reactor.
-  void CompleteNegotiation(const Status &negotiation_status);
+  void CompleteNegotiation(const Status& negotiation_status);
 
   // Indicate that negotiation is complete and that the Reactor is now in control of the socket.
   void MarkNegotiationComplete();
@@ -167,7 +154,19 @@ class Connection : public RefCountedThreadSafe<Connection> {
   Status DumpPB(const DumpRunningRpcsRequestPB& req,
                 RpcConnectionPB* resp);
 
-  ReactorThread *reactor_thread() const { return reactor_thread_; }
+  ReactorThread* reactor_thread() const { return reactor_thread_; }
+
+  std::unique_ptr<Socket> release_socket() {
+    return std::move(socket_);
+  }
+
+  void adopt_socket(std::unique_ptr<Socket> socket) {
+    socket_ = std::move(socket);
+  }
+
+  void set_remote_features(std::set<RpcFeatureFlag> remote_features) {
+    remote_features_ = std::move(remote_features);
+  }
 
  private:
   friend struct CallAwaitingResponse;
@@ -226,7 +225,7 @@ class Connection : public RefCountedThreadSafe<Connection> {
   void QueueOutbound(gscoped_ptr<OutboundTransfer> transfer);
 
   // The reactor thread that created this connection.
-  ReactorThread * const reactor_thread_;
+  ReactorThread* const reactor_thread_;
 
   // The remote address we're talking to.
   const Sockaddr remote_;
@@ -277,17 +276,14 @@ class Connection : public RefCountedThreadSafe<Connection> {
   // when serializing calls.
   std::vector<Slice> slices_tmp_;
 
+  // RPC features supported by the remote end of the connection.
+  std::set<RpcFeatureFlag> remote_features_;
+
   // Pool from which CallAwaitingResponse objects are allocated.
   // Also a funny name.
   ObjectPool<CallAwaitingResponse> car_pool_;
   typedef ObjectPool<CallAwaitingResponse>::scoped_ptr scoped_car;
 
-  // SASL client instance used for connection negotiation when Direction == CLIENT.
-  SaslClient sasl_client_;
-
-  // SASL server instance used for connection negotiation when Direction == SERVER.
-  SaslServer sasl_server_;
-
   // Whether we completed connection negotiation.
   bool negotiation_complete_;
 };

http://git-wip-us.apache.org/repos/asf/kudu/blob/dc852535/src/kudu/rpc/messenger.cc
----------------------------------------------------------------------
diff --git a/src/kudu/rpc/messenger.cc b/src/kudu/rpc/messenger.cc
index 103f72c..5a5ed7a 100644
--- a/src/kudu/rpc/messenger.cc
+++ b/src/kudu/rpc/messenger.cc
@@ -41,6 +41,7 @@
 #include "kudu/rpc/rpc_service.h"
 #include "kudu/rpc/rpcz_store.h"
 #include "kudu/rpc/sasl_common.h"
+#include "kudu/rpc/server_negotiation.h"
 #include "kudu/rpc/transfer.h"
 #include "kudu/security/ssl_factory.h"
 #include "kudu/util/errno.h"
@@ -56,7 +57,6 @@ using std::string;
 using std::shared_ptr;
 using strings::Substitute;
 
-
 DEFINE_string(rpc_ssl_server_certificate, "", "Path to the SSL certificate to be used for the RPC "
     "layer.");
 DEFINE_string(rpc_ssl_private_key, "",
@@ -130,7 +130,7 @@ MessengerBuilder &MessengerBuilder::set_metric_entity(
 }
 
 Status MessengerBuilder::Build(shared_ptr<Messenger> *msgr) {
-  RETURN_NOT_OK(SaslInit(kSaslAppName)); // Initialize SASL library before we start making requests
+  RETURN_NOT_OK(SaslInit()); // Initialize SASL library before we start making requests
   Messenger* new_msgr(new Messenger(*this));
   Status build_status = new_msgr->Init();
   if (!build_status.ok()) {
@@ -192,7 +192,7 @@ Status Messenger::AddAcceptorPool(const Sockaddr &accept_addr,
   // that everything is set up correctly. This way we'll generate errors on
   // startup rather than later on when we first receive a client connection.
   if (FLAGS_server_require_kerberos) {
-    RETURN_NOT_OK_PREPEND(SaslServer::PreflightCheckGSSAPI(kSaslAppName),
+    RETURN_NOT_OK_PREPEND(ServerNegotiation::PreflightCheckGSSAPI(),
                           "GSSAPI/Kerberos not properly configured");
   }
 
@@ -296,11 +296,11 @@ Reactor* Messenger::RemoteToReactor(const Sockaddr &remote) {
   return reactors_[reactor_idx];
 }
 
-
 Status Messenger::Init() {
   Status status;
   ssl_enabled_ = !FLAGS_rpc_ssl_server_certificate.empty() || !FLAGS_rpc_ssl_private_key.empty()
                    || !FLAGS_rpc_ssl_certificate_authority.empty();
+
   if (ssl_enabled_) {
     ssl_factory_.reset(new SSLFactory());
     RETURN_NOT_OK(ssl_factory_->Init());

http://git-wip-us.apache.org/repos/asf/kudu/blob/dc852535/src/kudu/rpc/negotiation-test.cc
----------------------------------------------------------------------
diff --git a/src/kudu/rpc/negotiation-test.cc b/src/kudu/rpc/negotiation-test.cc
index 2bb73e4..984dff1 100644
--- a/src/kudu/rpc/negotiation-test.cc
+++ b/src/kudu/rpc/negotiation-test.cc
@@ -43,6 +43,7 @@
 
 using std::string;
 using std::thread;
+using std::unique_ptr;
 
 // HACK: MIT Kerberos doesn't have any way of determining its version number,
 // but the error messages in krb5-1.10 and earlier are broken due to
@@ -62,35 +63,25 @@ DEFINE_bool(is_test_child, false,
 namespace kudu {
 namespace rpc {
 
-class TestSaslRpc : public RpcTestBase {
+class TestNegotiation : public RpcTestBase {
  public:
   virtual void SetUp() OVERRIDE {
     RpcTestBase::SetUp();
-    ASSERT_OK(SaslInit(kSaslAppName));
+    ASSERT_OK(SaslInit());
   }
 };
 
-// Test basic initialization of the objects.
-TEST_F(TestSaslRpc, TestBasicInit) {
-  SaslServer server(kSaslAppName, nullptr);
-  server.EnablePlain();
-  ASSERT_OK(server.Init(kSaslAppName));
-  SaslClient client(kSaslAppName, nullptr);
-  client.EnablePlain("test", "test");
-  ASSERT_OK(client.Init(kSaslAppName));
-}
-
-// A "Callable" that takes a Socket* param, for use with starting a thread.
-// Can be used for SaslServer or SaslClient threads.
-typedef std::function<void(Socket*)> SocketCallable;
+// A "Callable" that takes a socket for use with starting a thread.
+// Can be used for ServerNegotiation or ClientNegotiation threads.
+typedef std::function<void(unique_ptr<Socket>)> SocketCallable;
 
 // Call Accept() on the socket, then pass the connection to the server runner
 static void RunAcceptingDelegator(Socket* acceptor,
                                   const SocketCallable& server_runner) {
-  Socket conn;
+  unique_ptr<Socket> conn(new Socket());
   Sockaddr remote;
-  CHECK_OK(acceptor->Accept(&conn, &remote, 0));
-  server_runner(&conn);
+  CHECK_OK(acceptor->Accept(conn.get(), &remote, 0));
+  server_runner(std::move(conn));
 }
 
 // Set up a socket and run a SASL negotiation.
@@ -103,45 +94,38 @@ static void RunNegotiationTest(const SocketCallable& server_runner,
   ASSERT_OK(server_sock.GetSocketAddress(&server_bind_addr));
   thread server(RunAcceptingDelegator, &server_sock, server_runner);
 
-  Socket client_sock;
-  CHECK_OK(client_sock.Init(0));
-  ASSERT_OK(client_sock.Connect(server_bind_addr));
-  thread client(client_runner, &client_sock);
+  unique_ptr<Socket> client_sock(new Socket());
+  CHECK_OK(client_sock->Init(0));
+  ASSERT_OK(client_sock->Connect(server_bind_addr));
+  thread client(client_runner, std::move(client_sock));
 
   LOG(INFO) << "Waiting for test threads to terminate...";
   client.join();
   LOG(INFO) << "Client thread terminated.";
 
-  // TODO(todd): if the client fails to negotiate, it doesn't
-  // always result in sending a nice error message to the
-  // other side.
-  client_sock.Close();
-
   server.join();
   LOG(INFO) << "Server thread terminated.";
 }
 
 ////////////////////////////////////////////////////////////////////////////////
 
-static void RunPlainNegotiationServer(Socket* conn) {
-  SaslServer sasl_server(kSaslAppName, conn);
-  CHECK_OK(sasl_server.EnablePlain());
-  CHECK_OK(sasl_server.Init(kSaslAppName));
-  CHECK_OK(sasl_server.Negotiate());
-  CHECK(ContainsKey(sasl_server.client_features(), APPLICATION_FEATURE_FLAGS));
-  CHECK_EQ("my-username", sasl_server.authenticated_user());
+static void RunPlainNegotiationServer(unique_ptr<Socket> socket) {
+  ServerNegotiation server_negotiation(std::move(socket));
+  CHECK_OK(server_negotiation.EnablePlain());
+  CHECK_OK(server_negotiation.Negotiate());
+  CHECK(ContainsKey(server_negotiation.client_features(), APPLICATION_FEATURE_FLAGS));
+  CHECK_EQ("my-username", server_negotiation.authenticated_user());
 }
 
-static void RunPlainNegotiationClient(Socket* conn) {
-  SaslClient sasl_client(kSaslAppName, conn);
-  CHECK_OK(sasl_client.EnablePlain("my-username", "ignored password"));
-  CHECK_OK(sasl_client.Init(kSaslAppName));
-  CHECK_OK(sasl_client.Negotiate());
-  CHECK(ContainsKey(sasl_client.server_features(), APPLICATION_FEATURE_FLAGS));
+static void RunPlainNegotiationClient(unique_ptr<Socket> socket) {
+  ClientNegotiation client_negotiation(std::move(socket));
+  CHECK_OK(client_negotiation.EnablePlain("my-username", "ignored password"));
+  CHECK_OK(client_negotiation.Negotiate());
+  CHECK(ContainsKey(client_negotiation.server_features(), APPLICATION_FEATURE_FLAGS));
 }
 
 // Test SASL negotiation using the PLAIN mechanism over a socket.
-TEST_F(TestSaslRpc, TestPlainNegotiation) {
+TEST_F(TestNegotiation, TestPlainNegotiation) {
   RunNegotiationTest(RunPlainNegotiationServer, RunPlainNegotiationClient);
 }
 
@@ -153,33 +137,29 @@ using CheckerFunction = std::function<void(const Status&, T&)>;
 
 // Run GSSAPI negotiation from the server side. Runs
 // 'post_check' after negotiation to verify the result.
-static void RunGSSAPINegotiationServer(
-    Socket* conn,
-    const CheckerFunction<SaslServer>& post_check) {
-  SaslServer sasl_server(kSaslAppName, conn);
-  sasl_server.set_server_fqdn("127.0.0.1");
-  CHECK_OK(sasl_server.EnableGSSAPI());
-  CHECK_OK(sasl_server.Init(kSaslAppName));
-  post_check(sasl_server.Negotiate(), sasl_server);
+static void RunGSSAPINegotiationServer(unique_ptr<Socket> socket,
+                                       const CheckerFunction<ServerNegotiation>& post_check) {
+  ServerNegotiation server_negotiation(std::move(socket));
+  server_negotiation.set_server_fqdn("127.0.0.1");
+  CHECK_OK(server_negotiation.EnableGSSAPI());
+  post_check(server_negotiation.Negotiate(), server_negotiation);
 }
 
 // Run GSSAPI negotiation from the client side. Runs
 // 'post_check' after negotiation to verify the result.
-static void RunGSSAPINegotiationClient(
-    Socket* conn,
-    const CheckerFunction<SaslClient>& post_check) {
-  SaslClient sasl_client(kSaslAppName, conn);
-  sasl_client.set_server_fqdn("127.0.0.1");
-  CHECK_OK(sasl_client.EnableGSSAPI());
-  CHECK_OK(sasl_client.Init(kSaslAppName));
-  post_check(sasl_client.Negotiate(), sasl_client);
+static void RunGSSAPINegotiationClient(unique_ptr<Socket> conn,
+                                       const CheckerFunction<ClientNegotiation>& post_check) {
+  ClientNegotiation client_negotiation(std::move(conn));
+  client_negotiation.set_server_fqdn("127.0.0.1");
+  CHECK_OK(client_negotiation.EnableGSSAPI());
+  post_check(client_negotiation.Negotiate(), client_negotiation);
 }
 
 // Test configuring a client to allow but not require Kerberos/GSSAPI,
 // and connect to a server which requires Kerberos/GSSAPI.
 //
 // They should negotiate to use Kerberos/GSSAPI.
-TEST_F(TestSaslRpc, TestRestrictiveServer_NonRestrictiveClient) {
+TEST_F(TestNegotiation, TestRestrictiveServer_NonRestrictiveClient) {
   MiniKdc kdc;
   ASSERT_OK(kdc.Start());
 
@@ -196,27 +176,26 @@ TEST_F(TestSaslRpc, TestRestrictiveServer_NonRestrictiveClient) {
   // Authentication should now succeed on both sides.
   RunNegotiationTest(
       std::bind(RunGSSAPINegotiationServer, std::placeholders::_1,
-                [](const Status& s, SaslServer& server) {
+                [](const Status& s, ServerNegotiation& server) {
                   CHECK_OK(s);
                   CHECK_EQ(SaslMechanism::GSSAPI, server.negotiated_mechanism());
                   CHECK_EQ("testuser", server.authenticated_user());
                 }),
-      [](Socket* conn) {
-        SaslClient sasl_client(kSaslAppName, conn);
-        sasl_client.set_server_fqdn("127.0.0.1");
+      [](unique_ptr<Socket> socket) {
+        ClientNegotiation client_negotiation(std::move(socket));
+        client_negotiation.set_server_fqdn("127.0.0.1");
         // The client enables both PLAIN and GSSAPI.
-        CHECK_OK(sasl_client.EnablePlain("foo", "bar"));
-        CHECK_OK(sasl_client.EnableGSSAPI());
-        CHECK_OK(sasl_client.Init(kSaslAppName));
-        CHECK_OK(sasl_client.Negotiate());
-        CHECK_EQ(SaslMechanism::GSSAPI, sasl_client.negotiated_mechanism());
+        CHECK_OK(client_negotiation.EnablePlain("foo", "bar"));
+        CHECK_OK(client_negotiation.EnableGSSAPI());
+        CHECK_OK(client_negotiation.Negotiate());
+        CHECK_EQ(SaslMechanism::GSSAPI, client_negotiation.negotiated_mechanism());
       });
 }
 
 // Test configuring a client to only support PLAIN, and a server which
 // only supports GSSAPI. This would happen, for example, if an old Kudu
 // client tries to talk to a secure-only cluster.
-TEST_F(TestSaslRpc, TestNoMatchingMechanisms) {
+TEST_F(TestNegotiation, TestNoMatchingMechanisms) {
   MiniKdc kdc;
   ASSERT_OK(kdc.Start());
 
@@ -227,7 +206,7 @@ TEST_F(TestSaslRpc, TestNoMatchingMechanisms) {
 
   RunNegotiationTest(
       std::bind(RunGSSAPINegotiationServer, std::placeholders::_1,
-                [](const Status& s, SaslServer& server) {
+                [](const Status& s, ServerNegotiation& server) {
                   // The client fails to find a matching mechanism and
                   // doesn't send any failure message to the server.
                   // Instead, it just disconnects.
@@ -235,19 +214,18 @@ TEST_F(TestSaslRpc, TestNoMatchingMechanisms) {
                   // TODO(todd): this could produce a better message!
                   ASSERT_STR_CONTAINS(s.ToString(), "got EOF from remote");
                 }),
-      [](Socket* conn) {
-        SaslClient sasl_client(kSaslAppName, conn);
-        sasl_client.set_server_fqdn("127.0.0.1");
+      [](unique_ptr<Socket> socket) {
+        ClientNegotiation client_negotiation(std::move(socket));
+        client_negotiation.set_server_fqdn("127.0.0.1");
         // The client enables both PLAIN and GSSAPI.
-        CHECK_OK(sasl_client.EnablePlain("foo", "bar"));
-        CHECK_OK(sasl_client.Init(kSaslAppName));
-        Status s = sasl_client.Negotiate();
+        CHECK_OK(client_negotiation.EnablePlain("foo", "bar"));
+        Status s = client_negotiation.Negotiate();
         ASSERT_STR_CONTAINS(s.ToString(), "client was missing the required SASL module");
       });
 }
 
 // Test SASL negotiation using the GSSAPI (kerberos) mechanism over a socket.
-TEST_F(TestSaslRpc, TestGSSAPINegotiation) {
+TEST_F(TestNegotiation, TestGSSAPINegotiation) {
   MiniKdc kdc;
   ASSERT_OK(kdc.Start());
 
@@ -264,13 +242,13 @@ TEST_F(TestSaslRpc, TestGSSAPINegotiation) {
   // Authentication should succeed on both sides.
   RunNegotiationTest(
       std::bind(RunGSSAPINegotiationServer, std::placeholders::_1,
-                [](const Status& s, SaslServer& server) {
+                [](const Status& s, ServerNegotiation& server) {
                   CHECK_OK(s);
                   CHECK_EQ(SaslMechanism::GSSAPI, server.negotiated_mechanism());
                   CHECK_EQ("testuser", server.authenticated_user());
                 }),
       std::bind(RunGSSAPINegotiationClient, std::placeholders::_1,
-                [](const Status& s, SaslClient& client) {
+                [](const Status& s, ClientNegotiation& client) {
                   CHECK_OK(s);
                   CHECK_EQ(SaslMechanism::GSSAPI, client.negotiated_mechanism());
                 }));
@@ -281,7 +259,7 @@ TEST_F(TestSaslRpc, TestGSSAPINegotiation) {
 // This test is ignored on macOS because the system Kerberos implementation
 // (Heimdal) caches the non-existence of client credentials, which causes futher
 // tests to fail.
-TEST_F(TestSaslRpc, TestGSSAPIInvalidNegotiation) {
+TEST_F(TestNegotiation, TestGSSAPIInvalidNegotiation) {
   MiniKdc kdc;
   ASSERT_OK(kdc.Start());
 
@@ -289,7 +267,7 @@ TEST_F(TestSaslRpc, TestGSSAPIInvalidNegotiation) {
   // sides.
   RunNegotiationTest(
       std::bind(RunGSSAPINegotiationServer, std::placeholders::_1,
-                [](const Status& s, SaslServer& server) {
+                [](const Status& s, ServerNegotiation& server) {
                   // The client notices there are no credentials and
                   // doesn't send any failure message to the server.
                   // Instead, it just disconnects.
@@ -299,7 +277,7 @@ TEST_F(TestSaslRpc, TestGSSAPIInvalidNegotiation) {
                   CHECK(s.IsNetworkError());
                 }),
       std::bind(RunGSSAPINegotiationClient, std::placeholders::_1,
-                [](const Status& s, SaslClient& client) {
+                [](const Status& s, ClientNegotiation& client) {
                   CHECK(s.IsNotAuthorized());
 #ifndef KRB5_VERSION_LE_1_10
                   CHECK_GT(s.ToString().find("No Kerberos credentials available"), 0);
@@ -316,14 +294,14 @@ TEST_F(TestSaslRpc, TestGSSAPIInvalidNegotiation) {
   // sides.
   RunNegotiationTest(
       std::bind(RunGSSAPINegotiationServer, std::placeholders::_1,
-                [](const Status& s, SaslServer& server) {
+                [](const Status& s, ServerNegotiation& server) {
                   // The client notices there are no credentials and
                   // doesn't send any failure message to the server.
                   // Instead, it just disconnects.
                   CHECK(s.IsNetworkError());
                 }),
       std::bind(RunGSSAPINegotiationClient, std::placeholders::_1,
-                [](const Status& s, SaslClient& client) {
+                [](const Status& s, ClientNegotiation& client) {
                   CHECK(s.IsNotAuthorized());
 #ifndef KRB5_VERSION_LE_1_10
                   CHECK_EQ(s.message().ToString(), "No Kerberos credentials available");
@@ -343,7 +321,7 @@ TEST_F(TestSaslRpc, TestGSSAPIInvalidNegotiation) {
 
   RunNegotiationTest(
       std::bind(RunGSSAPINegotiationServer, std::placeholders::_1,
-                [](const Status& s, SaslServer& server) {
+                [](const Status& s, ServerNegotiation& server) {
                   CHECK(s.IsNotAuthorized());
 #ifndef KRB5_VERSION_LE_1_10
                   ASSERT_STR_CONTAINS(s.ToString(),
@@ -351,7 +329,7 @@ TEST_F(TestSaslRpc, TestGSSAPIInvalidNegotiation) {
 #endif
                 }),
       std::bind(RunGSSAPINegotiationClient, std::placeholders::_1,
-                [](const Status& s, SaslClient& client) {
+                [](const Status& s, ClientNegotiation& client) {
                   CHECK(s.IsNotAuthorized());
 #ifndef KRB5_VERSION_LE_1_10
                   ASSERT_STR_CONTAINS(s.ToString(),
@@ -368,9 +346,9 @@ TEST_F(TestSaslRpc, TestGSSAPIInvalidNegotiation) {
 // This is ignored on macOS because the system Kerberos implementation does not
 // fail the preflight check when the keytab is inaccessible, probably because
 // the preflight check passes a 0-length token.
-TEST_F(TestSaslRpc, TestPreflight) {
+TEST_F(TestNegotiation, TestPreflight) {
   // Try pre-flight with no keytab.
-  Status s = SaslServer::PreflightCheckGSSAPI(kSaslAppName);
+  Status s = ServerNegotiation::PreflightCheckGSSAPI();
   ASSERT_FALSE(s.ok());
 #ifndef KRB5_VERSION_LE_1_10
   ASSERT_STR_MATCHES(s.ToString(), "Key table file.*not found");
@@ -383,11 +361,11 @@ TEST_F(TestSaslRpc, TestPreflight) {
   ASSERT_OK(kdc.CreateServiceKeytab("kudu/127.0.0.1", &kt_path));
   CHECK_ERR(setenv("KRB5_KTNAME", kt_path.c_str(), 1 /*replace*/));
 
-  ASSERT_OK(SaslServer::PreflightCheckGSSAPI(kSaslAppName));
+  ASSERT_OK(ServerNegotiation::PreflightCheckGSSAPI());
 
   // Try with an inaccessible keytab.
   CHECK_ERR(chmod(kt_path.c_str(), 0000));
-  s = SaslServer::PreflightCheckGSSAPI(kSaslAppName);
+  s = ServerNegotiation::PreflightCheckGSSAPI();
   ASSERT_FALSE(s.ok());
 #ifndef KRB5_VERSION_LE_1_10
   ASSERT_STR_MATCHES(s.ToString(), "error accessing keytab: Permission denied");
@@ -397,7 +375,7 @@ TEST_F(TestSaslRpc, TestPreflight) {
   // Try with a keytab that has the wrong credentials.
   ASSERT_OK(kdc.CreateServiceKeytab("wrong-service/127.0.0.1", &kt_path));
   CHECK_ERR(setenv("KRB5_KTNAME", kt_path.c_str(), 1 /*replace*/));
-  s = SaslServer::PreflightCheckGSSAPI(kSaslAppName);
+  s = ServerNegotiation::PreflightCheckGSSAPI();
   ASSERT_FALSE(s.ok());
 #ifndef KRB5_VERSION_LE_1_10
   ASSERT_STR_MATCHES(s.ToString(), "No key table entry found matching kudu/.*");
@@ -407,55 +385,51 @@ TEST_F(TestSaslRpc, TestPreflight) {
 
 ////////////////////////////////////////////////////////////////////////////////
 
-static void RunTimeoutExpectingServer(Socket* conn) {
-  SaslServer sasl_server(kSaslAppName, conn);
-  CHECK_OK(sasl_server.EnablePlain());
-  CHECK_OK(sasl_server.Init(kSaslAppName));
-  Status s = sasl_server.Negotiate();
+static void RunTimeoutExpectingServer(unique_ptr<Socket> socket) {
+  ServerNegotiation server_negotiation(std::move(socket));
+  CHECK_OK(server_negotiation.EnablePlain());
+  Status s = server_negotiation.Negotiate();
   ASSERT_TRUE(s.IsNetworkError()) << "Expected client to time out and close the connection. Got: "
-      << s.ToString();
+                                  << s.ToString();
 }
 
-static void RunTimeoutNegotiationClient(Socket* sock) {
-  SaslClient sasl_client(kSaslAppName, sock);
-  CHECK_OK(sasl_client.EnablePlain("test", "test"));
-  CHECK_OK(sasl_client.Init(kSaslAppName));
+static void RunTimeoutNegotiationClient(unique_ptr<Socket> sock) {
+  ClientNegotiation client_negotiation(std::move(sock));
+  CHECK_OK(client_negotiation.EnablePlain("test", "test"));
   MonoTime deadline = MonoTime::Now() - MonoDelta::FromMilliseconds(100L);
-  sasl_client.set_deadline(deadline);
-  Status s = sasl_client.Negotiate();
+  client_negotiation.set_deadline(deadline);
+  Status s = client_negotiation.Negotiate();
   ASSERT_TRUE(s.IsTimedOut()) << "Expected timeout! Got: " << s.ToString();
-  CHECK_OK(sock->Shutdown(true, true));
+  CHECK_OK(client_negotiation.socket()->Shutdown(true, true));
 }
 
 // Ensure that the client times out.
-TEST_F(TestSaslRpc, TestClientTimeout) {
+TEST_F(TestNegotiation, TestClientTimeout) {
   RunNegotiationTest(RunTimeoutExpectingServer, RunTimeoutNegotiationClient);
 }
 
 ////////////////////////////////////////////////////////////////////////////////
 
-static void RunTimeoutNegotiationServer(Socket* sock) {
-  SaslServer sasl_server(kSaslAppName, sock);
-  CHECK_OK(sasl_server.EnablePlain());
-  CHECK_OK(sasl_server.Init(kSaslAppName));
+static void RunTimeoutNegotiationServer(unique_ptr<Socket> socket) {
+  ServerNegotiation server_negotiation(std::move(socket));
+  CHECK_OK(server_negotiation.EnablePlain());
   MonoTime deadline = MonoTime::Now() - MonoDelta::FromMilliseconds(100L);
-  sasl_server.set_deadline(deadline);
-  Status s = sasl_server.Negotiate();
+  server_negotiation.set_deadline(deadline);
+  Status s = server_negotiation.Negotiate();
   ASSERT_TRUE(s.IsTimedOut()) << "Expected timeout! Got: " << s.ToString();
-  CHECK_OK(sock->Close());
+  CHECK_OK(server_negotiation.socket()->Close());
 }
 
-static void RunTimeoutExpectingClient(Socket* conn) {
-  SaslClient sasl_client(kSaslAppName, conn);
-  CHECK_OK(sasl_client.EnablePlain("test", "test"));
-  CHECK_OK(sasl_client.Init(kSaslAppName));
-  Status s = sasl_client.Negotiate();
+static void RunTimeoutExpectingClient(unique_ptr<Socket> socket) {
+  ClientNegotiation client_negotiation(std::move(socket));
+  CHECK_OK(client_negotiation.EnablePlain("test", "test"));
+  Status s = client_negotiation.Negotiate();
   ASSERT_TRUE(s.IsNetworkError()) << "Expected server to time out and close the connection. Got: "
       << s.ToString();
 }
 
 // Ensure that the server times out.
-TEST_F(TestSaslRpc, TestServerTimeout) {
+TEST_F(TestNegotiation, TestServerTimeout) {
   RunNegotiationTest(RunTimeoutNegotiationServer, RunTimeoutExpectingClient);
 }
 
@@ -498,7 +472,7 @@ class TestDisableInit : public KuduTest {
 TEST_F(TestDisableInit, TestDisableSasl_NotInitialized) {
   DoTest([]() {
       CHECK_OK(DisableSaslInitialization());
-      Status s = SaslInit("kudu");
+      Status s = SaslInit();
       ASSERT_STR_CONTAINS(s.ToString(), "was disabled, but SASL was not externally initialized");
     });
 }
@@ -509,7 +483,7 @@ TEST_F(TestDisableInit, TestDisableSasl_Good) {
       rpc::internal::SaslSetMutex();
       sasl_client_init(NULL);
       CHECK_OK(DisableSaslInitialization());
-      ASSERT_OK(SaslInit("kudu"));
+      ASSERT_OK(SaslInit());
     });
 }
 
@@ -520,7 +494,7 @@ TEST_F(TestDisableInit, TestMultipleSaslInit) {
   DoTest([]() {
       rpc::internal::SaslSetMutex();
       sasl_client_init(NULL);
-      ASSERT_OK(SaslInit("kudu"));
+      ASSERT_OK(SaslInit());
     }, &stderr);
   // If we are the parent, we should see the warning from the child that it automatically
   // skipped initialization because it detected that it was already initialized.
@@ -538,7 +512,7 @@ TEST_F(TestDisableInit, TestDisableSasl_NoMutexImpl) {
   DoTest([]() {
       sasl_client_init(NULL);
       CHECK_OK(DisableSaslInitialization());
-      ASSERT_OK(SaslInit("kudu"));
+      ASSERT_OK(SaslInit());
     }, &stderr);
   // If we are the parent, we should see the warning from the child.
   if (!FLAGS_is_test_child) {
@@ -552,7 +526,7 @@ TEST_F(TestDisableInit, TestMultipleSaslInit_NoMutexImpl) {
   string stderr;
   DoTest([]() {
       sasl_client_init(NULL);
-      ASSERT_OK(SaslInit("kudu"));
+      ASSERT_OK(SaslInit());
     }, &stderr);
   // If we are the parent, we should see the warning from the child that it automatically
   // skipped initialization because it detected that it was already initialized.

http://git-wip-us.apache.org/repos/asf/kudu/blob/dc852535/src/kudu/rpc/negotiation.cc
----------------------------------------------------------------------
diff --git a/src/kudu/rpc/negotiation.cc b/src/kudu/rpc/negotiation.cc
index 859016c..4b387ca 100644
--- a/src/kudu/rpc/negotiation.cc
+++ b/src/kudu/rpc/negotiation.cc
@@ -30,12 +30,15 @@
 #include "kudu/rpc/blocking_ops.h"
 #include "kudu/rpc/client_negotiation.h"
 #include "kudu/rpc/connection.h"
+#include "kudu/rpc/messenger.h"
 #include "kudu/rpc/reactor.h"
 #include "kudu/rpc/rpc_header.pb.h"
 #include "kudu/rpc/sasl_common.h"
 #include "kudu/rpc/server_negotiation.h"
+#include "kudu/security/ssl_socket.h"
 #include "kudu/util/errno.h"
 #include "kudu/util/flag_tags.h"
+#include "kudu/util/logging.h"
 #include "kudu/util/status.h"
 #include "kudu/util/trace.h"
 
@@ -50,62 +53,17 @@ DEFINE_int32(rpc_negotiation_inject_delay_ms, 0,
              "the RPC negotiation process on the server side.");
 TAG_FLAG(rpc_negotiation_inject_delay_ms, unsafe);
 
-namespace kudu {
-namespace rpc {
+DECLARE_bool(server_require_kerberos);
 
-using std::shared_ptr;
 using strings::Substitute;
 
-// Client: Send ConnectionContextPB message based on information stored in the Connection object.
-static Status SendConnectionContext(Connection* conn, const MonoTime& deadline) {
-  TRACE("Sending connection context");
-  RequestHeader header;
-  header.set_call_id(kConnectionContextCallId);
-
-  ConnectionContextPB conn_context;
-  // This field is deprecated but used by servers <Kudu 1.1. Newer server versions ignore
-  // this and use the SASL-provided username instead.
-  conn_context.mutable_deprecated_user_info()->set_real_user(
-      conn->user_credentials().real_user());
-  return SendFramedMessageBlocking(conn->socket(), header, conn_context, deadline);
-}
-
-// Server: Receive ConnectionContextPB message and update the corresponding fields in the
-// associated Connection object. Perform validation against SASL-negotiated information
-// as needed.
-static Status RecvConnectionContext(Connection* conn, const MonoTime& deadline) {
-  TRACE("Waiting for connection context");
-  faststring recv_buf(1024); // Should be plenty for a ConnectionContextPB message.
-  RequestHeader header;
-  Slice param_buf;
-  RETURN_NOT_OK(ReceiveFramedMessageBlocking(conn->socket(), &recv_buf,
-                                             &header, &param_buf, deadline));
-  DCHECK(header.IsInitialized());
-
-  if (header.call_id() != kConnectionContextCallId) {
-    return Status::IllegalState("Expected ConnectionContext callid, received",
-        Substitute("$0", header.call_id()));
-  }
-
-  ConnectionContextPB conn_context;
-  if (!conn_context.ParseFromArray(param_buf.data(), param_buf.size())) {
-    return Status::InvalidArgument("Invalid ConnectionContextPB message, missing fields",
-                                   conn_context.InitializationErrorString());
-  }
-
-  if (conn->sasl_server().authenticated_user().empty()) {
-    return Status::NotAuthorized("No user was authenticated");
-  }
-
-  conn->mutable_user_credentials()->set_real_user(conn->sasl_server().authenticated_user());
-
-  return Status::OK();
-}
+namespace kudu {
+namespace rpc {
 
 // Wait for the client connection to be established and become ready for writing.
-static Status WaitForClientConnect(Connection* conn, const MonoTime& deadline) {
+static Status WaitForClientConnect(Socket* socket, const MonoTime& deadline) {
   TRACE("Waiting for socket to connect");
-  int fd = conn->socket()->GetFd();
+  int fd = socket->GetFd();
   struct pollfd poll_fd;
   poll_fd.fd = fd;
   poll_fd.events = POLLOUT;
@@ -164,57 +122,99 @@ static Status WaitForClientConnect(Connection* conn, const MonoTime& deadline) {
 }
 
 // Disable / reset socket timeouts.
-static Status DisableSocketTimeouts(Connection* conn) {
-  RETURN_NOT_OK(conn->socket()->SetSendTimeout(MonoDelta::FromNanoseconds(0L)));
-  RETURN_NOT_OK(conn->socket()->SetRecvTimeout(MonoDelta::FromNanoseconds(0L)));
+static Status DisableSocketTimeouts(Socket* socket) {
+  RETURN_NOT_OK(socket->SetSendTimeout(MonoDelta::FromNanoseconds(0L)));
+  RETURN_NOT_OK(socket->SetRecvTimeout(MonoDelta::FromNanoseconds(0L)));
   return Status::OK();
 }
 
 // Perform client negotiation. We don't LOG() anything, we leave that to our caller.
-static Status DoClientNegotiation(Connection* conn,
-                                  const MonoTime& deadline) {
-  // The SASL initialization on the client side can be relatively heavy-weight
-  // (it may result in DNS queries in the case of GSSAPI).
-  // So, we do it while the connect() is still in progress to reduce latency.
-  //
-  // TODO(todd): we should consider doing this even before connecting, since as soon
-  // as we connect, we are tying up a negotiation thread on the server side.
-
-  RETURN_NOT_OK(conn->InitSaslClient());
-
-  RETURN_NOT_OK(WaitForClientConnect(conn, deadline));
-  RETURN_NOT_OK(conn->SetNonBlocking(false));
-  RETURN_NOT_OK(conn->InitSSLIfNecessary());
-  conn->sasl_client().set_deadline(deadline);
-  RETURN_NOT_OK(conn->sasl_client().Negotiate());
-  RETURN_NOT_OK(SendConnectionContext(conn, deadline));
-  RETURN_NOT_OK(DisableSocketTimeouts(conn));
+static Status DoClientNegotiation(Connection* conn, MonoTime deadline) {
+  ClientNegotiation client_negotiation(conn->release_socket());
+
+  // Note that the fqdn is an IP address here: we've already lost whatever DNS
+  // name the client was attempting to use. Unless krb5 is configured with 'rdns
+  // = false', it will automatically take care of reversing this address to its
+  // canonical hostname to determine the expected server principal.
+  client_negotiation.set_server_fqdn(conn->remote().host());
+
+  Status s = client_negotiation.EnableGSSAPI();
+  if (!s.ok()) {
+    // If we can't enable GSSAPI, it's likely the client is just missing the
+    // appropriate SASL plugin. We don't want to require it to be installed
+    // if the user doesn't care about connecting to servers using Kerberos
+    // authentication. So, we'll just VLOG this here. If we try to connect
+    // to a server which requires Kerberos, we'll get a negotiation error
+    // at that point.
+    if (VLOG_IS_ON(1)) {
+      KLOG_FIRST_N(INFO, 1) << "Couldn't enable GSSAPI (Kerberos) SASL plugin: "
+                            << s.message().ToString()
+                            << ". This process will be unable to connect to "
+                            << "servers requiring Kerberos authentication.";
+    }
+  }
+
+  RETURN_NOT_OK(client_negotiation.EnablePlain(conn->user_credentials().real_user(), ""));
+  client_negotiation.set_deadline(deadline);
+
+  RETURN_NOT_OK(WaitForClientConnect(client_negotiation.socket(), deadline));
+  RETURN_NOT_OK(client_negotiation.socket()->SetNonBlocking(false));
+
+  // Do SSL handshake.
+  // TODO(dan): This is a messy place to do this.
+  if (conn->reactor_thread()->reactor()->messenger()->ssl_enabled()) {
+    SSLSocket* ssl_socket = down_cast<SSLSocket*>(client_negotiation.socket());
+    RETURN_NOT_OK(ssl_socket->DoHandshake());
+  }
+
+  RETURN_NOT_OK(client_negotiation.Negotiate());
+  RETURN_NOT_OK(DisableSocketTimeouts(client_negotiation.socket()));
+
+  // Transfer the negotiated socket and state back to the connection.
+  conn->adopt_socket(client_negotiation.release_socket());
+  conn->set_remote_features(client_negotiation.take_server_features());
 
   return Status::OK();
 }
 
 // Perform server negotiation. We don't LOG() anything, we leave that to our caller.
-static Status DoServerNegotiation(Connection* conn,
-                                  const MonoTime& deadline) {
+static Status DoServerNegotiation(Connection* conn, const MonoTime& deadline) {
   if (FLAGS_rpc_negotiation_inject_delay_ms > 0) {
     LOG(WARNING) << "Injecting " << FLAGS_rpc_negotiation_inject_delay_ms
                  << "ms delay in negotiation";
     SleepFor(MonoDelta::FromMilliseconds(FLAGS_rpc_negotiation_inject_delay_ms));
   }
-  RETURN_NOT_OK(conn->SetNonBlocking(false));
-  RETURN_NOT_OK(conn->InitSSLIfNecessary());
-  RETURN_NOT_OK(conn->InitSaslServer());
-  conn->sasl_server().set_deadline(deadline);
-  RETURN_NOT_OK(conn->sasl_server().Negotiate());
-  RETURN_NOT_OK(RecvConnectionContext(conn, deadline));
-  RETURN_NOT_OK(DisableSocketTimeouts(conn));
+
+  // Create a new ServerNegotiation to handle the synchronous negotiation.
+  ServerNegotiation server_negotiation(conn->release_socket());
+  if (FLAGS_server_require_kerberos) {
+    RETURN_NOT_OK(server_negotiation.EnableGSSAPI());
+  } else {
+    RETURN_NOT_OK(server_negotiation.EnablePlain());
+  }
+
+  server_negotiation.set_deadline(deadline);
+  RETURN_NOT_OK(server_negotiation.socket()->SetNonBlocking(false));
+
+  // Do SSL handshake.
+  // TODO(dan): This is a messy place to do this.
+  if (conn->reactor_thread()->reactor()->messenger()->ssl_enabled()) {
+    SSLSocket* ssl_socket = down_cast<SSLSocket*>(server_negotiation.socket());
+    RETURN_NOT_OK(ssl_socket->DoHandshake());
+  }
+
+  RETURN_NOT_OK(server_negotiation.Negotiate());
+  RETURN_NOT_OK(DisableSocketTimeouts(server_negotiation.socket()));
+
+  // Transfer the negotiated socket and state back to the connection.
+  conn->adopt_socket(server_negotiation.release_socket());
+  conn->set_remote_features(server_negotiation.take_client_features());
+  conn->mutable_user_credentials()->set_real_user(server_negotiation.authenticated_user());
 
   return Status::OK();
 }
 
-// Perform negotiation for a connection (either server or client)
-void Negotiation::RunNegotiation(const scoped_refptr<Connection>& conn,
-                                 const MonoTime& deadline) {
+void Negotiation::RunNegotiation(const scoped_refptr<Connection>& conn, MonoTime deadline) {
   Status s;
   if (conn->direction() == Connection::SERVER) {
     s = DoServerNegotiation(conn.get(), deadline);

http://git-wip-us.apache.org/repos/asf/kudu/blob/dc852535/src/kudu/rpc/negotiation.h
----------------------------------------------------------------------
diff --git a/src/kudu/rpc/negotiation.h b/src/kudu/rpc/negotiation.h
index 0562555..64d009b 100644
--- a/src/kudu/rpc/negotiation.h
+++ b/src/kudu/rpc/negotiation.h
@@ -21,14 +21,16 @@
 #include "kudu/util/monotime.h"
 
 namespace kudu {
+
 namespace rpc {
 
 class Connection;
 
 class Negotiation {
  public:
-  static void RunNegotiation(const scoped_refptr<Connection>& conn,
-                             const MonoTime &deadline);
+
+  // Perform negotiation for a connection (either server or client)
+  static void RunNegotiation(const scoped_refptr<Connection>& conn, MonoTime deadline);
  private:
   DISALLOW_IMPLICIT_CONSTRUCTORS(Negotiation);
 };