You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ko...@apache.org on 2023/05/01 09:41:08 UTC
[arrow] branch main updated: GH-35375: [C++][FlightRPC] Add `arrow::flight::ServerCallContext::incoming_headers()` (#35376)
This is an automated email from the ASF dual-hosted git repository.
kou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 75439f09a5 GH-35375: [C++][FlightRPC] Add `arrow::flight::ServerCallContext::incoming_headers()` (#35376)
75439f09a5 is described below
commit 75439f09a5eb6fbe21622dfb31c2be37a8cc0afd
Author: Sutou Kouhei <ko...@clear-code.com>
AuthorDate: Mon May 1 18:40:56 2023 +0900
GH-35375: [C++][FlightRPC] Add `arrow::flight::ServerCallContext::incoming_headers()` (#35376)
### Rationale for this change
It returns headers sent by a client.
We can get them only in `arrow::flight::ServerMiddlewareCactory::StartCall()` for now. But they're useful for in each RPC call.
### What changes are included in this PR?
Add the method.
### Are these changes tested?
Yes.
### Are there any user-facing changes?
Yes.
* Closes: #35375
Authored-by: Sutou Kouhei <ko...@clear-code.com>
Signed-off-by: Sutou Kouhei <ko...@clear-code.com>
---
cpp/src/arrow/flight/client.h | 2 +-
cpp/src/arrow/flight/flight_test.cc | 24 ++++++++++++++++++++++
cpp/src/arrow/flight/middleware.h | 8 +-------
cpp/src/arrow/flight/server.h | 2 ++
cpp/src/arrow/flight/test_util.cc | 19 +++++++++++++++++
cpp/src/arrow/flight/transport/grpc/grpc_server.cc | 19 +++++++++--------
cpp/src/arrow/flight/transport/ucx/ucx_server.cc | 2 ++
cpp/src/arrow/flight/types.h | 6 ++++++
8 files changed, 66 insertions(+), 16 deletions(-)
diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h
index 61fa6e9d0c..1085855250 100644
--- a/cpp/src/arrow/flight/client.h
+++ b/cpp/src/arrow/flight/client.h
@@ -223,7 +223,7 @@ class ARROW_FLIGHT_EXPORT FlightClient {
/// \param[in] username Username to use
/// \param[in] password Password to use
/// \return Arrow result with bearer token and status OK if client authenticated
- /// sucessfully
+ /// successfully
arrow::Result<std::pair<std::string, std::string>> AuthenticateBasicToken(
const FlightCallOptions& options, const std::string& username,
const std::string& password);
diff --git a/cpp/src/arrow/flight/flight_test.cc b/cpp/src/arrow/flight/flight_test.cc
index 5520dfc48f..a2b69494b8 100644
--- a/cpp/src/arrow/flight/flight_test.cc
+++ b/cpp/src/arrow/flight/flight_test.cc
@@ -214,6 +214,30 @@ TEST(TestFlight, DISABLED_IpV6Port) {
ASSERT_OK(client->ListFlights());
}
+TEST(TestFlight, ServerCallContextIncomingHeaders) {
+ auto server = ExampleTestServer();
+ ASSERT_OK_AND_ASSIGN(auto location, Location::ForGrpcTcp("localhost", 0));
+ FlightServerOptions options(location);
+ ASSERT_OK(server->Init(options));
+
+ ASSERT_OK_AND_ASSIGN(auto client, FlightClient::Connect(server->location()));
+ Action action;
+ action.type = "list-incoming-headers";
+ action.body = Buffer::FromString("test-header");
+ FlightCallOptions call_options;
+ call_options.headers.emplace_back("test-header1", "value1");
+ call_options.headers.emplace_back("test-header2", "value2");
+ ASSERT_OK_AND_ASSIGN(auto stream, client->DoAction(call_options, action));
+ ASSERT_OK_AND_ASSIGN(auto result, stream->Next());
+ ASSERT_NE(result.get(), nullptr);
+ ASSERT_EQ(result->body->ToString(), "test-header1: value1");
+ ASSERT_OK_AND_ASSIGN(result, stream->Next());
+ ASSERT_NE(result.get(), nullptr);
+ ASSERT_EQ(result->body->ToString(), "test-header2: value2");
+ ASSERT_OK_AND_ASSIGN(result, stream->Next());
+ ASSERT_EQ(result.get(), nullptr);
+}
+
// ----------------------------------------------------------------------
// Client tests
diff --git a/cpp/src/arrow/flight/middleware.h b/cpp/src/arrow/flight/middleware.h
index dc1ad24bc5..e936b9f020 100644
--- a/cpp/src/arrow/flight/middleware.h
+++ b/cpp/src/arrow/flight/middleware.h
@@ -20,23 +20,17 @@
#pragma once
-#include <map>
#include <memory>
#include <string>
#include <string_view>
#include <utility>
-#include "arrow/flight/visibility.h" // IWYU pragma: keep
+#include "arrow/flight/types.h"
#include "arrow/status.h"
namespace arrow {
namespace flight {
-/// \brief Headers sent from the client or server.
-///
-/// Header values are ordered.
-using CallHeaders = std::multimap<std::string_view, std::string_view>;
-
/// \brief A write-only wrapper around headers for an RPC call.
class ARROW_FLIGHT_EXPORT AddCallHeaders {
public:
diff --git a/cpp/src/arrow/flight/server.h b/cpp/src/arrow/flight/server.h
index 1d1b1a50f3..6fb8ab1213 100644
--- a/cpp/src/arrow/flight/server.h
+++ b/cpp/src/arrow/flight/server.h
@@ -137,6 +137,8 @@ class ARROW_FLIGHT_EXPORT ServerCallContext {
/// \brief Check if the current RPC has been cancelled (by the client, by
/// a network error, etc.).
virtual bool is_cancelled() const = 0;
+ /// \brief The headers sent by the client for this call.
+ virtual const CallHeaders& incoming_headers() const = 0;
};
class ARROW_FLIGHT_EXPORT FlightServerOptions {
diff --git a/cpp/src/arrow/flight/test_util.cc b/cpp/src/arrow/flight/test_util.cc
index 0d6c28b296..7430a9b7de 100644
--- a/cpp/src/arrow/flight/test_util.cc
+++ b/cpp/src/arrow/flight/test_util.cc
@@ -474,12 +474,31 @@ class FlightTestServer : public FlightServerBase {
return Status::OK();
}
+ Status ListIncomingHeaders(const ServerCallContext& context, const Action& action,
+ std::unique_ptr<ResultStream>* out) {
+ std::vector<Result> results;
+ std::string_view prefix(*action.body);
+ for (const auto& header : context.incoming_headers()) {
+ if (header.first.substr(0, prefix.size()) != prefix) {
+ continue;
+ }
+ Result result;
+ result.body = Buffer::FromString(std::string(header.first) + ": " +
+ std::string(header.second));
+ results.push_back(result);
+ }
+ *out = std::make_unique<SimpleResultStream>(std::move(results));
+ return Status::OK();
+ }
+
Status DoAction(const ServerCallContext& context, const Action& action,
std::unique_ptr<ResultStream>* out) override {
if (action.type == "action1") {
return RunAction1(action, out);
} else if (action.type == "action2") {
return RunAction2(out);
+ } else if (action.type == "list-incoming-headers") {
+ return ListIncomingHeaders(context, action, out);
} else {
return Status::NotImplemented(action.type);
}
diff --git a/cpp/src/arrow/flight/transport/grpc/grpc_server.cc b/cpp/src/arrow/flight/transport/grpc/grpc_server.cc
index a643111e3b..acf80462f1 100644
--- a/cpp/src/arrow/flight/transport/grpc/grpc_server.cc
+++ b/cpp/src/arrow/flight/transport/grpc/grpc_server.cc
@@ -117,11 +117,18 @@ class GrpcServerAuthSender : public ServerAuthSender {
class GrpcServerCallContext : public ServerCallContext {
explicit GrpcServerCallContext(::grpc::ServerContext* context)
- : context_(context), peer_(context_->peer()) {}
+ : context_(context), peer_(context_->peer()) {
+ for (const auto& entry : context->client_metadata()) {
+ incoming_headers_.insert(
+ {std::string_view(entry.first.data(), entry.first.length()),
+ std::string_view(entry.second.data(), entry.second.length())});
+ }
+ }
const std::string& peer_identity() const override { return peer_identity_; }
const std::string& peer() const override { return peer_; }
bool is_cancelled() const override { return context_->IsCancelled(); }
+ const CallHeaders& incoming_headers() const override { return incoming_headers_; }
// Helper method that runs interceptors given the result of an RPC,
// then returns the final gRPC status to send to the client
@@ -156,6 +163,7 @@ class GrpcServerCallContext : public ServerCallContext {
std::string peer_identity_;
std::vector<std::shared_ptr<ServerMiddleware>> middleware_;
std::unordered_map<std::string, std::shared_ptr<ServerMiddleware>> middleware_map_;
+ CallHeaders incoming_headers_;
};
class GrpcAddServerHeaders : public AddCallHeaders {
@@ -310,17 +318,12 @@ class GrpcServiceHandler final : public FlightService::Service {
GrpcServerCallContext& flight_context) {
// Run server middleware
const CallInfo info{method};
- CallHeaders incoming_headers;
- for (const auto& entry : context->client_metadata()) {
- incoming_headers.insert(
- {std::string_view(entry.first.data(), entry.first.length()),
- std::string_view(entry.second.data(), entry.second.length())});
- }
GrpcAddServerHeaders outgoing_headers(context);
for (const auto& factory : middleware_) {
std::shared_ptr<ServerMiddleware> instance;
- Status result = factory.second->StartCall(info, incoming_headers, &instance);
+ Status result =
+ factory.second->StartCall(info, flight_context.incoming_headers(), &instance);
if (!result.ok()) {
// Interceptor rejected call, end the request on all existing
// interceptors
diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_server.cc b/cpp/src/arrow/flight/transport/ucx/ucx_server.cc
index 946b29383b..4a573d7429 100644
--- a/cpp/src/arrow/flight/transport/ucx/ucx_server.cc
+++ b/cpp/src/arrow/flight/transport/ucx/ucx_server.cc
@@ -76,9 +76,11 @@ class UcxServerCallContext : public flight::ServerCallContext {
return nullptr;
}
bool is_cancelled() const override { return false; }
+ const CallHeaders& incoming_headers() const override { return incoming_headers_; }
private:
std::string peer_;
+ CallHeaders incoming_headers_;
};
class UcxServerStream : public internal::ServerDataStream {
diff --git a/cpp/src/arrow/flight/types.h b/cpp/src/arrow/flight/types.h
index 39353bcb99..9d92f0be95 100644
--- a/cpp/src/arrow/flight/types.h
+++ b/cpp/src/arrow/flight/types.h
@@ -21,6 +21,7 @@
#include <cstddef>
#include <cstdint>
+#include <map>
#include <memory>
#include <string>
#include <string_view>
@@ -123,6 +124,11 @@ ARROW_FLIGHT_EXPORT
Status MakeFlightError(FlightStatusCode code, std::string message,
std::string extra_info = {});
+/// \brief Headers sent from the client or server.
+///
+/// Header values are ordered.
+using CallHeaders = std::multimap<std::string_view, std::string_view>;
+
/// \brief A TLS certificate plus key.
struct ARROW_FLIGHT_EXPORT CertKeyPair {
/// \brief The certificate in PEM format.