You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kudu.apache.org by da...@apache.org on 2017/01/26 02:19:49 UTC

[1/2] kudu git commit: TLS-negotiation [6/n]: Refactor RPC negotiation

Repository: kudu
Updated Branches:
  refs/heads/master b9aa5dd31 -> dc8525358


http://git-wip-us.apache.org/repos/asf/kudu/blob/dc852535/src/kudu/rpc/reactor.cc
----------------------------------------------------------------------
diff --git a/src/kudu/rpc/reactor.cc b/src/kudu/rpc/reactor.cc
index f178a2d..c8f523c 100644
--- a/src/kudu/rpc/reactor.cc
+++ b/src/kudu/rpc/reactor.cc
@@ -18,17 +18,19 @@
 #include "kudu/rpc/reactor.h"
 
 #include <arpa/inet.h>
-#include <boost/intrusive/list.hpp>
-#include <ev++.h>
-#include <glog/logging.h>
-#include <mutex>
 #include <netinet/in.h>
 #include <stdlib.h>
-#include <string>
 #include <sys/socket.h>
 #include <sys/types.h>
 #include <unistd.h>
 
+#include <mutex>
+#include <string>
+
+#include <boost/intrusive/list.hpp>
+#include <ev++.h>
+#include <glog/logging.h>
+
 #include "kudu/gutil/ref_counted.h"
 #include "kudu/gutil/stringprintf.h"
 #include "kudu/rpc/client_negotiation.h"
@@ -85,7 +87,7 @@ Status ShutdownError(bool aborted) {
 }
 } // anonymous namespace
 
-ReactorThread::ReactorThread(Reactor *reactor, const MessengerBuilder &bld)
+ReactorThread::ReactorThread(Reactor *reactor, const MessengerBuilder& bld)
   : loop_(kDefaultLibEvFlags),
     cur_time_(MonoTime::Now()),
     last_unused_tcp_scan_(cur_time_),
@@ -191,7 +193,7 @@ void ReactorThread::WakeThread() {
 // threads that want to bring something to our attention, like the fact that
 // we're shutting down, or the fact that there is a new outbound Transfer
 // ready to send.
-void ReactorThread::AsyncHandler(ev::async &watcher, int revents) {
+void ReactorThread::AsyncHandler(ev::async& /*watcher*/, int /*revents*/) {
   DCHECK(IsCurrentThread());
 
   if (PREDICT_FALSE(reactor_->closing())) {
@@ -204,13 +206,13 @@ void ReactorThread::AsyncHandler(ev::async &watcher, int revents) {
   reactor_->DrainTaskQueue(&tasks);
 
   while (!tasks.empty()) {
-    ReactorTask &task = tasks.front();
+    ReactorTask& task = tasks.front();
     tasks.pop_front();
     task.Run(this);
   }
 }
 
-void ReactorThread::RegisterConnection(const scoped_refptr<Connection>& conn) {
+void ReactorThread::RegisterConnection(scoped_refptr<Connection> conn) {
   DCHECK(IsCurrentThread());
 
   Status s = StartConnectionNegotiation(conn);
@@ -219,10 +221,10 @@ void ReactorThread::RegisterConnection(const scoped_refptr<Connection>& conn) {
     DestroyConnection(conn.get(), s);
     return;
   }
-  server_conns_.push_back(conn);
+  server_conns_.emplace_back(std::move(conn));
 }
 
-void ReactorThread::AssignOutboundCall(const shared_ptr<OutboundCall> &call) {
+void ReactorThread::AssignOutboundCall(const shared_ptr<OutboundCall>& call) {
   DCHECK(IsCurrentThread());
   scoped_refptr<Connection> conn;
 
@@ -242,7 +244,7 @@ void ReactorThread::AssignOutboundCall(const shared_ptr<OutboundCall> &call) {
 // 2. every tcp_conn_timeo_ seconds, close down connections older than
 //    tcp_conn_timeo_ seconds.
 //
-void ReactorThread::TimerHandler(ev::timer &watcher, int revents) {
+void ReactorThread::TimerHandler(ev::timer& /*watcher*/, int revents) {
   DCHECK(IsCurrentThread());
   if (EV_ERROR & revents) {
     LOG(WARNING) << "Reactor " << name() << " got an error in "
@@ -293,7 +295,7 @@ void ReactorThread::ScanIdleConnections() {
   VLOG_IF(1, timed_out > 0) << name() << ": timed out " << timed_out << " TCP connections.";
 }
 
-const std::string &ReactorThread::name() const {
+const std::string& ReactorThread::name() const {
   return reactor_->name();
 }
 
@@ -321,7 +323,7 @@ void ReactorThread::RunThread() {
   reactor_->messenger_.reset();
 }
 
-Status ReactorThread::FindOrStartConnection(const ConnectionId &conn_id,
+Status ReactorThread::FindOrStartConnection(const ConnectionId& conn_id,
                                             scoped_refptr<Connection>* conn) {
   DCHECK(IsCurrentThread());
   conn_map_t::const_iterator c = client_conns_.find(conn_id);
@@ -348,7 +350,7 @@ Status ReactorThread::FindOrStartConnection(const ConnectionId &conn_id,
   }
 
   // Register the new connection in our map.
-  *conn = new Connection(this, conn_id.remote(), new_socket.release(), Connection::CLIENT);
+  *conn = new Connection(this, conn_id.remote(), std::move(new_socket), Connection::CLIENT);
   (*conn)->set_user_credentials(conn_id.user_credentials());
 
   // Kick off blocking client connection negotiation.
@@ -377,12 +379,12 @@ Status ReactorThread::StartConnectionNegotiation(const scoped_refptr<Connection>
   ADOPT_TRACE(trace.get());
   TRACE("Submitting negotiation task for $0", conn->ToString());
   RETURN_NOT_OK(reactor()->messenger()->negotiation_pool()->SubmitClosure(
-      Bind(&Negotiation::RunNegotiation, conn, deadline)));
+        Bind(&Negotiation::RunNegotiation, conn, deadline)));
   return Status::OK();
 }
 
 void ReactorThread::CompleteConnectionNegotiation(const scoped_refptr<Connection>& conn,
-      const Status &status) {
+                                                  const Status& status) {
   DCHECK(IsCurrentThread());
   if (PREDICT_FALSE(!status.ok())) {
     DestroyConnection(conn.get(), status);
@@ -396,6 +398,7 @@ void ReactorThread::CompleteConnectionNegotiation(const scoped_refptr<Connection
     DestroyConnection(conn.get(), s);
     return;
   }
+
   conn->MarkNegotiationComplete();
   conn->EpollRegister(loop_);
 }
@@ -405,13 +408,13 @@ Status ReactorThread::CreateClientSocket(Socket *sock) {
   if (ret.ok()) {
     ret = sock->SetNoDelay(true);
   }
-  LOG_IF(WARNING, !ret.ok()) << "failed to create an "
-    "outbound connection because a new socket could not "
-    "be created: " << ret.ToString();
+  LOG_IF(WARNING, !ret.ok())
+      << "failed to create an outbound connection because a new socket could not be created: "
+      << ret.ToString();
   return ret;
 }
 
-Status ReactorThread::StartConnect(Socket *sock, const Sockaddr &remote, bool *in_progress) {
+Status ReactorThread::StartConnect(Socket *sock, const Sockaddr& remote, bool *in_progress) {
   Status ret = sock->Connect(remote);
   if (ret.ok()) {
     VLOG(3) << "StartConnect: connect finished immediately for " << remote.ToString();
@@ -432,7 +435,7 @@ Status ReactorThread::StartConnect(Socket *sock, const Sockaddr &remote, bool *i
 }
 
 void ReactorThread::DestroyConnection(Connection *conn,
-                                      const Status &conn_status) {
+                                      const Status& conn_status) {
   DCHECK(IsCurrentThread());
 
   conn->Shutdown(conn_status);
@@ -455,9 +458,12 @@ void ReactorThread::DestroyConnection(Connection *conn,
   }
 }
 
-DelayedTask::DelayedTask(boost::function<void(const Status &)> func,
+DelayedTask::DelayedTask(boost::function<void(const Status&)> func,
                          MonoDelta when)
-    : func_(std::move(func)), when_(std::move(when)), thread_(nullptr) {}
+    : func_(std::move(func)),
+      when_(when),
+      thread_(nullptr) {
+}
 
 void DelayedTask::Run(ReactorThread* thread) {
   DCHECK(thread_ == nullptr) << "Task has already been scheduled";
@@ -492,7 +498,7 @@ void DelayedTask::TimerHandler(ev::timer& watcher, int revents) {
 }
 
 Reactor::Reactor(const shared_ptr<Messenger>& messenger,
-                 int index, const MessengerBuilder &bld)
+                 int index, const MessengerBuilder& bld)
   : messenger_(messenger),
     name_(StringPrintf("%s_R%03d", messenger->name().c_str(), index)),
     closing_(false),
@@ -529,7 +535,7 @@ Reactor::~Reactor() {
   Shutdown();
 }
 
-const std::string &Reactor::name() const {
+const std::string& Reactor::name() const {
   return name_;
 }
 
@@ -544,11 +550,11 @@ class RunFunctionTask : public ReactorTask {
   explicit RunFunctionTask(boost::function<Status()> f)
       : function_(std::move(f)), latch_(1) {}
 
-  virtual void Run(ReactorThread *reactor) OVERRIDE {
+  void Run(ReactorThread* /*reactor*/) override {
     status_ = function_();
     latch_.CountDown();
   }
-  virtual void Abort(const Status &status) OVERRIDE {
+  void Abort(const Status& status) override {
     status_ = status;
     latch_.CountDown();
   }
@@ -584,16 +590,16 @@ Status Reactor::DumpRunningRpcs(const DumpRunningRpcsRequestPB& req,
 
 class RegisterConnectionTask : public ReactorTask {
  public:
-  explicit RegisterConnectionTask(const scoped_refptr<Connection>& conn) :
-    conn_(conn)
-  {}
+  explicit RegisterConnectionTask(scoped_refptr<Connection> conn)
+      : conn_(std::move(conn)) {
+  }
 
-  virtual void Run(ReactorThread *thread) OVERRIDE {
-    thread->RegisterConnection(conn_);
+  void Run(ReactorThread* reactor) override {
+    reactor->RegisterConnection(std::move(conn_));
     delete this;
   }
 
-  virtual void Abort(const Status &status) OVERRIDE {
+  void Abort(const Status& /*status*/) override {
     // We don't need to Shutdown the connection since it was never registered.
     // This is only used for inbound connections, and inbound connections will
     // never have any calls added to them until they've been registered.
@@ -604,7 +610,7 @@ class RegisterConnectionTask : public ReactorTask {
   scoped_refptr<Connection> conn_;
 };
 
-void Reactor::RegisterInboundSocket(Socket *socket, const Sockaddr &remote) {
+void Reactor::RegisterInboundSocket(Socket *socket, const Sockaddr& remote) {
   VLOG(3) << name_ << ": new inbound connection to " << remote.ToString();
   std::unique_ptr<Socket> new_socket;
   if (messenger()->ssl_enabled()) {
@@ -612,9 +618,8 @@ void Reactor::RegisterInboundSocket(Socket *socket, const Sockaddr &remote) {
   } else {
     new_socket.reset(new Socket(socket->Release()));
   }
-  scoped_refptr<Connection> conn(
-    new Connection(&thread_, remote, new_socket.release(), Connection::SERVER));
-  auto task = new RegisterConnectionTask(conn);
+  auto task = new RegisterConnectionTask(
+      new Connection(&thread_, remote, std::move(new_socket), Connection::SERVER));
   ScheduleReactorTask(task);
 }
 
@@ -625,12 +630,12 @@ class AssignOutboundCallTask : public ReactorTask {
   explicit AssignOutboundCallTask(shared_ptr<OutboundCall> call)
       : call_(std::move(call)) {}
 
-  virtual void Run(ReactorThread *reactor) OVERRIDE {
+  void Run(ReactorThread* reactor) override {
     reactor->AssignOutboundCall(call_);
     delete this;
   }
 
-  virtual void Abort(const Status &status) OVERRIDE {
+  void Abort(const Status& status) override {
     call_->SetFailed(status);
     delete this;
   }
@@ -639,7 +644,7 @@ class AssignOutboundCallTask : public ReactorTask {
   shared_ptr<OutboundCall> call_;
 };
 
-void Reactor::QueueOutboundCall(const shared_ptr<OutboundCall> &call) {
+void Reactor::QueueOutboundCall(const shared_ptr<OutboundCall>& call) {
   DVLOG(3) << name_ << ": queueing outbound call "
            << call->ToString() << " to remote " << call->conn_id().remote().ToString();
   AssignOutboundCallTask *task = new AssignOutboundCallTask(call);

http://git-wip-us.apache.org/repos/asf/kudu/blob/dc852535/src/kudu/rpc/reactor.h
----------------------------------------------------------------------
diff --git a/src/kudu/rpc/reactor.h b/src/kudu/rpc/reactor.h
index 54f3332..f9f5662 100644
--- a/src/kudu/rpc/reactor.h
+++ b/src/kudu/rpc/reactor.h
@@ -17,16 +17,18 @@
 #ifndef KUDU_RPC_REACTOR_H
 #define KUDU_RPC_REACTOR_H
 
-#include <boost/function.hpp>
-#include <boost/intrusive/list.hpp>
-#include <ev++.h>
+#include <stdint.h>
+
 #include <list>
 #include <map>
 #include <memory>
 #include <set>
-#include <stdint.h>
 #include <string>
 
+#include <boost/function.hpp>
+#include <boost/intrusive/list.hpp>
+#include <ev++.h>
+
 #include "kudu/gutil/ref_counted.h"
 #include "kudu/rpc/connection.h"
 #include "kudu/rpc/transfer.h"
@@ -37,9 +39,12 @@
 #include "kudu/util/status.h"
 
 namespace kudu {
+
+class Socket;
+
 namespace rpc {
 
-typedef std::list<scoped_refptr<Connection> > conn_list_t;
+typedef std::list<scoped_refptr<Connection>> conn_list_t;
 
 class DumpRunningRpcsRequestPB;
 class DumpRunningRpcsResponsePB;
@@ -89,10 +94,10 @@ class DelayedTask : public ReactorTask {
   DelayedTask(boost::function<void(const Status &)> func, MonoDelta when);
 
   // Schedules the task for running later but doesn't actually run it yet.
-  virtual void Run(ReactorThread* reactor) OVERRIDE;
+  void Run(ReactorThread* thread) override;
 
   // Behaves like ReactorTask::Abort.
-  virtual void Abort(const Status& abort_status) OVERRIDE;
+  void Abort(const Status& abort_status) override;
 
  private:
   // libev callback for when the registered timer fires.
@@ -172,7 +177,7 @@ class ReactorThread {
   // Transition back from negotiating to processing requests.
   // Must be called from the reactor thread.
   void CompleteConnectionNegotiation(const scoped_refptr<Connection>& conn,
-      const Status& status);
+                                     const Status& status);
 
   // Collect metrics.
   // Must be called from the reactor thread.
@@ -217,7 +222,7 @@ class ReactorThread {
   void AssignOutboundCall(const std::shared_ptr<OutboundCall> &call);
 
   // Register a new connection.
-  void RegisterConnection(const scoped_refptr<Connection>& conn);
+  void RegisterConnection(scoped_refptr<Connection> conn);
 
   // Actually perform shutdown of the thread, tearing down any connections,
   // etc. This is called from within the thread.

http://git-wip-us.apache.org/repos/asf/kudu/blob/dc852535/src/kudu/rpc/rpc-test.cc
----------------------------------------------------------------------
diff --git a/src/kudu/rpc/rpc-test.cc b/src/kudu/rpc/rpc-test.cc
index ea9d266..9769ffe 100644
--- a/src/kudu/rpc/rpc-test.cc
+++ b/src/kudu/rpc/rpc-test.cc
@@ -365,8 +365,8 @@ TEST_F(TestRpc, TestNegotiationTimeout) {
   // Create another thread to accept the connection on the fake server.
   scoped_refptr<Thread> acceptor_thread;
   ASSERT_OK(Thread::Create("test", "acceptor",
-                                  AcceptAndReadForever, &listen_sock,
-                                  &acceptor_thread));
+                           AcceptAndReadForever, &listen_sock,
+                           &acceptor_thread));
 
   // Set up client.
   shared_ptr<Messenger> client_messenger(CreateMessenger("Client"));

http://git-wip-us.apache.org/repos/asf/kudu/blob/dc852535/src/kudu/rpc/sasl_common.cc
----------------------------------------------------------------------
diff --git a/src/kudu/rpc/sasl_common.cc b/src/kudu/rpc/sasl_common.cc
index 359dc80..5e4aae6 100644
--- a/src/kudu/rpc/sasl_common.cc
+++ b/src/kudu/rpc/sasl_common.cc
@@ -30,6 +30,7 @@
 #include "kudu/gutil/macros.h"
 #include "kudu/gutil/once.h"
 #include "kudu/gutil/stringprintf.h"
+#include "kudu/rpc/constants.h"
 #include "kudu/util/flag_tags.h"
 #include "kudu/util/mutex.h"
 #include "kudu/util/net/sockaddr.h"
@@ -46,11 +47,8 @@ const char* const kSaslMechGSSAPI = "GSSAPI";
 static __thread string* g_auth_failure_capture = nullptr;
 
 // Determine whether initialization was ever called
-struct InitializationData {
-  Status status;
-  string app_name;
-};
-static struct InitializationData* sasl_init_data;
+static Status sasl_init_status = Status::OK();
+static bool sasl_is_initialized = false;
 
 // If true, then we expect someone else has initialized SASL.
 static bool g_disable_sasl_init = false;
@@ -203,14 +201,9 @@ static bool SaslMutexImplementationProvided() {
 #endif
 
 // Actually perform the initialization for the SASL subsystem.
-// Meant to be called via GoogleOnceInitArg().
-static void DoSaslInit(void* app_name_char_array) {
-  // Explicitly cast from void* here so GoogleOnce doesn't have to deal with it.
-  // We were getting Clang 3.4 UBSAN errors when letting GoogleOnce cast.
-  const char* const app_name = reinterpret_cast<const char* const>(app_name_char_array);
+// Meant to be called via GoogleOnceInit().
+static void DoSaslInit() {
   VLOG(3) << "Initializing SASL library";
-  sasl_init_data = new InitializationData();
-  sasl_init_data->app_name = app_name;
 
   bool sasl_initialized = SaslIsInitialized();
   if (sasl_initialized && !g_disable_sasl_init) {
@@ -222,7 +215,7 @@ static void DoSaslInit(void* app_name_char_array) {
 
   if (g_disable_sasl_init) {
     if (!sasl_initialized) {
-      sasl_init_data->status = Status::RuntimeError(
+      sasl_init_status = Status::RuntimeError(
           "SASL initialization was disabled, but SASL was not externally initialized.");
       return;
     }
@@ -232,30 +225,29 @@ static void DoSaslInit(void* app_name_char_array) {
           << "but was not provided with a mutex implementation! If "
           << "manually initializing SASL, use sasl_set_mutex(3).";
     }
-    sasl_init_data->status = Status::OK();
     return;
   }
   internal::SaslSetMutex();
   int result = sasl_client_init(&callbacks[0]);
   if (result != SASL_OK) {
-    sasl_init_data->status = Status::RuntimeError("Could not initialize SASL client",
-        sasl_errstring(result, nullptr, nullptr));
+    sasl_init_status = Status::RuntimeError("Could not initialize SASL client",
+                                            sasl_errstring(result, nullptr, nullptr));
     return;
   }
 
-  result = sasl_server_init(&callbacks[0], sasl_init_data->app_name.c_str());
+  result = sasl_server_init(&callbacks[0], kSaslAppName);
   if (result != SASL_OK) {
-    sasl_init_data->status = Status::RuntimeError("Could not initialize SASL server",
-        sasl_errstring(result, nullptr, nullptr));
+    sasl_init_status = Status::RuntimeError("Could not initialize SASL server",
+                                            sasl_errstring(result, nullptr, nullptr));
     return;
   }
 
-  sasl_init_data->status = Status::OK();
+  sasl_is_initialized = true;
 }
 
 Status DisableSaslInitialization() {
   if (g_disable_sasl_init) return Status::OK();
-  if (sasl_init_data != nullptr) {
+  if (sasl_is_initialized) {
     return Status::IllegalState("SASL already initialized. Initialization can only be disabled "
                                 "before first usage.");
   }
@@ -263,18 +255,11 @@ Status DisableSaslInitialization() {
   return Status::OK();
 }
 
-Status SaslInit(const char* app_name) {
+Status SaslInit() {
   // Only execute SASL initialization once
   static GoogleOnceType once = GOOGLE_ONCE_INIT;
-  GoogleOnceInitArg(&once,
-                    &DoSaslInit,
-                    // This is a bit ugly, but Clang 3.4 UBSAN complains otherwise.
-                    reinterpret_cast<void*>(const_cast<char*>(app_name)));
-  if (PREDICT_FALSE(sasl_init_data->app_name != app_name)) {
-    return Status::InvalidArgument("SaslInit called successively with different arguments",
-        StringPrintf("Previous: %s, current: %s", sasl_init_data->app_name.c_str(), app_name));
-  }
-  return sasl_init_data->status;
+  GoogleOnceInit(&once, &DoSaslInit);
+  return sasl_init_status;
 }
 
 static string CleanSaslError(const string& err) {

http://git-wip-us.apache.org/repos/asf/kudu/blob/dc852535/src/kudu/rpc/sasl_common.h
----------------------------------------------------------------------
diff --git a/src/kudu/rpc/sasl_common.h b/src/kudu/rpc/sasl_common.h
index 419fc5f..53b713e 100644
--- a/src/kudu/rpc/sasl_common.h
+++ b/src/kudu/rpc/sasl_common.h
@@ -53,7 +53,7 @@ extern const char* const kSaslMechGSSAPI;
 //
 // This function is thread safe and uses a static lock.
 // This function should NOT be called during static initialization.
-Status SaslInit(const char* app_name);
+Status SaslInit();
 
 // Disable Kudu's initialization of SASL. See equivalent method in client.h.
 Status DisableSaslInitialization();

http://git-wip-us.apache.org/repos/asf/kudu/blob/dc852535/src/kudu/rpc/sasl_helper.cc
----------------------------------------------------------------------
diff --git a/src/kudu/rpc/sasl_helper.cc b/src/kudu/rpc/sasl_helper.cc
index 4e408f7..0614f3e 100644
--- a/src/kudu/rpc/sasl_helper.cc
+++ b/src/kudu/rpc/sasl_helper.cc
@@ -17,28 +17,24 @@
 
 #include "kudu/rpc/sasl_helper.h"
 
-#include <set>
 #include <string>
 
 #include <glog/logging.h>
 #include <google/protobuf/message_lite.h>
 
-#include "kudu/gutil/endian.h"
-#include "kudu/gutil/gscoped_ptr.h"
 #include "kudu/gutil/macros.h"
 #include "kudu/gutil/map-util.h"
 #include "kudu/gutil/port.h"
 #include "kudu/gutil/strings/join.h"
 #include "kudu/gutil/strings/substitute.h"
-#include "kudu/rpc/blocking_ops.h"
 #include "kudu/rpc/constants.h"
 #include "kudu/rpc/rpc_header.pb.h"
 #include "kudu/rpc/sasl_common.h"
 #include "kudu/rpc/serialization.h"
-#include "kudu/util/faststring.h"
-#include "kudu/util/monotime.h"
 #include "kudu/util/status.h"
 
+using std::string;
+
 namespace kudu {
 namespace rpc {
 
@@ -46,13 +42,10 @@ using google::protobuf::MessageLite;
 
 SaslHelper::SaslHelper(PeerType peer_type)
   : peer_type_(peer_type),
-    conn_header_exchanged_(false),
+    global_mechs_(SaslListAvailableMechs()),
     plain_enabled_(false),
     gssapi_enabled_(false) {
-  tag_ = (peer_type_ == SERVER) ? "Sasl Server" : "Sasl Client";
-}
-
-SaslHelper::~SaslHelper() {
+  tag_ = (peer_type_ == SERVER) ? "Server" : "Client";
 }
 
 void SaslHelper::set_local_addr(const Sockaddr& addr) {
@@ -76,27 +69,11 @@ const char* SaslHelper::server_fqdn() const {
   return server_fqdn_.empty() ? nullptr : server_fqdn_.c_str();
 }
 
-const std::set<std::string>& SaslHelper::GlobalMechs() const {
-  if (!global_mechs_) {
-    global_mechs_.reset(new set<string>(SaslListAvailableMechs()));
-  }
-  return *global_mechs_;
-}
-
-void SaslHelper::AddToLocalMechList(const string& mech) {
-  mechs_.insert(mech);
-}
-
-const std::set<std::string>& SaslHelper::LocalMechs() const {
-  return mechs_;
+const char* SaslHelper::EnabledMechsString() const {
+  JoinStrings(enabled_mechs_, " ", &enabled_mechs_string_);
+  return enabled_mechs_string_.c_str();
 }
 
-const char* SaslHelper::LocalMechListString() const {
-  JoinStrings(mechs_, " ", &mech_list_);
-  return mech_list_.c_str();
-}
-
-
 int SaslHelper::GetOptionCb(const char* plugin_name, const char* option,
                             const char** result, unsigned* len) {
   DVLOG(4) << tag_ << ": GetOption Callback called. ";
@@ -112,7 +89,7 @@ int SaslHelper::GetOptionCb(const char* plugin_name, const char* option,
   if (plugin_name == nullptr) {
     // SASL library option, not a plugin option
     if (strcmp(option, "mech_list") == 0) {
-      *result = LocalMechListString();
+      *result = EnabledMechsString();
       if (len != nullptr) *len = strlen(*result);
       VLOG(4) << tag_ << ": Enabled mech list: " << *result;
       return SASL_OK;
@@ -137,10 +114,10 @@ Status SaslHelper::EnableGSSAPI() {
 }
 
 Status SaslHelper::EnableMechanism(const string& mech) {
-  if (PREDICT_FALSE(!ContainsKey(GlobalMechs(), mech))) {
+  if (PREDICT_FALSE(!ContainsKey(global_mechs_, mech))) {
     return Status::InvalidArgument("unable to find SASL plugin", mech);
   }
-  AddToLocalMechList(mech);
+  enabled_mechs_.insert(mech);
   return Status::OK();
 }
 
@@ -148,11 +125,11 @@ bool SaslHelper::IsPlainEnabled() const {
   return plain_enabled_;
 }
 
-Status SaslHelper::SanityCheckNegotiateCallId(int32_t call_id) const {
+Status SaslHelper::CheckNegotiateCallId(int32_t call_id) const {
   if (call_id != kNegotiateCallId) {
     Status s = Status::IllegalState(strings::Substitute(
-          "Non-Negotiate request during negotiation. Expected callId: $0, received callId: $1",
-          kNegotiateCallId, call_id));
+        "Received illegal call-id during negotiation; expected: $0, received: $1",
+        kNegotiateCallId, call_id));
     LOG(DFATAL) << tag_ << ": " << s.ToString();
     return s;
   }
@@ -167,27 +144,5 @@ Status SaslHelper::ParseNegotiatePB(const Slice& param_buf, NegotiatePB* msg) {
   return Status::OK();
 }
 
-Status SaslHelper::SendNegotiatePB(Socket* sock,
-                                   const MessageLite& header,
-                                   const MessageLite& msg,
-                                   const MonoTime& deadline) {
-  DCHECK(sock != nullptr);
-  DCHECK(header.IsInitialized()) << tag_ << ": Header must be initialized";
-  DCHECK(msg.IsInitialized()) << tag_ << ": Message must be initialized";
-
-  // Write connection header, if needed
-  if (PREDICT_FALSE(peer_type_ == CLIENT && !conn_header_exchanged_)) {
-    const uint8_t buflen = kMagicNumberLength + kHeaderFlagsLength;
-    uint8_t buf[buflen];
-    serialization::SerializeConnHeader(buf);
-    size_t nsent;
-    RETURN_NOT_OK(sock->BlockingWrite(buf, buflen, &nsent, deadline));
-    conn_header_exchanged_ = true;
-  }
-
-  RETURN_NOT_OK(SendFramedMessageBlocking(sock, header, msg, deadline));
-  return Status::OK();
-}
-
 } // namespace rpc
 } // namespace kudu

http://git-wip-us.apache.org/repos/asf/kudu/blob/dc852535/src/kudu/rpc/sasl_helper.h
----------------------------------------------------------------------
diff --git a/src/kudu/rpc/sasl_helper.h b/src/kudu/rpc/sasl_helper.h
index f0a676e..05d6904 100644
--- a/src/kudu/rpc/sasl_helper.h
+++ b/src/kudu/rpc/sasl_helper.h
@@ -23,30 +23,19 @@
 
 #include <sasl/sasl.h>
 
-#include "kudu/gutil/gscoped_ptr.h"
-#include "kudu/gutil/macros.h"
-#include "kudu/util/net/socket.h"
-
-namespace google {
-namespace protobuf {
-class MessageLite;
-} // namespace protobuf
-} // namespace google
+#include "kudu/util/status.h"
 
 namespace kudu {
 
-class MonoTime;
 class Sockaddr;
-class Status;
 
 namespace rpc {
 
-using std::string;
-
 class NegotiatePB;
 
-// Helper class which contains functionality that is common to SaslClient & SaslServer.
-// Most of these methods are convenience methods for interacting with the libsasl2 library.
+// Helper class which contains functionality that is common to client and server
+// SASL negotiations. Most of these methods are convenience methods for
+// interacting with the libsasl2 library.
 class SaslHelper {
  public:
   enum PeerType {
@@ -55,7 +44,7 @@ class SaslHelper {
   };
 
   explicit SaslHelper(PeerType peer_type);
-  ~SaslHelper();
+  ~SaslHelper() = default;
 
   // Specify IP:port of local side of connection.
   void set_local_addr(const Sockaddr& addr);
@@ -66,20 +55,18 @@ class SaslHelper {
   const char* remote_addr_string() const;
 
   // Specify the fully-qualified domain name of the remote server.
-  void set_server_fqdn(const string& domain_name);
+  void set_server_fqdn(const std::string& domain_name);
   const char* server_fqdn() const;
 
   // Globally-registered available SASL plugins.
-  const std::set<string>& GlobalMechs() const;
+  const std::set<std::string>& GlobalMechs() const {
+    return global_mechs_;
+  }
 
   // Helper functions for managing the list of active SASL mechanisms.
-  void AddToLocalMechList(const string& mech);
-  const std::set<string>& LocalMechs() const;
-
-  // Returns space-delimited local mechanism list string suitable for passing
-  // to libsasl2, such as via "mech_list" callbacks.
-  // The returned pointer is valid only until the next call to LocalMechListString().
-  const char* LocalMechListString() const;
+  const std::set<std::string>& EnabledMechs() const {
+    return enabled_mechs_;
+  }
 
   // Implements the client_mech_list / mech_list callbacks.
   int GetOptionCb(const char* plugin_name, const char* option, const char** result, unsigned* len);
@@ -95,31 +82,29 @@ class SaslHelper {
 
   // Sanity check that the call ID is the negotiation call ID.
   // Logs DFATAL if call_id does not match.
-  Status SanityCheckNegotiateCallId(int32_t call_id) const;
+  Status CheckNegotiateCallId(int32_t call_id) const;
 
   // Parse msg from the given Slice.
   Status ParseNegotiatePB(const Slice& param_buf, NegotiatePB* msg);
 
-  // Encode and send a message over a socket, sending the connection header if necessary.
-  Status SendNegotiatePB(Socket* sock,
-                         const google::protobuf::MessageLite& header,
-                         const google::protobuf::MessageLite& msg,
-                         const MonoTime& deadline);
-
  private:
   Status EnableMechanism(const std::string& mech);
 
-  string local_addr_;
-  string remote_addr_;
-  string server_fqdn_;
+  // Returns space-delimited local mechanism list string suitable for passing
+  // to libsasl2, such as via "mech_list" callbacks.
+  // The returned pointer is valid only until the next call to EnabledMechsString().
+  const char* EnabledMechsString() const;
+
+  std::string local_addr_;
+  std::string remote_addr_;
+  std::string server_fqdn_;
 
   // Authentication types and data.
   const PeerType peer_type_;
-  bool conn_header_exchanged_;
-  string tag_;
-  mutable gscoped_ptr< std::set<string> > global_mechs_;  // Cache of global mechanisms.
-  std::set<string> mechs_;    // Active mechanisms.
-  mutable string mech_list_;  // Mechanism list string returned by callbacks.
+  std::string tag_;
+  std::set<std::string> global_mechs_;       // Cache of global mechanisms.
+  std::set<std::string> enabled_mechs_;      // Active mechanisms.
+  mutable std::string enabled_mechs_string_; // Mechanism list string returned by callbacks.
 
   bool plain_enabled_;
   bool gssapi_enabled_;

http://git-wip-us.apache.org/repos/asf/kudu/blob/dc852535/src/kudu/rpc/server_negotiation.cc
----------------------------------------------------------------------
diff --git a/src/kudu/rpc/server_negotiation.cc b/src/kudu/rpc/server_negotiation.cc
index 9e91481..837fb47 100644
--- a/src/kudu/rpc/server_negotiation.cc
+++ b/src/kudu/rpc/server_negotiation.cc
@@ -17,17 +17,18 @@
 
 #include "kudu/rpc/server_negotiation.h"
 
-#include <glog/logging.h>
-#include <google/protobuf/message_lite.h>
 #include <limits>
-#include <sasl/sasl.h>
+#include <memory>
 #include <set>
 #include <string>
 
+#include <glog/logging.h>
+#include <google/protobuf/message_lite.h>
+#include <sasl/sasl.h>
+
 #include "kudu/gutil/endian.h"
 #include "kudu/gutil/map-util.h"
-#include "kudu/gutil/stringprintf.h"
-#include "kudu/gutil/strings/split.h"
+#include "kudu/gutil/strings/substitute.h"
 #include "kudu/rpc/blocking_ops.h"
 #include "kudu/rpc/constants.h"
 #include "kudu/rpc/serialization.h"
@@ -36,242 +37,197 @@
 #include "kudu/util/scoped_cleanup.h"
 #include "kudu/util/trace.h"
 
+using std::set;
+using std::string;
+using std::unique_ptr;
+
 namespace kudu {
 namespace rpc {
 
-static int SaslServerGetoptCb(void* sasl_server, const char* plugin_name, const char* option,
-                       const char** result, unsigned* len) {
-  return static_cast<SaslServer*>(sasl_server)
-    ->GetOptionCb(plugin_name, option, result, len);
+static int ServerNegotiationGetoptCb(ServerNegotiation* server_negotiation,
+                                     const char* plugin_name,
+                                     const char* option,
+                                     const char** result,
+                                     unsigned* len) {
+  return server_negotiation->GetOptionCb(plugin_name, option, result, len);
 }
 
-static int SaslServerPlainAuthCb(sasl_conn_t *conn, void *sasl_server, const char *user,
-    const char *pass, unsigned passlen, struct propctx *propctx) {
-  return static_cast<SaslServer*>(sasl_server)
-    ->PlainAuthCb(conn, user, pass, passlen, propctx);
+static int ServerNegotiationPlainAuthCb(sasl_conn_t* conn,
+                                        ServerNegotiation* server_negotiation,
+                                        const char* user,
+                                        const char* pass,
+                                        unsigned passlen,
+                                        struct propctx* propctx) {
+  return server_negotiation->PlainAuthCb(conn, user, pass, passlen, propctx);
 }
 
-SaslServer::SaslServer(string app_name, Socket* socket)
-    : app_name_(std::move(app_name)),
-      sock_(socket),
+ServerNegotiation::ServerNegotiation(unique_ptr<Socket> socket)
+    : socket_(std::move(socket)),
       helper_(SaslHelper::SERVER),
-      server_state_(SaslNegotiationState::NEW),
       negotiated_mech_(SaslMechanism::INVALID),
       deadline_(MonoTime::Max()) {
   callbacks_.push_back(SaslBuildCallback(SASL_CB_GETOPT,
-      reinterpret_cast<int (*)()>(&SaslServerGetoptCb), this));
+      reinterpret_cast<int (*)()>(&ServerNegotiationGetoptCb), this));
   callbacks_.push_back(SaslBuildCallback(SASL_CB_SERVER_USERDB_CHECKPASS,
-      reinterpret_cast<int (*)()>(&SaslServerPlainAuthCb), this));
+      reinterpret_cast<int (*)()>(&ServerNegotiationPlainAuthCb), this));
   callbacks_.push_back(SaslBuildCallback(SASL_CB_LIST_END, nullptr, nullptr));
 }
 
-Status SaslServer::EnablePlain() {
-  DCHECK_EQ(server_state_, SaslNegotiationState::NEW);
+Status ServerNegotiation::EnablePlain() {
   RETURN_NOT_OK(helper_.EnablePlain());
   return Status::OK();
 }
 
-Status SaslServer::EnableGSSAPI() {
-  DCHECK_EQ(server_state_, SaslNegotiationState::NEW);
+Status ServerNegotiation::EnableGSSAPI() {
   return helper_.EnableGSSAPI();
 }
 
-SaslMechanism::Type SaslServer::negotiated_mechanism() const {
-  DCHECK_EQ(server_state_, SaslNegotiationState::NEGOTIATED);
+SaslMechanism::Type ServerNegotiation::negotiated_mechanism() const {
   return negotiated_mech_;
 }
 
-const std::string& SaslServer::authenticated_user() const {
-  DCHECK_EQ(server_state_, SaslNegotiationState::NEGOTIATED);
+const string& ServerNegotiation::authenticated_user() const {
   return authenticated_user_;
 }
 
-void SaslServer::set_local_addr(const Sockaddr& addr) {
-  DCHECK_EQ(server_state_, SaslNegotiationState::NEW);
+void ServerNegotiation::set_local_addr(const Sockaddr& addr) {
   helper_.set_local_addr(addr);
 }
 
-void SaslServer::set_remote_addr(const Sockaddr& addr) {
-  DCHECK_EQ(server_state_, SaslNegotiationState::NEW);
+void ServerNegotiation::set_remote_addr(const Sockaddr& addr) {
   helper_.set_remote_addr(addr);
 }
 
-void SaslServer::set_server_fqdn(const string& domain_name) {
-  DCHECK_EQ(server_state_, SaslNegotiationState::NEW);
+void ServerNegotiation::set_server_fqdn(const string& domain_name) {
   helper_.set_server_fqdn(domain_name);
 }
 
-void SaslServer::set_deadline(const MonoTime& deadline) {
-  DCHECK_NE(server_state_, SaslNegotiationState::NEGOTIATED);
+void ServerNegotiation::set_deadline(const MonoTime& deadline) {
   deadline_ = deadline;
 }
 
-// calls sasl_server_init() and sasl_server_new()
-Status SaslServer::Init(const string& service_type) {
-  RETURN_NOT_OK(SaslInit(app_name_.c_str()));
-
-  // Ensure we are not called more than once.
-  if (server_state_ != SaslNegotiationState::NEW) {
-    return Status::IllegalState("Init() may only be called once per SaslServer object.");
-  }
-
-  // TODO(unknown): Support security flags.
-  unsigned secflags = 0;
-
-  sasl_conn_t* sasl_conn = nullptr;
-  Status s = WrapSaslCall(nullptr /* no conn */, [&]() {
-      return sasl_server_new(
-          // Registered name of the service using SASL. Required.
-          service_type.c_str(),
-          // The fully qualified domain name of this server.
-          helper_.server_fqdn(),
-          // Permits multiple user realms on server. NULL == use default.
-          nullptr,
-          // Local and remote IP address strings. (NULL disables
-          // mechanisms which require this info.)
-          helper_.local_addr_string(),
-          helper_.remote_addr_string(),
-          // Connection-specific callbacks.
-          &callbacks_[0],
-          // Security flags.
-          secflags,
-          &sasl_conn);
-    });
-
-  if (PREDICT_FALSE(!s.ok())) {
-    return Status::RuntimeError("Unable to create new SASL server",
-                                s.message());
-  }
-  sasl_conn_.reset(sasl_conn);
-
-  server_state_ = SaslNegotiationState::INITIALIZED;
-  return Status::OK();
-}
-
-Status SaslServer::Negotiate() {
-  // After negotiation, we no longer need the SASL library object, so
-  // may as well free its memory since the connection may be long-lived.
-  auto cleanup = MakeScopedCleanup([&]() {
-      sasl_conn_.reset();
-    });
-  DVLOG(4) << "Called SaslServer::Negotiate()";
-
-  // Ensure we are called exactly once, and in the right order.
-  if (server_state_ == SaslNegotiationState::NEW) {
-    return Status::IllegalState("SaslServer: Init() must be called before calling Negotiate()");
-  }
-  if (server_state_ == SaslNegotiationState::NEGOTIATED) {
-    return Status::IllegalState("SaslServer: Negotiate() may only be called once per object.");
-  }
+Status ServerNegotiation::Negotiate() {
+  TRACE("Beginning negotiation");
 
   // Ensure we can use blocking calls on the socket during negotiation.
-  RETURN_NOT_OK(EnsureBlockingMode(sock_));
+  RETURN_NOT_OK(EnsureBlockingMode(socket_.get()));
 
   faststring recv_buf;
 
-  // Read connection header
+  // Step 1: Read the connection header.
   RETURN_NOT_OK(ValidateConnectionHeader(&recv_buf));
 
-  nego_ok_ = false;
-  while (!nego_ok_) {
-    TRACE("Waiting for next Negotiation message...");
-    RequestHeader header;
-    Slice param_buf;
-    RETURN_NOT_OK(ReceiveFramedMessageBlocking(sock_, &recv_buf, &header, &param_buf, deadline_));
+  { // Step 2: Receive and respond to the NEGOTIATE step message.
+    NegotiatePB request;
+    RETURN_NOT_OK(RecvNegotiatePB(&request, &recv_buf));
+    RETURN_NOT_OK(HandleNegotiate(request));
+  }
 
+  // Step 3: SASL negotiation.
+  RETURN_NOT_OK(InitSaslServer());
+  {
     NegotiatePB request;
-    RETURN_NOT_OK(ParseNegotiatePB(header, param_buf, &request));
-
-    switch (request.step()) {
-      // NEGOTIATE: They want a list of available mechanisms.
-      case NegotiatePB::NEGOTIATE:
-        RETURN_NOT_OK(HandleNegotiateRequest(request));
-        break;
-
-      // INITIATE: They want to initiate negotiation based on their specified mechanism.
-      case NegotiatePB::SASL_INITIATE:
-        RETURN_NOT_OK(HandleInitiateRequest(request));
-        break;
-
-      // RESPONSE: Client sent a new request as a follow-up to a SASL_CHALLENGE response.
-      case NegotiatePB::SASL_RESPONSE:
-        RETURN_NOT_OK(HandleResponseRequest(request));
-        break;
-
-      // Client sent us an unsupported Negotiation request.
-      default: {
-        TRACE("SASL Server: Received unsupported request from client");
-        Status s = Status::InvalidArgument("RPC server doesn't support negotiation step in request",
-                                           NegotiatePB::NegotiateStep_Name(request.step()));
-        RETURN_NOT_OK(SendRpcError(ErrorStatusPB::FATAL_UNAUTHORIZED, s));
-        return s;
-      }
+    RETURN_NOT_OK(RecvNegotiatePB(&request, &recv_buf));
+    Status s = HandleSaslInitiate(request);
+
+    while (s.IsIncomplete()) {
+      RETURN_NOT_OK(RecvNegotiatePB(&request, &recv_buf));
+      s = HandleSaslResponse(request);
     }
+    RETURN_NOT_OK(s);
   }
-
   const char* username = nullptr;
   int rc = sasl_getprop(sasl_conn_.get(), SASL_USERNAME,
                         reinterpret_cast<const void**>(&username));
   // We expect that SASL_USERNAME will always get set.
-  CHECK(rc == SASL_OK && username != nullptr)
-      << "No username on authenticated connection";
+  CHECK(rc == SASL_OK && username != nullptr) << "No username on authenticated connection";
   authenticated_user_ = username;
 
-  TRACE("SASL Server: Successful negotiation");
-  server_state_ = SaslNegotiationState::NEGOTIATED;
+  // Step 4: Receive connection context.
+  RETURN_NOT_OK(RecvConnectionContext(&recv_buf));
+
+  TRACE("Negotiation successful");
   return Status::OK();
 }
 
-Status SaslServer::ValidateConnectionHeader(faststring* recv_buf) {
-  TRACE("Waiting for connection header");
-  size_t num_read;
-  const size_t conn_header_len = kMagicNumberLength + kHeaderFlagsLength;
-  recv_buf->resize(conn_header_len);
-  RETURN_NOT_OK(sock_->BlockingRecv(recv_buf->data(), conn_header_len, &num_read, deadline_));
-  DCHECK_EQ(conn_header_len, num_read);
+Status ServerNegotiation::PreflightCheckGSSAPI() {
+  // TODO(todd): the error messages that come from this function on el6
+  // are relatively useless due to the following krb5 bug:
+  // http://krbdev.mit.edu/rt/Ticket/Display.html?id=6973
+  // This may not be useful anymore given the keytab login that happens
+  // in security/init.cc.
 
-  RETURN_NOT_OK(serialization::ValidateConnHeader(*recv_buf));
-  TRACE("Connection header received");
-  return Status::OK();
+  // Initialize a ServerNegotiation with a null socket, and enable
+  // only GSSAPI.
+  //
+  // We aren't going to actually send/receive any messages, but
+  // this makes it easier to reuse the initialization code.
+  ServerNegotiation server(nullptr);
+  Status s = server.EnableGSSAPI();
+  if (!s.ok()) {
+    return Status::RuntimeError(s.message());
+  }
+
+  RETURN_NOT_OK(server.InitSaslServer());
+
+  // Start the SASL server as if we were accepting a connection.
+  const char* server_out = nullptr; // ignored
+  uint32_t server_out_len = 0;
+  s = WrapSaslCall(server.sasl_conn_.get(), [&]() {
+      return sasl_server_start(
+          server.sasl_conn_.get(),
+          kSaslMechGSSAPI,
+          "", 0,  // Pass a 0-length token.
+          &server_out, &server_out_len);
+    });
+
+  // We expect 'Incomplete' status to indicate that the first step of negotiation
+  // was correct.
+  if (s.IsIncomplete()) return Status::OK();
+
+  string err_msg = s.message().ToString();
+  if (err_msg == "Permission denied") {
+    // For bad keytab permissions, we get a rather vague message. So,
+    // we make it more specific for better usability.
+    err_msg = "error accessing keytab: " + err_msg;
+  }
+  return Status::RuntimeError(err_msg);
 }
 
-Status SaslServer::ParseNegotiatePB(const RequestHeader& header,
-                                    const Slice& param_buf,
-                                    NegotiatePB* request) {
-  Status s = helper_.SanityCheckNegotiateCallId(header.call_id());
+Status ServerNegotiation::RecvNegotiatePB(NegotiatePB* msg, faststring* recv_buf) {
+  RequestHeader header;
+  Slice param_buf;
+  RETURN_NOT_OK(ReceiveFramedMessageBlocking(socket(), recv_buf, &header, &param_buf, deadline_));
+  Status s = helper_.CheckNegotiateCallId(header.call_id());
   if (!s.ok()) {
-    RETURN_NOT_OK(SendRpcError(ErrorStatusPB::FATAL_INVALID_RPC_HEADER, s));
+    RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_INVALID_RPC_HEADER, s));
+    return s;
   }
 
-  s = helper_.ParseNegotiatePB(param_buf, request);
+  s = helper_.ParseNegotiatePB(param_buf, msg);
   if (!s.ok()) {
-    RETURN_NOT_OK(SendRpcError(ErrorStatusPB::FATAL_DESERIALIZING_REQUEST, s));
+    RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_DESERIALIZING_REQUEST, s));
     return s;
   }
 
+  TRACE("Received $0 NegotiatePB request", NegotiatePB::NegotiateStep_Name(msg->step()));
   return Status::OK();
 }
 
-Status SaslServer::SendNegotiatePB(const NegotiatePB& msg) {
-  DCHECK_NE(server_state_, SaslNegotiationState::NEW)
-      << "Must not send Negotiate messages before calling Init()";
-  DCHECK_NE(server_state_, SaslNegotiationState::NEGOTIATED)
-      << "Must not send Negotiate messages after Negotiate() succeeds";
-
-  // Create header with negotiation-specific callId
+Status ServerNegotiation::SendNegotiatePB(const NegotiatePB& msg) {
   ResponseHeader header;
   header.set_call_id(kNegotiateCallId);
-  return helper_.SendNegotiatePB(sock_, header, msg, deadline_);
+
+  DCHECK(socket_);
+  DCHECK(msg.IsInitialized()) << "message must be initialized";
+  DCHECK(msg.has_step()) << "message must have a step";
+
+  TRACE("Sending $0 NegotiatePB response", NegotiatePB::NegotiateStep_Name(msg.step()));
+  return SendFramedMessageBlocking(socket(), header, msg, deadline_);
 }
 
-Status SaslServer::SendRpcError(ErrorStatusPB::RpcErrorCodePB code, const Status& err) {
-  DCHECK_NE(server_state_, SaslNegotiationState::NEW)
-      << "Must not send SASL messages before calling Init()";
-  DCHECK_NE(server_state_, SaslNegotiationState::NEGOTIATED)
-      << "Must not send SASL messages after Negotiate() succeeds";
-  if (err.ok()) {
-    return Status::InvalidArgument("Cannot send error message using OK status");
-  }
+Status ServerNegotiation::SendError(ErrorStatusPB::RpcErrorCodePB code, const Status& err) {
+  DCHECK(!err.ok());
 
   // Create header with negotiation-specific callId
   ResponseHeader header;
@@ -283,39 +239,87 @@ Status SaslServer::SendRpcError(ErrorStatusPB::RpcErrorCodePB code, const Status
   msg.set_code(code);
   msg.set_message(err.ToString());
 
-  RETURN_NOT_OK(helper_.SendNegotiatePB(sock_, header, msg, deadline_));
-  TRACE("Sent SASL error: $0", ErrorStatusPB::RpcErrorCodePB_Name(code));
+  TRACE("Sending RPC error: $0", ErrorStatusPB::RpcErrorCodePB_Name(code));
+  RETURN_NOT_OK(SendFramedMessageBlocking(socket(), header, msg, deadline_));
+
+  return Status::OK();
+}
+
+Status ServerNegotiation::ValidateConnectionHeader(faststring* recv_buf) {
+  TRACE("Waiting for connection header");
+  size_t num_read;
+  const size_t conn_header_len = kMagicNumberLength + kHeaderFlagsLength;
+  recv_buf->resize(conn_header_len);
+  RETURN_NOT_OK(socket_->BlockingRecv(recv_buf->data(), conn_header_len, &num_read, deadline_));
+  DCHECK_EQ(conn_header_len, num_read);
+
+  RETURN_NOT_OK(serialization::ValidateConnHeader(*recv_buf));
+  TRACE("Connection header received");
   return Status::OK();
 }
 
-Status SaslServer::HandleNegotiateRequest(const NegotiatePB& request) {
-  TRACE("SASL Server: Received NEGOTIATE request from client");
+// calls sasl_server_init() and sasl_server_new()
+Status ServerNegotiation::InitSaslServer() {
+  RETURN_NOT_OK(SaslInit());
+
+  // TODO(unknown): Support security flags.
+  unsigned secflags = 0;
+
+  sasl_conn_t* sasl_conn = nullptr;
+  RETURN_NOT_OK_PREPEND(WrapSaslCall(nullptr /* no conn */, [&]() {
+      return sasl_server_new(
+          // Registered name of the service using SASL. Required.
+          kSaslProtoName,
+          // The fully qualified domain name of this server.
+          helper_.server_fqdn(),
+          // Permits multiple user realms on server. NULL == use default.
+          nullptr,
+          // Local and remote IP address strings. (NULL disables
+          // mechanisms which require this info.)
+          helper_.local_addr_string(),
+          helper_.remote_addr_string(),
+          // Connection-specific callbacks.
+          &callbacks_[0],
+          // Security flags.
+          secflags,
+          &sasl_conn);
+    }), "Unable to create new SASL server");
+  sasl_conn_.reset(sasl_conn);
+  return Status::OK();
+}
+
+Status ServerNegotiation::HandleNegotiate(const NegotiatePB& request) {
+  if (request.step() != NegotiatePB::NEGOTIATE) {
+    return Status::NotAuthorized("expected NEGOTIATE step",
+                                 NegotiatePB::NegotiateStep_Name(request.step()));
+  }
+  TRACE("Received NEGOTIATE request from client");
 
   // Fill in the set of features supported by the client.
   for (int flag : request.supported_features()) {
     // We only add the features that our local build knows about.
     RpcFeatureFlag feature_flag = RpcFeatureFlag_IsValid(flag) ?
                                   static_cast<RpcFeatureFlag>(flag) : UNKNOWN;
-    if (ContainsKey(kSupportedServerRpcFeatureFlags, feature_flag)) {
+    if (feature_flag != UNKNOWN) {
       client_features_.insert(feature_flag);
     }
   }
 
-  set<string> server_mechs = helper_.LocalMechs();
+  set<string> server_mechs = helper_.EnabledMechs();
   if (PREDICT_FALSE(server_mechs.empty())) {
     // This will happen if no mechanisms are enabled before calling Init()
-    Status s = Status::IllegalState("SASL server mechanism list is empty!");
+    Status s = Status::NotAuthorized("SASL server mechanism list is empty!");
     LOG(ERROR) << s.ToString();
-    TRACE("SASL Server: Sending FATAL_UNAUTHORIZED response to client");
-    RETURN_NOT_OK(SendRpcError(ErrorStatusPB::FATAL_UNAUTHORIZED, s));
+    TRACE("Sending FATAL_UNAUTHORIZED response to client");
+    RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_UNAUTHORIZED, s));
     return s;
   }
 
-  RETURN_NOT_OK(SendNegotiateResponse(server_mechs));
+  RETURN_NOT_OK(SendNegotiate(server_mechs));
   return Status::OK();
 }
 
-Status SaslServer::SendNegotiateResponse(const set<string>& server_mechs) {
+Status ServerNegotiation::SendNegotiate(const set<string>& server_mechs) {
   NegotiatePB response;
   response.set_step(NegotiatePB::NEGOTIATE);
 
@@ -329,31 +333,32 @@ Status SaslServer::SendNegotiateResponse(const set<string>& server_mechs) {
   }
 
   RETURN_NOT_OK(SendNegotiatePB(response));
-  TRACE("Sent NEGOTIATE response");
   return Status::OK();
 }
 
-
-Status SaslServer::HandleInitiateRequest(const NegotiatePB& request) {
-  TRACE("SASL Server: Received INITIATE request from client");
+Status ServerNegotiation::HandleSaslInitiate(const NegotiatePB& request) {
+  if (PREDICT_FALSE(request.step() != NegotiatePB::SASL_INITIATE)) {
+    Status s =  Status::NotAuthorized("expected SASL_INITIATE step",
+                                      NegotiatePB::NegotiateStep_Name(request.step()));
+    RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_UNAUTHORIZED, s));
+    return s;
+  }
+  TRACE("Received SASL_INITIATE request from client");
 
   if (request.auths_size() != 1) {
-    Status s = Status::NotAuthorized(StringPrintf(
-          "SASL INITIATE request must include exactly one SaslAuth section, found: %d",
-          request.auths_size()));
-    RETURN_NOT_OK(SendRpcError(ErrorStatusPB::FATAL_UNAUTHORIZED, s));
+    Status s = Status::NotAuthorized(
+        "SASL_INITIATE request must include exactly one SaslAuth section, found",
+        std::to_string(request.auths_size()));
+    RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_UNAUTHORIZED, s));
     return s;
   }
 
   const NegotiatePB::SaslAuth& auth = request.auths(0);
-  TRACE("SASL Server: Client requested to use mechanism: $0", auth.mechanism());
-
-  // Security issue to display this. Commented out but left for debugging purposes.
-  //DVLOG(3) << "SASL server: Client token: " << request.token();
+  TRACE("Client requested to use mechanism: $0", auth.mechanism());
 
   const char* server_out = nullptr;
   uint32_t server_out_len = 0;
-  TRACE("SASL Server: Calling sasl_server_start()");
+  TRACE("Calling sasl_server_start()");
 
   Status s = WrapSaslCall(sasl_conn_.get(), [&]() {
       return sasl_server_start(
@@ -364,55 +369,41 @@ Status SaslServer::HandleInitiateRequest(const NegotiatePB& request) {
           &server_out,              // The output of the SASL library, might not be NULL terminated
           &server_out_len);         // Output len.
     });
+
   if (PREDICT_FALSE(!s.ok() && !s.IsIncomplete())) {
-    RETURN_NOT_OK(SendRpcError(ErrorStatusPB::FATAL_UNAUTHORIZED, s));
+    RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_UNAUTHORIZED, s));
     return s;
   }
+
   negotiated_mech_ = SaslMechanism::value_of(auth.mechanism());
 
   // We have a valid mechanism match
   if (s.ok()) {
-    nego_ok_ = true;
-    RETURN_NOT_OK(SendSuccessResponse(server_out, server_out_len));
+    RETURN_NOT_OK(SendSaslSuccess(server_out, server_out_len));
   } else { // s.IsIncomplete() (equivalent to SASL_CONTINUE)
-    RETURN_NOT_OK(SendChallengeResponse(server_out, server_out_len));
+    RETURN_NOT_OK(SendSaslChallenge(server_out, server_out_len));
   }
-  return Status::OK();
-}
-
-Status SaslServer::SendChallengeResponse(const char* challenge, unsigned clen) {
-  NegotiatePB response;
-  response.set_step(NegotiatePB::SASL_CHALLENGE);
-  response.mutable_token()->assign(challenge, clen);
-  TRACE("SASL Server: Sending SASL_CHALLENGE response to client");
-  RETURN_NOT_OK(SendNegotiatePB(response));
-  return Status::OK();
+  return s;
 }
 
-Status SaslServer::SendSuccessResponse(const char* token, unsigned tlen) {
-  NegotiatePB response;
-  response.set_step(NegotiatePB::SASL_SUCCESS);
-  if (PREDICT_FALSE(tlen > 0)) {
-    response.mutable_token()->assign(token, tlen);
+Status ServerNegotiation::HandleSaslResponse(const NegotiatePB& request) {
+  if (PREDICT_FALSE(request.step() != NegotiatePB::SASL_RESPONSE)) {
+    Status s =  Status::NotAuthorized("expected SASL_RESPONSE step",
+                                      NegotiatePB::NegotiateStep_Name(request.step()));
+    RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_UNAUTHORIZED, s));
+    return s;
   }
-  TRACE("SASL Server: Sending SASL_SUCCESS response to client");
-  RETURN_NOT_OK(SendNegotiatePB(response));
-  return Status::OK();
-}
-
-
-Status SaslServer::HandleResponseRequest(const NegotiatePB& request) {
-  TRACE("SASL Server: Received RESPONSE request from client");
+  TRACE("Received SASL_RESPONSE request from client");
 
   if (!request.has_token()) {
-    Status s = Status::InvalidArgument("No token in SASL_RESPONSE from client");
-    RETURN_NOT_OK(SendRpcError(ErrorStatusPB::FATAL_UNAUTHORIZED, s));
+    Status s = Status::NotAuthorized("no token in SASL_RESPONSE from client");
+    RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_UNAUTHORIZED, s));
     return s;
   }
 
   const char* server_out = nullptr;
   uint32_t server_out_len = 0;
-  TRACE("SASL Server: Calling sasl_server_step()");
+  TRACE("Calling sasl_server_step()");
   Status s = WrapSaslCall(sasl_conn_.get(), [&]() {
       return sasl_server_step(
           sasl_conn_.get(),         // The SASL connection context created by init()
@@ -421,78 +412,76 @@ Status SaslServer::HandleResponseRequest(const NegotiatePB& request) {
           &server_out,              // The output of the SASL library, might not be NULL terminated
           &server_out_len);         // Output len
     });
-  if (!s.ok() && !s.IsIncomplete()) {
-    RETURN_NOT_OK(SendRpcError(ErrorStatusPB::FATAL_UNAUTHORIZED, s));
-    return s;
-  }
 
   if (s.ok()) {
-    nego_ok_ = true;
-    RETURN_NOT_OK(SendSuccessResponse(server_out, server_out_len));
-  } else { // s.IsIncomplete() (equivalent to SASL_CONTINUE)
-    RETURN_NOT_OK(SendChallengeResponse(server_out, server_out_len));
+    return SendSaslSuccess(server_out, server_out_len);
   }
-  return Status::OK();
+  if (s.IsIncomplete()) {
+    return SendSaslChallenge(server_out, server_out_len);
+  }
+  RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_UNAUTHORIZED, s));
+  return s;
 }
 
-int SaslServer::GetOptionCb(const char* plugin_name, const char* option,
-                            const char** result, unsigned* len) {
-  return helper_.GetOptionCb(plugin_name, option, result, len);
+Status ServerNegotiation::SendSaslChallenge(const char* challenge, unsigned clen) {
+  NegotiatePB response;
+  response.set_step(NegotiatePB::SASL_CHALLENGE);
+  response.mutable_token()->assign(challenge, clen);
+  RETURN_NOT_OK(SendNegotiatePB(response));
+  return Status::Incomplete("");
 }
 
-int SaslServer::PlainAuthCb(sasl_conn_t * /*conn*/, const char * /*user*/, const char * /*pass*/,
-                            unsigned /*passlen*/, struct propctx * /*propctx*/) {
-  TRACE("SASL Server: Received PLAIN auth.");
-  if (PREDICT_FALSE(!helper_.IsPlainEnabled())) {
-    LOG(DFATAL) << "Password authentication callback called while PLAIN auth disabled";
-    return SASL_BADPARAM;
+Status ServerNegotiation::SendSaslSuccess(const char* token, unsigned tlen) {
+  NegotiatePB response;
+  response.set_step(NegotiatePB::SASL_SUCCESS);
+  if (PREDICT_FALSE(tlen > 0)) {
+    response.mutable_token()->assign(token, tlen);
   }
-  // We always allow PLAIN authentication to succeed.
-  return SASL_OK;
+  RETURN_NOT_OK(SendNegotiatePB(response));
+  return Status::OK();
 }
 
-Status SaslServer::PreflightCheckGSSAPI(const string& app_name) {
-  // TODO(todd): the error messages that come from this function on el6
-  // are relatively useless due to the following krb5 bug:
-  // http://krbdev.mit.edu/rt/Ticket/Display.html?id=6973
-  // This may not be useful anymore given the keytab login that happens
-  // in security/init.cc.
+Status ServerNegotiation::RecvConnectionContext(faststring* recv_buf) {
+  TRACE("Waiting for connection context");
+  RequestHeader header;
+  Slice param_buf;
+  RETURN_NOT_OK(ReceiveFramedMessageBlocking(socket(), recv_buf, &header, &param_buf, deadline_));
+  DCHECK(header.IsInitialized());
 
-  // Initialize a SaslServer with a null socket, and enable
-  // only GSSAPI.
-  //
-  // We aren't going to actually send/receive any messages, but
-  // this makes it easier to reuse the initialization code.
-  SaslServer server(app_name, nullptr);
-  Status s = server.EnableGSSAPI();
-  if (!s.ok()) {
-    return Status::RuntimeError(s.message());
+  if (header.call_id() != kConnectionContextCallId) {
+    return Status::NotAuthorized("expected ConnectionContext callid, received",
+                                 std::to_string(header.call_id()));
   }
 
-  RETURN_NOT_OK(server.Init(app_name));
+  ConnectionContextPB conn_context;
+  if (!conn_context.ParseFromArray(param_buf.data(), param_buf.size())) {
+    return Status::NotAuthorized("invalid ConnectionContextPB message, missing fields",
+                                 conn_context.InitializationErrorString());
+  }
 
-  // Start the SASL server as if we were accepting a connection.
-  const char* server_out = nullptr; // ignored
-  uint32_t server_out_len = 0;
-  s = WrapSaslCall(server.sasl_conn_.get(), [&]() {
-      return sasl_server_start(
-          server.sasl_conn_.get(),
-          kSaslMechGSSAPI,
-          "", 0,  // Pass a 0-length token.
-          &server_out, &server_out_len);
-    });
+  // Currently none of the fields of the connection context are used.
+  return Status::OK();
+}
 
-  // We expect 'Incomplete' status to indicate that the first step of negotiation
-  // was correct.
-  if (s.IsIncomplete()) return Status::OK();
+int ServerNegotiation::GetOptionCb(const char* plugin_name,
+                                   const char* option,
+                                   const char** result,
+                                   unsigned* len) {
+  return helper_.GetOptionCb(plugin_name, option, result, len);
+}
 
-  string err_msg = s.message().ToString();
-  if (err_msg == "Permission denied") {
-    // For bad keytab permissions, we get a rather vague message. So,
-    // we make it more specific for better usability.
-    err_msg = "error accessing keytab: " + err_msg;
+int ServerNegotiation::PlainAuthCb(sasl_conn_t* /*conn*/,
+                                   const char*  /*user*/,
+                                   const char*  /*pass*/,
+                                   unsigned /*passlen*/,
+                                   struct propctx*  /*propctx*/) {
+  TRACE("Received PLAIN auth.");
+  if (PREDICT_FALSE(!helper_.IsPlainEnabled())) {
+    LOG(DFATAL) << "Password authentication callback called while PLAIN auth disabled";
+    return SASL_BADPARAM;
   }
-  return Status::RuntimeError(err_msg);
+  // We always allow PLAIN authentication to succeed.
+  return SASL_OK;
 }
 
 } // namespace rpc

http://git-wip-us.apache.org/repos/asf/kudu/blob/dc852535/src/kudu/rpc/server_negotiation.h
----------------------------------------------------------------------
diff --git a/src/kudu/rpc/server_negotiation.h b/src/kudu/rpc/server_negotiation.h
index 53f674d..089bc0b 100644
--- a/src/kudu/rpc/server_negotiation.h
+++ b/src/kudu/rpc/server_negotiation.h
@@ -15,15 +15,16 @@
 // specific language governing permissions and limitations
 // under the License.
 
-#ifndef KUDU_RPC_SASL_SERVER_H
-#define KUDU_RPC_SASL_SERVER_H
+#pragma once
 
+#include <memory>
 #include <set>
 #include <string>
 #include <vector>
 
 #include <sasl/sasl.h>
 
+#include "kudu/gutil/gscoped_ptr.h"
 #include "kudu/rpc/rpc_header.pb.h"
 #include "kudu/rpc/sasl_common.h"
 #include "kudu/rpc/sasl_helper.h"
@@ -37,23 +38,24 @@ class Slice;
 
 namespace rpc {
 
-using std::string;
-
-// Class for doing SASL negotiation with a SaslClient over a bidirectional socket.
+// Class for doing KRPC negotiation with a remote client over a bidirectional socket.
 // Operations on this class are NOT thread-safe.
-class SaslServer {
+class ServerNegotiation {
  public:
-  // Does not take ownership of 'socket'.
-  SaslServer(string app_name, Socket* socket);
+  // Creates a new server negotiation instance, taking ownership of the
+  // provided socket. After completing the negotiation process by setting the
+  // desired options and calling Negotiate((), the socket can be retrieved with
+  // release_socket().
+  explicit ServerNegotiation(std::unique_ptr<Socket> socket);
 
   // Enable PLAIN authentication.
   // Despite PLAIN authentication taking a username and password, we disregard
   // the password and use this as a "unauthenticated" mode.
-  // Must be called after Init().
+  // Must be called before Negotiate().
   Status EnablePlain();
 
   // Enable GSSAPI (Kerberos) authentication.
-  // Call after Init().
+  // Must be called before Negotiate().
   Status EnableGSSAPI();
 
   // Returns mechanism negotiated by this connection.
@@ -62,41 +64,49 @@ class SaslServer {
 
   // Returns the set of RPC system features supported by the remote client.
   // Must be called after Negotiate().
-  const std::set<RpcFeatureFlag>& client_features() const {
+  std::set<RpcFeatureFlag> client_features() const {
     return client_features_;
   }
 
+  // Returns the set of RPC system features supported by the remote client.
+  // Must be called after Negotiate().
+  // Subsequent calls to this method or client_features() will return an empty set.
+  std::set<RpcFeatureFlag> take_client_features() {
+    return std::move(client_features_);
+  }
+
   // Name of the user that was authenticated.
   // Must be called after a successful Negotiate().
   const std::string& authenticated_user() const;
 
   // Specify IP:port of local side of connection.
-  // Must be called before Init(). Required for some mechanisms.
+  // Must be called before Negotiate(). Required for some mechanisms.
   void set_local_addr(const Sockaddr& addr);
 
   // Specify IP:port of remote side of connection.
-  // Must be called before Init(). Required for some mechanisms.
+  // Must be called before Negotiate(). Required for some mechanisms.
   void set_remote_addr(const Sockaddr& addr);
 
   // Specify the fully-qualified domain name of the remote server.
-  // Must be called before Init(). Required for some mechanisms.
-  void set_server_fqdn(const string& domain_name);
+  // Must be called before Negotiate(). Required for some mechanisms.
+  void set_server_fqdn(const std::string& domain_name);
 
   // Set deadline for connection negotiation.
   void set_deadline(const MonoTime& deadline);
 
-  // Get deadline for connection negotiation.
-  const MonoTime& deadline() const { return deadline_; }
+  Socket* socket() const { return socket_.get(); }
 
-  // Initialize a new SASL server. Must be called before Negotiate().
-  // Returns OK on success, otherwise RuntimeError.
-  Status Init(const string& service_type);
+  // Returns the socket owned by this server negotiation. The caller will own
+  // the socket after this call, and the negotiation instance should no longer
+  // be used. Must be called after Negotiate().
+  std::unique_ptr<Socket> release_socket() { return std::move(socket_); }
 
-  // Begin negotiation with the SASL client on the other side of the fd socket
-  // that this server was constructed with.
-  // Returns OK on success.
-  // Otherwise, it may return NotAuthorized, NotSupported, or another non-OK status.
-  Status Negotiate();
+  // Negotiate with the remote client. Should only be called once per
+  // ServerNegotiation and socket instance, after all options have been set.
+  //
+  // Returns OK on success, otherwise may return NotAuthorized, NotSupported, or
+  // another non-OK status.
+  Status Negotiate() WARN_UNUSED_RESULT;
 
   // SASL callback for plugin options, supported mechanisms, etc.
   // Returns SASL_FAIL if the option is not handled, which does not fail the handshake.
@@ -109,72 +119,71 @@ class SaslServer {
 
   // Perform a "pre-flight check" that everything required to act as a Kerberos
   // server is properly set up.
-  static Status PreflightCheckGSSAPI(const std::string& app_name);
+  static Status PreflightCheckGSSAPI() WARN_UNUSED_RESULT;
 
  private:
-  // Parse and validate connection header.
-  Status ValidateConnectionHeader(faststring* recv_buf);
 
-  // Parse request body. If malformed, sends an error message to the client.
-  Status ParseNegotiatePB(const RequestHeader& header,
-                          const Slice& param_buf,
-                          NegotiatePB* request);
+  // Parse a negotiate request from the client, deserializing it into 'msg'.
+  // If the request is malformed, sends an error message to the client.
+  Status RecvNegotiatePB(NegotiatePB* msg, faststring* recv_buf) WARN_UNUSED_RESULT;
 
-  // Encode and send the specified SASL message to the client.
-  Status SendNegotiatePB(const NegotiatePB& msg);
+  // Encode and send the specified negotiate response message to the server.
+  Status SendNegotiatePB(const NegotiatePB& msg) WARN_UNUSED_RESULT;
 
   // Encode and send the specified RPC error message to the client.
   // Calls Status.ToString() for the embedded error message.
-  Status SendRpcError(ErrorStatusPB::RpcErrorCodePB code, const Status& err);
+  Status SendError(ErrorStatusPB::RpcErrorCodePB code, const Status& err) WARN_UNUSED_RESULT;
+
+  // Parse and validate connection header.
+  Status ValidateConnectionHeader(faststring* recv_buf) WARN_UNUSED_RESULT;
+
+  // Initialize the SASL server negotiation instance.
+  Status InitSaslServer() WARN_UNUSED_RESULT;
 
   // Handle case when client sends NEGOTIATE request.
-  Status HandleNegotiateRequest(const NegotiatePB& request);
+  Status HandleNegotiate(const NegotiatePB& request) WARN_UNUSED_RESULT;
 
   // Send a NEGOTIATE response to the client with the list of available mechanisms.
-  Status SendNegotiateResponse(const std::set<string>& server_mechs);
+  Status SendNegotiate(const std::set<std::string>& server_mechs) WARN_UNUSED_RESULT;
+
+  // Handle case when client sends SASL_INITIATE request.
+  // Returns Status::OK if the SASL negotiation is complete, or
+  // Status::Incomplete if a SASL_RESPONSE step is expected.
+  Status HandleSaslInitiate(const NegotiatePB& request) WARN_UNUSED_RESULT;
 
-  // Handle case when client sends INITIATE request.
-  Status HandleInitiateRequest(const NegotiatePB& request);
+  // Handle case when client sends SASL_RESPONSE request.
+  Status HandleSaslResponse(const NegotiatePB& request) WARN_UNUSED_RESULT;
 
-  // Send a CHALLENGE response to the client with a challenge token.
-  Status SendChallengeResponse(const char* challenge, unsigned clen);
+  // Send a SASL_CHALLENGE response to the client with a challenge token.
+  Status SendSaslChallenge(const char* challenge, unsigned clen) WARN_UNUSED_RESULT;
 
-  // Send a SUCCESS response to the client with an token (typically empty).
-  Status SendSuccessResponse(const char* token, unsigned tlen);
+  // Send a SASL_SUCCESS response to the client with an token (typically empty).
+  Status SendSaslSuccess(const char* token, unsigned tlen) WARN_UNUSED_RESULT;
 
-  // Handle case when client sends RESPONSE request.
-  Status HandleResponseRequest(const NegotiatePB& request);
+  // Receive and validate the ConnectionContextPB.
+  Status RecvConnectionContext(faststring* recv_buf) WARN_UNUSED_RESULT;
 
-  string app_name_;
-  Socket* sock_;
+  // The socket to the remote client.
+  std::unique_ptr<Socket> socket_;
+
+  // SASL state.
   std::vector<sasl_callback_t> callbacks_;
-  // The SASL connection object. This is initialized in Init() and
-  // freed after Negotiate() completes (regardless whether it was successful).
   gscoped_ptr<sasl_conn_t, SaslDeleter> sasl_conn_;
   SaslHelper helper_;
 
-  // The set of features that the client supports. Filled in
-  // after we receive the NEGOTIATE request from the client.
+  // The set of features supported by the client. Filled in during negotiation.
   std::set<RpcFeatureFlag> client_features_;
 
-  // The successfully-authenticated user, if applicable.
-  string authenticated_user_;
-
-  SaslNegotiationState::Type server_state_;
+  // The successfully-authenticated user, if applicable. Filled in during
+  // negotiation.
+  std::string authenticated_user_;
 
-  // The mechanism we negotiated with the client.
+  // The SASL mechanism. Filled in during negotiation.
   SaslMechanism::Type negotiated_mech_;
 
-  // Intra-negotiation state.
-  bool nego_ok_;  // During negotiation: did we get a SASL_OK response from the SASL library?
-
   // Negotiation timeout deadline.
   MonoTime deadline_;
-
-  DISALLOW_COPY_AND_ASSIGN(SaslServer);
 };
 
 } // namespace rpc
 } // namespace kudu
-
-#endif  // KUDU_RPC_SASL_SERVER_H

http://git-wip-us.apache.org/repos/asf/kudu/blob/dc852535/src/kudu/security/ssl_socket.cc
----------------------------------------------------------------------
diff --git a/src/kudu/security/ssl_socket.cc b/src/kudu/security/ssl_socket.cc
index 126d2f1..c2725c2 100644
--- a/src/kudu/security/ssl_socket.cc
+++ b/src/kudu/security/ssl_socket.cc
@@ -44,6 +44,7 @@ SSLSocket::SSLSocket(int fd, SSL* ssl, bool is_server) :
 }
 
 SSLSocket::~SSLSocket() {
+  WARN_NOT_OK(Close(), "unable to close SSL socket in destructor");
 }
 
 Status SSLSocket::DoHandshake() {
@@ -154,7 +155,10 @@ Status SSLSocket::Recv(uint8_t *buf, int32_t amt, int32_t *nread) {
 }
 
 Status SSLSocket::Close() {
-  CHECK(ssl_);
+  if (!ssl_) {
+    // Socket is already closed.
+    return Status::OK();
+  }
   ERR_clear_error();
   errno = 0;
   int32_t ret = SSL_shutdown(ssl_);

http://git-wip-us.apache.org/repos/asf/kudu/blob/dc852535/src/kudu/security/ssl_socket.h
----------------------------------------------------------------------
diff --git a/src/kudu/security/ssl_socket.h b/src/kudu/security/ssl_socket.h
index 4f67d48..f7570fb 100644
--- a/src/kudu/security/ssl_socket.h
+++ b/src/kudu/security/ssl_socket.h
@@ -32,9 +32,10 @@ class Sockaddr;
 
 class SSLSocket : public Socket {
  public:
+
   SSLSocket(int fd, SSL* ssl, bool is_server);
 
-  ~SSLSocket();
+  ~SSLSocket() override;
 
   // Do the SSL handshake as a client or a server and verify that the credentials were correctly
   // verified.


[2/2] kudu git commit: TLS-negotiation [6/n]: Refactor RPC negotiation

Posted by da...@apache.org.
TLS-negotiation [6/n]: Refactor RPC negotiation

The KRPC negotiation process now encompasses more than just SASL
authentication. Feature flags are negotiated, and soon TLS encryption
will be negotiated as well. In anticipation of TLS negotiation, this
commit refactors the SaslClient and SaslServer classes to better reflect
their role.  Additionally, the Connection class now has less
responsibility and knowledge of negotiation. Instead, the existing logic
in negotiation.cc has been beefed up to serve as the glue between the
ClientNegotiation and ServerNegotiation classes (which have no knowledge
of Connection), and Connection (which now has no knowledge of
ClientNegotiation/ServerNegotiation). Hopefully this will lead to more
maintainable code in the long term. It is expected that this refactor
will make TLS-negotiation an easy add-on.

Change-Id: I567b62f3341a1e74342c30c76b63f2ca5d7990bd
Reviewed-on: http://gerrit.cloudera.org:8080/5760
Tested-by: Kudu Jenkins
Reviewed-by: Todd Lipcon <to...@apache.org>
Reviewed-by: Alexey Serbin <as...@cloudera.com>


Project: http://git-wip-us.apache.org/repos/asf/kudu/repo
Commit: http://git-wip-us.apache.org/repos/asf/kudu/commit/dc852535
Tree: http://git-wip-us.apache.org/repos/asf/kudu/tree/dc852535
Diff: http://git-wip-us.apache.org/repos/asf/kudu/diff/dc852535

Branch: refs/heads/master
Commit: dc8525358f48df4142e3347b389612704a9312d1
Parents: b9aa5dd
Author: Dan Burkert <da...@apache.org>
Authored: Fri Jan 20 15:17:44 2017 -0800
Committer: Dan Burkert <da...@apache.org>
Committed: Thu Jan 26 02:19:25 2017 +0000

----------------------------------------------------------------------
 src/kudu/rpc/client_negotiation.cc | 408 +++++++++++-------------
 src/kudu/rpc/client_negotiation.h  | 144 +++++----
 src/kudu/rpc/connection.cc         |  88 ++----
 src/kudu/rpc/connection.h          |  58 ++--
 src/kudu/rpc/messenger.cc          |   8 +-
 src/kudu/rpc/negotiation-test.cc   | 218 ++++++-------
 src/kudu/rpc/negotiation.cc        | 166 +++++-----
 src/kudu/rpc/negotiation.h         |   6 +-
 src/kudu/rpc/reactor.cc            |  87 +++---
 src/kudu/rpc/reactor.h             |  23 +-
 src/kudu/rpc/rpc-test.cc           |   4 +-
 src/kudu/rpc/sasl_common.cc        |  47 +--
 src/kudu/rpc/sasl_common.h         |   2 +-
 src/kudu/rpc/sasl_helper.cc        |  71 +----
 src/kudu/rpc/sasl_helper.h         |  65 ++--
 src/kudu/rpc/server_negotiation.cc | 535 ++++++++++++++++----------------
 src/kudu/rpc/server_negotiation.h  | 135 ++++----
 src/kudu/security/ssl_socket.cc    |   6 +-
 src/kudu/security/ssl_socket.h     |   3 +-
 19 files changed, 951 insertions(+), 1123 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/kudu/blob/dc852535/src/kudu/rpc/client_negotiation.cc
----------------------------------------------------------------------
diff --git a/src/kudu/rpc/client_negotiation.cc b/src/kudu/rpc/client_negotiation.cc
index 0eb96b8..a7c6c63 100644
--- a/src/kudu/rpc/client_negotiation.cc
+++ b/src/kudu/rpc/client_negotiation.cc
@@ -20,6 +20,7 @@
 #include <string.h>
 
 #include <map>
+#include <memory>
 #include <set>
 #include <string>
 
@@ -29,7 +30,6 @@
 #include "kudu/gutil/endian.h"
 #include "kudu/gutil/map-util.h"
 #include "kudu/gutil/stl_util.h"
-#include "kudu/gutil/stringprintf.h"
 #include "kudu/gutil/strings/join.h"
 #include "kudu/gutil/strings/substitute.h"
 #include "kudu/gutil/strings/util.h"
@@ -51,21 +51,28 @@ namespace rpc {
 using std::map;
 using std::set;
 using std::string;
-
-static int SaslClientGetoptCb(void* sasl_client, const char* plugin_name, const char* option,
-                       const char** result, unsigned* len) {
-  return static_cast<SaslClient*>(sasl_client)
-    ->GetOptionCb(plugin_name, option, result, len);
+using std::unique_ptr;
+
+static int ClientNegotiationGetoptCb(ClientNegotiation* client_negotiation,
+                                     const char* plugin_name,
+                                     const char* option,
+                                     const char** result,
+                                     unsigned* len) {
+  return client_negotiation->GetOptionCb(plugin_name, option, result, len);
 }
 
-static int SaslClientSimpleCb(void *sasl_client, int id,
-                       const char **result, unsigned *len) {
-  return static_cast<SaslClient*>(sasl_client)->SimpleCb(id, result, len);
+static int ClientNegotiationSimpleCb(ClientNegotiation* client_negotiation,
+                                     int id,
+                                     const char** result,
+                                     unsigned* len) {
+  return client_negotiation->SimpleCb(id, result, len);
 }
 
-static int SaslClientSecretCb(sasl_conn_t* conn, void *sasl_client, int id,
-                       sasl_secret_t** psecret) {
-  return static_cast<SaslClient*>(sasl_client)->SecretCb(conn, id, psecret);
+static int ClientNegotiationSecretCb(sasl_conn_t* conn,
+                                     ClientNegotiation* client_negotiation,
+                                     int id,
+                                     sasl_secret_t** psecret) {
+  return client_negotiation->SecretCb(conn, id, psecret);
 }
 
 // Return an appropriately-typed Status object based on an ErrorStatusPB returned
@@ -85,187 +92,166 @@ static Status StatusFromRpcError(const ErrorStatusPB& error) {
   }
 }
 
-SaslClient::SaslClient(string app_name, Socket* socket)
-    : app_name_(std::move(app_name)),
-      sock_(socket),
+ClientNegotiation::ClientNegotiation(unique_ptr<Socket> socket)
+    : socket_(std::move(socket)),
       helper_(SaslHelper::CLIENT),
-      client_state_(SaslNegotiationState::NEW),
       negotiated_mech_(SaslMechanism::INVALID),
       deadline_(MonoTime::Max()) {
   callbacks_.push_back(SaslBuildCallback(SASL_CB_GETOPT,
-      reinterpret_cast<int (*)()>(&SaslClientGetoptCb), this));
+      reinterpret_cast<int (*)()>(&ClientNegotiationGetoptCb), this));
   callbacks_.push_back(SaslBuildCallback(SASL_CB_AUTHNAME,
-      reinterpret_cast<int (*)()>(&SaslClientSimpleCb), this));
+      reinterpret_cast<int (*)()>(&ClientNegotiationSimpleCb), this));
   callbacks_.push_back(SaslBuildCallback(SASL_CB_PASS,
-      reinterpret_cast<int (*)()>(&SaslClientSecretCb), this));
+      reinterpret_cast<int (*)()>(&ClientNegotiationSecretCb), this));
   callbacks_.push_back(SaslBuildCallback(SASL_CB_LIST_END, nullptr, nullptr));
 }
 
-Status SaslClient::EnablePlain(const string& user, const string& pass) {
-  DCHECK_EQ(client_state_, SaslNegotiationState::NEW);
+Status ClientNegotiation::EnablePlain(const string& user, const string& pass) {
   RETURN_NOT_OK(helper_.EnablePlain());
   plain_auth_user_ = user;
   plain_pass_ = pass;
   return Status::OK();
 }
 
-Status SaslClient::EnableGSSAPI() {
-  DCHECK_EQ(client_state_, SaslNegotiationState::NEW);
+Status ClientNegotiation::EnableGSSAPI() {
   return helper_.EnableGSSAPI();
 }
 
-SaslMechanism::Type SaslClient::negotiated_mechanism() const {
-  DCHECK_EQ(client_state_, SaslNegotiationState::NEGOTIATED);
+SaslMechanism::Type ClientNegotiation::negotiated_mechanism() const {
   return negotiated_mech_;
 }
 
-void SaslClient::set_local_addr(const Sockaddr& addr) {
-  DCHECK_EQ(client_state_, SaslNegotiationState::NEW);
+void ClientNegotiation::set_local_addr(const Sockaddr& addr) {
   helper_.set_local_addr(addr);
 }
 
-void SaslClient::set_remote_addr(const Sockaddr& addr) {
-  DCHECK_EQ(client_state_, SaslNegotiationState::NEW);
+void ClientNegotiation::set_remote_addr(const Sockaddr& addr) {
   helper_.set_remote_addr(addr);
 }
 
-void SaslClient::set_server_fqdn(const string& domain_name) {
-  DCHECK_EQ(client_state_, SaslNegotiationState::NEW);
+void ClientNegotiation::set_server_fqdn(const string& domain_name) {
   helper_.set_server_fqdn(domain_name);
 }
 
-void SaslClient::set_deadline(const MonoTime& deadline) {
-  DCHECK_NE(client_state_, SaslNegotiationState::NEGOTIATED);
+void ClientNegotiation::set_deadline(const MonoTime& deadline) {
   deadline_ = deadline;
 }
 
-// calls sasl_client_init() and sasl_client_new()
-Status SaslClient::Init(const string& service_type) {
-  RETURN_NOT_OK(SaslInit(app_name_.c_str()));
-
-  // Ensure we are not called more than once.
-  if (client_state_ != SaslNegotiationState::NEW) {
-    return Status::IllegalState("Init() may only be called once per SaslClient object.");
-  }
-
-  // TODO(unknown): Support security flags.
-  unsigned secflags = 0;
-
-  sasl_conn_t* sasl_conn = nullptr;
-  Status s = WrapSaslCall(nullptr /* no conn */, [&]() {
-      return sasl_client_new(
-          service_type.c_str(),         // Registered name of the service using SASL. Required.
-          helper_.server_fqdn(),        // The fully qualified domain name of the remote server.
-          helper_.local_addr_string(),  // Local and remote IP address strings. (NULL disables
-          helper_.remote_addr_string(), //   mechanisms which require this info.)
-          &callbacks_[0],               // Connection-specific callbacks.
-          secflags,                     // Security flags.
-          &sasl_conn);
-    });
-  if (!s.ok()) {
-    return Status::RuntimeError("Unable to create new SASL client",
-                                s.message());
-  }
-  sasl_conn_.reset(sasl_conn);
-
-  client_state_ = SaslNegotiationState::INITIALIZED;
-  return Status::OK();
-}
-
-Status SaslClient::Negotiate() {
-  // After negotiation, we no longer need the SASL library object, so
-  // may as well free its memory since the connection may be long-lived.
-  // Additionally, this works around a SEGV seen at process shutdown time:
-  // if we still have SASL objects retained by Reactor when the process
-  // is exiting, the SASL libraries may start destructing global state
-  // and cause a crash when we sasl_dispose the connection.
-  auto cleanup = MakeScopedCleanup([&]() {
-      sasl_conn_.reset();
-    });
-  TRACE("Called SaslClient::Negotiate()");
-
-  // Ensure we called exactly once, and in the right order.
-  if (client_state_ == SaslNegotiationState::NEW) {
-    return Status::IllegalState("SaslClient: Init() must be called before calling Negotiate()");
-  }
-  if (client_state_ == SaslNegotiationState::NEGOTIATED) {
-    return Status::IllegalState("SaslClient: Negotiate() may only be called once per object.");
-  }
+Status ClientNegotiation::Negotiate() {
+  TRACE("Beginning negotiation");
 
   // Ensure we can use blocking calls on the socket during negotiation.
-  RETURN_NOT_OK(EnsureBlockingMode(sock_));
+  RETURN_NOT_OK(EnsureBlockingMode(socket_.get()));
 
-  // Start by asking the server for a list of available auth mechanisms.
-  RETURN_NOT_OK(SendNegotiateMessage());
+  // Step 1: send the connection header.
+  RETURN_NOT_OK(SendConnectionHeader());
 
   faststring recv_buf;
-  nego_ok_ = false;
-
-  // We set nego_ok_ = true when the SASL library returns SASL_OK to us.
-  // We set nego_response_expected_ = true each time we send a request to the server.
-  while (!nego_ok_ || nego_response_expected_) {
-    ResponseHeader header;
-    Slice param_buf;
-    RETURN_NOT_OK(ReceiveFramedMessageBlocking(sock_, &recv_buf, &header, &param_buf, deadline_));
-    nego_response_expected_ = false;
 
+  { // Step 2: send and receive the NEGOTIATE step messages.
+    RETURN_NOT_OK(SendNegotiate());
     NegotiatePB response;
-    RETURN_NOT_OK(ParseNegotiatePB(header, param_buf, &response));
+    RETURN_NOT_OK(RecvNegotiatePB(&response, &recv_buf));
+    RETURN_NOT_OK(HandleNegotiate(response));
+  }
 
+  // Step 3: SASL negotiation.
+  RETURN_NOT_OK(InitSaslClient());
+  RETURN_NOT_OK(SendSaslInitiate());
+  for (bool cont = true; cont; ) {
+    NegotiatePB response;
+    RETURN_NOT_OK(RecvNegotiatePB(&response, &recv_buf));
+    Status s;
     switch (response.step()) {
-      // NEGOTIATE: Server has sent us its list of supported SASL mechanisms.
-      case NegotiatePB::NEGOTIATE:
-        RETURN_NOT_OK(HandleNegotiateResponse(response));
-        break;
-
       // SASL_CHALLENGE: Server sent us a follow-up to an SASL_INITIATE or SASL_RESPONSE request.
       case NegotiatePB::SASL_CHALLENGE:
-        RETURN_NOT_OK(HandleChallengeResponse(response));
+        RETURN_NOT_OK(HandleSaslChallenge(response));
         break;
-
       // SASL_SUCCESS: Server has accepted our authentication request. Negotiation successful.
       case NegotiatePB::SASL_SUCCESS:
-        RETURN_NOT_OK(HandleSuccessResponse(response));
+        cont = false;
         break;
-
-      // Client sent us some unsupported SASL response.
       default:
-        LOG(ERROR) << "SASL Client: Received unsupported response from server";
-        return Status::InvalidArgument("RPC client doesn't support Negotiate step",
-                                       NegotiatePB::NegotiateStep_Name(response.step()));
+        return Status::NotAuthorized("expected SASL_CHALLENGE or SASL_SUCCESS step",
+                                     NegotiatePB::NegotiateStep_Name(response.step()));
     }
   }
 
-  TRACE("SASL Client: Successful negotiation");
-  client_state_ = SaslNegotiationState::NEGOTIATED;
+  // Step 4: Send connection context.
+  RETURN_NOT_OK(SendConnectionContext());
+
+  TRACE("Negotiation successful");
   return Status::OK();
 }
 
-Status SaslClient::SendNegotiatePB(const NegotiatePB& msg) {
-  DCHECK_NE(client_state_, SaslNegotiationState::NEW)
-      << "Must not send Negotiate messages before calling Init()";
-  DCHECK_NE(client_state_, SaslNegotiationState::NEGOTIATED)
-      << "Must not send Negotiate messages after negotiation succeeds";
-
-  // Create header with SASL-specific callId
+Status ClientNegotiation::SendNegotiatePB(const NegotiatePB& msg) {
   RequestHeader header;
   header.set_call_id(kNegotiateCallId);
-  return helper_.SendNegotiatePB(sock_, header, msg, deadline_);
+
+  DCHECK(socket_);
+  DCHECK(msg.IsInitialized()) << "message must be initialized";
+  DCHECK(msg.has_step()) << "message must have a step";
+
+  TRACE("Sending $0 NegotiatePB request", NegotiatePB::NegotiateStep_Name(msg.step()));
+  return SendFramedMessageBlocking(socket(), header, msg, deadline_);
 }
 
-Status SaslClient::ParseNegotiatePB(const ResponseHeader& header,
-                                    const Slice& param_buf,
-                                    NegotiatePB* response) {
-  RETURN_NOT_OK(helper_.SanityCheckNegotiateCallId(header.call_id()));
+Status ClientNegotiation::RecvNegotiatePB(NegotiatePB* msg, faststring* buffer) {
+  ResponseHeader header;
+  Slice param_buf;
+  RETURN_NOT_OK(ReceiveFramedMessageBlocking(socket(), buffer, &header, &param_buf, deadline_));
+  RETURN_NOT_OK(helper_.CheckNegotiateCallId(header.call_id()));
 
   if (header.is_error()) {
     return ParseError(param_buf);
   }
 
-  return helper_.ParseNegotiatePB(param_buf, response);
+  RETURN_NOT_OK(helper_.ParseNegotiatePB(param_buf, msg));
+  TRACE("Received $0 NegotiatePB response", NegotiatePB::NegotiateStep_Name(msg->step()));
+  return Status::OK();
+}
+
+Status ClientNegotiation::ParseError(const Slice& err_data) {
+  ErrorStatusPB error;
+  if (!error.ParseFromArray(err_data.data(), err_data.size())) {
+    return Status::IOError("invalid error response, missing fields",
+                           error.InitializationErrorString());
+  }
+  Status s = StatusFromRpcError(error);
+  TRACE("Received error response from server: $0", s.ToString());
+  return s;
+}
+
+Status ClientNegotiation::SendConnectionHeader() {
+  const uint8_t buflen = kMagicNumberLength + kHeaderFlagsLength;
+  uint8_t buf[buflen];
+  serialization::SerializeConnHeader(buf);
+  size_t nsent;
+  return socket()->BlockingWrite(buf, buflen, &nsent, deadline_);
 }
 
-Status SaslClient::SendNegotiateMessage() {
+Status ClientNegotiation::InitSaslClient() {
+  RETURN_NOT_OK(SaslInit());
+
+  // TODO(unknown): Support security flags.
+  unsigned secflags = 0;
+
+  sasl_conn_t* sasl_conn = nullptr;
+  RETURN_NOT_OK_PREPEND(WrapSaslCall(nullptr /* no conn */, [&]() {
+      return sasl_client_new(
+          kSaslProtoName,               // Registered name of the service using SASL. Required.
+          helper_.server_fqdn(),        // The fully qualified domain name of the remote server.
+          helper_.local_addr_string(),  // Local and remote IP address strings. (NULL disables
+          helper_.remote_addr_string(), //   mechanisms which require this info.)
+          &callbacks_[0],               // Connection-specific callbacks.
+          secflags,                     // Security flags.
+          &sasl_conn);
+    }), "Unable to create new SASL client");
+  sasl_conn_.reset(sasl_conn);
+  return Status::OK();
+}
+
+Status ClientNegotiation::SendNegotiate() {
   NegotiatePB msg;
   msg.set_step(NegotiatePB::NEGOTIATE);
 
@@ -274,59 +260,29 @@ Status SaslClient::SendNegotiateMessage() {
     msg.add_supported_features(feature);
   }
 
-  TRACE("SASL Client: Sending NEGOTIATE request to server.");
-  RETURN_NOT_OK(SendNegotiatePB(msg));
-  nego_response_expected_ = true;
-  return Status::OK();
-}
-
-Status SaslClient::SendInitiateMessage(const NegotiatePB_SaslAuth& auth,
-    const char* init_msg, unsigned init_msg_len) {
-  NegotiatePB msg;
-  msg.set_step(NegotiatePB::SASL_INITIATE);
-  msg.mutable_token()->assign(init_msg, init_msg_len);
-  msg.add_auths()->CopyFrom(auth);
-  TRACE("SASL Client: Sending SASL_INITIATE request to server.");
   RETURN_NOT_OK(SendNegotiatePB(msg));
-  nego_response_expected_ = true;
-  return Status::OK();
-}
-
-Status SaslClient::SendResponseMessage(const char* resp_msg, unsigned resp_msg_len) {
-  NegotiatePB reply;
-  reply.set_step(NegotiatePB::SASL_RESPONSE);
-  reply.mutable_token()->assign(resp_msg, resp_msg_len);
-  TRACE("SASL Client: Sending SASL_RESPONSE request to server.");
-  RETURN_NOT_OK(SendNegotiatePB(reply));
-  nego_response_expected_ = true;
   return Status::OK();
 }
 
-Status SaslClient::DoSaslStep(const string& in, const char** out, unsigned* out_len) {
-  TRACE("SASL Client: Calling sasl_client_step()");
-  Status s = WrapSaslCall(sasl_conn_.get(), [&]() {
-      return sasl_client_step(sasl_conn_.get(), in.c_str(), in.length(), nullptr, out, out_len);
-    });
-  if (s.ok()) {
-    nego_ok_ = true;
+Status ClientNegotiation::HandleNegotiate(const NegotiatePB& response) {
+  if (PREDICT_FALSE(response.step() != NegotiatePB::NEGOTIATE)) {
+    return Status::NotAuthorized("expected NEGOTIATE step",
+                                 NegotiatePB::NegotiateStep_Name(response.step()));
   }
-  return s;
-}
+  TRACE("Received NEGOTIATE response from server");
 
-Status SaslClient::HandleNegotiateResponse(const NegotiatePB& response) {
-  TRACE("SASL Client: Received NEGOTIATE response from server");
   // Fill in the set of features supported by the server.
   for (int flag : response.supported_features()) {
     // We only add the features that our local build knows about.
     RpcFeatureFlag feature_flag = RpcFeatureFlag_IsValid(flag) ?
                                   static_cast<RpcFeatureFlag>(flag) : UNKNOWN;
-    if (ContainsKey(kSupportedClientRpcFeatureFlags, feature_flag)) {
+    if (feature_flag != UNKNOWN) {
       server_features_.insert(feature_flag);
     }
   }
 
-  // Build a map of the mechanisms offered by the server.
-  const set<string>& local_mechs = helper_.LocalMechs();
+  // Build a map of the SASL mechanisms offered by the server.
+  const set<string>& enabled_mechs = helper_.EnabledMechs();
   set<string> server_mechs;
   map<string, NegotiatePB::SaslAuth> server_mech_map;
   for (const NegotiatePB::SaslAuth& auth : response.auths()) {
@@ -334,21 +290,26 @@ Status SaslClient::HandleNegotiateResponse(const NegotiatePB& response) {
     server_mech_map[mech] = auth;
     server_mechs.insert(mech);
   }
+
   // Determine which server mechs are also enabled by the client.
   // Cyrus SASL 2.1.25 and later supports doing this set intersection via
   // the 'client_mech_list' option, but that version is not available on
   // RHEL 6, so we have to do it manually.
-  set<string> matching_mechs = STLSetIntersection(local_mechs, server_mechs);
+  common_mechs_ = STLSetIntersection(enabled_mechs, server_mechs);
 
-  if (matching_mechs.empty() &&
+  if (common_mechs_.empty() &&
       ContainsKey(server_mechs, kSaslMechGSSAPI) &&
-      !ContainsKey(local_mechs, kSaslMechGSSAPI)) {
+      !ContainsKey(enabled_mechs, kSaslMechGSSAPI)) {
     return Status::NotAuthorized("server requires GSSAPI (Kerberos) authentication and "
                                  "client was missing the required SASL module");
   }
 
-  string matching_mechs_str = JoinElements(matching_mechs, " ");
-  TRACE("SASL Client: Matching mech list: $0", matching_mechs_str);
+  return Status::OK();
+}
+
+Status ClientNegotiation::SendSaslInitiate() {
+  string matching_mechs_str = JoinElements(common_mechs_, " ");
+  TRACE("Matching mech list: $0", matching_mechs_str);
 
   const char* init_msg = nullptr;
   unsigned init_msg_len = 0;
@@ -368,7 +329,7 @@ Status SaslClient::HandleNegotiateResponse(const NegotiatePB& response) {
    *  SASL_NOMECH   -- no mechanism meets requested properties
    *  SASL_INTERACT -- user interaction needed to fill in prompt_need list
    */
-  TRACE("SASL Client: Calling sasl_client_start()");
+  TRACE("Calling sasl_client_start()");
   Status s = WrapSaslCall(sasl_conn_.get(), [&]() {
       return sasl_client_start(
           sasl_conn_.get(),           // The SASL connection context created by init()
@@ -377,112 +338,101 @@ Status SaslClient::HandleNegotiateResponse(const NegotiatePB& response) {
           &init_msg,                  // Filled in on success.
           &init_msg_len,              // Filled in on success.
           &negotiated_mech);          // Filled in on success.
-    });
+  });
 
-  if (s.ok()) {
-    nego_ok_ = true;
-  } else if (!s.IsIncomplete()) {
+  if (PREDICT_FALSE(!s.IsIncomplete() && !s.ok())) {
     return s;
   }
 
   // The server matched one of our mechanisms.
-  NegotiatePB::SaslAuth* auth = FindOrNull(server_mech_map, negotiated_mech);
-  if (PREDICT_FALSE(auth == nullptr)) {
-    return Status::IllegalState("Unable to find auth in map, unexpected error", negotiated_mech);
-  }
   negotiated_mech_ = SaslMechanism::value_of(negotiated_mech);
 
-  RETURN_NOT_OK(SendInitiateMessage(*auth, init_msg, init_msg_len));
-  return Status::OK();
+  NegotiatePB msg;
+  msg.set_step(NegotiatePB::SASL_INITIATE);
+  msg.mutable_token()->assign(init_msg, init_msg_len);
+  msg.add_auths()->set_mechanism(negotiated_mech);
+  return SendNegotiatePB(msg);
 }
 
-Status SaslClient::HandleChallengeResponse(const NegotiatePB& response) {
-  TRACE("SASL Client: Received SASL_CHALLENGE response from server");
-  if (PREDICT_FALSE(nego_ok_)) {
-    LOG(DFATAL) << "Server sent SASL_CHALLENGE response after client library returned SASL_OK";
-  }
+Status ClientNegotiation::SendSaslResponse(const char* resp_msg, unsigned resp_msg_len) {
+  NegotiatePB reply;
+  reply.set_step(NegotiatePB::SASL_RESPONSE);
+  reply.mutable_token()->assign(resp_msg, resp_msg_len);
+  return SendNegotiatePB(reply);
+}
 
+Status ClientNegotiation::HandleSaslChallenge(const NegotiatePB& response) {
+  TRACE("Received SASL_CHALLENGE response from server");
   if (PREDICT_FALSE(!response.has_token())) {
-    return Status::InvalidArgument("No token in SASL_CHALLENGE response from server");
+    return Status::NotAuthorized("no token in SASL_CHALLENGE response from server");
   }
 
   const char* out = nullptr;
   unsigned out_len = 0;
   Status s = DoSaslStep(response.token(), &out, &out_len);
-  if (!s.ok() && !s.IsIncomplete()) {
+  if (PREDICT_FALSE(!s.IsIncomplete() && !s.ok())) {
     return s;
   }
-  RETURN_NOT_OK(SendResponseMessage(out, out_len));
-  return Status::OK();
+
+  return SendSaslResponse(out, out_len);
 }
 
-Status SaslClient::HandleSuccessResponse(const NegotiatePB& response) {
-  TRACE("SASL Client: Received SASL_SUCCESS response from server");
-  if (!nego_ok_) {
-    const char* out = nullptr;
-    unsigned out_len = 0;
-    Status s = DoSaslStep(response.token(), &out, &out_len);
-    if (s.IsIncomplete()) {
-      return Status::IllegalState("Server indicated successful authentication, but client "
-                                  "was not complete");
-    }
-    RETURN_NOT_OK(s);
-    if (out_len > 0) {
-      return Status::IllegalState("SASL client library generated spurious token after SASL_SUCCESS",
-          string(out, out_len));
-    }
-    CHECK(nego_ok_);
-  }
-  return Status::OK();
+Status ClientNegotiation::DoSaslStep(const string& in, const char** out, unsigned* out_len) {
+  TRACE("Calling sasl_client_step()");
+
+  return WrapSaslCall(sasl_conn_.get(), [&]() {
+      return sasl_client_step(sasl_conn_.get(), in.c_str(), in.length(), nullptr, out, out_len);
+  });
 }
 
-// Parse error status message from raw bytes of an ErrorStatusPB.
-Status SaslClient::ParseError(const Slice& err_data) {
-  ErrorStatusPB error;
-  if (!error.ParseFromArray(err_data.data(), err_data.size())) {
-    return Status::IOError("Invalid error response, missing fields",
-        error.InitializationErrorString());
-  }
-  Status s = StatusFromRpcError(error);
-  TRACE("SASL Client: Received error response from server: $0", s.ToString());
-  return s;
+Status ClientNegotiation::SendConnectionContext() {
+  TRACE("Sending connection context");
+  RequestHeader header;
+  header.set_call_id(kConnectionContextCallId);
+
+  ConnectionContextPB conn_context;
+  // This field is deprecated but used by servers <Kudu 1.1. Newer server versions ignore
+  // this and use the SASL-provided username instead.
+  conn_context.mutable_deprecated_user_info()->set_real_user(
+      plain_auth_user_.empty() ? "cpp-client" : plain_auth_user_);
+  return SendFramedMessageBlocking(socket(), header, conn_context, deadline_);
 }
 
-int SaslClient::GetOptionCb(const char* plugin_name, const char* option,
+int ClientNegotiation::GetOptionCb(const char* plugin_name, const char* option,
                             const char** result, unsigned* len) {
   return helper_.GetOptionCb(plugin_name, option, result, len);
 }
 
 // Used for PLAIN.
 // SASL callback for SASL_CB_USER, SASL_CB_AUTHNAME, SASL_CB_LANGUAGE
-int SaslClient::SimpleCb(int id, const char** result, unsigned* len) {
+int ClientNegotiation::SimpleCb(int id, const char** result, unsigned* len) {
   if (PREDICT_FALSE(!helper_.IsPlainEnabled())) {
-    LOG(DFATAL) << "SASL Client: Simple callback called, but PLAIN auth is not enabled";
+    LOG(DFATAL) << "Simple callback called, but PLAIN auth is not enabled";
     return SASL_FAIL;
   }
   if (PREDICT_FALSE(result == nullptr)) {
-    LOG(DFATAL) << "SASL Client: result outparam is NULL";
+    LOG(DFATAL) << "result outparam is NULL";
     return SASL_BADPARAM;
   }
   switch (id) {
     // TODO(unknown): Support impersonation?
     // For impersonation, USER is the impersonated user, AUTHNAME is the "sudoer".
     case SASL_CB_USER:
-      TRACE("SASL Client: callback for SASL_CB_USER");
+      TRACE("callback for SASL_CB_USER");
       *result = plain_auth_user_.c_str();
       if (len != nullptr) *len = plain_auth_user_.length();
       break;
     case SASL_CB_AUTHNAME:
-      TRACE("SASL Client: callback for SASL_CB_AUTHNAME");
+      TRACE("callback for SASL_CB_AUTHNAME");
       *result = plain_auth_user_.c_str();
       if (len != nullptr) *len = plain_auth_user_.length();
       break;
     case SASL_CB_LANGUAGE:
-      LOG(DFATAL) << "SASL Client: Unable to handle SASL callback type SASL_CB_LANGUAGE"
+      LOG(DFATAL) << "Unable to handle SASL callback type SASL_CB_LANGUAGE"
         << "(" << id << ")";
       return SASL_BADPARAM;
     default:
-      LOG(DFATAL) << "SASL Client: Unexpected SASL callback type: " << id;
+      LOG(DFATAL) << "Unexpected SASL callback type: " << id;
       return SASL_BADPARAM;
   }
 
@@ -491,9 +441,9 @@ int SaslClient::SimpleCb(int id, const char** result, unsigned* len) {
 
 // Used for PLAIN.
 // SASL callback for SASL_CB_PASS: User password.
-int SaslClient::SecretCb(sasl_conn_t* conn, int id, sasl_secret_t** psecret) {
+int ClientNegotiation::SecretCb(sasl_conn_t* conn, int id, sasl_secret_t** psecret) {
   if (PREDICT_FALSE(!helper_.IsPlainEnabled())) {
-    LOG(DFATAL) << "SASL Client: Plain secret callback called, but PLAIN auth is not enabled";
+    LOG(DFATAL) << "Plain secret callback called, but PLAIN auth is not enabled";
     return SASL_FAIL;
   }
   switch (id) {
@@ -511,7 +461,7 @@ int SaslClient::SecretCb(sasl_conn_t* conn, int id, sasl_secret_t** psecret) {
       break;
     }
     default:
-      LOG(DFATAL) << "SASL Client: Unexpected SASL callback type: " << id;
+      LOG(DFATAL) << "Unexpected SASL callback type: " << id;
       return SASL_BADPARAM;
   }
 

http://git-wip-us.apache.org/repos/asf/kudu/blob/dc852535/src/kudu/rpc/client_negotiation.h
----------------------------------------------------------------------
diff --git a/src/kudu/rpc/client_negotiation.h b/src/kudu/rpc/client_negotiation.h
index 2f01e56..6677d33 100644
--- a/src/kudu/rpc/client_negotiation.h
+++ b/src/kudu/rpc/client_negotiation.h
@@ -15,9 +15,9 @@
 // specific language governing permissions and limitations
 // under the License.
 
-#ifndef KUDU_RPC_SASL_CLIENT_H
-#define KUDU_RPC_SASL_CLIENT_H
+#pragma once
 
+#include <memory>
 #include <set>
 #include <string>
 #include <vector>
@@ -33,65 +33,77 @@
 #include "kudu/util/status.h"
 
 namespace kudu {
-namespace rpc {
 
-using std::string;
+namespace rpc {
 
 class NegotiatePB;
 class NegotiatePB_SaslAuth;
 class ResponseHeader;
 
-// Class for doing SASL negotiation with a SaslServer over a bidirectional socket.
+// Class for doing KRPC negotiation with a remote server over a bidirectional socket.
 // Operations on this class are NOT thread-safe.
-class SaslClient {
+class ClientNegotiation {
  public:
-  // Does not take ownership of the socket indicated by the fd.
-  SaslClient(string app_name, Socket* socket);
+  // Creates a new client negotiation instance, taking ownership of the
+  // provided socket. After completing the negotiation process by setting the
+  // desired options and calling Negotiate(), the socket can be retrieved with
+  // 'release_socket'.
+  explicit ClientNegotiation(std::unique_ptr<Socket> socket);
 
   // Enable PLAIN authentication.
-  // Must be called after Init().
-  Status EnablePlain(const string& user, const string& pass);
+  // Must be called before Negotiate().
+  Status EnablePlain(const std::string& user,
+                     const std::string& pass);
 
   // Enable GSSAPI authentication.
-  // Call after Init().
+  // Must be called before Negotiate().
   Status EnableGSSAPI();
 
   // Returns mechanism negotiated by this connection.
-  // Must be called after Negotiate().
+  // Must be called before Negotiate().
   SaslMechanism::Type negotiated_mechanism() const;
 
   // Returns the set of RPC system features supported by the remote server.
-  // Must be called after Negotiate().
-  const std::set<RpcFeatureFlag>& server_features() const {
+  // Must be called before Negotiate().
+  std::set<RpcFeatureFlag> server_features() const {
     return server_features_;
   }
 
+  // Returns the set of RPC system features supported by the remote server.
+  // Must be called after Negotiate().
+  // Subsequent calls to this method or server_features() will return an empty set.
+  std::set<RpcFeatureFlag> take_server_features() {
+    return std::move(server_features_);
+  }
+
   // Specify IP:port of local side of connection.
-  // Must be called before Init(). Required for some mechanisms.
+  // Must be called before Negotiate(). Required for some mechanisms.
   void set_local_addr(const Sockaddr& addr);
 
   // Specify IP:port of remote side of connection.
-  // Must be called before Init(). Required for some mechanisms.
+  // Must be called before Negotiate(). Required for some mechanisms.
   void set_remote_addr(const Sockaddr& addr);
 
   // Specify the fully-qualified domain name of the remote server.
-  // Must be called before Init(). Required for some mechanisms.
-  void set_server_fqdn(const string& domain_name);
+  // Must be called before Negotiate(). Required for some mechanisms.
+  void set_server_fqdn(const std::string& domain_name);
 
   // Set deadline for connection negotiation.
   void set_deadline(const MonoTime& deadline);
 
-  // Get deadline for connection negotiation.
-  const MonoTime& deadline() const { return deadline_; }
+  Socket* socket() { return socket_.get(); }
 
-  // Initialize a new SASL client. Must be called before Negotiate().
-  // Returns OK on success, otherwise RuntimeError.
-  Status Init(const string& service_type);
+  // Takes and returns the socket owned by this client negotiation. The caller
+  // will own the socket after this call, and the negotiation instance should no
+  // longer be used. Must be called after Negotiate(). Subsequent calls to this
+  // method or socket() will return a null pointer.
+  std::unique_ptr<Socket> release_socket() { return std::move(socket_); }
 
-  // Begin negotiation with the SASL server on the other side of the fd socket
-  // that this client was constructed with.
-  // Returns OK on success.
-  // Otherwise, it may return NotAuthorized, NotSupported, or another non-OK status.
+  // Negotiate with the remote server. Should only be called once per
+  // ClientNegotiation and socket instance, after all options have been set.
+  //
+  // Returns OK on success, otherwise may return NotAuthorized, NotSupported, or
+  // another non-OK status.
   Status Negotiate();
 
   // SASL callback for plugin options, supported mechanisms, etc.
@@ -106,23 +118,36 @@ class SaslClient {
   int SecretCb(sasl_conn_t* conn, int id, sasl_secret_t** psecret);
 
  private:
-  // Encode and send the specified negotiate message to the server.
-  Status SendNegotiatePB(const NegotiatePB& msg);
 
-  // Validate that header does not indicate an error, parse param_buf into response.
-  Status ParseNegotiatePB(const ResponseHeader& header,
-                          const Slice& param_buf,
-                          NegotiatePB* response);
+  // Encode and send the specified negotiate request message to the server.
+  Status SendNegotiatePB(const NegotiatePB& msg) WARN_UNUSED_RESULT;
+
+  // Receive a negotiate response message from the server, deserializing it into 'msg'.
+  // Validates that the response is not an error.
+  Status RecvNegotiatePB(NegotiatePB* msg, faststring* buffer) WARN_UNUSED_RESULT;
+
+  // Parse error status message from raw bytes of an ErrorStatusPB.
+  Status ParseError(const Slice& err_data) WARN_UNUSED_RESULT;
 
-  // Send an NEGOTIATE message to the server.
-  Status SendNegotiateMessage();
+  Status SendConnectionHeader() WARN_UNUSED_RESULT;
+
+  // Initialize the SASL client negotiation instance.
+  Status InitSaslClient() WARN_UNUSED_RESULT;
+
+  // Send a NEGOTIATE step message to the server.
+  Status SendNegotiate() WARN_UNUSED_RESULT;
+
+  // Handle NEGOTIATE step response from the server.
+  Status HandleNegotiate(const NegotiatePB& response) WARN_UNUSED_RESULT;
 
   // Send an SASL_INITIATE message to the server.
-  Status SendInitiateMessage(const NegotiatePB_SaslAuth& auth,
-                             const char* init_msg, unsigned init_msg_len);
+  Status SendSaslInitiate() WARN_UNUSED_RESULT;
 
-  // Send a RESPONSE message to the server.
-  Status SendResponseMessage(const char* resp_msg, unsigned resp_msg_len);
+  // Send a SASL_RESPONSE message to the server.
+  Status SendSaslResponse(const char* resp_msg, unsigned resp_msg_len) WARN_UNUSED_RESULT;
+
+  // Handle case when server sends SASL_CHALLENGE response.
+  Status HandleSaslChallenge(const NegotiatePB& response) WARN_UNUSED_RESULT;
 
   // Perform a client-side step of the SASL negotiation.
   // Input is what came from the server. Output is what we will send back to the server.
@@ -130,51 +155,36 @@ class SaslClient {
   //   Status::OK if sasl_client_step returns SASL_OK.
   //   Status::Incomplete if sasl_client_step returns SASL_CONTINUE
   // otherwise returns an appropriate error status.
-  Status DoSaslStep(const string& in, const char** out, unsigned* out_len);
-
-  // Handle case when server sends NEGOTIATE response.
-  Status HandleNegotiateResponse(const NegotiatePB& response);
+  Status DoSaslStep(const std::string& in, const char** out, unsigned* out_len) WARN_UNUSED_RESULT;
 
-  // Handle case when server sends CHALLENGE response.
-  Status HandleChallengeResponse(const NegotiatePB& response);
+  Status SendConnectionContext() WARN_UNUSED_RESULT;
 
-  // Handle case when server sends SUCCESS response.
-  Status HandleSuccessResponse(const NegotiatePB& response);
+  // The socket to the remote server.
+  std::unique_ptr<Socket> socket_;
 
-  // Parse error status message from raw bytes of an ErrorStatusPB.
-  Status ParseError(const Slice& err_data);
-
-  string app_name_;
-  Socket* sock_;
+  // SASL state.
   std::vector<sasl_callback_t> callbacks_;
-  // The SASL connection object. This is initialized in Init() and
-  // freed after Negotiate() completes (regardless whether it was successful).
   gscoped_ptr<sasl_conn_t, SaslDeleter> sasl_conn_;
   SaslHelper helper_;
 
-  string plain_auth_user_;
-  string plain_pass_;
+  // Authentication state.
+  std::string plain_auth_user_;
+  std::string plain_pass_;
   gscoped_ptr<sasl_secret_t, FreeDeleter> psecret_;
 
-  // The set of features supported by the server.
+  // The set of features supported by the server. Filled in during negotiation.
   std::set<RpcFeatureFlag> server_features_;
 
-  SaslNegotiationState::Type client_state_;
+  // The set of SASL mechanisms supported by the server and the client. Filled
+  // in during negotiation.
+  std::set<std::string> common_mechs_;
 
-  // The mechanism we negotiated with the server.
+  // The SASL mechanism used by the connection. Filled in during negotiation.
   SaslMechanism::Type negotiated_mech_;
 
-  // Intra-negotiation state.
-  bool nego_ok_;  // During negotiation: did we get a SASL_OK response from the SASL library?
-  bool nego_response_expected_;  // During negotiation: Are we waiting for a server response?
-
   // Negotiation timeout deadline.
   MonoTime deadline_;
-
-  DISALLOW_COPY_AND_ASSIGN(SaslClient);
 };
 
 } // namespace rpc
 } // namespace kudu
-
-#endif  // KUDU_RPC_SASL_CLIENT_H

http://git-wip-us.apache.org/repos/asf/kudu/blob/dc852535/src/kudu/rpc/connection.cc
----------------------------------------------------------------------
diff --git a/src/kudu/rpc/connection.cc b/src/kudu/rpc/connection.cc
index 2abf6ca..30d3720 100644
--- a/src/kudu/rpc/connection.cc
+++ b/src/kudu/rpc/connection.cc
@@ -17,16 +17,19 @@
 
 #include "kudu/rpc/connection.h"
 
+#include <stdint.h>
+
 #include <algorithm>
-#include <boost/intrusive/list.hpp>
-#include <gflags/gflags.h>
-#include <glog/logging.h>
 #include <iostream>
 #include <set>
-#include <stdint.h>
 #include <string>
+#include <unordered_set>
 #include <vector>
 
+#include <boost/intrusive/list.hpp>
+#include <gflags/gflags.h>
+#include <glog/logging.h>
+
 #include "kudu/gutil/map-util.h"
 #include "kudu/gutil/strings/human_readable.h"
 #include "kudu/gutil/strings/substitute.h"
@@ -38,12 +41,12 @@
 #include "kudu/rpc/rpc_header.pb.h"
 #include "kudu/rpc/rpc_introspection.pb.h"
 #include "kudu/rpc/transfer.h"
-#include "kudu/security/ssl_factory.h"
 #include "kudu/security/ssl_socket.h"
 #include "kudu/util/debug-util.h"
 #include "kudu/util/flag_tags.h"
 #include "kudu/util/logging.h"
 #include "kudu/util/net/sockaddr.h"
+#include "kudu/util/net/socket.h"
 #include "kudu/util/status.h"
 #include "kudu/util/trace.h"
 
@@ -51,28 +54,27 @@ using std::function;
 using std::includes;
 using std::set;
 using std::shared_ptr;
+using std::unique_ptr;
 using std::vector;
 using strings::Substitute;
 
-DECLARE_bool(server_require_kerberos);
-
 namespace kudu {
 namespace rpc {
 
 ///
 /// Connection
 ///
-Connection::Connection(ReactorThread *reactor_thread, Sockaddr remote,
-                       Socket* socket, Direction direction)
+Connection::Connection(ReactorThread *reactor_thread,
+                       Sockaddr remote,
+                       unique_ptr<Socket> socket,
+                       Direction direction)
     : reactor_thread_(reactor_thread),
-      remote_(std::move(remote)),
-      socket_(socket),
+      remote_(remote),
+      socket_(std::move(socket)),
       direction_(direction),
       last_activity_time_(MonoTime::Now()),
       is_epoll_registered_(false),
       next_call_id_(1),
-      sasl_client_(kSaslAppName, socket),
-      sasl_server_(kSaslAppName, socket),
       negotiation_complete_(false) {
 }
 
@@ -169,7 +171,9 @@ void Connection::Shutdown(const Status &status) {
   read_io_.stop();
   write_io_.stop();
   is_epoll_registered_ = false;
-  WARN_NOT_OK(socket_->Close(), "Error closing socket");
+  if (socket_) {
+    WARN_NOT_OK(socket_->Close(), "Error closing socket");
+  }
 }
 
 void Connection::QueueOutbound(gscoped_ptr<OutboundTransfer> transfer) {
@@ -578,8 +582,7 @@ void Connection::WriteHandler(ev::io &watcher, int revents) {
         // order to ensure that the negotiation has taken place, so that the flags
         // are available.
         const set<RpcFeatureFlag>& required_features = car->call->required_rpc_features();
-        const set<RpcFeatureFlag>& server_features = sasl_client_.server_features();
-        if (!includes(server_features.begin(), server_features.end(),
+        if (!includes(remote_features_.begin(), remote_features_.end(),
                       required_features.begin(), required_features.end())) {
           outbound_transfers_.pop_front();
           Status s = Status::NotSupported("server does not support the required RPC features");
@@ -626,58 +629,13 @@ std::string Connection::ToString() const {
     remote_.ToString());
 }
 
-Status Connection::InitSSLIfNecessary() {
-  if (!reactor_thread_->reactor()->messenger()->ssl_enabled()) return Status::OK();
-  SSLSocket* ssl_socket = down_cast<SSLSocket*>(socket_.get());
-  RETURN_NOT_OK(ssl_socket->DoHandshake());
-  return Status::OK();
-}
-
-Status Connection::InitSaslClient() {
-  // Note that remote_.host() is an IP address here: we've already lost
-  // whatever DNS name the client was attempting to use. Unless krb5
-  // is configured with 'rdns = false', it will automatically take care
-  // of reversing this address to its canonical hostname to determine
-  // the expected server principal.
-  sasl_client().set_server_fqdn(remote_.host());
-  Status gssapi_status = sasl_client().EnableGSSAPI();
-  if (!gssapi_status.ok()) {
-    // If we can't enable GSSAPI, it's likely the client is just missing the
-    // appropriate SASL plugin. We don't want to require it to be installed
-    // if the user doesn't care about connecting to servers using Kerberos
-    // authentication. So, we'll just VLOG this here. If we try to connect
-    // to a server which requires Kerberos, we'll get a negotiation error
-    // at that point.
-    if (VLOG_IS_ON(1)) {
-      KLOG_FIRST_N(INFO, 1) << "Couldn't enable GSSAPI (Kerberos) SASL plugin: "
-                            << gssapi_status.message().ToString()
-                            << ". This process will be unable to connect to "
-                            << "servers requiring Kerberos authentication.";
-    }
-  }
-  RETURN_NOT_OK(sasl_client().EnablePlain(user_credentials().real_user(), ""));
-  RETURN_NOT_OK(sasl_client().Init(kSaslProtoName));
-  return Status::OK();
-}
-
-Status Connection::InitSaslServer() {
-  if (FLAGS_server_require_kerberos) {
-    RETURN_NOT_OK(sasl_server().EnableGSSAPI());
-  } else {
-    RETURN_NOT_OK(sasl_server().EnablePlain());
-  }
-  RETURN_NOT_OK(sasl_server().Init(kSaslProtoName));
-  return Status::OK();
-}
-
 // Reactor task that transitions this Connection from connection negotiation to
 // regular RPC handling. Destroys Connection on negotiation error.
 class NegotiationCompletedTask : public ReactorTask {
  public:
-  NegotiationCompletedTask(Connection* conn,
-      const Status& negotiation_status)
+  NegotiationCompletedTask(Connection* conn, Status negotiation_status)
     : conn_(conn),
-      negotiation_status_(negotiation_status) {
+      negotiation_status_(std::move(negotiation_status)) {
   }
 
   virtual void Run(ReactorThread *rthread) OVERRIDE {
@@ -687,8 +645,8 @@ class NegotiationCompletedTask : public ReactorTask {
 
   virtual void Abort(const Status &status) OVERRIDE {
     DCHECK(conn_->reactor_thread()->reactor()->closing());
-    VLOG(1) << "Failed connection negotiation due to shut down reactor thread: " <<
-        status.ToString();
+    VLOG(1) << "Failed connection negotiation due to shut down reactor thread: "
+            << status.ToString();
     delete this;
   }
 

http://git-wip-us.apache.org/repos/asf/kudu/blob/dc852535/src/kudu/rpc/connection.h
----------------------------------------------------------------------
diff --git a/src/kudu/rpc/connection.h b/src/kudu/rpc/connection.h
index ef70483..ad5ffed 100644
--- a/src/kudu/rpc/connection.h
+++ b/src/kudu/rpc/connection.h
@@ -18,22 +18,22 @@
 #ifndef KUDU_RPC_CONNECTION_H
 #define KUDU_RPC_CONNECTION_H
 
-#include <boost/intrusive/list.hpp>
-#include <ev++.h>
-#include <memory>
 #include <stdint.h>
-#include <unordered_map>
 
 #include <limits>
+#include <memory>
+#include <set>
 #include <string>
+#include <unordered_map>
 #include <vector>
 
+#include <boost/intrusive/list.hpp>
+#include <ev++.h>
+
 #include "kudu/gutil/gscoped_ptr.h"
 #include "kudu/gutil/ref_counted.h"
-#include "kudu/rpc/client_negotiation.h"
 #include "kudu/rpc/inbound_call.h"
 #include "kudu/rpc/outbound_call.h"
-#include "kudu/rpc/server_negotiation.h"
 #include "kudu/rpc/transfer.h"
 #include "kudu/util/monotime.h"
 #include "kudu/util/net/sockaddr.h"
@@ -80,7 +80,9 @@ class Connection : public RefCountedThreadSafe<Connection> {
   // remote: the address of the remote end
   // socket: the socket to take ownership of.
   // direction: whether we are the client or server side
-  Connection(ReactorThread *reactor_thread, Sockaddr remote, Socket* socket,
+  Connection(ReactorThread *reactor_thread,
+             Sockaddr remote,
+             std::unique_ptr<Socket> socket,
              Direction direction);
 
   // Set underlying socket to non-blocking (or blocking) mode.
@@ -143,23 +145,8 @@ class Connection : public RefCountedThreadSafe<Connection> {
 
   Socket* socket() { return socket_.get(); }
 
-  // Return SASL client instance for this connection.
-  SaslClient &sasl_client() { return sasl_client_; }
-
-  // Return SASL server instance for this connection.
-  SaslServer &sasl_server() { return sasl_server_; }
-
-  // Initialize underlying SSLSocket if SSL is enabled.
-  Status InitSSLIfNecessary();
-
-  // Initialize SASL client before negotiation begins.
-  Status InitSaslClient();
-
-  // Initialize SASL server before negotiation begins.
-  Status InitSaslServer();
-
   // Go through the process of transferring control of the underlying socket back to the Reactor.
-  void CompleteNegotiation(const Status &negotiation_status);
+  void CompleteNegotiation(const Status& negotiation_status);
 
   // Indicate that negotiation is complete and that the Reactor is now in control of the socket.
   void MarkNegotiationComplete();
@@ -167,7 +154,19 @@ class Connection : public RefCountedThreadSafe<Connection> {
   Status DumpPB(const DumpRunningRpcsRequestPB& req,
                 RpcConnectionPB* resp);
 
-  ReactorThread *reactor_thread() const { return reactor_thread_; }
+  ReactorThread* reactor_thread() const { return reactor_thread_; }
+
+  std::unique_ptr<Socket> release_socket() {
+    return std::move(socket_);
+  }
+
+  void adopt_socket(std::unique_ptr<Socket> socket) {
+    socket_ = std::move(socket);
+  }
+
+  void set_remote_features(std::set<RpcFeatureFlag> remote_features) {
+    remote_features_ = std::move(remote_features);
+  }
 
  private:
   friend struct CallAwaitingResponse;
@@ -226,7 +225,7 @@ class Connection : public RefCountedThreadSafe<Connection> {
   void QueueOutbound(gscoped_ptr<OutboundTransfer> transfer);
 
   // The reactor thread that created this connection.
-  ReactorThread * const reactor_thread_;
+  ReactorThread* const reactor_thread_;
 
   // The remote address we're talking to.
   const Sockaddr remote_;
@@ -277,17 +276,14 @@ class Connection : public RefCountedThreadSafe<Connection> {
   // when serializing calls.
   std::vector<Slice> slices_tmp_;
 
+  // RPC features supported by the remote end of the connection.
+  std::set<RpcFeatureFlag> remote_features_;
+
   // Pool from which CallAwaitingResponse objects are allocated.
   // Also a funny name.
   ObjectPool<CallAwaitingResponse> car_pool_;
   typedef ObjectPool<CallAwaitingResponse>::scoped_ptr scoped_car;
 
-  // SASL client instance used for connection negotiation when Direction == CLIENT.
-  SaslClient sasl_client_;
-
-  // SASL server instance used for connection negotiation when Direction == SERVER.
-  SaslServer sasl_server_;
-
   // Whether we completed connection negotiation.
   bool negotiation_complete_;
 };

http://git-wip-us.apache.org/repos/asf/kudu/blob/dc852535/src/kudu/rpc/messenger.cc
----------------------------------------------------------------------
diff --git a/src/kudu/rpc/messenger.cc b/src/kudu/rpc/messenger.cc
index 103f72c..5a5ed7a 100644
--- a/src/kudu/rpc/messenger.cc
+++ b/src/kudu/rpc/messenger.cc
@@ -41,6 +41,7 @@
 #include "kudu/rpc/rpc_service.h"
 #include "kudu/rpc/rpcz_store.h"
 #include "kudu/rpc/sasl_common.h"
+#include "kudu/rpc/server_negotiation.h"
 #include "kudu/rpc/transfer.h"
 #include "kudu/security/ssl_factory.h"
 #include "kudu/util/errno.h"
@@ -56,7 +57,6 @@ using std::string;
 using std::shared_ptr;
 using strings::Substitute;
 
-
 DEFINE_string(rpc_ssl_server_certificate, "", "Path to the SSL certificate to be used for the RPC "
     "layer.");
 DEFINE_string(rpc_ssl_private_key, "",
@@ -130,7 +130,7 @@ MessengerBuilder &MessengerBuilder::set_metric_entity(
 }
 
 Status MessengerBuilder::Build(shared_ptr<Messenger> *msgr) {
-  RETURN_NOT_OK(SaslInit(kSaslAppName)); // Initialize SASL library before we start making requests
+  RETURN_NOT_OK(SaslInit()); // Initialize SASL library before we start making requests
   Messenger* new_msgr(new Messenger(*this));
   Status build_status = new_msgr->Init();
   if (!build_status.ok()) {
@@ -192,7 +192,7 @@ Status Messenger::AddAcceptorPool(const Sockaddr &accept_addr,
   // that everything is set up correctly. This way we'll generate errors on
   // startup rather than later on when we first receive a client connection.
   if (FLAGS_server_require_kerberos) {
-    RETURN_NOT_OK_PREPEND(SaslServer::PreflightCheckGSSAPI(kSaslAppName),
+    RETURN_NOT_OK_PREPEND(ServerNegotiation::PreflightCheckGSSAPI(),
                           "GSSAPI/Kerberos not properly configured");
   }
 
@@ -296,11 +296,11 @@ Reactor* Messenger::RemoteToReactor(const Sockaddr &remote) {
   return reactors_[reactor_idx];
 }
 
-
 Status Messenger::Init() {
   Status status;
   ssl_enabled_ = !FLAGS_rpc_ssl_server_certificate.empty() || !FLAGS_rpc_ssl_private_key.empty()
                    || !FLAGS_rpc_ssl_certificate_authority.empty();
+
   if (ssl_enabled_) {
     ssl_factory_.reset(new SSLFactory());
     RETURN_NOT_OK(ssl_factory_->Init());

http://git-wip-us.apache.org/repos/asf/kudu/blob/dc852535/src/kudu/rpc/negotiation-test.cc
----------------------------------------------------------------------
diff --git a/src/kudu/rpc/negotiation-test.cc b/src/kudu/rpc/negotiation-test.cc
index 2bb73e4..984dff1 100644
--- a/src/kudu/rpc/negotiation-test.cc
+++ b/src/kudu/rpc/negotiation-test.cc
@@ -43,6 +43,7 @@
 
 using std::string;
 using std::thread;
+using std::unique_ptr;
 
 // HACK: MIT Kerberos doesn't have any way of determining its version number,
 // but the error messages in krb5-1.10 and earlier are broken due to
@@ -62,35 +63,25 @@ DEFINE_bool(is_test_child, false,
 namespace kudu {
 namespace rpc {
 
-class TestSaslRpc : public RpcTestBase {
+class TestNegotiation : public RpcTestBase {
  public:
   virtual void SetUp() OVERRIDE {
     RpcTestBase::SetUp();
-    ASSERT_OK(SaslInit(kSaslAppName));
+    ASSERT_OK(SaslInit());
   }
 };
 
-// Test basic initialization of the objects.
-TEST_F(TestSaslRpc, TestBasicInit) {
-  SaslServer server(kSaslAppName, nullptr);
-  server.EnablePlain();
-  ASSERT_OK(server.Init(kSaslAppName));
-  SaslClient client(kSaslAppName, nullptr);
-  client.EnablePlain("test", "test");
-  ASSERT_OK(client.Init(kSaslAppName));
-}
-
-// A "Callable" that takes a Socket* param, for use with starting a thread.
-// Can be used for SaslServer or SaslClient threads.
-typedef std::function<void(Socket*)> SocketCallable;
+// A "Callable" that takes a socket for use with starting a thread.
+// Can be used for ServerNegotiation or ClientNegotiation threads.
+typedef std::function<void(unique_ptr<Socket>)> SocketCallable;
 
 // Call Accept() on the socket, then pass the connection to the server runner
 static void RunAcceptingDelegator(Socket* acceptor,
                                   const SocketCallable& server_runner) {
-  Socket conn;
+  unique_ptr<Socket> conn(new Socket());
   Sockaddr remote;
-  CHECK_OK(acceptor->Accept(&conn, &remote, 0));
-  server_runner(&conn);
+  CHECK_OK(acceptor->Accept(conn.get(), &remote, 0));
+  server_runner(std::move(conn));
 }
 
 // Set up a socket and run a SASL negotiation.
@@ -103,45 +94,38 @@ static void RunNegotiationTest(const SocketCallable& server_runner,
   ASSERT_OK(server_sock.GetSocketAddress(&server_bind_addr));
   thread server(RunAcceptingDelegator, &server_sock, server_runner);
 
-  Socket client_sock;
-  CHECK_OK(client_sock.Init(0));
-  ASSERT_OK(client_sock.Connect(server_bind_addr));
-  thread client(client_runner, &client_sock);
+  unique_ptr<Socket> client_sock(new Socket());
+  CHECK_OK(client_sock->Init(0));
+  ASSERT_OK(client_sock->Connect(server_bind_addr));
+  thread client(client_runner, std::move(client_sock));
 
   LOG(INFO) << "Waiting for test threads to terminate...";
   client.join();
   LOG(INFO) << "Client thread terminated.";
 
-  // TODO(todd): if the client fails to negotiate, it doesn't
-  // always result in sending a nice error message to the
-  // other side.
-  client_sock.Close();
-
   server.join();
   LOG(INFO) << "Server thread terminated.";
 }
 
 ////////////////////////////////////////////////////////////////////////////////
 
-static void RunPlainNegotiationServer(Socket* conn) {
-  SaslServer sasl_server(kSaslAppName, conn);
-  CHECK_OK(sasl_server.EnablePlain());
-  CHECK_OK(sasl_server.Init(kSaslAppName));
-  CHECK_OK(sasl_server.Negotiate());
-  CHECK(ContainsKey(sasl_server.client_features(), APPLICATION_FEATURE_FLAGS));
-  CHECK_EQ("my-username", sasl_server.authenticated_user());
+static void RunPlainNegotiationServer(unique_ptr<Socket> socket) {
+  ServerNegotiation server_negotiation(std::move(socket));
+  CHECK_OK(server_negotiation.EnablePlain());
+  CHECK_OK(server_negotiation.Negotiate());
+  CHECK(ContainsKey(server_negotiation.client_features(), APPLICATION_FEATURE_FLAGS));
+  CHECK_EQ("my-username", server_negotiation.authenticated_user());
 }
 
-static void RunPlainNegotiationClient(Socket* conn) {
-  SaslClient sasl_client(kSaslAppName, conn);
-  CHECK_OK(sasl_client.EnablePlain("my-username", "ignored password"));
-  CHECK_OK(sasl_client.Init(kSaslAppName));
-  CHECK_OK(sasl_client.Negotiate());
-  CHECK(ContainsKey(sasl_client.server_features(), APPLICATION_FEATURE_FLAGS));
+static void RunPlainNegotiationClient(unique_ptr<Socket> socket) {
+  ClientNegotiation client_negotiation(std::move(socket));
+  CHECK_OK(client_negotiation.EnablePlain("my-username", "ignored password"));
+  CHECK_OK(client_negotiation.Negotiate());
+  CHECK(ContainsKey(client_negotiation.server_features(), APPLICATION_FEATURE_FLAGS));
 }
 
 // Test SASL negotiation using the PLAIN mechanism over a socket.
-TEST_F(TestSaslRpc, TestPlainNegotiation) {
+TEST_F(TestNegotiation, TestPlainNegotiation) {
   RunNegotiationTest(RunPlainNegotiationServer, RunPlainNegotiationClient);
 }
 
@@ -153,33 +137,29 @@ using CheckerFunction = std::function<void(const Status&, T&)>;
 
 // Run GSSAPI negotiation from the server side. Runs
 // 'post_check' after negotiation to verify the result.
-static void RunGSSAPINegotiationServer(
-    Socket* conn,
-    const CheckerFunction<SaslServer>& post_check) {
-  SaslServer sasl_server(kSaslAppName, conn);
-  sasl_server.set_server_fqdn("127.0.0.1");
-  CHECK_OK(sasl_server.EnableGSSAPI());
-  CHECK_OK(sasl_server.Init(kSaslAppName));
-  post_check(sasl_server.Negotiate(), sasl_server);
+static void RunGSSAPINegotiationServer(unique_ptr<Socket> socket,
+                                       const CheckerFunction<ServerNegotiation>& post_check) {
+  ServerNegotiation server_negotiation(std::move(socket));
+  server_negotiation.set_server_fqdn("127.0.0.1");
+  CHECK_OK(server_negotiation.EnableGSSAPI());
+  post_check(server_negotiation.Negotiate(), server_negotiation);
 }
 
 // Run GSSAPI negotiation from the client side. Runs
 // 'post_check' after negotiation to verify the result.
-static void RunGSSAPINegotiationClient(
-    Socket* conn,
-    const CheckerFunction<SaslClient>& post_check) {
-  SaslClient sasl_client(kSaslAppName, conn);
-  sasl_client.set_server_fqdn("127.0.0.1");
-  CHECK_OK(sasl_client.EnableGSSAPI());
-  CHECK_OK(sasl_client.Init(kSaslAppName));
-  post_check(sasl_client.Negotiate(), sasl_client);
+static void RunGSSAPINegotiationClient(unique_ptr<Socket> conn,
+                                       const CheckerFunction<ClientNegotiation>& post_check) {
+  ClientNegotiation client_negotiation(std::move(conn));
+  client_negotiation.set_server_fqdn("127.0.0.1");
+  CHECK_OK(client_negotiation.EnableGSSAPI());
+  post_check(client_negotiation.Negotiate(), client_negotiation);
 }
 
 // Test configuring a client to allow but not require Kerberos/GSSAPI,
 // and connect to a server which requires Kerberos/GSSAPI.
 //
 // They should negotiate to use Kerberos/GSSAPI.
-TEST_F(TestSaslRpc, TestRestrictiveServer_NonRestrictiveClient) {
+TEST_F(TestNegotiation, TestRestrictiveServer_NonRestrictiveClient) {
   MiniKdc kdc;
   ASSERT_OK(kdc.Start());
 
@@ -196,27 +176,26 @@ TEST_F(TestSaslRpc, TestRestrictiveServer_NonRestrictiveClient) {
   // Authentication should now succeed on both sides.
   RunNegotiationTest(
       std::bind(RunGSSAPINegotiationServer, std::placeholders::_1,
-                [](const Status& s, SaslServer& server) {
+                [](const Status& s, ServerNegotiation& server) {
                   CHECK_OK(s);
                   CHECK_EQ(SaslMechanism::GSSAPI, server.negotiated_mechanism());
                   CHECK_EQ("testuser", server.authenticated_user());
                 }),
-      [](Socket* conn) {
-        SaslClient sasl_client(kSaslAppName, conn);
-        sasl_client.set_server_fqdn("127.0.0.1");
+      [](unique_ptr<Socket> socket) {
+        ClientNegotiation client_negotiation(std::move(socket));
+        client_negotiation.set_server_fqdn("127.0.0.1");
         // The client enables both PLAIN and GSSAPI.
-        CHECK_OK(sasl_client.EnablePlain("foo", "bar"));
-        CHECK_OK(sasl_client.EnableGSSAPI());
-        CHECK_OK(sasl_client.Init(kSaslAppName));
-        CHECK_OK(sasl_client.Negotiate());
-        CHECK_EQ(SaslMechanism::GSSAPI, sasl_client.negotiated_mechanism());
+        CHECK_OK(client_negotiation.EnablePlain("foo", "bar"));
+        CHECK_OK(client_negotiation.EnableGSSAPI());
+        CHECK_OK(client_negotiation.Negotiate());
+        CHECK_EQ(SaslMechanism::GSSAPI, client_negotiation.negotiated_mechanism());
       });
 }
 
 // Test configuring a client to only support PLAIN, and a server which
 // only supports GSSAPI. This would happen, for example, if an old Kudu
 // client tries to talk to a secure-only cluster.
-TEST_F(TestSaslRpc, TestNoMatchingMechanisms) {
+TEST_F(TestNegotiation, TestNoMatchingMechanisms) {
   MiniKdc kdc;
   ASSERT_OK(kdc.Start());
 
@@ -227,7 +206,7 @@ TEST_F(TestSaslRpc, TestNoMatchingMechanisms) {
 
   RunNegotiationTest(
       std::bind(RunGSSAPINegotiationServer, std::placeholders::_1,
-                [](const Status& s, SaslServer& server) {
+                [](const Status& s, ServerNegotiation& server) {
                   // The client fails to find a matching mechanism and
                   // doesn't send any failure message to the server.
                   // Instead, it just disconnects.
@@ -235,19 +214,18 @@ TEST_F(TestSaslRpc, TestNoMatchingMechanisms) {
                   // TODO(todd): this could produce a better message!
                   ASSERT_STR_CONTAINS(s.ToString(), "got EOF from remote");
                 }),
-      [](Socket* conn) {
-        SaslClient sasl_client(kSaslAppName, conn);
-        sasl_client.set_server_fqdn("127.0.0.1");
+      [](unique_ptr<Socket> socket) {
+        ClientNegotiation client_negotiation(std::move(socket));
+        client_negotiation.set_server_fqdn("127.0.0.1");
         // The client enables both PLAIN and GSSAPI.
-        CHECK_OK(sasl_client.EnablePlain("foo", "bar"));
-        CHECK_OK(sasl_client.Init(kSaslAppName));
-        Status s = sasl_client.Negotiate();
+        CHECK_OK(client_negotiation.EnablePlain("foo", "bar"));
+        Status s = client_negotiation.Negotiate();
         ASSERT_STR_CONTAINS(s.ToString(), "client was missing the required SASL module");
       });
 }
 
 // Test SASL negotiation using the GSSAPI (kerberos) mechanism over a socket.
-TEST_F(TestSaslRpc, TestGSSAPINegotiation) {
+TEST_F(TestNegotiation, TestGSSAPINegotiation) {
   MiniKdc kdc;
   ASSERT_OK(kdc.Start());
 
@@ -264,13 +242,13 @@ TEST_F(TestSaslRpc, TestGSSAPINegotiation) {
   // Authentication should succeed on both sides.
   RunNegotiationTest(
       std::bind(RunGSSAPINegotiationServer, std::placeholders::_1,
-                [](const Status& s, SaslServer& server) {
+                [](const Status& s, ServerNegotiation& server) {
                   CHECK_OK(s);
                   CHECK_EQ(SaslMechanism::GSSAPI, server.negotiated_mechanism());
                   CHECK_EQ("testuser", server.authenticated_user());
                 }),
       std::bind(RunGSSAPINegotiationClient, std::placeholders::_1,
-                [](const Status& s, SaslClient& client) {
+                [](const Status& s, ClientNegotiation& client) {
                   CHECK_OK(s);
                   CHECK_EQ(SaslMechanism::GSSAPI, client.negotiated_mechanism());
                 }));
@@ -281,7 +259,7 @@ TEST_F(TestSaslRpc, TestGSSAPINegotiation) {
 // This test is ignored on macOS because the system Kerberos implementation
 // (Heimdal) caches the non-existence of client credentials, which causes futher
 // tests to fail.
-TEST_F(TestSaslRpc, TestGSSAPIInvalidNegotiation) {
+TEST_F(TestNegotiation, TestGSSAPIInvalidNegotiation) {
   MiniKdc kdc;
   ASSERT_OK(kdc.Start());
 
@@ -289,7 +267,7 @@ TEST_F(TestSaslRpc, TestGSSAPIInvalidNegotiation) {
   // sides.
   RunNegotiationTest(
       std::bind(RunGSSAPINegotiationServer, std::placeholders::_1,
-                [](const Status& s, SaslServer& server) {
+                [](const Status& s, ServerNegotiation& server) {
                   // The client notices there are no credentials and
                   // doesn't send any failure message to the server.
                   // Instead, it just disconnects.
@@ -299,7 +277,7 @@ TEST_F(TestSaslRpc, TestGSSAPIInvalidNegotiation) {
                   CHECK(s.IsNetworkError());
                 }),
       std::bind(RunGSSAPINegotiationClient, std::placeholders::_1,
-                [](const Status& s, SaslClient& client) {
+                [](const Status& s, ClientNegotiation& client) {
                   CHECK(s.IsNotAuthorized());
 #ifndef KRB5_VERSION_LE_1_10
                   CHECK_GT(s.ToString().find("No Kerberos credentials available"), 0);
@@ -316,14 +294,14 @@ TEST_F(TestSaslRpc, TestGSSAPIInvalidNegotiation) {
   // sides.
   RunNegotiationTest(
       std::bind(RunGSSAPINegotiationServer, std::placeholders::_1,
-                [](const Status& s, SaslServer& server) {
+                [](const Status& s, ServerNegotiation& server) {
                   // The client notices there are no credentials and
                   // doesn't send any failure message to the server.
                   // Instead, it just disconnects.
                   CHECK(s.IsNetworkError());
                 }),
       std::bind(RunGSSAPINegotiationClient, std::placeholders::_1,
-                [](const Status& s, SaslClient& client) {
+                [](const Status& s, ClientNegotiation& client) {
                   CHECK(s.IsNotAuthorized());
 #ifndef KRB5_VERSION_LE_1_10
                   CHECK_EQ(s.message().ToString(), "No Kerberos credentials available");
@@ -343,7 +321,7 @@ TEST_F(TestSaslRpc, TestGSSAPIInvalidNegotiation) {
 
   RunNegotiationTest(
       std::bind(RunGSSAPINegotiationServer, std::placeholders::_1,
-                [](const Status& s, SaslServer& server) {
+                [](const Status& s, ServerNegotiation& server) {
                   CHECK(s.IsNotAuthorized());
 #ifndef KRB5_VERSION_LE_1_10
                   ASSERT_STR_CONTAINS(s.ToString(),
@@ -351,7 +329,7 @@ TEST_F(TestSaslRpc, TestGSSAPIInvalidNegotiation) {
 #endif
                 }),
       std::bind(RunGSSAPINegotiationClient, std::placeholders::_1,
-                [](const Status& s, SaslClient& client) {
+                [](const Status& s, ClientNegotiation& client) {
                   CHECK(s.IsNotAuthorized());
 #ifndef KRB5_VERSION_LE_1_10
                   ASSERT_STR_CONTAINS(s.ToString(),
@@ -368,9 +346,9 @@ TEST_F(TestSaslRpc, TestGSSAPIInvalidNegotiation) {
 // This is ignored on macOS because the system Kerberos implementation does not
 // fail the preflight check when the keytab is inaccessible, probably because
 // the preflight check passes a 0-length token.
-TEST_F(TestSaslRpc, TestPreflight) {
+TEST_F(TestNegotiation, TestPreflight) {
   // Try pre-flight with no keytab.
-  Status s = SaslServer::PreflightCheckGSSAPI(kSaslAppName);
+  Status s = ServerNegotiation::PreflightCheckGSSAPI();
   ASSERT_FALSE(s.ok());
 #ifndef KRB5_VERSION_LE_1_10
   ASSERT_STR_MATCHES(s.ToString(), "Key table file.*not found");
@@ -383,11 +361,11 @@ TEST_F(TestSaslRpc, TestPreflight) {
   ASSERT_OK(kdc.CreateServiceKeytab("kudu/127.0.0.1", &kt_path));
   CHECK_ERR(setenv("KRB5_KTNAME", kt_path.c_str(), 1 /*replace*/));
 
-  ASSERT_OK(SaslServer::PreflightCheckGSSAPI(kSaslAppName));
+  ASSERT_OK(ServerNegotiation::PreflightCheckGSSAPI());
 
   // Try with an inaccessible keytab.
   CHECK_ERR(chmod(kt_path.c_str(), 0000));
-  s = SaslServer::PreflightCheckGSSAPI(kSaslAppName);
+  s = ServerNegotiation::PreflightCheckGSSAPI();
   ASSERT_FALSE(s.ok());
 #ifndef KRB5_VERSION_LE_1_10
   ASSERT_STR_MATCHES(s.ToString(), "error accessing keytab: Permission denied");
@@ -397,7 +375,7 @@ TEST_F(TestSaslRpc, TestPreflight) {
   // Try with a keytab that has the wrong credentials.
   ASSERT_OK(kdc.CreateServiceKeytab("wrong-service/127.0.0.1", &kt_path));
   CHECK_ERR(setenv("KRB5_KTNAME", kt_path.c_str(), 1 /*replace*/));
-  s = SaslServer::PreflightCheckGSSAPI(kSaslAppName);
+  s = ServerNegotiation::PreflightCheckGSSAPI();
   ASSERT_FALSE(s.ok());
 #ifndef KRB5_VERSION_LE_1_10
   ASSERT_STR_MATCHES(s.ToString(), "No key table entry found matching kudu/.*");
@@ -407,55 +385,51 @@ TEST_F(TestSaslRpc, TestPreflight) {
 
 ////////////////////////////////////////////////////////////////////////////////
 
-static void RunTimeoutExpectingServer(Socket* conn) {
-  SaslServer sasl_server(kSaslAppName, conn);
-  CHECK_OK(sasl_server.EnablePlain());
-  CHECK_OK(sasl_server.Init(kSaslAppName));
-  Status s = sasl_server.Negotiate();
+static void RunTimeoutExpectingServer(unique_ptr<Socket> socket) {
+  ServerNegotiation server_negotiation(std::move(socket));
+  CHECK_OK(server_negotiation.EnablePlain());
+  Status s = server_negotiation.Negotiate();
   ASSERT_TRUE(s.IsNetworkError()) << "Expected client to time out and close the connection. Got: "
-      << s.ToString();
+                                  << s.ToString();
 }
 
-static void RunTimeoutNegotiationClient(Socket* sock) {
-  SaslClient sasl_client(kSaslAppName, sock);
-  CHECK_OK(sasl_client.EnablePlain("test", "test"));
-  CHECK_OK(sasl_client.Init(kSaslAppName));
+static void RunTimeoutNegotiationClient(unique_ptr<Socket> sock) {
+  ClientNegotiation client_negotiation(std::move(sock));
+  CHECK_OK(client_negotiation.EnablePlain("test", "test"));
   MonoTime deadline = MonoTime::Now() - MonoDelta::FromMilliseconds(100L);
-  sasl_client.set_deadline(deadline);
-  Status s = sasl_client.Negotiate();
+  client_negotiation.set_deadline(deadline);
+  Status s = client_negotiation.Negotiate();
   ASSERT_TRUE(s.IsTimedOut()) << "Expected timeout! Got: " << s.ToString();
-  CHECK_OK(sock->Shutdown(true, true));
+  CHECK_OK(client_negotiation.socket()->Shutdown(true, true));
 }
 
 // Ensure that the client times out.
-TEST_F(TestSaslRpc, TestClientTimeout) {
+TEST_F(TestNegotiation, TestClientTimeout) {
   RunNegotiationTest(RunTimeoutExpectingServer, RunTimeoutNegotiationClient);
 }
 
 ////////////////////////////////////////////////////////////////////////////////
 
-static void RunTimeoutNegotiationServer(Socket* sock) {
-  SaslServer sasl_server(kSaslAppName, sock);
-  CHECK_OK(sasl_server.EnablePlain());
-  CHECK_OK(sasl_server.Init(kSaslAppName));
+static void RunTimeoutNegotiationServer(unique_ptr<Socket> socket) {
+  ServerNegotiation server_negotiation(std::move(socket));
+  CHECK_OK(server_negotiation.EnablePlain());
   MonoTime deadline = MonoTime::Now() - MonoDelta::FromMilliseconds(100L);
-  sasl_server.set_deadline(deadline);
-  Status s = sasl_server.Negotiate();
+  server_negotiation.set_deadline(deadline);
+  Status s = server_negotiation.Negotiate();
   ASSERT_TRUE(s.IsTimedOut()) << "Expected timeout! Got: " << s.ToString();
-  CHECK_OK(sock->Close());
+  CHECK_OK(server_negotiation.socket()->Close());
 }
 
-static void RunTimeoutExpectingClient(Socket* conn) {
-  SaslClient sasl_client(kSaslAppName, conn);
-  CHECK_OK(sasl_client.EnablePlain("test", "test"));
-  CHECK_OK(sasl_client.Init(kSaslAppName));
-  Status s = sasl_client.Negotiate();
+static void RunTimeoutExpectingClient(unique_ptr<Socket> socket) {
+  ClientNegotiation client_negotiation(std::move(socket));
+  CHECK_OK(client_negotiation.EnablePlain("test", "test"));
+  Status s = client_negotiation.Negotiate();
   ASSERT_TRUE(s.IsNetworkError()) << "Expected server to time out and close the connection. Got: "
       << s.ToString();
 }
 
 // Ensure that the server times out.
-TEST_F(TestSaslRpc, TestServerTimeout) {
+TEST_F(TestNegotiation, TestServerTimeout) {
   RunNegotiationTest(RunTimeoutNegotiationServer, RunTimeoutExpectingClient);
 }
 
@@ -498,7 +472,7 @@ class TestDisableInit : public KuduTest {
 TEST_F(TestDisableInit, TestDisableSasl_NotInitialized) {
   DoTest([]() {
       CHECK_OK(DisableSaslInitialization());
-      Status s = SaslInit("kudu");
+      Status s = SaslInit();
       ASSERT_STR_CONTAINS(s.ToString(), "was disabled, but SASL was not externally initialized");
     });
 }
@@ -509,7 +483,7 @@ TEST_F(TestDisableInit, TestDisableSasl_Good) {
       rpc::internal::SaslSetMutex();
       sasl_client_init(NULL);
       CHECK_OK(DisableSaslInitialization());
-      ASSERT_OK(SaslInit("kudu"));
+      ASSERT_OK(SaslInit());
     });
 }
 
@@ -520,7 +494,7 @@ TEST_F(TestDisableInit, TestMultipleSaslInit) {
   DoTest([]() {
       rpc::internal::SaslSetMutex();
       sasl_client_init(NULL);
-      ASSERT_OK(SaslInit("kudu"));
+      ASSERT_OK(SaslInit());
     }, &stderr);
   // If we are the parent, we should see the warning from the child that it automatically
   // skipped initialization because it detected that it was already initialized.
@@ -538,7 +512,7 @@ TEST_F(TestDisableInit, TestDisableSasl_NoMutexImpl) {
   DoTest([]() {
       sasl_client_init(NULL);
       CHECK_OK(DisableSaslInitialization());
-      ASSERT_OK(SaslInit("kudu"));
+      ASSERT_OK(SaslInit());
     }, &stderr);
   // If we are the parent, we should see the warning from the child.
   if (!FLAGS_is_test_child) {
@@ -552,7 +526,7 @@ TEST_F(TestDisableInit, TestMultipleSaslInit_NoMutexImpl) {
   string stderr;
   DoTest([]() {
       sasl_client_init(NULL);
-      ASSERT_OK(SaslInit("kudu"));
+      ASSERT_OK(SaslInit());
     }, &stderr);
   // If we are the parent, we should see the warning from the child that it automatically
   // skipped initialization because it detected that it was already initialized.

http://git-wip-us.apache.org/repos/asf/kudu/blob/dc852535/src/kudu/rpc/negotiation.cc
----------------------------------------------------------------------
diff --git a/src/kudu/rpc/negotiation.cc b/src/kudu/rpc/negotiation.cc
index 859016c..4b387ca 100644
--- a/src/kudu/rpc/negotiation.cc
+++ b/src/kudu/rpc/negotiation.cc
@@ -30,12 +30,15 @@
 #include "kudu/rpc/blocking_ops.h"
 #include "kudu/rpc/client_negotiation.h"
 #include "kudu/rpc/connection.h"
+#include "kudu/rpc/messenger.h"
 #include "kudu/rpc/reactor.h"
 #include "kudu/rpc/rpc_header.pb.h"
 #include "kudu/rpc/sasl_common.h"
 #include "kudu/rpc/server_negotiation.h"
+#include "kudu/security/ssl_socket.h"
 #include "kudu/util/errno.h"
 #include "kudu/util/flag_tags.h"
+#include "kudu/util/logging.h"
 #include "kudu/util/status.h"
 #include "kudu/util/trace.h"
 
@@ -50,62 +53,17 @@ DEFINE_int32(rpc_negotiation_inject_delay_ms, 0,
              "the RPC negotiation process on the server side.");
 TAG_FLAG(rpc_negotiation_inject_delay_ms, unsafe);
 
-namespace kudu {
-namespace rpc {
+DECLARE_bool(server_require_kerberos);
 
-using std::shared_ptr;
 using strings::Substitute;
 
-// Client: Send ConnectionContextPB message based on information stored in the Connection object.
-static Status SendConnectionContext(Connection* conn, const MonoTime& deadline) {
-  TRACE("Sending connection context");
-  RequestHeader header;
-  header.set_call_id(kConnectionContextCallId);
-
-  ConnectionContextPB conn_context;
-  // This field is deprecated but used by servers <Kudu 1.1. Newer server versions ignore
-  // this and use the SASL-provided username instead.
-  conn_context.mutable_deprecated_user_info()->set_real_user(
-      conn->user_credentials().real_user());
-  return SendFramedMessageBlocking(conn->socket(), header, conn_context, deadline);
-}
-
-// Server: Receive ConnectionContextPB message and update the corresponding fields in the
-// associated Connection object. Perform validation against SASL-negotiated information
-// as needed.
-static Status RecvConnectionContext(Connection* conn, const MonoTime& deadline) {
-  TRACE("Waiting for connection context");
-  faststring recv_buf(1024); // Should be plenty for a ConnectionContextPB message.
-  RequestHeader header;
-  Slice param_buf;
-  RETURN_NOT_OK(ReceiveFramedMessageBlocking(conn->socket(), &recv_buf,
-                                             &header, &param_buf, deadline));
-  DCHECK(header.IsInitialized());
-
-  if (header.call_id() != kConnectionContextCallId) {
-    return Status::IllegalState("Expected ConnectionContext callid, received",
-        Substitute("$0", header.call_id()));
-  }
-
-  ConnectionContextPB conn_context;
-  if (!conn_context.ParseFromArray(param_buf.data(), param_buf.size())) {
-    return Status::InvalidArgument("Invalid ConnectionContextPB message, missing fields",
-                                   conn_context.InitializationErrorString());
-  }
-
-  if (conn->sasl_server().authenticated_user().empty()) {
-    return Status::NotAuthorized("No user was authenticated");
-  }
-
-  conn->mutable_user_credentials()->set_real_user(conn->sasl_server().authenticated_user());
-
-  return Status::OK();
-}
+namespace kudu {
+namespace rpc {
 
 // Wait for the client connection to be established and become ready for writing.
-static Status WaitForClientConnect(Connection* conn, const MonoTime& deadline) {
+static Status WaitForClientConnect(Socket* socket, const MonoTime& deadline) {
   TRACE("Waiting for socket to connect");
-  int fd = conn->socket()->GetFd();
+  int fd = socket->GetFd();
   struct pollfd poll_fd;
   poll_fd.fd = fd;
   poll_fd.events = POLLOUT;
@@ -164,57 +122,99 @@ static Status WaitForClientConnect(Connection* conn, const MonoTime& deadline) {
 }
 
 // Disable / reset socket timeouts.
-static Status DisableSocketTimeouts(Connection* conn) {
-  RETURN_NOT_OK(conn->socket()->SetSendTimeout(MonoDelta::FromNanoseconds(0L)));
-  RETURN_NOT_OK(conn->socket()->SetRecvTimeout(MonoDelta::FromNanoseconds(0L)));
+static Status DisableSocketTimeouts(Socket* socket) {
+  RETURN_NOT_OK(socket->SetSendTimeout(MonoDelta::FromNanoseconds(0L)));
+  RETURN_NOT_OK(socket->SetRecvTimeout(MonoDelta::FromNanoseconds(0L)));
   return Status::OK();
 }
 
 // Perform client negotiation. We don't LOG() anything, we leave that to our caller.
-static Status DoClientNegotiation(Connection* conn,
-                                  const MonoTime& deadline) {
-  // The SASL initialization on the client side can be relatively heavy-weight
-  // (it may result in DNS queries in the case of GSSAPI).
-  // So, we do it while the connect() is still in progress to reduce latency.
-  //
-  // TODO(todd): we should consider doing this even before connecting, since as soon
-  // as we connect, we are tying up a negotiation thread on the server side.
-
-  RETURN_NOT_OK(conn->InitSaslClient());
-
-  RETURN_NOT_OK(WaitForClientConnect(conn, deadline));
-  RETURN_NOT_OK(conn->SetNonBlocking(false));
-  RETURN_NOT_OK(conn->InitSSLIfNecessary());
-  conn->sasl_client().set_deadline(deadline);
-  RETURN_NOT_OK(conn->sasl_client().Negotiate());
-  RETURN_NOT_OK(SendConnectionContext(conn, deadline));
-  RETURN_NOT_OK(DisableSocketTimeouts(conn));
+static Status DoClientNegotiation(Connection* conn, MonoTime deadline) {
+  ClientNegotiation client_negotiation(conn->release_socket());
+
+  // Note that the fqdn is an IP address here: we've already lost whatever DNS
+  // name the client was attempting to use. Unless krb5 is configured with 'rdns
+  // = false', it will automatically take care of reversing this address to its
+  // canonical hostname to determine the expected server principal.
+  client_negotiation.set_server_fqdn(conn->remote().host());
+
+  Status s = client_negotiation.EnableGSSAPI();
+  if (!s.ok()) {
+    // If we can't enable GSSAPI, it's likely the client is just missing the
+    // appropriate SASL plugin. We don't want to require it to be installed
+    // if the user doesn't care about connecting to servers using Kerberos
+    // authentication. So, we'll just VLOG this here. If we try to connect
+    // to a server which requires Kerberos, we'll get a negotiation error
+    // at that point.
+    if (VLOG_IS_ON(1)) {
+      KLOG_FIRST_N(INFO, 1) << "Couldn't enable GSSAPI (Kerberos) SASL plugin: "
+                            << s.message().ToString()
+                            << ". This process will be unable to connect to "
+                            << "servers requiring Kerberos authentication.";
+    }
+  }
+
+  RETURN_NOT_OK(client_negotiation.EnablePlain(conn->user_credentials().real_user(), ""));
+  client_negotiation.set_deadline(deadline);
+
+  RETURN_NOT_OK(WaitForClientConnect(client_negotiation.socket(), deadline));
+  RETURN_NOT_OK(client_negotiation.socket()->SetNonBlocking(false));
+
+  // Do SSL handshake.
+  // TODO(dan): This is a messy place to do this.
+  if (conn->reactor_thread()->reactor()->messenger()->ssl_enabled()) {
+    SSLSocket* ssl_socket = down_cast<SSLSocket*>(client_negotiation.socket());
+    RETURN_NOT_OK(ssl_socket->DoHandshake());
+  }
+
+  RETURN_NOT_OK(client_negotiation.Negotiate());
+  RETURN_NOT_OK(DisableSocketTimeouts(client_negotiation.socket()));
+
+  // Transfer the negotiated socket and state back to the connection.
+  conn->adopt_socket(client_negotiation.release_socket());
+  conn->set_remote_features(client_negotiation.take_server_features());
 
   return Status::OK();
 }
 
 // Perform server negotiation. We don't LOG() anything, we leave that to our caller.
-static Status DoServerNegotiation(Connection* conn,
-                                  const MonoTime& deadline) {
+static Status DoServerNegotiation(Connection* conn, const MonoTime& deadline) {
   if (FLAGS_rpc_negotiation_inject_delay_ms > 0) {
     LOG(WARNING) << "Injecting " << FLAGS_rpc_negotiation_inject_delay_ms
                  << "ms delay in negotiation";
     SleepFor(MonoDelta::FromMilliseconds(FLAGS_rpc_negotiation_inject_delay_ms));
   }
-  RETURN_NOT_OK(conn->SetNonBlocking(false));
-  RETURN_NOT_OK(conn->InitSSLIfNecessary());
-  RETURN_NOT_OK(conn->InitSaslServer());
-  conn->sasl_server().set_deadline(deadline);
-  RETURN_NOT_OK(conn->sasl_server().Negotiate());
-  RETURN_NOT_OK(RecvConnectionContext(conn, deadline));
-  RETURN_NOT_OK(DisableSocketTimeouts(conn));
+
+  // Create a new ServerNegotiation to handle the synchronous negotiation.
+  ServerNegotiation server_negotiation(conn->release_socket());
+  if (FLAGS_server_require_kerberos) {
+    RETURN_NOT_OK(server_negotiation.EnableGSSAPI());
+  } else {
+    RETURN_NOT_OK(server_negotiation.EnablePlain());
+  }
+
+  server_negotiation.set_deadline(deadline);
+  RETURN_NOT_OK(server_negotiation.socket()->SetNonBlocking(false));
+
+  // Do SSL handshake.
+  // TODO(dan): This is a messy place to do this.
+  if (conn->reactor_thread()->reactor()->messenger()->ssl_enabled()) {
+    SSLSocket* ssl_socket = down_cast<SSLSocket*>(server_negotiation.socket());
+    RETURN_NOT_OK(ssl_socket->DoHandshake());
+  }
+
+  RETURN_NOT_OK(server_negotiation.Negotiate());
+  RETURN_NOT_OK(DisableSocketTimeouts(server_negotiation.socket()));
+
+  // Transfer the negotiated socket and state back to the connection.
+  conn->adopt_socket(server_negotiation.release_socket());
+  conn->set_remote_features(server_negotiation.take_client_features());
+  conn->mutable_user_credentials()->set_real_user(server_negotiation.authenticated_user());
 
   return Status::OK();
 }
 
-// Perform negotiation for a connection (either server or client)
-void Negotiation::RunNegotiation(const scoped_refptr<Connection>& conn,
-                                 const MonoTime& deadline) {
+void Negotiation::RunNegotiation(const scoped_refptr<Connection>& conn, MonoTime deadline) {
   Status s;
   if (conn->direction() == Connection::SERVER) {
     s = DoServerNegotiation(conn.get(), deadline);

http://git-wip-us.apache.org/repos/asf/kudu/blob/dc852535/src/kudu/rpc/negotiation.h
----------------------------------------------------------------------
diff --git a/src/kudu/rpc/negotiation.h b/src/kudu/rpc/negotiation.h
index 0562555..64d009b 100644
--- a/src/kudu/rpc/negotiation.h
+++ b/src/kudu/rpc/negotiation.h
@@ -21,14 +21,16 @@
 #include "kudu/util/monotime.h"
 
 namespace kudu {
+
 namespace rpc {
 
 class Connection;
 
 class Negotiation {
  public:
-  static void RunNegotiation(const scoped_refptr<Connection>& conn,
-                             const MonoTime &deadline);
+
+  // Perform negotiation for a connection (either server or client)
+  static void RunNegotiation(const scoped_refptr<Connection>& conn, MonoTime deadline);
  private:
   DISALLOW_IMPLICIT_CONSTRUCTORS(Negotiation);
 };