You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ap...@apache.org on 2019/05/08 14:22:03 UTC
[arrow] branch master updated: ARROW-5136: [Flight] Call options
This is an automated email from the ASF dual-hosted git repository.
apitrou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 7f79416 ARROW-5136: [Flight] Call options
7f79416 is described below
commit 7f7941676e0825075aa7a8f22b2344523d316165
Author: David Li <li...@gmail.com>
AuthorDate: Wed May 8 16:21:49 2019 +0200
ARROW-5136: [Flight] Call options
Right now the only option is a timeout on an overall call, which will be useful to limit hangs when connecting to a Flight service that might have gone down. (gRPC is rather aggressive about retrying for a while before giving up.)
You could imagine having a health check action in DoAction, for instance, which fails if the call doesn't complete within a short time.
I would also like a stream timeout, where the timeout only fires if there is a long delay between messages on a DoGet or DoAction, but this is hard to implement as-is in gRPC. Implementing an async client interface might be a better way to achieve this.
Author: David Li <li...@gmail.com>
Closes #4144 from lihalite/arrow-5136-call-options and squashes the following commits:
f20181c90 <David Li> Use varargs for Java Flight CallOptions
a573d2c61 <David Li> Make FlightCallOptions a POD type
d4bef2781 <David Li> Add tests for Flight call options in Python
0e6111154 <David Li> Don't hold GIL during Flight server shutdown
a1829b32c <David Li> implement call options for Flight
81e3c7908 <David Li> add call options to Flight
---
cpp/src/arrow/flight/client.cc | 82 +++++++++------
cpp/src/arrow/flight/client.h | 61 +++++++++--
cpp/src/arrow/flight/flight-test.cc | 30 ++++++
.../java/org/apache/arrow/flight/CallOption.java | 24 +++++
.../java/org/apache/arrow/flight/CallOptions.java | 59 +++++++++++
.../java/org/apache/arrow/flight/FlightClient.java | 103 +++++++++++++------
.../org/apache/arrow/flight/TestCallOptions.java | 113 +++++++++++++++++++++
python/examples/flight/client.py | 9 ++
python/pyarrow/_flight.pyx | 79 ++++++++++----
python/pyarrow/flight.py | 1 +
python/pyarrow/includes/libarrow_flight.pxd | 27 +++--
python/pyarrow/tests/test_flight.py | 29 ++++++
12 files changed, 522 insertions(+), 95 deletions(-)
diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc
index 69b7399..5f4d6dd 100644
--- a/cpp/src/arrow/flight/client.cc
+++ b/cpp/src/arrow/flight/client.cc
@@ -51,12 +51,21 @@ class MemoryPool;
namespace flight {
+FlightCallOptions::FlightCallOptions() : timeout(-1) {}
+
struct ClientRpc {
grpc::ClientContext context;
- ClientRpc() {
+ explicit ClientRpc(const FlightCallOptions& options) {
/// XXX workaround until we have a handshake in Connect
context.set_wait_for_ready(true);
+
+ if (options.timeout.count() >= 0) {
+ std::chrono::system_clock::time_point deadline =
+ std::chrono::time_point_cast<std::chrono::system_clock::time_point::duration>(
+ std::chrono::system_clock::now() + options.timeout);
+ context.set_deadline(deadline);
+ }
}
Status IOError(const std::string& error_message) {
@@ -237,11 +246,12 @@ class FlightClient::FlightClientImpl {
return Status::OK();
}
- Status Authenticate(std::unique_ptr<ClientAuthHandler> auth_handler) {
+ Status Authenticate(const FlightCallOptions& options,
+ std::unique_ptr<ClientAuthHandler> auth_handler) {
auth_handler_ = std::move(auth_handler);
- grpc::ClientContext context{};
+ ClientRpc rpc(options);
std::shared_ptr<grpc::ClientReaderWriter<pb::HandshakeRequest, pb::HandshakeResponse>>
- stream = stub_->Handshake(&context);
+ stream = stub_->Handshake(&rpc.context);
GrpcClientAuthSender outgoing{stream};
GrpcClientAuthReader incoming{stream};
RETURN_NOT_OK(auth_handler_->Authenticate(&outgoing, &incoming));
@@ -249,11 +259,12 @@ class FlightClient::FlightClientImpl {
return Status::OK();
}
- Status ListFlights(const Criteria& criteria, std::unique_ptr<FlightListing>* listing) {
+ Status ListFlights(const FlightCallOptions& options, const Criteria& criteria,
+ std::unique_ptr<FlightListing>* listing) {
// TODO(wesm): populate criteria
pb::Criteria pb_criteria;
- ClientRpc rpc;
+ ClientRpc rpc(options);
RETURN_NOT_OK(rpc.SetToken(auth_handler_.get()));
std::unique_ptr<grpc::ClientReader<pb::FlightInfo>> stream(
stub_->ListFlights(&rpc.context, pb_criteria));
@@ -271,11 +282,12 @@ class FlightClient::FlightClientImpl {
return internal::FromGrpcStatus(stream->Finish());
}
- Status DoAction(const Action& action, std::unique_ptr<ResultStream>* results) {
+ Status DoAction(const FlightCallOptions& options, const Action& action,
+ std::unique_ptr<ResultStream>* results) {
pb::Action pb_action;
RETURN_NOT_OK(internal::ToProto(action, &pb_action));
- ClientRpc rpc;
+ ClientRpc rpc(options);
RETURN_NOT_OK(rpc.SetToken(auth_handler_.get()));
std::unique_ptr<grpc::ClientReader<pb::Result>> stream(
stub_->DoAction(&rpc.context, pb_action));
@@ -294,10 +306,10 @@ class FlightClient::FlightClientImpl {
return internal::FromGrpcStatus(stream->Finish());
}
- Status ListActions(std::vector<ActionType>* types) {
+ Status ListActions(const FlightCallOptions& options, std::vector<ActionType>* types) {
pb::Empty empty;
- ClientRpc rpc;
+ ClientRpc rpc(options);
RETURN_NOT_OK(rpc.SetToken(auth_handler_.get()));
std::unique_ptr<grpc::ClientReader<pb::ActionType>> stream(
stub_->ListActions(&rpc.context, empty));
@@ -311,14 +323,15 @@ class FlightClient::FlightClientImpl {
return internal::FromGrpcStatus(stream->Finish());
}
- Status GetFlightInfo(const FlightDescriptor& descriptor,
+ Status GetFlightInfo(const FlightCallOptions& options,
+ const FlightDescriptor& descriptor,
std::unique_ptr<FlightInfo>* info) {
pb::FlightDescriptor pb_descriptor;
pb::FlightInfo pb_response;
RETURN_NOT_OK(internal::ToProto(descriptor, &pb_descriptor));
- ClientRpc rpc;
+ ClientRpc rpc(options);
RETURN_NOT_OK(rpc.SetToken(auth_handler_.get()));
Status s = internal::FromGrpcStatus(
stub_->GetFlightInfo(&rpc.context, pb_descriptor, &pb_response));
@@ -330,11 +343,12 @@ class FlightClient::FlightClientImpl {
return Status::OK();
}
- Status DoGet(const Ticket& ticket, std::unique_ptr<RecordBatchReader>* out) {
+ Status DoGet(const FlightCallOptions& options, const Ticket& ticket,
+ std::unique_ptr<RecordBatchReader>* out) {
pb::Ticket pb_ticket;
internal::ToProto(ticket, &pb_ticket);
- std::unique_ptr<ClientRpc> rpc(new ClientRpc);
+ std::unique_ptr<ClientRpc> rpc(new ClientRpc(options));
RETURN_NOT_OK(rpc->SetToken(auth_handler_.get()));
std::unique_ptr<grpc::ClientReader<pb::FlightData>> stream(
stub_->DoGet(&rpc->context, pb_ticket));
@@ -344,9 +358,10 @@ class FlightClient::FlightClientImpl {
return ipc::RecordBatchStreamReader::Open(std::move(message_reader), out);
}
- Status DoPut(const FlightDescriptor& descriptor, const std::shared_ptr<Schema>& schema,
+ Status DoPut(const FlightCallOptions& options, const FlightDescriptor& descriptor,
+ const std::shared_ptr<Schema>& schema,
std::unique_ptr<ipc::RecordBatchWriter>* out) {
- std::unique_ptr<ClientRpc> rpc(new ClientRpc);
+ std::unique_ptr<ClientRpc> rpc(new ClientRpc(options));
RETURN_NOT_OK(rpc->SetToken(auth_handler_.get()));
std::unique_ptr<protocol::PutResult> response(new protocol::PutResult);
std::unique_ptr<grpc::ClientWriter<pb::FlightData>> writer(
@@ -374,42 +389,47 @@ Status FlightClient::Connect(const std::string& host, int port,
return (*client)->impl_->Connect(host, port);
}
-Status FlightClient::Authenticate(std::unique_ptr<ClientAuthHandler> auth_handler) {
- return impl_->Authenticate(std::move(auth_handler));
+Status FlightClient::Authenticate(const FlightCallOptions& options,
+ std::unique_ptr<ClientAuthHandler> auth_handler) {
+ return impl_->Authenticate(options, std::move(auth_handler));
}
-Status FlightClient::DoAction(const Action& action,
+Status FlightClient::DoAction(const FlightCallOptions& options, const Action& action,
std::unique_ptr<ResultStream>* results) {
- return impl_->DoAction(action, results);
+ return impl_->DoAction(options, action, results);
}
-Status FlightClient::ListActions(std::vector<ActionType>* actions) {
- return impl_->ListActions(actions);
+Status FlightClient::ListActions(const FlightCallOptions& options,
+ std::vector<ActionType>* actions) {
+ return impl_->ListActions(options, actions);
}
-Status FlightClient::GetFlightInfo(const FlightDescriptor& descriptor,
+Status FlightClient::GetFlightInfo(const FlightCallOptions& options,
+ const FlightDescriptor& descriptor,
std::unique_ptr<FlightInfo>* info) {
- return impl_->GetFlightInfo(descriptor, info);
+ return impl_->GetFlightInfo(options, descriptor, info);
}
Status FlightClient::ListFlights(std::unique_ptr<FlightListing>* listing) {
- return ListFlights({}, listing);
+ return ListFlights({}, {}, listing);
}
-Status FlightClient::ListFlights(const Criteria& criteria,
+Status FlightClient::ListFlights(const FlightCallOptions& options,
+ const Criteria& criteria,
std::unique_ptr<FlightListing>* listing) {
- return impl_->ListFlights(criteria, listing);
+ return impl_->ListFlights(options, criteria, listing);
}
-Status FlightClient::DoGet(const Ticket& ticket,
+Status FlightClient::DoGet(const FlightCallOptions& options, const Ticket& ticket,
std::unique_ptr<RecordBatchReader>* stream) {
- return impl_->DoGet(ticket, stream);
+ return impl_->DoGet(options, ticket, stream);
}
-Status FlightClient::DoPut(const FlightDescriptor& descriptor,
+Status FlightClient::DoPut(const FlightCallOptions& options,
+ const FlightDescriptor& descriptor,
const std::shared_ptr<Schema>& schema,
std::unique_ptr<ipc::RecordBatchWriter>* stream) {
- return impl_->DoPut(descriptor, schema, stream);
+ return impl_->DoPut(options, descriptor, schema, stream);
}
} // namespace flight
diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h
index 682a6f1..3886360 100644
--- a/cpp/src/arrow/flight/client.h
+++ b/cpp/src/arrow/flight/client.h
@@ -20,6 +20,7 @@
#pragma once
+#include <chrono>
#include <memory>
#include <string>
#include <vector>
@@ -41,6 +42,21 @@ namespace flight {
class ClientAuthHandler;
+/// \brief A duration type for Flight call timeouts.
+typedef std::chrono::duration<double, std::chrono::seconds::period> TimeoutDuration;
+
+/// \brief Hints to the underlying RPC layer for Arrow Flight calls.
+class ARROW_EXPORT FlightCallOptions {
+ public:
+ /// Create a default set of call options.
+ FlightCallOptions();
+
+ /// \brief An optional timeout for this call. Negative durations
+ /// mean an implementation-defined default behavior will be used
+ /// instead. This is the default value.
+ TimeoutDuration timeout;
+};
+
/// \brief Client class for Arrow Flight RPC services (gRPC-based).
/// API experimental for now
class ARROW_EXPORT FlightClient {
@@ -57,29 +73,47 @@ class ARROW_EXPORT FlightClient {
std::unique_ptr<FlightClient>* client);
/// \brief Authenticate to the server using the given handler.
+ /// \param[in] options Per-RPC options
+ /// \param[in] auth_handler The authentication mechanism to use
/// \return Status OK if the client authenticated successfully
- Status Authenticate(std::unique_ptr<ClientAuthHandler> auth_handler);
+ Status Authenticate(const FlightCallOptions& options,
+ std::unique_ptr<ClientAuthHandler> auth_handler);
/// \brief Perform the indicated action, returning an iterator to the stream
/// of results, if any
+ /// \param[in] options Per-RPC options
/// \param[in] action the action to be performed
/// \param[out] results an iterator object for reading the returned results
/// \return Status
- Status DoAction(const Action& action, std::unique_ptr<ResultStream>* results);
+ Status DoAction(const FlightCallOptions& options, const Action& action,
+ std::unique_ptr<ResultStream>* results);
+ Status DoAction(const Action& action, std::unique_ptr<ResultStream>* results) {
+ return DoAction({}, action, results);
+ }
/// \brief Retrieve a list of available Action types
+ /// \param[in] options Per-RPC options
/// \param[out] actions the available actions
/// \return Status
- Status ListActions(std::vector<ActionType>* actions);
+ Status ListActions(const FlightCallOptions& options, std::vector<ActionType>* actions);
+ Status ListActions(std::vector<ActionType>* actions) {
+ return ListActions({}, actions);
+ }
/// \brief Request access plan for a single flight, which may be an existing
/// dataset or a command to be executed
+ /// \param[in] options Per-RPC options
/// \param[in] descriptor the dataset request, whether a named dataset or
/// command
/// \param[out] info the FlightInfo describing where to access the dataset
/// \return Status
- Status GetFlightInfo(const FlightDescriptor& descriptor,
+ Status GetFlightInfo(const FlightCallOptions& options,
+ const FlightDescriptor& descriptor,
std::unique_ptr<FlightInfo>* info);
+ Status GetFlightInfo(const FlightDescriptor& descriptor,
+ std::unique_ptr<FlightInfo>* info) {
+ return GetFlightInfo({}, descriptor, info);
+ }
/// \brief List all available flights known to the server
/// \param[out] listing an iterator that returns a FlightInfo for each flight
@@ -87,27 +121,40 @@ class ARROW_EXPORT FlightClient {
Status ListFlights(std::unique_ptr<FlightListing>* listing);
/// \brief List available flights given indicated filter criteria
+ /// \param[in] options Per-RPC options
/// \param[in] criteria the filter criteria (opaque)
/// \param[out] listing an iterator that returns a FlightInfo for each flight
/// \return Status
- Status ListFlights(const Criteria& criteria, std::unique_ptr<FlightListing>* listing);
+ Status ListFlights(const FlightCallOptions& options, const Criteria& criteria,
+ std::unique_ptr<FlightListing>* listing);
/// \brief Given a flight ticket and schema, request to be sent the
/// stream. Returns record batch stream reader
+ /// \param[in] options Per-RPC options
/// \param[in] ticket The flight ticket to use
/// \param[out] stream the returned RecordBatchReader
/// \return Status
- Status DoGet(const Ticket& ticket, std::unique_ptr<RecordBatchReader>* stream);
+ Status DoGet(const FlightCallOptions& options, const Ticket& ticket,
+ std::unique_ptr<RecordBatchReader>* stream);
+ Status DoGet(const Ticket& ticket, std::unique_ptr<RecordBatchReader>* stream) {
+ return DoGet({}, ticket, stream);
+ }
/// \brief Upload data to a Flight described by the given
/// descriptor. The caller must call Close() on the returned stream
/// once they are done writing.
+ /// \param[in] options Per-RPC options
/// \param[in] descriptor the descriptor of the stream
/// \param[in] schema the schema for the data to upload
/// \param[out] stream a writer to write record batches to
/// \return Status
- Status DoPut(const FlightDescriptor& descriptor, const std::shared_ptr<Schema>& schema,
+ Status DoPut(const FlightCallOptions& options, const FlightDescriptor& descriptor,
+ const std::shared_ptr<Schema>& schema,
std::unique_ptr<ipc::RecordBatchWriter>* stream);
+ Status DoPut(const FlightDescriptor& descriptor, const std::shared_ptr<Schema>& schema,
+ std::unique_ptr<ipc::RecordBatchWriter>* stream) {
+ return DoPut({}, descriptor, schema, stream);
+ }
private:
FlightClient();
diff --git a/cpp/src/arrow/flight/flight-test.cc b/cpp/src/arrow/flight/flight-test.cc
index 9fea1de..504779d 100644
--- a/cpp/src/arrow/flight/flight-test.cc
+++ b/cpp/src/arrow/flight/flight-test.cc
@@ -359,8 +359,37 @@ TEST_F(TestFlightClient, Issue5095) {
ASSERT_THAT(status.message(), ::testing::HasSubstr("No data"));
}
+TEST_F(TestFlightClient, TimeoutFires) {
+ // Server does not exist on this port, so call should fail
+ std::unique_ptr<FlightClient> client;
+ ASSERT_OK(FlightClient::Connect("localhost", 30001, &client));
+ FlightCallOptions options;
+ options.timeout = TimeoutDuration{0.2};
+ std::unique_ptr<FlightInfo> info;
+ auto start = std::chrono::system_clock::now();
+ Status status = client->GetFlightInfo(options, FlightDescriptor{}, &info);
+ auto end = std::chrono::system_clock::now();
+ EXPECT_LE(end - start, std::chrono::milliseconds{400});
+ ASSERT_RAISES(IOError, status);
+}
+
+TEST_F(TestFlightClient, NoTimeout) {
+ // Call should complete quickly, so timeout should not fire
+ FlightCallOptions options;
+ options.timeout = TimeoutDuration{0.5};
+ std::unique_ptr<FlightInfo> info;
+ auto start = std::chrono::system_clock::now();
+ auto descriptor = FlightDescriptor::Path({"examples", "ints"});
+ Status status = client_->GetFlightInfo(options, descriptor, &info);
+ auto end = std::chrono::system_clock::now();
+ EXPECT_LE(end - start, std::chrono::milliseconds{600});
+ ASSERT_OK(status);
+ ASSERT_NE(nullptr, info);
+}
+
TEST_F(TestAuthHandler, PassAuthenticatedCalls) {
ASSERT_OK(client_->Authenticate(
+ {},
std::unique_ptr<ClientAuthHandler>(new TestClientAuthHandler("user", "p4ssw0rd"))));
Status status;
@@ -437,6 +466,7 @@ TEST_F(TestAuthHandler, FailUnauthenticatedCalls) {
TEST_F(TestAuthHandler, CheckPeerIdentity) {
ASSERT_OK(client_->Authenticate(
+ {},
std::unique_ptr<ClientAuthHandler>(new TestClientAuthHandler("user", "p4ssw0rd"))));
Action action;
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/CallOption.java b/java/flight/src/main/java/org/apache/arrow/flight/CallOption.java
new file mode 100644
index 0000000..d3ee3ab
--- /dev/null
+++ b/java/flight/src/main/java/org/apache/arrow/flight/CallOption.java
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+/**
+ * Per-call RPC options. These are hints to the underlying RPC layer and may not be respected.
+ */
+public interface CallOption {
+}
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/CallOptions.java b/java/flight/src/main/java/org/apache/arrow/flight/CallOptions.java
new file mode 100644
index 0000000..946434b
--- /dev/null
+++ b/java/flight/src/main/java/org/apache/arrow/flight/CallOptions.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.util.concurrent.TimeUnit;
+
+import io.grpc.stub.AbstractStub;
+
+/**
+ * Common call options.
+ */
+public class CallOptions {
+ public static CallOption timeout(long duration, TimeUnit unit) {
+ return new Timeout(duration, unit);
+ }
+
+ static <T extends AbstractStub<T>> T wrapStub(T stub, CallOption[] options) {
+ for (CallOption option : options) {
+ if (option instanceof GrpcCallOption) {
+ stub = ((GrpcCallOption) option).wrapStub(stub);
+ }
+ }
+ return stub;
+ }
+
+ private static class Timeout implements GrpcCallOption {
+ long timeout;
+ TimeUnit timeoutUnit;
+
+ Timeout(long timeout, TimeUnit timeoutUnit) {
+ this.timeout = timeout;
+ this.timeoutUnit = timeoutUnit;
+ }
+
+ @Override
+ public <T extends AbstractStub<T>> T wrapStub(T stub) {
+ return stub.withDeadlineAfter(timeout, timeoutUnit);
+ }
+ }
+
+ interface GrpcCallOption extends CallOption {
+ <T extends AbstractStub<T>> T wrapStub(T stub);
+ }
+}
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java
index b3addbd..1ca2273 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java
@@ -86,25 +86,41 @@ public class FlightClient implements AutoCloseable {
/**
* Get a list of available flights.
+ *
* @param criteria Criteria for selecting flights
+ * @param options RPC-layer hints for the call.
* @return FlightInfo Iterable
*/
- public Iterable<FlightInfo> listFlights(Criteria criteria) {
- return ImmutableList.copyOf(blockingStub.listFlights(criteria.asCriteria()))
+ public Iterable<FlightInfo> listFlights(Criteria criteria, CallOption... options) {
+ return ImmutableList.copyOf(CallOptions.wrapStub(blockingStub, options).listFlights(criteria.asCriteria()))
.stream()
- .map(t -> new FlightInfo(t))
+ .map(FlightInfo::new)
.collect(Collectors.toList());
}
- public Iterable<ActionType> listActions() {
- return ImmutableList.copyOf(blockingStub.listActions(Empty.getDefaultInstance()))
+ /**
+ * List actions available on the Flight service.
+ *
+ * @param options RPC-layer hints for the call.
+ */
+ public Iterable<ActionType> listActions(CallOption... options) {
+ return ImmutableList.copyOf(CallOptions.wrapStub(blockingStub, options)
+ .listActions(Empty.getDefaultInstance()))
.stream()
- .map(t -> new ActionType(t))
+ .map(ActionType::new)
.collect(Collectors.toList());
}
- public Iterator<Result> doAction(Action action) {
- return Iterators.transform(blockingStub.doAction(action.toProtocol()), t -> new Result(t));
+ /**
+ * Perform an action on the Flight service.
+ *
+ * @param action The action to perform.
+ * @param options RPC-layer hints for this call.
+ * @return An iterator of results.
+ */
+ public Iterator<Result> doAction(Action action, CallOption... options) {
+ return Iterators
+ .transform(CallOptions.wrapStub(blockingStub, options).doAction(action.toProtocol()), Result::new);
}
public void authenticateBasic(String username, String password) {
@@ -112,9 +128,15 @@ public class FlightClient implements AutoCloseable {
authenticate(basicClient);
}
- public void authenticate(ClientAuthHandler handler) {
+ /**
+ * Authenticate against the Flight service.
+ *
+ * @param options RPC-layer hints for this call.
+ * @param handler The auth mechanism to use.
+ */
+ public void authenticate(ClientAuthHandler handler, CallOption... options) {
Preconditions.checkArgument(!authInterceptor.hasAuthHandler(), "Auth already completed.");
- ClientAuthWrapper.doClientAuth(handler, asyncStub);
+ ClientAuthWrapper.doClientAuth(handler, CallOptions.wrapStub(asyncStub, options));
authInterceptor.setAuthHandler(handler);
}
@@ -122,16 +144,19 @@ public class FlightClient implements AutoCloseable {
* Create or append a descriptor with another stream.
* @param descriptor FlightDescriptor
* @param root VectorSchemaRoot
+ * @param options RPC-layer hints for this call.
* @return ClientStreamListener
*/
- public ClientStreamListener startPut(FlightDescriptor descriptor, VectorSchemaRoot root) {
+ public ClientStreamListener startPut(
+ FlightDescriptor descriptor, VectorSchemaRoot root, CallOption... options) {
Preconditions.checkNotNull(descriptor);
Preconditions.checkNotNull(root);
SetStreamObserver<PutResult> resultObserver = new SetStreamObserver<>();
+ final io.grpc.CallOptions callOptions = CallOptions.wrapStub(asyncStub, options).getCallOptions();
ClientCallStreamObserver<ArrowMessage> observer = (ClientCallStreamObserver<ArrowMessage>)
asyncClientStreamingCall(
- authInterceptor.interceptCall(doPutDescriptor, asyncStub.getCallOptions(), channel), resultObserver);
+ authInterceptor.interceptCall(doPutDescriptor, callOptions, channel), resultObserver);
// send the schema to start.
ArrowMessage message = new ArrowMessage(descriptor.toProtocol(), root.getSchema());
observer.onNext(message);
@@ -140,13 +165,24 @@ public class FlightClient implements AutoCloseable {
observer, resultObserver.getFuture());
}
- public FlightInfo getInfo(FlightDescriptor descriptor) {
- return new FlightInfo(blockingStub.getFlightInfo(descriptor.toProtocol()));
+ /**
+ * Get info on a stream.
+ * @param descriptor The descriptor for the stream.
+ * @param options RPC-layer hints for this call.
+ */
+ public FlightInfo getInfo(FlightDescriptor descriptor, CallOption... options) {
+ return new FlightInfo(CallOptions.wrapStub(blockingStub, options).getFlightInfo(descriptor.toProtocol()));
}
- public FlightStream getStream(Ticket ticket) {
+ /**
+ * Retrieve a stream from the server.
+ * @param ticket The ticket granting access to the data stream.
+ * @param options RPC-layer hints for this call.
+ */
+ public FlightStream getStream(Ticket ticket, CallOption... options) {
+ final io.grpc.CallOptions callOptions = CallOptions.wrapStub(asyncStub, options).getCallOptions();
ClientCall<Flight.Ticket, ArrowMessage> call =
- authInterceptor.interceptCall(doGetDescriptor, asyncStub.getCallOptions(), channel);
+ authInterceptor.interceptCall(doGetDescriptor, callOptions, channel);
FlightStream stream = new FlightStream(
allocator,
PENDING_REQUESTS,
@@ -157,27 +193,27 @@ public class FlightClient implements AutoCloseable {
ClientResponseObserver<Flight.Ticket, ArrowMessage> clientResponseObserver =
new ClientResponseObserver<Flight.Ticket, ArrowMessage>() {
- @Override
- public void beforeStart(ClientCallStreamObserver<org.apache.arrow.flight.impl.Flight.Ticket> requestStream) {
- requestStream.disableAutoInboundFlowControl();
- }
+ @Override
+ public void beforeStart(ClientCallStreamObserver<org.apache.arrow.flight.impl.Flight.Ticket> requestStream) {
+ requestStream.disableAutoInboundFlowControl();
+ }
- @Override
- public void onNext(ArrowMessage value) {
- delegate.onNext(value);
- }
+ @Override
+ public void onNext(ArrowMessage value) {
+ delegate.onNext(value);
+ }
- @Override
- public void onError(Throwable t) {
- delegate.onError(t);
- }
+ @Override
+ public void onError(Throwable t) {
+ delegate.onError(t);
+ }
- @Override
- public void onCompleted() {
- delegate.onCompleted();
- }
+ @Override
+ public void onCompleted() {
+ delegate.onCompleted();
+ }
- };
+ };
asyncServerStreamingCall(call, ticket.toProtocol(), clientResponseObserver);
return stream;
@@ -266,4 +302,5 @@ public class FlightClient implements AutoCloseable {
channel.shutdown().awaitTermination(5, TimeUnit.SECONDS);
allocator.close();
}
+
}
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/TestCallOptions.java b/java/flight/src/test/java/org/apache/arrow/flight/TestCallOptions.java
new file mode 100644
index 0000000..c95cb98
--- /dev/null
+++ b/java/flight/src/test/java/org/apache/arrow/flight/TestCallOptions.java
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.io.IOException;
+import java.time.Duration;
+import java.time.Instant;
+import java.util.Iterator;
+import java.util.concurrent.TimeUnit;
+import java.util.function.Consumer;
+
+import org.apache.arrow.flight.auth.ServerAuthHandler;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class TestCallOptions {
+
+ @Test
+ public void timeoutFires() {
+ test((client) -> {
+ Instant start = Instant.now();
+ Iterator<Result> results = client.doAction(new Action("hang"), CallOptions.timeout(1, TimeUnit.SECONDS));
+ try {
+ results.next();
+ Assert.fail("Call should have failed");
+ } catch (RuntimeException e) {
+ Assert.assertTrue(e.getMessage(), e.getMessage().contains("deadline exceeded"));
+ }
+ Instant end = Instant.now();
+ Assert.assertTrue("Call took over 1500 ms despite timeout", Duration.between(start, end).toMillis() < 1500);
+ });
+ }
+
+ @Test
+ public void underTimeout() {
+ test((client) -> {
+ Instant start = Instant.now();
+ // This shouldn't fail and it should complete within the timeout
+ Iterator<Result> results = client.doAction(new Action("fast"), CallOptions.timeout(2, TimeUnit.SECONDS));
+ Assert.assertArrayEquals(new byte[]{42, 42}, results.next().getBody());
+ Instant end = Instant.now();
+ Assert.assertTrue("Call took over 2500 ms despite timeout", Duration.between(start, end).toMillis() < 2500);
+ });
+ }
+
+ void test(Consumer<FlightClient> testFn) {
+ try (
+ BufferAllocator a = new RootAllocator(Long.MAX_VALUE);
+ Producer producer = new Producer(a);
+ FlightServer s =
+ FlightTestUtil.getStartedServer((port) -> new FlightServer(a, port, producer, ServerAuthHandler.NO_OP));
+ FlightClient client = new FlightClient(a, new Location(FlightTestUtil.LOCALHOST, s.getPort()))) {
+ testFn.accept(client);
+ } catch (InterruptedException | IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ static class Producer extends NoOpFlightProducer implements AutoCloseable {
+
+ private final BufferAllocator allocator;
+
+ Producer(BufferAllocator allocator) {
+ this.allocator = allocator;
+ }
+
+ @Override
+ public void close() {
+ }
+
+ @Override
+ public Result doAction(CallContext context, Action action) {
+ switch (action.getType()) {
+ case "hang": {
+ try {
+ Thread.sleep(25000);
+ } catch (InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+ return new Result(new byte[]{});
+ }
+ case "fast": {
+ try {
+ Thread.sleep(500);
+ } catch (InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+ return new Result(new byte[]{42, 42});
+ }
+ default: {
+ throw new UnsupportedOperationException(action.getType());
+ }
+ }
+ }
+ }
+}
diff --git a/python/examples/flight/client.py b/python/examples/flight/client.py
index 6d5b452..d1e60d2 100644
--- a/python/examples/flight/client.py
+++ b/python/examples/flight/client.py
@@ -132,6 +132,15 @@ def main():
host, port = args.host.split(':')
port = int(port)
client = pyarrow.flight.FlightClient.connect(host, port)
+ while True:
+ try:
+ action = pyarrow.flight.Action("healthcheck", b"")
+ options = pyarrow.flight.FlightCallOptions(timeout=1)
+ list(client.do_action(action, options=options))
+ break
+ except pyarrow.ArrowIOError as e:
+ if "Deadline" in str(e):
+ print("Server is not ready, waiting...")
commands[args.action](args, client)
diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx
index e6d9574..271d135 100644
--- a/python/pyarrow/_flight.pyx
+++ b/python/pyarrow/_flight.pyx
@@ -30,6 +30,29 @@ from pyarrow.ipc import _ReadPandasOption
import pyarrow.lib as lib
+cdef CFlightCallOptions DEFAULT_CALL_OPTIONS
+
+
+cdef class FlightCallOptions:
+ """RPC-layer options for a Flight call."""
+
+ cdef:
+ CFlightCallOptions options
+
+ def __init__(self, timeout=None):
+ if timeout is not None:
+ self.options.timeout = CTimeoutDuration(timeout)
+
+ @staticmethod
+ cdef CFlightCallOptions* unwrap(obj):
+ if not obj:
+ return &DEFAULT_CALL_OPTIONS
+ elif isinstance(obj, FlightCallOptions):
+ return &((<FlightCallOptions> obj).options)
+ raise TypeError("Expected a FlightCallOptions object, not "
+ "'{}'".format(type(obj)))
+
+
cdef class Action:
"""An action executable on a Flight service."""
cdef:
@@ -312,24 +335,30 @@ cdef class FlightClient:
return result
- def authenticate(self, auth_handler):
+ def authenticate(self, auth_handler, options: FlightCallOptions = None):
"""Authenticate to the server."""
- cdef unique_ptr[CClientAuthHandler] handler
+ cdef:
+ unique_ptr[CClientAuthHandler] handler
+ CFlightCallOptions* c_options = FlightCallOptions.unwrap(options)
+
if not isinstance(auth_handler, ClientAuthHandler):
raise TypeError(
"FlightClient.authenticate takes a ClientAuthHandler, "
"not '{}'".format(type(auth_handler)))
handler.reset((<ClientAuthHandler> auth_handler).to_handler())
with nogil:
- check_status(self.client.get().Authenticate(move(handler)))
+ check_status(self.client.get().Authenticate(deref(c_options),
+ move(handler)))
- def list_actions(self):
+ def list_actions(self, options: FlightCallOptions = None):
"""List the actions available on a service."""
cdef:
vector[CActionType] results
+ CFlightCallOptions* c_options = FlightCallOptions.unwrap(options)
with nogil:
- check_status(self.client.get().ListActions(&results))
+ check_status(
+ self.client.get().ListActions(deref(c_options), &results))
result = []
for action_type in results:
@@ -339,13 +368,15 @@ cdef class FlightClient:
return result
- def do_action(self, action: Action):
+ def do_action(self, action: Action, options: FlightCallOptions = None):
"""Execute an action on a service."""
cdef:
unique_ptr[CResultStream] results
Result result
+ CFlightCallOptions* c_options = FlightCallOptions.unwrap(options)
with nogil:
- check_status(self.client.get().DoAction(action.action, &results))
+ check_status(self.client.get().DoAction(deref(c_options),
+ action.action, &results))
while True:
result = Result.__new__(Result)
@@ -355,14 +386,17 @@ cdef class FlightClient:
break
yield result
- def list_flights(self):
+ def list_flights(self, options: FlightCallOptions = None):
"""List the flights available on a service."""
cdef:
unique_ptr[CFlightListing] listing
FlightInfo result
+ CFlightCallOptions* c_options = FlightCallOptions.unwrap(options)
+ CCriteria c_criteria
with nogil:
- check_status(self.client.get().ListFlights(&listing))
+ check_status(self.client.get().ListFlights(deref(c_options),
+ c_criteria, &listing))
while True:
result = FlightInfo.__new__(FlightInfo)
@@ -372,40 +406,46 @@ cdef class FlightClient:
break
yield result
- def get_flight_info(self, descriptor: FlightDescriptor):
+ def get_flight_info(self, descriptor: FlightDescriptor,
+ options: FlightCallOptions = None):
"""Request information about an available flight."""
cdef:
FlightInfo result = FlightInfo.__new__(FlightInfo)
+ CFlightCallOptions* c_options = FlightCallOptions.unwrap(options)
with nogil:
check_status(self.client.get().GetFlightInfo(
- descriptor.descriptor, &result.info))
+ deref(c_options), descriptor.descriptor, &result.info))
return result
- def do_get(self, ticket: Ticket):
+ def do_get(self, ticket: Ticket, options: FlightCallOptions = None):
"""Request the data for a flight."""
cdef:
# TODO: introduce unwrap
CTicket c_ticket
unique_ptr[CRecordBatchReader] reader
+ CFlightCallOptions* c_options = FlightCallOptions.unwrap(options)
c_ticket.ticket = ticket.ticket
with nogil:
- check_status(self.client.get().DoGet(c_ticket, &reader))
+ check_status(
+ self.client.get().DoGet(deref(c_options), c_ticket, &reader))
result = FlightRecordBatchReader()
result.reader.reset(reader.release())
return result
- def do_put(self, descriptor: FlightDescriptor, schema: Schema):
+ def do_put(self, descriptor: FlightDescriptor, schema: Schema,
+ options: FlightCallOptions = None):
"""Upload data to a flight."""
cdef:
shared_ptr[CSchema] c_schema = pyarrow_unwrap_schema(schema)
unique_ptr[CRecordBatchWriter] writer
+ CFlightCallOptions* c_options = FlightCallOptions.unwrap(options)
with nogil:
check_status(self.client.get().DoPut(
- descriptor.descriptor, c_schema, &writer))
+ deref(c_options), descriptor.descriptor, c_schema, &writer))
result = FlightRecordBatchWriter()
result.writer.reset(writer.release())
return result
@@ -932,5 +972,10 @@ cdef class FlightServerBase:
request to finish. Instead, call this method from a background
thread.
"""
- if self.server.get() != NULL:
- self.server.get().Shutdown()
+ # Must not hold the GIL: shutdown waits for pending RPCs to
+ # complete. Holding the GIL means Python-implemented Flight
+ # methods will never get to run, so this will hang
+ # indefinitely.
+ with nogil:
+ if self.server.get() != NULL:
+ self.server.get().Shutdown()
diff --git a/python/pyarrow/flight.py b/python/pyarrow/flight.py
index 33bc168..9b881bd 100644
--- a/python/pyarrow/flight.py
+++ b/python/pyarrow/flight.py
@@ -19,6 +19,7 @@ from pyarrow._flight import ( # noqa
Action,
ActionType,
DescriptorType,
+ FlightCallOptions,
FlightClient,
FlightDescriptor,
FlightEndpoint,
diff --git a/python/pyarrow/includes/libarrow_flight.pxd b/python/pyarrow/includes/libarrow_flight.pxd
index f1883a9..2d083e3 100644
--- a/python/pyarrow/includes/libarrow_flight.pxd
+++ b/python/pyarrow/includes/libarrow_flight.pxd
@@ -141,23 +141,36 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil:
cdef cppclass CServerCallContext" arrow::flight::ServerCallContext":
c_string& peer_identity()
+ cdef cppclass CTimeoutDuration" arrow::flight::TimeoutDuration":
+ CTimeoutDuration(double)
+
+ cdef cppclass CFlightCallOptions" arrow::flight::FlightCallOptions":
+ CFlightCallOptions()
+ CTimeoutDuration timeout
+
cdef cppclass CFlightClient" arrow::flight::FlightClient":
@staticmethod
CStatus Connect(const c_string& host, int port,
unique_ptr[CFlightClient]* client)
- CStatus Authenticate(unique_ptr[CClientAuthHandler] auth_handler)
+ CStatus Authenticate(CFlightCallOptions& options,
+ unique_ptr[CClientAuthHandler] auth_handler)
- CStatus DoAction(CAction& action, unique_ptr[CResultStream]* results)
- CStatus ListActions(vector[CActionType]* actions)
+ CStatus DoAction(CFlightCallOptions& options, CAction& action,
+ unique_ptr[CResultStream]* results)
+ CStatus ListActions(CFlightCallOptions& options,
+ vector[CActionType]* actions)
- CStatus ListFlights(unique_ptr[CFlightListing]* listing)
- CStatus GetFlightInfo(CFlightDescriptor& descriptor,
+ CStatus ListFlights(CFlightCallOptions& options, CCriteria criteria,
+ unique_ptr[CFlightListing]* listing)
+ CStatus GetFlightInfo(CFlightCallOptions& options,
+ CFlightDescriptor& descriptor,
unique_ptr[CFlightInfo]* info)
- CStatus DoGet(CTicket& ticket,
+ CStatus DoGet(CFlightCallOptions& options, CTicket& ticket,
unique_ptr[CRecordBatchReader]* stream)
- CStatus DoPut(CFlightDescriptor& descriptor,
+ CStatus DoPut(CFlightCallOptions& options,
+ CFlightDescriptor& descriptor,
shared_ptr[CSchema]& schema,
unique_ptr[CRecordBatchWriter]* stream)
diff --git a/python/pyarrow/tests/test_flight.py b/python/pyarrow/tests/test_flight.py
index 74e3bc4..f4fc719 100644
--- a/python/pyarrow/tests/test_flight.py
+++ b/python/pyarrow/tests/test_flight.py
@@ -20,6 +20,7 @@ import base64
import contextlib
import socket
import threading
+import time
import pytest
@@ -93,6 +94,14 @@ class InvalidStreamFlightServer(flight.FlightServerBase):
return flight.GeneratorStream(self.schema, [table1, table2])
+class SlowFlightServer(flight.FlightServerBase):
+ """A Flight server that delays its responses to test timeouts."""
+
+ def do_action(self, context, action):
+ time.sleep(0.5)
+ return iter([])
+
+
class HttpBasicServerAuthHandler(flight.ServerAuthHandler):
"""An example implementation of HTTP basic authentication."""
@@ -257,6 +266,26 @@ def test_flight_invalid_generator_stream():
client.do_get(flight.Ticket(b'')).read_all()
+def test_timeout_fires():
+ """Make sure timeouts fire on slow requests."""
+ # Do this in a separate thread so that if it fails, we don't hang
+ # the entire test process
+ with flight_server(SlowFlightServer) as server_port:
+ client = flight.FlightClient.connect('localhost', server_port)
+ action = flight.Action("", b"")
+ options = flight.FlightCallOptions(timeout=0.2)
+ with pytest.raises(pa.ArrowIOError, match="Deadline Exceeded"):
+ list(client.do_action(action, options=options))
+
+
+def test_timeout_passes():
+ """Make sure timeouts do not fire on fast requests."""
+ with flight_server(ConstantFlightServer) as server_port:
+ client = flight.FlightClient.connect('localhost', server_port)
+ options = flight.FlightCallOptions(timeout=0.2)
+ client.do_get(flight.Ticket(b''), options=options).read_all()
+
+
basic_auth_handler = HttpBasicServerAuthHandler(creds={
b"test": b"p4ssw0rd",
})