You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ap...@apache.org on 2019/06/13 16:42:46 UTC

[arrow] branch master updated: ARROW-5397: [FlightRPC] Add TLS certificates for testing Flight

This is an automated email from the ASF dual-hosted git repository.

apitrou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new b7e8ed7  ARROW-5397: [FlightRPC] Add TLS certificates for testing Flight
b7e8ed7 is described below

commit b7e8ed7fe9613c899a2181bcf48996466b31d9f8
Author: David Li <li...@gmail.com>
AuthorDate: Thu Jun 13 18:42:38 2019 +0200

    ARROW-5397: [FlightRPC] Add TLS certificates for testing Flight
    
    This needs https://github.com/apache/arrow-testing/pull/2.
    
    Author: David Li <li...@gmail.com>
    
    Closes #4510 from lihalite/flight-tls and squashes the following commits:
    
    5eff72470 <David Li> Don't set wait_for_ready in Flight
    776b9d01e <David Li> Add tests for TLS in Flight (C++, Python)
    9d2efa20a <David Li> Allow multiple TLS certificates in Flight
---
 cpp/src/arrow/flight/client.cc              |   3 -
 cpp/src/arrow/flight/flight-test.cc         |  89 ++++++++++++---
 cpp/src/arrow/flight/server.cc              |   7 +-
 cpp/src/arrow/flight/server.h               |   3 +-
 cpp/src/arrow/flight/test-server.cc         | 109 +------------------
 cpp/src/arrow/flight/test-util.cc           | 161 ++++++++++++++++++++++++++++
 cpp/src/arrow/flight/test-util.h            |  10 ++
 cpp/src/arrow/flight/types.cc               |   6 ++
 cpp/src/arrow/flight/types.h                |  16 +++
 python/pyarrow/_flight.pyx                  |  32 ++++--
 python/pyarrow/flight.py                    |   1 +
 python/pyarrow/includes/libarrow_flight.pxd |  10 +-
 python/pyarrow/tests/test_flight.py         |  92 +++++++++++++++-
 testing                                     |   2 +-
 14 files changed, 398 insertions(+), 143 deletions(-)

diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc
index 1c927da..2b7c699 100644
--- a/cpp/src/arrow/flight/client.cc
+++ b/cpp/src/arrow/flight/client.cc
@@ -60,9 +60,6 @@ struct ClientRpc {
   grpc::ClientContext context;
 
   explicit ClientRpc(const FlightCallOptions& options) {
-    /// XXX workaround until we have a handshake in Connect
-    context.set_wait_for_ready(true);
-
     if (options.timeout.count() >= 0) {
       std::chrono::system_clock::time_point deadline =
           std::chrono::time_point_cast<std::chrono::system_clock::time_point::duration>(
diff --git a/cpp/src/arrow/flight/flight-test.cc b/cpp/src/arrow/flight/flight-test.cc
index cb7e57c..b295878 100644
--- a/cpp/src/arrow/flight/flight-test.cc
+++ b/cpp/src/arrow/flight/flight-test.cc
@@ -176,29 +176,22 @@ TEST(TestFlight, ConnectUri) {
 
 class TestFlightClient : public ::testing::Test {
  public:
-  // Uncomment these when you want to run the server separately for
-  // debugging/valgrind/gdb
+  void SetUp() {
+    Location location;
+    std::unique_ptr<FlightServerBase> server = ExampleTestServer();
 
-  // void SetUp() {
-  //   port_ = 92358;
-  //   ASSERT_OK(ConnectClient());
-  // }
-  // void TearDown() {}
+    ASSERT_OK(Location::ForGrpcTcp("localhost", GetListenPort(), &location));
+    FlightServerOptions options(location);
+    ASSERT_OK(server->Init(options));
 
-  void SetUp() {
-    server_.reset(new TestServer("flight-test-server"));
-    server_->Start();
-    port_ = server_->port();
+    server_.reset(new InProcessTestServer(std::move(server), location));
+    ASSERT_OK(server_->Start());
     ASSERT_OK(ConnectClient());
   }
 
   void TearDown() { server_->Stop(); }
 
-  Status ConnectClient() {
-    Location location;
-    RETURN_NOT_OK(Location::ForGrpcTcp("localhost", port_, &location));
-    return FlightClient::Connect(location, &client_);
-  }
+  Status ConnectClient() { return FlightClient::Connect(server_->location(), &client_); }
 
   template <typename EndpointCheckFunc>
   void CheckDoGet(const FlightDescriptor& descr, const BatchVector& expected_batches,
@@ -236,7 +229,7 @@ class TestFlightClient : public ::testing::Test {
  protected:
   int port_;
   std::unique_ptr<FlightClient> client_;
-  std::unique_ptr<TestServer> server_;
+  std::unique_ptr<InProcessTestServer> server_;
 };
 
 class AuthTestServer : public FlightServerBase {
@@ -249,6 +242,16 @@ class AuthTestServer : public FlightServerBase {
   }
 };
 
+class TlsTestServer : public FlightServerBase {
+  Status DoAction(const ServerCallContext& context, const Action& action,
+                  std::unique_ptr<ResultStream>* result) override {
+    std::shared_ptr<Buffer> buf;
+    RETURN_NOT_OK(Buffer::FromString("Hello, world!", &buf));
+    *result = std::unique_ptr<ResultStream>(new SimpleResultStream({Result{buf}}));
+    return Status::OK();
+  }
+};
+
 class DoPutTestServer : public FlightServerBase {
  public:
   Status DoPut(const ServerCallContext& context,
@@ -336,6 +339,42 @@ class TestDoPut : public ::testing::Test {
   DoPutTestServer* do_put_server_;
 };
 
+class TestTls : public ::testing::Test {
+ public:
+  void SetUp() {
+    Location location;
+    std::unique_ptr<FlightServerBase> server(new TlsTestServer);
+
+    ASSERT_OK(Location::ForGrpcTls("localhost", GetListenPort(), &location));
+    FlightServerOptions options(location);
+    ASSERT_RAISES(UnknownError, server->Init(options));
+    ASSERT_OK(ExampleTlsCertificates(&options.tls_certificates));
+    ASSERT_OK(server->Init(options));
+
+    server_.reset(new InProcessTestServer(std::move(server), location));
+    ASSERT_OK(server_->Start());
+    ASSERT_OK(ConnectClient());
+  }
+
+  void TearDown() {
+    if (server_) {
+      server_->Stop();
+    }
+  }
+
+  Status ConnectClient() {
+    auto options = FlightClientOptions();
+    CertKeyPair root_cert;
+    RETURN_NOT_OK(ExampleTlsCertificateRoot(&root_cert));
+    options.tls_root_certs = root_cert.pem_cert;
+    return FlightClient::Connect(server_->location(), options, &client_);
+  }
+
+ protected:
+  std::unique_ptr<FlightClient> client_;
+  std::unique_ptr<InProcessTestServer> server_;
+};
+
 TEST_F(TestFlightClient, ListFlights) {
   std::unique_ptr<FlightListing> listing;
   ASSERT_OK(client_->ListFlights(&listing));
@@ -620,5 +659,21 @@ TEST_F(TestAuthHandler, CheckPeerIdentity) {
   ASSERT_EQ(result->body->ToString(), "user");
 }
 
+TEST_F(TestTls, DoAction) {
+  FlightCallOptions options;
+  options.timeout = TimeoutDuration{5.0};
+  Action action;
+  action.type = "test";
+  action.body = Buffer::FromString("");
+  std::unique_ptr<ResultStream> results;
+  ASSERT_OK(client_->DoAction(options, action, &results));
+  ASSERT_NE(results, nullptr);
+
+  std::unique_ptr<Result> result;
+  ASSERT_OK(results->Next(&result));
+  ASSERT_NE(result, nullptr);
+  ASSERT_EQ(result->body->ToString(), "Hello, world!");
+}
+
 }  // namespace flight
 }  // namespace arrow
diff --git a/cpp/src/arrow/flight/server.cc b/cpp/src/arrow/flight/server.cc
index 9b6bf6c..6f3c466 100644
--- a/cpp/src/arrow/flight/server.cc
+++ b/cpp/src/arrow/flight/server.cc
@@ -460,7 +460,7 @@ thread_local std::atomic<FlightServerBase::Impl*>
 #endif
 
 FlightServerOptions::FlightServerOptions(const Location& location_)
-    : location(location_), auth_handler(nullptr) {}
+    : location(location_), auth_handler(nullptr), tls_certificates() {}
 
 FlightServerBase::FlightServerBase() { impl_.reset(new Impl); }
 
@@ -483,8 +483,9 @@ Status FlightServerBase::Init(FlightServerOptions& options) {
     std::shared_ptr<grpc::ServerCredentials> creds;
     if (scheme == kSchemeGrpcTls) {
       grpc::SslServerCredentialsOptions ssl_options;
-      ssl_options.pem_key_cert_pairs.push_back(
-          {options.tls_private_key, options.tls_cert_chain});
+      for (const auto& pair : options.tls_certificates) {
+        ssl_options.pem_key_cert_pairs.push_back({pair.pem_key, pair.pem_cert});
+      }
       creds = grpc::SslServerCredentials(ssl_options);
     } else {
       creds = grpc::InsecureServerCredentials();
diff --git a/cpp/src/arrow/flight/server.h b/cpp/src/arrow/flight/server.h
index 7164b64..c1bcb5c 100644
--- a/cpp/src/arrow/flight/server.h
+++ b/cpp/src/arrow/flight/server.h
@@ -106,8 +106,7 @@ class ARROW_FLIGHT_EXPORT FlightServerOptions {
 
   Location location;
   std::unique_ptr<ServerAuthHandler> auth_handler;
-  std::string tls_cert_chain;
-  std::string tls_private_key;
+  std::vector<CertKeyPair> tls_certificates;
 };
 
 /// \brief Skeleton RPC server implementation which can be used to create
diff --git a/cpp/src/arrow/flight/test-server.cc b/cpp/src/arrow/flight/test-server.cc
index f72fd3c..87ef62f 100644
--- a/cpp/src/arrow/flight/test-server.cc
+++ b/cpp/src/arrow/flight/test-server.cc
@@ -25,120 +25,19 @@
 
 #include <gflags/gflags.h>
 
-#include "arrow/buffer.h"
-#include "arrow/io/test-common.h"
-#include "arrow/record_batch.h"
-#include "arrow/util/logging.h"
-
 #include "arrow/flight/server.h"
-#include "arrow/flight/server_auth.h"
 #include "arrow/flight/test-util.h"
+#include "arrow/flight/types.h"
+#include "arrow/util/logging.h"
 
 DEFINE_int32(port, 31337, "Server port to listen on");
 
-namespace arrow {
-namespace flight {
-
-Status GetBatchForFlight(const Ticket& ticket, std::shared_ptr<RecordBatchReader>* out) {
-  if (ticket.ticket == "ticket-ints-1") {
-    BatchVector batches;
-    RETURN_NOT_OK(ExampleIntBatches(&batches));
-    *out = std::make_shared<BatchIterator>(batches[0]->schema(), batches);
-    return Status::OK();
-  } else if (ticket.ticket == "ticket-dicts-1") {
-    BatchVector batches;
-    RETURN_NOT_OK(ExampleDictBatches(&batches));
-    *out = std::make_shared<BatchIterator>(batches[0]->schema(), batches);
-    return Status::OK();
-  } else {
-    return Status::NotImplemented("no stream implemented for this ticket");
-  }
-}
-
-class FlightTestServer : public FlightServerBase {
-  Status ListFlights(const ServerCallContext& context, const Criteria* criteria,
-                     std::unique_ptr<FlightListing>* listings) override {
-    std::vector<FlightInfo> flights = ExampleFlightInfo();
-    *listings = std::unique_ptr<FlightListing>(new SimpleFlightListing(flights));
-    return Status::OK();
-  }
-
-  Status GetFlightInfo(const ServerCallContext& context, const FlightDescriptor& request,
-                       std::unique_ptr<FlightInfo>* out) override {
-    std::vector<FlightInfo> flights = ExampleFlightInfo();
-
-    for (const auto& info : flights) {
-      if (info.descriptor().Equals(request)) {
-        *out = std::unique_ptr<FlightInfo>(new FlightInfo(info));
-        return Status::OK();
-      }
-    }
-    return Status::Invalid("Flight not found: ", request.ToString());
-  }
-
-  Status DoGet(const ServerCallContext& context, const Ticket& request,
-               std::unique_ptr<FlightDataStream>* data_stream) override {
-    // Test for ARROW-5095
-    if (request.ticket == "ARROW-5095-fail") {
-      return Status::UnknownError("Server-side error");
-    }
-    if (request.ticket == "ARROW-5095-success") {
-      return Status::OK();
-    }
-
-    std::shared_ptr<RecordBatchReader> batch_reader;
-    RETURN_NOT_OK(GetBatchForFlight(request, &batch_reader));
-
-    *data_stream = std::unique_ptr<FlightDataStream>(new RecordBatchStream(batch_reader));
-    return Status::OK();
-  }
-
-  Status RunAction1(const Action& action, std::unique_ptr<ResultStream>* out) {
-    std::vector<Result> results;
-    for (int i = 0; i < 3; ++i) {
-      Result result;
-      std::string value = action.body->ToString() + "-part" + std::to_string(i);
-      RETURN_NOT_OK(Buffer::FromString(value, &result.body));
-      results.push_back(result);
-    }
-    *out = std::unique_ptr<ResultStream>(new SimpleResultStream(std::move(results)));
-    return Status::OK();
-  }
-
-  Status RunAction2(std::unique_ptr<ResultStream>* out) {
-    // Empty
-    *out = std::unique_ptr<ResultStream>(new SimpleResultStream({}));
-    return Status::OK();
-  }
-
-  Status DoAction(const ServerCallContext& context, const Action& action,
-                  std::unique_ptr<ResultStream>* out) override {
-    if (action.type == "action1") {
-      return RunAction1(action, out);
-    } else if (action.type == "action2") {
-      return RunAction2(out);
-    } else {
-      return Status::NotImplemented(action.type);
-    }
-  }
-
-  Status ListActions(const ServerCallContext& context,
-                     std::vector<ActionType>* out) override {
-    std::vector<ActionType> actions = ExampleActionTypes();
-    *out = std::move(actions);
-    return Status::OK();
-  }
-};
-
-}  // namespace flight
-}  // namespace arrow
-
-std::unique_ptr<arrow::flight::FlightTestServer> g_server;
+std::unique_ptr<arrow::flight::FlightServerBase> g_server;
 
 int main(int argc, char** argv) {
   gflags::ParseCommandLineFlags(&argc, &argv, true);
 
-  g_server.reset(new arrow::flight::FlightTestServer);
+  g_server = arrow::flight::ExampleTestServer();
 
   arrow::flight::Location location;
   ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTcp("0.0.0.0", FLAGS_port, &location));
diff --git a/cpp/src/arrow/flight/test-util.cc b/cpp/src/arrow/flight/test-util.cc
index b20a4cb..7dd78fd 100644
--- a/cpp/src/arrow/flight/test-util.cc
+++ b/cpp/src/arrow/flight/test-util.cc
@@ -22,6 +22,7 @@
 #include <mach-o/dyld.h>
 #endif
 
+#include <cstdlib>
 #include <sstream>
 
 #include <boost/filesystem.hpp>
@@ -154,6 +155,101 @@ InProcessTestServer::~InProcessTestServer() {
   }
 }
 
+Status GetBatchForFlight(const Ticket& ticket, std::shared_ptr<RecordBatchReader>* out) {
+  if (ticket.ticket == "ticket-ints-1") {
+    BatchVector batches;
+    RETURN_NOT_OK(ExampleIntBatches(&batches));
+    *out = std::make_shared<BatchIterator>(batches[0]->schema(), batches);
+    return Status::OK();
+  } else if (ticket.ticket == "ticket-dicts-1") {
+    BatchVector batches;
+    RETURN_NOT_OK(ExampleDictBatches(&batches));
+    *out = std::make_shared<BatchIterator>(batches[0]->schema(), batches);
+    return Status::OK();
+  } else {
+    return Status::NotImplemented("no stream implemented for this ticket");
+  }
+}
+
+class FlightTestServer : public FlightServerBase {
+  Status ListFlights(const ServerCallContext& context, const Criteria* criteria,
+                     std::unique_ptr<FlightListing>* listings) override {
+    std::vector<FlightInfo> flights = ExampleFlightInfo();
+    *listings = std::unique_ptr<FlightListing>(new SimpleFlightListing(flights));
+    return Status::OK();
+  }
+
+  Status GetFlightInfo(const ServerCallContext& context, const FlightDescriptor& request,
+                       std::unique_ptr<FlightInfo>* out) override {
+    std::vector<FlightInfo> flights = ExampleFlightInfo();
+
+    for (const auto& info : flights) {
+      if (info.descriptor().Equals(request)) {
+        *out = std::unique_ptr<FlightInfo>(new FlightInfo(info));
+        return Status::OK();
+      }
+    }
+    return Status::Invalid("Flight not found: ", request.ToString());
+  }
+
+  Status DoGet(const ServerCallContext& context, const Ticket& request,
+               std::unique_ptr<FlightDataStream>* data_stream) override {
+    // Test for ARROW-5095
+    if (request.ticket == "ARROW-5095-fail") {
+      return Status::UnknownError("Server-side error");
+    }
+    if (request.ticket == "ARROW-5095-success") {
+      return Status::OK();
+    }
+
+    std::shared_ptr<RecordBatchReader> batch_reader;
+    RETURN_NOT_OK(GetBatchForFlight(request, &batch_reader));
+
+    *data_stream = std::unique_ptr<FlightDataStream>(new RecordBatchStream(batch_reader));
+    return Status::OK();
+  }
+
+  Status RunAction1(const Action& action, std::unique_ptr<ResultStream>* out) {
+    std::vector<Result> results;
+    for (int i = 0; i < 3; ++i) {
+      Result result;
+      std::string value = action.body->ToString() + "-part" + std::to_string(i);
+      RETURN_NOT_OK(Buffer::FromString(value, &result.body));
+      results.push_back(result);
+    }
+    *out = std::unique_ptr<ResultStream>(new SimpleResultStream(std::move(results)));
+    return Status::OK();
+  }
+
+  Status RunAction2(std::unique_ptr<ResultStream>* out) {
+    // Empty
+    *out = std::unique_ptr<ResultStream>(new SimpleResultStream({}));
+    return Status::OK();
+  }
+
+  Status DoAction(const ServerCallContext& context, const Action& action,
+                  std::unique_ptr<ResultStream>* out) override {
+    if (action.type == "action1") {
+      return RunAction1(action, out);
+    } else if (action.type == "action2") {
+      return RunAction2(out);
+    } else {
+      return Status::NotImplemented(action.type);
+    }
+  }
+
+  Status ListActions(const ServerCallContext& context,
+                     std::vector<ActionType>* out) override {
+    std::vector<ActionType> actions = ExampleActionTypes();
+    *out = std::move(actions);
+    return Status::OK();
+  }
+};
+
+std::unique_ptr<FlightServerBase> ExampleTestServer() {
+  return std::unique_ptr<FlightServerBase>(new FlightTestServer);
+}
+
 Status MakeFlightInfo(const Schema& schema, const FlightDescriptor& descriptor,
                       const std::vector<FlightEndpoint>& endpoints, int64_t total_records,
                       int64_t total_bytes, FlightInfo::Data* out) {
@@ -286,5 +382,70 @@ Status TestClientAuthHandler::GetToken(std::string* token) {
   return Status::OK();
 }
 
+Status GetTestResourceRoot(std::string* out) {
+  const char* c_root = std::getenv("ARROW_TEST_DATA");
+  if (!c_root) {
+    return Status::IOError("Test resources not found, set ARROW_TEST_DATA");
+  }
+  *out = std::string(c_root);
+  return Status::OK();
+}
+
+Status ExampleTlsCertificates(std::vector<CertKeyPair>* out) {
+  std::string root;
+  RETURN_NOT_OK(GetTestResourceRoot(&root));
+
+  *out = std::vector<CertKeyPair>();
+  for (int i = 0; i < 2; i++) {
+    try {
+      std::stringstream cert_path;
+      cert_path << root << "/flight/cert" << i << ".pem";
+      std::stringstream key_path;
+      key_path << root << "/flight/cert" << i << ".key";
+
+      std::ifstream cert_file(cert_path.str());
+      if (!cert_file) {
+        return Status::IOError("Could not open certificate: " + cert_path.str());
+      }
+      std::stringstream cert;
+      cert << cert_file.rdbuf();
+
+      std::ifstream key_file(key_path.str());
+      if (!key_file) {
+        return Status::IOError("Could not open key: " + key_path.str());
+      }
+      std::stringstream key;
+      key << key_file.rdbuf();
+
+      out->push_back(CertKeyPair{cert.str(), key.str()});
+    } catch (const std::ifstream::failure& e) {
+      return Status::IOError(e.what());
+    }
+  }
+  return Status::OK();
+}
+
+Status ExampleTlsCertificateRoot(CertKeyPair* out) {
+  std::string root;
+  RETURN_NOT_OK(GetTestResourceRoot(&root));
+
+  std::stringstream path;
+  path << root << "/flight/root-ca.pem";
+
+  try {
+    std::ifstream cert_file(path.str());
+    if (!cert_file) {
+      return Status::IOError("Could not open certificate: " + path.str());
+    }
+    std::stringstream cert;
+    cert << cert_file.rdbuf();
+    out->pem_cert = cert.str();
+    out->pem_key = "";
+    return Status::OK();
+  } catch (const std::ifstream::failure& e) {
+    return Status::IOError(e.what());
+  }
+}
+
 }  // namespace flight
 }  // namespace arrow
diff --git a/cpp/src/arrow/flight/test-util.h b/cpp/src/arrow/flight/test-util.h
index 2e1f4b0..5b02630 100644
--- a/cpp/src/arrow/flight/test-util.h
+++ b/cpp/src/arrow/flight/test-util.h
@@ -86,6 +86,10 @@ class ARROW_FLIGHT_EXPORT InProcessTestServer {
   std::thread thread_;
 };
 
+/// \brief Create a simple Flight server for testing
+ARROW_FLIGHT_EXPORT
+std::unique_ptr<FlightServerBase> ExampleTestServer();
+
 // ----------------------------------------------------------------------
 // A RecordBatchReader for serving a sequence of in-memory record batches
 
@@ -184,5 +188,11 @@ class ARROW_FLIGHT_EXPORT TestClientAuthHandler : public ClientAuthHandler {
   std::string password_;
 };
 
+ARROW_FLIGHT_EXPORT
+Status ExampleTlsCertificates(std::vector<CertKeyPair>* out);
+
+ARROW_FLIGHT_EXPORT
+Status ExampleTlsCertificateRoot(CertKeyPair* out);
+
 }  // namespace flight
 }  // namespace arrow
diff --git a/cpp/src/arrow/flight/types.cc b/cpp/src/arrow/flight/types.cc
index dadb510..d982efc 100644
--- a/cpp/src/arrow/flight/types.cc
+++ b/cpp/src/arrow/flight/types.cc
@@ -96,6 +96,12 @@ Status Location::ForGrpcTcp(const std::string& host, const int port, Location* l
   return Location::Parse(uri_string.str(), location);
 }
 
+Status Location::ForGrpcTls(const std::string& host, const int port, Location* location) {
+  std::stringstream uri_string;
+  uri_string << "grpc+tls://" << host << ':' << port;
+  return Location::Parse(uri_string.str(), location);
+}
+
 Status Location::ForGrpcUnix(const std::string& path, Location* location) {
   std::stringstream uri_string;
   uri_string << "grpc+unix://" << path;
diff --git a/cpp/src/arrow/flight/types.h b/cpp/src/arrow/flight/types.h
index 8d37225..e5f7bcd 100644
--- a/cpp/src/arrow/flight/types.h
+++ b/cpp/src/arrow/flight/types.h
@@ -49,6 +49,15 @@ class Uri;
 
 namespace flight {
 
+/// \brief A TLS certificate plus key.
+struct ARROW_FLIGHT_EXPORT CertKeyPair {
+  /// \brief The certificate in PEM format.
+  std::string pem_cert;
+
+  /// \brief The key in PEM format.
+  std::string pem_key;
+};
+
 /// \brief A type of action that can be performed with the DoAction RPC
 struct ARROW_FLIGHT_EXPORT ActionType {
   /// Name of action
@@ -145,6 +154,13 @@ struct ARROW_FLIGHT_EXPORT Location {
   /// \param[out] location The resulting location
   static Status ForGrpcTcp(const std::string& host, const int port, Location* location);
 
+  /// \brief Initialize a location for a TLS-enabled, gRPC-based Flight
+  /// service from a host and port
+  /// \param[in] host The hostname to connect to
+  /// \param[in] port The port
+  /// \param[out] location The resulting location
+  static Status ForGrpcTls(const std::string& host, const int port, Location* location);
+
   /// \brief Initialize a location for a domain socket-based Flight
   /// service
   /// \param[in] path The path to the domain socket
diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx
index c682635..c916e6b 100644
--- a/python/pyarrow/_flight.pyx
+++ b/python/pyarrow/_flight.pyx
@@ -57,6 +57,13 @@ cdef class FlightCallOptions:
                         "'{}'".format(type(obj)))
 
 
+_CertKeyPair = collections.namedtuple('_CertKeyPair', ['cert', 'key'])
+
+
+class CertKeyPair(_CertKeyPair):
+    """A TLS certificate and key for use in Flight."""
+
+
 cdef class Action:
     """An action executable on a Flight service."""
     cdef:
@@ -228,6 +235,16 @@ cdef class Location:
         return result
 
     @staticmethod
+    def for_grpc_tls(host, port):
+        """Create a Location for a TLS-based gRPC service."""
+        cdef:
+            c_string c_host = tobytes(host)
+            int c_port = port
+            Location result = Location.__new__(Location)
+        check_status(CLocation.ForGrpcTls(c_host, c_port, &result.location))
+        return result
+
+    @staticmethod
     def for_grpc_unix(path):
         """Create a Location for a domain socket-based gRPC service."""
         cdef:
@@ -1016,12 +1033,12 @@ cdef class FlightServerBase:
     cdef:
         unique_ptr[PyFlightServer] server
 
-    def run(self, location, auth_handler=None,
-            tls_cert_chain=None, tls_private_key=None):
+    def run(self, location, auth_handler=None, tls_certificates=None):
         cdef:
             PyFlightServerVtable vtable = PyFlightServerVtable()
             PyFlightServer* c_server
             unique_ptr[CFlightServerOptions] c_options
+            CCertKeyPair c_cert
 
         c_options.reset(new CFlightServerOptions(Location.unwrap(location)))
 
@@ -1032,12 +1049,11 @@ cdef class FlightServerBase:
             c_options.get().auth_handler.reset(
                 (<ServerAuthHandler> auth_handler).to_handler())
 
-        if tls_cert_chain:
-            if not tls_private_key:
-                raise ValueError(
-                    "Must provide both cert chain and private key")
-            c_options.get().tls_cert_chain = tobytes(tls_cert_chain)
-            c_options.get().tls_private_key = tobytes(tls_private_key)
+        if tls_certificates:
+            for cert, key in tls_certificates:
+                c_cert.pem_cert = tobytes(cert)
+                c_cert.pem_key = tobytes(key)
+                c_options.get().tls_certificates.push_back(c_cert)
 
         vtable.list_flights = &_list_flights
         vtable.get_flight_info = &_get_flight_info
diff --git a/python/pyarrow/flight.py b/python/pyarrow/flight.py
index 37a21e4..05198e4 100644
--- a/python/pyarrow/flight.py
+++ b/python/pyarrow/flight.py
@@ -25,6 +25,7 @@ if sys.version_info < (3,):
 from pyarrow._flight import (  # noqa
     Action,
     ActionType,
+    CertKeyPair,
     DescriptorType,
     FlightCallOptions,
     FlightClient,
diff --git a/python/pyarrow/includes/libarrow_flight.pxd b/python/pyarrow/includes/libarrow_flight.pxd
index 4b74990..14d1ed1 100644
--- a/python/pyarrow/includes/libarrow_flight.pxd
+++ b/python/pyarrow/includes/libarrow_flight.pxd
@@ -88,6 +88,8 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil:
         @staticmethod
         CStatus ForGrpcTcp(c_string& host, int port, CLocation* location)
         @staticmethod
+        CStatus ForGrpcTls(c_string& host, int port, CLocation* location)
+        @staticmethod
         CStatus ForGrpcUnix(c_string& path, CLocation* location)
 
     cdef cppclass CFlightEndpoint" arrow::flight::FlightEndpoint":
@@ -154,12 +156,16 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil:
         CFlightCallOptions()
         CTimeoutDuration timeout
 
+    cdef cppclass CCertKeyPair" arrow::flight::CertKeyPair":
+        CCertKeyPair()
+        c_string pem_cert
+        c_string pem_key
+
     cdef cppclass CFlightServerOptions" arrow::flight::FlightServerOptions":
         CFlightServerOptions(const CLocation& location)
         CLocation location
         unique_ptr[CServerAuthHandler] auth_handler
-        c_string tls_cert_chain
-        c_string tls_private_key
+        vector[CCertKeyPair] tls_certificates
 
     cdef cppclass CFlightClientOptions" arrow::flight::FlightClientOptions":
         CFlightClientOptions()
diff --git a/python/pyarrow/tests/test_flight.py b/python/pyarrow/tests/test_flight.py
index 9ce2264..a7e6e34 100644
--- a/python/pyarrow/tests/test_flight.py
+++ b/python/pyarrow/tests/test_flight.py
@@ -28,12 +28,52 @@ import pytest
 
 import pyarrow as pa
 
+from pathlib import Path
 from pyarrow.compat import tobytes
 
 
 flight = pytest.importorskip("pyarrow.flight")
 
 
+def resource_root():
+    """Get the path to the test resources directory."""
+    if not os.environ.get("ARROW_TEST_DATA"):
+        raise RuntimeError("Test resources not found; set "
+                           "ARROW_TEST_DATA to <repo root>/testing")
+    return Path(os.environ["ARROW_TEST_DATA"]) / "flight"
+
+
+def read_flight_resource(path):
+    """Get the contents of a test resource file."""
+    root = resource_root()
+    if not root:
+        return None
+    try:
+        with (root / path).open("rb") as f:
+            return f.read()
+    except FileNotFoundError as e:
+        raise RuntimeError(
+            "Test resource {} not found; did you initialize the "
+            "test resource submodule?".format(root / path)) from e
+
+
+def example_tls_certs():
+    """Get the paths to test TLS certificates."""
+    return {
+        "root_cert": read_flight_resource("root-ca.pem"),
+        "certificates": [
+            flight.CertKeyPair(
+                cert=read_flight_resource("cert0.pem"),
+                key=read_flight_resource("cert0.key"),
+            ),
+            flight.CertKeyPair(
+                cert=read_flight_resource("cert1.pem"),
+                key=read_flight_resource("cert1.key"),
+            ),
+        ]
+    }
+
+
 def simple_ints_table():
     data = [
         pa.array([-10, -5, 0, 5, 10])
@@ -245,6 +285,7 @@ class TokenClientAuthHandler(flight.ClientAuthHandler):
 def flight_server(server_base, *args, **kwargs):
     """Spawn a Flight server on a free port, shutting it down when done."""
     auth_handler = kwargs.pop('auth_handler', None)
+    tls_certificates = kwargs.pop('tls_certificates', None)
     location = kwargs.pop('location', None)
 
     if location is None:
@@ -254,7 +295,10 @@ def flight_server(server_base, *args, **kwargs):
             sock.bind(('', 0))
             sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
             port = sock.getsockname()[1]
-        location = flight.Location.for_grpc_tcp("localhost", port)
+        ctor = flight.Location.for_grpc_tcp
+        if tls_certificates:
+            ctor = flight.Location.for_grpc_tls
+        location = ctor("localhost", port)
     else:
         port = None
 
@@ -262,11 +306,26 @@ def flight_server(server_base, *args, **kwargs):
     server_instance = server_base(*args, **ctor_kwargs)
 
     def _server_thread():
-        server_instance.run(location, auth_handler=auth_handler)
+        server_instance.run(
+            location,
+            auth_handler=auth_handler,
+            tls_certificates=tls_certificates,
+        )
 
     thread = threading.Thread(target=_server_thread, daemon=True)
     thread.start()
 
+    # Wait for server to start
+    client = flight.FlightClient.connect(location)
+    while True:
+        try:
+            list(client.list_flights())
+        except Exception as e:
+            if 'Connect Failed' in str(e):
+                time.sleep(0.025)
+                continue
+            break
+
     yield location
 
     server_instance.shutdown()
@@ -471,3 +530,32 @@ def test_location_invalid():
     server = ConstantFlightServer()
     with pytest.raises(pa.ArrowInvalid, match=".*Cannot parse URI:.*"):
         server.run("%")
+
+
+@pytest.mark.slow
+def test_tls_fails():
+    """Make sure clients cannot connect when cert verification fails."""
+    certs = example_tls_certs()
+
+    with flight_server(
+            ConstantFlightServer, tls_certificates=certs["certificates"]
+    ) as server_location:
+        # Ensure client doesn't connect when certificate verification
+        # fails (this is a slow test since gRPC does retry a few times)
+        client = flight.FlightClient.connect(server_location)
+        with pytest.raises(pa.ArrowIOError, match="Connect Failed"):
+            client.do_get(flight.Ticket(b'ints'))
+
+
+def test_tls_do_get():
+    """Try a simple do_get call over TLS."""
+    table = simple_ints_table()
+    certs = example_tls_certs()
+
+    with flight_server(
+            ConstantFlightServer, tls_certificates=certs["certificates"]
+    ) as server_location:
+        client = flight.FlightClient.connect(
+            server_location, tls_root_certs=certs["root_cert"])
+        data = client.do_get(flight.Ticket(b'ints')).read_all()
+        assert data.equals(table)
diff --git a/testing b/testing
index bf0abe4..12f9dbd 160000
--- a/testing
+++ b/testing
@@ -1 +1 @@
-Subproject commit bf0abe442bf7e313380452c8972692940f4e56b6
+Subproject commit 12f9dbd2a37eea6fa370e108a1d797ee1167724a