You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by li...@apache.org on 2020/11/25 21:02:19 UTC

[arrow] branch master updated: ARROW-10487 [FlightRPC][C++] Header-based auth in clients

This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 65aa527  ARROW-10487 [FlightRPC][C++] Header-based auth in clients
65aa527 is described below

commit 65aa527707d7067bff833bd744ebfd97f1c2c909
Author: Lyndon Bauto <ly...@bitquilltech.com>
AuthorDate: Wed Nov 25 16:01:03 2020 -0500

    ARROW-10487 [FlightRPC][C++] Header-based auth in clients
    
    Added support for header based authentication in clients.
    - Added support for base 64 encoded username / password auth to match Java implementation
    - Added bearer token receiving and populating of call options to send back
    - Added unit tests that connects C++ client and a mock C++ server to test authentication
    
    Closes #8724 from lyndonb-bq/jduo/lyndon/flight-auth-cpp-redesign-client
    
    Authored-by: Lyndon Bauto <ly...@bitquilltech.com>
    Signed-off-by: David Li <li...@gmail.com>
---
 cpp/src/arrow/flight/CMakeLists.txt            |   1 +
 cpp/src/arrow/flight/client.cc                 |  34 +++++
 cpp/src/arrow/flight/client.h                  |  14 ++
 cpp/src/arrow/flight/client_header_internal.cc |  92 +++++++++++++
 cpp/src/arrow/flight/client_header_internal.h  |  57 ++++++++
 cpp/src/arrow/flight/flight_test.cc            | 180 +++++++++++++++++++++++++
 6 files changed, 378 insertions(+)

diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt
index f0c23ef..86e3c51 100644
--- a/cpp/src/arrow/flight/CMakeLists.txt
+++ b/cpp/src/arrow/flight/CMakeLists.txt
@@ -118,6 +118,7 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS_BACKUP}")
 # protobuf-internal.cc
 set(ARROW_FLIGHT_SRCS
     client.cc
+    client_header_internal.cc
     internal.cc
     protocol_internal.cc
     serialization_internal.cc
diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc
index cdffa7f..5c56e64 100644
--- a/cpp/src/arrow/flight/client.cc
+++ b/cpp/src/arrow/flight/client.cc
@@ -50,6 +50,7 @@
 #include "arrow/util/uri.h"
 
 #include "arrow/flight/client_auth.h"
+#include "arrow/flight/client_header_internal.h"
 #include "arrow/flight/client_middleware.h"
 #include "arrow/flight/internal.h"
 #include "arrow/flight/middleware.h"
@@ -104,6 +105,9 @@ struct ClientRpc {
               std::chrono::system_clock::now() + options.timeout);
       context.set_deadline(deadline);
     }
+    for (auto header : options.headers) {
+      context.AddMetadata(header.first, header.second);
+    }
   }
 
   /// \brief Add an auth token via an auth handler
@@ -994,6 +998,30 @@ class FlightClient::FlightClientImpl {
     return Status::OK();
   }
 
+  arrow::Result<std::pair<std::string, std::string>> AuthenticateBasicToken(
+      const FlightCallOptions& options, const std::string& username,
+      const std::string& password) {
+    // Add basic auth headers to outgoing headers.
+    ClientRpc rpc(options);
+    internal::AddBasicAuthHeaders(&rpc.context, username, password);
+
+    std::shared_ptr<grpc::ClientReaderWriter<pb::HandshakeRequest, pb::HandshakeResponse>>
+        stream = stub_->Handshake(&rpc.context);
+    GrpcClientAuthSender outgoing{stream};
+    GrpcClientAuthReader incoming{stream};
+
+    // Explicitly close our side of the connection.
+    bool finished_writes = stream->WritesDone();
+    RETURN_NOT_OK(internal::FromGrpcStatus(stream->Finish(), &rpc.context));
+    if (!finished_writes) {
+      return MakeFlightError(FlightStatusCode::Internal,
+                             "Could not finish writing before closing");
+    }
+
+    // Grab bearer token from incoming headers.
+    return internal::GetBearerTokenHeader(rpc.context);
+  }
+
   Status ListFlights(const FlightCallOptions& options, const Criteria& criteria,
                      std::unique_ptr<FlightListing>* listing) {
     pb::Criteria pb_criteria;
@@ -1198,6 +1226,12 @@ Status FlightClient::Authenticate(const FlightCallOptions& options,
   return impl_->Authenticate(options, std::move(auth_handler));
 }
 
+arrow::Result<std::pair<std::string, std::string>> FlightClient::AuthenticateBasicToken(
+    const FlightCallOptions& options, const std::string& username,
+    const std::string& password) {
+  return impl_->AuthenticateBasicToken(options, username, password);
+}
+
 Status FlightClient::DoAction(const FlightCallOptions& options, const Action& action,
                               std::unique_ptr<ResultStream>* results) {
   return impl_->DoAction(options, action, results);
diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h
index 935e8fb..441f114 100644
--- a/cpp/src/arrow/flight/client.h
+++ b/cpp/src/arrow/flight/client.h
@@ -29,6 +29,7 @@
 #include "arrow/ipc/options.h"
 #include "arrow/ipc/reader.h"
 #include "arrow/ipc/writer.h"
+#include "arrow/result.h"
 #include "arrow/status.h"
 #include "arrow/util/variant.h"
 
@@ -65,6 +66,9 @@ class ARROW_FLIGHT_EXPORT FlightCallOptions {
 
   /// \brief IPC writer options, if applicable for the call.
   ipc::IpcWriteOptions write_options;
+
+  /// \brief Headers for client to add to context.
+  std::vector<std::pair<std::string, std::string>> headers;
 };
 
 /// \brief Indicate that the client attempted to write a message
@@ -191,6 +195,16 @@ class ARROW_FLIGHT_EXPORT FlightClient {
   Status Authenticate(const FlightCallOptions& options,
                       std::unique_ptr<ClientAuthHandler> auth_handler);
 
+  /// \brief Authenticate to the server using basic HTTP style authentication.
+  /// \param[in] options Per-RPC options
+  /// \param[in] username Username to use
+  /// \param[in] password Password to use
+  /// \return Arrow result with bearer token and status OK if client authenticated
+  /// sucessfully
+  arrow::Result<std::pair<std::string, std::string>> AuthenticateBasicToken(
+      const FlightCallOptions& options, const std::string& username,
+      const std::string& password);
+
   /// \brief Perform the indicated action, returning an iterator to the stream
   /// of results, if any
   /// \param[in] options Per-RPC options
diff --git a/cpp/src/arrow/flight/client_header_internal.cc b/cpp/src/arrow/flight/client_header_internal.cc
new file mode 100644
index 0000000..2112b41
--- /dev/null
+++ b/cpp/src/arrow/flight/client_header_internal.cc
@@ -0,0 +1,92 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Interfaces for defining middleware for Flight clients. Currently
+// experimental.
+
+#include "arrow/flight/client_header_internal.h"
+#include "arrow/flight/client.h"
+#include "arrow/flight/client_auth.h"
+#include "arrow/util/base64.h"
+#include "arrow/util/make_unique.h"
+
+#include <algorithm>
+#include <cctype>
+#include <memory>
+#include <string>
+
+const char kAuthHeader[] = "authorization";
+const char kBearerPrefix[] = "Bearer ";
+const char kBasicPrefix[] = "Basic ";
+
+namespace arrow {
+namespace flight {
+namespace internal {
+
+/// \brief Add base64 encoded credentials to the outbound headers.
+///
+/// \param context Context object to add the headers to.
+/// \param username Username to format and encode.
+/// \param password Password to format and encode.
+void AddBasicAuthHeaders(grpc::ClientContext* context, const std::string& username,
+                         const std::string& password) {
+  const std::string credentials = username + ":" + password;
+  context->AddMetadata(
+      kAuthHeader,
+      kBasicPrefix + arrow::util::base64_encode(
+                         reinterpret_cast<const unsigned char*>(credentials.c_str()),
+                         static_cast<unsigned int>(credentials.size())));
+}
+
+/// \brief Get bearer token from inbound headers.
+///
+/// \param context Incoming ClientContext that contains headers.
+/// \return Arrow result with bearer token (empty if no bearer token found).
+arrow::Result<std::pair<std::string, std::string>> GetBearerTokenHeader(
+    grpc::ClientContext& context) {
+  // Lambda function to compare characters without case sensitivity.
+  auto char_compare = [](const char& char1, const char& char2) {
+    return (::toupper(char1) == ::toupper(char2));
+  };
+
+  // Get the auth token if it exists, this can be in the initial or the trailing metadata.
+  auto trailing_headers = context.GetServerTrailingMetadata();
+  auto initial_headers = context.GetServerInitialMetadata();
+  auto bearer_iter = trailing_headers.find(kAuthHeader);
+  if (bearer_iter == trailing_headers.end()) {
+    bearer_iter = initial_headers.find(kAuthHeader);
+    if (bearer_iter == initial_headers.end()) {
+      return std::make_pair("", "");
+    }
+  }
+
+  // Check if the value of the auth token starts with the bearer prefix and latch it.
+  std::string bearer_val(bearer_iter->second.data(), bearer_iter->second.size());
+  if (bearer_val.size() > strlen(kBearerPrefix)) {
+    if (std::equal(bearer_val.begin(), bearer_val.begin() + strlen(kBearerPrefix),
+                   kBearerPrefix, char_compare)) {
+      return std::make_pair(kAuthHeader, bearer_val);
+    }
+  }
+
+  // The server is not required to provide a bearer token.
+  return std::make_pair("", "");
+}
+
+}  // namespace internal
+}  // namespace flight
+}  // namespace arrow
diff --git a/cpp/src/arrow/flight/client_header_internal.h b/cpp/src/arrow/flight/client_header_internal.h
new file mode 100644
index 0000000..718848a
--- /dev/null
+++ b/cpp/src/arrow/flight/client_header_internal.h
@@ -0,0 +1,57 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Interfaces for defining middleware for Flight clients. Currently
+// experimental.
+
+#pragma once
+
+#include "arrow/flight/client_middleware.h"
+#include "arrow/result.h"
+
+#ifdef GRPCPP_PP_INCLUDE
+#include <grpcpp/grpcpp.h>
+#if defined(GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS)
+#include <grpcpp/security/tls_credentials_options.h>
+#endif
+#else
+#include <grpc++/grpc++.h>
+#endif
+
+namespace arrow {
+namespace flight {
+namespace internal {
+
+/// \brief Add basic authentication header key value pair to context.
+///
+/// \param context grpc context variable to add header to.
+/// \param username username to encode into header.
+/// \param password password to to encode into header.
+void ARROW_FLIGHT_EXPORT AddBasicAuthHeaders(grpc::ClientContext* context,
+                                             const std::string& username,
+                                             const std::string& password);
+
+/// \brief Get bearer token from incoming headers.
+///
+/// \param context context that contains headers which hold the bearer token.
+/// \return Bearer token, parsed from headers, empty if one is not present.
+arrow::Result<std::pair<std::string, std::string>> ARROW_FLIGHT_EXPORT
+GetBearerTokenHeader(grpc::ClientContext& context);
+
+}  // namespace internal
+}  // namespace flight
+}  // namespace arrow
diff --git a/cpp/src/arrow/flight/flight_test.cc b/cpp/src/arrow/flight/flight_test.cc
index 95048f6..2868f84 100644
--- a/cpp/src/arrow/flight/flight_test.cc
+++ b/cpp/src/arrow/flight/flight_test.cc
@@ -36,6 +36,7 @@
 #include "arrow/testing/generator.h"
 #include "arrow/testing/gtest_util.h"
 #include "arrow/testing/util.h"
+#include "arrow/util/base64.h"
 #include "arrow/util/logging.h"
 #include "arrow/util/make_unique.h"
 
@@ -43,6 +44,7 @@
 #error "gRPC headers should not be in public API"
 #endif
 
+#include "arrow/flight/client_header_internal.h"
 #include "arrow/flight/internal.h"
 #include "arrow/flight/middleware_internal.h"
 #include "arrow/flight/test_util.h"
@@ -52,6 +54,15 @@ namespace pb = arrow::flight::protocol;
 namespace arrow {
 namespace flight {
 
+const char kValidUsername[] = "flight_username";
+const char kValidPassword[] = "flight_password";
+const char kInvalidUsername[] = "invalid_flight_username";
+const char kInvalidPassword[] = "invalid_flight_password";
+const char kBearerToken[] = "bearertoken";
+const char kBasicPrefix[] = "Basic ";
+const char kBearerPrefix[] = "Bearer ";
+const char kAuthHeader[] = "authorization";
+
 void AssertEqual(const ActionType& expected, const ActionType& actual) {
   ASSERT_EQ(expected.type, actual.type);
   ASSERT_EQ(expected.description, actual.description);
@@ -559,6 +570,14 @@ class OptionsTestServer : public FlightServerBase {
   }
 };
 
+class HeaderAuthTestServer : public FlightServerBase {
+ public:
+  Status ListFlights(const ServerCallContext& context, const Criteria* criteria,
+                     std::unique_ptr<FlightListing>* listings) override {
+    return Status::OK();
+  }
+};
+
 class TestMetadata : public ::testing::Test {
  public:
   void SetUp() {
@@ -791,6 +810,113 @@ class TracingServerMiddlewareFactory : public ServerMiddlewareFactory {
   }
 };
 
+// Function to look in CallHeaders for a key that has a value starting with prefix and
+// return the rest of the value after the prefix.
+std::string FindKeyValPrefixInCallHeaders(const CallHeaders& incoming_headers,
+                                          const std::string& key,
+                                          const std::string& prefix) {
+  // Lambda function to compare characters without case sensitivity.
+  auto char_compare = [](const char& char1, const char& char2) {
+    return (::toupper(char1) == ::toupper(char2));
+  };
+
+  auto iter = incoming_headers.find(key);
+  if (iter == incoming_headers.end()) {
+    return "";
+  }
+  const std::string val = iter->second.to_string();
+  if (val.size() > prefix.length()) {
+    if (std::equal(val.begin(), val.begin() + prefix.length(), prefix.begin(),
+                   char_compare)) {
+      return val.substr(prefix.length());
+    }
+  }
+  return "";
+}
+
+class HeaderAuthServerMiddleware : public ServerMiddleware {
+ public:
+  void SendingHeaders(AddCallHeaders* outgoing_headers) override {
+    outgoing_headers->AddHeader(kAuthHeader, std::string(kBearerPrefix) + kBearerToken);
+  }
+
+  void CallCompleted(const Status& status) override {}
+
+  std::string name() const override { return "HeaderAuthServerMiddleware"; }
+};
+
+void ParseBasicHeader(const CallHeaders& incoming_headers, std::string& username,
+                      std::string& password) {
+  std::string encoded_credentials =
+      FindKeyValPrefixInCallHeaders(incoming_headers, kAuthHeader, kBasicPrefix);
+  std::stringstream decoded_stream(arrow::util::base64_decode(encoded_credentials));
+  std::getline(decoded_stream, username, ':');
+  std::getline(decoded_stream, password, ':');
+}
+
+// Factory for base64 header authentication testing.
+class HeaderAuthServerMiddlewareFactory : public ServerMiddlewareFactory {
+ public:
+  HeaderAuthServerMiddlewareFactory() {}
+
+  Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers,
+                   std::shared_ptr<ServerMiddleware>* middleware) override {
+    std::string username, password;
+    ParseBasicHeader(incoming_headers, username, password);
+    if ((username == kValidUsername) && (password == kValidPassword)) {
+      *middleware = std::make_shared<HeaderAuthServerMiddleware>();
+    } else if ((username == kInvalidUsername) && (password == kInvalidPassword)) {
+      return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid credentials");
+    }
+    return Status::OK();
+  }
+};
+
+// A server middleware for validating incoming bearer header authentication.
+class BearerAuthServerMiddleware : public ServerMiddleware {
+ public:
+  explicit BearerAuthServerMiddleware(const CallHeaders& incoming_headers, bool* isValid)
+      : isValid_(isValid) {
+    incoming_headers_ = incoming_headers;
+  }
+
+  void SendingHeaders(AddCallHeaders* outgoing_headers) override {
+    std::string bearer_token =
+        FindKeyValPrefixInCallHeaders(incoming_headers_, kAuthHeader, kBearerPrefix);
+    *isValid_ = (bearer_token == std::string(kBearerToken));
+  }
+
+  void CallCompleted(const Status& status) override {}
+
+  std::string name() const override { return "BearerAuthServerMiddleware"; }
+
+ private:
+  CallHeaders incoming_headers_;
+  bool* isValid_;
+};
+
+// Factory for base64 header authentication testing.
+class BearerAuthServerMiddlewareFactory : public ServerMiddlewareFactory {
+ public:
+  BearerAuthServerMiddlewareFactory() : isValid_(false) {}
+
+  Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers,
+                   std::shared_ptr<ServerMiddleware>* middleware) override {
+    const std::pair<CallHeaders::const_iterator, CallHeaders::const_iterator>& iter_pair =
+        incoming_headers.equal_range(kAuthHeader);
+    if (iter_pair.first != iter_pair.second) {
+      *middleware =
+          std::make_shared<BearerAuthServerMiddleware>(incoming_headers, &isValid_);
+    }
+    return Status::OK();
+  }
+
+  bool GetIsValid() { return isValid_; }
+
+ private:
+  bool isValid_;
+};
+
 // A client middleware that adds a thread-local "request ID" to
 // outgoing calls as a header, and keeps track of the status of
 // completed calls. NOT thread-safe.
@@ -1010,6 +1136,56 @@ class TestErrorMiddleware : public ::testing::Test {
   std::unique_ptr<FlightServerBase> server_;
 };
 
+class TestBasicHeaderAuthMiddleware : public ::testing::Test {
+ public:
+  void SetUp() {
+    header_middleware_ = std::make_shared<HeaderAuthServerMiddlewareFactory>();
+    bearer_middleware_ = std::make_shared<BearerAuthServerMiddlewareFactory>();
+    std::pair<std::string, std::string> bearer = make_pair(
+        kAuthHeader, std::string(kBearerPrefix) + " " + std::string(kBearerToken));
+    ASSERT_OK(MakeServer<HeaderAuthTestServer>(
+        &server_, &client_,
+        [&](FlightServerOptions* options) {
+          options->auth_handler =
+              std::unique_ptr<ServerAuthHandler>(new NoOpAuthHandler());
+          options->middleware.push_back({"header-auth-server", header_middleware_});
+          options->middleware.push_back({"bearer-auth-server", bearer_middleware_});
+          return Status::OK();
+        },
+        [&](FlightClientOptions* options) { return Status::OK(); }));
+  }
+
+  void RunValidClientAuth() {
+    arrow::Result<std::pair<std::string, std::string>> bearer_result =
+        client_->AuthenticateBasicToken({}, kValidUsername, kValidPassword);
+    ASSERT_OK(bearer_result.status());
+    ASSERT_EQ(bearer_result.ValueOrDie().first, kAuthHeader);
+    ASSERT_EQ(bearer_result.ValueOrDie().second,
+              (std::string(kBearerPrefix) + kBearerToken));
+    std::unique_ptr<FlightListing> listing;
+    FlightCallOptions call_options;
+    call_options.headers.push_back(bearer_result.ValueOrDie());
+    ASSERT_OK(client_->ListFlights(call_options, {}, &listing));
+    ASSERT_TRUE(bearer_middleware_->GetIsValid());
+  }
+
+  void RunInvalidClientAuth() {
+    arrow::Result<std::pair<std::string, std::string>> bearer_result =
+        client_->AuthenticateBasicToken({}, kInvalidUsername, kInvalidPassword);
+    ASSERT_RAISES(IOError, bearer_result.status());
+    ASSERT_THAT(bearer_result.status().message(),
+                ::testing::HasSubstr("Invalid credentials"));
+  }
+
+  void TearDown() { ASSERT_OK(server_->Shutdown()); }
+
+ protected:
+  std::unique_ptr<FlightClient> client_;
+  std::unique_ptr<FlightServerBase> server_;
+  std::shared_ptr<HeaderAuthServerMiddlewareFactory> header_middleware_;
+  std::shared_ptr<BearerAuthServerMiddlewareFactory> bearer_middleware_;
+};
+
 TEST_F(TestErrorMiddleware, TestMetadata) {
   Action action;
   std::unique_ptr<ResultStream> stream;
@@ -2193,5 +2369,9 @@ TEST_F(TestPropagatingMiddleware, DoPut) {
   ValidateStatus(status, FlightMethod::DoPut);
 }
 
+TEST_F(TestBasicHeaderAuthMiddleware, ValidCredentials) { RunValidClientAuth(); }
+
+TEST_F(TestBasicHeaderAuthMiddleware, InvalidCredentials) { RunInvalidClientAuth(); }
+
 }  // namespace flight
 }  // namespace arrow