You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by yi...@apache.org on 2022/02/10 03:05:11 UTC

[arrow] branch master updated: ARROW-15487: [FlightRPC][C++][GLib][Python][R] Implement FlightClient::Close

This is an automated email from the ASF dual-hosted git repository.

yibocai pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 647371b  ARROW-15487: [FlightRPC][C++][GLib][Python][R] Implement FlightClient::Close
647371b is described below

commit 647371b504df166860bd33346dcbd962c85e046f
Author: David Li <li...@gmail.com>
AuthorDate: Thu Feb 10 03:03:22 2022 +0000

    ARROW-15487: [FlightRPC][C++][GLib][Python][R] Implement FlightClient::Close
    
    Add a method to explicitly close FlightClient in anticipation of implementing alternative transports which may need this, and to provide an interface for things like ARROW-15473. Because this did not exist before, it is implicitly called by the destructor. For gRPC, this is a no-op.
    
    Closes #12302 from lidavidm/arrow-15487
    
    Authored-by: David Li <li...@gmail.com>
    Signed-off-by: Yibo Cai <yi...@arm.com>
---
 c_glib/arrow-flight-glib/client.cpp         |  20 ++
 c_glib/arrow-flight-glib/client.h           |   5 +
 c_glib/test/flight/test-client.rb           |   7 +
 cpp/src/arrow/flight/client.cc              |  33 +++-
 cpp/src/arrow/flight/client.h               |  11 ++
 cpp/src/arrow/flight/flight_test.cc         |  72 +++++--
 cpp/src/arrow/flight/sql/client.cc          |   2 +
 cpp/src/arrow/flight/sql/client.h           |   3 +
 cpp/src/arrow/flight/sql/server_test.cc     |   1 +
 python/pyarrow/_flight.pyx                  |  14 ++
 python/pyarrow/includes/libarrow_flight.pxd |   1 +
 python/pyarrow/tests/test_flight.py         | 289 +++++++++++++++-------------
 r/NAMESPACE                                 |   1 +
 r/R/flight.R                                |   8 +
 r/_pkgdown.yml                              |   1 +
 r/man/flight_disconnect.Rd                  |  14 ++
 r/tests/testthat/test-python-flight.R       |   6 +
 17 files changed, 339 insertions(+), 149 deletions(-)

diff --git a/c_glib/arrow-flight-glib/client.cpp b/c_glib/arrow-flight-glib/client.cpp
index 7610fc9..c0be5b8 100644
--- a/c_glib/arrow-flight-glib/client.cpp
+++ b/c_glib/arrow-flight-glib/client.cpp
@@ -266,6 +266,26 @@ gaflight_client_new(GAFlightLocation *location,
 }
 
 /**
+ * gaflight_client_close:
+ * @client: A #GAFlightClient.
+ * @error: (nullable): Return location for a #GError or %NULL.
+ *
+ * Returns: %TRUE on success, %FALSE if there was an error.
+ *
+ * Since: 8.0.0
+ */
+gboolean
+gaflight_client_close(GAFlightClient *client,
+                      GError **error)
+{
+  auto flight_client = gaflight_client_get_raw(client);
+  auto status = flight_client->Close();
+  return garrow::check(error,
+                       status,
+                       "[flight-client][close]");
+}
+
+/**
  * gaflight_client_list_flights:
  * @client: A #GAFlightClient.
  * @criteria: (nullable): A #GAFlightCriteria.
diff --git a/c_glib/arrow-flight-glib/client.h b/c_glib/arrow-flight-glib/client.h
index bc29711..f601e66 100644
--- a/c_glib/arrow-flight-glib/client.h
+++ b/c_glib/arrow-flight-glib/client.h
@@ -86,6 +86,11 @@ gaflight_client_new(GAFlightLocation *location,
                     GAFlightClientOptions *options,
                     GError **error);
 
+GARROW_AVAILABLE_IN_8_0
+gboolean
+gaflight_client_close(GAFlightClient *client,
+                      GError **error);
+
 GARROW_AVAILABLE_IN_5_0
 GList *
 gaflight_client_list_flights(GAFlightClient *client,
diff --git a/c_glib/test/flight/test-client.rb b/c_glib/test/flight/test-client.rb
index f6660a4..48f0322 100644
--- a/c_glib/test/flight/test-client.rb
+++ b/c_glib/test/flight/test-client.rb
@@ -36,6 +36,13 @@ class TestFlightClient < Test::Unit::TestCase
     @server.shutdown
   end
 
+  def test_close
+    client = ArrowFlight::Client.new(@location)
+    client.close
+    # Idempotent
+    client.close
+  end
+
   def test_list_flights
     client = ArrowFlight::Client.new(@location)
     generator = Helper::FlightInfoGenerator.new
diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc
index 6cafcf1..14fcc6a 100644
--- a/cpp/src/arrow/flight/client.cc
+++ b/cpp/src/arrow/flight/client.cc
@@ -1290,7 +1290,13 @@ class FlightClient::FlightClientImpl {
 
 FlightClient::FlightClient() { impl_.reset(new FlightClientImpl); }
 
-FlightClient::~FlightClient() {}
+FlightClient::~FlightClient() {
+  auto st = Close();
+  if (!st.ok()) {
+    ARROW_LOG(WARNING) << "FlightClient::~FlightClient(): Close() failed: "
+                       << st.ToString();
+  }
+}
 
 Status FlightClient::Connect(const Location& location,
                              std::unique_ptr<FlightClient>* client) {
@@ -1305,49 +1311,58 @@ Status FlightClient::Connect(const Location& location, const FlightClientOptions
 
 Status FlightClient::Authenticate(const FlightCallOptions& options,
                                   std::unique_ptr<ClientAuthHandler> auth_handler) {
+  RETURN_NOT_OK(CheckOpen());
   return impl_->Authenticate(options, std::move(auth_handler));
 }
 
 arrow::Result<std::pair<std::string, std::string>> FlightClient::AuthenticateBasicToken(
     const FlightCallOptions& options, const std::string& username,
     const std::string& password) {
+  RETURN_NOT_OK(CheckOpen());
   return impl_->AuthenticateBasicToken(options, username, password);
 }
 
 Status FlightClient::DoAction(const FlightCallOptions& options, const Action& action,
                               std::unique_ptr<ResultStream>* results) {
+  RETURN_NOT_OK(CheckOpen());
   return impl_->DoAction(options, action, results);
 }
 
 Status FlightClient::ListActions(const FlightCallOptions& options,
                                  std::vector<ActionType>* actions) {
+  RETURN_NOT_OK(CheckOpen());
   return impl_->ListActions(options, actions);
 }
 
 Status FlightClient::GetFlightInfo(const FlightCallOptions& options,
                                    const FlightDescriptor& descriptor,
                                    std::unique_ptr<FlightInfo>* info) {
+  RETURN_NOT_OK(CheckOpen());
   return impl_->GetFlightInfo(options, descriptor, info);
 }
 
 Status FlightClient::GetSchema(const FlightCallOptions& options,
                                const FlightDescriptor& descriptor,
                                std::unique_ptr<SchemaResult>* schema_result) {
+  RETURN_NOT_OK(CheckOpen());
   return impl_->GetSchema(options, descriptor, schema_result);
 }
 
 Status FlightClient::ListFlights(std::unique_ptr<FlightListing>* listing) {
+  RETURN_NOT_OK(CheckOpen());
   return ListFlights({}, {}, listing);
 }
 
 Status FlightClient::ListFlights(const FlightCallOptions& options,
                                  const Criteria& criteria,
                                  std::unique_ptr<FlightListing>* listing) {
+  RETURN_NOT_OK(CheckOpen());
   return impl_->ListFlights(options, criteria, listing);
 }
 
 Status FlightClient::DoGet(const FlightCallOptions& options, const Ticket& ticket,
                            std::unique_ptr<FlightStreamReader>* stream) {
+  RETURN_NOT_OK(CheckOpen());
   return impl_->DoGet(options, ticket, stream);
 }
 
@@ -1356,6 +1371,7 @@ Status FlightClient::DoPut(const FlightCallOptions& options,
                            const std::shared_ptr<Schema>& schema,
                            std::unique_ptr<FlightStreamWriter>* stream,
                            std::unique_ptr<FlightMetadataReader>* reader) {
+  RETURN_NOT_OK(CheckOpen());
   return impl_->DoPut(options, descriptor, schema, stream, reader);
 }
 
@@ -1363,8 +1379,23 @@ Status FlightClient::DoExchange(const FlightCallOptions& options,
                                 const FlightDescriptor& descriptor,
                                 std::unique_ptr<FlightStreamWriter>* writer,
                                 std::unique_ptr<FlightStreamReader>* reader) {
+  RETURN_NOT_OK(CheckOpen());
   return impl_->DoExchange(options, descriptor, writer, reader);
 }
 
+Status FlightClient::Close() {
+  // gRPC doesn't offer an explicit shutdown
+  impl_.reset(nullptr);
+  // TODO(ARROW-15473): if we track ongoing RPCs, we can cancel them first
+  return Status::OK();
+}
+
+Status FlightClient::CheckOpen() const {
+  if (!impl_) {
+    return Status::Invalid("FlightClient is closed");
+  }
+  return Status::OK();
+}
+
 }  // namespace flight
 }  // namespace arrow
diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h
index fecc510..15c6705 100644
--- a/cpp/src/arrow/flight/client.h
+++ b/cpp/src/arrow/flight/client.h
@@ -323,8 +323,19 @@ class ARROW_FLIGHT_EXPORT FlightClient {
     return DoExchange({}, descriptor, writer, reader);
   }
 
+  /// \brief Explicitly shut down and clean up the client.
+  ///
+  /// For backwards compatibility, this will be implicitly called by
+  /// the destructor if not already called, but this gives the
+  /// application no chance to handle errors, so it is recommended to
+  /// explicitly close the client.
+  ///
+  /// \since 8.0.0
+  Status Close();
+
  private:
   FlightClient();
+  Status CheckOpen() const;
   class FlightClientImpl;
   std::unique_ptr<FlightClientImpl> impl_;
 };
diff --git a/cpp/src/arrow/flight/flight_test.cc b/cpp/src/arrow/flight/flight_test.cc
index 339f7b4..df3f602 100644
--- a/cpp/src/arrow/flight/flight_test.cc
+++ b/cpp/src/arrow/flight/flight_test.cc
@@ -194,7 +194,9 @@ TEST(TestFlight, ConnectUri) {
   ASSERT_OK(Location::Parse(uri, &location1));
   ASSERT_OK(Location::Parse(uri, &location2));
   ASSERT_OK(FlightClient::Connect(location1, &client));
+  ASSERT_OK(client->Close());
   ASSERT_OK(FlightClient::Connect(location2, &client));
+  ASSERT_OK(client->Close());
 }
 
 #ifndef _WIN32
@@ -213,7 +215,9 @@ TEST(TestFlight, ConnectUriUnix) {
   ASSERT_OK(Location::Parse(uri, &location1));
   ASSERT_OK(Location::Parse(uri, &location2));
   ASSERT_OK(FlightClient::Connect(location1, &client));
+  ASSERT_OK(client->Close());
   ASSERT_OK(FlightClient::Connect(location2, &client));
+  ASSERT_OK(client->Close());
 }
 #endif
 
@@ -405,7 +409,10 @@ class TestFlightClient : public ::testing::Test {
     ASSERT_OK(ConnectClient());
   }
 
-  void TearDown() { ASSERT_OK(server_->Shutdown()); }
+  void TearDown() {
+    ASSERT_OK(client_->Close());
+    ASSERT_OK(server_->Shutdown());
+  }
 
   Status ConnectClient() {
     Location location;
@@ -631,7 +638,10 @@ class TestMetadata : public ::testing::Test {
         [](FlightClientOptions* options) { return Status::OK(); }));
   }
 
-  void TearDown() { ASSERT_OK(server_->Shutdown()); }
+  void TearDown() {
+    ASSERT_OK(client_->Close());
+    ASSERT_OK(server_->Shutdown());
+  }
 
  protected:
   std::unique_ptr<FlightClient> client_;
@@ -646,7 +656,10 @@ class TestOptions : public ::testing::Test {
         [](FlightClientOptions* options) { return Status::OK(); }));
   }
 
-  void TearDown() { ASSERT_OK(server_->Shutdown()); }
+  void TearDown() {
+    ASSERT_OK(client_->Close());
+    ASSERT_OK(server_->Shutdown());
+  }
 
  protected:
   std::unique_ptr<FlightClient> client_;
@@ -666,7 +679,10 @@ class TestAuthHandler : public ::testing::Test {
         [](FlightClientOptions* options) { return Status::OK(); }));
   }
 
-  void TearDown() { ASSERT_OK(server_->Shutdown()); }
+  void TearDown() {
+    ASSERT_OK(client_->Close());
+    ASSERT_OK(server_->Shutdown());
+  }
 
  protected:
   std::unique_ptr<FlightClient> client_;
@@ -686,7 +702,10 @@ class TestBasicAuthHandler : public ::testing::Test {
         [](FlightClientOptions* options) { return Status::OK(); }));
   }
 
-  void TearDown() { ASSERT_OK(server_->Shutdown()); }
+  void TearDown() {
+    ASSERT_OK(client_->Close());
+    ASSERT_OK(server_->Shutdown());
+  }
 
  protected:
   std::unique_ptr<FlightClient> client_;
@@ -702,7 +721,10 @@ class TestDoPut : public ::testing::Test {
     do_put_server_ = (DoPutTestServer*)server_.get();
   }
 
-  void TearDown() { ASSERT_OK(server_->Shutdown()); }
+  void TearDown() {
+    ASSERT_OK(client_->Close());
+    ASSERT_OK(server_->Shutdown());
+  }
 
   void CheckBatches(FlightDescriptor expected_descriptor,
                     const BatchVector& expected_batches) {
@@ -758,6 +780,7 @@ class TestTls : public ::testing::Test {
   }
 
   void TearDown() {
+    ASSERT_OK(client_->Close());
     ASSERT_OK(server_->Shutdown());
     grpc_shutdown();
   }
@@ -1070,7 +1093,10 @@ class TestRejectServerMiddleware : public ::testing::Test {
         [](FlightClientOptions* options) { return Status::OK(); }));
   }
 
-  void TearDown() { ASSERT_OK(server_->Shutdown()); }
+  void TearDown() {
+    ASSERT_OK(client_->Close());
+    ASSERT_OK(server_->Shutdown());
+  }
 
  protected:
   std::unique_ptr<FlightClient> client_;
@@ -1090,7 +1116,10 @@ class TestCountingServerMiddleware : public ::testing::Test {
         [](FlightClientOptions* options) { return Status::OK(); }));
   }
 
-  void TearDown() { ASSERT_OK(server_->Shutdown()); }
+  void TearDown() {
+    ASSERT_OK(client_->Close());
+    ASSERT_OK(server_->Shutdown());
+  }
 
  protected:
   std::shared_ptr<CountingServerMiddlewareFactory> request_counter_;
@@ -1144,6 +1173,7 @@ class TestPropagatingMiddleware : public ::testing::Test {
   }
 
   void TearDown() {
+    ASSERT_OK(client_->Close());
     ASSERT_OK(first_server_->Shutdown());
     ASSERT_OK(second_server_->Shutdown());
   }
@@ -1174,7 +1204,10 @@ class TestErrorMiddleware : public ::testing::Test {
         [](FlightClientOptions* options) { return Status::OK(); }));
   }
 
-  void TearDown() { ASSERT_OK(server_->Shutdown()); }
+  void TearDown() {
+    ASSERT_OK(client_->Close());
+    ASSERT_OK(server_->Shutdown());
+  }
 
  protected:
   std::unique_ptr<FlightClient> client_;
@@ -1222,7 +1255,10 @@ class TestBasicHeaderAuthMiddleware : public ::testing::Test {
                 ::testing::HasSubstr("Invalid credentials"));
   }
 
-  void TearDown() { ASSERT_OK(server_->Shutdown()); }
+  void TearDown() {
+    ASSERT_OK(client_->Close());
+    ASSERT_OK(server_->Shutdown());
+  }
 
  protected:
   std::unique_ptr<FlightClient> client_;
@@ -1892,6 +1928,17 @@ TEST_F(TestFlightClient, NoTimeout) {
   ASSERT_NE(nullptr, info);
 }
 
+TEST_F(TestFlightClient, Close) {
+  // For gRPC, this is always effectively a no-op
+  ASSERT_OK(client_->Close());
+  // Idempotent
+  ASSERT_OK(client_->Close());
+
+  std::unique_ptr<FlightListing> listing;
+  EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("FlightClient is closed"),
+                                  client_->ListFlights(&listing));
+}
+
 TEST_F(TestDoPut, DoPutInts) {
   auto descr = FlightDescriptor::Path({"ints"});
   BatchVector batches;
@@ -2807,7 +2854,10 @@ class TestCancel : public ::testing::Test {
         &server_, &client_, [](FlightServerOptions* options) { return Status::OK(); },
         [](FlightClientOptions* options) { return Status::OK(); }));
   }
-  void TearDown() { ASSERT_OK(server_->Shutdown()); }
+  void TearDown() {
+    ASSERT_OK(client_->Close());
+    ASSERT_OK(server_->Shutdown());
+  }
 
  protected:
   std::unique_ptr<FlightClient> client_;
diff --git a/cpp/src/arrow/flight/sql/client.cc b/cpp/src/arrow/flight/sql/client.cc
index addb90a..e72935c 100644
--- a/cpp/src/arrow/flight/sql/client.cc
+++ b/cpp/src/arrow/flight/sql/client.cc
@@ -411,6 +411,8 @@ Status PreparedStatement::Close() {
   return Status::OK();
 }
 
+Status FlightSqlClient::Close() { return impl_->Close(); }
+
 arrow::Result<std::unique_ptr<FlightInfo>> FlightSqlClient::GetSqlInfo(
     const FlightCallOptions& options, const std::vector<int>& sql_info) {
   flight_sql_pb::CommandGetSqlInfo command;
diff --git a/cpp/src/arrow/flight/sql/client.h b/cpp/src/arrow/flight/sql/client.h
index 5bf1b3e..fd2234c 100644
--- a/cpp/src/arrow/flight/sql/client.h
+++ b/cpp/src/arrow/flight/sql/client.h
@@ -163,6 +163,9 @@ class ARROW_EXPORT FlightSqlClient {
     return info;
   }
 
+  /// \brief Explicitly shut down and clean up the client.
+  Status Close();
+
  protected:
   virtual Status DoPut(const FlightCallOptions& options,
                        const FlightDescriptor& descriptor,
diff --git a/cpp/src/arrow/flight/sql/server_test.cc b/cpp/src/arrow/flight/sql/server_test.cc
index 507745c..47a03bd 100644
--- a/cpp/src/arrow/flight/sql/server_test.cc
+++ b/cpp/src/arrow/flight/sql/server_test.cc
@@ -168,6 +168,7 @@ class TestFlightSqlServer : public ::testing::Test {
   }
 
   void TearDown() override {
+    ASSERT_OK(sql_client->Close());
     sql_client.reset();
 
     ASSERT_OK(server->Shutdown());
diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx
index 5b00c53..f8c5856 100644
--- a/python/pyarrow/_flight.pyx
+++ b/python/pyarrow/_flight.pyx
@@ -1454,6 +1454,20 @@ cdef class FlightClient(_Weakrefable):
         py_reader.reader.reset(c_reader.release())
         return py_writer, py_reader
 
+    def close(self):
+        check_flight_status(self.client.get().Close())
+
+    def __del__(self):
+        # Not ideal, but close() wasn't originally present so
+        # applications may not be calling it
+        self.close()
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        self.close()
+
 
 cdef class FlightDataStream(_Weakrefable):
     """Abstract base class for Flight data streams."""
diff --git a/python/pyarrow/includes/libarrow_flight.pxd b/python/pyarrow/includes/libarrow_flight.pxd
index 2ac737a..364821c 100644
--- a/python/pyarrow/includes/libarrow_flight.pxd
+++ b/python/pyarrow/includes/libarrow_flight.pxd
@@ -347,6 +347,7 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil:
                            CFlightDescriptor& descriptor,
                            unique_ptr[CFlightStreamWriter]* writer,
                            unique_ptr[CFlightStreamReader]* reader)
+        CStatus Close()
 
     cdef cppclass CFlightStatusCode" arrow::flight::FlightStatusCode":
         bint operator==(CFlightStatusCode)
diff --git a/python/pyarrow/tests/test_flight.py b/python/pyarrow/tests/test_flight.py
index f3b6b91..12f815d 100644
--- a/python/pyarrow/tests/test_flight.py
+++ b/python/pyarrow/tests/test_flight.py
@@ -882,38 +882,48 @@ def test_client_wait_for_available():
         server = FlightServerBase(location)
         server.serve()
 
-    client = FlightClient(location)
-    thread = threading.Thread(target=serve, daemon=True)
-    thread.start()
+    with FlightClient(location) as client:
+        thread = threading.Thread(target=serve, daemon=True)
+        thread.start()
 
-    started = time.time()
-    client.wait_for_available(timeout=5)
-    elapsed = time.time() - started
-    assert elapsed >= 0.5
+        started = time.time()
+        client.wait_for_available(timeout=5)
+        elapsed = time.time() - started
+        assert elapsed >= 0.5
 
 
 def test_flight_list_flights():
     """Try a simple list_flights call."""
-    with ConstantFlightServer() as server:
-        client = flight.connect(('localhost', server.port))
+    with ConstantFlightServer() as server, \
+            flight.connect(('localhost', server.port)) as client:
         assert list(client.list_flights()) == []
         flights = client.list_flights(ConstantFlightServer.CRITERIA)
         assert len(list(flights)) == 1
 
 
+def test_flight_client_close():
+    with ConstantFlightServer() as server, \
+            flight.connect(('localhost', server.port)) as client:
+        assert list(client.list_flights()) == []
+        client.close()
+        client.close()  # Idempotent
+        with pytest.raises(pa.ArrowInvalid):
+            list(client.list_flights())
+
+
 def test_flight_do_get_ints():
     """Try a simple do_get call."""
     table = simple_ints_table()
 
-    with ConstantFlightServer() as server:
-        client = flight.connect(('localhost', server.port))
+    with ConstantFlightServer() as server, \
+            flight.connect(('localhost', server.port)) as client:
         data = client.do_get(flight.Ticket(b'ints')).read_all()
         assert data.equals(table)
 
     options = pa.ipc.IpcWriteOptions(
         metadata_version=pa.ipc.MetadataVersion.V4)
-    with ConstantFlightServer(options=options) as server:
-        client = flight.connect(('localhost', server.port))
+    with ConstantFlightServer(options=options) as server, \
+            flight.connect(('localhost', server.port)) as client:
         data = client.do_get(flight.Ticket(b'ints')).read_all()
         assert data.equals(table)
 
@@ -923,8 +933,8 @@ def test_flight_do_get_ints():
 
     with pytest.raises(flight.FlightServerError,
                        match="expected IpcWriteOptions, got <class 'int'>"):
-        with ConstantFlightServer(options=42) as server:
-            client = flight.connect(('localhost', server.port))
+        with ConstantFlightServer(options=42) as server, \
+                flight.connect(('localhost', server.port)) as client:
             data = client.do_get(flight.Ticket(b'ints')).read_all()
 
 
@@ -933,8 +943,8 @@ def test_do_get_ints_pandas():
     """Try a simple do_get call."""
     table = simple_ints_table()
 
-    with ConstantFlightServer() as server:
-        client = flight.connect(('localhost', server.port))
+    with ConstantFlightServer() as server, \
+            flight.connect(('localhost', server.port)) as client:
         data = client.do_get(flight.Ticket(b'ints')).read_pandas()
         assert list(data['some_ints']) == table.column(0).to_pylist()
 
@@ -942,8 +952,8 @@ def test_do_get_ints_pandas():
 def test_flight_do_get_dicts():
     table = simple_dicts_table()
 
-    with ConstantFlightServer() as server:
-        client = flight.connect(('localhost', server.port))
+    with ConstantFlightServer() as server, \
+            flight.connect(('localhost', server.port)) as client:
         data = client.do_get(flight.Ticket(b'dicts')).read_all()
         assert data.equals(table)
 
@@ -952,8 +962,8 @@ def test_flight_do_get_ticket():
     """Make sure Tickets get passed to the server."""
     data1 = [pa.array([-10, -5, 0, 5, 10], type=pa.int32())]
     table = pa.Table.from_arrays(data1, names=['a'])
-    with CheckTicketFlightServer(expected_ticket=b'the-ticket') as server:
-        client = flight.connect(('localhost', server.port))
+    with CheckTicketFlightServer(expected_ticket=b'the-ticket') as server, \
+            flight.connect(('localhost', server.port)) as client:
         data = client.do_get(flight.Ticket(b'the-ticket')).read_all()
         assert data.equals(table)
 
@@ -975,8 +985,8 @@ def test_flight_get_info():
 
 def test_flight_get_schema():
     """Make sure GetSchema returns correct schema."""
-    with GetInfoFlightServer() as server:
-        client = FlightClient(('localhost', server.port))
+    with GetInfoFlightServer() as server, \
+            FlightClient(('localhost', server.port)) as client:
         info = client.get_schema(flight.FlightDescriptor.for_command(b''))
         assert info.schema == pa.schema([('a', pa.int32())])
 
@@ -984,8 +994,8 @@ def test_flight_get_schema():
 def test_list_actions():
     """Make sure the return type of ListActions is validated."""
     # ARROW-6392
-    with ListActionsErrorFlightServer() as server:
-        client = FlightClient(('localhost', server.port))
+    with ListActionsErrorFlightServer() as server, \
+            FlightClient(('localhost', server.port)) as client:
         with pytest.raises(
                 flight.FlightServerError,
                 match=("Results of list_actions must be "
@@ -993,8 +1003,8 @@ def test_list_actions():
         ):
             list(client.list_actions())
 
-    with ListActionsFlightServer() as server:
-        client = FlightClient(('localhost', server.port))
+    with ListActionsFlightServer() as server, \
+            FlightClient(('localhost', server.port)) as client:
         assert list(client.list_actions()) == \
             ListActionsFlightServer.expected_actions()
 
@@ -1020,8 +1030,8 @@ class ConvenienceServer(FlightServerBase):
 
 
 def test_do_action_result_convenience():
-    with ConvenienceServer() as server:
-        client = FlightClient(('localhost', server.port))
+    with ConvenienceServer() as server, \
+            FlightClient(('localhost', server.port)) as client:
 
         # do_action as action type without body
         results = [x.body for x in client.do_action('simple-action')]
@@ -1034,8 +1044,8 @@ def test_do_action_result_convenience():
 
 
 def test_nicer_server_exceptions():
-    with ConvenienceServer() as server:
-        client = FlightClient(('localhost', server.port))
+    with ConvenienceServer() as server, \
+            FlightClient(('localhost', server.port)) as client:
         with pytest.raises(flight.FlightServerError,
                            match="a bytes-like object is required"):
             list(client.do_action('bad-action'))
@@ -1064,8 +1074,8 @@ def test_flight_domain_socket():
     with tempfile.NamedTemporaryFile() as sock:
         sock.close()
         location = flight.Location.for_grpc_unix(sock.name)
-        with ConstantFlightServer(location=location):
-            client = FlightClient(location)
+        with ConstantFlightServer(location=location), \
+                FlightClient(location) as client:
 
             reader = client.do_get(flight.Ticket(b'ints'))
             table = simple_ints_table()
@@ -1091,8 +1101,8 @@ def test_flight_large_message():
         pa.array(range(0, 10 * 1024 * 1024))
     ], names=['a'])
 
-    with EchoFlightServer(expected_schema=data.schema) as server:
-        client = FlightClient(('localhost', server.port))
+    with EchoFlightServer(expected_schema=data.schema) as server, \
+            FlightClient(('localhost', server.port)) as client:
         writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'),
                                   data.schema)
         # Write a single giant chunk
@@ -1108,8 +1118,8 @@ def test_flight_generator_stream():
         pa.array(range(0, 10 * 1024))
     ], names=['a'])
 
-    with EchoStreamFlightServer() as server:
-        client = FlightClient(('localhost', server.port))
+    with EchoStreamFlightServer() as server, \
+            FlightClient(('localhost', server.port)) as client:
         writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'),
                                   data.schema)
         writer.write_table(data)
@@ -1120,8 +1130,8 @@ def test_flight_generator_stream():
 
 def test_flight_invalid_generator_stream():
     """Try streaming data with mismatched schemas."""
-    with InvalidStreamFlightServer() as server:
-        client = FlightClient(('localhost', server.port))
+    with InvalidStreamFlightServer() as server, \
+            FlightClient(('localhost', server.port)) as client:
         with pytest.raises(pa.ArrowException):
             client.do_get(flight.Ticket(b'')).read_all()
 
@@ -1130,8 +1140,8 @@ def test_timeout_fires():
     """Make sure timeouts fire on slow requests."""
     # Do this in a separate thread so that if it fails, we don't hang
     # the entire test process
-    with SlowFlightServer() as server:
-        client = FlightClient(('localhost', server.port))
+    with SlowFlightServer() as server, \
+            FlightClient(('localhost', server.port)) as client:
         action = flight.Action("", b"")
         options = flight.FlightCallOptions(timeout=0.2)
         # gRPC error messages change based on version, so don't look
@@ -1142,8 +1152,8 @@ def test_timeout_fires():
 
 def test_timeout_passes():
     """Make sure timeouts do not fire on fast requests."""
-    with ConstantFlightServer() as server:
-        client = FlightClient(('localhost', server.port))
+    with ConstantFlightServer() as server, \
+            FlightClient(('localhost', server.port)) as client:
         options = flight.FlightCallOptions(timeout=5.0)
         client.do_get(flight.Ticket(b'ints'), options=options).read_all()
 
@@ -1160,8 +1170,8 @@ token_auth_handler = TokenServerAuthHandler(creds={
 @pytest.mark.slow
 def test_http_basic_unauth():
     """Test that auth fails when not authenticated."""
-    with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server:
-        client = FlightClient(('localhost', server.port))
+    with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server, \
+            FlightClient(('localhost', server.port)) as client:
         action = flight.Action("who-am-i", b"")
         with pytest.raises(flight.FlightUnauthenticatedError,
                            match=".*unauthenticated.*"):
@@ -1172,8 +1182,8 @@ def test_http_basic_unauth():
                     reason="ARROW-10013: gRPC on Windows corrupts peer()")
 def test_http_basic_auth():
     """Test a Python implementation of HTTP basic authentication."""
-    with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server:
-        client = FlightClient(('localhost', server.port))
+    with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server, \
+            FlightClient(('localhost', server.port)) as client:
         action = flight.Action("who-am-i", b"")
         client.authenticate(HttpBasicClientAuthHandler('test', 'p4ssw0rd'))
         results = client.do_action(action)
@@ -1185,8 +1195,8 @@ def test_http_basic_auth():
 
 def test_http_basic_auth_invalid_password():
     """Test that auth fails with the wrong password."""
-    with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server:
-        client = FlightClient(('localhost', server.port))
+    with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server, \
+            FlightClient(('localhost', server.port)) as client:
         action = flight.Action("who-am-i", b"")
         with pytest.raises(flight.FlightUnauthenticatedError,
                            match=".*wrong password.*"):
@@ -1196,8 +1206,8 @@ def test_http_basic_auth_invalid_password():
 
 def test_token_auth():
     """Test an auth mechanism that uses a handshake."""
-    with EchoStreamFlightServer(auth_handler=token_auth_handler) as server:
-        client = FlightClient(('localhost', server.port))
+    with EchoStreamFlightServer(auth_handler=token_auth_handler) as server, \
+            FlightClient(('localhost', server.port)) as client:
         action = flight.Action("who-am-i", b"")
         client.authenticate(TokenClientAuthHandler('test', 'p4ssw0rd'))
         identity = next(client.do_action(action))
@@ -1206,8 +1216,8 @@ def test_token_auth():
 
 def test_token_auth_invalid():
     """Test an auth mechanism that uses a handshake."""
-    with EchoStreamFlightServer(auth_handler=token_auth_handler) as server:
-        client = FlightClient(('localhost', server.port))
+    with EchoStreamFlightServer(auth_handler=token_auth_handler) as server, \
+            FlightClient(('localhost', server.port)) as client:
         with pytest.raises(flight.FlightUnauthenticatedError):
             client.authenticate(TokenClientAuthHandler('test', 'wrong'))
 
@@ -1220,8 +1230,8 @@ def test_authenticate_basic_token():
     """Test authenticate_basic_token with bearer token and auth headers."""
     with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={
         "auth": HeaderAuthServerMiddlewareFactory()
-    }) as server:
-        client = FlightClient(('localhost', server.port))
+    }) as server, \
+            FlightClient(('localhost', server.port)) as client:
         token_pair = client.authenticate_basic_token(b'test', b'password')
         assert token_pair[0] == b'authorization'
         assert token_pair[1] == b'Bearer token1234'
@@ -1231,8 +1241,8 @@ def test_authenticate_basic_token_invalid_password():
     """Test authenticate_basic_token with an invalid password."""
     with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={
         "auth": HeaderAuthServerMiddlewareFactory()
-    }) as server:
-        client = FlightClient(('localhost', server.port))
+    }) as server, \
+            FlightClient(('localhost', server.port)) as client:
         with pytest.raises(flight.FlightUnauthenticatedError):
             client.authenticate_basic_token(b'test', b'badpassword')
 
@@ -1241,8 +1251,8 @@ def test_authenticate_basic_token_and_action():
     """Test authenticate_basic_token and doAction after authentication."""
     with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={
         "auth": HeaderAuthServerMiddlewareFactory()
-    }) as server:
-        client = FlightClient(('localhost', server.port))
+    }) as server, \
+            FlightClient(('localhost', server.port)) as client:
         token_pair = client.authenticate_basic_token(b'test', b'password')
         assert token_pair[0] == b'authorization'
         assert token_pair[1] == b'Bearer token1234'
@@ -1281,17 +1291,18 @@ def test_authenticate_basic_token_with_client_middleware():
         assert client_auth_middleware.call_credential[0] == b'authorization'
         assert client_auth_middleware.call_credential[1] == \
             b'Bearer ' + b'token1234'
+        client.close()
 
 
 def test_arbitrary_headers_in_flight_call_options():
     """Test passing multiple arbitrary headers to the middleware."""
     with ArbitraryHeadersFlightServer(
-            auth_handler=no_op_auth_handler,
-            middleware={
-                "auth": HeaderAuthServerMiddlewareFactory(),
-                "arbitrary-headers": ArbitraryHeadersServerMiddlewareFactory()
-            }) as server:
-        client = FlightClient(('localhost', server.port))
+        auth_handler=no_op_auth_handler,
+        middleware={
+            "auth": HeaderAuthServerMiddlewareFactory(),
+            "arbitrary-headers": ArbitraryHeadersServerMiddlewareFactory()
+        }) as server, \
+            FlightClient(('localhost', server.port)) as client:
         token_pair = client.authenticate_basic_token(b'test', b'password')
         assert token_pair[0] == b'authorization'
         assert token_pair[1] == b'Bearer token1234'
@@ -1328,11 +1339,10 @@ def test_tls_fails():
     """Make sure clients cannot connect when cert verification fails."""
     certs = example_tls_certs()
 
-    with ConstantFlightServer(tls_certificates=certs["certificates"]) as s:
-        # Ensure client doesn't connect when certificate verification
-        # fails (this is a slow test since gRPC does retry a few times)
-        client = FlightClient("grpc+tls://localhost:" + str(s.port))
-
+    # Ensure client doesn't connect when certificate verification
+    # fails (this is a slow test since gRPC does retry a few times)
+    with ConstantFlightServer(tls_certificates=certs["certificates"]) as s, \
+            FlightClient("grpc+tls://localhost:" + str(s.port)) as client:
         # gRPC error messages change based on version, so don't look
         # for a particular error
         with pytest.raises(flight.FlightUnavailableError):
@@ -1345,9 +1355,9 @@ def test_tls_do_get():
     table = simple_ints_table()
     certs = example_tls_certs()
 
-    with ConstantFlightServer(tls_certificates=certs["certificates"]) as s:
-        client = FlightClient(('localhost', s.port),
-                              tls_root_certs=certs["root_cert"])
+    with ConstantFlightServer(tls_certificates=certs["certificates"]) as s, \
+        FlightClient(('localhost', s.port),
+                     tls_root_certs=certs["root_cert"]) as client:
         data = client.do_get(flight.Ticket(b'ints')).read_all()
         assert data.equals(table)
 
@@ -1366,6 +1376,7 @@ def test_tls_disable_server_verification():
             pytest.skip('disable_server_verification feature is not available')
         data = client.do_get(flight.Ticket(b'ints')).read_all()
         assert data.equals(table)
+        client.close()
 
 
 @pytest.mark.requires_testing_data
@@ -1373,10 +1384,10 @@ def test_tls_override_hostname():
     """Check that incorrectly overriding the hostname fails."""
     certs = example_tls_certs()
 
-    with ConstantFlightServer(tls_certificates=certs["certificates"]) as s:
-        client = flight.connect(('localhost', s.port),
-                                tls_root_certs=certs["root_cert"],
-                                override_hostname="fakehostname")
+    with ConstantFlightServer(tls_certificates=certs["certificates"]) as s,\
+        flight.connect(('localhost', s.port),
+                       tls_root_certs=certs["root_cert"],
+                       override_hostname="fakehostname") as client:
         with pytest.raises(flight.FlightUnavailableError):
             client.do_get(flight.Ticket(b'ints'))
 
@@ -1389,8 +1400,8 @@ def test_flight_do_get_metadata():
     table = pa.Table.from_arrays(data, names=['a'])
 
     batches = []
-    with MetadataFlightServer() as server:
-        client = FlightClient(('localhost', server.port))
+    with MetadataFlightServer() as server, \
+            FlightClient(('localhost', server.port)) as client:
         reader = client.do_get(flight.Ticket(b''))
         idx = 0
         while True:
@@ -1412,8 +1423,8 @@ def test_flight_do_get_metadata_v4():
         [pa.array([-10, -5, 0, 5, 10])], names=['a'])
     options = pa.ipc.IpcWriteOptions(
         metadata_version=pa.ipc.MetadataVersion.V4)
-    with MetadataFlightServer(options=options) as server:
-        client = FlightClient(('localhost', server.port))
+    with MetadataFlightServer(options=options) as server, \
+            FlightClient(('localhost', server.port)) as client:
         reader = client.do_get(flight.Ticket(b''))
         data = reader.read_all()
         assert data.equals(table)
@@ -1426,8 +1437,8 @@ def test_flight_do_put_metadata():
     ]
     table = pa.Table.from_arrays(data, names=['a'])
 
-    with MetadataFlightServer() as server:
-        client = FlightClient(('localhost', server.port))
+    with MetadataFlightServer() as server, \
+            FlightClient(('localhost', server.port)) as client:
         writer, metadata_reader = client.do_put(
             flight.FlightDescriptor.for_path(''),
             table.schema)
@@ -1447,9 +1458,9 @@ def test_flight_do_put_limit():
         pa.array(np.ones(768, dtype=np.int64())),
     ], names=['a'])
 
-    with EchoFlightServer() as server:
-        client = FlightClient(('localhost', server.port),
-                              write_size_limit_bytes=4096)
+    with EchoFlightServer() as server, \
+        FlightClient(('localhost', server.port),
+                     write_size_limit_bytes=4096) as client:
         writer, metadata_reader = client.do_put(
             flight.FlightDescriptor.for_path(''),
             large_batch.schema)
@@ -1472,8 +1483,8 @@ def test_flight_do_put_limit():
 @pytest.mark.slow
 def test_cancel_do_get():
     """Test canceling a DoGet operation on the client side."""
-    with ConstantFlightServer() as server:
-        client = FlightClient(('localhost', server.port))
+    with ConstantFlightServer() as server, \
+            FlightClient(('localhost', server.port)) as client:
         reader = client.do_get(flight.Ticket(b'ints'))
         reader.cancel()
         with pytest.raises(flight.FlightCancelledError, match=".*Cancel.*"):
@@ -1483,8 +1494,8 @@ def test_cancel_do_get():
 @pytest.mark.slow
 def test_cancel_do_get_threaded():
     """Test canceling a DoGet operation from another thread."""
-    with SlowFlightServer() as server:
-        client = FlightClient(('localhost', server.port))
+    with SlowFlightServer() as server, \
+            FlightClient(('localhost', server.port)) as client:
         reader = client.do_get(flight.Ticket(b'ints'))
 
         read_first_message = threading.Event()
@@ -1547,8 +1558,8 @@ def test_roundtrip_types():
 
 def test_roundtrip_errors():
     """Ensure that Flight errors propagate from server to client."""
-    with ErrorFlightServer() as server:
-        client = FlightClient(('localhost', server.port))
+    with ErrorFlightServer() as server, \
+            FlightClient(('localhost', server.port)) as client:
 
         with pytest.raises(flight.FlightInternalError, match=".*foo.*"):
             list(client.do_action(flight.Action("internal", b"")))
@@ -1600,8 +1611,8 @@ def test_do_put_independent_read_write():
     ]
     table = pa.Table.from_arrays(data, names=['a'])
 
-    with MetadataFlightServer() as server:
-        client = FlightClient(('localhost', server.port))
+    with MetadataFlightServer() as server, \
+            FlightClient(('localhost', server.port)) as client:
         writer, metadata_reader = client.do_put(
             flight.FlightDescriptor.for_path(''),
             table.schema)
@@ -1633,8 +1644,8 @@ def test_server_middleware_same_thread():
     """Ensure that server middleware run on the same thread as the RPC."""
     with HeaderFlightServer(middleware={
         "test": HeaderServerMiddlewareFactory(),
-    }) as server:
-        client = FlightClient(('localhost', server.port))
+    }) as server, \
+            FlightClient(('localhost', server.port)) as client:
         results = list(client.do_action(flight.Action(b"test", b"")))
         assert len(results) == 1
         value = results[0].body.to_pybytes()
@@ -1645,8 +1656,8 @@ def test_middleware_reject():
     """Test rejecting an RPC with server middleware."""
     with HeaderFlightServer(middleware={
         "test": SelectiveAuthServerMiddlewareFactory(),
-    }) as server:
-        client = FlightClient(('localhost', server.port))
+    }) as server, \
+            FlightClient(('localhost', server.port)) as client:
         # The middleware allows this through without auth.
         with pytest.raises(pa.ArrowNotImplementedError):
             list(client.list_actions())
@@ -1667,11 +1678,11 @@ def test_middleware_mapping():
     """Test that middleware records methods correctly."""
     server_middleware = RecordingServerMiddlewareFactory()
     client_middleware = RecordingClientMiddlewareFactory()
-    with FlightServerBase(middleware={"test": server_middleware}) as server:
-        client = FlightClient(
+    with FlightServerBase(middleware={"test": server_middleware}) as server, \
+        FlightClient(
             ('localhost', server.port),
             middleware=[client_middleware]
-        )
+    ) as client:
 
         descriptor = flight.FlightDescriptor.for_command(b"")
         with pytest.raises(NotImplementedError):
@@ -1708,8 +1719,8 @@ def test_middleware_mapping():
 
 
 def test_extra_info():
-    with ErrorFlightServer() as server:
-        client = FlightClient(('localhost', server.port))
+    with ErrorFlightServer() as server, \
+            FlightClient(('localhost', server.port)) as client:
         try:
             list(client.do_action(flight.Action("protobuf", b"")))
             assert False
@@ -1728,12 +1739,12 @@ def test_mtls():
     with ConstantFlightServer(
             tls_certificates=[certs["certificates"][0]],
             verify_client=True,
-            root_certificates=certs["root_cert"]) as s:
-        client = FlightClient(
+            root_certificates=certs["root_cert"]) as s, \
+        FlightClient(
             ('localhost', s.port),
             tls_root_certs=certs["root_cert"],
             cert_chain=certs["certificates"][0].cert,
-            private_key=certs["certificates"][0].key)
+            private_key=certs["certificates"][0].key) as client:
         data = client.do_get(flight.Ticket(b'ints')).read_all()
         assert data.equals(table)
 
@@ -1744,8 +1755,8 @@ def test_doexchange_get():
         pa.array(range(0, 10 * 1024))
     ], names=["a"])
 
-    with ExchangeFlightServer() as server:
-        client = FlightClient(("localhost", server.port))
+    with ExchangeFlightServer() as server, \
+            FlightClient(("localhost", server.port)) as client:
         descriptor = flight.FlightDescriptor.for_command(b"get")
         writer, reader = client.do_exchange(descriptor)
         with writer:
@@ -1760,8 +1771,8 @@ def test_doexchange_put():
     ], names=["a"])
     batches = data.to_batches(max_chunksize=512)
 
-    with ExchangeFlightServer() as server:
-        client = FlightClient(("localhost", server.port))
+    with ExchangeFlightServer() as server, \
+            FlightClient(("localhost", server.port)) as client:
         descriptor = flight.FlightDescriptor.for_command(b"put")
         writer, reader = client.do_exchange(descriptor)
         with writer:
@@ -1782,8 +1793,8 @@ def test_doexchange_echo():
     ], names=["a"])
     batches = data.to_batches(max_chunksize=512)
 
-    with ExchangeFlightServer() as server:
-        client = FlightClient(("localhost", server.port))
+    with ExchangeFlightServer() as server, \
+            FlightClient(("localhost", server.port)) as client:
         descriptor = flight.FlightDescriptor.for_command(b"echo")
         writer, reader = client.do_exchange(descriptor)
         with writer:
@@ -1822,8 +1833,8 @@ def test_doexchange_echo_v4():
 
     options = pa.ipc.IpcWriteOptions(
         metadata_version=pa.ipc.MetadataVersion.V4)
-    with ExchangeFlightServer(options=options) as server:
-        client = FlightClient(("localhost", server.port))
+    with ExchangeFlightServer(options=options) as server, \
+            FlightClient(("localhost", server.port)) as client:
         descriptor = flight.FlightDescriptor.for_command(b"echo")
         writer, reader = client.do_exchange(descriptor)
         with writer:
@@ -1848,8 +1859,8 @@ def test_doexchange_transform():
         pa.array(range(3, 1024 * 3 + 3, 3)),
     ], names=["sum"])
 
-    with ExchangeFlightServer() as server:
-        client = FlightClient(("localhost", server.port))
+    with ExchangeFlightServer() as server, \
+            FlightClient(("localhost", server.port)) as client:
         descriptor = flight.FlightDescriptor.for_command(b"transform")
         writer, reader = client.do_exchange(descriptor)
         with writer:
@@ -1866,15 +1877,17 @@ def test_middleware_multi_header():
         "test": MultiHeaderServerMiddlewareFactory(),
     }) as server:
         headers = MultiHeaderClientMiddlewareFactory()
-        client = FlightClient(('localhost', server.port), middleware=[headers])
-        response = next(client.do_action(flight.Action(b"", b"")))
-        # The server echoes the headers it got back to us.
-        raw_headers = response.body.to_pybytes().decode("utf-8")
-        client_headers = ast.literal_eval(raw_headers)
-        # Don't directly compare; gRPC may add headers like User-Agent.
-        for header, values in MultiHeaderClientMiddleware.EXPECTED.items():
-            assert client_headers.get(header) == values
-            assert headers.last_headers.get(header) == values
+        with FlightClient(
+                ('localhost', server.port),
+                middleware=[headers]) as client:
+            response = next(client.do_action(flight.Action(b"", b"")))
+            # The server echoes the headers it got back to us.
+            raw_headers = response.body.to_pybytes().decode("utf-8")
+            client_headers = ast.literal_eval(raw_headers)
+            # Don't directly compare; gRPC may add headers like User-Agent.
+            for header, values in MultiHeaderClientMiddleware.EXPECTED.items():
+                assert client_headers.get(header) == values
+                assert headers.last_headers.get(header) == values
 
 
 @pytest.mark.requires_testing_data
@@ -1890,6 +1903,7 @@ def test_generic_options():
                                 generic_options=options)
         with pytest.raises(flight.FlightUnavailableError):
             client.do_get(flight.Ticket(b'ints'))
+        client.close()
         # Try setting an int argument that will make requests fail
         options = [("grpc.max_receive_message_length", 32)]
         client = flight.connect(('localhost', s.port),
@@ -1897,6 +1911,7 @@ def test_generic_options():
                                 generic_options=options)
         with pytest.raises(pa.ArrowInvalid):
             client.do_get(flight.Ticket(b'ints'))
+        client.close()
 
 
 class CancelFlightServer(FlightServerBase):
@@ -1946,8 +1961,8 @@ def test_interrupt():
         assert isinstance(e, (pa.ArrowCancelled, KeyboardInterrupt)) or \
             isinstance(e.__context__, (pa.ArrowCancelled, KeyboardInterrupt))
 
-    with CancelFlightServer() as server:
-        client = FlightClient(("localhost", server.port))
+    with CancelFlightServer() as server, \
+            FlightClient(("localhost", server.port)) as client:
 
         reader = client.do_get(flight.Ticket(b""))
         test(reader.read_all)
@@ -1960,8 +1975,8 @@ def test_interrupt():
 def test_never_sends_data():
     # Regression test for ARROW-12779
     match = "application server implementation error"
-    with NeverSendsDataFlightServer() as server:
-        client = flight.connect(('localhost', server.port))
+    with NeverSendsDataFlightServer() as server, \
+            flight.connect(('localhost', server.port)) as client:
         with pytest.raises(flight.FlightServerError, match=match):
             client.do_get(flight.Ticket(b'')).read_all()
 
@@ -1978,8 +1993,8 @@ def test_large_descriptor():
     # since some CI pipelines can't run the C++ equivalent
     large_descriptor = flight.FlightDescriptor.for_command(
         b' ' * (2 ** 31 + 1))
-    with FlightServerBase() as server:
-        client = flight.connect(('localhost', server.port))
+    with FlightServerBase() as server, \
+            flight.connect(('localhost', server.port)) as client:
         with pytest.raises(OSError,
                            match="Failed to serialize Flight descriptor"):
             writer, _ = client.do_put(large_descriptor, pa.schema([]))
@@ -1995,8 +2010,8 @@ def test_large_metadata_client():
     # Regression test for ARROW-13253
     descriptor = flight.FlightDescriptor.for_command(b'')
     metadata = b' ' * (2 ** 31 + 1)
-    with EchoFlightServer() as server:
-        client = flight.connect(('localhost', server.port))
+    with EchoFlightServer() as server, \
+            flight.connect(('localhost', server.port)) as client:
         with pytest.raises(pa.ArrowCapacityError,
                            match="app_metadata size overflow"):
             writer, _ = client.do_put(descriptor, pa.schema([]))
@@ -2010,8 +2025,8 @@ def test_large_metadata_client():
                 writer.write_metadata(metadata)
 
     del metadata
-    with LargeMetadataFlightServer() as server:
-        client = flight.connect(('localhost', server.port))
+    with LargeMetadataFlightServer() as server, \
+            flight.connect(('localhost', server.port)) as client:
         with pytest.raises(flight.FlightServerError,
                            match="app_metadata size overflow"):
             reader = client.do_get(flight.Ticket(b''))
@@ -2042,8 +2057,8 @@ def test_none_action_side_effect():
     See https://issues.apache.org/jira/browse/ARROW-14255
     """
 
-    with ActionNoneFlightServer() as server:
-        client = FlightClient(('localhost', server.port))
+    with ActionNoneFlightServer() as server, \
+            FlightClient(('localhost', server.port)) as client:
         client.do_action(flight.Action("append", b""))
         r = client.do_action(flight.Action("get_value", b""))
         assert json.loads(next(r).body.to_pybytes()) == [True]
diff --git a/r/NAMESPACE b/r/NAMESPACE
index 6233b4f..d01f66d 100644
--- a/r/NAMESPACE
+++ b/r/NAMESPACE
@@ -227,6 +227,7 @@ export(field)
 export(fixed_size_binary)
 export(fixed_size_list_of)
 export(flight_connect)
+export(flight_disconnect)
 export(flight_get)
 export(flight_path_exists)
 export(flight_put)
diff --git a/r/R/flight.R b/r/R/flight.R
index 4d190de..f56308f 100644
--- a/r/R/flight.R
+++ b/r/R/flight.R
@@ -40,6 +40,14 @@ flight_connect <- function(host = "localhost", port, scheme = "grpc+tcp") {
   pa$flight$FlightClient(location)
 }
 
+#' Explicitly close a Flight client
+#'
+#' @param client The client to disconnect
+#' @export
+flight_disconnect <- function(client) {
+  client$close()
+}
+
 #' Send data to a Flight server
 #'
 #' @param client `pyarrow.flight.FlightClient`, as returned by [flight_connect()]
diff --git a/r/_pkgdown.yml b/r/_pkgdown.yml
index 38e808a..ba3162b 100644
--- a/r/_pkgdown.yml
+++ b/r/_pkgdown.yml
@@ -151,6 +151,7 @@ reference:
     contents:
       - load_flight_server
       - flight_connect
+      - flight_disconnect
       - flight_get
       - flight_put
       - list_flights
diff --git a/r/man/flight_disconnect.Rd b/r/man/flight_disconnect.Rd
new file mode 100644
index 0000000..83c5edf
--- /dev/null
+++ b/r/man/flight_disconnect.Rd
@@ -0,0 +1,14 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/flight.R
+\name{flight_disconnect}
+\alias{flight_disconnect}
+\title{Explicitly close a Flight client}
+\usage{
+flight_disconnect(client)
+}
+\arguments{
+\item{client}{The client to disconnect}
+}
+\description{
+Explicitly close a Flight client
+}
diff --git a/r/tests/testthat/test-python-flight.R b/r/tests/testthat/test-python-flight.R
index 0ffc7e4..db67cd6 100644
--- a/r/tests/testthat/test-python-flight.R
+++ b/r/tests/testthat/test-python-flight.R
@@ -57,6 +57,12 @@ if (process_is_running("demo_flight_server")) {
     flight_put(client, example_with_times, path = flight_obj)
     expect_identical(as.data.frame(flight_get(client, flight_obj)), example_with_times)
   })
+
+  test_that("flight_disconnect", {
+    flight_disconnect(client)
+    # Idempotent
+    flight_disconnect(client)
+  })
 } else {
   # Kinda hacky, let's put a skipped test here, just so we note that the tests
   # didn't run