You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ko...@apache.org on 2023/05/01 09:41:08 UTC

[arrow] branch main updated: GH-35375: [C++][FlightRPC] Add `arrow::flight::ServerCallContext::incoming_headers()` (#35376)

This is an automated email from the ASF dual-hosted git repository.

kou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new 75439f09a5 GH-35375: [C++][FlightRPC] Add `arrow::flight::ServerCallContext::incoming_headers()` (#35376)
75439f09a5 is described below

commit 75439f09a5eb6fbe21622dfb31c2be37a8cc0afd
Author: Sutou Kouhei <ko...@clear-code.com>
AuthorDate: Mon May 1 18:40:56 2023 +0900

    GH-35375: [C++][FlightRPC] Add `arrow::flight::ServerCallContext::incoming_headers()` (#35376)
    
    ### Rationale for this change
    
    It returns headers sent by a client.
    
    We can get them only in `arrow::flight::ServerMiddlewareCactory::StartCall()` for now. But they're useful for in each RPC call.
    
    ### What changes are included in this PR?
    
    Add the method.
    
    ### Are these changes tested?
    
    Yes.
    
    ### Are there any user-facing changes?
    
    Yes.
    * Closes: #35375
    
    Authored-by: Sutou Kouhei <ko...@clear-code.com>
    Signed-off-by: Sutou Kouhei <ko...@clear-code.com>
---
 cpp/src/arrow/flight/client.h                      |  2 +-
 cpp/src/arrow/flight/flight_test.cc                | 24 ++++++++++++++++++++++
 cpp/src/arrow/flight/middleware.h                  |  8 +-------
 cpp/src/arrow/flight/server.h                      |  2 ++
 cpp/src/arrow/flight/test_util.cc                  | 19 +++++++++++++++++
 cpp/src/arrow/flight/transport/grpc/grpc_server.cc | 19 +++++++++--------
 cpp/src/arrow/flight/transport/ucx/ucx_server.cc   |  2 ++
 cpp/src/arrow/flight/types.h                       |  6 ++++++
 8 files changed, 66 insertions(+), 16 deletions(-)

diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h
index 61fa6e9d0c..1085855250 100644
--- a/cpp/src/arrow/flight/client.h
+++ b/cpp/src/arrow/flight/client.h
@@ -223,7 +223,7 @@ class ARROW_FLIGHT_EXPORT FlightClient {
   /// \param[in] username Username to use
   /// \param[in] password Password to use
   /// \return Arrow result with bearer token and status OK if client authenticated
-  /// sucessfully
+  /// successfully
   arrow::Result<std::pair<std::string, std::string>> AuthenticateBasicToken(
       const FlightCallOptions& options, const std::string& username,
       const std::string& password);
diff --git a/cpp/src/arrow/flight/flight_test.cc b/cpp/src/arrow/flight/flight_test.cc
index 5520dfc48f..a2b69494b8 100644
--- a/cpp/src/arrow/flight/flight_test.cc
+++ b/cpp/src/arrow/flight/flight_test.cc
@@ -214,6 +214,30 @@ TEST(TestFlight, DISABLED_IpV6Port) {
   ASSERT_OK(client->ListFlights());
 }
 
+TEST(TestFlight, ServerCallContextIncomingHeaders) {
+  auto server = ExampleTestServer();
+  ASSERT_OK_AND_ASSIGN(auto location, Location::ForGrpcTcp("localhost", 0));
+  FlightServerOptions options(location);
+  ASSERT_OK(server->Init(options));
+
+  ASSERT_OK_AND_ASSIGN(auto client, FlightClient::Connect(server->location()));
+  Action action;
+  action.type = "list-incoming-headers";
+  action.body = Buffer::FromString("test-header");
+  FlightCallOptions call_options;
+  call_options.headers.emplace_back("test-header1", "value1");
+  call_options.headers.emplace_back("test-header2", "value2");
+  ASSERT_OK_AND_ASSIGN(auto stream, client->DoAction(call_options, action));
+  ASSERT_OK_AND_ASSIGN(auto result, stream->Next());
+  ASSERT_NE(result.get(), nullptr);
+  ASSERT_EQ(result->body->ToString(), "test-header1: value1");
+  ASSERT_OK_AND_ASSIGN(result, stream->Next());
+  ASSERT_NE(result.get(), nullptr);
+  ASSERT_EQ(result->body->ToString(), "test-header2: value2");
+  ASSERT_OK_AND_ASSIGN(result, stream->Next());
+  ASSERT_EQ(result.get(), nullptr);
+}
+
 // ----------------------------------------------------------------------
 // Client tests
 
diff --git a/cpp/src/arrow/flight/middleware.h b/cpp/src/arrow/flight/middleware.h
index dc1ad24bc5..e936b9f020 100644
--- a/cpp/src/arrow/flight/middleware.h
+++ b/cpp/src/arrow/flight/middleware.h
@@ -20,23 +20,17 @@
 
 #pragma once
 
-#include <map>
 #include <memory>
 #include <string>
 #include <string_view>
 #include <utility>
 
-#include "arrow/flight/visibility.h"  // IWYU pragma: keep
+#include "arrow/flight/types.h"
 #include "arrow/status.h"
 
 namespace arrow {
 namespace flight {
 
-/// \brief Headers sent from the client or server.
-///
-/// Header values are ordered.
-using CallHeaders = std::multimap<std::string_view, std::string_view>;
-
 /// \brief A write-only wrapper around headers for an RPC call.
 class ARROW_FLIGHT_EXPORT AddCallHeaders {
  public:
diff --git a/cpp/src/arrow/flight/server.h b/cpp/src/arrow/flight/server.h
index 1d1b1a50f3..6fb8ab1213 100644
--- a/cpp/src/arrow/flight/server.h
+++ b/cpp/src/arrow/flight/server.h
@@ -137,6 +137,8 @@ class ARROW_FLIGHT_EXPORT ServerCallContext {
   /// \brief Check if the current RPC has been cancelled (by the client, by
   /// a network error, etc.).
   virtual bool is_cancelled() const = 0;
+  /// \brief The headers sent by the client for this call.
+  virtual const CallHeaders& incoming_headers() const = 0;
 };
 
 class ARROW_FLIGHT_EXPORT FlightServerOptions {
diff --git a/cpp/src/arrow/flight/test_util.cc b/cpp/src/arrow/flight/test_util.cc
index 0d6c28b296..7430a9b7de 100644
--- a/cpp/src/arrow/flight/test_util.cc
+++ b/cpp/src/arrow/flight/test_util.cc
@@ -474,12 +474,31 @@ class FlightTestServer : public FlightServerBase {
     return Status::OK();
   }
 
+  Status ListIncomingHeaders(const ServerCallContext& context, const Action& action,
+                             std::unique_ptr<ResultStream>* out) {
+    std::vector<Result> results;
+    std::string_view prefix(*action.body);
+    for (const auto& header : context.incoming_headers()) {
+      if (header.first.substr(0, prefix.size()) != prefix) {
+        continue;
+      }
+      Result result;
+      result.body = Buffer::FromString(std::string(header.first) + ": " +
+                                       std::string(header.second));
+      results.push_back(result);
+    }
+    *out = std::make_unique<SimpleResultStream>(std::move(results));
+    return Status::OK();
+  }
+
   Status DoAction(const ServerCallContext& context, const Action& action,
                   std::unique_ptr<ResultStream>* out) override {
     if (action.type == "action1") {
       return RunAction1(action, out);
     } else if (action.type == "action2") {
       return RunAction2(out);
+    } else if (action.type == "list-incoming-headers") {
+      return ListIncomingHeaders(context, action, out);
     } else {
       return Status::NotImplemented(action.type);
     }
diff --git a/cpp/src/arrow/flight/transport/grpc/grpc_server.cc b/cpp/src/arrow/flight/transport/grpc/grpc_server.cc
index a643111e3b..acf80462f1 100644
--- a/cpp/src/arrow/flight/transport/grpc/grpc_server.cc
+++ b/cpp/src/arrow/flight/transport/grpc/grpc_server.cc
@@ -117,11 +117,18 @@ class GrpcServerAuthSender : public ServerAuthSender {
 
 class GrpcServerCallContext : public ServerCallContext {
   explicit GrpcServerCallContext(::grpc::ServerContext* context)
-      : context_(context), peer_(context_->peer()) {}
+      : context_(context), peer_(context_->peer()) {
+    for (const auto& entry : context->client_metadata()) {
+      incoming_headers_.insert(
+          {std::string_view(entry.first.data(), entry.first.length()),
+           std::string_view(entry.second.data(), entry.second.length())});
+    }
+  }
 
   const std::string& peer_identity() const override { return peer_identity_; }
   const std::string& peer() const override { return peer_; }
   bool is_cancelled() const override { return context_->IsCancelled(); }
+  const CallHeaders& incoming_headers() const override { return incoming_headers_; }
 
   // Helper method that runs interceptors given the result of an RPC,
   // then returns the final gRPC status to send to the client
@@ -156,6 +163,7 @@ class GrpcServerCallContext : public ServerCallContext {
   std::string peer_identity_;
   std::vector<std::shared_ptr<ServerMiddleware>> middleware_;
   std::unordered_map<std::string, std::shared_ptr<ServerMiddleware>> middleware_map_;
+  CallHeaders incoming_headers_;
 };
 
 class GrpcAddServerHeaders : public AddCallHeaders {
@@ -310,17 +318,12 @@ class GrpcServiceHandler final : public FlightService::Service {
                                  GrpcServerCallContext& flight_context) {
     // Run server middleware
     const CallInfo info{method};
-    CallHeaders incoming_headers;
-    for (const auto& entry : context->client_metadata()) {
-      incoming_headers.insert(
-          {std::string_view(entry.first.data(), entry.first.length()),
-           std::string_view(entry.second.data(), entry.second.length())});
-    }
 
     GrpcAddServerHeaders outgoing_headers(context);
     for (const auto& factory : middleware_) {
       std::shared_ptr<ServerMiddleware> instance;
-      Status result = factory.second->StartCall(info, incoming_headers, &instance);
+      Status result =
+          factory.second->StartCall(info, flight_context.incoming_headers(), &instance);
       if (!result.ok()) {
         // Interceptor rejected call, end the request on all existing
         // interceptors
diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_server.cc b/cpp/src/arrow/flight/transport/ucx/ucx_server.cc
index 946b29383b..4a573d7429 100644
--- a/cpp/src/arrow/flight/transport/ucx/ucx_server.cc
+++ b/cpp/src/arrow/flight/transport/ucx/ucx_server.cc
@@ -76,9 +76,11 @@ class UcxServerCallContext : public flight::ServerCallContext {
     return nullptr;
   }
   bool is_cancelled() const override { return false; }
+  const CallHeaders& incoming_headers() const override { return incoming_headers_; }
 
  private:
   std::string peer_;
+  CallHeaders incoming_headers_;
 };
 
 class UcxServerStream : public internal::ServerDataStream {
diff --git a/cpp/src/arrow/flight/types.h b/cpp/src/arrow/flight/types.h
index 39353bcb99..9d92f0be95 100644
--- a/cpp/src/arrow/flight/types.h
+++ b/cpp/src/arrow/flight/types.h
@@ -21,6 +21,7 @@
 
 #include <cstddef>
 #include <cstdint>
+#include <map>
 #include <memory>
 #include <string>
 #include <string_view>
@@ -123,6 +124,11 @@ ARROW_FLIGHT_EXPORT
 Status MakeFlightError(FlightStatusCode code, std::string message,
                        std::string extra_info = {});
 
+/// \brief Headers sent from the client or server.
+///
+/// Header values are ordered.
+using CallHeaders = std::multimap<std::string_view, std::string_view>;
+
 /// \brief A TLS certificate plus key.
 struct ARROW_FLIGHT_EXPORT CertKeyPair {
   /// \brief The certificate in PEM format.