You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ap...@apache.org on 2019/07/25 19:09:07 UTC
[arrow] branch master updated: ARROW-5681: [FlightRPC] Add
Flight-specific error APIs
This is an automated email from the ASF dual-hosted git repository.
apitrou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new a8abbe3 ARROW-5681: [FlightRPC] Add Flight-specific error APIs
a8abbe3 is described below
commit a8abbe3214bb8810b5e6988c55c36c5fa20da1e6
Author: David Li <li...@gmail.com>
AuthorDate: Thu Jul 25 21:08:49 2019 +0200
ARROW-5681: [FlightRPC] Add Flight-specific error APIs
This adds a set of exceptions (in Java/Python) and status codes (in C++) for Flight, to convey Flight-specific issues and to let server implementations provide more semantic error messages to clients.
The error codes defined are a subset of the gRPC ones. In Flight/Java, you could always raise `io.grpc.StatusRuntimeException` yourself, but we want to abstract away from the gRPC APIs.
There are not tests for conveying these statuses cross-language. That should be done as part of ARROW-5875.
Travis: https://travis-ci.com/lihalite/arrow/builds/118452703
AppVeyor: https://ci.appveyor.com/project/lihalite/arrow/builds/25849820
Closes #4840 from lihalite/flight-status and squashes the following commits:
40bc687ac <David Li> Implement Flight-specific status codes in Python
fb51b45d0 <David Li> Add Flight-specific status codes in C++
9af9606e3 <David Li> Add statuses to Flight in Java
Authored-by: David Li <li...@gmail.com>
Signed-off-by: Antoine Pitrou <an...@python.org>
---
cpp/src/arrow/flight/flight-test.cc | 41 ++-
cpp/src/arrow/flight/internal.cc | 116 +++++++-
cpp/src/arrow/flight/test-util.cc | 6 +-
cpp/src/arrow/flight/types.cc | 40 +++
cpp/src/arrow/flight/types.h | 54 ++++
cpp/src/arrow/python/flight.cc | 62 ++--
cpp/src/arrow/python/flight.h | 50 ++--
cpp/src/arrow/status.h | 6 +
.../java/org/apache/arrow/flight/ActionType.java | 8 +
.../org/apache/arrow/flight/AsyncPutListener.java | 4 +-
.../java/org/apache/arrow/flight/CallStatus.java | 115 +++++++
.../apache/arrow/flight/FlightBindingService.java | 5 +-
.../java/org/apache/arrow/flight/FlightClient.java | 90 +++---
...uthHandler.java => FlightRuntimeException.java} | 37 +--
.../java/org/apache/arrow/flight/FlightServer.java | 14 +-
.../org/apache/arrow/flight/FlightService.java | 22 +-
.../org/apache/arrow/flight/FlightStatusCode.java | 78 +++++
.../apache/arrow/flight/NoOpFlightProducer.java | 12 +-
.../java/org/apache/arrow/flight/StreamPipe.java | 3 +-
.../org/apache/arrow/flight/SyncPutListener.java | 4 +-
.../arrow/flight/auth/ClientAuthHandler.java | 8 +-
.../arrow/flight/auth/ClientAuthWrapper.java | 33 +-
.../arrow/flight/auth/ServerAuthHandler.java | 2 +-
.../arrow/flight/auth/ServerAuthInterceptor.java | 2 +-
.../arrow/flight/auth/ServerAuthWrapper.java | 17 +-
.../apache/arrow/flight/example/InMemoryStore.java | 3 +-
.../org/apache/arrow/flight/grpc/StatusUtils.java | 192 ++++++++++++
.../org/apache/arrow/flight/FlightTestUtil.java | 16 +
.../arrow/flight/TestApplicationMetadata.java | 5 +-
.../java/org/apache/arrow/flight/TestAuth.java | 2 +-
.../apache/arrow/flight/TestBasicOperation.java | 12 +-
.../test/java/org/apache/arrow/flight/TestTls.java | 14 +-
.../apache/arrow/flight/auth/TestBasicAuth.java | 22 +-
python/pyarrow/_flight.pyx | 331 ++++++++++++++-------
python/pyarrow/flight.py | 7 +
python/pyarrow/includes/libarrow_flight.pxd | 75 +++--
python/pyarrow/tests/test_flight.py | 73 ++++-
37 files changed, 1237 insertions(+), 344 deletions(-)
diff --git a/cpp/src/arrow/flight/flight-test.cc b/cpp/src/arrow/flight/flight-test.cc
index 901f626..68c9146 100644
--- a/cpp/src/arrow/flight/flight-test.cc
+++ b/cpp/src/arrow/flight/flight-test.cc
@@ -221,6 +221,41 @@ TEST(TestFlight, RoundTripTypes) {
ASSERT_EQ(info->total_bytes(), info_deserialized->total_bytes());
}
+TEST(TestFlight, RoundtripStatus) {
+ // Make sure status codes round trip through our conversions
+
+ std::shared_ptr<FlightStatusDetail> detail;
+ detail = FlightStatusDetail::UnwrapStatus(
+ MakeFlightError(FlightStatusCode::Internal, "Test message"));
+ ASSERT_NE(nullptr, detail);
+ ASSERT_EQ(FlightStatusCode::Internal, detail->code());
+
+ detail = FlightStatusDetail::UnwrapStatus(
+ MakeFlightError(FlightStatusCode::TimedOut, "Test message"));
+ ASSERT_NE(nullptr, detail);
+ ASSERT_EQ(FlightStatusCode::TimedOut, detail->code());
+
+ detail = FlightStatusDetail::UnwrapStatus(
+ MakeFlightError(FlightStatusCode::Cancelled, "Test message"));
+ ASSERT_NE(nullptr, detail);
+ ASSERT_EQ(FlightStatusCode::Cancelled, detail->code());
+
+ detail = FlightStatusDetail::UnwrapStatus(
+ MakeFlightError(FlightStatusCode::Unauthenticated, "Test message"));
+ ASSERT_NE(nullptr, detail);
+ ASSERT_EQ(FlightStatusCode::Unauthenticated, detail->code());
+
+ detail = FlightStatusDetail::UnwrapStatus(
+ MakeFlightError(FlightStatusCode::Unauthorized, "Test message"));
+ ASSERT_NE(nullptr, detail);
+ ASSERT_EQ(FlightStatusCode::Unauthorized, detail->code());
+
+ detail = FlightStatusDetail::UnwrapStatus(
+ MakeFlightError(FlightStatusCode::Unavailable, "Test message"));
+ ASSERT_NE(nullptr, detail);
+ ASSERT_EQ(FlightStatusCode::Unavailable, detail->code());
+}
+
// ----------------------------------------------------------------------
// Client tests
@@ -527,7 +562,7 @@ TEST_F(TestFlightClient, GetFlightInfoNotFound) {
// XXX Ideally should be Invalid (or KeyError), but gRPC doesn't support
// multiple error codes.
auto st = client_->GetFlightInfo(descr, &info);
- ASSERT_RAISES(IOError, st);
+ ASSERT_RAISES(Invalid, st);
ASSERT_NE(st.message().find("Flight not found"), std::string::npos);
}
@@ -603,12 +638,12 @@ TEST_F(TestFlightClient, Issue5095) {
Ticket ticket1{"ARROW-5095-fail"};
std::unique_ptr<FlightStreamReader> stream;
Status status = client_->DoGet(ticket1, &stream);
- ASSERT_RAISES(IOError, status);
+ ASSERT_RAISES(UnknownError, status);
ASSERT_THAT(status.message(), ::testing::HasSubstr("Server-side error"));
Ticket ticket2{"ARROW-5095-success"};
status = client_->DoGet(ticket2, &stream);
- ASSERT_RAISES(IOError, status);
+ ASSERT_RAISES(KeyError, status);
ASSERT_THAT(status.message(), ::testing::HasSubstr("No data"));
}
diff --git a/cpp/src/arrow/flight/internal.cc b/cpp/src/arrow/flight/internal.cc
index 56fc862..ae8e819 100644
--- a/cpp/src/arrow/flight/internal.cc
+++ b/cpp/src/arrow/flight/internal.cc
@@ -37,6 +37,7 @@
#include "arrow/memory_pool.h"
#include "arrow/status.h"
#include "arrow/util/logging.h"
+#include "arrow/util/string_builder.h"
namespace arrow {
namespace flight {
@@ -49,27 +50,118 @@ Status FromGrpcStatus(const grpc::Status& grpc_status) {
return Status::OK();
}
- if (grpc_status.error_code() == grpc::StatusCode::UNIMPLEMENTED) {
- return Status::NotImplemented("gRPC returned unimplemented error, with message: ",
+ switch (grpc_status.error_code()) {
+ case grpc::StatusCode::OK:
+ return Status::OK();
+ case grpc::StatusCode::CANCELLED:
+ return Status::IOError("gRPC cancelled call, with message: ",
+ grpc_status.error_message())
+ .WithDetail(std::make_shared<FlightStatusDetail>(FlightStatusCode::Cancelled));
+ case grpc::StatusCode::UNKNOWN:
+ return Status::UnknownError("gRPC returned unknown error, with message: ",
grpc_status.error_message());
- } else {
- return Status::IOError("gRPC failed with error code ", grpc_status.error_code(),
- " and message: ", grpc_status.error_message());
+ case grpc::StatusCode::INVALID_ARGUMENT:
+ return Status::Invalid("gRPC returned invalid argument error, with message: ",
+ grpc_status.error_message());
+ case grpc::StatusCode::DEADLINE_EXCEEDED:
+ return Status::IOError("gRPC returned deadline exceeded error, with message: ",
+ grpc_status.error_message())
+ .WithDetail(std::make_shared<FlightStatusDetail>(FlightStatusCode::TimedOut));
+ case grpc::StatusCode::NOT_FOUND:
+ return Status::KeyError("gRPC returned not found error, with message: ",
+ grpc_status.error_message());
+ case grpc::StatusCode::ALREADY_EXISTS:
+ return Status::AlreadyExists("gRPC returned already exists error, with message: ",
+ grpc_status.error_message());
+ case grpc::StatusCode::PERMISSION_DENIED:
+ return Status::IOError("gRPC returned permission denied error, with message: ",
+ grpc_status.error_message())
+ .WithDetail(
+ std::make_shared<FlightStatusDetail>(FlightStatusCode::Unauthorized));
+ case grpc::StatusCode::RESOURCE_EXHAUSTED:
+ return Status::Invalid("gRPC returned resource exhausted error, with message: ",
+ grpc_status.error_message());
+ case grpc::StatusCode::FAILED_PRECONDITION:
+ return Status::Invalid("gRPC returned precondition failed error, with message: ",
+ grpc_status.error_message());
+ case grpc::StatusCode::ABORTED:
+ return Status::IOError("gRPC returned aborted error, with message: ",
+ grpc_status.error_message())
+ .WithDetail(std::make_shared<FlightStatusDetail>(FlightStatusCode::Internal));
+ case grpc::StatusCode::OUT_OF_RANGE:
+ return Status::Invalid("gRPC returned out-of-range error, with message: ",
+ grpc_status.error_message());
+ case grpc::StatusCode::UNIMPLEMENTED:
+ return Status::NotImplemented("gRPC returned unimplemented error, with message: ",
+ grpc_status.error_message());
+ case grpc::StatusCode::INTERNAL:
+ return Status::IOError("gRPC returned internal error, with message: ",
+ grpc_status.error_message())
+ .WithDetail(std::make_shared<FlightStatusDetail>(FlightStatusCode::Internal));
+ case grpc::StatusCode::UNAVAILABLE:
+ return Status::IOError("gRPC returned unavailable error, with message: ",
+ grpc_status.error_message())
+ .WithDetail(
+ std::make_shared<FlightStatusDetail>(FlightStatusCode::Unavailable));
+ case grpc::StatusCode::DATA_LOSS:
+ return Status::IOError("gRPC returned data loss error, with message: ",
+ grpc_status.error_message())
+ .WithDetail(std::make_shared<FlightStatusDetail>(FlightStatusCode::Internal));
+ case grpc::StatusCode::UNAUTHENTICATED:
+ return Status::IOError("gRPC returned unauthenticated error, with message: ",
+ grpc_status.error_message())
+ .WithDetail(
+ std::make_shared<FlightStatusDetail>(FlightStatusCode::Unauthenticated));
+ default:
+ return Status::UnknownError("gRPC failed with error code ",
+ grpc_status.error_code(),
+ " and message: ", grpc_status.error_message());
}
}
grpc::Status ToGrpcStatus(const Status& arrow_status) {
if (arrow_status.ok()) {
return grpc::Status::OK;
- } else {
- grpc::StatusCode grpc_code = grpc::StatusCode::UNKNOWN;
- if (arrow_status.IsNotImplemented()) {
- grpc_code = grpc::StatusCode::UNIMPLEMENTED;
- } else if (arrow_status.IsInvalid()) {
- grpc_code = grpc::StatusCode::INVALID_ARGUMENT;
+ }
+
+ grpc::StatusCode grpc_code = grpc::StatusCode::UNKNOWN;
+ std::string message = arrow_status.message();
+ if (arrow_status.detail()) {
+ message += ". Detail: ";
+ message += arrow_status.detail()->ToString();
+ }
+
+ std::shared_ptr<FlightStatusDetail> flight_status =
+ FlightStatusDetail::UnwrapStatus(arrow_status);
+ if (flight_status) {
+ switch (flight_status->code()) {
+ case FlightStatusCode::Internal:
+ grpc_code = grpc::StatusCode::INTERNAL;
+ break;
+ case FlightStatusCode::TimedOut:
+ grpc_code = grpc::StatusCode::DEADLINE_EXCEEDED;
+ break;
+ case FlightStatusCode::Cancelled:
+ grpc_code = grpc::StatusCode::CANCELLED;
+ break;
+ case FlightStatusCode::Unauthenticated:
+ grpc_code = grpc::StatusCode::UNAUTHENTICATED;
+ break;
+ case FlightStatusCode::Unauthorized:
+ grpc_code = grpc::StatusCode::PERMISSION_DENIED;
+ break;
+ case FlightStatusCode::Unavailable:
+ grpc_code = grpc::StatusCode::UNAVAILABLE;
+ break;
+ default:
+ break;
}
- return grpc::Status(grpc_code, arrow_status.message());
+ } else if (arrow_status.IsNotImplemented()) {
+ grpc_code = grpc::StatusCode::UNIMPLEMENTED;
+ } else if (arrow_status.IsInvalid()) {
+ grpc_code = grpc::StatusCode::INVALID_ARGUMENT;
}
+ return grpc::Status(grpc_code, message);
}
// ActionType
diff --git a/cpp/src/arrow/flight/test-util.cc b/cpp/src/arrow/flight/test-util.cc
index e93e83f..c8d67e1 100644
--- a/cpp/src/arrow/flight/test-util.cc
+++ b/cpp/src/arrow/flight/test-util.cc
@@ -363,7 +363,7 @@ Status TestServerAuthHandler::Authenticate(ServerAuthSender* outgoing,
std::string token;
RETURN_NOT_OK(incoming->Read(&token));
if (token != password_) {
- return Status::Invalid("Invalid password");
+ return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token");
}
RETURN_NOT_OK(outgoing->Write(username_));
return Status::OK();
@@ -372,7 +372,7 @@ Status TestServerAuthHandler::Authenticate(ServerAuthSender* outgoing,
Status TestServerAuthHandler::IsValid(const std::string& token,
std::string* peer_identity) {
if (token != password_) {
- return Status::Invalid("Invalid token");
+ return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token");
}
*peer_identity = username_;
return Status::OK();
@@ -390,7 +390,7 @@ Status TestClientAuthHandler::Authenticate(ClientAuthSender* outgoing,
std::string username;
RETURN_NOT_OK(incoming->Read(&username));
if (username != username_) {
- return Status::Invalid("Invalid username");
+ return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token");
}
return Status::OK();
}
diff --git a/cpp/src/arrow/flight/types.cc b/cpp/src/arrow/flight/types.cc
index 89ebd82..e69a5ca 100644
--- a/cpp/src/arrow/flight/types.cc
+++ b/cpp/src/arrow/flight/types.cc
@@ -37,6 +37,46 @@ const char* kSchemeGrpcTcp = "grpc+tcp";
const char* kSchemeGrpcUnix = "grpc+unix";
const char* kSchemeGrpcTls = "grpc+tls";
+const char* kErrorDetailTypeId = "flight::FlightStatusDetail";
+
+const char* FlightStatusDetail::type_id() const { return kErrorDetailTypeId; }
+
+std::string FlightStatusDetail::ToString() const { return CodeAsString(); }
+
+FlightStatusCode FlightStatusDetail::code() const { return code_; }
+
+std::string FlightStatusDetail::CodeAsString() const {
+ switch (code()) {
+ case FlightStatusCode::Internal:
+ return "Internal";
+ case FlightStatusCode::TimedOut:
+ return "TimedOut";
+ case FlightStatusCode::Cancelled:
+ return "Cancelled";
+ case FlightStatusCode::Unauthenticated:
+ return "Unauthenticated";
+ case FlightStatusCode::Unauthorized:
+ return "Unauthorized";
+ case FlightStatusCode::Unavailable:
+ return "Unavailable";
+ default:
+ return "Unknown";
+ }
+}
+
+std::shared_ptr<FlightStatusDetail> FlightStatusDetail::UnwrapStatus(
+ const arrow::Status& status) {
+ if (!status.detail() || status.detail()->type_id() != kErrorDetailTypeId) {
+ return nullptr;
+ }
+ return std::dynamic_pointer_cast<FlightStatusDetail>(status.detail());
+}
+
+Status MakeFlightError(FlightStatusCode code, const std::string& message) {
+ StatusCode arrow_code = arrow::StatusCode::IOError;
+ return arrow::Status(arrow_code, message, std::make_shared<FlightStatusDetail>(code));
+}
+
bool FlightDescriptor::Equals(const FlightDescriptor& other) const {
if (type != other.type) {
return false;
diff --git a/cpp/src/arrow/flight/types.h b/cpp/src/arrow/flight/types.h
index a80b697..152d888 100644
--- a/cpp/src/arrow/flight/types.h
+++ b/cpp/src/arrow/flight/types.h
@@ -51,6 +51,60 @@ class Uri;
namespace flight {
+/// \brief A Flight-specific status code.
+enum class FlightStatusCode : int8_t {
+ /// An implementation error has occurred.
+ Internal,
+ /// A request timed out.
+ TimedOut,
+ /// A request was cancelled.
+ Cancelled,
+ /// We are not authenticated to the remote service.
+ Unauthenticated,
+ /// We do not have permission to make this request.
+ Unauthorized,
+ /// The remote service cannot handle this request at the moment.
+ Unavailable,
+};
+
+// Silence warning
+// "non dll-interface class RecordBatchReader used as base for dll-interface class"
+#ifdef _MSC_VER
+#pragma warning(push)
+#pragma warning(disable : 4275)
+#endif
+
+/// \brief Flight-specific information in a Status.
+class ARROW_FLIGHT_EXPORT FlightStatusDetail : public arrow::StatusDetail {
+ public:
+ explicit FlightStatusDetail(FlightStatusCode code) : code_{code} {}
+ const char* type_id() const override;
+ std::string ToString() const override;
+
+ /// \brief Get the Flight status code.
+ FlightStatusCode code() const;
+ /// \brief Get the human-readable name of the status code.
+ std::string CodeAsString() const;
+
+ /// \brief Try to extract a \a FlightStatusDetail from any Arrow
+ /// status.
+ ///
+ /// \return a \a FlightStatusDetail if it could be unwrapped, \a
+ /// nullptr otherwise
+ static std::shared_ptr<FlightStatusDetail> UnwrapStatus(const arrow::Status& status);
+
+ private:
+ FlightStatusCode code_;
+};
+
+#ifdef _MSC_VER
+#pragma warning(pop)
+#endif
+
+/// \brief Make an appropriate Arrow status for the given Flight status.
+ARROW_FLIGHT_EXPORT
+Status MakeFlightError(FlightStatusCode code, const std::string& message);
+
/// \brief A TLS certificate plus key.
struct ARROW_FLIGHT_EXPORT CertKeyPair {
/// \brief The certificate in PEM format.
diff --git a/cpp/src/arrow/python/flight.cc b/cpp/src/arrow/python/flight.cc
index da5026a..c5b2fe2 100644
--- a/cpp/src/arrow/python/flight.cc
+++ b/cpp/src/arrow/python/flight.cc
@@ -39,16 +39,18 @@ PyServerAuthHandler::PyServerAuthHandler(PyObject* handler,
Status PyServerAuthHandler::Authenticate(arrow::flight::ServerAuthSender* outgoing,
arrow::flight::ServerAuthReader* incoming) {
return SafeCallIntoPython([=] {
- vtable_.authenticate(handler_.obj(), outgoing, incoming);
- return CheckPyError();
+ const Status status = vtable_.authenticate(handler_.obj(), outgoing, incoming);
+ RETURN_NOT_OK(CheckPyError());
+ return status;
});
}
Status PyServerAuthHandler::IsValid(const std::string& token,
std::string* peer_identity) {
return SafeCallIntoPython([=] {
- vtable_.is_valid(handler_.obj(), token, peer_identity);
- return CheckPyError();
+ const Status status = vtable_.is_valid(handler_.obj(), token, peer_identity);
+ RETURN_NOT_OK(CheckPyError());
+ return status;
});
}
@@ -62,15 +64,17 @@ PyClientAuthHandler::PyClientAuthHandler(PyObject* handler,
Status PyClientAuthHandler::Authenticate(arrow::flight::ClientAuthSender* outgoing,
arrow::flight::ClientAuthReader* incoming) {
return SafeCallIntoPython([=] {
- vtable_.authenticate(handler_.obj(), outgoing, incoming);
- return CheckPyError();
+ const Status status = vtable_.authenticate(handler_.obj(), outgoing, incoming);
+ RETURN_NOT_OK(CheckPyError());
+ return status;
});
}
Status PyClientAuthHandler::GetToken(std::string* token) {
return SafeCallIntoPython([=] {
- vtable_.get_token(handler_.obj(), token);
- return CheckPyError();
+ const Status status = vtable_.get_token(handler_.obj(), token);
+ RETURN_NOT_OK(CheckPyError());
+ return status;
});
}
@@ -85,8 +89,10 @@ Status PyFlightServer::ListFlights(
const arrow::flight::Criteria* criteria,
std::unique_ptr<arrow::flight::FlightListing>* listings) {
return SafeCallIntoPython([&] {
- vtable_.list_flights(server_.obj(), context, criteria, listings);
- return CheckPyError();
+ const Status status =
+ vtable_.list_flights(server_.obj(), context, criteria, listings);
+ RETURN_NOT_OK(CheckPyError());
+ return status;
});
}
@@ -94,8 +100,9 @@ Status PyFlightServer::GetFlightInfo(const arrow::flight::ServerCallContext& con
const arrow::flight::FlightDescriptor& request,
std::unique_ptr<arrow::flight::FlightInfo>* info) {
return SafeCallIntoPython([&] {
- vtable_.get_flight_info(server_.obj(), context, request, info);
- return CheckPyError();
+ const Status status = vtable_.get_flight_info(server_.obj(), context, request, info);
+ RETURN_NOT_OK(CheckPyError());
+ return status;
});
}
@@ -103,8 +110,9 @@ Status PyFlightServer::DoGet(const arrow::flight::ServerCallContext& context,
const arrow::flight::Ticket& request,
std::unique_ptr<arrow::flight::FlightDataStream>* stream) {
return SafeCallIntoPython([&] {
- vtable_.do_get(server_.obj(), context, request, stream);
- return CheckPyError();
+ const Status status = vtable_.do_get(server_.obj(), context, request, stream);
+ RETURN_NOT_OK(CheckPyError());
+ return status;
});
}
@@ -113,8 +121,10 @@ Status PyFlightServer::DoPut(
std::unique_ptr<arrow::flight::FlightMessageReader> reader,
std::unique_ptr<arrow::flight::FlightMetadataWriter> writer) {
return SafeCallIntoPython([&] {
- vtable_.do_put(server_.obj(), context, std::move(reader), std::move(writer));
- return CheckPyError();
+ const Status status =
+ vtable_.do_put(server_.obj(), context, std::move(reader), std::move(writer));
+ RETURN_NOT_OK(CheckPyError());
+ return status;
});
}
@@ -122,16 +132,18 @@ Status PyFlightServer::DoAction(const arrow::flight::ServerCallContext& context,
const arrow::flight::Action& action,
std::unique_ptr<arrow::flight::ResultStream>* result) {
return SafeCallIntoPython([&] {
- vtable_.do_action(server_.obj(), context, action, result);
- return CheckPyError();
+ const Status status = vtable_.do_action(server_.obj(), context, action, result);
+ RETURN_NOT_OK(CheckPyError());
+ return status;
});
}
Status PyFlightServer::ListActions(const arrow::flight::ServerCallContext& context,
std::vector<arrow::flight::ActionType>* actions) {
return SafeCallIntoPython([&] {
- vtable_.list_actions(server_.obj(), context, actions);
- return CheckPyError();
+ const Status status = vtable_.list_actions(server_.obj(), context, actions);
+ RETURN_NOT_OK(CheckPyError());
+ return status;
});
}
@@ -173,8 +185,9 @@ PyFlightResultStream::PyFlightResultStream(PyObject* generator,
Status PyFlightResultStream::Next(std::unique_ptr<arrow::flight::Result>* result) {
return SafeCallIntoPython([=] {
- callback_(generator_.obj(), result);
- return CheckPyError();
+ const Status status = callback_(generator_.obj(), result);
+ RETURN_NOT_OK(CheckPyError());
+ return status;
});
}
@@ -210,8 +223,9 @@ Status PyGeneratorFlightDataStream::GetSchemaPayload(FlightPayload* payload) {
Status PyGeneratorFlightDataStream::Next(FlightPayload* payload) {
return SafeCallIntoPython([=] {
- callback_(generator_.obj(), payload);
- return CheckPyError();
+ const Status status = callback_(generator_.obj(), payload);
+ RETURN_NOT_OK(CheckPyError());
+ return status;
});
}
diff --git a/cpp/src/arrow/python/flight.h b/cpp/src/arrow/python/flight.h
index 5aea7e8..fe224f0 100644
--- a/cpp/src/arrow/python/flight.h
+++ b/cpp/src/arrow/python/flight.h
@@ -37,45 +37,45 @@ namespace flight {
/// Python.
class ARROW_PYTHON_EXPORT PyFlightServerVtable {
public:
- std::function<void(PyObject*, const arrow::flight::ServerCallContext&,
- const arrow::flight::Criteria*,
- std::unique_ptr<arrow::flight::FlightListing>*)>
+ std::function<Status(PyObject*, const arrow::flight::ServerCallContext&,
+ const arrow::flight::Criteria*,
+ std::unique_ptr<arrow::flight::FlightListing>*)>
list_flights;
- std::function<void(PyObject*, const arrow::flight::ServerCallContext&,
- const arrow::flight::FlightDescriptor&,
- std::unique_ptr<arrow::flight::FlightInfo>*)>
+ std::function<Status(PyObject*, const arrow::flight::ServerCallContext&,
+ const arrow::flight::FlightDescriptor&,
+ std::unique_ptr<arrow::flight::FlightInfo>*)>
get_flight_info;
- std::function<void(PyObject*, const arrow::flight::ServerCallContext&,
- const arrow::flight::Ticket&,
- std::unique_ptr<arrow::flight::FlightDataStream>*)>
+ std::function<Status(PyObject*, const arrow::flight::ServerCallContext&,
+ const arrow::flight::Ticket&,
+ std::unique_ptr<arrow::flight::FlightDataStream>*)>
do_get;
- std::function<void(PyObject*, const arrow::flight::ServerCallContext&,
- std::unique_ptr<arrow::flight::FlightMessageReader>,
- std::unique_ptr<arrow::flight::FlightMetadataWriter>)>
+ std::function<Status(PyObject*, const arrow::flight::ServerCallContext&,
+ std::unique_ptr<arrow::flight::FlightMessageReader>,
+ std::unique_ptr<arrow::flight::FlightMetadataWriter>)>
do_put;
- std::function<void(PyObject*, const arrow::flight::ServerCallContext&,
- const arrow::flight::Action&,
- std::unique_ptr<arrow::flight::ResultStream>*)>
+ std::function<Status(PyObject*, const arrow::flight::ServerCallContext&,
+ const arrow::flight::Action&,
+ std::unique_ptr<arrow::flight::ResultStream>*)>
do_action;
- std::function<void(PyObject*, const arrow::flight::ServerCallContext&,
- std::vector<arrow::flight::ActionType>*)>
+ std::function<Status(PyObject*, const arrow::flight::ServerCallContext&,
+ std::vector<arrow::flight::ActionType>*)>
list_actions;
};
class ARROW_PYTHON_EXPORT PyServerAuthHandlerVtable {
public:
- std::function<void(PyObject*, arrow::flight::ServerAuthSender*,
- arrow::flight::ServerAuthReader*)>
+ std::function<Status(PyObject*, arrow::flight::ServerAuthSender*,
+ arrow::flight::ServerAuthReader*)>
authenticate;
- std::function<void(PyObject*, const std::string&, std::string*)> is_valid;
+ std::function<Status(PyObject*, const std::string&, std::string*)> is_valid;
};
class ARROW_PYTHON_EXPORT PyClientAuthHandlerVtable {
public:
- std::function<void(PyObject*, arrow::flight::ClientAuthSender*,
- arrow::flight::ClientAuthReader*)>
+ std::function<Status(PyObject*, arrow::flight::ClientAuthSender*,
+ arrow::flight::ClientAuthReader*)>
authenticate;
- std::function<void(PyObject*, std::string*)> get_token;
+ std::function<Status(PyObject*, std::string*)> get_token;
};
/// \brief A helper to implement an auth mechanism in Python.
@@ -138,7 +138,7 @@ class ARROW_PYTHON_EXPORT PyFlightServer : public arrow::flight::FlightServerBas
};
/// \brief A callback that obtains the next result from a Flight action.
-typedef std::function<void(PyObject*, std::unique_ptr<arrow::flight::Result>*)>
+typedef std::function<Status(PyObject*, std::unique_ptr<arrow::flight::Result>*)>
PyFlightResultStreamCallback;
/// \brief A ResultStream built around a Python callback.
@@ -174,7 +174,7 @@ class ARROW_PYTHON_EXPORT PyFlightDataStream : public arrow::flight::FlightDataS
};
/// \brief A callback that obtains the next payload from a Flight result stream.
-typedef std::function<void(PyObject*, arrow::flight::FlightPayload*)>
+typedef std::function<Status(PyObject*, arrow::flight::FlightPayload*)>
PyGeneratorFlightDataStreamCallback;
/// \brief A FlightDataStream built around a Python callback.
diff --git a/cpp/src/arrow/status.h b/cpp/src/arrow/status.h
index b690409..6a3269f 100644
--- a/cpp/src/arrow/status.h
+++ b/cpp/src/arrow/status.h
@@ -249,6 +249,12 @@ class ARROW_EXPORT Status {
util::StringBuilder(std::forward<Args>(args)...));
}
+ template <typename... Args>
+ static Status AlreadyExists(Args&&... args) {
+ return Status(StatusCode::AlreadyExists,
+ util::StringBuilder(std::forward<Args>(args)...));
+ }
+
/// Return true iff the status indicates success.
bool ok() const { return (state_ == NULLPTR); }
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/ActionType.java b/java/flight/src/main/java/org/apache/arrow/flight/ActionType.java
index ed8fec7..d893656 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/ActionType.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/ActionType.java
@@ -59,4 +59,12 @@ public class ActionType {
.setDescription(description)
.build();
}
+
+ @Override
+ public String toString() {
+ return "ActionType{" +
+ "type='" + type + '\'' +
+ ", description='" + description + '\'' +
+ '}';
+ }
}
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/AsyncPutListener.java b/java/flight/src/main/java/org/apache/arrow/flight/AsyncPutListener.java
index c8214e3..c2182fc 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/AsyncPutListener.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/AsyncPutListener.java
@@ -20,6 +20,8 @@ package org.apache.arrow.flight;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
+import org.apache.arrow.flight.grpc.StatusUtils;
+
/**
* A handler for server-sent application metadata messages during a Flight DoPut operation.
*
@@ -53,7 +55,7 @@ public class AsyncPutListener implements FlightClient.PutListener {
@Override
public final void onError(Throwable t) {
- completed.completeExceptionally(t);
+ completed.completeExceptionally(StatusUtils.fromThrowable(t));
}
@Override
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/CallStatus.java b/java/flight/src/main/java/org/apache/arrow/flight/CallStatus.java
new file mode 100644
index 0000000..39bd034
--- /dev/null
+++ b/java/flight/src/main/java/org/apache/arrow/flight/CallStatus.java
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.util.Objects;
+
+import org.apache.arrow.flight.FlightProducer.ServerStreamListener;
+import org.apache.arrow.flight.FlightProducer.StreamListener;
+
+/**
+ * The result of a Flight RPC, consisting of a status code with an optional description and/or exception that led
+ * to the status.
+ *
+ * <p>If raised or sent through {@link StreamListener#onError(Throwable)} or
+ * {@link ServerStreamListener#error(Throwable)}, the client call will raise the same error (a
+ * {@link FlightRuntimeException} with the same {@link FlightStatusCode} and description). The exception within, if
+ * present, will not be sent to the client.
+ */
+public class CallStatus {
+
+ private final FlightStatusCode code;
+ private final Throwable cause;
+ private final String description;
+
+ public static final CallStatus UNKNOWN = FlightStatusCode.UNKNOWN.toStatus();
+ public static final CallStatus INTERNAL = FlightStatusCode.INTERNAL.toStatus();
+ public static final CallStatus INVALID_ARGUMENT = FlightStatusCode.INVALID_ARGUMENT.toStatus();
+ public static final CallStatus TIMED_OUT = FlightStatusCode.TIMED_OUT.toStatus();
+ public static final CallStatus NOT_FOUND = FlightStatusCode.NOT_FOUND.toStatus();
+ public static final CallStatus ALREADY_EXISTS = FlightStatusCode.ALREADY_EXISTS.toStatus();
+ public static final CallStatus CANCELLED = FlightStatusCode.CANCELLED.toStatus();
+ public static final CallStatus UNAUTHENTICATED = FlightStatusCode.UNAUTHENTICATED.toStatus();
+ public static final CallStatus UNAUTHORIZED = FlightStatusCode.UNAUTHORIZED.toStatus();
+ public static final CallStatus UNIMPLEMENTED = FlightStatusCode.UNIMPLEMENTED.toStatus();
+ public static final CallStatus UNAVAILABLE = FlightStatusCode.UNAVAILABLE.toStatus();
+
+ /**
+ * Create a new status.
+ *
+ * @param code The status code.
+ * @param cause An exception that resulted in this status (or null).
+ * @param description A description of the status (or null).
+ */
+ public CallStatus(FlightStatusCode code, Throwable cause, String description) {
+ this.code = Objects.requireNonNull(code);
+ this.cause = cause;
+ this.description = description == null ? "" : description;
+ }
+
+ /**
+ * Create a new status with no cause or description.
+ *
+ * @param code The status code.
+ */
+ public CallStatus(FlightStatusCode code) {
+ this(code, /* no cause */ null, /* no description */ null);
+ }
+
+ /**
+ * The status code describing the result of the RPC.
+ */
+ public FlightStatusCode code() {
+ return code;
+ }
+
+ /**
+ * The exception that led to this result. May be null.
+ */
+ public Throwable cause() {
+ return cause;
+ }
+
+ /**
+ * A description of the result.
+ */
+ public String description() {
+ return description;
+ }
+
+ /**
+ * Return a copy of this status with an error message.
+ */
+ public CallStatus withDescription(String message) {
+ return new CallStatus(code, cause, message);
+ }
+
+ /**
+ * Return a copy of this status with the given exception as the cause. This will not be sent over the wire.
+ */
+ public CallStatus withCause(Throwable t) {
+ return new CallStatus(code, t, description);
+ }
+
+ /**
+ * Convert the status to an equivalent exception.
+ */
+ public FlightRuntimeException toRuntimeException() {
+ return new FlightRuntimeException(this);
+ }
+}
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightBindingService.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightBindingService.java
index 13a28f9..94928f9 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/FlightBindingService.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightBindingService.java
@@ -18,6 +18,7 @@
package org.apache.arrow.flight;
import java.util.Set;
+import java.util.concurrent.ExecutorService;
import org.apache.arrow.flight.auth.ServerAuthHandler;
import org.apache.arrow.flight.impl.Flight;
@@ -48,9 +49,9 @@ class FlightBindingService implements BindableService {
private final BufferAllocator allocator;
public FlightBindingService(BufferAllocator allocator, FlightProducer producer,
- ServerAuthHandler authHandler) {
+ ServerAuthHandler authHandler, ExecutorService executor) {
this.allocator = allocator;
- this.delegate = new FlightService(allocator, producer, authHandler);
+ this.delegate = new FlightService(allocator, producer, authHandler, executor);
}
public static MethodDescriptor<Flight.Ticket, ArrowMessage> getDoGetDescriptor(BufferAllocator allocator) {
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java
index 9ac3686..4caaac1 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java
@@ -21,7 +21,6 @@ import java.io.InputStream;
import java.net.URISyntaxException;
import java.util.Iterator;
import java.util.concurrent.TimeUnit;
-import java.util.stream.Collectors;
import javax.net.ssl.SSLException;
@@ -30,6 +29,7 @@ import org.apache.arrow.flight.auth.BasicClientAuthHandler;
import org.apache.arrow.flight.auth.ClientAuthHandler;
import org.apache.arrow.flight.auth.ClientAuthInterceptor;
import org.apache.arrow.flight.auth.ClientAuthWrapper;
+import org.apache.arrow.flight.grpc.StatusUtils;
import org.apache.arrow.flight.impl.Flight;
import org.apache.arrow.flight.impl.Flight.Empty;
import org.apache.arrow.flight.impl.FlightServiceGrpc;
@@ -43,12 +43,11 @@ import org.apache.arrow.vector.dictionary.DictionaryProvider.MapDictionaryProvid
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import com.google.common.base.Preconditions;
-import com.google.common.collect.ImmutableList;
-import com.google.common.collect.Iterators;
import io.grpc.ClientCall;
import io.grpc.ManagedChannel;
import io.grpc.MethodDescriptor;
+import io.grpc.StatusRuntimeException;
import io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.NettyChannelBuilder;
import io.grpc.stub.ClientCallStreamObserver;
@@ -96,18 +95,22 @@ public class FlightClient implements AutoCloseable {
* @return FlightInfo Iterable
*/
public Iterable<FlightInfo> listFlights(Criteria criteria, CallOption... options) {
- return ImmutableList.copyOf(CallOptions.wrapStub(blockingStub, options).listFlights(criteria.asCriteria()))
- .stream()
- .map(t -> {
- try {
- return new FlightInfo(t);
- } catch (URISyntaxException e) {
- // We don't expect this will happen for conforming Flight implementations. For instance, a Java server
- // itself wouldn't be able to construct an invalid Location.
- throw new RuntimeException(e);
- }
- })
- .collect(Collectors.toList());
+ final Iterator<Flight.FlightInfo> flights;
+ try {
+ flights = CallOptions.wrapStub(blockingStub, options)
+ .listFlights(criteria.asCriteria());
+ } catch (StatusRuntimeException sre) {
+ throw StatusUtils.fromGrpcRuntimeException(sre);
+ }
+ return () -> StatusUtils.wrapIterator(flights, t -> {
+ try {
+ return new FlightInfo(t);
+ } catch (URISyntaxException e) {
+ // We don't expect this will happen for conforming Flight implementations. For instance, a Java server
+ // itself wouldn't be able to construct an invalid Location.
+ throw new RuntimeException(e);
+ }
+ });
}
/**
@@ -116,11 +119,14 @@ public class FlightClient implements AutoCloseable {
* @param options RPC-layer hints for the call.
*/
public Iterable<ActionType> listActions(CallOption... options) {
- return ImmutableList.copyOf(CallOptions.wrapStub(blockingStub, options)
- .listActions(Empty.getDefaultInstance()))
- .stream()
- .map(ActionType::new)
- .collect(Collectors.toList());
+ final Iterator<Flight.ActionType> actions;
+ try {
+ actions = CallOptions.wrapStub(blockingStub, options)
+ .listActions(Empty.getDefaultInstance());
+ } catch (StatusRuntimeException sre) {
+ throw StatusUtils.fromGrpcRuntimeException(sre);
+ }
+ return () -> StatusUtils.wrapIterator(actions, ActionType::new);
}
/**
@@ -131,8 +137,8 @@ public class FlightClient implements AutoCloseable {
* @return An iterator of results.
*/
public Iterator<Result> doAction(Action action, CallOption... options) {
- return Iterators
- .transform(CallOptions.wrapStub(blockingStub, options).doAction(action.toProtocol()), Result::new);
+ return StatusUtils
+ .wrapIterator(CallOptions.wrapStub(blockingStub, options).doAction(action.toProtocol()), Result::new);
}
/**
@@ -183,16 +189,20 @@ public class FlightClient implements AutoCloseable {
Preconditions.checkNotNull(descriptor);
Preconditions.checkNotNull(root);
- SetStreamObserver resultObserver = new SetStreamObserver(allocator, metadataListener);
- final io.grpc.CallOptions callOptions = CallOptions.wrapStub(asyncStub, options).getCallOptions();
- ClientCallStreamObserver<ArrowMessage> observer = (ClientCallStreamObserver<ArrowMessage>)
- ClientCalls.asyncBidiStreamingCall(
- authInterceptor.interceptCall(doPutDescriptor, callOptions, channel), resultObserver);
- // send the schema to start.
- DictionaryUtils.generateSchemaMessages(root.getSchema(), descriptor, provider, observer::onNext);
- return new PutObserver(new VectorUnloader(
- root, true /* include # of nulls in vectors */, true /* must align buffers to be C++-compatible */),
- observer, metadataListener);
+ try {
+ SetStreamObserver resultObserver = new SetStreamObserver(allocator, metadataListener);
+ final io.grpc.CallOptions callOptions = CallOptions.wrapStub(asyncStub, options).getCallOptions();
+ ClientCallStreamObserver<ArrowMessage> observer = (ClientCallStreamObserver<ArrowMessage>)
+ ClientCalls.asyncBidiStreamingCall(
+ authInterceptor.interceptCall(doPutDescriptor, callOptions, channel), resultObserver);
+ // send the schema to start.
+ DictionaryUtils.generateSchemaMessages(root.getSchema(), descriptor, provider, observer::onNext);
+ return new PutObserver(new VectorUnloader(
+ root, true /* include # of nulls in vectors */, true /* must align buffers to be C++-compatible */),
+ observer, metadataListener);
+ } catch (StatusRuntimeException sre) {
+ throw StatusUtils.fromGrpcRuntimeException(sre);
+ }
}
/**
@@ -207,6 +217,8 @@ public class FlightClient implements AutoCloseable {
// We don't expect this will happen for conforming Flight implementations. For instance, a Java server
// itself wouldn't be able to construct an invalid Location.
throw new RuntimeException(e);
+ } catch (StatusRuntimeException sre) {
+ throw StatusUtils.fromGrpcRuntimeException(sre);
}
}
@@ -241,7 +253,7 @@ public class FlightClient implements AutoCloseable {
@Override
public void onError(Throwable t) {
- delegate.onError(t);
+ delegate.onError(StatusUtils.toGrpcException(t));
}
@Override
@@ -274,7 +286,7 @@ public class FlightClient implements AutoCloseable {
@Override
public void onError(Throwable t) {
- listener.onError(t);
+ listener.onError(StatusUtils.fromThrowable(t));
}
@Override
@@ -307,13 +319,17 @@ public class FlightClient implements AutoCloseable {
while (!observer.isReady()) {
/* busy wait */
}
- // Takes ownership of appMetadata
- observer.onNext(new ArrowMessage(batch, appMetadata));
+ try {
+ // Takes ownership of appMetadata
+ observer.onNext(new ArrowMessage(batch, appMetadata));
+ } catch (StatusRuntimeException sre) {
+ throw StatusUtils.fromGrpcRuntimeException(sre);
+ }
}
@Override
public void error(Throwable ex) {
- observer.onError(ex);
+ observer.onError(StatusUtils.toGrpcException(ex));
}
@Override
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/auth/ClientAuthHandler.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightRuntimeException.java
similarity index 54%
copy from java/flight/src/main/java/org/apache/arrow/flight/auth/ClientAuthHandler.java
copy to java/flight/src/main/java/org/apache/arrow/flight/FlightRuntimeException.java
index 845de2e..5ac8176 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/auth/ClientAuthHandler.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightRuntimeException.java
@@ -15,35 +15,26 @@
* limitations under the License.
*/
-package org.apache.arrow.flight.auth;
-
-import java.util.Iterator;
+package org.apache.arrow.flight;
/**
- * Implement authentication for Flight on the client side.
+ * An exception raised from a Flight RPC.
+ *
+ * <p>In service implementations, raising an instance of this exception will provide clients with a more detailed
+ * message and error code.
*/
-public interface ClientAuthHandler {
- /**
- * Handle the initial handshake with the server.
- * @param outgoing A channel to send data to the server.
- * @param incoming An iterator of incoming data from the server.
- */
- void authenticate(ClientAuthSender outgoing, Iterator<byte[]> incoming);
+public class FlightRuntimeException extends RuntimeException {
+ private final CallStatus status;
/**
- * Get the per-call authentication token.
+ * Create a new exception from the given status.
*/
- byte[] getCallToken();
-
- /**
- * A communication channel to the server during initial connection.
- */
- interface ClientAuthSender {
-
- void send(byte[] payload);
-
- void onError(String message, Throwable cause);
-
+ FlightRuntimeException(CallStatus status) {
+ super(status.description(), status.cause());
+ this.status = status;
}
+ public CallStatus status() {
+ return status;
+ }
}
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java
index 3f02dd5..8ec665c 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightServer.java
@@ -25,8 +25,8 @@ import java.net.URI;
import java.net.URISyntaxException;
import java.util.HashMap;
import java.util.Map;
-import java.util.concurrent.Executor;
-import java.util.concurrent.ForkJoinPool;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import org.apache.arrow.flight.auth.ServerAuthHandler;
@@ -147,7 +147,7 @@ public class FlightServer implements AutoCloseable {
private FlightProducer producer;
private final Map<String, Object> builderOptions;
private ServerAuthHandler authHandler = ServerAuthHandler.NO_OP;
- private Executor executor = null;
+ private ExecutorService executor = null;
private int maxInboundMessageSize = MAX_GRPC_MESSAGE_SIZE;
private InputStream certChain;
private InputStream key;
@@ -214,12 +214,14 @@ public class FlightServer implements AutoCloseable {
builder.useTransportSecurity(certChain, key);
}
+ // Share one executor between the gRPC service, DoPut, and Handshake
+ final ExecutorService exec = executor != null ? executor : Executors.newCachedThreadPool();
builder
- .executor(executor != null ? executor : new ForkJoinPool())
+ .executor(exec)
.maxInboundMessageSize(maxInboundMessageSize)
.addService(
ServerInterceptors.intercept(
- new FlightBindingService(allocator, producer, authHandler),
+ new FlightBindingService(allocator, producer, authHandler, exec),
new ServerAuthInterceptor(authHandler)));
// Allow setting some Netty-specific options
@@ -268,7 +270,7 @@ public class FlightServer implements AutoCloseable {
/**
* Set the executor used by the server.
*/
- public Builder executor(Executor executor) {
+ public Builder executor(ExecutorService executor) {
this.executor = executor;
return this;
}
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightService.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightService.java
index e805917..bd83cc6 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/FlightService.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightService.java
@@ -18,13 +18,13 @@
package org.apache.arrow.flight;
import java.util.concurrent.ExecutorService;
-import java.util.concurrent.Executors;
import java.util.function.BooleanSupplier;
import org.apache.arrow.flight.FlightProducer.ServerStreamListener;
import org.apache.arrow.flight.auth.AuthConstants;
import org.apache.arrow.flight.auth.ServerAuthHandler;
import org.apache.arrow.flight.auth.ServerAuthWrapper;
+import org.apache.arrow.flight.grpc.StatusUtils;
import org.apache.arrow.flight.impl.Flight;
import org.apache.arrow.flight.impl.Flight.ActionType;
import org.apache.arrow.flight.impl.Flight.Empty;
@@ -57,12 +57,14 @@ class FlightService extends FlightServiceImplBase {
private final BufferAllocator allocator;
private final FlightProducer producer;
private final ServerAuthHandler authHandler;
- private final ExecutorService executors = Executors.newCachedThreadPool();
+ private final ExecutorService executors;
- public FlightService(BufferAllocator allocator, FlightProducer producer, ServerAuthHandler authHandler) {
+ FlightService(BufferAllocator allocator, FlightProducer producer, ServerAuthHandler authHandler,
+ ExecutorService executors) {
this.allocator = allocator;
this.producer = producer;
this.authHandler = authHandler;
+ this.executors = executors;
}
@Override
@@ -76,7 +78,7 @@ class FlightService extends FlightServiceImplBase {
producer.listFlights(makeContext((ServerCallStreamObserver<?>) responseObserver), new Criteria(criteria),
StreamPipe.wrap(responseObserver, FlightInfo::toProtocol));
} catch (Exception ex) {
- responseObserver.onError(ex);
+ responseObserver.onError(StatusUtils.toGrpcException(ex));
}
}
@@ -90,7 +92,7 @@ class FlightService extends FlightServiceImplBase {
producer.getStream(makeContext((ServerCallStreamObserver<?>) responseObserver), new Ticket(ticket),
new GetListener(responseObserver));
} catch (Exception ex) {
- responseObserver.onError(ex);
+ responseObserver.onError(StatusUtils.toGrpcException(ex));
}
}
@@ -100,7 +102,7 @@ class FlightService extends FlightServiceImplBase {
producer.doAction(makeContext((ServerCallStreamObserver<?>) responseObserver), new Action(request),
StreamPipe.wrap(responseObserver, Result::toProtocol));
} catch (Exception ex) {
- responseObserver.onError(ex);
+ responseObserver.onError(StatusUtils.toGrpcException(ex));
}
}
@@ -110,7 +112,7 @@ class FlightService extends FlightServiceImplBase {
producer.listActions(makeContext((ServerCallStreamObserver<?>) responseObserver),
StreamPipe.wrap(responseObserver, t -> t.toProtocol()));
} catch (Exception ex) {
- responseObserver.onError(ex);
+ responseObserver.onError(StatusUtils.toGrpcException(ex));
}
}
@@ -164,7 +166,7 @@ class FlightService extends FlightServiceImplBase {
@Override
public void error(Throwable ex) {
- responseObserver.onError(ex);
+ responseObserver.onError(StatusUtils.toGrpcException(ex));
}
@Override
@@ -192,7 +194,7 @@ class FlightService extends FlightServiceImplBase {
responseObserver.onCompleted();
} catch (Exception ex) {
logger.error("Failed to process custom put.", ex);
- responseObserver.onError(ex);
+ responseObserver.onError(StatusUtils.toGrpcException(ex));
// The client may have terminated, so the exception here is effectively swallowed.
// Log the error as well so -something- makes it to the developer.
logger.error("Exception handling DoPut", ex);
@@ -216,7 +218,7 @@ class FlightService extends FlightServiceImplBase {
responseObserver.onNext(info.toProtocol());
responseObserver.onCompleted();
} catch (Exception ex) {
- responseObserver.onError(ex);
+ responseObserver.onError(StatusUtils.toGrpcException(ex));
}
}
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightStatusCode.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightStatusCode.java
new file mode 100644
index 0000000..3f7f78d
--- /dev/null
+++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightStatusCode.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+/**
+ * A status code describing the result of a Flight call.
+ */
+public enum FlightStatusCode {
+ /**
+ * An unknown error occurred. This may also be the result of an implementation error on the server-side; by default,
+ * unhandled server exceptions result in this code.
+ */
+ UNKNOWN,
+ /**
+ * An internal/implementation error occurred.
+ */
+ INTERNAL,
+ /**
+ * One or more of the given arguments was invalid.
+ */
+ INVALID_ARGUMENT,
+ /**
+ * The operation timed out.
+ */
+ TIMED_OUT,
+ /**
+ * The operation describes a resource that does not exist.
+ */
+ NOT_FOUND,
+ /**
+ * The operation creates a resource that already exists.
+ */
+ ALREADY_EXISTS,
+ /**
+ * The operation was cancelled.
+ */
+ CANCELLED,
+ /**
+ * The client was not authenticated.
+ */
+ UNAUTHENTICATED,
+ /**
+ * The client did not have permission to make the call.
+ */
+ UNAUTHORIZED,
+ /**
+ * The requested operation is not implemented.
+ */
+ UNIMPLEMENTED,
+ /**
+ * The server cannot currently handle the request. This should be used for retriable requests, i.e. the server
+ * should send this code only if it has not done any work.
+ */
+ UNAVAILABLE,
+ ;
+
+ /**
+ * Create a blank {@link CallStatus} with this code.
+ */
+ public CallStatus toStatus() {
+ return new CallStatus(this);
+ }
+}
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/NoOpFlightProducer.java b/java/flight/src/main/java/org/apache/arrow/flight/NoOpFlightProducer.java
index eca32e1..d1432f5 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/NoOpFlightProducer.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/NoOpFlightProducer.java
@@ -25,37 +25,37 @@ public class NoOpFlightProducer implements FlightProducer {
@Override
public void getStream(CallContext context, Ticket ticket,
ServerStreamListener listener) {
- listener.error(new UnsupportedOperationException("NYI"));
+ listener.error(CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException());
}
@Override
public void listFlights(CallContext context, Criteria criteria,
StreamListener<FlightInfo> listener) {
- listener.onError(new UnsupportedOperationException("NYI"));
+ listener.onError(CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException());
}
@Override
public FlightInfo getFlightInfo(CallContext context,
FlightDescriptor descriptor) {
- throw new UnsupportedOperationException("NYI");
+ throw CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException();
}
@Override
public Runnable acceptPut(CallContext context,
FlightStream flightStream, StreamListener<PutResult> ackStream) {
- throw new UnsupportedOperationException("NYI");
+ throw CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException();
}
@Override
public void doAction(CallContext context, Action action,
StreamListener<Result> listener) {
- throw new UnsupportedOperationException("NYI");
+ listener.onError(CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException());
}
@Override
public void listActions(CallContext context,
StreamListener<ActionType> listener) {
- listener.onError(new UnsupportedOperationException("NYI"));
+ listener.onError(CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException());
}
}
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/StreamPipe.java b/java/flight/src/main/java/org/apache/arrow/flight/StreamPipe.java
index 3563277..d107ec5 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/StreamPipe.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/StreamPipe.java
@@ -20,6 +20,7 @@ package org.apache.arrow.flight;
import java.util.function.Function;
import org.apache.arrow.flight.FlightProducer.StreamListener;
+import org.apache.arrow.flight.grpc.StatusUtils;
import io.grpc.stub.StreamObserver;
@@ -51,7 +52,7 @@ class StreamPipe<FROM, TO> implements StreamListener<FROM> {
@Override
public void onError(Throwable t) {
- delegate.onError(t);
+ delegate.onError(StatusUtils.toGrpcException(t));
}
@Override
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/SyncPutListener.java b/java/flight/src/main/java/org/apache/arrow/flight/SyncPutListener.java
index f1246a1..690e774 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/SyncPutListener.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/SyncPutListener.java
@@ -22,6 +22,8 @@ import java.util.concurrent.ExecutionException;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
+import org.apache.arrow.flight.grpc.StatusUtils;
+
import io.netty.buffer.ArrowBuf;
/**
@@ -93,7 +95,7 @@ public final class SyncPutListener implements FlightClient.PutListener, AutoClos
@Override
public void onError(Throwable t) {
- completed.completeExceptionally(t);
+ completed.completeExceptionally(StatusUtils.fromThrowable(t));
queue.add(DONE_WITH_EXCEPTION);
}
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/auth/ClientAuthHandler.java b/java/flight/src/main/java/org/apache/arrow/flight/auth/ClientAuthHandler.java
index 845de2e..985e10a 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/auth/ClientAuthHandler.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/auth/ClientAuthHandler.java
@@ -40,9 +40,15 @@ public interface ClientAuthHandler {
*/
interface ClientAuthSender {
+ /**
+ * Send the server a message.
+ */
void send(byte[] payload);
- void onError(String message, Throwable cause);
+ /**
+ * Signal an error to the server and abort the authentication attempt.
+ */
+ void onError(Throwable cause);
}
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/auth/ClientAuthWrapper.java b/java/flight/src/main/java/org/apache/arrow/flight/auth/ClientAuthWrapper.java
index 9b80340..e86dc16 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/auth/ClientAuthWrapper.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/auth/ClientAuthWrapper.java
@@ -18,12 +18,12 @@
package org.apache.arrow.flight.auth;
import java.util.Iterator;
-import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.LinkedBlockingQueue;
import org.apache.arrow.flight.auth.ClientAuthHandler.ClientAuthSender;
+import org.apache.arrow.flight.grpc.StatusUtils;
import org.apache.arrow.flight.impl.Flight.HandshakeRequest;
import org.apache.arrow.flight.impl.Flight.HandshakeResponse;
import org.apache.arrow.flight.impl.FlightServiceGrpc.FlightServiceStub;
@@ -46,18 +46,24 @@ public class ClientAuthWrapper {
*/
public static void doClientAuth(ClientAuthHandler authHandler, FlightServiceStub stub) {
AuthObserver observer = new AuthObserver();
- observer.responseObserver = stub.handshake(observer);
- authHandler.authenticate(observer.sender, observer.iter);
- if (!observer.sender.errored) {
- observer.responseObserver.onCompleted();
+ try {
+ observer.responseObserver = stub.handshake(observer);
+ authHandler.authenticate(observer.sender, observer.iter);
+ if (!observer.sender.errored) {
+ observer.responseObserver.onCompleted();
+ }
+ } catch (StatusRuntimeException sre) {
+ throw StatusUtils.fromGrpcRuntimeException(sre);
}
try {
if (!observer.completed.get()) {
// TODO: ARROW-5681
throw new RuntimeException("Unauthenticated");
}
- } catch (InterruptedException | ExecutionException e) {
+ } catch (InterruptedException e) {
throw new RuntimeException(e);
+ } catch (ExecutionException e) {
+ throw StatusUtils.fromThrowable(e.getCause());
}
}
@@ -130,16 +136,19 @@ public class ClientAuthWrapper {
@Override
public void send(byte[] payload) {
- responseObserver.onNext(HandshakeRequest.newBuilder()
- .setPayload(ByteString.copyFrom(payload))
- .build());
+ try {
+ responseObserver.onNext(HandshakeRequest.newBuilder()
+ .setPayload(ByteString.copyFrom(payload))
+ .build());
+ } catch (StatusRuntimeException sre) {
+ throw StatusUtils.fromGrpcRuntimeException(sre);
+ }
}
@Override
- public void onError(String message, Throwable cause) {
+ public void onError(Throwable cause) {
this.errored = true;
- Objects.requireNonNull(cause);
- responseObserver.onError(cause);
+ responseObserver.onError(StatusUtils.toGrpcException(cause));
}
}
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/auth/ServerAuthHandler.java b/java/flight/src/main/java/org/apache/arrow/flight/auth/ServerAuthHandler.java
index 0507d3b..5195959 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/auth/ServerAuthHandler.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/auth/ServerAuthHandler.java
@@ -50,7 +50,7 @@ public interface ServerAuthHandler {
void send(byte[] payload);
- void onError(String message, Throwable cause);
+ void onError(Throwable cause);
}
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/auth/ServerAuthInterceptor.java b/java/flight/src/main/java/org/apache/arrow/flight/auth/ServerAuthInterceptor.java
index f38dee7..4ebd742 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/auth/ServerAuthInterceptor.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/auth/ServerAuthInterceptor.java
@@ -45,7 +45,7 @@ public class ServerAuthInterceptor implements ServerInterceptor {
if (!call.getMethodDescriptor().getFullMethodName().equals(AuthConstants.HANDSHAKE_DESCRIPTOR_NAME)) {
final Optional<String> peerIdentity = isValid(headers);
if (!peerIdentity.isPresent()) {
- call.close(Status.PERMISSION_DENIED, new Metadata());
+ call.close(Status.UNAUTHENTICATED, new Metadata());
// TODO: we should actually terminate here instead of causing an exception below.
return new NoopServerCallListener<>();
}
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/auth/ServerAuthWrapper.java b/java/flight/src/main/java/org/apache/arrow/flight/auth/ServerAuthWrapper.java
index a3c698b..6678aea 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/auth/ServerAuthWrapper.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/auth/ServerAuthWrapper.java
@@ -22,19 +22,24 @@ import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
+import org.apache.arrow.flight.CallStatus;
import org.apache.arrow.flight.auth.ServerAuthHandler.ServerAuthSender;
+import org.apache.arrow.flight.grpc.StatusUtils;
import org.apache.arrow.flight.impl.Flight.HandshakeRequest;
import org.apache.arrow.flight.impl.Flight.HandshakeResponse;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
import com.google.protobuf.ByteString;
-import io.grpc.Status;
import io.grpc.stub.StreamObserver;
/**
* Contains utility methods for integrating authorization into a GRPC stream.
*/
public class ServerAuthWrapper {
+ private static final Logger LOGGER = LoggerFactory.getLogger(ServerAuthWrapper.class);
/**
* Wrap the auth handler for handshake purposes.
@@ -56,10 +61,10 @@ public class ServerAuthWrapper {
return;
}
- responseObserver.onError(Status.PERMISSION_DENIED.asException());
+ responseObserver.onError(StatusUtils.toGrpcException(CallStatus.UNAUTHENTICATED.toRuntimeException()));
} catch (Exception ex) {
- ex.printStackTrace();
- responseObserver.onError(ex);
+ LOGGER.error("Error during authentication", ex);
+ responseObserver.onError(StatusUtils.toGrpcException(ex));
}
};
observer.future = executors.submit(r);
@@ -130,8 +135,8 @@ public class ServerAuthWrapper {
}
@Override
- public void onError(String message, Throwable cause) {
- responseObserver.onError(cause);
+ public void onError(Throwable cause) {
+ responseObserver.onError(StatusUtils.toGrpcException(cause));
}
}
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/example/InMemoryStore.java b/java/flight/src/main/java/org/apache/arrow/flight/example/InMemoryStore.java
index 59324b3..5508399 100644
--- a/java/flight/src/main/java/org/apache/arrow/flight/example/InMemoryStore.java
+++ b/java/flight/src/main/java/org/apache/arrow/flight/example/InMemoryStore.java
@@ -22,6 +22,7 @@ import java.util.concurrent.ConcurrentMap;
import org.apache.arrow.flight.Action;
import org.apache.arrow.flight.ActionType;
+import org.apache.arrow.flight.CallStatus;
import org.apache.arrow.flight.Criteria;
import org.apache.arrow.flight.FlightDescriptor;
import org.apache.arrow.flight.FlightInfo;
@@ -143,7 +144,7 @@ public class InMemoryStore implements FlightProducer, AutoCloseable {
break;
}
default: {
- listener.onError(new UnsupportedOperationException());
+ listener.onError(CallStatus.UNIMPLEMENTED.toRuntimeException());
}
}
}
diff --git a/java/flight/src/main/java/org/apache/arrow/flight/grpc/StatusUtils.java b/java/flight/src/main/java/org/apache/arrow/flight/grpc/StatusUtils.java
new file mode 100644
index 0000000..062b9ab
--- /dev/null
+++ b/java/flight/src/main/java/org/apache/arrow/flight/grpc/StatusUtils.java
@@ -0,0 +1,192 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.grpc;
+
+import java.util.Iterator;
+import java.util.Objects;
+import java.util.function.Function;
+
+import org.apache.arrow.flight.CallStatus;
+import org.apache.arrow.flight.FlightRuntimeException;
+import org.apache.arrow.flight.FlightStatusCode;
+
+import io.grpc.Status;
+import io.grpc.Status.Code;
+import io.grpc.StatusException;
+import io.grpc.StatusRuntimeException;
+
+/**
+ * Utilities to adapt gRPC and Flight status objects.
+ *
+ * <p>NOT A PUBLIC CLASS, interface is not guaranteed to remain stable.
+ */
+public class StatusUtils {
+
+ private StatusUtils() {
+ throw new AssertionError("Do not instantiate this class.");
+ }
+
+ /**
+ * Convert from a Flight status code to a gRPC status code.
+ */
+ public static Status.Code toGrpcStatusCode(FlightStatusCode code) {
+ switch (code) {
+ case UNKNOWN:
+ return Code.UNKNOWN;
+ case INTERNAL:
+ return Code.INTERNAL;
+ case INVALID_ARGUMENT:
+ return Code.INVALID_ARGUMENT;
+ case TIMED_OUT:
+ return Code.DEADLINE_EXCEEDED;
+ case NOT_FOUND:
+ return Code.NOT_FOUND;
+ case ALREADY_EXISTS:
+ return Code.ALREADY_EXISTS;
+ case CANCELLED:
+ return Code.CANCELLED;
+ case UNAUTHENTICATED:
+ return Code.UNAUTHENTICATED;
+ case UNAUTHORIZED:
+ return Code.PERMISSION_DENIED;
+ case UNIMPLEMENTED:
+ return Code.UNIMPLEMENTED;
+ case UNAVAILABLE:
+ return Code.UNAVAILABLE;
+ default:
+ return Code.UNKNOWN;
+ }
+ }
+
+ /**
+ * Convert from a gRPC status code to a Flight status code.
+ */
+ public static FlightStatusCode fromGrpcStatusCode(Status.Code code) {
+ switch (code) {
+ case CANCELLED:
+ return FlightStatusCode.CANCELLED;
+ case UNKNOWN:
+ return FlightStatusCode.UNKNOWN;
+ case INVALID_ARGUMENT:
+ return FlightStatusCode.INVALID_ARGUMENT;
+ case DEADLINE_EXCEEDED:
+ return FlightStatusCode.TIMED_OUT;
+ case NOT_FOUND:
+ return FlightStatusCode.NOT_FOUND;
+ case ALREADY_EXISTS:
+ return FlightStatusCode.ALREADY_EXISTS;
+ case PERMISSION_DENIED:
+ return FlightStatusCode.UNAUTHORIZED;
+ case RESOURCE_EXHAUSTED:
+ return FlightStatusCode.INVALID_ARGUMENT;
+ case FAILED_PRECONDITION:
+ return FlightStatusCode.INVALID_ARGUMENT;
+ case ABORTED:
+ return FlightStatusCode.INTERNAL;
+ case OUT_OF_RANGE:
+ return FlightStatusCode.INVALID_ARGUMENT;
+ case UNIMPLEMENTED:
+ return FlightStatusCode.UNIMPLEMENTED;
+ case INTERNAL:
+ return FlightStatusCode.INTERNAL;
+ case UNAVAILABLE:
+ return FlightStatusCode.UNAVAILABLE;
+ case DATA_LOSS:
+ return FlightStatusCode.INTERNAL;
+ case UNAUTHENTICATED:
+ return FlightStatusCode.UNAUTHENTICATED;
+ default:
+ return FlightStatusCode.UNKNOWN;
+ }
+ }
+
+ /** Convert from a gRPC status to a Flight status. */
+ public static CallStatus fromGrpcStatus(Status status) {
+ return new CallStatus(fromGrpcStatusCode(status.getCode()), status.getCause(), status.getDescription());
+ }
+
+ /** Convert from a Flight status to a gRPC status. */
+ public static Status toGrpcStatus(CallStatus status) {
+ return toGrpcStatusCode(status.code()).toStatus().withDescription(status.description()).withCause(status.cause());
+ }
+
+ /** Convert from a gRPC exception to a Flight exception. */
+ public static FlightRuntimeException fromGrpcRuntimeException(StatusRuntimeException sre) {
+ return fromGrpcStatus(sre.getStatus()).toRuntimeException();
+ }
+
+ /**
+ * Convert arbitrary exceptions to a {@link FlightRuntimeException}.
+ */
+ public static FlightRuntimeException fromThrowable(Throwable t) {
+ if (t instanceof StatusRuntimeException) {
+ return fromGrpcRuntimeException((StatusRuntimeException) t);
+ } else if (t instanceof FlightRuntimeException) {
+ return (FlightRuntimeException) t;
+ }
+ return CallStatus.UNKNOWN.withCause(t).withDescription(t.getMessage()).toRuntimeException();
+ }
+
+ /**
+ * Convert arbitrary exceptions to a {@link StatusRuntimeException} or {@link StatusException}.
+ *
+ * <p>Such exceptions can be passed to {@link io.grpc.stub.StreamObserver#onError(Throwable)} and will give the client
+ * a reasonable error message.
+ */
+ public static Throwable toGrpcException(Throwable ex) {
+ if (ex instanceof StatusRuntimeException) {
+ return ex;
+ } else if (ex instanceof StatusException) {
+ return ex;
+ } else if (ex instanceof FlightRuntimeException) {
+ final FlightRuntimeException fre = (FlightRuntimeException) ex;
+ return toGrpcStatus(fre.status()).asRuntimeException();
+ }
+ return Status.INTERNAL.withCause(ex).withDescription("There was an error servicing your request.")
+ .asRuntimeException();
+ }
+
+ /**
+ * Maps a transformation function to the elements of an iterator, while wrapping exceptions in {@link
+ * FlightRuntimeException}.
+ */
+ public static <FROM, TO> Iterator<TO> wrapIterator(Iterator<FROM> fromIterator,
+ Function<? super FROM, ? extends TO> transformer) {
+ Objects.requireNonNull(fromIterator);
+ Objects.requireNonNull(transformer);
+ return new Iterator<TO>() {
+ @Override
+ public boolean hasNext() {
+ try {
+ return fromIterator.hasNext();
+ } catch (StatusRuntimeException e) {
+ throw fromGrpcRuntimeException(e);
+ }
+ }
+
+ @Override
+ public TO next() {
+ try {
+ return transformer.apply(fromIterator.next());
+ } catch (StatusRuntimeException e) {
+ throw fromGrpcRuntimeException(e);
+ }
+ }
+ };
+ }
+}
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/FlightTestUtil.java b/java/flight/src/test/java/org/apache/arrow/flight/FlightTestUtil.java
index a10d490..cd043b6 100644
--- a/java/flight/src/test/java/org/apache/arrow/flight/FlightTestUtil.java
+++ b/java/flight/src/test/java/org/apache/arrow/flight/FlightTestUtil.java
@@ -28,6 +28,10 @@ import java.util.Objects;
import java.util.Random;
import java.util.function.Function;
+import org.junit.Assert;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.function.Executable;
+
/**
* Utility methods and constants for testing flight servers.
*/
@@ -118,6 +122,18 @@ public class FlightTestUtil {
return isEpollAvailable() || isKqueueAvailable();
}
+ /**
+ * Assert that the given runnable fails with a Flight exception of the given code.
+ * @param code The expected Flight status code.
+ * @param r The code to run.
+ * @return The thrown status.
+ */
+ public static CallStatus assertCode(FlightStatusCode code, Executable r) {
+ final FlightRuntimeException ex = Assertions.assertThrows(FlightRuntimeException.class, r);
+ Assert.assertEquals(code, ex.status().code());
+ return ex.status();
+ }
+
public static class CertKeyPair {
public final File cert;
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/TestApplicationMetadata.java b/java/flight/src/test/java/org/apache/arrow/flight/TestApplicationMetadata.java
index ad2c58f..e19bff0 100644
--- a/java/flight/src/test/java/org/apache/arrow/flight/TestApplicationMetadata.java
+++ b/java/flight/src/test/java/org/apache/arrow/flight/TestApplicationMetadata.java
@@ -36,7 +36,6 @@ import org.junit.Assert;
import org.junit.Ignore;
import org.junit.Test;
-import io.grpc.Status;
import io.netty.buffer.ArrowBuf;
/**
@@ -225,9 +224,9 @@ public class TestApplicationMetadata {
while (stream.next()) {
final ArrowBuf metadata = stream.getLatestMetadata();
if (current != metadata.getByte(0)) {
- ackStream.onError(Status.INVALID_ARGUMENT.withDescription(String
+ ackStream.onError(CallStatus.INVALID_ARGUMENT.withDescription(String
.format("Metadata does not match expected value; got %d but expected %d.", metadata.getByte(0),
- current)).asRuntimeException());
+ current)).toRuntimeException());
return;
}
ackStream.onNext(PutResult.metadata(metadata));
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/TestAuth.java b/java/flight/src/test/java/org/apache/arrow/flight/TestAuth.java
index bfaf660..161a9c4 100644
--- a/java/flight/src/test/java/org/apache/arrow/flight/TestAuth.java
+++ b/java/flight/src/test/java/org/apache/arrow/flight/TestAuth.java
@@ -66,7 +66,7 @@ public class TestAuth {
outgoing.send(new byte[0]);
// Ensure the server-side runs
incoming.next();
- outgoing.onError("test", new RuntimeException("test"));
+ outgoing.onError(new RuntimeException("test"));
}
@Override
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/TestBasicOperation.java b/java/flight/src/test/java/org/apache/arrow/flight/TestBasicOperation.java
index c764502..634e38f 100644
--- a/java/flight/src/test/java/org/apache/arrow/flight/TestBasicOperation.java
+++ b/java/flight/src/test/java/org/apache/arrow/flight/TestBasicOperation.java
@@ -38,6 +38,7 @@ import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
+
import org.junit.Assert;
import org.junit.Test;
@@ -194,6 +195,14 @@ public class TestBasicOperation {
});
}
+ @Test
+ public void propagateErrors() throws Exception {
+ test(client -> {
+ FlightTestUtil.assertCode(FlightStatusCode.UNIMPLEMENTED, () -> {
+ client.doAction(new Action("invalid-action")).forEachRemaining(action -> Assert.fail());
+ });
+ });
+ }
@Test
public void getStream() throws Exception {
@@ -345,7 +354,8 @@ public class TestBasicOperation {
break;
}
default:
- listener.onError(new UnsupportedOperationException());
+ listener.onError(CallStatus.UNIMPLEMENTED.withDescription("Action not implemented: " + action.getType())
+ .toRuntimeException());
}
}
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/TestTls.java b/java/flight/src/test/java/org/apache/arrow/flight/TestTls.java
index b9d4dea..255a00b 100644
--- a/java/flight/src/test/java/org/apache/arrow/flight/TestTls.java
+++ b/java/flight/src/test/java/org/apache/arrow/flight/TestTls.java
@@ -57,13 +57,12 @@ public class TestTls {
/**
* Make sure that connections are rejected when the root certificate isn't trusted.
*/
- @Test(expected = io.grpc.StatusRuntimeException.class)
+ @Test
public void rejectInvalidCert() {
test((builder) -> {
try (final FlightClient client = builder.build()) {
final Iterator<Result> responses = client.doAction(new Action("hello-world"));
- responses.next().getBody();
- Assert.fail("Call should have failed");
+ FlightTestUtil.assertCode(FlightStatusCode.UNAVAILABLE, () -> responses.next().getBody());
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
@@ -73,15 +72,14 @@ public class TestTls {
/**
* Make sure that connections are rejected when the hostname doesn't match.
*/
- @Test(expected = io.grpc.StatusRuntimeException.class)
+ @Test
public void rejectHostname() {
test((builder) -> {
try (final InputStream roots = new FileInputStream(FlightTestUtil.exampleTlsRootCert().toFile());
final FlightClient client = builder.trustedCertificates(roots).overrideHostname("fakehostname")
.build()) {
final Iterator<Result> responses = client.doAction(new Action("hello-world"));
- responses.next().getBody();
- Assert.fail("Call should have failed");
+ FlightTestUtil.assertCode(FlightStatusCode.UNAVAILABLE, () -> responses.next().getBody());
} catch (InterruptedException | IOException e) {
throw new RuntimeException(e);
}
@@ -119,8 +117,10 @@ public class TestTls {
if (action.getType().equals("hello-world")) {
listener.onNext(new Result("Hello, world!".getBytes(StandardCharsets.UTF_8)));
listener.onCompleted();
+ return;
}
- listener.onError(new UnsupportedOperationException("Invalid action " + action.getType()));
+ listener
+ .onError(CallStatus.UNIMPLEMENTED.withDescription("Invalid action " + action.getType()).toRuntimeException());
}
@Override
diff --git a/java/flight/src/test/java/org/apache/arrow/flight/auth/TestBasicAuth.java b/java/flight/src/test/java/org/apache/arrow/flight/auth/TestBasicAuth.java
index 9fe6b04..5a2cae6 100644
--- a/java/flight/src/test/java/org/apache/arrow/flight/auth/TestBasicAuth.java
+++ b/java/flight/src/test/java/org/apache/arrow/flight/auth/TestBasicAuth.java
@@ -17,8 +17,6 @@
package org.apache.arrow.flight.auth;
-import static org.junit.jupiter.api.Assertions.assertThrows;
-
import java.io.IOException;
import java.util.Arrays;
import java.util.Optional;
@@ -27,6 +25,7 @@ import org.apache.arrow.flight.Criteria;
import org.apache.arrow.flight.FlightClient;
import org.apache.arrow.flight.FlightInfo;
import org.apache.arrow.flight.FlightServer;
+import org.apache.arrow.flight.FlightStatusCode;
import org.apache.arrow.flight.FlightStream;
import org.apache.arrow.flight.FlightTestUtil;
import org.apache.arrow.flight.NoOpFlightProducer;
@@ -46,10 +45,7 @@ import org.junit.Test;
import com.google.common.collect.ImmutableList;
-import io.grpc.StatusRuntimeException;
-
public class TestBasicAuth {
- final String PERMISSION_DENIED = "PERMISSION_DENIED";
private static final String USERNAME = "flight";
private static final String PASSWORD = "woohoo";
@@ -79,20 +75,20 @@ public class TestBasicAuth {
@Test
public void invalidAuth() {
- assertThrows(StatusRuntimeException.class, () -> {
+ FlightTestUtil.assertCode(FlightStatusCode.UNAUTHENTICATED, () -> {
client.authenticateBasic(USERNAME, "WRONG");
- }, PERMISSION_DENIED);
+ });
- assertThrows(StatusRuntimeException.class, () -> {
- client.listFlights(Criteria.ALL);
- }, PERMISSION_DENIED);
+ FlightTestUtil.assertCode(FlightStatusCode.UNAUTHENTICATED, () -> {
+ client.listFlights(Criteria.ALL).forEach(action -> Assert.fail());
+ });
}
@Test
public void didntAuth() {
- assertThrows(StatusRuntimeException.class, () -> {
- client.listFlights(Criteria.ALL);
- }, PERMISSION_DENIED);
+ FlightTestUtil.assertCode(FlightStatusCode.UNAUTHENTICATED, () -> {
+ client.listFlights(Criteria.ALL).forEach(action -> Assert.fail());
+ });
}
@Before
diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx
index 8c51c14..8502471 100644
--- a/python/pyarrow/_flight.pyx
+++ b/python/pyarrow/_flight.pyx
@@ -28,6 +28,7 @@ from cython.operator cimport dereference as deref
from pyarrow.compat import frombytes, tobytes
from pyarrow.lib cimport *
+from pyarrow.lib import ArrowException
from pyarrow.lib import as_buffer
from pyarrow.includes.libarrow_flight cimport *
from pyarrow.ipc import _ReadPandasOption
@@ -37,6 +38,32 @@ import pyarrow.lib as lib
cdef CFlightCallOptions DEFAULT_CALL_OPTIONS
+cdef int check_flight_status(const CStatus& status) nogil except -1:
+ cdef shared_ptr[FlightStatusDetail] detail
+
+ if status.ok():
+ return 0
+
+ detail = FlightStatusDetail.UnwrapStatus(status)
+ if detail:
+ with gil:
+ message = frombytes(status.message())
+ if detail.get().code() == CFlightStatusInternal:
+ raise FlightInternalError(message)
+ elif detail.get().code() == CFlightStatusTimedOut:
+ raise FlightTimedOutError(message)
+ elif detail.get().code() == CFlightStatusCancelled:
+ raise FlightCancelledError(message)
+ elif detail.get().code() == CFlightStatusUnauthenticated:
+ raise FlightUnauthenticatedError(message)
+ elif detail.get().code() == CFlightStatusUnauthorized:
+ raise FlightUnauthorizedError(message)
+ elif detail.get().code() == CFlightStatusUnavailable:
+ raise FlightUnavailableError(message)
+
+ return check_status(status)
+
+
cdef class FlightCallOptions:
"""RPC-layer options for a Flight call."""
@@ -73,6 +100,45 @@ class CertKeyPair(_CertKeyPair):
"""A TLS certificate and key for use in Flight."""
+cdef class FlightError(Exception):
+ cdef dict __dict__
+
+ cdef CStatus to_status(self):
+ message = tobytes("Flight error: {}".format(str(self)))
+ return CStatus_UnknownError(message)
+
+
+cdef class FlightInternalError(FlightError, ArrowException):
+ cdef CStatus to_status(self):
+ return MakeFlightError(CFlightStatusInternal, tobytes(str(self)))
+
+
+cdef class FlightTimedOutError(FlightError, ArrowException):
+ cdef CStatus to_status(self):
+ return MakeFlightError(CFlightStatusTimedOut, tobytes(str(self)))
+
+
+cdef class FlightCancelledError(FlightError, ArrowException):
+ cdef CStatus to_status(self):
+ return MakeFlightError(CFlightStatusCancelled, tobytes(str(self)))
+
+
+cdef class FlightUnauthenticatedError(FlightError, ArrowException):
+ cdef CStatus to_status(self):
+ return MakeFlightError(
+ CFlightStatusUnauthenticated, tobytes(str(self)))
+
+
+cdef class FlightUnauthorizedError(FlightError, ArrowException):
+ cdef CStatus to_status(self):
+ return MakeFlightError(CFlightStatusUnauthorized, tobytes(str(self)))
+
+
+cdef class FlightUnavailableError(FlightError, ArrowException):
+ cdef CStatus to_status(self):
+ return MakeFlightError(CFlightStatusUnavailable, tobytes(str(self)))
+
+
cdef class Action:
"""An action executable on a Flight service."""
cdef:
@@ -245,7 +311,7 @@ cdef class FlightDescriptor:
"""
cdef c_string out
- check_status(self.descriptor.SerializeToString(&out))
+ check_flight_status(self.descriptor.SerializeToString(&out))
return out
@classmethod
@@ -258,7 +324,7 @@ cdef class FlightDescriptor:
"""
cdef FlightDescriptor descriptor = \
FlightDescriptor.__new__(FlightDescriptor)
- check_status(CFlightDescriptor.Deserialize(
+ check_flight_status(CFlightDescriptor.Deserialize(
tobytes(serialized), &descriptor.descriptor))
return descriptor
@@ -287,7 +353,7 @@ cdef class Ticket:
"""
cdef c_string out
- check_status(self.ticket.SerializeToString(&out))
+ check_flight_status(self.ticket.SerializeToString(&out))
return out
@classmethod
@@ -301,7 +367,8 @@ cdef class Ticket:
cdef:
CTicket c_ticket
Ticket ticket
- check_status(CTicket.Deserialize(tobytes(serialized), &c_ticket))
+ check_flight_status(
+ CTicket.Deserialize(tobytes(serialized), &c_ticket))
ticket = Ticket.__new__(Ticket)
ticket.ticket = c_ticket
return ticket
@@ -319,7 +386,7 @@ cdef class Location:
CLocation location
def __init__(self, uri):
- check_status(CLocation.Parse(tobytes(uri), &self.location))
+ check_flight_status(CLocation.Parse(tobytes(uri), &self.location))
def __repr__(self):
return '<Location {}>'.format(self.location.ToString())
@@ -343,7 +410,8 @@ cdef class Location:
c_string c_host = tobytes(host)
int c_port = port
Location result = Location.__new__(Location)
- check_status(CLocation.ForGrpcTcp(c_host, c_port, &result.location))
+ check_flight_status(
+ CLocation.ForGrpcTcp(c_host, c_port, &result.location))
return result
@staticmethod
@@ -353,7 +421,8 @@ cdef class Location:
c_string c_host = tobytes(host)
int c_port = port
Location result = Location.__new__(Location)
- check_status(CLocation.ForGrpcTls(c_host, c_port, &result.location))
+ check_flight_status(
+ CLocation.ForGrpcTls(c_host, c_port, &result.location))
return result
@staticmethod
@@ -362,7 +431,7 @@ cdef class Location:
cdef:
c_string c_path = tobytes(path)
Location result = Location.__new__(Location)
- check_status(CLocation.ForGrpcUnix(c_path, &result.location))
+ check_flight_status(CLocation.ForGrpcUnix(c_path, &result.location))
return result
@staticmethod
@@ -375,7 +444,8 @@ cdef class Location:
cdef CLocation unwrap(object location) except *:
cdef CLocation c_location
if isinstance(location, six.text_type):
- check_status(CLocation.Parse(tobytes(location), &c_location))
+ check_flight_status(
+ CLocation.Parse(tobytes(location), &c_location))
return c_location
elif not isinstance(location, Location):
raise TypeError("Must provide a Location, not '{}'".format(
@@ -416,7 +486,8 @@ cdef class FlightEndpoint:
c_location = (<Location> location).location
else:
c_location = CLocation()
- check_status(CLocation.Parse(tobytes(location), &c_location))
+ check_flight_status(
+ CLocation.Parse(tobytes(location), &c_location))
self.endpoint.locations.push_back(c_location)
@property
@@ -470,11 +541,11 @@ cdef class FlightInfo:
raise TypeError('Endpoint {} is not instance of'
' FlightEndpoint'.format(endpoint))
- check_status(CreateFlightInfo(c_schema,
- descriptor.descriptor,
- c_endpoints,
- total_records,
- total_bytes, &self.info))
+ check_flight_status(CreateFlightInfo(c_schema,
+ descriptor.descriptor,
+ c_endpoints,
+ total_records,
+ total_bytes, &self.info))
@property
def total_records(self):
@@ -493,7 +564,7 @@ cdef class FlightInfo:
shared_ptr[CSchema] schema
CDictionaryMemo dummy_memo
- check_status(self.info.get().GetSchema(&dummy_memo, &schema))
+ check_flight_status(self.info.get().GetSchema(&dummy_memo, &schema))
return pyarrow_wrap_schema(schema)
@property
@@ -527,7 +598,7 @@ cdef class FlightInfo:
"""
cdef c_string out
- check_status(self.info.get().SerializeToString(&out))
+ check_flight_status(self.info.get().SerializeToString(&out))
return out
@classmethod
@@ -539,7 +610,7 @@ cdef class FlightInfo:
"""
cdef FlightInfo info = FlightInfo.__new__(FlightInfo)
- check_status(CFlightInfo.Deserialize(
+ check_flight_status(CFlightInfo.Deserialize(
tobytes(serialized), &info.info))
return info
@@ -591,7 +662,7 @@ cdef class MetadataRecordBatchReader(_MetadataRecordBatchReader,
cdef:
shared_ptr[CTable] c_table
with nogil:
- check_status(self.reader.get().ReadAll(&c_table))
+ check_flight_status(self.reader.get().ReadAll(&c_table))
return pyarrow_wrap_table(c_table)
def read_chunk(self):
@@ -614,7 +685,7 @@ cdef class MetadataRecordBatchReader(_MetadataRecordBatchReader,
FlightStreamChunk chunk = FlightStreamChunk()
with nogil:
- check_status(self.reader.get().Next(&chunk.chunk))
+ check_flight_status(self.reader.get().Next(&chunk.chunk))
if chunk.chunk.data == NULL:
raise StopIteration
@@ -647,7 +718,7 @@ cdef class FlightStreamWriter(_CRecordBatchWriter):
"""
cdef shared_ptr[CBuffer] c_buf = pyarrow_unwrap_buffer(as_buffer(buf))
with nogil:
- check_status(
+ check_flight_status(
(<CFlightStreamWriter*> self.writer.get())
.WriteWithMetadata(deref(batch.batch),
c_buf,
@@ -664,7 +735,7 @@ cdef class FlightMetadataReader:
"""Read the next metadata message."""
cdef shared_ptr[CBuffer] buf
with nogil:
- check_status(self.reader.get().ReadMetadata(&buf))
+ check_flight_status(self.reader.get().ReadMetadata(&buf))
if buf == NULL:
return None
return pyarrow_wrap_buffer(buf)
@@ -686,7 +757,7 @@ cdef class FlightMetadataWriter:
cdef shared_ptr[CBuffer] buf = \
pyarrow_unwrap_buffer(as_buffer(message))
with nogil:
- check_status(self.writer.get().WriteMetadata(deref(buf)))
+ check_flight_status(self.writer.get().WriteMetadata(deref(buf)))
cdef class FlightClient:
@@ -726,8 +797,8 @@ cdef class FlightClient:
c_options.override_hostname = tobytes(override_hostname)
with nogil:
- check_status(CFlightClient.Connect(c_location, c_options,
- &result.client))
+ check_flight_status(CFlightClient.Connect(c_location, c_options,
+ &result.client))
return result
@@ -751,8 +822,9 @@ cdef class FlightClient:
"not '{}'".format(type(auth_handler)))
handler.reset((<ClientAuthHandler> auth_handler).to_handler())
with nogil:
- check_status(self.client.get().Authenticate(deref(c_options),
- move(handler)))
+ check_flight_status(
+ self.client.get().Authenticate(deref(c_options),
+ move(handler)))
def list_actions(self, options: FlightCallOptions = None):
"""List the actions available on a service."""
@@ -761,7 +833,7 @@ cdef class FlightClient:
CFlightCallOptions* c_options = FlightCallOptions.unwrap(options)
with nogil:
- check_status(
+ check_flight_status(
self.client.get().ListActions(deref(c_options), &results))
result = []
@@ -780,14 +852,14 @@ cdef class FlightClient:
CAction c_action = Action.unwrap(action)
CFlightCallOptions* c_options = FlightCallOptions.unwrap(options)
with nogil:
- check_status(
+ check_flight_status(
self.client.get().DoAction(deref(c_options), c_action,
&results))
while True:
result = Result.__new__(Result)
with nogil:
- check_status(results.get().Next(&result.result))
+ check_flight_status(results.get().Next(&result.result))
if result.result == NULL:
break
yield result
@@ -801,13 +873,14 @@ cdef class FlightClient:
CCriteria c_criteria
with nogil:
- check_status(self.client.get().ListFlights(deref(c_options),
- c_criteria, &listing))
+ check_flight_status(
+ self.client.get().ListFlights(deref(c_options),
+ c_criteria, &listing))
while True:
result = FlightInfo.__new__(FlightInfo)
with nogil:
- check_status(listing.get().Next(&result.info))
+ check_flight_status(listing.get().Next(&result.info))
if result.info == NULL:
break
yield result
@@ -822,7 +895,7 @@ cdef class FlightClient:
FlightDescriptor.unwrap(descriptor)
with nogil:
- check_status(self.client.get().GetFlightInfo(
+ check_flight_status(self.client.get().GetFlightInfo(
deref(c_options), c_descriptor, &result.info))
return result
@@ -839,7 +912,7 @@ cdef class FlightClient:
CFlightCallOptions* c_options = FlightCallOptions.unwrap(options)
with nogil:
- check_status(
+ check_flight_status(
self.client.get().DoGet(
deref(c_options), ticket.ticket, &reader))
result = FlightStreamReader()
@@ -866,7 +939,7 @@ cdef class FlightClient:
FlightMetadataReader reader = FlightMetadataReader()
with nogil:
- check_status(self.client.get().DoPut(
+ check_flight_status(self.client.get().DoPut(
deref(c_options),
c_descriptor,
c_schema,
@@ -982,7 +1055,7 @@ cdef class ServerAuthReader:
raise ValueError("Cannot use ServerAuthReader outside "
"ServerAuthHandler.authenticate")
with nogil:
- check_status(self.reader.Read(&token))
+ check_flight_status(self.reader.Read(&token))
return token
cdef void poison(self):
@@ -1013,7 +1086,7 @@ cdef class ServerAuthSender:
raise ValueError("Cannot use ServerAuthSender outside "
"ServerAuthHandler.authenticate")
with nogil:
- check_status(self.sender.Write(c_message))
+ check_flight_status(self.sender.Write(c_message))
cdef void poison(self):
"""Prevent further usage of this object.
@@ -1043,7 +1116,7 @@ cdef class ClientAuthReader:
raise ValueError("Cannot use ClientAuthReader outside "
"ClientAuthHandler.authenticate")
with nogil:
- check_status(self.reader.Read(&token))
+ check_flight_status(self.reader.Read(&token))
return token
cdef void poison(self):
@@ -1074,7 +1147,7 @@ cdef class ClientAuthSender:
raise ValueError("Cannot use ClientAuthSender outside "
"ClientAuthHandler.authenticate")
with nogil:
- check_status(self.sender.Write(c_message))
+ check_flight_status(self.sender.Write(c_message))
cdef void poison(self):
"""Prevent further usage of this object.
@@ -1093,7 +1166,7 @@ cdef class ClientAuthSender:
return result
-cdef void _data_stream_next(void* self, CFlightPayload* payload) except *:
+cdef CStatus _data_stream_next(void* self, CFlightPayload* payload) except *:
"""Callback for implementing FlightDataStream in Python."""
cdef:
unique_ptr[CFlightDataStream] data_stream
@@ -1104,18 +1177,20 @@ cdef void _data_stream_next(void* self, CFlightPayload* payload) except *:
stream = <GeneratorStream> py_stream
if stream.current_stream != nullptr:
- check_status(stream.current_stream.get().Next(payload))
+ check_flight_status(stream.current_stream.get().Next(payload))
# If the stream ended, see if there's another stream from the
# generator
if payload.ipc_message.metadata != nullptr:
- return
+ return CStatus_OK()
stream.current_stream.reset(nullptr)
try:
result = next(stream.generator)
except StopIteration:
payload.ipc_message.metadata.reset(<CBuffer*> nullptr)
- return
+ return CStatus_OK()
+ except FlightError as flight_error:
+ return (<FlightError> flight_error).to_status()
if isinstance(result, (list, tuple)):
result, metadata = result
@@ -1144,7 +1219,7 @@ cdef void _data_stream_next(void* self, CFlightPayload* payload) except *:
stream_schema))
stream.current_stream.reset(
new CPyFlightDataStream(result, move(data_stream)))
- _data_stream_next(self, payload)
+ return _data_stream_next(self, payload)
elif isinstance(result, RecordBatch):
batch = <RecordBatch> result
if batch.schema != stream_schema:
@@ -1153,7 +1228,7 @@ cdef void _data_stream_next(void* self, CFlightPayload* payload) except *:
"GeneratorStream. "
"Got: {}\nExpected: {}".format(batch.schema,
stream_schema))
- check_status(_GetRecordBatchPayload(
+ check_flight_status(_GetRecordBatchPayload(
deref(batch.batch),
c_default_memory_pool(),
&payload.ipc_message))
@@ -1164,45 +1239,56 @@ cdef void _data_stream_next(void* self, CFlightPayload* payload) except *:
"an iterator of FlightDataStream, Table, "
"RecordBatch, or RecordBatchStreamReader objects, "
"not {}.".format(type(result)))
+ return CStatus_OK()
-cdef void _list_flights(void* self, const CServerCallContext& context,
- const CCriteria* c_criteria,
- unique_ptr[CFlightListing]* listing) except *:
+cdef CStatus _list_flights(void* self, const CServerCallContext& context,
+ const CCriteria* c_criteria,
+ unique_ptr[CFlightListing]* listing) except *:
"""Callback for implementing ListFlights in Python."""
cdef:
vector[CFlightInfo] flights
- result = (<object> self).list_flights(ServerCallContext.wrap(context),
- c_criteria.expression)
- for info in result:
- if not isinstance(info, FlightInfo):
- raise TypeError("FlightServerBase.list_flights must return "
- "FlightInfo instances, but got {}".format(
- type(info)))
- flights.push_back(deref((<FlightInfo> info).info.get()))
- listing.reset(new CSimpleFlightListing(flights))
-
-
-cdef void _get_flight_info(void* self, const CServerCallContext& context,
- CFlightDescriptor c_descriptor,
- unique_ptr[CFlightInfo]* info) except *:
+
+ try:
+ result = (<object> self).list_flights(ServerCallContext.wrap(context),
+ c_criteria.expression)
+ for info in result:
+ if not isinstance(info, FlightInfo):
+ raise TypeError("FlightServerBase.list_flights must return "
+ "FlightInfo instances, but got {}".format(
+ type(info)))
+ flights.push_back(deref((<FlightInfo> info).info.get()))
+ listing.reset(new CSimpleFlightListing(flights))
+ except FlightError as flight_error:
+ return (<FlightError> flight_error).to_status()
+ return CStatus_OK()
+
+
+cdef CStatus _get_flight_info(void* self, const CServerCallContext& context,
+ CFlightDescriptor c_descriptor,
+ unique_ptr[CFlightInfo]* info) except *:
"""Callback for implementing Flight servers in Python."""
cdef:
FlightDescriptor py_descriptor = \
FlightDescriptor.__new__(FlightDescriptor)
py_descriptor.descriptor = c_descriptor
- result = (<object> self).get_flight_info(ServerCallContext.wrap(context),
- py_descriptor)
+ try:
+ result = (<object> self).get_flight_info(
+ ServerCallContext.wrap(context),
+ py_descriptor)
+ except FlightError as flight_error:
+ return (<FlightError> flight_error).to_status()
if not isinstance(result, FlightInfo):
raise TypeError("FlightServerBase.get_flight_info must return "
"a FlightInfo instance, but got {}".format(
type(result)))
info.reset(new CFlightInfo(deref((<FlightInfo> result).info.get())))
+ return CStatus_OK()
-cdef void _do_put(void* self, const CServerCallContext& context,
- unique_ptr[CFlightMessageReader] reader,
- unique_ptr[CFlightMetadataWriter] writer) except *:
+cdef CStatus _do_put(void* self, const CServerCallContext& context,
+ unique_ptr[CFlightMessageReader] reader,
+ unique_ptr[CFlightMetadataWriter] writer) except *:
"""Callback for implementing Flight servers in Python."""
cdef:
MetadataRecordBatchReader py_reader = MetadataRecordBatchReader()
@@ -1215,20 +1301,27 @@ cdef void _do_put(void* self, const CServerCallContext& context,
py_reader.schema = pyarrow_wrap_schema(
py_reader.reader.get().schema())
py_writer.writer.reset(writer.release())
- (<object> self).do_put(ServerCallContext.wrap(context), descriptor,
- py_reader, py_writer)
+ try:
+ (<object> self).do_put(ServerCallContext.wrap(context), descriptor,
+ py_reader, py_writer)
+ return CStatus_OK()
+ except FlightError as flight_error:
+ return (<FlightError> flight_error).to_status()
-cdef void _do_get(void* self, const CServerCallContext& context,
- CTicket ticket,
- unique_ptr[CFlightDataStream]* stream) except *:
+cdef CStatus _do_get(void* self, const CServerCallContext& context,
+ CTicket ticket,
+ unique_ptr[CFlightDataStream]* stream) except *:
"""Callback for implementing Flight servers in Python."""
cdef:
unique_ptr[CFlightDataStream] data_stream
py_ticket = Ticket(ticket.ticket)
- result = (<object> self).do_get(ServerCallContext.wrap(context),
- py_ticket)
+ try:
+ result = (<object> self).do_get(ServerCallContext.wrap(context),
+ py_ticket)
+ except FlightError as flight_error:
+ return (<FlightError> flight_error).to_status()
if not isinstance(result, FlightDataStream):
raise TypeError("FlightServerBase.do_get must return "
"a FlightDataStream")
@@ -1236,10 +1329,13 @@ cdef void _do_get(void* self, const CServerCallContext& context,
(<FlightDataStream> result).to_stream())
stream[0] = unique_ptr[CFlightDataStream](
new CPyFlightDataStream(result, move(data_stream)))
+ return CStatus_OK()
-cdef void _do_action_result_next(void* self,
- unique_ptr[CFlightResult]* result) except *:
+cdef CStatus _do_action_result_next(
+ void* self,
+ unique_ptr[CFlightResult]* result
+) except *:
"""Callback for implementing Flight servers in Python."""
cdef:
CFlightResult* c_result
@@ -1253,70 +1349,95 @@ cdef void _do_action_result_next(void* self,
result.reset(new CFlightResult(deref(c_result)))
except StopIteration:
result.reset(nullptr)
+ except FlightError as flight_error:
+ return (<FlightError> flight_error).to_status()
+ return CStatus_OK()
-cdef void _do_action(void* self, const CServerCallContext& context,
- const CAction& action,
- unique_ptr[CResultStream]* result) except *:
+cdef CStatus _do_action(void* self, const CServerCallContext& context,
+ const CAction& action,
+ unique_ptr[CResultStream]* result) except *:
"""Callback for implementing Flight servers in Python."""
cdef:
function[cb_result_next] ptr = &_do_action_result_next
py_action = Action(action.type, pyarrow_wrap_buffer(action.body))
- responses = (<object> self).do_action(ServerCallContext.wrap(context),
- py_action)
+ try:
+ responses = (<object> self).do_action(ServerCallContext.wrap(context),
+ py_action)
+ except FlightError as flight_error:
+ return (<FlightError> flight_error).to_status()
result.reset(new CPyFlightResultStream(responses, ptr))
+ return CStatus_OK()
-cdef void _list_actions(void* self, const CServerCallContext& context,
- vector[CActionType]* actions) except *:
+cdef CStatus _list_actions(void* self, const CServerCallContext& context,
+ vector[CActionType]* actions) except *:
"""Callback for implementing Flight servers in Python."""
cdef:
CActionType action_type
# Method should return a list of ActionTypes or similar tuple
- result = (<object> self).list_actions(ServerCallContext.wrap(context))
- for action in result:
- action_type.type = tobytes(action[0])
- action_type.description = tobytes(action[1])
- actions.push_back(action_type)
-
-
-cdef void _server_authenticate(void* self, CServerAuthSender* outgoing,
- CServerAuthReader* incoming) except *:
+ try:
+ result = (<object> self).list_actions(ServerCallContext.wrap(context))
+ for action in result:
+ action_type.type = tobytes(action[0])
+ action_type.description = tobytes(action[1])
+ actions.push_back(action_type)
+ except FlightError as flight_error:
+ return (<FlightError> flight_error).to_status()
+ return CStatus_OK()
+
+
+cdef CStatus _server_authenticate(void* self, CServerAuthSender* outgoing,
+ CServerAuthReader* incoming) except *:
"""Callback for implementing authentication in Python."""
sender = ServerAuthSender.wrap(outgoing)
reader = ServerAuthReader.wrap(incoming)
try:
(<object> self).authenticate(sender, reader)
+ except FlightError as flight_error:
+ return (<FlightError> flight_error).to_status()
finally:
sender.poison()
reader.poison()
+ return CStatus_OK()
-cdef void _is_valid(void* self, const c_string& token,
- c_string* peer_identity) except *:
+cdef CStatus _is_valid(void* self, const c_string& token,
+ c_string* peer_identity) except *:
"""Callback for implementing authentication in Python."""
cdef c_string c_result
- c_result = tobytes((<object> self).is_valid(token))
- peer_identity[0] = c_result
+ try:
+ c_result = tobytes((<object> self).is_valid(token))
+ peer_identity[0] = c_result
+ except FlightError as flight_error:
+ return (<FlightError> flight_error).to_status()
+ return CStatus_OK()
-cdef void _client_authenticate(void* self, CClientAuthSender* outgoing,
- CClientAuthReader* incoming) except *:
+cdef CStatus _client_authenticate(void* self, CClientAuthSender* outgoing,
+ CClientAuthReader* incoming) except *:
"""Callback for implementing authentication in Python."""
sender = ClientAuthSender.wrap(outgoing)
reader = ClientAuthReader.wrap(incoming)
try:
(<object> self).authenticate(sender, reader)
+ except FlightError as flight_error:
+ return (<FlightError> flight_error).to_status()
finally:
sender.poison()
reader.poison()
+ return CStatus_OK()
-cdef void _get_token(void* self, c_string* token) except *:
+cdef CStatus _get_token(void* self, c_string* token) except *:
"""Callback for implementing authentication in Python."""
cdef c_string c_result
- c_result = tobytes((<object> self).get_token())
- token[0] = c_result
+ try:
+ c_result = tobytes((<object> self).get_token())
+ token[0] = c_result
+ except FlightError as flight_error:
+ return (<FlightError> flight_error).to_status()
+ return CStatus_OK()
cdef class ServerAuthHandler:
@@ -1441,7 +1562,7 @@ cdef class FlightServerBase:
c_server = new PyFlightServer(self, vtable)
self.server.reset(c_server)
with nogil:
- check_status(c_server.Init(deref(c_options)))
+ check_flight_status(c_server.Init(deref(c_options)))
def run(self):
"""
@@ -1453,7 +1574,7 @@ cdef class FlightServerBase:
if self.server.get() == nullptr:
raise ValueError("run() on uninitialized FlightServerBase")
with nogil:
- check_status(self.server.get().ServeWithSignals())
+ check_flight_status(self.server.get().ServeWithSignals())
def list_flights(self, context, criteria):
raise NotImplementedError
@@ -1489,4 +1610,4 @@ cdef class FlightServerBase:
if self.server.get() == nullptr:
raise ValueError("shutdown() on uninitialized FlightServerBase")
with nogil:
- check_status(self.server.get().Shutdown())
+ check_flight_status(self.server.get().Shutdown())
diff --git a/python/pyarrow/flight.py b/python/pyarrow/flight.py
index 05198e4..2d037f1 100644
--- a/python/pyarrow/flight.py
+++ b/python/pyarrow/flight.py
@@ -33,6 +33,13 @@ from pyarrow._flight import ( # noqa
FlightEndpoint,
FlightInfo,
FlightServerBase,
+ FlightError,
+ FlightInternalError,
+ FlightTimedOutError,
+ FlightCancelledError,
+ FlightUnauthenticatedError,
+ FlightUnauthorizedError,
+ FlightUnavailableError,
GeneratorStream,
Location,
Ticket,
diff --git a/python/pyarrow/includes/libarrow_flight.pxd b/python/pyarrow/includes/libarrow_flight.pxd
index 7ac744a..7373e0b 100644
--- a/python/pyarrow/includes/libarrow_flight.pxd
+++ b/python/pyarrow/includes/libarrow_flight.pxd
@@ -243,33 +243,58 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil:
unique_ptr[CFlightStreamWriter]* stream,
unique_ptr[CFlightMetadataReader]* reader)
+ cdef cppclass CFlightStatusCode" arrow::flight::FlightStatusCode":
+ bint operator==(CFlightStatusCode)
+
+ CFlightStatusCode CFlightStatusInternal \
+ " arrow::flight::FlightStatusCode::Internal"
+ CFlightStatusCode CFlightStatusTimedOut \
+ " arrow::flight::FlightStatusCode::TimedOut"
+ CFlightStatusCode CFlightStatusCancelled \
+ " arrow::flight::FlightStatusCode::Cancelled"
+ CFlightStatusCode CFlightStatusUnauthenticated \
+ " arrow::flight::FlightStatusCode::Unauthenticated"
+ CFlightStatusCode CFlightStatusUnauthorized \
+ " arrow::flight::FlightStatusCode::Unauthorized"
+ CFlightStatusCode CFlightStatusUnavailable \
+ " arrow::flight::FlightStatusCode::Unavailable"
+
+ cdef cppclass FlightStatusDetail" arrow::flight::FlightStatusDetail":
+ CFlightStatusCode code()
+ @staticmethod
+ shared_ptr[FlightStatusDetail] UnwrapStatus(const CStatus& status)
+
+ cdef CStatus MakeFlightError" arrow::flight::MakeFlightError" \
+ (CFlightStatusCode code, const c_string& message)
+
# Callbacks for implementing Flight servers
-# Use typedef to emulate syntax for std::function<void(...)>
-ctypedef void cb_list_flights(object, const CServerCallContext&,
- const CCriteria*,
- unique_ptr[CFlightListing]*)
-ctypedef void cb_get_flight_info(object, const CServerCallContext&,
- const CFlightDescriptor&,
- unique_ptr[CFlightInfo]*)
-ctypedef void cb_do_put(object, const CServerCallContext&,
- unique_ptr[CFlightMessageReader],
- unique_ptr[CFlightMetadataWriter])
-ctypedef void cb_do_get(object, const CServerCallContext&,
- const CTicket&,
- unique_ptr[CFlightDataStream]*)
-ctypedef void cb_do_action(object, const CServerCallContext&, const CAction&,
- unique_ptr[CResultStream]*)
-ctypedef void cb_list_actions(object, const CServerCallContext&,
- vector[CActionType]*)
-ctypedef void cb_result_next(object, unique_ptr[CFlightResult]*)
-ctypedef void cb_data_stream_next(object, CFlightPayload*)
-ctypedef void cb_server_authenticate(object, CServerAuthSender*,
- CServerAuthReader*)
-ctypedef void cb_is_valid(object, const c_string&, c_string*)
-ctypedef void cb_client_authenticate(object, CClientAuthSender*,
- CClientAuthReader*)
-ctypedef void cb_get_token(object, c_string*)
+# Use typedef to emulate syntax for std::function<void(..)>
+ctypedef CStatus cb_list_flights(object, const CServerCallContext&,
+ const CCriteria*,
+ unique_ptr[CFlightListing]*)
+ctypedef CStatus cb_get_flight_info(object, const CServerCallContext&,
+ const CFlightDescriptor&,
+ unique_ptr[CFlightInfo]*)
+ctypedef CStatus cb_do_put(object, const CServerCallContext&,
+ unique_ptr[CFlightMessageReader],
+ unique_ptr[CFlightMetadataWriter])
+ctypedef CStatus cb_do_get(object, const CServerCallContext&,
+ const CTicket&,
+ unique_ptr[CFlightDataStream]*)
+ctypedef CStatus cb_do_action(object, const CServerCallContext&,
+ const CAction&,
+ unique_ptr[CResultStream]*)
+ctypedef CStatus cb_list_actions(object, const CServerCallContext&,
+ vector[CActionType]*)
+ctypedef CStatus cb_result_next(object, unique_ptr[CFlightResult]*)
+ctypedef CStatus cb_data_stream_next(object, CFlightPayload*)
+ctypedef CStatus cb_server_authenticate(object, CServerAuthSender*,
+ CServerAuthReader*)
+ctypedef CStatus cb_is_valid(object, const c_string&, c_string*)
+ctypedef CStatus cb_client_authenticate(object, CClientAuthSender*,
+ CClientAuthReader*)
+ctypedef CStatus cb_get_token(object, c_string*)
cdef extern from "arrow/python/flight.h" namespace "arrow::py::flight" nogil:
cdef cppclass PyFlightServerVtable:
diff --git a/python/pyarrow/tests/test_flight.py b/python/pyarrow/tests/test_flight.py
index ec8c52c..3e21369 100644
--- a/python/pyarrow/tests/test_flight.py
+++ b/python/pyarrow/tests/test_flight.py
@@ -276,6 +276,32 @@ class SlowFlightServer(FlightServerBase):
yield pa.Table.from_arrays(data1, names=['a'])
+class ErrorFlightServer(FlightServerBase):
+ """A Flight server that uses all the Flight-specific errors."""
+
+ def do_action(self, context, action):
+ if action.type == "internal":
+ raise flight.FlightInternalError("foo")
+ elif action.type == "timedout":
+ raise flight.FlightTimedOutError("foo")
+ elif action.type == "cancel":
+ raise flight.FlightCancelledError("foo")
+ elif action.type == "unauthenticated":
+ raise flight.FlightUnauthenticatedError("foo")
+ elif action.type == "unauthorized":
+ raise flight.FlightUnauthorizedError("foo")
+ raise NotImplementedError
+
+ def list_flights(self, context, criteria):
+ yield flight.FlightInfo(
+ pa.schema([]),
+ flight.FlightDescriptor.for_path('/foo'),
+ [],
+ -1, -1
+ )
+ raise flight.FlightInternalError("foo")
+
+
class HttpBasicServerAuthHandler(ServerAuthHandler):
"""An example implementation of HTTP basic authentication."""
@@ -288,13 +314,13 @@ class HttpBasicServerAuthHandler(ServerAuthHandler):
def is_valid(self, token):
if not token:
- raise ValueError("unauthenticated: token not provided")
+ raise flight.FlightUnauthenticatedError("token not provided")
token = base64.b64decode(token)
username, password = token.split(b':')
if username not in self.creds:
- raise ValueError("unknown user")
+ raise flight.FlightUnauthenticatedError("unknown user")
if self.creds[username] != password:
- raise ValueError("wrong password")
+ raise flight.FlightUnauthenticatedError("wrong password")
return username
@@ -326,12 +352,13 @@ class TokenServerAuthHandler(ServerAuthHandler):
if username in self.creds and self.creds[username] == password:
outgoing.write(base64.b64encode(b'secret:' + username))
else:
- raise ValueError("unauthenticated: invalid username/password")
+ raise flight.FlightUnauthenticatedError(
+ "invalid username/password")
def is_valid(self, token):
token = base64.b64decode(token)
if not token.startswith(b'secret:'):
- raise ValueError("unauthenticated: invalid token")
+ raise flight.FlightUnauthenticatedError("invalid token")
return token[7:]
@@ -552,7 +579,7 @@ def test_timeout_fires():
options = flight.FlightCallOptions(timeout=0.2)
# gRPC error messages change based on version, so don't look
# for a particular error
- with pytest.raises(pa.ArrowIOError):
+ with pytest.raises(flight.FlightTimedOutError):
list(client.do_action(action, options=options))
@@ -580,7 +607,8 @@ def test_http_basic_unauth():
auth_handler=basic_auth_handler) as server_location:
client = flight.FlightClient.connect(server_location)
action = flight.Action("who-am-i", b"")
- with pytest.raises(pa.ArrowException, match=".*unauthenticated.*"):
+ with pytest.raises(flight.FlightUnauthenticatedError,
+ match=".*unauthenticated.*"):
list(client.do_action(action))
@@ -602,7 +630,8 @@ def test_http_basic_auth_invalid_password():
client = flight.FlightClient.connect(server_location)
action = flight.Action("who-am-i", b"")
client.authenticate(HttpBasicClientAuthHandler('test', 'wrong'))
- with pytest.raises(pa.ArrowException, match=".*wrong password.*"):
+ with pytest.raises(flight.FlightUnauthenticatedError,
+ match=".*wrong password.*"):
next(client.do_action(action))
@@ -622,7 +651,7 @@ def test_token_auth_invalid():
with flight_server(EchoStreamFlightServer,
auth_handler=token_auth_handler) as server_location:
client = flight.FlightClient.connect(server_location)
- with pytest.raises(pa.ArrowException, match=".*unauthenticated.*"):
+ with pytest.raises(flight.FlightUnauthenticatedError):
client.authenticate(TokenClientAuthHandler('test', 'wrong'))
@@ -658,7 +687,7 @@ def test_tls_fails():
client = flight.FlightClient.connect(server_location)
# gRPC error messages change based on version, so don't look
# for a particular error
- with pytest.raises(pa.ArrowIOError):
+ with pytest.raises(flight.FlightUnavailableError):
client.do_get(flight.Ticket(b'ints'))
@@ -690,7 +719,7 @@ def test_tls_override_hostname():
client = flight.FlightClient.connect(
server_location, tls_root_certs=certs["root_cert"],
override_hostname="fakehostname")
- with pytest.raises(pa.ArrowIOError):
+ with pytest.raises(flight.FlightUnavailableError):
client.do_get(flight.Ticket(b'ints'))
@@ -748,7 +777,7 @@ def test_cancel_do_get():
client = flight.FlightClient.connect(server_location)
reader = client.do_get(flight.Ticket(b'ints'))
reader.cancel()
- with pytest.raises(pa.ArrowIOError, match=".*Cancel.*"):
+ with pytest.raises(flight.FlightCancelledError, match=".*Cancel.*"):
reader.read_chunk()
@@ -770,7 +799,7 @@ def test_cancel_do_get_threaded():
stream_canceled.wait(timeout=5)
try:
reader.read_chunk()
- except pa.ArrowIOError:
+ except flight.FlightCancelledError:
with result_lock:
raised_proper_exception.set()
@@ -815,3 +844,21 @@ def test_roundtrip_types():
assert info.total_bytes == info2.total_bytes
assert info.total_records == info2.total_records
assert info.endpoints == info2.endpoints
+
+
+def test_roundtrip_errors():
+ """Ensure that Flight errors propagate from server to client."""
+ with flight_server(ErrorFlightServer) as server_location:
+ client = flight.FlightClient.connect(server_location)
+ with pytest.raises(flight.FlightInternalError, match=".*foo.*"):
+ list(client.do_action(flight.Action("internal", b"")))
+ with pytest.raises(flight.FlightTimedOutError, match=".*foo.*"):
+ list(client.do_action(flight.Action("timedout", b"")))
+ with pytest.raises(flight.FlightCancelledError, match=".*foo.*"):
+ list(client.do_action(flight.Action("cancel", b"")))
+ with pytest.raises(flight.FlightUnauthenticatedError, match=".*foo.*"):
+ list(client.do_action(flight.Action("unauthenticated", b"")))
+ with pytest.raises(flight.FlightUnauthorizedError, match=".*foo.*"):
+ list(client.do_action(flight.Action("unauthorized", b"")))
+ with pytest.raises(flight.FlightInternalError, match=".*foo.*"):
+ list(client.list_flights())