You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by li...@apache.org on 2020/11/25 21:02:19 UTC
[arrow] branch master updated: ARROW-10487 [FlightRPC][C++]
Header-based auth in clients
This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 65aa527 ARROW-10487 [FlightRPC][C++] Header-based auth in clients
65aa527 is described below
commit 65aa527707d7067bff833bd744ebfd97f1c2c909
Author: Lyndon Bauto <ly...@bitquilltech.com>
AuthorDate: Wed Nov 25 16:01:03 2020 -0500
ARROW-10487 [FlightRPC][C++] Header-based auth in clients
Added support for header based authentication in clients.
- Added support for base 64 encoded username / password auth to match Java implementation
- Added bearer token receiving and populating of call options to send back
- Added unit tests that connects C++ client and a mock C++ server to test authentication
Closes #8724 from lyndonb-bq/jduo/lyndon/flight-auth-cpp-redesign-client
Authored-by: Lyndon Bauto <ly...@bitquilltech.com>
Signed-off-by: David Li <li...@gmail.com>
---
cpp/src/arrow/flight/CMakeLists.txt | 1 +
cpp/src/arrow/flight/client.cc | 34 +++++
cpp/src/arrow/flight/client.h | 14 ++
cpp/src/arrow/flight/client_header_internal.cc | 92 +++++++++++++
cpp/src/arrow/flight/client_header_internal.h | 57 ++++++++
cpp/src/arrow/flight/flight_test.cc | 180 +++++++++++++++++++++++++
6 files changed, 378 insertions(+)
diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt
index f0c23ef..86e3c51 100644
--- a/cpp/src/arrow/flight/CMakeLists.txt
+++ b/cpp/src/arrow/flight/CMakeLists.txt
@@ -118,6 +118,7 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS_BACKUP}")
# protobuf-internal.cc
set(ARROW_FLIGHT_SRCS
client.cc
+ client_header_internal.cc
internal.cc
protocol_internal.cc
serialization_internal.cc
diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc
index cdffa7f..5c56e64 100644
--- a/cpp/src/arrow/flight/client.cc
+++ b/cpp/src/arrow/flight/client.cc
@@ -50,6 +50,7 @@
#include "arrow/util/uri.h"
#include "arrow/flight/client_auth.h"
+#include "arrow/flight/client_header_internal.h"
#include "arrow/flight/client_middleware.h"
#include "arrow/flight/internal.h"
#include "arrow/flight/middleware.h"
@@ -104,6 +105,9 @@ struct ClientRpc {
std::chrono::system_clock::now() + options.timeout);
context.set_deadline(deadline);
}
+ for (auto header : options.headers) {
+ context.AddMetadata(header.first, header.second);
+ }
}
/// \brief Add an auth token via an auth handler
@@ -994,6 +998,30 @@ class FlightClient::FlightClientImpl {
return Status::OK();
}
+ arrow::Result<std::pair<std::string, std::string>> AuthenticateBasicToken(
+ const FlightCallOptions& options, const std::string& username,
+ const std::string& password) {
+ // Add basic auth headers to outgoing headers.
+ ClientRpc rpc(options);
+ internal::AddBasicAuthHeaders(&rpc.context, username, password);
+
+ std::shared_ptr<grpc::ClientReaderWriter<pb::HandshakeRequest, pb::HandshakeResponse>>
+ stream = stub_->Handshake(&rpc.context);
+ GrpcClientAuthSender outgoing{stream};
+ GrpcClientAuthReader incoming{stream};
+
+ // Explicitly close our side of the connection.
+ bool finished_writes = stream->WritesDone();
+ RETURN_NOT_OK(internal::FromGrpcStatus(stream->Finish(), &rpc.context));
+ if (!finished_writes) {
+ return MakeFlightError(FlightStatusCode::Internal,
+ "Could not finish writing before closing");
+ }
+
+ // Grab bearer token from incoming headers.
+ return internal::GetBearerTokenHeader(rpc.context);
+ }
+
Status ListFlights(const FlightCallOptions& options, const Criteria& criteria,
std::unique_ptr<FlightListing>* listing) {
pb::Criteria pb_criteria;
@@ -1198,6 +1226,12 @@ Status FlightClient::Authenticate(const FlightCallOptions& options,
return impl_->Authenticate(options, std::move(auth_handler));
}
+arrow::Result<std::pair<std::string, std::string>> FlightClient::AuthenticateBasicToken(
+ const FlightCallOptions& options, const std::string& username,
+ const std::string& password) {
+ return impl_->AuthenticateBasicToken(options, username, password);
+}
+
Status FlightClient::DoAction(const FlightCallOptions& options, const Action& action,
std::unique_ptr<ResultStream>* results) {
return impl_->DoAction(options, action, results);
diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h
index 935e8fb..441f114 100644
--- a/cpp/src/arrow/flight/client.h
+++ b/cpp/src/arrow/flight/client.h
@@ -29,6 +29,7 @@
#include "arrow/ipc/options.h"
#include "arrow/ipc/reader.h"
#include "arrow/ipc/writer.h"
+#include "arrow/result.h"
#include "arrow/status.h"
#include "arrow/util/variant.h"
@@ -65,6 +66,9 @@ class ARROW_FLIGHT_EXPORT FlightCallOptions {
/// \brief IPC writer options, if applicable for the call.
ipc::IpcWriteOptions write_options;
+
+ /// \brief Headers for client to add to context.
+ std::vector<std::pair<std::string, std::string>> headers;
};
/// \brief Indicate that the client attempted to write a message
@@ -191,6 +195,16 @@ class ARROW_FLIGHT_EXPORT FlightClient {
Status Authenticate(const FlightCallOptions& options,
std::unique_ptr<ClientAuthHandler> auth_handler);
+ /// \brief Authenticate to the server using basic HTTP style authentication.
+ /// \param[in] options Per-RPC options
+ /// \param[in] username Username to use
+ /// \param[in] password Password to use
+ /// \return Arrow result with bearer token and status OK if client authenticated
+ /// sucessfully
+ arrow::Result<std::pair<std::string, std::string>> AuthenticateBasicToken(
+ const FlightCallOptions& options, const std::string& username,
+ const std::string& password);
+
/// \brief Perform the indicated action, returning an iterator to the stream
/// of results, if any
/// \param[in] options Per-RPC options
diff --git a/cpp/src/arrow/flight/client_header_internal.cc b/cpp/src/arrow/flight/client_header_internal.cc
new file mode 100644
index 0000000..2112b41
--- /dev/null
+++ b/cpp/src/arrow/flight/client_header_internal.cc
@@ -0,0 +1,92 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Interfaces for defining middleware for Flight clients. Currently
+// experimental.
+
+#include "arrow/flight/client_header_internal.h"
+#include "arrow/flight/client.h"
+#include "arrow/flight/client_auth.h"
+#include "arrow/util/base64.h"
+#include "arrow/util/make_unique.h"
+
+#include <algorithm>
+#include <cctype>
+#include <memory>
+#include <string>
+
+const char kAuthHeader[] = "authorization";
+const char kBearerPrefix[] = "Bearer ";
+const char kBasicPrefix[] = "Basic ";
+
+namespace arrow {
+namespace flight {
+namespace internal {
+
+/// \brief Add base64 encoded credentials to the outbound headers.
+///
+/// \param context Context object to add the headers to.
+/// \param username Username to format and encode.
+/// \param password Password to format and encode.
+void AddBasicAuthHeaders(grpc::ClientContext* context, const std::string& username,
+ const std::string& password) {
+ const std::string credentials = username + ":" + password;
+ context->AddMetadata(
+ kAuthHeader,
+ kBasicPrefix + arrow::util::base64_encode(
+ reinterpret_cast<const unsigned char*>(credentials.c_str()),
+ static_cast<unsigned int>(credentials.size())));
+}
+
+/// \brief Get bearer token from inbound headers.
+///
+/// \param context Incoming ClientContext that contains headers.
+/// \return Arrow result with bearer token (empty if no bearer token found).
+arrow::Result<std::pair<std::string, std::string>> GetBearerTokenHeader(
+ grpc::ClientContext& context) {
+ // Lambda function to compare characters without case sensitivity.
+ auto char_compare = [](const char& char1, const char& char2) {
+ return (::toupper(char1) == ::toupper(char2));
+ };
+
+ // Get the auth token if it exists, this can be in the initial or the trailing metadata.
+ auto trailing_headers = context.GetServerTrailingMetadata();
+ auto initial_headers = context.GetServerInitialMetadata();
+ auto bearer_iter = trailing_headers.find(kAuthHeader);
+ if (bearer_iter == trailing_headers.end()) {
+ bearer_iter = initial_headers.find(kAuthHeader);
+ if (bearer_iter == initial_headers.end()) {
+ return std::make_pair("", "");
+ }
+ }
+
+ // Check if the value of the auth token starts with the bearer prefix and latch it.
+ std::string bearer_val(bearer_iter->second.data(), bearer_iter->second.size());
+ if (bearer_val.size() > strlen(kBearerPrefix)) {
+ if (std::equal(bearer_val.begin(), bearer_val.begin() + strlen(kBearerPrefix),
+ kBearerPrefix, char_compare)) {
+ return std::make_pair(kAuthHeader, bearer_val);
+ }
+ }
+
+ // The server is not required to provide a bearer token.
+ return std::make_pair("", "");
+}
+
+} // namespace internal
+} // namespace flight
+} // namespace arrow
diff --git a/cpp/src/arrow/flight/client_header_internal.h b/cpp/src/arrow/flight/client_header_internal.h
new file mode 100644
index 0000000..718848a
--- /dev/null
+++ b/cpp/src/arrow/flight/client_header_internal.h
@@ -0,0 +1,57 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Interfaces for defining middleware for Flight clients. Currently
+// experimental.
+
+#pragma once
+
+#include "arrow/flight/client_middleware.h"
+#include "arrow/result.h"
+
+#ifdef GRPCPP_PP_INCLUDE
+#include <grpcpp/grpcpp.h>
+#if defined(GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS)
+#include <grpcpp/security/tls_credentials_options.h>
+#endif
+#else
+#include <grpc++/grpc++.h>
+#endif
+
+namespace arrow {
+namespace flight {
+namespace internal {
+
+/// \brief Add basic authentication header key value pair to context.
+///
+/// \param context grpc context variable to add header to.
+/// \param username username to encode into header.
+/// \param password password to to encode into header.
+void ARROW_FLIGHT_EXPORT AddBasicAuthHeaders(grpc::ClientContext* context,
+ const std::string& username,
+ const std::string& password);
+
+/// \brief Get bearer token from incoming headers.
+///
+/// \param context context that contains headers which hold the bearer token.
+/// \return Bearer token, parsed from headers, empty if one is not present.
+arrow::Result<std::pair<std::string, std::string>> ARROW_FLIGHT_EXPORT
+GetBearerTokenHeader(grpc::ClientContext& context);
+
+} // namespace internal
+} // namespace flight
+} // namespace arrow
diff --git a/cpp/src/arrow/flight/flight_test.cc b/cpp/src/arrow/flight/flight_test.cc
index 95048f6..2868f84 100644
--- a/cpp/src/arrow/flight/flight_test.cc
+++ b/cpp/src/arrow/flight/flight_test.cc
@@ -36,6 +36,7 @@
#include "arrow/testing/generator.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/testing/util.h"
+#include "arrow/util/base64.h"
#include "arrow/util/logging.h"
#include "arrow/util/make_unique.h"
@@ -43,6 +44,7 @@
#error "gRPC headers should not be in public API"
#endif
+#include "arrow/flight/client_header_internal.h"
#include "arrow/flight/internal.h"
#include "arrow/flight/middleware_internal.h"
#include "arrow/flight/test_util.h"
@@ -52,6 +54,15 @@ namespace pb = arrow::flight::protocol;
namespace arrow {
namespace flight {
+const char kValidUsername[] = "flight_username";
+const char kValidPassword[] = "flight_password";
+const char kInvalidUsername[] = "invalid_flight_username";
+const char kInvalidPassword[] = "invalid_flight_password";
+const char kBearerToken[] = "bearertoken";
+const char kBasicPrefix[] = "Basic ";
+const char kBearerPrefix[] = "Bearer ";
+const char kAuthHeader[] = "authorization";
+
void AssertEqual(const ActionType& expected, const ActionType& actual) {
ASSERT_EQ(expected.type, actual.type);
ASSERT_EQ(expected.description, actual.description);
@@ -559,6 +570,14 @@ class OptionsTestServer : public FlightServerBase {
}
};
+class HeaderAuthTestServer : public FlightServerBase {
+ public:
+ Status ListFlights(const ServerCallContext& context, const Criteria* criteria,
+ std::unique_ptr<FlightListing>* listings) override {
+ return Status::OK();
+ }
+};
+
class TestMetadata : public ::testing::Test {
public:
void SetUp() {
@@ -791,6 +810,113 @@ class TracingServerMiddlewareFactory : public ServerMiddlewareFactory {
}
};
+// Function to look in CallHeaders for a key that has a value starting with prefix and
+// return the rest of the value after the prefix.
+std::string FindKeyValPrefixInCallHeaders(const CallHeaders& incoming_headers,
+ const std::string& key,
+ const std::string& prefix) {
+ // Lambda function to compare characters without case sensitivity.
+ auto char_compare = [](const char& char1, const char& char2) {
+ return (::toupper(char1) == ::toupper(char2));
+ };
+
+ auto iter = incoming_headers.find(key);
+ if (iter == incoming_headers.end()) {
+ return "";
+ }
+ const std::string val = iter->second.to_string();
+ if (val.size() > prefix.length()) {
+ if (std::equal(val.begin(), val.begin() + prefix.length(), prefix.begin(),
+ char_compare)) {
+ return val.substr(prefix.length());
+ }
+ }
+ return "";
+}
+
+class HeaderAuthServerMiddleware : public ServerMiddleware {
+ public:
+ void SendingHeaders(AddCallHeaders* outgoing_headers) override {
+ outgoing_headers->AddHeader(kAuthHeader, std::string(kBearerPrefix) + kBearerToken);
+ }
+
+ void CallCompleted(const Status& status) override {}
+
+ std::string name() const override { return "HeaderAuthServerMiddleware"; }
+};
+
+void ParseBasicHeader(const CallHeaders& incoming_headers, std::string& username,
+ std::string& password) {
+ std::string encoded_credentials =
+ FindKeyValPrefixInCallHeaders(incoming_headers, kAuthHeader, kBasicPrefix);
+ std::stringstream decoded_stream(arrow::util::base64_decode(encoded_credentials));
+ std::getline(decoded_stream, username, ':');
+ std::getline(decoded_stream, password, ':');
+}
+
+// Factory for base64 header authentication testing.
+class HeaderAuthServerMiddlewareFactory : public ServerMiddlewareFactory {
+ public:
+ HeaderAuthServerMiddlewareFactory() {}
+
+ Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers,
+ std::shared_ptr<ServerMiddleware>* middleware) override {
+ std::string username, password;
+ ParseBasicHeader(incoming_headers, username, password);
+ if ((username == kValidUsername) && (password == kValidPassword)) {
+ *middleware = std::make_shared<HeaderAuthServerMiddleware>();
+ } else if ((username == kInvalidUsername) && (password == kInvalidPassword)) {
+ return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid credentials");
+ }
+ return Status::OK();
+ }
+};
+
+// A server middleware for validating incoming bearer header authentication.
+class BearerAuthServerMiddleware : public ServerMiddleware {
+ public:
+ explicit BearerAuthServerMiddleware(const CallHeaders& incoming_headers, bool* isValid)
+ : isValid_(isValid) {
+ incoming_headers_ = incoming_headers;
+ }
+
+ void SendingHeaders(AddCallHeaders* outgoing_headers) override {
+ std::string bearer_token =
+ FindKeyValPrefixInCallHeaders(incoming_headers_, kAuthHeader, kBearerPrefix);
+ *isValid_ = (bearer_token == std::string(kBearerToken));
+ }
+
+ void CallCompleted(const Status& status) override {}
+
+ std::string name() const override { return "BearerAuthServerMiddleware"; }
+
+ private:
+ CallHeaders incoming_headers_;
+ bool* isValid_;
+};
+
+// Factory for base64 header authentication testing.
+class BearerAuthServerMiddlewareFactory : public ServerMiddlewareFactory {
+ public:
+ BearerAuthServerMiddlewareFactory() : isValid_(false) {}
+
+ Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers,
+ std::shared_ptr<ServerMiddleware>* middleware) override {
+ const std::pair<CallHeaders::const_iterator, CallHeaders::const_iterator>& iter_pair =
+ incoming_headers.equal_range(kAuthHeader);
+ if (iter_pair.first != iter_pair.second) {
+ *middleware =
+ std::make_shared<BearerAuthServerMiddleware>(incoming_headers, &isValid_);
+ }
+ return Status::OK();
+ }
+
+ bool GetIsValid() { return isValid_; }
+
+ private:
+ bool isValid_;
+};
+
// A client middleware that adds a thread-local "request ID" to
// outgoing calls as a header, and keeps track of the status of
// completed calls. NOT thread-safe.
@@ -1010,6 +1136,56 @@ class TestErrorMiddleware : public ::testing::Test {
std::unique_ptr<FlightServerBase> server_;
};
+class TestBasicHeaderAuthMiddleware : public ::testing::Test {
+ public:
+ void SetUp() {
+ header_middleware_ = std::make_shared<HeaderAuthServerMiddlewareFactory>();
+ bearer_middleware_ = std::make_shared<BearerAuthServerMiddlewareFactory>();
+ std::pair<std::string, std::string> bearer = make_pair(
+ kAuthHeader, std::string(kBearerPrefix) + " " + std::string(kBearerToken));
+ ASSERT_OK(MakeServer<HeaderAuthTestServer>(
+ &server_, &client_,
+ [&](FlightServerOptions* options) {
+ options->auth_handler =
+ std::unique_ptr<ServerAuthHandler>(new NoOpAuthHandler());
+ options->middleware.push_back({"header-auth-server", header_middleware_});
+ options->middleware.push_back({"bearer-auth-server", bearer_middleware_});
+ return Status::OK();
+ },
+ [&](FlightClientOptions* options) { return Status::OK(); }));
+ }
+
+ void RunValidClientAuth() {
+ arrow::Result<std::pair<std::string, std::string>> bearer_result =
+ client_->AuthenticateBasicToken({}, kValidUsername, kValidPassword);
+ ASSERT_OK(bearer_result.status());
+ ASSERT_EQ(bearer_result.ValueOrDie().first, kAuthHeader);
+ ASSERT_EQ(bearer_result.ValueOrDie().second,
+ (std::string(kBearerPrefix) + kBearerToken));
+ std::unique_ptr<FlightListing> listing;
+ FlightCallOptions call_options;
+ call_options.headers.push_back(bearer_result.ValueOrDie());
+ ASSERT_OK(client_->ListFlights(call_options, {}, &listing));
+ ASSERT_TRUE(bearer_middleware_->GetIsValid());
+ }
+
+ void RunInvalidClientAuth() {
+ arrow::Result<std::pair<std::string, std::string>> bearer_result =
+ client_->AuthenticateBasicToken({}, kInvalidUsername, kInvalidPassword);
+ ASSERT_RAISES(IOError, bearer_result.status());
+ ASSERT_THAT(bearer_result.status().message(),
+ ::testing::HasSubstr("Invalid credentials"));
+ }
+
+ void TearDown() { ASSERT_OK(server_->Shutdown()); }
+
+ protected:
+ std::unique_ptr<FlightClient> client_;
+ std::unique_ptr<FlightServerBase> server_;
+ std::shared_ptr<HeaderAuthServerMiddlewareFactory> header_middleware_;
+ std::shared_ptr<BearerAuthServerMiddlewareFactory> bearer_middleware_;
+};
+
TEST_F(TestErrorMiddleware, TestMetadata) {
Action action;
std::unique_ptr<ResultStream> stream;
@@ -2193,5 +2369,9 @@ TEST_F(TestPropagatingMiddleware, DoPut) {
ValidateStatus(status, FlightMethod::DoPut);
}
+TEST_F(TestBasicHeaderAuthMiddleware, ValidCredentials) { RunValidClientAuth(); }
+
+TEST_F(TestBasicHeaderAuthMiddleware, InvalidCredentials) { RunInvalidClientAuth(); }
+
} // namespace flight
} // namespace arrow