You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by yi...@apache.org on 2022/02/10 03:05:11 UTC
[arrow] branch master updated: ARROW-15487: [FlightRPC][C++][GLib][Python][R] Implement FlightClient::Close
This is an automated email from the ASF dual-hosted git repository.
yibocai pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 647371b ARROW-15487: [FlightRPC][C++][GLib][Python][R] Implement FlightClient::Close
647371b is described below
commit 647371b504df166860bd33346dcbd962c85e046f
Author: David Li <li...@gmail.com>
AuthorDate: Thu Feb 10 03:03:22 2022 +0000
ARROW-15487: [FlightRPC][C++][GLib][Python][R] Implement FlightClient::Close
Add a method to explicitly close FlightClient in anticipation of implementing alternative transports which may need this, and to provide an interface for things like ARROW-15473. Because this did not exist before, it is implicitly called by the destructor. For gRPC, this is a no-op.
Closes #12302 from lidavidm/arrow-15487
Authored-by: David Li <li...@gmail.com>
Signed-off-by: Yibo Cai <yi...@arm.com>
---
c_glib/arrow-flight-glib/client.cpp | 20 ++
c_glib/arrow-flight-glib/client.h | 5 +
c_glib/test/flight/test-client.rb | 7 +
cpp/src/arrow/flight/client.cc | 33 +++-
cpp/src/arrow/flight/client.h | 11 ++
cpp/src/arrow/flight/flight_test.cc | 72 +++++--
cpp/src/arrow/flight/sql/client.cc | 2 +
cpp/src/arrow/flight/sql/client.h | 3 +
cpp/src/arrow/flight/sql/server_test.cc | 1 +
python/pyarrow/_flight.pyx | 14 ++
python/pyarrow/includes/libarrow_flight.pxd | 1 +
python/pyarrow/tests/test_flight.py | 289 +++++++++++++++-------------
r/NAMESPACE | 1 +
r/R/flight.R | 8 +
r/_pkgdown.yml | 1 +
r/man/flight_disconnect.Rd | 14 ++
r/tests/testthat/test-python-flight.R | 6 +
17 files changed, 339 insertions(+), 149 deletions(-)
diff --git a/c_glib/arrow-flight-glib/client.cpp b/c_glib/arrow-flight-glib/client.cpp
index 7610fc9..c0be5b8 100644
--- a/c_glib/arrow-flight-glib/client.cpp
+++ b/c_glib/arrow-flight-glib/client.cpp
@@ -266,6 +266,26 @@ gaflight_client_new(GAFlightLocation *location,
}
/**
+ * gaflight_client_close:
+ * @client: A #GAFlightClient.
+ * @error: (nullable): Return location for a #GError or %NULL.
+ *
+ * Returns: %TRUE on success, %FALSE if there was an error.
+ *
+ * Since: 8.0.0
+ */
+gboolean
+gaflight_client_close(GAFlightClient *client,
+ GError **error)
+{
+ auto flight_client = gaflight_client_get_raw(client);
+ auto status = flight_client->Close();
+ return garrow::check(error,
+ status,
+ "[flight-client][close]");
+}
+
+/**
* gaflight_client_list_flights:
* @client: A #GAFlightClient.
* @criteria: (nullable): A #GAFlightCriteria.
diff --git a/c_glib/arrow-flight-glib/client.h b/c_glib/arrow-flight-glib/client.h
index bc29711..f601e66 100644
--- a/c_glib/arrow-flight-glib/client.h
+++ b/c_glib/arrow-flight-glib/client.h
@@ -86,6 +86,11 @@ gaflight_client_new(GAFlightLocation *location,
GAFlightClientOptions *options,
GError **error);
+GARROW_AVAILABLE_IN_8_0
+gboolean
+gaflight_client_close(GAFlightClient *client,
+ GError **error);
+
GARROW_AVAILABLE_IN_5_0
GList *
gaflight_client_list_flights(GAFlightClient *client,
diff --git a/c_glib/test/flight/test-client.rb b/c_glib/test/flight/test-client.rb
index f6660a4..48f0322 100644
--- a/c_glib/test/flight/test-client.rb
+++ b/c_glib/test/flight/test-client.rb
@@ -36,6 +36,13 @@ class TestFlightClient < Test::Unit::TestCase
@server.shutdown
end
+ def test_close
+ client = ArrowFlight::Client.new(@location)
+ client.close
+ # Idempotent
+ client.close
+ end
+
def test_list_flights
client = ArrowFlight::Client.new(@location)
generator = Helper::FlightInfoGenerator.new
diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc
index 6cafcf1..14fcc6a 100644
--- a/cpp/src/arrow/flight/client.cc
+++ b/cpp/src/arrow/flight/client.cc
@@ -1290,7 +1290,13 @@ class FlightClient::FlightClientImpl {
FlightClient::FlightClient() { impl_.reset(new FlightClientImpl); }
-FlightClient::~FlightClient() {}
+FlightClient::~FlightClient() {
+ auto st = Close();
+ if (!st.ok()) {
+ ARROW_LOG(WARNING) << "FlightClient::~FlightClient(): Close() failed: "
+ << st.ToString();
+ }
+}
Status FlightClient::Connect(const Location& location,
std::unique_ptr<FlightClient>* client) {
@@ -1305,49 +1311,58 @@ Status FlightClient::Connect(const Location& location, const FlightClientOptions
Status FlightClient::Authenticate(const FlightCallOptions& options,
std::unique_ptr<ClientAuthHandler> auth_handler) {
+ RETURN_NOT_OK(CheckOpen());
return impl_->Authenticate(options, std::move(auth_handler));
}
arrow::Result<std::pair<std::string, std::string>> FlightClient::AuthenticateBasicToken(
const FlightCallOptions& options, const std::string& username,
const std::string& password) {
+ RETURN_NOT_OK(CheckOpen());
return impl_->AuthenticateBasicToken(options, username, password);
}
Status FlightClient::DoAction(const FlightCallOptions& options, const Action& action,
std::unique_ptr<ResultStream>* results) {
+ RETURN_NOT_OK(CheckOpen());
return impl_->DoAction(options, action, results);
}
Status FlightClient::ListActions(const FlightCallOptions& options,
std::vector<ActionType>* actions) {
+ RETURN_NOT_OK(CheckOpen());
return impl_->ListActions(options, actions);
}
Status FlightClient::GetFlightInfo(const FlightCallOptions& options,
const FlightDescriptor& descriptor,
std::unique_ptr<FlightInfo>* info) {
+ RETURN_NOT_OK(CheckOpen());
return impl_->GetFlightInfo(options, descriptor, info);
}
Status FlightClient::GetSchema(const FlightCallOptions& options,
const FlightDescriptor& descriptor,
std::unique_ptr<SchemaResult>* schema_result) {
+ RETURN_NOT_OK(CheckOpen());
return impl_->GetSchema(options, descriptor, schema_result);
}
Status FlightClient::ListFlights(std::unique_ptr<FlightListing>* listing) {
+ RETURN_NOT_OK(CheckOpen());
return ListFlights({}, {}, listing);
}
Status FlightClient::ListFlights(const FlightCallOptions& options,
const Criteria& criteria,
std::unique_ptr<FlightListing>* listing) {
+ RETURN_NOT_OK(CheckOpen());
return impl_->ListFlights(options, criteria, listing);
}
Status FlightClient::DoGet(const FlightCallOptions& options, const Ticket& ticket,
std::unique_ptr<FlightStreamReader>* stream) {
+ RETURN_NOT_OK(CheckOpen());
return impl_->DoGet(options, ticket, stream);
}
@@ -1356,6 +1371,7 @@ Status FlightClient::DoPut(const FlightCallOptions& options,
const std::shared_ptr<Schema>& schema,
std::unique_ptr<FlightStreamWriter>* stream,
std::unique_ptr<FlightMetadataReader>* reader) {
+ RETURN_NOT_OK(CheckOpen());
return impl_->DoPut(options, descriptor, schema, stream, reader);
}
@@ -1363,8 +1379,23 @@ Status FlightClient::DoExchange(const FlightCallOptions& options,
const FlightDescriptor& descriptor,
std::unique_ptr<FlightStreamWriter>* writer,
std::unique_ptr<FlightStreamReader>* reader) {
+ RETURN_NOT_OK(CheckOpen());
return impl_->DoExchange(options, descriptor, writer, reader);
}
+Status FlightClient::Close() {
+ // gRPC doesn't offer an explicit shutdown
+ impl_.reset(nullptr);
+ // TODO(ARROW-15473): if we track ongoing RPCs, we can cancel them first
+ return Status::OK();
+}
+
+Status FlightClient::CheckOpen() const {
+ if (!impl_) {
+ return Status::Invalid("FlightClient is closed");
+ }
+ return Status::OK();
+}
+
} // namespace flight
} // namespace arrow
diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h
index fecc510..15c6705 100644
--- a/cpp/src/arrow/flight/client.h
+++ b/cpp/src/arrow/flight/client.h
@@ -323,8 +323,19 @@ class ARROW_FLIGHT_EXPORT FlightClient {
return DoExchange({}, descriptor, writer, reader);
}
+ /// \brief Explicitly shut down and clean up the client.
+ ///
+ /// For backwards compatibility, this will be implicitly called by
+ /// the destructor if not already called, but this gives the
+ /// application no chance to handle errors, so it is recommended to
+ /// explicitly close the client.
+ ///
+ /// \since 8.0.0
+ Status Close();
+
private:
FlightClient();
+ Status CheckOpen() const;
class FlightClientImpl;
std::unique_ptr<FlightClientImpl> impl_;
};
diff --git a/cpp/src/arrow/flight/flight_test.cc b/cpp/src/arrow/flight/flight_test.cc
index 339f7b4..df3f602 100644
--- a/cpp/src/arrow/flight/flight_test.cc
+++ b/cpp/src/arrow/flight/flight_test.cc
@@ -194,7 +194,9 @@ TEST(TestFlight, ConnectUri) {
ASSERT_OK(Location::Parse(uri, &location1));
ASSERT_OK(Location::Parse(uri, &location2));
ASSERT_OK(FlightClient::Connect(location1, &client));
+ ASSERT_OK(client->Close());
ASSERT_OK(FlightClient::Connect(location2, &client));
+ ASSERT_OK(client->Close());
}
#ifndef _WIN32
@@ -213,7 +215,9 @@ TEST(TestFlight, ConnectUriUnix) {
ASSERT_OK(Location::Parse(uri, &location1));
ASSERT_OK(Location::Parse(uri, &location2));
ASSERT_OK(FlightClient::Connect(location1, &client));
+ ASSERT_OK(client->Close());
ASSERT_OK(FlightClient::Connect(location2, &client));
+ ASSERT_OK(client->Close());
}
#endif
@@ -405,7 +409,10 @@ class TestFlightClient : public ::testing::Test {
ASSERT_OK(ConnectClient());
}
- void TearDown() { ASSERT_OK(server_->Shutdown()); }
+ void TearDown() {
+ ASSERT_OK(client_->Close());
+ ASSERT_OK(server_->Shutdown());
+ }
Status ConnectClient() {
Location location;
@@ -631,7 +638,10 @@ class TestMetadata : public ::testing::Test {
[](FlightClientOptions* options) { return Status::OK(); }));
}
- void TearDown() { ASSERT_OK(server_->Shutdown()); }
+ void TearDown() {
+ ASSERT_OK(client_->Close());
+ ASSERT_OK(server_->Shutdown());
+ }
protected:
std::unique_ptr<FlightClient> client_;
@@ -646,7 +656,10 @@ class TestOptions : public ::testing::Test {
[](FlightClientOptions* options) { return Status::OK(); }));
}
- void TearDown() { ASSERT_OK(server_->Shutdown()); }
+ void TearDown() {
+ ASSERT_OK(client_->Close());
+ ASSERT_OK(server_->Shutdown());
+ }
protected:
std::unique_ptr<FlightClient> client_;
@@ -666,7 +679,10 @@ class TestAuthHandler : public ::testing::Test {
[](FlightClientOptions* options) { return Status::OK(); }));
}
- void TearDown() { ASSERT_OK(server_->Shutdown()); }
+ void TearDown() {
+ ASSERT_OK(client_->Close());
+ ASSERT_OK(server_->Shutdown());
+ }
protected:
std::unique_ptr<FlightClient> client_;
@@ -686,7 +702,10 @@ class TestBasicAuthHandler : public ::testing::Test {
[](FlightClientOptions* options) { return Status::OK(); }));
}
- void TearDown() { ASSERT_OK(server_->Shutdown()); }
+ void TearDown() {
+ ASSERT_OK(client_->Close());
+ ASSERT_OK(server_->Shutdown());
+ }
protected:
std::unique_ptr<FlightClient> client_;
@@ -702,7 +721,10 @@ class TestDoPut : public ::testing::Test {
do_put_server_ = (DoPutTestServer*)server_.get();
}
- void TearDown() { ASSERT_OK(server_->Shutdown()); }
+ void TearDown() {
+ ASSERT_OK(client_->Close());
+ ASSERT_OK(server_->Shutdown());
+ }
void CheckBatches(FlightDescriptor expected_descriptor,
const BatchVector& expected_batches) {
@@ -758,6 +780,7 @@ class TestTls : public ::testing::Test {
}
void TearDown() {
+ ASSERT_OK(client_->Close());
ASSERT_OK(server_->Shutdown());
grpc_shutdown();
}
@@ -1070,7 +1093,10 @@ class TestRejectServerMiddleware : public ::testing::Test {
[](FlightClientOptions* options) { return Status::OK(); }));
}
- void TearDown() { ASSERT_OK(server_->Shutdown()); }
+ void TearDown() {
+ ASSERT_OK(client_->Close());
+ ASSERT_OK(server_->Shutdown());
+ }
protected:
std::unique_ptr<FlightClient> client_;
@@ -1090,7 +1116,10 @@ class TestCountingServerMiddleware : public ::testing::Test {
[](FlightClientOptions* options) { return Status::OK(); }));
}
- void TearDown() { ASSERT_OK(server_->Shutdown()); }
+ void TearDown() {
+ ASSERT_OK(client_->Close());
+ ASSERT_OK(server_->Shutdown());
+ }
protected:
std::shared_ptr<CountingServerMiddlewareFactory> request_counter_;
@@ -1144,6 +1173,7 @@ class TestPropagatingMiddleware : public ::testing::Test {
}
void TearDown() {
+ ASSERT_OK(client_->Close());
ASSERT_OK(first_server_->Shutdown());
ASSERT_OK(second_server_->Shutdown());
}
@@ -1174,7 +1204,10 @@ class TestErrorMiddleware : public ::testing::Test {
[](FlightClientOptions* options) { return Status::OK(); }));
}
- void TearDown() { ASSERT_OK(server_->Shutdown()); }
+ void TearDown() {
+ ASSERT_OK(client_->Close());
+ ASSERT_OK(server_->Shutdown());
+ }
protected:
std::unique_ptr<FlightClient> client_;
@@ -1222,7 +1255,10 @@ class TestBasicHeaderAuthMiddleware : public ::testing::Test {
::testing::HasSubstr("Invalid credentials"));
}
- void TearDown() { ASSERT_OK(server_->Shutdown()); }
+ void TearDown() {
+ ASSERT_OK(client_->Close());
+ ASSERT_OK(server_->Shutdown());
+ }
protected:
std::unique_ptr<FlightClient> client_;
@@ -1892,6 +1928,17 @@ TEST_F(TestFlightClient, NoTimeout) {
ASSERT_NE(nullptr, info);
}
+TEST_F(TestFlightClient, Close) {
+ // For gRPC, this is always effectively a no-op
+ ASSERT_OK(client_->Close());
+ // Idempotent
+ ASSERT_OK(client_->Close());
+
+ std::unique_ptr<FlightListing> listing;
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("FlightClient is closed"),
+ client_->ListFlights(&listing));
+}
+
TEST_F(TestDoPut, DoPutInts) {
auto descr = FlightDescriptor::Path({"ints"});
BatchVector batches;
@@ -2807,7 +2854,10 @@ class TestCancel : public ::testing::Test {
&server_, &client_, [](FlightServerOptions* options) { return Status::OK(); },
[](FlightClientOptions* options) { return Status::OK(); }));
}
- void TearDown() { ASSERT_OK(server_->Shutdown()); }
+ void TearDown() {
+ ASSERT_OK(client_->Close());
+ ASSERT_OK(server_->Shutdown());
+ }
protected:
std::unique_ptr<FlightClient> client_;
diff --git a/cpp/src/arrow/flight/sql/client.cc b/cpp/src/arrow/flight/sql/client.cc
index addb90a..e72935c 100644
--- a/cpp/src/arrow/flight/sql/client.cc
+++ b/cpp/src/arrow/flight/sql/client.cc
@@ -411,6 +411,8 @@ Status PreparedStatement::Close() {
return Status::OK();
}
+Status FlightSqlClient::Close() { return impl_->Close(); }
+
arrow::Result<std::unique_ptr<FlightInfo>> FlightSqlClient::GetSqlInfo(
const FlightCallOptions& options, const std::vector<int>& sql_info) {
flight_sql_pb::CommandGetSqlInfo command;
diff --git a/cpp/src/arrow/flight/sql/client.h b/cpp/src/arrow/flight/sql/client.h
index 5bf1b3e..fd2234c 100644
--- a/cpp/src/arrow/flight/sql/client.h
+++ b/cpp/src/arrow/flight/sql/client.h
@@ -163,6 +163,9 @@ class ARROW_EXPORT FlightSqlClient {
return info;
}
+ /// \brief Explicitly shut down and clean up the client.
+ Status Close();
+
protected:
virtual Status DoPut(const FlightCallOptions& options,
const FlightDescriptor& descriptor,
diff --git a/cpp/src/arrow/flight/sql/server_test.cc b/cpp/src/arrow/flight/sql/server_test.cc
index 507745c..47a03bd 100644
--- a/cpp/src/arrow/flight/sql/server_test.cc
+++ b/cpp/src/arrow/flight/sql/server_test.cc
@@ -168,6 +168,7 @@ class TestFlightSqlServer : public ::testing::Test {
}
void TearDown() override {
+ ASSERT_OK(sql_client->Close());
sql_client.reset();
ASSERT_OK(server->Shutdown());
diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx
index 5b00c53..f8c5856 100644
--- a/python/pyarrow/_flight.pyx
+++ b/python/pyarrow/_flight.pyx
@@ -1454,6 +1454,20 @@ cdef class FlightClient(_Weakrefable):
py_reader.reader.reset(c_reader.release())
return py_writer, py_reader
+ def close(self):
+ check_flight_status(self.client.get().Close())
+
+ def __del__(self):
+ # Not ideal, but close() wasn't originally present so
+ # applications may not be calling it
+ self.close()
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ self.close()
+
cdef class FlightDataStream(_Weakrefable):
"""Abstract base class for Flight data streams."""
diff --git a/python/pyarrow/includes/libarrow_flight.pxd b/python/pyarrow/includes/libarrow_flight.pxd
index 2ac737a..364821c 100644
--- a/python/pyarrow/includes/libarrow_flight.pxd
+++ b/python/pyarrow/includes/libarrow_flight.pxd
@@ -347,6 +347,7 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil:
CFlightDescriptor& descriptor,
unique_ptr[CFlightStreamWriter]* writer,
unique_ptr[CFlightStreamReader]* reader)
+ CStatus Close()
cdef cppclass CFlightStatusCode" arrow::flight::FlightStatusCode":
bint operator==(CFlightStatusCode)
diff --git a/python/pyarrow/tests/test_flight.py b/python/pyarrow/tests/test_flight.py
index f3b6b91..12f815d 100644
--- a/python/pyarrow/tests/test_flight.py
+++ b/python/pyarrow/tests/test_flight.py
@@ -882,38 +882,48 @@ def test_client_wait_for_available():
server = FlightServerBase(location)
server.serve()
- client = FlightClient(location)
- thread = threading.Thread(target=serve, daemon=True)
- thread.start()
+ with FlightClient(location) as client:
+ thread = threading.Thread(target=serve, daemon=True)
+ thread.start()
- started = time.time()
- client.wait_for_available(timeout=5)
- elapsed = time.time() - started
- assert elapsed >= 0.5
+ started = time.time()
+ client.wait_for_available(timeout=5)
+ elapsed = time.time() - started
+ assert elapsed >= 0.5
def test_flight_list_flights():
"""Try a simple list_flights call."""
- with ConstantFlightServer() as server:
- client = flight.connect(('localhost', server.port))
+ with ConstantFlightServer() as server, \
+ flight.connect(('localhost', server.port)) as client:
assert list(client.list_flights()) == []
flights = client.list_flights(ConstantFlightServer.CRITERIA)
assert len(list(flights)) == 1
+def test_flight_client_close():
+ with ConstantFlightServer() as server, \
+ flight.connect(('localhost', server.port)) as client:
+ assert list(client.list_flights()) == []
+ client.close()
+ client.close() # Idempotent
+ with pytest.raises(pa.ArrowInvalid):
+ list(client.list_flights())
+
+
def test_flight_do_get_ints():
"""Try a simple do_get call."""
table = simple_ints_table()
- with ConstantFlightServer() as server:
- client = flight.connect(('localhost', server.port))
+ with ConstantFlightServer() as server, \
+ flight.connect(('localhost', server.port)) as client:
data = client.do_get(flight.Ticket(b'ints')).read_all()
assert data.equals(table)
options = pa.ipc.IpcWriteOptions(
metadata_version=pa.ipc.MetadataVersion.V4)
- with ConstantFlightServer(options=options) as server:
- client = flight.connect(('localhost', server.port))
+ with ConstantFlightServer(options=options) as server, \
+ flight.connect(('localhost', server.port)) as client:
data = client.do_get(flight.Ticket(b'ints')).read_all()
assert data.equals(table)
@@ -923,8 +933,8 @@ def test_flight_do_get_ints():
with pytest.raises(flight.FlightServerError,
match="expected IpcWriteOptions, got <class 'int'>"):
- with ConstantFlightServer(options=42) as server:
- client = flight.connect(('localhost', server.port))
+ with ConstantFlightServer(options=42) as server, \
+ flight.connect(('localhost', server.port)) as client:
data = client.do_get(flight.Ticket(b'ints')).read_all()
@@ -933,8 +943,8 @@ def test_do_get_ints_pandas():
"""Try a simple do_get call."""
table = simple_ints_table()
- with ConstantFlightServer() as server:
- client = flight.connect(('localhost', server.port))
+ with ConstantFlightServer() as server, \
+ flight.connect(('localhost', server.port)) as client:
data = client.do_get(flight.Ticket(b'ints')).read_pandas()
assert list(data['some_ints']) == table.column(0).to_pylist()
@@ -942,8 +952,8 @@ def test_do_get_ints_pandas():
def test_flight_do_get_dicts():
table = simple_dicts_table()
- with ConstantFlightServer() as server:
- client = flight.connect(('localhost', server.port))
+ with ConstantFlightServer() as server, \
+ flight.connect(('localhost', server.port)) as client:
data = client.do_get(flight.Ticket(b'dicts')).read_all()
assert data.equals(table)
@@ -952,8 +962,8 @@ def test_flight_do_get_ticket():
"""Make sure Tickets get passed to the server."""
data1 = [pa.array([-10, -5, 0, 5, 10], type=pa.int32())]
table = pa.Table.from_arrays(data1, names=['a'])
- with CheckTicketFlightServer(expected_ticket=b'the-ticket') as server:
- client = flight.connect(('localhost', server.port))
+ with CheckTicketFlightServer(expected_ticket=b'the-ticket') as server, \
+ flight.connect(('localhost', server.port)) as client:
data = client.do_get(flight.Ticket(b'the-ticket')).read_all()
assert data.equals(table)
@@ -975,8 +985,8 @@ def test_flight_get_info():
def test_flight_get_schema():
"""Make sure GetSchema returns correct schema."""
- with GetInfoFlightServer() as server:
- client = FlightClient(('localhost', server.port))
+ with GetInfoFlightServer() as server, \
+ FlightClient(('localhost', server.port)) as client:
info = client.get_schema(flight.FlightDescriptor.for_command(b''))
assert info.schema == pa.schema([('a', pa.int32())])
@@ -984,8 +994,8 @@ def test_flight_get_schema():
def test_list_actions():
"""Make sure the return type of ListActions is validated."""
# ARROW-6392
- with ListActionsErrorFlightServer() as server:
- client = FlightClient(('localhost', server.port))
+ with ListActionsErrorFlightServer() as server, \
+ FlightClient(('localhost', server.port)) as client:
with pytest.raises(
flight.FlightServerError,
match=("Results of list_actions must be "
@@ -993,8 +1003,8 @@ def test_list_actions():
):
list(client.list_actions())
- with ListActionsFlightServer() as server:
- client = FlightClient(('localhost', server.port))
+ with ListActionsFlightServer() as server, \
+ FlightClient(('localhost', server.port)) as client:
assert list(client.list_actions()) == \
ListActionsFlightServer.expected_actions()
@@ -1020,8 +1030,8 @@ class ConvenienceServer(FlightServerBase):
def test_do_action_result_convenience():
- with ConvenienceServer() as server:
- client = FlightClient(('localhost', server.port))
+ with ConvenienceServer() as server, \
+ FlightClient(('localhost', server.port)) as client:
# do_action as action type without body
results = [x.body for x in client.do_action('simple-action')]
@@ -1034,8 +1044,8 @@ def test_do_action_result_convenience():
def test_nicer_server_exceptions():
- with ConvenienceServer() as server:
- client = FlightClient(('localhost', server.port))
+ with ConvenienceServer() as server, \
+ FlightClient(('localhost', server.port)) as client:
with pytest.raises(flight.FlightServerError,
match="a bytes-like object is required"):
list(client.do_action('bad-action'))
@@ -1064,8 +1074,8 @@ def test_flight_domain_socket():
with tempfile.NamedTemporaryFile() as sock:
sock.close()
location = flight.Location.for_grpc_unix(sock.name)
- with ConstantFlightServer(location=location):
- client = FlightClient(location)
+ with ConstantFlightServer(location=location), \
+ FlightClient(location) as client:
reader = client.do_get(flight.Ticket(b'ints'))
table = simple_ints_table()
@@ -1091,8 +1101,8 @@ def test_flight_large_message():
pa.array(range(0, 10 * 1024 * 1024))
], names=['a'])
- with EchoFlightServer(expected_schema=data.schema) as server:
- client = FlightClient(('localhost', server.port))
+ with EchoFlightServer(expected_schema=data.schema) as server, \
+ FlightClient(('localhost', server.port)) as client:
writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'),
data.schema)
# Write a single giant chunk
@@ -1108,8 +1118,8 @@ def test_flight_generator_stream():
pa.array(range(0, 10 * 1024))
], names=['a'])
- with EchoStreamFlightServer() as server:
- client = FlightClient(('localhost', server.port))
+ with EchoStreamFlightServer() as server, \
+ FlightClient(('localhost', server.port)) as client:
writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'),
data.schema)
writer.write_table(data)
@@ -1120,8 +1130,8 @@ def test_flight_generator_stream():
def test_flight_invalid_generator_stream():
"""Try streaming data with mismatched schemas."""
- with InvalidStreamFlightServer() as server:
- client = FlightClient(('localhost', server.port))
+ with InvalidStreamFlightServer() as server, \
+ FlightClient(('localhost', server.port)) as client:
with pytest.raises(pa.ArrowException):
client.do_get(flight.Ticket(b'')).read_all()
@@ -1130,8 +1140,8 @@ def test_timeout_fires():
"""Make sure timeouts fire on slow requests."""
# Do this in a separate thread so that if it fails, we don't hang
# the entire test process
- with SlowFlightServer() as server:
- client = FlightClient(('localhost', server.port))
+ with SlowFlightServer() as server, \
+ FlightClient(('localhost', server.port)) as client:
action = flight.Action("", b"")
options = flight.FlightCallOptions(timeout=0.2)
# gRPC error messages change based on version, so don't look
@@ -1142,8 +1152,8 @@ def test_timeout_fires():
def test_timeout_passes():
"""Make sure timeouts do not fire on fast requests."""
- with ConstantFlightServer() as server:
- client = FlightClient(('localhost', server.port))
+ with ConstantFlightServer() as server, \
+ FlightClient(('localhost', server.port)) as client:
options = flight.FlightCallOptions(timeout=5.0)
client.do_get(flight.Ticket(b'ints'), options=options).read_all()
@@ -1160,8 +1170,8 @@ token_auth_handler = TokenServerAuthHandler(creds={
@pytest.mark.slow
def test_http_basic_unauth():
"""Test that auth fails when not authenticated."""
- with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server:
- client = FlightClient(('localhost', server.port))
+ with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server, \
+ FlightClient(('localhost', server.port)) as client:
action = flight.Action("who-am-i", b"")
with pytest.raises(flight.FlightUnauthenticatedError,
match=".*unauthenticated.*"):
@@ -1172,8 +1182,8 @@ def test_http_basic_unauth():
reason="ARROW-10013: gRPC on Windows corrupts peer()")
def test_http_basic_auth():
"""Test a Python implementation of HTTP basic authentication."""
- with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server:
- client = FlightClient(('localhost', server.port))
+ with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server, \
+ FlightClient(('localhost', server.port)) as client:
action = flight.Action("who-am-i", b"")
client.authenticate(HttpBasicClientAuthHandler('test', 'p4ssw0rd'))
results = client.do_action(action)
@@ -1185,8 +1195,8 @@ def test_http_basic_auth():
def test_http_basic_auth_invalid_password():
"""Test that auth fails with the wrong password."""
- with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server:
- client = FlightClient(('localhost', server.port))
+ with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server, \
+ FlightClient(('localhost', server.port)) as client:
action = flight.Action("who-am-i", b"")
with pytest.raises(flight.FlightUnauthenticatedError,
match=".*wrong password.*"):
@@ -1196,8 +1206,8 @@ def test_http_basic_auth_invalid_password():
def test_token_auth():
"""Test an auth mechanism that uses a handshake."""
- with EchoStreamFlightServer(auth_handler=token_auth_handler) as server:
- client = FlightClient(('localhost', server.port))
+ with EchoStreamFlightServer(auth_handler=token_auth_handler) as server, \
+ FlightClient(('localhost', server.port)) as client:
action = flight.Action("who-am-i", b"")
client.authenticate(TokenClientAuthHandler('test', 'p4ssw0rd'))
identity = next(client.do_action(action))
@@ -1206,8 +1216,8 @@ def test_token_auth():
def test_token_auth_invalid():
"""Test an auth mechanism that uses a handshake."""
- with EchoStreamFlightServer(auth_handler=token_auth_handler) as server:
- client = FlightClient(('localhost', server.port))
+ with EchoStreamFlightServer(auth_handler=token_auth_handler) as server, \
+ FlightClient(('localhost', server.port)) as client:
with pytest.raises(flight.FlightUnauthenticatedError):
client.authenticate(TokenClientAuthHandler('test', 'wrong'))
@@ -1220,8 +1230,8 @@ def test_authenticate_basic_token():
"""Test authenticate_basic_token with bearer token and auth headers."""
with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={
"auth": HeaderAuthServerMiddlewareFactory()
- }) as server:
- client = FlightClient(('localhost', server.port))
+ }) as server, \
+ FlightClient(('localhost', server.port)) as client:
token_pair = client.authenticate_basic_token(b'test', b'password')
assert token_pair[0] == b'authorization'
assert token_pair[1] == b'Bearer token1234'
@@ -1231,8 +1241,8 @@ def test_authenticate_basic_token_invalid_password():
"""Test authenticate_basic_token with an invalid password."""
with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={
"auth": HeaderAuthServerMiddlewareFactory()
- }) as server:
- client = FlightClient(('localhost', server.port))
+ }) as server, \
+ FlightClient(('localhost', server.port)) as client:
with pytest.raises(flight.FlightUnauthenticatedError):
client.authenticate_basic_token(b'test', b'badpassword')
@@ -1241,8 +1251,8 @@ def test_authenticate_basic_token_and_action():
"""Test authenticate_basic_token and doAction after authentication."""
with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={
"auth": HeaderAuthServerMiddlewareFactory()
- }) as server:
- client = FlightClient(('localhost', server.port))
+ }) as server, \
+ FlightClient(('localhost', server.port)) as client:
token_pair = client.authenticate_basic_token(b'test', b'password')
assert token_pair[0] == b'authorization'
assert token_pair[1] == b'Bearer token1234'
@@ -1281,17 +1291,18 @@ def test_authenticate_basic_token_with_client_middleware():
assert client_auth_middleware.call_credential[0] == b'authorization'
assert client_auth_middleware.call_credential[1] == \
b'Bearer ' + b'token1234'
+ client.close()
def test_arbitrary_headers_in_flight_call_options():
"""Test passing multiple arbitrary headers to the middleware."""
with ArbitraryHeadersFlightServer(
- auth_handler=no_op_auth_handler,
- middleware={
- "auth": HeaderAuthServerMiddlewareFactory(),
- "arbitrary-headers": ArbitraryHeadersServerMiddlewareFactory()
- }) as server:
- client = FlightClient(('localhost', server.port))
+ auth_handler=no_op_auth_handler,
+ middleware={
+ "auth": HeaderAuthServerMiddlewareFactory(),
+ "arbitrary-headers": ArbitraryHeadersServerMiddlewareFactory()
+ }) as server, \
+ FlightClient(('localhost', server.port)) as client:
token_pair = client.authenticate_basic_token(b'test', b'password')
assert token_pair[0] == b'authorization'
assert token_pair[1] == b'Bearer token1234'
@@ -1328,11 +1339,10 @@ def test_tls_fails():
"""Make sure clients cannot connect when cert verification fails."""
certs = example_tls_certs()
- with ConstantFlightServer(tls_certificates=certs["certificates"]) as s:
- # Ensure client doesn't connect when certificate verification
- # fails (this is a slow test since gRPC does retry a few times)
- client = FlightClient("grpc+tls://localhost:" + str(s.port))
-
+ # Ensure client doesn't connect when certificate verification
+ # fails (this is a slow test since gRPC does retry a few times)
+ with ConstantFlightServer(tls_certificates=certs["certificates"]) as s, \
+ FlightClient("grpc+tls://localhost:" + str(s.port)) as client:
# gRPC error messages change based on version, so don't look
# for a particular error
with pytest.raises(flight.FlightUnavailableError):
@@ -1345,9 +1355,9 @@ def test_tls_do_get():
table = simple_ints_table()
certs = example_tls_certs()
- with ConstantFlightServer(tls_certificates=certs["certificates"]) as s:
- client = FlightClient(('localhost', s.port),
- tls_root_certs=certs["root_cert"])
+ with ConstantFlightServer(tls_certificates=certs["certificates"]) as s, \
+ FlightClient(('localhost', s.port),
+ tls_root_certs=certs["root_cert"]) as client:
data = client.do_get(flight.Ticket(b'ints')).read_all()
assert data.equals(table)
@@ -1366,6 +1376,7 @@ def test_tls_disable_server_verification():
pytest.skip('disable_server_verification feature is not available')
data = client.do_get(flight.Ticket(b'ints')).read_all()
assert data.equals(table)
+ client.close()
@pytest.mark.requires_testing_data
@@ -1373,10 +1384,10 @@ def test_tls_override_hostname():
"""Check that incorrectly overriding the hostname fails."""
certs = example_tls_certs()
- with ConstantFlightServer(tls_certificates=certs["certificates"]) as s:
- client = flight.connect(('localhost', s.port),
- tls_root_certs=certs["root_cert"],
- override_hostname="fakehostname")
+ with ConstantFlightServer(tls_certificates=certs["certificates"]) as s,\
+ flight.connect(('localhost', s.port),
+ tls_root_certs=certs["root_cert"],
+ override_hostname="fakehostname") as client:
with pytest.raises(flight.FlightUnavailableError):
client.do_get(flight.Ticket(b'ints'))
@@ -1389,8 +1400,8 @@ def test_flight_do_get_metadata():
table = pa.Table.from_arrays(data, names=['a'])
batches = []
- with MetadataFlightServer() as server:
- client = FlightClient(('localhost', server.port))
+ with MetadataFlightServer() as server, \
+ FlightClient(('localhost', server.port)) as client:
reader = client.do_get(flight.Ticket(b''))
idx = 0
while True:
@@ -1412,8 +1423,8 @@ def test_flight_do_get_metadata_v4():
[pa.array([-10, -5, 0, 5, 10])], names=['a'])
options = pa.ipc.IpcWriteOptions(
metadata_version=pa.ipc.MetadataVersion.V4)
- with MetadataFlightServer(options=options) as server:
- client = FlightClient(('localhost', server.port))
+ with MetadataFlightServer(options=options) as server, \
+ FlightClient(('localhost', server.port)) as client:
reader = client.do_get(flight.Ticket(b''))
data = reader.read_all()
assert data.equals(table)
@@ -1426,8 +1437,8 @@ def test_flight_do_put_metadata():
]
table = pa.Table.from_arrays(data, names=['a'])
- with MetadataFlightServer() as server:
- client = FlightClient(('localhost', server.port))
+ with MetadataFlightServer() as server, \
+ FlightClient(('localhost', server.port)) as client:
writer, metadata_reader = client.do_put(
flight.FlightDescriptor.for_path(''),
table.schema)
@@ -1447,9 +1458,9 @@ def test_flight_do_put_limit():
pa.array(np.ones(768, dtype=np.int64())),
], names=['a'])
- with EchoFlightServer() as server:
- client = FlightClient(('localhost', server.port),
- write_size_limit_bytes=4096)
+ with EchoFlightServer() as server, \
+ FlightClient(('localhost', server.port),
+ write_size_limit_bytes=4096) as client:
writer, metadata_reader = client.do_put(
flight.FlightDescriptor.for_path(''),
large_batch.schema)
@@ -1472,8 +1483,8 @@ def test_flight_do_put_limit():
@pytest.mark.slow
def test_cancel_do_get():
"""Test canceling a DoGet operation on the client side."""
- with ConstantFlightServer() as server:
- client = FlightClient(('localhost', server.port))
+ with ConstantFlightServer() as server, \
+ FlightClient(('localhost', server.port)) as client:
reader = client.do_get(flight.Ticket(b'ints'))
reader.cancel()
with pytest.raises(flight.FlightCancelledError, match=".*Cancel.*"):
@@ -1483,8 +1494,8 @@ def test_cancel_do_get():
@pytest.mark.slow
def test_cancel_do_get_threaded():
"""Test canceling a DoGet operation from another thread."""
- with SlowFlightServer() as server:
- client = FlightClient(('localhost', server.port))
+ with SlowFlightServer() as server, \
+ FlightClient(('localhost', server.port)) as client:
reader = client.do_get(flight.Ticket(b'ints'))
read_first_message = threading.Event()
@@ -1547,8 +1558,8 @@ def test_roundtrip_types():
def test_roundtrip_errors():
"""Ensure that Flight errors propagate from server to client."""
- with ErrorFlightServer() as server:
- client = FlightClient(('localhost', server.port))
+ with ErrorFlightServer() as server, \
+ FlightClient(('localhost', server.port)) as client:
with pytest.raises(flight.FlightInternalError, match=".*foo.*"):
list(client.do_action(flight.Action("internal", b"")))
@@ -1600,8 +1611,8 @@ def test_do_put_independent_read_write():
]
table = pa.Table.from_arrays(data, names=['a'])
- with MetadataFlightServer() as server:
- client = FlightClient(('localhost', server.port))
+ with MetadataFlightServer() as server, \
+ FlightClient(('localhost', server.port)) as client:
writer, metadata_reader = client.do_put(
flight.FlightDescriptor.for_path(''),
table.schema)
@@ -1633,8 +1644,8 @@ def test_server_middleware_same_thread():
"""Ensure that server middleware run on the same thread as the RPC."""
with HeaderFlightServer(middleware={
"test": HeaderServerMiddlewareFactory(),
- }) as server:
- client = FlightClient(('localhost', server.port))
+ }) as server, \
+ FlightClient(('localhost', server.port)) as client:
results = list(client.do_action(flight.Action(b"test", b"")))
assert len(results) == 1
value = results[0].body.to_pybytes()
@@ -1645,8 +1656,8 @@ def test_middleware_reject():
"""Test rejecting an RPC with server middleware."""
with HeaderFlightServer(middleware={
"test": SelectiveAuthServerMiddlewareFactory(),
- }) as server:
- client = FlightClient(('localhost', server.port))
+ }) as server, \
+ FlightClient(('localhost', server.port)) as client:
# The middleware allows this through without auth.
with pytest.raises(pa.ArrowNotImplementedError):
list(client.list_actions())
@@ -1667,11 +1678,11 @@ def test_middleware_mapping():
"""Test that middleware records methods correctly."""
server_middleware = RecordingServerMiddlewareFactory()
client_middleware = RecordingClientMiddlewareFactory()
- with FlightServerBase(middleware={"test": server_middleware}) as server:
- client = FlightClient(
+ with FlightServerBase(middleware={"test": server_middleware}) as server, \
+ FlightClient(
('localhost', server.port),
middleware=[client_middleware]
- )
+ ) as client:
descriptor = flight.FlightDescriptor.for_command(b"")
with pytest.raises(NotImplementedError):
@@ -1708,8 +1719,8 @@ def test_middleware_mapping():
def test_extra_info():
- with ErrorFlightServer() as server:
- client = FlightClient(('localhost', server.port))
+ with ErrorFlightServer() as server, \
+ FlightClient(('localhost', server.port)) as client:
try:
list(client.do_action(flight.Action("protobuf", b"")))
assert False
@@ -1728,12 +1739,12 @@ def test_mtls():
with ConstantFlightServer(
tls_certificates=[certs["certificates"][0]],
verify_client=True,
- root_certificates=certs["root_cert"]) as s:
- client = FlightClient(
+ root_certificates=certs["root_cert"]) as s, \
+ FlightClient(
('localhost', s.port),
tls_root_certs=certs["root_cert"],
cert_chain=certs["certificates"][0].cert,
- private_key=certs["certificates"][0].key)
+ private_key=certs["certificates"][0].key) as client:
data = client.do_get(flight.Ticket(b'ints')).read_all()
assert data.equals(table)
@@ -1744,8 +1755,8 @@ def test_doexchange_get():
pa.array(range(0, 10 * 1024))
], names=["a"])
- with ExchangeFlightServer() as server:
- client = FlightClient(("localhost", server.port))
+ with ExchangeFlightServer() as server, \
+ FlightClient(("localhost", server.port)) as client:
descriptor = flight.FlightDescriptor.for_command(b"get")
writer, reader = client.do_exchange(descriptor)
with writer:
@@ -1760,8 +1771,8 @@ def test_doexchange_put():
], names=["a"])
batches = data.to_batches(max_chunksize=512)
- with ExchangeFlightServer() as server:
- client = FlightClient(("localhost", server.port))
+ with ExchangeFlightServer() as server, \
+ FlightClient(("localhost", server.port)) as client:
descriptor = flight.FlightDescriptor.for_command(b"put")
writer, reader = client.do_exchange(descriptor)
with writer:
@@ -1782,8 +1793,8 @@ def test_doexchange_echo():
], names=["a"])
batches = data.to_batches(max_chunksize=512)
- with ExchangeFlightServer() as server:
- client = FlightClient(("localhost", server.port))
+ with ExchangeFlightServer() as server, \
+ FlightClient(("localhost", server.port)) as client:
descriptor = flight.FlightDescriptor.for_command(b"echo")
writer, reader = client.do_exchange(descriptor)
with writer:
@@ -1822,8 +1833,8 @@ def test_doexchange_echo_v4():
options = pa.ipc.IpcWriteOptions(
metadata_version=pa.ipc.MetadataVersion.V4)
- with ExchangeFlightServer(options=options) as server:
- client = FlightClient(("localhost", server.port))
+ with ExchangeFlightServer(options=options) as server, \
+ FlightClient(("localhost", server.port)) as client:
descriptor = flight.FlightDescriptor.for_command(b"echo")
writer, reader = client.do_exchange(descriptor)
with writer:
@@ -1848,8 +1859,8 @@ def test_doexchange_transform():
pa.array(range(3, 1024 * 3 + 3, 3)),
], names=["sum"])
- with ExchangeFlightServer() as server:
- client = FlightClient(("localhost", server.port))
+ with ExchangeFlightServer() as server, \
+ FlightClient(("localhost", server.port)) as client:
descriptor = flight.FlightDescriptor.for_command(b"transform")
writer, reader = client.do_exchange(descriptor)
with writer:
@@ -1866,15 +1877,17 @@ def test_middleware_multi_header():
"test": MultiHeaderServerMiddlewareFactory(),
}) as server:
headers = MultiHeaderClientMiddlewareFactory()
- client = FlightClient(('localhost', server.port), middleware=[headers])
- response = next(client.do_action(flight.Action(b"", b"")))
- # The server echoes the headers it got back to us.
- raw_headers = response.body.to_pybytes().decode("utf-8")
- client_headers = ast.literal_eval(raw_headers)
- # Don't directly compare; gRPC may add headers like User-Agent.
- for header, values in MultiHeaderClientMiddleware.EXPECTED.items():
- assert client_headers.get(header) == values
- assert headers.last_headers.get(header) == values
+ with FlightClient(
+ ('localhost', server.port),
+ middleware=[headers]) as client:
+ response = next(client.do_action(flight.Action(b"", b"")))
+ # The server echoes the headers it got back to us.
+ raw_headers = response.body.to_pybytes().decode("utf-8")
+ client_headers = ast.literal_eval(raw_headers)
+ # Don't directly compare; gRPC may add headers like User-Agent.
+ for header, values in MultiHeaderClientMiddleware.EXPECTED.items():
+ assert client_headers.get(header) == values
+ assert headers.last_headers.get(header) == values
@pytest.mark.requires_testing_data
@@ -1890,6 +1903,7 @@ def test_generic_options():
generic_options=options)
with pytest.raises(flight.FlightUnavailableError):
client.do_get(flight.Ticket(b'ints'))
+ client.close()
# Try setting an int argument that will make requests fail
options = [("grpc.max_receive_message_length", 32)]
client = flight.connect(('localhost', s.port),
@@ -1897,6 +1911,7 @@ def test_generic_options():
generic_options=options)
with pytest.raises(pa.ArrowInvalid):
client.do_get(flight.Ticket(b'ints'))
+ client.close()
class CancelFlightServer(FlightServerBase):
@@ -1946,8 +1961,8 @@ def test_interrupt():
assert isinstance(e, (pa.ArrowCancelled, KeyboardInterrupt)) or \
isinstance(e.__context__, (pa.ArrowCancelled, KeyboardInterrupt))
- with CancelFlightServer() as server:
- client = FlightClient(("localhost", server.port))
+ with CancelFlightServer() as server, \
+ FlightClient(("localhost", server.port)) as client:
reader = client.do_get(flight.Ticket(b""))
test(reader.read_all)
@@ -1960,8 +1975,8 @@ def test_interrupt():
def test_never_sends_data():
# Regression test for ARROW-12779
match = "application server implementation error"
- with NeverSendsDataFlightServer() as server:
- client = flight.connect(('localhost', server.port))
+ with NeverSendsDataFlightServer() as server, \
+ flight.connect(('localhost', server.port)) as client:
with pytest.raises(flight.FlightServerError, match=match):
client.do_get(flight.Ticket(b'')).read_all()
@@ -1978,8 +1993,8 @@ def test_large_descriptor():
# since some CI pipelines can't run the C++ equivalent
large_descriptor = flight.FlightDescriptor.for_command(
b' ' * (2 ** 31 + 1))
- with FlightServerBase() as server:
- client = flight.connect(('localhost', server.port))
+ with FlightServerBase() as server, \
+ flight.connect(('localhost', server.port)) as client:
with pytest.raises(OSError,
match="Failed to serialize Flight descriptor"):
writer, _ = client.do_put(large_descriptor, pa.schema([]))
@@ -1995,8 +2010,8 @@ def test_large_metadata_client():
# Regression test for ARROW-13253
descriptor = flight.FlightDescriptor.for_command(b'')
metadata = b' ' * (2 ** 31 + 1)
- with EchoFlightServer() as server:
- client = flight.connect(('localhost', server.port))
+ with EchoFlightServer() as server, \
+ flight.connect(('localhost', server.port)) as client:
with pytest.raises(pa.ArrowCapacityError,
match="app_metadata size overflow"):
writer, _ = client.do_put(descriptor, pa.schema([]))
@@ -2010,8 +2025,8 @@ def test_large_metadata_client():
writer.write_metadata(metadata)
del metadata
- with LargeMetadataFlightServer() as server:
- client = flight.connect(('localhost', server.port))
+ with LargeMetadataFlightServer() as server, \
+ flight.connect(('localhost', server.port)) as client:
with pytest.raises(flight.FlightServerError,
match="app_metadata size overflow"):
reader = client.do_get(flight.Ticket(b''))
@@ -2042,8 +2057,8 @@ def test_none_action_side_effect():
See https://issues.apache.org/jira/browse/ARROW-14255
"""
- with ActionNoneFlightServer() as server:
- client = FlightClient(('localhost', server.port))
+ with ActionNoneFlightServer() as server, \
+ FlightClient(('localhost', server.port)) as client:
client.do_action(flight.Action("append", b""))
r = client.do_action(flight.Action("get_value", b""))
assert json.loads(next(r).body.to_pybytes()) == [True]
diff --git a/r/NAMESPACE b/r/NAMESPACE
index 6233b4f..d01f66d 100644
--- a/r/NAMESPACE
+++ b/r/NAMESPACE
@@ -227,6 +227,7 @@ export(field)
export(fixed_size_binary)
export(fixed_size_list_of)
export(flight_connect)
+export(flight_disconnect)
export(flight_get)
export(flight_path_exists)
export(flight_put)
diff --git a/r/R/flight.R b/r/R/flight.R
index 4d190de..f56308f 100644
--- a/r/R/flight.R
+++ b/r/R/flight.R
@@ -40,6 +40,14 @@ flight_connect <- function(host = "localhost", port, scheme = "grpc+tcp") {
pa$flight$FlightClient(location)
}
+#' Explicitly close a Flight client
+#'
+#' @param client The client to disconnect
+#' @export
+flight_disconnect <- function(client) {
+ client$close()
+}
+
#' Send data to a Flight server
#'
#' @param client `pyarrow.flight.FlightClient`, as returned by [flight_connect()]
diff --git a/r/_pkgdown.yml b/r/_pkgdown.yml
index 38e808a..ba3162b 100644
--- a/r/_pkgdown.yml
+++ b/r/_pkgdown.yml
@@ -151,6 +151,7 @@ reference:
contents:
- load_flight_server
- flight_connect
+ - flight_disconnect
- flight_get
- flight_put
- list_flights
diff --git a/r/man/flight_disconnect.Rd b/r/man/flight_disconnect.Rd
new file mode 100644
index 0000000..83c5edf
--- /dev/null
+++ b/r/man/flight_disconnect.Rd
@@ -0,0 +1,14 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/flight.R
+\name{flight_disconnect}
+\alias{flight_disconnect}
+\title{Explicitly close a Flight client}
+\usage{
+flight_disconnect(client)
+}
+\arguments{
+\item{client}{The client to disconnect}
+}
+\description{
+Explicitly close a Flight client
+}
diff --git a/r/tests/testthat/test-python-flight.R b/r/tests/testthat/test-python-flight.R
index 0ffc7e4..db67cd6 100644
--- a/r/tests/testthat/test-python-flight.R
+++ b/r/tests/testthat/test-python-flight.R
@@ -57,6 +57,12 @@ if (process_is_running("demo_flight_server")) {
flight_put(client, example_with_times, path = flight_obj)
expect_identical(as.data.frame(flight_get(client, flight_obj)), example_with_times)
})
+
+ test_that("flight_disconnect", {
+ flight_disconnect(client)
+ # Idempotent
+ flight_disconnect(client)
+ })
} else {
# Kinda hacky, let's put a skipped test here, just so we note that the tests
# didn't run