You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by li...@apache.org on 2022/07/29 13:02:54 UTC
[arrow] branch flight-sql-jdbc updated: ARROW-15452: [FlightRPC][Java] JDBC driver for Flight SQL (#12830)
This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch flight-sql-jdbc
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/flight-sql-jdbc by this push:
new e7a72519cf ARROW-15452: [FlightRPC][Java] JDBC driver for Flight SQL (#12830)
e7a72519cf is described below
commit e7a72519cf37dfa9a736254572f4165e8838d364
Author: James Duong <du...@gmail.com>
AuthorDate: Fri Jul 29 06:02:42 2022 -0700
ARROW-15452: [FlightRPC][Java] JDBC driver for Flight SQL (#12830)
This implements a JDBC driver able to communicate to Flight SQL sources.
So far this covers:
Metadata retrieval by DatabaseMetadata, ResultSetMetadata, etc.
Query execution by statements and prepared statements
Yet to be done:
Parameter binding on prepared statements
Lead-authored-by: James Duong <du...@gmail.com>
Co-authored-by: iurysalino <iu...@gmail.com>
Co-authored-by: Jose Almeida <al...@gmail.com>
Co-authored-by: Gabriel Escobar <ga...@gmail.com>
Co-authored-by: Vinicius Fraga <sx...@gmail.com>
Co-authored-by: Vinicius Fraga <62...@users.noreply.github.com>
Co-authored-by: Sanjiban Sengupta <sa...@gmail.com>
Co-authored-by: Rafael Telles <ra...@telles.dev>
Co-authored-by: Jose Almeida <53...@users.noreply.github.com>
Co-authored-by: Yibo Cai <yi...@arm.com>
Signed-off-by: David Li <li...@gmail.com>
---
.dockerignore | 1 +
.env | 3 +-
.github/workflows/cpp.yml | 9 +
.github/workflows/go.yml | 8 +-
.github/workflows/r.yml | 36 +-
.github/workflows/ruby.yml | 3 +
c_glib/arrow-flight-glib/client.cpp | 19 +-
ci/docker/debian-10-go.dockerfile | 2 +-
ci/docker/debian-11-go.dockerfile | 2 +-
ci/docker/ubuntu-18.04-verify-rc.dockerfile | 40 +-
ci/docker/ubuntu-20.04-verify-rc.dockerfile | 26 +-
...kerfile => ubuntu-22.04-cpp-minimal.dockerfile} | 113 +-
...-cpp.dockerfile => ubuntu-22.04-cpp.dockerfile} | 61 +-
.../docker/ubuntu-22.04-verify-rc.dockerfile | 13 +-
ci/scripts/go_build.sh | 2 +-
ci/scripts/install_gcs_testbench.sh | 5 +-
ci/scripts/msys2_setup.sh | 3 -
ci/scripts/python_wheel_windows_test.bat | 24 +-
cpp/CMakeLists.txt | 1 +
cpp/build-support/lsan-suppressions.txt | 7 +
cpp/cmake_modules/BuildUtils.cmake | 1 +
cpp/cmake_modules/DefineOptions.cmake | 4 +
cpp/cmake_modules/FindArrow.cmake | 2 +-
cpp/cmake_modules/ThirdpartyToolchain.cmake | 32 +-
cpp/examples/arrow/engine_substrait_consumption.cc | 4 +
.../arrow/execution_plan_documentation_examples.cc | 51 +-
cpp/examples/arrow/flight_sql_example.cc | 2 +-
cpp/src/arrow/CMakeLists.txt | 4 +
cpp/src/arrow/compute/exec/options.h | 17 +-
cpp/src/arrow/compute/exec/plan_test.cc | 52 +-
cpp/src/arrow/compute/exec/sink_node.cc | 20 +-
cpp/src/arrow/compute/kernels/codegen_internal.cc | 24 +-
cpp/src/arrow/compute/kernels/codegen_internal.h | 3 +-
.../arrow/compute/kernels/codegen_internal_test.cc | 98 +-
cpp/src/arrow/compute/kernels/scalar_arithmetic.cc | 10 +-
.../arrow/compute/kernels/scalar_cast_nested.cc | 58 +-
cpp/src/arrow/compute/kernels/scalar_cast_test.cc | 358 +++-
.../arrow/compute/kernels/scalar_temporal_test.cc | 78 +
cpp/src/arrow/dataset/file_base.cc | 32 +-
cpp/src/arrow/dataset/file_base.h | 10 +-
cpp/src/arrow/dataset/file_csv.cc | 30 +-
cpp/src/arrow/dataset/file_ipc.cc | 32 +-
cpp/src/arrow/dataset/file_parquet.cc | 8 +-
cpp/src/arrow/dataset/scanner.cc | 34 +-
cpp/src/arrow/dataset/scanner_test.cc | 19 +-
cpp/src/arrow/engine/substrait/extension_set.cc | 2 +-
cpp/src/arrow/engine/substrait/serde_test.cc | 26 +
cpp/src/arrow/filesystem/gcsfs.cc | 3 +-
cpp/src/arrow/flight/CMakeLists.txt | 10 +
cpp/src/arrow/flight/client.cc | 147 +-
cpp/src/arrow/flight/client.h | 130 +-
cpp/src/arrow/flight/flight_benchmark.cc | 59 +-
cpp/src/arrow/flight/flight_test.cc | 242 +--
.../flight/integration_tests/test_integration.cc | 10 +-
.../integration_tests/test_integration_client.cc | 17 +-
cpp/src/arrow/flight/perf_server.cc | 36 +-
cpp/src/arrow/flight/sql/client.cc | 4 +-
cpp/src/arrow/flight/sql/client.h | 41 +-
cpp/src/arrow/flight/sql/column_metadata.cc | 11 +
cpp/src/arrow/flight/sql/column_metadata.h | 24 +-
cpp/src/arrow/flight/sql/server.h | 131 +-
cpp/src/arrow/flight/sql/server_test.cc | 3 +-
cpp/src/arrow/flight/sql/test_app_cli.cc | 3 +-
cpp/src/arrow/flight/sql/types.h | 1235 +++++++------
cpp/src/arrow/flight/test_definitions.cc | 268 ++-
cpp/src/arrow/flight/test_util.h | 4 +-
cpp/src/arrow/flight/transport/ucx/CMakeLists.txt | 77 +
.../transport/ucx/flight_transport_ucx_test.cc | 386 +++++
cpp/src/arrow/flight/transport/ucx/ucx.cc | 45 +
cpp/src/arrow/flight/transport/ucx/ucx.h | 35 +
cpp/src/arrow/flight/transport/ucx/ucx_client.cc | 733 ++++++++
cpp/src/arrow/flight/transport/ucx/ucx_internal.cc | 1171 +++++++++++++
cpp/src/arrow/flight/transport/ucx/ucx_internal.h | 354 ++++
cpp/src/arrow/flight/transport/ucx/ucx_server.cc | 628 +++++++
.../arrow/flight/transport/ucx/util_internal.cc | 289 ++++
cpp/src/arrow/flight/transport/ucx/util_internal.h | 83 +
cpp/src/arrow/flight/transport_server.cc | 5 +-
cpp/src/arrow/python/arrow_to_pandas.cc | 6 +-
cpp/src/arrow/util/config.h.cmake | 1 +
cpp/src/arrow/util/future.cc | 14 +
cpp/src/arrow/util/thread_pool.cc | 29 +
cpp/src/arrow/util/thread_pool_test.cc | 2 +-
cpp/src/arrow/util/tracing_internal.h | 84 +-
cpp/src/gandiva/engine.cc | 4 +
cpp/src/parquet/arrow/reader.cc | 13 +
cpp/thirdparty/versions.txt | 10 +-
cpp/valgrind.supp | 18 +
dev/archery/archery/crossbow/core.py | 9 +-
dev/archery/archery/crossbow/tests/test_core.py | 26 +-
dev/release/binary-task.rb | 1 +
dev/release/post-09-docs.sh | 30 +-
dev/release/setup-ubuntu.sh | 54 +-
dev/release/verify-release-candidate.sh | 4 +-
.../apt/ubuntu-jammy/Dockerfile | 36 +-
.../apache-arrow/apt/ubuntu-jammy-arm64/from | 5 +-
.../apache-arrow/apt/ubuntu-jammy/Dockerfile | 84 +
dev/tasks/linux-packages/package-task.rb | 2 +
dev/tasks/r/github.macos.brew.yml | 1 +
dev/tasks/tasks.yml | 14 +-
docker-compose.yml | 10 +-
docs/source/cpp/api.rst | 1 +
.../cpp/{getting_started.rst => api/flightsql.rst} | 58 +-
docs/source/cpp/flight.rst | 35 +
docs/source/cpp/getting_started.rst | 1 +
docs/source/cpp/orc.rst | 183 ++
docs/source/cpp/streaming_execution.rst | 21 +
docs/source/developers/cpp/windows.rst | 14 +
.../developers/guide/tutorials/r_tutorial.rst | 22 +-
docs/source/format/FlightSql.rst | 15 +
docs/source/format/FlightSql/CommandGetTables.mmd | 29 +
.../format/FlightSql/CommandGetTables.mmd.svg | 1 +
.../FlightSql/CommandPreparedStatementQuery.mmd | 39 +
.../CommandPreparedStatementQuery.mmd.svg | 1 +
.../format/FlightSql/CommandStatementQuery.mmd | 31 +
.../format/FlightSql/CommandStatementQuery.mmd.svg | 1 +
docs/source/index.rst | 2 +-
docs/source/java/cdata.rst | 223 +++
docs/source/java/index.rst | 3 +
docs/source/java/install.rst | 167 +-
docs/source/java/overview.rst | 98 ++
docs/source/java/quickstartguide.rst | 316 ++++
docs/source/java/vector_schema_root.rst | 44 +-
docs/source/python/api/dataset.rst | 1 +
docs/source/python/api/files.rst | 9 +-
docs/source/python/api/filesystems.rst | 4 +-
docs/source/python/api/flight.rst | 7 +
docs/source/python/api/tables.rst | 1 +
docs/source/python/data.rst | 24 +
docs/source/python/index.rst | 1 +
docs/source/python/orc.rst | 184 ++
docs/source/status.rst | 78 +-
format/FlightSql.proto | 59 +-
go/arrow/flight/server.go | 12 +-
.../flight-jdbc-driver/jdbc-spotbugs-exclude.xml | 40 +
java/flight/flight-jdbc-driver/pom.xml | 362 ++++
.../arrow/driver/jdbc/ArrowDatabaseMetadata.java | 1218 +++++++++++++
.../arrow/driver/jdbc/ArrowFlightConnection.java | 189 ++
.../driver/jdbc/ArrowFlightInfoStatement.java} | 38 +-
.../arrow/driver/jdbc/ArrowFlightJdbcArray.java | 178 ++
.../ArrowFlightJdbcConnectionPoolDataSource.java | 125 ++
.../arrow/driver/jdbc/ArrowFlightJdbcCursor.java | 102 ++
.../driver/jdbc/ArrowFlightJdbcDataSource.java | 134 ++
.../arrow/driver/jdbc/ArrowFlightJdbcDriver.java | 254 +++
.../arrow/driver/jdbc/ArrowFlightJdbcFactory.java | 124 ++
.../jdbc/ArrowFlightJdbcFlightStreamResultSet.java | 250 +++
.../jdbc/ArrowFlightJdbcPooledConnection.java | 112 ++
.../arrow/driver/jdbc/ArrowFlightJdbcTime.java | 106 ++
.../ArrowFlightJdbcVectorSchemaRootResultSet.java | 153 ++
.../arrow/driver/jdbc/ArrowFlightMetaImpl.java | 200 +++
.../driver/jdbc/ArrowFlightPreparedStatement.java | 98 ++
.../arrow/driver/jdbc/ArrowFlightStatement.java | 60 +
.../jdbc/accessor/ArrowFlightJdbcAccessor.java | 256 +++
.../accessor/ArrowFlightJdbcAccessorFactory.java | 214 +++
.../impl/ArrowFlightJdbcNullVectorAccessor.java} | 40 +-
.../ArrowFlightJdbcBinaryVectorAccessor.java | 137 ++
.../ArrowFlightJdbcDateVectorAccessor.java | 137 ++
.../calendar/ArrowFlightJdbcDateVectorGetter.java | 67 +
.../ArrowFlightJdbcDurationVectorAccessor.java | 54 +
.../ArrowFlightJdbcIntervalVectorAccessor.java | 126 ++
.../ArrowFlightJdbcTimeStampVectorAccessor.java | 185 ++
.../ArrowFlightJdbcTimeStampVectorGetter.java | 156 ++
.../ArrowFlightJdbcTimeVectorAccessor.java | 159 ++
.../calendar/ArrowFlightJdbcTimeVectorGetter.java | 89 +
.../AbstractArrowFlightJdbcListVectorAccessor.java | 73 +
...AbstractArrowFlightJdbcUnionVectorAccessor.java | 259 +++
.../ArrowFlightJdbcDenseUnionVectorAccessor.java | 66 +
...ArrowFlightJdbcFixedSizeListVectorAccessor.java | 70 +
.../ArrowFlightJdbcLargeListVectorAccessor.java | 70 +
.../complex/ArrowFlightJdbcListVectorAccessor.java | 70 +
.../complex/ArrowFlightJdbcMapVectorAccessor.java | 92 +
.../ArrowFlightJdbcStructVectorAccessor.java | 75 +
.../ArrowFlightJdbcUnionVectorAccessor.java | 64 +
.../ArrowFlightJdbcBaseIntVectorAccessor.java | 203 +++
.../numeric/ArrowFlightJdbcBitVectorAccessor.java | 117 ++
.../ArrowFlightJdbcDecimalVectorAccessor.java | 136 ++
.../ArrowFlightJdbcFloat4VectorAccessor.java | 133 ++
.../ArrowFlightJdbcFloat8VectorAccessor.java | 131 ++
.../impl/numeric/ArrowFlightJdbcNumericGetter.java | 216 +++
.../text/ArrowFlightJdbcVarCharVectorAccessor.java | 258 +++
.../jdbc/client/ArrowFlightSqlClientHandler.java | 582 +++++++
.../client/utils/ClientAuthenticationUtils.java | 251 +++
.../utils/ArrowFlightConnectionConfigImpl.java | 286 +++
.../arrow/driver/jdbc/utils/ConnectionWrapper.java | 344 ++++
.../arrow/driver/jdbc/utils/ConvertUtils.java | 116 ++
.../arrow/driver/jdbc/utils/DateTimeUtils.java | 76 +
.../arrow/driver/jdbc/utils/FlightStreamQueue.java | 237 +++
.../driver/jdbc/utils/IntervalStringUtils.java | 84 +
.../apache/arrow/driver/jdbc/utils/SqlTypes.java | 164 ++
.../apache/arrow/driver/jdbc/utils/UrlParser.java | 49 +
.../jdbc/utils/VectorSchemaRootTransformer.java | 154 ++
.../resources/META-INF/services/java.sql.Driver | 15 +
.../driver/jdbc/ArrowDatabaseMetadataTest.java | 1422 +++++++++++++++
.../driver/jdbc/ArrowFlightJdbcArrayTest.java | 173 ++
.../jdbc/ArrowFlightJdbcConnectionCookieTest.java | 54 +
...rrowFlightJdbcConnectionPoolDataSourceTest.java | 137 ++
.../driver/jdbc/ArrowFlightJdbcCursorTest.java | 251 +++
.../driver/jdbc/ArrowFlightJdbcDriverTest.java | 353 ++++
.../driver/jdbc/ArrowFlightJdbcFactoryTest.java | 88 +
.../arrow/driver/jdbc/ArrowFlightJdbcTimeTest.java | 80 +
.../jdbc/ArrowFlightPreparedStatementTest.java | 78 +
.../jdbc/ArrowFlightStatementExecuteTest.java | 173 ++
.../ArrowFlightStatementExecuteUpdateTest.java | 216 +++
.../apache/arrow/driver/jdbc/ConnectionTest.java | 552 ++++++
.../arrow/driver/jdbc/ConnectionTlsTest.java | 446 +++++
.../arrow/driver/jdbc/FlightServerTestRule.java | 365 ++++
.../arrow/driver/jdbc/ResultSetMetadataTest.java | 236 +++
.../apache/arrow/driver/jdbc/ResultSetTest.java | 372 ++++
.../arrow/driver/jdbc/TokenAuthenticationTest.java | 66 +
.../ArrowFlightJdbcAccessorFactoryTest.java | 496 ++++++
.../jdbc/accessor/ArrowFlightJdbcAccessorTest.java | 358 ++++
.../ArrowFlightJdbcNullVectorAccessorTest.java | 38 +
.../ArrowFlightJdbcBinaryVectorAccessorTest.java | 244 +++
.../ArrowFlightJdbcDateVectorAccessorTest.java | 254 +++
.../ArrowFlightJdbcDurationVectorAccessorTest.java | 115 ++
.../ArrowFlightJdbcIntervalVectorAccessorTest.java | 249 +++
...ArrowFlightJdbcTimeStampVectorAccessorTest.java | 322 ++++
.../ArrowFlightJdbcTimeVectorAccessorTest.java | 263 +++
.../AbstractArrowFlightJdbcListAccessorTest.java | 185 ++
...ractArrowFlightJdbcUnionVectorAccessorTest.java | 265 +++
...rrowFlightJdbcDenseUnionVectorAccessorTest.java | 126 ++
.../ArrowFlightJdbcMapVectorAccessorTest.java | 221 +++
.../ArrowFlightJdbcStructVectorAccessorTest.java | 209 +++
.../ArrowFlightJdbcUnionVectorAccessorTest.java | 118 ++
.../ArrowFlightJdbcBaseIntVectorAccessorTest.java | 171 ++
...rowFlightJdbcBaseIntVectorAccessorUnitTest.java | 213 +++
.../ArrowFlightJdbcBitVectorAccessorTest.java | 155 ++
.../ArrowFlightJdbcDecimalVectorAccessorTest.java | 248 +++
.../ArrowFlightJdbcFloat4VectorAccessorTest.java | 206 +++
.../ArrowFlightJdbcFloat8VectorAccessorTest.java | 187 ++
.../ArrowFlightJdbcVarCharVectorAccessorTest.java | 733 ++++++++
.../driver/jdbc/authentication/Authentication.java | 37 +
.../jdbc/authentication/TokenAuthentication.java | 73 +
.../authentication/UserPasswordAuthentication.java | 73 +
.../utils/ClientAuthenticationUtilsTest.java | 146 ++
.../arrow/driver/jdbc/utils/AccessorTestUtils.java | 141 ++
.../utils/ArrowFlightConnectionConfigImplTest.java | 96 +
.../utils/ArrowFlightConnectionPropertyTest.java | 90 +
.../driver/jdbc/utils/ConnectionWrapperTest.java | 443 +++++
.../arrow/driver/jdbc/utils/ConvertUtilsTest.java | 119 ++
.../driver/jdbc/utils/CoreMockedSqlProducers.java | 298 ++++
.../arrow/driver/jdbc/utils/DateTimeUtilsTest.java | 102 ++
.../jdbc/utils/FlightSqlTestCertificates.java | 77 +
.../driver/jdbc/utils/FlightStreamQueueTest.java | 86 +
.../driver/jdbc/utils/MockFlightSqlProducer.java | 539 ++++++
.../driver/jdbc/utils/ResultSetTestUtils.java | 213 +++
.../driver/jdbc/utils/RootAllocatorTestRule.java | 820 +++++++++
.../arrow/driver/jdbc/utils/SqlTypesTest.java | 123 ++
.../driver/jdbc/utils/ThrowableAssertionUtils.java | 57 +
.../utils/VectorSchemaRootTransformerTest.java | 119 ++
.../src/test/resources/keys/keyStore.jks | Bin 0 -> 1537 bytes
.../src/test/resources/keys/noCertificate.jks | Bin 0 -> 2545 bytes
.../arrow/flight/sql/FlightSqlColumnMetadata.java | 19 +
.../org/apache/arrow/flight/TestFlightSql.java | 7 +
.../arrow/flight/sql/example/FlightSqlExample.java | 2 +
java/flight/pom.xml | 1 +
java/pom.xml | 16 +-
java/vector/pom.xml | 4 +
.../org/apache/arrow/vector/types/pojo/Schema.java | 5 +
.../arrow/vector/util/JsonStringArrayList.java | 8 +-
.../arrow/vector/util/JsonStringHashMap.java | 8 +-
...StringHashMap.java => ObjectMapperFactory.java} | 34 +-
js/package.json | 2 +-
js/src/ipc/metadata/message.ts | 3 +
js/src/recordbatch.ts | 2 +-
js/src/table.ts | 2 +-
js/src/util/bn.ts | 22 +-
js/src/visitor/bytelength.ts | 53 +-
js/src/visitor/get.ts | 8 +-
js/test/unit/table-tests.ts | 12 +-
js/yarn.lock | 28 +-
python/CMakeLists.txt | 1 +
python/pyarrow/__init__.py | 3 +-
python/pyarrow/_dataset.pxd | 14 +
python/pyarrow/_dataset.pyx | 59 +-
python/pyarrow/_exec_plan.pyx | 77 +-
python/pyarrow/_flight.pyx | 227 ++-
python/pyarrow/array.pxi | 51 +
python/pyarrow/compute.py | 1 -
python/pyarrow/includes/libarrow.pxd | 1 +
python/pyarrow/includes/libarrow_dataset.pxd | 14 +-
python/pyarrow/includes/libarrow_flight.pxd | 44 +-
python/pyarrow/ipc.pxi | 49 +-
python/pyarrow/lib.pyx | 14 +-
python/pyarrow/orc.py | 5 +-
python/pyarrow/parquet.py | 49 +-
python/pyarrow/table.pxi | 1828 +++++++++++++++++++-
python/pyarrow/tests/parquet/test_dataset.py | 15 +
.../pyarrow/tests/read_record_batch.py | 12 +-
python/pyarrow/tests/test_cffi.py | 12 +-
python/pyarrow/tests/test_dataset.py | 45 +-
python/pyarrow/tests/test_exec_plan.py | 45 +-
python/pyarrow/tests/test_filesystem.py | 12 +
python/pyarrow/tests/test_ipc.py | 21 +-
python/pyarrow/tests/test_table.py | 3 +
python/pyarrow/types.pxi | 313 +++-
python/pyarrow/util.py | 5 +-
r/DESCRIPTION | 1 +
r/NAMESPACE | 10 +
r/NEWS.md | 7 +-
r/R/arrow-package.R | 21 +-
r/R/arrowExports.R | 49 +-
r/R/dataset-write.R | 4 +
r/R/dplyr-funcs-datetime.R | 42 +
r/R/dplyr-mutate.R | 17 +-
r/R/dplyr-select.R | 4 +-
r/R/extension.R | 545 ++++++
r/R/query-engine.R | 4 +-
r/_pkgdown.yml | 5 +
r/data-raw/codegen.R | 6 +-
r/man/ExtensionArray.Rd | 23 +
r/man/ExtensionType.Rd | 48 +
r/man/new_extension_type.Rd | 167 ++
r/man/vctrs_extension_array.Rd | 50 +
r/pkgdown/extra.js | 40 +-
r/src/array.cpp | 2 +
r/src/array_to_vector.cpp | 33 +
r/src/arrowExports.cpp | 206 ++-
r/src/compute-exec.cpp | 34 +-
r/src/datatype.cpp | 2 +
r/src/extension-impl.cpp | 198 +++
r/src/extension.h | 75 +
r/src/safe-call-into-r-impl.cpp | 89 +
r/src/safe-call-into-r.h | 145 ++
r/src/type_infer.cpp | 6 +-
r/tests/testthat/_snaps/extension.md | 10 +
r/tests/testthat/test-Array.R | 27 +-
r/tests/testthat/test-dataset-write.R | 18 +
r/tests/testthat/test-dplyr-funcs-datetime.R | 35 +-
r/tests/testthat/test-dplyr-mutate.R | 34 +-
r/tests/testthat/test-extension.R | 345 ++++
r/tests/testthat/test-query-engine.R | 7 +-
r/tests/testthat/test-safe-call-into-r.R | 60 +
.../red-arrow/test/values/test-dictionary-array.rb | 295 ++++
333 files changed, 37213 insertions(+), 2166 deletions(-)
diff --git a/.dockerignore b/.dockerignore
index 554cb34f1e..e062e582ff 100644
--- a/.dockerignore
+++ b/.dockerignore
@@ -28,6 +28,7 @@
!ci/**
!c_glib/Gemfile
!dev/archery/setup.py
+!dev/release/setup-*.sh
!docs/requirements*.txt
!python/requirements*.txt
!python/manylinux1/**
diff --git a/.env b/.env
index 6f0fa2808b..a972654497 100644
--- a/.env
+++ b/.env
@@ -61,7 +61,8 @@ GO=1.16
HDFS=3.2.1
JDK=8
KARTOTHEK=latest
-LLVM=12
+# LLVM 12 and GCC 11 reports -Wmismatched-new-delete.
+LLVM=13
MAVEN=3.5.4
NODE=16
NUMPY=latest
diff --git a/.github/workflows/cpp.yml b/.github/workflows/cpp.yml
index 7401fc489c..00f9e335f5 100644
--- a/.github/workflows/cpp.yml
+++ b/.github/workflows/cpp.yml
@@ -126,6 +126,7 @@ jobs:
ARROW_DATASET: ON
ARROW_FLIGHT: ON
ARROW_GANDIVA: ON
+ ARROW_GCS: ON
ARROW_HDFS: ON
ARROW_HOME: /usr/local
ARROW_JEMALLOC: ON
@@ -141,6 +142,8 @@ jobs:
ARROW_WITH_SNAPPY: ON
ARROW_WITH_ZLIB: ON
ARROW_WITH_ZSTD: ON
+ # System Abseil installed by Homebrew uses C++ 17
+ CMAKE_CXX_STANDARD: 17
steps:
- name: Checkout Arrow
uses: actions/checkout@v2
@@ -153,6 +156,9 @@ jobs:
rm -f /usr/local/bin/2to3
brew update --preinstall
brew bundle --file=cpp/Brewfile
+ - name: Install Google Cloud Storage Testbench
+ shell: bash
+ run: ci/scripts/install_gcs_testbench.sh default
- name: Setup ccache
run: |
ci/scripts/ccache_setup.sh
@@ -268,6 +274,9 @@ jobs:
ARROW_DATASET: ON
ARROW_FLIGHT: ON
ARROW_GANDIVA: ON
+ # google-could-cpp uses _dupenv_s() but it can't be used with msvcrt.
+ # We need to use ucrt to use _dupenv_s().
+ # ARROW_GCS: ON
ARROW_HDFS: OFF
ARROW_HOME: /mingw${{ matrix.mingw-n-bits }}
ARROW_JEMALLOC: OFF
diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml
index b4c9cfaa37..a35c6a3697 100644
--- a/.github/workflows/go.yml
+++ b/.github/workflows/go.yml
@@ -152,7 +152,7 @@ jobs:
fetch-depth: 0
submodules: recursive
- name: Install staticcheck
- run: go install honnef.co/go/tools/cmd/staticcheck@latest
+ run: go install honnef.co/go/tools/cmd/staticcheck@v0.2.2
- name: Build
shell: bash
run: ci/scripts/go_build.sh $(pwd)
@@ -180,7 +180,7 @@ jobs:
fetch-depth: 0
submodules: recursive
- name: Install staticcheck
- run: go install honnef.co/go/tools/cmd/staticcheck@latest
+ run: go install honnef.co/go/tools/cmd/staticcheck@v0.2.2
- name: Build
shell: bash
run: ci/scripts/go_build.sh $(pwd)
@@ -213,7 +213,7 @@ jobs:
shell: bash
run: brew install apache-arrow
- name: Install staticcheck
- run: go install honnef.co/go/tools/cmd/staticcheck@latest
+ run: go install honnef.co/go/tools/cmd/staticcheck@v0.2.2
- name: Build
shell: bash
run: ci/scripts/go_build.sh $(pwd)
@@ -268,7 +268,7 @@ jobs:
with:
go-version: '1.17'
- name: Install staticcheck
- run: go install honnef.co/go/tools/cmd/staticcheck@latest
+ run: go install honnef.co/go/tools/cmd/staticcheck@v0.2.2
- name: Build
shell: bash
run: ci/scripts/go_build.sh $(pwd)
diff --git a/.github/workflows/r.yml b/.github/workflows/r.yml
index 16b36cfae6..1f27d4fa3b 100644
--- a/.github/workflows/r.yml
+++ b/.github/workflows/r.yml
@@ -236,7 +236,7 @@ jobs:
name: AMD64 Windows R ${{ matrix.config.rversion }} RTools ${{ matrix.config.rtools }}
runs-on: windows-2019
if: ${{ !contains(github.event.pull_request.title, 'WIP') }}
- timeout-minutes: 60
+ timeout-minutes: 75
strategy:
fail-fast: false
matrix:
@@ -329,12 +329,37 @@ jobs:
if: ${{ matrix.config.rtools == 35 }}
shell: Rscript {0}
run: install.packages("cpp11", type = "source")
+ - name: Prune dependencies (on R 3.6)
+ if: ${{ matrix.config.rtools == 35 }}
+ shell: Rscript {0}
+ run: |
+ # To prevent the build from timing out, let's prune some optional deps (and their possible version requirements)
+ setwd("r")
+ # create a backup to use later
+ file.copy("DESCRIPTION", "DESCRIPTION.bak")
+ d <- read.dcf("DESCRIPTION")
+ to_prune <- c("duckdb", "DBI", "dbplyr", "decor", "knitr", "rmarkdown", "pkgload", "reticulate")
+ pattern <- paste0("\\n?", to_prune, "( \\([^,]*\\))?,?", collapse = "|")
+ d[,"Suggests"] <- gsub(pattern, "", d[,"Suggests"])
+ write.dcf(d, "DESCRIPTION")
+ - name: R package cache
+ uses: actions/cache@v2
+ with:
+ path: |
+ ${{ env.R_LIBS_USER }}/*
+ key: r-${{ matrix.config.rtools }}-R-LIBS-${{ hashFiles('r/DESCRIPTION') }}
+ restore-keys: r-${{ matrix.config.rtools }}-R-LIBS-
- name: Install R package dependencies
shell: Rscript {0}
run: |
# options(pkgType="win.binary") # processx doesn't have a binary for UCRT yet
install.packages(c("remotes", "rcmdcheck"))
remotes::install_deps("r", dependencies = TRUE)
+ - name: Restore DESCRIPTION for 3.6
+ if: ${{ matrix.config.rtools == 35 }}
+ run: |
+ rm r/DESCRIPTION
+ mv r/DESCRIPTION.bak r/DESCRIPTION
- name: Check
shell: Rscript {0}
run: |
@@ -357,6 +382,15 @@ jobs:
check_dir = 'check',
timeout = 3600
)
+ - name: Run lintr
+ env:
+ NOT_CRAN: "true"
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ shell: Rscript {0}
+ working-directory: r
+ run: |
+ remotes::install_github("jonkeane/lintr@arrow-branch")
+ lintr::expect_lint_free()
- name: Dump install logs
shell: cmd
run: cat r/check/arrow.Rcheck/00install.out
diff --git a/.github/workflows/ruby.yml b/.github/workflows/ruby.yml
index 0770f2f32a..54292eafc8 100644
--- a/.github/workflows/ruby.yml
+++ b/.github/workflows/ruby.yml
@@ -90,6 +90,7 @@ jobs:
ulimit -c unlimited
archery docker run \
-e ARROW_FLIGHT=ON \
+ -e ARROW_GCS=ON \
-e Protobuf_SOURCE=BUNDLED \
-e gRPC_SOURCE=BUNDLED \
ubuntu-ruby
@@ -110,6 +111,7 @@ jobs:
ARROW_BUILD_TESTS: OFF
ARROW_FLIGHT: ON
ARROW_GANDIVA: ON
+ ARROW_GCS: ON
ARROW_GLIB_GTK_DOC: true
ARROW_GLIB_WERROR: true
ARROW_HOME: /usr/local
@@ -188,6 +190,7 @@ jobs:
ARROW_BUILD_TYPE: release
ARROW_FLIGHT: ON
ARROW_GANDIVA: ON
+ ARROW_GCS: ON
ARROW_HDFS: OFF
ARROW_HOME: /ucrt${{ matrix.mingw-n-bits }}
ARROW_JEMALLOC: OFF
diff --git a/c_glib/arrow-flight-glib/client.cpp b/c_glib/arrow-flight-glib/client.cpp
index b4de6468c6..0d1961e6c6 100644
--- a/c_glib/arrow-flight-glib/client.cpp
+++ b/c_glib/arrow-flight-glib/client.cpp
@@ -251,12 +251,11 @@ gaflight_client_new(GAFlightLocation *location,
arrow::Status status;
if (options) {
const auto flight_options = gaflight_client_options_get_raw(options);
- status = arrow::flight::FlightClient::Connect(*flight_location,
- *flight_options,
- &flight_client);
+ auto result = arrow::flight::FlightClient::Connect(*flight_location, *flight_options);
+ status = std::move(result).Value(&flight_client);
} else {
- status = arrow::flight::FlightClient::Connect(*flight_location,
- &flight_client);
+ auto result = arrow::flight::FlightClient::Connect(*flight_location);
+ status = std::move(result).Value(&flight_client);
}
if (garrow::check(error, status, "[flight-client][new]")) {
return gaflight_client_new_raw(flight_client.release());
@@ -315,9 +314,8 @@ gaflight_client_list_flights(GAFlightClient *client,
flight_options = gaflight_call_options_get_raw(options);
}
std::unique_ptr<arrow::flight::FlightListing> flight_listing;
- auto status = flight_client->ListFlights(*flight_options,
- *flight_criteria,
- &flight_listing);
+ auto result = flight_client->ListFlights(*flight_options, *flight_criteria);
+ auto status = std::move(result).Value(&flight_listing);
if (!garrow::check(error,
status,
"[flight-client][list-flights]")) {
@@ -369,9 +367,8 @@ gaflight_client_do_get(GAFlightClient *client,
flight_options = gaflight_call_options_get_raw(options);
}
std::unique_ptr<arrow::flight::FlightStreamReader> flight_reader;
- auto status = flight_client->DoGet(*flight_options,
- *flight_ticket,
- &flight_reader);
+ auto result = flight_client->DoGet(*flight_options, *flight_ticket);
+ auto status = std::move(result).Value(&flight_reader);
if (garrow::check(error,
status,
"[flight-client][do-get]")) {
diff --git a/ci/docker/debian-10-go.dockerfile b/ci/docker/debian-10-go.dockerfile
index f12cf83db5..f0c0522081 100644
--- a/ci/docker/debian-10-go.dockerfile
+++ b/ci/docker/debian-10-go.dockerfile
@@ -19,7 +19,7 @@ ARG arch=amd64
ARG go=1.15
FROM ${arch}/golang:${go}-buster
-RUN GO111MODULE=on go install honnef.co/go/tools/cmd/staticcheck@latest
+RUN GO111MODULE=on go install honnef.co/go/tools/cmd/staticcheck@v0.2.2
# TODO(kszucs):
# 1. add the files required to install the dependencies to .dockerignore
diff --git a/ci/docker/debian-11-go.dockerfile b/ci/docker/debian-11-go.dockerfile
index 64271b49ce..33f523e36a 100644
--- a/ci/docker/debian-11-go.dockerfile
+++ b/ci/docker/debian-11-go.dockerfile
@@ -19,7 +19,7 @@ ARG arch=amd64
ARG go=1.16
FROM ${arch}/golang:${go}-bullseye
-RUN GO111MODULE=on go install honnef.co/go/tools/cmd/staticcheck@latest
+RUN GO111MODULE=on go install honnef.co/go/tools/cmd/staticcheck@v0.2.2
# TODO(kszucs):
# 1. add the files required to install the dependencies to .dockerignore
diff --git a/ci/docker/ubuntu-18.04-verify-rc.dockerfile b/ci/docker/ubuntu-18.04-verify-rc.dockerfile
index 88a74b6003..d210e01d9d 100644
--- a/ci/docker/ubuntu-18.04-verify-rc.dockerfile
+++ b/ci/docker/ubuntu-18.04-verify-rc.dockerfile
@@ -19,42 +19,8 @@ ARG arch=amd64
FROM ${arch}/ubuntu:18.04
ENV DEBIAN_FRONTEND=noninteractive
-
-ARG llvm=12
-RUN apt-get update -y -q && \
- apt-get install -y -q --no-install-recommends \
- apt-transport-https \
- ca-certificates \
- gnupg \
- wget && \
- wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key | apt-key add - && \
- echo "deb https://apt.llvm.org/bionic/ llvm-toolchain-bionic-${llvm} main" > \
- /etc/apt/sources.list.d/llvm.list && \
- apt-get update -y -q && \
- apt-get install -y -q --no-install-recommends \
- build-essential \
- clang \
- cmake \
- curl \
- git \
- libcurl4-openssl-dev \
- libgirepository1.0-dev \
- libglib2.0-dev \
- libsqlite3-dev \
- libssl-dev \
- llvm-${llvm}-dev \
- maven \
- ninja-build \
- openjdk-11-jdk \
- pkg-config \
- python3-pip \
- python3.8-dev \
- python3.8-venv \
- ruby-dev \
- wget \
- tzdata && \
+COPY dev/release/setup-ubuntu.sh /
+RUN /setup-ubuntu.sh && \
+ rm /setup-ubuntu.sh && \
apt-get clean && \
rm -rf /var/lib/apt/lists*
-
-RUN python3.8 -m pip install -U pip && \
- update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.8 1
diff --git a/ci/docker/ubuntu-20.04-verify-rc.dockerfile b/ci/docker/ubuntu-20.04-verify-rc.dockerfile
index e3415bd745..cee1e50e08 100644
--- a/ci/docker/ubuntu-20.04-verify-rc.dockerfile
+++ b/ci/docker/ubuntu-20.04-verify-rc.dockerfile
@@ -19,28 +19,8 @@ ARG arch=amd64
FROM ${arch}/ubuntu:20.04
ENV DEBIAN_FRONTEND=noninteractive
-RUN apt-get update -y -q && \
- apt-get install -y -q --no-install-recommends \
- build-essential \
- clang \
- cmake \
- curl \
- git \
- libcurl4-openssl-dev \
- libgirepository1.0-dev \
- libglib2.0-dev \
- libsqlite3-dev \
- libssl-dev \
- llvm-dev \
- maven \
- ninja-build \
- nlohmann-json3-dev \
- openjdk-11-jdk \
- pkg-config \
- python3-dev \
- python3-pip \
- python3-venv \
- ruby-dev \
- wget && \
+COPY dev/release/setup-ubuntu.sh /
+RUN /setup-ubuntu.sh && \
+ rm /setup-ubuntu.sh && \
apt-get clean && \
rm -rf /var/lib/apt/lists*
diff --git a/ci/docker/ubuntu-21.04-cpp.dockerfile b/ci/docker/ubuntu-22.04-cpp-minimal.dockerfile
similarity index 50%
copy from ci/docker/ubuntu-21.04-cpp.dockerfile
copy to ci/docker/ubuntu-22.04-cpp-minimal.dockerfile
index ff0979ea64..8bc5ab3e48 100644
--- a/ci/docker/ubuntu-21.04-cpp.dockerfile
+++ b/ci/docker/ubuntu-22.04-cpp-minimal.dockerfile
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
-ARG base=amd64/ubuntu:21.04
+ARG base=amd64/ubuntu:22.04
FROM ${base}
SHELL ["/bin/bash", "-o", "pipefail", "-c"]
@@ -23,15 +23,26 @@ SHELL ["/bin/bash", "-o", "pipefail", "-c"]
RUN echo "debconf debconf/frontend select Noninteractive" | \
debconf-set-selections
+RUN apt-get update -y -q && \
+ apt-get install -y -q \
+ build-essential \
+ ccache \
+ cmake \
+ git \
+ libssl-dev \
+ libcurl4-openssl-dev \
+ python3-pip \
+ wget && \
+ apt-get clean && \
+ rm -rf /var/lib/apt/lists*
+
# Installs LLVM toolchain, for Gandiva and testing other compilers
#
# Note that this is installed before the base packages to improve iteration
# while debugging package list with docker build.
-ARG clang_tools
ARG llvm
-RUN latest_system_llvm=12 && \
- if [ ${llvm} -gt ${latest_system_llvm} -o \
- ${clang_tools} -gt ${latest_system_llvm} ]; then \
+RUN latest_system_llvm=14 && \
+ if [ ${llvm} -gt ${latest_system_llvm} ]; then \
apt-get update -y -q && \
apt-get install -y -q --no-install-recommends \
apt-transport-https \
@@ -44,85 +55,26 @@ RUN latest_system_llvm=12 && \
if [ ${llvm} -gt 10 ]; then \
echo "deb https://apt.llvm.org/${code_name}/ llvm-toolchain-${code_name}-${llvm} main" > \
/etc/apt/sources.list.d/llvm.list; \
- fi && \
- if [ ${clang_tools} -ne ${llvm} -a \
- ${clang_tools} -gt ${latest_system_llvm} ]; then \
- echo "deb https://apt.llvm.org/${code_name}/ llvm-toolchain-${code_name}-${clang_tools} main" > \
- /etc/apt/sources.list.d/clang-tools.list; \
fi; \
fi && \
apt-get update -y -q && \
apt-get install -y -q --no-install-recommends \
- clang-${clang_tools} \
clang-${llvm} \
- clang-format-${clang_tools} \
- clang-tidy-${clang_tools} \
llvm-${llvm}-dev && \
apt-get clean && \
rm -rf /var/lib/apt/lists*
-# Installs C++ toolchain and dependencies
-RUN apt-get update -y -q && \
- apt-get install -y -q --no-install-recommends \
- autoconf \
- ca-certificates \
- ccache \
- cmake \
- gdb \
- git \
- libbenchmark-dev \
- libboost-filesystem-dev \
- libboost-system-dev \
- libbrotli-dev \
- libbz2-dev \
- libc-ares-dev \
- libcurl4-openssl-dev \
- libgflags-dev \
- libgoogle-glog-dev \
- libgrpc++-dev \
- liblz4-dev \
- libprotobuf-dev \
- libprotoc-dev \
- libre2-dev \
- libsnappy-dev \
- libssl-dev \
- libsqlite3-dev \
- libthrift-dev \
- libutf8proc-dev \
- libzstd-dev \
- make \
- ninja-build \
- nlohmann-json3-dev \
- pkg-config \
- protobuf-compiler \
- protobuf-compiler-grpc \
- python3-pip \
- rapidjson-dev \
- rsync \
- tzdata \
- wget && \
- apt-get clean && \
- rm -rf /var/lib/apt/lists*
-
COPY ci/scripts/install_minio.sh /arrow/ci/scripts/
RUN /arrow/ci/scripts/install_minio.sh latest /usr/local
COPY ci/scripts/install_gcs_testbench.sh /arrow/ci/scripts/
RUN /arrow/ci/scripts/install_gcs_testbench.sh default
-# Prioritize system packages and local installation
-# The following dependencies will be downloaded due to missing/invalid packages
-# provided by the distribution:
-# - libc-ares-dev does not install CMake config files
-# - flatbuffer is not packaged
-# - libgtest-dev only provide sources
-# - libprotobuf-dev only provide sources
ENV ARROW_BUILD_TESTS=ON \
- ARROW_DEPENDENCY_SOURCE=SYSTEM \
ARROW_DATASET=ON \
ARROW_FLIGHT=ON \
- ARROW_FLIGHT_SQL=ON \
ARROW_GANDIVA=ON \
+ ARROW_GCS=ON \
ARROW_HDFS=ON \
ARROW_HOME=/usr/local \
ARROW_INSTALL_NAME_RPATH=OFF \
@@ -131,9 +83,7 @@ ENV ARROW_BUILD_TESTS=ON \
ARROW_PARQUET=ON \
ARROW_PLASMA=ON \
ARROW_S3=ON \
- ARROW_USE_ASAN=OFF \
ARROW_USE_CCACHE=ON \
- ARROW_USE_UBSAN=OFF \
ARROW_WITH_BROTLI=ON \
ARROW_WITH_BZ2=ON \
ARROW_WITH_LZ4=ON \
@@ -141,35 +91,8 @@ ENV ARROW_BUILD_TESTS=ON \
ARROW_WITH_SNAPPY=ON \
ARROW_WITH_ZLIB=ON \
ARROW_WITH_ZSTD=ON \
- AWSSDK_SOURCE=BUNDLED \
- GTest_SOURCE=BUNDLED \
- ORC_SOURCE=BUNDLED \
+ CMAKE_GENERATOR="Unix Makefiles" \
PARQUET_BUILD_EXAMPLES=ON \
PARQUET_BUILD_EXECUTABLES=ON \
- Protobuf_SOURCE=BUNDLED \
PATH=/usr/lib/ccache/:$PATH \
PYTHON=python3
-
-ARG gcc_version=""
-RUN if [ "${gcc_version}" = "" ]; then \
- apt-get update -y -q && \
- apt-get install -y -q --no-install-recommends \
- g++ \
- gcc; \
- else \
- if [ "${gcc_version}" -gt "10" ]; then \
- apt-get update -y -q && \
- apt-get install -y -q --no-install-recommends software-properties-common && \
- add-apt-repository ppa:ubuntu-toolchain-r/volatile; \
- fi; \
- apt-get update -y -q && \
- apt-get install -y -q --no-install-recommends \
- g++-${gcc_version} \
- gcc-${gcc_version} && \
- update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-${gcc_version} 100 && \
- update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-${gcc_version} 100 && \
- update-alternatives --install /usr/bin/cc cc /usr/bin/gcc 100 && \
- update-alternatives --set cc /usr/bin/gcc && \
- update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++ 100 && \
- update-alternatives --set c++ /usr/bin/g++; \
- fi
diff --git a/ci/docker/ubuntu-21.04-cpp.dockerfile b/ci/docker/ubuntu-22.04-cpp.dockerfile
similarity index 91%
rename from ci/docker/ubuntu-21.04-cpp.dockerfile
rename to ci/docker/ubuntu-22.04-cpp.dockerfile
index ff0979ea64..92d802f876 100644
--- a/ci/docker/ubuntu-21.04-cpp.dockerfile
+++ b/ci/docker/ubuntu-22.04-cpp.dockerfile
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
-ARG base=amd64/ubuntu:21.04
+ARG base=amd64/ubuntu:22.04
FROM ${base}
SHELL ["/bin/bash", "-o", "pipefail", "-c"]
@@ -29,7 +29,7 @@ RUN echo "debconf debconf/frontend select Noninteractive" | \
# while debugging package list with docker build.
ARG clang_tools
ARG llvm
-RUN latest_system_llvm=12 && \
+RUN latest_system_llvm=14 && \
if [ ${llvm} -gt ${latest_system_llvm} -o \
${clang_tools} -gt ${latest_system_llvm} ]; then \
apt-get update -y -q && \
@@ -96,6 +96,7 @@ RUN apt-get update -y -q && \
pkg-config \
protobuf-compiler \
protobuf-compiler-grpc \
+ python3-dev \
python3-pip \
rapidjson-dev \
rsync \
@@ -104,6 +105,38 @@ RUN apt-get update -y -q && \
apt-get clean && \
rm -rf /var/lib/apt/lists*
+ARG gcc_version=""
+RUN if [ "${gcc_version}" = "" ]; then \
+ apt-get update -y -q && \
+ apt-get install -y -q --no-install-recommends \
+ g++ \
+ gcc; \
+ else \
+ if [ "${gcc_version}" -gt "11" ]; then \
+ apt-get update -y -q && \
+ apt-get install -y -q --no-install-recommends software-properties-common && \
+ add-apt-repository ppa:ubuntu-toolchain-r/volatile; \
+ fi; \
+ apt-get update -y -q && \
+ apt-get install -y -q --no-install-recommends \
+ g++-${gcc_version} \
+ gcc-${gcc_version} && \
+ update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-${gcc_version} 100 && \
+ update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-${gcc_version} 100 && \
+ update-alternatives --install \
+ /usr/bin/$(uname --machine)-linux-gnu-gcc \
+ $(uname --machine)-linux-gnu-gcc \
+ /usr/bin/$(uname --machine)-linux-gnu-gcc-${gcc_version} 100 && \
+ update-alternatives --install \
+ /usr/bin/$(uname --machine)-linux-gnu-g++ \
+ $(uname --machine)-linux-gnu-g++ \
+ /usr/bin/$(uname --machine)-linux-gnu-g++-${gcc_version} 100 && \
+ update-alternatives --install /usr/bin/cc cc /usr/bin/gcc 100 && \
+ update-alternatives --set cc /usr/bin/gcc && \
+ update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++ 100 && \
+ update-alternatives --set c++ /usr/bin/g++; \
+ fi
+
COPY ci/scripts/install_minio.sh /arrow/ci/scripts/
RUN /arrow/ci/scripts/install_minio.sh latest /usr/local
@@ -149,27 +182,3 @@ ENV ARROW_BUILD_TESTS=ON \
Protobuf_SOURCE=BUNDLED \
PATH=/usr/lib/ccache/:$PATH \
PYTHON=python3
-
-ARG gcc_version=""
-RUN if [ "${gcc_version}" = "" ]; then \
- apt-get update -y -q && \
- apt-get install -y -q --no-install-recommends \
- g++ \
- gcc; \
- else \
- if [ "${gcc_version}" -gt "10" ]; then \
- apt-get update -y -q && \
- apt-get install -y -q --no-install-recommends software-properties-common && \
- add-apt-repository ppa:ubuntu-toolchain-r/volatile; \
- fi; \
- apt-get update -y -q && \
- apt-get install -y -q --no-install-recommends \
- g++-${gcc_version} \
- gcc-${gcc_version} && \
- update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-${gcc_version} 100 && \
- update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-${gcc_version} 100 && \
- update-alternatives --install /usr/bin/cc cc /usr/bin/gcc 100 && \
- update-alternatives --set cc /usr/bin/gcc && \
- update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++ 100 && \
- update-alternatives --set c++ /usr/bin/g++; \
- fi
diff --git a/cpp/build-support/lsan-suppressions.txt b/ci/docker/ubuntu-22.04-verify-rc.dockerfile
similarity index 78%
copy from cpp/build-support/lsan-suppressions.txt
copy to ci/docker/ubuntu-22.04-verify-rc.dockerfile
index 566857a9c0..8bc6f39b67 100644
--- a/cpp/build-support/lsan-suppressions.txt
+++ b/ci/docker/ubuntu-22.04-verify-rc.dockerfile
@@ -15,7 +15,12 @@
# specific language governing permissions and limitations
# under the License.
-# False positive from atexit() registration in libc
-leak:*__new_exitfn*
-# Leak at shutdown in OpenSSL
-leak:CRYPTO_zalloc
+ARG arch=amd64
+FROM ${arch}/ubuntu:22.04
+
+ENV DEBIAN_FRONTEND=noninteractive
+COPY dev/release/setup-ubuntu.sh /
+RUN /setup-ubuntu.sh && \
+ rm /setup-ubuntu.sh && \
+ apt-get clean && \
+ rm -rf /var/lib/apt/lists*
diff --git a/ci/scripts/go_build.sh b/ci/scripts/go_build.sh
index 72ac7a9319..20879cc0e7 100755
--- a/ci/scripts/go_build.sh
+++ b/ci/scripts/go_build.sh
@@ -25,7 +25,7 @@ ARCH=`uname -m`
# Arm64 CI is triggered by travis and run in arm64v8/golang:1.16-bullseye
if [ "aarch64" == "$ARCH" ]; then
# Install `staticcheck`
- GO111MODULE=on go install honnef.co/go/tools/cmd/staticcheck@latest
+ GO111MODULE=on go install honnef.co/go/tools/cmd/staticcheck@v0.2.2
fi
pushd ${source_dir}/arrow
diff --git a/ci/scripts/install_gcs_testbench.sh b/ci/scripts/install_gcs_testbench.sh
index 16e3c4042c..0282e0fda5 100755
--- a/ci/scripts/install_gcs_testbench.sh
+++ b/ci/scripts/install_gcs_testbench.sh
@@ -31,7 +31,8 @@ fi
version=$1
if [[ "${version}" -eq "default" ]]; then
- version="v0.7.0"
+ version="v0.16.0"
fi
-pip install "https://github.com/googleapis/storage-testbench/archive/${version}.tar.gz"
+${PYTHON:-python3} -m pip install \
+ "https://github.com/googleapis/storage-testbench/archive/${version}.tar.gz"
diff --git a/ci/scripts/msys2_setup.sh b/ci/scripts/msys2_setup.sh
index cf2f8e17d9..b7401546ff 100755
--- a/ci/scripts/msys2_setup.sh
+++ b/ci/scripts/msys2_setup.sh
@@ -36,13 +36,10 @@ case "${target}" in
packages+=(${MINGW_PACKAGE_PREFIX}-gtest)
packages+=(${MINGW_PACKAGE_PREFIX}-libutf8proc)
packages+=(${MINGW_PACKAGE_PREFIX}-libxml2)
- packages+=(${MINGW_PACKAGE_PREFIX}-llvm)
packages+=(${MINGW_PACKAGE_PREFIX}-lz4)
packages+=(${MINGW_PACKAGE_PREFIX}-make)
- packages+=(${MINGW_PACKAGE_PREFIX}-mlir)
packages+=(${MINGW_PACKAGE_PREFIX}-ninja)
packages+=(${MINGW_PACKAGE_PREFIX}-nlohmann-json)
- packages+=(${MINGW_PACKAGE_PREFIX}-polly)
packages+=(${MINGW_PACKAGE_PREFIX}-protobuf)
packages+=(${MINGW_PACKAGE_PREFIX}-python3-numpy)
packages+=(${MINGW_PACKAGE_PREFIX}-rapidjson)
diff --git a/ci/scripts/python_wheel_windows_test.bat b/ci/scripts/python_wheel_windows_test.bat
index f2b46940af..498d08954b 100755
--- a/ci/scripts/python_wheel_windows_test.bat
+++ b/ci/scripts/python_wheel_windows_test.bat
@@ -35,21 +35,21 @@ set ARROW_TEST_DATA=C:\arrow\testing\data
set PARQUET_TEST_DATA=C:\arrow\submodules\parquet-testing\data
@REM Install testing dependencies
-pip install -r C:\arrow\python\requirements-wheel-test.txt || exit /B
+pip install -r C:\arrow\python\requirements-wheel-test.txt || exit /B 1
@REM Install the built wheels
-python -m pip install --no-index --find-links=C:\arrow\python\dist\ pyarrow || exit /B
+python -m pip install --no-index --find-links=C:\arrow\python\dist\ pyarrow || exit /B 1
@REM Test that the modules are importable
-python -c "import pyarrow" || exit /B
-python -c "import pyarrow._hdfs" || exit /B
-python -c "import pyarrow._s3fs" || exit /B
-python -c "import pyarrow.csv" || exit /B
-python -c "import pyarrow.dataset" || exit /B
-python -c "import pyarrow.flight" || exit /B
-python -c "import pyarrow.fs" || exit /B
-python -c "import pyarrow.json" || exit /B
-python -c "import pyarrow.parquet" || exit /B
+python -c "import pyarrow" || exit /B 1
+python -c "import pyarrow._hdfs" || exit /B 1
+python -c "import pyarrow._s3fs" || exit /B 1
+python -c "import pyarrow.csv" || exit /B 1
+python -c "import pyarrow.dataset" || exit /B 1
+python -c "import pyarrow.flight" || exit /B 1
+python -c "import pyarrow.fs" || exit /B 1
+python -c "import pyarrow.json" || exit /B 1
+python -c "import pyarrow.parquet" || exit /B 1
@REM Execute unittest
-pytest -r s --pyargs pyarrow || exit /B
+pytest -r s --pyargs pyarrow || exit /B 1
diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt
index ecbc351641..ffa5cc5660 100644
--- a/cpp/CMakeLists.txt
+++ b/cpp/CMakeLists.txt
@@ -111,6 +111,7 @@ set(ARROW_CMAKE_INSTALL_DIR "${CMAKE_INSTALL_LIBDIR}/cmake/${PROJECT_NAME}")
set(ARROW_DOC_DIR "share/doc/${PROJECT_NAME}")
set(ARROW_LLVM_VERSIONS
+ "14.0"
"13.0"
"12.0"
"11.1"
diff --git a/cpp/build-support/lsan-suppressions.txt b/cpp/build-support/lsan-suppressions.txt
index 566857a9c0..a8918e10d9 100644
--- a/cpp/build-support/lsan-suppressions.txt
+++ b/cpp/build-support/lsan-suppressions.txt
@@ -19,3 +19,10 @@
leak:*__new_exitfn*
# Leak at shutdown in OpenSSL
leak:CRYPTO_zalloc
+
+# OpenTelemetry. These seem like false positives and go away if the
+# CPU thread pool is manually shut down before exit.
+# Note that ASan has trouble backtracing these and may not be able to
+# without LSAN_OPTIONS=fast_unwind_on_malloc=0:malloc_context_size=100
+leak:opentelemetry::v1::context::ThreadLocalContextStorage::GetStack
+leak:opentelemetry::v1::context::ThreadLocalContextStorage::Stack::Resize
diff --git a/cpp/cmake_modules/BuildUtils.cmake b/cpp/cmake_modules/BuildUtils.cmake
index 2bcaed43e5..174b1c515a 100644
--- a/cpp/cmake_modules/BuildUtils.cmake
+++ b/cpp/cmake_modules/BuildUtils.cmake
@@ -445,6 +445,7 @@ function(ADD_ARROW_LIB LIB_NAME)
if(ARROW_BUILD_STATIC AND WIN32)
target_compile_definitions(${LIB_NAME}_static PUBLIC ARROW_STATIC)
+ target_compile_definitions(${LIB_NAME}_static PUBLIC ARROW_FLIGHT_STATIC)
endif()
set_target_properties(${LIB_NAME}_static
diff --git a/cpp/cmake_modules/DefineOptions.cmake b/cpp/cmake_modules/DefineOptions.cmake
index 05fc14bbc7..ec1e0b6352 100644
--- a/cpp/cmake_modules/DefineOptions.cmake
+++ b/cpp/cmake_modules/DefineOptions.cmake
@@ -391,6 +391,10 @@ if("${CMAKE_SOURCE_DIR}" STREQUAL "${CMAKE_CURRENT_SOURCE_DIR}")
define_option(ARROW_WITH_ZLIB "Build with zlib compression" OFF)
define_option(ARROW_WITH_ZSTD "Build with zstd compression" OFF)
+ define_option(ARROW_WITH_UCX
+ "Build with UCX transport for Arrow Flight;(only used if ARROW_FLIGHT is ON)"
+ OFF)
+
define_option(ARROW_WITH_UTF8PROC
"Build with support for Unicode properties using the utf8proc library;(only used if ARROW_COMPUTE is ON or ARROW_GANDIVA is ON)"
ON)
diff --git a/cpp/cmake_modules/FindArrow.cmake b/cpp/cmake_modules/FindArrow.cmake
index 68024cc276..9d2faaf581 100644
--- a/cpp/cmake_modules/FindArrow.cmake
+++ b/cpp/cmake_modules/FindArrow.cmake
@@ -36,7 +36,7 @@ if(DEFINED ARROW_FOUND)
return()
endif()
-include(FindPkgConfig)
+find_package(PkgConfig)
include(FindPackageHandleStandardArgs)
if(WIN32 AND NOT MINGW)
diff --git a/cpp/cmake_modules/ThirdpartyToolchain.cmake b/cpp/cmake_modules/ThirdpartyToolchain.cmake
index 35bd80be3e..3a0353bc7d 100644
--- a/cpp/cmake_modules/ThirdpartyToolchain.cmake
+++ b/cpp/cmake_modules/ThirdpartyToolchain.cmake
@@ -2593,8 +2593,9 @@ endmacro()
# ----------------------------------------------------------------------
# Dependencies for Arrow Flight RPC
-macro(build_absl_once)
- if(NOT TARGET absl_ep)
+macro(resolve_dependency_absl)
+ # Choose one of built absl::* targets
+ if(NOT TARGET absl::algorithm)
message(STATUS "Building Abseil-cpp from source")
set(ABSL_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/absl_ep-install")
set(ABSL_INCLUDE_DIR "${ABSL_PREFIX}/include")
@@ -3451,6 +3452,7 @@ macro(build_absl_once)
# Work around https://gitlab.kitware.com/cmake/cmake/issues/15052
file(MAKE_DIRECTORY ${ABSL_INCLUDE_DIR})
+ set(ABSL_VENDORED TRUE)
endif()
endmacro()
@@ -3464,8 +3466,8 @@ macro(build_grpc)
get_target_property(c-ares_INCLUDE_DIR c-ares::cares INTERFACE_INCLUDE_DIRECTORIES)
include_directories(SYSTEM ${c-ares_INCLUDE_DIR})
- # First need to build Abseil
- build_absl_once()
+ # First need Abseil
+ resolve_dependency_absl()
message(STATUS "Building gRPC from source")
@@ -3498,7 +3500,9 @@ macro(build_grpc)
add_custom_target(grpc_dependencies)
- add_dependencies(grpc_dependencies absl_ep)
+ if(ABSL_VENDORED)
+ add_dependencies(grpc_dependencies absl_ep)
+ endif()
if(CARES_VENDORED)
add_dependencies(grpc_dependencies cares_ep)
endif()
@@ -3819,7 +3823,7 @@ macro(build_google_cloud_cpp_storage)
message(STATUS "Only building the google-cloud-cpp::storage component")
# List of dependencies taken from https://github.com/googleapis/google-cloud-cpp/blob/master/doc/packaging.md
- build_absl_once()
+ resolve_dependency_absl()
build_crc32c_once()
# Curl is required on all platforms, but building it internally might also trip over S3's copy.
@@ -3830,7 +3834,9 @@ macro(build_google_cloud_cpp_storage)
# Build google-cloud-cpp, with only storage_client
# Inject vendored packages via CMAKE_PREFIX_PATH
- list(APPEND GOOGLE_CLOUD_CPP_PREFIX_PATH_LIST ${ABSL_PREFIX})
+ if(ABSL_VENDORED)
+ list(APPEND GOOGLE_CLOUD_CPP_PREFIX_PATH_LIST ${ABSL_PREFIX})
+ endif()
list(APPEND GOOGLE_CLOUD_CPP_PREFIX_PATH_LIST ${CRC32C_PREFIX})
list(APPEND GOOGLE_CLOUD_CPP_PREFIX_PATH_LIST ${NLOHMANN_JSON_PREFIX})
@@ -3852,14 +3858,19 @@ macro(build_google_cloud_cpp_storage)
# Compile only the storage library and its dependencies. To enable
# other services (Spanner, Bigtable, etc.) add them (as a list) to this
# parameter. Each has its own `google-cloud-cpp::*` library.
- -DGOOGLE_CLOUD_CPP_ENABLE=storage)
+ -DGOOGLE_CLOUD_CPP_ENABLE=storage
+ # We need this to build with OpenSSL 3.0.
+ # See also: https://github.com/googleapis/google-cloud-cpp/issues/8544
+ -DGOOGLE_CLOUD_CPP_ENABLE_WERROR=OFF)
if(OPENSSL_ROOT_DIR)
list(APPEND GOOGLE_CLOUD_CPP_CMAKE_ARGS -DOPENSSL_ROOT_DIR=${OPENSSL_ROOT_DIR})
endif()
add_custom_target(google_cloud_cpp_dependencies)
- add_dependencies(google_cloud_cpp_dependencies absl_ep)
+ if(ABSL_VENDORED)
+ add_dependencies(google_cloud_cpp_dependencies absl_ep)
+ endif()
add_dependencies(google_cloud_cpp_dependencies crc32c_ep)
add_dependencies(google_cloud_cpp_dependencies nlohmann_json::nlohmann_json)
@@ -3900,7 +3911,8 @@ macro(build_google_cloud_cpp_storage)
absl::memory
absl::optional
absl::time
- Threads::Threads)
+ Threads::Threads
+ OpenSSL::Crypto)
add_library(google-cloud-cpp::storage STATIC IMPORTED)
set_target_properties(google-cloud-cpp::storage
diff --git a/cpp/examples/arrow/engine_substrait_consumption.cc b/cpp/examples/arrow/engine_substrait_consumption.cc
index b0109b3688..d74f674965 100644
--- a/cpp/examples/arrow/engine_substrait_consumption.cc
+++ b/cpp/examples/arrow/engine_substrait_consumption.cc
@@ -40,6 +40,10 @@ class IgnoringConsumer : public cp::SinkNodeConsumer {
public:
explicit IgnoringConsumer(size_t tag) : tag_{tag} {}
+ arrow::Status Init(const std::shared_ptr<arrow::Schema>& schema) override {
+ return arrow::Status::OK();
+ }
+
arrow::Status Consume(cp::ExecBatch batch) override {
// Consume a batch of data
// (just print its row count to stdout)
diff --git a/cpp/examples/arrow/execution_plan_documentation_examples.cc b/cpp/examples/arrow/execution_plan_documentation_examples.cc
index 81cdcef530..1ca3d36a34 100644
--- a/cpp/examples/arrow/execution_plan_documentation_examples.cc
+++ b/cpp/examples/arrow/execution_plan_documentation_examples.cc
@@ -591,6 +591,10 @@ arrow::Status SourceConsumingSinkExample(cp::ExecContext& exec_context) {
CustomSinkNodeConsumer(std::atomic<uint32_t>* batches_seen, arrow::Future<> finish)
: batches_seen(batches_seen), finish(std::move(finish)) {}
+ arrow::Status Init(const std::shared_ptr<arrow::Schema>& schema) override {
+ return arrow::Status::OK();
+ }
+
arrow::Status Consume(cp::ExecBatch batch) override {
(*batches_seen)++;
return arrow::Status::OK();
@@ -794,7 +798,7 @@ arrow::Status ScanFilterWriteExample(cp::ExecContext& exec_context,
write_options.partitioning = partitioning;
write_options.basename_template = "part{i}.parquet";
- arrow::dataset::WriteNodeOptions write_node_options{write_options, dataset->schema()};
+ arrow::dataset::WriteNodeOptions write_node_options{write_options};
ARROW_RETURN_NOT_OK(cp::MakeExecNode("write", plan.get(), {scan}, write_node_options));
@@ -851,6 +855,46 @@ arrow::Status SourceUnionSinkExample(cp::ExecContext& exec_context) {
// (Doc section: Union Example)
+// (Doc section: Table Sink Example)
+
+/// \brief An example showing a table sink node
+/// \param exec_context The execution context to run the plan in
+///
+/// TableSink Example
+/// This example shows how a table_sink can be used
+/// in an execution plan. This includes a source node
+/// receiving data as batches and the table sink node
+/// which emits the output as a table.
+arrow::Status TableSinkExample(cp::ExecContext& exec_context) {
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<cp::ExecPlan> plan,
+ cp::ExecPlan::Make(&exec_context));
+
+ ARROW_ASSIGN_OR_RAISE(auto basic_data, MakeBasicBatches());
+
+ auto source_node_options = cp::SourceNodeOptions{basic_data.schema, basic_data.gen()};
+
+ ARROW_ASSIGN_OR_RAISE(cp::ExecNode * source,
+ cp::MakeExecNode("source", plan.get(), {}, source_node_options));
+
+ std::shared_ptr<arrow::Table> output_table;
+ auto table_sink_options = cp::TableSinkNodeOptions{&output_table};
+
+ ARROW_RETURN_NOT_OK(
+ cp::MakeExecNode("table_sink", plan.get(), {source}, table_sink_options));
+ // validate the ExecPlan
+ ARROW_RETURN_NOT_OK(plan->Validate());
+ std::cout << "ExecPlan created : " << plan->ToString() << std::endl;
+ // start the ExecPlan
+ ARROW_RETURN_NOT_OK(plan->StartProducing());
+
+ // Wait for the plan to finish
+ auto finished = plan->finished();
+ RETURN_NOT_OK(finished.status());
+ std::cout << "Results : " << output_table->ToString() << std::endl;
+ return arrow::Status::OK();
+}
+// (Doc section: Table Sink Example)
+
enum ExampleMode {
SOURCE_SINK = 0,
TABLE_SOURCE_SINK = 1,
@@ -865,6 +909,7 @@ enum ExampleMode {
KSELECT = 10,
WRITE = 11,
UNION = 12,
+ TABLE_SOURCE_TABLE_SINK = 13
};
int main(int argc, char** argv) {
@@ -933,6 +978,10 @@ int main(int argc, char** argv) {
PrintBlock("Union Example");
status = SourceUnionSinkExample(exec_context);
break;
+ case TABLE_SOURCE_TABLE_SINK:
+ PrintBlock("TableSink Example");
+ status = TableSinkExample(exec_context);
+ break;
default:
break;
}
diff --git a/cpp/examples/arrow/flight_sql_example.cc b/cpp/examples/arrow/flight_sql_example.cc
index 5dfd97dbf1..1201d78c5e 100644
--- a/cpp/examples/arrow/flight_sql_example.cc
+++ b/cpp/examples/arrow/flight_sql_example.cc
@@ -44,7 +44,7 @@ arrow::Status Main() {
// Set up the Flight SQL client
std::unique_ptr<flight::FlightClient> flight_client;
- ARROW_RETURN_NOT_OK(flight::FlightClient::Connect(location, &flight_client));
+ ARROW_ASSIGN_OR_RAISE(flight_client, flight::FlightClient::Connect(location));
std::unique_ptr<flightsql::FlightSqlClient> client(
new flightsql::FlightSqlClient(std::move(flight_client)));
diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt
index e9e826097b..b6f1e2481f 100644
--- a/cpp/src/arrow/CMakeLists.txt
+++ b/cpp/src/arrow/CMakeLists.txt
@@ -747,6 +747,10 @@ endif()
if(ARROW_FLIGHT)
add_subdirectory(flight)
+
+ if(ARROW_WITH_UCX)
+ add_subdirectory(flight/transport/ucx)
+ endif()
endif()
if(ARROW_FLIGHT_SQL)
diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h
index d578075325..9e99953e87 100644
--- a/cpp/src/arrow/compute/exec/options.h
+++ b/cpp/src/arrow/compute/exec/options.h
@@ -145,6 +145,12 @@ class ARROW_EXPORT SinkNodeOptions : public ExecNodeOptions {
class ARROW_EXPORT SinkNodeConsumer {
public:
virtual ~SinkNodeConsumer() = default;
+ /// \brief Prepare any consumer state
+ ///
+ /// This will be run once the schema is finalized as the plan is starting and
+ /// before any calls to Consume. A common use is to save off the schema so that
+ /// batches can be interpreted.
+ virtual Status Init(const std::shared_ptr<Schema>& schema) = 0;
/// \brief Consume a batch of data
virtual Status Consume(ExecBatch batch) = 0;
/// \brief Signal to the consumer that the last batch has been delivered
@@ -299,21 +305,18 @@ class ARROW_EXPORT SelectKSinkNodeOptions : public SinkNodeOptions {
/// SelectK options
SelectKOptions select_k_options;
};
-
/// @}
-/// \brief Adapt an Table as a sink node
+/// \brief Adapt a Table as a sink node
///
-/// obtains the output of a execution plan to
+/// obtains the output of an execution plan to
/// a table pointer.
class ARROW_EXPORT TableSinkNodeOptions : public ExecNodeOptions {
public:
- TableSinkNodeOptions(std::shared_ptr<Table>* output_table,
- std::shared_ptr<Schema> output_schema)
- : output_table(output_table), output_schema(std::move(output_schema)) {}
+ explicit TableSinkNodeOptions(std::shared_ptr<Table>* output_table)
+ : output_table(output_table) {}
std::shared_ptr<Table>* output_table;
- std::shared_ptr<Schema> output_schema;
};
} // namespace compute
diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc
index e176c701b6..615dec33fa 100644
--- a/cpp/src/arrow/compute/exec/plan_test.cc
+++ b/cpp/src/arrow/compute/exec/plan_test.cc
@@ -493,6 +493,10 @@ TEST(ExecPlanExecution, SourceConsumingSink) {
TestConsumer(std::atomic<uint32_t>* batches_seen, Future<> finish)
: batches_seen(batches_seen), finish(std::move(finish)) {}
+ Status Init(const std::shared_ptr<Schema>& schema) override {
+ return Status::OK();
+ }
+
Status Consume(ExecBatch batch) override {
(*batches_seen)++;
return Status::OK();
@@ -539,7 +543,7 @@ TEST(ExecPlanExecution, SourceTableConsumingSink) {
auto basic_data = MakeBasicBatches();
- TableSinkNodeOptions options{&out, basic_data.schema};
+ TableSinkNodeOptions options{&out};
ASSERT_OK_AND_ASSIGN(
auto source, MakeExecNode("source", plan.get(), {},
@@ -560,16 +564,26 @@ TEST(ExecPlanExecution, SourceTableConsumingSink) {
}
TEST(ExecPlanExecution, ConsumingSinkError) {
+ struct InitErrorConsumer : public SinkNodeConsumer {
+ Status Init(const std::shared_ptr<Schema>& schema) override {
+ return Status::Invalid("XYZ");
+ }
+ Status Consume(ExecBatch batch) override { return Status::OK(); }
+ Future<> Finish() override { return Future<>::MakeFinished(); }
+ };
struct ConsumeErrorConsumer : public SinkNodeConsumer {
+ Status Init(const std::shared_ptr<Schema>& schema) override { return Status::OK(); }
Status Consume(ExecBatch batch) override { return Status::Invalid("XYZ"); }
Future<> Finish() override { return Future<>::MakeFinished(); }
};
struct FinishErrorConsumer : public SinkNodeConsumer {
+ Status Init(const std::shared_ptr<Schema>& schema) override { return Status::OK(); }
Status Consume(ExecBatch batch) override { return Status::OK(); }
Future<> Finish() override { return Future<>::MakeFinished(Status::Invalid("XYZ")); }
};
std::vector<std::shared_ptr<SinkNodeConsumer>> consumers{
- std::make_shared<ConsumeErrorConsumer>(), std::make_shared<FinishErrorConsumer>()};
+ std::make_shared<InitErrorConsumer>(), std::make_shared<ConsumeErrorConsumer>(),
+ std::make_shared<FinishErrorConsumer>()};
for (auto& consumer : consumers) {
ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
@@ -585,35 +599,17 @@ TEST(ExecPlanExecution, ConsumingSinkError) {
SourceNodeOptions(basic_data.schema, basic_data.gen(false, false))));
ASSERT_OK(MakeExecNode("consuming_sink", plan.get(), {source},
ConsumingSinkNodeOptions(consumer)));
- ASSERT_OK(plan->StartProducing());
- ASSERT_FINISHES_AND_RAISES(Invalid, plan->finished());
+ // If we fail at init we see it during StartProducing. Other
+ // failures are not seen until we start running.
+ if (std::dynamic_pointer_cast<InitErrorConsumer>(consumer)) {
+ ASSERT_RAISES(Invalid, plan->StartProducing());
+ } else {
+ ASSERT_OK(plan->StartProducing());
+ ASSERT_FINISHES_AND_RAISES(Invalid, plan->finished());
+ }
}
}
-TEST(ExecPlanExecution, ConsumingSinkErrorFinish) {
- ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
- struct FinishErrorConsumer : public SinkNodeConsumer {
- Status Consume(ExecBatch batch) override { return Status::OK(); }
- Future<> Finish() override { return Future<>::MakeFinished(Status::Invalid("XYZ")); }
- };
- std::shared_ptr<FinishErrorConsumer> consumer = std::make_shared<FinishErrorConsumer>();
-
- auto basic_data = MakeBasicBatches();
- ASSERT_OK(
- Declaration::Sequence(
- {{"source", SourceNodeOptions(basic_data.schema, basic_data.gen(false, false))},
- {"consuming_sink", ConsumingSinkNodeOptions(consumer)}})
- .AddToPlan(plan.get()));
- ASSERT_OK_AND_ASSIGN(
- auto source,
- MakeExecNode("source", plan.get(), {},
- SourceNodeOptions(basic_data.schema, basic_data.gen(false, false))));
- ASSERT_OK(MakeExecNode("consuming_sink", plan.get(), {source},
- ConsumingSinkNodeOptions(consumer)));
- ASSERT_OK(plan->StartProducing());
- ASSERT_FINISHES_AND_RAISES(Invalid, plan->finished());
-}
-
TEST(ExecPlanExecution, StressSourceSink) {
for (bool slow : {false, true}) {
SCOPED_TRACE(slow ? "slowed" : "unslowed");
diff --git a/cpp/src/arrow/compute/exec/sink_node.cc b/cpp/src/arrow/compute/exec/sink_node.cc
index 13564c736b..e981de3899 100644
--- a/cpp/src/arrow/compute/exec/sink_node.cc
+++ b/cpp/src/arrow/compute/exec/sink_node.cc
@@ -178,6 +178,8 @@ class ConsumingSinkNode : public ExecNode {
{{"node.label", label()},
{"node.detail", ToString()},
{"node.kind", kind_name()}});
+ DCHECK_GT(inputs_.size(), 0);
+ RETURN_NOT_OK(consumer_->Init(inputs_[0]->output_schema()));
finished_ = Future<>::Make();
END_SPAN_ON_FUTURE_COMPLETION(span_, finished_, this);
return Status::OK();
@@ -268,13 +270,17 @@ class ConsumingSinkNode : public ExecNode {
struct TableSinkNodeConsumer : public arrow::compute::SinkNodeConsumer {
public:
- TableSinkNodeConsumer(std::shared_ptr<Table>* out,
- std::shared_ptr<Schema> output_schema, MemoryPool* pool)
- : out_(out), output_schema_(std::move(output_schema)), pool_(pool) {}
+ TableSinkNodeConsumer(std::shared_ptr<Table>* out, MemoryPool* pool)
+ : out_(out), pool_(pool) {}
+
+ Status Init(const std::shared_ptr<Schema>& schema) override {
+ schema_ = schema;
+ return Status::OK();
+ }
Status Consume(ExecBatch batch) override {
std::lock_guard<std::mutex> guard(consume_mutex_);
- ARROW_ASSIGN_OR_RAISE(auto rb, batch.ToRecordBatch(output_schema_, pool_));
+ ARROW_ASSIGN_OR_RAISE(auto rb, batch.ToRecordBatch(schema_, pool_));
batches_.push_back(rb);
return Status::OK();
}
@@ -286,8 +292,8 @@ struct TableSinkNodeConsumer : public arrow::compute::SinkNodeConsumer {
private:
std::shared_ptr<Table>* out_;
- std::shared_ptr<Schema> output_schema_;
MemoryPool* pool_;
+ std::shared_ptr<Schema> schema_;
std::vector<std::shared_ptr<RecordBatch>> batches_;
std::mutex consume_mutex_;
};
@@ -298,8 +304,8 @@ static Result<ExecNode*> MakeTableConsumingSinkNode(
RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "TableConsumingSinkNode"));
const auto& sink_options = checked_cast<const TableSinkNodeOptions&>(options);
MemoryPool* pool = plan->exec_context()->memory_pool();
- auto tb_consumer = std::make_shared<TableSinkNodeConsumer>(
- sink_options.output_table, sink_options.output_schema, pool);
+ auto tb_consumer =
+ std::make_shared<TableSinkNodeConsumer>(sink_options.output_table, pool);
auto consuming_sink_node_options = ConsumingSinkNodeOptions{tb_consumer};
return MakeExecNode("consuming_sink", plan, inputs, consuming_sink_node_options);
}
diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc b/cpp/src/arrow/compute/kernels/codegen_internal.cc
index 26c12ebf59..b31ef408b1 100644
--- a/cpp/src/arrow/compute/kernels/codegen_internal.cc
+++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc
@@ -214,45 +214,53 @@ std::shared_ptr<DataType> CommonNumeric(const ValueDescr* begin, size_t count) {
return int8();
}
-TimeUnit::type CommonTemporalResolution(const ValueDescr* begin, size_t count) {
- TimeUnit::type finest_unit = TimeUnit::SECOND;
+bool CommonTemporalResolution(const ValueDescr* begin, size_t count,
+ TimeUnit::type* finest_unit) {
+ bool is_time_unit = false;
+ *finest_unit = TimeUnit::SECOND;
const ValueDescr* end = begin + count;
for (auto it = begin; it != end; it++) {
auto id = it->type->id();
switch (id) {
case Type::DATE32: {
// Date32's unit is days, but the coarsest we have is seconds
+ is_time_unit = true;
continue;
}
case Type::DATE64: {
- finest_unit = std::max(finest_unit, TimeUnit::MILLI);
+ *finest_unit = std::max(*finest_unit, TimeUnit::MILLI);
+ is_time_unit = true;
continue;
}
case Type::TIMESTAMP: {
const auto& ty = checked_cast<const TimestampType&>(*it->type);
- finest_unit = std::max(finest_unit, ty.unit());
+ *finest_unit = std::max(*finest_unit, ty.unit());
+ is_time_unit = true;
continue;
}
case Type::DURATION: {
const auto& ty = checked_cast<const DurationType&>(*it->type);
- finest_unit = std::max(finest_unit, ty.unit());
+ *finest_unit = std::max(*finest_unit, ty.unit());
+ is_time_unit = true;
continue;
}
case Type::TIME32: {
const auto& ty = checked_cast<const Time32Type&>(*it->type);
- finest_unit = std::max(finest_unit, ty.unit());
+ *finest_unit = std::max(*finest_unit, ty.unit());
+ is_time_unit = true;
continue;
}
case Type::TIME64: {
const auto& ty = checked_cast<const Time64Type&>(*it->type);
- finest_unit = std::max(finest_unit, ty.unit());
+ *finest_unit = std::max(*finest_unit, ty.unit());
+ is_time_unit = true;
continue;
}
default:
continue;
}
}
- return finest_unit;
+ return is_time_unit;
}
std::shared_ptr<DataType> CommonTemporal(const ValueDescr* begin, size_t count) {
diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h
index ff7b9161fe..fa50427bc3 100644
--- a/cpp/src/arrow/compute/kernels/codegen_internal.h
+++ b/cpp/src/arrow/compute/kernels/codegen_internal.h
@@ -1394,7 +1394,8 @@ ARROW_EXPORT
std::shared_ptr<DataType> CommonTemporal(const ValueDescr* begin, size_t count);
ARROW_EXPORT
-TimeUnit::type CommonTemporalResolution(const ValueDescr* begin, size_t count);
+bool CommonTemporalResolution(const ValueDescr* begin, size_t count,
+ TimeUnit::type* finest_unit);
ARROW_EXPORT
std::shared_ptr<DataType> CommonBinary(const ValueDescr* begin, size_t count);
diff --git a/cpp/src/arrow/compute/kernels/codegen_internal_test.cc b/cpp/src/arrow/compute/kernels/codegen_internal_test.cc
index 46d31c8ae4..6fb84cf55b 100644
--- a/cpp/src/arrow/compute/kernels/codegen_internal_test.cc
+++ b/cpp/src/arrow/compute/kernels/codegen_internal_test.cc
@@ -137,6 +137,8 @@ TEST(TestDispatchBest, CommonTemporal) {
args = {timestamp(TimeUnit::SECOND, "UTC"), timestamp(TimeUnit::NANO, "UTC")};
AssertTypeEqual(timestamp(TimeUnit::NANO, "UTC"),
CommonTemporal(args.data(), args.size()));
+ args = {timestamp(TimeUnit::SECOND), date32()};
+ AssertTypeEqual(timestamp(TimeUnit::SECOND), CommonTemporal(args.data(), args.size()));
args = {date32(), timestamp(TimeUnit::NANO)};
AssertTypeEqual(timestamp(TimeUnit::NANO), CommonTemporal(args.data(), args.size()));
args = {date64(), timestamp(TimeUnit::SECOND)};
@@ -161,53 +163,80 @@ TEST(TestDispatchBest, CommonTemporal) {
TEST(TestDispatchBest, CommonTemporalResolution) {
std::vector<ValueDescr> args;
std::string tz = "Pacific/Marquesas";
+ TimeUnit::type ty;
args = {date32(), date32()};
- ASSERT_EQ(TimeUnit::SECOND, CommonTemporalResolution(args.data(), args.size()));
+ ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
+ ASSERT_EQ(TimeUnit::SECOND, ty);
args = {date32(), date64()};
- ASSERT_EQ(TimeUnit::MILLI, CommonTemporalResolution(args.data(), args.size()));
+ ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
+ ASSERT_EQ(TimeUnit::MILLI, ty);
+ args = {date32(), timestamp(TimeUnit::SECOND)};
+ ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
+ ASSERT_EQ(TimeUnit::SECOND, ty);
args = {time32(TimeUnit::MILLI), date32()};
- ASSERT_EQ(TimeUnit::MILLI, CommonTemporalResolution(args.data(), args.size()));
+ ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
+ ASSERT_EQ(TimeUnit::MILLI, ty);
args = {time32(TimeUnit::MILLI), time32(TimeUnit::SECOND)};
- ASSERT_EQ(TimeUnit::MILLI, CommonTemporalResolution(args.data(), args.size()));
+ ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
+ ASSERT_EQ(TimeUnit::MILLI, ty);
args = {time32(TimeUnit::MILLI), time64(TimeUnit::MICRO)};
- ASSERT_EQ(TimeUnit::MICRO, CommonTemporalResolution(args.data(), args.size()));
+ ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
+ ASSERT_EQ(TimeUnit::MICRO, ty);
args = {time64(TimeUnit::NANO), time64(TimeUnit::MICRO)};
- ASSERT_EQ(TimeUnit::NANO, CommonTemporalResolution(args.data(), args.size()));
+ ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
+ ASSERT_EQ(TimeUnit::NANO, ty);
args = {duration(TimeUnit::MILLI), duration(TimeUnit::MICRO)};
- ASSERT_EQ(TimeUnit::MICRO, CommonTemporalResolution(args.data(), args.size()));
+ ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
+ ASSERT_EQ(TimeUnit::MICRO, ty);
args = {duration(TimeUnit::MILLI), date32()};
- ASSERT_EQ(TimeUnit::MILLI, CommonTemporalResolution(args.data(), args.size()));
+ ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
+ ASSERT_EQ(TimeUnit::MILLI, ty);
args = {date64(), duration(TimeUnit::SECOND)};
- ASSERT_EQ(TimeUnit::MILLI, CommonTemporalResolution(args.data(), args.size()));
+ ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
+ ASSERT_EQ(TimeUnit::MILLI, ty);
args = {duration(TimeUnit::SECOND), time32(TimeUnit::SECOND)};
- ASSERT_EQ(TimeUnit::SECOND, CommonTemporalResolution(args.data(), args.size()));
+ ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
+ ASSERT_EQ(TimeUnit::SECOND, ty);
args = {duration(TimeUnit::SECOND), time64(TimeUnit::NANO)};
- ASSERT_EQ(TimeUnit::NANO, CommonTemporalResolution(args.data(), args.size()));
+ ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
+ ASSERT_EQ(TimeUnit::NANO, ty);
args = {time64(TimeUnit::MICRO), duration(TimeUnit::NANO)};
- ASSERT_EQ(TimeUnit::NANO, CommonTemporalResolution(args.data(), args.size()));
+ ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
+ ASSERT_EQ(TimeUnit::NANO, ty);
args = {timestamp(TimeUnit::SECOND, tz), timestamp(TimeUnit::MICRO)};
- ASSERT_EQ(TimeUnit::MICRO, CommonTemporalResolution(args.data(), args.size()));
+ ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
+ ASSERT_EQ(TimeUnit::MICRO, ty);
args = {date32(), timestamp(TimeUnit::MICRO, tz)};
- ASSERT_EQ(TimeUnit::MICRO, CommonTemporalResolution(args.data(), args.size()));
+ ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
+ ASSERT_EQ(TimeUnit::MICRO, ty);
args = {timestamp(TimeUnit::MICRO, tz), date64()};
- ASSERT_EQ(TimeUnit::MICRO, CommonTemporalResolution(args.data(), args.size()));
+ ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
+ ASSERT_EQ(TimeUnit::MICRO, ty);
args = {time32(TimeUnit::MILLI), timestamp(TimeUnit::MICRO, tz)};
- ASSERT_EQ(TimeUnit::MICRO, CommonTemporalResolution(args.data(), args.size()));
+ ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
+ ASSERT_EQ(TimeUnit::MICRO, ty);
args = {timestamp(TimeUnit::MICRO, tz), time64(TimeUnit::NANO)};
- ASSERT_EQ(TimeUnit::NANO, CommonTemporalResolution(args.data(), args.size()));
+ ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
+ ASSERT_EQ(TimeUnit::NANO, ty);
args = {timestamp(TimeUnit::SECOND, tz), duration(TimeUnit::MILLI)};
- ASSERT_EQ(TimeUnit::MILLI, CommonTemporalResolution(args.data(), args.size()));
+ ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
+ ASSERT_EQ(TimeUnit::MILLI, ty);
args = {timestamp(TimeUnit::SECOND, "UTC"), timestamp(TimeUnit::SECOND, tz)};
- ASSERT_EQ(TimeUnit::SECOND, CommonTemporalResolution(args.data(), args.size()));
+ ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
+ ASSERT_EQ(TimeUnit::SECOND, ty);
args = {time32(TimeUnit::MILLI), duration(TimeUnit::SECOND)};
- ASSERT_EQ(TimeUnit::MILLI, CommonTemporalResolution(args.data(), args.size()));
+ ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
+ ASSERT_EQ(TimeUnit::MILLI, ty);
args = {time64(TimeUnit::MICRO), duration(TimeUnit::NANO)};
- ASSERT_EQ(TimeUnit::NANO, CommonTemporalResolution(args.data(), args.size()));
+ ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
+ ASSERT_EQ(TimeUnit::NANO, ty);
args = {duration(TimeUnit::SECOND), int64()};
- ASSERT_EQ(TimeUnit::SECOND, CommonTemporalResolution(args.data(), args.size()));
+ ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
+ ASSERT_EQ(TimeUnit::SECOND, ty);
args = {duration(TimeUnit::MILLI), timestamp(TimeUnit::SECOND, tz)};
- ASSERT_EQ(TimeUnit::MILLI, CommonTemporalResolution(args.data(), args.size()));
+ ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
+ ASSERT_EQ(TimeUnit::MILLI, ty);
}
TEST(TestDispatchBest, ReplaceTemporalTypes) {
@@ -216,67 +245,68 @@ TEST(TestDispatchBest, ReplaceTemporalTypes) {
TimeUnit::type ty;
args = {date32(), date32()};
- ty = CommonTemporalResolution(args.data(), args.size());
+ ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
ReplaceTemporalTypes(ty, &args);
AssertTypeEqual(args[0].type, timestamp(TimeUnit::SECOND));
AssertTypeEqual(args[1].type, timestamp(TimeUnit::SECOND));
args = {date64(), time32(TimeUnit::SECOND)};
- ty = CommonTemporalResolution(args.data(), args.size());
+ ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
ReplaceTemporalTypes(ty, &args);
AssertTypeEqual(args[0].type, timestamp(TimeUnit::MILLI));
AssertTypeEqual(args[1].type, time32(TimeUnit::MILLI));
args = {duration(TimeUnit::SECOND), date64()};
- ty = CommonTemporalResolution(args.data(), args.size());
+ ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
ReplaceTemporalTypes(ty, &args);
AssertTypeEqual(args[0].type, duration(TimeUnit::MILLI));
AssertTypeEqual(args[1].type, timestamp(TimeUnit::MILLI));
args = {timestamp(TimeUnit::MICRO, tz), timestamp(TimeUnit::NANO)};
- ty = CommonTemporalResolution(args.data(), args.size());
+ ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
ReplaceTemporalTypes(ty, &args);
AssertTypeEqual(args[0].type, timestamp(TimeUnit::NANO, tz));
AssertTypeEqual(args[1].type, timestamp(TimeUnit::NANO));
args = {timestamp(TimeUnit::MICRO, tz), time64(TimeUnit::NANO)};
- ty = CommonTemporalResolution(args.data(), args.size());
+ ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
ReplaceTemporalTypes(ty, &args);
AssertTypeEqual(args[0].type, timestamp(TimeUnit::NANO, tz));
AssertTypeEqual(args[1].type, time64(TimeUnit::NANO));
args = {timestamp(TimeUnit::SECOND, tz), date64()};
- ty = CommonTemporalResolution(args.data(), args.size());
+ ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
ReplaceTemporalTypes(ty, &args);
AssertTypeEqual(args[0].type, timestamp(TimeUnit::MILLI, tz));
AssertTypeEqual(args[1].type, timestamp(TimeUnit::MILLI));
args = {timestamp(TimeUnit::SECOND, "UTC"), timestamp(TimeUnit::SECOND, tz)};
- ty = CommonTemporalResolution(args.data(), args.size());
+ ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
ReplaceTemporalTypes(ty, &args);
AssertTypeEqual(args[0].type, timestamp(TimeUnit::SECOND, "UTC"));
AssertTypeEqual(args[1].type, timestamp(TimeUnit::SECOND, tz));
args = {time32(TimeUnit::SECOND), duration(TimeUnit::SECOND)};
- ty = CommonTemporalResolution(args.data(), args.size());
+ ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
ReplaceTemporalTypes(ty, &args);
AssertTypeEqual(args[0].type, time32(TimeUnit::SECOND));
AssertTypeEqual(args[1].type, duration(TimeUnit::SECOND));
args = {time64(TimeUnit::MICRO), duration(TimeUnit::SECOND)};
- ty = CommonTemporalResolution(args.data(), args.size());
+ ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
ReplaceTemporalTypes(ty, &args);
AssertTypeEqual(args[0].type, time64(TimeUnit::MICRO));
AssertTypeEqual(args[1].type, duration(TimeUnit::MICRO));
args = {time32(TimeUnit::SECOND), duration(TimeUnit::NANO)};
- ty = CommonTemporalResolution(args.data(), args.size());
+ ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
ReplaceTemporalTypes(ty, &args);
AssertTypeEqual(args[0].type, time64(TimeUnit::NANO));
AssertTypeEqual(args[1].type, duration(TimeUnit::NANO));
args = {duration(TimeUnit::SECOND), int64()};
- ReplaceTemporalTypes(CommonTemporalResolution(args.data(), args.size()), &args);
+ ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
+ ReplaceTemporalTypes(ty, &args);
AssertTypeEqual(args[0].type, duration(TimeUnit::SECOND));
AssertTypeEqual(args[1].type, int64());
}
diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
index bfafb6fcc1..103f8e66c5 100644
--- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
@@ -1818,11 +1818,11 @@ struct ArithmeticFunction : ScalarFunction {
// Only promote types for binary functions
if (values->size() == 2) {
ReplaceNullWithOtherType(values);
-
- if (auto type = CommonTemporalResolution(values->data(), values->size())) {
- ReplaceTemporalTypes(type, values);
- } else if (auto type = CommonNumeric(*values)) {
- ReplaceTypes(type, values);
+ TimeUnit::type finest_unit;
+ if (CommonTemporalResolution(values->data(), values->size(), &finest_unit)) {
+ ReplaceTemporalTypes(finest_unit, values);
+ } else if (auto numeric_type = CommonNumeric(*values)) {
+ ReplaceTypes(numeric_type, values);
}
}
diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_nested.cc b/cpp/src/arrow/compute/kernels/scalar_cast_nested.cc
index cc36c51036..d91bf032e5 100644
--- a/cpp/src/arrow/compute/kernels/scalar_cast_nested.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_cast_nested.cc
@@ -153,27 +153,32 @@ void AddListCast(CastFunction* func) {
struct CastStruct {
static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
const CastOptions& options = CastState::Get(ctx);
- const StructType& in_type = checked_cast<const StructType&>(*batch[0].type());
- const StructType& out_type = checked_cast<const StructType&>(*out->type());
- const auto in_field_count = in_type.num_fields();
-
- if (in_field_count != out_type.num_fields()) {
- return Status::TypeError("struct field sizes do not match: ", in_type.ToString(),
- " ", out_type.ToString());
- }
-
- for (int i = 0; i < in_field_count; ++i) {
- const auto in_field = in_type.field(i);
- const auto out_field = out_type.field(i);
- if (in_field->name() != out_field->name()) {
- return Status::TypeError("struct field names do not match: ", in_type.ToString(),
- " ", out_type.ToString());
+ const auto& in_type = checked_cast<const StructType&>(*batch[0].type());
+ const auto& out_type = checked_cast<const StructType&>(*out->type());
+ const int in_field_count = in_type.num_fields();
+ const int out_field_count = out_type.num_fields();
+
+ std::vector<int> fields_to_select(out_field_count, -1);
+
+ int out_field_index = 0;
+ for (int in_field_index = 0;
+ in_field_index < in_field_count && out_field_index < out_field_count;
+ ++in_field_index) {
+ const auto& in_field = in_type.field(in_field_index);
+ const auto& out_field = out_type.field(out_field_index);
+ if (in_field->name() == out_field->name()) {
+ if (in_field->nullable() && !out_field->nullable()) {
+ return Status::TypeError("cannot cast nullable field to non-nullable field: ",
+ in_type.ToString(), " ", out_type.ToString());
+ }
+ fields_to_select[out_field_index++] = in_field_index;
}
+ }
- if (in_field->nullable() && !out_field->nullable()) {
- return Status::TypeError("cannot cast nullable struct to non-nullable struct: ",
- in_type.ToString(), " ", out_type.ToString());
- }
+ if (out_field_index < out_field_count) {
+ return Status::TypeError(
+ "struct fields don't match or are in the wrong order: Input fields: ",
+ in_type.ToString(), " output fields: ", out_type.ToString());
}
if (out->kind() == Datum::SCALAR) {
@@ -182,9 +187,10 @@ struct CastStruct {
DCHECK(!out_scalar->is_valid);
if (in_scalar.is_valid) {
- for (int i = 0; i < in_field_count; i++) {
- auto values = in_scalar.value[i];
- auto target_type = out->type()->field(i)->type();
+ out_field_index = 0;
+ for (int field_index : fields_to_select) {
+ const auto& values = in_scalar.value[field_index];
+ const auto& target_type = out->type()->field(out_field_index++)->type();
ARROW_ASSIGN_OR_RAISE(Datum cast_values,
Cast(values, target_type, options, ctx->exec_context()));
DCHECK_EQ(Datum::SCALAR, cast_values.kind());
@@ -204,9 +210,11 @@ struct CastStruct {
in_array.offset, in_array.length));
}
- for (int i = 0; i < in_field_count; ++i) {
- auto values = in_array.child_data[i]->Slice(in_array.offset, in_array.length);
- auto target_type = out->type()->field(i)->type();
+ out_field_index = 0;
+ for (int field_index : fields_to_select) {
+ const auto& values =
+ in_array.child_data[field_index]->Slice(in_array.offset, in_array.length);
+ const auto& target_type = out->type()->field(out_field_index++)->type();
ARROW_ASSIGN_OR_RAISE(Datum cast_values,
Cast(values, target_type, options, ctx->exec_context()));
diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc
index 222b5ee88a..b800299658 100644
--- a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc
@@ -2230,8 +2230,223 @@ static void CheckStructToStruct(
}
}
-TEST(Cast, StructToSameSizedAndNamedStruct) {
- CheckStructToStruct({int32(), float32(), int64()});
+static void CheckStructToStructSubset(
+ const std::vector<std::shared_ptr<DataType>>& value_types) {
+ for (const auto& src_value_type : value_types) {
+ ARROW_SCOPED_TRACE("From type: ", src_value_type->ToString());
+ for (const auto& dest_value_type : value_types) {
+ ARROW_SCOPED_TRACE("To type: ", dest_value_type->ToString());
+
+ std::vector<std::string> field_names = {"a", "b", "c", "d", "e"};
+
+ std::shared_ptr<Array> a1, b1, c1, d1, e1;
+ a1 = ArrayFromJSON(src_value_type, "[1, 2, 5]");
+ b1 = ArrayFromJSON(src_value_type, "[3, 4, 7]");
+ c1 = ArrayFromJSON(src_value_type, "[9, 11, 44]");
+ d1 = ArrayFromJSON(src_value_type, "[6, 51, 49]");
+ e1 = ArrayFromJSON(src_value_type, "[19, 17, 74]");
+
+ std::shared_ptr<Array> a2, b2, c2, d2, e2;
+ a2 = ArrayFromJSON(dest_value_type, "[1, 2, 5]");
+ b2 = ArrayFromJSON(dest_value_type, "[3, 4, 7]");
+ c2 = ArrayFromJSON(dest_value_type, "[9, 11, 44]");
+ d2 = ArrayFromJSON(dest_value_type, "[6, 51, 49]");
+ e2 = ArrayFromJSON(dest_value_type, "[19, 17, 74]");
+
+ ASSERT_OK_AND_ASSIGN(auto src,
+ StructArray::Make({a1, b1, c1, d1, e1}, field_names));
+ ASSERT_OK_AND_ASSIGN(auto dest1,
+ StructArray::Make({a2}, std::vector<std::string>{"a"}));
+ CheckCast(src, dest1);
+
+ ASSERT_OK_AND_ASSIGN(
+ auto dest2, StructArray::Make({b2, c2}, std::vector<std::string>{"b", "c"}));
+ CheckCast(src, dest2);
+
+ ASSERT_OK_AND_ASSIGN(
+ auto dest3,
+ StructArray::Make({c2, d2, e2}, std::vector<std::string>{"c", "d", "e"}));
+ CheckCast(src, dest3);
+
+ ASSERT_OK_AND_ASSIGN(
+ auto dest4, StructArray::Make({a2, b2, c2, e2},
+ std::vector<std::string>{"a", "b", "c", "e"}));
+ CheckCast(src, dest4);
+
+ ASSERT_OK_AND_ASSIGN(
+ auto dest5, StructArray::Make({a2, b2, c2, d2, e2}, {"a", "b", "c", "d", "e"}));
+ CheckCast(src, dest5);
+
+ // field does not exist
+ const auto dest6 = arrow::struct_({std::make_shared<Field>("a", int8()),
+ std::make_shared<Field>("d", int16()),
+ std::make_shared<Field>("f", int64())});
+ const auto options6 = CastOptions::Safe(dest6);
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ TypeError,
+ ::testing::HasSubstr("struct fields don't match or are in the wrong order"),
+ Cast(src, options6));
+
+ // fields in wrong order
+ const auto dest7 = arrow::struct_({std::make_shared<Field>("a", int8()),
+ std::make_shared<Field>("c", int16()),
+ std::make_shared<Field>("b", int64())});
+ const auto options7 = CastOptions::Safe(dest7);
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ TypeError,
+ ::testing::HasSubstr("struct fields don't match or are in the wrong order"),
+ Cast(src, options7));
+
+ // duplicate missing field names
+ const auto dest8 = arrow::struct_(
+ {std::make_shared<Field>("a", int8()), std::make_shared<Field>("c", int16()),
+ std::make_shared<Field>("d", int32()), std::make_shared<Field>("a", int64())});
+ const auto options8 = CastOptions::Safe(dest8);
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ TypeError,
+ ::testing::HasSubstr("struct fields don't match or are in the wrong order"),
+ Cast(src, options8));
+
+ // duplicate present field names
+ ASSERT_OK_AND_ASSIGN(
+ auto src_duplicate_field_names,
+ StructArray::Make({a1, b1, c1}, std::vector<std::string>{"a", "a", "a"}));
+
+ ASSERT_OK_AND_ASSIGN(auto dest1_duplicate_field_names,
+ StructArray::Make({a2}, std::vector<std::string>{"a"}));
+ CheckCast(src_duplicate_field_names, dest1_duplicate_field_names);
+
+ ASSERT_OK_AND_ASSIGN(
+ auto dest2_duplicate_field_names,
+ StructArray::Make({a2, b2}, std::vector<std::string>{"a", "a"}));
+ CheckCast(src_duplicate_field_names, dest2_duplicate_field_names);
+
+ ASSERT_OK_AND_ASSIGN(
+ auto dest3_duplicate_field_names,
+ StructArray::Make({a2, b2, c2}, std::vector<std::string>{"a", "a", "a"}));
+ CheckCast(src_duplicate_field_names, dest3_duplicate_field_names);
+ }
+ }
+}
+
+static void CheckStructToStructSubsetWithNulls(
+ const std::vector<std::shared_ptr<DataType>>& value_types) {
+ for (const auto& src_value_type : value_types) {
+ ARROW_SCOPED_TRACE("From type: ", src_value_type->ToString());
+ for (const auto& dest_value_type : value_types) {
+ ARROW_SCOPED_TRACE("To type: ", dest_value_type->ToString());
+
+ std::vector<std::string> field_names = {"a", "b", "c", "d", "e"};
+
+ std::shared_ptr<Array> a1, b1, c1, d1, e1;
+ a1 = ArrayFromJSON(src_value_type, "[1, 2, 5]");
+ b1 = ArrayFromJSON(src_value_type, "[3, null, 7]");
+ c1 = ArrayFromJSON(src_value_type, "[9, 11, 44]");
+ d1 = ArrayFromJSON(src_value_type, "[6, 51, null]");
+ e1 = ArrayFromJSON(src_value_type, "[null, 17, 74]");
+
+ std::shared_ptr<Array> a2, b2, c2, d2, e2;
+ a2 = ArrayFromJSON(dest_value_type, "[1, 2, 5]");
+ b2 = ArrayFromJSON(dest_value_type, "[3, null, 7]");
+ c2 = ArrayFromJSON(dest_value_type, "[9, 11, 44]");
+ d2 = ArrayFromJSON(dest_value_type, "[6, 51, null]");
+ e2 = ArrayFromJSON(dest_value_type, "[null, 17, 74]");
+
+ std::shared_ptr<Buffer> null_bitmap;
+ BitmapFromVector<int>({0, 1, 0}, &null_bitmap);
+
+ ASSERT_OK_AND_ASSIGN(auto src_null, StructArray::Make({a1, b1, c1, d1, e1},
+ field_names, null_bitmap));
+ ASSERT_OK_AND_ASSIGN(
+ auto dest1_null,
+ StructArray::Make({a2}, std::vector<std::string>{"a"}, null_bitmap));
+ CheckCast(src_null, dest1_null);
+
+ ASSERT_OK_AND_ASSIGN(
+ auto dest2_null,
+ StructArray::Make({b2, c2}, std::vector<std::string>{"b", "c"}, null_bitmap));
+ CheckCast(src_null, dest2_null);
+
+ ASSERT_OK_AND_ASSIGN(
+ auto dest3_null,
+ StructArray::Make({a2, d2, e2}, std::vector<std::string>{"a", "d", "e"},
+ null_bitmap));
+ CheckCast(src_null, dest3_null);
+
+ ASSERT_OK_AND_ASSIGN(
+ auto dest4_null,
+ StructArray::Make({a2, b2, c2, e2},
+ std::vector<std::string>{"a", "b", "c", "e"}, null_bitmap));
+ CheckCast(src_null, dest4_null);
+
+ ASSERT_OK_AND_ASSIGN(
+ auto dest5_null,
+ StructArray::Make({a2, b2, c2, d2, e2},
+ std::vector<std::string>{"a", "b", "c", "d", "e"},
+ null_bitmap));
+ CheckCast(src_null, dest5_null);
+
+ // field does not exist
+ const auto dest6_null = arrow::struct_({std::make_shared<Field>("a", int8()),
+ std::make_shared<Field>("d", int16()),
+ std::make_shared<Field>("f", int64())});
+ const auto options6_null = CastOptions::Safe(dest6_null);
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ TypeError,
+ ::testing::HasSubstr("struct fields don't match or are in the wrong order"),
+ Cast(src_null, options6_null));
+
+ // fields in wrong order
+ const auto dest7_null = arrow::struct_({std::make_shared<Field>("a", int8()),
+ std::make_shared<Field>("c", int16()),
+ std::make_shared<Field>("b", int64())});
+ const auto options7_null = CastOptions::Safe(dest7_null);
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ TypeError,
+ ::testing::HasSubstr("struct fields don't match or are in the wrong order"),
+ Cast(src_null, options7_null));
+
+ // duplicate missing field names
+ const auto dest8_null = arrow::struct_(
+ {std::make_shared<Field>("a", int8()), std::make_shared<Field>("c", int16()),
+ std::make_shared<Field>("d", int32()), std::make_shared<Field>("a", int64())});
+ const auto options8_null = CastOptions::Safe(dest8_null);
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ TypeError,
+ ::testing::HasSubstr("struct fields don't match or are in the wrong order"),
+ Cast(src_null, options8_null));
+
+ // duplicate present field values
+ ASSERT_OK_AND_ASSIGN(
+ auto src_duplicate_field_names_null,
+ StructArray::Make({a1, b1, c1}, std::vector<std::string>{"a", "a", "a"},
+ null_bitmap));
+
+ ASSERT_OK_AND_ASSIGN(
+ auto dest1_duplicate_field_names_null,
+ StructArray::Make({a2}, std::vector<std::string>{"a"}, null_bitmap));
+ CheckCast(src_duplicate_field_names_null, dest1_duplicate_field_names_null);
+
+ ASSERT_OK_AND_ASSIGN(
+ auto dest2_duplicate_field_names_null,
+ StructArray::Make({a2, b2}, std::vector<std::string>{"a", "a"}, null_bitmap));
+ CheckCast(src_duplicate_field_names_null, dest2_duplicate_field_names_null);
+
+ ASSERT_OK_AND_ASSIGN(
+ auto dest3_duplicate_field_names_null,
+ StructArray::Make({a2, b2, c2}, std::vector<std::string>{"a", "a", "a"},
+ null_bitmap));
+ CheckCast(src_duplicate_field_names_null, dest3_duplicate_field_names_null);
+ }
+ }
+}
+
+TEST(Cast, StructToSameSizedAndNamedStruct) { CheckStructToStruct(NumericTypes()); }
+
+TEST(Cast, StructToStructSubset) { CheckStructToStructSubset(NumericTypes()); }
+
+TEST(Cast, StructToStructSubsetWithNulls) {
+ CheckStructToStructSubsetWithNulls(NumericTypes());
}
TEST(Cast, StructToSameSizedButDifferentNamedStruct) {
@@ -2247,12 +2462,11 @@ TEST(Cast, StructToSameSizedButDifferentNamedStruct) {
EXPECT_RAISES_WITH_MESSAGE_THAT(
TypeError,
- ::testing::HasSubstr("Type error: struct field names do not match: struct<a: int8, "
- "b: int8> struct<c: int8, d: int8>"),
+ ::testing::HasSubstr("struct fields don't match or are in the wrong order"),
Cast(src, options));
}
-TEST(Cast, StructToDifferentSizeStruct) {
+TEST(Cast, StructToBiggerStruct) {
std::vector<std::string> field_names = {"a", "b"};
std::shared_ptr<Array> a, b;
a = ArrayFromJSON(int8(), "[1, 2]");
@@ -2266,52 +2480,100 @@ TEST(Cast, StructToDifferentSizeStruct) {
EXPECT_RAISES_WITH_MESSAGE_THAT(
TypeError,
- ::testing::HasSubstr("Type error: struct field sizes do not match: struct<a: int8, "
- "b: int8> struct<a: int8, b: int8, c: int8>"),
+ ::testing::HasSubstr("struct fields don't match or are in the wrong order"),
Cast(src, options));
}
-TEST(Cast, StructToSameSizedButDifferentNullabilityStruct) {
- // OK to go from not-nullable to nullable...
- std::vector<std::shared_ptr<Field>> fields1 = {
- std::make_shared<Field>("a", int8(), false),
- std::make_shared<Field>("b", int8(), false)};
- std::shared_ptr<Array> a1, b1;
- a1 = ArrayFromJSON(int8(), "[1, 2]");
- b1 = ArrayFromJSON(int8(), "[3, 4]");
- ASSERT_OK_AND_ASSIGN(auto src1, StructArray::Make({a1, b1}, fields1));
-
- std::vector<std::shared_ptr<Field>> fields2 = {
- std::make_shared<Field>("a", int8(), true),
- std::make_shared<Field>("b", int8(), true)};
- std::shared_ptr<Array> a2, b2;
- a2 = ArrayFromJSON(int8(), "[1, 2]");
- b2 = ArrayFromJSON(int8(), "[3, 4]");
- ASSERT_OK_AND_ASSIGN(auto dest1, StructArray::Make({a2, b2}, fields2));
-
- CheckCast(src1, dest1);
-
- // But not the other way around
- std::vector<std::shared_ptr<Field>> fields3 = {
- std::make_shared<Field>("a", int8(), true),
- std::make_shared<Field>("b", int8(), true)};
- std::shared_ptr<Array> a3, b3;
- a3 = ArrayFromJSON(int8(), "[1, null]");
- b3 = ArrayFromJSON(int8(), "[3, 4]");
- ASSERT_OK_AND_ASSIGN(auto src2, StructArray::Make({a3, b3}, fields3));
-
- std::vector<std::shared_ptr<Field>> fields4 = {
- std::make_shared<Field>("a", int8(), false),
- std::make_shared<Field>("b", int8(), false)};
- const auto dest2 = arrow::struct_(fields4);
- const auto options = CastOptions::Safe(dest2);
-
- EXPECT_RAISES_WITH_MESSAGE_THAT(
- TypeError,
- ::testing::HasSubstr(
- "Type error: cannot cast nullable struct to non-nullable "
- "struct: struct<a: int8, b: int8> struct<a: int8 not null, b: int8 not null>"),
- Cast(src2, options));
+TEST(Cast, StructToDifferentNullabilityStruct) {
+ {
+ // OK to go from non-nullable to nullable...
+ std::vector<std::shared_ptr<Field>> fields_src_non_nullable = {
+ std::make_shared<Field>("a", int8(), false),
+ std::make_shared<Field>("b", int8(), false),
+ std::make_shared<Field>("c", int8(), false)};
+ std::shared_ptr<Array> a_src_non_nullable, b_src_non_nullable, c_src_non_nullable;
+ a_src_non_nullable = ArrayFromJSON(int8(), "[11, 23, 56]");
+ b_src_non_nullable = ArrayFromJSON(int8(), "[32, 46, 37]");
+ c_src_non_nullable = ArrayFromJSON(int8(), "[95, 11, 44]");
+ ASSERT_OK_AND_ASSIGN(
+ auto src_non_nullable,
+ StructArray::Make({a_src_non_nullable, b_src_non_nullable, c_src_non_nullable},
+ fields_src_non_nullable));
+
+ std::shared_ptr<Array> a_dest_nullable, b_dest_nullable, c_dest_nullable;
+ a_dest_nullable = ArrayFromJSON(int64(), "[11, 23, 56]");
+ b_dest_nullable = ArrayFromJSON(int64(), "[32, 46, 37]");
+ c_dest_nullable = ArrayFromJSON(int64(), "[95, 11, 44]");
+
+ std::vector<std::shared_ptr<Field>> fields_dest1_nullable = {
+ std::make_shared<Field>("a", int64(), true),
+ std::make_shared<Field>("b", int64(), true),
+ std::make_shared<Field>("c", int64(), true)};
+ ASSERT_OK_AND_ASSIGN(
+ auto dest1_nullable,
+ StructArray::Make({a_dest_nullable, b_dest_nullable, c_dest_nullable},
+ fields_dest1_nullable));
+ CheckCast(src_non_nullable, dest1_nullable);
+
+ std::vector<std::shared_ptr<Field>> fields_dest2_nullable = {
+ std::make_shared<Field>("a", int64(), true),
+ std::make_shared<Field>("c", int64(), true)};
+ ASSERT_OK_AND_ASSIGN(
+ auto dest2_nullable,
+ StructArray::Make({a_dest_nullable, c_dest_nullable}, fields_dest2_nullable));
+ CheckCast(src_non_nullable, dest2_nullable);
+
+ std::vector<std::shared_ptr<Field>> fields_dest3_nullable = {
+ std::make_shared<Field>("b", int64(), true)};
+ ASSERT_OK_AND_ASSIGN(auto dest3_nullable,
+ StructArray::Make({b_dest_nullable}, fields_dest3_nullable));
+ CheckCast(src_non_nullable, dest3_nullable);
+ }
+ {
+ // But NOT OK to go from nullable to non-nullable...
+ std::vector<std::shared_ptr<Field>> fields_src_nullable = {
+ std::make_shared<Field>("a", int8(), true),
+ std::make_shared<Field>("b", int8(), true),
+ std::make_shared<Field>("c", int8(), true)};
+ std::shared_ptr<Array> a_src_nullable, b_src_nullable, c_src_nullable;
+ a_src_nullable = ArrayFromJSON(int8(), "[1, null, 5]");
+ b_src_nullable = ArrayFromJSON(int8(), "[3, 4, null]");
+ c_src_nullable = ArrayFromJSON(int8(), "[9, 11, 44]");
+ ASSERT_OK_AND_ASSIGN(
+ auto src_nullable,
+ StructArray::Make({a_src_nullable, b_src_nullable, c_src_nullable},
+ fields_src_nullable));
+
+ std::vector<std::shared_ptr<Field>> fields_dest1_non_nullable = {
+ std::make_shared<Field>("a", int64(), false),
+ std::make_shared<Field>("b", int64(), false),
+ std::make_shared<Field>("c", int64(), false)};
+ const auto dest1_non_nullable = arrow::struct_(fields_dest1_non_nullable);
+ const auto options1_non_nullable = CastOptions::Safe(dest1_non_nullable);
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ TypeError,
+ ::testing::HasSubstr("cannot cast nullable field to non-nullable field"),
+ Cast(src_nullable, options1_non_nullable));
+
+ std::vector<std::shared_ptr<Field>> fields_dest2_non_nullble = {
+ std::make_shared<Field>("a", int64(), false),
+ std::make_shared<Field>("c", int64(), false)};
+ const auto dest2_non_nullable = arrow::struct_(fields_dest2_non_nullble);
+ const auto options2_non_nullable = CastOptions::Safe(dest2_non_nullable);
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ TypeError,
+ ::testing::HasSubstr("cannot cast nullable field to non-nullable field"),
+ Cast(src_nullable, options2_non_nullable));
+
+ std::vector<std::shared_ptr<Field>> fields_dest3_non_nullble = {
+ std::make_shared<Field>("c", int64(), false)};
+ const auto dest3_non_nullable = arrow::struct_(fields_dest3_non_nullble);
+ const auto options3_non_nullable = CastOptions::Safe(dest3_non_nullable);
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ TypeError,
+ ::testing::HasSubstr("cannot cast nullable field to non-nullable field"),
+ Cast(src_nullable, options3_non_nullable));
+ }
}
TEST(Cast, IdentityCasts) {
diff --git a/cpp/src/arrow/compute/kernels/scalar_temporal_test.cc b/cpp/src/arrow/compute/kernels/scalar_temporal_test.cc
index f2915951cb..a52d69c36c 100644
--- a/cpp/src/arrow/compute/kernels/scalar_temporal_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_temporal_test.cc
@@ -1234,6 +1234,84 @@ TEST_F(ScalarTemporalTest, TestTemporalSubtractDate) {
}
}
+TEST_F(ScalarTemporalTest, TestTemporalSubtractTimestampAndDate) {
+ std::string seconds_between_date_and_time =
+ "[59, 84203, 3560, 12800, 3905, 7810, 11715, 15620, "
+ "19525, 23430, 27335, 31240, 35145, 0, 0, 3723, null]";
+ std::string milliseconds_between_date_and_time =
+ "[59000, 84203000, 3560000, 12800000, 3905000, 7810000, 11715000, 15620000, "
+ "19525000, 23430000, 27335000, 31240000, 35145000, 0, 0, 3723000, null]";
+ std::string microseconds_between_date_and_time =
+ "[59000000, 84203000000, 3560000000, 12800000000, 3905000000, 7810000000, "
+ "11715000000, 15620000000, 19525000000, 23430000000, 27335000000, 31240000000, "
+ "35145000000, 0, 0, 3723000000, null]";
+ std::string nanoseconds_between_date_and_time =
+ "[59000000000, 84203000000000, 3560000000000, 12800000000000, 3905000000000, "
+ "7810000000000, 11715000000000, 15620000000000, 19525000000000, 23430000000000, "
+ "27335000000000, 31240000000000, 35145000000000, 0, 0, 3723000000000, null]";
+ std::string seconds_between_date_and_time2 =
+ "[-59, -84203, -3560, -12800, -3905, -7810, -11715, -15620, "
+ "-19525, -23430, -27335, -31240, -35145, 0, 0, -3723, null]";
+ std::string milliseconds_between_date_and_time2 =
+ "[-59000, -84203000, -3560000, -12800000, -3905000, -7810000, -11715000,"
+ "-15620000, -19525000, -23430000, -27335000, -31240000, -35145000, 0, 0, "
+ "-3723000, null]";
+ std::string microseconds_between_date_and_time2 =
+ "[-59000000, -84203000000, -3560000000, -12800000000, -3905000000, -7810000000, "
+ "-11715000000, -15620000000, -19525000000, -23430000000, -27335000000,"
+ "-31240000000, -35145000000, 0, 0, -3723000000, null]";
+ std::string nanoseconds_between_date_and_time2 =
+ "[-59000000000, -84203000000000, -3560000000000, -12800000000000, "
+ "-3905000000000, -7810000000000, -11715000000000, -15620000000000, "
+ "-19525000000000, -23430000000000, -27335000000000, -31240000000000, "
+ "-35145000000000, 0, 0, -3723000000000, null]";
+
+ auto arr_date32s = ArrayFromJSON(date32(), date32s);
+ auto arr_date32s2 = ArrayFromJSON(date32(), date32s2);
+ auto arr_date64s = ArrayFromJSON(date64(), date64s);
+ auto arr_date64s2 = ArrayFromJSON(date64(), date64s2);
+ auto timestamp_s = ArrayFromJSON(timestamp(TimeUnit::SECOND), times_seconds_precision);
+ auto timestamp_ms = ArrayFromJSON(timestamp(TimeUnit::MILLI), times_seconds_precision);
+ auto timestamp_us = ArrayFromJSON(timestamp(TimeUnit::MICRO), times_seconds_precision);
+ auto timestamp_ns = ArrayFromJSON(timestamp(TimeUnit::NANO), times_seconds_precision);
+ auto between_s =
+ ArrayFromJSON(duration(TimeUnit::SECOND), seconds_between_date_and_time);
+ auto between_ms =
+ ArrayFromJSON(duration(TimeUnit::MILLI), milliseconds_between_date_and_time);
+ auto between_us =
+ ArrayFromJSON(duration(TimeUnit::MICRO), microseconds_between_date_and_time);
+ auto between_ns =
+ ArrayFromJSON(duration(TimeUnit::NANO), nanoseconds_between_date_and_time);
+ auto between_s2 =
+ ArrayFromJSON(duration(TimeUnit::SECOND), seconds_between_date_and_time2);
+ auto between_ms2 =
+ ArrayFromJSON(duration(TimeUnit::MILLI), milliseconds_between_date_and_time2);
+ auto between_us2 =
+ ArrayFromJSON(duration(TimeUnit::MICRO), microseconds_between_date_and_time2);
+ auto between_ns2 =
+ ArrayFromJSON(duration(TimeUnit::NANO), nanoseconds_between_date_and_time2);
+
+ for (auto op : {"subtract", "subtract_checked"}) {
+ CheckScalarBinary(op, timestamp_s, arr_date32s, between_s);
+ CheckScalarBinary(op, timestamp_ms, arr_date32s, between_ms);
+ CheckScalarBinary(op, timestamp_us, arr_date32s, between_us);
+ CheckScalarBinary(op, timestamp_ns, arr_date32s, between_ns);
+ CheckScalarBinary(op, timestamp_s, arr_date64s, between_ms);
+ CheckScalarBinary(op, timestamp_ms, arr_date64s, between_ms);
+ CheckScalarBinary(op, timestamp_us, arr_date64s, between_us);
+ CheckScalarBinary(op, timestamp_ns, arr_date64s, between_ns);
+
+ CheckScalarBinary(op, arr_date32s, timestamp_s, between_s2);
+ CheckScalarBinary(op, arr_date32s, timestamp_ms, between_ms2);
+ CheckScalarBinary(op, arr_date32s, timestamp_us, between_us2);
+ CheckScalarBinary(op, arr_date32s, timestamp_ns, between_ns2);
+ CheckScalarBinary(op, arr_date64s, timestamp_s, between_ms2);
+ CheckScalarBinary(op, arr_date64s, timestamp_ms, between_ms2);
+ CheckScalarBinary(op, arr_date64s, timestamp_us, between_us2);
+ CheckScalarBinary(op, arr_date64s, timestamp_ns, between_ns2);
+ }
+}
+
TEST_F(ScalarTemporalTest, TestTemporalSubtractTimestamp) {
for (auto op : {"subtract", "subtract_checked"}) {
for (auto tz : {"", "UTC", "Pacific/Marquesas"}) {
diff --git a/cpp/src/arrow/dataset/file_base.cc b/cpp/src/arrow/dataset/file_base.cc
index 4d5e8e59dc..4c66053e5b 100644
--- a/cpp/src/arrow/dataset/file_base.cc
+++ b/cpp/src/arrow/dataset/file_base.cc
@@ -271,22 +271,31 @@ namespace {
class DatasetWritingSinkNodeConsumer : public compute::SinkNodeConsumer {
public:
- DatasetWritingSinkNodeConsumer(std::shared_ptr<Schema> schema,
+ DatasetWritingSinkNodeConsumer(std::shared_ptr<const KeyValueMetadata> custom_metadata,
std::unique_ptr<internal::DatasetWriter> dataset_writer,
FileSystemDatasetWriteOptions write_options,
std::shared_ptr<util::AsyncToggle> backpressure_toggle)
- : schema_(std::move(schema)),
+ : custom_metadata_(std::move(custom_metadata)),
dataset_writer_(std::move(dataset_writer)),
write_options_(std::move(write_options)),
backpressure_toggle_(std::move(backpressure_toggle)) {}
- Status Consume(compute::ExecBatch batch) {
+ Status Init(const std::shared_ptr<Schema>& schema) override {
+ if (custom_metadata_) {
+ schema_ = schema->WithMetadata(custom_metadata_);
+ } else {
+ schema_ = schema;
+ }
+ return Status::OK();
+ }
+
+ Status Consume(compute::ExecBatch batch) override {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<RecordBatch> record_batch,
batch.ToRecordBatch(schema_));
return WriteNextBatch(std::move(record_batch), batch.guarantee);
}
- Future<> Finish() {
+ Future<> Finish() override {
RETURN_NOT_OK(task_group_.AddTask([this] { return dataset_writer_->Finish(); }));
return task_group_.End();
}
@@ -327,11 +336,12 @@ class DatasetWritingSinkNodeConsumer : public compute::SinkNodeConsumer {
return Status::OK();
}
- std::shared_ptr<Schema> schema_;
+ std::shared_ptr<const KeyValueMetadata> custom_metadata_;
std::unique_ptr<internal::DatasetWriter> dataset_writer_;
FileSystemDatasetWriteOptions write_options_;
std::shared_ptr<util::AsyncToggle> backpressure_toggle_;
util::SerializedAsyncTaskGroup task_group_;
+ std::shared_ptr<Schema> schema_ = nullptr;
};
} // namespace
@@ -354,6 +364,10 @@ Status FileSystemDataset::Write(const FileSystemDatasetWriteOptions& write_optio
std::shared_ptr<util::AsyncToggle> backpressure_toggle =
std::make_shared<util::AsyncToggle>();
+ // The projected_schema is currently used by pyarrow to preserve the custom metadata
+ // when reading from a single input file.
+ const auto& custom_metadata = scanner->options()->projected_schema->metadata();
+
RETURN_NOT_OK(
compute::Declaration::Sequence(
{
@@ -362,8 +376,7 @@ Status FileSystemDataset::Write(const FileSystemDatasetWriteOptions& write_optio
{"project",
compute::ProjectNodeOptions{std::move(exprs), std::move(names)}},
{"write",
- WriteNodeOptions{write_options, scanner->options()->projected_schema,
- backpressure_toggle}},
+ WriteNodeOptions{write_options, custom_metadata, backpressure_toggle}},
})
.AddToPlan(plan.get()));
@@ -381,8 +394,9 @@ Result<compute::ExecNode*> MakeWriteNode(compute::ExecPlan* plan,
const WriteNodeOptions write_node_options =
checked_cast<const WriteNodeOptions&>(options);
+ const std::shared_ptr<const KeyValueMetadata>& custom_metadata =
+ write_node_options.custom_metadata;
const FileSystemDatasetWriteOptions& write_options = write_node_options.write_options;
- const std::shared_ptr<Schema>& schema = write_node_options.schema;
const std::shared_ptr<util::AsyncToggle>& backpressure_toggle =
write_node_options.backpressure_toggle;
@@ -391,7 +405,7 @@ Result<compute::ExecNode*> MakeWriteNode(compute::ExecPlan* plan,
std::shared_ptr<DatasetWritingSinkNodeConsumer> consumer =
std::make_shared<DatasetWritingSinkNodeConsumer>(
- schema, std::move(dataset_writer), write_options, backpressure_toggle);
+ custom_metadata, std::move(dataset_writer), write_options, backpressure_toggle);
ARROW_ASSIGN_OR_RAISE(
auto node,
diff --git a/cpp/src/arrow/dataset/file_base.h b/cpp/src/arrow/dataset/file_base.h
index 07b156778f..ca8b7e6450 100644
--- a/cpp/src/arrow/dataset/file_base.h
+++ b/cpp/src/arrow/dataset/file_base.h
@@ -408,14 +408,18 @@ struct ARROW_DS_EXPORT FileSystemDatasetWriteOptions {
class ARROW_DS_EXPORT WriteNodeOptions : public compute::ExecNodeOptions {
public:
explicit WriteNodeOptions(
- FileSystemDatasetWriteOptions options, std::shared_ptr<Schema> schema,
+ FileSystemDatasetWriteOptions options,
+ std::shared_ptr<const KeyValueMetadata> custom_metadata = NULLPTR,
std::shared_ptr<util::AsyncToggle> backpressure_toggle = NULLPTR)
: write_options(std::move(options)),
- schema(std::move(schema)),
+ custom_metadata(std::move(custom_metadata)),
backpressure_toggle(std::move(backpressure_toggle)) {}
+ /// \brief Options to control how to write the dataset
FileSystemDatasetWriteOptions write_options;
- std::shared_ptr<Schema> schema;
+ /// \brief Optional metadata to attach to written batches
+ std::shared_ptr<const KeyValueMetadata> custom_metadata;
+ /// \brief Optional toggle that can be used to pause producers when the node is full
std::shared_ptr<util::AsyncToggle> backpressure_toggle;
};
diff --git a/cpp/src/arrow/dataset/file_csv.cc b/cpp/src/arrow/dataset/file_csv.cc
index 1cc7957083..277bab29a0 100644
--- a/cpp/src/arrow/dataset/file_csv.cc
+++ b/cpp/src/arrow/dataset/file_csv.cc
@@ -39,6 +39,7 @@
#include "arrow/util/async_generator.h"
#include "arrow/util/iterator.h"
#include "arrow/util/logging.h"
+#include "arrow/util/tracing_internal.h"
#include "arrow/util/utf8.h"
namespace arrow {
@@ -167,9 +168,14 @@ static inline Result<csv::ReadOptions> GetReadOptions(
static inline Future<std::shared_ptr<csv::StreamingReader>> OpenReaderAsync(
const FileSource& source, const CsvFileFormat& format,
const std::shared_ptr<ScanOptions>& scan_options, Executor* cpu_executor) {
+#ifdef ARROW_WITH_OPENTELEMETRY
+ auto tracer = arrow::internal::tracing::GetTracer();
+ auto span = tracer->StartSpan("arrow::dataset::CsvFileFormat::OpenReaderAsync");
+#endif
ARROW_ASSIGN_OR_RAISE(auto reader_options, GetReadOptions(format, scan_options));
ARROW_ASSIGN_OR_RAISE(auto input, source.OpenCompressed());
+ const auto& path = source.path();
ARROW_ASSIGN_OR_RAISE(
input, io::BufferedInputStream::Create(reader_options.block_size,
default_memory_pool(), std::move(input)));
@@ -190,11 +196,20 @@ static inline Future<std::shared_ptr<csv::StreamingReader>> OpenReaderAsync(
}));
return reader_fut.Then(
// Adds the filename to the error
- [](const std::shared_ptr<csv::StreamingReader>& reader)
- -> Result<std::shared_ptr<csv::StreamingReader>> { return reader; },
- [source](const Status& err) -> Result<std::shared_ptr<csv::StreamingReader>> {
- return err.WithMessage("Could not open CSV input source '", source.path(),
- "': ", err);
+ [=](const std::shared_ptr<csv::StreamingReader>& reader)
+ -> Result<std::shared_ptr<csv::StreamingReader>> {
+#ifdef ARROW_WITH_OPENTELEMETRY
+ span->SetStatus(opentelemetry::trace::StatusCode::kOk);
+ span->End();
+#endif
+ return reader;
+ },
+ [=](const Status& err) -> Result<std::shared_ptr<csv::StreamingReader>> {
+#ifdef ARROW_WITH_OPENTELEMETRY
+ arrow::internal::tracing::MarkSpan(err, span.get());
+ span->End();
+#endif
+ return err.WithMessage("Could not open CSV input source '", path, "': ", err);
});
}
@@ -250,7 +265,10 @@ Result<RecordBatchGenerator> CsvFileFormat::ScanBatchesAsync(
auto source = file->source();
auto reader_fut =
OpenReaderAsync(source, *this, scan_options, ::arrow::internal::GetCpuThreadPool());
- return GeneratorFromReader(std::move(reader_fut), scan_options->batch_size);
+ auto generator = GeneratorFromReader(std::move(reader_fut), scan_options->batch_size);
+ WRAP_ASYNC_GENERATOR_WITH_CHILD_SPAN(
+ generator, "arrow::dataset::CsvFileFormat::ScanBatchesAsync::Next");
+ return generator;
}
Future<util::optional<int64_t>> CsvFileFormat::CountRows(
diff --git a/cpp/src/arrow/dataset/file_ipc.cc b/cpp/src/arrow/dataset/file_ipc.cc
index e386c7dea8..7c45a5d705 100644
--- a/cpp/src/arrow/dataset/file_ipc.cc
+++ b/cpp/src/arrow/dataset/file_ipc.cc
@@ -30,6 +30,7 @@
#include "arrow/util/checked_cast.h"
#include "arrow/util/iterator.h"
#include "arrow/util/logging.h"
+#include "arrow/util/tracing_internal.h"
namespace arrow {
@@ -62,16 +63,31 @@ static inline Result<std::shared_ptr<ipc::RecordBatchFileReader>> OpenReader(
static inline Future<std::shared_ptr<ipc::RecordBatchFileReader>> OpenReaderAsync(
const FileSource& source,
const ipc::IpcReadOptions& options = default_read_options()) {
+#ifdef ARROW_WITH_OPENTELEMETRY
+ auto tracer = arrow::internal::tracing::GetTracer();
+ auto span = tracer->StartSpan("arrow::dataset::IpcFileFormat::OpenReaderAsync");
+#endif
ARROW_ASSIGN_OR_RAISE(auto input, source.Open());
auto path = source.path();
return ipc::RecordBatchFileReader::OpenAsync(std::move(input), options)
- .Then([](const std::shared_ptr<ipc::RecordBatchFileReader>& reader)
- -> Result<std::shared_ptr<ipc::RecordBatchFileReader>> { return reader; },
- [path](const Status& status)
- -> Result<std::shared_ptr<ipc::RecordBatchFileReader>> {
- return status.WithMessage("Could not open IPC input source '", path,
- "': ", status.message());
- });
+ .Then(
+ [=](const std::shared_ptr<ipc::RecordBatchFileReader>& reader)
+ -> Result<std::shared_ptr<ipc::RecordBatchFileReader>> {
+#ifdef ARROW_WITH_OPENTELEMETRY
+ span->SetStatus(opentelemetry::trace::StatusCode::kOk);
+ span->End();
+#endif
+ return reader;
+ },
+ [=](const Status& status)
+ -> Result<std::shared_ptr<ipc::RecordBatchFileReader>> {
+#ifdef ARROW_WITH_OPENTELEMETRY
+ arrow::internal::tracing::MarkSpan(status, span.get());
+ span->End();
+#endif
+ return status.WithMessage("Could not open IPC input source '", path,
+ "': ", status.message());
+ });
}
static inline Result<std::vector<int>> GetIncludedFields(
@@ -151,6 +167,8 @@ Result<RecordBatchGenerator> IpcFileFormat::ScanBatchesAsync(
ARROW_ASSIGN_OR_RAISE(generator, reader->GetRecordBatchGenerator(
/*coalesce=*/false, options->io_context));
}
+ WRAP_ASYNC_GENERATOR_WITH_CHILD_SPAN(
+ generator, "arrow::dataset::IpcFileFormat::ScanBatchesAsync::Next");
auto batch_generator = MakeReadaheadGenerator(std::move(generator), readahead_level);
return MakeChunkedBatchGenerator(std::move(batch_generator), options->batch_size);
};
diff --git a/cpp/src/arrow/dataset/file_parquet.cc b/cpp/src/arrow/dataset/file_parquet.cc
index 5fbd457ecc..4a8d409312 100644
--- a/cpp/src/arrow/dataset/file_parquet.cc
+++ b/cpp/src/arrow/dataset/file_parquet.cc
@@ -34,6 +34,7 @@
#include "arrow/util/iterator.h"
#include "arrow/util/logging.h"
#include "arrow/util/range.h"
+#include "arrow/util/tracing_internal.h"
#include "parquet/arrow/reader.h"
#include "parquet/arrow/schema.h"
#include "parquet/arrow/writer.h"
@@ -415,8 +416,11 @@ Result<RecordBatchGenerator> ParquetFileFormat::ScanBatchesAsync(
::arrow::internal::GetCpuThreadPool(), row_group_readahead));
return generator;
};
- return MakeFromFuture(GetReaderAsync(parquet_fragment->source(), options)
- .Then(std::move(make_generator)));
+ auto generator = MakeFromFuture(GetReaderAsync(parquet_fragment->source(), options)
+ .Then(std::move(make_generator)));
+ WRAP_ASYNC_GENERATOR_WITH_CHILD_SPAN(
+ generator, "arrow::dataset::ParquetFileFormat::ScanBatchesAsync::Next");
+ return generator;
}
Future<util::optional<int64_t>> ParquetFileFormat::CountRows(
diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc
index b958f7b9e6..537aa9609b 100644
--- a/cpp/src/arrow/dataset/scanner.cc
+++ b/cpp/src/arrow/dataset/scanner.cc
@@ -35,10 +35,12 @@
#include "arrow/dataset/plan.h"
#include "arrow/table.h"
#include "arrow/util/async_generator.h"
+#include "arrow/util/config.h"
#include "arrow/util/iterator.h"
#include "arrow/util/logging.h"
#include "arrow/util/task_group.h"
#include "arrow/util/thread_pool.h"
+#include "arrow/util/tracing_internal.h"
namespace arrow {
@@ -91,6 +93,7 @@ const FieldVector kAugmentedFields{
field("__fragment_index", int32()),
field("__batch_index", int32()),
field("__last_in_fragment", boolean()),
+ field("__filename", utf8()),
};
// Scan options has a number of options that we can infer from the dataset
@@ -206,6 +209,18 @@ class AsyncScanner : public Scanner, public std::enable_shared_from_this<AsyncSc
Result<EnumeratedRecordBatchGenerator> FragmentToBatches(
const Enumerated<std::shared_ptr<Fragment>>& fragment,
const std::shared_ptr<ScanOptions>& options) {
+#ifdef ARROW_WITH_OPENTELEMETRY
+ auto tracer = arrow::internal::tracing::GetTracer();
+ auto span = tracer->StartSpan(
+ "arrow::dataset::FragmentToBatches",
+ {
+ {"arrow.dataset.fragment", fragment.value->ToString()},
+ {"arrow.dataset.fragment.index", fragment.index},
+ {"arrow.dataset.fragment.last", fragment.last},
+ {"arrow.dataset.fragment.type_name", fragment.value->type_name()},
+ });
+ auto scope = tracer->WithActiveSpan(span);
+#endif
ARROW_ASSIGN_OR_RAISE(auto batch_gen, fragment.value->ScanBatchesAsync(options));
ArrayVector columns;
for (const auto& field : options->dataset_schema->fields()) {
@@ -214,6 +229,7 @@ Result<EnumeratedRecordBatchGenerator> FragmentToBatches(
MakeArrayOfNull(field->type(), /*length=*/0, options->pool));
columns.push_back(std::move(array));
}
+ WRAP_ASYNC_GENERATOR(batch_gen);
batch_gen = MakeDefaultIfEmptyGenerator(
std::move(batch_gen),
RecordBatch::Make(options->dataset_schema, /*num_rows=*/0, std::move(columns)));
@@ -230,10 +246,13 @@ Result<EnumeratedRecordBatchGenerator> FragmentToBatches(
Result<AsyncGenerator<EnumeratedRecordBatchGenerator>> FragmentsToBatches(
FragmentGenerator fragment_gen, const std::shared_ptr<ScanOptions>& options) {
auto enumerated_fragment_gen = MakeEnumeratedGenerator(std::move(fragment_gen));
- return MakeMappedGenerator(std::move(enumerated_fragment_gen),
- [=](const Enumerated<std::shared_ptr<Fragment>>& fragment) {
- return FragmentToBatches(fragment, options);
- });
+ auto batch_gen_gen =
+ MakeMappedGenerator(std::move(enumerated_fragment_gen),
+ [=](const Enumerated<std::shared_ptr<Fragment>>& fragment) {
+ return FragmentToBatches(fragment, options);
+ });
+ PROPAGATE_SPAN_TO_GENERATOR(std::move(batch_gen_gen));
+ return batch_gen_gen;
}
class OneShotFragment : public Fragment {
@@ -708,8 +727,12 @@ Result<ProjectionDescr> ProjectionDescr::FromNames(std::vector<std::string> name
for (size_t i = 0; i < exprs.size(); ++i) {
exprs[i] = compute::field_ref(names[i]);
}
+ auto fields = dataset_schema.fields();
+ for (const auto& aug_field : kAugmentedFields) {
+ fields.push_back(aug_field);
+ }
return ProjectionDescr::FromExpressions(std::move(exprs), std::move(names),
- dataset_schema);
+ Schema(fields, dataset_schema.metadata()));
}
Result<ProjectionDescr> ProjectionDescr::Default(const Schema& dataset_schema) {
@@ -877,6 +900,7 @@ Result<compute::ExecNode*> MakeScanNode(compute::ExecPlan* plan,
batch->values.emplace_back(partial.fragment.index);
batch->values.emplace_back(partial.record_batch.index);
batch->values.emplace_back(partial.record_batch.last);
+ batch->values.emplace_back(partial.fragment.value->ToString());
return batch;
});
diff --git a/cpp/src/arrow/dataset/scanner_test.cc b/cpp/src/arrow/dataset/scanner_test.cc
index 7fedb7d3c7..a3d1e6ac5e 100644
--- a/cpp/src/arrow/dataset/scanner_test.cc
+++ b/cpp/src/arrow/dataset/scanner_test.cc
@@ -128,6 +128,15 @@ class TestScanner : public DatasetFixtureMixinWithParam<TestScannerParams> {
AssertScanBatchesEquals(expected.get(), scanner.get());
}
+ void AssertNoAugmentedFields(std::shared_ptr<Scanner> scanner) {
+ ASSERT_OK_AND_ASSIGN(auto table, scanner.get()->ToTable());
+ auto columns = table.get()->ColumnNames();
+ EXPECT_TRUE(std::none_of(columns.begin(), columns.end(), [](std::string& x) {
+ return x == "__fragment_index" || x == "__batch_index" ||
+ x == "__last_in_fragment" || x == "__filename";
+ }));
+ }
+
void AssertScanBatchesUnorderedEqualRepetitionsOf(
std::shared_ptr<Scanner> scanner, std::shared_ptr<RecordBatch> batch,
const int64_t total_batches = GetParam().num_child_datasets *
@@ -257,6 +266,7 @@ TEST_P(TestScanner, ProjectionDefaults) {
options_->projection = literal(true);
options_->projected_schema = nullptr;
AssertScanBatchesEqualRepetitionsOf(MakeScanner(batch_in), batch_in);
+ AssertNoAugmentedFields(MakeScanner(batch_in));
}
// If we only specify a projection expression then infer the projected schema
// from the projection expression
@@ -1386,6 +1396,7 @@ DatasetAndBatches DatasetAndBatchesFromJSON(
// ... and with the last-in-fragment flag
batches.back().values.emplace_back(batch_index ==
fragment_batch_strs[fragment_index].size() - 1);
+ batches.back().values.emplace_back(fragments[fragment_index]->ToString());
// each batch carries a guarantee inherited from its Fragment's partition expression
batches.back().guarantee = fragments[fragment_index]->partition_expression();
@@ -1472,7 +1483,8 @@ DatasetAndBatches MakeNestedDataset() {
compute::Expression Materialize(std::vector<std::string> names,
bool include_aug_fields = false) {
if (include_aug_fields) {
- for (auto aug_name : {"__fragment_index", "__batch_index", "__last_in_fragment"}) {
+ for (auto aug_name :
+ {"__fragment_index", "__batch_index", "__last_in_fragment", "__filename"}) {
names.emplace_back(aug_name);
}
}
@@ -1502,6 +1514,7 @@ TEST(ScanNode, Schema) {
fields.push_back(field("__fragment_index", int32()));
fields.push_back(field("__batch_index", int32()));
fields.push_back(field("__last_in_fragment", boolean()));
+ fields.push_back(field("__filename", utf8()));
// output_schema is *always* the full augmented dataset schema, regardless of
// projection (but some columns *may* be placeholder null Scalars if not projected)
AssertSchemaEqual(Schema(fields), *scan->output_schema());
@@ -1656,7 +1669,9 @@ TEST(ScanNode, MaterializationOfNestedVirtualColumn) {
// TODO(ARROW-1888): allow scanner to "patch up" structs with casts
EXPECT_FINISHES_AND_RAISES_WITH_MESSAGE_THAT(
- TypeError, ::testing::HasSubstr("struct field sizes do not match"), plan.Run());
+ TypeError,
+ ::testing::HasSubstr("struct fields don't match or are in the wrong order"),
+ plan.Run());
}
TEST(ScanNode, MinimalEndToEnd) {
diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc
index fe43ab2879..80cdf59f49 100644
--- a/cpp/src/arrow/engine/substrait/extension_set.cc
+++ b/cpp/src/arrow/engine/substrait/extension_set.cc
@@ -164,7 +164,7 @@ Result<ExtensionSet> ExtensionSet::Make(std::vector<util::string_view> uris,
set.functions_[i] = {rec->id, rec->function_name};
continue;
}
- return Status::Invalid("Function ", function_ids[i].uri, "#", type_ids[i].name,
+ return Status::Invalid("Function ", function_ids[i].uri, "#", function_ids[i].name,
" not found");
}
diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc
index 6af5d71521..300a6c528b 100644
--- a/cpp/src/arrow/engine/substrait/serde_test.cc
+++ b/cpp/src/arrow/engine/substrait/serde_test.cc
@@ -724,5 +724,31 @@ TEST(Substrait, ExtensionSetFromPlan) {
EXPECT_EQ(decoded_add_func.name, "add");
}
+TEST(Substrait, ExtensionSetFromPlanMissingFunc) {
+ ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({
+ "relations": [],
+ "extension_uris": [
+ {
+ "extension_uri_anchor": 7,
+ "uri": "https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml"
+ }
+ ],
+ "extensions": [
+ {"extension_function": {
+ "extension_uri_reference": 7,
+ "function_anchor": 42,
+ "name": "does_not_exist"
+ }}
+ ]
+ })"));
+
+ ExtensionSet ext_set;
+ ASSERT_RAISES(
+ Invalid,
+ DeserializePlan(
+ *buf, [] { return std::shared_ptr<compute::SinkNodeConsumer>{nullptr}; },
+ &ext_set));
+}
+
} // namespace engine
} // namespace arrow
diff --git a/cpp/src/arrow/filesystem/gcsfs.cc b/cpp/src/arrow/filesystem/gcsfs.cc
index 7a07dfe5d5..0b0bba6766 100644
--- a/cpp/src/arrow/filesystem/gcsfs.cc
+++ b/cpp/src/arrow/filesystem/gcsfs.cc
@@ -846,8 +846,7 @@ Result<std::shared_ptr<io::OutputStream>> GcsFileSystem::OpenAppendStream(
std::shared_ptr<GcsFileSystem> GcsFileSystem::Make(const GcsOptions& options,
const io::IOContext& context) {
// Cannot use `std::make_shared<>` as the constructor is private.
- return std::shared_ptr<GcsFileSystem>(
- new GcsFileSystem(options, io::default_io_context()));
+ return std::shared_ptr<GcsFileSystem>(new GcsFileSystem(options, context));
}
GcsFileSystem::GcsFileSystem(const GcsOptions& options, const io::IOContext& context)
diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt
index 7447e675e0..f9d135654b 100644
--- a/cpp/src/arrow/flight/CMakeLists.txt
+++ b/cpp/src/arrow/flight/CMakeLists.txt
@@ -313,4 +313,14 @@ if(ARROW_BUILD_BENCHMARKS)
add_dependencies(arrow-flight-benchmark arrow-flight-perf-server)
add_dependencies(arrow_flight arrow-flight-benchmark)
+
+ if(ARROW_WITH_UCX)
+ if(ARROW_FLIGHT_TEST_LINKAGE STREQUAL "static")
+ target_link_libraries(arrow-flight-benchmark arrow_flight_transport_ucx_static)
+ target_link_libraries(arrow-flight-perf-server arrow_flight_transport_ucx_static)
+ else()
+ target_link_libraries(arrow-flight-benchmark arrow_flight_transport_ucx_shared)
+ target_link_libraries(arrow-flight-perf-server arrow_flight_transport_ucx_shared)
+ endif()
+ endif()
endif(ARROW_BUILD_BENCHMARKS)
diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc
index 160387b166..a5d6cd7c37 100644
--- a/cpp/src/arrow/flight/client.cc
+++ b/cpp/src/arrow/flight/client.cc
@@ -33,6 +33,7 @@
#include "arrow/status.h"
#include "arrow/table.h"
#include "arrow/util/logging.h"
+#include "arrow/util/make_unique.h"
#include "arrow/flight/client_auth.h"
#include "arrow/flight/serialization_internal.h"
@@ -489,21 +490,32 @@ FlightClient::~FlightClient() {
}
}
+arrow::Result<std::unique_ptr<FlightClient>> FlightClient::Connect(
+ const Location& location) {
+ return Connect(location, FlightClientOptions::Defaults());
+}
+
Status FlightClient::Connect(const Location& location,
std::unique_ptr<FlightClient>* client) {
- return Connect(location, FlightClientOptions::Defaults(), client);
+ return Connect(location, FlightClientOptions::Defaults()).Value(client);
}
-Status FlightClient::Connect(const Location& location, const FlightClientOptions& options,
- std::unique_ptr<FlightClient>* client) {
+arrow::Result<std::unique_ptr<FlightClient>> FlightClient::Connect(
+ const Location& location, const FlightClientOptions& options) {
flight::transport::grpc::InitializeFlightGrpcClient();
- client->reset(new FlightClient);
- (*client)->write_size_limit_bytes_ = options.write_size_limit_bytes;
+ std::unique_ptr<FlightClient> client(new FlightClient());
+ client->write_size_limit_bytes_ = options.write_size_limit_bytes;
const auto scheme = location.scheme();
- ARROW_ASSIGN_OR_RAISE((*client)->transport_,
+ ARROW_ASSIGN_OR_RAISE(client->transport_,
internal::GetDefaultTransportRegistry()->MakeClient(scheme));
- return (*client)->transport_->Init(options, location, *location.uri_);
+ RETURN_NOT_OK(client->transport_->Init(options, location, *location.uri_));
+ return client;
+}
+
+Status FlightClient::Connect(const Location& location, const FlightClientOptions& options,
+ std::unique_ptr<FlightClient>* client) {
+ return Connect(location, options).Value(client);
}
Status FlightClient::Authenticate(const FlightCallOptions& options,
@@ -519,23 +531,44 @@ arrow::Result<std::pair<std::string, std::string>> FlightClient::AuthenticateBas
return transport_->AuthenticateBasicToken(options, username, password);
}
+arrow::Result<std::unique_ptr<ResultStream>> FlightClient::DoAction(
+ const FlightCallOptions& options, const Action& action) {
+ std::unique_ptr<ResultStream> results;
+ RETURN_NOT_OK(CheckOpen());
+ RETURN_NOT_OK(transport_->DoAction(options, action, &results));
+ return results;
+}
+
Status FlightClient::DoAction(const FlightCallOptions& options, const Action& action,
std::unique_ptr<ResultStream>* results) {
+ return DoAction(options, action).Value(results);
+}
+
+arrow::Result<std::vector<ActionType>> FlightClient::ListActions(
+ const FlightCallOptions& options) {
+ std::vector<ActionType> actions;
RETURN_NOT_OK(CheckOpen());
- return transport_->DoAction(options, action, results);
+ RETURN_NOT_OK(transport_->ListActions(options, &actions));
+ return actions;
}
Status FlightClient::ListActions(const FlightCallOptions& options,
std::vector<ActionType>* actions) {
+ return ListActions(options).Value(actions);
+}
+
+arrow::Result<std::unique_ptr<FlightInfo>> FlightClient::GetFlightInfo(
+ const FlightCallOptions& options, const FlightDescriptor& descriptor) {
+ std::unique_ptr<FlightInfo> info;
RETURN_NOT_OK(CheckOpen());
- return transport_->ListActions(options, actions);
+ RETURN_NOT_OK(transport_->GetFlightInfo(options, descriptor, &info));
+ return info;
}
Status FlightClient::GetFlightInfo(const FlightCallOptions& options,
const FlightDescriptor& descriptor,
std::unique_ptr<FlightInfo>* info) {
- RETURN_NOT_OK(CheckOpen());
- return transport_->GetFlightInfo(options, descriptor, info);
+ return GetFlightInfo(options, descriptor).Value(info);
}
arrow::Result<std::unique_ptr<SchemaResult>> FlightClient::GetSchema(
@@ -550,63 +583,99 @@ Status FlightClient::GetSchema(const FlightCallOptions& options,
return GetSchema(options, descriptor).Value(schema_result);
}
+arrow::Result<std::unique_ptr<FlightListing>> FlightClient::ListFlights() {
+ return ListFlights({}, {});
+}
+
Status FlightClient::ListFlights(std::unique_ptr<FlightListing>* listing) {
+ return ListFlights({}, {}).Value(listing);
+}
+
+arrow::Result<std::unique_ptr<FlightListing>> FlightClient::ListFlights(
+ const FlightCallOptions& options, const Criteria& criteria) {
+ std::unique_ptr<FlightListing> listing;
RETURN_NOT_OK(CheckOpen());
- return ListFlights({}, {}, listing);
+ RETURN_NOT_OK(transport_->ListFlights(options, criteria, &listing));
+ return listing;
}
Status FlightClient::ListFlights(const FlightCallOptions& options,
const Criteria& criteria,
std::unique_ptr<FlightListing>* listing) {
+ return ListFlights(options, criteria).Value(listing);
+}
+
+arrow::Result<std::unique_ptr<FlightStreamReader>> FlightClient::DoGet(
+ const FlightCallOptions& options, const Ticket& ticket) {
RETURN_NOT_OK(CheckOpen());
- return transport_->ListFlights(options, criteria, listing);
+ std::unique_ptr<internal::ClientDataStream> remote_stream;
+ RETURN_NOT_OK(transport_->DoGet(options, ticket, &remote_stream));
+ std::unique_ptr<FlightStreamReader> stream_reader =
+ arrow::internal::make_unique<ClientStreamReader>(
+ std::move(remote_stream), options.read_options, options.stop_token,
+ options.memory_manager);
+ // Eagerly read the schema
+ RETURN_NOT_OK(
+ static_cast<ClientStreamReader*>(stream_reader.get())->EnsureDataStarted());
+ return stream_reader;
}
Status FlightClient::DoGet(const FlightCallOptions& options, const Ticket& ticket,
std::unique_ptr<FlightStreamReader>* stream) {
+ return DoGet(options, ticket).Value(stream);
+}
+
+arrow::Result<FlightClient::DoPutResult> FlightClient::DoPut(
+ const FlightCallOptions& options, const FlightDescriptor& descriptor,
+ const std::shared_ptr<Schema>& schema) {
RETURN_NOT_OK(CheckOpen());
std::unique_ptr<internal::ClientDataStream> remote_stream;
- RETURN_NOT_OK(transport_->DoGet(options, ticket, &remote_stream));
- *stream = std::unique_ptr<ClientStreamReader>(
- new ClientStreamReader(std::move(remote_stream), options.read_options,
- options.stop_token, options.memory_manager));
- // Eagerly read the schema
- return static_cast<ClientStreamReader*>(stream->get())->EnsureDataStarted();
+ RETURN_NOT_OK(transport_->DoPut(options, &remote_stream));
+ std::shared_ptr<internal::ClientDataStream> shared_stream = std::move(remote_stream);
+ DoPutResult result;
+ result.reader = arrow::internal::make_unique<ClientMetadataReader>(shared_stream);
+ result.writer = arrow::internal::make_unique<ClientStreamWriter>(
+ std::move(shared_stream), options.write_options, write_size_limit_bytes_,
+ descriptor);
+ RETURN_NOT_OK(result.writer->Begin(schema, options.write_options));
+ return result;
}
Status FlightClient::DoPut(const FlightCallOptions& options,
const FlightDescriptor& descriptor,
const std::shared_ptr<Schema>& schema,
- std::unique_ptr<FlightStreamWriter>* stream,
+ std::unique_ptr<FlightStreamWriter>* writer,
std::unique_ptr<FlightMetadataReader>* reader) {
+ ARROW_ASSIGN_OR_RAISE(auto result, DoPut(options, descriptor, schema));
+ *writer = std::move(result.writer);
+ *reader = std::move(result.reader);
+ return Status::OK();
+}
+
+arrow::Result<FlightClient::DoExchangeResult> FlightClient::DoExchange(
+ const FlightCallOptions& options, const FlightDescriptor& descriptor) {
RETURN_NOT_OK(CheckOpen());
std::unique_ptr<internal::ClientDataStream> remote_stream;
- RETURN_NOT_OK(transport_->DoPut(options, &remote_stream));
+ RETURN_NOT_OK(transport_->DoExchange(options, &remote_stream));
std::shared_ptr<internal::ClientDataStream> shared_stream = std::move(remote_stream);
- *reader =
- std::unique_ptr<FlightMetadataReader>(new ClientMetadataReader(shared_stream));
- *stream = std::unique_ptr<FlightStreamWriter>(
- new ClientStreamWriter(std::move(shared_stream), options.write_options,
- write_size_limit_bytes_, descriptor));
- RETURN_NOT_OK((*stream)->Begin(schema, options.write_options));
- return Status::OK();
+ DoExchangeResult result;
+ result.reader = arrow::internal::make_unique<ClientStreamReader>(
+ shared_stream, options.read_options, options.stop_token, options.memory_manager);
+ auto stream_writer = arrow::internal::make_unique<ClientStreamWriter>(
+ std::move(shared_stream), options.write_options, write_size_limit_bytes_,
+ descriptor);
+ RETURN_NOT_OK(stream_writer->Begin());
+ result.writer = std::move(stream_writer);
+ return result;
}
Status FlightClient::DoExchange(const FlightCallOptions& options,
const FlightDescriptor& descriptor,
std::unique_ptr<FlightStreamWriter>* writer,
std::unique_ptr<FlightStreamReader>* reader) {
- RETURN_NOT_OK(CheckOpen());
- std::unique_ptr<internal::ClientDataStream> remote_stream;
- RETURN_NOT_OK(transport_->DoExchange(options, &remote_stream));
- std::shared_ptr<internal::ClientDataStream> shared_stream = std::move(remote_stream);
- *reader = std::unique_ptr<FlightStreamReader>(new ClientStreamReader(
- shared_stream, options.read_options, options.stop_token, options.memory_manager));
- auto stream_writer = std::unique_ptr<ClientStreamWriter>(
- new ClientStreamWriter(std::move(shared_stream), options.write_options,
- write_size_limit_bytes_, descriptor));
- RETURN_NOT_OK(stream_writer->Begin());
- *writer = std::move(stream_writer);
+ ARROW_ASSIGN_OR_RAISE(auto result, DoExchange(options, descriptor));
+ *writer = std::move(result.writer);
+ *reader = std::move(result.reader);
return Status::OK();
}
diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h
index 06d87bb9ae..0298abe366 100644
--- a/cpp/src/arrow/flight/client.h
+++ b/cpp/src/arrow/flight/client.h
@@ -192,17 +192,22 @@ class ARROW_FLIGHT_EXPORT FlightClient {
/// \brief Connect to an unauthenticated flight service
/// \param[in] location the URI
- /// \param[out] client the created FlightClient
- /// \return Status OK status may not indicate that the connection was
- /// successful
+ /// \return Arrow result with the created FlightClient, OK status may not indicate that
+ /// the connection was successful
+ static arrow::Result<std::unique_ptr<FlightClient>> Connect(const Location& location);
+
+ ARROW_DEPRECATED("Deprecated in 8.0.0. Use Result-returning overload instead.")
static Status Connect(const Location& location, std::unique_ptr<FlightClient>* client);
/// \brief Connect to an unauthenticated flight service
/// \param[in] location the URI
/// \param[in] options Other options for setting up the client
- /// \param[out] client the created FlightClient
- /// \return Status OK status may not indicate that the connection was
- /// successful
+ /// \return Arrow result with the created FlightClient, OK status may not indicate that
+ /// the connection was successful
+ static arrow::Result<std::unique_ptr<FlightClient>> Connect(
+ const Location& location, const FlightClientOptions& options);
+
+ ARROW_DEPRECATED("Deprecated in 8.0.0. Use Result-returning overload instead.")
static Status Connect(const Location& location, const FlightClientOptions& options,
std::unique_ptr<FlightClient>* client);
@@ -227,21 +232,34 @@ class ARROW_FLIGHT_EXPORT FlightClient {
/// of results, if any
/// \param[in] options Per-RPC options
/// \param[in] action the action to be performed
- /// \param[out] results an iterator object for reading the returned results
- /// \return Status
+ /// \return Arrow result with an iterator object for reading the returned results
+ arrow::Result<std::unique_ptr<ResultStream>> DoAction(const FlightCallOptions& options,
+ const Action& action);
+ arrow::Result<std::unique_ptr<ResultStream>> DoAction(const Action& action) {
+ return DoAction({}, action);
+ }
+
+ ARROW_DEPRECATED("Deprecated in 8.0.0. Use Result-returning overload instead.")
Status DoAction(const FlightCallOptions& options, const Action& action,
std::unique_ptr<ResultStream>* results);
+ ARROW_DEPRECATED("Deprecated in 8.0.0. Use Result-returning overload instead.")
Status DoAction(const Action& action, std::unique_ptr<ResultStream>* results) {
- return DoAction({}, action, results);
+ return DoAction({}, action).Value(results);
}
/// \brief Retrieve a list of available Action types
/// \param[in] options Per-RPC options
- /// \param[out] actions the available actions
- /// \return Status
+ /// \return Arrow result with the available actions
+ arrow::Result<std::vector<ActionType>> ListActions(const FlightCallOptions& options);
+ arrow::Result<std::vector<ActionType>> ListActions() {
+ return ListActions(FlightCallOptions());
+ }
+
+ ARROW_DEPRECATED("Deprecated in 8.0.0. Use Result-returning overload instead.")
Status ListActions(const FlightCallOptions& options, std::vector<ActionType>* actions);
+ ARROW_DEPRECATED("Deprecated in 8.0.0. Use Result-returning overload instead.")
Status ListActions(std::vector<ActionType>* actions) {
- return ListActions({}, actions);
+ return ListActions().Value(actions);
}
/// \brief Request access plan for a single flight, which may be an existing
@@ -249,14 +267,22 @@ class ARROW_FLIGHT_EXPORT FlightClient {
/// \param[in] options Per-RPC options
/// \param[in] descriptor the dataset request, whether a named dataset or
/// command
- /// \param[out] info the FlightInfo describing where to access the dataset
- /// \return Status
+ /// \return Arrow result with the FlightInfo describing where to access the dataset
+ arrow::Result<std::unique_ptr<FlightInfo>> GetFlightInfo(
+ const FlightCallOptions& options, const FlightDescriptor& descriptor);
+ arrow::Result<std::unique_ptr<FlightInfo>> GetFlightInfo(
+ const FlightDescriptor& descriptor) {
+ return GetFlightInfo({}, descriptor);
+ }
+
+ ARROW_DEPRECATED("Deprecated in 8.0.0. Use Result-returning overload instead.")
Status GetFlightInfo(const FlightCallOptions& options,
const FlightDescriptor& descriptor,
std::unique_ptr<FlightInfo>* info);
+ ARROW_DEPRECATED("Deprecated in 8.0.0. Use Result-returning overload instead.")
Status GetFlightInfo(const FlightDescriptor& descriptor,
std::unique_ptr<FlightInfo>* info) {
- return GetFlightInfo({}, descriptor, info);
+ return GetFlightInfo({}, descriptor).Value(info);
}
/// \brief Request schema for a single flight, which may be an existing
@@ -283,15 +309,20 @@ class ARROW_FLIGHT_EXPORT FlightClient {
}
/// \brief List all available flights known to the server
- /// \param[out] listing an iterator that returns a FlightInfo for each flight
- /// \return Status
+ /// \return Arrow result with an iterator that returns a FlightInfo for each flight
+ arrow::Result<std::unique_ptr<FlightListing>> ListFlights();
+
+ ARROW_DEPRECATED("Deprecated in 8.0.0. Use Result-returning overload instead.")
Status ListFlights(std::unique_ptr<FlightListing>* listing);
/// \brief List available flights given indicated filter criteria
/// \param[in] options Per-RPC options
/// \param[in] criteria the filter criteria (opaque)
- /// \param[out] listing an iterator that returns a FlightInfo for each flight
- /// \return Status
+ /// \return Arrow result with an iterator that returns a FlightInfo for each flight
+ arrow::Result<std::unique_ptr<FlightListing>> ListFlights(
+ const FlightCallOptions& options, const Criteria& criteria);
+
+ ARROW_DEPRECATED("Deprecated in 8.0.0. Use Result-returning overload instead.")
Status ListFlights(const FlightCallOptions& options, const Criteria& criteria,
std::unique_ptr<FlightListing>* listing);
@@ -299,14 +330,28 @@ class ARROW_FLIGHT_EXPORT FlightClient {
/// stream. Returns record batch stream reader
/// \param[in] options Per-RPC options
/// \param[in] ticket The flight ticket to use
- /// \param[out] stream the returned RecordBatchReader
- /// \return Status
+ /// \return Arrow result with the returned RecordBatchReader
+ arrow::Result<std::unique_ptr<FlightStreamReader>> DoGet(
+ const FlightCallOptions& options, const Ticket& ticket);
+ arrow::Result<std::unique_ptr<FlightStreamReader>> DoGet(const Ticket& ticket) {
+ return DoGet({}, ticket);
+ }
+
+ ARROW_DEPRECATED("Deprecated in 8.0.0. Use Result-returning overload instead.")
Status DoGet(const FlightCallOptions& options, const Ticket& ticket,
std::unique_ptr<FlightStreamReader>* stream);
+ ARROW_DEPRECATED("Deprecated in 8.0.0. Use Result-returning overload instead.")
Status DoGet(const Ticket& ticket, std::unique_ptr<FlightStreamReader>* stream) {
- return DoGet({}, ticket, stream);
+ return DoGet({}, ticket).Value(stream);
}
+ /// \brief DoPut return value
+ struct DoPutResult {
+ /// \brief a writer to write record batches to
+ std::unique_ptr<FlightStreamWriter> writer;
+ /// \brief a reader for application metadata from the server
+ std::unique_ptr<FlightMetadataReader> reader;
+ };
/// \brief Upload data to a Flight described by the given
/// descriptor. The caller must call Close() on the returned stream
/// once they are done writing.
@@ -318,26 +363,53 @@ class ARROW_FLIGHT_EXPORT FlightClient {
/// \param[in] options Per-RPC options
/// \param[in] descriptor the descriptor of the stream
/// \param[in] schema the schema for the data to upload
- /// \param[out] stream a writer to write record batches to
- /// \param[out] reader a reader for application metadata from the server
- /// \return Status
+ /// \return Arrow result with a DoPutResult struct holding a reader and a writer
+ arrow::Result<DoPutResult> DoPut(const FlightCallOptions& options,
+ const FlightDescriptor& descriptor,
+ const std::shared_ptr<Schema>& schema);
+
+ arrow::Result<DoPutResult> DoPut(const FlightDescriptor& descriptor,
+ const std::shared_ptr<Schema>& schema) {
+ return DoPut({}, descriptor, schema);
+ }
+
+ ARROW_DEPRECATED("Deprecated in 8.0.0. Use Result-returning overload instead.")
Status DoPut(const FlightCallOptions& options, const FlightDescriptor& descriptor,
const std::shared_ptr<Schema>& schema,
- std::unique_ptr<FlightStreamWriter>* stream,
+ std::unique_ptr<FlightStreamWriter>* writer,
std::unique_ptr<FlightMetadataReader>* reader);
+ ARROW_DEPRECATED("Deprecated in 8.0.0. Use Result-returning overload instead.")
Status DoPut(const FlightDescriptor& descriptor, const std::shared_ptr<Schema>& schema,
- std::unique_ptr<FlightStreamWriter>* stream,
+ std::unique_ptr<FlightStreamWriter>* writer,
std::unique_ptr<FlightMetadataReader>* reader) {
- return DoPut({}, descriptor, schema, stream, reader);
+ ARROW_ASSIGN_OR_RAISE(auto output, DoPut({}, descriptor, schema));
+ *writer = std::move(output.writer);
+ *reader = std::move(output.reader);
+ return Status::OK();
}
+ struct DoExchangeResult {
+ std::unique_ptr<FlightStreamWriter> writer;
+ std::unique_ptr<FlightStreamReader> reader;
+ };
+ arrow::Result<DoExchangeResult> DoExchange(const FlightCallOptions& options,
+ const FlightDescriptor& descriptor);
+ arrow::Result<DoExchangeResult> DoExchange(const FlightDescriptor& descriptor) {
+ return DoExchange({}, descriptor);
+ }
+
+ ARROW_DEPRECATED("Deprecated in 8.0.0. Use Result-returning overload instead.")
Status DoExchange(const FlightCallOptions& options, const FlightDescriptor& descriptor,
std::unique_ptr<FlightStreamWriter>* writer,
std::unique_ptr<FlightStreamReader>* reader);
+ ARROW_DEPRECATED("Deprecated in 8.0.0. Use Result-returning overload instead.")
Status DoExchange(const FlightDescriptor& descriptor,
std::unique_ptr<FlightStreamWriter>* writer,
std::unique_ptr<FlightStreamReader>* reader) {
- return DoExchange({}, descriptor, writer, reader);
+ ARROW_ASSIGN_OR_RAISE(auto output, DoExchange({}, descriptor));
+ *writer = std::move(output.writer);
+ *reader = std::move(output.reader);
+ return Status::OK();
}
/// \brief Explicitly shut down and clean up the client.
diff --git a/cpp/src/arrow/flight/flight_benchmark.cc b/cpp/src/arrow/flight/flight_benchmark.cc
index aeb20407d7..fa0cc9a3d5 100644
--- a/cpp/src/arrow/flight/flight_benchmark.cc
+++ b/cpp/src/arrow/flight/flight_benchmark.cc
@@ -40,12 +40,20 @@
#include "arrow/flight/test_util.h"
#ifdef ARROW_CUDA
+#include <cuda.h>
#include "arrow/gpu/cuda_api.h"
#endif
+#ifdef ARROW_WITH_UCX
+#include "arrow/flight/transport/ucx/ucx.h"
+#endif
DEFINE_bool(cuda, false, "Allocate results in CUDA memory");
DEFINE_string(transport, "grpc",
- "The network transport to use. Supported: \"grpc\" (default).");
+ "The network transport to use. Supported: \"grpc\" (default)"
+#ifdef ARROW_WITH_UCX
+ ", \"ucx\""
+#endif // ARROW_WITH_UCX
+ ".");
DEFINE_string(server_host, "",
"An existing performance server to benchmark against (leave blank to spawn "
"one automatically)");
@@ -123,8 +131,7 @@ struct PerformanceStats {
Status WaitForReady(FlightClient* client, const FlightCallOptions& call_options) {
Action action{"ping", nullptr};
for (int attempt = 0; attempt < 10; attempt++) {
- std::unique_ptr<ResultStream> stream;
- if (client->DoAction(call_options, action, &stream).ok()) {
+ if (client->DoAction(call_options, action).ok()) {
return Status::OK();
}
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
@@ -138,7 +145,7 @@ arrow::Result<PerformanceResult> RunDoGetTest(FlightClient* client,
const FlightEndpoint& endpoint,
PerformanceStats* stats) {
std::unique_ptr<FlightStreamReader> reader;
- RETURN_NOT_OK(client->DoGet(call_options, endpoint.ticket, &reader));
+ ARROW_ASSIGN_OR_RAISE(reader, client->DoGet(call_options, endpoint.ticket));
FlightStreamChunk batch;
@@ -246,10 +253,10 @@ arrow::Result<PerformanceResult> RunDoPutTest(FlightClient* client,
StopWatch timer;
int64_t num_records = 0;
int64_t num_bytes = 0;
- std::unique_ptr<FlightStreamWriter> writer;
- std::unique_ptr<FlightMetadataReader> reader;
- RETURN_NOT_OK(client->DoPut(call_options, FlightDescriptor{},
- batches[0].batch->schema(), &writer, &reader));
+ ARROW_ASSIGN_OR_RAISE(
+ auto do_put_result,
+ client->DoPut(call_options, FlightDescriptor{}, batches[0].batch->schema()));
+ std::unique_ptr<FlightStreamWriter> writer = std::move(do_put_result.writer);
for (size_t i = 0; i < batches.size(); i++) {
auto batch = batches[i];
auto is_last = i == (batches.size() - 1);
@@ -283,8 +290,7 @@ Status DoSinglePerfRun(FlightClient* client, const FlightClientOptions client_op
descriptor.type = FlightDescriptor::CMD;
perf.SerializeToString(&descriptor.cmd);
- std::unique_ptr<FlightInfo> plan;
- RETURN_NOT_OK(client->GetFlightInfo(call_options, descriptor, &plan));
+ ARROW_ASSIGN_OR_RAISE(auto plan, client->GetFlightInfo(call_options, descriptor));
// Read the streams in parallel
ipc::DictionaryMemo dict_memo;
@@ -300,8 +306,9 @@ Status DoSinglePerfRun(FlightClient* client, const FlightClientOptions client_op
if (endpoint.locations.empty()) {
data_client = client;
} else {
- RETURN_NOT_OK(FlightClient::Connect(endpoint.locations.front(), client_options,
- &local_client));
+ ARROW_ASSIGN_OR_RAISE(
+ local_client,
+ FlightClient::Connect(endpoint.locations.front(), client_options));
data_client = local_client.get();
}
@@ -498,6 +505,21 @@ int main(int argc, char** argv) {
options.disable_server_verification = true;
}
}
+ } else if (FLAGS_transport == "ucx") {
+#ifdef ARROW_WITH_UCX
+ arrow::flight::transport::ucx::InitializeFlightUcx();
+ if (FLAGS_test_unix || !FLAGS_server_unix.empty()) {
+ std::cerr << "Transport does not support domain sockets: " << FLAGS_transport
+ << std::endl;
+ return EXIT_FAILURE;
+ }
+ ARROW_CHECK_OK(arrow::flight::Location::Parse("ucx://" + FLAGS_server_host + ":" +
+ std::to_string(FLAGS_server_port))
+ .Value(&location));
+#else
+ std::cerr << "Not built with transport: " << FLAGS_transport << std::endl;
+ return EXIT_FAILURE;
+#endif
} else {
std::cerr << "Unknown transport: " << FLAGS_transport << std::endl;
return EXIT_FAILURE;
@@ -515,14 +537,23 @@ int main(int argc, char** argv) {
ABORT_NOT_OK(arrow::cuda::CudaDeviceManager::Instance().Value(&manager));
ABORT_NOT_OK(manager->GetDevice(0).Value(&device));
call_options.memory_manager = device->default_memory_manager();
+
+ // Needed to prevent UCX warning
+ // cuda_md.c:162 UCX ERROR cuMemGetAddressRange(0x7f2ab5dc0000) error: invalid
+ // device context
+ std::shared_ptr<arrow::cuda::CudaContext> context;
+ ABORT_NOT_OK(device->GetContext().Value(&context));
+ auto cuda_status = cuCtxPushCurrent(reinterpret_cast<CUcontext>(context->handle()));
+ if (cuda_status != CUDA_SUCCESS) {
+ ARROW_LOG(WARNING) << "CUDA error " << cuda_status;
+ }
#else
std::cerr << "-cuda requires that Arrow is built with ARROW_CUDA" << std::endl;
return 1;
#endif
}
- std::unique_ptr<arrow::flight::FlightClient> client;
- ABORT_NOT_OK(arrow::flight::FlightClient::Connect(location, options, &client));
+ auto client = arrow::flight::FlightClient::Connect(location, options).ValueOrDie();
ABORT_NOT_OK(arrow::flight::WaitForReady(client.get(), call_options));
arrow::Status s = arrow::flight::RunPerformanceTest(client.get(), options, call_options,
diff --git a/cpp/src/arrow/flight/flight_test.cc b/cpp/src/arrow/flight/flight_test.cc
index b30a91268e..3f0ed7114f 100644
--- a/cpp/src/arrow/flight/flight_test.cc
+++ b/cpp/src/arrow/flight/flight_test.cc
@@ -124,9 +124,9 @@ TEST(TestFlight, ConnectUri) {
std::unique_ptr<FlightClient> client;
ASSERT_OK_AND_ASSIGN(auto location1, Location::Parse(uri));
ASSERT_OK_AND_ASSIGN(auto location2, Location::Parse(uri));
- ASSERT_OK(FlightClient::Connect(location1, &client));
+ ASSERT_OK_AND_ASSIGN(client, FlightClient::Connect(location1));
ASSERT_OK(client->Close());
- ASSERT_OK(FlightClient::Connect(location2, &client));
+ ASSERT_OK_AND_ASSIGN(client, FlightClient::Connect(location2));
ASSERT_OK(client->Close());
}
@@ -143,9 +143,9 @@ TEST(TestFlight, ConnectUriUnix) {
std::unique_ptr<FlightClient> client;
ASSERT_OK_AND_ASSIGN(auto location1, Location::Parse(uri));
ASSERT_OK_AND_ASSIGN(auto location2, Location::Parse(uri));
- ASSERT_OK(FlightClient::Connect(location1, &client));
+ ASSERT_OK_AND_ASSIGN(client, FlightClient::Connect(location1));
ASSERT_OK(client->Close());
- ASSERT_OK(FlightClient::Connect(location2, &client));
+ ASSERT_OK_AND_ASSIGN(client, FlightClient::Connect(location2));
ASSERT_OK(client->Close());
}
#endif
@@ -161,9 +161,8 @@ TEST(TestFlight, DISABLED_IpV6Port) {
ASSERT_OK_AND_ASSIGN(auto location2, Location::ForGrpcTcp("[::1]", server->port()));
std::unique_ptr<FlightClient> client;
- ASSERT_OK(FlightClient::Connect(location2, &client));
- std::unique_ptr<FlightListing> listing;
- ASSERT_OK(client->ListFlights(&listing));
+ ASSERT_OK_AND_ASSIGN(client, FlightClient::Connect(location2));
+ ASSERT_OK(client->ListFlights());
}
// ----------------------------------------------------------------------
@@ -189,7 +188,7 @@ class TestFlightClient : public ::testing::Test {
Status ConnectClient() {
ARROW_ASSIGN_OR_RAISE(auto location,
Location::ForGrpcTcp("localhost", server_->port()));
- return FlightClient::Connect(location, &client_);
+ return FlightClient::Connect(location).Value(&client_);
}
template <typename EndpointCheckFunc>
@@ -198,8 +197,7 @@ class TestFlightClient : public ::testing::Test {
EndpointCheckFunc&& check_endpoints) {
auto expected_schema = expected_batches[0]->schema();
- std::unique_ptr<FlightInfo> info;
- ASSERT_OK(client_->GetFlightInfo(descr, &info));
+ ASSERT_OK_AND_ASSIGN(auto info, client_->GetFlightInfo(descr));
check_endpoints(info->endpoints());
ipc::DictionaryMemo dict_memo;
@@ -215,11 +213,8 @@ class TestFlightClient : public ::testing::Test {
auto num_batches = static_cast<int>(expected_batches.size());
ASSERT_GE(num_batches, 2);
- std::unique_ptr<FlightStreamReader> stream;
- ASSERT_OK(client_->DoGet(ticket, &stream));
-
- std::unique_ptr<FlightStreamReader> stream2;
- ASSERT_OK(client_->DoGet(ticket, &stream2));
+ ASSERT_OK_AND_ASSIGN(auto stream, client_->DoGet(ticket));
+ ASSERT_OK_AND_ASSIGN(auto stream2, client_->DoGet(ticket));
ASSERT_OK_AND_ASSIGN(auto reader, MakeRecordBatchReader(std::move(stream2)));
FlightStreamChunk chunk;
@@ -367,7 +362,7 @@ class TestTls : public ::testing::Test {
CertKeyPair root_cert;
RETURN_NOT_OK(ExampleTlsCertificateRoot(&root_cert));
options.tls_root_certs = root_cert.pem_cert;
- return FlightClient::Connect(location_, options, &client_);
+ return FlightClient::Connect(location_, options).Value(&client_);
}
protected:
@@ -650,7 +645,7 @@ class PropagatingTestServer : public FlightServerBase {
current_span_id = ((const TracingServerMiddleware*)middleware)->span_id;
}
- return client_->DoAction(action, result);
+ return client_->DoAction(action).Value(result);
}
private:
@@ -820,7 +815,7 @@ class TestBasicHeaderAuthMiddleware : public ::testing::Test {
std::unique_ptr<FlightListing> listing;
FlightCallOptions call_options;
call_options.headers.push_back(bearer_result.ValueOrDie());
- ASSERT_OK(client_->ListFlights(call_options, {}, &listing));
+ ASSERT_OK_AND_ASSIGN(listing, client_->ListFlights(call_options, {}));
ASSERT_TRUE(bearer_middleware_->GetIsValid());
}
@@ -846,13 +841,12 @@ class TestBasicHeaderAuthMiddleware : public ::testing::Test {
TEST_F(TestErrorMiddleware, TestMetadata) {
Action action;
- std::unique_ptr<ResultStream> stream;
// Run action1
action.type = "action1";
action.body = Buffer::FromString("action1-content");
- Status s = client_->DoAction(action, &stream);
+ Status s = client_->DoAction(action).status();
ASSERT_FALSE(s.ok());
std::shared_ptr<FlightStatusDetail> flightStatusDetail =
FlightStatusDetail::UnwrapStatus(s);
@@ -861,8 +855,7 @@ TEST_F(TestErrorMiddleware, TestMetadata) {
}
TEST_F(TestFlightClient, ListFlights) {
- std::unique_ptr<FlightListing> listing;
- ASSERT_OK(client_->ListFlights(&listing));
+ ASSERT_OK_AND_ASSIGN(auto listing, client_->ListFlights());
ASSERT_TRUE(listing != nullptr);
std::vector<FlightInfo> flights = ExampleFlightInfo();
@@ -880,8 +873,7 @@ TEST_F(TestFlightClient, ListFlights) {
}
TEST_F(TestFlightClient, ListFlightsWithCriteria) {
- std::unique_ptr<FlightListing> listing;
- ASSERT_OK(client_->ListFlights(FlightCallOptions(), {"foo"}, &listing));
+ ASSERT_OK_AND_ASSIGN(auto listing, client_->ListFlights(FlightCallOptions(), {"foo"}));
std::unique_ptr<FlightInfo> info;
ASSERT_OK_AND_ASSIGN(info, listing->Next());
ASSERT_TRUE(info == nullptr);
@@ -889,9 +881,7 @@ TEST_F(TestFlightClient, ListFlightsWithCriteria) {
TEST_F(TestFlightClient, GetFlightInfo) {
auto descr = FlightDescriptor::Path({"examples", "ints"});
- std::unique_ptr<FlightInfo> info;
-
- ASSERT_OK(client_->GetFlightInfo(descr, &info));
+ ASSERT_OK_AND_ASSIGN(auto info, client_->GetFlightInfo(descr));
ASSERT_NE(info, nullptr);
std::vector<FlightInfo> flights = ExampleFlightInfo();
@@ -909,17 +899,15 @@ TEST_F(TestFlightClient, GetSchema) {
TEST_F(TestFlightClient, GetFlightInfoNotFound) {
auto descr = FlightDescriptor::Path({"examples", "things"});
- std::unique_ptr<FlightInfo> info;
// XXX Ideally should be Invalid (or KeyError), but gRPC doesn't support
// multiple error codes.
- auto st = client_->GetFlightInfo(descr, &info);
+ auto st = client_->GetFlightInfo(descr).status();
ASSERT_RAISES(Invalid, st);
ASSERT_NE(st.message().find("Flight not found"), std::string::npos);
}
TEST_F(TestFlightClient, ListActions) {
- std::vector<ActionType> actions;
- ASSERT_OK(client_->ListActions(&actions));
+ ASSERT_OK_AND_ASSIGN(std::vector<ActionType> actions, client_->ListActions());
std::vector<ActionType> expected = ExampleActionTypes();
EXPECT_THAT(actions, ::testing::ContainerEq(expected));
@@ -927,7 +915,6 @@ TEST_F(TestFlightClient, ListActions) {
TEST_F(TestFlightClient, DoAction) {
Action action;
- std::unique_ptr<ResultStream> stream;
std::unique_ptr<Result> result;
// Run action1
@@ -935,7 +922,7 @@ TEST_F(TestFlightClient, DoAction) {
const std::string action1_value = "action1-content";
action.body = Buffer::FromString(action1_value);
- ASSERT_OK(client_->DoAction(action, &stream));
+ ASSERT_OK_AND_ASSIGN(auto stream, client_->DoAction(action));
for (int i = 0; i < 3; ++i) {
ASSERT_OK_AND_ASSIGN(result, stream->Next());
@@ -949,7 +936,7 @@ TEST_F(TestFlightClient, DoAction) {
// Run action2, no results
action.type = "action2";
- ASSERT_OK(client_->DoAction(action, &stream));
+ ASSERT_OK_AND_ASSIGN(stream, client_->DoAction(action));
ASSERT_OK_AND_ASSIGN(result, stream->Next());
ASSERT_EQ(nullptr, result);
@@ -957,20 +944,18 @@ TEST_F(TestFlightClient, DoAction) {
TEST_F(TestFlightClient, RoundTripStatus) {
const auto descr = FlightDescriptor::Command("status-outofmemory");
- std::unique_ptr<FlightInfo> info;
- const auto status = client_->GetFlightInfo(descr, &info);
+ const auto status = client_->GetFlightInfo(descr).status();
ASSERT_RAISES(OutOfMemory, status);
}
// Test setting generic transport options by configuring gRPC to fail
// all calls.
TEST_F(TestFlightClient, GenericOptions) {
- std::unique_ptr<FlightClient> client;
auto options = FlightClientOptions::Defaults();
// Set a very low limit at the gRPC layer to fail all calls
options.generic_options.emplace_back(GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH, 4);
ASSERT_OK_AND_ASSIGN(auto location, Location::ForGrpcTcp("localhost", server_->port()));
- ASSERT_OK(FlightClient::Connect(location, options, &client));
+ ASSERT_OK_AND_ASSIGN(auto client, FlightClient::Connect(location, options));
auto descr = FlightDescriptor::Path({"examples", "ints"});
std::shared_ptr<Schema> schema;
ipc::DictionaryMemo dict_memo;
@@ -981,14 +966,12 @@ TEST_F(TestFlightClient, GenericOptions) {
TEST_F(TestFlightClient, TimeoutFires) {
// Server does not exist on this port, so call should fail
- std::unique_ptr<FlightClient> client;
ASSERT_OK_AND_ASSIGN(auto location, Location::ForGrpcTcp("localhost", 30001));
- ASSERT_OK(FlightClient::Connect(location, &client));
+ ASSERT_OK_AND_ASSIGN(auto client, FlightClient::Connect(location));
FlightCallOptions options;
options.timeout = TimeoutDuration{0.2};
- std::unique_ptr<FlightInfo> info;
auto start = std::chrono::system_clock::now();
- Status status = client->GetFlightInfo(options, FlightDescriptor{}, &info);
+ Status status = client->GetFlightInfo(options, FlightDescriptor{}).status();
auto end = std::chrono::system_clock::now();
#ifdef ARROW_WITH_TIMING_TESTS
EXPECT_LE(end - start, std::chrono::milliseconds{400});
@@ -1005,7 +988,7 @@ TEST_F(TestFlightClient, NoTimeout) {
std::unique_ptr<FlightInfo> info;
auto start = std::chrono::system_clock::now();
auto descriptor = FlightDescriptor::Path({"examples", "ints"});
- Status status = client_->GetFlightInfo(options, descriptor, &info);
+ Status status = client_->GetFlightInfo(options, descriptor).Value(&info);
auto end = std::chrono::system_clock::now();
#ifdef ARROW_WITH_TIMING_TESTS
EXPECT_LE(end - start, std::chrono::milliseconds{600});
@@ -1022,9 +1005,8 @@ TEST_F(TestFlightClient, Close) {
// Idempotent
ASSERT_OK(client_->Close());
- std::unique_ptr<FlightListing> listing;
EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("FlightClient is closed"),
- client_->ListFlights(&listing));
+ client_->ListFlights());
}
TEST_F(TestAuthHandler, PassAuthenticatedCalls) {
@@ -1033,65 +1015,52 @@ TEST_F(TestAuthHandler, PassAuthenticatedCalls) {
std::unique_ptr<ClientAuthHandler>(new TestClientAuthHandler("user", "p4ssw0rd"))));
Status status;
- std::unique_ptr<FlightListing> listing;
- status = client_->ListFlights(&listing);
+ status = client_->ListFlights().status();
ASSERT_RAISES(NotImplemented, status);
std::unique_ptr<ResultStream> results;
Action action;
action.type = "";
action.body = Buffer::FromString("");
- status = client_->DoAction(action, &results);
- ASSERT_OK(status);
+ ASSERT_OK_AND_ASSIGN(results, client_->DoAction(action));
- std::vector<ActionType> actions;
- status = client_->ListActions(&actions);
+ status = client_->ListActions().status();
ASSERT_RAISES(NotImplemented, status);
- std::unique_ptr<FlightInfo> info;
- status = client_->GetFlightInfo(FlightDescriptor{}, &info);
+ status = client_->GetFlightInfo(FlightDescriptor{}).status();
ASSERT_RAISES(NotImplemented, status);
- std::unique_ptr<FlightStreamReader> stream;
- status = client_->DoGet(Ticket{}, &stream);
+ status = client_->DoGet(Ticket{}).status();
ASSERT_RAISES(NotImplemented, status);
- std::unique_ptr<FlightStreamWriter> writer;
- std::unique_ptr<FlightMetadataReader> reader;
std::shared_ptr<Schema> schema = arrow::schema({});
- status = client_->DoPut(FlightDescriptor{}, schema, &writer, &reader);
- ASSERT_OK(status);
- status = writer->Close();
+ ASSERT_OK_AND_ASSIGN(auto do_put_result, client_->DoPut(FlightDescriptor{}, schema));
+ status = do_put_result.writer->Close();
ASSERT_RAISES(NotImplemented, status);
}
TEST_F(TestAuthHandler, FailUnauthenticatedCalls) {
Status status;
- std::unique_ptr<FlightListing> listing;
- status = client_->ListFlights(&listing);
+ status = client_->ListFlights().status();
ASSERT_RAISES(IOError, status);
ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
- std::unique_ptr<ResultStream> results;
Action action;
action.type = "";
action.body = Buffer::FromString("");
- status = client_->DoAction(action, &results);
+ status = client_->DoAction(action).status();
ASSERT_RAISES(IOError, status);
ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
- std::vector<ActionType> actions;
- status = client_->ListActions(&actions);
+ status = client_->ListActions().status();
ASSERT_RAISES(IOError, status);
ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
- std::unique_ptr<FlightInfo> info;
- status = client_->GetFlightInfo(FlightDescriptor{}, &info);
+ status = client_->GetFlightInfo(FlightDescriptor{}).status();
ASSERT_RAISES(IOError, status);
ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
- std::unique_ptr<FlightStreamReader> stream;
- status = client_->DoGet(Ticket{}, &stream);
+ status = client_->DoGet(Ticket{}).status();
ASSERT_RAISES(IOError, status);
ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
@@ -1099,9 +1068,10 @@ TEST_F(TestAuthHandler, FailUnauthenticatedCalls) {
std::unique_ptr<FlightMetadataReader> reader;
std::shared_ptr<Schema> schema(
(new arrow::Schema(std::vector<std::shared_ptr<Field>>())));
- status = client_->DoPut(FlightDescriptor{}, schema, &writer, &reader);
+ FlightClient::DoPutResult do_put_result;
+ status = client_->DoPut(FlightDescriptor{}, schema).Value(&do_put_result);
// ARROW-16053: gRPC may or may not fail the call immediately
- if (status.ok()) status = writer->Close();
+ if (status.ok()) status = do_put_result.writer->Close();
ASSERT_RAISES(IOError, status);
// ARROW-7583: don't check the error message here.
// Because gRPC reports errors in some paths with booleans, instead
@@ -1119,7 +1089,7 @@ TEST_F(TestAuthHandler, CheckPeerIdentity) {
action.type = "who-am-i";
action.body = Buffer::FromString("");
std::unique_ptr<ResultStream> results;
- ASSERT_OK(client_->DoAction(action, &results));
+ ASSERT_OK_AND_ASSIGN(results, client_->DoAction(action));
ASSERT_NE(results, nullptr);
std::unique_ptr<Result> result;
@@ -1144,76 +1114,62 @@ TEST_F(TestBasicAuthHandler, PassAuthenticatedCalls) {
new TestClientBasicAuthHandler("user", "p4ssw0rd"))));
Status status;
- std::unique_ptr<FlightListing> listing;
- status = client_->ListFlights(&listing);
+ status = client_->ListFlights().status();
ASSERT_RAISES(NotImplemented, status);
- std::unique_ptr<ResultStream> results;
Action action;
action.type = "";
action.body = Buffer::FromString("");
- status = client_->DoAction(action, &results);
+ status = client_->DoAction(action).status();
ASSERT_OK(status);
- std::vector<ActionType> actions;
- status = client_->ListActions(&actions);
+ status = client_->ListActions().status();
ASSERT_RAISES(NotImplemented, status);
- std::unique_ptr<FlightInfo> info;
- status = client_->GetFlightInfo(FlightDescriptor{}, &info);
+ status = client_->GetFlightInfo(FlightDescriptor{}).status();
ASSERT_RAISES(NotImplemented, status);
- std::unique_ptr<FlightStreamReader> stream;
- status = client_->DoGet(Ticket{}, &stream);
+ status = client_->DoGet(Ticket{}).status();
ASSERT_RAISES(NotImplemented, status);
- std::unique_ptr<FlightStreamWriter> writer;
- std::unique_ptr<FlightMetadataReader> reader;
std::shared_ptr<Schema> schema = arrow::schema({});
- status = client_->DoPut(FlightDescriptor{}, schema, &writer, &reader);
- ASSERT_OK(status);
- status = writer->Close();
+ ASSERT_OK_AND_ASSIGN(auto do_put_result, client_->DoPut(FlightDescriptor{}, schema));
+ status = do_put_result.writer->Close();
ASSERT_RAISES(NotImplemented, status);
}
TEST_F(TestBasicAuthHandler, FailUnauthenticatedCalls) {
Status status;
- std::unique_ptr<FlightListing> listing;
- status = client_->ListFlights(&listing);
+ status = client_->ListFlights().status();
ASSERT_RAISES(IOError, status);
ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
- std::unique_ptr<ResultStream> results;
Action action;
action.type = "";
action.body = Buffer::FromString("");
- status = client_->DoAction(action, &results);
+ status = client_->DoAction(action).status();
ASSERT_RAISES(IOError, status);
ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
- std::vector<ActionType> actions;
- status = client_->ListActions(&actions);
+ status = client_->ListActions().status();
ASSERT_RAISES(IOError, status);
ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
- std::unique_ptr<FlightInfo> info;
- status = client_->GetFlightInfo(FlightDescriptor{}, &info);
+ status = client_->GetFlightInfo(FlightDescriptor{}).status();
ASSERT_RAISES(IOError, status);
ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
- std::unique_ptr<FlightStreamReader> stream;
- status = client_->DoGet(Ticket{}, &stream);
+ status = client_->DoGet(Ticket{}).status();
ASSERT_RAISES(IOError, status);
ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
- std::unique_ptr<FlightStreamWriter> writer;
- std::unique_ptr<FlightMetadataReader> reader;
std::shared_ptr<Schema> schema(
(new arrow::Schema(std::vector<std::shared_ptr<Field>>())));
- status = client_->DoPut(FlightDescriptor{}, schema, &writer, &reader);
+ FlightClient::DoPutResult do_put_result;
+ status = client_->DoPut(FlightDescriptor{}, schema).Value(&do_put_result);
// May or may not succeed depending on if the transport buffers the write
ARROW_UNUSED(status);
- status = writer->Close();
+ status = do_put_result.writer->Close();
// But this should definitely fail
ASSERT_RAISES(IOError, status);
ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
@@ -1227,8 +1183,7 @@ TEST_F(TestBasicAuthHandler, CheckPeerIdentity) {
Action action;
action.type = "who-am-i";
action.body = Buffer::FromString("");
- std::unique_ptr<ResultStream> results;
- ASSERT_OK(client_->DoAction(action, &results));
+ ASSERT_OK_AND_ASSIGN(auto results, client_->DoAction(action));
ASSERT_NE(results, nullptr);
std::unique_ptr<Result> result;
@@ -1244,8 +1199,7 @@ TEST_F(TestTls, DoAction) {
Action action;
action.type = "test";
action.body = Buffer::FromString("");
- std::unique_ptr<ResultStream> results;
- ASSERT_OK(client_->DoAction(options, action, &results));
+ ASSERT_OK_AND_ASSIGN(auto results, client_->DoAction(options, action));
ASSERT_NE(results, nullptr);
std::unique_ptr<Result> result;
@@ -1256,21 +1210,19 @@ TEST_F(TestTls, DoAction) {
#if defined(GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS)
TEST_F(TestTls, DisableServerVerification) {
- std::unique_ptr<FlightClient> client;
auto client_options = FlightClientOptions::Defaults();
// For security reasons, if encryption is being used,
// the client should be configured to verify the server by default.
ASSERT_EQ(client_options.disable_server_verification, false);
client_options.disable_server_verification = true;
- ASSERT_OK(FlightClient::Connect(location_, client_options, &client));
+ ASSERT_OK_AND_ASSIGN(auto client, FlightClient::Connect(location_, client_options));
FlightCallOptions options;
options.timeout = TimeoutDuration{5.0};
Action action;
action.type = "test";
action.body = Buffer::FromString("");
- std::unique_ptr<ResultStream> results;
- ASSERT_OK(client->DoAction(options, action, &results));
+ ASSERT_OK_AND_ASSIGN(auto results, client->DoAction(options, action));
ASSERT_NE(results, nullptr);
std::unique_ptr<Result> result;
@@ -1281,60 +1233,53 @@ TEST_F(TestTls, DisableServerVerification) {
#endif
TEST_F(TestTls, OverrideHostname) {
- std::unique_ptr<FlightClient> client;
auto client_options = FlightClientOptions::Defaults();
client_options.override_hostname = "fakehostname";
CertKeyPair root_cert;
ASSERT_OK(ExampleTlsCertificateRoot(&root_cert));
client_options.tls_root_certs = root_cert.pem_cert;
- ASSERT_OK(FlightClient::Connect(location_, client_options, &client));
+ ASSERT_OK_AND_ASSIGN(auto client, FlightClient::Connect(location_, client_options));
FlightCallOptions options;
options.timeout = TimeoutDuration{5.0};
Action action;
action.type = "test";
action.body = Buffer::FromString("");
- std::unique_ptr<ResultStream> results;
- ASSERT_RAISES(IOError, client->DoAction(options, action, &results));
+ ASSERT_RAISES(IOError, client->DoAction(options, action));
}
// Test the facility for setting generic transport options.
TEST_F(TestTls, OverrideHostnameGeneric) {
- std::unique_ptr<FlightClient> client;
auto client_options = FlightClientOptions::Defaults();
client_options.generic_options.emplace_back(GRPC_SSL_TARGET_NAME_OVERRIDE_ARG,
"fakehostname");
CertKeyPair root_cert;
ASSERT_OK(ExampleTlsCertificateRoot(&root_cert));
client_options.tls_root_certs = root_cert.pem_cert;
- ASSERT_OK(FlightClient::Connect(location_, client_options, &client));
+ ASSERT_OK_AND_ASSIGN(auto client, FlightClient::Connect(location_, client_options));
FlightCallOptions options;
options.timeout = TimeoutDuration{5.0};
Action action;
action.type = "test";
action.body = Buffer::FromString("");
- std::unique_ptr<ResultStream> results;
- ASSERT_RAISES(IOError, client->DoAction(options, action, &results));
+ ASSERT_RAISES(IOError, client->DoAction(options, action));
// Could check error message for the gRPC error message but it isn't
// necessarily stable
}
TEST_F(TestRejectServerMiddleware, Rejected) {
- std::unique_ptr<FlightInfo> info;
- const auto& status = client_->GetFlightInfo(FlightDescriptor{}, &info);
+ const Status status = client_->GetFlightInfo(FlightDescriptor{}).status();
ASSERT_RAISES(IOError, status);
ASSERT_THAT(status.message(), ::testing::HasSubstr("All calls are rejected"));
}
TEST_F(TestCountingServerMiddleware, Count) {
- std::unique_ptr<FlightInfo> info;
- const auto& status = client_->GetFlightInfo(FlightDescriptor{}, &info);
+ const Status status = client_->GetFlightInfo(FlightDescriptor{}).status();
ASSERT_RAISES(NotImplemented, status);
Ticket ticket{""};
- std::unique_ptr<FlightStreamReader> stream;
- ASSERT_OK(client_->DoGet(ticket, &stream));
+ ASSERT_OK_AND_ASSIGN(auto stream, client_->DoGet(ticket));
ASSERT_EQ(1, request_counter_->failed_);
@@ -1351,7 +1296,6 @@ TEST_F(TestCountingServerMiddleware, Count) {
TEST_F(TestPropagatingMiddleware, Propagate) {
Action action;
- std::unique_ptr<ResultStream> stream;
std::unique_ptr<Result> result;
current_span_id = "trace-id";
@@ -1359,7 +1303,7 @@ TEST_F(TestPropagatingMiddleware, Propagate) {
action.type = "action1";
action.body = Buffer::FromString("action1-content");
- ASSERT_OK(client_->DoAction(action, &stream));
+ ASSERT_OK_AND_ASSIGN(auto stream, client_->DoAction(action));
ASSERT_OK_AND_ASSIGN(result, stream->Next());
ASSERT_EQ("trace-id", result->body->ToString());
@@ -1371,8 +1315,7 @@ TEST_F(TestPropagatingMiddleware, Propagate) {
// passed to the interceptor
TEST_F(TestPropagatingMiddleware, ListFlights) {
client_middleware_->Reset();
- std::unique_ptr<FlightListing> listing;
- const Status status = client_->ListFlights(&listing);
+ const Status status = client_->ListFlights().status();
ASSERT_RAISES(NotImplemented, status);
ValidateStatus(status, FlightMethod::ListFlights);
}
@@ -1380,8 +1323,7 @@ TEST_F(TestPropagatingMiddleware, ListFlights) {
TEST_F(TestPropagatingMiddleware, GetFlightInfo) {
client_middleware_->Reset();
auto descr = FlightDescriptor::Path({"examples", "ints"});
- std::unique_ptr<FlightInfo> info;
- const Status status = client_->GetFlightInfo(descr, &info);
+ const Status status = client_->GetFlightInfo(descr).status();
ASSERT_RAISES(NotImplemented, status);
ValidateStatus(status, FlightMethod::GetFlightInfo);
}
@@ -1397,7 +1339,7 @@ TEST_F(TestPropagatingMiddleware, GetSchema) {
TEST_F(TestPropagatingMiddleware, ListActions) {
client_middleware_->Reset();
std::vector<ActionType> actions;
- const Status status = client_->ListActions(&actions);
+ const Status status = client_->ListActions().status();
ASSERT_RAISES(NotImplemented, status);
ValidateStatus(status, FlightMethod::ListActions);
}
@@ -1406,7 +1348,7 @@ TEST_F(TestPropagatingMiddleware, DoGet) {
client_middleware_->Reset();
Ticket ticket1{"ARROW-5095-fail"};
std::unique_ptr<FlightStreamReader> stream;
- Status status = client_->DoGet(ticket1, &stream);
+ Status status = client_->DoGet(ticket1).status();
ASSERT_RAISES(NotImplemented, status);
ValidateStatus(status, FlightMethod::DoGet);
}
@@ -1417,10 +1359,8 @@ TEST_F(TestPropagatingMiddleware, DoPut) {
auto a1 = ArrayFromJSON(int32(), "[4, 5, 6, null]");
auto schema = arrow::schema({field("f1", a1->type())});
- std::unique_ptr<FlightStreamWriter> stream;
- std::unique_ptr<FlightMetadataReader> reader;
- ASSERT_OK(client_->DoPut(descr, schema, &stream, &reader));
- const Status status = stream->Close();
+ ASSERT_OK_AND_ASSIGN(auto do_put_result, client_->DoPut(descr, schema));
+ const Status status = do_put_result.writer->Close();
ASSERT_RAISES(NotImplemented, status);
ValidateStatus(status, FlightMethod::DoPut);
}
@@ -1515,20 +1455,18 @@ TEST_F(TestCancel, ListFlights) {
StopSource stop_source;
FlightCallOptions options;
options.stop_token = stop_source.token();
- std::unique_ptr<FlightListing> listing;
stop_source.RequestStop(Status::Cancelled("StopSource"));
EXPECT_RAISES_WITH_MESSAGE_THAT(Cancelled, ::testing::HasSubstr("StopSource"),
- client_->ListFlights(options, {}, &listing));
+ client_->ListFlights(options, {}));
}
TEST_F(TestCancel, DoAction) {
StopSource stop_source;
FlightCallOptions options;
options.stop_token = stop_source.token();
- std::unique_ptr<ResultStream> results;
stop_source.RequestStop(Status::Cancelled("StopSource"));
EXPECT_RAISES_WITH_MESSAGE_THAT(Cancelled, ::testing::HasSubstr("StopSource"),
- client_->DoAction(options, {}, &results));
+ client_->DoAction(options, {}));
}
TEST_F(TestCancel, ListActions) {
@@ -1538,7 +1476,7 @@ TEST_F(TestCancel, ListActions) {
std::vector<ActionType> results;
stop_source.RequestStop(Status::Cancelled("StopSource"));
EXPECT_RAISES_WITH_MESSAGE_THAT(Cancelled, ::testing::HasSubstr("StopSource"),
- client_->ListActions(options, &results));
+ client_->ListActions(options));
}
TEST_F(TestCancel, DoGet) {
@@ -1547,12 +1485,11 @@ TEST_F(TestCancel, DoGet) {
options.stop_token = stop_source.token();
std::unique_ptr<ResultStream> results;
stop_source.RequestStop(Status::Cancelled("StopSource"));
- std::unique_ptr<FlightStreamReader> stream;
- ASSERT_OK(client_->DoGet(options, {}, &stream));
+ ASSERT_OK_AND_ASSIGN(auto stream, client_->DoGet(options, {}));
EXPECT_RAISES_WITH_MESSAGE_THAT(Cancelled, ::testing::HasSubstr("StopSource"),
stream->ToTable());
- ASSERT_OK(client_->DoGet({}, &stream));
+ ASSERT_OK_AND_ASSIGN(stream, client_->DoGet({}));
EXPECT_RAISES_WITH_MESSAGE_THAT(Cancelled, ::testing::HasSubstr("StopSource"),
stream->ToTable(options.stop_token));
}
@@ -1563,18 +1500,17 @@ TEST_F(TestCancel, DoExchange) {
options.stop_token = stop_source.token();
std::unique_ptr<ResultStream> results;
stop_source.RequestStop(Status::Cancelled("StopSource"));
- std::unique_ptr<FlightStreamWriter> writer;
- std::unique_ptr<FlightStreamReader> stream;
- ASSERT_OK(
- client_->DoExchange(options, FlightDescriptor::Command(""), &writer, &stream));
+ ASSERT_OK_AND_ASSIGN(auto do_exchange_result,
+ client_->DoExchange(options, FlightDescriptor::Command("")));
EXPECT_RAISES_WITH_MESSAGE_THAT(Cancelled, ::testing::HasSubstr("StopSource"),
- stream->ToTable());
- ARROW_UNUSED(writer->Close());
+ do_exchange_result.reader->ToTable());
+ ARROW_UNUSED(do_exchange_result.writer->Close());
- ASSERT_OK(client_->DoExchange(FlightDescriptor::Command(""), &writer, &stream));
+ ASSERT_OK_AND_ASSIGN(do_exchange_result,
+ client_->DoExchange(FlightDescriptor::Command("")));
EXPECT_RAISES_WITH_MESSAGE_THAT(Cancelled, ::testing::HasSubstr("StopSource"),
- stream->ToTable(options.stop_token));
- ARROW_UNUSED(writer->Close());
+ do_exchange_result.reader->ToTable(options.stop_token));
+ ARROW_UNUSED(do_exchange_result.writer->Close());
}
} // namespace flight
diff --git a/cpp/src/arrow/flight/integration_tests/test_integration.cc b/cpp/src/arrow/flight/integration_tests/test_integration.cc
index 89dc191c65..e1d79f1a3e 100644
--- a/cpp/src/arrow/flight/integration_tests/test_integration.cc
+++ b/cpp/src/arrow/flight/integration_tests/test_integration.cc
@@ -51,7 +51,7 @@ class AuthBasicProtoServer : public FlightServerBase {
Status CheckActionResults(FlightClient* client, const Action& action,
std::vector<std::string> results) {
std::unique_ptr<ResultStream> stream;
- RETURN_NOT_OK(client->DoAction(action, &stream));
+ ARROW_ASSIGN_OR_RAISE(stream, client->DoAction(action));
std::unique_ptr<Result> result;
for (const std::string& expected : results) {
ARROW_ASSIGN_OR_RAISE(result, stream->Next());
@@ -91,7 +91,7 @@ class AuthBasicProtoScenario : public Scenario {
Action action;
std::unique_ptr<ResultStream> stream;
std::shared_ptr<FlightStatusDetail> detail;
- const auto& status = client->DoAction(action, &stream);
+ const auto& status = client->DoAction(action).Value(&stream);
detail = FlightStatusDetail::UnwrapStatus(status);
// This client is unauthenticated and should fail.
if (detail == nullptr) {
@@ -231,12 +231,11 @@ class MiddlewareScenario : public Scenario {
}
Status RunClient(std::unique_ptr<FlightClient> client) override {
- std::unique_ptr<FlightInfo> info;
// This call is expected to fail. In gRPC/Java, this causes the
// server to combine headers and HTTP/2 trailers, so to read the
// expected header, Flight must check for both headers and
// trailers.
- if (client->GetFlightInfo(FlightDescriptor::Command(""), &info).ok()) {
+ if (client->GetFlightInfo(FlightDescriptor::Command("")).status().ok()) {
return Status::Invalid("Expected call to fail");
}
if (client_middleware_->received_header_ != "expected value") {
@@ -248,7 +247,8 @@ class MiddlewareScenario : public Scenario {
// This call should succeed
client_middleware_->received_header_ = "";
- RETURN_NOT_OK(client->GetFlightInfo(FlightDescriptor::Command("success"), &info));
+ ARROW_ASSIGN_OR_RAISE(auto info,
+ client->GetFlightInfo(FlightDescriptor::Command("success")));
if (client_middleware_->received_header_ != "expected value") {
return Status::Invalid(
"Expected to receive header 'x-middleware: expected value', but instead got '",
diff --git a/cpp/src/arrow/flight/integration_tests/test_integration_client.cc b/cpp/src/arrow/flight/integration_tests/test_integration_client.cc
index 08f80e9923..9c5c985c06 100644
--- a/cpp/src/arrow/flight/integration_tests/test_integration_client.cc
+++ b/cpp/src/arrow/flight/integration_tests/test_integration_client.cc
@@ -93,8 +93,7 @@ Status UploadBatchesToFlight(const std::vector<std::shared_ptr<RecordBatch>>& ch
Status ConsumeFlightLocation(
FlightClient* read_client, const Ticket& ticket,
const std::vector<std::shared_ptr<RecordBatch>>& retrieved_data) {
- std::unique_ptr<FlightStreamReader> stream;
- RETURN_NOT_OK(read_client->DoGet(ticket, &stream));
+ ARROW_ASSIGN_OR_RAISE(auto stream, read_client->DoGet(ticket));
int counter = 0;
const int expected = static_cast<int>(retrieved_data.size());
@@ -161,14 +160,14 @@ class IntegrationTestScenario : public Scenario {
std::vector<std::shared_ptr<RecordBatch>> original_data;
ABORT_NOT_OK(ReadBatches(reader, &original_data));
- std::unique_ptr<FlightStreamWriter> write_stream;
- std::unique_ptr<FlightMetadataReader> metadata_reader;
- ABORT_NOT_OK(client->DoPut(descr, original_schema, &write_stream, &metadata_reader));
+ auto do_put_result = client->DoPut(descr, original_schema).ValueOrDie();
+ std::unique_ptr<FlightStreamWriter> write_stream = std::move(do_put_result.writer);
+ std::unique_ptr<FlightMetadataReader> metadata_reader =
+ std::move(do_put_result.reader);
ABORT_NOT_OK(UploadBatchesToFlight(original_data, *write_stream, *metadata_reader));
// 2. Get the ticket for the data.
- std::unique_ptr<FlightInfo> info;
- ABORT_NOT_OK(client->GetFlightInfo(descr, &info));
+ std::unique_ptr<FlightInfo> info = client->GetFlightInfo(descr).ValueOrDie();
std::shared_ptr<Schema> schema;
ipc::DictionaryMemo dict_memo;
@@ -189,7 +188,7 @@ class IntegrationTestScenario : public Scenario {
for (const auto& location : endpoint.locations) {
std::cout << "Verifying location " << location.ToString() << std::endl;
std::unique_ptr<FlightClient> read_client;
- RETURN_NOT_OK(FlightClient::Connect(location, &read_client));
+ ARROW_ASSIGN_OR_RAISE(read_client, FlightClient::Connect(location));
RETURN_NOT_OK(ConsumeFlightLocation(read_client.get(), ticket, original_data));
RETURN_NOT_OK(read_client->Close());
}
@@ -212,7 +211,7 @@ arrow::Status RunScenario(arrow::flight::integration_tests::Scenario* scenario)
RETURN_NOT_OK(scenario->MakeClient(&options));
ARROW_ASSIGN_OR_RAISE(auto location,
arrow::flight::Location::ForGrpcTcp(FLAGS_host, FLAGS_port));
- RETURN_NOT_OK(arrow::flight::FlightClient::Connect(location, options, &client));
+ ARROW_ASSIGN_OR_RAISE(client, arrow::flight::FlightClient::Connect(location, options));
return scenario->RunClient(std::move(client));
}
diff --git a/cpp/src/arrow/flight/perf_server.cc b/cpp/src/arrow/flight/perf_server.cc
index cc42ffedd6..37e3ec4d77 100644
--- a/cpp/src/arrow/flight/perf_server.cc
+++ b/cpp/src/arrow/flight/perf_server.cc
@@ -19,6 +19,7 @@
#include <signal.h>
#include <cstdint>
+#include <cstdlib>
#include <fstream>
#include <iostream>
#include <memory>
@@ -43,10 +44,17 @@
#ifdef ARROW_CUDA
#include "arrow/gpu/cuda_api.h"
#endif
+#ifdef ARROW_WITH_UCX
+#include "arrow/flight/transport/ucx/ucx.h"
+#endif
DEFINE_bool(cuda, false, "Allocate results in CUDA memory");
DEFINE_string(transport, "grpc",
- "The network transport to use. Supported: \"grpc\" (default).");
+ "The network transport to use. Supported: \"grpc\" (default)"
+#ifdef ARROW_WITH_UCX
+ ", \"ucx\""
+#endif // ARROW_WITH_UCX
+ ".");
DEFINE_string(server_host, "localhost", "Host where the server is running on");
DEFINE_int32(port, 31337, "Server port to listen on");
DEFINE_string(server_unix, "", "Unix socket path where the server is running on");
@@ -97,7 +105,7 @@ class PerfDataStream : public FlightDataStream {
if (records_sent_ >= total_records_) {
// Signal that iteration is over
payload.ipc_message.metadata = nullptr;
- return Status::OK();
+ return payload;
}
if (verify_) {
@@ -274,6 +282,29 @@ int main(int argc, char** argv) {
ARROW_CHECK_OK(arrow::flight::Location::ForGrpcUnix(FLAGS_server_unix)
.Value(&connect_location));
}
+ } else if (FLAGS_transport == "ucx") {
+#ifdef ARROW_WITH_UCX
+ arrow::flight::transport::ucx::InitializeFlightUcx();
+ if (FLAGS_server_unix.empty()) {
+ if (!FLAGS_cert_file.empty() || !FLAGS_key_file.empty()) {
+ std::cerr << "Transport does not support TLS: " << FLAGS_transport << std::endl;
+ return EXIT_FAILURE;
+ }
+ ARROW_CHECK_OK(arrow::flight::Location::Parse("ucx://" + FLAGS_server_host + ":" +
+ std::to_string(FLAGS_port))
+ .Value(&bind_location));
+ ARROW_CHECK_OK(arrow::flight::Location::Parse("ucx://" + FLAGS_server_host + ":" +
+ std::to_string(FLAGS_port))
+ .Value(&connect_location));
+ } else {
+ std::cerr << "Transport does not support domain sockets: " << FLAGS_transport
+ << std::endl;
+ return EXIT_FAILURE;
+ }
+#else
+ std::cerr << "Not built with transport: " << FLAGS_transport << std::endl;
+ return EXIT_FAILURE;
+#endif
} else {
std::cerr << "Unknown transport: " << FLAGS_transport << std::endl;
return EXIT_FAILURE;
@@ -308,6 +339,7 @@ int main(int argc, char** argv) {
// Exit with a clean error code (0) on SIGTERM
ARROW_CHECK_OK(g_server->SetShutdownOnSignals({SIGTERM}));
std::cout << "Server transport: " << FLAGS_transport << std::endl;
+ std::cout << "Server location: " << connect_location.ToString() << std::endl;
if (FLAGS_server_unix.empty()) {
std::cout << "Server host: " << FLAGS_server_host << std::endl;
std::cout << "Server port: " << FLAGS_port << std::endl;
diff --git a/cpp/src/arrow/flight/sql/client.cc b/cpp/src/arrow/flight/sql/client.cc
index b86cb8008e..89cfb3aad0 100644
--- a/cpp/src/arrow/flight/sql/client.cc
+++ b/cpp/src/arrow/flight/sql/client.cc
@@ -369,8 +369,8 @@ arrow::Result<int64_t> PreparedStatement::ExecuteUpdate() {
} else {
const std::shared_ptr<Schema> schema = arrow::schema({});
ARROW_RETURN_NOT_OK(client_->DoPut(options_, descriptor, schema, &writer, &reader));
- const auto& record_batch =
- arrow::RecordBatch::Make(schema, 0, (std::vector<std::shared_ptr<Array>>){});
+ const ArrayVector columns;
+ const auto& record_batch = arrow::RecordBatch::Make(schema, 0, columns);
ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*record_batch));
}
diff --git a/cpp/src/arrow/flight/sql/client.h b/cpp/src/arrow/flight/sql/client.h
index 454f98a351..78c162e4cd 100644
--- a/cpp/src/arrow/flight/sql/client.h
+++ b/cpp/src/arrow/flight/sql/client.h
@@ -33,6 +33,8 @@ namespace sql {
class PreparedStatement;
/// \brief Flight client with Flight SQL semantics.
+///
+/// Wraps a Flight client to provide the Flight SQL RPC calls.
class ARROW_EXPORT FlightSqlClient {
friend class PreparedStatement;
@@ -170,10 +172,7 @@ class ARROW_EXPORT FlightSqlClient {
// function GetFlightInfoForCommand.
virtual arrow::Result<std::unique_ptr<FlightInfo>> GetFlightInfo(
const FlightCallOptions& options, const FlightDescriptor& descriptor) {
- std::unique_ptr<FlightInfo> info;
- ARROW_RETURN_NOT_OK(impl_->GetFlightInfo(options, descriptor, &info));
-
- return info;
+ return impl_->GetFlightInfo(options, descriptor);
}
/// \brief Explicitly shut down and clean up the client.
@@ -183,34 +182,31 @@ class ARROW_EXPORT FlightSqlClient {
virtual Status DoPut(const FlightCallOptions& options,
const FlightDescriptor& descriptor,
const std::shared_ptr<Schema>& schema,
- std::unique_ptr<FlightStreamWriter>* stream,
+ std::unique_ptr<FlightStreamWriter>* writer,
std::unique_ptr<FlightMetadataReader>* reader) {
- return impl_->DoPut(options, descriptor, schema, stream, reader);
+ ARROW_ASSIGN_OR_RAISE(auto result, impl_->DoPut(options, descriptor, schema));
+ *writer = std::move(result.writer);
+ *reader = std::move(result.reader);
+ return Status::OK();
}
virtual Status DoGet(const FlightCallOptions& options, const Ticket& ticket,
std::unique_ptr<FlightStreamReader>* stream) {
- return impl_->DoGet(options, ticket, stream);
+ return impl_->DoGet(options, ticket).Value(stream);
}
virtual Status DoAction(const FlightCallOptions& options, const Action& action,
std::unique_ptr<ResultStream>* results) {
- return impl_->DoAction(options, action, results);
+ return impl_->DoAction(options, action).Value(results);
}
};
-/// \brief PreparedStatement class from flight sql.
+/// \brief A prepared statement that can be executed.
class ARROW_EXPORT PreparedStatement {
- FlightSqlClient* client_;
- FlightCallOptions options_;
- std::string handle_;
- std::shared_ptr<Schema> dataset_schema_;
- std::shared_ptr<Schema> parameter_schema_;
- std::shared_ptr<RecordBatch> parameter_binding_;
- bool is_closed_;
-
public:
- /// \brief Constructor for the PreparedStatement class.
+ /// \brief Create a new prepared statement. However, applications
+ /// should generally use FlightSqlClient::Prepare.
+ ///
/// \param[in] client Client object used to make the RPC requests.
/// \param[in] handle Handle for this prepared statement.
/// \param[in] dataset_schema Schema of the resulting dataset.
@@ -256,6 +252,15 @@ class ARROW_EXPORT PreparedStatement {
/// \brief Check if the prepared statement is closed.
/// \return The state of the prepared statement.
bool IsClosed() const;
+
+ private:
+ FlightSqlClient* client_;
+ FlightCallOptions options_;
+ std::string handle_;
+ std::shared_ptr<Schema> dataset_schema_;
+ std::shared_ptr<Schema> parameter_schema_;
+ std::shared_ptr<RecordBatch> parameter_binding_;
+ bool is_closed_;
};
} // namespace sql
diff --git a/cpp/src/arrow/flight/sql/column_metadata.cc b/cpp/src/arrow/flight/sql/column_metadata.cc
index e98b29c292..30ef240105 100644
--- a/cpp/src/arrow/flight/sql/column_metadata.cc
+++ b/cpp/src/arrow/flight/sql/column_metadata.cc
@@ -42,6 +42,7 @@ bool StringToBoolean(const std::string& string_value) {
const char* ColumnMetadata::kCatalogName = "ARROW:FLIGHT:SQL:CATALOG_NAME";
const char* ColumnMetadata::kSchemaName = "ARROW:FLIGHT:SQL:SCHEMA_NAME";
const char* ColumnMetadata::kTableName = "ARROW:FLIGHT:SQL:TABLE_NAME";
+const char* ColumnMetadata::kTypeName = "ARROW:FLIGHT:SQL:TYPE_NAME";
const char* ColumnMetadata::kPrecision = "ARROW:FLIGHT:SQL:PRECISION";
const char* ColumnMetadata::kScale = "ARROW:FLIGHT:SQL:SCALE";
const char* ColumnMetadata::kIsAutoIncrement = "ARROW:FLIGHT:SQL:IS_AUTO_INCREMENT";
@@ -65,6 +66,10 @@ arrow::Result<std::string> ColumnMetadata::GetTableName() const {
return metadata_map_->Get(kTableName);
}
+arrow::Result<std::string> ColumnMetadata::GetTypeName() const {
+ return metadata_map_->Get(kTypeName);
+}
+
arrow::Result<int32_t> ColumnMetadata::GetPrecision() const {
std::string precision_string;
ARROW_ASSIGN_OR_RAISE(precision_string, metadata_map_->Get(kPrecision));
@@ -130,6 +135,12 @@ ColumnMetadata::ColumnMetadataBuilder& ColumnMetadata::ColumnMetadataBuilder::Ta
return *this;
}
+ColumnMetadata::ColumnMetadataBuilder& ColumnMetadata::ColumnMetadataBuilder::TypeName(
+ std::string& type_name) {
+ metadata_map_->Append(ColumnMetadata::kTypeName, type_name);
+ return *this;
+}
+
ColumnMetadata::ColumnMetadataBuilder& ColumnMetadata::ColumnMetadataBuilder::Precision(
int32_t precision) {
metadata_map_->Append(ColumnMetadata::kPrecision, std::to_string(precision));
diff --git a/cpp/src/arrow/flight/sql/column_metadata.h b/cpp/src/arrow/flight/sql/column_metadata.h
index 7a20e74d67..d347155fff 100644
--- a/cpp/src/arrow/flight/sql/column_metadata.h
+++ b/cpp/src/arrow/flight/sql/column_metadata.h
@@ -46,6 +46,9 @@ class ColumnMetadata {
static const char* kTableName;
/// \brief Constant variable to hold the value of the key that
/// will be used in the KeyValueMetadata class.
+ static const char* kTypeName;
+ /// \brief Constant variable to hold the value of the key that
+ /// will be used in the KeyValueMetadata class.
static const char* kPrecision;
/// \brief Constant variable to hold the value of the key that
/// will be used in the KeyValueMetadata class.
@@ -78,6 +81,10 @@ class ColumnMetadata {
/// \return The table name.
arrow::Result<std::string> GetTableName() const;
+ /// \brief Return the type name set in the KeyValueMetadata.
+ /// \return The type name.
+ arrow::Result<std::string> GetTypeName() const;
+
/// \brief Return the precision set in the KeyValueMetadata.
/// \return The precision.
arrow::Result<int32_t> GetPrecision() const;
@@ -117,15 +124,20 @@ class ColumnMetadata {
ColumnMetadataBuilder& CatalogName(std::string& catalog_name);
/// \brief Set the schema_name in the KeyValueMetadata object.
- /// \param[in] schema_name The schema_name.
+ /// \param[in] schema_name The schema_name.
/// \return A ColumnMetadataBuilder.
ColumnMetadataBuilder& SchemaName(std::string& schema_name);
/// \brief Set the table name in the KeyValueMetadata object.
- /// \param[in] table_name The table name.
+ /// \param[in] table_name The table name.
/// \return A ColumnMetadataBuilder.
ColumnMetadataBuilder& TableName(std::string& table_name);
+ /// \brief Set the type name in the KeyValueMetadata object.
+ /// \param[in] type_name The type name.
+ /// \return A ColumnMetadataBuilder.
+ ColumnMetadataBuilder& TypeName(std::string& type_name);
+
/// \brief Set the precision in the KeyValueMetadata object.
/// \param[in] precision The precision.
/// \return A ColumnMetadataBuilder.
@@ -138,22 +150,22 @@ class ColumnMetadata {
/// \brief Set the IsAutoIncrement in the KeyValueMetadata object.
/// \param[in] is_auto_increment The IsAutoIncrement.
- /// \return A ColumnMetadataBuilder.
+ /// \return A ColumnMetadataBuilder.
ColumnMetadataBuilder& IsAutoIncrement(bool is_auto_increment);
/// \brief Set the IsCaseSensitive in the KeyValueMetadata object.
/// \param[in] is_case_sensitive The IsCaseSensitive.
- /// \return A ColumnMetadataBuilder.
+ /// \return A ColumnMetadataBuilder.
ColumnMetadataBuilder& IsCaseSensitive(bool is_case_sensitive);
/// \brief Set the IsReadOnly in the KeyValueMetadata object.
/// \param[in] is_read_only The IsReadOnly.
- /// \return A ColumnMetadataBuilder.
+ /// \return A ColumnMetadataBuilder.
ColumnMetadataBuilder& IsReadOnly(bool is_read_only);
/// \brief Set the IsSearchable in the KeyValueMetadata object.
/// \param[in] is_searchable The IsSearchable.
- /// \return A ColumnMetadataBuilder.
+ /// \return A ColumnMetadataBuilder.
ColumnMetadataBuilder& IsSearchable(bool is_searchable);
ColumnMetadata Build() const;
diff --git a/cpp/src/arrow/flight/sql/server.h b/cpp/src/arrow/flight/sql/server.h
index b5137fd9f4..4e6ddce239 100644
--- a/cpp/src/arrow/flight/sql/server.h
+++ b/cpp/src/arrow/flight/sql/server.h
@@ -33,117 +33,149 @@ namespace arrow {
namespace flight {
namespace sql {
+/// \defgroup flight-sql-protocol-messages Flight SQL Protocol Messages
+/// Simple struct wrappers for various protocol messages, used to
+/// avoid exposing Protobuf types in the API.
+/// @{
+
+/// \brief A SQL query.
struct StatementQuery {
+ /// \brief The SQL query.
std::string query;
};
+/// \brief A SQL update query.
struct StatementUpdate {
+ /// \brief The SQL query.
std::string query;
};
+/// \brief A request to execute a query.
struct StatementQueryTicket {
+ /// \brief The server-generated opaque identifier for the query.
std::string statement_handle;
};
+/// \brief A prepared query statement.
struct PreparedStatementQuery {
+ /// \brief The server-generated opaque identifier for the statement.
std::string prepared_statement_handle;
};
+/// \brief A prepared update statement.
struct PreparedStatementUpdate {
+ /// \brief The server-generated opaque identifier for the statement.
std::string prepared_statement_handle;
};
+/// \brief A request to fetch server metadata.
struct GetSqlInfo {
+ /// \brief A list of metadata IDs to fetch.
std::vector<int32_t> info;
};
+/// \brief A request to list database schemas.
struct GetDbSchemas {
+ /// \brief An optional database catalog to filter on.
util::optional<std::string> catalog;
+ /// \brief An optional database schema to filter on.
util::optional<std::string> db_schema_filter_pattern;
};
+/// \brief A request to list database tables.
struct GetTables {
+ /// \brief An optional database catalog to filter on.
util::optional<std::string> catalog;
+ /// \brief An optional database schema to filter on.
util::optional<std::string> db_schema_filter_pattern;
+ /// \brief An optional table name to filter on.
util::optional<std::string> table_name_filter_pattern;
+ /// \brief A list of table types to filter on.
std::vector<std::string> table_types;
+ /// \brief Whether to include the Arrow schema in the response.
bool include_schema;
};
+/// \brief A request to get SQL data type information.
struct GetXdbcTypeInfo {
+ /// \brief A specific SQL type ID to fetch information about.
util::optional<int> data_type;
};
+/// \brief A request to list primary keys of a table.
struct GetPrimaryKeys {
+ /// \brief The given table.
TableRef table_ref;
};
+/// \brief A request to list foreign key columns referencing primary key
+/// columns of a table.
struct GetExportedKeys {
+ /// \brief The given table.
TableRef table_ref;
};
+/// \brief A request to list foreign keys of a table.
struct GetImportedKeys {
+ /// \brief The given table.
TableRef table_ref;
};
+/// \brief A request to list foreign key columns of a table that
+/// reference columns in a given parent table.
struct GetCrossReference {
+ /// \brief The parent table (the one containing referenced columns).
TableRef pk_table_ref;
+ /// \brief The foreign table (for which foreign key columns will be listed).
TableRef fk_table_ref;
};
+/// \brief A request to create a new prepared statement.
struct ActionCreatePreparedStatementRequest {
+ /// \brief The SQL query.
std::string query;
};
+/// \brief A request to close a prepared statement.
struct ActionClosePreparedStatementRequest {
+ /// \brief The server-generated opaque identifier for the statement.
std::string prepared_statement_handle;
};
+/// \brief The result of creating a new prepared statement.
struct ActionCreatePreparedStatementResult {
+ /// \brief The schema of the query results, if applicable.
std::shared_ptr<Schema> dataset_schema;
+ /// \brief The schema of the query parameters, if applicable.
std::shared_ptr<Schema> parameter_schema;
+ /// \brief The server-generated opaque identifier for the statement.
std::string prepared_statement_handle;
};
-/// \brief A utility function to create a ticket (a opaque binary token that the server
-/// uses to identify this query) for a statement query.
-/// Intended for Flight SQL server implementations.
+/// @}
+
+/// \brief A utility function to create a ticket (a opaque binary
+/// token that the server uses to identify this query) for a statement
+/// query. Intended for Flight SQL server implementations.
+///
/// \param[in] statement_handle The statement handle that will originate the ticket.
/// \return The parsed ticket as an string.
arrow::Result<std::string> CreateStatementQueryTicket(
const std::string& statement_handle);
+/// \brief The base class for Flight SQL servers.
+///
+/// Applications should subclass this class and override the virtual
+/// methods declared on this class.
class ARROW_EXPORT FlightSqlServerBase : public FlightServerBase {
private:
SqlInfoResultMap sql_info_id_to_result_;
public:
- Status GetFlightInfo(const ServerCallContext& context, const FlightDescriptor& request,
- std::unique_ptr<FlightInfo>* info) override;
-
- Status DoGet(const ServerCallContext& context, const Ticket& request,
- std::unique_ptr<FlightDataStream>* stream) override;
-
- Status DoPut(const ServerCallContext& context,
- std::unique_ptr<FlightMessageReader> reader,
- std::unique_ptr<FlightMetadataWriter> writer) override;
-
- const ActionType kCreatePreparedStatementActionType =
- ActionType{"CreatePreparedStatement",
- "Creates a reusable prepared statement resource on the server.\n"
- "Request Message: ActionCreatePreparedStatementRequest\n"
- "Response Message: ActionCreatePreparedStatementResult"};
- const ActionType kClosePreparedStatementActionType =
- ActionType{"ClosePreparedStatement",
- "Closes a reusable prepared statement resource on the server.\n"
- "Request Message: ActionClosePreparedStatementRequest\n"
- "Response Message: N/A"};
-
- Status ListActions(const ServerCallContext& context,
- std::vector<ActionType>* actions) override;
-
- Status DoAction(const ServerCallContext& context, const Action& action,
- std::unique_ptr<ResultStream>* result) override;
+ /// \name Flight SQL methods
+ /// Applications should override these methods to implement the
+ /// Flight SQL endpoints.
+ /// @{
/// \brief Get a FlightInfo for executing a SQL query.
/// \param[in] context Per-call context.
@@ -408,10 +440,51 @@ class ARROW_EXPORT FlightSqlServerBase : public FlightServerBase {
const ServerCallContext& context, const PreparedStatementUpdate& command,
FlightMessageReader* reader);
+ /// @}
+
+ /// \name Utility methods
+ /// @{
+
/// \brief Register a new SqlInfo result, making it available when calling GetSqlInfo.
/// \param[in] id the SqlInfo identifier.
/// \param[in] result the result.
void RegisterSqlInfo(int32_t id, const SqlInfoResult& result);
+
+ /// @}
+
+ /// \name Flight RPC handlers
+ /// Applications should not override these methods; they implement
+ /// the Flight SQL protocol.
+ /// @{
+
+ Status GetFlightInfo(const ServerCallContext& context, const FlightDescriptor& request,
+ std::unique_ptr<FlightInfo>* info) final;
+
+ Status DoGet(const ServerCallContext& context, const Ticket& request,
+ std::unique_ptr<FlightDataStream>* stream) final;
+
+ Status DoPut(const ServerCallContext& context,
+ std::unique_ptr<FlightMessageReader> reader,
+ std::unique_ptr<FlightMetadataWriter> writer) final;
+
+ const ActionType kCreatePreparedStatementActionType =
+ ActionType{"CreatePreparedStatement",
+ "Creates a reusable prepared statement resource on the server.\n"
+ "Request Message: ActionCreatePreparedStatementRequest\n"
+ "Response Message: ActionCreatePreparedStatementResult"};
+ const ActionType kClosePreparedStatementActionType =
+ ActionType{"ClosePreparedStatement",
+ "Closes a reusable prepared statement resource on the server.\n"
+ "Request Message: ActionClosePreparedStatementRequest\n"
+ "Response Message: N/A"};
+
+ Status ListActions(const ServerCallContext& context,
+ std::vector<ActionType>* actions) final;
+
+ Status DoAction(const ServerCallContext& context, const Action& action,
+ std::unique_ptr<ResultStream>* result) final;
+
+ /// @}
};
/// \brief Auxiliary class containing all Schemas used on Flight SQL.
diff --git a/cpp/src/arrow/flight/sql/server_test.cc b/cpp/src/arrow/flight/sql/server_test.cc
index d2b41df8f9..746c91c102 100644
--- a/cpp/src/arrow/flight/sql/server_test.cc
+++ b/cpp/src/arrow/flight/sql/server_test.cc
@@ -162,9 +162,8 @@ class TestFlightSqlServer : public ::testing::Test {
ss << "grpc://localhost:" << port;
std::string uri = ss.str();
- std::unique_ptr<FlightClient> client;
ASSERT_OK_AND_ASSIGN(auto location, Location::Parse(uri));
- ASSERT_OK(FlightClient::Connect(location, &client));
+ ASSERT_OK_AND_ASSIGN(auto client, FlightClient::Connect(location));
sql_client.reset(new FlightSqlClient(std::move(client)));
}
diff --git a/cpp/src/arrow/flight/sql/test_app_cli.cc b/cpp/src/arrow/flight/sql/test_app_cli.cc
index 63924cc1c9..7989210dd0 100644
--- a/cpp/src/arrow/flight/sql/test_app_cli.cc
+++ b/cpp/src/arrow/flight/sql/test_app_cli.cc
@@ -101,9 +101,8 @@ Status PrintResults(FlightSqlClient& client, const FlightCallOptions& call_optio
}
Status RunMain() {
- std::unique_ptr<FlightClient> client;
ARROW_ASSIGN_OR_RAISE(auto location, Location::ForGrpcTcp(FLAGS_host, FLAGS_port));
- ARROW_RETURN_NOT_OK(FlightClient::Connect(location, &client));
+ ARROW_ASSIGN_OR_RAISE(auto client, FlightClient::Connect(location));
FlightCallOptions call_options;
diff --git a/cpp/src/arrow/flight/sql/types.h b/cpp/src/arrow/flight/sql/types.h
index 44b8bca471..ebfb2ef0ea 100644
--- a/cpp/src/arrow/flight/sql/types.h
+++ b/cpp/src/arrow/flight/sql/types.h
@@ -30,6 +30,10 @@ namespace arrow {
namespace flight {
namespace sql {
+/// \defgroup flight-sql-common-types Common protocol types for Flight SQL
+///
+/// @{
+
/// \brief Variant supporting all possible types on SQL info.
using SqlInfoResult =
arrow::util::Variant<std::string, bool, int64_t, int32_t, std::vector<std::string>,
@@ -40,813 +44,764 @@ using SqlInfoResultMap = std::unordered_map<int32_t, SqlInfoResult>;
/// \brief Options to be set in the SqlInfo.
struct SqlInfoOptions {
+ /// \brief Predefined info values for GetSqlInfo.
enum SqlInfo {
- // Server Information [0-500): Provides basic information about the Flight SQL Server.
+ /// \name Server Information
+ /// Values [0-500): Provides basic information about the Flight SQL Server.
+ /// @{
- // Retrieves a UTF-8 string with the name of the Flight SQL Server.
+ /// Retrieves a UTF-8 string with the name of the Flight SQL Server.
FLIGHT_SQL_SERVER_NAME = 0,
- // Retrieves a UTF-8 string with the native version of the Flight SQL Server.
+ /// Retrieves a UTF-8 string with the native version of the Flight SQL
+ /// Server.
FLIGHT_SQL_SERVER_VERSION = 1,
- // Retrieves a UTF-8 string with the Arrow format version of the Flight SQL Server.
+ /// Retrieves a UTF-8 string with the Arrow format version of the Flight
+ /// SQL Server.
FLIGHT_SQL_SERVER_ARROW_VERSION = 2,
- /*
- * Retrieves a boolean value indicating whether the Flight SQL Server is read only.
- *
- * Returns:
- * - false: if read-write
- * - true: if read only
- */
+ /// Retrieves a boolean value indicating whether the Flight SQL Server is
+ /// read only.
+ ///
+ /// Returns:
+ /// - false: if read-write
+ /// - true: if read only
FLIGHT_SQL_SERVER_READ_ONLY = 3,
- // SQL Syntax Information [500-1000): provides information about SQL syntax supported
- // by the Flight SQL Server.
-
- /*
- * Retrieves a boolean value indicating whether the Flight SQL Server supports CREATE
- * and DROP of catalogs.
- *
- * Returns:
- * - false: if it doesn't support CREATE and DROP of catalogs.
- * - true: if it supports CREATE and DROP of catalogs.
- */
+ /// @}
+
+ /// \name SQL Syntax Information
+ /// Values [500-1000): provides information about SQL syntax supported
+ /// by the Flight SQL Server.
+ /// @{
+
+ /// Retrieves a boolean value indicating whether the Flight SQL
+ /// Server supports CREATE and DROP of catalogs.
+ ///
+ /// Returns:
+ /// - false: if it doesn't support CREATE and DROP of catalogs.
+ /// - true: if it supports CREATE and DROP of catalogs.
SQL_DDL_CATALOG = 500,
- /*
- * Retrieves a boolean value indicating whether the Flight SQL Server supports CREATE
- * and DROP of schemas.
- *
- * Returns:
- * - false: if it doesn't support CREATE and DROP of schemas.
- * - true: if it supports CREATE and DROP of schemas.
- */
+ /// Retrieves a boolean value indicating whether the Flight SQL
+ /// Server supports CREATE and DROP of schemas.
+ ///
+ /// Returns:
+ /// - false: if it doesn't support CREATE and DROP of schemas.
+ /// - true: if it supports CREATE and DROP of schemas.
SQL_DDL_SCHEMA = 501,
- /*
- * Indicates whether the Flight SQL Server supports CREATE and DROP of tables.
- *
- * Returns:
- * - false: if it doesn't support CREATE and DROP of tables.
- * - true: if it supports CREATE and DROP of tables.
- */
+ /// Indicates whether the Flight SQL Server supports CREATE and DROP of
+ /// tables.
+ ///
+ /// Returns:
+ /// - false: if it doesn't support CREATE and DROP of tables.
+ /// - true: if it supports CREATE and DROP of tables.
SQL_DDL_TABLE = 502,
- /*
- * Retrieves a uint32 value representing the enu uint32 ordinal for the case
- * sensitivity of catalog, table and schema names.
- *
- * The possible values are listed in
- * `arrow.flight.protocol.sql.SqlSupportedCaseSensitivity`.
- */
+ /// Retrieves a int32 value representing the enum ordinal for the
+ /// case sensitivity of catalog, table and schema names.
+ ///
+ /// The possible values are listed in
+ /// `arrow.flight.protocol.sql.SqlSupportedCaseSensitivity`.
SQL_IDENTIFIER_CASE = 503,
- // Retrieves a UTF-8 string with the supported character(s) used to surround a
- // delimited identifier.
+ /// Retrieves a UTF-8 string with the supported character(s) used
+ /// to surround a delimited identifier.
SQL_IDENTIFIER_QUOTE_CHAR = 504,
- /*
- * Retrieves a uint32 value representing the enu uint32 ordinal for the case
- * sensitivity of quoted identifiers.
- *
- * The possible values are listed in
- * `arrow.flight.protocol.sql.SqlSupportedCaseSensitivity`.
- */
+ /// Retrieves a int32 value representing the enum ordinal
+ /// for the case sensitivity of quoted identifiers.
+ ///
+ /// The possible values are listed in
+ /// `arrow.flight.protocol.sql.SqlSupportedCaseSensitivity`.
SQL_QUOTED_IDENTIFIER_CASE = 505,
- /*
- * Retrieves a boolean value indicating whether all tables are selectable.
- *
- * Returns:
- * - false: if not all tables are selectable or if none are;
- * - true: if all tables are selectable.
- */
+ /// Retrieves a boolean value indicating whether all tables are
+ /// selectable.
+ ///
+ /// Returns:
+ /// - false: if not all tables are selectable or if none are;
+ /// - true: if all tables are selectable.
SQL_ALL_TABLES_ARE_SELECTABLE = 506,
- /*
- * Retrieves the null ordering.
- *
- * Returns a uint32 ordinal for the null ordering being used, as described in
- * `arrow.flight.protocol.sql.SqlNullOrdering`.
- */
+ /// Retrieves the null ordering used by the database as a int32
+ /// ordinal value.
+ ///
+ /// Returns a int32 ordinal for the null ordering being used, as
+ /// described in `arrow.flight.protocol.sql.SqlNullOrdering`.
SQL_NULL_ORDERING = 507,
- // Retrieves a UTF-8 string list with values of the supported keywords.
+ /// Retrieves a UTF-8 string list with values of the supported keywords.
SQL_KEYWORDS = 508,
- // Retrieves a UTF-8 string list with values of the supported numeric functions.
+ /// Retrieves a UTF-8 string list with values of the supported numeric functions.
SQL_NUMERIC_FUNCTIONS = 509,
- // Retrieves a UTF-8 string list with values of the supported string functions.
+ /// Retrieves a UTF-8 string list with values of the supported string functions.
SQL_STRING_FUNCTIONS = 510,
- // Retrieves a UTF-8 string list with values of the supported system functions.
+ /// Retrieves a UTF-8 string list with values of the supported system functions.
SQL_SYSTEM_FUNCTIONS = 511,
- // Retrieves a UTF-8 string list with values of the supported datetime functions.
+ /// Retrieves a UTF-8 string list with values of the supported datetime functions.
SQL_DATETIME_FUNCTIONS = 512,
- /*
- * Retrieves the UTF-8 string that can be used to escape wildcard characters.
- * This is the string that can be used to escape '_' or '%' in the catalog search
- * parameters that are a pattern (and therefore use one of the wildcard characters).
- * The '_' character represents any single character; the '%' character represents any
- * sequence of zero or more characters.
- */
+ /// Retrieves the UTF-8 string that can be used to escape wildcard characters.
+ /// This is the string that can be used to escape '_' or '%' in the catalog search
+ /// parameters that are a pattern (and therefore use one of the wildcard characters).
+ /// The '_' character represents any single character; the '%' character represents
+ /// any
+ /// sequence of zero or more characters.
SQL_SEARCH_STRING_ESCAPE = 513,
- /*
- * Retrieves a UTF-8 string with all the "extra" characters that can be used in
- * unquoted identifier names (those beyond a-z, A-Z, 0-9 and _).
- */
+ /// Retrieves a UTF-8 string with all the "extra" characters that can be used in
+ /// unquoted identifier names (those beyond a-z, A-Z, 0-9 and _).
SQL_EXTRA_NAME_CHARACTERS = 514,
- /*
- * Retrieves a boolean value indicating whether column aliasing is supported.
- * If so, the SQL AS clause can be used to provide names for computed columns or to
- * provide alias names for columns as required.
- *
- * Returns:
- * - false: if column aliasing is unsupported;
- * - true: if column aliasing is supported.
- */
+ /// Retrieves a boolean value indicating whether column aliasing is supported.
+ /// If so, the SQL AS clause can be used to provide names for computed columns or to
+ /// provide alias names for columns as required.
+ ///
+ /// Returns:
+ /// - false: if column aliasing is unsupported;
+ /// - true: if column aliasing is supported.
SQL_SUPPORTS_COLUMN_ALIASING = 515,
- /*
- * Retrieves a boolean value indicating whether concatenations between null and
- * non-null values being null are supported.
- *
- * - Returns:
- * - false: if concatenations between null and non-null values being null are
- * unsupported;
- * - true: if concatenations between null and non-null values being null are
- * supported.
- */
+ /// Retrieves a boolean value indicating whether concatenations between null and
+ /// non-null values being null are supported.
+ ///
+ /// - Returns:
+ /// - false: if concatenations between null and non-null values being null are
+ /// unsupported;
+ /// - true: if concatenations between null and non-null values being null are
+ /// supported.
SQL_NULL_PLUS_NULL_IS_NULL = 516,
- /*
- * Retrieves a map where the key is the type to convert from and the value is a list
- * with the types to convert to, indicating the supported conversions. Each key and
- * each item on the list value is a value to a predefined type on SqlSupportsConvert
- * enum. The returned map will be: map<int32, list<int32>>
- */
+ /// Retrieves a map where the key is the type to convert from and the value is a list
+ /// with the types to convert to, indicating the supported conversions. Each key and
+ /// each item on the list value is a value to a predefined type on SqlSupportsConvert
+ /// enum. The returned map will be: map<int32, list<int32>>
SQL_SUPPORTS_CONVERT = 517,
- /*
- * Retrieves a boolean value indicating whether, when table correlation names are
- * supported, they are restricted to being different from the names of the tables.
- *
- * Returns:
- * - false: if table correlation names are unsupported;
- * - true: if table correlation names are supported.
- */
+ /// Retrieves a boolean value indicating whether, when table correlation names are
+ /// supported, they are restricted to being different from the names of the tables.
+ ///
+ /// Returns:
+ /// - false: if table correlation names are unsupported;
+ /// - true: if table correlation names are supported.
SQL_SUPPORTS_TABLE_CORRELATION_NAMES = 518,
- /*
- * Retrieves a boolean value indicating whether, when table correlation names are
- * supported, they are restricted to being different from the names of the tables.
- *
- * Returns:
- * - false: if different table correlation names are unsupported;
- * - true: if different table correlation names are supported
- */
+ /// Retrieves a boolean value indicating whether, when table correlation names are
+ /// supported, they are restricted to being different from the names of the tables.
+ ///
+ /// Returns:
+ /// - false: if different table correlation names are unsupported;
+ /// - true: if different table correlation names are supported
SQL_SUPPORTS_DIFFERENT_TABLE_CORRELATION_NAMES = 519,
- /*
- * Retrieves a boolean value indicating whether expressions in ORDER BY lists are
- * supported.
- *
- * Returns:
- * - false: if expressions in ORDER BY are unsupported;
- * - true: if expressions in ORDER BY are supported;
- */
+ /// Retrieves a boolean value indicating whether expressions in ORDER BY lists are
+ /// supported.
+ ///
+ /// Returns:
+ /// - false: if expressions in ORDER BY are unsupported;
+ /// - true: if expressions in ORDER BY are supported;
SQL_SUPPORTS_EXPRESSIONS_IN_ORDER_BY = 520,
- /*
- * Retrieves a boolean value indicating whether using a column that is not in the
- * SELECT statement in a GROUP BY clause is supported.
- *
- * Returns:
- * - false: if using a column that is not in the SELECT statement in a GROUP BY clause
- * is unsupported;
- * - true: if using a column that is not in the SELECT statement in a GROUP BY clause
- * is supported.
- */
+ /// Retrieves a boolean value indicating whether using a column that is not in the
+ /// SELECT statement in a GROUP BY clause is supported.
+ ///
+ /// Returns:
+ /// - false: if using a column that is not in the SELECT statement in a GROUP BY
+ /// clause
+ /// is unsupported;
+ /// - true: if using a column that is not in the SELECT statement in a GROUP BY clause
+ /// is supported.
SQL_SUPPORTS_ORDER_BY_UNRELATED = 521,
- /*
- * Retrieves the supported GROUP BY commands;
- *
- * Returns an int32 bitmask value representing the supported commands.
- * The returned bitmask should be parsed in order to retrieve the supported commands.
- *
- * For instance:
- * - return 0 (\b0) => [] (GROUP BY is unsupported);
- * - return 1 (\b1) => [SQL_GROUP_BY_UNRELATED];
- * - return 2 (\b10) => [SQL_GROUP_BY_BEYOND_SELECT];
- * - return 3 (\b11) => [SQL_GROUP_BY_UNRELATED, SQL_GROUP_BY_BEYOND_SELECT].
- * Valid GROUP BY types are described under
- * `arrow.flight.protocol.sql.SqlSupportedGroupBy`.
- */
+ /// Retrieves the supported GROUP BY commands as an int32 bitmask.
+ /// The returned bitmask should be parsed in order to retrieve the supported commands.
+ ///
+ /// - return 0 (0b0) => [] (GROUP BY is unsupported);
+ /// - return 1 (0b1) => [SQL_GROUP_BY_UNRELATED];
+ /// - return 2 (0b10) => [SQL_GROUP_BY_BEYOND_SELECT];
+ /// - return 3 (0b11) => [SQL_GROUP_BY_UNRELATED, SQL_GROUP_BY_BEYOND_SELECT].
+ ///
+ /// Valid GROUP BY types are described under
+ /// `arrow.flight.protocol.sql.SqlSupportedGroupBy`.
SQL_SUPPORTED_GROUP_BY = 522,
- /*
- * Retrieves a boolean value indicating whether specifying a LIKE escape clause is
- * supported.
- *
- * Returns:
- * - false: if specifying a LIKE escape clause is unsupported;
- * - true: if specifying a LIKE escape clause is supported.
- */
+ /// Retrieves a boolean value indicating whether specifying a LIKE escape clause is
+ /// supported.
+ ///
+ /// Returns:
+ /// - false: if specifying a LIKE escape clause is unsupported;
+ /// - true: if specifying a LIKE escape clause is supported.
SQL_SUPPORTS_LIKE_ESCAPE_CLAUSE = 523,
- /*
- * Retrieves a boolean value indicating whether columns may be defined as
- * non-nullable.
- *
- * Returns:
- * - false: if columns cannot be defined as non-nullable;
- * - true: if columns may be defined as non-nullable.
- */
+ /// Retrieves a boolean value indicating whether columns may be defined as
+ /// non-nullable.
+ ///
+ /// Returns:
+ /// - false: if columns cannot be defined as non-nullable;
+ /// - true: if columns may be defined as non-nullable.
SQL_SUPPORTS_NON_NULLABLE_COLUMNS = 524,
- /*
- * Retrieves the supported SQL grammar level as per the ODBC specification.
- *
- * Returns an int32 bitmask value representing the supported SQL grammar level.
- * The returned bitmask should be parsed in order to retrieve the supported grammar
- * levels.
- *
- * For instance:
- * - return 0 (\b0) => [] (SQL grammar is unsupported);
- * - return 1 (\b1) => [SQL_MINIMUM_GRAMMAR];
- * - return 2 (\b10) => [SQL_CORE_GRAMMAR];
- * - return 3 (\b11) => [SQL_MINIMUM_GRAMMAR, SQL_CORE_GRAMMAR];
- * - return 4 (\b100) => [SQL_EXTENDED_GRAMMAR];
- * - return 5 (\b101) => [SQL_MINIMUM_GRAMMAR, SQL_EXTENDED_GRAMMAR];
- * - return 6 (\b110) => [SQL_CORE_GRAMMAR, SQL_EXTENDED_GRAMMAR];
- * - return 7 (\b111) => [SQL_MINIMUM_GRAMMAR, SQL_CORE_GRAMMAR,
- * SQL_EXTENDED_GRAMMAR]. Valid SQL grammar levels are described under
- * `arrow.flight.protocol.sql.SupportedSqlGrammar`.
- */
+ /// Retrieves the supported SQL grammar level as per the ODBC
+ /// specification.
+ ///
+ /// Returns an int32 bitmask value representing the supported SQL grammar
+ /// level. The returned bitmask should be parsed in order to retrieve the
+ /// supported grammar levels.
+ ///
+ /// For instance:
+ /// - return 0 (0b0) => [] (SQL grammar is unsupported);
+ /// - return 1 (0b1) => [SQL_MINIMUM_GRAMMAR];
+ /// - return 2 (0b10) => [SQL_CORE_GRAMMAR];
+ /// - return 3 (0b11) => [SQL_MINIMUM_GRAMMAR, SQL_CORE_GRAMMAR];
+ /// - return 4 (0b100) => [SQL_EXTENDED_GRAMMAR];
+ /// - return 5 (0b101) => [SQL_MINIMUM_GRAMMAR, SQL_EXTENDED_GRAMMAR];
+ /// - return 6 (0b110) => [SQL_CORE_GRAMMAR, SQL_EXTENDED_GRAMMAR];
+ /// - return 7 (0b111) => [SQL_MINIMUM_GRAMMAR, SQL_CORE_GRAMMAR,
+ /// SQL_EXTENDED_GRAMMAR].
+ ///
+ /// Valid SQL grammar levels are described under
+ /// `arrow.flight.protocol.sql.SupportedSqlGrammar`.
SQL_SUPPORTED_GRAMMAR = 525,
- /*
- * Retrieves the supported ANSI92 SQL grammar level.
- *
- * Returns an int32 bitmask value representing the supported ANSI92 SQL grammar level.
- * The returned bitmask should be parsed in order to retrieve the supported commands.
- *
- * For instance:
- * - return 0 (\b0) => [] (ANSI92 SQL grammar is unsupported);
- * - return 1 (\b1) => [ANSI92_ENTRY_SQL];
- * - return 2 (\b10) => [ANSI92_INTERMEDIATE_SQL];
- * - return 3 (\b11) => [ANSI92_ENTRY_SQL, ANSI92_INTERMEDIATE_SQL];
- * - return 4 (\b100) => [ANSI92_FULL_SQL];
- * - return 5 (\b101) => [ANSI92_ENTRY_SQL, ANSI92_FULL_SQL];
- * - return 6 (\b110) => [ANSI92_INTERMEDIATE_SQL, ANSI92_FULL_SQL];
- * - return 7 (\b111) => [ANSI92_ENTRY_SQL, ANSI92_INTERMEDIATE_SQL, ANSI92_FULL_SQL].
- * Valid ANSI92 SQL grammar levels are described under
- * `arrow.flight.protocol.sql.SupportedAnsi92SqlGrammarLevel`.
- */
+ /// Retrieves the supported ANSI92 SQL grammar level as per the ODBC
+ /// specification.
+ ///
+ /// Returns an int32 bitmask value representing the supported ANSI92 SQL
+ /// grammar level. The returned bitmask should be parsed in order to
+ /// retrieve the supported grammar levels.
+ ///
+ /// For instance:
+ /// - return 0 (0b0) => [] (ANSI92 SQL grammar is unsupported);
+ /// - return 1 (0b1) => [ANSI92_ENTRY_SQL];
+ /// - return 2 (0b10) => [ANSI92_INTERMEDIATE_SQL];
+ /// - return 3 (0b11) => [ANSI92_ENTRY_SQL, ANSI92_INTERMEDIATE_SQL];
+ /// - return 4 (0b100) => [ANSI92_FULL_SQL];
+ /// - return 5 (0b101) => [ANSI92_ENTRY_SQL, ANSI92_FULL_SQL];
+ /// - return 6 (0b110) => [ANSI92_INTERMEDIATE_SQL, ANSI92_FULL_SQL];
+ /// - return 7 (0b111) => [ANSI92_ENTRY_SQL, ANSI92_INTERMEDIATE_SQL,
+ /// ANSI92_FULL_SQL].
+ ///
+ /// Valid ANSI92 SQL grammar levels are described under
+ /// `arrow.flight.protocol.sql.SupportedAnsi92SqlGrammarLevel`.
SQL_ANSI92_SUPPORTED_LEVEL = 526,
- /*
- * Retrieves a boolean value indicating whether the SQL Integrity Enhancement Facility
- * is supported.
- *
- * Returns:
- * - false: if the SQL Integrity Enhancement Facility is supported;
- * - true: if the SQL Integrity Enhancement Facility is supported.
- */
+ /// Retrieves a boolean value indicating whether the SQL Integrity
+ /// Enhancement Facility is supported.
+ ///
+ /// Returns:
+ /// - false: if the SQL Integrity Enhancement Facility is supported;
+ /// - true: if the SQL Integrity Enhancement Facility is supported.
SQL_SUPPORTS_INTEGRITY_ENHANCEMENT_FACILITY = 527,
- /*
- * Retrieves the support level for SQL OUTER JOINs.
- *
- * Returns a uint3 uint32 ordinal for the SQL ordering being used, as described in
- * `arrow.flight.protocol.sql.SqlOuterJoinsSupportLevel`.
- */
+ /// Retrieves the support level for SQL OUTER JOINs as an int32 ordinal, as
+ /// described in `arrow.flight.protocol.sql.SqlOuterJoinsSupportLevel`.
SQL_OUTER_JOINS_SUPPORT_LEVEL = 528,
- // Retrieves a UTF-8 string with the preferred term for "schema".
+ /// Retrieves a UTF-8 string with the preferred term for "schema".
SQL_SCHEMA_TERM = 529,
- // Retrieves a UTF-8 string with the preferred term for "procedure".
+ /// Retrieves a UTF-8 string with the preferred term for "procedure".
SQL_PROCEDURE_TERM = 530,
- // Retrieves a UTF-8 string with the preferred term for "catalog".
+ /// Retrieves a UTF-8 string with the preferred term for "catalog".
SQL_CATALOG_TERM = 531,
- /*
- * Retrieves a boolean value indicating whether a catalog appears at the start of a
- * fully qualified table name.
- *
- * - false: if a catalog does not appear at the start of a fully qualified table name;
- * - true: if a catalog appears at the start of a fully qualified table name.
- */
+ /// Retrieves a boolean value indicating whether a catalog appears at the
+ /// start of a fully qualified table name.
+ ///
+ /// - false: if a catalog does not appear at the start of a fully qualified table
+ /// name;
+ /// - true: if a catalog appears at the start of a fully qualified table name.
SQL_CATALOG_AT_START = 532,
- /*
- * Retrieves the supported actions for a SQL schema.
- *
- * Returns an int32 bitmask value representing the supported actions for a SQL schema.
- * The returned bitmask should be parsed in order to retrieve the supported actions
- * for a SQL schema.
- *
- * For instance:
- * - return 0 (\b0) => [] (no supported actions for SQL schema);
- * - return 1 (\b1) => [SQL_ELEMENT_IN_PROCEDURE_CALLS];
- * - return 2 (\b10) => [SQL_ELEMENT_IN_INDEX_DEFINITIONS];
- * - return 3 (\b11) => [SQL_ELEMENT_IN_PROCEDURE_CALLS,
- * SQL_ELEMENT_IN_INDEX_DEFINITIONS];
- * - return 4 (\b100) => [SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS];
- * - return 5 (\b101) => [SQL_ELEMENT_IN_PROCEDURE_CALLS,
- * SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS];
- * - return 6 (\b110) => [SQL_ELEMENT_IN_INDEX_DEFINITIONS,
- * SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS];
- * - return 7 (\b111) => [SQL_ELEMENT_IN_PROCEDURE_CALLS,
- * SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]. Valid
- * actions for a SQL schema described under
- * `arrow.flight.protocol.sql.SqlSupportedElementActions`.
- */
+ /// Retrieves the supported actions for a SQL database schema as an int32
+ /// bitmask value.
+ ///
+ /// Returns an int32 bitmask value representing the supported actions for a
+ /// SQL schema. The returned bitmask should be parsed in order to retrieve
+ /// the supported actions for a SQL schema.
+ ///
+ /// For instance:
+ /// - return 0 (0b0) => [] (no supported actions for SQL schema);
+ /// - return 1 (0b1) => [SQL_ELEMENT_IN_PROCEDURE_CALLS];
+ /// - return 2 (0b10) => [SQL_ELEMENT_IN_INDEX_DEFINITIONS];
+ /// - return 3 (0b11) => [SQL_ELEMENT_IN_PROCEDURE_CALLS,
+ /// SQL_ELEMENT_IN_INDEX_DEFINITIONS];
+ /// - return 4 (0b100) => [SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS];
+ /// - return 5 (0b101) => [SQL_ELEMENT_IN_PROCEDURE_CALLS,
+ /// SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS];
+ /// - return 6 (0b110) => [SQL_ELEMENT_IN_INDEX_DEFINITIONS,
+ /// SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS];
+ /// - return 7 (0b111) => [SQL_ELEMENT_IN_PROCEDURE_CALLS,
+ /// SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS].
+ ///
+ /// Valid actions for a SQL schema described under
+ /// `arrow.flight.protocol.sql.SqlSupportedElementActions`.
SQL_SCHEMAS_SUPPORTED_ACTIONS = 533,
- /*
- * Retrieves the supported actions for a SQL schema.
- *
- * Returns an int32 bitmask value representing the supported actions for a SQL
- * catalog. The returned bitmask should be parsed in order to retrieve the supported
- * actions for a SQL catalog.
- *
- * For instance:
- * - return 0 (\b0) => [] (no supported actions for SQL catalog);
- * - return 1 (\b1) => [SQL_ELEMENT_IN_PROCEDURE_CALLS];
- * - return 2 (\b10) => [SQL_ELEMENT_IN_INDEX_DEFINITIONS];
- * - return 3 (\b11) => [SQL_ELEMENT_IN_PROCEDURE_CALLS,
- * SQL_ELEMENT_IN_INDEX_DEFINITIONS];
- * - return 4 (\b100) => [SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS];
- * - return 5 (\b101) => [SQL_ELEMENT_IN_PROCEDURE_CALLS,
- * SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS];
- * - return 6 (\b110) => [SQL_ELEMENT_IN_INDEX_DEFINITIONS,
- * SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS];
- * - return 7 (\b111) => [SQL_ELEMENT_IN_PROCEDURE_CALLS,
- * SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]. Valid
- * actions for a SQL catalog are described under
- * `arrow.flight.protocol.sql.SqlSupportedElementActions`.
- */
+ /// Retrieves the supported actions for a SQL catalog as an int32 bitmask
+ /// value.
+ ///
+ /// Returns an int32 bitmask value representing the supported actions for a SQL
+ /// catalog. The returned bitmask should be parsed in order to retrieve the supported
+ /// actions for a SQL catalog.
+ ///
+ /// For instance:
+ /// - return 0 (0b0) => [] (no supported actions for SQL catalog);
+ /// - return 1 (0b1) => [SQL_ELEMENT_IN_PROCEDURE_CALLS];
+ /// - return 2 (0b10) => [SQL_ELEMENT_IN_INDEX_DEFINITIONS];
+ /// - return 3 (0b11) => [SQL_ELEMENT_IN_PROCEDURE_CALLS,
+ /// SQL_ELEMENT_IN_INDEX_DEFINITIONS];
+ /// - return 4 (0b100) => [SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS];
+ /// - return 5 (0b101) => [SQL_ELEMENT_IN_PROCEDURE_CALLS,
+ /// SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS];
+ /// - return 6 (0b110) => [SQL_ELEMENT_IN_INDEX_DEFINITIONS,
+ /// SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS];
+ /// - return 7 (0b111) => [SQL_ELEMENT_IN_PROCEDURE_CALLS,
+ /// SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS].
+ ///
+ /// Valid actions for a SQL catalog are described under
+ /// `arrow.flight.protocol.sql.SqlSupportedElementActions`.
SQL_CATALOGS_SUPPORTED_ACTIONS = 534,
- /*
- * Retrieves the supported SQL positioned commands.
- *
- * Returns an int32 bitmask value representing the supported SQL positioned commands.
- * The returned bitmask should be parsed in order to retrieve the supported SQL
- * positioned commands.
- *
- * For instance:
- * - return 0 (\b0) => [] (no supported SQL positioned commands);
- * - return 1 (\b1) => [SQL_POSITIONED_DELETE];
- * - return 2 (\b10) => [SQL_POSITIONED_UPDATE];
- * - return 3 (\b11) => [SQL_POSITIONED_DELETE, SQL_POSITIONED_UPDATE].
- * Valid SQL positioned commands are described under
- * `arrow.flight.protocol.sql.SqlSupportedPositionedCommands`.
- */
+ /// Retrieves the supported SQL positioned commands as an int32 bitmask
+ /// value.
+ ///
+ /// Returns an int32 bitmask value representing the supported SQL positioned commands.
+ /// The returned bitmask should be parsed in order to retrieve the supported SQL
+ /// positioned commands.
+ ///
+ /// For instance:
+ /// - return 0 (0b0) => [] (no supported SQL positioned commands);
+ /// - return 1 (0b1) => [SQL_POSITIONED_DELETE];
+ /// - return 2 (0b10) => [SQL_POSITIONED_UPDATE];
+ /// - return 3 (0b11) => [SQL_POSITIONED_DELETE, SQL_POSITIONED_UPDATE].
+ ///
+ /// Valid SQL positioned commands are described under
+ /// `arrow.flight.protocol.sql.SqlSupportedPositionedCommands`.
SQL_SUPPORTED_POSITIONED_COMMANDS = 535,
- /*
- * Retrieves a boolean value indicating whether SELECT FOR UPDATE statements are
- * supported.
- *
- * Returns:
- * - false: if SELECT FOR UPDATE statements are unsupported;
- * - true: if SELECT FOR UPDATE statements are supported.
- */
+ /// Retrieves a boolean value indicating whether SELECT FOR UPDATE
+ /// statements are supported.
+ ///
+ /// Returns:
+ /// - false: if SELECT FOR UPDATE statements are unsupported;
+ /// - true: if SELECT FOR UPDATE statements are supported.
SQL_SELECT_FOR_UPDATE_SUPPORTED = 536,
- /*
- * Retrieves a boolean value indicating whether stored procedure calls that use the
- * stored procedure escape syntax are supported.
- *
- * Returns:
- * - false: if stored procedure calls that use the stored procedure escape syntax are
- * unsupported;
- * - true: if stored procedure calls that use the stored procedure escape syntax are
- * supported.
- */
+ /// Retrieves a boolean value indicating whether stored procedure calls
+ /// that use the stored procedure escape syntax are supported.
+ ///
+ /// Returns:
+ /// - false: if stored procedure calls that use the stored procedure escape syntax are
+ /// unsupported;
+ /// - true: if stored procedure calls that use the stored procedure escape syntax are
+ /// supported.
SQL_STORED_PROCEDURES_SUPPORTED = 537,
- /*
- * Retrieves the supported SQL subqueries.
- *
- * Returns an int32 bitmask value representing the supported SQL subqueries.
- * The returned bitmask should be parsed in order to retrieve the supported SQL
- * subqueries.
- *
- * For instance:
- * - return 0 (\b0) => [] (no supported SQL subqueries);
- * - return 1 (\b1) => [SQL_SUBQUERIES_IN_COMPARISONS];
- * - return 2 (\b10) => [SQL_SUBQUERIES_IN_EXISTS];
- * - return 3 (\b11) => [SQL_SUBQUERIES_IN_COMPARISONS,
- * SQL_SUBQUERIES_IN_EXISTS];
- * - return 4 (\b100) => [SQL_SUBQUERIES_IN_INS];
- * - return 5 (\b101) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_INS];
- * - return 6 (\b110) => [SQL_SUBQUERIES_IN_COMPARISONS,
- * SQL_SUBQUERIES_IN_EXISTS];
- * - return 7 (\b111) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS,
- * SQL_SUBQUERIES_IN_INS];
- * - return 8 (\b1000) => [SQL_SUBQUERIES_IN_QUANTIFIEDS];
- * - return 9 (\b1001) => [SQL_SUBQUERIES_IN_COMPARISONS,
- * SQL_SUBQUERIES_IN_QUANTIFIEDS];
- * - return 10 (\b1010) => [SQL_SUBQUERIES_IN_EXISTS,
- * SQL_SUBQUERIES_IN_QUANTIFIEDS];
- * - return 11 (\b1011) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS,
- * SQL_SUBQUERIES_IN_QUANTIFIEDS];
- * - return 12 (\b1100) => [SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS];
- * - return 13 (\b1101) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_INS,
- * SQL_SUBQUERIES_IN_QUANTIFIEDS];
- * - return 14 (\b1110) => [SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_INS,
- * SQL_SUBQUERIES_IN_QUANTIFIEDS];
- * - return 15 (\b1111) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS,
- * SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS];
- * - ...
- * Valid SQL subqueries are described under
- * `arrow.flight.protocol.sql.SqlSupportedSubqueries`.
- */
+ /// Retrieves the types of supported SQL subqueries as an int32 bitmask
+ /// value.
+ ///
+ /// Returns an int32 bitmask value representing the supported SQL
+ /// subqueries. The returned bitmask should be parsed in order to retrieve
+ /// the supported SQL subqueries.
+ ///
+ /// For instance:
+ /// - return 0 (0b0) => [] (no supported SQL subqueries);
+ /// - return 1 (0b1) => [SQL_SUBQUERIES_IN_COMPARISONS];
+ /// - return 2 (0b10) => [SQL_SUBQUERIES_IN_EXISTS];
+ /// - return 3 (0b11) => [SQL_SUBQUERIES_IN_COMPARISONS,
+ /// SQL_SUBQUERIES_IN_EXISTS];
+ /// - return 4 (0b100) => [SQL_SUBQUERIES_IN_INS];
+ /// - return 5 (0b101) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_INS];
+ /// - return 6 (0b110) => [SQL_SUBQUERIES_IN_COMPARISONS,
+ /// SQL_SUBQUERIES_IN_EXISTS];
+ /// - return 7 (0b111) => [SQL_SUBQUERIES_IN_COMPARISONS,
+ /// SQL_SUBQUERIES_IN_EXISTS,
+ /// SQL_SUBQUERIES_IN_INS];
+ /// - return 8 (0b1000) => [SQL_SUBQUERIES_IN_QUANTIFIEDS];
+ /// - return 9 (0b1001) => [SQL_SUBQUERIES_IN_COMPARISONS,
+ /// SQL_SUBQUERIES_IN_QUANTIFIEDS];
+ /// - return 10 (0b1010) => [SQL_SUBQUERIES_IN_EXISTS,
+ /// SQL_SUBQUERIES_IN_QUANTIFIEDS];
+ /// - return 11 (0b1011) => [SQL_SUBQUERIES_IN_COMPARISONS,
+ /// SQL_SUBQUERIES_IN_EXISTS,
+ /// SQL_SUBQUERIES_IN_QUANTIFIEDS];
+ /// - return 12 (0b1100) => [SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS];
+ /// - return 13 (0b1101) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_INS,
+ /// SQL_SUBQUERIES_IN_QUANTIFIEDS];
+ /// - return 14 (0b1110) => [SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_INS,
+ /// SQL_SUBQUERIES_IN_QUANTIFIEDS];
+ /// - return 15 (0b1111) => [SQL_SUBQUERIES_IN_COMPARISONS,
+ /// SQL_SUBQUERIES_IN_EXISTS,
+ /// SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS];
+ /// - ...
+ ///
+ /// Valid SQL subqueries are described under
+ /// `arrow.flight.protocol.sql.SqlSupportedSubqueries`.
SQL_SUPPORTED_SUBQUERIES = 538,
- /*
- * Retrieves a boolean value indicating whether correlated subqueries are supported.
- *
- * Returns:
- * - false: if correlated subqueries are unsupported;
- * - true: if correlated subqueries are supported.
- */
+ /// Retrieves a boolean value indicating whether correlated subqueries are
+ /// supported.
+ ///
+ /// Returns:
+ /// - false: if correlated subqueries are unsupported;
+ /// - true: if correlated subqueries are supported.
SQL_CORRELATED_SUBQUERIES_SUPPORTED = 539,
- /*
- * Retrieves the supported SQL UNIONs.
- *
- * Returns an int32 bitmask value representing the supported SQL UNIONs.
- * The returned bitmask should be parsed in order to retrieve the supported SQL
- * UNIONs.
- *
- * For instance:
- * - return 0 (\b0) => [] (no supported SQL positioned commands);
- * - return 1 (\b1) => [SQL_UNION];
- * - return 2 (\b10) => [SQL_UNION_ALL];
- * - return 3 (\b11) => [SQL_UNION, SQL_UNION_ALL].
- * Valid SQL positioned commands are described under
- * `arrow.flight.protocol.sql.SqlSupportedUnions`.
- */
+ /// Retrieves the supported SQL UNION features as an int32 bitmask
+ /// value.
+ ///
+ /// Returns an int32 bitmask value representing the supported SQL UNIONs.
+ /// The returned bitmask should be parsed in order to retrieve the supported SQL
+ /// UNIONs.
+ ///
+ /// For instance:
+ /// - return 0 (0b0) => [] (no supported SQL positioned commands);
+ /// - return 1 (0b1) => [SQL_UNION];
+ /// - return 2 (0b10) => [SQL_UNION_ALL];
+ /// - return 3 (0b11) => [SQL_UNION, SQL_UNION_ALL].
+ ///
+ /// Valid SQL union operators are described under
+ /// `arrow.flight.protocol.sql.SqlSupportedUnions`.
SQL_SUPPORTED_UNIONS = 540,
- // Retrieves a uint32 value representing the maximum number of hex characters allowed
- // in an inline binary literal.
+ /// Retrieves a int64 value representing the maximum number of hex
+ /// characters allowed in an inline binary literal.
SQL_MAX_BINARY_LITERAL_LENGTH = 541,
- // Retrieves a uint32 value representing the maximum number of characters allowed for
- // a character literal.
+ /// Retrieves a int64 value representing the maximum number of characters
+ /// allowed for a character literal.
SQL_MAX_CHAR_LITERAL_LENGTH = 542,
- // Retrieves a uint32 value representing the maximum number of characters allowed for
- // a column name.
+ /// Retrieves a int64 value representing the maximum number of characters
+ /// allowed for a column name.
SQL_MAX_COLUMN_NAME_LENGTH = 543,
- // Retrieves a uint32 value representing the the maximum number of columns allowed in
- // a GROUP BY clause.
+ /// Retrieves a int64 value representing the the maximum number of columns
+ /// allowed in a GROUP BY clause.
SQL_MAX_COLUMNS_IN_GROUP_BY = 544,
- // Retrieves a uint32 value representing the maximum number of columns allowed in an
- // index.
+ /// Retrieves a int64 value representing the maximum number of columns
+ /// allowed in an index.
SQL_MAX_COLUMNS_IN_INDEX = 545,
- // Retrieves a uint32 value representing the maximum number of columns allowed in an
- // ORDER BY clause.
+ /// Retrieves a int64 value representing the maximum number of columns
+ /// allowed in an ORDER BY clause.
SQL_MAX_COLUMNS_IN_ORDER_BY = 546,
- // Retrieves a uint32 value representing the maximum number of columns allowed in a
- // SELECT list.
+ /// Retrieves a int64 value representing the maximum number of columns
+ /// allowed in a SELECT list.
SQL_MAX_COLUMNS_IN_SELECT = 547,
- // Retrieves a uint32 value representing the maximum number of columns allowed in a
- // table.
+ /// Retrieves a int64 value representing the maximum number of columns
+ /// allowed in a table.
SQL_MAX_COLUMNS_IN_TABLE = 548,
- // Retrieves a uint32 value representing the maximum number of concurrent connections
- // possible.
+ /// Retrieves a int64 value representing the maximum number of concurrent
+ /// connections possible.
SQL_MAX_CONNECTIONS = 549,
- // Retrieves a uint32 value the maximum number of characters allowed in a cursor name.
+ /// Retrieves a int64 value the maximum number of characters allowed in a
+ /// cursor name.
SQL_MAX_CURSOR_NAME_LENGTH = 550,
- /*
- * Retrieves a uint32 value representing the maximum number of bytes allowed for an
- * index, including all of the parts of the index.
- */
+ /// Retrieves a int64 value representing the maximum number of bytes
+ /// allowed for an index, including all of the parts of the index.
SQL_MAX_INDEX_LENGTH = 551,
- // Retrieves a uint32 value representing the maximum number of characters allowed in a
- // procedure name.
+ /// Retrieves a int64 value representing the maximum number of characters
+ /// allowed in a procedure name.
SQL_SCHEMA_NAME_LENGTH = 552,
- // Retrieves a uint32 value representing the maximum number of bytes allowed in a
- // single row.
+ /// Retrieves a int64 value representing the maximum number of bytes
+ /// allowed in a single row.
SQL_MAX_PROCEDURE_NAME_LENGTH = 553,
- // Retrieves a uint32 value representing the maximum number of characters allowed in a
- // catalog name.
+ /// Retrieves a int64 value representing the maximum number of characters
+ /// allowed in a catalog name.
SQL_MAX_CATALOG_NAME_LENGTH = 554,
- // Retrieves a uint32 value representing the maximum number of bytes allowed in a
- // single row.
+ /// Retrieves a int64 value representing the maximum number of bytes
+ /// allowed in a single row.
SQL_MAX_ROW_SIZE = 555,
- /*
- * Retrieves a boolean indicating whether the return value for the JDBC method
- * getMaxRowSize includes the SQL data types LONGVARCHAR and LONGVARBINARY.
- *
- * Returns:
- * - false: if return value for the JDBC method getMaxRowSize does
- * not include the SQL data types LONGVARCHAR and LONGVARBINARY;
- * - true: if return value for the JDBC method getMaxRowSize includes
- * the SQL data types LONGVARCHAR and LONGVARBINARY.
- */
+ /// Retrieves a boolean indicating whether the return value for the JDBC
+ /// method getMaxRowSize includes the SQL data types LONGVARCHAR and
+ /// LONGVARBINARY.
+ ///
+ /// Returns:
+ /// - false: if return value for the JDBC method getMaxRowSize does
+ /// not include the SQL data types LONGVARCHAR and LONGVARBINARY;
+ /// - true: if return value for the JDBC method getMaxRowSize includes
+ /// the SQL data types LONGVARCHAR and LONGVARBINARY.
SQL_MAX_ROW_SIZE_INCLUDES_BLOBS = 556,
- /*
- * Retrieves a uint32 value representing the maximum number of characters allowed for
- * an SQL statement; a result of 0 (zero) means that there is no limit or the limit is
- * not known.
- */
+ /// Retrieves a int32 value representing the maximum number of characters
+ /// allowed for an SQL statement; a result of 0 (zero) means that there is
+ /// no limit or the limit is not known.
SQL_MAX_STATEMENT_LENGTH = 557,
- // Retrieves a uint32 value representing the maximum number of active statements that
- // can be open at the same time.
+ /// Retrieves a int32 value representing the maximum number of active
+ /// statements that can be open at the same time.
SQL_MAX_STATEMENTS = 558,
- // Retrieves a uint32 value representing the maximum number of characters allowed in a
- // table name.
+ /// Retrieves a int32 value representing the maximum number of characters
+ /// allowed in a table name.
SQL_MAX_TABLE_NAME_LENGTH = 559,
- // Retrieves a uint32 value representing the maximum number of tables allowed in a
- // SELECT statement.
+ /// Retrieves a int32 value representing the maximum number of tables
+ /// allowed in a SELECT statement.
SQL_MAX_TABLES_IN_SELECT = 560,
- // Retrieves a uint32 value representing the maximum number of characters allowed in a
- // user name.
+ /// Retrieves a int32 value representing the maximum number of characters
+ /// allowed in a user name.
SQL_MAX_USERNAME_LENGTH = 561,
- /*
- * Retrieves this database's default transaction isolation level as described in
- * `arrow.flight.protocol.sql.SqlTransactionIsolationLevel`.
- *
- * Returns a uint32 ordinal for the SQL transaction isolation level.
- */
+ /// Retrieves this database's default transaction isolation level as
+ /// described in `arrow.flight.protocol.sql.SqlTransactionIsolationLevel`.
+ ///
+ /// Returns a int32 ordinal for the SQL transaction isolation level.
SQL_DEFAULT_TRANSACTION_ISOLATION = 562,
- /*
- * Retrieves a boolean value indicating whether transactions are supported. If not,
- * invoking the method commit is a noop, and the isolation level is
- * `arrow.flight.protocol.sql.SqlTransactionIsolationLevel.TRANSACTION_NONE`.
- *
- * Returns:
- * - false: if transactions are unsupported;
- * - true: if transactions are supported.
- */
+ /// Retrieves a boolean value indicating whether transactions are
+ /// supported. If not, invoking the method commit is a noop, and the
+ /// isolation level is
+ /// `arrow.flight.protocol.sql.SqlTransactionIsolationLevel.TRANSACTION_NONE`.
+ ///
+ /// Returns:
+ /// - false: if transactions are unsupported;
+ /// - true: if transactions are supported.
SQL_TRANSACTIONS_SUPPORTED = 563,
- /*
- * Retrieves the supported transactions isolation levels.
- *
- * Returns an int32 bitmask value representing the supported transactions isolation
- * levels. The returned bitmask should be parsed in order to retrieve the supported
- * transactions isolation levels.
- *
- * For instance:
- * - return 0 (\b0) => [] (no supported SQL transactions isolation levels);
- * - return 1 (\b1) => [SQL_TRANSACTION_NONE];
- * - return 2 (\b10) => [SQL_TRANSACTION_READ_UNCOMMITTED];
- * - return 3 (\b11) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED];
- * - return 4 (\b100) => [SQL_TRANSACTION_REPEATABLE_READ];
- * - return 5 (\b101) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_REPEATABLE_READ];
- * - return 6 (\b110) => [SQL_TRANSACTION_READ_UNCOMMITTED,
- * SQL_TRANSACTION_REPEATABLE_READ];
- * - return 7 (\b111) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED,
- * SQL_TRANSACTION_REPEATABLE_READ];
- * - return 8 (\b1000) => [SQL_TRANSACTION_REPEATABLE_READ];
- * - return 9 (\b1001) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_REPEATABLE_READ];
- * - return 10 (\b1010) => [SQL_TRANSACTION_READ_UNCOMMITTED,
- * SQL_TRANSACTION_REPEATABLE_READ];
- * - return 11 (\b1011) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED,
- * SQL_TRANSACTION_REPEATABLE_READ];
- * - return 12 (\b1100) => [SQL_TRANSACTION_REPEATABLE_READ,
- * SQL_TRANSACTION_REPEATABLE_READ];
- * - return 13 (\b1101) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_REPEATABLE_READ,
- * SQL_TRANSACTION_REPEATABLE_READ];
- * - return 14 (\b1110) => [SQL_TRANSACTION_READ_UNCOMMITTED,
- * SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ];
- * - return 15 (\b1111) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED,
- * SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ];
- * - return 16 (\b10000) => [SQL_TRANSACTION_SERIALIZABLE];
- * - ...
- * Valid SQL positioned commands are described under
- * `arrow.flight.protocol.sql.SqlTransactionIsolationLevel`.
- */
+ /// Retrieves the supported transactions isolation levels, if transactions
+ /// are supported.
+ ///
+ /// Returns an int32 bitmask value representing the supported transactions
+ /// isolation levels. The returned bitmask should be parsed in order to
+ /// retrieve the supported transactions isolation levels.
+ ///
+ /// For instance:
+ /// - return 0 (0b0) => [] (no supported SQL transactions isolation levels);
+ /// - return 1 (0b1) => [SQL_TRANSACTION_NONE];
+ /// - return 2 (0b10) => [SQL_TRANSACTION_READ_UNCOMMITTED];
+ /// - return 3 (0b11) => [SQL_TRANSACTION_NONE,
+ /// SQL_TRANSACTION_READ_UNCOMMITTED];
+ /// - return 4 (0b100) => [SQL_TRANSACTION_REPEATABLE_READ];
+ /// - return 5 (0b101) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_REPEATABLE_READ];
+ /// - return 6 (0b110) => [SQL_TRANSACTION_READ_UNCOMMITTED,
+ /// SQL_TRANSACTION_REPEATABLE_READ];
+ /// - return 7 (0b111) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED,
+ /// SQL_TRANSACTION_REPEATABLE_READ];
+ /// - return 8 (0b1000) => [SQL_TRANSACTION_REPEATABLE_READ];
+ /// - return 9 (0b1001) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_REPEATABLE_READ];
+ /// - return 10 (0b1010) => [SQL_TRANSACTION_READ_UNCOMMITTED,
+ /// SQL_TRANSACTION_REPEATABLE_READ];
+ /// - return 11 (0b1011) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED,
+ /// SQL_TRANSACTION_REPEATABLE_READ];
+ /// - return 12 (0b1100) => [SQL_TRANSACTION_REPEATABLE_READ,
+ /// SQL_TRANSACTION_REPEATABLE_READ];
+ /// - return 13 (0b1101) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_REPEATABLE_READ,
+ /// SQL_TRANSACTION_REPEATABLE_READ];
+ /// - return 14 (0b1110) => [SQL_TRANSACTION_READ_UNCOMMITTED,
+ /// SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ];
+ /// - return 15 (0b1111) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED,
+ /// SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ];
+ /// - return 16 (0b10000) => [SQL_TRANSACTION_SERIALIZABLE];
+ /// - ...
+ ///
+ /// Valid SQL positioned commands are described under
+ /// `arrow.flight.protocol.sql.SqlTransactionIsolationLevel`.
SQL_SUPPORTED_TRANSACTIONS_ISOLATION_LEVELS = 564,
- /*
- * Retrieves a boolean value indicating whether a data definition statement within a
- * transaction forces the transaction to commit.
- *
- * Returns:
- * - false: if a data definition statement within a transaction does not force the
- * transaction to commit;
- * - true: if a data definition statement within a transaction forces the transaction
- * to commit.
- */
+ /// Retrieves a boolean value indicating whether a data definition
+ /// statement within a transaction forces the transaction to commit.
+ ///
+ /// Returns:
+ /// - false: if a data definition statement within a transaction does not force the
+ /// transaction to commit;
+ /// - true: if a data definition statement within a transaction forces the transaction
+ /// to commit.
SQL_DATA_DEFINITION_CAUSES_TRANSACTION_COMMIT = 565,
- /*
- * Retrieves a boolean value indicating whether a data definition statement within a
- * transaction is ignored.
- *
- * Returns:
- * - false: if a data definition statement within a transaction is taken into account;
- * - true: a data definition statement within a transaction is ignored.
- */
+ /// Retrieves a boolean value indicating whether a data definition
+ /// statement within a transaction is ignored.
+ ///
+ /// Returns:
+ /// - false: if a data definition statement within a transaction is taken into
+ /// account;
+ /// - true: a data definition statement within a transaction is ignored.
SQL_DATA_DEFINITIONS_IN_TRANSACTIONS_IGNORED = 566,
- /*
- * Retrieves an int32 bitmask value representing the supported result set types.
- * The returned bitmask should be parsed in order to retrieve the supported result set
- * types.
- *
- * For instance:
- * - return 0 (\b0) => [] (no supported result set types);
- * - return 1 (\b1) => [SQL_RESULT_SET_TYPE_UNSPECIFIED];
- * - return 2 (\b10) => [SQL_RESULT_SET_TYPE_FORWARD_ONLY];
- * - return 3 (\b11) => [SQL_RESULT_SET_TYPE_UNSPECIFIED,
- * SQL_RESULT_SET_TYPE_FORWARD_ONLY];
- * - return 4 (\b100) => [SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE];
- * - return 5 (\b101) => [SQL_RESULT_SET_TYPE_UNSPECIFIED,
- * SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE];
- * - return 6 (\b110) => [SQL_RESULT_SET_TYPE_FORWARD_ONLY,
- * SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE];
- * - return 7 (\b111) => [SQL_RESULT_SET_TYPE_UNSPECIFIED,
- * SQL_RESULT_SET_TYPE_FORWARD_ONLY, SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE];
- * - return 8 (\b1000) => [SQL_RESULT_SET_TYPE_SCROLL_SENSITIVE];
- * - ...
- * Valid result set types are described under
- * `arrow.flight.protocol.sql.SqlSupportedResultSetType`.
- */
+ /// Retrieves an int32 bitmask value representing the supported result set
+ /// types. The returned bitmask should be parsed in order to retrieve the
+ /// supported result set types.
+ ///
+ /// For instance:
+ /// - return 0 (0b0) => [] (no supported result set types);
+ /// - return 1 (0b1) => [SQL_RESULT_SET_TYPE_UNSPECIFIED];
+ /// - return 2 (0b10) => [SQL_RESULT_SET_TYPE_FORWARD_ONLY];
+ /// - return 3 (0b11) => [SQL_RESULT_SET_TYPE_UNSPECIFIED,
+ /// SQL_RESULT_SET_TYPE_FORWARD_ONLY];
+ /// - return 4 (0b100) => [SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE];
+ /// - return 5 (0b101) => [SQL_RESULT_SET_TYPE_UNSPECIFIED,
+ /// SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE];
+ /// - return 6 (0b110) => [SQL_RESULT_SET_TYPE_FORWARD_ONLY,
+ /// SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE];
+ /// - return 7 (0b111) => [SQL_RESULT_SET_TYPE_UNSPECIFIED,
+ /// SQL_RESULT_SET_TYPE_FORWARD_ONLY, SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE];
+ /// - return 8 (0b1000) => [SQL_RESULT_SET_TYPE_SCROLL_SENSITIVE];
+ /// - ...
+ ///
+ /// Valid result set types are described under
+ /// `arrow.flight.protocol.sql.SqlSupportedResultSetType`.
SQL_SUPPORTED_RESULT_SET_TYPES = 567,
- /*
- * Returns an int32 bitmask value concurrency types supported for
- * `arrow.flight.protocol.sql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_UNSPECIFIED`.
- *
- * For instance:
- * - return 0 (\b0) => [] (no supported concurrency types for this result set type)
- * - return 1 (\b1) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED]
- * - return 2 (\b10) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY]
- * - return 3 (\b11) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED,
- * SQL_RESULT_SET_CONCURRENCY_READ_ONLY]
- * - return 4 (\b100) => [SQL_RESULT_SET_CONCURRENCY_UPDATABLE]
- * - return 5 (\b101) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED,
- * SQL_RESULT_SET_CONCURRENCY_UPDATABLE]
- * - return 6 (\b110) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY,
- * SQL_RESULT_SET_CONCURRENCY_UPDATABLE]
- * - return 7 (\b111) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED,
- * SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] Valid
- * result set types are described under
- * `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`.
- */
+ /// Returns an int32 bitmask value representing the concurrency types
+ /// supported by the server for
+ /// `SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_UNSPECIFIED`.
+ ///
+ /// For instance:
+ /// - return 0 (0b0) => [] (no supported concurrency types for this result set type)
+ /// - return 1 (0b1) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED]
+ /// - return 2 (0b10) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY]
+ /// - return 3 (0b11) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED,
+ /// SQL_RESULT_SET_CONCURRENCY_READ_ONLY]
+ /// - return 4 (0b100) => [SQL_RESULT_SET_CONCURRENCY_UPDATABLE]
+ /// - return 5 (0b101) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED,
+ /// SQL_RESULT_SET_CONCURRENCY_UPDATABLE]
+ /// - return 6 (0b110) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY,
+ /// SQL_RESULT_SET_CONCURRENCY_UPDATABLE]
+ /// - return 7 (0b111) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED,
+ /// SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE]
+ ///
+ /// Valid result set types are described under
+ /// `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`.
SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_UNSPECIFIED = 568,
- /*
- * Returns an int32 bitmask value concurrency types supported for
- * `arrow.flight.protocol.sql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_FORWARD_ONLY`.
- *
- * For instance:
- * - return 0 (\b0) => [] (no supported concurrency types for this result set type)
- * - return 1 (\b1) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED]
- * - return 2 (\b10) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY]
- * - return 3 (\b11) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED,
- * SQL_RESULT_SET_CONCURRENCY_READ_ONLY]
- * - return 4 (\b100) => [SQL_RESULT_SET_CONCURRENCY_UPDATABLE]
- * - return 5 (\b101) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED,
- * SQL_RESULT_SET_CONCURRENCY_UPDATABLE]
- * - return 6 (\b110) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY,
- * SQL_RESULT_SET_CONCURRENCY_UPDATABLE]
- * - return 7 (\b111) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED,
- * SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] Valid
- * result set types are described under
- * `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`.
- */
+ /// Returns an int32 bitmask value representing the concurrency types
+ /// supported by the server for
+ /// `SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_FORWARD_ONLY`.
+ ///
+ /// For instance:
+ /// - return 0 (0b0) => [] (no supported concurrency types for this result set type)
+ /// - return 1 (0b1) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED]
+ /// - return 2 (0b10) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY]
+ /// - return 3 (0b11) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED,
+ /// SQL_RESULT_SET_CONCURRENCY_READ_ONLY]
+ /// - return 4 (0b100) => [SQL_RESULT_SET_CONCURRENCY_UPDATABLE]
+ /// - return 5 (0b101) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED,
+ /// SQL_RESULT_SET_CONCURRENCY_UPDATABLE]
+ /// - return 6 (0b110) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY,
+ /// SQL_RESULT_SET_CONCURRENCY_UPDATABLE]
+ /// - return 7 (0b111) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED,
+ /// SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE]
+ ///
+ /// Valid result set types are described under
+ /// `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`.
SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_FORWARD_ONLY = 569,
- /*
- * Returns an int32 bitmask value concurrency types supported for
- * `arrow.flight.protocol.sql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_SCROLL_SENSITIVE`.
- *
- * For instance:
- * - return 0 (\b0) => [] (no supported concurrency types for this result set type)
- * - return 1 (\b1) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED]
- * - return 2 (\b10) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY]
- * - return 3 (\b11) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED,
- * SQL_RESULT_SET_CONCURRENCY_READ_ONLY]
- * - return 4 (\b100) => [SQL_RESULT_SET_CONCURRENCY_UPDATABLE]
- * - return 5 (\b101) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED,
- * SQL_RESULT_SET_CONCURRENCY_UPDATABLE]
- * - return 6 (\b110) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY,
- * SQL_RESULT_SET_CONCURRENCY_UPDATABLE]
- * - return 7 (\b111) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED,
- * SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] Valid
- * result set types are described under
- * `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`.
- */
+ /// Returns an int32 bitmask value representing the concurrency types
+ /// supported by the server for
+ /// `SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_SCROLL_SENSITIVE`.
+ ///
+ /// For instance:
+ /// - return 0 (0b0) => [] (no supported concurrency types for this result set type)
+ /// - return 1 (0b1) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED]
+ /// - return 2 (0b10) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY]
+ /// - return 3 (0b11) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED,
+ /// SQL_RESULT_SET_CONCURRENCY_READ_ONLY]
+ /// - return 4 (0b100) => [SQL_RESULT_SET_CONCURRENCY_UPDATABLE]
+ /// - return 5 (0b101) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED,
+ /// SQL_RESULT_SET_CONCURRENCY_UPDATABLE]
+ /// - return 6 (0b110) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY,
+ /// SQL_RESULT_SET_CONCURRENCY_UPDATABLE]
+ /// - return 7 (0b111) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED,
+ /// SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE]
+ ///
+ /// Valid result set types are described under
+ /// `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`.
SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_SCROLL_SENSITIVE = 570,
- /*
- * Returns an int32 bitmask value concurrency types supported for
- * `arrow.flight.protocol.sql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE`.
- *
- * For instance:
- * - return 0 (\b0) => [] (no supported concurrency types for this result set type)
- * - return 1 (\b1) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED]
- * - return 2 (\b10) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY]
- * - return 3 (\b11) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED,
- * SQL_RESULT_SET_CONCURRENCY_READ_ONLY]
- * - return 4 (\b100) => [SQL_RESULT_SET_CONCURRENCY_UPDATABLE]
- * - return 5 (\b101) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED,
- * SQL_RESULT_SET_CONCURRENCY_UPDATABLE]
- * - return 6 (\b110) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY,
- * SQL_RESULT_SET_CONCURRENCY_UPDATABLE]
- * - return 7 (\b111) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED,
- * SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] Valid
- * result set types are described under
- * `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`.
- */
+ /// Returns an int32 bitmask value representing concurrency types supported
+ /// by the server for
+ /// `SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE`.
+ ///
+ /// For instance:
+ /// - return 0 (0b0) => [] (no supported concurrency types for this result set type)
+ /// - return 1 (0b1) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED]
+ /// - return 2 (0b10) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY]
+ /// - return 3 (0b11) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED,
+ /// SQL_RESULT_SET_CONCURRENCY_READ_ONLY]
+ /// - return 4 (0b100) => [SQL_RESULT_SET_CONCURRENCY_UPDATABLE]
+ /// - return 5 (0b101) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED,
+ /// SQL_RESULT_SET_CONCURRENCY_UPDATABLE]
+ /// - return 6 (0b110) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY,
+ /// SQL_RESULT_SET_CONCURRENCY_UPDATABLE]
+ /// - return 7 (0b111) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED,
+ /// SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE]
+ ///
+ /// Valid result set types are described under
+ /// `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`.
SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_SCROLL_INSENSITIVE = 571,
- /*
- * Retrieves a boolean value indicating whether this database supports batch updates.
- *
- * - false: if this database does not support batch updates;
- * - true: if this database supports batch updates.
- */
+ /// Retrieves a boolean value indicating whether this database supports batch updates.
+ ///
+ /// - false: if this database does not support batch updates;
+ /// - true: if this database supports batch updates.
SQL_BATCH_UPDATES_SUPPORTED = 572,
- /*
- * Retrieves a boolean value indicating whether this database supports savepoints.
- *
- * Returns:
- * - false: if this database does not support savepoints;
- * - true: if this database supports savepoints.
- */
+ /// Retrieves a boolean value indicating whether this database supports savepoints.
+ ///
+ /// Returns:
+ /// - false: if this database does not support savepoints;
+ /// - true: if this database supports savepoints.
SQL_SAVEPOINTS_SUPPORTED = 573,
- /*
- * Retrieves a boolean value indicating whether named parameters are supported in
- * callable statements.
- *
- * Returns:
- * - false: if named parameters in callable statements are unsupported;
- * - true: if named parameters in callable statements are supported.
- */
+ /// Retrieves a boolean value indicating whether named parameters are supported in
+ /// callable statements.
+ ///
+ /// Returns:
+ /// - false: if named parameters in callable statements are unsupported;
+ /// - true: if named parameters in callable statements are supported.
SQL_NAMED_PARAMETERS_SUPPORTED = 574,
- /*
- * Retrieves a boolean value indicating whether updates made to a LOB are made on a
- * copy or directly to the LOB.
- *
- * Returns:
- * - false: if updates made to a LOB are made directly to the LOB;
- * - true: if updates made to a LOB are made on a copy.
- */
+ /// Retrieves a boolean value indicating whether updates made to a LOB are made on a
+ /// copy or directly to the LOB.
+ ///
+ /// Returns:
+ /// - false: if updates made to a LOB are made directly to the LOB;
+ /// - true: if updates made to a LOB are made on a copy.
SQL_LOCATORS_UPDATE_COPY = 575,
- /*
- * Retrieves a boolean value indicating whether invoking user-defined or vendor
- * functions using the stored procedure escape syntax is supported.
- *
- * Returns:
- * - false: if invoking user-defined or vendor functions using the stored procedure
- * escape syntax is unsupported;
- * - true: if invoking user-defined or vendor functions using the stored procedure
- * escape syntax is supported.
- */
+ /// Retrieves a boolean value indicating whether invoking user-defined or vendor
+ /// functions using the stored procedure escape syntax is supported.
+ ///
+ /// Returns:
+ /// - false: if invoking user-defined or vendor functions using the stored procedure
+ /// escape syntax is unsupported;
+ /// - true: if invoking user-defined or vendor functions using the stored procedure
+ /// escape syntax is supported.
SQL_STORED_FUNCTIONS_USING_CALL_SYNTAX_SUPPORTED = 576,
+
+ /// @}
};
+ /// Indicate whether something (e.g. an identifier) is case-sensitive.
enum SqlSupportedCaseSensitivity {
SQL_CASE_SENSITIVITY_UNKNOWN = 0,
SQL_CASE_SENSITIVITY_CASE_INSENSITIVE = 1,
SQL_CASE_SENSITIVITY_UPPERCASE = 2,
};
+ /// Indicate how nulls are sorted.
enum SqlNullOrdering {
SQL_NULLS_SORTED_HIGH = 0,
SQL_NULLS_SORTED_LOW = 1,
@@ -854,6 +809,7 @@ struct SqlInfoOptions {
SQL_NULLS_SORTED_AT_END = 3,
};
+ /// Type identifiers used to indicate support for converting between types.
enum SqlSupportsConvert {
SQL_CONVERT_BIGINT = 0,
SQL_CONVERT_BINARY = 1,
@@ -878,13 +834,18 @@ struct SqlInfoOptions {
};
};
-/// \brief Table reference, optionally containing table's catalog and db_schema.
+/// \brief A SQL %table reference, optionally containing table's catalog and db_schema.
struct TableRef {
+ /// \brief The table's catalog.
util::optional<std::string> catalog;
+ /// \brief The table's database schema.
util::optional<std::string> db_schema;
+ /// \brief The table name.
std::string table;
};
+/// @}
+
} // namespace sql
} // namespace flight
} // namespace arrow
diff --git a/cpp/src/arrow/flight/test_definitions.cc b/cpp/src/arrow/flight/test_definitions.cc
index 5ead99f94b..1ec06a1f00 100644
--- a/cpp/src/arrow/flight/test_definitions.cc
+++ b/cpp/src/arrow/flight/test_definitions.cc
@@ -45,7 +45,7 @@ using arrow::internal::checked_cast;
void ConnectivityTest::TestGetPort() {
std::unique_ptr<FlightServerBase> server = ExampleTestServer();
- ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0));
+ ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0));
FlightServerOptions options(location);
ASSERT_OK(server->Init(options));
ASSERT_GT(server->port(), 0);
@@ -53,7 +53,7 @@ void ConnectivityTest::TestGetPort() {
void ConnectivityTest::TestBuilderHook() {
std::unique_ptr<FlightServerBase> server = ExampleTestServer();
- ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0));
+ ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0));
FlightServerOptions options(location);
bool builder_hook_run = false;
options.builder_hook = [&builder_hook_run](void* builder) {
@@ -68,7 +68,7 @@ void ConnectivityTest::TestBuilderHook() {
void ConnectivityTest::TestShutdown() {
// Regression test for ARROW-15181
constexpr int kIterations = 10;
- ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0));
+ ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0));
for (int i = 0; i < kIterations; i++) {
std::unique_ptr<FlightServerBase> server = ExampleTestServer();
@@ -84,7 +84,7 @@ void ConnectivityTest::TestShutdown() {
void ConnectivityTest::TestShutdownWithDeadline() {
std::unique_ptr<FlightServerBase> server = ExampleTestServer();
- ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0));
+ ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0));
FlightServerOptions options(location);
ASSERT_OK(server->Init(options));
ASSERT_GT(server->port(), 0);
@@ -96,20 +96,19 @@ void ConnectivityTest::TestShutdownWithDeadline() {
}
void ConnectivityTest::TestBrokenConnection() {
std::unique_ptr<FlightServerBase> server = ExampleTestServer();
- ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0));
+ ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0));
FlightServerOptions options(location);
ASSERT_OK(server->Init(options));
std::unique_ptr<FlightClient> client;
ASSERT_OK_AND_ASSIGN(location,
- Location::ForScheme(transport(), "localhost", server->port()));
- ASSERT_OK(FlightClient::Connect(location, &client));
+ Location::ForScheme(transport(), "127.0.0.1", server->port()));
+ ASSERT_OK_AND_ASSIGN(client, FlightClient::Connect(location));
ASSERT_OK(server->Shutdown());
ASSERT_OK(server->Wait());
- std::unique_ptr<FlightInfo> info;
- ASSERT_RAISES(IOError, client->GetFlightInfo(FlightDescriptor::Command(""), &info));
+ ASSERT_RAISES(IOError, client->GetFlightInfo(FlightDescriptor::Command("")));
}
//------------------------------------------------------------
@@ -118,7 +117,7 @@ void ConnectivityTest::TestBrokenConnection() {
void DataTest::SetUp() {
server_ = ExampleTestServer();
- ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0));
+ ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0));
FlightServerOptions options(location);
ASSERT_OK(server_->Init(options));
@@ -130,16 +129,16 @@ void DataTest::TearDown() {
}
Status DataTest::ConnectClient() {
ARROW_ASSIGN_OR_RAISE(auto location,
- Location::ForScheme(transport(), "localhost", server_->port()));
- return FlightClient::Connect(location, &client_);
+ Location::ForScheme(transport(), "127.0.0.1", server_->port()));
+ ARROW_ASSIGN_OR_RAISE(client_, FlightClient::Connect(location));
+ return Status::OK();
}
void DataTest::CheckDoGet(
const FlightDescriptor& descr, const RecordBatchVector& expected_batches,
std::function<void(const std::vector<FlightEndpoint>&)> check_endpoints) {
auto expected_schema = expected_batches[0]->schema();
- std::unique_ptr<FlightInfo> info;
- ASSERT_OK(client_->GetFlightInfo(descr, &info));
+ ASSERT_OK_AND_ASSIGN(auto info, client_->GetFlightInfo(descr));
check_endpoints(info->endpoints());
ipc::DictionaryMemo dict_memo;
@@ -155,11 +154,9 @@ void DataTest::CheckDoGet(const Ticket& ticket,
auto num_batches = static_cast<int>(expected_batches.size());
ASSERT_GE(num_batches, 2);
- std::unique_ptr<FlightStreamReader> stream;
- ASSERT_OK(client_->DoGet(ticket, &stream));
+ ASSERT_OK_AND_ASSIGN(auto stream, client_->DoGet(ticket));
- std::unique_ptr<FlightStreamReader> stream2;
- ASSERT_OK(client_->DoGet(ticket, &stream2));
+ ASSERT_OK_AND_ASSIGN(auto stream2, client_->DoGet(ticket));
ASSERT_OK_AND_ASSIGN(auto reader, MakeRecordBatchReader(std::move(stream2)));
std::shared_ptr<RecordBatch> batch;
@@ -247,7 +244,7 @@ void DataTest::TestOverflowServerBatch() {
// DoGet: check for overflow on large batch
Ticket ticket{"ARROW-13253-DoGet-Batch"};
std::unique_ptr<FlightStreamReader> stream;
- ASSERT_OK(client_->DoGet(ticket, &stream));
+ ASSERT_OK_AND_ASSIGN(stream, client_->DoGet(ticket));
FlightStreamChunk chunk;
EXPECT_RAISES_WITH_MESSAGE_THAT(
Invalid, ::testing::HasSubstr("Cannot send record batches exceeding 2GiB yet"),
@@ -256,35 +253,30 @@ void DataTest::TestOverflowServerBatch() {
{
// DoExchange: check for overflow on large batch from server
auto descr = FlightDescriptor::Command("large_batch");
- std::unique_ptr<FlightStreamReader> reader;
- std::unique_ptr<FlightStreamWriter> writer;
- ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
+ ASSERT_OK_AND_ASSIGN(auto do_exchange_result, client_->DoExchange(descr));
RecordBatchVector batches;
EXPECT_RAISES_WITH_MESSAGE_THAT(
Invalid, ::testing::HasSubstr("Cannot send record batches exceeding 2GiB yet"),
- reader->ToRecordBatches().Value(&batches));
- ARROW_UNUSED(writer->Close());
+ do_exchange_result.reader->ToRecordBatches().Value(&batches));
+ ARROW_UNUSED(do_exchange_result.writer->Close());
}
}
void DataTest::TestOverflowClientBatch() {
ASSERT_OK_AND_ASSIGN(auto batch, VeryLargeBatch());
{
// DoPut: check for overflow on large batch
- std::unique_ptr<FlightStreamWriter> stream;
- std::unique_ptr<FlightMetadataReader> reader;
auto descr = FlightDescriptor::Path({""});
- ASSERT_OK(client_->DoPut(descr, batch->schema(), &stream, &reader));
+ ASSERT_OK_AND_ASSIGN(auto do_put_result, client_->DoPut(descr, batch->schema()));
EXPECT_RAISES_WITH_MESSAGE_THAT(
Invalid, ::testing::HasSubstr("Cannot send record batches exceeding 2GiB yet"),
- stream->WriteRecordBatch(*batch));
- ASSERT_OK(stream->Close());
+ do_put_result.writer->WriteRecordBatch(*batch));
+ ASSERT_OK(do_put_result.writer->Close());
}
{
// DoExchange: check for overflow on large batch from client
auto descr = FlightDescriptor::Command("counter");
- std::unique_ptr<FlightStreamReader> reader;
- std::unique_ptr<FlightStreamWriter> writer;
- ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
+ ASSERT_OK_AND_ASSIGN(auto exchange, client_->DoExchange(descr));
+ auto writer = std::move(exchange.writer);
ASSERT_OK(writer->Begin(batch->schema()));
EXPECT_RAISES_WITH_MESSAGE_THAT(
Invalid, ::testing::HasSubstr("Cannot send record batches exceeding 2GiB yet"),
@@ -298,9 +290,9 @@ void DataTest::TestDoExchange() {
auto a1 = ArrayFromJSON(int32(), "[4, 5, 6, null]");
auto schema = arrow::schema({field("f1", a1->type())});
batches.push_back(RecordBatch::Make(schema, a1->length(), {a1}));
- std::unique_ptr<FlightStreamReader> reader;
- std::unique_ptr<FlightStreamWriter> writer;
- ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
+ ASSERT_OK_AND_ASSIGN(auto exchange, client_->DoExchange(descr));
+ std::unique_ptr<FlightStreamReader> reader = std::move(exchange.reader);
+ std::unique_ptr<FlightStreamWriter> writer = std::move(exchange.writer);
ASSERT_OK(writer->Begin(schema));
for (const auto& batch : batches) {
ASSERT_OK(writer->WriteRecordBatch(*batch));
@@ -322,9 +314,9 @@ void DataTest::TestDoExchange() {
// schema messages
void DataTest::TestDoExchangeNoData() {
auto descr = FlightDescriptor::Command("counter");
- std::unique_ptr<FlightStreamReader> reader;
- std::unique_ptr<FlightStreamWriter> writer;
- ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
+ ASSERT_OK_AND_ASSIGN(auto exchange, client_->DoExchange(descr));
+ std::unique_ptr<FlightStreamReader> reader = std::move(exchange.reader);
+ std::unique_ptr<FlightStreamWriter> writer = std::move(exchange.writer);
ASSERT_OK(writer->DoneWriting());
ASSERT_OK_AND_ASSIGN(auto chunk, reader->Next());
ASSERT_EQ(nullptr, chunk.data);
@@ -336,9 +328,9 @@ void DataTest::TestDoExchangeNoData() {
// in the client-side writer.
void DataTest::TestDoExchangeWriteOnlySchema() {
auto descr = FlightDescriptor::Command("counter");
- std::unique_ptr<FlightStreamReader> reader;
- std::unique_ptr<FlightStreamWriter> writer;
- ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
+ ASSERT_OK_AND_ASSIGN(auto exchange, client_->DoExchange(descr));
+ std::unique_ptr<FlightStreamReader> reader = std::move(exchange.reader);
+ std::unique_ptr<FlightStreamWriter> writer = std::move(exchange.writer);
auto schema = arrow::schema({field("f1", arrow::int32())});
ASSERT_OK(writer->Begin(schema));
ASSERT_OK(writer->WriteMetadata(Buffer::FromString("foo")));
@@ -352,9 +344,9 @@ void DataTest::TestDoExchangeWriteOnlySchema() {
// Emulate DoGet
void DataTest::TestDoExchangeGet() {
auto descr = FlightDescriptor::Command("get");
- std::unique_ptr<FlightStreamReader> reader;
- std::unique_ptr<FlightStreamWriter> writer;
- ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
+ ASSERT_OK_AND_ASSIGN(auto exchange, client_->DoExchange(descr));
+ std::unique_ptr<FlightStreamReader> reader = std::move(exchange.reader);
+ std::unique_ptr<FlightStreamWriter> writer = std::move(exchange.writer);
ASSERT_OK_AND_ASSIGN(auto server_schema, reader->GetSchema());
AssertSchemaEqual(*ExampleIntSchema(), *server_schema);
RecordBatchVector batches;
@@ -372,9 +364,9 @@ void DataTest::TestDoExchangeGet() {
// Emulate DoPut
void DataTest::TestDoExchangePut() {
auto descr = FlightDescriptor::Command("put");
- std::unique_ptr<FlightStreamReader> reader;
- std::unique_ptr<FlightStreamWriter> writer;
- ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
+ ASSERT_OK_AND_ASSIGN(auto exchange, client_->DoExchange(descr));
+ std::unique_ptr<FlightStreamReader> reader = std::move(exchange.reader);
+ std::unique_ptr<FlightStreamWriter> writer = std::move(exchange.writer);
ASSERT_OK(writer->Begin(ExampleIntSchema()));
RecordBatchVector batches;
ASSERT_OK(ExampleIntBatches(&batches));
@@ -393,9 +385,9 @@ void DataTest::TestDoExchangePut() {
// Test the echo server
void DataTest::TestDoExchangeEcho() {
auto descr = FlightDescriptor::Command("echo");
- std::unique_ptr<FlightStreamReader> reader;
- std::unique_ptr<FlightStreamWriter> writer;
- ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
+ ASSERT_OK_AND_ASSIGN(auto exchange, client_->DoExchange(descr));
+ std::unique_ptr<FlightStreamReader> reader = std::move(exchange.reader);
+ std::unique_ptr<FlightStreamWriter> writer = std::move(exchange.writer);
ASSERT_OK(writer->Begin(ExampleIntSchema()));
RecordBatchVector batches;
ASSERT_OK(ExampleIntBatches(&batches));
@@ -445,7 +437,9 @@ void DataTest::TestDoExchangeTotal() {
// here.
EXPECT_RAISES_WITH_MESSAGE_THAT(
Invalid, ::testing::HasSubstr("Field is not INT64: f1"), ([&]() {
- RETURN_NOT_OK(client_->DoExchange(descr, &writer, &reader));
+ ARROW_ASSIGN_OR_RAISE(auto exchange, client_->DoExchange(descr));
+ reader = std::move(exchange.reader);
+ writer = std::move(exchange.writer);
RETURN_NOT_OK(writer->Begin(schema));
auto batch = RecordBatch::Make(schema, /* num_rows */ 4, {a1});
RETURN_NOT_OK(writer->WriteRecordBatch(*batch));
@@ -456,7 +450,9 @@ void DataTest::TestDoExchangeTotal() {
auto a1 = ArrayFromJSON(arrow::int64(), "[1, 2, null, 3]");
auto a2 = ArrayFromJSON(arrow::int64(), "[null, 4, 5, 6]");
auto schema = arrow::schema({field("f1", a1->type()), field("f2", a2->type())});
- ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
+ ASSERT_OK_AND_ASSIGN(auto exchange, client_->DoExchange(descr));
+ reader = std::move(exchange.reader);
+ writer = std::move(exchange.writer);
ASSERT_OK(writer->Begin(schema));
auto batch = RecordBatch::Make(schema, /* num_rows */ 4, {a1, a2});
ASSERT_OK(writer->WriteRecordBatch(*batch));
@@ -484,22 +480,28 @@ void DataTest::TestDoExchangeTotal() {
// Ensure server errors get propagated no matter what we try
void DataTest::TestDoExchangeError() {
auto descr = FlightDescriptor::Command("error");
- std::unique_ptr<FlightStreamReader> reader;
std::unique_ptr<FlightStreamWriter> writer;
+ std::unique_ptr<FlightStreamReader> reader;
{
- ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
+ ASSERT_OK_AND_ASSIGN(auto exchange, client_->DoExchange(descr));
+ writer = std::move(exchange.writer);
+ reader = std::move(exchange.reader);
auto status = writer->Close();
EXPECT_RAISES_WITH_MESSAGE_THAT(
NotImplemented, ::testing::HasSubstr("Expected error"), writer->Close());
}
{
- ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
+ ASSERT_OK_AND_ASSIGN(auto exchange, client_->DoExchange(descr));
+ writer = std::move(exchange.writer);
+ reader = std::move(exchange.reader);
EXPECT_RAISES_WITH_MESSAGE_THAT(
NotImplemented, ::testing::HasSubstr("Expected error"), reader->Next());
ARROW_UNUSED(writer->Close());
}
{
- ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
+ ASSERT_OK_AND_ASSIGN(auto exchange, client_->DoExchange(descr));
+ writer = std::move(exchange.writer);
+ reader = std::move(exchange.reader);
EXPECT_RAISES_WITH_MESSAGE_THAT(
NotImplemented, ::testing::HasSubstr("Expected error"), reader->GetSchema());
ARROW_UNUSED(writer->Close());
@@ -513,9 +515,9 @@ void DataTest::TestDoExchangeError() {
void DataTest::TestDoExchangeConcurrency() {
// Ensure that we can do reads/writes on separate threads
auto descr = FlightDescriptor::Command("echo");
- std::unique_ptr<FlightStreamReader> reader;
- std::unique_ptr<FlightStreamWriter> writer;
- ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
+ ASSERT_OK_AND_ASSIGN(auto exchange, client_->DoExchange(descr));
+ std::unique_ptr<FlightStreamReader> reader = std::move(exchange.reader);
+ std::unique_ptr<FlightStreamWriter> writer = std::move(exchange.writer);
RecordBatchVector batches;
ASSERT_OK(ExampleIntBatches(&batches));
@@ -546,9 +548,9 @@ void DataTest::TestDoExchangeUndrained() {
auto descr = FlightDescriptor::Command("TestUndrained");
auto schema = arrow::schema({arrow::field("ints", int64())});
- std::unique_ptr<FlightStreamReader> reader;
- std::unique_ptr<FlightStreamWriter> writer;
- ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
+ ASSERT_OK_AND_ASSIGN(auto exchange, client_->DoExchange(descr));
+ std::unique_ptr<FlightStreamReader> reader = std::move(exchange.reader);
+ std::unique_ptr<FlightStreamWriter> writer = std::move(exchange.writer);
auto batch = RecordBatchFromJSON(schema, "[[1], [2], [3], [4]]");
ASSERT_OK(writer->Begin(schema));
@@ -567,13 +569,12 @@ void DataTest::TestIssue5095() {
// Make sure the server-side error message is reflected to the
// client
Ticket ticket1{"ARROW-5095-fail"};
- std::unique_ptr<FlightStreamReader> stream;
- Status status = client_->DoGet(ticket1, &stream);
+ Status status = client_->DoGet(ticket1).status();
ASSERT_RAISES(UnknownError, status);
ASSERT_THAT(status.message(), ::testing::HasSubstr("Server-side error"));
Ticket ticket2{"ARROW-5095-success"};
- status = client_->DoGet(ticket2, &stream);
+ status = client_->DoGet(ticket2).status();
ASSERT_RAISES(KeyError, status);
ASSERT_THAT(status.message(), ::testing::HasSubstr("No data"));
}
@@ -637,7 +638,7 @@ class DoPutTestServer : public FlightServerBase {
};
void DoPutTest::SetUp() {
- ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0));
+ ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0));
ASSERT_OK(MakeServer<DoPutTestServer>(
location, &server_, &client_,
[](FlightServerOptions* options) { return Status::OK(); },
@@ -660,9 +661,9 @@ void DoPutTest::CheckBatches(const FlightDescriptor& expected_descriptor,
void DoPutTest::CheckDoPut(const FlightDescriptor& descr,
const std::shared_ptr<Schema>& schema,
const RecordBatchVector& batches) {
- std::unique_ptr<FlightStreamWriter> stream;
- std::unique_ptr<FlightMetadataReader> reader;
- ASSERT_OK(client_->DoPut(descr, schema, &stream, &reader));
+ ASSERT_OK_AND_ASSIGN(auto do_put_result, client_->DoPut(descr, schema));
+ std::unique_ptr<FlightStreamWriter> writer = std::move(do_put_result.writer);
+ std::unique_ptr<FlightMetadataReader> reader = std::move(do_put_result.reader);
// Ensure that the reader can be used independently of the writer
std::thread reader_thread([&reader, &batches]() {
@@ -675,18 +676,18 @@ void DoPutTest::CheckDoPut(const FlightDescriptor& descr,
int64_t counter = 0;
for (const auto& batch : batches) {
if (counter % 2 == 0) {
- ASSERT_OK(stream->WriteRecordBatch(*batch));
+ ASSERT_OK(writer->WriteRecordBatch(*batch));
} else {
auto buffer = Buffer::FromString(std::to_string(counter));
- ASSERT_OK(stream->WriteWithMetadata(*batch, std::move(buffer)));
+ ASSERT_OK(writer->WriteWithMetadata(*batch, std::move(buffer)));
}
counter++;
}
// Write a metadata-only message
- ASSERT_OK(stream->WriteMetadata(Buffer::FromString(kExpectedMetadata)));
- ASSERT_OK(stream->DoneWriting());
+ ASSERT_OK(writer->WriteMetadata(Buffer::FromString(kExpectedMetadata)));
+ ASSERT_OK(writer->DoneWriting());
reader_thread.join();
- ASSERT_OK(stream->Close());
+ ASSERT_OK(writer->Close());
CheckBatches(descr, batches);
}
@@ -765,11 +766,10 @@ void DoPutTest::TestLargeBatch() {
void DoPutTest::TestSizeLimit() {
const int64_t size_limit = 4096;
ASSERT_OK_AND_ASSIGN(auto location,
- Location::ForScheme(transport(), "localhost", server_->port()));
+ Location::ForScheme(transport(), "127.0.0.1", server_->port()));
auto client_options = FlightClientOptions::Defaults();
client_options.write_size_limit_bytes = size_limit;
- std::unique_ptr<FlightClient> client;
- ASSERT_OK(FlightClient::Connect(location, client_options, &client));
+ ASSERT_OK_AND_ASSIGN(auto client, FlightClient::Connect(location, client_options));
auto descr = FlightDescriptor::Command("simple");
// Batch is too large to fit in one message
@@ -778,12 +778,12 @@ void DoPutTest::TestSizeLimit() {
auto batch1 = batch->Slice(0, 384);
auto batch2 = batch->Slice(384);
- std::unique_ptr<FlightStreamWriter> stream;
- std::unique_ptr<FlightMetadataReader> reader;
- ASSERT_OK(client->DoPut(descr, schema, &stream, &reader));
+ ASSERT_OK_AND_ASSIGN(auto do_put_result, client->DoPut(descr, schema));
+ std::unique_ptr<FlightStreamWriter> writer = std::move(do_put_result.writer);
+ std::unique_ptr<FlightMetadataReader> reader = std::move(do_put_result.reader);
// Large batch will exceed the limit
- const auto status = stream->WriteRecordBatch(*batch);
+ const auto status = writer->WriteRecordBatch(*batch);
EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("exceeded soft limit"),
status);
auto detail = FlightWriteSizeStatusDetail::UnwrapStatus(status);
@@ -792,14 +792,14 @@ void DoPutTest::TestSizeLimit() {
ASSERT_GT(detail->actual(), size_limit);
// But we can retry with smaller batches
- ASSERT_OK(stream->WriteRecordBatch(*batch1));
- ASSERT_OK(stream->WriteWithMetadata(*batch2, Buffer::FromString("1")));
+ ASSERT_OK(writer->WriteRecordBatch(*batch1));
+ ASSERT_OK(writer->WriteWithMetadata(*batch2, Buffer::FromString("1")));
// Write a metadata-only message
- ASSERT_OK(stream->WriteMetadata(Buffer::FromString(kExpectedMetadata)));
+ ASSERT_OK(writer->WriteMetadata(Buffer::FromString(kExpectedMetadata)));
- ASSERT_OK(stream->DoneWriting());
- ASSERT_OK(stream->Close());
+ ASSERT_OK(writer->DoneWriting());
+ ASSERT_OK(writer->Close());
CheckBatches(descr, {batch1, batch2});
}
void DoPutTest::TestUndrained() {
@@ -808,17 +808,17 @@ void DoPutTest::TestUndrained() {
auto descr = FlightDescriptor::Command("TestUndrained");
auto schema = arrow::schema({arrow::field("ints", int64())});
- std::unique_ptr<FlightStreamWriter> stream;
- std::unique_ptr<FlightMetadataReader> reader;
- ASSERT_OK(client_->DoPut(descr, schema, &stream, &reader));
+ ASSERT_OK_AND_ASSIGN(auto do_put_result, client_->DoPut(descr, schema));
+ std::unique_ptr<FlightStreamWriter> writer = std::move(do_put_result.writer);
+ std::unique_ptr<FlightMetadataReader> reader = std::move(do_put_result.reader);
auto batch = RecordBatchFromJSON(schema, "[[1], [2], [3], [4]]");
// These calls may or may not fail depending on how quickly the
// transport reacts, whether it batches, writes, etc.
- ARROW_UNUSED(stream->WriteRecordBatch(*batch));
- ARROW_UNUSED(stream->WriteRecordBatch(*batch));
- ARROW_UNUSED(stream->WriteRecordBatch(*batch));
- ARROW_UNUSED(stream->WriteRecordBatch(*batch));
- ASSERT_OK(stream->Close());
+ ARROW_UNUSED(writer->WriteRecordBatch(*batch));
+ ARROW_UNUSED(writer->WriteRecordBatch(*batch));
+ ARROW_UNUSED(writer->WriteRecordBatch(*batch));
+ ARROW_UNUSED(writer->WriteRecordBatch(*batch));
+ ASSERT_OK(writer->Close());
// We should be able to make another call
CheckDoPut(FlightDescriptor::Command("foo"), schema, {batch, batch});
@@ -866,7 +866,7 @@ Status AppMetadataTestServer::DoPut(const ServerCallContext& context,
}
void AppMetadataTest::SetUp() {
- ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0));
+ ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0));
ASSERT_OK(MakeServer<AppMetadataTestServer>(
location, &server_, &client_,
[](FlightServerOptions* options) { return Status::OK(); },
@@ -878,8 +878,8 @@ void AppMetadataTest::TearDown() {
}
void AppMetadataTest::TestDoGet() {
Ticket ticket{""};
- std::unique_ptr<FlightStreamReader> stream;
- ASSERT_OK(client_->DoGet(ticket, &stream));
+ ASSERT_OK_AND_ASSIGN(std::unique_ptr<FlightStreamReader> stream,
+ client_->DoGet(ticket));
RecordBatchVector expected_batches;
ASSERT_OK(ExampleIntBatches(&expected_batches));
@@ -901,8 +901,8 @@ void AppMetadataTest::TestDoGet() {
// from the record batch, and not one of the dictionary batches.
void AppMetadataTest::TestDoGetDictionaries() {
Ticket ticket{"dicts"};
- std::unique_ptr<FlightStreamReader> stream;
- ASSERT_OK(client_->DoGet(ticket, &stream));
+ ASSERT_OK_AND_ASSIGN(std::unique_ptr<FlightStreamReader> stream,
+ client_->DoGet(ticket));
RecordBatchVector expected_batches;
ASSERT_OK(ExampleDictBatches(&expected_batches));
@@ -919,10 +919,9 @@ void AppMetadataTest::TestDoGetDictionaries() {
ASSERT_EQ(nullptr, chunk.data);
}
void AppMetadataTest::TestDoPut() {
- std::unique_ptr<FlightStreamWriter> writer;
- std::unique_ptr<FlightMetadataReader> reader;
std::shared_ptr<Schema> schema = ExampleIntSchema();
- ASSERT_OK(client_->DoPut(FlightDescriptor{}, schema, &writer, &reader));
+ ASSERT_OK_AND_ASSIGN(auto do_put_result, client_->DoPut(FlightDescriptor{}, schema));
+ std::unique_ptr<FlightStreamWriter> writer = std::move(do_put_result.writer);
RecordBatchVector expected_batches;
ASSERT_OK(ExampleIntBatches(&expected_batches));
@@ -943,8 +942,6 @@ void AppMetadataTest::TestDoPut() {
// Test DoPut() with dictionaries. This tests a corner case in the
// server-side reader; see DoGetDictionaries above.
void AppMetadataTest::TestDoPutDictionaries() {
- std::unique_ptr<FlightStreamWriter> writer;
- std::unique_ptr<FlightMetadataReader> reader;
RecordBatchVector expected_batches;
ASSERT_OK(ExampleDictBatches(&expected_batches));
// ARROW-8749: don't get the schema via ExampleDictSchema because
@@ -953,8 +950,10 @@ void AppMetadataTest::TestDoPutDictionaries() {
// (identity-wise) than the schema of the first batch we write,
// we'll end up generating a duplicate set of dictionaries that
// confuses the reader.
- ASSERT_OK(client_->DoPut(FlightDescriptor{}, expected_batches[0]->schema(), &writer,
- &reader));
+ ASSERT_OK_AND_ASSIGN(auto do_put_result,
+ client_->DoPut(FlightDescriptor{}, expected_batches[0]->schema()));
+ std::unique_ptr<FlightStreamWriter> writer = std::move(do_put_result.writer);
+
std::shared_ptr<RecordBatch> chunk;
std::shared_ptr<Buffer> metadata;
auto num_batches = static_cast<int>(expected_batches.size());
@@ -965,10 +964,10 @@ void AppMetadataTest::TestDoPutDictionaries() {
ASSERT_OK(writer->Close());
}
void AppMetadataTest::TestDoPutReadMetadata() {
- std::unique_ptr<FlightStreamWriter> writer;
- std::unique_ptr<FlightMetadataReader> reader;
std::shared_ptr<Schema> schema = ExampleIntSchema();
- ASSERT_OK(client_->DoPut(FlightDescriptor{}, schema, &writer, &reader));
+ ASSERT_OK_AND_ASSIGN(auto do_put_result, client_->DoPut(FlightDescriptor{}, schema));
+ std::unique_ptr<FlightStreamWriter> writer = std::move(do_put_result.writer);
+ std::unique_ptr<FlightMetadataReader> reader = std::move(do_put_result.reader);
RecordBatchVector expected_batches;
ASSERT_OK(ExampleIntBatches(&expected_batches));
@@ -1046,7 +1045,7 @@ class IpcOptionsTestServer : public FlightServerBase {
};
void IpcOptionsTest::SetUp() {
- ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0));
+ ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0));
ASSERT_OK(MakeServer<IpcOptionsTestServer>(
location, &server_, &client_,
[](FlightServerOptions* options) { return Status::OK(); },
@@ -1062,7 +1061,7 @@ void IpcOptionsTest::TestDoGetReadOptions() {
auto options = FlightCallOptions();
options.read_options.max_recursion_depth = 1;
std::unique_ptr<FlightStreamReader> stream;
- ASSERT_OK(client_->DoGet(options, ticket, &stream));
+ ASSERT_OK_AND_ASSIGN(stream, client_->DoGet(options, ticket));
ASSERT_RAISES(Invalid, stream->Next());
}
void IpcOptionsTest::TestDoPutWriteOptions() {
@@ -1074,10 +1073,10 @@ void IpcOptionsTest::TestDoPutWriteOptions() {
auto options = FlightCallOptions();
options.write_options.max_recursion_depth = 1;
- ASSERT_OK(client_->DoPut(options, FlightDescriptor{}, expected_batches[0]->schema(),
- &writer, &reader));
+ ASSERT_OK_AND_ASSIGN(auto do_put_result, client_->DoPut(options, FlightDescriptor{},
+ expected_batches[0]->schema()));
for (const auto& batch : expected_batches) {
- ASSERT_RAISES(Invalid, writer->WriteRecordBatch(*batch));
+ ASSERT_RAISES(Invalid, do_put_result.writer->WriteRecordBatch(*batch));
}
}
void IpcOptionsTest::TestDoExchangeClientWriteOptions() {
@@ -1086,9 +1085,8 @@ void IpcOptionsTest::TestDoExchangeClientWriteOptions() {
auto options = FlightCallOptions();
options.write_options.max_recursion_depth = 1;
auto descr = FlightDescriptor::Command("");
- std::unique_ptr<FlightStreamReader> reader;
- std::unique_ptr<FlightStreamWriter> writer;
- ASSERT_OK(client_->DoExchange(options, descr, &writer, &reader));
+ ASSERT_OK_AND_ASSIGN(auto do_exchange_result, client_->DoExchange(options, descr));
+ std::unique_ptr<FlightStreamWriter> writer = std::move(do_exchange_result.writer);
RecordBatchVector batches;
ASSERT_OK(ExampleNestedBatches(&batches));
ASSERT_OK(writer->Begin(batches[0]->schema()));
@@ -1103,9 +1101,8 @@ void IpcOptionsTest::TestDoExchangeClientWriteOptionsBegin() {
// fail the call. Here the options are set explicitly when we write data and not in the
// call options.
auto descr = FlightDescriptor::Command("");
- std::unique_ptr<FlightStreamReader> reader;
- std::unique_ptr<FlightStreamWriter> writer;
- ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
+ ASSERT_OK_AND_ASSIGN(auto do_exchange_result, client_->DoExchange(descr));
+ std::unique_ptr<FlightStreamWriter> writer = std::move(do_exchange_result.writer);
RecordBatchVector batches;
ASSERT_OK(ExampleNestedBatches(&batches));
auto options = ipc::IpcWriteOptions::Defaults();
@@ -1121,9 +1118,8 @@ void IpcOptionsTest::TestDoExchangeServerWriteOptions() {
// Call DoExchange and write nested data, but with a very low nesting depth set to fail
// the call. (The low nesting depth is set on the server side.)
auto descr = FlightDescriptor::Command("");
- std::unique_ptr<FlightStreamReader> reader;
- std::unique_ptr<FlightStreamWriter> writer;
- ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
+ ASSERT_OK_AND_ASSIGN(auto do_exchange_result, client_->DoExchange(descr));
+ std::unique_ptr<FlightStreamWriter> writer = std::move(do_exchange_result.writer);
RecordBatchVector batches;
ASSERT_OK(ExampleNestedBatches(&batches));
ASSERT_OK(writer->Begin(batches[0]->schema()));
@@ -1245,7 +1241,7 @@ void CudaDataTest::SetUp() {
impl_->device = std::move(device);
impl_->context = std::move(context);
- ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "localhost", 0));
+ ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0));
ASSERT_OK(MakeServer<CudaTestServer>(
location, &server_, &client_,
[this](FlightServerOptions* options) {
@@ -1269,8 +1265,7 @@ void CudaDataTest::TestDoGet() {
checked_cast<CudaTestServer*>(server_.get())->batches();
Ticket ticket{""};
- std::unique_ptr<FlightStreamReader> stream;
- ASSERT_OK(client_->DoGet(options, ticket, &stream));
+ ASSERT_OK_AND_ASSIGN(auto stream, client_->DoGet(options, ticket));
size_t idx = 0;
while (true) {
@@ -1294,10 +1289,9 @@ void CudaDataTest::TestDoPut() {
RecordBatchVector batches;
ASSERT_OK(ExampleIntBatches(&batches));
- std::unique_ptr<FlightStreamWriter> writer;
- std::unique_ptr<FlightMetadataReader> reader;
auto descriptor = FlightDescriptor::Path({""});
- ASSERT_OK(client_->DoPut(descriptor, batches[0]->schema(), &writer, &reader));
+ ASSERT_OK_AND_ASSIGN(auto do_put_result,
+ client_->DoPut(descriptor, batches[0]->schema()));
ipc::DictionaryMemo memo;
for (const auto& batch : batches) {
@@ -1307,9 +1301,9 @@ void CudaDataTest::TestDoPut() {
cuda::ReadRecordBatch(batch->schema(), &memo, buffer));
ASSERT_OK(CheckBuffersOnDevice(*cuda_batch, *impl_->device));
- ASSERT_OK(writer->WriteRecordBatch(*cuda_batch));
+ ASSERT_OK(do_put_result.writer->WriteRecordBatch(*cuda_batch));
}
- ASSERT_OK(writer->Close());
+ ASSERT_OK(do_put_result.writer->Close());
ASSERT_OK(impl_->context->Synchronize());
const RecordBatchVector& written =
@@ -1332,11 +1326,9 @@ void CudaDataTest::TestDoExchange() {
RecordBatchVector batches;
ASSERT_OK(ExampleIntBatches(&batches));
- std::unique_ptr<FlightStreamWriter> writer;
- std::unique_ptr<FlightStreamReader> reader;
auto descriptor = FlightDescriptor::Path({""});
- ASSERT_OK(client_->DoExchange(options, descriptor, &writer, &reader));
- ASSERT_OK(writer->Begin(batches[0]->schema()));
+ ASSERT_OK_AND_ASSIGN(auto exchange, client_->DoExchange(options, descriptor));
+ ASSERT_OK(exchange.writer->Begin(batches[0]->schema()));
ipc::DictionaryMemo write_memo;
ipc::DictionaryMemo read_memo;
@@ -1347,16 +1339,16 @@ void CudaDataTest::TestDoExchange() {
cuda::ReadRecordBatch(batch->schema(), &write_memo, buffer));
ASSERT_OK(CheckBuffersOnDevice(*cuda_batch, *impl_->device));
- ASSERT_OK(writer->WriteRecordBatch(*cuda_batch));
+ ASSERT_OK(exchange.writer->WriteRecordBatch(*cuda_batch));
- ASSERT_OK_AND_ASSIGN(auto chunk, reader->Next());
+ ASSERT_OK_AND_ASSIGN(auto chunk, exchange.reader->Next());
ASSERT_OK(CheckBuffersOnDevice(*chunk.data, *impl_->device));
// Bounce record batch back to host memory
ASSERT_OK_AND_ASSIGN(auto host_batch, CopyBatchToHost(*chunk.data));
AssertBatchesEqual(*batch, *host_batch);
}
- ASSERT_OK(writer->Close());
+ ASSERT_OK(exchange.writer->Close());
}
#else
diff --git a/cpp/src/arrow/flight/test_util.h b/cpp/src/arrow/flight/test_util.h
index e44ca001d2..d5b774b4a3 100644
--- a/cpp/src/arrow/flight/test_util.h
+++ b/cpp/src/arrow/flight/test_util.h
@@ -113,11 +113,11 @@ Status MakeServer(const Location& location, std::unique_ptr<FlightServerBase>* s
RETURN_NOT_OK(make_server_options(&server_options));
RETURN_NOT_OK((*server)->Init(server_options));
std::string uri =
- location.scheme() + "://localhost:" + std::to_string((*server)->port());
+ location.scheme() + "://127.0.0.1:" + std::to_string((*server)->port());
ARROW_ASSIGN_OR_RAISE(auto real_location, Location::Parse(uri));
FlightClientOptions client_options = FlightClientOptions::Defaults();
RETURN_NOT_OK(make_client_options(&client_options));
- return FlightClient::Connect(real_location, client_options, client);
+ return FlightClient::Connect(real_location, client_options).Value(client);
}
// Helper to initialize a server and matching client with callbacks to
diff --git a/cpp/src/arrow/flight/transport/ucx/CMakeLists.txt b/cpp/src/arrow/flight/transport/ucx/CMakeLists.txt
new file mode 100644
index 0000000000..6e315b68d6
--- /dev/null
+++ b/cpp/src/arrow/flight/transport/ucx/CMakeLists.txt
@@ -0,0 +1,77 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+add_custom_target(arrow_flight_transport_ucx)
+arrow_install_all_headers("arrow/flight/transport/ucx")
+
+find_package(PkgConfig REQUIRED)
+pkg_check_modules(UCX REQUIRED IMPORTED_TARGET ucx)
+
+set(ARROW_FLIGHT_TRANSPORT_UCX_SRCS
+ ucx_client.cc
+ ucx_server.cc
+ ucx.cc
+ ucx_internal.cc
+ util_internal.cc)
+set(ARROW_FLIGHT_TRANSPORT_UCX_LINK_LIBS)
+
+include_directories(SYSTEM ${UCX_INCLUDE_DIRS})
+list(APPEND ARROW_FLIGHT_TRANSPORT_UCX_LINK_LIBS PkgConfig::UCX)
+
+add_arrow_lib(arrow_flight_transport_ucx
+ # CMAKE_PACKAGE_NAME
+ # ArrowFlightTransportUcx
+ # PKG_CONFIG_NAME
+ # arrow-flight-transport-ucx
+ SOURCES
+ ${ARROW_FLIGHT_TRANSPORT_UCX_SRCS}
+ PRECOMPILED_HEADERS
+ "$<$<COMPILE_LANGUAGE:CXX>:arrow/flight/transport/ucx/pch.h>"
+ DEPENDENCIES
+ SHARED_LINK_FLAGS
+ ${ARROW_VERSION_SCRIPT_FLAGS} # Defined in cpp/arrow/CMakeLists.txt
+ SHARED_LINK_LIBS
+ arrow_shared
+ arrow_flight_shared
+ ${ARROW_FLIGHT_TRANSPORT_UCX_LINK_LIBS}
+ STATIC_LINK_LIBS
+ arrow_static
+ arrow_flight_static
+ ${ARROW_FLIGHT_TRANSPORT_UCX_LINK_LIBS})
+
+if(ARROW_BUILD_TESTS)
+ if(ARROW_FLIGHT_TEST_LINKAGE STREQUAL "static")
+ set(ARROW_FLIGHT_UCX_TEST_LINK_LIBS
+ arrow_static
+ arrow_flight_static
+ arrow_flight_testing_static
+ arrow_flight_transport_ucx_static
+ ${ARROW_TEST_LINK_LIBS})
+ else()
+ set(ARROW_FLIGHT_UCX_TEST_LINK_LIBS
+ arrow_shared
+ arrow_flight_shared
+ arrow_flight_testing_shared
+ arrow_flight_transport_ucx_shared
+ ${ARROW_TEST_LINK_LIBS})
+ endif()
+ add_arrow_test(flight_transport_ucx_test
+ STATIC_LINK_LIBS
+ ${ARROW_FLIGHT_UCX_TEST_LINK_LIBS}
+ LABELS
+ "arrow_flight")
+endif()
diff --git a/cpp/src/arrow/flight/transport/ucx/flight_transport_ucx_test.cc b/cpp/src/arrow/flight/transport/ucx/flight_transport_ucx_test.cc
new file mode 100644
index 0000000000..6a580af92f
--- /dev/null
+++ b/cpp/src/arrow/flight/transport/ucx/flight_transport_ucx_test.cc
@@ -0,0 +1,386 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include "arrow/array/array_base.h"
+#include "arrow/flight/test_definitions.h"
+#include "arrow/flight/test_util.h"
+#include "arrow/flight/transport/ucx/ucx.h"
+#include "arrow/table.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/config.h"
+
+#ifdef UCP_API_VERSION
+#error "UCX headers should not be in public API"
+#endif
+
+#include "arrow/flight/transport/ucx/ucx_internal.h"
+
+#ifdef ARROW_CUDA
+#include "arrow/gpu/cuda_api.h"
+#endif
+
+namespace arrow {
+namespace flight {
+
+class UcxEnvironment : public ::testing::Environment {
+ public:
+ void SetUp() override { transport::ucx::InitializeFlightUcx(); }
+};
+
+testing::Environment* const kUcxEnvironment =
+ testing::AddGlobalTestEnvironment(new UcxEnvironment());
+
+//------------------------------------------------------------
+// Common transport tests
+
+class UcxConnectivityTest : public ConnectivityTest {
+ protected:
+ std::string transport() const override { return "ucx"; }
+};
+ARROW_FLIGHT_TEST_CONNECTIVITY(UcxConnectivityTest);
+
+class UcxDataTest : public DataTest {
+ protected:
+ std::string transport() const override { return "ucx"; }
+};
+ARROW_FLIGHT_TEST_DATA(UcxDataTest);
+
+class UcxDoPutTest : public DoPutTest {
+ protected:
+ std::string transport() const override { return "ucx"; }
+};
+ARROW_FLIGHT_TEST_DO_PUT(UcxDoPutTest);
+
+class UcxAppMetadataTest : public AppMetadataTest {
+ protected:
+ std::string transport() const override { return "ucx"; }
+};
+ARROW_FLIGHT_TEST_APP_METADATA(UcxAppMetadataTest);
+
+class UcxIpcOptionsTest : public IpcOptionsTest {
+ protected:
+ std::string transport() const override { return "ucx"; }
+};
+ARROW_FLIGHT_TEST_IPC_OPTIONS(UcxIpcOptionsTest);
+
+class UcxCudaDataTest : public CudaDataTest {
+ protected:
+ std::string transport() const override { return "ucx"; }
+};
+ARROW_FLIGHT_TEST_CUDA_DATA(UcxCudaDataTest);
+
+//------------------------------------------------------------
+// UCX internals tests
+
+constexpr std::initializer_list<StatusCode> kStatusCodes = {
+ StatusCode::OK,
+ StatusCode::OutOfMemory,
+ StatusCode::KeyError,
+ StatusCode::TypeError,
+ StatusCode::Invalid,
+ StatusCode::IOError,
+ StatusCode::CapacityError,
+ StatusCode::IndexError,
+ StatusCode::Cancelled,
+ StatusCode::UnknownError,
+ StatusCode::NotImplemented,
+ StatusCode::SerializationError,
+ StatusCode::RError,
+ StatusCode::CodeGenError,
+ StatusCode::ExpressionValidationError,
+ StatusCode::ExecutionError,
+ StatusCode::AlreadyExists,
+};
+
+constexpr std::initializer_list<FlightStatusCode> kFlightStatusCodes = {
+ FlightStatusCode::Internal, FlightStatusCode::TimedOut,
+ FlightStatusCode::Cancelled, FlightStatusCode::Unauthenticated,
+ FlightStatusCode::Unauthorized, FlightStatusCode::Unavailable,
+ FlightStatusCode::Failed,
+};
+
+class TestStatusDetail : public StatusDetail {
+ public:
+ const char* type_id() const override { return "test-status-detail"; }
+ std::string ToString() const override { return "Custom status detail"; }
+};
+
+namespace transport {
+namespace ucx {
+
+static constexpr std::initializer_list<FrameType> kFrameTypes = {
+ FrameType::kHeaders, FrameType::kBuffer, FrameType::kPayloadHeader,
+ FrameType::kPayloadBody, FrameType::kDisconnect,
+};
+
+TEST(FrameHeader, Basics) {
+ for (const auto frame_type : kFrameTypes) {
+ FrameHeader header;
+ ASSERT_OK(header.Set(frame_type, /*counter=*/42, /*body_size=*/65535));
+ if (frame_type == FrameType::kDisconnect) {
+ ASSERT_RAISES(Cancelled, Frame::ParseHeader(header.data(), header.size()));
+ } else {
+ ASSERT_OK_AND_ASSIGN(auto frame, Frame::ParseHeader(header.data(), header.size()));
+ ASSERT_EQ(frame->type, frame_type);
+ ASSERT_EQ(frame->counter, 42);
+ ASSERT_EQ(frame->size, 65535);
+ }
+ }
+}
+
+TEST(FrameHeader, FrameType) {
+ for (const auto frame_type : kFrameTypes) {
+ ASSERT_LE(static_cast<int>(frame_type), static_cast<int>(FrameType::kMaxFrameType));
+ }
+}
+
+TEST(HeadersFrame, Parse) {
+ const char* data =
+ ("\x00\x00\x00\x02\x00\x00\x00\x05\x00\x00\x00\x03x-foobar"
+ "\x00\x00\x00\x05\x00\x00\x00\x01x-bin\x01");
+ constexpr int64_t size = 34;
+
+ {
+ std::unique_ptr<Buffer> buffer(
+ new Buffer(reinterpret_cast<const uint8_t*>(data), size));
+ ASSERT_OK_AND_ASSIGN(auto headers, HeadersFrame::Parse(std::move(buffer)));
+ ASSERT_OK_AND_ASSIGN(auto foo, headers.Get("x-foo"));
+ ASSERT_EQ(foo, "bar");
+ ASSERT_OK_AND_ASSIGN(auto bin, headers.Get("x-bin"));
+ ASSERT_EQ(bin, "\x01");
+ }
+ {
+ std::unique_ptr<Buffer> buffer(new Buffer(reinterpret_cast<const uint8_t*>(data), 3));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
+ ::testing::HasSubstr("expected number of headers"),
+ HeadersFrame::Parse(std::move(buffer)));
+ }
+ {
+ std::unique_ptr<Buffer> buffer(new Buffer(reinterpret_cast<const uint8_t*>(data), 7));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
+ ::testing::HasSubstr("expected length of key 1"),
+ HeadersFrame::Parse(std::move(buffer)));
+ }
+ {
+ std::unique_ptr<Buffer> buffer(
+ new Buffer(reinterpret_cast<const uint8_t*>(data), 10));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
+ ::testing::HasSubstr("expected length of value 1"),
+ HeadersFrame::Parse(std::move(buffer)));
+ }
+ {
+ std::unique_ptr<Buffer> buffer(
+ new Buffer(reinterpret_cast<const uint8_t*>(data), 12));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ ::testing::HasSubstr("expected key 1 to have length 5, but only 0 bytes remain"),
+ HeadersFrame::Parse(std::move(buffer)));
+ }
+ {
+ std::unique_ptr<Buffer> buffer(
+ new Buffer(reinterpret_cast<const uint8_t*>(data), 17));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ ::testing::HasSubstr(
+ "expected value 1 to have length 3, but only 0 bytes remain"),
+ HeadersFrame::Parse(std::move(buffer)));
+ }
+}
+
+TEST(HeadersFrame, RoundTripStatus) {
+ for (const auto code : kStatusCodes) {
+ {
+ Status expected = code == StatusCode::OK ? Status() : Status(code, "foo");
+ ASSERT_OK_AND_ASSIGN(auto headers, HeadersFrame::Make(expected, {}));
+ Status status;
+ ASSERT_OK(headers.GetStatus(&status));
+ ASSERT_EQ(status, expected);
+ }
+
+ if (code == StatusCode::OK) continue;
+
+ // Attach a generic status detail
+ {
+ auto detail = std::make_shared<TestStatusDetail>();
+ Status original(code, "foo", detail);
+ Status expected(code, "foo",
+ std::make_shared<FlightStatusDetail>(FlightStatusCode::Internal,
+ detail->ToString()));
+ ASSERT_OK_AND_ASSIGN(auto headers, HeadersFrame::Make(expected, {}));
+ Status status;
+ ASSERT_OK(headers.GetStatus(&status));
+ ASSERT_EQ(status, expected);
+ }
+
+ // Attach a Flight status detail
+ for (const auto flight_code : kFlightStatusCodes) {
+ Status expected(code, "foo",
+ std::make_shared<FlightStatusDetail>(flight_code, "extra"));
+ ASSERT_OK_AND_ASSIGN(auto headers, HeadersFrame::Make(expected, {}));
+ Status status;
+ ASSERT_OK(headers.GetStatus(&status));
+ ASSERT_EQ(status, expected);
+ }
+ }
+}
+} // namespace ucx
+} // namespace transport
+
+//------------------------------------------------------------
+// Ad-hoc UCX-specific tests
+
+class SimpleTestServer : public FlightServerBase {
+ public:
+ Status GetFlightInfo(const ServerCallContext& context, const FlightDescriptor& request,
+ std::unique_ptr<FlightInfo>* info) override {
+ if (request.path.size() > 0 && request.path[0] == "error") {
+ return status_;
+ }
+ auto examples = ExampleFlightInfo();
+ info->reset(new FlightInfo(examples[0]));
+ return Status::OK();
+ }
+
+ Status DoGet(const ServerCallContext& context, const Ticket& request,
+ std::unique_ptr<FlightDataStream>* data_stream) override {
+ RecordBatchVector batches;
+ RETURN_NOT_OK(ExampleIntBatches(&batches));
+ auto batch_reader = std::make_shared<BatchIterator>(batches[0]->schema(), batches);
+ *data_stream = std::unique_ptr<FlightDataStream>(new RecordBatchStream(batch_reader));
+ return Status::OK();
+ }
+
+ void set_error_status(Status st) { status_ = std::move(st); }
+
+ private:
+ Status status_;
+};
+
+class TestUcx : public ::testing::Test {
+ public:
+ void SetUp() {
+ ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme("ucx", "127.0.0.1", 0));
+ ASSERT_OK(MakeServer<SimpleTestServer>(
+ location, &server_, &client_,
+ [](FlightServerOptions* options) { return Status::OK(); },
+ [](FlightClientOptions* options) { return Status::OK(); }));
+ }
+
+ void TearDown() {
+ ASSERT_OK(client_->Close());
+ ASSERT_OK(server_->Shutdown());
+ }
+
+ protected:
+ std::unique_ptr<FlightClient> client_;
+ std::unique_ptr<FlightServerBase> server_;
+};
+
+TEST_F(TestUcx, GetFlightInfo) {
+ auto descriptor = FlightDescriptor::Path({"foo", "bar"});
+ std::unique_ptr<FlightInfo> info;
+ ASSERT_OK_AND_ASSIGN(info, client_->GetFlightInfo(descriptor));
+ // Test that we can reuse the connection
+ ASSERT_OK_AND_ASSIGN(info, client_->GetFlightInfo(descriptor));
+}
+
+TEST_F(TestUcx, SequentialClients) {
+ ASSERT_OK_AND_ASSIGN(
+ auto client2,
+ FlightClient::Connect(server_->location(), FlightClientOptions::Defaults()));
+
+ Ticket ticket{"a"};
+
+ ASSERT_OK_AND_ASSIGN(auto stream1, client_->DoGet(ticket));
+ ASSERT_OK_AND_ASSIGN(auto table1, stream1->ToTable());
+
+ ASSERT_OK_AND_ASSIGN(auto stream2, client2->DoGet(ticket));
+ ASSERT_OK_AND_ASSIGN(auto table2, stream2->ToTable());
+
+ AssertTablesEqual(*table1, *table2);
+}
+
+TEST_F(TestUcx, ConcurrentClients) {
+ ASSERT_OK_AND_ASSIGN(
+ auto client2,
+ FlightClient::Connect(server_->location(), FlightClientOptions::Defaults()));
+
+ Ticket ticket{"a"};
+
+ ASSERT_OK_AND_ASSIGN(auto stream1, client_->DoGet(ticket));
+ ASSERT_OK_AND_ASSIGN(auto stream2, client2->DoGet(ticket));
+
+ ASSERT_OK_AND_ASSIGN(auto table1, stream1->ToTable());
+ ASSERT_OK_AND_ASSIGN(auto table2, stream2->ToTable());
+
+ AssertTablesEqual(*table1, *table2);
+}
+
+TEST_F(TestUcx, Errors) {
+ auto descriptor = FlightDescriptor::Path({"error", "bar"});
+ auto* server = reinterpret_cast<SimpleTestServer*>(server_.get());
+ for (const auto code : kStatusCodes) {
+ if (code == StatusCode::OK) continue;
+
+ Status expected(code, "Error message");
+ server->set_error_status(expected);
+ Status actual = client_->GetFlightInfo(descriptor).status();
+ ASSERT_EQ(actual, expected);
+
+ // Attach a generic status detail
+ {
+ auto detail = std::make_shared<TestStatusDetail>();
+ server->set_error_status(Status(code, "foo", detail));
+ Status expected(code, "foo",
+ std::make_shared<FlightStatusDetail>(FlightStatusCode::Internal,
+ detail->ToString()));
+ Status actual = client_->GetFlightInfo(descriptor).status();
+ ASSERT_EQ(actual, expected);
+ }
+
+ // Attach a Flight status detail
+ for (const auto flight_code : kFlightStatusCodes) {
+ Status expected(code, "Error message",
+ std::make_shared<FlightStatusDetail>(flight_code, "extra"));
+ server->set_error_status(expected);
+ Status actual = client_->GetFlightInfo(descriptor).status();
+ ASSERT_EQ(actual, expected);
+ }
+ }
+}
+
+TEST(TestUcxIpV6, DISABLED_IpV6Port) {
+ // Also, disabled in CI as machines lack an IPv6 interface
+ ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme("ucx", "[::1]", 0));
+
+ std::unique_ptr<FlightServerBase> server(new SimpleTestServer());
+ FlightServerOptions server_options(location);
+ ASSERT_OK(server->Init(server_options));
+
+ FlightClientOptions client_options = FlightClientOptions::Defaults();
+ ASSERT_OK_AND_ASSIGN(auto client,
+ FlightClient::Connect(server->location(), client_options));
+
+ auto descriptor = FlightDescriptor::Path({"foo", "bar"});
+ ASSERT_OK_AND_ASSIGN(auto info, client->GetFlightInfo(descriptor));
+}
+
+} // namespace flight
+} // namespace arrow
diff --git a/cpp/src/arrow/flight/transport/ucx/ucx.cc b/cpp/src/arrow/flight/transport/ucx/ucx.cc
new file mode 100644
index 0000000000..0e3daf6021
--- /dev/null
+++ b/cpp/src/arrow/flight/transport/ucx/ucx.cc
@@ -0,0 +1,45 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/flight/transport/ucx/ucx.h"
+
+#include <mutex>
+
+#include "arrow/flight/transport.h"
+#include "arrow/flight/transport/ucx/ucx_internal.h"
+#include "arrow/flight/transport_server.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+namespace flight {
+namespace transport {
+namespace ucx {
+
+namespace {
+std::once_flag kInitializeOnce;
+}
+void InitializeFlightUcx() {
+ std::call_once(kInitializeOnce, []() {
+ auto* registry = flight::internal::GetDefaultTransportRegistry();
+ DCHECK_OK(registry->RegisterClient("ucx", MakeUcxClientImpl));
+ DCHECK_OK(registry->RegisterServer("ucx", MakeUcxServerImpl));
+ });
+}
+} // namespace ucx
+} // namespace transport
+} // namespace flight
+} // namespace arrow
diff --git a/cpp/src/arrow/flight/transport/ucx/ucx.h b/cpp/src/arrow/flight/transport/ucx/ucx.h
new file mode 100644
index 0000000000..dda2c83035
--- /dev/null
+++ b/cpp/src/arrow/flight/transport/ucx/ucx.h
@@ -0,0 +1,35 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Experimental UCX-based transport for Flight.
+
+#pragma once
+
+#include "arrow/flight/visibility.h"
+
+namespace arrow {
+namespace flight {
+namespace transport {
+namespace ucx {
+
+ARROW_FLIGHT_EXPORT
+void InitializeFlightUcx();
+
+} // namespace ucx
+} // namespace transport
+} // namespace flight
+} // namespace arrow
diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_client.cc b/cpp/src/arrow/flight/transport/ucx/ucx_client.cc
new file mode 100644
index 0000000000..173132062e
--- /dev/null
+++ b/cpp/src/arrow/flight/transport/ucx/ucx_client.cc
@@ -0,0 +1,733 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+/// The client-side implementation of a UCX-based transport for
+/// Flight.
+///
+/// Each UCX driver is used to support one call at a time. This gives
+/// the greatest throughput for data plane methods, but is relatively
+/// expensive in terms of other resources, both for the server and the
+/// client. (UCX drivers have multiple threading modes: single-thread
+/// access, serialized access, and multi-thread access. Testing found
+/// that multi-thread access incurred high synchronization costs.)
+/// Hence, for concurrent calls in a single client, we must maintain
+/// multiple drivers, and so unlike gRPC, there is no real difference
+/// between using one client concurrently and using multiple
+/// independent clients.
+
+#include "arrow/flight/transport/ucx/ucx_internal.h"
+
+#include <condition_variable>
+#include <deque>
+#include <mutex>
+#include <thread>
+
+#include <arpa/inet.h>
+#include <ucp/api/ucp.h>
+
+#include "arrow/buffer.h"
+#include "arrow/flight/client.h"
+#include "arrow/flight/transport.h"
+#include "arrow/flight/transport/ucx/util_internal.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/make_unique.h"
+#include "arrow/util/uri.h"
+
+namespace arrow {
+namespace flight {
+namespace transport {
+namespace ucx {
+
+namespace {
+class UcxClientImpl;
+
+Status MergeStatuses(Status server_status, Status transport_status) {
+ if (server_status.ok()) {
+ if (transport_status.ok()) return server_status;
+ return transport_status;
+ } else if (transport_status.ok()) {
+ return server_status;
+ }
+ return Status::FromDetailAndArgs(server_status.code(), server_status.detail(),
+ server_status.message(),
+ ". Transport context: ", transport_status.ToString());
+}
+
+/// \brief An individual connection to the server.
+class ClientConnection {
+ public:
+ ClientConnection() = default;
+ ARROW_DISALLOW_COPY_AND_ASSIGN(ClientConnection);
+ ARROW_DEFAULT_MOVE_AND_ASSIGN(ClientConnection);
+ ~ClientConnection() { DCHECK(!driver_) << "Connection was not closed!"; }
+
+ Status Init(std::shared_ptr<UcpContext> ucp_context, const arrow::internal::Uri& uri) {
+ auto status = InitImpl(std::move(ucp_context), uri);
+ // Clean up after-the-fact if we fail to initialize
+ if (!status.ok()) {
+ if (driver_) {
+ status = MergeStatuses(std::move(status), driver_->Close());
+ driver_.reset();
+ remote_endpoint_ = nullptr;
+ }
+ if (ucp_worker_) ucp_worker_.reset();
+ }
+ return status;
+ }
+
+ Status InitImpl(std::shared_ptr<UcpContext> ucp_context,
+ const arrow::internal::Uri& uri) {
+ {
+ ucs_status_t status;
+ ucp_worker_params_t worker_params;
+ std::memset(&worker_params, 0, sizeof(worker_params));
+ worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE;
+ worker_params.thread_mode = UCS_THREAD_MODE_SERIALIZED;
+
+ ucp_worker_h ucp_worker;
+ status = ucp_worker_create(ucp_context->get(), &worker_params, &ucp_worker);
+ RETURN_NOT_OK(FromUcsStatus("ucp_worker_create", status));
+ ucp_worker_.reset(new UcpWorker(std::move(ucp_context), ucp_worker));
+ }
+ {
+ // Create endpoint for remote worker
+ struct sockaddr_storage connect_addr;
+ ARROW_ASSIGN_OR_RAISE(auto addrlen, UriToSockaddr(uri, &connect_addr));
+ std::string peer;
+ ARROW_UNUSED(SockaddrToString(connect_addr).Value(&peer));
+ ARROW_LOG(DEBUG) << "Connecting to " << peer;
+
+ ucp_ep_params_t params;
+ params.field_mask = UCP_EP_PARAM_FIELD_FLAGS | UCP_EP_PARAM_FIELD_NAME |
+ UCP_EP_PARAM_FIELD_SOCK_ADDR;
+ params.flags = UCP_EP_PARAMS_FLAGS_CLIENT_SERVER;
+ params.name = "UcxClientImpl";
+ params.sockaddr.addr = reinterpret_cast<const sockaddr*>(&connect_addr);
+ params.sockaddr.addrlen = addrlen;
+
+ auto status = ucp_ep_create(ucp_worker_->get(), ¶ms, &remote_endpoint_);
+ RETURN_NOT_OK(FromUcsStatus("ucp_ep_create", status));
+ }
+
+ driver_.reset(new UcpCallDriver(ucp_worker_, remote_endpoint_));
+ ARROW_LOG(DEBUG) << "Connected to " << driver_->peer();
+
+ {
+ // Set up Active Message (AM) handler
+ ucp_am_handler_param_t handler_params;
+ handler_params.field_mask = UCP_AM_HANDLER_PARAM_FIELD_ID |
+ UCP_AM_HANDLER_PARAM_FIELD_CB |
+ UCP_AM_HANDLER_PARAM_FIELD_ARG;
+ handler_params.id = kUcpAmHandlerId;
+ handler_params.cb = HandleIncomingActiveMessage;
+ handler_params.arg = driver_.get();
+ ucs_status_t status =
+ ucp_worker_set_am_recv_handler(ucp_worker_->get(), &handler_params);
+ RETURN_NOT_OK(FromUcsStatus("ucp_worker_set_am_recv_handler", status));
+ }
+
+ return Status::OK();
+ }
+
+ Status Close() {
+ if (!driver_) return Status::OK();
+
+ auto status = driver_->SendFrame(FrameType::kDisconnect, nullptr, 0);
+ const auto ucs_status = FlightUcxStatusDetail::Unwrap(status);
+ if (IsIgnorableDisconnectError(ucs_status)) {
+ status = Status::OK();
+ }
+ status = MergeStatuses(std::move(status), driver_->Close());
+
+ driver_.reset();
+ remote_endpoint_ = nullptr;
+ ucp_worker_.reset();
+ return status;
+ }
+
+ UcpCallDriver* driver() {
+ DCHECK(driver_);
+ return driver_.get();
+ }
+
+ private:
+ static ucs_status_t HandleIncomingActiveMessage(void* self, const void* header,
+ size_t header_length, void* data,
+ size_t data_length,
+ const ucp_am_recv_param_t* param) {
+ auto* driver = reinterpret_cast<UcpCallDriver*>(self);
+ return driver->RecvActiveMessage(header, header_length, data, data_length, param);
+ }
+
+ std::shared_ptr<UcpWorker> ucp_worker_;
+ ucp_ep_h remote_endpoint_;
+ std::unique_ptr<UcpCallDriver> driver_;
+};
+
+class UcxClientStream : public internal::ClientDataStream {
+ public:
+ UcxClientStream(UcxClientImpl* impl, ClientConnection conn)
+ : impl_(impl),
+ conn_(std::move(conn)),
+ driver_(conn_.driver()),
+ writes_done_(false),
+ finished_(false) {}
+
+ protected:
+ Status DoFinish() override;
+
+ UcxClientImpl* impl_;
+ ClientConnection conn_;
+ UcpCallDriver* driver_;
+ bool writes_done_;
+ bool finished_;
+ Status io_status_;
+ Status server_status_;
+};
+
+class GetClientStream : public UcxClientStream {
+ public:
+ GetClientStream(UcxClientImpl* impl, ClientConnection conn)
+ : UcxClientStream(impl, std::move(conn)) {
+ writes_done_ = true;
+ }
+
+ bool ReadData(internal::FlightData* data) override {
+ if (finished_) return false;
+
+ bool success = true;
+ io_status_ = ReadImpl(data).Value(&success);
+
+ if (!io_status_.ok() || !success) {
+ finished_ = true;
+ }
+ return success;
+ }
+
+ private:
+ ::arrow::Result<bool> ReadImpl(internal::FlightData* data) {
+ ARROW_ASSIGN_OR_RAISE(auto frame, driver_->ReadNextFrame());
+
+ if (frame->type == FrameType::kHeaders) {
+ // Trailers, stream is over
+ ARROW_ASSIGN_OR_RAISE(auto headers, HeadersFrame::Parse(std::move(frame->buffer)));
+ RETURN_NOT_OK(headers.GetStatus(&server_status_));
+ return false;
+ }
+
+ RETURN_NOT_OK(driver_->ExpectFrameType(*frame, FrameType::kPayloadHeader));
+ PayloadHeaderFrame payload_header(std::move(frame->buffer));
+ RETURN_NOT_OK(payload_header.ToFlightData(data));
+
+ // DoGet does not support metadata-only messages, so we can always
+ // assume we have an IPC payload
+ ARROW_ASSIGN_OR_RAISE(auto message, ipc::Message::Open(data->metadata, nullptr));
+
+ if (ipc::Message::HasBody(message->type())) {
+ ARROW_ASSIGN_OR_RAISE(frame, driver_->ReadNextFrame());
+ RETURN_NOT_OK(driver_->ExpectFrameType(*frame, FrameType::kPayloadBody));
+ data->body = std::move(frame->buffer);
+ }
+ return true;
+ }
+};
+
+class WriteClientStream : public UcxClientStream {
+ public:
+ WriteClientStream(UcxClientImpl* impl, ClientConnection conn)
+ : UcxClientStream(impl, std::move(conn)) {
+ std::thread t(&WriteClientStream::DriveWorker, this);
+ driver_thread_.swap(t);
+ }
+ arrow::Result<bool> WriteData(const FlightPayload& payload) override {
+ std::unique_lock<std::mutex> guard(driver_mutex_);
+ if (finished_ || writes_done_) return Status::Invalid("Already done writing");
+ outgoing_ = driver_->SendFlightPayload(payload);
+ working_cv_.notify_all();
+ completed_cv_.wait(guard, [this] { return outgoing_.is_finished(); });
+
+ auto status = outgoing_.status();
+ outgoing_ = Future<>();
+ RETURN_NOT_OK(status);
+ return true;
+ }
+ Status WritesDone() override {
+ std::unique_lock<std::mutex> guard(driver_mutex_);
+ if (!writes_done_) {
+ ARROW_ASSIGN_OR_RAISE(auto headers, HeadersFrame::Make({}));
+ outgoing_ =
+ driver_->SendFrameAsync(FrameType::kHeaders, std::move(headers).GetBuffer());
+ working_cv_.notify_all();
+ completed_cv_.wait(guard, [this] { return outgoing_.is_finished(); });
+
+ writes_done_ = true;
+ auto status = outgoing_.status();
+ outgoing_ = Future<>();
+ RETURN_NOT_OK(status);
+ }
+ return Status::OK();
+ }
+
+ protected:
+ void JoinThread() {
+ try {
+ driver_thread_.join();
+ } catch (const std::system_error&) {
+ // Ignore
+ }
+ }
+ // Flight's API allows concurrent reads/writes, but the UCX driver
+ // here is single-threaded, so push all UCX work onto a single
+ // worker thread
+ void DriveWorker() {
+ while (true) {
+ {
+ std::unique_lock<std::mutex> guard(driver_mutex_);
+ working_cv_.wait(guard,
+ [this] { return incoming_.is_valid() || outgoing_.is_valid(); });
+ }
+
+ while (true) {
+ std::unique_lock<std::mutex> guard(driver_mutex_);
+ if (!incoming_.is_valid() && !outgoing_.is_valid()) break;
+ if (incoming_.is_valid() && incoming_.is_finished()) {
+ if (!incoming_.status().ok()) {
+ io_status_ = incoming_.status();
+ finished_ = true;
+ } else {
+ HandleIncomingMessage(*incoming_.result());
+ }
+ incoming_ = Future<std::shared_ptr<Frame>>();
+ completed_cv_.notify_all();
+ break;
+ }
+ if (outgoing_.is_valid() && outgoing_.is_finished()) {
+ completed_cv_.notify_all();
+ break;
+ }
+ driver_->MakeProgress();
+ }
+ if (finished_) return;
+ }
+ }
+
+ virtual void HandleIncomingMessage(const std::shared_ptr<Frame>& frame) {}
+
+ std::mutex driver_mutex_;
+ std::thread driver_thread_;
+ std::condition_variable completed_cv_;
+ std::condition_variable working_cv_;
+ Future<std::shared_ptr<Frame>> incoming_;
+ Future<> outgoing_;
+};
+
+class PutClientStream : public WriteClientStream {
+ public:
+ using WriteClientStream::WriteClientStream;
+ bool ReadPutMetadata(std::shared_ptr<Buffer>* out) override {
+ std::unique_lock<std::mutex> guard(driver_mutex_);
+ if (finished_) {
+ *out = nullptr;
+ guard.unlock();
+ JoinThread();
+ return false;
+ }
+ next_metadata_ = nullptr;
+ incoming_ = driver_->ReadFrameAsync();
+ working_cv_.notify_all();
+ completed_cv_.wait(guard, [this] { return next_metadata_ != nullptr || finished_; });
+
+ if (finished_) {
+ *out = nullptr;
+ guard.unlock();
+ JoinThread();
+ return false;
+ }
+ *out = std::move(next_metadata_);
+ return true;
+ }
+
+ private:
+ void HandleIncomingMessage(const std::shared_ptr<Frame>& frame) override {
+ // No lock here, since this is called from DriveWorker() which is
+ // holding the lock
+ if (frame->type == FrameType::kBuffer) {
+ next_metadata_ = std::move(frame->buffer);
+ } else if (frame->type == FrameType::kHeaders) {
+ // Trailers, stream is over
+ finished_ = true;
+ HeadersFrame headers;
+ io_status_ = HeadersFrame::Parse(std::move(frame->buffer)).Value(&headers);
+ if (!io_status_.ok()) {
+ finished_ = true;
+ return;
+ }
+ io_status_ = headers.GetStatus(&server_status_);
+ if (!io_status_.ok()) {
+ finished_ = true;
+ return;
+ }
+ } else {
+ finished_ = true;
+ io_status_ =
+ Status::IOError("Unexpected frame type ", static_cast<int>(frame->type));
+ }
+ }
+ std::shared_ptr<Buffer> next_metadata_;
+};
+
+class ExchangeClientStream : public WriteClientStream {
+ public:
+ ExchangeClientStream(UcxClientImpl* impl, ClientConnection conn)
+ : WriteClientStream(impl, std::move(conn)), read_state_(ReadState::kFinished) {}
+
+ bool ReadData(internal::FlightData* data) override {
+ std::unique_lock<std::mutex> guard(driver_mutex_);
+ if (finished_) {
+ guard.unlock();
+ JoinThread();
+ return false;
+ }
+
+ // Drive the read loop here. (We can't recursively call
+ // ReadFrameAsync below since the internal mutex is not
+ // recursive.)
+ read_state_ = ReadState::kExpectHeader;
+ incoming_ = driver_->ReadFrameAsync();
+ working_cv_.notify_all();
+ completed_cv_.wait(guard, [this] { return read_state_ != ReadState::kExpectHeader; });
+ if (read_state_ != ReadState::kFinished) {
+ incoming_ = driver_->ReadFrameAsync();
+ working_cv_.notify_all();
+ completed_cv_.wait(guard, [this] { return read_state_ == ReadState::kFinished; });
+ }
+
+ if (finished_) {
+ guard.unlock();
+ JoinThread();
+ return false;
+ }
+ *data = std::move(next_data_);
+ return true;
+ }
+
+ private:
+ enum class ReadState {
+ kFinished,
+ kExpectHeader,
+ kExpectBody,
+ };
+
+ std::string DebugExpectingString() {
+ switch (read_state_) {
+ case ReadState::kFinished:
+ return "(not expecting a frame)";
+ case ReadState::kExpectHeader:
+ return "payload header frame";
+ case ReadState::kExpectBody:
+ return "payload body frame";
+ }
+ return "(unknown or invalid state)";
+ }
+
+ void HandleIncomingMessage(const std::shared_ptr<Frame>& frame) override {
+ // No lock here, since this is called from MakeProgress()
+ // which is called under the lock already
+ if (frame->type == FrameType::kPayloadHeader) {
+ if (read_state_ != ReadState::kExpectHeader) {
+ finished_ = true;
+ io_status_ = Status::IOError("Got unexpected payload header frame, expected: ",
+ DebugExpectingString());
+ return;
+ }
+
+ PayloadHeaderFrame payload_header(std::move(frame->buffer));
+ io_status_ = payload_header.ToFlightData(&next_data_);
+ if (!io_status_.ok()) {
+ finished_ = true;
+ return;
+ }
+
+ if (next_data_.metadata) {
+ std::unique_ptr<ipc::Message> message;
+ io_status_ = ipc::Message::Open(next_data_.metadata, nullptr).Value(&message);
+ if (!io_status_.ok()) {
+ finished_ = true;
+ return;
+ }
+ if (ipc::Message::HasBody(message->type())) {
+ read_state_ = ReadState::kExpectBody;
+ return;
+ }
+ }
+ read_state_ = ReadState::kFinished;
+ } else if (frame->type == FrameType::kPayloadBody) {
+ next_data_.body = std::move(frame->buffer);
+ read_state_ = ReadState::kFinished;
+ } else if (frame->type == FrameType::kHeaders) {
+ // Trailers, stream is over
+ finished_ = true;
+ read_state_ = ReadState::kFinished;
+ HeadersFrame headers;
+ io_status_ = HeadersFrame::Parse(std::move(frame->buffer)).Value(&headers);
+ if (!io_status_.ok()) {
+ finished_ = true;
+ return;
+ }
+ io_status_ = headers.GetStatus(&server_status_);
+ if (!io_status_.ok()) {
+ finished_ = true;
+ return;
+ }
+ } else {
+ finished_ = true;
+ io_status_ =
+ Status::IOError("Unexpected frame type ", static_cast<int>(frame->type));
+ read_state_ = ReadState::kFinished;
+ }
+ }
+
+ internal::FlightData next_data_;
+ ReadState read_state_;
+};
+
+class UcxClientImpl : public arrow::flight::internal::ClientTransport {
+ public:
+ UcxClientImpl() {}
+
+ virtual ~UcxClientImpl() {
+ if (!ucp_context_) return;
+ auto status = Close();
+ if (!status.ok()) {
+ ARROW_LOG(WARNING) << "UcxClientImpl errored in Close() in destructor: "
+ << status.ToString();
+ }
+ }
+
+ Status Init(const FlightClientOptions& options, const Location& location,
+ const arrow::internal::Uri& uri) override {
+ RETURN_NOT_OK(uri_.Parse(uri.ToString()));
+ {
+ ucp_config_t* ucp_config;
+ ucp_params_t ucp_params;
+ ucs_status_t status;
+
+ status = ucp_config_read(nullptr, nullptr, &ucp_config);
+ RETURN_NOT_OK(FromUcsStatus("ucp_config_read", status));
+
+ // If location is IPv6, must adjust UCX config
+ // XXX: we assume locations always resolve to IPv6 or IPv4 but
+ // that is not necessarily true.
+ {
+ struct sockaddr_storage connect_addr;
+ RETURN_NOT_OK(UriToSockaddr(uri, &connect_addr));
+ if (connect_addr.ss_family == AF_INET6) {
+ status = ucp_config_modify(ucp_config, "AF_PRIO", "inet6");
+ RETURN_NOT_OK(FromUcsStatus("ucp_config_modify", status));
+ }
+ }
+
+ std::memset(&ucp_params, 0, sizeof(ucp_params));
+ ucp_params.field_mask = UCP_PARAM_FIELD_FEATURES;
+ ucp_params.features = UCP_FEATURE_AM | UCP_FEATURE_WAKEUP;
+
+ ucp_context_h ucp_context;
+ status = ucp_init(&ucp_params, ucp_config, &ucp_context);
+ ucp_config_release(ucp_config);
+ RETURN_NOT_OK(FromUcsStatus("ucp_init", status));
+ ucp_context_.reset(new UcpContext(ucp_context));
+ }
+
+ RETURN_NOT_OK(MakeConnection());
+ return Status::OK();
+ }
+
+ Status Close() override {
+ std::unique_lock<std::mutex> connections_mutex_;
+ while (!connections_.empty()) {
+ RETURN_NOT_OK(connections_.front().Close());
+ connections_.pop_front();
+ }
+ return Status::OK();
+ }
+
+ Status GetFlightInfo(const FlightCallOptions& options,
+ const FlightDescriptor& descriptor,
+ std::unique_ptr<FlightInfo>* info) override {
+ ARROW_ASSIGN_OR_RAISE(auto connection, CheckoutConnection(options));
+ UcpCallDriver* driver = connection.driver();
+
+ auto impl = [&]() {
+ RETURN_NOT_OK(driver->StartCall(kMethodGetFlightInfo));
+
+ ARROW_ASSIGN_OR_RAISE(std::string payload, descriptor.SerializeToString());
+
+ RETURN_NOT_OK(driver->SendFrame(FrameType::kBuffer,
+ reinterpret_cast<const uint8_t*>(payload.data()),
+ static_cast<int64_t>(payload.size())));
+
+ ARROW_ASSIGN_OR_RAISE(auto incoming_message, driver->ReadNextFrame());
+ if (incoming_message->type == FrameType::kBuffer) {
+ ARROW_ASSIGN_OR_RAISE(
+ *info, FlightInfo::Deserialize(util::string_view(*incoming_message->buffer)));
+ ARROW_ASSIGN_OR_RAISE(incoming_message, driver->ReadNextFrame());
+ }
+ RETURN_NOT_OK(driver->ExpectFrameType(*incoming_message, FrameType::kHeaders));
+ ARROW_ASSIGN_OR_RAISE(auto headers,
+ HeadersFrame::Parse(std::move(incoming_message->buffer)));
+ Status status;
+ RETURN_NOT_OK(headers.GetStatus(&status));
+ return status;
+ };
+ auto status = impl();
+ return MergeStatuses(std::move(status), ReturnConnection(std::move(connection)));
+ }
+
+ Status DoExchange(const FlightCallOptions& options,
+ std::unique_ptr<internal::ClientDataStream>* out) override {
+ ARROW_ASSIGN_OR_RAISE(auto connection, CheckoutConnection(options));
+ UcpCallDriver* driver = connection.driver();
+
+ auto status = driver->StartCall(kMethodDoExchange);
+ if (ARROW_PREDICT_TRUE(status.ok())) {
+ *out =
+ arrow::internal::make_unique<ExchangeClientStream>(this, std::move(connection));
+ return Status::OK();
+ }
+ return MergeStatuses(std::move(status), ReturnConnection(std::move(connection)));
+ }
+
+ Status DoGet(const FlightCallOptions& options, const Ticket& ticket,
+ std::unique_ptr<internal::ClientDataStream>* stream) override {
+ ARROW_ASSIGN_OR_RAISE(auto connection, CheckoutConnection(options));
+ UcpCallDriver* driver = connection.driver();
+
+ auto impl = [&]() {
+ RETURN_NOT_OK(driver->StartCall(kMethodDoGet));
+ ARROW_ASSIGN_OR_RAISE(std::string payload, ticket.SerializeToString());
+ RETURN_NOT_OK(driver->SendFrame(FrameType::kBuffer,
+ reinterpret_cast<const uint8_t*>(payload.data()),
+ static_cast<int64_t>(payload.size())));
+ *stream =
+ arrow::internal::make_unique<GetClientStream>(this, std::move(connection));
+ return Status::OK();
+ };
+
+ auto status = impl();
+ if (ARROW_PREDICT_TRUE(status.ok())) return status;
+ return MergeStatuses(std::move(status), ReturnConnection(std::move(connection)));
+ }
+
+ Status DoPut(const FlightCallOptions& options,
+ std::unique_ptr<internal::ClientDataStream>* out) override {
+ ARROW_ASSIGN_OR_RAISE(auto connection, CheckoutConnection(options));
+ UcpCallDriver* driver = connection.driver();
+
+ auto status = driver->StartCall(kMethodDoPut);
+ if (ARROW_PREDICT_TRUE(status.ok())) {
+ *out = arrow::internal::make_unique<PutClientStream>(this, std::move(connection));
+ return Status::OK();
+ }
+ return MergeStatuses(std::move(status), ReturnConnection(std::move(connection)));
+ }
+
+ Status DoAction(const FlightCallOptions& options, const Action& action,
+ std::unique_ptr<ResultStream>* results) override {
+ // XXX: fake this for now to get the perf test to work
+ return Status::OK();
+ }
+
+ Status MakeConnection() {
+ ClientConnection conn;
+ RETURN_NOT_OK(conn.Init(ucp_context_, uri_));
+ connections_.push_back(std::move(conn));
+ return Status::OK();
+ }
+
+ arrow::Result<ClientConnection> CheckoutConnection(const FlightCallOptions& options) {
+ std::unique_lock<std::mutex> connections_mutex_;
+ if (connections_.empty()) RETURN_NOT_OK(MakeConnection());
+ ClientConnection conn = std::move(connections_.front());
+ conn.driver()->set_memory_manager(options.memory_manager);
+ conn.driver()->set_read_memory_pool(options.read_options.memory_pool);
+ conn.driver()->set_write_memory_pool(options.write_options.memory_pool);
+ connections_.pop_front();
+ return conn;
+ }
+
+ Status ReturnConnection(ClientConnection conn) {
+ std::unique_lock<std::mutex> connections_mutex_;
+ // TODO(ARROW-16127): for future improvement: reclaim clients
+ // asynchronously in the background (try to avoid issues like
+ // constantly opening/closing clients because the application is
+ // just barely over the limit of open connections)
+ if (connections_.size() >= kMaxOpenConnections) {
+ RETURN_NOT_OK(conn.Close());
+ return Status::OK();
+ }
+ connections_.push_back(std::move(conn));
+ return Status::OK();
+ }
+
+ private:
+ static constexpr size_t kMaxOpenConnections = 3;
+
+ arrow::internal::Uri uri_;
+ std::shared_ptr<UcpContext> ucp_context_;
+ std::mutex connections_mutex_;
+ std::deque<ClientConnection> connections_;
+};
+
+Status UcxClientStream::DoFinish() {
+ RETURN_NOT_OK(WritesDone());
+ if (!finished_) {
+ internal::FlightData message;
+ std::shared_ptr<Buffer> metadata;
+ while (ReadData(&message)) {
+ }
+ while (ReadPutMetadata(&metadata)) {
+ }
+ finished_ = true;
+ }
+ if (impl_) {
+ auto status = impl_->ReturnConnection(std::move(conn_));
+ impl_ = nullptr;
+ driver_ = nullptr;
+ if (!status.ok()) {
+ if (io_status_.ok()) {
+ io_status_ = std::move(status);
+ } else {
+ io_status_ = Status::FromDetailAndArgs(
+ io_status_.code(), io_status_.detail(), io_status_.message(),
+ ". Transport context: ", status.ToString());
+ }
+ }
+ }
+ return MergeStatuses(server_status_, io_status_);
+}
+} // namespace
+
+std::unique_ptr<arrow::flight::internal::ClientTransport> MakeUcxClientImpl() {
+ return arrow::internal::make_unique<UcxClientImpl>();
+}
+
+} // namespace ucx
+} // namespace transport
+} // namespace flight
+} // namespace arrow
diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc b/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc
new file mode 100644
index 0000000000..ab4cc323f4
--- /dev/null
+++ b/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc
@@ -0,0 +1,1171 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/flight/transport/ucx/ucx_internal.h"
+
+#include <array>
+#include <limits>
+#include <mutex>
+#include <unordered_map>
+
+#include "arrow/buffer.h"
+#include "arrow/flight/transport/ucx/util_internal.h"
+#include "arrow/flight/types.h"
+#include "arrow/util/base64.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/make_unique.h"
+#include "arrow/util/uri.h"
+
+namespace arrow {
+namespace flight {
+namespace transport {
+namespace ucx {
+
+// Defines to test different implementation strategies
+// Enable the CONTIG path for CPU-only data
+// #define ARROW_FLIGHT_UCX_SEND_CONTIG
+// Enable ucp_mem_map in IOV path
+// #define ARROW_FLIGHT_UCX_SEND_IOV_MAP
+
+constexpr char kHeaderMethod[] = ":method:";
+
+namespace {
+Status SizeToUInt32BytesBe(const int64_t in, uint8_t* out) {
+ if (ARROW_PREDICT_FALSE(in < 0)) {
+ return Status::Invalid("Length cannot be negative");
+ } else if (ARROW_PREDICT_FALSE(
+ in > static_cast<int64_t>(std::numeric_limits<uint32_t>::max()))) {
+ return Status::Invalid("Length cannot exceed uint32_t");
+ }
+ UInt32ToBytesBe(static_cast<uint32_t>(in), out);
+ return Status::OK();
+}
+ucs_memory_type InferMemoryType(const Buffer& buffer) {
+ if (!buffer.is_cpu()) {
+ return UCS_MEMORY_TYPE_CUDA;
+ }
+ return UCS_MEMORY_TYPE_UNKNOWN;
+}
+void TryMapBuffer(ucp_context_h context, const void* buffer, const size_t size,
+ ucs_memory_type memory_type, ucp_mem_h* memh_p) {
+ ucp_mem_map_params_t map_param;
+ std::memset(&map_param, 0, sizeof(map_param));
+ map_param.field_mask = UCP_MEM_MAP_PARAM_FIELD_ADDRESS |
+ UCP_MEM_MAP_PARAM_FIELD_LENGTH |
+ UCP_MEM_MAP_PARAM_FIELD_MEMORY_TYPE;
+ map_param.address = const_cast<void*>(buffer);
+ map_param.length = size;
+ map_param.memory_type = memory_type;
+ auto ucs_status = ucp_mem_map(context, &map_param, memh_p);
+ if (ucs_status != UCS_OK) {
+ *memh_p = nullptr;
+ ARROW_LOG(WARNING) << "Could not map memory: "
+ << FromUcsStatus("ucp_mem_map", ucs_status);
+ }
+}
+void TryMapBuffer(ucp_context_h context, const Buffer& buffer, ucp_mem_h* memh_p) {
+ TryMapBuffer(context, reinterpret_cast<void*>(buffer.address()),
+ static_cast<size_t>(buffer.size()), InferMemoryType(buffer), memh_p);
+}
+void TryUnmapBuffer(ucp_context_h context, ucp_mem_h memh_p) {
+ if (memh_p) {
+ auto ucs_status = ucp_mem_unmap(context, memh_p);
+ if (ucs_status != UCS_OK) {
+ ARROW_LOG(WARNING) << "Could not unmap memory: "
+ << FromUcsStatus("ucp_mem_unmap", ucs_status);
+ }
+ }
+}
+
+/// \brief Wrapper around a UCX zero copy buffer (a host memory DATA
+/// buffer).
+///
+/// Owns a reference to the associated worker to avoid undefined
+/// behavior.
+class UcxDataBuffer : public Buffer {
+ public:
+ explicit UcxDataBuffer(std::shared_ptr<UcpWorker> worker, void* data, size_t size)
+ : Buffer(reinterpret_cast<uint8_t*>(data), static_cast<int64_t>(size)),
+ worker_(std::move(worker)) {}
+
+ ~UcxDataBuffer() {
+ ucp_am_data_release(worker_->get(),
+ const_cast<void*>(reinterpret_cast<const void*>(data())));
+ }
+
+ private:
+ std::shared_ptr<UcpWorker> worker_;
+};
+}; // namespace
+
+constexpr size_t FrameHeader::kFrameHeaderBytes;
+constexpr uint8_t FrameHeader::kFrameVersion;
+
+Status FrameHeader::Set(FrameType frame_type, uint32_t counter, int64_t body_size) {
+ header[0] = kFrameVersion;
+ header[1] = static_cast<uint8_t>(frame_type);
+ UInt32ToBytesBe(counter, header.data() + 4);
+ RETURN_NOT_OK(SizeToUInt32BytesBe(body_size, header.data() + 8));
+ return Status::OK();
+}
+
+arrow::Result<std::shared_ptr<Frame>> Frame::ParseHeader(const void* header,
+ size_t header_length) {
+ if (header_length < FrameHeader::kFrameHeaderBytes) {
+ return Status::IOError("Header is too short, must be at least ",
+ FrameHeader::kFrameHeaderBytes, " bytes, got ", header_length);
+ }
+
+ const uint8_t* frame_header = reinterpret_cast<const uint8_t*>(header);
+ if (frame_header[0] != FrameHeader::kFrameVersion) {
+ return Status::IOError("Expected frame version ",
+ static_cast<int>(FrameHeader::kFrameVersion), " but got ",
+ static_cast<int>(frame_header[0]));
+ } else if (frame_header[1] > static_cast<uint8_t>(FrameType::kMaxFrameType)) {
+ return Status::IOError("Unknown frame type ", static_cast<int>(frame_header[1]));
+ }
+
+ const FrameType frame_type = static_cast<FrameType>(frame_header[1]);
+ const uint32_t frame_counter = BytesToUInt32Be(frame_header + 4);
+ const uint32_t frame_size = BytesToUInt32Be(frame_header + 8);
+
+ if (frame_type == FrameType::kDisconnect) {
+ return Status::Cancelled("Client initiated disconnect");
+ }
+
+ return std::make_shared<Frame>(frame_type, frame_size, frame_counter, nullptr);
+}
+
+arrow::Result<HeadersFrame> HeadersFrame::Parse(std::unique_ptr<Buffer> buffer) {
+ HeadersFrame result;
+ const uint8_t* payload = buffer->data();
+ const uint8_t* end = payload + buffer->size();
+ if (ARROW_PREDICT_FALSE((end - payload) < 4)) {
+ return Status::Invalid("Buffer underflow, expected number of headers");
+ }
+ const uint32_t num_headers = BytesToUInt32Be(payload);
+ payload += 4;
+ for (uint32_t i = 0; i < num_headers; i++) {
+ if (ARROW_PREDICT_FALSE((end - payload) < 4)) {
+ return Status::Invalid("Buffer underflow, expected length of key ", i + 1);
+ }
+ const uint32_t key_length = BytesToUInt32Be(payload);
+ payload += 4;
+
+ if (ARROW_PREDICT_FALSE((end - payload) < 4)) {
+ return Status::Invalid("Buffer underflow, expected length of value ", i + 1);
+ }
+ const uint32_t value_length = BytesToUInt32Be(payload);
+ payload += 4;
+
+ if (ARROW_PREDICT_FALSE((end - payload) < key_length)) {
+ return Status::Invalid("Buffer underflow, expected key ", i + 1, " to have length ",
+ key_length, ", but only ", (end - payload), " bytes remain");
+ }
+ const util::string_view key(reinterpret_cast<const char*>(payload), key_length);
+ payload += key_length;
+
+ if (ARROW_PREDICT_FALSE((end - payload) < value_length)) {
+ return Status::Invalid("Buffer underflow, expected value ", i + 1,
+ " to have length ", value_length, ", but only ",
+ (end - payload), " bytes remain");
+ }
+ const util::string_view value(reinterpret_cast<const char*>(payload), value_length);
+ payload += value_length;
+ result.headers_.emplace_back(key, value);
+ }
+
+ result.buffer_ = std::move(buffer);
+ return result;
+}
+arrow::Result<HeadersFrame> HeadersFrame::Make(
+ const std::vector<std::pair<std::string, std::string>>& headers) {
+ int32_t total_length = 4 /* # of headers */;
+ for (const auto& header : headers) {
+ total_length += 4 /* key length */ + 4 /* value length */ +
+ header.first.size() /* key */ + header.second.size();
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto buffer, AllocateBuffer(total_length));
+ uint8_t* payload = buffer->mutable_data();
+
+ RETURN_NOT_OK(SizeToUInt32BytesBe(headers.size(), payload));
+ payload += 4;
+ for (const auto& header : headers) {
+ RETURN_NOT_OK(SizeToUInt32BytesBe(header.first.size(), payload));
+ payload += 4;
+ RETURN_NOT_OK(SizeToUInt32BytesBe(header.second.size(), payload));
+ payload += 4;
+ std::memcpy(payload, header.first.data(), header.first.size());
+ payload += header.first.size();
+ std::memcpy(payload, header.second.data(), header.second.size());
+ payload += header.second.size();
+ }
+ return Parse(std::move(buffer));
+}
+arrow::Result<HeadersFrame> HeadersFrame::Make(
+ const Status& status,
+ const std::vector<std::pair<std::string, std::string>>& headers) {
+ auto all_headers = headers;
+ all_headers.emplace_back(kHeaderStatusCode,
+ std::to_string(static_cast<int32_t>(status.code())));
+ all_headers.emplace_back(kHeaderStatusMessage, status.message());
+ if (status.detail()) {
+ auto fsd = FlightStatusDetail::UnwrapStatus(status);
+ if (fsd) {
+ all_headers.emplace_back(kHeaderStatusDetailCode,
+ std::to_string(static_cast<int32_t>(fsd->code())));
+ all_headers.emplace_back(kHeaderStatusDetail, fsd->extra_info());
+ } else {
+ all_headers.emplace_back(kHeaderStatusDetail, status.detail()->ToString());
+ }
+ }
+ return Make(all_headers);
+}
+
+arrow::Result<util::string_view> HeadersFrame::Get(const std::string& key) {
+ for (const auto& pair : headers_) {
+ if (pair.first == key) return pair.second;
+ }
+ return Status::KeyError(key);
+}
+
+Status HeadersFrame::GetStatus(Status* out) {
+ util::string_view code_str, message_str;
+ auto status = Get(kHeaderStatusCode).Value(&code_str);
+ if (!status.ok()) {
+ return Status::KeyError("Server did not send status code header ", kHeaderStatusCode);
+ }
+
+ StatusCode status_code = StatusCode::OK;
+ auto code = std::strtol(code_str.data(), nullptr, /*base=*/10);
+ switch (code) {
+ case 0:
+ status_code = StatusCode::OK;
+ break;
+ case 1:
+ status_code = StatusCode::OutOfMemory;
+ break;
+ case 2:
+ status_code = StatusCode::KeyError;
+ break;
+ case 3:
+ status_code = StatusCode::TypeError;
+ break;
+ case 4:
+ status_code = StatusCode::Invalid;
+ break;
+ case 5:
+ status_code = StatusCode::IOError;
+ break;
+ case 6:
+ status_code = StatusCode::CapacityError;
+ break;
+ case 7:
+ status_code = StatusCode::IndexError;
+ break;
+ case 8:
+ status_code = StatusCode::Cancelled;
+ break;
+ case 9:
+ status_code = StatusCode::UnknownError;
+ break;
+ case 10:
+ status_code = StatusCode::NotImplemented;
+ break;
+ case 11:
+ status_code = StatusCode::SerializationError;
+ break;
+ case 13:
+ status_code = StatusCode::RError;
+ break;
+ case 40:
+ status_code = StatusCode::CodeGenError;
+ break;
+ case 41:
+ status_code = StatusCode::ExpressionValidationError;
+ break;
+ case 42:
+ status_code = StatusCode::ExecutionError;
+ break;
+ case 45:
+ status_code = StatusCode::AlreadyExists;
+ break;
+ default:
+ status_code = StatusCode::UnknownError;
+ break;
+ }
+ if (status_code == StatusCode::OK) {
+ *out = Status::OK();
+ return Status::OK();
+ }
+
+ status = Get(kHeaderStatusMessage).Value(&message_str);
+ if (!status.ok()) {
+ *out = Status(status_code, "Server did not send status message header", nullptr);
+ return Status::OK();
+ }
+
+ util::string_view detail_code_str, detail_str;
+ FlightStatusCode detail_code = FlightStatusCode::Internal;
+
+ if (Get(kHeaderStatusDetailCode).Value(&detail_code_str).ok()) {
+ auto detail_code_int = std::strtol(detail_code_str.data(), nullptr, /*base=*/10);
+ switch (detail_code_int) {
+ case 1:
+ detail_code = FlightStatusCode::TimedOut;
+ break;
+ case 2:
+ detail_code = FlightStatusCode::Cancelled;
+ break;
+ case 3:
+ detail_code = FlightStatusCode::Unauthenticated;
+ break;
+ case 4:
+ detail_code = FlightStatusCode::Unauthorized;
+ break;
+ case 5:
+ detail_code = FlightStatusCode::Unavailable;
+ break;
+ case 6:
+ detail_code = FlightStatusCode::Failed;
+ break;
+ case 0:
+ default:
+ detail_code = FlightStatusCode::Internal;
+ break;
+ }
+ }
+ ARROW_UNUSED(Get(kHeaderStatusDetail).Value(&detail_str));
+
+ std::shared_ptr<StatusDetail> detail = nullptr;
+ if (!detail_str.empty()) {
+ detail = std::make_shared<FlightStatusDetail>(detail_code, std::string(detail_str));
+ }
+ *out = Status(status_code, std::string(message_str), std::move(detail));
+ return Status::OK();
+}
+
+namespace {
+static constexpr uint32_t kMissingFieldSentinel = std::numeric_limits<uint32_t>::max();
+static constexpr uint32_t kInt32Max =
+ static_cast<uint32_t>(std::numeric_limits<int32_t>::max());
+arrow::Result<uint32_t> PayloadHeaderFieldSize(const std::string& field,
+ const std::shared_ptr<Buffer>& data,
+ uint32_t* total_size) {
+ if (!data) return kMissingFieldSentinel;
+ if (data->size() > kInt32Max) {
+ return Status::Invalid(field, " must be less than 2 GiB, was: ", data->size());
+ }
+ *total_size += static_cast<uint32_t>(data->size());
+ // Check for underflow
+ if (*total_size < 0) return Status::Invalid("Payload header must fit in a uint32_t");
+ return static_cast<uint32_t>(data->size());
+}
+uint8_t* PackField(uint32_t size, const std::shared_ptr<Buffer>& data, uint8_t* out) {
+ UInt32ToBytesBe(size, out);
+ if (size != kMissingFieldSentinel) {
+ std::memcpy(out + 4, data->data(), size);
+ return out + 4 + size;
+ } else {
+ return out + 4;
+ }
+}
+} // namespace
+
+arrow::Result<PayloadHeaderFrame> PayloadHeaderFrame::Make(const FlightPayload& payload,
+ MemoryPool* memory_pool) {
+ // Assemble all non-data fields here. Presumably this is much less
+ // than data size so we will pay the copy.
+
+ // Structure per field: [4 byte length][data]. If a field is not
+ // present, UINT32_MAX is used as the sentinel (since 0-sized fields
+ // are acceptable)
+ uint32_t header_size = 12;
+ ARROW_ASSIGN_OR_RAISE(
+ const uint32_t descriptor_size,
+ PayloadHeaderFieldSize("descriptor", payload.descriptor, &header_size));
+ ARROW_ASSIGN_OR_RAISE(
+ const uint32_t app_metadata_size,
+ PayloadHeaderFieldSize("app_metadata", payload.app_metadata, &header_size));
+ ARROW_ASSIGN_OR_RAISE(
+ const uint32_t ipc_metadata_size,
+ PayloadHeaderFieldSize("ipc_message.metadata", payload.ipc_message.metadata,
+ &header_size));
+
+ ARROW_ASSIGN_OR_RAISE(auto header_buffer, AllocateBuffer(header_size, memory_pool));
+ uint8_t* payload_header = header_buffer->mutable_data();
+
+ payload_header = PackField(descriptor_size, payload.descriptor, payload_header);
+ payload_header = PackField(app_metadata_size, payload.app_metadata, payload_header);
+ payload_header =
+ PackField(ipc_metadata_size, payload.ipc_message.metadata, payload_header);
+
+ return PayloadHeaderFrame(std::move(header_buffer));
+}
+Status PayloadHeaderFrame::ToFlightData(internal::FlightData* data) {
+ std::shared_ptr<Buffer> buffer = std::move(buffer_);
+
+ // Unpack the descriptor
+ uint32_t offset = 0;
+ uint32_t size = BytesToUInt32Be(buffer->data());
+ offset += 4;
+ if (size != kMissingFieldSentinel) {
+ if (static_cast<int64_t>(offset + size) > buffer->size()) {
+ return Status::Invalid("Buffer is too small: expected ", offset + size,
+ " bytes but have ", buffer->size());
+ }
+ util::string_view desc(reinterpret_cast<const char*>(buffer->data() + offset), size);
+ data->descriptor.reset(new FlightDescriptor());
+ ARROW_ASSIGN_OR_RAISE(*data->descriptor, FlightDescriptor::Deserialize(desc));
+ offset += size;
+ } else {
+ data->descriptor = nullptr;
+ }
+
+ // Unpack app_metadata
+ size = BytesToUInt32Be(buffer->data() + offset);
+ offset += 4;
+ // While we properly handle zero-size vs nullptr metadata here, gRPC
+ // doesn't (Protobuf doesn't differentiate between the two)
+ if (size != kMissingFieldSentinel) {
+ if (static_cast<int64_t>(offset + size) > buffer->size()) {
+ return Status::Invalid("Buffer is too small: expected ", offset + size,
+ " bytes but have ", buffer->size());
+ }
+ data->app_metadata = SliceBuffer(buffer, offset, size);
+ offset += size;
+ } else {
+ data->app_metadata = nullptr;
+ }
+
+ // Unpack the IPC header
+ size = BytesToUInt32Be(buffer->data() + offset);
+ offset += 4;
+ if (size != kMissingFieldSentinel) {
+ if (static_cast<int64_t>(offset + size) > buffer->size()) {
+ return Status::Invalid("Buffer is too small: expected ", offset + size,
+ " bytes but have ", buffer->size());
+ }
+ data->metadata = SliceBuffer(std::move(buffer), offset, size);
+ } else {
+ data->metadata = nullptr;
+ }
+ data->body = nullptr;
+ return Status::OK();
+}
+
+// pImpl the driver since async methods require a stable address
+class UcpCallDriver::Impl {
+ public:
+#if defined(ARROW_FLIGHT_UCX_SEND_CONTIG)
+ constexpr static bool kEnableContigSend = true;
+#else
+ constexpr static bool kEnableContigSend = false;
+#endif
+
+ Impl(std::shared_ptr<UcpWorker> worker, ucp_ep_h endpoint)
+ : padding_bytes_({0, 0, 0, 0, 0, 0, 0, 0}),
+ worker_(std::move(worker)),
+ endpoint_(endpoint),
+ read_memory_pool_(default_memory_pool()),
+ write_memory_pool_(default_memory_pool()),
+ memory_manager_(CPUDevice::Instance()->default_memory_manager()),
+ name_("(unknown remote)"),
+ counter_(0) {
+#if defined(ARROW_FLIGHT_UCX_SEND_IOV_MAP)
+ TryMapBuffer(worker_->context().get(), padding_bytes_.data(), padding_bytes_.size(),
+ UCS_MEMORY_TYPE_HOST, &padding_memh_p_);
+#endif
+
+ ucp_ep_attr_t attrs;
+ std::memset(&attrs, 0, sizeof(attrs));
+ attrs.field_mask =
+ UCP_EP_ATTR_FIELD_LOCAL_SOCKADDR | UCP_EP_ATTR_FIELD_REMOTE_SOCKADDR;
+ if (ucp_ep_query(endpoint_, &attrs) == UCS_OK) {
+ std::string local_addr, remote_addr;
+ ARROW_UNUSED(SockaddrToString(attrs.local_sockaddr).Value(&local_addr));
+ ARROW_UNUSED(SockaddrToString(attrs.remote_sockaddr).Value(&remote_addr));
+ name_ = "local:" + local_addr + ";remote:" + remote_addr;
+ }
+ }
+
+ ~Impl() {
+#if defined(ARROW_FLIGHT_UCX_SEND_IOV_MAP)
+ TryUnmapBuffer(worker_->context().get(), padding_memh_p_);
+#endif
+ }
+
+ arrow::Result<std::shared_ptr<Frame>> ReadNextFrame() {
+ auto fut = ReadFrameAsync();
+ while (!fut.is_finished()) MakeProgress();
+ RETURN_NOT_OK(fut.status());
+ return fut.MoveResult();
+ }
+
+ Future<std::shared_ptr<Frame>> ReadFrameAsync() {
+ RETURN_NOT_OK(CheckClosed());
+
+ std::unique_lock<std::mutex> guard(frame_mutex_);
+ if (ARROW_PREDICT_FALSE(!status_.ok())) return status_;
+
+ // Expected value of "counter" field in the frame header
+ const uint32_t counter_value = next_counter_++;
+ auto it = frames_.find(counter_value);
+ if (it != frames_.end()) {
+ // Message already delivered, return it
+ Future<std::shared_ptr<Frame>> fut = it->second;
+ frames_.erase(it);
+ return fut;
+ }
+ // Message not yet delivered, insert a future and wait
+ auto pair = frames_.insert({counter_value, Future<std::shared_ptr<Frame>>::Make()});
+ DCHECK(pair.second);
+ return pair.first->second;
+ }
+
+ Status SendFrame(FrameType frame_type, const uint8_t* data, const int64_t size) {
+ RETURN_NOT_OK(CheckClosed());
+
+ void* request = nullptr;
+ ucp_request_param_t request_param;
+ request_param.op_attr_mask = UCP_OP_ATTR_FIELD_FLAGS;
+ request_param.flags = UCP_AM_SEND_FLAG_REPLY;
+
+ // Send frame header
+ FrameHeader header;
+ RETURN_NOT_OK(header.Set(frame_type, counter_++, size));
+ if (size == 0) {
+ // UCX appears to crash on zero-byte payloads
+ request = ucp_am_send_nbx(endpoint_, kUcpAmHandlerId, header.data(), header.size(),
+ padding_bytes_.data(),
+ /*size=*/1, &request_param);
+ } else {
+ request = ucp_am_send_nbx(endpoint_, kUcpAmHandlerId, header.data(), header.size(),
+ data, size, &request_param);
+ }
+ RETURN_NOT_OK(CompleteRequestBlocking("ucp_am_send_nbx", request));
+
+ return Status::OK();
+ }
+
+ Future<> SendFrameAsync(FrameType frame_type, std::unique_ptr<Buffer> buffer) {
+ RETURN_NOT_OK(CheckClosed());
+
+ ucp_request_param_t request_param;
+ std::memset(&request_param, 0, sizeof(request_param));
+ request_param.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_DATATYPE |
+ UCP_OP_ATTR_FIELD_FLAGS | UCP_OP_ATTR_FIELD_USER_DATA;
+ request_param.cb.send = AmSendCallback;
+ request_param.datatype = ucp_dt_make_contig(1);
+ request_param.flags = UCP_AM_SEND_FLAG_REPLY;
+
+ const int64_t size = buffer->size();
+ if (size == 0) {
+ // UCX appears to crash on zero-byte payloads
+ ARROW_ASSIGN_OR_RAISE(auto buffer, AllocateBuffer(1, write_memory_pool_));
+ }
+
+ std::unique_ptr<PendingContigSend> pending_send(new PendingContigSend());
+ RETURN_NOT_OK(pending_send->header.Set(frame_type, counter_++, size));
+ pending_send->ipc_message = std::move(buffer);
+ pending_send->driver = this;
+ pending_send->completed = Future<>::Make();
+ pending_send->memh_p = nullptr;
+
+ request_param.user_data = pending_send.release();
+ {
+ auto* pending_send = reinterpret_cast<PendingContigSend*>(request_param.user_data);
+
+ void* request = ucp_am_send_nbx(
+ endpoint_, kUcpAmHandlerId, pending_send->header.data(),
+ pending_send->header.size(),
+ reinterpret_cast<void*>(pending_send->ipc_message->mutable_data()),
+ static_cast<size_t>(pending_send->ipc_message->size()), &request_param);
+ if (!request) {
+ // Request completed immediately
+ delete pending_send;
+ return Status::OK();
+ } else if (UCS_PTR_IS_ERR(request)) {
+ delete pending_send;
+ return FromUcsStatus("ucp_am_send_nbx", UCS_PTR_STATUS(request));
+ }
+ return pending_send->completed;
+ }
+ }
+
+ Future<> SendFlightPayload(const FlightPayload& payload) {
+ static const int64_t kMaxBatchSize = std::numeric_limits<int32_t>::max();
+ RETURN_NOT_OK(CheckClosed());
+
+ if (payload.ipc_message.body_length > kMaxBatchSize) {
+ return Status::Invalid("Cannot send record batches exceeding 2GiB yet");
+ }
+
+ {
+ ARROW_ASSIGN_OR_RAISE(auto frame,
+ PayloadHeaderFrame::Make(payload, write_memory_pool_));
+ RETURN_NOT_OK(SendFrame(FrameType::kPayloadHeader, frame.data(), frame.size()));
+ }
+
+ if (!ipc::Message::HasBody(payload.ipc_message.type)) {
+ return Status::OK();
+ }
+
+ // While IOV (scatter-gather) might seem like it avoids a memcpy,
+ // profiling shows that at least for the TCP/SHM/RDMA transports,
+ // UCX just does a memcpy internally. Furthermore, on the receiver
+ // side, a sender-side IOV send prevents optimizations based on
+ // mapped buffers (UCX will memcpy to the destination buffer
+ // regardless of whether it's mapped or not).
+
+ // If all buffers are on the CPU, concatenate them ourselves and
+ // do a regular send to avoid this. Else, use IOV and let UCX
+ // figure out what to do.
+
+ // Weirdness: UCX prefers TCP over shared memory for CONTIG? We
+ // can avoid this by setting UCX_RNDV_THRESH=inf, this will make
+ // UCX prefer shared memory again. However, we still want to avoid
+ // the CONTIG path when shared memory is available, because the
+ // total amount of time spent in memcpy is greater than using IOV
+ // and letting UCX handle it.
+
+ // Consider: if we can figure out how to make IOV always as fast
+ // as CONTIG, we can just send the metadata fields as part of the
+ // IOV payload and avoid having to send two distinct messages.
+
+ bool all_cpu = true;
+ int32_t total_buffers = 0;
+ for (const auto& buffer : payload.ipc_message.body_buffers) {
+ if (!buffer || buffer->size() == 0) continue;
+ all_cpu = all_cpu && buffer->is_cpu();
+ total_buffers++;
+
+ // Arrow IPC requires that we align buffers to 8 byte boundary
+ const auto remainder = static_cast<int>(
+ bit_util::RoundUpToMultipleOf8(buffer->size()) - buffer->size());
+ if (remainder) total_buffers++;
+ }
+
+ ucp_request_param_t request_param;
+ std::memset(&request_param, 0, sizeof(request_param));
+ request_param.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_DATATYPE |
+ UCP_OP_ATTR_FIELD_FLAGS | UCP_OP_ATTR_FIELD_USER_DATA;
+ request_param.cb.send = AmSendCallback;
+ request_param.flags = UCP_AM_SEND_FLAG_REPLY;
+
+ std::unique_ptr<PendingAmSend> pending_send;
+ void* send_data = nullptr;
+ size_t send_size = 0;
+
+ if (!all_cpu) {
+ request_param.op_attr_mask =
+ request_param.op_attr_mask | UCP_OP_ATTR_FIELD_MEMORY_TYPE;
+ // XXX: UCX doesn't appear to autodetect this correctly if we
+ // use UNKNOWN
+ request_param.memory_type = UCS_MEMORY_TYPE_CUDA;
+ }
+
+ if (kEnableContigSend && all_cpu) {
+ // CONTIG - concatenate buffers into one before sending
+
+ // TODO(ARROW-16126): this needs to be pipelined since it can be expensive.
+ // Preliminary profiling shows ~5% overhead just from mapping the buffer
+ // alone (on Infiniband; it seems to be trivial for shared memory)
+ request_param.datatype = ucp_dt_make_contig(1);
+ pending_send = arrow::internal::make_unique<PendingContigSend>();
+ auto* pending_contig = reinterpret_cast<PendingContigSend*>(pending_send.get());
+
+ const int64_t body_length = std::max<int64_t>(payload.ipc_message.body_length, 1);
+ ARROW_ASSIGN_OR_RAISE(pending_contig->ipc_message,
+ AllocateBuffer(body_length, write_memory_pool_));
+ TryMapBuffer(worker_->context().get(), *pending_contig->ipc_message,
+ &pending_contig->memh_p);
+
+ uint8_t* ipc_message = pending_contig->ipc_message->mutable_data();
+ if (payload.ipc_message.body_length == 0) {
+ std::memset(ipc_message, '\0', 1);
+ }
+
+ for (const auto& buffer : payload.ipc_message.body_buffers) {
+ if (!buffer || buffer->size() == 0) continue;
+
+ std::memcpy(ipc_message, buffer->data(), buffer->size());
+ ipc_message += buffer->size();
+
+ const auto remainder = static_cast<int>(
+ bit_util::RoundUpToMultipleOf8(buffer->size()) - buffer->size());
+ if (remainder) {
+ std::memset(ipc_message, 0, remainder);
+ ipc_message += remainder;
+ }
+ }
+
+ send_data = reinterpret_cast<void*>(pending_contig->ipc_message->mutable_data());
+ send_size = static_cast<size_t>(pending_contig->ipc_message->size());
+ } else {
+ // IOV - let UCX use scatter-gather path
+ request_param.datatype = UCP_DATATYPE_IOV;
+ pending_send = arrow::internal::make_unique<PendingIovSend>();
+ auto* pending_iov = reinterpret_cast<PendingIovSend*>(pending_send.get());
+
+ pending_iov->payload = payload;
+ pending_iov->iovs.resize(total_buffers);
+ ucp_dt_iov_t* iov = pending_iov->iovs.data();
+#if defined(ARROW_FLIGHT_UCX_SEND_IOV_MAP)
+ // XXX: this seems to have no benefits in tests so far
+ pending_iov->memh_ps.resize(total_buffers);
+ ucp_mem_h* memh_p = pending_iov->memh_ps.data();
+#endif
+ for (const auto& buffer : payload.ipc_message.body_buffers) {
+ if (!buffer || buffer->size() == 0) continue;
+
+ iov->buffer = const_cast<void*>(reinterpret_cast<const void*>(buffer->address()));
+ iov->length = buffer->size();
+ ++iov;
+
+#if defined(ARROW_FLIGHT_UCX_SEND_IOV_MAP)
+ TryMapBuffer(worker_->context().get(), *buffer, memh_p);
+ memh_p++;
+#endif
+
+ const auto remainder = static_cast<int>(
+ bit_util::RoundUpToMultipleOf8(buffer->size()) - buffer->size());
+ if (remainder) {
+ iov->buffer =
+ const_cast<void*>(reinterpret_cast<const void*>(padding_bytes_.data()));
+ iov->length = remainder;
+ ++iov;
+ }
+ }
+
+ if (total_buffers == 0) {
+ // UCX cannot handle zero-byte payloads
+ pending_iov->iovs.resize(1);
+ pending_iov->iovs[0].buffer =
+ const_cast<void*>(reinterpret_cast<const void*>(padding_bytes_.data()));
+ pending_iov->iovs[0].length = 1;
+ }
+
+ send_data = pending_iov->iovs.data();
+ send_size = pending_iov->iovs.size();
+ }
+
+ DCHECK(send_data) << "Payload cannot be nullptr";
+ DCHECK_GT(send_size, 0) << "Payload cannot be empty";
+
+ RETURN_NOT_OK(pending_send->header.Set(FrameType::kPayloadBody, counter_++,
+ payload.ipc_message.body_length));
+ pending_send->driver = this;
+ pending_send->completed = Future<>::Make();
+
+ request_param.user_data = pending_send.release();
+ {
+ auto* pending_send = reinterpret_cast<PendingAmSend*>(request_param.user_data);
+
+ void* request = ucp_am_send_nbx(
+ endpoint_, kUcpAmHandlerId, pending_send->header.data(),
+ pending_send->header.size(), send_data, send_size, &request_param);
+ if (!request) {
+ // Request completed immediately
+ delete pending_send;
+ return Status::OK();
+ } else if (UCS_PTR_IS_ERR(request)) {
+ delete pending_send;
+ return FromUcsStatus("ucp_am_send_nbx", UCS_PTR_STATUS(request));
+ }
+ return pending_send->completed;
+ }
+ }
+
+ Status Close() {
+ if (!endpoint_) return Status::OK();
+
+ for (auto& item : frames_) {
+ item.second.MarkFinished(Status::Cancelled("UcpCallDriver is being closed"));
+ }
+ frames_.clear();
+
+ void* request = ucp_ep_close_nb(endpoint_, UCP_EP_CLOSE_MODE_FLUSH);
+ ucs_status_t status = UCS_OK;
+ std::string origin = "ucp_ep_close_nb";
+ if (UCS_PTR_IS_ERR(request)) {
+ status = UCS_PTR_STATUS(request);
+ } else if (UCS_PTR_IS_PTR(request)) {
+ origin = "ucp_request_check_status";
+ while ((status = ucp_request_check_status(request)) == UCS_INPROGRESS) {
+ MakeProgress();
+ }
+ ucp_request_free(request);
+ } else {
+ DCHECK(!request);
+ }
+
+ endpoint_ = nullptr;
+ if (status != UCS_OK && !IsIgnorableDisconnectError(status)) {
+ return FromUcsStatus(origin, status);
+ }
+ return Status::OK();
+ }
+
+ void MakeProgress() { ucp_worker_progress(worker_->get()); }
+
+ void Push(std::shared_ptr<Frame> frame) {
+ std::unique_lock<std::mutex> guard(frame_mutex_);
+ if (ARROW_PREDICT_FALSE(!status_.ok())) return;
+ auto pair = frames_.insert({frame->counter, frame});
+ if (!pair.second) {
+ // Not inserted, because ReadFrameAsync was called for this
+ // frame counter value and the client is already waiting on
+ // it. Complete the existing future.
+ pair.first->second.MarkFinished(std::move(frame));
+ frames_.erase(pair.first);
+ }
+ // Otherwise, we inserted the frame, meaning the client was not
+ // currently waiting for that frame counter value
+ }
+
+ void Push(Status status) {
+ std::unique_lock<std::mutex> guard(frame_mutex_);
+ status_ = std::move(status);
+ for (auto& item : frames_) {
+ item.second.MarkFinished(status_);
+ }
+ frames_.clear();
+ }
+
+ ucs_status_t RecvActiveMessage(const void* header, size_t header_length, void* data,
+ const size_t data_length,
+ const ucp_am_recv_param_t* param) {
+ auto maybe_status =
+ RecvActiveMessageImpl(header, header_length, data, data_length, param);
+ if (!maybe_status.ok()) {
+ Push(maybe_status.status());
+ return UCS_OK;
+ }
+ return maybe_status.MoveValueUnsafe();
+ }
+
+ const std::shared_ptr<MemoryManager>& memory_manager() const { return memory_manager_; }
+ void set_memory_manager(std::shared_ptr<MemoryManager> memory_manager) {
+ if (memory_manager) {
+ memory_manager_ = std::move(memory_manager);
+ } else {
+ memory_manager_ = CPUDevice::Instance()->default_memory_manager();
+ }
+ }
+ void set_read_memory_pool(MemoryPool* pool) {
+ read_memory_pool_ = pool ? pool : default_memory_pool();
+ }
+ void set_write_memory_pool(MemoryPool* pool) {
+ write_memory_pool_ = pool ? pool : default_memory_pool();
+ }
+ const std::string& peer() const { return name_; }
+
+ private:
+ class PendingAmSend {
+ public:
+ virtual ~PendingAmSend() = default;
+ UcpCallDriver::Impl* driver;
+ Future<> completed;
+ FrameHeader header;
+ };
+
+ class PendingContigSend : public PendingAmSend {
+ public:
+ std::unique_ptr<Buffer> ipc_message;
+ ucp_mem_h memh_p;
+
+ virtual ~PendingContigSend() {
+ TryUnmapBuffer(driver->worker_->context().get(), memh_p);
+ }
+ };
+
+ class PendingIovSend : public PendingAmSend {
+ public:
+ FlightPayload payload;
+ std::vector<ucp_dt_iov_t> iovs;
+#if defined(ARROW_FLIGHT_UCX_SEND_IOV_MAP)
+ std::vector<ucp_mem_h> memh_ps;
+
+ virtual ~PendingIovSend() {
+ for (ucp_mem_h memh_p : memh_ps) {
+ TryUnmapBuffer(driver->worker_->context().get(), memh_p);
+ }
+ }
+#endif
+ };
+
+ struct PendingAmRecv {
+ UcpCallDriver::Impl* driver;
+ std::shared_ptr<Frame> frame;
+ ucp_mem_h memh_p;
+
+ PendingAmRecv(UcpCallDriver::Impl* driver_, std::shared_ptr<Frame> frame_)
+ : driver(driver_), frame(std::move(frame_)) {}
+
+ ~PendingAmRecv() { TryUnmapBuffer(driver->worker_->context().get(), memh_p); }
+ };
+
+ static void AmSendCallback(void* request, ucs_status_t status, void* user_data) {
+ auto* pending_send = reinterpret_cast<PendingAmSend*>(user_data);
+ if (status == UCS_OK) {
+ pending_send->completed.MarkFinished();
+ } else {
+ pending_send->completed.MarkFinished(FromUcsStatus("ucp_am_send_nbx", status));
+ }
+ // TODO(ARROW-16126): delete should occur on a background thread if there's
+ // mapped buffers, since unmapping can be nontrivial and we don't want to block
+ // the thread doing UCX work. (Borrow the Rust transfer-and-drop pattern.)
+ delete pending_send;
+ ucp_request_free(request);
+ }
+
+ static void AmRecvCallback(void* request, ucs_status_t status, size_t length,
+ void* user_data) {
+ auto* pending_recv = reinterpret_cast<PendingAmRecv*>(user_data);
+ ucp_request_free(request);
+ if (status != UCS_OK) {
+ pending_recv->driver->Push(
+ FromUcsStatus("ucp_am_recv_data_nbx (callback)", status));
+ } else {
+ pending_recv->driver->Push(std::move(pending_recv->frame));
+ }
+ delete pending_recv;
+ }
+
+ arrow::Result<ucs_status_t> RecvActiveMessageImpl(const void* header,
+ size_t header_length, void* data,
+ const size_t data_length,
+ const ucp_am_recv_param_t* param) {
+ DCHECK(param->recv_attr & UCP_AM_RECV_ATTR_FIELD_REPLY_EP);
+
+ if (data_length > static_cast<size_t>(std::numeric_limits<int32_t>::max())) {
+ return Status::Invalid("Cannot allocate buffer greater than 2 GiB, requested: ",
+ data_length);
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto frame, Frame::ParseHeader(header, header_length));
+ if (data_length < frame->size) {
+ return Status::IOError("Expected frame of ", frame->size, " bytes, but got only ",
+ data_length);
+ }
+
+ if ((param->recv_attr & UCP_AM_RECV_ATTR_FLAG_DATA) &&
+ (memory_manager_->is_cpu() || frame->type != FrameType::kPayloadBody)) {
+ // Zero-copy path. UCX-allocated buffer must be freed later.
+
+ // XXX: this buffer can NOT be freed until AFTER we return from
+ // this handler. Otherwise, UCX won't have fully set up its
+ // internal data structures (allocated just before the buffer)
+ // and we'll crash when we free the buffer. Effectively: we can
+ // never use Then/AddCallback on a Future<> from ReadFrameAsync,
+ // because we might run the callback synchronously (which might
+ // free the buffer) when we call Push here.
+ frame->buffer =
+ arrow::internal::make_unique<UcxDataBuffer>(worker_, data, data_length);
+ Push(std::move(frame));
+ return UCS_INPROGRESS;
+ }
+
+ if ((param->recv_attr & UCP_AM_RECV_ATTR_FLAG_DATA) ||
+ (param->recv_attr & UCP_AM_RECV_ATTR_FLAG_RNDV)) {
+ // Rendezvous protocol (RNDV), or unpack to destination (DATA).
+
+ // We want to map/pin/register the buffer for faster transfer
+ // where possible. (It gets unmapped in ~PendingAmRecv.)
+ // TODO(ARROW-16126): This takes non-trivial time, so return
+ // UCS_INPROGRESS, kick off the allocation in the background,
+ // and recv the data later (is it allowed to call
+ // ucp_am_recv_data_nbx asynchronously?).
+ if (frame->type == FrameType::kPayloadBody) {
+ ARROW_ASSIGN_OR_RAISE(frame->buffer,
+ memory_manager_->AllocateBuffer(data_length));
+ } else {
+ ARROW_ASSIGN_OR_RAISE(frame->buffer,
+ AllocateBuffer(data_length, read_memory_pool_));
+ }
+
+ PendingAmRecv* pending_recv = new PendingAmRecv(this, std::move(frame));
+ TryMapBuffer(worker_->context().get(), *pending_recv->frame->buffer,
+ &pending_recv->memh_p);
+
+ ucp_request_param_t recv_param;
+ recv_param.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK |
+ UCP_OP_ATTR_FIELD_MEMORY_TYPE |
+ UCP_OP_ATTR_FIELD_USER_DATA;
+ recv_param.cb.recv_am = AmRecvCallback;
+ recv_param.user_data = pending_recv;
+ recv_param.memory_type = InferMemoryType(*pending_recv->frame->buffer);
+
+ void* dest =
+ reinterpret_cast<void*>(pending_recv->frame->buffer->mutable_address());
+ void* request =
+ ucp_am_recv_data_nbx(worker_->get(), data, dest, data_length, &recv_param);
+ if (UCS_PTR_IS_ERR(request)) {
+ delete pending_recv;
+ return FromUcsStatus("ucp_am_recv_data_nbx", UCS_PTR_STATUS(request));
+ } else if (!request) {
+ // Request completed instantly
+ Push(std::move(pending_recv->frame));
+ delete pending_recv;
+ }
+ return UCS_OK;
+ } else {
+ // Data will be freed after callback returns - copy to buffer
+ if (memory_manager_->is_cpu() || frame->type != FrameType::kPayloadBody) {
+ ARROW_ASSIGN_OR_RAISE(frame->buffer,
+ AllocateBuffer(data_length, read_memory_pool_));
+ std::memcpy(frame->buffer->mutable_data(), data, data_length);
+ } else {
+ ARROW_ASSIGN_OR_RAISE(
+ frame->buffer,
+ MemoryManager::CopyNonOwned(Buffer(reinterpret_cast<uint8_t*>(data),
+ static_cast<int64_t>(data_length)),
+ memory_manager_));
+ }
+ Push(std::move(frame));
+ return UCS_OK;
+ }
+ }
+
+ Status CompleteRequestBlocking(const std::string& context, void* request) {
+ if (UCS_PTR_IS_ERR(request)) {
+ return FromUcsStatus(context, UCS_PTR_STATUS(request));
+ } else if (UCS_PTR_IS_PTR(request)) {
+ while (true) {
+ auto status = ucp_request_check_status(request);
+ if (status == UCS_OK) {
+ break;
+ } else if (status != UCS_INPROGRESS) {
+ ucp_request_release(request);
+ return FromUcsStatus("ucp_request_check_status", status);
+ }
+ MakeProgress();
+ }
+ ucp_request_free(request);
+ } else {
+ // Send was completed instantly
+ DCHECK(!request);
+ }
+ return Status::OK();
+ }
+
+ Status CheckClosed() {
+ if (!endpoint_) {
+ return Status::Invalid("UcpCallDriver is closed");
+ }
+ return Status::OK();
+ }
+
+ const std::array<uint8_t, 8> padding_bytes_;
+#if defined(ARROW_FLIGHT_UCX_SEND_IOV_MAP)
+ ucp_mem_h padding_memh_p_;
+#endif
+
+ std::shared_ptr<UcpWorker> worker_;
+ ucp_ep_h endpoint_;
+ MemoryPool* read_memory_pool_;
+ MemoryPool* write_memory_pool_;
+ std::shared_ptr<MemoryManager> memory_manager_;
+
+ // Internal name for logging/tracing
+ std::string name_;
+ // Counter used to reorder messages
+ uint32_t counter_ = 0;
+
+ std::mutex frame_mutex_;
+ Status status_;
+ std::unordered_map<uint32_t, Future<std::shared_ptr<Frame>>> frames_;
+ uint32_t next_counter_ = 0;
+};
+
+UcpCallDriver::UcpCallDriver(std::shared_ptr<UcpWorker> worker, ucp_ep_h endpoint)
+ : impl_(new Impl(std::move(worker), endpoint)) {}
+UcpCallDriver::UcpCallDriver(UcpCallDriver&&) = default;
+UcpCallDriver& UcpCallDriver::operator=(UcpCallDriver&&) = default;
+UcpCallDriver::~UcpCallDriver() = default;
+
+arrow::Result<std::shared_ptr<Frame>> UcpCallDriver::ReadNextFrame() {
+ return impl_->ReadNextFrame();
+}
+
+Future<std::shared_ptr<Frame>> UcpCallDriver::ReadFrameAsync() {
+ return impl_->ReadFrameAsync();
+}
+
+Status UcpCallDriver::ExpectFrameType(const Frame& frame, FrameType type) {
+ if (frame.type != type) {
+ return Status::IOError("Expected frame type ", static_cast<int32_t>(type),
+ ", but got frame type ", static_cast<int32_t>(frame.type));
+ }
+ return Status::OK();
+}
+
+Status UcpCallDriver::StartCall(const std::string& method) {
+ std::vector<std::pair<std::string, std::string>> headers;
+ headers.emplace_back(kHeaderMethod, method);
+ ARROW_ASSIGN_OR_RAISE(auto frame, HeadersFrame::Make(headers));
+ auto buffer = std::move(frame).GetBuffer();
+ RETURN_NOT_OK(impl_->SendFrame(FrameType::kHeaders, buffer->data(), buffer->size()));
+ return Status::OK();
+}
+
+Future<> UcpCallDriver::SendFlightPayload(const FlightPayload& payload) {
+ return impl_->SendFlightPayload(payload);
+}
+
+Status UcpCallDriver::SendFrame(FrameType frame_type, const uint8_t* data,
+ const int64_t size) {
+ return impl_->SendFrame(frame_type, data, size);
+}
+
+Future<> UcpCallDriver::SendFrameAsync(FrameType frame_type,
+ std::unique_ptr<Buffer> buffer) {
+ return impl_->SendFrameAsync(frame_type, std::move(buffer));
+}
+
+Status UcpCallDriver::Close() { return impl_->Close(); }
+
+void UcpCallDriver::MakeProgress() { impl_->MakeProgress(); }
+
+ucs_status_t UcpCallDriver::RecvActiveMessage(const void* header, size_t header_length,
+ void* data, const size_t data_length,
+ const ucp_am_recv_param_t* param) {
+ return impl_->RecvActiveMessage(header, header_length, data, data_length, param);
+}
+
+const std::shared_ptr<MemoryManager>& UcpCallDriver::memory_manager() const {
+ return impl_->memory_manager();
+}
+
+void UcpCallDriver::set_memory_manager(std::shared_ptr<MemoryManager> memory_manager) {
+ impl_->set_memory_manager(std::move(memory_manager));
+}
+void UcpCallDriver::set_read_memory_pool(MemoryPool* pool) {
+ impl_->set_read_memory_pool(pool);
+}
+void UcpCallDriver::set_write_memory_pool(MemoryPool* pool) {
+ impl_->set_write_memory_pool(pool);
+}
+const std::string& UcpCallDriver::peer() const { return impl_->peer(); }
+
+} // namespace ucx
+} // namespace transport
+} // namespace flight
+} // namespace arrow
diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_internal.h b/cpp/src/arrow/flight/transport/ucx/ucx_internal.h
new file mode 100644
index 0000000000..bd176e2369
--- /dev/null
+++ b/cpp/src/arrow/flight/transport/ucx/ucx_internal.h
@@ -0,0 +1,354 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Common implementation of UCX communication primitives.
+
+#pragma once
+
+#include <array>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include <ucp/api/ucp.h>
+
+#include "arrow/buffer.h"
+#include "arrow/flight/server.h"
+#include "arrow/flight/transport.h"
+#include "arrow/flight/transport/ucx/util_internal.h"
+#include "arrow/flight/visibility.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/future.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/string_view.h"
+
+namespace arrow {
+namespace flight {
+namespace transport {
+namespace ucx {
+
+//------------------------------------------------------------
+// Protocol Constants
+
+static constexpr char kMethodDoExchange[] = "DoExchange";
+static constexpr char kMethodDoGet[] = "DoGet";
+static constexpr char kMethodDoPut[] = "DoPut";
+static constexpr char kMethodGetFlightInfo[] = "GetFlightInfo";
+
+static constexpr char kHeaderStatusCode[] = "flight-status-code";
+static constexpr char kHeaderStatusMessage[] = "flight-status-message";
+static constexpr char kHeaderStatusDetail[] = "flight-status-detail";
+static constexpr char kHeaderStatusDetailCode[] = "flight-status-detail-code";
+
+//------------------------------------------------------------
+// UCX Helpers
+
+/// \brief A wrapper around a ucp_context_h.
+///
+/// Used so that multiple resources can share ownership of the
+/// context. UCX has zero-copy optimizations where an application can
+/// directly use a UCX buffer, but the lifetime of such buffers is
+/// tied to the UCX context and worker, so ownership needs to be
+/// preserved.
+class UcpContext final {
+ public:
+ UcpContext() : ucp_context_(nullptr) {}
+ explicit UcpContext(ucp_context_h context) : ucp_context_(context) {}
+ ~UcpContext() {
+ if (ucp_context_) ucp_cleanup(ucp_context_);
+ ucp_context_ = nullptr;
+ }
+ ucp_context_h get() const {
+ DCHECK(ucp_context_);
+ return ucp_context_;
+ }
+
+ private:
+ ucp_context_h ucp_context_;
+};
+
+/// \brief A wrapper around a ucp_worker_h.
+class UcpWorker final {
+ public:
+ UcpWorker() : ucp_worker_(nullptr) {}
+ UcpWorker(std::shared_ptr<UcpContext> context, ucp_worker_h worker)
+ : ucp_context_(std::move(context)), ucp_worker_(worker) {}
+ ~UcpWorker() {
+ if (ucp_worker_) ucp_worker_destroy(ucp_worker_);
+ ucp_worker_ = nullptr;
+ }
+ ucp_worker_h get() const {
+ DCHECK(ucp_worker_);
+ return ucp_worker_;
+ }
+ const UcpContext& context() const { return *ucp_context_; }
+
+ private:
+ std::shared_ptr<UcpContext> ucp_context_;
+ ucp_worker_h ucp_worker_;
+};
+
+//------------------------------------------------------------
+// Message Framing
+
+/// \brief The message type.
+enum class FrameType : uint8_t {
+ /// Key-value headers. Sent at the beginning (client->server) and
+ /// end (server->client) of a call. Also, for client-streaming calls
+ /// (e.g. DoPut), the client should send a headers frame to signal
+ /// end-of-stream.
+ kHeaders = 0,
+ /// Binary blob, does not contain Arrow data.
+ kBuffer,
+ /// Binary blob. Contains IPC metadata, app metadata.
+ kPayloadHeader,
+ /// Binary blob. Contains IPC body. Body is sent separately since it
+ /// may use a different memory type.
+ kPayloadBody,
+ /// Ask server to disconnect (to avoid client/server waiting on each
+ /// other and getting stuck).
+ kDisconnect,
+ /// Keep at end.
+ kMaxFrameType = kDisconnect,
+};
+
+/// \brief The header of a message frame. Used when sending only.
+///
+/// A frame is expected to be sent over UCP Active Messages and
+/// consists of a header (of kFrameHeaderBytes bytes) and a body.
+///
+/// The header is as follows:
+/// +-------+---------------------------------+
+/// | Bytes | Function |
+/// +=======+=================================+
+/// | 0 | Version tag (see kFrameVersion) |
+/// | 1 | Frame type (see FrameType) |
+/// | 2-3 | Unused, reserved |
+/// | 4-7 | Frame counter (big-endian) |
+/// | 8-11 | Body size (big-endian) |
+/// +-------+---------------------------------+
+///
+/// The frame counter lets the receiver ensure messages are processed
+/// in-order. (The message receive callback may use
+/// ucp_am_recv_data_nbx which is asynchronous.)
+///
+/// The body size reports the expected message size (UCX chokes on
+/// zero-size payloads which we occasionally want to send, so the size
+/// field in the header lets us know when a payload was meant to be
+/// empty).
+struct FrameHeader {
+ /// \brief The size of a frame header.
+ static constexpr size_t kFrameHeaderBytes = 12;
+ /// \brief The expected version tag in the header.
+ static constexpr uint8_t kFrameVersion = 0x01;
+
+ FrameHeader() = default;
+ /// \brief Initialize the frame header.
+ Status Set(FrameType frame_type, uint32_t counter, int64_t body_size);
+ void* data() const { return header.data(); }
+ size_t size() const { return kFrameHeaderBytes; }
+
+ // mutable since UCX expects void* not const void*
+ mutable std::array<uint8_t, kFrameHeaderBytes> header = {0};
+};
+
+/// \brief A single message received via UCX. Used when receiving only.
+struct Frame {
+ /// \brief The message type.
+ FrameType type;
+ /// \brief The message length.
+ uint32_t size;
+ /// \brief An incrementing message counter (may wrap over).
+ uint32_t counter;
+ /// \brief The message contents.
+ std::unique_ptr<Buffer> buffer;
+
+ Frame() = default;
+ Frame(FrameType type_, uint32_t size_, uint32_t counter_,
+ std::unique_ptr<Buffer> buffer_)
+ : type(type_), size(size_), counter(counter_), buffer(std::move(buffer_)) {}
+
+ util::string_view view() const {
+ return util::string_view(reinterpret_cast<const char*>(buffer->data()), size);
+ }
+
+ /// \brief Parse a UCX active message header. This will not
+ /// initialize the buffer field.
+ static arrow::Result<std::shared_ptr<Frame>> ParseHeader(const void* header,
+ size_t header_length);
+};
+
+/// \brief The active message handler callback ID.
+static constexpr uint32_t kUcpAmHandlerId = 0x1024;
+
+/// \brief A collection of key-value headers.
+///
+/// This should be stored in a frame of type kHeaders.
+///
+/// Format:
+/// +-------+----------------------------------+
+/// | Bytes | Contents |
+/// +=======+==================================+
+/// | 0-4 | # of headers (big-endian) |
+/// | 4-8 | Header key length (big-endian) |
+/// | 2-3 | Header value length (big-endian) |
+/// | (...) | Header key |
+/// | (...) | Header value |
+/// | (...) | (repeat from row 2) |
+/// +-------+----------------------------------+
+class HeadersFrame {
+ public:
+ /// \brief Get a header value (or an error if it was not found)
+ arrow::Result<util::string_view> Get(const std::string& key);
+ /// \brief Extract the server-sent status.
+ Status GetStatus(Status* out);
+ /// \brief Parse the headers from the buffer.
+ static arrow::Result<HeadersFrame> Parse(std::unique_ptr<Buffer> buffer);
+ /// \brief Create a new frame with the given headers.
+ static arrow::Result<HeadersFrame> Make(
+ const std::vector<std::pair<std::string, std::string>>& headers);
+ /// \brief Create a new frame with the given headers and the given status.
+ static arrow::Result<HeadersFrame> Make(
+ const Status& status,
+ const std::vector<std::pair<std::string, std::string>>& headers);
+
+ /// \brief Take ownership of the underlying buffer.
+ std::unique_ptr<Buffer> GetBuffer() && { return std::move(buffer_); }
+
+ private:
+ std::unique_ptr<Buffer> buffer_;
+ std::vector<std::pair<util::string_view, util::string_view>> headers_;
+};
+
+/// \brief A representation of a kPayloadHeader frame (i.e. all of the
+/// metadata in a FlightPayload/FlightData).
+///
+/// Data messages are sent in two parts: one containing all metadata
+/// (the Flatbuffers header, FlightDescriptor, and app_metadata
+/// fields) and one containing the actual data. This was done to avoid
+/// having to concatenate these fields with the data itself (in the
+/// cases where we are not using IOV).
+///
+/// Format:
+/// +--------+----------------------------------+
+/// | Bytes | Contents |
+/// +========+==================================+
+/// | 0-4 | Descriptor length (big-endian) |
+/// | 4..a | Descriptor bytes |
+/// | a-a+4 | app_metadata length (big-endian) |
+/// | a+4..b | app_metadata bytes |
+/// | b-b+4 | ipc_metadata length (big-endian) |
+/// | b+4..c | ipc_metadata bytes |
+/// +--------+----------------------------------+
+///
+/// If a field is not present, its length is still there, but is set
+/// to UINT32_MAX.
+class PayloadHeaderFrame {
+ public:
+ explicit PayloadHeaderFrame(std::unique_ptr<Buffer> buffer)
+ : buffer_(std::move(buffer)) {}
+ /// \brief Unpack the internal buffer into a FlightData.
+ Status ToFlightData(internal::FlightData* data);
+ /// \brief Pack a payload into the internal buffer.
+ static arrow::Result<PayloadHeaderFrame> Make(const FlightPayload& payload,
+ MemoryPool* memory_pool);
+ const uint8_t* data() const { return buffer_->data(); }
+ int64_t size() const { return buffer_->size(); }
+
+ private:
+ std::unique_ptr<Buffer> buffer_;
+};
+
+/// \brief Manage the state of a UCX connection.
+class UcpCallDriver {
+ public:
+ UcpCallDriver(std::shared_ptr<UcpWorker> worker, ucp_ep_h endpoint);
+
+ UcpCallDriver(const UcpCallDriver&) = delete;
+ UcpCallDriver(UcpCallDriver&&);
+ void operator=(const UcpCallDriver&) = delete;
+ UcpCallDriver& operator=(UcpCallDriver&&);
+
+ ~UcpCallDriver();
+
+ /// \brief Start a call by sending a headers frame. Client side only.
+ ///
+ /// \param[in] method The RPC method.
+ Status StartCall(const std::string& method);
+
+ /// \brief Synchronously send a generic message with binary payload.
+ Status SendFrame(FrameType frame_type, const uint8_t* data, const int64_t size);
+ /// \brief Asynchronously send a generic message with binary payload.
+ ///
+ /// The UCP driver must be manually polled (call MakeProgress()).
+ Future<> SendFrameAsync(FrameType frame_type, std::unique_ptr<Buffer> buffer);
+ /// \brief Asynchronously send a data message.
+ ///
+ /// The UCP driver must be manually polled (call MakeProgress()).
+ Future<> SendFlightPayload(const FlightPayload& payload);
+
+ /// \brief Synchronously read the next frame.
+ arrow::Result<std::shared_ptr<Frame>> ReadNextFrame();
+ /// \brief Asynchronously read the next frame.
+ ///
+ /// The UCP driver must be manually polled (call MakeProgress()).
+ Future<std::shared_ptr<Frame>> ReadFrameAsync();
+
+ /// \brief Validate that the frame is of the given type.
+ Status ExpectFrameType(const Frame& frame, FrameType type);
+
+ /// \brief Disconnect the other side of the connection. Note, this
+ /// can cause deadlock.
+ Status Close();
+
+ /// \brief Synchronously make progress (to adapt async to sync APIs)
+ void MakeProgress();
+
+ /// \brief Get the associated memory manager.
+ const std::shared_ptr<MemoryManager>& memory_manager() const;
+ /// \brief Set the associated memory manager.
+ void set_memory_manager(std::shared_ptr<MemoryManager> memory_manager);
+ /// \brief Set memory pool for scratch space used during reading.
+ void set_read_memory_pool(MemoryPool* memory_pool);
+ /// \brief Set memory pool for scratch space used during writing.
+ void set_write_memory_pool(MemoryPool* memory_pool);
+ /// \brief Get a debug string naming the peer.
+ const std::string& peer() const;
+
+ /// \brief Process an incoming active message. This will unblock the
+ /// corresponding call to ReadFrameAsync/ReadNextFrame.
+ ucs_status_t RecvActiveMessage(const void* header, size_t header_length, void* data,
+ const size_t data_length,
+ const ucp_am_recv_param_t* param);
+
+ private:
+ class Impl;
+ std::unique_ptr<Impl> impl_;
+};
+
+ARROW_FLIGHT_EXPORT
+std::unique_ptr<arrow::flight::internal::ClientTransport> MakeUcxClientImpl();
+
+ARROW_FLIGHT_EXPORT
+std::unique_ptr<arrow::flight::internal::ServerTransport> MakeUcxServerImpl(
+ FlightServerBase* base, std::shared_ptr<MemoryManager> memory_manager);
+
+} // namespace ucx
+} // namespace transport
+} // namespace flight
+} // namespace arrow
diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_server.cc b/cpp/src/arrow/flight/transport/ucx/ucx_server.cc
new file mode 100644
index 0000000000..74a9311d0c
--- /dev/null
+++ b/cpp/src/arrow/flight/transport/ucx/ucx_server.cc
@@ -0,0 +1,628 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/flight/transport/ucx/ucx_internal.h"
+
+#include <atomic>
+#include <mutex>
+#include <queue>
+#include <thread>
+#include <unordered_map>
+
+#include <arpa/inet.h>
+#include <ucp/api/ucp.h>
+
+#include "arrow/buffer.h"
+#include "arrow/flight/server.h"
+#include "arrow/flight/transport.h"
+#include "arrow/flight/transport/ucx/util_internal.h"
+#include "arrow/flight/transport_server.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/util/io_util.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/make_unique.h"
+#include "arrow/util/thread_pool.h"
+#include "arrow/util/uri.h"
+
+namespace arrow {
+namespace flight {
+namespace transport {
+namespace ucx {
+
+// Send an error to the client and return OK.
+// Statuses returned up to the main server loop trigger a kReset instead.
+#define SERVER_RETURN_NOT_OK(driver, status) \
+ do { \
+ ::arrow::Status s = (status); \
+ if (!s.ok()) { \
+ ARROW_ASSIGN_OR_RAISE(auto headers, HeadersFrame::Make(s, {})); \
+ auto payload = std::move(headers).GetBuffer(); \
+ RETURN_NOT_OK( \
+ driver->SendFrame(FrameType::kHeaders, payload->data(), payload->size())); \
+ return ::arrow::Status::OK(); \
+ } \
+ } while (false)
+
+#define FLIGHT_LOG(LEVEL) (ARROW_LOG(LEVEL) << "[server] ")
+#define FLIGHT_LOG_PEER(LEVEL, PEER) \
+ (ARROW_LOG(LEVEL) << "[server]" \
+ << "[peer=" << (PEER) << "] ")
+
+namespace {
+class UcxServerCallContext : public flight::ServerCallContext {
+ public:
+ const std::string& peer_identity() const override { return peer_; }
+ const std::string& peer() const override { return peer_; }
+ ServerMiddleware* GetMiddleware(const std::string& key) const override {
+ return nullptr;
+ }
+ bool is_cancelled() const override { return false; }
+
+ private:
+ std::string peer_;
+};
+
+class UcxServerStream : public internal::ServerDataStream {
+ public:
+ explicit UcxServerStream(UcpCallDriver* driver)
+ : peer_(driver->peer()), driver_(driver), writes_done_(false) {}
+
+ Status WritesDone() override {
+ writes_done_ = true;
+ return Status::OK();
+ }
+
+ protected:
+ std::string peer_;
... 37056 lines suppressed ...