You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ap...@apache.org on 2019/06/13 16:42:46 UTC
[arrow] branch master updated: ARROW-5397: [FlightRPC] Add TLS
certificates for testing Flight
This is an automated email from the ASF dual-hosted git repository.
apitrou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new b7e8ed7 ARROW-5397: [FlightRPC] Add TLS certificates for testing Flight
b7e8ed7 is described below
commit b7e8ed7fe9613c899a2181bcf48996466b31d9f8
Author: David Li <li...@gmail.com>
AuthorDate: Thu Jun 13 18:42:38 2019 +0200
ARROW-5397: [FlightRPC] Add TLS certificates for testing Flight
This needs https://github.com/apache/arrow-testing/pull/2.
Author: David Li <li...@gmail.com>
Closes #4510 from lihalite/flight-tls and squashes the following commits:
5eff72470 <David Li> Don't set wait_for_ready in Flight
776b9d01e <David Li> Add tests for TLS in Flight (C++, Python)
9d2efa20a <David Li> Allow multiple TLS certificates in Flight
---
cpp/src/arrow/flight/client.cc | 3 -
cpp/src/arrow/flight/flight-test.cc | 89 ++++++++++++---
cpp/src/arrow/flight/server.cc | 7 +-
cpp/src/arrow/flight/server.h | 3 +-
cpp/src/arrow/flight/test-server.cc | 109 +------------------
cpp/src/arrow/flight/test-util.cc | 161 ++++++++++++++++++++++++++++
cpp/src/arrow/flight/test-util.h | 10 ++
cpp/src/arrow/flight/types.cc | 6 ++
cpp/src/arrow/flight/types.h | 16 +++
python/pyarrow/_flight.pyx | 32 ++++--
python/pyarrow/flight.py | 1 +
python/pyarrow/includes/libarrow_flight.pxd | 10 +-
python/pyarrow/tests/test_flight.py | 92 +++++++++++++++-
testing | 2 +-
14 files changed, 398 insertions(+), 143 deletions(-)
diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc
index 1c927da..2b7c699 100644
--- a/cpp/src/arrow/flight/client.cc
+++ b/cpp/src/arrow/flight/client.cc
@@ -60,9 +60,6 @@ struct ClientRpc {
grpc::ClientContext context;
explicit ClientRpc(const FlightCallOptions& options) {
- /// XXX workaround until we have a handshake in Connect
- context.set_wait_for_ready(true);
-
if (options.timeout.count() >= 0) {
std::chrono::system_clock::time_point deadline =
std::chrono::time_point_cast<std::chrono::system_clock::time_point::duration>(
diff --git a/cpp/src/arrow/flight/flight-test.cc b/cpp/src/arrow/flight/flight-test.cc
index cb7e57c..b295878 100644
--- a/cpp/src/arrow/flight/flight-test.cc
+++ b/cpp/src/arrow/flight/flight-test.cc
@@ -176,29 +176,22 @@ TEST(TestFlight, ConnectUri) {
class TestFlightClient : public ::testing::Test {
public:
- // Uncomment these when you want to run the server separately for
- // debugging/valgrind/gdb
+ void SetUp() {
+ Location location;
+ std::unique_ptr<FlightServerBase> server = ExampleTestServer();
- // void SetUp() {
- // port_ = 92358;
- // ASSERT_OK(ConnectClient());
- // }
- // void TearDown() {}
+ ASSERT_OK(Location::ForGrpcTcp("localhost", GetListenPort(), &location));
+ FlightServerOptions options(location);
+ ASSERT_OK(server->Init(options));
- void SetUp() {
- server_.reset(new TestServer("flight-test-server"));
- server_->Start();
- port_ = server_->port();
+ server_.reset(new InProcessTestServer(std::move(server), location));
+ ASSERT_OK(server_->Start());
ASSERT_OK(ConnectClient());
}
void TearDown() { server_->Stop(); }
- Status ConnectClient() {
- Location location;
- RETURN_NOT_OK(Location::ForGrpcTcp("localhost", port_, &location));
- return FlightClient::Connect(location, &client_);
- }
+ Status ConnectClient() { return FlightClient::Connect(server_->location(), &client_); }
template <typename EndpointCheckFunc>
void CheckDoGet(const FlightDescriptor& descr, const BatchVector& expected_batches,
@@ -236,7 +229,7 @@ class TestFlightClient : public ::testing::Test {
protected:
int port_;
std::unique_ptr<FlightClient> client_;
- std::unique_ptr<TestServer> server_;
+ std::unique_ptr<InProcessTestServer> server_;
};
class AuthTestServer : public FlightServerBase {
@@ -249,6 +242,16 @@ class AuthTestServer : public FlightServerBase {
}
};
+class TlsTestServer : public FlightServerBase {
+ Status DoAction(const ServerCallContext& context, const Action& action,
+ std::unique_ptr<ResultStream>* result) override {
+ std::shared_ptr<Buffer> buf;
+ RETURN_NOT_OK(Buffer::FromString("Hello, world!", &buf));
+ *result = std::unique_ptr<ResultStream>(new SimpleResultStream({Result{buf}}));
+ return Status::OK();
+ }
+};
+
class DoPutTestServer : public FlightServerBase {
public:
Status DoPut(const ServerCallContext& context,
@@ -336,6 +339,42 @@ class TestDoPut : public ::testing::Test {
DoPutTestServer* do_put_server_;
};
+class TestTls : public ::testing::Test {
+ public:
+ void SetUp() {
+ Location location;
+ std::unique_ptr<FlightServerBase> server(new TlsTestServer);
+
+ ASSERT_OK(Location::ForGrpcTls("localhost", GetListenPort(), &location));
+ FlightServerOptions options(location);
+ ASSERT_RAISES(UnknownError, server->Init(options));
+ ASSERT_OK(ExampleTlsCertificates(&options.tls_certificates));
+ ASSERT_OK(server->Init(options));
+
+ server_.reset(new InProcessTestServer(std::move(server), location));
+ ASSERT_OK(server_->Start());
+ ASSERT_OK(ConnectClient());
+ }
+
+ void TearDown() {
+ if (server_) {
+ server_->Stop();
+ }
+ }
+
+ Status ConnectClient() {
+ auto options = FlightClientOptions();
+ CertKeyPair root_cert;
+ RETURN_NOT_OK(ExampleTlsCertificateRoot(&root_cert));
+ options.tls_root_certs = root_cert.pem_cert;
+ return FlightClient::Connect(server_->location(), options, &client_);
+ }
+
+ protected:
+ std::unique_ptr<FlightClient> client_;
+ std::unique_ptr<InProcessTestServer> server_;
+};
+
TEST_F(TestFlightClient, ListFlights) {
std::unique_ptr<FlightListing> listing;
ASSERT_OK(client_->ListFlights(&listing));
@@ -620,5 +659,21 @@ TEST_F(TestAuthHandler, CheckPeerIdentity) {
ASSERT_EQ(result->body->ToString(), "user");
}
+TEST_F(TestTls, DoAction) {
+ FlightCallOptions options;
+ options.timeout = TimeoutDuration{5.0};
+ Action action;
+ action.type = "test";
+ action.body = Buffer::FromString("");
+ std::unique_ptr<ResultStream> results;
+ ASSERT_OK(client_->DoAction(options, action, &results));
+ ASSERT_NE(results, nullptr);
+
+ std::unique_ptr<Result> result;
+ ASSERT_OK(results->Next(&result));
+ ASSERT_NE(result, nullptr);
+ ASSERT_EQ(result->body->ToString(), "Hello, world!");
+}
+
} // namespace flight
} // namespace arrow
diff --git a/cpp/src/arrow/flight/server.cc b/cpp/src/arrow/flight/server.cc
index 9b6bf6c..6f3c466 100644
--- a/cpp/src/arrow/flight/server.cc
+++ b/cpp/src/arrow/flight/server.cc
@@ -460,7 +460,7 @@ thread_local std::atomic<FlightServerBase::Impl*>
#endif
FlightServerOptions::FlightServerOptions(const Location& location_)
- : location(location_), auth_handler(nullptr) {}
+ : location(location_), auth_handler(nullptr), tls_certificates() {}
FlightServerBase::FlightServerBase() { impl_.reset(new Impl); }
@@ -483,8 +483,9 @@ Status FlightServerBase::Init(FlightServerOptions& options) {
std::shared_ptr<grpc::ServerCredentials> creds;
if (scheme == kSchemeGrpcTls) {
grpc::SslServerCredentialsOptions ssl_options;
- ssl_options.pem_key_cert_pairs.push_back(
- {options.tls_private_key, options.tls_cert_chain});
+ for (const auto& pair : options.tls_certificates) {
+ ssl_options.pem_key_cert_pairs.push_back({pair.pem_key, pair.pem_cert});
+ }
creds = grpc::SslServerCredentials(ssl_options);
} else {
creds = grpc::InsecureServerCredentials();
diff --git a/cpp/src/arrow/flight/server.h b/cpp/src/arrow/flight/server.h
index 7164b64..c1bcb5c 100644
--- a/cpp/src/arrow/flight/server.h
+++ b/cpp/src/arrow/flight/server.h
@@ -106,8 +106,7 @@ class ARROW_FLIGHT_EXPORT FlightServerOptions {
Location location;
std::unique_ptr<ServerAuthHandler> auth_handler;
- std::string tls_cert_chain;
- std::string tls_private_key;
+ std::vector<CertKeyPair> tls_certificates;
};
/// \brief Skeleton RPC server implementation which can be used to create
diff --git a/cpp/src/arrow/flight/test-server.cc b/cpp/src/arrow/flight/test-server.cc
index f72fd3c..87ef62f 100644
--- a/cpp/src/arrow/flight/test-server.cc
+++ b/cpp/src/arrow/flight/test-server.cc
@@ -25,120 +25,19 @@
#include <gflags/gflags.h>
-#include "arrow/buffer.h"
-#include "arrow/io/test-common.h"
-#include "arrow/record_batch.h"
-#include "arrow/util/logging.h"
-
#include "arrow/flight/server.h"
-#include "arrow/flight/server_auth.h"
#include "arrow/flight/test-util.h"
+#include "arrow/flight/types.h"
+#include "arrow/util/logging.h"
DEFINE_int32(port, 31337, "Server port to listen on");
-namespace arrow {
-namespace flight {
-
-Status GetBatchForFlight(const Ticket& ticket, std::shared_ptr<RecordBatchReader>* out) {
- if (ticket.ticket == "ticket-ints-1") {
- BatchVector batches;
- RETURN_NOT_OK(ExampleIntBatches(&batches));
- *out = std::make_shared<BatchIterator>(batches[0]->schema(), batches);
- return Status::OK();
- } else if (ticket.ticket == "ticket-dicts-1") {
- BatchVector batches;
- RETURN_NOT_OK(ExampleDictBatches(&batches));
- *out = std::make_shared<BatchIterator>(batches[0]->schema(), batches);
- return Status::OK();
- } else {
- return Status::NotImplemented("no stream implemented for this ticket");
- }
-}
-
-class FlightTestServer : public FlightServerBase {
- Status ListFlights(const ServerCallContext& context, const Criteria* criteria,
- std::unique_ptr<FlightListing>* listings) override {
- std::vector<FlightInfo> flights = ExampleFlightInfo();
- *listings = std::unique_ptr<FlightListing>(new SimpleFlightListing(flights));
- return Status::OK();
- }
-
- Status GetFlightInfo(const ServerCallContext& context, const FlightDescriptor& request,
- std::unique_ptr<FlightInfo>* out) override {
- std::vector<FlightInfo> flights = ExampleFlightInfo();
-
- for (const auto& info : flights) {
- if (info.descriptor().Equals(request)) {
- *out = std::unique_ptr<FlightInfo>(new FlightInfo(info));
- return Status::OK();
- }
- }
- return Status::Invalid("Flight not found: ", request.ToString());
- }
-
- Status DoGet(const ServerCallContext& context, const Ticket& request,
- std::unique_ptr<FlightDataStream>* data_stream) override {
- // Test for ARROW-5095
- if (request.ticket == "ARROW-5095-fail") {
- return Status::UnknownError("Server-side error");
- }
- if (request.ticket == "ARROW-5095-success") {
- return Status::OK();
- }
-
- std::shared_ptr<RecordBatchReader> batch_reader;
- RETURN_NOT_OK(GetBatchForFlight(request, &batch_reader));
-
- *data_stream = std::unique_ptr<FlightDataStream>(new RecordBatchStream(batch_reader));
- return Status::OK();
- }
-
- Status RunAction1(const Action& action, std::unique_ptr<ResultStream>* out) {
- std::vector<Result> results;
- for (int i = 0; i < 3; ++i) {
- Result result;
- std::string value = action.body->ToString() + "-part" + std::to_string(i);
- RETURN_NOT_OK(Buffer::FromString(value, &result.body));
- results.push_back(result);
- }
- *out = std::unique_ptr<ResultStream>(new SimpleResultStream(std::move(results)));
- return Status::OK();
- }
-
- Status RunAction2(std::unique_ptr<ResultStream>* out) {
- // Empty
- *out = std::unique_ptr<ResultStream>(new SimpleResultStream({}));
- return Status::OK();
- }
-
- Status DoAction(const ServerCallContext& context, const Action& action,
- std::unique_ptr<ResultStream>* out) override {
- if (action.type == "action1") {
- return RunAction1(action, out);
- } else if (action.type == "action2") {
- return RunAction2(out);
- } else {
- return Status::NotImplemented(action.type);
- }
- }
-
- Status ListActions(const ServerCallContext& context,
- std::vector<ActionType>* out) override {
- std::vector<ActionType> actions = ExampleActionTypes();
- *out = std::move(actions);
- return Status::OK();
- }
-};
-
-} // namespace flight
-} // namespace arrow
-
-std::unique_ptr<arrow::flight::FlightTestServer> g_server;
+std::unique_ptr<arrow::flight::FlightServerBase> g_server;
int main(int argc, char** argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
- g_server.reset(new arrow::flight::FlightTestServer);
+ g_server = arrow::flight::ExampleTestServer();
arrow::flight::Location location;
ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTcp("0.0.0.0", FLAGS_port, &location));
diff --git a/cpp/src/arrow/flight/test-util.cc b/cpp/src/arrow/flight/test-util.cc
index b20a4cb..7dd78fd 100644
--- a/cpp/src/arrow/flight/test-util.cc
+++ b/cpp/src/arrow/flight/test-util.cc
@@ -22,6 +22,7 @@
#include <mach-o/dyld.h>
#endif
+#include <cstdlib>
#include <sstream>
#include <boost/filesystem.hpp>
@@ -154,6 +155,101 @@ InProcessTestServer::~InProcessTestServer() {
}
}
+Status GetBatchForFlight(const Ticket& ticket, std::shared_ptr<RecordBatchReader>* out) {
+ if (ticket.ticket == "ticket-ints-1") {
+ BatchVector batches;
+ RETURN_NOT_OK(ExampleIntBatches(&batches));
+ *out = std::make_shared<BatchIterator>(batches[0]->schema(), batches);
+ return Status::OK();
+ } else if (ticket.ticket == "ticket-dicts-1") {
+ BatchVector batches;
+ RETURN_NOT_OK(ExampleDictBatches(&batches));
+ *out = std::make_shared<BatchIterator>(batches[0]->schema(), batches);
+ return Status::OK();
+ } else {
+ return Status::NotImplemented("no stream implemented for this ticket");
+ }
+}
+
+class FlightTestServer : public FlightServerBase {
+ Status ListFlights(const ServerCallContext& context, const Criteria* criteria,
+ std::unique_ptr<FlightListing>* listings) override {
+ std::vector<FlightInfo> flights = ExampleFlightInfo();
+ *listings = std::unique_ptr<FlightListing>(new SimpleFlightListing(flights));
+ return Status::OK();
+ }
+
+ Status GetFlightInfo(const ServerCallContext& context, const FlightDescriptor& request,
+ std::unique_ptr<FlightInfo>* out) override {
+ std::vector<FlightInfo> flights = ExampleFlightInfo();
+
+ for (const auto& info : flights) {
+ if (info.descriptor().Equals(request)) {
+ *out = std::unique_ptr<FlightInfo>(new FlightInfo(info));
+ return Status::OK();
+ }
+ }
+ return Status::Invalid("Flight not found: ", request.ToString());
+ }
+
+ Status DoGet(const ServerCallContext& context, const Ticket& request,
+ std::unique_ptr<FlightDataStream>* data_stream) override {
+ // Test for ARROW-5095
+ if (request.ticket == "ARROW-5095-fail") {
+ return Status::UnknownError("Server-side error");
+ }
+ if (request.ticket == "ARROW-5095-success") {
+ return Status::OK();
+ }
+
+ std::shared_ptr<RecordBatchReader> batch_reader;
+ RETURN_NOT_OK(GetBatchForFlight(request, &batch_reader));
+
+ *data_stream = std::unique_ptr<FlightDataStream>(new RecordBatchStream(batch_reader));
+ return Status::OK();
+ }
+
+ Status RunAction1(const Action& action, std::unique_ptr<ResultStream>* out) {
+ std::vector<Result> results;
+ for (int i = 0; i < 3; ++i) {
+ Result result;
+ std::string value = action.body->ToString() + "-part" + std::to_string(i);
+ RETURN_NOT_OK(Buffer::FromString(value, &result.body));
+ results.push_back(result);
+ }
+ *out = std::unique_ptr<ResultStream>(new SimpleResultStream(std::move(results)));
+ return Status::OK();
+ }
+
+ Status RunAction2(std::unique_ptr<ResultStream>* out) {
+ // Empty
+ *out = std::unique_ptr<ResultStream>(new SimpleResultStream({}));
+ return Status::OK();
+ }
+
+ Status DoAction(const ServerCallContext& context, const Action& action,
+ std::unique_ptr<ResultStream>* out) override {
+ if (action.type == "action1") {
+ return RunAction1(action, out);
+ } else if (action.type == "action2") {
+ return RunAction2(out);
+ } else {
+ return Status::NotImplemented(action.type);
+ }
+ }
+
+ Status ListActions(const ServerCallContext& context,
+ std::vector<ActionType>* out) override {
+ std::vector<ActionType> actions = ExampleActionTypes();
+ *out = std::move(actions);
+ return Status::OK();
+ }
+};
+
+std::unique_ptr<FlightServerBase> ExampleTestServer() {
+ return std::unique_ptr<FlightServerBase>(new FlightTestServer);
+}
+
Status MakeFlightInfo(const Schema& schema, const FlightDescriptor& descriptor,
const std::vector<FlightEndpoint>& endpoints, int64_t total_records,
int64_t total_bytes, FlightInfo::Data* out) {
@@ -286,5 +382,70 @@ Status TestClientAuthHandler::GetToken(std::string* token) {
return Status::OK();
}
+Status GetTestResourceRoot(std::string* out) {
+ const char* c_root = std::getenv("ARROW_TEST_DATA");
+ if (!c_root) {
+ return Status::IOError("Test resources not found, set ARROW_TEST_DATA");
+ }
+ *out = std::string(c_root);
+ return Status::OK();
+}
+
+Status ExampleTlsCertificates(std::vector<CertKeyPair>* out) {
+ std::string root;
+ RETURN_NOT_OK(GetTestResourceRoot(&root));
+
+ *out = std::vector<CertKeyPair>();
+ for (int i = 0; i < 2; i++) {
+ try {
+ std::stringstream cert_path;
+ cert_path << root << "/flight/cert" << i << ".pem";
+ std::stringstream key_path;
+ key_path << root << "/flight/cert" << i << ".key";
+
+ std::ifstream cert_file(cert_path.str());
+ if (!cert_file) {
+ return Status::IOError("Could not open certificate: " + cert_path.str());
+ }
+ std::stringstream cert;
+ cert << cert_file.rdbuf();
+
+ std::ifstream key_file(key_path.str());
+ if (!key_file) {
+ return Status::IOError("Could not open key: " + key_path.str());
+ }
+ std::stringstream key;
+ key << key_file.rdbuf();
+
+ out->push_back(CertKeyPair{cert.str(), key.str()});
+ } catch (const std::ifstream::failure& e) {
+ return Status::IOError(e.what());
+ }
+ }
+ return Status::OK();
+}
+
+Status ExampleTlsCertificateRoot(CertKeyPair* out) {
+ std::string root;
+ RETURN_NOT_OK(GetTestResourceRoot(&root));
+
+ std::stringstream path;
+ path << root << "/flight/root-ca.pem";
+
+ try {
+ std::ifstream cert_file(path.str());
+ if (!cert_file) {
+ return Status::IOError("Could not open certificate: " + path.str());
+ }
+ std::stringstream cert;
+ cert << cert_file.rdbuf();
+ out->pem_cert = cert.str();
+ out->pem_key = "";
+ return Status::OK();
+ } catch (const std::ifstream::failure& e) {
+ return Status::IOError(e.what());
+ }
+}
+
} // namespace flight
} // namespace arrow
diff --git a/cpp/src/arrow/flight/test-util.h b/cpp/src/arrow/flight/test-util.h
index 2e1f4b0..5b02630 100644
--- a/cpp/src/arrow/flight/test-util.h
+++ b/cpp/src/arrow/flight/test-util.h
@@ -86,6 +86,10 @@ class ARROW_FLIGHT_EXPORT InProcessTestServer {
std::thread thread_;
};
+/// \brief Create a simple Flight server for testing
+ARROW_FLIGHT_EXPORT
+std::unique_ptr<FlightServerBase> ExampleTestServer();
+
// ----------------------------------------------------------------------
// A RecordBatchReader for serving a sequence of in-memory record batches
@@ -184,5 +188,11 @@ class ARROW_FLIGHT_EXPORT TestClientAuthHandler : public ClientAuthHandler {
std::string password_;
};
+ARROW_FLIGHT_EXPORT
+Status ExampleTlsCertificates(std::vector<CertKeyPair>* out);
+
+ARROW_FLIGHT_EXPORT
+Status ExampleTlsCertificateRoot(CertKeyPair* out);
+
} // namespace flight
} // namespace arrow
diff --git a/cpp/src/arrow/flight/types.cc b/cpp/src/arrow/flight/types.cc
index dadb510..d982efc 100644
--- a/cpp/src/arrow/flight/types.cc
+++ b/cpp/src/arrow/flight/types.cc
@@ -96,6 +96,12 @@ Status Location::ForGrpcTcp(const std::string& host, const int port, Location* l
return Location::Parse(uri_string.str(), location);
}
+Status Location::ForGrpcTls(const std::string& host, const int port, Location* location) {
+ std::stringstream uri_string;
+ uri_string << "grpc+tls://" << host << ':' << port;
+ return Location::Parse(uri_string.str(), location);
+}
+
Status Location::ForGrpcUnix(const std::string& path, Location* location) {
std::stringstream uri_string;
uri_string << "grpc+unix://" << path;
diff --git a/cpp/src/arrow/flight/types.h b/cpp/src/arrow/flight/types.h
index 8d37225..e5f7bcd 100644
--- a/cpp/src/arrow/flight/types.h
+++ b/cpp/src/arrow/flight/types.h
@@ -49,6 +49,15 @@ class Uri;
namespace flight {
+/// \brief A TLS certificate plus key.
+struct ARROW_FLIGHT_EXPORT CertKeyPair {
+ /// \brief The certificate in PEM format.
+ std::string pem_cert;
+
+ /// \brief The key in PEM format.
+ std::string pem_key;
+};
+
/// \brief A type of action that can be performed with the DoAction RPC
struct ARROW_FLIGHT_EXPORT ActionType {
/// Name of action
@@ -145,6 +154,13 @@ struct ARROW_FLIGHT_EXPORT Location {
/// \param[out] location The resulting location
static Status ForGrpcTcp(const std::string& host, const int port, Location* location);
+ /// \brief Initialize a location for a TLS-enabled, gRPC-based Flight
+ /// service from a host and port
+ /// \param[in] host The hostname to connect to
+ /// \param[in] port The port
+ /// \param[out] location The resulting location
+ static Status ForGrpcTls(const std::string& host, const int port, Location* location);
+
/// \brief Initialize a location for a domain socket-based Flight
/// service
/// \param[in] path The path to the domain socket
diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx
index c682635..c916e6b 100644
--- a/python/pyarrow/_flight.pyx
+++ b/python/pyarrow/_flight.pyx
@@ -57,6 +57,13 @@ cdef class FlightCallOptions:
"'{}'".format(type(obj)))
+_CertKeyPair = collections.namedtuple('_CertKeyPair', ['cert', 'key'])
+
+
+class CertKeyPair(_CertKeyPair):
+ """A TLS certificate and key for use in Flight."""
+
+
cdef class Action:
"""An action executable on a Flight service."""
cdef:
@@ -228,6 +235,16 @@ cdef class Location:
return result
@staticmethod
+ def for_grpc_tls(host, port):
+ """Create a Location for a TLS-based gRPC service."""
+ cdef:
+ c_string c_host = tobytes(host)
+ int c_port = port
+ Location result = Location.__new__(Location)
+ check_status(CLocation.ForGrpcTls(c_host, c_port, &result.location))
+ return result
+
+ @staticmethod
def for_grpc_unix(path):
"""Create a Location for a domain socket-based gRPC service."""
cdef:
@@ -1016,12 +1033,12 @@ cdef class FlightServerBase:
cdef:
unique_ptr[PyFlightServer] server
- def run(self, location, auth_handler=None,
- tls_cert_chain=None, tls_private_key=None):
+ def run(self, location, auth_handler=None, tls_certificates=None):
cdef:
PyFlightServerVtable vtable = PyFlightServerVtable()
PyFlightServer* c_server
unique_ptr[CFlightServerOptions] c_options
+ CCertKeyPair c_cert
c_options.reset(new CFlightServerOptions(Location.unwrap(location)))
@@ -1032,12 +1049,11 @@ cdef class FlightServerBase:
c_options.get().auth_handler.reset(
(<ServerAuthHandler> auth_handler).to_handler())
- if tls_cert_chain:
- if not tls_private_key:
- raise ValueError(
- "Must provide both cert chain and private key")
- c_options.get().tls_cert_chain = tobytes(tls_cert_chain)
- c_options.get().tls_private_key = tobytes(tls_private_key)
+ if tls_certificates:
+ for cert, key in tls_certificates:
+ c_cert.pem_cert = tobytes(cert)
+ c_cert.pem_key = tobytes(key)
+ c_options.get().tls_certificates.push_back(c_cert)
vtable.list_flights = &_list_flights
vtable.get_flight_info = &_get_flight_info
diff --git a/python/pyarrow/flight.py b/python/pyarrow/flight.py
index 37a21e4..05198e4 100644
--- a/python/pyarrow/flight.py
+++ b/python/pyarrow/flight.py
@@ -25,6 +25,7 @@ if sys.version_info < (3,):
from pyarrow._flight import ( # noqa
Action,
ActionType,
+ CertKeyPair,
DescriptorType,
FlightCallOptions,
FlightClient,
diff --git a/python/pyarrow/includes/libarrow_flight.pxd b/python/pyarrow/includes/libarrow_flight.pxd
index 4b74990..14d1ed1 100644
--- a/python/pyarrow/includes/libarrow_flight.pxd
+++ b/python/pyarrow/includes/libarrow_flight.pxd
@@ -88,6 +88,8 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil:
@staticmethod
CStatus ForGrpcTcp(c_string& host, int port, CLocation* location)
@staticmethod
+ CStatus ForGrpcTls(c_string& host, int port, CLocation* location)
+ @staticmethod
CStatus ForGrpcUnix(c_string& path, CLocation* location)
cdef cppclass CFlightEndpoint" arrow::flight::FlightEndpoint":
@@ -154,12 +156,16 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil:
CFlightCallOptions()
CTimeoutDuration timeout
+ cdef cppclass CCertKeyPair" arrow::flight::CertKeyPair":
+ CCertKeyPair()
+ c_string pem_cert
+ c_string pem_key
+
cdef cppclass CFlightServerOptions" arrow::flight::FlightServerOptions":
CFlightServerOptions(const CLocation& location)
CLocation location
unique_ptr[CServerAuthHandler] auth_handler
- c_string tls_cert_chain
- c_string tls_private_key
+ vector[CCertKeyPair] tls_certificates
cdef cppclass CFlightClientOptions" arrow::flight::FlightClientOptions":
CFlightClientOptions()
diff --git a/python/pyarrow/tests/test_flight.py b/python/pyarrow/tests/test_flight.py
index 9ce2264..a7e6e34 100644
--- a/python/pyarrow/tests/test_flight.py
+++ b/python/pyarrow/tests/test_flight.py
@@ -28,12 +28,52 @@ import pytest
import pyarrow as pa
+from pathlib import Path
from pyarrow.compat import tobytes
flight = pytest.importorskip("pyarrow.flight")
+def resource_root():
+ """Get the path to the test resources directory."""
+ if not os.environ.get("ARROW_TEST_DATA"):
+ raise RuntimeError("Test resources not found; set "
+ "ARROW_TEST_DATA to <repo root>/testing")
+ return Path(os.environ["ARROW_TEST_DATA"]) / "flight"
+
+
+def read_flight_resource(path):
+ """Get the contents of a test resource file."""
+ root = resource_root()
+ if not root:
+ return None
+ try:
+ with (root / path).open("rb") as f:
+ return f.read()
+ except FileNotFoundError as e:
+ raise RuntimeError(
+ "Test resource {} not found; did you initialize the "
+ "test resource submodule?".format(root / path)) from e
+
+
+def example_tls_certs():
+ """Get the paths to test TLS certificates."""
+ return {
+ "root_cert": read_flight_resource("root-ca.pem"),
+ "certificates": [
+ flight.CertKeyPair(
+ cert=read_flight_resource("cert0.pem"),
+ key=read_flight_resource("cert0.key"),
+ ),
+ flight.CertKeyPair(
+ cert=read_flight_resource("cert1.pem"),
+ key=read_flight_resource("cert1.key"),
+ ),
+ ]
+ }
+
+
def simple_ints_table():
data = [
pa.array([-10, -5, 0, 5, 10])
@@ -245,6 +285,7 @@ class TokenClientAuthHandler(flight.ClientAuthHandler):
def flight_server(server_base, *args, **kwargs):
"""Spawn a Flight server on a free port, shutting it down when done."""
auth_handler = kwargs.pop('auth_handler', None)
+ tls_certificates = kwargs.pop('tls_certificates', None)
location = kwargs.pop('location', None)
if location is None:
@@ -254,7 +295,10 @@ def flight_server(server_base, *args, **kwargs):
sock.bind(('', 0))
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
port = sock.getsockname()[1]
- location = flight.Location.for_grpc_tcp("localhost", port)
+ ctor = flight.Location.for_grpc_tcp
+ if tls_certificates:
+ ctor = flight.Location.for_grpc_tls
+ location = ctor("localhost", port)
else:
port = None
@@ -262,11 +306,26 @@ def flight_server(server_base, *args, **kwargs):
server_instance = server_base(*args, **ctor_kwargs)
def _server_thread():
- server_instance.run(location, auth_handler=auth_handler)
+ server_instance.run(
+ location,
+ auth_handler=auth_handler,
+ tls_certificates=tls_certificates,
+ )
thread = threading.Thread(target=_server_thread, daemon=True)
thread.start()
+ # Wait for server to start
+ client = flight.FlightClient.connect(location)
+ while True:
+ try:
+ list(client.list_flights())
+ except Exception as e:
+ if 'Connect Failed' in str(e):
+ time.sleep(0.025)
+ continue
+ break
+
yield location
server_instance.shutdown()
@@ -471,3 +530,32 @@ def test_location_invalid():
server = ConstantFlightServer()
with pytest.raises(pa.ArrowInvalid, match=".*Cannot parse URI:.*"):
server.run("%")
+
+
+@pytest.mark.slow
+def test_tls_fails():
+ """Make sure clients cannot connect when cert verification fails."""
+ certs = example_tls_certs()
+
+ with flight_server(
+ ConstantFlightServer, tls_certificates=certs["certificates"]
+ ) as server_location:
+ # Ensure client doesn't connect when certificate verification
+ # fails (this is a slow test since gRPC does retry a few times)
+ client = flight.FlightClient.connect(server_location)
+ with pytest.raises(pa.ArrowIOError, match="Connect Failed"):
+ client.do_get(flight.Ticket(b'ints'))
+
+
+def test_tls_do_get():
+ """Try a simple do_get call over TLS."""
+ table = simple_ints_table()
+ certs = example_tls_certs()
+
+ with flight_server(
+ ConstantFlightServer, tls_certificates=certs["certificates"]
+ ) as server_location:
+ client = flight.FlightClient.connect(
+ server_location, tls_root_certs=certs["root_cert"])
+ data = client.do_get(flight.Ticket(b'ints')).read_all()
+ assert data.equals(table)
diff --git a/testing b/testing
index bf0abe4..12f9dbd 160000
--- a/testing
+++ b/testing
@@ -1 +1 @@
-Subproject commit bf0abe442bf7e313380452c8972692940f4e56b6
+Subproject commit 12f9dbd2a37eea6fa370e108a1d797ee1167724a