You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by li...@apache.org on 2022/04/07 12:41:31 UTC
[arrow] branch master updated: ARROW-15706: [C++][FlightRPC] Implement a UCX transport
This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 542158fa08 ARROW-15706: [C++][FlightRPC] Implement a UCX transport
542158fa08 is described below
commit 542158fa0847810f375189a36173693d1fe507b8
Author: David Li <li...@gmail.com>
AuthorDate: Thu Apr 7 08:40:13 2022 -0400
ARROW-15706: [C++][FlightRPC] Implement a UCX transport
This PR implements a UCX-based transport for Arrow Flight in C++ based on the transport interfaces added in ARROW-15282. Currently, it supports DoExchange/DoGet/DoPut (i.e., all the data methods) and GetFlightInfo. It supports multiple concurrent calls on a single client, like gRPC (though with caveats; see the documentation) and can run the Flight benchmark.
Closes #12442 from lidavidm/flight-ucx
Authored-by: David Li <li...@gmail.com>
Signed-off-by: David Li <li...@gmail.com>
---
cpp/cmake_modules/DefineOptions.cmake | 4 +
cpp/src/arrow/CMakeLists.txt | 4 +
cpp/src/arrow/flight/CMakeLists.txt | 10 +
cpp/src/arrow/flight/flight_benchmark.cc | 35 +-
cpp/src/arrow/flight/perf_server.cc | 36 +-
cpp/src/arrow/flight/test_definitions.cc | 26 +-
cpp/src/arrow/flight/test_util.h | 2 +-
cpp/src/arrow/flight/transport/ucx/CMakeLists.txt | 77 ++
.../transport/ucx/flight_transport_ucx_test.cc | 386 +++++++
cpp/src/arrow/flight/transport/ucx/ucx.cc | 45 +
cpp/src/arrow/flight/transport/ucx/ucx.h | 35 +
cpp/src/arrow/flight/transport/ucx/ucx_client.cc | 733 ++++++++++++
cpp/src/arrow/flight/transport/ucx/ucx_internal.cc | 1171 ++++++++++++++++++++
cpp/src/arrow/flight/transport/ucx/ucx_internal.h | 354 ++++++
cpp/src/arrow/flight/transport/ucx/ucx_server.cc | 628 +++++++++++
.../arrow/flight/transport/ucx/util_internal.cc | 289 +++++
cpp/src/arrow/flight/transport/ucx/util_internal.h | 83 ++
cpp/src/arrow/flight/transport_server.cc | 5 +-
cpp/src/arrow/util/config.h.cmake | 1 +
docs/source/cpp/flight.rst | 35 +
docs/source/status.rst | 78 +-
21 files changed, 4001 insertions(+), 36 deletions(-)
diff --git a/cpp/cmake_modules/DefineOptions.cmake b/cpp/cmake_modules/DefineOptions.cmake
index 05fc14bbc7..ec1e0b6352 100644
--- a/cpp/cmake_modules/DefineOptions.cmake
+++ b/cpp/cmake_modules/DefineOptions.cmake
@@ -391,6 +391,10 @@ if("${CMAKE_SOURCE_DIR}" STREQUAL "${CMAKE_CURRENT_SOURCE_DIR}")
define_option(ARROW_WITH_ZLIB "Build with zlib compression" OFF)
define_option(ARROW_WITH_ZSTD "Build with zstd compression" OFF)
+ define_option(ARROW_WITH_UCX
+ "Build with UCX transport for Arrow Flight;(only used if ARROW_FLIGHT is ON)"
+ OFF)
+
define_option(ARROW_WITH_UTF8PROC
"Build with support for Unicode properties using the utf8proc library;(only used if ARROW_COMPUTE is ON or ARROW_GANDIVA is ON)"
ON)
diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt
index e9e826097b..b6f1e2481f 100644
--- a/cpp/src/arrow/CMakeLists.txt
+++ b/cpp/src/arrow/CMakeLists.txt
@@ -747,6 +747,10 @@ endif()
if(ARROW_FLIGHT)
add_subdirectory(flight)
+
+ if(ARROW_WITH_UCX)
+ add_subdirectory(flight/transport/ucx)
+ endif()
endif()
if(ARROW_FLIGHT_SQL)
diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt
index 7447e675e0..f9d135654b 100644
--- a/cpp/src/arrow/flight/CMakeLists.txt
+++ b/cpp/src/arrow/flight/CMakeLists.txt
@@ -313,4 +313,14 @@ if(ARROW_BUILD_BENCHMARKS)
add_dependencies(arrow-flight-benchmark arrow-flight-perf-server)
add_dependencies(arrow_flight arrow-flight-benchmark)
+
+ if(ARROW_WITH_UCX)
+ if(ARROW_FLIGHT_TEST_LINKAGE STREQUAL "static")
+ target_link_libraries(arrow-flight-benchmark arrow_flight_transport_ucx_static)
+ target_link_libraries(arrow-flight-perf-server arrow_flight_transport_ucx_static)
+ else()
+ target_link_libraries(arrow-flight-benchmark arrow_flight_transport_ucx_shared)
+ target_link_libraries(arrow-flight-perf-server arrow_flight_transport_ucx_shared)
+ endif()
+ endif()
endif(ARROW_BUILD_BENCHMARKS)
diff --git a/cpp/src/arrow/flight/flight_benchmark.cc b/cpp/src/arrow/flight/flight_benchmark.cc
index 872c67c80b..fa0cc9a3d5 100644
--- a/cpp/src/arrow/flight/flight_benchmark.cc
+++ b/cpp/src/arrow/flight/flight_benchmark.cc
@@ -40,12 +40,20 @@
#include "arrow/flight/test_util.h"
#ifdef ARROW_CUDA
+#include <cuda.h>
#include "arrow/gpu/cuda_api.h"
#endif
+#ifdef ARROW_WITH_UCX
+#include "arrow/flight/transport/ucx/ucx.h"
+#endif
DEFINE_bool(cuda, false, "Allocate results in CUDA memory");
DEFINE_string(transport, "grpc",
- "The network transport to use. Supported: \"grpc\" (default).");
+ "The network transport to use. Supported: \"grpc\" (default)"
+#ifdef ARROW_WITH_UCX
+ ", \"ucx\""
+#endif // ARROW_WITH_UCX
+ ".");
DEFINE_string(server_host, "",
"An existing performance server to benchmark against (leave blank to spawn "
"one automatically)");
@@ -497,6 +505,21 @@ int main(int argc, char** argv) {
options.disable_server_verification = true;
}
}
+ } else if (FLAGS_transport == "ucx") {
+#ifdef ARROW_WITH_UCX
+ arrow::flight::transport::ucx::InitializeFlightUcx();
+ if (FLAGS_test_unix || !FLAGS_server_unix.empty()) {
+ std::cerr << "Transport does not support domain sockets: " << FLAGS_transport
+ << std::endl;
+ return EXIT_FAILURE;
+ }
+ ARROW_CHECK_OK(arrow::flight::Location::Parse("ucx://" + FLAGS_server_host + ":" +
+ std::to_string(FLAGS_server_port))
+ .Value(&location));
+#else
+ std::cerr << "Not built with transport: " << FLAGS_transport << std::endl;
+ return EXIT_FAILURE;
+#endif
} else {
std::cerr << "Unknown transport: " << FLAGS_transport << std::endl;
return EXIT_FAILURE;
@@ -514,6 +537,16 @@ int main(int argc, char** argv) {
ABORT_NOT_OK(arrow::cuda::CudaDeviceManager::Instance().Value(&manager));
ABORT_NOT_OK(manager->GetDevice(0).Value(&device));
call_options.memory_manager = device->default_memory_manager();
+
+ // Needed to prevent UCX warning
+ // cuda_md.c:162 UCX ERROR cuMemGetAddressRange(0x7f2ab5dc0000) error: invalid
+ // device context
+ std::shared_ptr<arrow::cuda::CudaContext> context;
+ ABORT_NOT_OK(device->GetContext().Value(&context));
+ auto cuda_status = cuCtxPushCurrent(reinterpret_cast<CUcontext>(context->handle()));
+ if (cuda_status != CUDA_SUCCESS) {
+ ARROW_LOG(WARNING) << "CUDA error " << cuda_status;
+ }
#else
std::cerr << "-cuda requires that Arrow is built with ARROW_CUDA" << std::endl;
return 1;
diff --git a/cpp/src/arrow/flight/perf_server.cc b/cpp/src/arrow/flight/perf_server.cc
index cc42ffedd6..37e3ec4d77 100644
--- a/cpp/src/arrow/flight/perf_server.cc
+++ b/cpp/src/arrow/flight/perf_server.cc
@@ -19,6 +19,7 @@
#include <signal.h>
#include <cstdint>
+#include <cstdlib>
#include <fstream>
#include <iostream>
#include <memory>
@@ -43,10 +44,17 @@
#ifdef ARROW_CUDA
#include "arrow/gpu/cuda_api.h"
#endif
+#ifdef ARROW_WITH_UCX
+#include "arrow/flight/transport/ucx/ucx.h"
+#endif
DEFINE_bool(cuda, false, "Allocate results in CUDA memory");
DEFINE_string(transport, "grpc",
- "The network transport to use. Supported: \"grpc\" (default).");
+ "The network transport to use. Supported: \"grpc\" (default)"
+#ifdef ARROW_WITH_UCX
+ ", \"ucx\""
+#endif // ARROW_WITH_UCX
+ ".");
DEFINE_string(server_host, "localhost", "Host where the server is running on");
DEFINE_int32(port, 31337, "Server port to listen on");
DEFINE_string(server_unix, "", "Unix socket path where the server is running on");
@@ -97,7 +105,7 @@ class PerfDataStream : public FlightDataStream {
if (records_sent_ >= total_records_) {
// Signal that iteration is over
payload.ipc_message.metadata = nullptr;
- return Status::OK();
+ return payload;
}
if (verify_) {
@@ -274,6 +282,29 @@ int main(int argc, char** argv) {
ARROW_CHECK_OK(arrow::flight::Location::ForGrpcUnix(FLAGS_server_unix)
.Value(&connect_location));
}
+ } else if (FLAGS_transport == "ucx") {
+#ifdef ARROW_WITH_UCX
+ arrow::flight::transport::ucx::InitializeFlightUcx();
+ if (FLAGS_server_unix.empty()) {
+ if (!FLAGS_cert_file.empty() || !FLAGS_key_file.empty()) {
+ std::cerr << "Transport does not support TLS: " << FLAGS_transport << std::endl;
+ return EXIT_FAILURE;
+ }
+ ARROW_CHECK_OK(arrow::flight::Location::Parse("ucx://" + FLAGS_server_host + ":" +
+ std::to_string(FLAGS_port))
+ .Value(&bind_location));
+ ARROW_CHECK_OK(arrow::flight::Location::Parse("ucx://" + FLAGS_server_host + ":" +
+ std::to_string(FLAGS_port))
+ .Value(&connect_location));
+ } else {
+ std::cerr << "Transport does not support domain sockets: " << FLAGS_transport
+ << std::endl;
+ return EXIT_FAILURE;
+ }
+#else
+ std::cerr << "Not built with transport: " << FLAGS_transport << std::endl;
+ return EXIT_FAILURE;
+#endif
} else {
std::cerr << "Unknown transport: " << FLAGS_transport << std::endl;
return EXIT_FAILURE;
@@ -308,6 +339,7 @@ int main(int argc, char** argv) {
// Exit with a clean error code (0) on SIGTERM
ARROW_CHECK_OK(g_server->SetShutdownOnSignals({SIGTERM}));
std::cout << "Server transport: " << FLAGS_transport << std::endl;
+ std::cout << "Server location: " << connect_location.ToString() << std::endl;
if (FLAGS_server_unix.empty()) {
std::cout << "Server host: " << FLAGS_server_host << std::endl;
std::cout << "Server port: " << FLAGS_port << std::endl;
diff --git a/cpp/src/arrow/flight/test_definitions.cc b/cpp/src/arrow/flight/test_definitions.cc
index 2cfac64144..1ec06a1f00 100644
--- a/cpp/src/arrow/flight/test_definitions.cc
+++ b/cpp/src/arrow/flight/test_definitions.cc
@@ -45,7 +45,7 @@ using arrow::internal::checked_cast;
void ConnectivityTest::TestGetPort() {
std::unique_ptr<FlightServerBase> server = ExampleTestServer();
- ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0));
+ ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0));
FlightServerOptions options(location);
ASSERT_OK(server->Init(options));
ASSERT_GT(server->port(), 0);
@@ -53,7 +53,7 @@ void ConnectivityTest::TestGetPort() {
void ConnectivityTest::TestBuilderHook() {
std::unique_ptr<FlightServerBase> server = ExampleTestServer();
- ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0));
+ ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0));
FlightServerOptions options(location);
bool builder_hook_run = false;
options.builder_hook = [&builder_hook_run](void* builder) {
@@ -68,7 +68,7 @@ void ConnectivityTest::TestBuilderHook() {
void ConnectivityTest::TestShutdown() {
// Regression test for ARROW-15181
constexpr int kIterations = 10;
- ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0));
+ ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0));
for (int i = 0; i < kIterations; i++) {
std::unique_ptr<FlightServerBase> server = ExampleTestServer();
@@ -84,7 +84,7 @@ void ConnectivityTest::TestShutdown() {
void ConnectivityTest::TestShutdownWithDeadline() {
std::unique_ptr<FlightServerBase> server = ExampleTestServer();
- ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0));
+ ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0));
FlightServerOptions options(location);
ASSERT_OK(server->Init(options));
ASSERT_GT(server->port(), 0);
@@ -96,13 +96,13 @@ void ConnectivityTest::TestShutdownWithDeadline() {
}
void ConnectivityTest::TestBrokenConnection() {
std::unique_ptr<FlightServerBase> server = ExampleTestServer();
- ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0));
+ ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0));
FlightServerOptions options(location);
ASSERT_OK(server->Init(options));
std::unique_ptr<FlightClient> client;
ASSERT_OK_AND_ASSIGN(location,
- Location::ForScheme(transport(), "localhost", server->port()));
+ Location::ForScheme(transport(), "127.0.0.1", server->port()));
ASSERT_OK_AND_ASSIGN(client, FlightClient::Connect(location));
ASSERT_OK(server->Shutdown());
@@ -117,7 +117,7 @@ void ConnectivityTest::TestBrokenConnection() {
void DataTest::SetUp() {
server_ = ExampleTestServer();
- ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0));
+ ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0));
FlightServerOptions options(location);
ASSERT_OK(server_->Init(options));
@@ -129,7 +129,7 @@ void DataTest::TearDown() {
}
Status DataTest::ConnectClient() {
ARROW_ASSIGN_OR_RAISE(auto location,
- Location::ForScheme(transport(), "localhost", server_->port()));
+ Location::ForScheme(transport(), "127.0.0.1", server_->port()));
ARROW_ASSIGN_OR_RAISE(client_, FlightClient::Connect(location));
return Status::OK();
}
@@ -638,7 +638,7 @@ class DoPutTestServer : public FlightServerBase {
};
void DoPutTest::SetUp() {
- ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0));
+ ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0));
ASSERT_OK(MakeServer<DoPutTestServer>(
location, &server_, &client_,
[](FlightServerOptions* options) { return Status::OK(); },
@@ -766,7 +766,7 @@ void DoPutTest::TestLargeBatch() {
void DoPutTest::TestSizeLimit() {
const int64_t size_limit = 4096;
ASSERT_OK_AND_ASSIGN(auto location,
- Location::ForScheme(transport(), "localhost", server_->port()));
+ Location::ForScheme(transport(), "127.0.0.1", server_->port()));
auto client_options = FlightClientOptions::Defaults();
client_options.write_size_limit_bytes = size_limit;
ASSERT_OK_AND_ASSIGN(auto client, FlightClient::Connect(location, client_options));
@@ -866,7 +866,7 @@ Status AppMetadataTestServer::DoPut(const ServerCallContext& context,
}
void AppMetadataTest::SetUp() {
- ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0));
+ ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0));
ASSERT_OK(MakeServer<AppMetadataTestServer>(
location, &server_, &client_,
[](FlightServerOptions* options) { return Status::OK(); },
@@ -1045,7 +1045,7 @@ class IpcOptionsTestServer : public FlightServerBase {
};
void IpcOptionsTest::SetUp() {
- ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0));
+ ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0));
ASSERT_OK(MakeServer<IpcOptionsTestServer>(
location, &server_, &client_,
[](FlightServerOptions* options) { return Status::OK(); },
@@ -1241,7 +1241,7 @@ void CudaDataTest::SetUp() {
impl_->device = std::move(device);
impl_->context = std::move(context);
- ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0));
+ ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0));
ASSERT_OK(MakeServer<CudaTestServer>(
location, &server_, &client_,
[this](FlightServerOptions* options) {
diff --git a/cpp/src/arrow/flight/test_util.h b/cpp/src/arrow/flight/test_util.h
index 5320b958d5..d5b774b4a3 100644
--- a/cpp/src/arrow/flight/test_util.h
+++ b/cpp/src/arrow/flight/test_util.h
@@ -113,7 +113,7 @@ Status MakeServer(const Location& location, std::unique_ptr<FlightServerBase>* s
RETURN_NOT_OK(make_server_options(&server_options));
RETURN_NOT_OK((*server)->Init(server_options));
std::string uri =
- location.scheme() + "://localhost:" + std::to_string((*server)->port());
+ location.scheme() + "://127.0.0.1:" + std::to_string((*server)->port());
ARROW_ASSIGN_OR_RAISE(auto real_location, Location::Parse(uri));
FlightClientOptions client_options = FlightClientOptions::Defaults();
RETURN_NOT_OK(make_client_options(&client_options));
diff --git a/cpp/src/arrow/flight/transport/ucx/CMakeLists.txt b/cpp/src/arrow/flight/transport/ucx/CMakeLists.txt
new file mode 100644
index 0000000000..6e315b68d6
--- /dev/null
+++ b/cpp/src/arrow/flight/transport/ucx/CMakeLists.txt
@@ -0,0 +1,77 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+add_custom_target(arrow_flight_transport_ucx)
+arrow_install_all_headers("arrow/flight/transport/ucx")
+
+find_package(PkgConfig REQUIRED)
+pkg_check_modules(UCX REQUIRED IMPORTED_TARGET ucx)
+
+set(ARROW_FLIGHT_TRANSPORT_UCX_SRCS
+ ucx_client.cc
+ ucx_server.cc
+ ucx.cc
+ ucx_internal.cc
+ util_internal.cc)
+set(ARROW_FLIGHT_TRANSPORT_UCX_LINK_LIBS)
+
+include_directories(SYSTEM ${UCX_INCLUDE_DIRS})
+list(APPEND ARROW_FLIGHT_TRANSPORT_UCX_LINK_LIBS PkgConfig::UCX)
+
+add_arrow_lib(arrow_flight_transport_ucx
+ # CMAKE_PACKAGE_NAME
+ # ArrowFlightTransportUcx
+ # PKG_CONFIG_NAME
+ # arrow-flight-transport-ucx
+ SOURCES
+ ${ARROW_FLIGHT_TRANSPORT_UCX_SRCS}
+ PRECOMPILED_HEADERS
+ "$<$<COMPILE_LANGUAGE:CXX>:arrow/flight/transport/ucx/pch.h>"
+ DEPENDENCIES
+ SHARED_LINK_FLAGS
+ ${ARROW_VERSION_SCRIPT_FLAGS} # Defined in cpp/arrow/CMakeLists.txt
+ SHARED_LINK_LIBS
+ arrow_shared
+ arrow_flight_shared
+ ${ARROW_FLIGHT_TRANSPORT_UCX_LINK_LIBS}
+ STATIC_LINK_LIBS
+ arrow_static
+ arrow_flight_static
+ ${ARROW_FLIGHT_TRANSPORT_UCX_LINK_LIBS})
+
+if(ARROW_BUILD_TESTS)
+ if(ARROW_FLIGHT_TEST_LINKAGE STREQUAL "static")
+ set(ARROW_FLIGHT_UCX_TEST_LINK_LIBS
+ arrow_static
+ arrow_flight_static
+ arrow_flight_testing_static
+ arrow_flight_transport_ucx_static
+ ${ARROW_TEST_LINK_LIBS})
+ else()
+ set(ARROW_FLIGHT_UCX_TEST_LINK_LIBS
+ arrow_shared
+ arrow_flight_shared
+ arrow_flight_testing_shared
+ arrow_flight_transport_ucx_shared
+ ${ARROW_TEST_LINK_LIBS})
+ endif()
+ add_arrow_test(flight_transport_ucx_test
+ STATIC_LINK_LIBS
+ ${ARROW_FLIGHT_UCX_TEST_LINK_LIBS}
+ LABELS
+ "arrow_flight")
+endif()
diff --git a/cpp/src/arrow/flight/transport/ucx/flight_transport_ucx_test.cc b/cpp/src/arrow/flight/transport/ucx/flight_transport_ucx_test.cc
new file mode 100644
index 0000000000..6a580af92f
--- /dev/null
+++ b/cpp/src/arrow/flight/transport/ucx/flight_transport_ucx_test.cc
@@ -0,0 +1,386 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include "arrow/array/array_base.h"
+#include "arrow/flight/test_definitions.h"
+#include "arrow/flight/test_util.h"
+#include "arrow/flight/transport/ucx/ucx.h"
+#include "arrow/table.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/config.h"
+
+#ifdef UCP_API_VERSION
+#error "UCX headers should not be in public API"
+#endif
+
+#include "arrow/flight/transport/ucx/ucx_internal.h"
+
+#ifdef ARROW_CUDA
+#include "arrow/gpu/cuda_api.h"
+#endif
+
+namespace arrow {
+namespace flight {
+
+class UcxEnvironment : public ::testing::Environment {
+ public:
+ void SetUp() override { transport::ucx::InitializeFlightUcx(); }
+};
+
+testing::Environment* const kUcxEnvironment =
+ testing::AddGlobalTestEnvironment(new UcxEnvironment());
+
+//------------------------------------------------------------
+// Common transport tests
+
+class UcxConnectivityTest : public ConnectivityTest {
+ protected:
+ std::string transport() const override { return "ucx"; }
+};
+ARROW_FLIGHT_TEST_CONNECTIVITY(UcxConnectivityTest);
+
+class UcxDataTest : public DataTest {
+ protected:
+ std::string transport() const override { return "ucx"; }
+};
+ARROW_FLIGHT_TEST_DATA(UcxDataTest);
+
+class UcxDoPutTest : public DoPutTest {
+ protected:
+ std::string transport() const override { return "ucx"; }
+};
+ARROW_FLIGHT_TEST_DO_PUT(UcxDoPutTest);
+
+class UcxAppMetadataTest : public AppMetadataTest {
+ protected:
+ std::string transport() const override { return "ucx"; }
+};
+ARROW_FLIGHT_TEST_APP_METADATA(UcxAppMetadataTest);
+
+class UcxIpcOptionsTest : public IpcOptionsTest {
+ protected:
+ std::string transport() const override { return "ucx"; }
+};
+ARROW_FLIGHT_TEST_IPC_OPTIONS(UcxIpcOptionsTest);
+
+class UcxCudaDataTest : public CudaDataTest {
+ protected:
+ std::string transport() const override { return "ucx"; }
+};
+ARROW_FLIGHT_TEST_CUDA_DATA(UcxCudaDataTest);
+
+//------------------------------------------------------------
+// UCX internals tests
+
+constexpr std::initializer_list<StatusCode> kStatusCodes = {
+ StatusCode::OK,
+ StatusCode::OutOfMemory,
+ StatusCode::KeyError,
+ StatusCode::TypeError,
+ StatusCode::Invalid,
+ StatusCode::IOError,
+ StatusCode::CapacityError,
+ StatusCode::IndexError,
+ StatusCode::Cancelled,
+ StatusCode::UnknownError,
+ StatusCode::NotImplemented,
+ StatusCode::SerializationError,
+ StatusCode::RError,
+ StatusCode::CodeGenError,
+ StatusCode::ExpressionValidationError,
+ StatusCode::ExecutionError,
+ StatusCode::AlreadyExists,
+};
+
+constexpr std::initializer_list<FlightStatusCode> kFlightStatusCodes = {
+ FlightStatusCode::Internal, FlightStatusCode::TimedOut,
+ FlightStatusCode::Cancelled, FlightStatusCode::Unauthenticated,
+ FlightStatusCode::Unauthorized, FlightStatusCode::Unavailable,
+ FlightStatusCode::Failed,
+};
+
+class TestStatusDetail : public StatusDetail {
+ public:
+ const char* type_id() const override { return "test-status-detail"; }
+ std::string ToString() const override { return "Custom status detail"; }
+};
+
+namespace transport {
+namespace ucx {
+
+static constexpr std::initializer_list<FrameType> kFrameTypes = {
+ FrameType::kHeaders, FrameType::kBuffer, FrameType::kPayloadHeader,
+ FrameType::kPayloadBody, FrameType::kDisconnect,
+};
+
+TEST(FrameHeader, Basics) {
+ for (const auto frame_type : kFrameTypes) {
+ FrameHeader header;
+ ASSERT_OK(header.Set(frame_type, /*counter=*/42, /*body_size=*/65535));
+ if (frame_type == FrameType::kDisconnect) {
+ ASSERT_RAISES(Cancelled, Frame::ParseHeader(header.data(), header.size()));
+ } else {
+ ASSERT_OK_AND_ASSIGN(auto frame, Frame::ParseHeader(header.data(), header.size()));
+ ASSERT_EQ(frame->type, frame_type);
+ ASSERT_EQ(frame->counter, 42);
+ ASSERT_EQ(frame->size, 65535);
+ }
+ }
+}
+
+TEST(FrameHeader, FrameType) {
+ for (const auto frame_type : kFrameTypes) {
+ ASSERT_LE(static_cast<int>(frame_type), static_cast<int>(FrameType::kMaxFrameType));
+ }
+}
+
+TEST(HeadersFrame, Parse) {
+ const char* data =
+ ("\x00\x00\x00\x02\x00\x00\x00\x05\x00\x00\x00\x03x-foobar"
+ "\x00\x00\x00\x05\x00\x00\x00\x01x-bin\x01");
+ constexpr int64_t size = 34;
+
+ {
+ std::unique_ptr<Buffer> buffer(
+ new Buffer(reinterpret_cast<const uint8_t*>(data), size));
+ ASSERT_OK_AND_ASSIGN(auto headers, HeadersFrame::Parse(std::move(buffer)));
+ ASSERT_OK_AND_ASSIGN(auto foo, headers.Get("x-foo"));
+ ASSERT_EQ(foo, "bar");
+ ASSERT_OK_AND_ASSIGN(auto bin, headers.Get("x-bin"));
+ ASSERT_EQ(bin, "\x01");
+ }
+ {
+ std::unique_ptr<Buffer> buffer(new Buffer(reinterpret_cast<const uint8_t*>(data), 3));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
+ ::testing::HasSubstr("expected number of headers"),
+ HeadersFrame::Parse(std::move(buffer)));
+ }
+ {
+ std::unique_ptr<Buffer> buffer(new Buffer(reinterpret_cast<const uint8_t*>(data), 7));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
+ ::testing::HasSubstr("expected length of key 1"),
+ HeadersFrame::Parse(std::move(buffer)));
+ }
+ {
+ std::unique_ptr<Buffer> buffer(
+ new Buffer(reinterpret_cast<const uint8_t*>(data), 10));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
+ ::testing::HasSubstr("expected length of value 1"),
+ HeadersFrame::Parse(std::move(buffer)));
+ }
+ {
+ std::unique_ptr<Buffer> buffer(
+ new Buffer(reinterpret_cast<const uint8_t*>(data), 12));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ ::testing::HasSubstr("expected key 1 to have length 5, but only 0 bytes remain"),
+ HeadersFrame::Parse(std::move(buffer)));
+ }
+ {
+ std::unique_ptr<Buffer> buffer(
+ new Buffer(reinterpret_cast<const uint8_t*>(data), 17));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ ::testing::HasSubstr(
+ "expected value 1 to have length 3, but only 0 bytes remain"),
+ HeadersFrame::Parse(std::move(buffer)));
+ }
+}
+
+TEST(HeadersFrame, RoundTripStatus) {
+ for (const auto code : kStatusCodes) {
+ {
+ Status expected = code == StatusCode::OK ? Status() : Status(code, "foo");
+ ASSERT_OK_AND_ASSIGN(auto headers, HeadersFrame::Make(expected, {}));
+ Status status;
+ ASSERT_OK(headers.GetStatus(&status));
+ ASSERT_EQ(status, expected);
+ }
+
+ if (code == StatusCode::OK) continue;
+
+ // Attach a generic status detail
+ {
+ auto detail = std::make_shared<TestStatusDetail>();
+ Status original(code, "foo", detail);
+ Status expected(code, "foo",
+ std::make_shared<FlightStatusDetail>(FlightStatusCode::Internal,
+ detail->ToString()));
+ ASSERT_OK_AND_ASSIGN(auto headers, HeadersFrame::Make(expected, {}));
+ Status status;
+ ASSERT_OK(headers.GetStatus(&status));
+ ASSERT_EQ(status, expected);
+ }
+
+ // Attach a Flight status detail
+ for (const auto flight_code : kFlightStatusCodes) {
+ Status expected(code, "foo",
+ std::make_shared<FlightStatusDetail>(flight_code, "extra"));
+ ASSERT_OK_AND_ASSIGN(auto headers, HeadersFrame::Make(expected, {}));
+ Status status;
+ ASSERT_OK(headers.GetStatus(&status));
+ ASSERT_EQ(status, expected);
+ }
+ }
+}
+} // namespace ucx
+} // namespace transport
+
+//------------------------------------------------------------
+// Ad-hoc UCX-specific tests
+
+class SimpleTestServer : public FlightServerBase {
+ public:
+ Status GetFlightInfo(const ServerCallContext& context, const FlightDescriptor& request,
+ std::unique_ptr<FlightInfo>* info) override {
+ if (request.path.size() > 0 && request.path[0] == "error") {
+ return status_;
+ }
+ auto examples = ExampleFlightInfo();
+ info->reset(new FlightInfo(examples[0]));
+ return Status::OK();
+ }
+
+ Status DoGet(const ServerCallContext& context, const Ticket& request,
+ std::unique_ptr<FlightDataStream>* data_stream) override {
+ RecordBatchVector batches;
+ RETURN_NOT_OK(ExampleIntBatches(&batches));
+ auto batch_reader = std::make_shared<BatchIterator>(batches[0]->schema(), batches);
+ *data_stream = std::unique_ptr<FlightDataStream>(new RecordBatchStream(batch_reader));
+ return Status::OK();
+ }
+
+ void set_error_status(Status st) { status_ = std::move(st); }
+
+ private:
+ Status status_;
+};
+
+class TestUcx : public ::testing::Test {
+ public:
+ void SetUp() {
+ ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme("ucx", "127.0.0.1", 0));
+ ASSERT_OK(MakeServer<SimpleTestServer>(
+ location, &server_, &client_,
+ [](FlightServerOptions* options) { return Status::OK(); },
+ [](FlightClientOptions* options) { return Status::OK(); }));
+ }
+
+ void TearDown() {
+ ASSERT_OK(client_->Close());
+ ASSERT_OK(server_->Shutdown());
+ }
+
+ protected:
+ std::unique_ptr<FlightClient> client_;
+ std::unique_ptr<FlightServerBase> server_;
+};
+
+TEST_F(TestUcx, GetFlightInfo) {
+ auto descriptor = FlightDescriptor::Path({"foo", "bar"});
+ std::unique_ptr<FlightInfo> info;
+ ASSERT_OK_AND_ASSIGN(info, client_->GetFlightInfo(descriptor));
+ // Test that we can reuse the connection
+ ASSERT_OK_AND_ASSIGN(info, client_->GetFlightInfo(descriptor));
+}
+
+TEST_F(TestUcx, SequentialClients) {
+ ASSERT_OK_AND_ASSIGN(
+ auto client2,
+ FlightClient::Connect(server_->location(), FlightClientOptions::Defaults()));
+
+ Ticket ticket{"a"};
+
+ ASSERT_OK_AND_ASSIGN(auto stream1, client_->DoGet(ticket));
+ ASSERT_OK_AND_ASSIGN(auto table1, stream1->ToTable());
+
+ ASSERT_OK_AND_ASSIGN(auto stream2, client2->DoGet(ticket));
+ ASSERT_OK_AND_ASSIGN(auto table2, stream2->ToTable());
+
+ AssertTablesEqual(*table1, *table2);
+}
+
+TEST_F(TestUcx, ConcurrentClients) {
+ ASSERT_OK_AND_ASSIGN(
+ auto client2,
+ FlightClient::Connect(server_->location(), FlightClientOptions::Defaults()));
+
+ Ticket ticket{"a"};
+
+ ASSERT_OK_AND_ASSIGN(auto stream1, client_->DoGet(ticket));
+ ASSERT_OK_AND_ASSIGN(auto stream2, client2->DoGet(ticket));
+
+ ASSERT_OK_AND_ASSIGN(auto table1, stream1->ToTable());
+ ASSERT_OK_AND_ASSIGN(auto table2, stream2->ToTable());
+
+ AssertTablesEqual(*table1, *table2);
+}
+
+TEST_F(TestUcx, Errors) {
+ auto descriptor = FlightDescriptor::Path({"error", "bar"});
+ auto* server = reinterpret_cast<SimpleTestServer*>(server_.get());
+ for (const auto code : kStatusCodes) {
+ if (code == StatusCode::OK) continue;
+
+ Status expected(code, "Error message");
+ server->set_error_status(expected);
+ Status actual = client_->GetFlightInfo(descriptor).status();
+ ASSERT_EQ(actual, expected);
+
+ // Attach a generic status detail
+ {
+ auto detail = std::make_shared<TestStatusDetail>();
+ server->set_error_status(Status(code, "foo", detail));
+ Status expected(code, "foo",
+ std::make_shared<FlightStatusDetail>(FlightStatusCode::Internal,
+ detail->ToString()));
+ Status actual = client_->GetFlightInfo(descriptor).status();
+ ASSERT_EQ(actual, expected);
+ }
+
+ // Attach a Flight status detail
+ for (const auto flight_code : kFlightStatusCodes) {
+ Status expected(code, "Error message",
+ std::make_shared<FlightStatusDetail>(flight_code, "extra"));
+ server->set_error_status(expected);
+ Status actual = client_->GetFlightInfo(descriptor).status();
+ ASSERT_EQ(actual, expected);
+ }
+ }
+}
+
+TEST(TestUcxIpV6, DISABLED_IpV6Port) {
+ // Also, disabled in CI as machines lack an IPv6 interface
+ ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme("ucx", "[::1]", 0));
+
+ std::unique_ptr<FlightServerBase> server(new SimpleTestServer());
+ FlightServerOptions server_options(location);
+ ASSERT_OK(server->Init(server_options));
+
+ FlightClientOptions client_options = FlightClientOptions::Defaults();
+ ASSERT_OK_AND_ASSIGN(auto client,
+ FlightClient::Connect(server->location(), client_options));
+
+ auto descriptor = FlightDescriptor::Path({"foo", "bar"});
+ ASSERT_OK_AND_ASSIGN(auto info, client->GetFlightInfo(descriptor));
+}
+
+} // namespace flight
+} // namespace arrow
diff --git a/cpp/src/arrow/flight/transport/ucx/ucx.cc b/cpp/src/arrow/flight/transport/ucx/ucx.cc
new file mode 100644
index 0000000000..0e3daf6021
--- /dev/null
+++ b/cpp/src/arrow/flight/transport/ucx/ucx.cc
@@ -0,0 +1,45 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/flight/transport/ucx/ucx.h"
+
+#include <mutex>
+
+#include "arrow/flight/transport.h"
+#include "arrow/flight/transport/ucx/ucx_internal.h"
+#include "arrow/flight/transport_server.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+namespace flight {
+namespace transport {
+namespace ucx {
+
+namespace {
+std::once_flag kInitializeOnce;
+}
+void InitializeFlightUcx() {
+ std::call_once(kInitializeOnce, []() {
+ auto* registry = flight::internal::GetDefaultTransportRegistry();
+ DCHECK_OK(registry->RegisterClient("ucx", MakeUcxClientImpl));
+ DCHECK_OK(registry->RegisterServer("ucx", MakeUcxServerImpl));
+ });
+}
+} // namespace ucx
+} // namespace transport
+} // namespace flight
+} // namespace arrow
diff --git a/cpp/src/arrow/flight/transport/ucx/ucx.h b/cpp/src/arrow/flight/transport/ucx/ucx.h
new file mode 100644
index 0000000000..dda2c83035
--- /dev/null
+++ b/cpp/src/arrow/flight/transport/ucx/ucx.h
@@ -0,0 +1,35 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Experimental UCX-based transport for Flight.
+
+#pragma once
+
+#include "arrow/flight/visibility.h"
+
+namespace arrow {
+namespace flight {
+namespace transport {
+namespace ucx {
+
+ARROW_FLIGHT_EXPORT
+void InitializeFlightUcx();
+
+} // namespace ucx
+} // namespace transport
+} // namespace flight
+} // namespace arrow
diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_client.cc b/cpp/src/arrow/flight/transport/ucx/ucx_client.cc
new file mode 100644
index 0000000000..173132062e
--- /dev/null
+++ b/cpp/src/arrow/flight/transport/ucx/ucx_client.cc
@@ -0,0 +1,733 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+/// The client-side implementation of a UCX-based transport for
+/// Flight.
+///
+/// Each UCX driver is used to support one call at a time. This gives
+/// the greatest throughput for data plane methods, but is relatively
+/// expensive in terms of other resources, both for the server and the
+/// client. (UCX drivers have multiple threading modes: single-thread
+/// access, serialized access, and multi-thread access. Testing found
+/// that multi-thread access incurred high synchronization costs.)
+/// Hence, for concurrent calls in a single client, we must maintain
+/// multiple drivers, and so unlike gRPC, there is no real difference
+/// between using one client concurrently and using multiple
+/// independent clients.
+
+#include "arrow/flight/transport/ucx/ucx_internal.h"
+
+#include <condition_variable>
+#include <deque>
+#include <mutex>
+#include <thread>
+
+#include <arpa/inet.h>
+#include <ucp/api/ucp.h>
+
+#include "arrow/buffer.h"
+#include "arrow/flight/client.h"
+#include "arrow/flight/transport.h"
+#include "arrow/flight/transport/ucx/util_internal.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/make_unique.h"
+#include "arrow/util/uri.h"
+
+namespace arrow {
+namespace flight {
+namespace transport {
+namespace ucx {
+
+namespace {
+class UcxClientImpl;
+
+Status MergeStatuses(Status server_status, Status transport_status) {
+ if (server_status.ok()) {
+ if (transport_status.ok()) return server_status;
+ return transport_status;
+ } else if (transport_status.ok()) {
+ return server_status;
+ }
+ return Status::FromDetailAndArgs(server_status.code(), server_status.detail(),
+ server_status.message(),
+ ". Transport context: ", transport_status.ToString());
+}
+
+/// \brief An individual connection to the server.
+class ClientConnection {
+ public:
+ ClientConnection() = default;
+ ARROW_DISALLOW_COPY_AND_ASSIGN(ClientConnection);
+ ARROW_DEFAULT_MOVE_AND_ASSIGN(ClientConnection);
+ ~ClientConnection() { DCHECK(!driver_) << "Connection was not closed!"; }
+
+ Status Init(std::shared_ptr<UcpContext> ucp_context, const arrow::internal::Uri& uri) {
+ auto status = InitImpl(std::move(ucp_context), uri);
+ // Clean up after-the-fact if we fail to initialize
+ if (!status.ok()) {
+ if (driver_) {
+ status = MergeStatuses(std::move(status), driver_->Close());
+ driver_.reset();
+ remote_endpoint_ = nullptr;
+ }
+ if (ucp_worker_) ucp_worker_.reset();
+ }
+ return status;
+ }
+
+ Status InitImpl(std::shared_ptr<UcpContext> ucp_context,
+ const arrow::internal::Uri& uri) {
+ {
+ ucs_status_t status;
+ ucp_worker_params_t worker_params;
+ std::memset(&worker_params, 0, sizeof(worker_params));
+ worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE;
+ worker_params.thread_mode = UCS_THREAD_MODE_SERIALIZED;
+
+ ucp_worker_h ucp_worker;
+ status = ucp_worker_create(ucp_context->get(), &worker_params, &ucp_worker);
+ RETURN_NOT_OK(FromUcsStatus("ucp_worker_create", status));
+ ucp_worker_.reset(new UcpWorker(std::move(ucp_context), ucp_worker));
+ }
+ {
+ // Create endpoint for remote worker
+ struct sockaddr_storage connect_addr;
+ ARROW_ASSIGN_OR_RAISE(auto addrlen, UriToSockaddr(uri, &connect_addr));
+ std::string peer;
+ ARROW_UNUSED(SockaddrToString(connect_addr).Value(&peer));
+ ARROW_LOG(DEBUG) << "Connecting to " << peer;
+
+ ucp_ep_params_t params;
+ params.field_mask = UCP_EP_PARAM_FIELD_FLAGS | UCP_EP_PARAM_FIELD_NAME |
+ UCP_EP_PARAM_FIELD_SOCK_ADDR;
+ params.flags = UCP_EP_PARAMS_FLAGS_CLIENT_SERVER;
+ params.name = "UcxClientImpl";
+ params.sockaddr.addr = reinterpret_cast<const sockaddr*>(&connect_addr);
+ params.sockaddr.addrlen = addrlen;
+
+ auto status = ucp_ep_create(ucp_worker_->get(), ¶ms, &remote_endpoint_);
+ RETURN_NOT_OK(FromUcsStatus("ucp_ep_create", status));
+ }
+
+ driver_.reset(new UcpCallDriver(ucp_worker_, remote_endpoint_));
+ ARROW_LOG(DEBUG) << "Connected to " << driver_->peer();
+
+ {
+ // Set up Active Message (AM) handler
+ ucp_am_handler_param_t handler_params;
+ handler_params.field_mask = UCP_AM_HANDLER_PARAM_FIELD_ID |
+ UCP_AM_HANDLER_PARAM_FIELD_CB |
+ UCP_AM_HANDLER_PARAM_FIELD_ARG;
+ handler_params.id = kUcpAmHandlerId;
+ handler_params.cb = HandleIncomingActiveMessage;
+ handler_params.arg = driver_.get();
+ ucs_status_t status =
+ ucp_worker_set_am_recv_handler(ucp_worker_->get(), &handler_params);
+ RETURN_NOT_OK(FromUcsStatus("ucp_worker_set_am_recv_handler", status));
+ }
+
+ return Status::OK();
+ }
+
+ Status Close() {
+ if (!driver_) return Status::OK();
+
+ auto status = driver_->SendFrame(FrameType::kDisconnect, nullptr, 0);
+ const auto ucs_status = FlightUcxStatusDetail::Unwrap(status);
+ if (IsIgnorableDisconnectError(ucs_status)) {
+ status = Status::OK();
+ }
+ status = MergeStatuses(std::move(status), driver_->Close());
+
+ driver_.reset();
+ remote_endpoint_ = nullptr;
+ ucp_worker_.reset();
+ return status;
+ }
+
+ UcpCallDriver* driver() {
+ DCHECK(driver_);
+ return driver_.get();
+ }
+
+ private:
+ static ucs_status_t HandleIncomingActiveMessage(void* self, const void* header,
+ size_t header_length, void* data,
+ size_t data_length,
+ const ucp_am_recv_param_t* param) {
+ auto* driver = reinterpret_cast<UcpCallDriver*>(self);
+ return driver->RecvActiveMessage(header, header_length, data, data_length, param);
+ }
+
+ std::shared_ptr<UcpWorker> ucp_worker_;
+ ucp_ep_h remote_endpoint_;
+ std::unique_ptr<UcpCallDriver> driver_;
+};
+
+class UcxClientStream : public internal::ClientDataStream {
+ public:
+ UcxClientStream(UcxClientImpl* impl, ClientConnection conn)
+ : impl_(impl),
+ conn_(std::move(conn)),
+ driver_(conn_.driver()),
+ writes_done_(false),
+ finished_(false) {}
+
+ protected:
+ Status DoFinish() override;
+
+ UcxClientImpl* impl_;
+ ClientConnection conn_;
+ UcpCallDriver* driver_;
+ bool writes_done_;
+ bool finished_;
+ Status io_status_;
+ Status server_status_;
+};
+
+class GetClientStream : public UcxClientStream {
+ public:
+ GetClientStream(UcxClientImpl* impl, ClientConnection conn)
+ : UcxClientStream(impl, std::move(conn)) {
+ writes_done_ = true;
+ }
+
+ bool ReadData(internal::FlightData* data) override {
+ if (finished_) return false;
+
+ bool success = true;
+ io_status_ = ReadImpl(data).Value(&success);
+
+ if (!io_status_.ok() || !success) {
+ finished_ = true;
+ }
+ return success;
+ }
+
+ private:
+ ::arrow::Result<bool> ReadImpl(internal::FlightData* data) {
+ ARROW_ASSIGN_OR_RAISE(auto frame, driver_->ReadNextFrame());
+
+ if (frame->type == FrameType::kHeaders) {
+ // Trailers, stream is over
+ ARROW_ASSIGN_OR_RAISE(auto headers, HeadersFrame::Parse(std::move(frame->buffer)));
+ RETURN_NOT_OK(headers.GetStatus(&server_status_));
+ return false;
+ }
+
+ RETURN_NOT_OK(driver_->ExpectFrameType(*frame, FrameType::kPayloadHeader));
+ PayloadHeaderFrame payload_header(std::move(frame->buffer));
+ RETURN_NOT_OK(payload_header.ToFlightData(data));
+
+ // DoGet does not support metadata-only messages, so we can always
+ // assume we have an IPC payload
+ ARROW_ASSIGN_OR_RAISE(auto message, ipc::Message::Open(data->metadata, nullptr));
+
+ if (ipc::Message::HasBody(message->type())) {
+ ARROW_ASSIGN_OR_RAISE(frame, driver_->ReadNextFrame());
+ RETURN_NOT_OK(driver_->ExpectFrameType(*frame, FrameType::kPayloadBody));
+ data->body = std::move(frame->buffer);
+ }
+ return true;
+ }
+};
+
+class WriteClientStream : public UcxClientStream {
+ public:
+ WriteClientStream(UcxClientImpl* impl, ClientConnection conn)
+ : UcxClientStream(impl, std::move(conn)) {
+ std::thread t(&WriteClientStream::DriveWorker, this);
+ driver_thread_.swap(t);
+ }
+ arrow::Result<bool> WriteData(const FlightPayload& payload) override {
+ std::unique_lock<std::mutex> guard(driver_mutex_);
+ if (finished_ || writes_done_) return Status::Invalid("Already done writing");
+ outgoing_ = driver_->SendFlightPayload(payload);
+ working_cv_.notify_all();
+ completed_cv_.wait(guard, [this] { return outgoing_.is_finished(); });
+
+ auto status = outgoing_.status();
+ outgoing_ = Future<>();
+ RETURN_NOT_OK(status);
+ return true;
+ }
+ Status WritesDone() override {
+ std::unique_lock<std::mutex> guard(driver_mutex_);
+ if (!writes_done_) {
+ ARROW_ASSIGN_OR_RAISE(auto headers, HeadersFrame::Make({}));
+ outgoing_ =
+ driver_->SendFrameAsync(FrameType::kHeaders, std::move(headers).GetBuffer());
+ working_cv_.notify_all();
+ completed_cv_.wait(guard, [this] { return outgoing_.is_finished(); });
+
+ writes_done_ = true;
+ auto status = outgoing_.status();
+ outgoing_ = Future<>();
+ RETURN_NOT_OK(status);
+ }
+ return Status::OK();
+ }
+
+ protected:
+ void JoinThread() {
+ try {
+ driver_thread_.join();
+ } catch (const std::system_error&) {
+ // Ignore
+ }
+ }
+ // Flight's API allows concurrent reads/writes, but the UCX driver
+ // here is single-threaded, so push all UCX work onto a single
+ // worker thread
+ void DriveWorker() {
+ while (true) {
+ {
+ std::unique_lock<std::mutex> guard(driver_mutex_);
+ working_cv_.wait(guard,
+ [this] { return incoming_.is_valid() || outgoing_.is_valid(); });
+ }
+
+ while (true) {
+ std::unique_lock<std::mutex> guard(driver_mutex_);
+ if (!incoming_.is_valid() && !outgoing_.is_valid()) break;
+ if (incoming_.is_valid() && incoming_.is_finished()) {
+ if (!incoming_.status().ok()) {
+ io_status_ = incoming_.status();
+ finished_ = true;
+ } else {
+ HandleIncomingMessage(*incoming_.result());
+ }
+ incoming_ = Future<std::shared_ptr<Frame>>();
+ completed_cv_.notify_all();
+ break;
+ }
+ if (outgoing_.is_valid() && outgoing_.is_finished()) {
+ completed_cv_.notify_all();
+ break;
+ }
+ driver_->MakeProgress();
+ }
+ if (finished_) return;
+ }
+ }
+
+ virtual void HandleIncomingMessage(const std::shared_ptr<Frame>& frame) {}
+
+ std::mutex driver_mutex_;
+ std::thread driver_thread_;
+ std::condition_variable completed_cv_;
+ std::condition_variable working_cv_;
+ Future<std::shared_ptr<Frame>> incoming_;
+ Future<> outgoing_;
+};
+
+class PutClientStream : public WriteClientStream {
+ public:
+ using WriteClientStream::WriteClientStream;
+ bool ReadPutMetadata(std::shared_ptr<Buffer>* out) override {
+ std::unique_lock<std::mutex> guard(driver_mutex_);
+ if (finished_) {
+ *out = nullptr;
+ guard.unlock();
+ JoinThread();
+ return false;
+ }
+ next_metadata_ = nullptr;
+ incoming_ = driver_->ReadFrameAsync();
+ working_cv_.notify_all();
+ completed_cv_.wait(guard, [this] { return next_metadata_ != nullptr || finished_; });
+
+ if (finished_) {
+ *out = nullptr;
+ guard.unlock();
+ JoinThread();
+ return false;
+ }
+ *out = std::move(next_metadata_);
+ return true;
+ }
+
+ private:
+ void HandleIncomingMessage(const std::shared_ptr<Frame>& frame) override {
+ // No lock here, since this is called from DriveWorker() which is
+ // holding the lock
+ if (frame->type == FrameType::kBuffer) {
+ next_metadata_ = std::move(frame->buffer);
+ } else if (frame->type == FrameType::kHeaders) {
+ // Trailers, stream is over
+ finished_ = true;
+ HeadersFrame headers;
+ io_status_ = HeadersFrame::Parse(std::move(frame->buffer)).Value(&headers);
+ if (!io_status_.ok()) {
+ finished_ = true;
+ return;
+ }
+ io_status_ = headers.GetStatus(&server_status_);
+ if (!io_status_.ok()) {
+ finished_ = true;
+ return;
+ }
+ } else {
+ finished_ = true;
+ io_status_ =
+ Status::IOError("Unexpected frame type ", static_cast<int>(frame->type));
+ }
+ }
+ std::shared_ptr<Buffer> next_metadata_;
+};
+
+class ExchangeClientStream : public WriteClientStream {
+ public:
+ ExchangeClientStream(UcxClientImpl* impl, ClientConnection conn)
+ : WriteClientStream(impl, std::move(conn)), read_state_(ReadState::kFinished) {}
+
+ bool ReadData(internal::FlightData* data) override {
+ std::unique_lock<std::mutex> guard(driver_mutex_);
+ if (finished_) {
+ guard.unlock();
+ JoinThread();
+ return false;
+ }
+
+ // Drive the read loop here. (We can't recursively call
+ // ReadFrameAsync below since the internal mutex is not
+ // recursive.)
+ read_state_ = ReadState::kExpectHeader;
+ incoming_ = driver_->ReadFrameAsync();
+ working_cv_.notify_all();
+ completed_cv_.wait(guard, [this] { return read_state_ != ReadState::kExpectHeader; });
+ if (read_state_ != ReadState::kFinished) {
+ incoming_ = driver_->ReadFrameAsync();
+ working_cv_.notify_all();
+ completed_cv_.wait(guard, [this] { return read_state_ == ReadState::kFinished; });
+ }
+
+ if (finished_) {
+ guard.unlock();
+ JoinThread();
+ return false;
+ }
+ *data = std::move(next_data_);
+ return true;
+ }
+
+ private:
+ enum class ReadState {
+ kFinished,
+ kExpectHeader,
+ kExpectBody,
+ };
+
+ std::string DebugExpectingString() {
+ switch (read_state_) {
+ case ReadState::kFinished:
+ return "(not expecting a frame)";
+ case ReadState::kExpectHeader:
+ return "payload header frame";
+ case ReadState::kExpectBody:
+ return "payload body frame";
+ }
+ return "(unknown or invalid state)";
+ }
+
+ void HandleIncomingMessage(const std::shared_ptr<Frame>& frame) override {
+ // No lock here, since this is called from MakeProgress()
+ // which is called under the lock already
+ if (frame->type == FrameType::kPayloadHeader) {
+ if (read_state_ != ReadState::kExpectHeader) {
+ finished_ = true;
+ io_status_ = Status::IOError("Got unexpected payload header frame, expected: ",
+ DebugExpectingString());
+ return;
+ }
+
+ PayloadHeaderFrame payload_header(std::move(frame->buffer));
+ io_status_ = payload_header.ToFlightData(&next_data_);
+ if (!io_status_.ok()) {
+ finished_ = true;
+ return;
+ }
+
+ if (next_data_.metadata) {
+ std::unique_ptr<ipc::Message> message;
+ io_status_ = ipc::Message::Open(next_data_.metadata, nullptr).Value(&message);
+ if (!io_status_.ok()) {
+ finished_ = true;
+ return;
+ }
+ if (ipc::Message::HasBody(message->type())) {
+ read_state_ = ReadState::kExpectBody;
+ return;
+ }
+ }
+ read_state_ = ReadState::kFinished;
+ } else if (frame->type == FrameType::kPayloadBody) {
+ next_data_.body = std::move(frame->buffer);
+ read_state_ = ReadState::kFinished;
+ } else if (frame->type == FrameType::kHeaders) {
+ // Trailers, stream is over
+ finished_ = true;
+ read_state_ = ReadState::kFinished;
+ HeadersFrame headers;
+ io_status_ = HeadersFrame::Parse(std::move(frame->buffer)).Value(&headers);
+ if (!io_status_.ok()) {
+ finished_ = true;
+ return;
+ }
+ io_status_ = headers.GetStatus(&server_status_);
+ if (!io_status_.ok()) {
+ finished_ = true;
+ return;
+ }
+ } else {
+ finished_ = true;
+ io_status_ =
+ Status::IOError("Unexpected frame type ", static_cast<int>(frame->type));
+ read_state_ = ReadState::kFinished;
+ }
+ }
+
+ internal::FlightData next_data_;
+ ReadState read_state_;
+};
+
+class UcxClientImpl : public arrow::flight::internal::ClientTransport {
+ public:
+ UcxClientImpl() {}
+
+ virtual ~UcxClientImpl() {
+ if (!ucp_context_) return;
+ auto status = Close();
+ if (!status.ok()) {
+ ARROW_LOG(WARNING) << "UcxClientImpl errored in Close() in destructor: "
+ << status.ToString();
+ }
+ }
+
+ Status Init(const FlightClientOptions& options, const Location& location,
+ const arrow::internal::Uri& uri) override {
+ RETURN_NOT_OK(uri_.Parse(uri.ToString()));
+ {
+ ucp_config_t* ucp_config;
+ ucp_params_t ucp_params;
+ ucs_status_t status;
+
+ status = ucp_config_read(nullptr, nullptr, &ucp_config);
+ RETURN_NOT_OK(FromUcsStatus("ucp_config_read", status));
+
+ // If location is IPv6, must adjust UCX config
+ // XXX: we assume locations always resolve to IPv6 or IPv4 but
+ // that is not necessarily true.
+ {
+ struct sockaddr_storage connect_addr;
+ RETURN_NOT_OK(UriToSockaddr(uri, &connect_addr));
+ if (connect_addr.ss_family == AF_INET6) {
+ status = ucp_config_modify(ucp_config, "AF_PRIO", "inet6");
+ RETURN_NOT_OK(FromUcsStatus("ucp_config_modify", status));
+ }
+ }
+
+ std::memset(&ucp_params, 0, sizeof(ucp_params));
+ ucp_params.field_mask = UCP_PARAM_FIELD_FEATURES;
+ ucp_params.features = UCP_FEATURE_AM | UCP_FEATURE_WAKEUP;
+
+ ucp_context_h ucp_context;
+ status = ucp_init(&ucp_params, ucp_config, &ucp_context);
+ ucp_config_release(ucp_config);
+ RETURN_NOT_OK(FromUcsStatus("ucp_init", status));
+ ucp_context_.reset(new UcpContext(ucp_context));
+ }
+
+ RETURN_NOT_OK(MakeConnection());
+ return Status::OK();
+ }
+
+ Status Close() override {
+ std::unique_lock<std::mutex> connections_mutex_;
+ while (!connections_.empty()) {
+ RETURN_NOT_OK(connections_.front().Close());
+ connections_.pop_front();
+ }
+ return Status::OK();
+ }
+
+ Status GetFlightInfo(const FlightCallOptions& options,
+ const FlightDescriptor& descriptor,
+ std::unique_ptr<FlightInfo>* info) override {
+ ARROW_ASSIGN_OR_RAISE(auto connection, CheckoutConnection(options));
+ UcpCallDriver* driver = connection.driver();
+
+ auto impl = [&]() {
+ RETURN_NOT_OK(driver->StartCall(kMethodGetFlightInfo));
+
+ ARROW_ASSIGN_OR_RAISE(std::string payload, descriptor.SerializeToString());
+
+ RETURN_NOT_OK(driver->SendFrame(FrameType::kBuffer,
+ reinterpret_cast<const uint8_t*>(payload.data()),
+ static_cast<int64_t>(payload.size())));
+
+ ARROW_ASSIGN_OR_RAISE(auto incoming_message, driver->ReadNextFrame());
+ if (incoming_message->type == FrameType::kBuffer) {
+ ARROW_ASSIGN_OR_RAISE(
+ *info, FlightInfo::Deserialize(util::string_view(*incoming_message->buffer)));
+ ARROW_ASSIGN_OR_RAISE(incoming_message, driver->ReadNextFrame());
+ }
+ RETURN_NOT_OK(driver->ExpectFrameType(*incoming_message, FrameType::kHeaders));
+ ARROW_ASSIGN_OR_RAISE(auto headers,
+ HeadersFrame::Parse(std::move(incoming_message->buffer)));
+ Status status;
+ RETURN_NOT_OK(headers.GetStatus(&status));
+ return status;
+ };
+ auto status = impl();
+ return MergeStatuses(std::move(status), ReturnConnection(std::move(connection)));
+ }
+
+ Status DoExchange(const FlightCallOptions& options,
+ std::unique_ptr<internal::ClientDataStream>* out) override {
+ ARROW_ASSIGN_OR_RAISE(auto connection, CheckoutConnection(options));
+ UcpCallDriver* driver = connection.driver();
+
+ auto status = driver->StartCall(kMethodDoExchange);
+ if (ARROW_PREDICT_TRUE(status.ok())) {
+ *out =
+ arrow::internal::make_unique<ExchangeClientStream>(this, std::move(connection));
+ return Status::OK();
+ }
+ return MergeStatuses(std::move(status), ReturnConnection(std::move(connection)));
+ }
+
+ Status DoGet(const FlightCallOptions& options, const Ticket& ticket,
+ std::unique_ptr<internal::ClientDataStream>* stream) override {
+ ARROW_ASSIGN_OR_RAISE(auto connection, CheckoutConnection(options));
+ UcpCallDriver* driver = connection.driver();
+
+ auto impl = [&]() {
+ RETURN_NOT_OK(driver->StartCall(kMethodDoGet));
+ ARROW_ASSIGN_OR_RAISE(std::string payload, ticket.SerializeToString());
+ RETURN_NOT_OK(driver->SendFrame(FrameType::kBuffer,
+ reinterpret_cast<const uint8_t*>(payload.data()),
+ static_cast<int64_t>(payload.size())));
+ *stream =
+ arrow::internal::make_unique<GetClientStream>(this, std::move(connection));
+ return Status::OK();
+ };
+
+ auto status = impl();
+ if (ARROW_PREDICT_TRUE(status.ok())) return status;
+ return MergeStatuses(std::move(status), ReturnConnection(std::move(connection)));
+ }
+
+ Status DoPut(const FlightCallOptions& options,
+ std::unique_ptr<internal::ClientDataStream>* out) override {
+ ARROW_ASSIGN_OR_RAISE(auto connection, CheckoutConnection(options));
+ UcpCallDriver* driver = connection.driver();
+
+ auto status = driver->StartCall(kMethodDoPut);
+ if (ARROW_PREDICT_TRUE(status.ok())) {
+ *out = arrow::internal::make_unique<PutClientStream>(this, std::move(connection));
+ return Status::OK();
+ }
+ return MergeStatuses(std::move(status), ReturnConnection(std::move(connection)));
+ }
+
+ Status DoAction(const FlightCallOptions& options, const Action& action,
+ std::unique_ptr<ResultStream>* results) override {
+ // XXX: fake this for now to get the perf test to work
+ return Status::OK();
+ }
+
+ Status MakeConnection() {
+ ClientConnection conn;
+ RETURN_NOT_OK(conn.Init(ucp_context_, uri_));
+ connections_.push_back(std::move(conn));
+ return Status::OK();
+ }
+
+ arrow::Result<ClientConnection> CheckoutConnection(const FlightCallOptions& options) {
+ std::unique_lock<std::mutex> connections_mutex_;
+ if (connections_.empty()) RETURN_NOT_OK(MakeConnection());
+ ClientConnection conn = std::move(connections_.front());
+ conn.driver()->set_memory_manager(options.memory_manager);
+ conn.driver()->set_read_memory_pool(options.read_options.memory_pool);
+ conn.driver()->set_write_memory_pool(options.write_options.memory_pool);
+ connections_.pop_front();
+ return conn;
+ }
+
+ Status ReturnConnection(ClientConnection conn) {
+ std::unique_lock<std::mutex> connections_mutex_;
+ // TODO(ARROW-16127): for future improvement: reclaim clients
+ // asynchronously in the background (try to avoid issues like
+ // constantly opening/closing clients because the application is
+ // just barely over the limit of open connections)
+ if (connections_.size() >= kMaxOpenConnections) {
+ RETURN_NOT_OK(conn.Close());
+ return Status::OK();
+ }
+ connections_.push_back(std::move(conn));
+ return Status::OK();
+ }
+
+ private:
+ static constexpr size_t kMaxOpenConnections = 3;
+
+ arrow::internal::Uri uri_;
+ std::shared_ptr<UcpContext> ucp_context_;
+ std::mutex connections_mutex_;
+ std::deque<ClientConnection> connections_;
+};
+
+Status UcxClientStream::DoFinish() {
+ RETURN_NOT_OK(WritesDone());
+ if (!finished_) {
+ internal::FlightData message;
+ std::shared_ptr<Buffer> metadata;
+ while (ReadData(&message)) {
+ }
+ while (ReadPutMetadata(&metadata)) {
+ }
+ finished_ = true;
+ }
+ if (impl_) {
+ auto status = impl_->ReturnConnection(std::move(conn_));
+ impl_ = nullptr;
+ driver_ = nullptr;
+ if (!status.ok()) {
+ if (io_status_.ok()) {
+ io_status_ = std::move(status);
+ } else {
+ io_status_ = Status::FromDetailAndArgs(
+ io_status_.code(), io_status_.detail(), io_status_.message(),
+ ". Transport context: ", status.ToString());
+ }
+ }
+ }
+ return MergeStatuses(server_status_, io_status_);
+}
+} // namespace
+
+std::unique_ptr<arrow::flight::internal::ClientTransport> MakeUcxClientImpl() {
+ return arrow::internal::make_unique<UcxClientImpl>();
+}
+
+} // namespace ucx
+} // namespace transport
+} // namespace flight
+} // namespace arrow
diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc b/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc
new file mode 100644
index 0000000000..ab4cc323f4
--- /dev/null
+++ b/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc
@@ -0,0 +1,1171 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/flight/transport/ucx/ucx_internal.h"
+
+#include <array>
+#include <limits>
+#include <mutex>
+#include <unordered_map>
+
+#include "arrow/buffer.h"
+#include "arrow/flight/transport/ucx/util_internal.h"
+#include "arrow/flight/types.h"
+#include "arrow/util/base64.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/make_unique.h"
+#include "arrow/util/uri.h"
+
+namespace arrow {
+namespace flight {
+namespace transport {
+namespace ucx {
+
+// Defines to test different implementation strategies
+// Enable the CONTIG path for CPU-only data
+// #define ARROW_FLIGHT_UCX_SEND_CONTIG
+// Enable ucp_mem_map in IOV path
+// #define ARROW_FLIGHT_UCX_SEND_IOV_MAP
+
+constexpr char kHeaderMethod[] = ":method:";
+
+namespace {
+Status SizeToUInt32BytesBe(const int64_t in, uint8_t* out) {
+ if (ARROW_PREDICT_FALSE(in < 0)) {
+ return Status::Invalid("Length cannot be negative");
+ } else if (ARROW_PREDICT_FALSE(
+ in > static_cast<int64_t>(std::numeric_limits<uint32_t>::max()))) {
+ return Status::Invalid("Length cannot exceed uint32_t");
+ }
+ UInt32ToBytesBe(static_cast<uint32_t>(in), out);
+ return Status::OK();
+}
+ucs_memory_type InferMemoryType(const Buffer& buffer) {
+ if (!buffer.is_cpu()) {
+ return UCS_MEMORY_TYPE_CUDA;
+ }
+ return UCS_MEMORY_TYPE_UNKNOWN;
+}
+void TryMapBuffer(ucp_context_h context, const void* buffer, const size_t size,
+ ucs_memory_type memory_type, ucp_mem_h* memh_p) {
+ ucp_mem_map_params_t map_param;
+ std::memset(&map_param, 0, sizeof(map_param));
+ map_param.field_mask = UCP_MEM_MAP_PARAM_FIELD_ADDRESS |
+ UCP_MEM_MAP_PARAM_FIELD_LENGTH |
+ UCP_MEM_MAP_PARAM_FIELD_MEMORY_TYPE;
+ map_param.address = const_cast<void*>(buffer);
+ map_param.length = size;
+ map_param.memory_type = memory_type;
+ auto ucs_status = ucp_mem_map(context, &map_param, memh_p);
+ if (ucs_status != UCS_OK) {
+ *memh_p = nullptr;
+ ARROW_LOG(WARNING) << "Could not map memory: "
+ << FromUcsStatus("ucp_mem_map", ucs_status);
+ }
+}
+void TryMapBuffer(ucp_context_h context, const Buffer& buffer, ucp_mem_h* memh_p) {
+ TryMapBuffer(context, reinterpret_cast<void*>(buffer.address()),
+ static_cast<size_t>(buffer.size()), InferMemoryType(buffer), memh_p);
+}
+void TryUnmapBuffer(ucp_context_h context, ucp_mem_h memh_p) {
+ if (memh_p) {
+ auto ucs_status = ucp_mem_unmap(context, memh_p);
+ if (ucs_status != UCS_OK) {
+ ARROW_LOG(WARNING) << "Could not unmap memory: "
+ << FromUcsStatus("ucp_mem_unmap", ucs_status);
+ }
+ }
+}
+
+/// \brief Wrapper around a UCX zero copy buffer (a host memory DATA
+/// buffer).
+///
+/// Owns a reference to the associated worker to avoid undefined
+/// behavior.
+class UcxDataBuffer : public Buffer {
+ public:
+ explicit UcxDataBuffer(std::shared_ptr<UcpWorker> worker, void* data, size_t size)
+ : Buffer(reinterpret_cast<uint8_t*>(data), static_cast<int64_t>(size)),
+ worker_(std::move(worker)) {}
+
+ ~UcxDataBuffer() {
+ ucp_am_data_release(worker_->get(),
+ const_cast<void*>(reinterpret_cast<const void*>(data())));
+ }
+
+ private:
+ std::shared_ptr<UcpWorker> worker_;
+};
+}; // namespace
+
+constexpr size_t FrameHeader::kFrameHeaderBytes;
+constexpr uint8_t FrameHeader::kFrameVersion;
+
+Status FrameHeader::Set(FrameType frame_type, uint32_t counter, int64_t body_size) {
+ header[0] = kFrameVersion;
+ header[1] = static_cast<uint8_t>(frame_type);
+ UInt32ToBytesBe(counter, header.data() + 4);
+ RETURN_NOT_OK(SizeToUInt32BytesBe(body_size, header.data() + 8));
+ return Status::OK();
+}
+
+arrow::Result<std::shared_ptr<Frame>> Frame::ParseHeader(const void* header,
+ size_t header_length) {
+ if (header_length < FrameHeader::kFrameHeaderBytes) {
+ return Status::IOError("Header is too short, must be at least ",
+ FrameHeader::kFrameHeaderBytes, " bytes, got ", header_length);
+ }
+
+ const uint8_t* frame_header = reinterpret_cast<const uint8_t*>(header);
+ if (frame_header[0] != FrameHeader::kFrameVersion) {
+ return Status::IOError("Expected frame version ",
+ static_cast<int>(FrameHeader::kFrameVersion), " but got ",
+ static_cast<int>(frame_header[0]));
+ } else if (frame_header[1] > static_cast<uint8_t>(FrameType::kMaxFrameType)) {
+ return Status::IOError("Unknown frame type ", static_cast<int>(frame_header[1]));
+ }
+
+ const FrameType frame_type = static_cast<FrameType>(frame_header[1]);
+ const uint32_t frame_counter = BytesToUInt32Be(frame_header + 4);
+ const uint32_t frame_size = BytesToUInt32Be(frame_header + 8);
+
+ if (frame_type == FrameType::kDisconnect) {
+ return Status::Cancelled("Client initiated disconnect");
+ }
+
+ return std::make_shared<Frame>(frame_type, frame_size, frame_counter, nullptr);
+}
+
+arrow::Result<HeadersFrame> HeadersFrame::Parse(std::unique_ptr<Buffer> buffer) {
+ HeadersFrame result;
+ const uint8_t* payload = buffer->data();
+ const uint8_t* end = payload + buffer->size();
+ if (ARROW_PREDICT_FALSE((end - payload) < 4)) {
+ return Status::Invalid("Buffer underflow, expected number of headers");
+ }
+ const uint32_t num_headers = BytesToUInt32Be(payload);
+ payload += 4;
+ for (uint32_t i = 0; i < num_headers; i++) {
+ if (ARROW_PREDICT_FALSE((end - payload) < 4)) {
+ return Status::Invalid("Buffer underflow, expected length of key ", i + 1);
+ }
+ const uint32_t key_length = BytesToUInt32Be(payload);
+ payload += 4;
+
+ if (ARROW_PREDICT_FALSE((end - payload) < 4)) {
+ return Status::Invalid("Buffer underflow, expected length of value ", i + 1);
+ }
+ const uint32_t value_length = BytesToUInt32Be(payload);
+ payload += 4;
+
+ if (ARROW_PREDICT_FALSE((end - payload) < key_length)) {
+ return Status::Invalid("Buffer underflow, expected key ", i + 1, " to have length ",
+ key_length, ", but only ", (end - payload), " bytes remain");
+ }
+ const util::string_view key(reinterpret_cast<const char*>(payload), key_length);
+ payload += key_length;
+
+ if (ARROW_PREDICT_FALSE((end - payload) < value_length)) {
+ return Status::Invalid("Buffer underflow, expected value ", i + 1,
+ " to have length ", value_length, ", but only ",
+ (end - payload), " bytes remain");
+ }
+ const util::string_view value(reinterpret_cast<const char*>(payload), value_length);
+ payload += value_length;
+ result.headers_.emplace_back(key, value);
+ }
+
+ result.buffer_ = std::move(buffer);
+ return result;
+}
+arrow::Result<HeadersFrame> HeadersFrame::Make(
+ const std::vector<std::pair<std::string, std::string>>& headers) {
+ int32_t total_length = 4 /* # of headers */;
+ for (const auto& header : headers) {
+ total_length += 4 /* key length */ + 4 /* value length */ +
+ header.first.size() /* key */ + header.second.size();
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto buffer, AllocateBuffer(total_length));
+ uint8_t* payload = buffer->mutable_data();
+
+ RETURN_NOT_OK(SizeToUInt32BytesBe(headers.size(), payload));
+ payload += 4;
+ for (const auto& header : headers) {
+ RETURN_NOT_OK(SizeToUInt32BytesBe(header.first.size(), payload));
+ payload += 4;
+ RETURN_NOT_OK(SizeToUInt32BytesBe(header.second.size(), payload));
+ payload += 4;
+ std::memcpy(payload, header.first.data(), header.first.size());
+ payload += header.first.size();
+ std::memcpy(payload, header.second.data(), header.second.size());
+ payload += header.second.size();
+ }
+ return Parse(std::move(buffer));
+}
+arrow::Result<HeadersFrame> HeadersFrame::Make(
+ const Status& status,
+ const std::vector<std::pair<std::string, std::string>>& headers) {
+ auto all_headers = headers;
+ all_headers.emplace_back(kHeaderStatusCode,
+ std::to_string(static_cast<int32_t>(status.code())));
+ all_headers.emplace_back(kHeaderStatusMessage, status.message());
+ if (status.detail()) {
+ auto fsd = FlightStatusDetail::UnwrapStatus(status);
+ if (fsd) {
+ all_headers.emplace_back(kHeaderStatusDetailCode,
+ std::to_string(static_cast<int32_t>(fsd->code())));
+ all_headers.emplace_back(kHeaderStatusDetail, fsd->extra_info());
+ } else {
+ all_headers.emplace_back(kHeaderStatusDetail, status.detail()->ToString());
+ }
+ }
+ return Make(all_headers);
+}
+
+arrow::Result<util::string_view> HeadersFrame::Get(const std::string& key) {
+ for (const auto& pair : headers_) {
+ if (pair.first == key) return pair.second;
+ }
+ return Status::KeyError(key);
+}
+
+Status HeadersFrame::GetStatus(Status* out) {
+ util::string_view code_str, message_str;
+ auto status = Get(kHeaderStatusCode).Value(&code_str);
+ if (!status.ok()) {
+ return Status::KeyError("Server did not send status code header ", kHeaderStatusCode);
+ }
+
+ StatusCode status_code = StatusCode::OK;
+ auto code = std::strtol(code_str.data(), nullptr, /*base=*/10);
+ switch (code) {
+ case 0:
+ status_code = StatusCode::OK;
+ break;
+ case 1:
+ status_code = StatusCode::OutOfMemory;
+ break;
+ case 2:
+ status_code = StatusCode::KeyError;
+ break;
+ case 3:
+ status_code = StatusCode::TypeError;
+ break;
+ case 4:
+ status_code = StatusCode::Invalid;
+ break;
+ case 5:
+ status_code = StatusCode::IOError;
+ break;
+ case 6:
+ status_code = StatusCode::CapacityError;
+ break;
+ case 7:
+ status_code = StatusCode::IndexError;
+ break;
+ case 8:
+ status_code = StatusCode::Cancelled;
+ break;
+ case 9:
+ status_code = StatusCode::UnknownError;
+ break;
+ case 10:
+ status_code = StatusCode::NotImplemented;
+ break;
+ case 11:
+ status_code = StatusCode::SerializationError;
+ break;
+ case 13:
+ status_code = StatusCode::RError;
+ break;
+ case 40:
+ status_code = StatusCode::CodeGenError;
+ break;
+ case 41:
+ status_code = StatusCode::ExpressionValidationError;
+ break;
+ case 42:
+ status_code = StatusCode::ExecutionError;
+ break;
+ case 45:
+ status_code = StatusCode::AlreadyExists;
+ break;
+ default:
+ status_code = StatusCode::UnknownError;
+ break;
+ }
+ if (status_code == StatusCode::OK) {
+ *out = Status::OK();
+ return Status::OK();
+ }
+
+ status = Get(kHeaderStatusMessage).Value(&message_str);
+ if (!status.ok()) {
+ *out = Status(status_code, "Server did not send status message header", nullptr);
+ return Status::OK();
+ }
+
+ util::string_view detail_code_str, detail_str;
+ FlightStatusCode detail_code = FlightStatusCode::Internal;
+
+ if (Get(kHeaderStatusDetailCode).Value(&detail_code_str).ok()) {
+ auto detail_code_int = std::strtol(detail_code_str.data(), nullptr, /*base=*/10);
+ switch (detail_code_int) {
+ case 1:
+ detail_code = FlightStatusCode::TimedOut;
+ break;
+ case 2:
+ detail_code = FlightStatusCode::Cancelled;
+ break;
+ case 3:
+ detail_code = FlightStatusCode::Unauthenticated;
+ break;
+ case 4:
+ detail_code = FlightStatusCode::Unauthorized;
+ break;
+ case 5:
+ detail_code = FlightStatusCode::Unavailable;
+ break;
+ case 6:
+ detail_code = FlightStatusCode::Failed;
+ break;
+ case 0:
+ default:
+ detail_code = FlightStatusCode::Internal;
+ break;
+ }
+ }
+ ARROW_UNUSED(Get(kHeaderStatusDetail).Value(&detail_str));
+
+ std::shared_ptr<StatusDetail> detail = nullptr;
+ if (!detail_str.empty()) {
+ detail = std::make_shared<FlightStatusDetail>(detail_code, std::string(detail_str));
+ }
+ *out = Status(status_code, std::string(message_str), std::move(detail));
+ return Status::OK();
+}
+
+namespace {
+static constexpr uint32_t kMissingFieldSentinel = std::numeric_limits<uint32_t>::max();
+static constexpr uint32_t kInt32Max =
+ static_cast<uint32_t>(std::numeric_limits<int32_t>::max());
+arrow::Result<uint32_t> PayloadHeaderFieldSize(const std::string& field,
+ const std::shared_ptr<Buffer>& data,
+ uint32_t* total_size) {
+ if (!data) return kMissingFieldSentinel;
+ if (data->size() > kInt32Max) {
+ return Status::Invalid(field, " must be less than 2 GiB, was: ", data->size());
+ }
+ *total_size += static_cast<uint32_t>(data->size());
+ // Check for underflow
+ if (*total_size < 0) return Status::Invalid("Payload header must fit in a uint32_t");
+ return static_cast<uint32_t>(data->size());
+}
+uint8_t* PackField(uint32_t size, const std::shared_ptr<Buffer>& data, uint8_t* out) {
+ UInt32ToBytesBe(size, out);
+ if (size != kMissingFieldSentinel) {
+ std::memcpy(out + 4, data->data(), size);
+ return out + 4 + size;
+ } else {
+ return out + 4;
+ }
+}
+} // namespace
+
+arrow::Result<PayloadHeaderFrame> PayloadHeaderFrame::Make(const FlightPayload& payload,
+ MemoryPool* memory_pool) {
+ // Assemble all non-data fields here. Presumably this is much less
+ // than data size so we will pay the copy.
+
+ // Structure per field: [4 byte length][data]. If a field is not
+ // present, UINT32_MAX is used as the sentinel (since 0-sized fields
+ // are acceptable)
+ uint32_t header_size = 12;
+ ARROW_ASSIGN_OR_RAISE(
+ const uint32_t descriptor_size,
+ PayloadHeaderFieldSize("descriptor", payload.descriptor, &header_size));
+ ARROW_ASSIGN_OR_RAISE(
+ const uint32_t app_metadata_size,
+ PayloadHeaderFieldSize("app_metadata", payload.app_metadata, &header_size));
+ ARROW_ASSIGN_OR_RAISE(
+ const uint32_t ipc_metadata_size,
+ PayloadHeaderFieldSize("ipc_message.metadata", payload.ipc_message.metadata,
+ &header_size));
+
+ ARROW_ASSIGN_OR_RAISE(auto header_buffer, AllocateBuffer(header_size, memory_pool));
+ uint8_t* payload_header = header_buffer->mutable_data();
+
+ payload_header = PackField(descriptor_size, payload.descriptor, payload_header);
+ payload_header = PackField(app_metadata_size, payload.app_metadata, payload_header);
+ payload_header =
+ PackField(ipc_metadata_size, payload.ipc_message.metadata, payload_header);
+
+ return PayloadHeaderFrame(std::move(header_buffer));
+}
+Status PayloadHeaderFrame::ToFlightData(internal::FlightData* data) {
+ std::shared_ptr<Buffer> buffer = std::move(buffer_);
+
+ // Unpack the descriptor
+ uint32_t offset = 0;
+ uint32_t size = BytesToUInt32Be(buffer->data());
+ offset += 4;
+ if (size != kMissingFieldSentinel) {
+ if (static_cast<int64_t>(offset + size) > buffer->size()) {
+ return Status::Invalid("Buffer is too small: expected ", offset + size,
+ " bytes but have ", buffer->size());
+ }
+ util::string_view desc(reinterpret_cast<const char*>(buffer->data() + offset), size);
+ data->descriptor.reset(new FlightDescriptor());
+ ARROW_ASSIGN_OR_RAISE(*data->descriptor, FlightDescriptor::Deserialize(desc));
+ offset += size;
+ } else {
+ data->descriptor = nullptr;
+ }
+
+ // Unpack app_metadata
+ size = BytesToUInt32Be(buffer->data() + offset);
+ offset += 4;
+ // While we properly handle zero-size vs nullptr metadata here, gRPC
+ // doesn't (Protobuf doesn't differentiate between the two)
+ if (size != kMissingFieldSentinel) {
+ if (static_cast<int64_t>(offset + size) > buffer->size()) {
+ return Status::Invalid("Buffer is too small: expected ", offset + size,
+ " bytes but have ", buffer->size());
+ }
+ data->app_metadata = SliceBuffer(buffer, offset, size);
+ offset += size;
+ } else {
+ data->app_metadata = nullptr;
+ }
+
+ // Unpack the IPC header
+ size = BytesToUInt32Be(buffer->data() + offset);
+ offset += 4;
+ if (size != kMissingFieldSentinel) {
+ if (static_cast<int64_t>(offset + size) > buffer->size()) {
+ return Status::Invalid("Buffer is too small: expected ", offset + size,
+ " bytes but have ", buffer->size());
+ }
+ data->metadata = SliceBuffer(std::move(buffer), offset, size);
+ } else {
+ data->metadata = nullptr;
+ }
+ data->body = nullptr;
+ return Status::OK();
+}
+
+// pImpl the driver since async methods require a stable address
+class UcpCallDriver::Impl {
+ public:
+#if defined(ARROW_FLIGHT_UCX_SEND_CONTIG)
+ constexpr static bool kEnableContigSend = true;
+#else
+ constexpr static bool kEnableContigSend = false;
+#endif
+
+ Impl(std::shared_ptr<UcpWorker> worker, ucp_ep_h endpoint)
+ : padding_bytes_({0, 0, 0, 0, 0, 0, 0, 0}),
+ worker_(std::move(worker)),
+ endpoint_(endpoint),
+ read_memory_pool_(default_memory_pool()),
+ write_memory_pool_(default_memory_pool()),
+ memory_manager_(CPUDevice::Instance()->default_memory_manager()),
+ name_("(unknown remote)"),
+ counter_(0) {
+#if defined(ARROW_FLIGHT_UCX_SEND_IOV_MAP)
+ TryMapBuffer(worker_->context().get(), padding_bytes_.data(), padding_bytes_.size(),
+ UCS_MEMORY_TYPE_HOST, &padding_memh_p_);
+#endif
+
+ ucp_ep_attr_t attrs;
+ std::memset(&attrs, 0, sizeof(attrs));
+ attrs.field_mask =
+ UCP_EP_ATTR_FIELD_LOCAL_SOCKADDR | UCP_EP_ATTR_FIELD_REMOTE_SOCKADDR;
+ if (ucp_ep_query(endpoint_, &attrs) == UCS_OK) {
+ std::string local_addr, remote_addr;
+ ARROW_UNUSED(SockaddrToString(attrs.local_sockaddr).Value(&local_addr));
+ ARROW_UNUSED(SockaddrToString(attrs.remote_sockaddr).Value(&remote_addr));
+ name_ = "local:" + local_addr + ";remote:" + remote_addr;
+ }
+ }
+
+ ~Impl() {
+#if defined(ARROW_FLIGHT_UCX_SEND_IOV_MAP)
+ TryUnmapBuffer(worker_->context().get(), padding_memh_p_);
+#endif
+ }
+
+ arrow::Result<std::shared_ptr<Frame>> ReadNextFrame() {
+ auto fut = ReadFrameAsync();
+ while (!fut.is_finished()) MakeProgress();
+ RETURN_NOT_OK(fut.status());
+ return fut.MoveResult();
+ }
+
+ Future<std::shared_ptr<Frame>> ReadFrameAsync() {
+ RETURN_NOT_OK(CheckClosed());
+
+ std::unique_lock<std::mutex> guard(frame_mutex_);
+ if (ARROW_PREDICT_FALSE(!status_.ok())) return status_;
+
+ // Expected value of "counter" field in the frame header
+ const uint32_t counter_value = next_counter_++;
+ auto it = frames_.find(counter_value);
+ if (it != frames_.end()) {
+ // Message already delivered, return it
+ Future<std::shared_ptr<Frame>> fut = it->second;
+ frames_.erase(it);
+ return fut;
+ }
+ // Message not yet delivered, insert a future and wait
+ auto pair = frames_.insert({counter_value, Future<std::shared_ptr<Frame>>::Make()});
+ DCHECK(pair.second);
+ return pair.first->second;
+ }
+
+ Status SendFrame(FrameType frame_type, const uint8_t* data, const int64_t size) {
+ RETURN_NOT_OK(CheckClosed());
+
+ void* request = nullptr;
+ ucp_request_param_t request_param;
+ request_param.op_attr_mask = UCP_OP_ATTR_FIELD_FLAGS;
+ request_param.flags = UCP_AM_SEND_FLAG_REPLY;
+
+ // Send frame header
+ FrameHeader header;
+ RETURN_NOT_OK(header.Set(frame_type, counter_++, size));
+ if (size == 0) {
+ // UCX appears to crash on zero-byte payloads
+ request = ucp_am_send_nbx(endpoint_, kUcpAmHandlerId, header.data(), header.size(),
+ padding_bytes_.data(),
+ /*size=*/1, &request_param);
+ } else {
+ request = ucp_am_send_nbx(endpoint_, kUcpAmHandlerId, header.data(), header.size(),
+ data, size, &request_param);
+ }
+ RETURN_NOT_OK(CompleteRequestBlocking("ucp_am_send_nbx", request));
+
+ return Status::OK();
+ }
+
+ Future<> SendFrameAsync(FrameType frame_type, std::unique_ptr<Buffer> buffer) {
+ RETURN_NOT_OK(CheckClosed());
+
+ ucp_request_param_t request_param;
+ std::memset(&request_param, 0, sizeof(request_param));
+ request_param.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_DATATYPE |
+ UCP_OP_ATTR_FIELD_FLAGS | UCP_OP_ATTR_FIELD_USER_DATA;
+ request_param.cb.send = AmSendCallback;
+ request_param.datatype = ucp_dt_make_contig(1);
+ request_param.flags = UCP_AM_SEND_FLAG_REPLY;
+
+ const int64_t size = buffer->size();
+ if (size == 0) {
+ // UCX appears to crash on zero-byte payloads
+ ARROW_ASSIGN_OR_RAISE(auto buffer, AllocateBuffer(1, write_memory_pool_));
+ }
+
+ std::unique_ptr<PendingContigSend> pending_send(new PendingContigSend());
+ RETURN_NOT_OK(pending_send->header.Set(frame_type, counter_++, size));
+ pending_send->ipc_message = std::move(buffer);
+ pending_send->driver = this;
+ pending_send->completed = Future<>::Make();
+ pending_send->memh_p = nullptr;
+
+ request_param.user_data = pending_send.release();
+ {
+ auto* pending_send = reinterpret_cast<PendingContigSend*>(request_param.user_data);
+
+ void* request = ucp_am_send_nbx(
+ endpoint_, kUcpAmHandlerId, pending_send->header.data(),
+ pending_send->header.size(),
+ reinterpret_cast<void*>(pending_send->ipc_message->mutable_data()),
+ static_cast<size_t>(pending_send->ipc_message->size()), &request_param);
+ if (!request) {
+ // Request completed immediately
+ delete pending_send;
+ return Status::OK();
+ } else if (UCS_PTR_IS_ERR(request)) {
+ delete pending_send;
+ return FromUcsStatus("ucp_am_send_nbx", UCS_PTR_STATUS(request));
+ }
+ return pending_send->completed;
+ }
+ }
+
+ Future<> SendFlightPayload(const FlightPayload& payload) {
+ static const int64_t kMaxBatchSize = std::numeric_limits<int32_t>::max();
+ RETURN_NOT_OK(CheckClosed());
+
+ if (payload.ipc_message.body_length > kMaxBatchSize) {
+ return Status::Invalid("Cannot send record batches exceeding 2GiB yet");
+ }
+
+ {
+ ARROW_ASSIGN_OR_RAISE(auto frame,
+ PayloadHeaderFrame::Make(payload, write_memory_pool_));
+ RETURN_NOT_OK(SendFrame(FrameType::kPayloadHeader, frame.data(), frame.size()));
+ }
+
+ if (!ipc::Message::HasBody(payload.ipc_message.type)) {
+ return Status::OK();
+ }
+
+ // While IOV (scatter-gather) might seem like it avoids a memcpy,
+ // profiling shows that at least for the TCP/SHM/RDMA transports,
+ // UCX just does a memcpy internally. Furthermore, on the receiver
+ // side, a sender-side IOV send prevents optimizations based on
+ // mapped buffers (UCX will memcpy to the destination buffer
+ // regardless of whether it's mapped or not).
+
+ // If all buffers are on the CPU, concatenate them ourselves and
+ // do a regular send to avoid this. Else, use IOV and let UCX
+ // figure out what to do.
+
+ // Weirdness: UCX prefers TCP over shared memory for CONTIG? We
+ // can avoid this by setting UCX_RNDV_THRESH=inf, this will make
+ // UCX prefer shared memory again. However, we still want to avoid
+ // the CONTIG path when shared memory is available, because the
+ // total amount of time spent in memcpy is greater than using IOV
+ // and letting UCX handle it.
+
+ // Consider: if we can figure out how to make IOV always as fast
+ // as CONTIG, we can just send the metadata fields as part of the
+ // IOV payload and avoid having to send two distinct messages.
+
+ bool all_cpu = true;
+ int32_t total_buffers = 0;
+ for (const auto& buffer : payload.ipc_message.body_buffers) {
+ if (!buffer || buffer->size() == 0) continue;
+ all_cpu = all_cpu && buffer->is_cpu();
+ total_buffers++;
+
+ // Arrow IPC requires that we align buffers to 8 byte boundary
+ const auto remainder = static_cast<int>(
+ bit_util::RoundUpToMultipleOf8(buffer->size()) - buffer->size());
+ if (remainder) total_buffers++;
+ }
+
+ ucp_request_param_t request_param;
+ std::memset(&request_param, 0, sizeof(request_param));
+ request_param.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_DATATYPE |
+ UCP_OP_ATTR_FIELD_FLAGS | UCP_OP_ATTR_FIELD_USER_DATA;
+ request_param.cb.send = AmSendCallback;
+ request_param.flags = UCP_AM_SEND_FLAG_REPLY;
+
+ std::unique_ptr<PendingAmSend> pending_send;
+ void* send_data = nullptr;
+ size_t send_size = 0;
+
+ if (!all_cpu) {
+ request_param.op_attr_mask =
+ request_param.op_attr_mask | UCP_OP_ATTR_FIELD_MEMORY_TYPE;
+ // XXX: UCX doesn't appear to autodetect this correctly if we
+ // use UNKNOWN
+ request_param.memory_type = UCS_MEMORY_TYPE_CUDA;
+ }
+
+ if (kEnableContigSend && all_cpu) {
+ // CONTIG - concatenate buffers into one before sending
+
+ // TODO(ARROW-16126): this needs to be pipelined since it can be expensive.
+ // Preliminary profiling shows ~5% overhead just from mapping the buffer
+ // alone (on Infiniband; it seems to be trivial for shared memory)
+ request_param.datatype = ucp_dt_make_contig(1);
+ pending_send = arrow::internal::make_unique<PendingContigSend>();
+ auto* pending_contig = reinterpret_cast<PendingContigSend*>(pending_send.get());
+
+ const int64_t body_length = std::max<int64_t>(payload.ipc_message.body_length, 1);
+ ARROW_ASSIGN_OR_RAISE(pending_contig->ipc_message,
+ AllocateBuffer(body_length, write_memory_pool_));
+ TryMapBuffer(worker_->context().get(), *pending_contig->ipc_message,
+ &pending_contig->memh_p);
+
+ uint8_t* ipc_message = pending_contig->ipc_message->mutable_data();
+ if (payload.ipc_message.body_length == 0) {
+ std::memset(ipc_message, '\0', 1);
+ }
+
+ for (const auto& buffer : payload.ipc_message.body_buffers) {
+ if (!buffer || buffer->size() == 0) continue;
+
+ std::memcpy(ipc_message, buffer->data(), buffer->size());
+ ipc_message += buffer->size();
+
+ const auto remainder = static_cast<int>(
+ bit_util::RoundUpToMultipleOf8(buffer->size()) - buffer->size());
+ if (remainder) {
+ std::memset(ipc_message, 0, remainder);
+ ipc_message += remainder;
+ }
+ }
+
+ send_data = reinterpret_cast<void*>(pending_contig->ipc_message->mutable_data());
+ send_size = static_cast<size_t>(pending_contig->ipc_message->size());
+ } else {
+ // IOV - let UCX use scatter-gather path
+ request_param.datatype = UCP_DATATYPE_IOV;
+ pending_send = arrow::internal::make_unique<PendingIovSend>();
+ auto* pending_iov = reinterpret_cast<PendingIovSend*>(pending_send.get());
+
+ pending_iov->payload = payload;
+ pending_iov->iovs.resize(total_buffers);
+ ucp_dt_iov_t* iov = pending_iov->iovs.data();
+#if defined(ARROW_FLIGHT_UCX_SEND_IOV_MAP)
+ // XXX: this seems to have no benefits in tests so far
+ pending_iov->memh_ps.resize(total_buffers);
+ ucp_mem_h* memh_p = pending_iov->memh_ps.data();
+#endif
+ for (const auto& buffer : payload.ipc_message.body_buffers) {
+ if (!buffer || buffer->size() == 0) continue;
+
+ iov->buffer = const_cast<void*>(reinterpret_cast<const void*>(buffer->address()));
+ iov->length = buffer->size();
+ ++iov;
+
+#if defined(ARROW_FLIGHT_UCX_SEND_IOV_MAP)
+ TryMapBuffer(worker_->context().get(), *buffer, memh_p);
+ memh_p++;
+#endif
+
+ const auto remainder = static_cast<int>(
+ bit_util::RoundUpToMultipleOf8(buffer->size()) - buffer->size());
+ if (remainder) {
+ iov->buffer =
+ const_cast<void*>(reinterpret_cast<const void*>(padding_bytes_.data()));
+ iov->length = remainder;
+ ++iov;
+ }
+ }
+
+ if (total_buffers == 0) {
+ // UCX cannot handle zero-byte payloads
+ pending_iov->iovs.resize(1);
+ pending_iov->iovs[0].buffer =
+ const_cast<void*>(reinterpret_cast<const void*>(padding_bytes_.data()));
+ pending_iov->iovs[0].length = 1;
+ }
+
+ send_data = pending_iov->iovs.data();
+ send_size = pending_iov->iovs.size();
+ }
+
+ DCHECK(send_data) << "Payload cannot be nullptr";
+ DCHECK_GT(send_size, 0) << "Payload cannot be empty";
+
+ RETURN_NOT_OK(pending_send->header.Set(FrameType::kPayloadBody, counter_++,
+ payload.ipc_message.body_length));
+ pending_send->driver = this;
+ pending_send->completed = Future<>::Make();
+
+ request_param.user_data = pending_send.release();
+ {
+ auto* pending_send = reinterpret_cast<PendingAmSend*>(request_param.user_data);
+
+ void* request = ucp_am_send_nbx(
+ endpoint_, kUcpAmHandlerId, pending_send->header.data(),
+ pending_send->header.size(), send_data, send_size, &request_param);
+ if (!request) {
+ // Request completed immediately
+ delete pending_send;
+ return Status::OK();
+ } else if (UCS_PTR_IS_ERR(request)) {
+ delete pending_send;
+ return FromUcsStatus("ucp_am_send_nbx", UCS_PTR_STATUS(request));
+ }
+ return pending_send->completed;
+ }
+ }
+
+ Status Close() {
+ if (!endpoint_) return Status::OK();
+
+ for (auto& item : frames_) {
+ item.second.MarkFinished(Status::Cancelled("UcpCallDriver is being closed"));
+ }
+ frames_.clear();
+
+ void* request = ucp_ep_close_nb(endpoint_, UCP_EP_CLOSE_MODE_FLUSH);
+ ucs_status_t status = UCS_OK;
+ std::string origin = "ucp_ep_close_nb";
+ if (UCS_PTR_IS_ERR(request)) {
+ status = UCS_PTR_STATUS(request);
+ } else if (UCS_PTR_IS_PTR(request)) {
+ origin = "ucp_request_check_status";
+ while ((status = ucp_request_check_status(request)) == UCS_INPROGRESS) {
+ MakeProgress();
+ }
+ ucp_request_free(request);
+ } else {
+ DCHECK(!request);
+ }
+
+ endpoint_ = nullptr;
+ if (status != UCS_OK && !IsIgnorableDisconnectError(status)) {
+ return FromUcsStatus(origin, status);
+ }
+ return Status::OK();
+ }
+
+ void MakeProgress() { ucp_worker_progress(worker_->get()); }
+
+ void Push(std::shared_ptr<Frame> frame) {
+ std::unique_lock<std::mutex> guard(frame_mutex_);
+ if (ARROW_PREDICT_FALSE(!status_.ok())) return;
+ auto pair = frames_.insert({frame->counter, frame});
+ if (!pair.second) {
+ // Not inserted, because ReadFrameAsync was called for this
+ // frame counter value and the client is already waiting on
+ // it. Complete the existing future.
+ pair.first->second.MarkFinished(std::move(frame));
+ frames_.erase(pair.first);
+ }
+ // Otherwise, we inserted the frame, meaning the client was not
+ // currently waiting for that frame counter value
+ }
+
+ void Push(Status status) {
+ std::unique_lock<std::mutex> guard(frame_mutex_);
+ status_ = std::move(status);
+ for (auto& item : frames_) {
+ item.second.MarkFinished(status_);
+ }
+ frames_.clear();
+ }
+
+ ucs_status_t RecvActiveMessage(const void* header, size_t header_length, void* data,
+ const size_t data_length,
+ const ucp_am_recv_param_t* param) {
+ auto maybe_status =
+ RecvActiveMessageImpl(header, header_length, data, data_length, param);
+ if (!maybe_status.ok()) {
+ Push(maybe_status.status());
+ return UCS_OK;
+ }
+ return maybe_status.MoveValueUnsafe();
+ }
+
+ const std::shared_ptr<MemoryManager>& memory_manager() const { return memory_manager_; }
+ void set_memory_manager(std::shared_ptr<MemoryManager> memory_manager) {
+ if (memory_manager) {
+ memory_manager_ = std::move(memory_manager);
+ } else {
+ memory_manager_ = CPUDevice::Instance()->default_memory_manager();
+ }
+ }
+ void set_read_memory_pool(MemoryPool* pool) {
+ read_memory_pool_ = pool ? pool : default_memory_pool();
+ }
+ void set_write_memory_pool(MemoryPool* pool) {
+ write_memory_pool_ = pool ? pool : default_memory_pool();
+ }
+ const std::string& peer() const { return name_; }
+
+ private:
+ class PendingAmSend {
+ public:
+ virtual ~PendingAmSend() = default;
+ UcpCallDriver::Impl* driver;
+ Future<> completed;
+ FrameHeader header;
+ };
+
+ class PendingContigSend : public PendingAmSend {
+ public:
+ std::unique_ptr<Buffer> ipc_message;
+ ucp_mem_h memh_p;
+
+ virtual ~PendingContigSend() {
+ TryUnmapBuffer(driver->worker_->context().get(), memh_p);
+ }
+ };
+
+ class PendingIovSend : public PendingAmSend {
+ public:
+ FlightPayload payload;
+ std::vector<ucp_dt_iov_t> iovs;
+#if defined(ARROW_FLIGHT_UCX_SEND_IOV_MAP)
+ std::vector<ucp_mem_h> memh_ps;
+
+ virtual ~PendingIovSend() {
+ for (ucp_mem_h memh_p : memh_ps) {
+ TryUnmapBuffer(driver->worker_->context().get(), memh_p);
+ }
+ }
+#endif
+ };
+
+ struct PendingAmRecv {
+ UcpCallDriver::Impl* driver;
+ std::shared_ptr<Frame> frame;
+ ucp_mem_h memh_p;
+
+ PendingAmRecv(UcpCallDriver::Impl* driver_, std::shared_ptr<Frame> frame_)
+ : driver(driver_), frame(std::move(frame_)) {}
+
+ ~PendingAmRecv() { TryUnmapBuffer(driver->worker_->context().get(), memh_p); }
+ };
+
+ static void AmSendCallback(void* request, ucs_status_t status, void* user_data) {
+ auto* pending_send = reinterpret_cast<PendingAmSend*>(user_data);
+ if (status == UCS_OK) {
+ pending_send->completed.MarkFinished();
+ } else {
+ pending_send->completed.MarkFinished(FromUcsStatus("ucp_am_send_nbx", status));
+ }
+ // TODO(ARROW-16126): delete should occur on a background thread if there's
+ // mapped buffers, since unmapping can be nontrivial and we don't want to block
+ // the thread doing UCX work. (Borrow the Rust transfer-and-drop pattern.)
+ delete pending_send;
+ ucp_request_free(request);
+ }
+
+ static void AmRecvCallback(void* request, ucs_status_t status, size_t length,
+ void* user_data) {
+ auto* pending_recv = reinterpret_cast<PendingAmRecv*>(user_data);
+ ucp_request_free(request);
+ if (status != UCS_OK) {
+ pending_recv->driver->Push(
+ FromUcsStatus("ucp_am_recv_data_nbx (callback)", status));
+ } else {
+ pending_recv->driver->Push(std::move(pending_recv->frame));
+ }
+ delete pending_recv;
+ }
+
+ arrow::Result<ucs_status_t> RecvActiveMessageImpl(const void* header,
+ size_t header_length, void* data,
+ const size_t data_length,
+ const ucp_am_recv_param_t* param) {
+ DCHECK(param->recv_attr & UCP_AM_RECV_ATTR_FIELD_REPLY_EP);
+
+ if (data_length > static_cast<size_t>(std::numeric_limits<int32_t>::max())) {
+ return Status::Invalid("Cannot allocate buffer greater than 2 GiB, requested: ",
+ data_length);
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto frame, Frame::ParseHeader(header, header_length));
+ if (data_length < frame->size) {
+ return Status::IOError("Expected frame of ", frame->size, " bytes, but got only ",
+ data_length);
+ }
+
+ if ((param->recv_attr & UCP_AM_RECV_ATTR_FLAG_DATA) &&
+ (memory_manager_->is_cpu() || frame->type != FrameType::kPayloadBody)) {
+ // Zero-copy path. UCX-allocated buffer must be freed later.
+
+ // XXX: this buffer can NOT be freed until AFTER we return from
+ // this handler. Otherwise, UCX won't have fully set up its
+ // internal data structures (allocated just before the buffer)
+ // and we'll crash when we free the buffer. Effectively: we can
+ // never use Then/AddCallback on a Future<> from ReadFrameAsync,
+ // because we might run the callback synchronously (which might
+ // free the buffer) when we call Push here.
+ frame->buffer =
+ arrow::internal::make_unique<UcxDataBuffer>(worker_, data, data_length);
+ Push(std::move(frame));
+ return UCS_INPROGRESS;
+ }
+
+ if ((param->recv_attr & UCP_AM_RECV_ATTR_FLAG_DATA) ||
+ (param->recv_attr & UCP_AM_RECV_ATTR_FLAG_RNDV)) {
+ // Rendezvous protocol (RNDV), or unpack to destination (DATA).
+
+ // We want to map/pin/register the buffer for faster transfer
+ // where possible. (It gets unmapped in ~PendingAmRecv.)
+ // TODO(ARROW-16126): This takes non-trivial time, so return
+ // UCS_INPROGRESS, kick off the allocation in the background,
+ // and recv the data later (is it allowed to call
+ // ucp_am_recv_data_nbx asynchronously?).
+ if (frame->type == FrameType::kPayloadBody) {
+ ARROW_ASSIGN_OR_RAISE(frame->buffer,
+ memory_manager_->AllocateBuffer(data_length));
+ } else {
+ ARROW_ASSIGN_OR_RAISE(frame->buffer,
+ AllocateBuffer(data_length, read_memory_pool_));
+ }
+
+ PendingAmRecv* pending_recv = new PendingAmRecv(this, std::move(frame));
+ TryMapBuffer(worker_->context().get(), *pending_recv->frame->buffer,
+ &pending_recv->memh_p);
+
+ ucp_request_param_t recv_param;
+ recv_param.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK |
+ UCP_OP_ATTR_FIELD_MEMORY_TYPE |
+ UCP_OP_ATTR_FIELD_USER_DATA;
+ recv_param.cb.recv_am = AmRecvCallback;
+ recv_param.user_data = pending_recv;
+ recv_param.memory_type = InferMemoryType(*pending_recv->frame->buffer);
+
+ void* dest =
+ reinterpret_cast<void*>(pending_recv->frame->buffer->mutable_address());
+ void* request =
+ ucp_am_recv_data_nbx(worker_->get(), data, dest, data_length, &recv_param);
+ if (UCS_PTR_IS_ERR(request)) {
+ delete pending_recv;
+ return FromUcsStatus("ucp_am_recv_data_nbx", UCS_PTR_STATUS(request));
+ } else if (!request) {
+ // Request completed instantly
+ Push(std::move(pending_recv->frame));
+ delete pending_recv;
+ }
+ return UCS_OK;
+ } else {
+ // Data will be freed after callback returns - copy to buffer
+ if (memory_manager_->is_cpu() || frame->type != FrameType::kPayloadBody) {
+ ARROW_ASSIGN_OR_RAISE(frame->buffer,
+ AllocateBuffer(data_length, read_memory_pool_));
+ std::memcpy(frame->buffer->mutable_data(), data, data_length);
+ } else {
+ ARROW_ASSIGN_OR_RAISE(
+ frame->buffer,
+ MemoryManager::CopyNonOwned(Buffer(reinterpret_cast<uint8_t*>(data),
+ static_cast<int64_t>(data_length)),
+ memory_manager_));
+ }
+ Push(std::move(frame));
+ return UCS_OK;
+ }
+ }
+
+ Status CompleteRequestBlocking(const std::string& context, void* request) {
+ if (UCS_PTR_IS_ERR(request)) {
+ return FromUcsStatus(context, UCS_PTR_STATUS(request));
+ } else if (UCS_PTR_IS_PTR(request)) {
+ while (true) {
+ auto status = ucp_request_check_status(request);
+ if (status == UCS_OK) {
+ break;
+ } else if (status != UCS_INPROGRESS) {
+ ucp_request_release(request);
+ return FromUcsStatus("ucp_request_check_status", status);
+ }
+ MakeProgress();
+ }
+ ucp_request_free(request);
+ } else {
+ // Send was completed instantly
+ DCHECK(!request);
+ }
+ return Status::OK();
+ }
+
+ Status CheckClosed() {
+ if (!endpoint_) {
+ return Status::Invalid("UcpCallDriver is closed");
+ }
+ return Status::OK();
+ }
+
+ const std::array<uint8_t, 8> padding_bytes_;
+#if defined(ARROW_FLIGHT_UCX_SEND_IOV_MAP)
+ ucp_mem_h padding_memh_p_;
+#endif
+
+ std::shared_ptr<UcpWorker> worker_;
+ ucp_ep_h endpoint_;
+ MemoryPool* read_memory_pool_;
+ MemoryPool* write_memory_pool_;
+ std::shared_ptr<MemoryManager> memory_manager_;
+
+ // Internal name for logging/tracing
+ std::string name_;
+ // Counter used to reorder messages
+ uint32_t counter_ = 0;
+
+ std::mutex frame_mutex_;
+ Status status_;
+ std::unordered_map<uint32_t, Future<std::shared_ptr<Frame>>> frames_;
+ uint32_t next_counter_ = 0;
+};
+
+UcpCallDriver::UcpCallDriver(std::shared_ptr<UcpWorker> worker, ucp_ep_h endpoint)
+ : impl_(new Impl(std::move(worker), endpoint)) {}
+UcpCallDriver::UcpCallDriver(UcpCallDriver&&) = default;
+UcpCallDriver& UcpCallDriver::operator=(UcpCallDriver&&) = default;
+UcpCallDriver::~UcpCallDriver() = default;
+
+arrow::Result<std::shared_ptr<Frame>> UcpCallDriver::ReadNextFrame() {
+ return impl_->ReadNextFrame();
+}
+
+Future<std::shared_ptr<Frame>> UcpCallDriver::ReadFrameAsync() {
+ return impl_->ReadFrameAsync();
+}
+
+Status UcpCallDriver::ExpectFrameType(const Frame& frame, FrameType type) {
+ if (frame.type != type) {
+ return Status::IOError("Expected frame type ", static_cast<int32_t>(type),
+ ", but got frame type ", static_cast<int32_t>(frame.type));
+ }
+ return Status::OK();
+}
+
+Status UcpCallDriver::StartCall(const std::string& method) {
+ std::vector<std::pair<std::string, std::string>> headers;
+ headers.emplace_back(kHeaderMethod, method);
+ ARROW_ASSIGN_OR_RAISE(auto frame, HeadersFrame::Make(headers));
+ auto buffer = std::move(frame).GetBuffer();
+ RETURN_NOT_OK(impl_->SendFrame(FrameType::kHeaders, buffer->data(), buffer->size()));
+ return Status::OK();
+}
+
+Future<> UcpCallDriver::SendFlightPayload(const FlightPayload& payload) {
+ return impl_->SendFlightPayload(payload);
+}
+
+Status UcpCallDriver::SendFrame(FrameType frame_type, const uint8_t* data,
+ const int64_t size) {
+ return impl_->SendFrame(frame_type, data, size);
+}
+
+Future<> UcpCallDriver::SendFrameAsync(FrameType frame_type,
+ std::unique_ptr<Buffer> buffer) {
+ return impl_->SendFrameAsync(frame_type, std::move(buffer));
+}
+
+Status UcpCallDriver::Close() { return impl_->Close(); }
+
+void UcpCallDriver::MakeProgress() { impl_->MakeProgress(); }
+
+ucs_status_t UcpCallDriver::RecvActiveMessage(const void* header, size_t header_length,
+ void* data, const size_t data_length,
+ const ucp_am_recv_param_t* param) {
+ return impl_->RecvActiveMessage(header, header_length, data, data_length, param);
+}
+
+const std::shared_ptr<MemoryManager>& UcpCallDriver::memory_manager() const {
+ return impl_->memory_manager();
+}
+
+void UcpCallDriver::set_memory_manager(std::shared_ptr<MemoryManager> memory_manager) {
+ impl_->set_memory_manager(std::move(memory_manager));
+}
+void UcpCallDriver::set_read_memory_pool(MemoryPool* pool) {
+ impl_->set_read_memory_pool(pool);
+}
+void UcpCallDriver::set_write_memory_pool(MemoryPool* pool) {
+ impl_->set_write_memory_pool(pool);
+}
+const std::string& UcpCallDriver::peer() const { return impl_->peer(); }
+
+} // namespace ucx
+} // namespace transport
+} // namespace flight
+} // namespace arrow
diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_internal.h b/cpp/src/arrow/flight/transport/ucx/ucx_internal.h
new file mode 100644
index 0000000000..bd176e2369
--- /dev/null
+++ b/cpp/src/arrow/flight/transport/ucx/ucx_internal.h
@@ -0,0 +1,354 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Common implementation of UCX communication primitives.
+
+#pragma once
+
+#include <array>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include <ucp/api/ucp.h>
+
+#include "arrow/buffer.h"
+#include "arrow/flight/server.h"
+#include "arrow/flight/transport.h"
+#include "arrow/flight/transport/ucx/util_internal.h"
+#include "arrow/flight/visibility.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/future.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/string_view.h"
+
+namespace arrow {
+namespace flight {
+namespace transport {
+namespace ucx {
+
+//------------------------------------------------------------
+// Protocol Constants
+
+static constexpr char kMethodDoExchange[] = "DoExchange";
+static constexpr char kMethodDoGet[] = "DoGet";
+static constexpr char kMethodDoPut[] = "DoPut";
+static constexpr char kMethodGetFlightInfo[] = "GetFlightInfo";
+
+static constexpr char kHeaderStatusCode[] = "flight-status-code";
+static constexpr char kHeaderStatusMessage[] = "flight-status-message";
+static constexpr char kHeaderStatusDetail[] = "flight-status-detail";
+static constexpr char kHeaderStatusDetailCode[] = "flight-status-detail-code";
+
+//------------------------------------------------------------
+// UCX Helpers
+
+/// \brief A wrapper around a ucp_context_h.
+///
+/// Used so that multiple resources can share ownership of the
+/// context. UCX has zero-copy optimizations where an application can
+/// directly use a UCX buffer, but the lifetime of such buffers is
+/// tied to the UCX context and worker, so ownership needs to be
+/// preserved.
+class UcpContext final {
+ public:
+ UcpContext() : ucp_context_(nullptr) {}
+ explicit UcpContext(ucp_context_h context) : ucp_context_(context) {}
+ ~UcpContext() {
+ if (ucp_context_) ucp_cleanup(ucp_context_);
+ ucp_context_ = nullptr;
+ }
+ ucp_context_h get() const {
+ DCHECK(ucp_context_);
+ return ucp_context_;
+ }
+
+ private:
+ ucp_context_h ucp_context_;
+};
+
+/// \brief A wrapper around a ucp_worker_h.
+class UcpWorker final {
+ public:
+ UcpWorker() : ucp_worker_(nullptr) {}
+ UcpWorker(std::shared_ptr<UcpContext> context, ucp_worker_h worker)
+ : ucp_context_(std::move(context)), ucp_worker_(worker) {}
+ ~UcpWorker() {
+ if (ucp_worker_) ucp_worker_destroy(ucp_worker_);
+ ucp_worker_ = nullptr;
+ }
+ ucp_worker_h get() const {
+ DCHECK(ucp_worker_);
+ return ucp_worker_;
+ }
+ const UcpContext& context() const { return *ucp_context_; }
+
+ private:
+ std::shared_ptr<UcpContext> ucp_context_;
+ ucp_worker_h ucp_worker_;
+};
+
+//------------------------------------------------------------
+// Message Framing
+
+/// \brief The message type.
+enum class FrameType : uint8_t {
+ /// Key-value headers. Sent at the beginning (client->server) and
+ /// end (server->client) of a call. Also, for client-streaming calls
+ /// (e.g. DoPut), the client should send a headers frame to signal
+ /// end-of-stream.
+ kHeaders = 0,
+ /// Binary blob, does not contain Arrow data.
+ kBuffer,
+ /// Binary blob. Contains IPC metadata, app metadata.
+ kPayloadHeader,
+ /// Binary blob. Contains IPC body. Body is sent separately since it
+ /// may use a different memory type.
+ kPayloadBody,
+ /// Ask server to disconnect (to avoid client/server waiting on each
+ /// other and getting stuck).
+ kDisconnect,
+ /// Keep at end.
+ kMaxFrameType = kDisconnect,
+};
+
+/// \brief The header of a message frame. Used when sending only.
+///
+/// A frame is expected to be sent over UCP Active Messages and
+/// consists of a header (of kFrameHeaderBytes bytes) and a body.
+///
+/// The header is as follows:
+/// +-------+---------------------------------+
+/// | Bytes | Function |
+/// +=======+=================================+
+/// | 0 | Version tag (see kFrameVersion) |
+/// | 1 | Frame type (see FrameType) |
+/// | 2-3 | Unused, reserved |
+/// | 4-7 | Frame counter (big-endian) |
+/// | 8-11 | Body size (big-endian) |
+/// +-------+---------------------------------+
+///
+/// The frame counter lets the receiver ensure messages are processed
+/// in-order. (The message receive callback may use
+/// ucp_am_recv_data_nbx which is asynchronous.)
+///
+/// The body size reports the expected message size (UCX chokes on
+/// zero-size payloads which we occasionally want to send, so the size
+/// field in the header lets us know when a payload was meant to be
+/// empty).
+struct FrameHeader {
+ /// \brief The size of a frame header.
+ static constexpr size_t kFrameHeaderBytes = 12;
+ /// \brief The expected version tag in the header.
+ static constexpr uint8_t kFrameVersion = 0x01;
+
+ FrameHeader() = default;
+ /// \brief Initialize the frame header.
+ Status Set(FrameType frame_type, uint32_t counter, int64_t body_size);
+ void* data() const { return header.data(); }
+ size_t size() const { return kFrameHeaderBytes; }
+
+ // mutable since UCX expects void* not const void*
+ mutable std::array<uint8_t, kFrameHeaderBytes> header = {0};
+};
+
+/// \brief A single message received via UCX. Used when receiving only.
+struct Frame {
+ /// \brief The message type.
+ FrameType type;
+ /// \brief The message length.
+ uint32_t size;
+ /// \brief An incrementing message counter (may wrap over).
+ uint32_t counter;
+ /// \brief The message contents.
+ std::unique_ptr<Buffer> buffer;
+
+ Frame() = default;
+ Frame(FrameType type_, uint32_t size_, uint32_t counter_,
+ std::unique_ptr<Buffer> buffer_)
+ : type(type_), size(size_), counter(counter_), buffer(std::move(buffer_)) {}
+
+ util::string_view view() const {
+ return util::string_view(reinterpret_cast<const char*>(buffer->data()), size);
+ }
+
+ /// \brief Parse a UCX active message header. This will not
+ /// initialize the buffer field.
+ static arrow::Result<std::shared_ptr<Frame>> ParseHeader(const void* header,
+ size_t header_length);
+};
+
+/// \brief The active message handler callback ID.
+static constexpr uint32_t kUcpAmHandlerId = 0x1024;
+
+/// \brief A collection of key-value headers.
+///
+/// This should be stored in a frame of type kHeaders.
+///
+/// Format:
+/// +-------+----------------------------------+
+/// | Bytes | Contents |
+/// +=======+==================================+
+/// | 0-4 | # of headers (big-endian) |
+/// | 4-8 | Header key length (big-endian) |
+/// | 2-3 | Header value length (big-endian) |
+/// | (...) | Header key |
+/// | (...) | Header value |
+/// | (...) | (repeat from row 2) |
+/// +-------+----------------------------------+
+class HeadersFrame {
+ public:
+ /// \brief Get a header value (or an error if it was not found)
+ arrow::Result<util::string_view> Get(const std::string& key);
+ /// \brief Extract the server-sent status.
+ Status GetStatus(Status* out);
+ /// \brief Parse the headers from the buffer.
+ static arrow::Result<HeadersFrame> Parse(std::unique_ptr<Buffer> buffer);
+ /// \brief Create a new frame with the given headers.
+ static arrow::Result<HeadersFrame> Make(
+ const std::vector<std::pair<std::string, std::string>>& headers);
+ /// \brief Create a new frame with the given headers and the given status.
+ static arrow::Result<HeadersFrame> Make(
+ const Status& status,
+ const std::vector<std::pair<std::string, std::string>>& headers);
+
+ /// \brief Take ownership of the underlying buffer.
+ std::unique_ptr<Buffer> GetBuffer() && { return std::move(buffer_); }
+
+ private:
+ std::unique_ptr<Buffer> buffer_;
+ std::vector<std::pair<util::string_view, util::string_view>> headers_;
+};
+
+/// \brief A representation of a kPayloadHeader frame (i.e. all of the
+/// metadata in a FlightPayload/FlightData).
+///
+/// Data messages are sent in two parts: one containing all metadata
+/// (the Flatbuffers header, FlightDescriptor, and app_metadata
+/// fields) and one containing the actual data. This was done to avoid
+/// having to concatenate these fields with the data itself (in the
+/// cases where we are not using IOV).
+///
+/// Format:
+/// +--------+----------------------------------+
+/// | Bytes | Contents |
+/// +========+==================================+
+/// | 0-4 | Descriptor length (big-endian) |
+/// | 4..a | Descriptor bytes |
+/// | a-a+4 | app_metadata length (big-endian) |
+/// | a+4..b | app_metadata bytes |
+/// | b-b+4 | ipc_metadata length (big-endian) |
+/// | b+4..c | ipc_metadata bytes |
+/// +--------+----------------------------------+
+///
+/// If a field is not present, its length is still there, but is set
+/// to UINT32_MAX.
+class PayloadHeaderFrame {
+ public:
+ explicit PayloadHeaderFrame(std::unique_ptr<Buffer> buffer)
+ : buffer_(std::move(buffer)) {}
+ /// \brief Unpack the internal buffer into a FlightData.
+ Status ToFlightData(internal::FlightData* data);
+ /// \brief Pack a payload into the internal buffer.
+ static arrow::Result<PayloadHeaderFrame> Make(const FlightPayload& payload,
+ MemoryPool* memory_pool);
+ const uint8_t* data() const { return buffer_->data(); }
+ int64_t size() const { return buffer_->size(); }
+
+ private:
+ std::unique_ptr<Buffer> buffer_;
+};
+
+/// \brief Manage the state of a UCX connection.
+class UcpCallDriver {
+ public:
+ UcpCallDriver(std::shared_ptr<UcpWorker> worker, ucp_ep_h endpoint);
+
+ UcpCallDriver(const UcpCallDriver&) = delete;
+ UcpCallDriver(UcpCallDriver&&);
+ void operator=(const UcpCallDriver&) = delete;
+ UcpCallDriver& operator=(UcpCallDriver&&);
+
+ ~UcpCallDriver();
+
+ /// \brief Start a call by sending a headers frame. Client side only.
+ ///
+ /// \param[in] method The RPC method.
+ Status StartCall(const std::string& method);
+
+ /// \brief Synchronously send a generic message with binary payload.
+ Status SendFrame(FrameType frame_type, const uint8_t* data, const int64_t size);
+ /// \brief Asynchronously send a generic message with binary payload.
+ ///
+ /// The UCP driver must be manually polled (call MakeProgress()).
+ Future<> SendFrameAsync(FrameType frame_type, std::unique_ptr<Buffer> buffer);
+ /// \brief Asynchronously send a data message.
+ ///
+ /// The UCP driver must be manually polled (call MakeProgress()).
+ Future<> SendFlightPayload(const FlightPayload& payload);
+
+ /// \brief Synchronously read the next frame.
+ arrow::Result<std::shared_ptr<Frame>> ReadNextFrame();
+ /// \brief Asynchronously read the next frame.
+ ///
+ /// The UCP driver must be manually polled (call MakeProgress()).
+ Future<std::shared_ptr<Frame>> ReadFrameAsync();
+
+ /// \brief Validate that the frame is of the given type.
+ Status ExpectFrameType(const Frame& frame, FrameType type);
+
+ /// \brief Disconnect the other side of the connection. Note, this
+ /// can cause deadlock.
+ Status Close();
+
+ /// \brief Synchronously make progress (to adapt async to sync APIs)
+ void MakeProgress();
+
+ /// \brief Get the associated memory manager.
+ const std::shared_ptr<MemoryManager>& memory_manager() const;
+ /// \brief Set the associated memory manager.
+ void set_memory_manager(std::shared_ptr<MemoryManager> memory_manager);
+ /// \brief Set memory pool for scratch space used during reading.
+ void set_read_memory_pool(MemoryPool* memory_pool);
+ /// \brief Set memory pool for scratch space used during writing.
+ void set_write_memory_pool(MemoryPool* memory_pool);
+ /// \brief Get a debug string naming the peer.
+ const std::string& peer() const;
+
+ /// \brief Process an incoming active message. This will unblock the
+ /// corresponding call to ReadFrameAsync/ReadNextFrame.
+ ucs_status_t RecvActiveMessage(const void* header, size_t header_length, void* data,
+ const size_t data_length,
+ const ucp_am_recv_param_t* param);
+
+ private:
+ class Impl;
+ std::unique_ptr<Impl> impl_;
+};
+
+ARROW_FLIGHT_EXPORT
+std::unique_ptr<arrow::flight::internal::ClientTransport> MakeUcxClientImpl();
+
+ARROW_FLIGHT_EXPORT
+std::unique_ptr<arrow::flight::internal::ServerTransport> MakeUcxServerImpl(
+ FlightServerBase* base, std::shared_ptr<MemoryManager> memory_manager);
+
+} // namespace ucx
+} // namespace transport
+} // namespace flight
+} // namespace arrow
diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_server.cc b/cpp/src/arrow/flight/transport/ucx/ucx_server.cc
new file mode 100644
index 0000000000..74a9311d0c
--- /dev/null
+++ b/cpp/src/arrow/flight/transport/ucx/ucx_server.cc
@@ -0,0 +1,628 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/flight/transport/ucx/ucx_internal.h"
+
+#include <atomic>
+#include <mutex>
+#include <queue>
+#include <thread>
+#include <unordered_map>
+
+#include <arpa/inet.h>
+#include <ucp/api/ucp.h>
+
+#include "arrow/buffer.h"
+#include "arrow/flight/server.h"
+#include "arrow/flight/transport.h"
+#include "arrow/flight/transport/ucx/util_internal.h"
+#include "arrow/flight/transport_server.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/util/io_util.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/make_unique.h"
+#include "arrow/util/thread_pool.h"
+#include "arrow/util/uri.h"
+
+namespace arrow {
+namespace flight {
+namespace transport {
+namespace ucx {
+
+// Send an error to the client and return OK.
+// Statuses returned up to the main server loop trigger a kReset instead.
+#define SERVER_RETURN_NOT_OK(driver, status) \
+ do { \
+ ::arrow::Status s = (status); \
+ if (!s.ok()) { \
+ ARROW_ASSIGN_OR_RAISE(auto headers, HeadersFrame::Make(s, {})); \
+ auto payload = std::move(headers).GetBuffer(); \
+ RETURN_NOT_OK( \
+ driver->SendFrame(FrameType::kHeaders, payload->data(), payload->size())); \
+ return ::arrow::Status::OK(); \
+ } \
+ } while (false)
+
+#define FLIGHT_LOG(LEVEL) (ARROW_LOG(LEVEL) << "[server] ")
+#define FLIGHT_LOG_PEER(LEVEL, PEER) \
+ (ARROW_LOG(LEVEL) << "[server]" \
+ << "[peer=" << (PEER) << "] ")
+
+namespace {
+class UcxServerCallContext : public flight::ServerCallContext {
+ public:
+ const std::string& peer_identity() const override { return peer_; }
+ const std::string& peer() const override { return peer_; }
+ ServerMiddleware* GetMiddleware(const std::string& key) const override {
+ return nullptr;
+ }
+ bool is_cancelled() const override { return false; }
+
+ private:
+ std::string peer_;
+};
+
+class UcxServerStream : public internal::ServerDataStream {
+ public:
+ explicit UcxServerStream(UcpCallDriver* driver)
+ : peer_(driver->peer()), driver_(driver), writes_done_(false) {}
+
+ Status WritesDone() override {
+ writes_done_ = true;
+ return Status::OK();
+ }
+
+ protected:
+ std::string peer_;
+ UcpCallDriver* driver_;
+ bool writes_done_;
+};
+
+class GetServerStream : public UcxServerStream {
+ public:
+ using UcxServerStream::UcxServerStream;
+
+ arrow::Result<bool> WriteData(const FlightPayload& payload) override {
+ if (writes_done_) return false;
+ Future<> pending_send = driver_->SendFlightPayload(payload);
+ while (!pending_send.is_finished()) {
+ driver_->MakeProgress();
+ }
+ RETURN_NOT_OK(pending_send.status());
+ return true;
+ }
+};
+
+class PutServerStream : public UcxServerStream {
+ public:
+ explicit PutServerStream(UcpCallDriver* driver)
+ : UcxServerStream(driver), finished_(false) {}
+
+ bool ReadData(internal::FlightData* data) override {
+ if (finished_) return false;
+
+ bool success = true;
+ auto status = ReadImpl(data).Value(&success);
+
+ if (!status.ok() || !success) {
+ finished_ = true;
+ if (!status.ok()) {
+ FLIGHT_LOG_PEER(WARNING, peer_) << "I/O error in DoPut: " << status.ToString();
+ return false;
+ }
+ }
+ return success;
+ }
+
+ Status WritePutMetadata(const Buffer& payload) override {
+ if (finished_) return Status::OK();
+ // Send synchronously (we don't control payload lifetime)
+ return driver_->SendFrame(FrameType::kBuffer, payload.data(), payload.size());
+ }
+
+ private:
+ ::arrow::Result<bool> ReadImpl(internal::FlightData* data) {
+ ARROW_ASSIGN_OR_RAISE(auto frame, driver_->ReadNextFrame());
+ if (frame->type == FrameType::kHeaders) {
+ // Trailers, client is done writing
+ return false;
+ }
+ RETURN_NOT_OK(driver_->ExpectFrameType(*frame, FrameType::kPayloadHeader));
+ PayloadHeaderFrame payload_header(std::move(frame->buffer));
+ RETURN_NOT_OK(payload_header.ToFlightData(data));
+
+ if (data->metadata) {
+ ARROW_ASSIGN_OR_RAISE(auto message, ipc::Message::Open(data->metadata, nullptr));
+
+ if (ipc::Message::HasBody(message->type())) {
+ ARROW_ASSIGN_OR_RAISE(frame, driver_->ReadNextFrame());
+ RETURN_NOT_OK(driver_->ExpectFrameType(*frame, FrameType::kPayloadBody));
+ data->body = std::move(frame->buffer);
+ }
+ }
+ return true;
+ }
+
+ bool finished_;
+};
+
+class ExchangeServerStream : public PutServerStream {
+ public:
+ using PutServerStream::PutServerStream;
+
+ arrow::Result<bool> WriteData(const FlightPayload& payload) override {
+ if (writes_done_) return false;
+ Future<> pending_send = driver_->SendFlightPayload(payload);
+ while (!pending_send.is_finished()) {
+ driver_->MakeProgress();
+ }
+ RETURN_NOT_OK(pending_send.status());
+ return true;
+ }
+ Status WritePutMetadata(const Buffer& payload) override {
+ return Status::NotImplemented("Not supported on this stream");
+ }
+};
+
+class UcxServerImpl : public arrow::flight::internal::ServerTransport {
+ public:
+ using arrow::flight::internal::ServerTransport::ServerTransport;
+
+ virtual ~UcxServerImpl() {
+ if (listening_.load()) {
+ auto st = Shutdown();
+ if (!st.ok()) {
+ ARROW_LOG(WARNING) << "Server did not shut down properly: " << st.ToString();
+ }
+ }
+ }
+
+ Status Init(const FlightServerOptions& options,
+ const arrow::internal::Uri& uri) override {
+ const auto max_threads = std::max<uint32_t>(8, std::thread::hardware_concurrency());
+ ARROW_ASSIGN_OR_RAISE(rpc_pool_, arrow::internal::ThreadPool::Make(max_threads));
+
+ struct sockaddr_storage listen_addr;
+ ARROW_ASSIGN_OR_RAISE(auto addrlen, UriToSockaddr(uri, &listen_addr));
+
+ // Init UCX
+ {
+ ucp_config_t* ucp_config;
+ ucp_params_t ucp_params;
+ ucs_status_t status;
+
+ status = ucp_config_read(nullptr, nullptr, &ucp_config);
+ RETURN_NOT_OK(FromUcsStatus("ucp_config_read", status));
+
+ // If location is IPv6, must adjust UCX config
+ if (listen_addr.ss_family == AF_INET6) {
+ status = ucp_config_modify(ucp_config, "AF_PRIO", "inet6");
+ RETURN_NOT_OK(FromUcsStatus("ucp_config_modify", status));
+ }
+
+ // Allow application to override UCP config
+ if (options.builder_hook) options.builder_hook(ucp_config);
+
+ std::memset(&ucp_params, 0, sizeof(ucp_params));
+ ucp_params.field_mask =
+ UCP_PARAM_FIELD_FEATURES | UCP_PARAM_FIELD_MT_WORKERS_SHARED;
+ ucp_params.features = UCP_FEATURE_AM | UCP_FEATURE_WAKEUP;
+ ucp_params.mt_workers_shared = UCS_THREAD_MODE_MULTI;
+
+ ucp_context_h ucp_context;
+ status = ucp_init(&ucp_params, ucp_config, &ucp_context);
+ ucp_config_release(ucp_config);
+ RETURN_NOT_OK(FromUcsStatus("ucp_init", status));
+ ucp_context_.reset(new UcpContext(ucp_context));
+ }
+
+ {
+ // Create one worker to listen for incoming connections.
+ ucp_worker_params_t worker_params;
+ ucs_status_t status;
+
+ std::memset(&worker_params, 0, sizeof(worker_params));
+ worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE;
+ worker_params.thread_mode = UCS_THREAD_MODE_MULTI;
+ ucp_worker_h worker;
+ status = ucp_worker_create(ucp_context_->get(), &worker_params, &worker);
+ RETURN_NOT_OK(FromUcsStatus("ucp_worker_create", status));
+ worker_conn_.reset(new UcpWorker(ucp_context_, worker));
+ }
+
+ // Start listening for connections.
+ {
+ ucp_listener_params_t params;
+ ucs_status_t status;
+
+ params.field_mask =
+ UCP_LISTENER_PARAM_FIELD_SOCK_ADDR | UCP_LISTENER_PARAM_FIELD_CONN_HANDLER;
+ params.sockaddr.addr = reinterpret_cast<const sockaddr*>(&listen_addr);
+ params.sockaddr.addrlen = addrlen;
+ params.conn_handler.cb = HandleIncomingConnection;
+ params.conn_handler.arg = this;
+
+ status = ucp_listener_create(worker_conn_->get(), ¶ms, &listener_);
+ RETURN_NOT_OK(FromUcsStatus("ucp_listener_create", status));
+
+ // Get the real address/port
+ ucp_listener_attr_t attr;
+ attr.field_mask = UCP_LISTENER_ATTR_FIELD_SOCKADDR;
+ status = ucp_listener_query(listener_, &attr);
+ RETURN_NOT_OK(FromUcsStatus("ucp_listener_query", status));
+
+ std::string raw_uri = "ucx://";
+ if (uri.host().find(':') != std::string::npos) {
+ // IPv6 host
+ raw_uri += '[';
+ raw_uri += uri.host();
+ raw_uri += ']';
+ } else {
+ raw_uri += uri.host();
+ }
+ raw_uri += ":";
+ raw_uri += std::to_string(
+ ntohs(reinterpret_cast<const sockaddr_in*>(&attr.sockaddr)->sin_port));
+ std::string listen_str;
+ ARROW_UNUSED(SockaddrToString(attr.sockaddr).Value(&listen_str));
+ FLIGHT_LOG(DEBUG) << "Listening on " << listen_str;
+ ARROW_ASSIGN_OR_RAISE(location_, Location::Parse(raw_uri));
+ }
+
+ {
+ listening_.store(true);
+ std::thread listener_thread(&UcxServerImpl::DriveConnections, this);
+ listener_thread_.swap(listener_thread);
+ }
+
+ return Status::OK();
+ }
+
+ Status Shutdown() override {
+ if (!listening_.load()) return Status::OK();
+ Status status;
+
+ // Wait for current RPCs to finish
+ listening_.store(false);
+ // Unstick the listener thread from ucp_worker_wait
+ RETURN_NOT_OK(
+ FromUcsStatus("ucp_worker_signal", ucp_worker_signal(worker_conn_->get())));
+ status &= Wait();
+
+ {
+ // Reject all pending connections
+ std::unique_lock<std::mutex> guard(pending_connections_mutex_);
+ while (!pending_connections_.empty()) {
+ status &=
+ FromUcsStatus("ucp_listener_reject",
+ ucp_listener_reject(listener_, pending_connections_.front()));
+ pending_connections_.pop();
+ }
+ ucp_listener_destroy(listener_);
+ worker_conn_.reset();
+ }
+
+ status &= rpc_pool_->Shutdown();
+ rpc_pool_.reset();
+
+ ucp_context_.reset();
+ return status;
+ }
+
+ Status Shutdown(const std::chrono::system_clock::time_point& deadline) override {
+ // TODO(ARROW-16125): implement shutdown with deadline
+ return Shutdown();
+ }
+
+ Status Wait() override {
+ std::lock_guard<std::mutex> guard(join_mutex_);
+ try {
+ listener_thread_.join();
+ } catch (const std::system_error& e) {
+ if (e.code() != std::errc::invalid_argument) {
+ return Status::UnknownError("Could not Wait(): ", e.what());
+ }
+ // Else, server wasn't running anyways
+ }
+ return Status::OK();
+ }
+
+ Location location() const override { return location_; }
+
+ private:
+ struct ClientWorker {
+ std::shared_ptr<UcpWorker> worker;
+ std::unique_ptr<UcpCallDriver> driver;
+ };
+
+ Status SendStatus(UcpCallDriver* driver, const Status& status) {
+ ARROW_ASSIGN_OR_RAISE(auto headers, HeadersFrame::Make(status, {}));
+ auto payload = std::move(headers).GetBuffer();
+ RETURN_NOT_OK(
+ driver->SendFrame(FrameType::kHeaders, payload->data(), payload->size()));
+ return Status::OK();
+ }
+
+ Status HandleGetFlightInfo(UcpCallDriver* driver) {
+ UcxServerCallContext context;
+
+ ARROW_ASSIGN_OR_RAISE(auto frame, driver->ReadNextFrame());
+ SERVER_RETURN_NOT_OK(driver, driver->ExpectFrameType(*frame, FrameType::kBuffer));
+ FlightDescriptor descriptor;
+ SERVER_RETURN_NOT_OK(driver,
+ FlightDescriptor::Deserialize(util::string_view(*frame->buffer))
+ .Value(&descriptor));
+
+ std::unique_ptr<FlightInfo> info;
+ std::string response;
+ SERVER_RETURN_NOT_OK(driver, base_->GetFlightInfo(context, descriptor, &info));
+ SERVER_RETURN_NOT_OK(driver, info->SerializeToString().Value(&response));
+ RETURN_NOT_OK(driver->SendFrame(FrameType::kBuffer,
+ reinterpret_cast<const uint8_t*>(response.data()),
+ static_cast<int64_t>(response.size())));
+ RETURN_NOT_OK(SendStatus(driver, Status::OK()));
+ return Status::OK();
+ }
+
+ Status HandleDoGet(UcpCallDriver* driver) {
+ UcxServerCallContext context;
+
+ ARROW_ASSIGN_OR_RAISE(auto frame, driver->ReadNextFrame());
+ SERVER_RETURN_NOT_OK(driver, driver->ExpectFrameType(*frame, FrameType::kBuffer));
+ Ticket ticket;
+ SERVER_RETURN_NOT_OK(driver, Ticket::Deserialize(frame->view()).Value(&ticket));
+
+ GetServerStream stream(driver);
+ auto status = DoGet(context, std::move(ticket), &stream);
+ RETURN_NOT_OK(SendStatus(driver, status));
+ return Status::OK();
+ }
+
+ Status HandleDoPut(UcpCallDriver* driver) {
+ UcxServerCallContext context;
+
+ PutServerStream stream(driver);
+ auto status = DoPut(context, &stream);
+ RETURN_NOT_OK(SendStatus(driver, status));
+ // Must drain any unread messages, or the next call will get confused
+ internal::FlightData ignored;
+ while (stream.ReadData(&ignored)) {
+ }
+ return Status::OK();
+ }
+
+ Status HandleDoExchange(UcpCallDriver* driver) {
+ UcxServerCallContext context;
+
+ ExchangeServerStream stream(driver);
+ auto status = DoExchange(context, &stream);
+ RETURN_NOT_OK(SendStatus(driver, status));
+ // Must drain any unread messages, or the next call will get confused
+ internal::FlightData ignored;
+ while (stream.ReadData(&ignored)) {
+ }
+ return Status::OK();
+ }
+
+ Status HandleOneCall(UcpCallDriver* driver, Frame* frame) {
+ SERVER_RETURN_NOT_OK(driver, driver->ExpectFrameType(*frame, FrameType::kHeaders));
+ ARROW_ASSIGN_OR_RAISE(auto headers, HeadersFrame::Parse(std::move(frame->buffer)));
+ ARROW_ASSIGN_OR_RAISE(auto method, headers.Get(":method:"));
+ if (method == kMethodGetFlightInfo) {
+ return HandleGetFlightInfo(driver);
+ } else if (method == kMethodDoExchange) {
+ return HandleDoExchange(driver);
+ } else if (method == kMethodDoGet) {
+ return HandleDoGet(driver);
+ } else if (method == kMethodDoPut) {
+ return HandleDoPut(driver);
+ }
+ RETURN_NOT_OK(SendStatus(driver, Status::NotImplemented(method)));
+ return Status::OK();
+ }
+
+ void WorkerLoop(ucp_conn_request_h request) {
+ std::string peer = "unknown:" + std::to_string(counter_++);
+ {
+ ucp_conn_request_attr_t request_attr;
+ std::memset(&request_attr, 0, sizeof(request_attr));
+ request_attr.field_mask = UCP_CONN_REQUEST_ATTR_FIELD_CLIENT_ADDR;
+ if (ucp_conn_request_query(request, &request_attr) == UCS_OK) {
+ ARROW_UNUSED(SockaddrToString(request_attr.client_address).Value(&peer));
+ }
+ }
+ FLIGHT_LOG_PEER(DEBUG, peer) << "Received connection request";
+
+ auto maybe_worker = CreateWorker();
+ if (!maybe_worker.ok()) {
+ FLIGHT_LOG_PEER(WARNING, peer)
+ << "Failed to create worker" << maybe_worker.status().ToString();
+ auto status = ucp_listener_reject(listener_, request);
+ if (status != UCS_OK) {
+ FLIGHT_LOG_PEER(WARNING, peer)
+ << FromUcsStatus("ucp_listener_reject", status).ToString();
+ }
+ return;
+ }
+ auto worker = maybe_worker.MoveValueUnsafe();
+
+ // Create an endpoint to the client, using the data worker
+ {
+ ucs_status_t status;
+ ucp_ep_params_t params;
+ std::memset(¶ms, 0, sizeof(params));
+ params.field_mask = UCP_EP_PARAM_FIELD_CONN_REQUEST;
+ params.conn_request = request;
+
+ ucp_ep_h client_endpoint;
+
+ status = ucp_ep_create(worker->worker->get(), ¶ms, &client_endpoint);
+ if (status != UCS_OK) {
+ FLIGHT_LOG_PEER(WARNING, peer)
+ << "Failed to create endpoint: "
+ << FromUcsStatus("ucp_ep_create", status).ToString();
+ return;
+ }
+ worker->driver.reset(new UcpCallDriver(worker->worker, client_endpoint));
+ worker->driver->set_memory_manager(memory_manager_);
+ peer = worker->driver->peer();
+ }
+
+ while (listening_.load()) {
+ auto maybe_frame = worker->driver->ReadNextFrame();
+ if (!maybe_frame.ok()) {
+ if (!maybe_frame.status().IsCancelled()) {
+ FLIGHT_LOG_PEER(WARNING, peer)
+ << "Failed to read next message: " << maybe_frame.status().ToString();
+ }
+ break;
+ }
+
+ auto status = HandleOneCall(worker->driver.get(), maybe_frame->get());
+ if (!status.ok()) {
+ FLIGHT_LOG_PEER(WARNING, peer) << "Call failed: " << status.ToString();
+ break;
+ }
+ }
+
+ // Clean up
+ auto status = worker->driver->Close();
+ if (!status.ok()) {
+ FLIGHT_LOG_PEER(WARNING, peer) << "Failed to close worker: " << status.ToString();
+ }
+ worker->worker.reset();
+ FLIGHT_LOG_PEER(DEBUG, peer) << "Disconnected";
+ }
+
+ void DriveConnections() {
+ while (listening_.load()) {
+ while (ucp_worker_progress(worker_conn_->get())) {
+ }
+ {
+ // Check for connect requests in queue
+ std::unique_lock<std::mutex> guard(pending_connections_mutex_);
+ while (!pending_connections_.empty()) {
+ ucp_conn_request_h request = pending_connections_.front();
+ pending_connections_.pop();
+
+ auto submitted = rpc_pool_->Submit([this, request]() { WorkerLoop(request); });
+ if (!submitted.ok()) {
+ ARROW_LOG(WARNING) << "Failed to submit task to handle client "
+ << submitted.status().ToString();
+ }
+ }
+ }
+
+ // Check listening_ in case we're shutting down. It is possible
+ // that Shutdown() was called while we were in
+ // ucp_worker_progress above, in which case if we don't check
+ // listening_ here, we'll enter ucp_worker_wait and get stuck.
+ if (!listening_.load()) break;
+ auto status = ucp_worker_wait(worker_conn_->get());
+ if (status != UCS_OK) {
+ FLIGHT_LOG(WARNING) << FromUcsStatus("ucp_worker_wait", status).ToString();
+ }
+ }
+ }
+
+ void EnqueueClient(ucp_conn_request_h connection_request) {
+ std::unique_lock<std::mutex> guard(pending_connections_mutex_);
+ pending_connections_.push(connection_request);
+ guard.unlock();
+ }
+
+ arrow::Result<std::shared_ptr<ClientWorker>> CreateWorker() {
+ auto worker = std::make_shared<ClientWorker>();
+
+ ucp_worker_params_t worker_params;
+ std::memset(&worker_params, 0, sizeof(worker_params));
+ worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE;
+ worker_params.thread_mode = UCS_THREAD_MODE_SINGLE;
+
+ ucp_worker_h ucp_worker;
+ auto status = ucp_worker_create(ucp_context_->get(), &worker_params, &ucp_worker);
+ RETURN_NOT_OK(FromUcsStatus("ucp_worker_create", status));
+ worker->worker.reset(new UcpWorker(ucp_context_, ucp_worker));
+
+ // Set up Active Message (AM) handler
+ ucp_am_handler_param_t handler_params;
+ std::memset(&handler_params, 0, sizeof(handler_params));
+ handler_params.field_mask = UCP_AM_HANDLER_PARAM_FIELD_ID |
+ UCP_AM_HANDLER_PARAM_FIELD_CB |
+ UCP_AM_HANDLER_PARAM_FIELD_ARG;
+ handler_params.id = kUcpAmHandlerId;
+ handler_params.cb = HandleIncomingActiveMessage;
+ handler_params.arg = worker.get();
+
+ status = ucp_worker_set_am_recv_handler(worker->worker->get(), &handler_params);
+ RETURN_NOT_OK(FromUcsStatus("ucp_worker_set_am_recv_handler", status));
+ return worker;
+ }
+
+ // Callback handler. A new client has connected to the server.
+ static void HandleIncomingConnection(ucp_conn_request_h connection_request,
+ void* data) {
+ UcxServerImpl* server = reinterpret_cast<UcxServerImpl*>(data);
+ // TODO(ARROW-16124): enable shedding load above some threshold
+ // (which is a pitfall with gRPC/Java)
+ server->EnqueueClient(connection_request);
+ }
+
+ static ucs_status_t HandleIncomingActiveMessage(void* self, const void* header,
+ size_t header_length, void* data,
+ size_t data_length,
+ const ucp_am_recv_param_t* param) {
+ ClientWorker* worker = reinterpret_cast<ClientWorker*>(self);
+ DCHECK(worker->driver);
+ return worker->driver->RecvActiveMessage(header, header_length, data, data_length,
+ param);
+ }
+
+ std::shared_ptr<UcpContext> ucp_context_;
+ // Listen for and handle incoming connections
+ std::shared_ptr<UcpWorker> worker_conn_;
+ ucp_listener_h listener_;
+ Location location_;
+
+ // Counter for identifying peers when UCX doesn't give us a way
+ std::atomic<size_t> counter_;
+
+ std::shared_ptr<arrow::internal::ThreadPool> rpc_pool_;
+ std::atomic<bool> listening_;
+ std::thread listener_thread_;
+ // std::thread::join cannot be called concurrently
+ std::mutex join_mutex_;
+
+ std::mutex pending_connections_mutex_;
+ std::queue<ucp_conn_request_h> pending_connections_;
+};
+} // namespace
+
+std::unique_ptr<arrow::flight::internal::ServerTransport> MakeUcxServerImpl(
+ FlightServerBase* base, std::shared_ptr<MemoryManager> memory_manager) {
+ return arrow::internal::make_unique<UcxServerImpl>(base, memory_manager);
+}
+
+#undef SERVER_RETURN_NOT_OK
+#undef FLIGHT_LOG
+#undef FLIGHT_LOG_PEER
+
+} // namespace ucx
+} // namespace transport
+} // namespace flight
+} // namespace arrow
diff --git a/cpp/src/arrow/flight/transport/ucx/util_internal.cc b/cpp/src/arrow/flight/transport/ucx/util_internal.cc
new file mode 100644
index 0000000000..ca4df21a05
--- /dev/null
+++ b/cpp/src/arrow/flight/transport/ucx/util_internal.cc
@@ -0,0 +1,289 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/flight/transport/ucx/util_internal.h"
+
+#include <netdb.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+
+#include <cerrno>
+#include <mutex>
+#include <unordered_map>
+
+#include "arrow/buffer.h"
+#include "arrow/flight/types.h"
+#include "arrow/util/base64.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/io_util.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/make_unique.h"
+#include "arrow/util/uri.h"
+
+namespace arrow {
+namespace flight {
+namespace transport {
+namespace ucx {
+
+constexpr char FlightUcxStatusDetail::kTypeId[];
+std::string FlightUcxStatusDetail::ToString() const { return ucs_status_string(status_); }
+ucs_status_t FlightUcxStatusDetail::Unwrap(const Status& status) {
+ if (!status.detail() || status.detail()->type_id() != kTypeId) return UCS_OK;
+ return dynamic_cast<const FlightUcxStatusDetail*>(status.detail().get())->status_;
+}
+
+arrow::Result<size_t> UriToSockaddr(const arrow::internal::Uri& uri,
+ struct sockaddr_storage* addr) {
+ std::string host = uri.host();
+ if (host.empty()) {
+ return Status::Invalid("Must provide a host");
+ } else if (uri.port() < 0) {
+ return Status::Invalid("Must provide a port");
+ }
+
+ std::memset(addr, 0, sizeof(*addr));
+
+ struct addrinfo* info = nullptr;
+ int err = getaddrinfo(host.c_str(), /*service=*/nullptr, /*hints=*/nullptr, &info);
+ if (err != 0) {
+ if (err == EAI_SYSTEM) {
+ return arrow::internal::IOErrorFromErrno(errno, "[getaddrinfo] Failure resolving ",
+ host);
+ } else {
+ return Status::IOError("[getaddrinfo] Failure resolving ", host, ": ",
+ gai_strerror(err));
+ }
+ }
+
+ struct addrinfo* cur_info = info;
+ while (cur_info) {
+ if (cur_info->ai_family != AF_INET && cur_info->ai_family != AF_INET6) {
+ cur_info = cur_info->ai_next;
+ continue;
+ }
+
+ std::memcpy(addr, cur_info->ai_addr, cur_info->ai_addrlen);
+ if (cur_info->ai_family == AF_INET) {
+ reinterpret_cast<sockaddr_in*>(addr)->sin_port = htons(uri.port());
+ } else if (cur_info->ai_family == AF_INET6) {
+ reinterpret_cast<sockaddr_in6*>(addr)->sin6_port = htons(uri.port());
+ }
+ size_t addrlen = cur_info->ai_addrlen;
+ freeaddrinfo(info);
+ return addrlen;
+ }
+
+ if (info) freeaddrinfo(info);
+ return Status::IOError("[getaddrinfo] Failure resolving ", host,
+ ": no results of a supported family returned");
+}
+
+arrow::Result<std::string> SockaddrToString(const struct sockaddr_storage& address) {
+ std::string result = "";
+ if (address.ss_family != AF_INET && address.ss_family != AF_INET6) {
+ return Status::NotImplemented("Unknown address family");
+ }
+
+ uint16_t port = 0;
+ if (address.ss_family == AF_INET) {
+ result.resize(INET_ADDRSTRLEN + 1);
+ const auto* in_addr = reinterpret_cast<const struct sockaddr_in*>(&address);
+ if (!inet_ntop(address.ss_family, &in_addr->sin_addr, &result[0], INET_ADDRSTRLEN)) {
+ return arrow::internal::IOErrorFromErrno(errno,
+ "Could not convert address to string");
+ }
+ port = ntohs(in_addr->sin_port);
+ } else {
+ result.resize(INET6_ADDRSTRLEN + 1);
+ const auto* in6_addr = reinterpret_cast<const struct sockaddr_in6*>(&address);
+ if (!inet_ntop(address.ss_family, &in6_addr->sin6_addr, &result[0],
+ INET6_ADDRSTRLEN)) {
+ return arrow::internal::IOErrorFromErrno(errno,
+ "Could not convert address to string");
+ }
+ port = ntohs(in6_addr->sin6_port);
+ }
+
+ const size_t pos = result.find('\0');
+ DCHECK_NE(pos, std::string::npos);
+ result[pos] = ':';
+ result.resize(pos + 1);
+ result += std::to_string(port);
+ return result;
+}
+
+Status FromUcsStatus(const std::string& context, ucs_status_t ucs_status) {
+ switch (ucs_status) {
+ case UCS_OK:
+ return Status::OK();
+ case UCS_INPROGRESS:
+ return Status::IOError(context, ": UCX error ", static_cast<int32_t>(ucs_status),
+ ": ", "UCS_INPROGRESS ", ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<FlightUcxStatusDetail>(ucs_status));
+ case UCS_ERR_NO_MESSAGE:
+ return Status::IOError(context, ": UCX error ", static_cast<int32_t>(ucs_status),
+ ": ", "UCS_ERR_NO_MESSAGE ", ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<FlightUcxStatusDetail>(ucs_status));
+ case UCS_ERR_NO_RESOURCE:
+ return Status::IOError(context, ": UCX error ", static_cast<int32_t>(ucs_status),
+ ": ", "UCS_ERR_NO_RESOURCE ", ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<FlightUcxStatusDetail>(ucs_status));
+ case UCS_ERR_IO_ERROR:
+ return Status::IOError(context, ": UCX error ", static_cast<int32_t>(ucs_status),
+ ": ", "UCS_ERR_IO_ERROR ", ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<FlightUcxStatusDetail>(ucs_status));
+ case UCS_ERR_NO_MEMORY:
+ return Status::OutOfMemory(context, ": UCX error ",
+ static_cast<int32_t>(ucs_status), ": ",
+ "UCS_ERR_NO_MEMORY ", ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<FlightUcxStatusDetail>(ucs_status));
+ case UCS_ERR_INVALID_PARAM:
+ return Status::Invalid(context, ": UCX error ", static_cast<int32_t>(ucs_status),
+ ": ", "UCS_ERR_INVALID_PARAM ",
+ ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<FlightUcxStatusDetail>(ucs_status));
+ case UCS_ERR_UNREACHABLE:
+ return Status::IOError(context, ": UCX error ", static_cast<int32_t>(ucs_status),
+ ": ", "UCS_ERR_UNREACHABLE ", ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<FlightUcxStatusDetail>(ucs_status));
+ case UCS_ERR_INVALID_ADDR:
+ return Status::Invalid(context, ": UCX error ", static_cast<int32_t>(ucs_status),
+ ": ", "UCS_ERR_INVALID_ADDR ", ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<FlightUcxStatusDetail>(ucs_status));
+ case UCS_ERR_NOT_IMPLEMENTED:
+ return Status::NotImplemented(
+ context, ": UCX error ", static_cast<int32_t>(ucs_status), ": ",
+ "UCS_ERR_NOT_IMPLEMENTED ", ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<FlightUcxStatusDetail>(ucs_status));
+ case UCS_ERR_MESSAGE_TRUNCATED:
+ return Status::IOError(context, ": UCX error ", static_cast<int32_t>(ucs_status),
+ ": ", "UCS_ERR_MESSAGE_TRUNCATED ",
+ ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<FlightUcxStatusDetail>(ucs_status));
+ case UCS_ERR_NO_PROGRESS:
+ return Status::IOError(context, ": UCX error ", static_cast<int32_t>(ucs_status),
+ ": ", "UCS_ERR_NO_PROGRESS ", ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<FlightUcxStatusDetail>(ucs_status));
+ case UCS_ERR_BUFFER_TOO_SMALL:
+ return Status::Invalid(context, ": UCX error ", static_cast<int32_t>(ucs_status),
+ ": ", "UCS_ERR_BUFFER_TOO_SMALL ",
+ ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<FlightUcxStatusDetail>(ucs_status));
+ case UCS_ERR_NO_ELEM:
+ return Status::IOError(context, ": UCX error ", static_cast<int32_t>(ucs_status),
+ ": ", "UCS_ERR_NO_ELEM ", ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<FlightUcxStatusDetail>(ucs_status));
+ case UCS_ERR_SOME_CONNECTS_FAILED:
+ return Status::IOError(context, ": UCX error ", static_cast<int32_t>(ucs_status),
+ ": ", "UCS_ERR_SOME_CONNECTS_FAILED ",
+ ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<FlightUcxStatusDetail>(ucs_status));
+ case UCS_ERR_NO_DEVICE:
+ return Status::IOError(context, ": UCX error ", static_cast<int32_t>(ucs_status),
+ ": ", "UCS_ERR_NO_DEVICE ", ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<FlightUcxStatusDetail>(ucs_status));
+ case UCS_ERR_BUSY:
+ return Status::IOError(context, ": UCX error ", static_cast<int32_t>(ucs_status),
+ ": ", "UCS_ERR_BUSY ", ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<FlightUcxStatusDetail>(ucs_status));
+ case UCS_ERR_CANCELED:
+ return Status::Cancelled(context, ": UCX error ", static_cast<int32_t>(ucs_status),
+ ": ", "UCS_ERR_CANCELED ", ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<FlightUcxStatusDetail>(ucs_status));
+ case UCS_ERR_SHMEM_SEGMENT:
+ return Status::IOError(context, ": UCX error ", static_cast<int32_t>(ucs_status),
+ ": ", "UCS_ERR_SHMEM_SEGMENT ",
+ ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<FlightUcxStatusDetail>(ucs_status));
+ case UCS_ERR_ALREADY_EXISTS:
+ return Status::AlreadyExists(
+ context, ": UCX error ", static_cast<int32_t>(ucs_status), ": ",
+ "UCS_ERR_ALREADY_EXISTS ", ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<FlightUcxStatusDetail>(ucs_status));
+ case UCS_ERR_OUT_OF_RANGE:
+ return Status::IOError(context, ": UCX error ", static_cast<int32_t>(ucs_status),
+ ": ", "UCS_ERR_OUT_OF_RANGE ", ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<FlightUcxStatusDetail>(ucs_status));
+ case UCS_ERR_TIMED_OUT:
+ return Status::Cancelled(context, ": UCX error ", static_cast<int32_t>(ucs_status),
+ ": ", "UCS_ERR_TIMED_OUT ", ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<FlightUcxStatusDetail>(ucs_status));
+ case UCS_ERR_EXCEEDS_LIMIT:
+ return Status::IOError(context, ": UCX error ", static_cast<int32_t>(ucs_status),
+ ": ", "UCS_ERR_EXCEEDS_LIMIT ",
+ ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<FlightUcxStatusDetail>(ucs_status));
+ case UCS_ERR_UNSUPPORTED:
+ return Status::NotImplemented(context, ": UCX error ",
+ static_cast<int32_t>(ucs_status), ": ",
+ "UCS_ERR_UNSUPPORTED ", ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<FlightUcxStatusDetail>(ucs_status));
+ case UCS_ERR_REJECTED:
+ return Status::IOError(context, ": UCX error ", static_cast<int32_t>(ucs_status),
+ ": ", "UCS_ERR_REJECTED ", ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<FlightUcxStatusDetail>(ucs_status));
+ case UCS_ERR_NOT_CONNECTED:
+ return Status::IOError(context, ": UCX error ", static_cast<int32_t>(ucs_status),
+ ": ", "UCS_ERR_NOT_CONNECTED ",
+ ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<FlightUcxStatusDetail>(ucs_status));
+ case UCS_ERR_CONNECTION_RESET:
+ return Status::IOError(context, ": UCX error ", static_cast<int32_t>(ucs_status),
+ ": ", "UCS_ERR_CONNECTION_RESET ",
+ ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<FlightUcxStatusDetail>(ucs_status));
+ case UCS_ERR_FIRST_LINK_FAILURE:
+ return Status::IOError(context, ": UCX error ", static_cast<int32_t>(ucs_status),
+ ": ", "UCS_ERR_FIRST_LINK_FAILURE ",
+ ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<FlightUcxStatusDetail>(ucs_status));
+ case UCS_ERR_LAST_LINK_FAILURE:
+ return Status::IOError(context, ": UCX error ", static_cast<int32_t>(ucs_status),
+ ": ", "UCS_ERR_LAST_LINK_FAILURE ",
+ ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<FlightUcxStatusDetail>(ucs_status));
+ case UCS_ERR_FIRST_ENDPOINT_FAILURE:
+ return Status::IOError(context, ": UCX error ", static_cast<int32_t>(ucs_status),
+ ": ", "UCS_ERR_FIRST_ENDPOINT_FAILURE ",
+ ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<FlightUcxStatusDetail>(ucs_status));
+ case UCS_ERR_LAST_ENDPOINT_FAILURE:
+ return Status::IOError(context, ": UCX error ", static_cast<int32_t>(ucs_status),
+ ": ", "UCS_ERR_LAST_ENDPOINT_FAILURE ",
+ ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<FlightUcxStatusDetail>(ucs_status));
+ case UCS_ERR_ENDPOINT_TIMEOUT:
+ return Status::IOError(context, ": UCX error ", static_cast<int32_t>(ucs_status),
+ ": ", "UCS_ERR_ENDPOINT_TIMEOUT ",
+ ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<FlightUcxStatusDetail>(ucs_status));
+ case UCS_ERR_LAST:
+ return Status::IOError(context, ": UCX error ", static_cast<int32_t>(ucs_status),
+ ": ", "UCS_ERR_LAST ", ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<FlightUcxStatusDetail>(ucs_status));
+ default:
+ return Status::UnknownError(
+ context, ": Unknown UCX error: ", static_cast<int32_t>(ucs_status), " ",
+ ucs_status_string(ucs_status))
+ .WithDetail(std::make_shared<FlightUcxStatusDetail>(ucs_status));
+ }
+}
+
+} // namespace ucx
+} // namespace transport
+} // namespace flight
+} // namespace arrow
diff --git a/cpp/src/arrow/flight/transport/ucx/util_internal.h b/cpp/src/arrow/flight/transport/ucx/util_internal.h
new file mode 100644
index 0000000000..84e84ba071
--- /dev/null
+++ b/cpp/src/arrow/flight/transport/ucx/util_internal.h
@@ -0,0 +1,83 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <arpa/inet.h>
+#include <ucp/api/ucp.h>
+#include <string>
+
+#include "arrow/flight/visibility.h"
+#include "arrow/status.h"
+#include "arrow/util/endian.h"
+#include "arrow/util/ubsan.h"
+#include "arrow/util/uri.h"
+
+namespace arrow {
+namespace flight {
+namespace transport {
+namespace ucx {
+
+static inline void UInt32ToBytesBe(const uint32_t in, uint8_t* out) {
+ util::SafeStore(out, bit_util::ToBigEndian(in));
+}
+
+static inline uint32_t BytesToUInt32Be(const uint8_t* in) {
+ return bit_util::FromBigEndian(util::SafeLoadAs<uint32_t>(in));
+}
+
+class ARROW_FLIGHT_EXPORT FlightUcxStatusDetail : public StatusDetail {
+ public:
+ explicit FlightUcxStatusDetail(ucs_status_t status) : status_(status) {}
+ static constexpr char const kTypeId[] = "flight::transport::ucx::FlightUcxStatusDetail";
+
+ const char* type_id() const override { return kTypeId; }
+ std::string ToString() const override;
+ static ucs_status_t Unwrap(const Status& status);
+
+ private:
+ ucs_status_t status_;
+};
+
+/// \brief Convert a UCS status to an Arrow Status.
+ARROW_FLIGHT_EXPORT
+Status FromUcsStatus(const std::string& context, ucs_status_t ucs_status);
+
+/// \brief Check if a UCS error code can be ignored in the context of
+/// a disconnect.
+static inline bool IsIgnorableDisconnectError(ucs_status_t ucs_status) {
+ // Not connected, connection reset: we're already disconnected
+ // Timeout: most likely disconnected, but we can't tell from our end
+ return ucs_status == UCS_OK || ucs_status == UCS_ERR_ENDPOINT_TIMEOUT ||
+ ucs_status == UCS_ERR_NOT_CONNECTED || ucs_status == UCS_ERR_CONNECTION_RESET;
+}
+
+/// \brief Helper to convert a Uri to a struct sockaddr (used in
+/// ucp_listener_params_t)
+///
+/// \return The length of the sockaddr
+ARROW_FLIGHT_EXPORT
+arrow::Result<size_t> UriToSockaddr(const arrow::internal::Uri& uri,
+ struct sockaddr_storage* addr);
+
+ARROW_FLIGHT_EXPORT
+arrow::Result<std::string> SockaddrToString(const struct sockaddr_storage& address);
+
+} // namespace ucx
+} // namespace transport
+} // namespace flight
+} // namespace arrow
diff --git a/cpp/src/arrow/flight/transport_server.cc b/cpp/src/arrow/flight/transport_server.cc
index fa5bf82710..4944a79b8f 100644
--- a/cpp/src/arrow/flight/transport_server.cc
+++ b/cpp/src/arrow/flight/transport_server.cc
@@ -54,8 +54,7 @@ class TransportIpcMessageReader : public ipc::MessageReader {
stream_finished_ = true;
return nullptr;
}
- if (data->body &&
- ARROW_PREDICT_FALSE(!data->body->device()->Equals(*memory_manager_->device()))) {
+ if (data->body) {
ARROW_ASSIGN_OR_RAISE(data->body, Buffer::ViewOrCopy(data->body, memory_manager_));
}
*app_metadata_ = std::move(data->app_metadata);
@@ -111,7 +110,7 @@ class TransportMessageReader final : public FlightMessageReader {
arrow::Result<FlightStreamChunk> Next() override {
FlightStreamChunk out;
- internal::FlightData* data;
+ internal::FlightData* data = nullptr;
peekable_reader_->Peek(&data);
if (!data) {
out.app_metadata = nullptr;
diff --git a/cpp/src/arrow/util/config.h.cmake b/cpp/src/arrow/util/config.h.cmake
index 7d7c83185e..55bc2d0100 100644
--- a/cpp/src/arrow/util/config.h.cmake
+++ b/cpp/src/arrow/util/config.h.cmake
@@ -48,5 +48,6 @@
#cmakedefine ARROW_S3
#cmakedefine ARROW_USE_NATIVE_INT128
#cmakedefine ARROW_WITH_OPENTELEMETRY
+#cmakedefine ARROW_WITH_UCX
#cmakedefine GRPCPP_PP_INCLUDE
diff --git a/docs/source/cpp/flight.rst b/docs/source/cpp/flight.rst
index c1d2e43b9f..75aea3c47c 100644
--- a/docs/source/cpp/flight.rst
+++ b/docs/source/cpp/flight.rst
@@ -117,3 +117,38 @@ success/failure of the request. Any other return values are specified
through out parameters. They also take an optional :class:`options
<arrow::flight::FlightCallOptions>` parameter that allows specifying a
timeout for the call.
+
+Alternative Transports
+======================
+
+The standard transport for Arrow Flight is gRPC_. The C++
+implementation also experimentally supports a transport based on
+UCX_. To use it, use the protocol scheme ``ucx:`` when starting a
+server or creating a client.
+
+UCX Transport
+-------------
+
+Not all features of the gRPC transport are supported. See
+:ref:`status-flight-rpc` for details. Also note these specific
+caveats:
+
+- The server creates an independent UCP worker for each client. This
+ consumes more resources but provides better throughput.
+- The client creates an independent UCP worker for each RPC
+ call. Again, this trades off resource consumption for
+ performance. This also means that unlike with gRPC, it is
+ essentially equivalent to make all calls with a single client or
+ with multiple clients.
+- The UCX transport attempts to avoid copies where possible. In some
+ cases, it can directly reuse UCX-allocated buffers to back
+ :class:`arrow::Buffer` objects, however, this will also extend the
+ lifetime of associated UCX resources beyond the lifetime of the
+ Flight client or server object.
+- Depending on the transport that UCX itself selects, you may find
+ that increasing ``UCX_MM_SEG_SIZE`` from the default (around 8KB) to
+ around 60KB improves performance (UCX will copy more data in a
+ single call).
+
+.. _gRPC: https://grpc.io/
+.. _UCX: https://openucx.org/
diff --git a/docs/source/status.rst b/docs/source/status.rst
index 7c6157357a..c30caed2f8 100644
--- a/docs/source/status.rst
+++ b/docs/source/status.rst
@@ -144,35 +144,81 @@ Notes:
.. seealso::
The :ref:`format-ipc` specification.
+.. _status-flight-rpc:
Flight RPC
==========
-+-----------------------------+-------+-------+-------+------------+-------+-------+-------+
-| Flight RPC Feature | C++ | Java | Go | JavaScript | C# | Rust | Julia |
-| | | | | | | | |
-+=============================+=======+=======+=======+============+=======+=======+=======+
-| gRPC transport | ✓ | ✓ | ✓ | | ✓ (1) | | |
-+-----------------------------+-------+-------+-------+------------+-------+-------+-------+
-| gRPC + TLS transport | ✓ | ✓ | ✓ | | ✓ | | |
-+-----------------------------+-------+-------+-------+------------+-------+-------+-------+
-| RPC error codes | ✓ | ✓ | ✓ | | ✓ | | |
-+-----------------------------+-------+-------+-------+------------+-------+-------+-------+
-| Authentication handlers | ✓ | ✓ | ✓ | | ✓ (2) | | |
-+-----------------------------+-------+-------+-------+------------+-------+-------+-------+
-| Custom client middleware | ✓ | ✓ | ✓ | | | | |
-+-----------------------------+-------+-------+-------+------------+-------+-------+-------+
-| Custom server middleware | ✓ | ✓ | ✓ | | | | |
-+-----------------------------+-------+-------+-------+------------+-------+-------+-------+
+.. note:: Flight RPC is still experimental.
+
++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+
+| Flight RPC Transport | C++ | Java | Go | JavaScript | C# | Rust | Julia |
++============================================+=======+=======+=======+============+=======+=======+=======+
+| gRPC_ transport (grpc:, grpc+tcp:) | ✓ | ✓ | ✓ | | ✓ | | |
++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+
+| gRPC domain socket transport (grpc+unix:) | ✓ | ✓ | ✓ | | ✓ | | |
++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+
+| gRPC + TLS transport (grpc+tls:) | ✓ | ✓ | ✓ | | ✓ | | |
++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+
+| UCX_ transport (ucx:) | ✓ | | | | | | |
++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+
+
+Supported features in the gRPC transport:
+
++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+
+| Flight RPC Feature | C++ | Java | Go | JavaScript | C# | Rust | Julia |
++============================================+=======+=======+=======+============+=======+=======+=======+
+| All RPC methods | ✓ | ✓ | ✓ | | × (1) | ✓ | |
++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+
+| Authentication handlers | ✓ | ✓ | ✓ | | ✓ (2) | ✓ | |
++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+
+| Call timeouts | ✓ | ✓ | ✓ | | | ✓ | |
++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+
+| Call cancellation | ✓ | ✓ | ✓ | | | ✓ | |
++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+
+| Concurrent client calls (3) | ✓ | ✓ | ✓ | | ✓ | ✓ | |
++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+
+| Custom middleware | ✓ | ✓ | ✓ | | | ✓ | |
++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+
+| RPC error codes | ✓ | ✓ | ✓ | | ✓ | ✓ | |
++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+
+
+Supported features in the UCX transport:
+
++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+
+| Flight RPC Feature | C++ | Java | Go | JavaScript | C# | Rust | Julia |
++============================================+=======+=======+=======+============+=======+=======+=======+
+| All RPC methods | × (4) | | | | | | |
++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+
+| Authentication handlers | | | | | | | |
++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+
+| Call timeouts | | | | | | | |
++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+
+| Call cancellation | | | | | | | |
++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+
+| Concurrent client calls | ✓ (5) | | | | | | |
++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+
+| Custom middleware | | | | | | | |
++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+
+| RPC error codes | ✓ | | | | | | |
++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+
Notes:
* \(1) No support for handshake or DoExchange.
* \(2) Support using AspNetCore authentication handlers.
+* \(3) Whether a single client can support multiple concurrent calls.
+* \(4) Only support for DoExchange, DoGet, DoPut, and GetFlightInfo.
+* \(5) Each concurrent call is a separate connection to the server
+ (unlike gRPC where concurrent calls are multiplexed over a single
+ connection). This will generally provide better throughput but
+ consumes more resources both on the server and the client.
.. seealso::
The :ref:`flight-rpc` specification.
+.. _gRPC: https://grpc.io/
+.. _UCX: https://openucx.org/
C Data Interface
================