You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by yi...@apache.org on 2022/04/14 01:47:52 UTC
[arrow] branch master updated: ARROW-16069: [C++][FlightRPC] Refactor out gRPC error code handling
This is an automated email from the ASF dual-hosted git repository.
yibocai pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new fc9af3cd3f ARROW-16069: [C++][FlightRPC] Refactor out gRPC error code handling
fc9af3cd3f is described below
commit fc9af3cd3fe9abc792a90217578d0def7b2a9a84
Author: David Li <li...@gmail.com>
AuthorDate: Thu Apr 14 01:47:30 2022 +0000
ARROW-16069: [C++][FlightRPC] Refactor out gRPC error code handling
Closes #12749 from lidavidm/arrow-16069
Authored-by: David Li <li...@gmail.com>
Signed-off-by: Yibo Cai <yi...@arm.com>
---
cpp/src/arrow/flight/flight_internals_test.cc | 56 +++++
cpp/src/arrow/flight/flight_test.cc | 6 +
cpp/src/arrow/flight/test_definitions.cc | 123 ++++++++++
cpp/src/arrow/flight/test_definitions.h | 19 ++
cpp/src/arrow/flight/transport.cc | 193 ++++++++++++++++
cpp/src/arrow/flight/transport.h | 50 ++++
.../arrow/flight/transport/grpc/util_internal.cc | 254 +++++++++------------
.../transport/ucx/flight_transport_ucx_test.cc | 56 ++---
cpp/src/arrow/flight/transport/ucx/ucx_internal.cc | 143 +++---------
cpp/src/arrow/flight/transport/ucx/ucx_internal.h | 10 +-
cpp/src/arrow/flight/types.cc | 2 +
python/pyarrow/_flight.pyx | 5 +-
python/pyarrow/tests/test_flight.py | 36 ++-
13 files changed, 639 insertions(+), 314 deletions(-)
diff --git a/cpp/src/arrow/flight/flight_internals_test.cc b/cpp/src/arrow/flight/flight_internals_test.cc
index f7b731f01c..84040a1a47 100644
--- a/cpp/src/arrow/flight/flight_internals_test.cc
+++ b/cpp/src/arrow/flight/flight_internals_test.cc
@@ -482,5 +482,61 @@ TEST_F(TestCookieParsing, CookieCache) {
AddCookieVerifyCache({"id0=0;", "id1=1;", "id2=2"}, "id0=0; id1=1; id2=2");
}
+// ----------------------------------------------------------------------
+// Transport abstraction tests
+
+TEST(TransportErrorHandling, ReconstructStatus) {
+ Status current = Status::Invalid("Base error message");
+ // Invalid code
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ ::testing::HasSubstr(". Also, server sent unknown or invalid Arrow status code -1"),
+ internal::ReconstructStatus("-1", current, util::nullopt, util::nullopt,
+ util::nullopt, /*detail=*/nullptr));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ ::testing::HasSubstr(
+ ". Also, server sent unknown or invalid Arrow status code foobar"),
+ internal::ReconstructStatus("foobar", current, util::nullopt, util::nullopt,
+ util::nullopt, /*detail=*/nullptr));
+
+ // Override code
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ AlreadyExists, ::testing::HasSubstr("Base error message"),
+ internal::ReconstructStatus(
+ std::to_string(static_cast<int>(StatusCode::AlreadyExists)), current,
+ util::nullopt, util::nullopt, util::nullopt, /*detail=*/nullptr));
+
+ // Override message
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ AlreadyExists, ::testing::HasSubstr("Custom error message"),
+ internal::ReconstructStatus(
+ std::to_string(static_cast<int>(StatusCode::AlreadyExists)), current,
+ "Custom error message", util::nullopt, util::nullopt, /*detail=*/nullptr));
+
+ // With detail
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ AlreadyExists,
+ ::testing::AllOf(::testing::HasSubstr("Custom error message"),
+ ::testing::HasSubstr(". Detail: Detail message")),
+ internal::ReconstructStatus(
+ std::to_string(static_cast<int>(StatusCode::AlreadyExists)), current,
+ "Custom error message", "Detail message", util::nullopt, /*detail=*/nullptr));
+
+ // With detail and bin
+ auto reconstructed = internal::ReconstructStatus(
+ std::to_string(static_cast<int>(StatusCode::AlreadyExists)), current,
+ "Custom error message", "Detail message", "Binary error details",
+ /*detail=*/nullptr);
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ AlreadyExists,
+ ::testing::AllOf(::testing::HasSubstr("Custom error message"),
+ ::testing::HasSubstr(". Detail: Detail message")),
+ reconstructed);
+ auto detail = FlightStatusDetail::UnwrapStatus(reconstructed);
+ ASSERT_NE(detail, nullptr);
+ ASSERT_EQ(detail->extra_info(), "Binary error details");
+}
+
} // namespace flight
} // namespace arrow
diff --git a/cpp/src/arrow/flight/flight_test.cc b/cpp/src/arrow/flight/flight_test.cc
index 3f0ed7114f..cf3c30358a 100644
--- a/cpp/src/arrow/flight/flight_test.cc
+++ b/cpp/src/arrow/flight/flight_test.cc
@@ -109,6 +109,12 @@ class GrpcCudaDataTest : public CudaDataTest {
};
ARROW_FLIGHT_TEST_CUDA_DATA(GrpcCudaDataTest);
+class GrpcErrorHandlingTest : public ErrorHandlingTest {
+ protected:
+ std::string transport() const override { return "grpc"; }
+};
+ARROW_FLIGHT_TEST_ERROR_HANDLING(GrpcErrorHandlingTest);
+
//------------------------------------------------------------
// Ad-hoc gRPC-specific tests
diff --git a/cpp/src/arrow/flight/test_definitions.cc b/cpp/src/arrow/flight/test_definitions.cc
index 1ec06a1f00..a152c3c960 100644
--- a/cpp/src/arrow/flight/test_definitions.cc
+++ b/cpp/src/arrow/flight/test_definitions.cc
@@ -1363,5 +1363,128 @@ void CudaDataTest::TestDoExchange() {
#endif
+//------------------------------------------------------------
+// Test error handling
+
+namespace {
+constexpr std::initializer_list<StatusCode> kStatusCodes = {
+ StatusCode::OutOfMemory,
+ StatusCode::KeyError,
+ StatusCode::TypeError,
+ StatusCode::Invalid,
+ StatusCode::IOError,
+ StatusCode::CapacityError,
+ StatusCode::IndexError,
+ StatusCode::Cancelled,
+ StatusCode::UnknownError,
+ StatusCode::NotImplemented,
+ StatusCode::SerializationError,
+ StatusCode::RError,
+ StatusCode::CodeGenError,
+ StatusCode::ExpressionValidationError,
+ StatusCode::ExecutionError,
+ StatusCode::AlreadyExists,
+};
+
+constexpr std::initializer_list<FlightStatusCode> kFlightStatusCodes = {
+ FlightStatusCode::Internal, FlightStatusCode::TimedOut,
+ FlightStatusCode::Cancelled, FlightStatusCode::Unauthenticated,
+ FlightStatusCode::Unauthorized, FlightStatusCode::Unavailable,
+ FlightStatusCode::Failed,
+};
+arrow::Result<StatusCode> TryConvertStatusCode(int raw_code) {
+ for (const auto status_code : kStatusCodes) {
+ if (raw_code == static_cast<int>(status_code)) {
+ return status_code;
+ }
+ }
+ return Status::Invalid(raw_code);
+}
+arrow::Result<FlightStatusCode> TryConvertFlightStatusCode(int raw_code) {
+ for (const auto status_code : kFlightStatusCodes) {
+ if (raw_code == static_cast<int>(status_code)) {
+ return status_code;
+ }
+ }
+ return Status::Invalid(raw_code);
+}
+
+class TestStatusDetail : public StatusDetail {
+ public:
+ const char* type_id() const override { return "test-status-detail"; }
+ std::string ToString() const override { return "Custom status detail"; }
+};
+class ErrorHandlingTestServer : public FlightServerBase {
+ public:
+ Status GetFlightInfo(const ServerCallContext& context, const FlightDescriptor& request,
+ std::unique_ptr<FlightInfo>* info) override {
+ if (request.path.size() >= 2) {
+ const int raw_code = std::atoi(request.path[0].c_str());
+ ARROW_ASSIGN_OR_RAISE(StatusCode code, TryConvertStatusCode(raw_code));
+
+ if (request.path.size() == 2) {
+ return Status(code, request.path[1]);
+ } else if (request.path.size() == 3) {
+ return Status(code, request.path[1], std::make_shared<TestStatusDetail>());
+ } else {
+ const int raw_code = std::atoi(request.path[2].c_str());
+ ARROW_ASSIGN_OR_RAISE(FlightStatusCode flight_code,
+ TryConvertFlightStatusCode(raw_code));
+ return Status(code, request.path[1],
+ std::make_shared<FlightStatusDetail>(flight_code, request.path[3]));
+ }
+ }
+ return Status::NotImplemented("NYI");
+ }
+};
+} // namespace
+
+void ErrorHandlingTest::SetUp() {
+ ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0));
+ ASSERT_OK(MakeServer<ErrorHandlingTestServer>(
+ location, &server_, &client_,
+ [](FlightServerOptions* options) { return Status::OK(); },
+ [](FlightClientOptions* options) { return Status::OK(); }));
+}
+void ErrorHandlingTest::TearDown() {
+ ASSERT_OK(client_->Close());
+ ASSERT_OK(server_->Shutdown());
+}
+
+void ErrorHandlingTest::TestGetFlightInfo() {
+ std::unique_ptr<FlightInfo> info;
+ for (const auto code : kStatusCodes) {
+ ARROW_SCOPED_TRACE("C++ status code: ", static_cast<int>(code));
+ auto descr = FlightDescriptor::Path(
+ {std::to_string(static_cast<int>(code)), "Expected message"});
+ auto status = client_->GetFlightInfo(descr).status();
+ EXPECT_EQ(status.code(), code);
+ EXPECT_THAT(status.message(), ::testing::HasSubstr("Expected message"));
+
+ // Custom status detail
+ descr = FlightDescriptor::Path(
+ {std::to_string(static_cast<int>(code)), "Expected message", ""});
+ status = client_->GetFlightInfo(descr).status();
+ EXPECT_EQ(status.code(), code);
+ EXPECT_THAT(status.message(), ::testing::HasSubstr("Expected message"));
+ EXPECT_THAT(status.message(), ::testing::HasSubstr("Detail: Custom status detail"));
+
+ // Flight status detail
+ for (const auto flight_code : kFlightStatusCodes) {
+ ARROW_SCOPED_TRACE("Flight status code: ", static_cast<int>(flight_code));
+ descr = FlightDescriptor::Path(
+ {std::to_string(static_cast<int>(code)), "Expected message",
+ std::to_string(static_cast<int>(flight_code)), "Expected detail message"});
+ status = client_->GetFlightInfo(descr).status();
+ // Don't check status code, since Flight code may override it
+ EXPECT_THAT(status.message(), ::testing::HasSubstr("Expected message"));
+ auto detail = FlightStatusDetail::UnwrapStatus(status);
+ ASSERT_NE(detail, nullptr);
+ EXPECT_EQ(detail->code(), flight_code);
+ EXPECT_THAT(detail->extra_info(), ::testing::HasSubstr("Expected detail message"));
+ }
+ }
+}
+
} // namespace flight
} // namespace arrow
diff --git a/cpp/src/arrow/flight/test_definitions.h b/cpp/src/arrow/flight/test_definitions.h
index 601e8d0b4b..464631455d 100644
--- a/cpp/src/arrow/flight/test_definitions.h
+++ b/cpp/src/arrow/flight/test_definitions.h
@@ -255,5 +255,24 @@ class ARROW_FLIGHT_EXPORT CudaDataTest : public FlightTest {
TEST_F(FIXTURE, TestDoPut) { TestDoPut(); } \
TEST_F(FIXTURE, TestDoExchange) { TestDoExchange(); }
+/// \brief Tests of error handling.
+class ARROW_FLIGHT_EXPORT ErrorHandlingTest : public FlightTest {
+ public:
+ void SetUp() override;
+ void TearDown() override;
+
+ // Test methods
+ void TestGetFlightInfo();
+
+ private:
+ std::unique_ptr<FlightClient> client_;
+ std::unique_ptr<FlightServerBase> server_;
+};
+
+#define ARROW_FLIGHT_TEST_ERROR_HANDLING(FIXTURE) \
+ static_assert(std::is_base_of<ErrorHandlingTest, FIXTURE>::value, \
+ ARROW_STRINGIFY(FIXTURE) " must inherit from ErrorHandlingTest"); \
+ TEST_F(FIXTURE, TestGetFlightInfo) { TestGetFlightInfo(); }
+
} // namespace flight
} // namespace arrow
diff --git a/cpp/src/arrow/flight/transport.cc b/cpp/src/arrow/flight/transport.cc
index 2ccdf82bd7..0da81a567e 100644
--- a/cpp/src/arrow/flight/transport.cc
+++ b/cpp/src/arrow/flight/transport.cc
@@ -17,6 +17,7 @@
#include "arrow/flight/transport.h"
+#include <sstream>
#include <unordered_map>
#include "arrow/flight/client_auth.h"
@@ -159,6 +160,198 @@ TransportRegistry* GetDefaultTransportRegistry() {
return &kRegistry;
}
+//------------------------------------------------------------
+// Error propagation helpers
+
+TransportStatus TransportStatus::FromStatus(const Status& arrow_status) {
+ if (arrow_status.ok()) {
+ return TransportStatus{TransportStatusCode::kOk, ""};
+ }
+
+ TransportStatusCode code = TransportStatusCode::kUnknown;
+ std::string message = arrow_status.message();
+ if (arrow_status.detail()) {
+ message += ". Detail: ";
+ message += arrow_status.detail()->ToString();
+ }
+
+ std::shared_ptr<FlightStatusDetail> flight_status =
+ FlightStatusDetail::UnwrapStatus(arrow_status);
+ if (flight_status) {
+ switch (flight_status->code()) {
+ case FlightStatusCode::Internal:
+ code = TransportStatusCode::kInternal;
+ break;
+ case FlightStatusCode::TimedOut:
+ code = TransportStatusCode::kTimedOut;
+ break;
+ case FlightStatusCode::Cancelled:
+ code = TransportStatusCode::kCancelled;
+ break;
+ case FlightStatusCode::Unauthenticated:
+ code = TransportStatusCode::kUnauthenticated;
+ break;
+ case FlightStatusCode::Unauthorized:
+ code = TransportStatusCode::kUnauthorized;
+ break;
+ case FlightStatusCode::Unavailable:
+ code = TransportStatusCode::kUnavailable;
+ break;
+ default:
+ break;
+ }
+ } else if (arrow_status.IsKeyError()) {
+ code = TransportStatusCode::kNotFound;
+ } else if (arrow_status.IsInvalid()) {
+ code = TransportStatusCode::kInvalidArgument;
+ } else if (arrow_status.IsCancelled()) {
+ code = TransportStatusCode::kCancelled;
+ } else if (arrow_status.IsNotImplemented()) {
+ code = TransportStatusCode::kUnimplemented;
+ } else if (arrow_status.IsAlreadyExists()) {
+ code = TransportStatusCode::kAlreadyExists;
+ }
+ return TransportStatus{code, std::move(message)};
+}
+
+TransportStatus TransportStatus::FromCodeStringAndMessage(const std::string& code_str,
+ std::string message) {
+ int code_int = 0;
+ try {
+ code_int = std::stoi(code_str);
+ } catch (...) {
+ return TransportStatus{
+ TransportStatusCode::kUnknown,
+ message + ". Also, server sent unknown or invalid Arrow status code " + code_str};
+ }
+ switch (code_int) {
+ case static_cast<int>(TransportStatusCode::kOk):
+ case static_cast<int>(TransportStatusCode::kUnknown):
+ case static_cast<int>(TransportStatusCode::kInternal):
+ case static_cast<int>(TransportStatusCode::kInvalidArgument):
+ case static_cast<int>(TransportStatusCode::kTimedOut):
+ case static_cast<int>(TransportStatusCode::kNotFound):
+ case static_cast<int>(TransportStatusCode::kAlreadyExists):
+ case static_cast<int>(TransportStatusCode::kCancelled):
+ case static_cast<int>(TransportStatusCode::kUnauthenticated):
+ case static_cast<int>(TransportStatusCode::kUnauthorized):
+ case static_cast<int>(TransportStatusCode::kUnimplemented):
+ case static_cast<int>(TransportStatusCode::kUnavailable):
+ return TransportStatus{static_cast<TransportStatusCode>(code_int),
+ std::move(message)};
+ default: {
+ return TransportStatus{
+ TransportStatusCode::kUnknown,
+ message + ". Also, server sent unknown or invalid Arrow status code " +
+ code_str};
+ }
+ }
+}
+
+Status TransportStatus::ToStatus() const {
+ switch (code) {
+ case TransportStatusCode::kOk:
+ return Status::OK();
+ case TransportStatusCode::kUnknown: {
+ std::stringstream ss;
+ ss << "Flight RPC failed with message: " << message;
+ return Status::UnknownError(ss.str()).WithDetail(
+ std::make_shared<FlightStatusDetail>(FlightStatusCode::Failed));
+ }
+ case TransportStatusCode::kInternal:
+ return Status::IOError("Flight returned internal error, with message: ", message)
+ .WithDetail(std::make_shared<FlightStatusDetail>(FlightStatusCode::Internal));
+ case TransportStatusCode::kInvalidArgument:
+ return Status::Invalid("Flight returned invalid argument error, with message: ",
+ message);
+ case TransportStatusCode::kTimedOut:
+ return Status::IOError("Flight returned timeout error, with message: ", message)
+ .WithDetail(std::make_shared<FlightStatusDetail>(FlightStatusCode::TimedOut));
+ case TransportStatusCode::kNotFound:
+ return Status::KeyError("Flight returned not found error, with message: ", message);
+ case TransportStatusCode::kAlreadyExists:
+ return Status::AlreadyExists("Flight returned already exists error, with message: ",
+ message);
+ case TransportStatusCode::kCancelled:
+ return Status::Cancelled("Flight cancelled call, with message: ", message)
+ .WithDetail(std::make_shared<FlightStatusDetail>(FlightStatusCode::Cancelled));
+ case TransportStatusCode::kUnauthenticated:
+ return Status::IOError("Flight returned unauthenticated error, with message: ",
+ message)
+ .WithDetail(
+ std::make_shared<FlightStatusDetail>(FlightStatusCode::Unauthenticated));
+ case TransportStatusCode::kUnauthorized:
+ return Status::IOError("Flight returned unauthorized error, with message: ",
+ message)
+ .WithDetail(
+ std::make_shared<FlightStatusDetail>(FlightStatusCode::Unauthorized));
+ case TransportStatusCode::kUnimplemented:
+ return Status::NotImplemented("Flight returned unimplemented error, with message: ",
+ message);
+ case TransportStatusCode::kUnavailable:
+ return Status::IOError("Flight returned unavailable error, with message: ", message)
+ .WithDetail(
+ std::make_shared<FlightStatusDetail>(FlightStatusCode::Unavailable));
+ default:
+ return Status::UnknownError("Flight failed with error code ",
+ static_cast<int>(code), " and message: ", message);
+ }
+}
+
+Status ReconstructStatus(const std::string& code_str, const Status& current_status,
+ util::optional<std::string> message,
+ util::optional<std::string> detail_message,
+ util::optional<std::string> detail_bin,
+ std::shared_ptr<FlightStatusDetail> detail) {
+ // Bounce through std::string to get a proper null-terminated C string
+ StatusCode status_code = current_status.code();
+ std::stringstream status_message;
+ try {
+ const auto code_int = std::stoi(code_str);
+ switch (code_int) {
+ case static_cast<int>(StatusCode::OutOfMemory):
+ case static_cast<int>(StatusCode::KeyError):
+ case static_cast<int>(StatusCode::TypeError):
+ case static_cast<int>(StatusCode::Invalid):
+ case static_cast<int>(StatusCode::IOError):
+ case static_cast<int>(StatusCode::CapacityError):
+ case static_cast<int>(StatusCode::IndexError):
+ case static_cast<int>(StatusCode::Cancelled):
+ case static_cast<int>(StatusCode::UnknownError):
+ case static_cast<int>(StatusCode::NotImplemented):
+ case static_cast<int>(StatusCode::SerializationError):
+ case static_cast<int>(StatusCode::RError):
+ case static_cast<int>(StatusCode::CodeGenError):
+ case static_cast<int>(StatusCode::ExpressionValidationError):
+ case static_cast<int>(StatusCode::ExecutionError):
+ case static_cast<int>(StatusCode::AlreadyExists): {
+ status_code = static_cast<StatusCode>(code_int);
+ break;
+ }
+ default: {
+ status_message << ". Also, server sent unknown or invalid Arrow status code "
+ << code_str;
+ break;
+ }
+ }
+ } catch (...) {
+ status_message << ". Also, server sent unknown or invalid Arrow status code "
+ << code_str;
+ }
+
+ status_message << (message.has_value() ? *message : current_status.message());
+ if (detail_message.has_value()) {
+ status_message << ". Detail: " << *detail_message;
+ }
+ if (detail_bin.has_value()) {
+ if (!detail) {
+ detail = std::make_shared<FlightStatusDetail>(FlightStatusCode::Internal);
+ }
+ detail->set_extra_info(std::move(*detail_bin));
+ }
+ return Status(status_code, status_message.str(), std::move(detail));
+}
+
} // namespace internal
} // namespace flight
} // namespace arrow
diff --git a/cpp/src/arrow/flight/transport.h b/cpp/src/arrow/flight/transport.h
index f02ab05157..66ded71fbe 100644
--- a/cpp/src/arrow/flight/transport.h
+++ b/cpp/src/arrow/flight/transport.h
@@ -65,12 +65,14 @@
#include "arrow/flight/type_fwd.h"
#include "arrow/flight/visibility.h"
#include "arrow/type_fwd.h"
+#include "arrow/util/optional.h"
namespace arrow {
namespace ipc {
class Message;
}
namespace flight {
+class FlightStatusDetail;
namespace internal {
/// Internal, not user-visible type used for memory-efficient reads
@@ -220,6 +222,54 @@ class ARROW_FLIGHT_EXPORT TransportRegistry {
ARROW_FLIGHT_EXPORT
TransportRegistry* GetDefaultTransportRegistry();
+//------------------------------------------------------------
+// Error propagation helpers
+
+/// \brief Abstract status code as per the Flight specification.
+enum class TransportStatusCode {
+ kOk = 0,
+ kUnknown = 1,
+ kInternal = 2,
+ kInvalidArgument = 3,
+ kTimedOut = 4,
+ kNotFound = 5,
+ kAlreadyExists = 6,
+ kCancelled = 7,
+ kUnauthenticated = 8,
+ kUnauthorized = 9,
+ kUnimplemented = 10,
+ kUnavailable = 11,
+};
+
+/// \brief Abstract error status.
+///
+/// Transport implementations may use side channels (e.g. HTTP
+/// trailers) to convey additional information to reconstruct the
+/// original C++ status for implementations that can use it.
+struct ARROW_FLIGHT_EXPORT TransportStatus {
+ TransportStatusCode code;
+ std::string message;
+
+ /// \brief Convert a C++ status to an abstract transport status.
+ static TransportStatus FromStatus(const Status& arrow_status);
+
+ /// \brief Reconstruct a string-encoded TransportStatus.
+ static TransportStatus FromCodeStringAndMessage(const std::string& code_str,
+ std::string message);
+
+ /// \brief Convert an abstract transport status to a C++ status.
+ Status ToStatus() const;
+};
+
+/// \brief Convert the string representation of an Arrow status code
+/// back to an Arrow status.
+ARROW_FLIGHT_EXPORT
+Status ReconstructStatus(const std::string& code_str, const Status& current_status,
+ util::optional<std::string> message,
+ util::optional<std::string> detail_message,
+ util::optional<std::string> detail_bin,
+ std::shared_ptr<FlightStatusDetail> detail);
+
} // namespace internal
} // namespace flight
} // namespace arrow
diff --git a/cpp/src/arrow/flight/transport/grpc/util_internal.cc b/cpp/src/arrow/flight/transport/grpc/util_internal.cc
index 5268df160e..0455dc119a 100644
--- a/cpp/src/arrow/flight/transport/grpc/util_internal.cc
+++ b/cpp/src/arrow/flight/transport/grpc/util_internal.cc
@@ -20,7 +20,6 @@
#include <cstdlib>
#include <map>
#include <memory>
-#include <sstream>
#include <string>
#ifdef GRPCPP_PP_INCLUDE
@@ -29,6 +28,7 @@
#include <grpc++/grpc++.h>
#endif
+#include "arrow/flight/transport.h"
#include "arrow/flight/types.h"
#include "arrow/status.h"
@@ -43,110 +43,77 @@ const char* kGrpcStatusMessageHeader = "x-arrow-status-message-bin";
const char* kGrpcStatusDetailHeader = "x-arrow-status-detail-bin";
const char* kBinaryErrorDetailsKey = "grpc-status-details-bin";
-static Status StatusCodeFromString(const ::grpc::string_ref& code_ref, StatusCode* code) {
- // Bounce through std::string to get a proper null-terminated C string
- const auto code_int = std::atoi(std::string(code_ref.data(), code_ref.size()).c_str());
- switch (code_int) {
- case static_cast<int>(StatusCode::OutOfMemory):
- case static_cast<int>(StatusCode::KeyError):
- case static_cast<int>(StatusCode::TypeError):
- case static_cast<int>(StatusCode::Invalid):
- case static_cast<int>(StatusCode::IOError):
- case static_cast<int>(StatusCode::CapacityError):
- case static_cast<int>(StatusCode::IndexError):
- case static_cast<int>(StatusCode::UnknownError):
- case static_cast<int>(StatusCode::NotImplemented):
- case static_cast<int>(StatusCode::SerializationError):
- case static_cast<int>(StatusCode::RError):
- case static_cast<int>(StatusCode::CodeGenError):
- case static_cast<int>(StatusCode::ExpressionValidationError):
- case static_cast<int>(StatusCode::ExecutionError):
- case static_cast<int>(StatusCode::AlreadyExists): {
- *code = static_cast<StatusCode>(code_int);
- return Status::OK();
- }
- default:
- // Code is invalid
- return Status::UnknownError("Unknown Arrow status code", code_ref);
- }
-}
-
/// Try to extract a status from gRPC trailers.
/// Return Status::OK if found, an error otherwise.
-static Status FromGrpcContext(const ::grpc::ClientContext& ctx, Status* status,
- std::shared_ptr<FlightStatusDetail> flight_status_detail) {
+static bool FromGrpcContext(const ::grpc::ClientContext& ctx,
+ const Status& current_status, Status* status,
+ std::shared_ptr<FlightStatusDetail> flight_status_detail) {
const std::multimap<::grpc::string_ref, ::grpc::string_ref>& trailers =
ctx.GetServerTrailingMetadata();
- const auto code_val = trailers.find(kGrpcStatusCodeHeader);
- if (code_val == trailers.end()) {
- return Status::IOError("Status code header not found");
- }
- const ::grpc::string_ref code_ref = code_val->second;
- StatusCode code = {};
- RETURN_NOT_OK(StatusCodeFromString(code_ref, &code));
+ const auto code_val = trailers.find(kGrpcStatusCodeHeader);
+ if (code_val == trailers.end()) return false;
const auto message_val = trailers.find(kGrpcStatusMessageHeader);
- if (message_val == trailers.end()) {
- return Status::IOError("Status message header not found");
- }
+ const util::optional<std::string> message =
+ message_val == trailers.end()
+ ? util::nullopt
+ : util::optional<std::string>(
+ std::string(message_val->second.data(), message_val->second.size()));
- const ::grpc::string_ref message_ref = message_val->second;
- std::string message = std::string(message_ref.data(), message_ref.size());
const auto detail_val = trailers.find(kGrpcStatusDetailHeader);
- if (detail_val != trailers.end()) {
- const ::grpc::string_ref detail_ref = detail_val->second;
- message += ". Detail: ";
- message += std::string(detail_ref.data(), detail_ref.size());
- }
+ const util::optional<std::string> detail_message =
+ detail_val == trailers.end()
+ ? util::nullopt
+ : util::optional<std::string>(
+ std::string(detail_val->second.data(), detail_val->second.size()));
+
const auto grpc_detail_val = trailers.find(kBinaryErrorDetailsKey);
- if (grpc_detail_val != trailers.end()) {
- const ::grpc::string_ref detail_ref = grpc_detail_val->second;
- std::string bin_detail = std::string(detail_ref.data(), detail_ref.size());
- if (!flight_status_detail) {
- flight_status_detail =
- std::make_shared<FlightStatusDetail>(FlightStatusCode::Internal);
- }
- flight_status_detail->set_extra_info(bin_detail);
- }
- *status = Status(code, message, flight_status_detail);
- return Status::OK();
+ const util::optional<std::string> detail_bin =
+ grpc_detail_val == trailers.end()
+ ? util::nullopt
+ : util::optional<std::string>(std::string(grpc_detail_val->second.data(),
+ grpc_detail_val->second.size()));
+
+ std::string code_str(code_val->second.data(), code_val->second.size());
+ *status = internal::ReconstructStatus(code_str, current_status, std::move(message),
+ std::move(detail_message), std::move(detail_bin),
+ std::move(flight_status_detail));
+ return true;
}
/// Convert a gRPC status to an Arrow status, ignoring any
/// implementation-defined headers that encode further detail.
static Status FromGrpcCode(const ::grpc::Status& grpc_status) {
+ using internal::TransportStatus;
+ using internal::TransportStatusCode;
switch (grpc_status.error_code()) {
case ::grpc::StatusCode::OK:
return Status::OK();
case ::grpc::StatusCode::CANCELLED:
- return Status::IOError("gRPC cancelled call, with message: ",
- grpc_status.error_message())
- .WithDetail(std::make_shared<FlightStatusDetail>(FlightStatusCode::Cancelled));
- case ::grpc::StatusCode::UNKNOWN: {
- std::stringstream ss;
- ss << "Flight RPC failed with message: " << grpc_status.error_message();
- return Status::UnknownError(ss.str()).WithDetail(
- std::make_shared<FlightStatusDetail>(FlightStatusCode::Failed));
- }
+ return TransportStatus{TransportStatusCode::kCancelled, grpc_status.error_message()}
+ .ToStatus();
+ case ::grpc::StatusCode::UNKNOWN:
+ return TransportStatus{TransportStatusCode::kUnknown, grpc_status.error_message()}
+ .ToStatus();
case ::grpc::StatusCode::INVALID_ARGUMENT:
- return Status::Invalid("gRPC returned invalid argument error, with message: ",
- grpc_status.error_message());
+ return TransportStatus{TransportStatusCode::kInvalidArgument,
+ grpc_status.error_message()}
+ .ToStatus();
case ::grpc::StatusCode::DEADLINE_EXCEEDED:
- return Status::IOError("gRPC returned deadline exceeded error, with message: ",
- grpc_status.error_message())
- .WithDetail(std::make_shared<FlightStatusDetail>(FlightStatusCode::TimedOut));
+ return TransportStatus{TransportStatusCode::kTimedOut, grpc_status.error_message()}
+ .ToStatus();
case ::grpc::StatusCode::NOT_FOUND:
- return Status::KeyError("gRPC returned not found error, with message: ",
- grpc_status.error_message());
+ return TransportStatus{TransportStatusCode::kNotFound, grpc_status.error_message()}
+ .ToStatus();
case ::grpc::StatusCode::ALREADY_EXISTS:
- return Status::AlreadyExists("gRPC returned already exists error, with message: ",
- grpc_status.error_message());
+ return TransportStatus{TransportStatusCode::kAlreadyExists,
+ grpc_status.error_message()}
+ .ToStatus();
case ::grpc::StatusCode::PERMISSION_DENIED:
- return Status::IOError("gRPC returned permission denied error, with message: ",
- grpc_status.error_message())
- .WithDetail(
- std::make_shared<FlightStatusDetail>(FlightStatusCode::Unauthorized));
+ return TransportStatus{TransportStatusCode::kUnauthorized,
+ grpc_status.error_message()}
+ .ToStatus();
case ::grpc::StatusCode::RESOURCE_EXHAUSTED:
return Status::Invalid("gRPC returned resource exhausted error, with message: ",
grpc_status.error_message());
@@ -161,26 +128,24 @@ static Status FromGrpcCode(const ::grpc::Status& grpc_status) {
return Status::Invalid("gRPC returned out-of-range error, with message: ",
grpc_status.error_message());
case ::grpc::StatusCode::UNIMPLEMENTED:
- return Status::NotImplemented("gRPC returned unimplemented error, with message: ",
- grpc_status.error_message());
+ return TransportStatus{TransportStatusCode::kUnimplemented,
+ grpc_status.error_message()}
+ .ToStatus();
case ::grpc::StatusCode::INTERNAL:
- return Status::IOError("gRPC returned internal error, with message: ",
- grpc_status.error_message())
- .WithDetail(std::make_shared<FlightStatusDetail>(FlightStatusCode::Internal));
+ return TransportStatus{TransportStatusCode::kInternal, grpc_status.error_message()}
+ .ToStatus();
case ::grpc::StatusCode::UNAVAILABLE:
- return Status::IOError("gRPC returned unavailable error, with message: ",
- grpc_status.error_message())
- .WithDetail(
- std::make_shared<FlightStatusDetail>(FlightStatusCode::Unavailable));
+ return TransportStatus{TransportStatusCode::kUnavailable,
+ grpc_status.error_message()}
+ .ToStatus();
case ::grpc::StatusCode::DATA_LOSS:
return Status::IOError("gRPC returned data loss error, with message: ",
grpc_status.error_message())
.WithDetail(std::make_shared<FlightStatusDetail>(FlightStatusCode::Internal));
case ::grpc::StatusCode::UNAUTHENTICATED:
- return Status::IOError("gRPC returned unauthenticated error, with message: ",
- grpc_status.error_message())
- .WithDetail(
- std::make_shared<FlightStatusDetail>(FlightStatusCode::Unauthenticated));
+ return TransportStatus{TransportStatusCode::kUnauthenticated,
+ grpc_status.error_message()}
+ .ToStatus();
default:
return Status::UnknownError("gRPC failed with error code ",
grpc_status.error_code(),
@@ -190,70 +155,67 @@ static Status FromGrpcCode(const ::grpc::Status& grpc_status) {
Status FromGrpcStatus(const ::grpc::Status& grpc_status, ::grpc::ClientContext* ctx) {
const Status status = FromGrpcCode(grpc_status);
-
if (!status.ok() && ctx) {
Status arrow_status;
-
- if (!FromGrpcContext(*ctx, &arrow_status, FlightStatusDetail::UnwrapStatus(status))
- .ok()) {
- // If we fail to decode a more detailed status from the headers,
- // proceed normally
- return status;
+ if (FromGrpcContext(*ctx, status, &arrow_status,
+ FlightStatusDetail::UnwrapStatus(status))) {
+ return arrow_status;
}
-
- return arrow_status;
+ // If we fail to decode a more detailed status from the headers,
+ // proceed normally
}
return status;
}
/// Convert an Arrow status to a gRPC status.
static ::grpc::Status ToRawGrpcStatus(const Status& arrow_status) {
- if (arrow_status.ok()) {
- return ::grpc::Status::OK;
- }
+ using internal::TransportStatus;
+ using internal::TransportStatusCode;
+ if (arrow_status.ok()) return ::grpc::Status::OK;
+ TransportStatus transport_status = TransportStatus::FromStatus(arrow_status);
::grpc::StatusCode grpc_code = ::grpc::StatusCode::UNKNOWN;
- std::string message = arrow_status.message();
- if (arrow_status.detail()) {
- message += ". Detail: ";
- message += arrow_status.detail()->ToString();
- }
-
- std::shared_ptr<FlightStatusDetail> flight_status =
- FlightStatusDetail::UnwrapStatus(arrow_status);
- if (flight_status) {
- switch (flight_status->code()) {
- case FlightStatusCode::Internal:
- grpc_code = ::grpc::StatusCode::INTERNAL;
- break;
- case FlightStatusCode::TimedOut:
- grpc_code = ::grpc::StatusCode::DEADLINE_EXCEEDED;
- break;
- case FlightStatusCode::Cancelled:
- grpc_code = ::grpc::StatusCode::CANCELLED;
- break;
- case FlightStatusCode::Unauthenticated:
- grpc_code = ::grpc::StatusCode::UNAUTHENTICATED;
- break;
- case FlightStatusCode::Unauthorized:
- grpc_code = ::grpc::StatusCode::PERMISSION_DENIED;
- break;
- case FlightStatusCode::Unavailable:
- grpc_code = ::grpc::StatusCode::UNAVAILABLE;
- break;
- default:
- break;
- }
- } else if (arrow_status.IsNotImplemented()) {
- grpc_code = ::grpc::StatusCode::UNIMPLEMENTED;
- } else if (arrow_status.IsInvalid()) {
- grpc_code = ::grpc::StatusCode::INVALID_ARGUMENT;
- } else if (arrow_status.IsKeyError()) {
- grpc_code = ::grpc::StatusCode::NOT_FOUND;
- } else if (arrow_status.IsAlreadyExists()) {
- grpc_code = ::grpc::StatusCode::ALREADY_EXISTS;
+ switch (transport_status.code) {
+ case TransportStatusCode::kOk:
+ return ::grpc::Status::OK;
+ case TransportStatusCode::kUnknown:
+ grpc_code = ::grpc::StatusCode::UNKNOWN;
+ break;
+ case TransportStatusCode::kInternal:
+ grpc_code = ::grpc::StatusCode::INTERNAL;
+ break;
+ case TransportStatusCode::kInvalidArgument:
+ grpc_code = ::grpc::StatusCode::INVALID_ARGUMENT;
+ break;
+ case TransportStatusCode::kTimedOut:
+ grpc_code = ::grpc::StatusCode::DEADLINE_EXCEEDED;
+ break;
+ case TransportStatusCode::kNotFound:
+ grpc_code = ::grpc::StatusCode::NOT_FOUND;
+ break;
+ case TransportStatusCode::kAlreadyExists:
+ grpc_code = ::grpc::StatusCode::ALREADY_EXISTS;
+ break;
+ case TransportStatusCode::kCancelled:
+ grpc_code = ::grpc::StatusCode::CANCELLED;
+ break;
+ case TransportStatusCode::kUnauthenticated:
+ grpc_code = ::grpc::StatusCode::UNAUTHENTICATED;
+ break;
+ case TransportStatusCode::kUnauthorized:
+ grpc_code = ::grpc::StatusCode::PERMISSION_DENIED;
+ break;
+ case TransportStatusCode::kUnimplemented:
+ grpc_code = ::grpc::StatusCode::UNIMPLEMENTED;
+ break;
+ case TransportStatusCode::kUnavailable:
+ grpc_code = ::grpc::StatusCode::UNAVAILABLE;
+ break;
+ default:
+ grpc_code = ::grpc::StatusCode::UNKNOWN;
+ break;
}
- return ::grpc::Status(grpc_code, message);
+ return ::grpc::Status(grpc_code, std::move(transport_status.message));
}
/// Convert an Arrow status to a gRPC status, and add extra headers to
diff --git a/cpp/src/arrow/flight/transport/ucx/flight_transport_ucx_test.cc b/cpp/src/arrow/flight/transport/ucx/flight_transport_ucx_test.cc
index 6a580af92f..a29d498d0b 100644
--- a/cpp/src/arrow/flight/transport/ucx/flight_transport_ucx_test.cc
+++ b/cpp/src/arrow/flight/transport/ucx/flight_transport_ucx_test.cc
@@ -86,6 +86,12 @@ class UcxCudaDataTest : public CudaDataTest {
};
ARROW_FLIGHT_TEST_CUDA_DATA(UcxCudaDataTest);
+class UcxErrorHandlingTest : public ErrorHandlingTest {
+ protected:
+ std::string transport() const override { return "ucx"; }
+};
+ARROW_FLIGHT_TEST_ERROR_HANDLING(UcxErrorHandlingTest);
+
//------------------------------------------------------------
// UCX internals tests
@@ -203,43 +209,6 @@ TEST(HeadersFrame, Parse) {
HeadersFrame::Parse(std::move(buffer)));
}
}
-
-TEST(HeadersFrame, RoundTripStatus) {
- for (const auto code : kStatusCodes) {
- {
- Status expected = code == StatusCode::OK ? Status() : Status(code, "foo");
- ASSERT_OK_AND_ASSIGN(auto headers, HeadersFrame::Make(expected, {}));
- Status status;
- ASSERT_OK(headers.GetStatus(&status));
- ASSERT_EQ(status, expected);
- }
-
- if (code == StatusCode::OK) continue;
-
- // Attach a generic status detail
- {
- auto detail = std::make_shared<TestStatusDetail>();
- Status original(code, "foo", detail);
- Status expected(code, "foo",
- std::make_shared<FlightStatusDetail>(FlightStatusCode::Internal,
- detail->ToString()));
- ASSERT_OK_AND_ASSIGN(auto headers, HeadersFrame::Make(expected, {}));
- Status status;
- ASSERT_OK(headers.GetStatus(&status));
- ASSERT_EQ(status, expected);
- }
-
- // Attach a Flight status detail
- for (const auto flight_code : kFlightStatusCodes) {
- Status expected(code, "foo",
- std::make_shared<FlightStatusDetail>(flight_code, "extra"));
- ASSERT_OK_AND_ASSIGN(auto headers, HeadersFrame::Make(expected, {}));
- Status status;
- ASSERT_OK(headers.GetStatus(&status));
- ASSERT_EQ(status, expected);
- }
- }
-}
} // namespace ucx
} // namespace transport
@@ -342,7 +311,9 @@ TEST_F(TestUcx, Errors) {
Status expected(code, "Error message");
server->set_error_status(expected);
Status actual = client_->GetFlightInfo(descriptor).status();
- ASSERT_EQ(actual, expected);
+ ASSERT_EQ(actual.code(), expected.code()) << actual.ToString();
+ ASSERT_THAT(actual.message(), ::testing::HasSubstr("Error message"))
+ << actual.ToString();
// Attach a generic status detail
{
@@ -352,7 +323,10 @@ TEST_F(TestUcx, Errors) {
std::make_shared<FlightStatusDetail>(FlightStatusCode::Internal,
detail->ToString()));
Status actual = client_->GetFlightInfo(descriptor).status();
- ASSERT_EQ(actual, expected);
+ ASSERT_EQ(actual.code(), expected.code()) << actual.ToString();
+ ASSERT_THAT(actual.message(), ::testing::HasSubstr("foo")) << actual.ToString();
+ ASSERT_THAT(actual.message(), ::testing::HasSubstr("Custom status detail"))
+ << actual.ToString();
}
// Attach a Flight status detail
@@ -361,7 +335,9 @@ TEST_F(TestUcx, Errors) {
std::make_shared<FlightStatusDetail>(flight_code, "extra"));
server->set_error_status(expected);
Status actual = client_->GetFlightInfo(descriptor).status();
- ASSERT_EQ(actual, expected);
+ ASSERT_EQ(actual.code(), expected.code()) << actual.ToString();
+ ASSERT_THAT(actual.message(), ::testing::HasSubstr("Error message"))
+ << actual.ToString();
}
}
}
diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc b/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc
index ab4cc323f4..abcf791125 100644
--- a/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc
+++ b/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc
@@ -36,6 +36,9 @@ namespace flight {
namespace transport {
namespace ucx {
+using internal::TransportStatus;
+using internal::TransportStatusCode;
+
// Defines to test different implementation strategies
// Enable the CONTIG path for CPU-only data
// #define ARROW_FLIGHT_UCX_SEND_CONTIG
@@ -222,17 +225,19 @@ arrow::Result<HeadersFrame> HeadersFrame::Make(
const Status& status,
const std::vector<std::pair<std::string, std::string>>& headers) {
auto all_headers = headers;
+
+ TransportStatus transport_status = TransportStatus::FromStatus(status);
+ all_headers.emplace_back(kHeaderStatus,
+ std::to_string(static_cast<int32_t>(transport_status.code)));
+ all_headers.emplace_back(kHeaderMessage, std::move(transport_status.message));
all_headers.emplace_back(kHeaderStatusCode,
std::to_string(static_cast<int32_t>(status.code())));
all_headers.emplace_back(kHeaderStatusMessage, status.message());
if (status.detail()) {
+ all_headers.emplace_back(kHeaderStatusDetail, status.detail()->ToString());
auto fsd = FlightStatusDetail::UnwrapStatus(status);
- if (fsd) {
- all_headers.emplace_back(kHeaderStatusDetailCode,
- std::to_string(static_cast<int32_t>(fsd->code())));
- all_headers.emplace_back(kHeaderStatusDetail, fsd->extra_info());
- } else {
- all_headers.emplace_back(kHeaderStatusDetail, status.detail()->ToString());
+ if (fsd && !fsd->extra_info().empty()) {
+ all_headers.emplace_back(kHeaderStatusDetailBin, fsd->extra_info());
}
}
return Make(all_headers);
@@ -246,118 +251,46 @@ arrow::Result<util::string_view> HeadersFrame::Get(const std::string& key) {
}
Status HeadersFrame::GetStatus(Status* out) {
+ static const std::string kUnknownMessage = "Server did not send status message header";
util::string_view code_str, message_str;
- auto status = Get(kHeaderStatusCode).Value(&code_str);
+ auto status = Get(kHeaderStatus).Value(&code_str);
if (!status.ok()) {
return Status::KeyError("Server did not send status code header ", kHeaderStatusCode);
}
-
- StatusCode status_code = StatusCode::OK;
- auto code = std::strtol(code_str.data(), nullptr, /*base=*/10);
- switch (code) {
- case 0:
- status_code = StatusCode::OK;
- break;
- case 1:
- status_code = StatusCode::OutOfMemory;
- break;
- case 2:
- status_code = StatusCode::KeyError;
- break;
- case 3:
- status_code = StatusCode::TypeError;
- break;
- case 4:
- status_code = StatusCode::Invalid;
- break;
- case 5:
- status_code = StatusCode::IOError;
- break;
- case 6:
- status_code = StatusCode::CapacityError;
- break;
- case 7:
- status_code = StatusCode::IndexError;
- break;
- case 8:
- status_code = StatusCode::Cancelled;
- break;
- case 9:
- status_code = StatusCode::UnknownError;
- break;
- case 10:
- status_code = StatusCode::NotImplemented;
- break;
- case 11:
- status_code = StatusCode::SerializationError;
- break;
- case 13:
- status_code = StatusCode::RError;
- break;
- case 40:
- status_code = StatusCode::CodeGenError;
- break;
- case 41:
- status_code = StatusCode::ExpressionValidationError;
- break;
- case 42:
- status_code = StatusCode::ExecutionError;
- break;
- case 45:
- status_code = StatusCode::AlreadyExists;
- break;
- default:
- status_code = StatusCode::UnknownError;
- break;
- }
- if (status_code == StatusCode::OK) {
+ if (code_str == "0") { // == std::to_string(TransportStatusCode::kOk)
*out = Status::OK();
return Status::OK();
}
- status = Get(kHeaderStatusMessage).Value(&message_str);
- if (!status.ok()) {
- *out = Status(status_code, "Server did not send status message header", nullptr);
+ status = Get(kHeaderMessage).Value(&message_str);
+ if (!status.ok()) message_str = kUnknownMessage;
+
+ TransportStatus transport_status = TransportStatus::FromCodeStringAndMessage(
+ std::string(code_str), std::string(message_str));
+ if (transport_status.code == TransportStatusCode::kOk) {
+ *out = Status::OK();
return Status::OK();
}
+ *out = transport_status.ToStatus();
- util::string_view detail_code_str, detail_str;
- FlightStatusCode detail_code = FlightStatusCode::Internal;
-
- if (Get(kHeaderStatusDetailCode).Value(&detail_code_str).ok()) {
- auto detail_code_int = std::strtol(detail_code_str.data(), nullptr, /*base=*/10);
- switch (detail_code_int) {
- case 1:
- detail_code = FlightStatusCode::TimedOut;
- break;
- case 2:
- detail_code = FlightStatusCode::Cancelled;
- break;
- case 3:
- detail_code = FlightStatusCode::Unauthenticated;
- break;
- case 4:
- detail_code = FlightStatusCode::Unauthorized;
- break;
- case 5:
- detail_code = FlightStatusCode::Unavailable;
- break;
- case 6:
- detail_code = FlightStatusCode::Failed;
- break;
- case 0:
- default:
- detail_code = FlightStatusCode::Internal;
- break;
- }
+ util::string_view detail_str, bin_str;
+ util::optional<std::string> message, detail_message, detail_bin;
+ if (!Get(kHeaderStatusCode).Value(&code_str).ok()) {
+ // No Arrow status sent, go with the transport status
+ return Status::OK();
}
- ARROW_UNUSED(Get(kHeaderStatusDetail).Value(&detail_str));
-
- std::shared_ptr<StatusDetail> detail = nullptr;
- if (!detail_str.empty()) {
- detail = std::make_shared<FlightStatusDetail>(detail_code, std::string(detail_str));
+ if (Get(kHeaderStatusMessage).Value(&message_str).ok()) {
+ message = std::string(message_str);
+ }
+ if (Get(kHeaderStatusDetail).Value(&detail_str).ok()) {
+ detail_message = std::string(detail_str);
+ }
+ if (Get(kHeaderStatusDetailBin).Value(&bin_str).ok()) {
+ detail_bin = std::string(bin_str);
}
- *out = Status(status_code, std::string(message_str), std::move(detail));
+ *out = internal::ReconstructStatus(std::string(code_str), *out, std::move(message),
+ std::move(detail_message), std::move(detail_bin),
+ FlightStatusDetail::UnwrapStatus(*out));
return Status::OK();
}
diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_internal.h b/cpp/src/arrow/flight/transport/ucx/ucx_internal.h
index bd176e2369..f5b81ab414 100644
--- a/cpp/src/arrow/flight/transport/ucx/ucx_internal.h
+++ b/cpp/src/arrow/flight/transport/ucx/ucx_internal.h
@@ -50,10 +50,18 @@ static constexpr char kMethodDoGet[] = "DoGet";
static constexpr char kMethodDoPut[] = "DoPut";
static constexpr char kMethodGetFlightInfo[] = "GetFlightInfo";
+/// The header encoding the transport status.
+static constexpr char kHeaderStatus[] = "flight-status";
+/// The header encoding the transport status.
+static constexpr char kHeaderMessage[] = "flight-message";
+/// The header encoding the C++ status.
static constexpr char kHeaderStatusCode[] = "flight-status-code";
+/// The header encoding the C++ status message.
static constexpr char kHeaderStatusMessage[] = "flight-status-message";
+/// The header encoding the C++ status detail message.
static constexpr char kHeaderStatusDetail[] = "flight-status-detail";
-static constexpr char kHeaderStatusDetailCode[] = "flight-status-detail-code";
+/// The header encoding the C++ status detail binary data.
+static constexpr char kHeaderStatusDetailBin[] = "flight-status-detail-bin";
//------------------------------------------------------------
// UCX Helpers
diff --git a/cpp/src/arrow/flight/types.cc b/cpp/src/arrow/flight/types.cc
index 4a169e985c..efc96bb775 100644
--- a/cpp/src/arrow/flight/types.cc
+++ b/cpp/src/arrow/flight/types.cc
@@ -67,6 +67,8 @@ std::string FlightStatusDetail::CodeAsString() const {
return "Unauthorized";
case FlightStatusCode::Unavailable:
return "Unavailable";
+ case FlightStatusCode::Failed:
+ return "Failed";
default:
return "Unknown";
}
diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx
index fa6ef29c07..5821956b29 100644
--- a/python/pyarrow/_flight.pyx
+++ b/python/pyarrow/_flight.pyx
@@ -31,7 +31,8 @@ from cython.operator cimport postincrement
from libcpp cimport bool as c_bool
from pyarrow.lib cimport *
-from pyarrow.lib import ArrowException, ArrowInvalid, SignalStopHandler
+from pyarrow.lib import (ArrowCancelled, ArrowException, ArrowInvalid,
+ SignalStopHandler)
from pyarrow.lib import as_buffer, frombytes, tobytes
from pyarrow.includes.libarrow_flight cimport *
from pyarrow.ipc import _get_legacy_format_default, _ReadPandasMixin
@@ -170,7 +171,7 @@ cdef class FlightTimedOutError(FlightError, ArrowException):
tobytes(str(self)), self.extra_info)
-cdef class FlightCancelledError(FlightError, ArrowException):
+cdef class FlightCancelledError(FlightError, ArrowCancelled):
cdef CStatus to_status(self):
return MakeFlightError(CFlightStatusCancelled, tobytes(str(self)),
self.extra_info)
diff --git a/python/pyarrow/tests/test_flight.py b/python/pyarrow/tests/test_flight.py
index 12f815dbea..9c61097251 100644
--- a/python/pyarrow/tests/test_flight.py
+++ b/python/pyarrow/tests/test_flight.py
@@ -355,17 +355,20 @@ class SlowFlightServer(FlightServerBase):
class ErrorFlightServer(FlightServerBase):
"""A Flight server that uses all the Flight-specific errors."""
+ errors = {
+ "internal": flight.FlightInternalError,
+ "timedout": flight.FlightTimedOutError,
+ "cancel": flight.FlightCancelledError,
+ "unauthenticated": flight.FlightUnauthenticatedError,
+ "unauthorized": flight.FlightUnauthorizedError,
+ "notimplemented": NotImplementedError,
+ "invalid": pa.ArrowInvalid,
+ "key": KeyError,
+ }
+
def do_action(self, context, action):
- if action.type == "internal":
- raise flight.FlightInternalError("foo")
- elif action.type == "timedout":
- raise flight.FlightTimedOutError("foo")
- elif action.type == "cancel":
- raise flight.FlightCancelledError("foo")
- elif action.type == "unauthenticated":
- raise flight.FlightUnauthenticatedError("foo")
- elif action.type == "unauthorized":
- raise flight.FlightUnauthorizedError("foo")
+ if action.type in self.errors:
+ raise self.errors[action.type]("foo")
elif action.type == "protobuf":
err_msg = b'this is an error message'
raise flight.FlightUnauthorizedError("foo", err_msg)
@@ -1561,16 +1564,9 @@ def test_roundtrip_errors():
with ErrorFlightServer() as server, \
FlightClient(('localhost', server.port)) as client:
- with pytest.raises(flight.FlightInternalError, match=".*foo.*"):
- list(client.do_action(flight.Action("internal", b"")))
- with pytest.raises(flight.FlightTimedOutError, match=".*foo.*"):
- list(client.do_action(flight.Action("timedout", b"")))
- with pytest.raises(flight.FlightCancelledError, match=".*foo.*"):
- list(client.do_action(flight.Action("cancel", b"")))
- with pytest.raises(flight.FlightUnauthenticatedError, match=".*foo.*"):
- list(client.do_action(flight.Action("unauthenticated", b"")))
- with pytest.raises(flight.FlightUnauthorizedError, match=".*foo.*"):
- list(client.do_action(flight.Action("unauthorized", b"")))
+ for arg, exc_type in ErrorFlightServer.errors.items():
+ with pytest.raises(exc_type, match=".*foo.*"):
+ list(client.do_action(flight.Action(arg, b"")))
with pytest.raises(flight.FlightInternalError, match=".*foo.*"):
list(client.list_flights())