You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@impala.apache.org by ta...@apache.org on 2018/07/13 06:03:48 UTC
[36/51] [abbrv] impala git commit: IMPALA-7006: Add KRPC folders from
kudu@334ecafd
http://git-wip-us.apache.org/repos/asf/impala/blob/fcf190c4/be/src/kudu/rpc/serialization.h
----------------------------------------------------------------------
diff --git a/be/src/kudu/rpc/serialization.h b/be/src/kudu/rpc/serialization.h
new file mode 100644
index 0000000..8406a1f
--- /dev/null
+++ b/be/src/kudu/rpc/serialization.h
@@ -0,0 +1,88 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#ifndef KUDU_RPC_SERIALIZATION_H
+#define KUDU_RPC_SERIALIZATION_H
+
+#include <cstdint>
+#include <cstring>
+
+namespace google {
+namespace protobuf {
+class MessageLite;
+} // namespace protobuf
+} // namespace google
+
+namespace kudu {
+
+class Status;
+class faststring;
+class Slice;
+
+namespace rpc {
+namespace serialization {
+
+// Serialize the request param into a buffer which is allocated by this function.
+// Uses the message's cached size by calling MessageLite::GetCachedSize().
+// In : 'message' Protobuf Message to serialize
+// 'additional_size' Optional argument which increases the recorded size
+// within param_buf. This argument is necessary if there will be
+// additional sidecars appended onto the message (that aren't part of
+// the protobuf itself).
+// 'use_cached_size' Additional optional argument whether to use the cached
+// or explicit byte size by calling MessageLite::GetCachedSize() or
+// MessageLite::ByteSize(), respectively.
+// Out: The faststring 'param_buf' to be populated with the serialized bytes.
+// The faststring's length is only determined by the amount that
+// needs to be serialized for the protobuf (i.e., no additional space
+// is reserved for 'additional_size', which only affects the
+// size indicator prefix in 'param_buf').
+void SerializeMessage(const google::protobuf::MessageLite& message,
+ faststring* param_buf, int additional_size = 0,
+ bool use_cached_size = false);
+
+// Serialize the request or response header into a buffer which is allocated
+// by this function.
+// Includes leading 32-bit length of the buffer.
+// In: Protobuf Header to serialize,
+// Length of the message param following this header in the frame.
+// Out: faststring to be populated with the serialized bytes.
+void SerializeHeader(const google::protobuf::MessageLite& header,
+ size_t param_len,
+ faststring* header_buf);
+
+// Deserialize the request.
+// In: data buffer Slice.
+// Out: parsed_header PB initialized,
+// parsed_main_message pointing to offset in original buffer containing
+// the main payload.
+Status ParseMessage(const Slice& buf,
+ google::protobuf::MessageLite* parsed_header,
+ Slice* parsed_main_message);
+
+// Serialize the RPC connection header (magic number + flags).
+// buf must have 7 bytes available (kMagicNumberLength + kHeaderFlagsLength).
+void SerializeConnHeader(uint8_t* buf);
+
+// Validate the entire rpc header (magic number + flags).
+Status ValidateConnHeader(const Slice& slice);
+
+
+} // namespace serialization
+} // namespace rpc
+} // namespace kudu
+#endif // KUDU_RPC_SERIALIZATION_H
http://git-wip-us.apache.org/repos/asf/impala/blob/fcf190c4/be/src/kudu/rpc/server_negotiation.cc
----------------------------------------------------------------------
diff --git a/be/src/kudu/rpc/server_negotiation.cc b/be/src/kudu/rpc/server_negotiation.cc
new file mode 100644
index 0000000..612701f
--- /dev/null
+++ b/be/src/kudu/rpc/server_negotiation.cc
@@ -0,0 +1,989 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "kudu/rpc/server_negotiation.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <cstdlib>
+#include <memory>
+#include <mutex>
+#include <ostream>
+#include <set>
+#include <string>
+
+#include <boost/optional/optional.hpp>
+#include <gflags/gflags.h>
+#include <gflags/gflags_declare.h>
+#include <glog/logging.h>
+#include <sasl/sasl.h>
+
+#include "kudu/gutil/macros.h"
+#include "kudu/gutil/map-util.h"
+#include "kudu/gutil/strings/split.h"
+#include "kudu/gutil/strings/stringpiece.h"
+#include "kudu/gutil/strings/substitute.h"
+#include "kudu/rpc/blocking_ops.h"
+#include "kudu/rpc/constants.h"
+#include "kudu/rpc/messenger.h"
+#include "kudu/rpc/serialization.h"
+#include "kudu/security/cert.h"
+#include "kudu/security/crypto.h"
+#include "kudu/security/init.h"
+#include "kudu/security/tls_context.h"
+#include "kudu/security/tls_handshake.h"
+#include "kudu/security/token.pb.h"
+#include "kudu/security/token_verifier.h"
+#include "kudu/util/faststring.h"
+#include "kudu/util/fault_injection.h"
+#include "kudu/util/flag_tags.h"
+#include "kudu/util/logging.h"
+#include "kudu/util/net/net_util.h"
+#include "kudu/util/net/sockaddr.h"
+#include "kudu/util/net/socket.h"
+#include "kudu/util/slice.h"
+#include "kudu/util/trace.h"
+
+using std::set;
+using std::string;
+using std::unique_ptr;
+using std::vector;
+
+// Fault injection flags.
+DEFINE_double(rpc_inject_invalid_authn_token_ratio, 0,
+ "If set higher than 0, AuthenticateByToken() randomly injects "
+ "errors replying with FATAL_INVALID_AUTHENTICATION_TOKEN code. "
+ "The flag's value corresponds to the probability of the fault "
+ "injection event. Used for only for tests.");
+TAG_FLAG(rpc_inject_invalid_authn_token_ratio, runtime);
+TAG_FLAG(rpc_inject_invalid_authn_token_ratio, unsafe);
+
+DECLARE_bool(rpc_encrypt_loopback_connections);
+
+DEFINE_string(trusted_subnets,
+ "127.0.0.0/8,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16,169.254.0.0/16",
+ "A trusted subnet whitelist. If set explicitly, all unauthenticated "
+ "or unencrypted connections are prohibited except the ones from the "
+ "specified address blocks. Otherwise, private network (127.0.0.0/8, etc.) "
+ "and local subnets of all local network interfaces will be used. Set it "
+ "to '0.0.0.0/0' to allow unauthenticated/unencrypted connections from all "
+ "remote IP addresses. However, if network access is not otherwise restricted "
+ "by a firewall, malicious users may be able to gain unauthorized access.");
+TAG_FLAG(trusted_subnets, advanced);
+TAG_FLAG(trusted_subnets, evolving);
+
+static bool ValidateTrustedSubnets(const char* /*flagname*/, const string& value) {
+ if (value.empty()) {
+ return true;
+ }
+
+ for (const auto& t : strings::Split(value, ",", strings::SkipEmpty())) {
+ kudu::Network network;
+ kudu::Status s = network.ParseCIDRString(t.ToString());
+ if (!s.ok()) {
+ LOG(ERROR) << "Invalid subnet address: " << t
+ << ". Subnet must be specified in CIDR notation.";
+ return false;
+ }
+ }
+
+ return true;
+}
+
+DEFINE_validator(trusted_subnets, &ValidateTrustedSubnets);
+
+namespace kudu {
+namespace rpc {
+
+namespace {
+vector<Network>* g_trusted_subnets = nullptr;
+} // anonymous namespace
+
+static int ServerNegotiationGetoptCb(ServerNegotiation* server_negotiation,
+ const char* plugin_name,
+ const char* option,
+ const char** result,
+ unsigned* len) {
+ return server_negotiation->GetOptionCb(plugin_name, option, result, len);
+}
+
+static int ServerNegotiationPlainAuthCb(sasl_conn_t* conn,
+ ServerNegotiation* server_negotiation,
+ const char* user,
+ const char* pass,
+ unsigned passlen,
+ struct propctx* propctx) {
+ return server_negotiation->PlainAuthCb(conn, user, pass, passlen, propctx);
+}
+
+ServerNegotiation::ServerNegotiation(unique_ptr<Socket> socket,
+ const security::TlsContext* tls_context,
+ const security::TokenVerifier* token_verifier,
+ RpcEncryption encryption,
+ std::string sasl_proto_name)
+ : socket_(std::move(socket)),
+ helper_(SaslHelper::SERVER),
+ tls_context_(tls_context),
+ encryption_(encryption),
+ tls_negotiated_(false),
+ token_verifier_(token_verifier),
+ negotiated_authn_(AuthenticationType::INVALID),
+ negotiated_mech_(SaslMechanism::INVALID),
+ sasl_proto_name_(std::move(sasl_proto_name)),
+ deadline_(MonoTime::Max()) {
+ callbacks_.push_back(SaslBuildCallback(SASL_CB_GETOPT,
+ reinterpret_cast<int (*)()>(&ServerNegotiationGetoptCb), this));
+ callbacks_.push_back(SaslBuildCallback(SASL_CB_SERVER_USERDB_CHECKPASS,
+ reinterpret_cast<int (*)()>(&ServerNegotiationPlainAuthCb), this));
+ callbacks_.push_back(SaslBuildCallback(SASL_CB_LIST_END, nullptr, nullptr));
+}
+
+Status ServerNegotiation::EnablePlain() {
+ return helper_.EnablePlain();
+}
+
+Status ServerNegotiation::EnableGSSAPI() {
+ return helper_.EnableGSSAPI();
+}
+
+SaslMechanism::Type ServerNegotiation::negotiated_mechanism() const {
+ return negotiated_mech_;
+}
+
+void ServerNegotiation::set_server_fqdn(const string& domain_name) {
+ helper_.set_server_fqdn(domain_name);
+}
+
+void ServerNegotiation::set_deadline(const MonoTime& deadline) {
+ deadline_ = deadline;
+}
+
+Status ServerNegotiation::Negotiate() {
+ TRACE("Beginning negotiation");
+
+ // Wait until starting negotiation to check that the socket, tls_context, and
+ // token_verifier are not null, since they do not need to be set for
+ // PreflightCheckGSSAPI.
+ DCHECK(socket_);
+ DCHECK(tls_context_);
+ DCHECK(token_verifier_);
+
+ // Ensure we can use blocking calls on the socket during negotiation.
+ RETURN_NOT_OK(CheckInBlockingMode(socket_.get()));
+
+ faststring recv_buf;
+
+ // Step 1: Read the connection header.
+ RETURN_NOT_OK(ValidateConnectionHeader(&recv_buf));
+
+ { // Step 2: Receive and respond to the NEGOTIATE step message.
+ NegotiatePB request;
+ RETURN_NOT_OK(RecvNegotiatePB(&request, &recv_buf));
+ RETURN_NOT_OK(HandleNegotiate(request));
+ TRACE("Negotiated authn=$0", AuthenticationTypeToString(negotiated_authn_));
+ }
+
+ // Step 3: if both ends support TLS, do a TLS handshake.
+ if (encryption_ != RpcEncryption::DISABLED &&
+ tls_context_->has_cert() &&
+ ContainsKey(client_features_, TLS)) {
+ RETURN_NOT_OK(tls_context_->InitiateHandshake(security::TlsHandshakeType::SERVER,
+ &tls_handshake_));
+
+ if (negotiated_authn_ != AuthenticationType::CERTIFICATE) {
+ // The server does not need to verify the client's certificate unless it's
+ // being used for authentication.
+ tls_handshake_.set_verification_mode(security::TlsVerificationMode::VERIFY_NONE);
+ }
+
+ while (true) {
+ NegotiatePB request;
+ RETURN_NOT_OK(RecvNegotiatePB(&request, &recv_buf));
+ Status s = HandleTlsHandshake(request);
+ if (s.ok()) break;
+ if (!s.IsIncomplete()) return s;
+ }
+ tls_negotiated_ = true;
+ }
+
+ // Rejects any connection from public routable IPs if encryption
+ // is disabled. See KUDU-1875.
+ if (!tls_negotiated_) {
+ Sockaddr addr;
+ RETURN_NOT_OK(socket_->GetPeerAddress(&addr));
+
+ if (!IsTrustedConnection(addr)) {
+ // Receives client response before sending error
+ // message, even though the response is never used,
+ // to avoid risk condition that connection gets
+ // closed before client receives server's error
+ // message.
+ NegotiatePB request;
+ RETURN_NOT_OK(RecvNegotiatePB(&request, &recv_buf));
+
+ Status s = Status::NotAuthorized("unencrypted connections from publicly routable "
+ "IPs are prohibited. See --trusted_subnets flag "
+ "for more information.",
+ addr.ToString());
+ RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_UNAUTHORIZED, s));
+ return s;
+ }
+ }
+
+ // Step 4: Authentication
+ switch (negotiated_authn_) {
+ case AuthenticationType::SASL:
+ RETURN_NOT_OK(AuthenticateBySasl(&recv_buf));
+ break;
+ case AuthenticationType::TOKEN:
+ RETURN_NOT_OK(AuthenticateByToken(&recv_buf));
+ break;
+ case AuthenticationType::CERTIFICATE:
+ RETURN_NOT_OK(AuthenticateByCertificate());
+ break;
+ case AuthenticationType::INVALID: LOG(FATAL) << "unreachable";
+ }
+
+ // Step 5: Receive connection context.
+ RETURN_NOT_OK(RecvConnectionContext(&recv_buf));
+
+ TRACE("Negotiation successful");
+ return Status::OK();
+}
+
+Status ServerNegotiation::PreflightCheckGSSAPI(const std::string& sasl_proto_name) {
+ // TODO(todd): the error messages that come from this function on el6
+ // are relatively useless due to the following krb5 bug:
+ // http://krbdev.mit.edu/rt/Ticket/Display.html?id=6973
+ // This may not be useful anymore given the keytab login that happens
+ // in security/init.cc.
+
+ // Initialize a ServerNegotiation with a null socket, and enable
+ // only GSSAPI.
+ //
+ // We aren't going to actually send/receive any messages, but
+ // this makes it easier to reuse the initialization code.
+ ServerNegotiation server(
+ nullptr, nullptr, nullptr, RpcEncryption::OPTIONAL, sasl_proto_name);
+ Status s = server.EnableGSSAPI();
+ if (!s.ok()) {
+ return Status::RuntimeError(s.message());
+ }
+
+ RETURN_NOT_OK(server.InitSaslServer());
+
+ // Start the SASL server as if we were accepting a connection.
+ const char* server_out = nullptr; // ignored
+ uint32_t server_out_len = 0;
+ s = WrapSaslCall(server.sasl_conn_.get(), [&]() {
+ return sasl_server_start(
+ server.sasl_conn_.get(),
+ kSaslMechGSSAPI,
+ "", 0, // Pass a 0-length token.
+ &server_out, &server_out_len);
+ });
+
+ // We expect 'Incomplete' status to indicate that the first step of negotiation
+ // was correct.
+ if (s.IsIncomplete()) return Status::OK();
+
+ string err_msg = s.message().ToString();
+ if (err_msg == "Permission denied") {
+ // For bad keytab permissions, we get a rather vague message. So,
+ // we make it more specific for better usability.
+ err_msg = "error accessing keytab: " + err_msg;
+ }
+ return Status::RuntimeError(err_msg);
+}
+
+Status ServerNegotiation::RecvNegotiatePB(NegotiatePB* msg, faststring* recv_buf) {
+ RequestHeader header;
+ Slice param_buf;
+ RETURN_NOT_OK(ReceiveFramedMessageBlocking(socket(), recv_buf, &header, ¶m_buf, deadline_));
+ Status s = helper_.CheckNegotiateCallId(header.call_id());
+ if (!s.ok()) {
+ RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_INVALID_RPC_HEADER, s));
+ return s;
+ }
+
+ s = helper_.ParseNegotiatePB(param_buf, msg);
+ if (!s.ok()) {
+ RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_DESERIALIZING_REQUEST, s));
+ return s;
+ }
+
+ TRACE("Received $0 NegotiatePB request", NegotiatePB::NegotiateStep_Name(msg->step()));
+ return Status::OK();
+}
+
+Status ServerNegotiation::SendNegotiatePB(const NegotiatePB& msg) {
+ ResponseHeader header;
+ header.set_call_id(kNegotiateCallId);
+
+ DCHECK(socket_);
+ DCHECK(msg.IsInitialized()) << "message must be initialized";
+ DCHECK(msg.has_step()) << "message must have a step";
+
+ TRACE("Sending $0 NegotiatePB response", NegotiatePB::NegotiateStep_Name(msg.step()));
+ return SendFramedMessageBlocking(socket(), header, msg, deadline_);
+}
+
+Status ServerNegotiation::SendError(ErrorStatusPB::RpcErrorCodePB code, const Status& err) {
+ DCHECK(!err.ok());
+
+ // Create header with negotiation-specific callId
+ ResponseHeader header;
+ header.set_call_id(kNegotiateCallId);
+ header.set_is_error(true);
+
+ // Get RPC error code from Status object
+ ErrorStatusPB msg;
+ msg.set_code(code);
+ msg.set_message(err.ToString());
+
+ TRACE("Sending RPC error: $0: $1", ErrorStatusPB::RpcErrorCodePB_Name(code), err.ToString());
+ RETURN_NOT_OK(SendFramedMessageBlocking(socket(), header, msg, deadline_));
+
+ return Status::OK();
+}
+
+Status ServerNegotiation::ValidateConnectionHeader(faststring* recv_buf) {
+ TRACE("Waiting for connection header");
+ size_t num_read;
+ const size_t conn_header_len = kMagicNumberLength + kHeaderFlagsLength;
+ recv_buf->resize(conn_header_len);
+ RETURN_NOT_OK(socket_->BlockingRecv(recv_buf->data(), conn_header_len, &num_read, deadline_));
+ DCHECK_EQ(conn_header_len, num_read);
+
+ RETURN_NOT_OK(serialization::ValidateConnHeader(*recv_buf));
+ TRACE("Connection header received");
+ return Status::OK();
+}
+
+// calls sasl_server_init() and sasl_server_new()
+Status ServerNegotiation::InitSaslServer() {
+ // TODO(unknown): Support security flags.
+ unsigned secflags = 0;
+
+ sasl_conn_t* sasl_conn = nullptr;
+ RETURN_NOT_OK_PREPEND(WrapSaslCall(nullptr /* no conn */, [&]() {
+ return sasl_server_new(
+ // Registered name of the service using SASL. Required.
+ sasl_proto_name_.c_str(),
+ // The fully qualified domain name of this server.
+ helper_.server_fqdn(),
+ // Permits multiple user realms on server. NULL == use default.
+ nullptr,
+ // Local and remote IP address strings. We don't use any mechanisms
+ // which need these.
+ nullptr,
+ nullptr,
+ // Connection-specific callbacks.
+ &callbacks_[0],
+ // Security flags.
+ secflags,
+ &sasl_conn);
+ }), "Unable to create new SASL server");
+ sasl_conn_.reset(sasl_conn);
+ return Status::OK();
+}
+
+Status ServerNegotiation::HandleNegotiate(const NegotiatePB& request) {
+ if (request.step() != NegotiatePB::NEGOTIATE) {
+ Status s = Status::NotAuthorized("expected NEGOTIATE step",
+ NegotiatePB::NegotiateStep_Name(request.step()));
+ RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_UNAUTHORIZED, s));
+ return s;
+ }
+ TRACE("Received NEGOTIATE request from client");
+
+ // Fill in the set of features supported by the client.
+ for (int flag : request.supported_features()) {
+ // We only add the features that our local build knows about.
+ RpcFeatureFlag feature_flag = RpcFeatureFlag_IsValid(flag) ?
+ static_cast<RpcFeatureFlag>(flag) : UNKNOWN;
+ if (feature_flag != UNKNOWN) {
+ client_features_.insert(feature_flag);
+ }
+ }
+
+ if (encryption_ == RpcEncryption::REQUIRED &&
+ !ContainsKey(client_features_, RpcFeatureFlag::TLS)) {
+ Status s = Status::NotAuthorized("client does not support required TLS encryption");
+ RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_UNAUTHORIZED, s));
+ return s;
+ }
+
+ // Find the set of mutually supported authentication types.
+ set<AuthenticationType> authn_types;
+ if (request.authn_types().empty()) {
+ // If the client doesn't send any support authentication types, we assume
+ // support for SASL. This preserves backwards compatibility with clients who
+ // don't support security features.
+ authn_types.insert(AuthenticationType::SASL);
+ } else {
+ for (const auto& type : request.authn_types()) {
+ switch (type.type_case()) {
+ case AuthenticationTypePB::kSasl:
+ authn_types.insert(AuthenticationType::SASL);
+ break;
+ case AuthenticationTypePB::kToken:
+ authn_types.insert(AuthenticationType::TOKEN);
+ break;
+ case AuthenticationTypePB::kCertificate:
+ // We only provide authenticated TLS if the certificates are generated
+ // by the internal CA.
+ if (!tls_context_->is_external_cert()) {
+ authn_types.insert(AuthenticationType::CERTIFICATE);
+ }
+ break;
+ case AuthenticationTypePB::TYPE_NOT_SET: {
+ Sockaddr addr;
+ RETURN_NOT_OK(socket_->GetPeerAddress(&addr));
+ KLOG_EVERY_N_SECS(WARNING, 60)
+ << "client supports unknown authentication type, consider updating server, address: "
+ << addr.ToString();
+ break;
+ }
+ }
+ }
+
+ if (authn_types.empty()) {
+ Status s = Status::NotSupported("no mutually supported authentication types");
+ RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_UNAUTHORIZED, s));
+ return s;
+ }
+ }
+
+ if (encryption_ != RpcEncryption::DISABLED &&
+ ContainsKey(authn_types, AuthenticationType::CERTIFICATE) &&
+ tls_context_->has_signed_cert()) {
+ // If the client supports it and we are locally configured with TLS and have
+ // a CA-signed cert, choose cert authn.
+ // TODO(KUDU-1924): consider adding the fingerprint of the CA cert which signed
+ // the client's cert to the authentication message.
+ negotiated_authn_ = AuthenticationType::CERTIFICATE;
+ } else if (ContainsKey(authn_types, AuthenticationType::TOKEN) &&
+ token_verifier_->GetMaxKnownKeySequenceNumber() >= 0 &&
+ encryption_ != RpcEncryption::DISABLED &&
+ tls_context_->has_signed_cert()) {
+ // If the client supports it, we have a TSK to verify the client's token,
+ // and we have a signed-cert so the client can verify us, choose token authn.
+ // TODO(KUDU-1924): consider adding the TSK sequence number to the authentication
+ // message.
+ negotiated_authn_ = AuthenticationType::TOKEN;
+ } else {
+ // Otherwise we always can fallback to SASL.
+ DCHECK(ContainsKey(authn_types, AuthenticationType::SASL));
+ negotiated_authn_ = AuthenticationType::SASL;
+ }
+
+ // Fill in the NEGOTIATE step response for the client.
+ NegotiatePB response;
+ response.set_step(NegotiatePB::NEGOTIATE);
+
+ // Tell the client which features we support.
+ server_features_ = kSupportedServerRpcFeatureFlags;
+ if (tls_context_->has_cert() && encryption_ != RpcEncryption::DISABLED) {
+ server_features_.insert(TLS);
+ // If the remote peer is local, then we allow using TLS for authentication
+ // without encryption or integrity.
+ if (socket_->IsLoopbackConnection() && !FLAGS_rpc_encrypt_loopback_connections) {
+ server_features_.insert(TLS_AUTHENTICATION_ONLY);
+ }
+ }
+
+ for (RpcFeatureFlag feature : server_features_) {
+ response.add_supported_features(feature);
+ }
+
+ switch (negotiated_authn_) {
+ case AuthenticationType::CERTIFICATE:
+ response.add_authn_types()->mutable_certificate();
+ break;
+ case AuthenticationType::TOKEN:
+ response.add_authn_types()->mutable_token();
+ break;
+ case AuthenticationType::SASL: {
+ response.add_authn_types()->mutable_sasl();
+ const set<SaslMechanism::Type>& server_mechs = helper_.EnabledMechs();
+ if (PREDICT_FALSE(server_mechs.empty())) {
+ // This will happen if no mechanisms are enabled before calling Init()
+ Status s = Status::NotAuthorized("SASL server mechanism list is empty!");
+ LOG(ERROR) << s.ToString();
+ TRACE("Sending FATAL_UNAUTHORIZED response to client");
+ RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_UNAUTHORIZED, s));
+ return s;
+ }
+
+ for (auto mechanism : server_mechs) {
+ response.add_sasl_mechanisms()->set_mechanism(SaslMechanism::name_of(mechanism));
+ }
+ break;
+ }
+ case AuthenticationType::INVALID: LOG(FATAL) << "unreachable";
+ }
+
+ return SendNegotiatePB(response);
+}
+
+Status ServerNegotiation::HandleTlsHandshake(const NegotiatePB& request) {
+ if (PREDICT_FALSE(request.step() != NegotiatePB::TLS_HANDSHAKE)) {
+ Status s = Status::NotAuthorized("expected TLS_HANDSHAKE step",
+ NegotiatePB::NegotiateStep_Name(request.step()));
+ RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_UNAUTHORIZED, s));
+ return s;
+ }
+
+ if (PREDICT_FALSE(!request.has_tls_handshake())) {
+ Status s = Status::NotAuthorized(
+ "No TLS handshake token in TLS_HANDSHAKE request from client");
+ RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_UNAUTHORIZED, s));
+ return s;
+ }
+
+ string token;
+ Status s = tls_handshake_.Continue(request.tls_handshake(), &token);
+
+ if (PREDICT_FALSE(!s.IsIncomplete() && !s.ok())) {
+ RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_UNAUTHORIZED, s));
+ return s;
+ }
+
+ // Regardless of whether this is the final handshake roundtrip (in which case
+ // Continue would have returned OK), we still need to return a response.
+ RETURN_NOT_OK(SendTlsHandshake(std::move(token)));
+ RETURN_NOT_OK(s);
+
+ // TLS handshake is finished.
+ if (ContainsKey(server_features_, TLS_AUTHENTICATION_ONLY) &&
+ ContainsKey(client_features_, TLS_AUTHENTICATION_ONLY)) {
+ TRACE("Negotiated auth-only $0 with cipher $1",
+ tls_handshake_.GetProtocol(), tls_handshake_.GetCipherDescription());
+ return tls_handshake_.FinishNoWrap(*socket_);
+ }
+
+ TRACE("Negotiated $0 with cipher $1",
+ tls_handshake_.GetProtocol(), tls_handshake_.GetCipherDescription());
+ return tls_handshake_.Finish(&socket_);
+}
+
+Status ServerNegotiation::SendTlsHandshake(string tls_token) {
+ NegotiatePB msg;
+ msg.set_step(NegotiatePB::TLS_HANDSHAKE);
+ msg.mutable_tls_handshake()->swap(tls_token);
+ return SendNegotiatePB(msg);
+}
+
+Status ServerNegotiation::AuthenticateBySasl(faststring* recv_buf) {
+ RETURN_NOT_OK(InitSaslServer());
+
+ NegotiatePB request;
+ RETURN_NOT_OK(RecvNegotiatePB(&request, recv_buf));
+ Status s = HandleSaslInitiate(request);
+
+ while (s.IsIncomplete()) {
+ RETURN_NOT_OK(RecvNegotiatePB(&request, recv_buf));
+ s = HandleSaslResponse(request);
+ }
+ RETURN_NOT_OK(s);
+
+ const char* c_username = nullptr;
+ int rc = sasl_getprop(sasl_conn_.get(), SASL_USERNAME,
+ reinterpret_cast<const void**>(&c_username));
+ // We expect that SASL_USERNAME will always get set.
+ CHECK(rc == SASL_OK && c_username != nullptr) << "No username on authenticated connection";
+ if (negotiated_mech_ == SaslMechanism::GSSAPI) {
+ // The SASL library doesn't include the user's realm in the username if it's the
+ // same realm as the default realm of the server. So, we pass it back through the
+ // Kerberos library to add back the realm if necessary.
+ string principal = c_username;
+ RETURN_NOT_OK_PREPEND(security::CanonicalizeKrb5Principal(&principal),
+ "could not canonicalize krb5 principal");
+
+ // Map the principal to the corresponding local username. For example, admins
+ // can set up mappings so that joe@REMOTEREALM becomes something like 'remote-joe'
+ // locally for the purposes of group mapping, ACLs, etc.
+ string local_name;
+ RETURN_NOT_OK_PREPEND(security::MapPrincipalToLocalName(principal, &local_name),
+ strings::Substitute("could not map krb5 principal '$0' to username",
+ principal));
+ authenticated_user_.SetAuthenticatedByKerberos(std::move(local_name), std::move(principal));
+ } else {
+ authenticated_user_.SetUnauthenticated(c_username);
+ }
+ return Status::OK();
+}
+
+Status ServerNegotiation::AuthenticateByToken(faststring* recv_buf) {
+ // Sanity check that TLS has been negotiated. Receiving the token on an
+ // unencrypted channel is a big no-no.
+ CHECK(tls_negotiated_);
+
+ // Receive the token from the client.
+ NegotiatePB pb;
+ RETURN_NOT_OK(RecvNegotiatePB(&pb, recv_buf));
+
+ if (pb.step() != NegotiatePB::TOKEN_EXCHANGE) {
+ Status s = Status::NotAuthorized("expected TOKEN_EXCHANGE step",
+ NegotiatePB::NegotiateStep_Name(pb.step()));
+ }
+ if (!pb.has_authn_token()) {
+ Status s = Status::NotAuthorized("TOKEN_EXCHANGE message must include an authentication token");
+ }
+
+ // TODO(KUDU-1924): propagate the specific token verification failure back to the client,
+ // so it knows how to intelligently retry.
+ security::TokenPB token;
+ auto verification_result = token_verifier_->VerifyTokenSignature(pb.authn_token(), &token);
+ switch (verification_result) {
+ case security::VerificationResult::VALID: break;
+
+ case security::VerificationResult::INVALID_TOKEN:
+ case security::VerificationResult::INVALID_SIGNATURE:
+ case security::VerificationResult::EXPIRED_TOKEN:
+ case security::VerificationResult::EXPIRED_SIGNING_KEY: {
+ // These errors indicate the client should get a new token and try again.
+ Status s = Status::NotAuthorized(VerificationResultToString(verification_result));
+ RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_INVALID_AUTHENTICATION_TOKEN, s));
+ return s;
+ }
+
+ case security::VerificationResult::UNKNOWN_SIGNING_KEY: {
+ // The server doesn't recognize the signing key. This indicates that the
+ // server has not been updated with the most recent TSKs, so tell the
+ // client to try again later.
+ Status s = Status::NotAuthorized(VerificationResultToString(verification_result));
+ RETURN_NOT_OK(SendError(ErrorStatusPB::ERROR_UNAVAILABLE, s));
+ return s;
+ }
+ case security::VerificationResult::INCOMPATIBLE_FEATURE: {
+ Status s = Status::NotAuthorized(VerificationResultToString(verification_result));
+ // These error types aren't recoverable by having the client get a new token.
+ RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_UNAUTHORIZED, s));
+ return s;
+ }
+ }
+
+ if (!token.has_authn()) {
+ Status s = Status::NotAuthorized("non-authentication token presented for authentication");
+ RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_UNAUTHORIZED, s));
+ return s;
+ }
+ if (!token.authn().has_username()) {
+ // This is a runtime error because there should be no way a client could
+ // get a signed authn token without a subject.
+ Status s = Status::RuntimeError("authentication token has no username");
+ RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_INVALID_AUTHENTICATION_TOKEN, s));
+ return s;
+ }
+
+ if (PREDICT_FALSE(FLAGS_rpc_inject_invalid_authn_token_ratio > 0)) {
+ security::VerificationResult res;
+ int sel = rand() % 4;
+ switch (sel) {
+ case 0:
+ res = security::VerificationResult::INVALID_TOKEN;
+ break;
+ case 1:
+ res = security::VerificationResult::INVALID_SIGNATURE;
+ break;
+ case 2:
+ res = security::VerificationResult::EXPIRED_TOKEN;
+ break;
+ case 3:
+ res = security::VerificationResult::EXPIRED_SIGNING_KEY;
+ break;
+ }
+ if (kudu::fault_injection::MaybeTrue(FLAGS_rpc_inject_invalid_authn_token_ratio)) {
+ Status s = Status::NotAuthorized(VerificationResultToString(res));
+ RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_INVALID_AUTHENTICATION_TOKEN, s));
+ return s;
+ }
+ }
+
+ authenticated_user_.SetAuthenticatedByToken(token.authn().username());
+
+ // Respond with success message.
+ pb.Clear();
+ pb.set_step(NegotiatePB::TOKEN_EXCHANGE);
+ return SendNegotiatePB(pb);
+}
+
+Status ServerNegotiation::AuthenticateByCertificate() {
+ // Sanity check that TLS has been negotiated. Cert-based authentication is
+ // only possible with TLS.
+ CHECK(tls_negotiated_);
+
+ // Grab the subject from the client's cert.
+ security::Cert cert;
+ RETURN_NOT_OK(tls_handshake_.GetRemoteCert(&cert));
+
+ boost::optional<string> user_id = cert.UserId();
+ boost::optional<string> principal = cert.KuduKerberosPrincipal();
+
+ if (!user_id) {
+ Status s = Status::NotAuthorized("did not find expected X509 userId extension in cert");
+ RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_INVALID_AUTHENTICATION_TOKEN, s));
+ return s;
+ }
+
+ authenticated_user_.SetAuthenticatedByClientCert(*user_id, std::move(principal));
+
+ return Status::OK();
+}
+
+Status ServerNegotiation::HandleSaslInitiate(const NegotiatePB& request) {
+ if (PREDICT_FALSE(request.step() != NegotiatePB::SASL_INITIATE)) {
+ Status s = Status::NotAuthorized("expected SASL_INITIATE step",
+ NegotiatePB::NegotiateStep_Name(request.step()));
+ RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_UNAUTHORIZED, s));
+ return s;
+ }
+ TRACE("Received SASL_INITIATE request from client");
+
+ if (request.sasl_mechanisms_size() != 1) {
+ Status s = Status::NotAuthorized(
+ "SASL_INITIATE request must include exactly one SASL mechanism, found",
+ std::to_string(request.sasl_mechanisms_size()));
+ RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_UNAUTHORIZED, s));
+ return s;
+ }
+
+ const string& mechanism = request.sasl_mechanisms(0).mechanism();
+ TRACE("Client requested to use mechanism: $0", mechanism);
+
+ negotiated_mech_ = SaslMechanism::value_of(mechanism);
+
+ // Rejects any connection from public routable IPs if authentication mechanism
+ // is plain. See KUDU-1875.
+ if (negotiated_mech_ == SaslMechanism::PLAIN) {
+ Sockaddr addr;
+ RETURN_NOT_OK(socket_->GetPeerAddress(&addr));
+
+ if (!IsTrustedConnection(addr)) {
+ Status s = Status::NotAuthorized("unauthenticated connections from publicly "
+ "routable IPs are prohibited. See "
+ "--trusted_subnets flag for more information.",
+ addr.ToString());
+ RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_UNAUTHORIZED, s));
+ return s;
+ }
+ }
+
+ // If the negotiated mechanism is GSSAPI (Kerberos), configure SASL to use
+ // integrity protection so that the channel bindings and nonce can be
+ // verified.
+ if (negotiated_mech_ == SaslMechanism::GSSAPI) {
+ RETURN_NOT_OK(EnableProtection(sasl_conn_.get(), SaslProtection::kIntegrity));
+ }
+
+ const char* server_out = nullptr;
+ uint32_t server_out_len = 0;
+ TRACE("Calling sasl_server_start()");
+
+ Status s = WrapSaslCall(sasl_conn_.get(), [&]() {
+ return sasl_server_start(
+ sasl_conn_.get(), // The SASL connection context created by init()
+ mechanism.c_str(), // The mechanism requested by the client.
+ request.token().c_str(), // Optional string the client gave us.
+ request.token().length(), // Client string len.
+ &server_out, // The output of the SASL library, might not be NULL terminated
+ &server_out_len); // Output len.
+ });
+
+ if (PREDICT_FALSE(!s.ok() && !s.IsIncomplete())) {
+ RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_UNAUTHORIZED, s));
+ return s;
+ }
+
+ // We have a valid mechanism match
+ if (s.ok()) {
+ DCHECK(server_out_len == 0);
+ RETURN_NOT_OK(SendSaslSuccess());
+ } else { // s.IsIncomplete() (equivalent to SASL_CONTINUE)
+ RETURN_NOT_OK(SendSaslChallenge(server_out, server_out_len));
+ }
+ return s;
+}
+
+Status ServerNegotiation::HandleSaslResponse(const NegotiatePB& request) {
+ if (PREDICT_FALSE(request.step() != NegotiatePB::SASL_RESPONSE)) {
+ Status s = Status::NotAuthorized("expected SASL_RESPONSE step",
+ NegotiatePB::NegotiateStep_Name(request.step()));
+ RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_UNAUTHORIZED, s));
+ return s;
+ }
+ TRACE("Received SASL_RESPONSE request from client");
+
+ if (!request.has_token()) {
+ Status s = Status::NotAuthorized("no token in SASL_RESPONSE from client");
+ RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_UNAUTHORIZED, s));
+ return s;
+ }
+
+ const char* server_out = nullptr;
+ uint32_t server_out_len = 0;
+ TRACE("Calling sasl_server_step()");
+ Status s = WrapSaslCall(sasl_conn_.get(), [&]() {
+ return sasl_server_step(
+ sasl_conn_.get(), // The SASL connection context created by init()
+ request.token().c_str(), // Optional string the client gave us
+ request.token().length(), // Client string len
+ &server_out, // The output of the SASL library, might not be NULL terminated
+ &server_out_len); // Output len
+ });
+
+ if (s.ok()) {
+ DCHECK(server_out_len == 0);
+ return SendSaslSuccess();
+ }
+ if (s.IsIncomplete()) {
+ return SendSaslChallenge(server_out, server_out_len);
+ }
+ RETURN_NOT_OK(SendError(ErrorStatusPB::FATAL_UNAUTHORIZED, s));
+ return s;
+}
+
+Status ServerNegotiation::SendSaslChallenge(const char* challenge, unsigned clen) {
+ NegotiatePB response;
+ response.set_step(NegotiatePB::SASL_CHALLENGE);
+ response.mutable_token()->assign(challenge, clen);
+ RETURN_NOT_OK(SendNegotiatePB(response));
+ return Status::Incomplete("");
+}
+
+Status ServerNegotiation::SendSaslSuccess() {
+ NegotiatePB response;
+ response.set_step(NegotiatePB::SASL_SUCCESS);
+
+ if (negotiated_mech_ == SaslMechanism::GSSAPI) {
+ // Send a nonce to the client.
+ nonce_ = string();
+ RETURN_NOT_OK(security::GenerateNonce(nonce_.get_ptr()));
+ response.set_nonce(*nonce_);
+
+ if (tls_negotiated_) {
+ // Send the channel bindings to the client.
+ security::Cert cert;
+ RETURN_NOT_OK(tls_handshake_.GetLocalCert(&cert));
+
+ string plaintext_channel_bindings;
+ RETURN_NOT_OK(cert.GetServerEndPointChannelBindings(&plaintext_channel_bindings));
+
+ Slice ciphertext;
+ RETURN_NOT_OK(SaslEncode(sasl_conn_.get(),
+ plaintext_channel_bindings,
+ &ciphertext));
+ *response.mutable_channel_bindings() = ciphertext.ToString();
+ }
+ }
+
+ RETURN_NOT_OK(SendNegotiatePB(response));
+ return Status::OK();
+}
+
+Status ServerNegotiation::RecvConnectionContext(faststring* recv_buf) {
+ TRACE("Waiting for connection context");
+ RequestHeader header;
+ Slice param_buf;
+ RETURN_NOT_OK(ReceiveFramedMessageBlocking(socket(), recv_buf, &header, ¶m_buf, deadline_));
+ DCHECK(header.IsInitialized());
+
+ if (header.call_id() != kConnectionContextCallId) {
+ return Status::NotAuthorized("expected ConnectionContext callid, received",
+ std::to_string(header.call_id()));
+ }
+
+ ConnectionContextPB conn_context;
+ if (!conn_context.ParseFromArray(param_buf.data(), param_buf.size())) {
+ return Status::NotAuthorized("invalid ConnectionContextPB message, missing fields",
+ conn_context.InitializationErrorString());
+ }
+
+ if (nonce_) {
+ Status s;
+ // Validate that the client returned the correct SASL protected nonce.
+ if (!conn_context.has_encoded_nonce()) {
+ return Status::NotAuthorized("ConnectionContextPB wrapped nonce missing");
+ }
+
+ Slice decoded_nonce;
+ s = SaslDecode(sasl_conn_.get(), conn_context.encoded_nonce(), &decoded_nonce);
+ if (!s.ok()) {
+ return Status::NotAuthorized("failed to decode nonce", s.message());
+ }
+
+ if (*nonce_ != decoded_nonce) {
+ Sockaddr addr;
+ RETURN_NOT_OK(socket_->GetPeerAddress(&addr));
+ LOG(WARNING) << "Received an invalid connection nonce from client "
+ << addr.ToString()
+ << ", this could indicate a replay attack";
+ return Status::NotAuthorized("nonce mismatch");
+ }
+ }
+
+ return Status::OK();
+}
+
+int ServerNegotiation::GetOptionCb(const char* plugin_name,
+ const char* option,
+ const char** result,
+ unsigned* len) {
+ return helper_.GetOptionCb(plugin_name, option, result, len);
+}
+
+int ServerNegotiation::PlainAuthCb(sasl_conn_t* /*conn*/,
+ const char* user,
+ const char* /*pass*/,
+ unsigned /*passlen*/,
+ struct propctx* /*propctx*/) {
+ TRACE("Received PLAIN auth, user=$0", user);
+ if (PREDICT_FALSE(!helper_.IsPlainEnabled())) {
+ LOG(DFATAL) << "Password authentication callback called while PLAIN auth disabled";
+ return SASL_BADPARAM;
+ }
+ // We always allow PLAIN authentication to succeed.
+ return SASL_OK;
+}
+
+bool ServerNegotiation::IsTrustedConnection(const Sockaddr& addr) {
+ static std::once_flag once;
+ std::call_once(once, [] {
+ g_trusted_subnets = new vector<Network>();
+ CHECK_OK(Network::ParseCIDRStrings(FLAGS_trusted_subnets, g_trusted_subnets));
+
+ // If --trusted_subnets is not set explicitly, local subnets of all local network
+ // interfaces as well as the default private subnets will be used.
+ if (google::GetCommandLineFlagInfoOrDie("trusted_subnets").is_default) {
+ std::vector<Network> local_networks;
+ WARN_NOT_OK(GetLocalNetworks(&local_networks),
+ "Unable to get local networks.");
+
+ g_trusted_subnets->insert(g_trusted_subnets->end(),
+ local_networks.begin(),
+ local_networks.end());
+ }
+ });
+
+ return std::any_of(g_trusted_subnets->begin(), g_trusted_subnets->end(),
+ [&](const Network& t) { return t.WithinNetwork(addr); });
+}
+
+} // namespace rpc
+} // namespace kudu
http://git-wip-us.apache.org/repos/asf/impala/blob/fcf190c4/be/src/kudu/rpc/server_negotiation.h
----------------------------------------------------------------------
diff --git a/be/src/kudu/rpc/server_negotiation.h b/be/src/kudu/rpc/server_negotiation.h
new file mode 100644
index 0000000..2582af1
--- /dev/null
+++ b/be/src/kudu/rpc/server_negotiation.h
@@ -0,0 +1,259 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <set>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include <boost/optional/optional.hpp>
+#include <glog/logging.h>
+#include <sasl/sasl.h>
+
+#include "kudu/gutil/port.h"
+#include "kudu/rpc/messenger.h"
+#include "kudu/rpc/negotiation.h"
+#include "kudu/rpc/remote_user.h"
+#include "kudu/rpc/rpc_header.pb.h"
+#include "kudu/rpc/sasl_common.h"
+#include "kudu/rpc/sasl_helper.h"
+#include "kudu/security/security_flags.h"
+#include "kudu/security/tls_handshake.h"
+#include "kudu/util/monotime.h"
+#include "kudu/util/net/socket.h"
+#include "kudu/util/status.h"
+
+namespace kudu {
+
+class Sockaddr;
+class faststring;
+
+namespace security {
+class TlsContext;
+class TokenVerifier;
+}
+
+namespace rpc {
+
+// Class for doing KRPC negotiation with a remote client over a bidirectional socket.
+// Operations on this class are NOT thread-safe.
+class ServerNegotiation {
+ public:
+ // Creates a new server negotiation instance, taking ownership of the
+ // provided socket. After completing the negotiation process by setting the
+ // desired options and calling Negotiate((), the socket can be retrieved with
+ // release_socket().
+ //
+ // The provided TlsContext must outlive this negotiation instance.
+ ServerNegotiation(std::unique_ptr<Socket> socket,
+ const security::TlsContext* tls_context,
+ const security::TokenVerifier* token_verifier,
+ RpcEncryption encryption,
+ std::string sasl_proto_name);
+
+ // Enable PLAIN authentication.
+ // Despite PLAIN authentication taking a username and password, we disregard
+ // the password and use this as a "unauthenticated" mode.
+ // Must be called before Negotiate().
+ Status EnablePlain();
+
+ // Enable GSSAPI (Kerberos) authentication.
+ // Must be called before Negotiate().
+ Status EnableGSSAPI();
+
+ // Returns mechanism negotiated by this connection.
+ // Must be called after Negotiate().
+ SaslMechanism::Type negotiated_mechanism() const;
+
+ // Returns the negotiated authentication type for the connection.
+ // Must be called after Negotiate().
+ AuthenticationType negotiated_authn() const {
+ DCHECK_NE(negotiated_authn_, AuthenticationType::INVALID);
+ return negotiated_authn_;
+ }
+
+ // Returns true if TLS was negotiated.
+ // Must be called after Negotiate().
+ bool tls_negotiated() const {
+ return tls_negotiated_;
+ }
+
+ // Returns the set of RPC system features supported by the remote client.
+ // Must be called after Negotiate().
+ std::set<RpcFeatureFlag> client_features() const {
+ return client_features_;
+ }
+
+ // Returns the set of RPC system features supported by the remote client.
+ // Must be called after Negotiate().
+ // Subsequent calls to this method or client_features() will return an empty set.
+ std::set<RpcFeatureFlag> take_client_features() {
+ return std::move(client_features_);
+ }
+
+ // Name of the user that was authenticated.
+ // Must be called after a successful Negotiate().
+ //
+ // Subsequent calls will return bogus data.
+ RemoteUser take_authenticated_user() {
+ return std::move(authenticated_user_);
+ }
+
+ // Specify the fully-qualified domain name of the remote server.
+ // Must be called before Negotiate(). Required for some mechanisms.
+ void set_server_fqdn(const std::string& domain_name);
+
+ // Set deadline for connection negotiation.
+ void set_deadline(const MonoTime& deadline);
+
+ Socket* socket() const { return socket_.get(); }
+
+ // Returns the socket owned by this server negotiation. The caller will own
+ // the socket after this call, and the negotiation instance should no longer
+ // be used. Must be called after Negotiate().
+ std::unique_ptr<Socket> release_socket() { return std::move(socket_); }
+
+ // Negotiate with the remote client. Should only be called once per
+ // ServerNegotiation and socket instance, after all options have been set.
+ //
+ // Returns OK on success, otherwise may return NotAuthorized, NotSupported, or
+ // another non-OK status.
+ Status Negotiate() WARN_UNUSED_RESULT;
+
+ // SASL callback for plugin options, supported mechanisms, etc.
+ // Returns SASL_FAIL if the option is not handled, which does not fail the handshake.
+ int GetOptionCb(const char* plugin_name, const char* option,
+ const char** result, unsigned* len);
+
+ // SASL callback for PLAIN authentication via SASL_CB_SERVER_USERDB_CHECKPASS.
+ int PlainAuthCb(sasl_conn_t* conn, const char* user, const char* pass,
+ unsigned passlen, struct propctx* propctx);
+
+ // Perform a "pre-flight check" that everything required to act as a Kerberos
+ // server is properly set up.
+ static Status PreflightCheckGSSAPI(const std::string& sasl_proto_name) WARN_UNUSED_RESULT;
+
+ private:
+
+ // Parse a negotiate request from the client, deserializing it into 'msg'.
+ // If the request is malformed, sends an error message to the client.
+ Status RecvNegotiatePB(NegotiatePB* msg, faststring* recv_buf) WARN_UNUSED_RESULT;
+
+ // Encode and send the specified negotiate response message to the server.
+ Status SendNegotiatePB(const NegotiatePB& msg) WARN_UNUSED_RESULT;
+
+ // Encode and send the specified RPC error message to the client.
+ // Calls Status.ToString() for the embedded error message.
+ Status SendError(ErrorStatusPB::RpcErrorCodePB code, const Status& err) WARN_UNUSED_RESULT;
+
+ // Parse and validate connection header.
+ Status ValidateConnectionHeader(faststring* recv_buf) WARN_UNUSED_RESULT;
+
+ // Initialize the SASL server negotiation instance.
+ Status InitSaslServer() WARN_UNUSED_RESULT;
+
+ // Handle case when client sends NEGOTIATE request. Builds the set of
+ // client-supported RPC features, determines a mutually supported
+ // authentication type to use for the connection, and sends a NEGOTIATE
+ // response.
+ Status HandleNegotiate(const NegotiatePB& request) WARN_UNUSED_RESULT;
+
+ // Handle a TLS_HANDSHAKE request message from the server.
+ Status HandleTlsHandshake(const NegotiatePB& request) WARN_UNUSED_RESULT;
+
+ // Send a TLS_HANDSHAKE response message to the server with the provided token.
+ Status SendTlsHandshake(std::string tls_token) WARN_UNUSED_RESULT;
+
+ // Authenticate the client using SASL. Populates the 'authenticated_user_'
+ // field with the SASL principal.
+ // 'recv_buf' allows a receive buffer to be reused.
+ Status AuthenticateBySasl(faststring* recv_buf) WARN_UNUSED_RESULT;
+
+ // Authenticate the client using a token. Populates the
+ // 'authenticated_user_' field with the token's principal.
+ // 'recv_buf' allows a receive buffer to be reused.
+ Status AuthenticateByToken(faststring* recv_buf) WARN_UNUSED_RESULT;
+
+ // Authenticate the client using the client's TLS certificate. Populates the
+ // 'authenticated_user_' field with the certificate's subject.
+ Status AuthenticateByCertificate() WARN_UNUSED_RESULT;
+
+ // Handle case when client sends SASL_INITIATE request.
+ // Returns Status::OK if the SASL negotiation is complete, or
+ // Status::Incomplete if a SASL_RESPONSE step is expected.
+ Status HandleSaslInitiate(const NegotiatePB& request) WARN_UNUSED_RESULT;
+
+ // Handle case when client sends SASL_RESPONSE request.
+ Status HandleSaslResponse(const NegotiatePB& request) WARN_UNUSED_RESULT;
+
+ // Send a SASL_CHALLENGE response to the client with a challenge token.
+ Status SendSaslChallenge(const char* challenge, unsigned clen) WARN_UNUSED_RESULT;
+
+ // Send a SASL_SUCCESS response to the client.
+ Status SendSaslSuccess() WARN_UNUSED_RESULT;
+
+ // Receive and validate the ConnectionContextPB.
+ Status RecvConnectionContext(faststring* recv_buf) WARN_UNUSED_RESULT;
+
+ // Returns true if connection is from trusted subnets or local networks.
+ bool IsTrustedConnection(const Sockaddr& addr);
+
+ // The socket to the remote client.
+ std::unique_ptr<Socket> socket_;
+
+ // SASL state.
+ std::vector<sasl_callback_t> callbacks_;
+ std::unique_ptr<sasl_conn_t, SaslDeleter> sasl_conn_;
+ SaslHelper helper_;
+ boost::optional<std::string> nonce_;
+
+ // TLS state.
+ const security::TlsContext* tls_context_;
+ security::TlsHandshake tls_handshake_;
+ const RpcEncryption encryption_;
+ bool tls_negotiated_;
+
+ // TSK state.
+ const security::TokenVerifier* token_verifier_;
+
+ // The set of features supported by the client and server. Filled in during negotiation.
+ std::set<RpcFeatureFlag> client_features_;
+ std::set<RpcFeatureFlag> server_features_;
+
+ // The successfully-authenticated user, if applicable. Filled in during
+ // negotiation.
+ RemoteUser authenticated_user_;
+
+ // The authentication type. Filled in during negotiation.
+ AuthenticationType negotiated_authn_;
+
+ // The SASL mechanism. Filled in during negotiation if the negotiated
+ // authentication type is SASL.
+ SaslMechanism::Type negotiated_mech_;
+
+ // The SASL protocol name that is used for the SASL negotiation.
+ const std::string sasl_proto_name_;
+
+ // Negotiation timeout deadline.
+ MonoTime deadline_;
+};
+
+} // namespace rpc
+} // namespace kudu
http://git-wip-us.apache.org/repos/asf/impala/blob/fcf190c4/be/src/kudu/rpc/service_if.cc
----------------------------------------------------------------------
diff --git a/be/src/kudu/rpc/service_if.cc b/be/src/kudu/rpc/service_if.cc
new file mode 100644
index 0000000..008c478
--- /dev/null
+++ b/be/src/kudu/rpc/service_if.cc
@@ -0,0 +1,160 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "kudu/rpc/service_if.h"
+
+#include <memory>
+#include <ostream>
+#include <string>
+#include <utility>
+
+#include <gflags/gflags.h>
+#include <glog/logging.h>
+
+#include "kudu/gutil/macros.h"
+#include "kudu/gutil/port.h"
+#include "kudu/gutil/strings/substitute.h"
+#include "kudu/rpc/connection.h"
+#include "kudu/rpc/inbound_call.h"
+#include "kudu/rpc/remote_method.h"
+#include "kudu/rpc/result_tracker.h"
+#include "kudu/rpc/rpc_context.h"
+#include "kudu/rpc/rpc_header.pb.h"
+#include "kudu/util/flag_tags.h"
+#include "kudu/util/net/sockaddr.h"
+#include "kudu/util/net/socket.h"
+#include "kudu/util/slice.h"
+#include "kudu/util/status.h"
+
+// TODO remove this once we have fully cluster-tested this.
+// Despite being on by default, this is left in in case we discover
+// any issues in 0.10.0, we'll have an easy workaround to disable the feature.
+DEFINE_bool(enable_exactly_once, true, "Whether to enable exactly once semantics.");
+TAG_FLAG(enable_exactly_once, hidden);
+
+using google::protobuf::Message;
+using std::string;
+using std::unique_ptr;
+using strings::Substitute;
+
+namespace kudu {
+namespace rpc {
+
+ServiceIf::~ServiceIf() {
+}
+
+void ServiceIf::Shutdown() {
+}
+
+bool ServiceIf::SupportsFeature(uint32_t feature) const {
+ return false;
+}
+
+bool ServiceIf::ParseParam(InboundCall *call, google::protobuf::Message *message) {
+ Slice param(call->serialized_request());
+ if (PREDICT_FALSE(!message->ParseFromArray(param.data(), param.size()))) {
+ string err = Substitute("invalid parameter for call $0: missing fields: $1",
+ call->remote_method().ToString(),
+ message->InitializationErrorString().c_str());
+ LOG(WARNING) << err;
+ call->RespondFailure(ErrorStatusPB::ERROR_INVALID_REQUEST,
+ Status::InvalidArgument(err));
+ return false;
+ }
+ return true;
+}
+
+void ServiceIf::RespondBadMethod(InboundCall *call) {
+ Sockaddr local_addr, remote_addr;
+
+ CHECK_OK(call->connection()->socket()->GetSocketAddress(&local_addr));
+ CHECK_OK(call->connection()->socket()->GetPeerAddress(&remote_addr));
+ string err = Substitute("Call on service $0 received at $1 from $2 with an "
+ "invalid method name: $3",
+ call->remote_method().service_name(),
+ local_addr.ToString(),
+ remote_addr.ToString(),
+ call->remote_method().method_name());
+ LOG(WARNING) << err;
+ call->RespondFailure(ErrorStatusPB::ERROR_NO_SUCH_METHOD,
+ Status::InvalidArgument(err));
+}
+
+GeneratedServiceIf::~GeneratedServiceIf() {
+}
+
+
+void GeneratedServiceIf::Handle(InboundCall *call) {
+ const RpcMethodInfo* method_info = call->method_info();
+ if (!method_info) {
+ RespondBadMethod(call);
+ return;
+ }
+ unique_ptr<Message> req(method_info->req_prototype->New());
+ if (PREDICT_FALSE(!ParseParam(call, req.get()))) {
+ return;
+ }
+ Message* resp = method_info->resp_prototype->New();
+
+ bool track_result = call->header().has_request_id()
+ && method_info->track_result
+ && FLAGS_enable_exactly_once;
+ RpcContext* ctx = new RpcContext(call,
+ req.release(),
+ resp,
+ track_result ? result_tracker_ : nullptr);
+ if (!method_info->authz_method(ctx->request_pb(), resp, ctx)) {
+ // The authz_method itself should have responded to the RPC.
+ return;
+ }
+
+ if (track_result) {
+ RequestIdPB request_id(call->header().request_id());
+ ResultTracker::RpcState state = ctx->result_tracker()->TrackRpc(
+ call->header().request_id(),
+ resp,
+ ctx);
+ switch (state) {
+ case ResultTracker::NEW:
+ // Fall out of the 'if' statement to the normal path.
+ break;
+ case ResultTracker::COMPLETED:
+ case ResultTracker::IN_PROGRESS:
+ case ResultTracker::STALE:
+ // ResultTracker has already responded to the RPC and deleted
+ // 'ctx'.
+ return;
+ default:
+ LOG(FATAL) << "Unknown state: " << state;
+ }
+ }
+ method_info->func(ctx->request_pb(), resp, ctx);
+}
+
+
+RpcMethodInfo* GeneratedServiceIf::LookupMethod(const RemoteMethod& method) {
+ DCHECK_EQ(method.service_name(), service_name());
+ const auto& it = methods_by_name_.find(method.method_name());
+ if (PREDICT_FALSE(it == methods_by_name_.end())) {
+ return nullptr;
+ }
+ return it->second.get();
+}
+
+
+} // namespace rpc
+} // namespace kudu
http://git-wip-us.apache.org/repos/asf/impala/blob/fcf190c4/be/src/kudu/rpc/service_if.h
----------------------------------------------------------------------
diff --git a/be/src/kudu/rpc/service_if.h b/be/src/kudu/rpc/service_if.h
new file mode 100644
index 0000000..9156b4a
--- /dev/null
+++ b/be/src/kudu/rpc/service_if.h
@@ -0,0 +1,134 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+#ifndef KUDU_RPC_SERVICE_IF_H
+#define KUDU_RPC_SERVICE_IF_H
+
+#include <cstdint>
+#include <functional>
+#include <memory>
+#include <string>
+#include <unordered_map>
+
+#include <google/protobuf/message.h>
+
+#include "kudu/gutil/ref_counted.h"
+#include "kudu/util/metrics.h"
+
+namespace kudu {
+namespace rpc {
+
+class InboundCall;
+class RemoteMethod;
+class ResultTracker;
+class RpcContext;
+
+// Generated services define an instance of this class for each
+// method that they implement. The generic server code implemented
+// by GeneratedServiceIf look up the RpcMethodInfo in order to handle
+// each RPC.
+struct RpcMethodInfo : public RefCountedThreadSafe<RpcMethodInfo> {
+ // Prototype protobufs for requests and responses.
+ // These are empty protobufs which are cloned in order to provide an
+ // instance for each request.
+ std::unique_ptr<google::protobuf::Message> req_prototype;
+ std::unique_ptr<google::protobuf::Message> resp_prototype;
+
+ scoped_refptr<Histogram> handler_latency_histogram;
+
+ // Whether we should track this method's result, using ResultTracker.
+ bool track_result;
+
+ // The authorization function for this RPC. If this function
+ // returns false, the RPC has already been handled (i.e. rejected)
+ // by the authorization function.
+ std::function<bool(const google::protobuf::Message* req,
+ google::protobuf::Message* resp,
+ RpcContext* ctx)> authz_method;
+
+ // The actual function to be called.
+ std::function<void(const google::protobuf::Message* req,
+ google::protobuf::Message* resp,
+ RpcContext* ctx)> func;
+};
+
+// Handles incoming messages that initiate an RPC.
+class ServiceIf {
+ public:
+ virtual ~ServiceIf();
+ virtual void Handle(InboundCall* incoming) = 0;
+ virtual void Shutdown();
+ virtual std::string service_name() const = 0;
+
+ // The service should return true if it supports the provided application
+ // specific feature flag.
+ virtual bool SupportsFeature(uint32_t feature) const;
+
+ // Look up the method being requested by the remote call.
+ //
+ // If this returns nullptr, then certain functionality like
+ // metrics collection will not be performed for this call.
+ virtual RpcMethodInfo* LookupMethod(const RemoteMethod& method) {
+ return nullptr;
+ }
+
+ // Default authorization method, which just allows all RPCs.
+ //
+ // See docs/design-docs/rpc.md for details on how to add custom
+ // authorization checks to a service.
+ bool AuthorizeAllowAll(const google::protobuf::Message* /*req*/,
+ google::protobuf::Message* /*resp*/,
+ RpcContext* /*ctx*/) {
+ return true;
+ }
+
+ protected:
+ bool ParseParam(InboundCall* call, google::protobuf::Message* message);
+ void RespondBadMethod(InboundCall* call);
+};
+
+
+// Base class for code-generated service classes.
+class GeneratedServiceIf : public ServiceIf {
+ public:
+ virtual ~GeneratedServiceIf();
+
+ // Looks up the appropriate method in 'methods_by_name_' and executes
+ // it on the current thread.
+ //
+ // If no such method is found, responds with an error.
+ void Handle(InboundCall* incoming) override;
+
+ RpcMethodInfo* LookupMethod(const RemoteMethod& method) override;
+
+ // Returns the mapping from method names to method infos.
+ typedef std::unordered_map<std::string, scoped_refptr<RpcMethodInfo>> MethodInfoMap;
+ const MethodInfoMap& methods_by_name() const { return methods_by_name_; }
+
+ protected:
+ // For each method, stores the relevant information about how to handle the
+ // call. Methods are inserted by the constructor of the generated subclass.
+ // After construction, this map is accessed by multiple threads and therefore
+ // must not be modified.
+ MethodInfoMap methods_by_name_;
+
+ // The result tracker for this service's methods.
+ scoped_refptr<ResultTracker> result_tracker_;
+};
+
+} // namespace rpc
+} // namespace kudu
+#endif
http://git-wip-us.apache.org/repos/asf/impala/blob/fcf190c4/be/src/kudu/rpc/service_pool.cc
----------------------------------------------------------------------
diff --git a/be/src/kudu/rpc/service_pool.cc b/be/src/kudu/rpc/service_pool.cc
new file mode 100644
index 0000000..62d46d6
--- /dev/null
+++ b/be/src/kudu/rpc/service_pool.cc
@@ -0,0 +1,234 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "kudu/rpc/service_pool.h"
+
+#include <cstdint>
+#include <memory>
+#include <ostream>
+#include <string>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+#include <boost/optional/optional.hpp>
+#include <glog/logging.h>
+
+#include "kudu/gutil/basictypes.h"
+#include "kudu/gutil/gscoped_ptr.h"
+#include "kudu/gutil/ref_counted.h"
+#include "kudu/gutil/strings/join.h"
+#include "kudu/gutil/strings/substitute.h"
+#include "kudu/rpc/inbound_call.h"
+#include "kudu/rpc/remote_method.h"
+#include "kudu/rpc/rpc_header.pb.h"
+#include "kudu/rpc/service_if.h"
+#include "kudu/rpc/service_queue.h"
+#include "kudu/util/logging.h"
+#include "kudu/util/metrics.h"
+#include "kudu/util/net/sockaddr.h"
+#include "kudu/util/status.h"
+#include "kudu/util/thread.h"
+#include "kudu/util/trace.h"
+
+using std::shared_ptr;
+using std::string;
+using std::vector;
+using strings::Substitute;
+
+METRIC_DEFINE_histogram(server, rpc_incoming_queue_time,
+ "RPC Queue Time",
+ kudu::MetricUnit::kMicroseconds,
+ "Number of microseconds incoming RPC requests spend in the worker queue",
+ 60000000LU, 3);
+
+METRIC_DEFINE_counter(server, rpcs_timed_out_in_queue,
+ "RPC Queue Timeouts",
+ kudu::MetricUnit::kRequests,
+ "Number of RPCs whose timeout elapsed while waiting "
+ "in the service queue, and thus were not processed.");
+
+METRIC_DEFINE_counter(server, rpcs_queue_overflow,
+ "RPC Queue Overflows",
+ kudu::MetricUnit::kRequests,
+ "Number of RPCs dropped because the service queue "
+ "was full.");
+
+namespace kudu {
+namespace rpc {
+
+ServicePool::ServicePool(gscoped_ptr<ServiceIf> service,
+ const scoped_refptr<MetricEntity>& entity,
+ size_t service_queue_length)
+ : service_(std::move(service)),
+ service_queue_(service_queue_length),
+ incoming_queue_time_(METRIC_rpc_incoming_queue_time.Instantiate(entity)),
+ rpcs_timed_out_in_queue_(METRIC_rpcs_timed_out_in_queue.Instantiate(entity)),
+ rpcs_queue_overflow_(METRIC_rpcs_queue_overflow.Instantiate(entity)),
+ closing_(false) {
+}
+
+ServicePool::~ServicePool() {
+ Shutdown();
+}
+
+Status ServicePool::Init(int num_threads) {
+ for (int i = 0; i < num_threads; i++) {
+ scoped_refptr<kudu::Thread> new_thread;
+ CHECK_OK(kudu::Thread::Create("service pool", "rpc worker",
+ &ServicePool::RunThread, this, &new_thread));
+ threads_.push_back(new_thread);
+ }
+ return Status::OK();
+}
+
+void ServicePool::Shutdown() {
+ service_queue_.Shutdown();
+
+ MutexLock lock(shutdown_lock_);
+ if (closing_) return;
+ closing_ = true;
+ // TODO: Use a proper thread pool implementation.
+ for (scoped_refptr<kudu::Thread>& thread : threads_) {
+ CHECK_OK(ThreadJoiner(thread.get()).Join());
+ }
+
+ // Now we must drain the service queue.
+ Status status = Status::ServiceUnavailable("Service is shutting down");
+ std::unique_ptr<InboundCall> incoming;
+ while (service_queue_.BlockingGet(&incoming)) {
+ incoming.release()->RespondFailure(ErrorStatusPB::FATAL_SERVER_SHUTTING_DOWN, status);
+ }
+
+ service_->Shutdown();
+}
+
+void ServicePool::RejectTooBusy(InboundCall* c) {
+ string err_msg =
+ Substitute("$0 request on $1 from $2 dropped due to backpressure. "
+ "The service queue is full; it has $3 items.",
+ c->remote_method().method_name(),
+ service_->service_name(),
+ c->remote_address().ToString(),
+ service_queue_.max_size());
+ rpcs_queue_overflow_->Increment();
+ KLOG_EVERY_N_SECS(WARNING, 1) << err_msg;
+ c->RespondFailure(ErrorStatusPB::ERROR_SERVER_TOO_BUSY,
+ Status::ServiceUnavailable(err_msg));
+ DLOG(INFO) << err_msg << " Contents of service queue:\n"
+ << service_queue_.ToString();
+
+ if (too_busy_hook_) {
+ too_busy_hook_();
+ }
+}
+
+RpcMethodInfo* ServicePool::LookupMethod(const RemoteMethod& method) {
+ return service_->LookupMethod(method);
+}
+
+Status ServicePool::QueueInboundCall(gscoped_ptr<InboundCall> call) {
+ InboundCall* c = call.release();
+
+ vector<uint32_t> unsupported_features;
+ for (uint32_t feature : c->GetRequiredFeatures()) {
+ if (!service_->SupportsFeature(feature)) {
+ unsupported_features.push_back(feature);
+ }
+ }
+
+ if (!unsupported_features.empty()) {
+ c->RespondUnsupportedFeature(unsupported_features);
+ return Status::NotSupported("call requires unsupported application feature flags",
+ JoinMapped(unsupported_features,
+ [] (uint32_t flag) { return std::to_string(flag); },
+ ", "));
+ }
+
+ TRACE_TO(c->trace(), "Inserting onto call queue");
+
+ // Queue message on service queue
+ boost::optional<InboundCall*> evicted;
+ auto queue_status = service_queue_.Put(c, &evicted);
+ if (queue_status == QUEUE_FULL) {
+ RejectTooBusy(c);
+ return Status::OK();
+ }
+
+ if (PREDICT_FALSE(evicted != boost::none)) {
+ RejectTooBusy(*evicted);
+ }
+
+ if (PREDICT_TRUE(queue_status == QUEUE_SUCCESS)) {
+ // NB: do not do anything with 'c' after it is successfully queued --
+ // a service thread may have already dequeued it, processed it, and
+ // responded by this point, in which case the pointer would be invalid.
+ return Status::OK();
+ }
+
+ Status status = Status::OK();
+ if (queue_status == QUEUE_SHUTDOWN) {
+ status = Status::ServiceUnavailable("Service is shutting down");
+ c->RespondFailure(ErrorStatusPB::FATAL_SERVER_SHUTTING_DOWN, status);
+ } else {
+ status = Status::RuntimeError(Substitute("Unknown error from BlockingQueue: $0", queue_status));
+ c->RespondFailure(ErrorStatusPB::FATAL_UNKNOWN, status);
+ }
+ return status;
+}
+
+void ServicePool::RunThread() {
+ while (true) {
+ std::unique_ptr<InboundCall> incoming;
+ if (!service_queue_.BlockingGet(&incoming)) {
+ VLOG(1) << "ServicePool: messenger shutting down.";
+ return;
+ }
+
+ incoming->RecordHandlingStarted(incoming_queue_time_.get());
+ ADOPT_TRACE(incoming->trace());
+
+ if (PREDICT_FALSE(incoming->ClientTimedOut())) {
+ TRACE_TO(incoming->trace(), "Skipping call since client already timed out");
+ rpcs_timed_out_in_queue_->Increment();
+
+ // Respond as a failure, even though the client will probably ignore
+ // the response anyway.
+ incoming->RespondFailure(
+ ErrorStatusPB::ERROR_SERVER_TOO_BUSY,
+ Status::TimedOut("Call waited in the queue past client deadline"));
+
+ // Must release since RespondFailure above ends up taking ownership
+ // of the object.
+ ignore_result(incoming.release());
+ continue;
+ }
+
+ TRACE_TO(incoming->trace(), "Handling call");
+
+ // Release the InboundCall pointer -- when the call is responded to,
+ // it will get deleted at that point.
+ service_->Handle(incoming.release());
+ }
+}
+
+const string ServicePool::service_name() const {
+ return service_->service_name();
+}
+
+} // namespace rpc
+} // namespace kudu
http://git-wip-us.apache.org/repos/asf/impala/blob/fcf190c4/be/src/kudu/rpc/service_pool.h
----------------------------------------------------------------------
diff --git a/be/src/kudu/rpc/service_pool.h b/be/src/kudu/rpc/service_pool.h
new file mode 100644
index 0000000..2bc8873
--- /dev/null
+++ b/be/src/kudu/rpc/service_pool.h
@@ -0,0 +1,117 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#ifndef KUDU_SERVICE_POOL_H
+#define KUDU_SERVICE_POOL_H
+
+#include <cstddef>
+#include <functional>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "kudu/gutil/gscoped_ptr.h"
+#include "kudu/gutil/macros.h"
+#include "kudu/gutil/port.h"
+#include "kudu/gutil/ref_counted.h"
+#include "kudu/rpc/rpc_service.h"
+#include "kudu/rpc/service_queue.h"
+#include "kudu/util/mutex.h"
+#include "kudu/util/status.h"
+
+namespace kudu {
+
+class Counter;
+class Histogram;
+class MetricEntity;
+class Thread;
+
+namespace rpc {
+
+class InboundCall;
+class RemoteMethod;
+class ServiceIf;
+
+struct RpcMethodInfo;
+
+// A pool of threads that handle new incoming RPC calls.
+// Also includes a queue that calls get pushed onto for handling by the pool.
+class ServicePool : public RpcService {
+ public:
+ ServicePool(gscoped_ptr<ServiceIf> service,
+ const scoped_refptr<MetricEntity>& metric_entity,
+ size_t service_queue_length);
+ virtual ~ServicePool();
+
+ // Set a hook function to be called when any RPC gets rejected because
+ // the service queue is full.
+ //
+ // NOTE: This hook runs on a reactor thread so must execute quickly.
+ // Additionally, if a service queue is overflowing, the server is likely
+ // under a lot of load, so hooks should be careful to throttle their own
+ // execution.
+ void set_too_busy_hook(std::function<void(void)> hook) {
+ too_busy_hook_ = std::move(hook);
+ }
+
+ // Start up the thread pool.
+ virtual Status Init(int num_threads);
+
+ // Shut down the queue and the thread pool.
+ virtual void Shutdown();
+
+ RpcMethodInfo* LookupMethod(const RemoteMethod& method) override;
+
+ virtual Status QueueInboundCall(gscoped_ptr<InboundCall> call) OVERRIDE;
+
+ const Counter* RpcsTimedOutInQueueMetricForTests() const {
+ return rpcs_timed_out_in_queue_.get();
+ }
+
+ const Histogram* IncomingQueueTimeMetricForTests() const {
+ return incoming_queue_time_.get();
+ }
+
+ const Counter* RpcsQueueOverflowMetric() const {
+ return rpcs_queue_overflow_.get();
+ }
+
+ const std::string service_name() const;
+
+ private:
+ void RunThread();
+ void RejectTooBusy(InboundCall* c);
+
+ gscoped_ptr<ServiceIf> service_;
+ std::vector<scoped_refptr<kudu::Thread> > threads_;
+ LifoServiceQueue service_queue_;
+ scoped_refptr<Histogram> incoming_queue_time_;
+ scoped_refptr<Counter> rpcs_timed_out_in_queue_;
+ scoped_refptr<Counter> rpcs_queue_overflow_;
+
+ mutable Mutex shutdown_lock_;
+ bool closing_;
+
+ std::function<void(void)> too_busy_hook_;
+
+ DISALLOW_COPY_AND_ASSIGN(ServicePool);
+};
+
+} // namespace rpc
+} // namespace kudu
+
+#endif
http://git-wip-us.apache.org/repos/asf/impala/blob/fcf190c4/be/src/kudu/rpc/service_queue-test.cc
----------------------------------------------------------------------
diff --git a/be/src/kudu/rpc/service_queue-test.cc b/be/src/kudu/rpc/service_queue-test.cc
new file mode 100644
index 0000000..f1450fd
--- /dev/null
+++ b/be/src/kudu/rpc/service_queue-test.cc
@@ -0,0 +1,151 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <atomic>
+#include <cstdint>
+#include <memory>
+#include <ostream>
+#include <string>
+#include <thread>
+#include <vector>
+
+#include <boost/optional/optional.hpp>
+#include <gflags/gflags.h>
+#include <glog/logging.h>
+#include <gtest/gtest.h>
+
+#include "kudu/gutil/atomicops.h"
+#include "kudu/gutil/port.h"
+#include "kudu/rpc/inbound_call.h"
+#include "kudu/rpc/service_queue.h"
+#include "kudu/util/monotime.h"
+#include "kudu/util/stopwatch.h"
+#include "kudu/util/test_util.h"
+
+using std::shared_ptr;
+using std::string;
+using std::unique_ptr;
+using std::vector;
+
+DEFINE_int32(num_producers, 4,
+ "Number of producer threads");
+
+DEFINE_int32(num_consumers, 20,
+ "Number of consumer threads");
+
+DEFINE_int32(max_queue_size, 50,
+ "Max queue length");
+
+namespace kudu {
+namespace rpc {
+
+static std::atomic<uint32_t> inprogress;
+
+static std::atomic<uint32_t> total;
+
+template <typename Queue>
+void ProducerThread(Queue* queue) {
+ int max_inprogress = FLAGS_max_queue_size - FLAGS_num_producers;
+ while (true) {
+ while (inprogress > max_inprogress) {
+ base::subtle::PauseCPU();
+ }
+ inprogress++;
+ InboundCall* call = new InboundCall(nullptr);
+ boost::optional<InboundCall*> evicted;
+ auto status = queue->Put(call, &evicted);
+ if (status == QUEUE_FULL) {
+ LOG(INFO) << "queue full: producer exiting";
+ delete call;
+ break;
+ }
+
+ if (PREDICT_FALSE(evicted != boost::none)) {
+ LOG(INFO) << "call evicted: producer exiting";
+ delete evicted.get();
+ break;
+ }
+
+ if (PREDICT_TRUE(status == QUEUE_SHUTDOWN)) {
+ delete call;
+ break;
+ }
+ }
+}
+
+template <typename Queue>
+void ConsumerThread(Queue* queue) {
+ unique_ptr<InboundCall> call;
+ while (queue->BlockingGet(&call)) {
+ inprogress--;
+ total++;
+ call.reset();
+ }
+}
+
+TEST(TestServiceQueue, LifoServiceQueuePerf) {
+ LifoServiceQueue queue(FLAGS_max_queue_size);
+ vector<std::thread> producers;
+ vector<std::thread> consumers;
+
+ for (int i = 0; i < FLAGS_num_producers; i++) {
+ producers.emplace_back(&ProducerThread<LifoServiceQueue>, &queue);
+ }
+
+ for (int i = 0; i < FLAGS_num_consumers; i++) {
+ consumers.emplace_back(&ConsumerThread<LifoServiceQueue>, &queue);
+ }
+
+ int seconds = AllowSlowTests() ? 10 : 1;
+ uint64_t total_sample = 0;
+ uint64_t total_queue_len = 0;
+ uint64_t total_idle_workers = 0;
+ Stopwatch sw(Stopwatch::ALL_THREADS);
+ sw.start();
+ int32_t before = total;
+
+ for (int i = 0; i < seconds * 50; i++) {
+ SleepFor(MonoDelta::FromMilliseconds(20));
+ total_sample++;
+ total_queue_len += queue.estimated_queue_length();
+ total_idle_workers += queue.estimated_idle_worker_count();
+ }
+
+ sw.stop();
+ int32_t delta = total - before;
+
+ queue.Shutdown();
+ for (int i = 0; i < FLAGS_num_producers; i++) {
+ producers[i].join();
+ }
+ for (int i = 0; i < FLAGS_num_consumers; i++) {
+ consumers[i].join();
+ }
+
+ float reqs_per_second = static_cast<float>(delta / sw.elapsed().wall_seconds());
+ float user_cpu_micros_per_req = static_cast<float>(sw.elapsed().user / 1000.0 / delta);
+ float sys_cpu_micros_per_req = static_cast<float>(sw.elapsed().system / 1000.0 / delta);
+
+ LOG(INFO) << "Reqs/sec: " << (int32_t)reqs_per_second;
+ LOG(INFO) << "User CPU per req: " << user_cpu_micros_per_req << "us";
+ LOG(INFO) << "Sys CPU per req: " << sys_cpu_micros_per_req << "us";
+ LOG(INFO) << "Avg rpc queue length: " << total_queue_len / static_cast<double>(total_sample);
+ LOG(INFO) << "Avg idle workers: " << total_idle_workers / static_cast<double>(total_sample);
+}
+
+} // namespace rpc
+} // namespace kudu
http://git-wip-us.apache.org/repos/asf/impala/blob/fcf190c4/be/src/kudu/rpc/service_queue.cc
----------------------------------------------------------------------
diff --git a/be/src/kudu/rpc/service_queue.cc b/be/src/kudu/rpc/service_queue.cc
new file mode 100644
index 0000000..29c0516
--- /dev/null
+++ b/be/src/kudu/rpc/service_queue.cc
@@ -0,0 +1,145 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "kudu/rpc/service_queue.h"
+
+#include <mutex>
+#include <ostream>
+
+#include <boost/optional/optional.hpp>
+
+#include "kudu/gutil/port.h"
+
+namespace kudu {
+namespace rpc {
+
+__thread LifoServiceQueue::ConsumerState* LifoServiceQueue::tl_consumer_ = nullptr;
+
+LifoServiceQueue::LifoServiceQueue(int max_size)
+ : shutdown_(false),
+ max_queue_size_(max_size) {
+ CHECK_GT(max_queue_size_, 0);
+}
+
+LifoServiceQueue::~LifoServiceQueue() {
+ DCHECK(queue_.empty())
+ << "ServiceQueue holds bare pointers at destruction time";
+}
+
+bool LifoServiceQueue::BlockingGet(std::unique_ptr<InboundCall>* out) {
+ auto consumer = tl_consumer_;
+ if (PREDICT_FALSE(!consumer)) {
+ consumer = tl_consumer_ = new ConsumerState(this);
+ std::lock_guard<simple_spinlock> l(lock_);
+ consumers_.emplace_back(consumer);
+ }
+
+ while (true) {
+ {
+ std::lock_guard<simple_spinlock> l(lock_);
+ if (!queue_.empty()) {
+ auto it = queue_.begin();
+ out->reset(*it);
+ queue_.erase(it);
+ return true;
+ }
+ if (PREDICT_FALSE(shutdown_)) {
+ return false;
+ }
+ consumer->DCheckBoundInstance(this);
+ waiting_consumers_.push_back(consumer);
+ }
+ InboundCall* call = consumer->Wait();
+ if (call != nullptr) {
+ out->reset(call);
+ return true;
+ }
+ // if call == nullptr, this means we are shutting down the queue.
+ // Loop back around and re-check 'shutdown_'.
+ }
+}
+
+QueueStatus LifoServiceQueue::Put(InboundCall* call,
+ boost::optional<InboundCall*>* evicted) {
+ std::unique_lock<simple_spinlock> l(lock_);
+ if (PREDICT_FALSE(shutdown_)) {
+ return QUEUE_SHUTDOWN;
+ }
+
+ DCHECK(!(waiting_consumers_.size() > 0 && queue_.size() > 0));
+
+ // fast path
+ if (queue_.empty() && waiting_consumers_.size() > 0) {
+ auto consumer = waiting_consumers_[waiting_consumers_.size() - 1];
+ waiting_consumers_.pop_back();
+ // Notify condition var(and wake up consumer thread) takes time,
+ // so put it out of spinlock scope.
+ l.unlock();
+ consumer->Post(call);
+ return QUEUE_SUCCESS;
+ }
+
+ if (PREDICT_FALSE(queue_.size() >= max_queue_size_)) {
+ // eviction
+ DCHECK_EQ(queue_.size(), max_queue_size_);
+ auto it = queue_.end();
+ --it;
+ if (DeadlineLess(*it, call)) {
+ return QUEUE_FULL;
+ }
+
+ *evicted = *it;
+ queue_.erase(it);
+ }
+
+ queue_.insert(call);
+ return QUEUE_SUCCESS;
+}
+
+void LifoServiceQueue::Shutdown() {
+ std::lock_guard<simple_spinlock> l(lock_);
+ shutdown_ = true;
+
+ // Post a nullptr to wake up any consumers which are waiting.
+ for (auto* cs : waiting_consumers_) {
+ cs->Post(nullptr);
+ }
+ waiting_consumers_.clear();
+}
+
+bool LifoServiceQueue::empty() const {
+ std::lock_guard<simple_spinlock> l(lock_);
+ return queue_.empty();
+}
+
+int LifoServiceQueue::max_size() const {
+ return max_queue_size_;
+}
+
+std::string LifoServiceQueue::ToString() const {
+ std::string ret;
+
+ std::lock_guard<simple_spinlock> l(lock_);
+ for (const auto* t : queue_) {
+ ret.append(t->ToString());
+ ret.append("\n");
+ }
+ return ret;
+}
+
+} // namespace rpc
+} // namespace kudu