You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by we...@apache.org on 2022/02/16 02:24:56 UTC
[arrow] branch master updated: ARROW-15238: [C++] ARROW_ENGINE module with substrait consumer
This is an automated email from the ASF dual-hosted git repository.
westonpace pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new a935c81 ARROW-15238: [C++] ARROW_ENGINE module with substrait consumer
a935c81 is described below
commit a935c81b595d24179e115d64cda944efa93aa0e0
Author: Benjamin Kietzman <be...@gmail.com>
AuthorDate: Tue Feb 15 16:22:12 2022 -1000
ARROW-15238: [C++] ARROW_ENGINE module with substrait consumer
Continuation of https://github.com/apache/arrow/pull/11707. I'm taking over from @bkietz for now because he's unavailable right now for personal reasons.
Closes #12279 from jvanstraten/substrait-consumer
Lead-authored-by: Benjamin Kietzman <be...@gmail.com>
Co-authored-by: Jeroen van Straten <je...@gmail.com>
Co-authored-by: Weston Pace <we...@gmail.com>
Signed-off-by: Weston Pace <we...@gmail.com>
---
.travis.yml | 1 +
cpp/CMakeLists.txt | 2 +
cpp/cmake_modules/DefineOptions.cmake | 12 +-
cpp/cmake_modules/FindArrowEngine.cmake | 88 ++
cpp/cmake_modules/ThirdpartyToolchain.cmake | 10 +-
cpp/examples/arrow/CMakeLists.txt | 4 +
cpp/examples/arrow/engine_substrait_consumption.cc | 186 +++++
cpp/src/arrow/CMakeLists.txt | 4 +
cpp/src/arrow/array/array_base.cc | 2 +
cpp/src/arrow/array/array_base.h | 7 +-
cpp/src/arrow/array/builder_base.h | 23 +
cpp/src/arrow/compute/exec/expression_internal.h | 11 +-
cpp/src/arrow/csv/column_decoder_test.cc | 1 +
cpp/src/arrow/dataset/scanner.cc | 11 +
cpp/src/arrow/datum.h | 12 +-
cpp/src/arrow/engine/ArrowEngineConfig.cmake.in | 38 +
cpp/src/arrow/engine/CMakeLists.txt | 143 ++++
cpp/src/arrow/engine/api.h | 23 +
cpp/src/arrow/engine/arrow-engine.pc.in | 25 +
cpp/src/arrow/engine/pch.h | 23 +
.../arrow/engine/simple_extension_type_internal.h | 196 +++++
.../arrow/engine/substrait/expression_internal.cc | 896 +++++++++++++++++++++
.../arrow/engine/substrait/expression_internal.h | 49 ++
cpp/src/arrow/engine/substrait/extension_set.cc | 367 +++++++++
cpp/src/arrow/engine/substrait/extension_set.h | 240 ++++++
cpp/src/arrow/engine/substrait/extension_types.cc | 147 ++++
cpp/src/arrow/engine/substrait/extension_types.h | 82 ++
cpp/src/arrow/engine/substrait/plan_internal.cc | 161 ++++
cpp/src/arrow/engine/substrait/plan_internal.h | 55 ++
.../arrow/engine/substrait/relation_internal.cc | 193 +++++
cpp/src/arrow/engine/substrait/relation_internal.h | 37 +
cpp/src/arrow/engine/substrait/serde.cc | 232 ++++++
cpp/src/arrow/engine/substrait/serde.h | 168 ++++
cpp/src/arrow/engine/substrait/serde_test.cc | 728 +++++++++++++++++
cpp/src/arrow/engine/substrait/type_internal.cc | 494 ++++++++++++
cpp/src/arrow/engine/substrait/type_internal.h | 51 ++
cpp/src/arrow/engine/visibility.h | 50 ++
cpp/src/arrow/flight/CMakeLists.txt | 6 +-
cpp/src/arrow/scalar.cc | 12 +
cpp/src/arrow/scalar.h | 20 +-
cpp/src/arrow/status_test.cc | 7 +-
cpp/src/arrow/testing/matchers.h | 110 ++-
cpp/src/arrow/type.h | 6 +-
cpp/src/arrow/util/hashing.h | 9 +
dev/archery/archery/cli.py | 6 +
dev/archery/archery/lang/cpp.py | 2 +-
dev/release/rat_exclude_files.txt | 1 +
format/substrait/extension_types.yaml | 87 ++
48 files changed, 4998 insertions(+), 40 deletions(-)
diff --git a/.travis.yml b/.travis.yml
index 85e239e..226d427 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -99,6 +99,7 @@ jobs:
-e ARROW_GCS=OFF
-e ARROW_MIMALLOC=OFF
-e ARROW_ORC=OFF
+ -e ARROW_ENGINE=OFF
-e ARROW_PARQUET=OFF
-e ARROW_S3=OFF
-e CMAKE_UNITY_BUILD=ON
diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt
index 5a0233b..c350787 100644
--- a/cpp/CMakeLists.txt
+++ b/cpp/CMakeLists.txt
@@ -351,7 +351,9 @@ if(ARROW_CUDA
endif()
if(ARROW_ENGINE)
+ set(ARROW_PARQUET ON)
set(ARROW_COMPUTE ON)
+ set(ARROW_DATASET ON)
endif()
if(ARROW_SKYHOOK)
diff --git a/cpp/cmake_modules/DefineOptions.cmake b/cpp/cmake_modules/DefineOptions.cmake
index 0a43ec1..30b1d0e 100644
--- a/cpp/cmake_modules/DefineOptions.cmake
+++ b/cpp/cmake_modules/DefineOptions.cmake
@@ -225,7 +225,7 @@ if("${CMAKE_SOURCE_DIR}" STREQUAL "${CMAKE_CURRENT_SOURCE_DIR}")
define_option(ARROW_DATASET "Build the Arrow Dataset Modules" OFF)
- define_option(ARROW_ENGINE "Build the Arrow Execution Engine" OFF)
+ define_option(ARROW_ENGINE "Build the Arrow Query Engine Module" OFF)
define_option(ARROW_FILESYSTEM "Build the Arrow Filesystem Layer" OFF)
@@ -478,6 +478,16 @@ advised that if this is enabled 'install' will fail silently on components;\
that have not been built"
OFF)
+ set(ARROW_SUBSTRAIT_REPO_DEFAULT "https://github.com/substrait-io/substrait")
+ define_option_string(ARROW_SUBSTRAIT_REPO
+ "Custom git repository URL for downloading Substrait sources.;\
+See also ARROW_SUBSTRAIT_TAG" "${ARROW_SUBSTRAIT_REPO_DEFAULT}")
+
+ set(ARROW_SUBSTRAIT_TAG_DEFAULT "e1b4c04a1b518912f4c4065b16a1b2c0ac8e14cf")
+ define_option_string(ARROW_SUBSTRAIT_TAG
+ "Custom git hash/tag/branch for Substrait repository.;\
+See also ARROW_SUBSTRAIT_REPO" "${ARROW_SUBSTRAIT_TAG_DEFAULT}")
+
option(ARROW_BUILD_CONFIG_SUMMARY_JSON "Summarize build configuration in a JSON file"
ON)
endif()
diff --git a/cpp/cmake_modules/FindArrowEngine.cmake b/cpp/cmake_modules/FindArrowEngine.cmake
new file mode 100644
index 0000000..3ee09e0
--- /dev/null
+++ b/cpp/cmake_modules/FindArrowEngine.cmake
@@ -0,0 +1,88 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# - Find Arrow Engine (arrow/engine/api.h, libarrow_engine.a, libarrow_engine.so)
+#
+# This module requires Arrow from which it uses
+# arrow_find_package()
+#
+# This module defines
+# ARROW_ENGINE_FOUND, whether Arrow Engine has been found
+# ARROW_ENGINE_IMPORT_LIB,
+# path to libarrow_engine's import library (Windows only)
+# ARROW_ENGINE_INCLUDE_DIR, directory containing headers
+# ARROW_ENGINE_LIB_DIR, directory containing Arrow Engine libraries
+# ARROW_ENGINE_SHARED_LIB, path to libarrow_engine's shared library
+# ARROW_ENGINE_STATIC_LIB, path to libarrow_engine.a
+
+if(DEFINED ARROW_ENGINE_FOUND)
+ return()
+endif()
+
+set(find_package_arguments)
+if(${CMAKE_FIND_PACKAGE_NAME}_FIND_VERSION)
+ list(APPEND find_package_arguments "${${CMAKE_FIND_PACKAGE_NAME}_FIND_VERSION}")
+endif()
+if(${CMAKE_FIND_PACKAGE_NAME}_FIND_REQUIRED)
+ list(APPEND find_package_arguments REQUIRED)
+endif()
+if(${CMAKE_FIND_PACKAGE_NAME}_FIND_QUIETLY)
+ list(APPEND find_package_arguments QUIET)
+endif()
+find_package(Arrow ${find_package_arguments})
+find_package(Parquet ${find_package_arguments})
+
+if(ARROW_FOUND AND PARQUET_FOUND)
+ arrow_find_package(ARROW_ENGINE
+ "${ARROW_HOME}"
+ arrow_engine
+ arrow/engine/api.h
+ ArrowEngine
+ arrow-engine)
+ if(NOT ARROW_ENGINE_VERSION)
+ set(ARROW_ENGINE_VERSION "${ARROW_VERSION}")
+ endif()
+endif()
+
+if("${ARROW_ENGINE_VERSION}" VERSION_EQUAL "${ARROW_VERSION}")
+ set(ARROW_ENGINE_VERSION_MATCH TRUE)
+else()
+ set(ARROW_ENGINE_VERSION_MATCH FALSE)
+endif()
+
+mark_as_advanced(ARROW_ENGINE_IMPORT_LIB
+ ARROW_ENGINE_INCLUDE_DIR
+ ARROW_ENGINE_LIBS
+ ARROW_ENGINE_LIB_DIR
+ ARROW_ENGINE_SHARED_IMP_LIB
+ ARROW_ENGINE_SHARED_LIB
+ ARROW_ENGINE_STATIC_LIB
+ ARROW_ENGINE_VERSION
+ ARROW_ENGINE_VERSION_MATCH)
+
+find_package_handle_standard_args(
+ ArrowEngine
+ REQUIRED_VARS ARROW_ENGINE_INCLUDE_DIR ARROW_ENGINE_LIB_DIR ARROW_ENGINE_VERSION_MATCH
+ VERSION_VAR ARROW_ENGINE_VERSION)
+set(ARROW_ENGINE_FOUND ${ArrowEngine_FOUND})
+
+if(ArrowEngine_FOUND AND NOT ArrowEngine_FIND_QUIETLY)
+ message(STATUS "Found the Arrow Engine by ${ARROW_ENGINE_FIND_APPROACH}")
+ message(STATUS "Found the Arrow Engine shared library: ${ARROW_ENGINE_SHARED_LIB}")
+ message(STATUS "Found the Arrow Engine import library: ${ARROW_ENGINE_IMPORT_LIB}")
+ message(STATUS "Found the Arrow Engine static library: ${ARROW_ENGINE_STATIC_LIB}")
+endif()
diff --git a/cpp/cmake_modules/ThirdpartyToolchain.cmake b/cpp/cmake_modules/ThirdpartyToolchain.cmake
index b7e6f6e..32af0a0 100644
--- a/cpp/cmake_modules/ThirdpartyToolchain.cmake
+++ b/cpp/cmake_modules/ThirdpartyToolchain.cmake
@@ -309,7 +309,8 @@ endif()
if(ARROW_ORC
OR ARROW_FLIGHT
- OR ARROW_GANDIVA)
+ OR ARROW_GANDIVA
+ OR ARROW_ENGINE)
set(ARROW_WITH_PROTOBUF ON)
endif()
@@ -1427,6 +1428,11 @@ macro(build_protobuf)
set(PROTOBUF_VENDORED TRUE)
set(PROTOBUF_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/protobuf_ep-install")
set(PROTOBUF_INCLUDE_DIR "${PROTOBUF_PREFIX}/include")
+ # This flag is based on what the user initially requested but if
+ # we've fallen back to building protobuf we always build it statically
+ # so we need to reset the flag so that we can link against it correctly
+ # later.
+ set(Protobuf_USE_STATIC_LIBS ON)
# Newer protobuf releases always have a lib prefix independent from CMAKE_STATIC_LIBRARY_PREFIX
set(PROTOBUF_STATIC_LIB
"${PROTOBUF_PREFIX}/lib/libprotobuf${CMAKE_STATIC_LIBRARY_SUFFIX}")
@@ -1533,7 +1539,7 @@ if(ARROW_WITH_PROTOBUF)
PC_PACKAGE_NAMES
protobuf)
- if(ARROW_PROTOBUF_USE_SHARED AND MSVC_TOOLCHAIN)
+ if(NOT Protobuf_USE_STATIC_LIBS AND MSVC_TOOLCHAIN)
add_definitions(-DPROTOBUF_USE_DLLS)
endif()
diff --git a/cpp/examples/arrow/CMakeLists.txt b/cpp/examples/arrow/CMakeLists.txt
index 54b7eeb..a13942f 100644
--- a/cpp/examples/arrow/CMakeLists.txt
+++ b/cpp/examples/arrow/CMakeLists.txt
@@ -21,6 +21,10 @@ if(ARROW_COMPUTE)
add_arrow_example(compute_register_example)
endif()
+if(ARROW_ENGINE)
+ add_arrow_example(engine_substrait_consumption EXTRA_LINK_LIBS arrow_engine_shared)
+endif()
+
if(ARROW_COMPUTE AND ARROW_CSV)
add_arrow_example(compute_and_write_csv_example)
endif()
diff --git a/cpp/examples/arrow/engine_substrait_consumption.cc b/cpp/examples/arrow/engine_substrait_consumption.cc
new file mode 100644
index 0000000..b0109b3
--- /dev/null
+++ b/cpp/examples/arrow/engine_substrait_consumption.cc
@@ -0,0 +1,186 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <arrow/api.h>
+#include <arrow/compute/api.h>
+#include <arrow/engine/api.h>
+
+#include <cstdlib>
+#include <iostream>
+#include <memory>
+#include <vector>
+
+namespace eng = arrow::engine;
+namespace cp = arrow::compute;
+
+#define ABORT_ON_FAILURE(expr) \
+ do { \
+ arrow::Status status_ = (expr); \
+ if (!status_.ok()) { \
+ std::cerr << status_.message() << std::endl; \
+ abort(); \
+ } \
+ } while (0);
+
+class IgnoringConsumer : public cp::SinkNodeConsumer {
+ public:
+ explicit IgnoringConsumer(size_t tag) : tag_{tag} {}
+
+ arrow::Status Consume(cp::ExecBatch batch) override {
+ // Consume a batch of data
+ // (just print its row count to stdout)
+ std::cout << "-" << tag_ << " consumed " << batch.length << " rows" << std::endl;
+ return arrow::Status::OK();
+ }
+
+ arrow::Future<> Finish() override {
+ // Signal to the consumer that the last batch has been delivered
+ // (we don't do any real work in this consumer so mark it finished immediately)
+ //
+ // The returned future should only finish when all outstanding tasks have completed
+ // (after this method is called Consume is guaranteed not to be called again)
+ std::cout << "-" << tag_ << " finished" << std::endl;
+ return arrow::Future<>::MakeFinished();
+ }
+
+ private:
+ // A unique label for instances to help distinguish logging output if a plan has
+ // multiple sinks
+ //
+ // In this example, this is set to the zero-based index of the relation tree in the plan
+ size_t tag_;
+};
+
+arrow::Future<std::shared_ptr<arrow::Buffer>> GetSubstraitFromServer(
+ const std::string& filename) {
+ // Emulate server interaction by parsing hard coded JSON
+ std::string substrait_json = R"({
+ "relations": [
+ {"rel": {
+ "read": {
+ "base_schema": {
+ "struct": {
+ "types": [ {"i64": {}}, {"bool": {}} ]
+ },
+ "names": ["i", "b"]
+ },
+ "local_files": {
+ "items": [
+ {
+ "uri_file": "file://FILENAME_PLACEHOLDER",
+ "format": "FILE_FORMAT_PARQUET"
+ }
+ ]
+ }
+ }
+ }}
+ ],
+ "extension_uris": [
+ {
+ "extension_uri_anchor": 7,
+ "uri": "https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml"
+ }
+ ],
+ "extensions": [
+ {"extension_type": {
+ "extension_uri_reference": 7,
+ "type_anchor": 42,
+ "name": "null"
+ }},
+ {"extension_type_variation": {
+ "extension_uri_reference": 7,
+ "type_variation_anchor": 23,
+ "name": "u8"
+ }},
+ {"extension_function": {
+ "extension_uri_reference": 7,
+ "function_anchor": 42,
+ "name": "add"
+ }}
+ ]
+ })";
+ std::string filename_placeholder = "FILENAME_PLACEHOLDER";
+ substrait_json.replace(substrait_json.find(filename_placeholder),
+ filename_placeholder.size(), filename);
+ return eng::internal::SubstraitFromJSON("Plan", substrait_json);
+}
+
+int main(int argc, char** argv) {
+ if (argc < 2) {
+ std::cout << "Please specify a parquet file to scan" << std::endl;
+ // Fake pass for CI
+ return EXIT_SUCCESS;
+ }
+
+ // Plans arrive at the consumer serialized in a Buffer, using the binary protobuf
+ // serialization of a substrait Plan
+ auto maybe_serialized_plan = GetSubstraitFromServer(argv[1]).result();
+ ABORT_ON_FAILURE(maybe_serialized_plan.status());
+ std::shared_ptr<arrow::Buffer> serialized_plan =
+ std::move(maybe_serialized_plan).ValueOrDie();
+
+ // Print the received plan to stdout as JSON
+ arrow::Result<std::string> maybe_plan_json =
+ eng::internal::SubstraitToJSON("Plan", *serialized_plan);
+ ABORT_ON_FAILURE(maybe_plan_json.status());
+ std::cout << std::string(50, '#') << " received substrait::Plan:" << std::endl;
+ std::cout << maybe_plan_json.ValueOrDie() << std::endl;
+
+ // The data sink(s) for plans is/are implicit in substrait plans, but explicit in
+ // Arrow. Therefore, deserializing a plan requires a factory for consumers: each
+ // time the root of a substrait relation tree is deserialized, an Arrow consumer is
+ // constructed into which its batches will be piped.
+ std::vector<std::shared_ptr<cp::SinkNodeConsumer>> consumers;
+ std::function<std::shared_ptr<cp::SinkNodeConsumer>()> consumer_factory = [&] {
+ // All batches produced by the plan will be fed into IgnoringConsumers:
+ auto tag = consumers.size();
+ consumers.emplace_back(new IgnoringConsumer{tag});
+ return consumers.back();
+ };
+
+ // Deserialize each relation tree in the substrait plan to an Arrow compute Declaration
+ arrow::Result<std::vector<cp::Declaration>> maybe_decls =
+ eng::DeserializePlan(*serialized_plan, consumer_factory);
+ ABORT_ON_FAILURE(maybe_decls.status());
+ std::vector<cp::Declaration> decls = std::move(maybe_decls).ValueOrDie();
+
+ // It's safe to drop the serialized plan; we don't leave references to its memory
+ serialized_plan.reset();
+
+ // Construct an empty plan (note: configure Function registry and ThreadPool here)
+ arrow::Result<std::shared_ptr<cp::ExecPlan>> maybe_plan = cp::ExecPlan::Make();
+ ABORT_ON_FAILURE(maybe_plan.status());
+ std::shared_ptr<cp::ExecPlan> plan = std::move(maybe_plan).ValueOrDie();
+
+ // Add decls to plan (note: configure ExecNode registry before this point)
+ for (const cp::Declaration& decl : decls) {
+ ABORT_ON_FAILURE(decl.AddToPlan(plan.get()).status());
+ }
+
+ // Validate the plan and print it to stdout
+ ABORT_ON_FAILURE(plan->Validate());
+ std::cout << std::string(50, '#') << " produced arrow::ExecPlan:" << std::endl;
+ std::cout << plan->ToString() << std::endl;
+
+ // Start the plan...
+ std::cout << std::string(50, '#') << " consuming batches:" << std::endl;
+ ABORT_ON_FAILURE(plan->StartProducing());
+
+ // ... and wait for it to finish
+ ABORT_ON_FAILURE(plan->finished().status());
+ return EXIT_SUCCESS;
+}
diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt
index 89e7e45..a895881 100644
--- a/cpp/src/arrow/CMakeLists.txt
+++ b/cpp/src/arrow/CMakeLists.txt
@@ -725,6 +725,10 @@ if(ARROW_COMPUTE)
add_subdirectory(compute)
endif()
+if(ARROW_ENGINE)
+ add_subdirectory(engine)
+endif()
+
if(ARROW_CUDA)
add_subdirectory(gpu)
endif()
diff --git a/cpp/src/arrow/array/array_base.cc b/cpp/src/arrow/array/array_base.cc
index c31a7b7..11b6b16 100644
--- a/cpp/src/arrow/array/array_base.cc
+++ b/cpp/src/arrow/array/array_base.cc
@@ -282,6 +282,8 @@ std::string Array::ToString() const {
return ss.str();
}
+void PrintTo(const Array& x, std::ostream* os) { *os << x.ToString(); }
+
Result<std::shared_ptr<Array>> Array::View(
const std::shared_ptr<DataType>& out_type) const {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<ArrayData> result,
diff --git a/cpp/src/arrow/array/array_base.h b/cpp/src/arrow/array/array_base.h
index b6b769c..c17daad 100644
--- a/cpp/src/arrow/array/array_base.h
+++ b/cpp/src/arrow/array/array_base.h
@@ -187,10 +187,11 @@ class ARROW_EXPORT Array {
Status ValidateFull() const;
protected:
- Array() : null_bitmap_data_(NULLPTR) {}
+ Array() = default;
+ ARROW_DEFAULT_MOVE_AND_ASSIGN(Array);
std::shared_ptr<ArrayData> data_;
- const uint8_t* null_bitmap_data_;
+ const uint8_t* null_bitmap_data_ = NULLPTR;
/// Protected method for constructors
void SetData(const std::shared_ptr<ArrayData>& data) {
@@ -204,6 +205,8 @@ class ARROW_EXPORT Array {
private:
ARROW_DISALLOW_COPY_AND_ASSIGN(Array);
+
+ ARROW_EXPORT friend void PrintTo(const Array& x, std::ostream* os);
};
static inline std::ostream& operator<<(std::ostream& os, const Array& x) {
diff --git a/cpp/src/arrow/array/builder_base.h b/cpp/src/arrow/array/builder_base.h
index 931324b..4d0b477 100644
--- a/cpp/src/arrow/array/builder_base.h
+++ b/cpp/src/arrow/array/builder_base.h
@@ -28,6 +28,7 @@
#include "arrow/array/array_primitive.h"
#include "arrow/buffer.h"
#include "arrow/buffer_builder.h"
+#include "arrow/result.h"
#include "arrow/status.h"
#include "arrow/type_fwd.h"
#include "arrow/util/macros.h"
@@ -306,6 +307,13 @@ ARROW_EXPORT
Status MakeBuilder(MemoryPool* pool, const std::shared_ptr<DataType>& type,
std::unique_ptr<ArrayBuilder>* out);
+inline Result<std::unique_ptr<ArrayBuilder>> MakeBuilder(
+ const std::shared_ptr<DataType>& type, MemoryPool* pool = default_memory_pool()) {
+ std::unique_ptr<ArrayBuilder> out;
+ ARROW_RETURN_NOT_OK(MakeBuilder(pool, type, &out));
+ return std::move(out);
+}
+
/// \brief Construct an empty ArrayBuilder corresponding to the data
/// type, where any top-level or nested dictionary builders return the
/// exact index type specified by the type.
@@ -313,6 +321,13 @@ ARROW_EXPORT
Status MakeBuilderExactIndex(MemoryPool* pool, const std::shared_ptr<DataType>& type,
std::unique_ptr<ArrayBuilder>* out);
+inline Result<std::unique_ptr<ArrayBuilder>> MakeBuilderExactIndex(
+ const std::shared_ptr<DataType>& type, MemoryPool* pool = default_memory_pool()) {
+ std::unique_ptr<ArrayBuilder> out;
+ ARROW_RETURN_NOT_OK(MakeBuilderExactIndex(pool, type, &out));
+ return std::move(out);
+}
+
/// \brief Construct an empty DictionaryBuilder initialized optionally
/// with a pre-existing dictionary
/// \param[in] pool the MemoryPool to use for allocations
@@ -324,4 +339,12 @@ Status MakeDictionaryBuilder(MemoryPool* pool, const std::shared_ptr<DataType>&
const std::shared_ptr<Array>& dictionary,
std::unique_ptr<ArrayBuilder>* out);
+inline Result<std::unique_ptr<ArrayBuilder>> MakeDictionaryBuilder(
+ const std::shared_ptr<DataType>& type, const std::shared_ptr<Array>& dictionary,
+ MemoryPool* pool = default_memory_pool()) {
+ std::unique_ptr<ArrayBuilder> out;
+ ARROW_RETURN_NOT_OK(MakeDictionaryBuilder(pool, type, dictionary, &out));
+ return std::move(out);
+}
+
} // namespace arrow
diff --git a/cpp/src/arrow/compute/exec/expression_internal.h b/cpp/src/arrow/compute/exec/expression_internal.h
index dc38924..f8c686d 100644
--- a/cpp/src/arrow/compute/exec/expression_internal.h
+++ b/cpp/src/arrow/compute/exec/expression_internal.h
@@ -29,9 +29,6 @@
#include "arrow/util/logging.h"
namespace arrow {
-
-using internal::checked_cast;
-
namespace compute {
struct KnownFieldValues {
@@ -213,7 +210,7 @@ struct Comparison {
inline const compute::CastOptions* GetCastOptions(const Expression::Call& call) {
if (call.function_name != "cast") return nullptr;
- return checked_cast<const compute::CastOptions*>(call.options.get());
+ return ::arrow::internal::checked_cast<const compute::CastOptions*>(call.options.get());
}
inline bool IsSetLookup(const std::string& function) {
@@ -223,7 +220,8 @@ inline bool IsSetLookup(const std::string& function) {
inline const compute::MakeStructOptions* GetMakeStructOptions(
const Expression::Call& call) {
if (call.function_name != "make_struct") return nullptr;
- return checked_cast<const compute::MakeStructOptions*>(call.options.get());
+ return ::arrow::internal::checked_cast<const compute::MakeStructOptions*>(
+ call.options.get());
}
/// A helper for unboxing an Expression composed of associative function calls.
@@ -281,7 +279,8 @@ inline Result<std::shared_ptr<compute::Function>> GetFunction(
return exec_context->func_registry()->GetFunction(call.function_name);
}
// XXX this special case is strange; why not make "cast" a ScalarFunction?
- const auto& to_type = checked_cast<const compute::CastOptions&>(*call.options).to_type;
+ const auto& to_type =
+ ::arrow::internal::checked_cast<const compute::CastOptions&>(*call.options).to_type;
return compute::GetCastFunction(to_type);
}
diff --git a/cpp/src/arrow/csv/column_decoder_test.cc b/cpp/src/arrow/csv/column_decoder_test.cc
index c8b96e0..ebac7a3 100644
--- a/cpp/src/arrow/csv/column_decoder_test.cc
+++ b/cpp/src/arrow/csv/column_decoder_test.cc
@@ -22,6 +22,7 @@
#include <gtest/gtest.h>
+#include "arrow/array/array_base.h"
#include "arrow/csv/column_decoder.h"
#include "arrow/csv/options.h"
#include "arrow/csv/test_common.h"
diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc
index c99316f..7ecfb3e 100644
--- a/cpp/src/arrow/dataset/scanner.cc
+++ b/cpp/src/arrow/dataset/scanner.cc
@@ -766,6 +766,17 @@ Result<compute::ExecNode*> MakeScanNode(compute::ExecPlan* plan,
scan_options->filter.Bind(*dataset->schema()));
}
+ // If no projection schema is specified we will use a default projection. In
+ // general we should not be able to get here if using the ScannerBuilder but
+ // it is possible to get here if scan_options is used directly. To be cleaned up
+ // in ARROW-12311
+ if (!scan_options->projected_schema) {
+ ARROW_ASSIGN_OR_RAISE(auto projection_descr,
+ ProjectionDescr::Default(*dataset->schema()));
+ scan_options->projected_schema = std::move(projection_descr.schema);
+ scan_options->projection = projection_descr.expression;
+ }
+
if (!scan_options->projection.IsBound()) {
auto fields = dataset->schema()->fields();
for (const auto& aug_field : kAugmentedFields) {
diff --git a/cpp/src/arrow/datum.h b/cpp/src/arrow/datum.h
index 514b424..bce53de 100644
--- a/cpp/src/arrow/datum.h
+++ b/cpp/src/arrow/datum.h
@@ -149,9 +149,17 @@ struct ARROW_EXPORT Datum {
template <typename T, bool IsArray = std::is_base_of<Array, T>::value,
bool IsScalar = std::is_base_of<Scalar, T>::value,
typename = enable_if_t<IsArray || IsScalar>>
- Datum(const std::shared_ptr<T>& value) // NOLINT implicit conversion
+ Datum(std::shared_ptr<T> value) // NOLINT implicit conversion
: Datum(std::shared_ptr<typename std::conditional<IsArray, Array, Scalar>::type>(
- value)) {}
+ std::move(value))) {}
+
+ // Cast from subtypes of Array or Scalar to Datum
+ template <typename T, typename TV = typename std::remove_reference<T>::type,
+ bool IsArray = std::is_base_of<Array, T>::value,
+ bool IsScalar = std::is_base_of<Scalar, T>::value,
+ typename = enable_if_t<IsArray || IsScalar>>
+ Datum(T&& value) // NOLINT implicit conversion
+ : Datum(std::make_shared<TV>(std::forward<T>(value))) {}
// Convenience constructors
explicit Datum(bool value);
diff --git a/cpp/src/arrow/engine/ArrowEngineConfig.cmake.in b/cpp/src/arrow/engine/ArrowEngineConfig.cmake.in
new file mode 100644
index 0000000..8fafcda
--- /dev/null
+++ b/cpp/src/arrow/engine/ArrowEngineConfig.cmake.in
@@ -0,0 +1,38 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# This config sets the following variables in your project::
+#
+# ArrowEngine_FOUND - true if Arrow Engine found on the system
+#
+# This config sets the following targets in your project::
+#
+# arrow_engine_shared - for linked as shared library if shared library is built
+# arrow_engine_static - for linked as static library if static library is built
+
+@PACKAGE_INIT@
+
+include(CMakeFindDependencyMacro)
+find_dependency(Arrow)
+find_dependency(ArrowDataset)
+find_dependency(Parquet)
+
+# Load targets only once. If we load targets multiple times, CMake reports
+# already existent target error.
+if(NOT (TARGET arrow_engine_shared OR TARGET arrow_engine_static))
+ include("${CMAKE_CURRENT_LIST_DIR}/ArrowEngineTargets.cmake")
+endif()
diff --git a/cpp/src/arrow/engine/CMakeLists.txt b/cpp/src/arrow/engine/CMakeLists.txt
new file mode 100644
index 0000000..0f00a66
--- /dev/null
+++ b/cpp/src/arrow/engine/CMakeLists.txt
@@ -0,0 +1,143 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+add_custom_target(arrow_engine)
+
+arrow_install_all_headers("arrow/engine")
+
+set(ARROW_ENGINE_LINK_LIBS ${ARROW_PROTOBUF_LIBPROTOBUF})
+
+#if(WIN32)
+# list(APPEND ARROW_ENGINE_LINK_LIBS ws2_32.lib)
+#endif()
+
+set(ARROW_ENGINE_SRCS
+ substrait/expression_internal.cc
+ substrait/extension_set.cc
+ substrait/extension_types.cc
+ substrait/serde.cc
+ substrait/plan_internal.cc
+ substrait/relation_internal.cc
+ substrait/type_internal.cc)
+
+set(SUBSTRAIT_LOCAL_DIR "${CMAKE_CURRENT_BINARY_DIR}/substrait")
+set(SUBSTRAIT_GEN_DIR "${CMAKE_CURRENT_BINARY_DIR}/generated")
+set(SUBSTRAIT_PROTOS
+ capabilities
+ expression
+ extensions/extensions
+ function
+ parameterized_types
+ plan
+ relations
+ type
+ type_expressions)
+
+externalproject_add(substrait_ep
+ GIT_REPOSITORY "${ARROW_SUBSTRAIT_REPO}"
+ GIT_TAG "${ARROW_SUBSTRAIT_TAG}"
+ SOURCE_DIR "${SUBSTRAIT_LOCAL_DIR}"
+ CONFIGURE_COMMAND ""
+ BUILD_COMMAND ""
+ INSTALL_COMMAND "")
+
+set(SUBSTRAIT_SUPPRESSED_WARNINGS)
+if(MSVC)
+ # Protobuf generated files trigger some spurious warnings on MSVC.
+
+ # Implicit conversion from uint64_t to uint32_t:
+ list(APPEND SUBSTRAIT_SUPPRESSED_WARNINGS "/wd4244")
+
+ # Missing dll-interface:
+ list(APPEND SUBSTRAIT_SUPPRESSED_WARNINGS "/wd4251")
+endif()
+
+set(SUBSTRAIT_PROTO_GEN_ALL)
+foreach(SUBSTRAIT_PROTO ${SUBSTRAIT_PROTOS})
+ set(SUBSTRAIT_PROTO_GEN "${SUBSTRAIT_GEN_DIR}/substrait/${SUBSTRAIT_PROTO}.pb")
+
+ foreach(EXT h cc)
+ set_source_files_properties("${SUBSTRAIT_PROTO_GEN}.${EXT}"
+ PROPERTIES COMPILE_OPTIONS
+ "${SUBSTRAIT_SUPPRESSED_WARNINGS}"
+ GENERATED TRUE
+ SKIP_UNITY_BUILD_INCLUSION TRUE)
+ add_custom_command(OUTPUT "${SUBSTRAIT_PROTO_GEN}.${EXT}"
+ COMMAND ${ARROW_PROTOBUF_PROTOC} "-I${SUBSTRAIT_LOCAL_DIR}/proto"
+ "--cpp_out=${SUBSTRAIT_GEN_DIR}"
+ "${SUBSTRAIT_LOCAL_DIR}/proto/substrait/${SUBSTRAIT_PROTO}.proto"
+ DEPENDS ${PROTO_DEPENDS} substrait_ep)
+ list(APPEND SUBSTRAIT_PROTO_GEN_ALL "${SUBSTRAIT_PROTO_GEN}.${EXT}")
+ endforeach()
+
+ list(APPEND ARROW_ENGINE_SRCS "${SUBSTRAIT_PROTO_GEN}.cc")
+endforeach()
+
+add_custom_target(substrait_gen ALL DEPENDS ${SUBSTRAIT_PROTO_GEN_ALL})
+
+find_package(Git)
+add_custom_target(substrait_gen_verify
+ COMMENT "Verifying that generated substrait accessors are consistent with \
+ ARROW_SUBSTRAIT_REPO_AND_TAG='${ARROW_SUBSTRAIT_REPO_AND_TAG}'"
+ COMMAND ${GIT_EXECUTABLE} diff --exit-code ${SUBSTRAIT_GEN_DIR}
+ DEPENDS substrait_gen_clear
+ DEPENDS substrait_gen)
+
+add_arrow_lib(arrow_engine
+ CMAKE_PACKAGE_NAME
+ ArrowEngine
+ PKG_CONFIG_NAME
+ arrow-engine
+ OUTPUTS
+ ARROW_ENGINE_LIBRARIES
+ SOURCES
+ ${ARROW_ENGINE_SRCS}
+ PRECOMPILED_HEADERS
+ "$<$<COMPILE_LANGUAGE:CXX>:arrow/engine/pch.h>"
+ SHARED_LINK_FLAGS
+ ${ARROW_VERSION_SCRIPT_FLAGS} # Defined in cpp/arrow/CMakeLists.txt
+ SHARED_LINK_LIBS
+ arrow_shared
+ arrow_dataset_shared
+ ${ARROW_ENGINE_LINK_LIBS}
+ STATIC_LINK_LIBS
+ arrow_static
+ arrow_dataset_static
+ ${ARROW_ENGINE_LINK_LIBS}
+ PRIVATE_INCLUDES
+ ${SUBSTRAIT_GEN_DIR})
+
+foreach(LIB_TARGET ${ARROW_ENGINE_LIBRARIES})
+ target_compile_definitions(${LIB_TARGET} PRIVATE ARROW_ENGINE_EXPORTING)
+endforeach()
+
+set(ARROW_ENGINE_TEST_LINK_LIBS ${ARROW_ENGINE_LINK_lIBS} ${ARROW_TEST_LINK_LIBS})
+if(ARROW_TEST_LINKAGE STREQUAL "static")
+ list(APPEND ARROW_ENGINE_TEST_LINK_LIBS arrow_engine_static)
+else()
+ list(APPEND ARROW_ENGINE_TEST_LINK_LIBS arrow_engine_shared)
+endif()
+
+add_arrow_test(substrait_test
+ SOURCES
+ substrait/serde_test.cc
+ EXTRA_LINK_LIBS
+ ${ARROW_ENGINE_TEST_LINK_LIBS}
+ PREFIX
+ "arrow-engine"
+ LABELS
+ "arrow_engine")
diff --git a/cpp/src/arrow/engine/api.h b/cpp/src/arrow/engine/api.h
new file mode 100644
index 0000000..de996e4
--- /dev/null
+++ b/cpp/src/arrow/engine/api.h
@@ -0,0 +1,23 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// This API is EXPERIMENTAL.
+
+#pragma once
+
+#include "arrow/engine/substrait/extension_types.h"
+#include "arrow/engine/substrait/serde.h"
diff --git a/cpp/src/arrow/engine/arrow-engine.pc.in b/cpp/src/arrow/engine/arrow-engine.pc.in
new file mode 100644
index 0000000..90fba82
--- /dev/null
+++ b/cpp/src/arrow/engine/arrow-engine.pc.in
@@ -0,0 +1,25 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+libdir=@CMAKE_INSTALL_FULL_LIBDIR@
+includedir=@CMAKE_INSTALL_FULL_INCLUDEDIR@
+
+Name: Apache Arrow Engine
+Description: Apache Arrow's Query Engine.
+Version: @ARROW_VERSION@
+Requires: arrow
+Libs: -L${libdir} -larrow_engine
diff --git a/cpp/src/arrow/engine/pch.h b/cpp/src/arrow/engine/pch.h
new file mode 100644
index 0000000..ddb4c12
--- /dev/null
+++ b/cpp/src/arrow/engine/pch.h
@@ -0,0 +1,23 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Often-used headers, for precompiling.
+// If updating this header, please make sure you check compilation speed
+// before checking in. Adding headers which are not used extremely often
+// may incur a slowdown, since it makes the precompiled header heavier to load.
+
+#include "arrow/pch.h"
diff --git a/cpp/src/arrow/engine/simple_extension_type_internal.h b/cpp/src/arrow/engine/simple_extension_type_internal.h
new file mode 100644
index 0000000..b177425
--- /dev/null
+++ b/cpp/src/arrow/engine/simple_extension_type_internal.h
@@ -0,0 +1,196 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <sstream>
+#include <string>
+#include <vector>
+
+#include "arrow/extension_type.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/optional.h"
+#include "arrow/util/reflection_internal.h"
+#include "arrow/util/string.h"
+
+namespace arrow {
+namespace engine {
+
+/// \brief A helper class for creating simple extension types
+///
+/// Extension types can be parameterized by flat structs
+///
+/// Each item in the struct will be serialized and deserialized using
+/// the STL insertion and extraction operators (i.e. << and >>).
+///
+/// Note: The serialization is a very barebones JSON-like format and
+/// probably shouldn't be hand-edited
+
+template <const util::string_view& kExtensionName, typename Params,
+ typename ParamsProperties, const ParamsProperties* kProperties,
+ std::shared_ptr<DataType> GetStorage(const Params&)>
+class SimpleExtensionType : public ExtensionType {
+ public:
+ explicit SimpleExtensionType(std::shared_ptr<DataType> storage_type, Params params = {})
+ : ExtensionType(std::move(storage_type)), params_(std::move(params)) {}
+
+ static std::shared_ptr<DataType> Make(Params params) {
+ auto storage_type = GetStorage(params);
+ return std::make_shared<SimpleExtensionType>(std::move(storage_type),
+ std::move(params));
+ }
+
+ /// \brief Returns the parameters object for the type
+ ///
+ /// If the type is not an instance of this extension type then nullptr will be returned
+ static const Params* GetIf(const DataType& type) {
+ if (type.id() != Type::EXTENSION) return nullptr;
+
+ const auto& ext_type = ::arrow::internal::checked_cast<const ExtensionType&>(type);
+ if (ext_type.extension_name() != kExtensionName) return nullptr;
+
+ return &::arrow::internal::checked_cast<const SimpleExtensionType&>(type).params_;
+ }
+
+ std::string extension_name() const override { return kExtensionName.to_string(); }
+
+ std::string ToString() const override { return "extension<" + this->Serialize() + ">"; }
+
+ /// \brief A comparator which returns true iff all parameter properties are equal
+ struct ExtensionEqualsImpl {
+ ExtensionEqualsImpl(const Params& l, const Params& r) : left_(l), right_(r) {
+ kProperties->ForEach(*this);
+ }
+
+ template <typename Property>
+ void operator()(const Property& prop, size_t i) {
+ equal_ &= prop.get(left_) == prop.get(right_);
+ }
+
+ const Params& left_;
+ const Params& right_;
+ bool equal_ = true;
+ };
+
+ bool ExtensionEquals(const ExtensionType& other) const override {
+ if (kExtensionName != other.extension_name()) return false;
+ const auto& other_params = static_cast<const SimpleExtensionType&>(other).params_;
+ return ExtensionEqualsImpl(params_, other_params).equal_;
+ }
+
+ std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override {
+ DCHECK_EQ(data->type->id(), Type::EXTENSION);
+ DCHECK_EQ(static_cast<const ExtensionType&>(*data->type).extension_name(),
+ kExtensionName);
+ return std::make_shared<ExtensionArray>(data);
+ }
+
+ struct DeserializeImpl {
+ explicit DeserializeImpl(util::string_view repr) {
+ Init(kExtensionName, repr, kProperties->size());
+ kProperties->ForEach(*this);
+ }
+
+ void Fail() { params_ = util::nullopt; }
+
+ void Init(util::string_view class_name, util::string_view repr,
+ size_t num_properties) {
+ if (!repr.starts_with(class_name)) return Fail();
+
+ repr = repr.substr(class_name.size());
+ if (repr.empty()) return Fail();
+ if (repr.front() != '{') return Fail();
+ if (repr.back() != '}') return Fail();
+
+ repr = repr.substr(1, repr.size() - 2);
+ members_ = ::arrow::internal::SplitString(repr, ',');
+ if (members_.size() != num_properties) return Fail();
+ }
+
+ template <typename Property>
+ void operator()(const Property& prop, size_t i) {
+ if (!params_) return;
+
+ auto first_colon = members_[i].find_first_of(':');
+ if (first_colon == util::string_view::npos) return Fail();
+
+ auto name = members_[i].substr(0, first_colon);
+ if (name != prop.name()) return Fail();
+
+ auto value_repr = members_[i].substr(first_colon + 1);
+ typename Property::Type value;
+ try {
+ std::stringstream ss(value_repr.to_string());
+ ss >> value;
+ if (!ss.eof()) return Fail();
+ } catch (...) {
+ return Fail();
+ }
+ prop.set(&*params_, std::move(value));
+ }
+
+ util::optional<Params> params_;
+ std::vector<util::string_view> members_;
+ };
+ Result<std::shared_ptr<DataType>> Deserialize(
+ std::shared_ptr<DataType> storage_type,
+ const std::string& serialized) const override {
+ if (auto params = DeserializeImpl(serialized).params_) {
+ if (!storage_type->Equals(GetStorage(*params))) {
+ return Status::Invalid("Invalid storage type for ", kExtensionName, ": ",
+ storage_type->ToString(), " (expected ",
+ GetStorage(*params)->ToString(), ")");
+ }
+
+ return std::make_shared<SimpleExtensionType>(std::move(storage_type),
+ std::move(*params));
+ }
+
+ return Status::Invalid("Could not parse parameters for extension type ",
+ extension_name(), " from ", serialized);
+ }
+
+ struct SerializeImpl {
+ explicit SerializeImpl(const Params& params)
+ : params_(params), members_(kProperties->size()) {
+ kProperties->ForEach(*this);
+ }
+
+ template <typename Property>
+ void operator()(const Property& prop, size_t i) {
+ std::stringstream ss;
+ ss << prop.name() << ":" << prop.get(params_);
+ members_[i] = ss.str();
+ }
+
+ std::string Finish() {
+ return kExtensionName.to_string() + "{" +
+ ::arrow::internal::JoinStrings(members_, ",") + "}";
+ }
+
+ const Params& params_;
+ std::vector<std::string> members_;
+ };
+ std::string Serialize() const override { return SerializeImpl(params_).Finish(); }
+
+ private:
+ Params params_;
+};
+
+} // namespace engine
+} // namespace arrow
diff --git a/cpp/src/arrow/engine/substrait/expression_internal.cc b/cpp/src/arrow/engine/substrait/expression_internal.cc
new file mode 100644
index 0000000..686ef5d
--- /dev/null
+++ b/cpp/src/arrow/engine/substrait/expression_internal.cc
@@ -0,0 +1,896 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// This API is EXPERIMENTAL.
+
+#include "arrow/engine/substrait/expression_internal.h"
+
+#include <utility>
+
+#include "arrow/builder.h"
+#include "arrow/compute/exec/expression.h"
+#include "arrow/compute/exec/expression_internal.h"
+#include "arrow/engine/substrait/extension_types.h"
+#include "arrow/engine/substrait/type_internal.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/util/make_unique.h"
+#include "arrow/visit_scalar_inline.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace engine {
+
+namespace internal {
+using ::arrow::internal::make_unique;
+} // namespace internal
+
+Result<compute::Expression> FromProto(const substrait::Expression& expr,
+ const ExtensionSet& ext_set) {
+ switch (expr.rex_type_case()) {
+ case substrait::Expression::kLiteral: {
+ ARROW_ASSIGN_OR_RAISE(auto datum, FromProto(expr.literal(), ext_set));
+ return compute::literal(std::move(datum));
+ }
+
+ case substrait::Expression::kSelection: {
+ if (!expr.selection().has_direct_reference()) break;
+
+ util::optional<compute::Expression> out;
+ if (expr.selection().has_expression()) {
+ ARROW_ASSIGN_OR_RAISE(out, FromProto(expr.selection().expression(), ext_set));
+ }
+
+ const auto* ref = &expr.selection().direct_reference();
+ while (ref != nullptr) {
+ switch (ref->reference_type_case()) {
+ case substrait::Expression::ReferenceSegment::kStructField: {
+ auto index = ref->struct_field().field();
+ if (!out) {
+ // Root StructField (column selection)
+ out = compute::field_ref(FieldRef(index));
+ } else if (auto out_ref = out->field_ref()) {
+ // Nested StructFields on the root (selection of struct-typed column
+ // combined with selecting struct fields)
+ out = compute::field_ref(FieldRef(*out_ref, index));
+ } else if (out->call() && out->call()->function_name == "struct_field") {
+ // Nested StructFields on top of an arbitrary expression
+ std::static_pointer_cast<arrow::compute::StructFieldOptions>(
+ out->call()->options)
+ ->indices.push_back(index);
+ } else {
+ // First StructField on top of an arbitrary expression
+ out = compute::call("struct_field", {std::move(*out)},
+ arrow::compute::StructFieldOptions({index}));
+ }
+
+ // Segment handled, continue with child segment (if any)
+ if (ref->struct_field().has_child()) {
+ ref = &ref->struct_field().child();
+ } else {
+ ref = nullptr;
+ }
+ break;
+ }
+ case substrait::Expression::ReferenceSegment::kListElement: {
+ if (!out) {
+ // Root ListField (illegal)
+ return Status::Invalid(
+ "substrait::ListElement cannot take a Relation as an argument");
+ }
+
+ // ListField on top of an arbitrary expression
+ out = compute::call(
+ "list_element",
+ {std::move(*out), compute::literal(ref->list_element().offset())});
+
+ // Segment handled, continue with child segment (if any)
+ if (ref->list_element().has_child()) {
+ ref = &ref->list_element().child();
+ } else {
+ ref = nullptr;
+ }
+ break;
+ }
+ default:
+ // Unimplemented construct, break out of loop
+ out.reset();
+ ref = nullptr;
+ }
+ }
+ if (out) {
+ return *std::move(out);
+ }
+ break;
+ }
+
+ case substrait::Expression::kIfThen: {
+ const auto& if_then = expr.if_then();
+ if (!if_then.has_else_()) break;
+ if (if_then.ifs_size() == 0) break;
+
+ if (if_then.ifs_size() == 1) {
+ ARROW_ASSIGN_OR_RAISE(auto if_, FromProto(if_then.ifs(0).if_(), ext_set));
+ ARROW_ASSIGN_OR_RAISE(auto then, FromProto(if_then.ifs(0).then(), ext_set));
+ ARROW_ASSIGN_OR_RAISE(auto else_, FromProto(if_then.else_(), ext_set));
+ return compute::call("if_else",
+ {std::move(if_), std::move(then), std::move(else_)});
+ }
+
+ std::vector<compute::Expression> conditions, args;
+ std::vector<std::string> condition_names;
+ conditions.reserve(if_then.ifs_size());
+ condition_names.reserve(if_then.ifs_size());
+ size_t name_counter = 0;
+ args.reserve(if_then.ifs_size() + 2);
+ args.emplace_back();
+ for (const auto& if_ : if_then.ifs()) {
+ ARROW_ASSIGN_OR_RAISE(auto compute_if, FromProto(if_.if_(), ext_set));
+ ARROW_ASSIGN_OR_RAISE(auto compute_then, FromProto(if_.then(), ext_set));
+ conditions.emplace_back(std::move(compute_if));
+ args.emplace_back(std::move(compute_then));
+ condition_names.emplace_back("cond" + std::to_string(++name_counter));
+ }
+ ARROW_ASSIGN_OR_RAISE(auto compute_else, FromProto(if_then.else_(), ext_set));
+ args.emplace_back(std::move(compute_else));
+ args[0] = compute::call("make_struct", std::move(conditions),
+ compute::MakeStructOptions(condition_names));
+ return compute::call("case_when", std::move(args));
+ }
+
+ case substrait::Expression::kScalarFunction: {
+ const auto& scalar_fn = expr.scalar_function();
+
+ ARROW_ASSIGN_OR_RAISE(auto decoded_function,
+ ext_set.DecodeFunction(scalar_fn.function_reference()));
+
+ std::vector<compute::Expression> arguments(scalar_fn.args_size());
+ for (int i = 0; i < scalar_fn.args_size(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(arguments[i], FromProto(scalar_fn.args(i), ext_set));
+ }
+
+ return compute::call(decoded_function.name.to_string(), std::move(arguments));
+ }
+
+ default:
+ break;
+ }
+
+ return Status::NotImplemented(
+ "conversion to arrow::compute::Expression from Substrait expression ",
+ expr.DebugString());
+}
+
+Result<Datum> FromProto(const substrait::Expression::Literal& lit,
+ const ExtensionSet& ext_set) {
+ if (lit.nullable()) {
+ // FIXME not sure how this field should be interpreted and there's no way to round
+ // trip it through arrow
+ return Status::Invalid(
+ "Nullable Literals - Literal.nullable must be left at the default");
+ }
+
+ switch (lit.literal_type_case()) {
+ case substrait::Expression::Literal::kBoolean:
+ return Datum(lit.boolean());
+
+ case substrait::Expression::Literal::kI8:
+ return Datum(static_cast<int8_t>(lit.i8()));
+ case substrait::Expression::Literal::kI16:
+ return Datum(static_cast<int16_t>(lit.i16()));
+ case substrait::Expression::Literal::kI32:
+ return Datum(static_cast<int32_t>(lit.i32()));
+ case substrait::Expression::Literal::kI64:
+ return Datum(static_cast<int64_t>(lit.i64()));
+
+ case substrait::Expression::Literal::kFp32:
+ return Datum(lit.fp32());
+ case substrait::Expression::Literal::kFp64:
+ return Datum(lit.fp64());
+
+ case substrait::Expression::Literal::kString:
+ return Datum(lit.string());
+ case substrait::Expression::Literal::kBinary:
+ return Datum(BinaryScalar(lit.binary()));
+
+ case substrait::Expression::Literal::kTimestamp:
+ return Datum(
+ TimestampScalar(static_cast<int64_t>(lit.timestamp()), TimeUnit::MICRO));
+
+ case substrait::Expression::Literal::kTimestampTz:
+ return Datum(TimestampScalar(static_cast<int64_t>(lit.timestamp_tz()),
+ TimeUnit::MICRO, TimestampTzTimezoneString()));
+
+ case substrait::Expression::Literal::kDate:
+ return Datum(Date32Scalar(lit.date()));
+ case substrait::Expression::Literal::kTime:
+ return Datum(Time64Scalar(lit.time(), TimeUnit::MICRO));
+
+ case substrait::Expression::Literal::kIntervalYearToMonth:
+ case substrait::Expression::Literal::kIntervalDayToSecond: {
+ Int32Builder builder;
+ std::shared_ptr<DataType> type;
+ if (lit.has_interval_year_to_month()) {
+ RETURN_NOT_OK(builder.Append(lit.interval_year_to_month().years()));
+ RETURN_NOT_OK(builder.Append(lit.interval_year_to_month().months()));
+ type = interval_year();
+ } else {
+ RETURN_NOT_OK(builder.Append(lit.interval_day_to_second().days()));
+ RETURN_NOT_OK(builder.Append(lit.interval_day_to_second().seconds()));
+ type = interval_day();
+ }
+ ARROW_ASSIGN_OR_RAISE(auto array, builder.Finish());
+ return Datum(
+ ExtensionScalar(FixedSizeListScalar(std::move(array)), std::move(type)));
+ }
+
+ case substrait::Expression::Literal::kUuid:
+ return Datum(ExtensionScalar(FixedSizeBinaryScalar(lit.uuid()), uuid()));
+
+ case substrait::Expression::Literal::kFixedChar:
+ return Datum(
+ ExtensionScalar(FixedSizeBinaryScalar(lit.fixed_char()),
+ fixed_char(static_cast<int32_t>(lit.fixed_char().size()))));
+
+ case substrait::Expression::Literal::kVarChar:
+ return Datum(
+ ExtensionScalar(StringScalar(lit.var_char().value()),
+ varchar(static_cast<int32_t>(lit.var_char().length()))));
+
+ case substrait::Expression::Literal::kFixedBinary:
+ return Datum(FixedSizeBinaryScalar(lit.fixed_binary()));
+
+ case substrait::Expression::Literal::kDecimal: {
+ if (lit.decimal().value().size() != sizeof(Decimal128)) {
+ return Status::Invalid("Decimal literal had ", lit.decimal().value().size(),
+ " bytes (expected ", sizeof(Decimal128), ")");
+ }
+
+ Decimal128 value;
+ std::memcpy(value.mutable_native_endian_bytes(), lit.decimal().value().data(),
+ sizeof(Decimal128));
+#if !ARROW_LITTLE_ENDIAN
+ std::reverse(value.mutable_native_endian_bytes(),
+ value.mutable_native_endian_bytes() + sizeof(Decimal128));
+#endif
+ auto type = decimal128(lit.decimal().precision(), lit.decimal().scale());
+ return Datum(Decimal128Scalar(value, std::move(type)));
+ }
+
+ case substrait::Expression::Literal::kStruct: {
+ const auto& struct_ = lit.struct_();
+
+ ScalarVector fields(struct_.fields_size());
+ for (int i = 0; i < struct_.fields_size(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(auto field, FromProto(struct_.fields(i), ext_set));
+ DCHECK(field.is_scalar());
+ fields[i] = field.scalar();
+ }
+
+ // Note that Substrait struct types don't have field names, but Arrow does, so we
+ // just use empty strings for them.
+ std::vector<std::string> field_names(fields.size(), "");
+
+ ARROW_ASSIGN_OR_RAISE(
+ auto scalar, StructScalar::Make(std::move(fields), std::move(field_names)));
+ return Datum(std::move(scalar));
+ }
+
+ case substrait::Expression::Literal::kList: {
+ const auto& list = lit.list();
+ if (list.values_size() == 0) {
+ return Status::Invalid(
+ "substrait::Expression::Literal::List had no values; should have been an "
+ "substrait::Expression::Literal::EmptyList");
+ }
+
+ std::shared_ptr<DataType> element_type;
+
+ ScalarVector values(list.values_size());
+ for (int i = 0; i < list.values_size(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(auto value, FromProto(list.values(i), ext_set));
+ DCHECK(value.is_scalar());
+ values[i] = value.scalar();
+ if (element_type) {
+ if (!value.type()->Equals(*element_type)) {
+ return Status::Invalid(
+ list.DebugString(),
+ " has a value whose type doesn't match the other list values");
+ }
+ } else {
+ element_type = value.type();
+ }
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto builder, MakeBuilder(element_type));
+ RETURN_NOT_OK(builder->AppendScalars(values));
+ ARROW_ASSIGN_OR_RAISE(auto arr, builder->Finish());
+ return Datum(ListScalar(std::move(arr)));
+ }
+
+ case substrait::Expression::Literal::kMap: {
+ const auto& map = lit.map();
+ if (map.key_values_size() == 0) {
+ return Status::Invalid(
+ "substrait::Expression::Literal::Map had no values; should have been an "
+ "substrait::Expression::Literal::EmptyMap");
+ }
+
+ std::shared_ptr<DataType> key_type, value_type;
+ ScalarVector keys(map.key_values_size()), values(map.key_values_size());
+ for (int i = 0; i < map.key_values_size(); ++i) {
+ const auto& kv = map.key_values(i);
+
+ static const std::array<char const*, 4> kMissing = {"key and value", "value",
+ "key", nullptr};
+ if (auto missing = kMissing[kv.has_key() + kv.has_value() * 2]) {
+ return Status::Invalid("While converting to MapScalar encountered missing ",
+ missing, " in ", map.DebugString());
+ }
+ ARROW_ASSIGN_OR_RAISE(auto key, FromProto(kv.key(), ext_set));
+ ARROW_ASSIGN_OR_RAISE(auto value, FromProto(kv.value(), ext_set));
+
+ DCHECK(key.is_scalar());
+ DCHECK(value.is_scalar());
+
+ keys[i] = key.scalar();
+ values[i] = value.scalar();
+
+ if (key_type) {
+ if (!key.type()->Equals(*key_type)) {
+ return Status::Invalid(map.DebugString(),
+ " has a key whose type doesn't match key_type");
+ }
+ } else {
+ key_type = value.type();
+ }
+
+ if (value_type) {
+ if (!value.type()->Equals(*value_type)) {
+ return Status::Invalid(map.DebugString(),
+ " has a value whose type doesn't match value_type");
+ }
+ } else {
+ value_type = value.type();
+ }
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto key_builder, MakeBuilder(key_type));
+ ARROW_ASSIGN_OR_RAISE(auto value_builder, MakeBuilder(value_type));
+ RETURN_NOT_OK(key_builder->AppendScalars(keys));
+ RETURN_NOT_OK(value_builder->AppendScalars(values));
+ ARROW_ASSIGN_OR_RAISE(auto key_arr, key_builder->Finish());
+ ARROW_ASSIGN_OR_RAISE(auto value_arr, value_builder->Finish());
+ ARROW_ASSIGN_OR_RAISE(
+ auto kv_arr,
+ StructArray::Make(ArrayVector{std::move(key_arr), std::move(value_arr)},
+ std::vector<std::string>{"key", "value"}));
+ return Datum(std::make_shared<MapScalar>(std::move(kv_arr)));
+ }
+
+ case substrait::Expression::Literal::kEmptyList: {
+ ARROW_ASSIGN_OR_RAISE(auto type_nullable,
+ FromProto(lit.empty_list().type(), ext_set));
+ ARROW_ASSIGN_OR_RAISE(auto values, MakeEmptyArray(type_nullable.first));
+ return ListScalar{std::move(values)};
+ }
+
+ case substrait::Expression::Literal::kEmptyMap: {
+ ARROW_ASSIGN_OR_RAISE(auto key_type_nullable,
+ FromProto(lit.empty_map().key(), ext_set));
+ ARROW_ASSIGN_OR_RAISE(auto keys,
+ MakeEmptyArray(std::move(key_type_nullable.first)));
+
+ ARROW_ASSIGN_OR_RAISE(auto value_type_nullable,
+ FromProto(lit.empty_map().value(), ext_set));
+ ARROW_ASSIGN_OR_RAISE(auto values,
+ MakeEmptyArray(std::move(value_type_nullable.first)));
+
+ auto map_type = std::make_shared<MapType>(keys->type(), values->type());
+ ARROW_ASSIGN_OR_RAISE(
+ auto key_values,
+ StructArray::Make(
+ {std::move(keys), std::move(values)},
+ checked_cast<const ListType&>(*map_type).value_type()->fields()));
+
+ return MapScalar{std::move(key_values)};
+ }
+
+ case substrait::Expression::Literal::kNull: {
+ ARROW_ASSIGN_OR_RAISE(auto type_nullable, FromProto(lit.null(), ext_set));
+ if (!type_nullable.second) {
+ return Status::Invalid("Substrait null literal ", lit.DebugString(),
+ " is of non-nullable type");
+ }
+
+ return Datum(MakeNullScalar(std::move(type_nullable.first)));
+ }
+
+ default:
+ break;
+ }
+
+ return Status::NotImplemented("conversion to arrow::Datum from Substrait literal ",
+ lit.DebugString());
+}
+
+namespace {
+struct ScalarToProtoImpl {
+ Status Visit(const NullScalar& s) { return NotImplemented(s); }
+
+ using Lit = substrait::Expression::Literal;
+
+ template <typename Arg, typename PrimitiveScalar>
+ Status Primitive(void (substrait::Expression::Literal::*set)(Arg),
+ const PrimitiveScalar& primitive_scalar) {
+ (lit_->*set)(static_cast<Arg>(primitive_scalar.value));
+ return Status::OK();
+ }
+
+ template <typename ScalarWithBufferValue>
+ Status FromBuffer(void (substrait::Expression::Literal::*set)(std::string&&),
+ const ScalarWithBufferValue& scalar_with_buffer) {
+ (lit_->*set)(scalar_with_buffer.value->ToString());
+ return Status::OK();
+ }
+
+ Status Visit(const BooleanScalar& s) { return Primitive(&Lit::set_boolean, s); }
+
+ Status Visit(const Int8Scalar& s) { return Primitive(&Lit::set_i8, s); }
+ Status Visit(const Int16Scalar& s) { return Primitive(&Lit::set_i16, s); }
+ Status Visit(const Int32Scalar& s) { return Primitive(&Lit::set_i32, s); }
+ Status Visit(const Int64Scalar& s) { return Primitive(&Lit::set_i64, s); }
+
+ Status Visit(const UInt8Scalar& s) { return NotImplemented(s); }
+ Status Visit(const UInt16Scalar& s) { return NotImplemented(s); }
+ Status Visit(const UInt32Scalar& s) { return NotImplemented(s); }
+ Status Visit(const UInt64Scalar& s) { return NotImplemented(s); }
+
+ Status Visit(const HalfFloatScalar& s) { return NotImplemented(s); }
+ Status Visit(const FloatScalar& s) { return Primitive(&Lit::set_fp32, s); }
+ Status Visit(const DoubleScalar& s) { return Primitive(&Lit::set_fp64, s); }
+
+ Status Visit(const StringScalar& s) { return FromBuffer(&Lit::set_string, s); }
+ Status Visit(const BinaryScalar& s) { return FromBuffer(&Lit::set_binary, s); }
+
+ Status Visit(const FixedSizeBinaryScalar& s) {
+ return FromBuffer(&Lit::set_fixed_binary, s);
+ }
+
+ Status Visit(const Date32Scalar& s) { return Primitive(&Lit::set_date, s); }
+ Status Visit(const Date64Scalar& s) { return NotImplemented(s); }
+
+ Status Visit(const TimestampScalar& s) {
+ const auto& t = checked_cast<const TimestampType&>(*s.type);
+
+ if (t.unit() != TimeUnit::MICRO) return NotImplemented(s);
+
+ if (t.timezone() == "") return Primitive(&Lit::set_timestamp, s);
+
+ if (t.timezone() == TimestampTzTimezoneString()) {
+ return Primitive(&Lit::set_timestamp_tz, s);
+ }
+
+ return NotImplemented(s);
+ }
+
+ Status Visit(const Time32Scalar& s) { return NotImplemented(s); }
+ Status Visit(const Time64Scalar& s) {
+ if (checked_cast<const Time64Type&>(*s.type).unit() != TimeUnit::MICRO) {
+ return NotImplemented(s);
+ }
+ return Primitive(&Lit::set_time, s);
+ }
+
+ Status Visit(const MonthIntervalScalar& s) { return NotImplemented(s); }
+ Status Visit(const DayTimeIntervalScalar& s) { return NotImplemented(s); }
+
+ Status Visit(const Decimal128Scalar& s) {
+ auto decimal = internal::make_unique<Lit::Decimal>();
+
+ auto decimal_type = checked_cast<const Decimal128Type*>(s.type.get());
+ decimal->set_precision(decimal_type->precision());
+ decimal->set_scale(decimal_type->scale());
+
+ decimal->set_value(reinterpret_cast<const char*>(s.value.native_endian_bytes()),
+ sizeof(Decimal128));
+#if !ARROW_LITTLE_ENDIAN
+ std::reverse(decimal->mutable_value()->begin(), decimal->mutable_value()->end());
+#endif
+ lit_->set_allocated_decimal(decimal.release());
+ return Status::OK();
+ }
+
+ Status Visit(const Decimal256Scalar& s) { return NotImplemented(s); }
+
+ Status Visit(const ListScalar& s) {
+ if (s.value->length() == 0) {
+ ARROW_ASSIGN_OR_RAISE(auto list_type,
+ ToProto(*s.type, /*nullable=*/true, ext_set_));
+ lit_->set_allocated_empty_list(list_type->release_list());
+ return Status::OK();
+ }
+
+ lit_->set_allocated_list(new Lit::List());
+
+ const auto& list_type = checked_cast<const ListType&>(*s.type);
+ ARROW_ASSIGN_OR_RAISE(
+ auto element_type,
+ ToProto(*list_type.value_type(), list_type.value_field()->nullable(), ext_set_));
+
+ auto values = lit_->mutable_list()->mutable_values();
+ values->Reserve(static_cast<int>(s.value->length()));
+
+ for (int64_t i = 0; i < s.value->length(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(Datum list_element, s.value->GetScalar(i));
+ ARROW_ASSIGN_OR_RAISE(auto lit, ToProto(list_element, ext_set_));
+ values->AddAllocated(lit.release());
+ }
+ return Status::OK();
+ }
+
+ Status Visit(const StructScalar& s) {
+ lit_->set_allocated_struct_(new Lit::Struct());
+
+ auto fields = lit_->mutable_struct_()->mutable_fields();
+ fields->Reserve(static_cast<int>(s.value.size()));
+
+ for (Datum field : s.value) {
+ ARROW_ASSIGN_OR_RAISE(auto lit, ToProto(field, ext_set_));
+ fields->AddAllocated(lit.release());
+ }
+ return Status::OK();
+ }
+
+ Status Visit(const SparseUnionScalar& s) { return NotImplemented(s); }
+ Status Visit(const DenseUnionScalar& s) { return NotImplemented(s); }
+ Status Visit(const DictionaryScalar& s) { return NotImplemented(s); }
+
+ Status Visit(const MapScalar& s) {
+ if (s.value->length() == 0) {
+ ARROW_ASSIGN_OR_RAISE(auto map_type, ToProto(*s.type, /*nullable=*/true, ext_set_));
+ lit_->set_allocated_empty_map(map_type->release_map());
+ return Status::OK();
+ }
+
+ lit_->set_allocated_map(new Lit::Map());
+
+ const auto& kv_arr = checked_cast<const StructArray&>(*s.value);
+
+ auto key_values = lit_->mutable_map()->mutable_key_values();
+ key_values->Reserve(static_cast<int>(kv_arr.length()));
+
+ for (int64_t i = 0; i < s.value->length(); ++i) {
+ auto kv = internal::make_unique<Lit::Map::KeyValue>();
+
+ ARROW_ASSIGN_OR_RAISE(Datum key_scalar, kv_arr.field(0)->GetScalar(i));
+ ARROW_ASSIGN_OR_RAISE(auto key, ToProto(key_scalar, ext_set_));
+ kv->set_allocated_key(key.release());
+
+ ARROW_ASSIGN_OR_RAISE(Datum value_scalar, kv_arr.field(1)->GetScalar(i));
+ ARROW_ASSIGN_OR_RAISE(auto value, ToProto(value_scalar, ext_set_));
+ kv->set_allocated_value(value.release());
+
+ key_values->AddAllocated(kv.release());
+ }
+ return Status::OK();
+ }
+
+ Status Visit(const ExtensionScalar& s) {
+ if (UnwrapUuid(*s.type)) {
+ return FromBuffer(&Lit::set_uuid,
+ checked_cast<const FixedSizeBinaryScalar&>(*s.value));
+ }
+
+ if (UnwrapFixedChar(*s.type)) {
+ return FromBuffer(&Lit::set_fixed_char,
+ checked_cast<const FixedSizeBinaryScalar&>(*s.value));
+ }
+
+ if (auto length = UnwrapVarChar(*s.type)) {
+ auto var_char = internal::make_unique<Lit::VarChar>();
+ var_char->set_length(*length);
+ var_char->set_value(checked_cast<const StringScalar&>(*s.value).value->ToString());
+
+ lit_->set_allocated_var_char(var_char.release());
+ return Status::OK();
+ }
+
+ auto GetPairOfInts = [&] {
+ const auto& array = *checked_cast<const FixedSizeListScalar&>(*s.value).value;
+ auto ints = checked_cast<const Int32Array&>(array).raw_values();
+ return std::make_pair(ints[0], ints[1]);
+ };
+
+ if (UnwrapIntervalYear(*s.type)) {
+ auto interval_year = internal::make_unique<Lit::IntervalYearToMonth>();
+ interval_year->set_years(GetPairOfInts().first);
+ interval_year->set_months(GetPairOfInts().second);
+
+ lit_->set_allocated_interval_year_to_month(interval_year.release());
+ return Status::OK();
+ }
+
+ if (UnwrapIntervalDay(*s.type)) {
+ auto interval_day = internal::make_unique<Lit::IntervalDayToSecond>();
+ interval_day->set_days(GetPairOfInts().first);
+ interval_day->set_seconds(GetPairOfInts().second);
+
+ lit_->set_allocated_interval_day_to_second(interval_day.release());
+ return Status::OK();
+ }
+
+ return NotImplemented(s);
+ }
+
+ Status Visit(const FixedSizeListScalar& s) { return NotImplemented(s); }
+ Status Visit(const DurationScalar& s) { return NotImplemented(s); }
+ Status Visit(const LargeStringScalar& s) { return NotImplemented(s); }
+ Status Visit(const LargeBinaryScalar& s) { return NotImplemented(s); }
+ Status Visit(const LargeListScalar& s) { return NotImplemented(s); }
+ Status Visit(const MonthDayNanoIntervalScalar& s) { return NotImplemented(s); }
+
+ Status NotImplemented(const Scalar& s) {
+ return Status::NotImplemented("conversion to substrait::Expression::Literal from ",
+ s.ToString());
+ }
+
+ Status operator()(const Scalar& scalar) { return VisitScalarInline(scalar, this); }
+
+ substrait::Expression::Literal* lit_;
+ ExtensionSet* ext_set_;
+};
+} // namespace
+
+Result<std::unique_ptr<substrait::Expression::Literal>> ToProto(const Datum& datum,
+ ExtensionSet* ext_set) {
+ if (!datum.is_scalar()) {
+ return Status::NotImplemented("representing ", datum.ToString(),
+ " as a substrait::Expression::Literal");
+ }
+
+ auto out = internal::make_unique<substrait::Expression::Literal>();
+
+ if (datum.scalar()->is_valid) {
+ RETURN_NOT_OK((ScalarToProtoImpl{out.get(), ext_set})(*datum.scalar()));
+ } else {
+ ARROW_ASSIGN_OR_RAISE(auto type, ToProto(*datum.type(), /*nullable=*/true, ext_set));
+ out->set_allocated_null(type.release());
+ }
+
+ return std::move(out);
+}
+
+static Status AddChildToReferenceSegment(
+ substrait::Expression::ReferenceSegment& segment,
+ std::unique_ptr<substrait::Expression::ReferenceSegment>&& child) {
+ auto status = Status::Invalid("Attempt to add child to incomplete reference segment");
+ switch (segment.reference_type_case()) {
+ case substrait::Expression::ReferenceSegment::kMapKey: {
+ auto map_key = segment.mutable_map_key();
+ if (map_key->has_child()) {
+ status = AddChildToReferenceSegment(*map_key->mutable_child(), std::move(child));
+ } else {
+ map_key->set_allocated_child(child.release());
+ status = Status::OK();
+ }
+ break;
+ }
+ case substrait::Expression::ReferenceSegment::kStructField: {
+ auto struct_field = segment.mutable_struct_field();
+ if (struct_field->has_child()) {
+ status =
+ AddChildToReferenceSegment(*struct_field->mutable_child(), std::move(child));
+ } else {
+ struct_field->set_allocated_child(child.release());
+ status = Status::OK();
+ }
+ break;
+ }
+ case substrait::Expression::ReferenceSegment::kListElement: {
+ auto list_element = segment.mutable_list_element();
+ if (list_element->has_child()) {
+ status =
+ AddChildToReferenceSegment(*list_element->mutable_child(), std::move(child));
+ } else {
+ list_element->set_allocated_child(child.release());
+ status = Status::OK();
+ }
+ break;
+ }
+ default:
+ break;
+ }
+ return status;
+}
+
+// Indexes the given Substrait expression or root (if expr is empty) using the given
+// ReferenceSegment.
+static Result<std::unique_ptr<substrait::Expression>> MakeDirectReference(
+ std::unique_ptr<substrait::Expression>&& expr,
+ std::unique_ptr<substrait::Expression::ReferenceSegment>&& ref_segment) {
+ // If expr is already a selection expression, add the index to its index stack.
+ if (expr && expr->has_selection() && expr->selection().has_direct_reference()) {
+ auto selection = expr->mutable_selection();
+ auto root_ref_segment = selection->mutable_direct_reference();
+ auto status = AddChildToReferenceSegment(*root_ref_segment, std::move(ref_segment));
+ if (status.ok()) {
+ return std::move(expr);
+ }
+ }
+
+ auto selection = internal::make_unique<substrait::Expression::FieldReference>();
+ selection->set_allocated_direct_reference(ref_segment.release());
+
+ if (expr && expr->rex_type_case() != substrait::Expression::REX_TYPE_NOT_SET) {
+ selection->set_allocated_expression(expr.release());
+ } else {
+ selection->set_allocated_root_reference(
+ new substrait::Expression::FieldReference::RootReference());
+ }
+
+ auto out = internal::make_unique<substrait::Expression>();
+ out->set_allocated_selection(selection.release());
+ return std::move(out);
+}
+
+// Indexes the given Substrait struct-typed expression or root (if expr is empty) using
+// the given field index.
+static Result<std::unique_ptr<substrait::Expression>> MakeStructFieldReference(
+ std::unique_ptr<substrait::Expression>&& expr, int field) {
+ auto struct_field =
+ internal::make_unique<substrait::Expression::ReferenceSegment::StructField>();
+ struct_field->set_field(field);
+
+ auto ref_segment = internal::make_unique<substrait::Expression::ReferenceSegment>();
+ ref_segment->set_allocated_struct_field(struct_field.release());
+
+ return MakeDirectReference(std::move(expr), std::move(ref_segment));
+}
+
+// Indexes the given Substrait list-typed expression using the given offset.
+static Result<std::unique_ptr<substrait::Expression>> MakeListElementReference(
+ std::unique_ptr<substrait::Expression>&& expr, int offset) {
+ auto list_element =
+ internal::make_unique<substrait::Expression::ReferenceSegment::ListElement>();
+ list_element->set_offset(offset);
+
+ auto ref_segment = internal::make_unique<substrait::Expression::ReferenceSegment>();
+ ref_segment->set_allocated_list_element(list_element.release());
+
+ return MakeDirectReference(std::move(expr), std::move(ref_segment));
+}
+
+Result<std::unique_ptr<substrait::Expression>> ToProto(const compute::Expression& expr,
+ ExtensionSet* ext_set) {
+ if (!expr.IsBound()) {
+ return Status::Invalid("ToProto requires a bound Expression");
+ }
+
+ auto out = internal::make_unique<substrait::Expression>();
+
+ if (auto datum = expr.literal()) {
+ ARROW_ASSIGN_OR_RAISE(auto literal, ToProto(*datum, ext_set));
+ out->set_allocated_literal(literal.release());
+ return std::move(out);
+ }
+
+ if (auto param = expr.parameter()) {
+ // Special case of a nested StructField
+ DCHECK(!param->indices.empty());
+
+ for (int index : param->indices) {
+ ARROW_ASSIGN_OR_RAISE(out, MakeStructFieldReference(std::move(out), index));
+ }
+
+ return std::move(out);
+ }
+
+ auto call = CallNotNull(expr);
+
+ if (call->function_name == "case_when") {
+ auto conditions = call->arguments[0].call();
+ if (conditions && conditions->function_name == "make_struct") {
+ // catch the special case of calls convertible to IfThen
+ auto if_then_ = internal::make_unique<substrait::Expression::IfThen>();
+
+ // don't try to convert argument 0 of the case_when; we have to convert the elements
+ // of make_struct individually
+ std::vector<std::unique_ptr<substrait::Expression>> arguments(
+ call->arguments.size() - 1);
+ for (size_t i = 1; i < call->arguments.size(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(arguments[i - 1], ToProto(call->arguments[i], ext_set));
+ }
+
+ for (size_t i = 0; i < conditions->arguments.size(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(auto cond_substrait,
+ ToProto(conditions->arguments[i], ext_set));
+ auto clause = internal::make_unique<substrait::Expression::IfThen::IfClause>();
+ clause->set_allocated_if_(cond_substrait.release());
+ clause->set_allocated_then(arguments[i].release());
+ if_then_->mutable_ifs()->AddAllocated(clause.release());
+ }
+
+ if_then_->set_allocated_else_(arguments.back().release());
+
+ out->set_allocated_if_then(if_then_.release());
+ return std::move(out);
+ }
+ }
+
+ // the remaining function pattern matchers only convert the function itself, so we
+ // should be able to convert all its arguments first here
+ std::vector<std::unique_ptr<substrait::Expression>> arguments(call->arguments.size());
+ for (size_t i = 0; i < arguments.size(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(arguments[i], ToProto(call->arguments[i], ext_set));
+ }
+
+ if (call->function_name == "struct_field") {
+ // catch the special case of calls convertible to a StructField
+ out = std::move(arguments[0]);
+ for (int index :
+ checked_cast<const arrow::compute::StructFieldOptions&>(*call->options)
+ .indices) {
+ ARROW_ASSIGN_OR_RAISE(out, MakeStructFieldReference(std::move(out), index));
+ }
+
+ return std::move(out);
+ }
+
+ if (call->function_name == "list_element") {
+ // catch the special case of calls convertible to a ListElement
+ if (arguments[0]->has_selection() &&
+ arguments[0]->selection().has_direct_reference()) {
+ if (arguments[1]->has_literal() && arguments[1]->literal().has_i32()) {
+ return MakeListElementReference(std::move(arguments[0]),
+ arguments[1]->literal().i32());
+ }
+ }
+ }
+
+ if (call->function_name == "if_else") {
+ // catch the special case of calls convertible to IfThen
+ auto if_clause = internal::make_unique<substrait::Expression::IfThen::IfClause>();
+ if_clause->set_allocated_if_(arguments[0].release());
+ if_clause->set_allocated_then(arguments[1].release());
+
+ auto if_then = internal::make_unique<substrait::Expression::IfThen>();
+ if_then->mutable_ifs()->AddAllocated(if_clause.release());
+ if_then->set_allocated_else_(arguments[2].release());
+
+ out->set_allocated_if_then(if_then.release());
+ return std::move(out);
+ }
+
+ // other expression types dive into extensions immediately
+ ARROW_ASSIGN_OR_RAISE(auto anchor, ext_set->EncodeFunction(call->function_name));
+
+ auto scalar_fn = internal::make_unique<substrait::Expression::ScalarFunction>();
+ scalar_fn->set_function_reference(anchor);
+ scalar_fn->mutable_args()->Reserve(static_cast<int>(arguments.size()));
+ for (auto& arg : arguments) {
+ scalar_fn->mutable_args()->AddAllocated(arg.release());
+ }
+
+ out->set_allocated_scalar_function(scalar_fn.release());
+ return std::move(out);
+}
+
+} // namespace engine
+} // namespace arrow
diff --git a/cpp/src/arrow/engine/substrait/expression_internal.h b/cpp/src/arrow/engine/substrait/expression_internal.h
new file mode 100644
index 0000000..e491fa6
--- /dev/null
+++ b/cpp/src/arrow/engine/substrait/expression_internal.h
@@ -0,0 +1,49 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// This API is EXPERIMENTAL.
+
+#pragma once
+
+#include <utility>
+
+#include "arrow/compute/type_fwd.h"
+#include "arrow/engine/substrait/extension_set.h"
+#include "arrow/engine/visibility.h"
+#include "arrow/type_fwd.h"
+
+#include "substrait/expression.pb.h" // IWYU pragma: export
+
+namespace arrow {
+namespace engine {
+
+ARROW_ENGINE_EXPORT
+Result<compute::Expression> FromProto(const substrait::Expression&, const ExtensionSet&);
+
+ARROW_ENGINE_EXPORT
+Result<std::unique_ptr<substrait::Expression>> ToProto(const compute::Expression&,
+ ExtensionSet*);
+
+ARROW_ENGINE_EXPORT
+Result<Datum> FromProto(const substrait::Expression::Literal&, const ExtensionSet&);
+
+ARROW_ENGINE_EXPORT
+Result<std::unique_ptr<substrait::Expression::Literal>> ToProto(const Datum&,
+ ExtensionSet*);
+
+} // namespace engine
+} // namespace arrow
diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc
new file mode 100644
index 0000000..fe43ab2
--- /dev/null
+++ b/cpp/src/arrow/engine/substrait/extension_set.cc
@@ -0,0 +1,367 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/engine/substrait/extension_set.h"
+
+#include <unordered_map>
+#include <unordered_set>
+
+#include "arrow/util/hash_util.h"
+#include "arrow/util/hashing.h"
+#include "arrow/util/string_view.h"
+
+namespace arrow {
+namespace engine {
+namespace {
+
+struct TypePtrHashEq {
+ template <typename Ptr>
+ size_t operator()(const Ptr& type) const {
+ return type->Hash();
+ }
+
+ template <typename Ptr>
+ bool operator()(const Ptr& l, const Ptr& r) const {
+ return *l == *r;
+ }
+};
+
+struct IdHashEq {
+ using Id = ExtensionSet::Id;
+
+ size_t operator()(Id id) const {
+ constexpr ::arrow::internal::StringViewHash hash = {};
+ auto out = static_cast<size_t>(hash(id.uri));
+ ::arrow::internal::hash_combine(out, hash(id.name));
+ return out;
+ }
+
+ bool operator()(Id l, Id r) const { return l.uri == r.uri && l.name == r.name; }
+};
+
+} // namespace
+
+// A builder used when creating a Substrait plan from an Arrow execution plan. In
+// that situation we do not have a set of anchor values already defined so we keep
+// a map of what Ids we have seen.
+struct ExtensionSet::Impl {
+ void AddUri(util::string_view uri, ExtensionSet* self) {
+ if (uris_.find(uri) != uris_.end()) return;
+
+ self->uris_.push_back(uri);
+ uris_.insert(self->uris_.back()); // lookup helper's keys should reference memory
+ // owned by this ExtensionSet
+ }
+
+ Status CheckHasUri(util::string_view uri) {
+ if (uris_.find(uri) != uris_.end()) return Status::OK();
+
+ return Status::Invalid(
+ "Uri ", uri,
+ " was referenced by an extension but was not declared in the ExtensionSet.");
+ }
+
+ uint32_t EncodeType(ExtensionIdRegistry::TypeRecord type_record, ExtensionSet* self) {
+ // note: at this point we're guaranteed to have an Id which points to memory owned by
+ // the set's registry.
+ AddUri(type_record.id.uri, self);
+ auto it_success =
+ types_.emplace(type_record.id, static_cast<uint32_t>(types_.size()));
+
+ if (it_success.second) {
+ self->types_.push_back(
+ {type_record.id, type_record.type, type_record.is_variation});
+ }
+
+ return it_success.first->second;
+ }
+
+ uint32_t EncodeFunction(Id id, util::string_view function_name, ExtensionSet* self) {
+ // note: at this point we're guaranteed to have an Id which points to memory owned by
+ // the set's registry.
+ AddUri(id.uri, self);
+ auto it_success = functions_.emplace(id, static_cast<uint32_t>(functions_.size()));
+
+ if (it_success.second) {
+ self->functions_.push_back({id, function_name});
+ }
+
+ return it_success.first->second;
+ }
+
+ std::unordered_set<util::string_view, ::arrow::internal::StringViewHash> uris_;
+ std::unordered_map<Id, uint32_t, IdHashEq, IdHashEq> types_, functions_;
+};
+
+ExtensionSet::ExtensionSet(ExtensionIdRegistry* registry)
+ : registry_(registry), impl_(new Impl(), [](Impl* impl) { delete impl; }) {}
+
+Result<ExtensionSet> ExtensionSet::Make(std::vector<util::string_view> uris,
+ std::vector<Id> type_ids,
+ std::vector<bool> type_is_variation,
+ std::vector<Id> function_ids,
+ ExtensionIdRegistry* registry) {
+ ExtensionSet set;
+ set.registry_ = registry;
+
+ // TODO(bkietz) move this into the registry as registry->OwnUris(&uris) or so
+ std::unordered_set<util::string_view, ::arrow::internal::StringViewHash>
+ uris_owned_by_registry;
+ for (util::string_view uri : registry->Uris()) {
+ uris_owned_by_registry.insert(uri);
+ }
+
+ for (auto& uri : uris) {
+ if (uri.empty()) continue;
+ auto it = uris_owned_by_registry.find(uri);
+ if (it == uris_owned_by_registry.end()) {
+ return Status::KeyError("Uri '", uri, "' not found in registry");
+ }
+ uri = *it; // Ensure uris point into the registry's memory
+ set.impl_->AddUri(*it, &set);
+ }
+
+ if (type_ids.size() != type_is_variation.size()) {
+ return Status::Invalid("Received ", type_ids.size(), " type ids but a ",
+ type_is_variation.size(), "-long is_variation vector");
+ }
+
+ set.types_.resize(type_ids.size());
+
+ for (size_t i = 0; i < type_ids.size(); ++i) {
+ if (type_ids[i].empty()) continue;
+ RETURN_NOT_OK(set.impl_->CheckHasUri(type_ids[i].uri));
+
+ if (auto rec = registry->GetType(type_ids[i], type_is_variation[i])) {
+ set.types_[i] = {rec->id, rec->type, rec->is_variation};
+ continue;
+ }
+ return Status::Invalid("Type", (type_is_variation[i] ? " variation" : ""), " ",
+ type_ids[i].uri, "#", type_ids[i].name, " not found");
+ }
+
+ set.functions_.resize(function_ids.size());
+
+ for (size_t i = 0; i < function_ids.size(); ++i) {
+ if (function_ids[i].empty()) continue;
+ RETURN_NOT_OK(set.impl_->CheckHasUri(function_ids[i].uri));
+
+ if (auto rec = registry->GetFunction(function_ids[i])) {
+ set.functions_[i] = {rec->id, rec->function_name};
+ continue;
+ }
+ return Status::Invalid("Function ", function_ids[i].uri, "#", type_ids[i].name,
+ " not found");
+ }
+
+ set.uris_ = std::move(uris);
+
+ return std::move(set);
+}
+
+Result<ExtensionSet::TypeRecord> ExtensionSet::DecodeType(uint32_t anchor) const {
+ if (anchor >= types_.size() || types_[anchor].id.empty()) {
+ return Status::Invalid("User defined type reference ", anchor,
+ " did not have a corresponding anchor in the extension set");
+ }
+ return types_[anchor];
+}
+
+Result<uint32_t> ExtensionSet::EncodeType(const DataType& type) {
+ if (auto rec = registry_->GetType(type)) {
+ return impl_->EncodeType(*rec, this);
+ }
+ return Status::KeyError("type ", type.ToString(), " not found in the registry");
+}
+
+Result<ExtensionSet::FunctionRecord> ExtensionSet::DecodeFunction(uint32_t anchor) const {
+ if (anchor >= functions_.size() || functions_[anchor].id.empty()) {
+ return Status::Invalid("User defined function reference ", anchor,
+ " did not have a corresponding anchor in the extension set");
+ }
+ return functions_[anchor];
+}
+
+Result<uint32_t> ExtensionSet::EncodeFunction(util::string_view function_name) {
+ if (auto rec = registry_->GetFunction(function_name)) {
+ return impl_->EncodeFunction(rec->id, rec->function_name, this);
+ }
+ return Status::KeyError("function ", function_name, " not found in the registry");
+}
+
+template <typename KeyToIndex, typename Key>
+const int* GetIndex(const KeyToIndex& key_to_index, const Key& key) {
+ auto it = key_to_index.find(key);
+ if (it == key_to_index.end()) return nullptr;
+ return &it->second;
+}
+
+ExtensionIdRegistry* default_extension_id_registry() {
+ static struct Impl : ExtensionIdRegistry {
+ Impl() {
+ struct TypeName {
+ std::shared_ptr<DataType> type;
+ util::string_view name;
+ };
+
+ // The type (variation) mappings listed below need to be kept in sync
+ // with the YAML at substrait/format/extension_types.yaml manually;
+ // see ARROW-15535.
+ for (TypeName e : {
+ TypeName{uint8(), "u8"},
+ TypeName{uint16(), "u16"},
+ TypeName{uint32(), "u32"},
+ TypeName{uint64(), "u64"},
+ TypeName{float16(), "fp16"},
+ }) {
+ DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type),
+ /*is_variation=*/true));
+ }
+
+ for (TypeName e : {
+ TypeName{null(), "null"},
+ TypeName{month_interval(), "interval_month"},
+ TypeName{day_time_interval(), "interval_day_milli"},
+ TypeName{month_day_nano_interval(), "interval_month_day_nano"},
+ }) {
+ DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type),
+ /*is_variation=*/false));
+ }
+
+ // TODO: this is just a placeholder right now. We'll need a YAML file for
+ // all functions (and prototypes) that Arrow provides that are relevant
+ // for Substrait, and include mappings for all of them here. See
+ // ARROW-15535.
+ for (util::string_view name : {
+ "add",
+ }) {
+ DCHECK_OK(RegisterFunction({kArrowExtTypesUri, name}, name.to_string()));
+ }
+ }
+
+ std::vector<util::string_view> Uris() const override {
+ return {uris_.begin(), uris_.end()};
+ }
+
+ util::optional<TypeRecord> GetType(const DataType& type) const override {
+ if (auto index = GetIndex(type_to_index_, &type)) {
+ return TypeRecord{type_ids_[*index], types_[*index], type_is_variation_[*index]};
+ }
+ return {};
+ }
+
+ util::optional<TypeRecord> GetType(Id id, bool is_variation) const override {
+ if (auto index =
+ GetIndex(is_variation ? variation_id_to_index_ : id_to_index_, id)) {
+ return TypeRecord{type_ids_[*index], types_[*index], type_is_variation_[*index]};
+ }
+ return {};
+ }
+
+ Status RegisterType(Id id, std::shared_ptr<DataType> type,
+ bool is_variation) override {
+ DCHECK_EQ(type_ids_.size(), types_.size());
+ DCHECK_EQ(type_ids_.size(), type_is_variation_.size());
+
+ Id copied_id{*uris_.emplace(id.uri.to_string()).first,
+ *names_.emplace(id.name.to_string()).first};
+
+ auto index = static_cast<int>(type_ids_.size());
+
+ auto* id_to_index = is_variation ? &variation_id_to_index_ : &id_to_index_;
+ auto it_success = id_to_index->emplace(copied_id, index);
+
+ if (!it_success.second) {
+ return Status::Invalid("Type id was already registered");
+ }
+
+ if (!type_to_index_.emplace(type.get(), index).second) {
+ id_to_index->erase(it_success.first);
+ return Status::Invalid("Type was already registered");
+ }
+
+ type_ids_.push_back(copied_id);
+ types_.push_back(std::move(type));
+ type_is_variation_.push_back(is_variation);
+ return Status::OK();
+ }
+
+ util::optional<FunctionRecord> GetFunction(
+ util::string_view arrow_function_name) const override {
+ if (auto index = GetIndex(function_name_to_index_, arrow_function_name)) {
+ return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]};
+ }
+ return {};
+ }
+
+ util::optional<FunctionRecord> GetFunction(Id id) const override {
+ if (auto index = GetIndex(function_id_to_index_, id)) {
+ return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]};
+ }
+ return {};
+ }
+
+ Status RegisterFunction(Id id, std::string arrow_function_name) override {
+ DCHECK_EQ(function_ids_.size(), function_name_ptrs_.size());
+
+ Id copied_id{*uris_.emplace(id.uri.to_string()).first,
+ *names_.emplace(id.name.to_string()).first};
+
+ const std::string& copied_function_name{
+ *function_names_.emplace(std::move(arrow_function_name)).first};
+
+ auto index = static_cast<int>(function_ids_.size());
+
+ auto it_success = function_id_to_index_.emplace(copied_id, index);
+
+ if (!it_success.second) {
+ return Status::Invalid("Function id was already registered");
+ }
+
+ if (!function_name_to_index_.emplace(copied_function_name, index).second) {
+ function_id_to_index_.erase(it_success.first);
+ return Status::Invalid("Function name was already registered");
+ }
+
+ function_name_ptrs_.push_back(&copied_function_name);
+ function_ids_.push_back(copied_id);
+ return Status::OK();
+ }
+
+ // owning storage of uris, names, (arrow::)function_names, types
+ // note that storing strings like this is safe since references into an
+ // unordered_set are not invalidated on insertion
+ std::unordered_set<std::string> uris_, names_, function_names_;
+ DataTypeVector types_;
+ std::vector<bool> type_is_variation_;
+
+ // non-owning lookup helpers
+ std::vector<Id> type_ids_, function_ids_;
+ std::unordered_map<Id, int, IdHashEq, IdHashEq> id_to_index_, variation_id_to_index_;
+ std::unordered_map<const DataType*, int, TypePtrHashEq, TypePtrHashEq> type_to_index_;
+
+ std::vector<const std::string*> function_name_ptrs_;
+ std::unordered_map<Id, int, IdHashEq, IdHashEq> function_id_to_index_;
+ std::unordered_map<util::string_view, int, ::arrow::internal::StringViewHash>
+ function_name_to_index_;
+ } impl_;
+
+ return &impl_;
+}
+
+} // namespace engine
+} // namespace arrow
diff --git a/cpp/src/arrow/engine/substrait/extension_set.h b/cpp/src/arrow/engine/substrait/extension_set.h
new file mode 100644
index 0000000..2eb4482
--- /dev/null
+++ b/cpp/src/arrow/engine/substrait/extension_set.h
@@ -0,0 +1,240 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// This API is EXPERIMENTAL.
+
+#pragma once
+
+#include <vector>
+
+#include "arrow/engine/visibility.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/optional.h"
+#include "arrow/util/string_view.h"
+
+namespace arrow {
+namespace engine {
+
+/// Substrait identifies functions and custom data types using a (uri, name) pair.
+///
+/// This registry is a bidirectional mapping between Substrait IDs and their corresponding
+/// Arrow counterparts (arrow::DataType and function names in a function registry)
+///
+/// Substrait extension types and variations must be registered with their corresponding
+/// arrow::DataType before they can be used!
+///
+/// Conceptually this can be thought of as two pairs of `unordered_map`s. One pair to
+/// go back and forth between Substrait ID and arrow::DataType and another pair to go
+/// back and forth between Substrait ID and Arrow function names.
+///
+/// Unlike an ExtensionSet this registry is not created automatically when consuming
+/// Substrait plans and must be configured ahead of time (although there is a default
+/// instance).
+class ARROW_ENGINE_EXPORT ExtensionIdRegistry {
+ public:
+ /// All uris registered in this ExtensionIdRegistry
+ virtual std::vector<util::string_view> Uris() const = 0;
+
+ struct Id {
+ util::string_view uri, name;
+
+ bool empty() const { return uri.empty() && name.empty(); }
+ };
+
+ /// \brief A mapping between a Substrait ID and an arrow::DataType
+ struct TypeRecord {
+ Id id;
+ const std::shared_ptr<DataType>& type;
+ bool is_variation;
+ };
+ virtual util::optional<TypeRecord> GetType(const DataType&) const = 0;
+ virtual util::optional<TypeRecord> GetType(Id, bool is_variation) const = 0;
+ virtual Status RegisterType(Id, std::shared_ptr<DataType>, bool is_variation) = 0;
+
+ /// \brief A mapping between a Substrait ID and an Arrow function
+ ///
+ /// Note: At the moment we identify functions solely by the name
+ /// of the function in the function registry.
+ ///
+ /// TODO(ARROW-15582) some functions will not be simple enough to convert without access
+ /// to their arguments/options. For example is_in embeds the set in options rather than
+ /// using an argument:
+ /// is_in(x, SetLookupOptions(set)) <-> (k...Uri, "is_in")(x, set)
+ ///
+ /// ... for another example, depending on the value of the first argument to
+ /// substrait::add it either corresponds to arrow::add or arrow::add_checked
+ struct FunctionRecord {
+ Id id;
+ const std::string& function_name;
+ };
+ virtual util::optional<FunctionRecord> GetFunction(Id) const = 0;
+ virtual util::optional<FunctionRecord> GetFunction(
+ util::string_view arrow_function_name) const = 0;
+ virtual Status RegisterFunction(Id, std::string arrow_function_name) = 0;
+};
+
+constexpr util::string_view kArrowExtTypesUri =
+ "https://github.com/apache/arrow/blob/master/format/substrait/"
+ "extension_types.yaml";
+
+/// A default registry with all supported functions and data types registered
+///
+/// Note: Function support is currently very minimal, see ARROW-15538
+ARROW_ENGINE_EXPORT ExtensionIdRegistry* default_extension_id_registry();
+
+/// \brief A set of extensions used within a plan
+///
+/// Each time an extension is used within a Substrait plan the extension
+/// must be included in an extension set that is defined at the root of the
+/// plan.
+///
+/// The plan refers to a specific extension using an "anchor" which is an
+/// arbitrary integer invented by the producer that has no meaning beyond a
+/// plan but which should be consistent within a plan.
+///
+/// To support serialization and deserialization this type serves as a
+/// bidirectional map between Substrait ID and "anchor"s.
+///
+/// When deserializing a Substrait plan the extension set should be extracted
+/// after the plan has been converted from Protobuf and before the plan
+/// is converted to an execution plan.
+///
+/// The extension set can be kept and reused during serialization if a perfect
+/// round trip is required. If serialization is not needed or round tripping
+/// is not required then the extension set can be safely discarded after the
+/// plan has been converted into an execution plan.
+///
+/// When converting an execution plan into a Substrait plan an extension set
+/// can be automatically generated or a previously generated extension set can
+/// be used.
+///
+/// ExtensionSet does not own strings; it only refers to strings in an
+/// ExtensionIdRegistry.
+class ARROW_ENGINE_EXPORT ExtensionSet {
+ public:
+ using Id = ExtensionIdRegistry::Id;
+
+ struct FunctionRecord {
+ Id id;
+ util::string_view name;
+ };
+
+ struct TypeRecord {
+ Id id;
+ std::shared_ptr<DataType> type;
+ bool is_variation;
+ };
+
+ /// Construct an empty ExtensionSet to be populated during serialization.
+ explicit ExtensionSet(ExtensionIdRegistry* = default_extension_id_registry());
+ ARROW_DEFAULT_MOVE_AND_ASSIGN(ExtensionSet);
+
+ /// Construct an ExtensionSet with explicit extension ids for efficient referencing
+ /// during deserialization. Note that input vectors need not be densely packed; an empty
+ /// (default constructed) Id may be used as a placeholder to indicate an unused
+ /// _anchor/_reference. This factory will be used to wrap the extensions declared in a
+ /// substrait::Plan before deserializing the plan's relations.
+ ///
+ /// Views will be replaced with equivalent views pointing to memory owned by the
+ /// registry.
+ ///
+ /// Note: This is an advanced operation. The order of the ids, types, and functions
+ /// must match the anchor numbers chosen for a plan.
+ ///
+ /// An extension set should instead be created using
+ /// arrow::engine::GetExtensionSetFromPlan
+ static Result<ExtensionSet> Make(
+ std::vector<util::string_view> uris, std::vector<Id> type_ids,
+ std::vector<bool> type_is_variation, std::vector<Id> function_ids,
+ ExtensionIdRegistry* = default_extension_id_registry());
+
+ // index in these vectors == value of _anchor/_reference fields
+ /// TODO(ARROW-15583) this assumes that _anchor/_references won't be huge, which is not
+ /// guaranteed. Could it be?
+ const std::vector<util::string_view>& uris() const { return uris_; }
+
+ /// \brief Returns a data type given an anchor
+ ///
+ /// This is used when converting a Substrait plan to an Arrow execution plan.
+ ///
+ /// If the anchor does not exist in this extension set an error will be returned.
+ Result<TypeRecord> DecodeType(uint32_t anchor) const;
+
+ /// \brief Returns the number of custom type records in this extension set
+ ///
+ /// Note: the types are currently stored as a sparse vector, so this may return a value
+ /// larger than the actual number of types. This behavior may change in the future; see
+ /// ARROW-15583.
+ std::size_t num_types() const { return types_.size(); }
+
+ /// \brief Lookup the anchor for a given type
+ ///
+ /// This operation is used when converting an Arrow execution plan to a Substrait plan.
+ /// If the type has been previously encoded then the same anchor value will returned.
+ ///
+ /// If the type has not been previously encoded then a new anchor value will be created.
+ ///
+ /// If the type does not exist in the extension id registry then an error will be
+ /// returned.
+ ///
+ /// \return An anchor that can be used to refer to the type within a plan
+ Result<uint32_t> EncodeType(const DataType& type);
+
+ /// \brief Returns a function given an anchor
+ ///
+ /// This is used when converting a Substrait plan to an Arrow execution plan.
+ ///
+ /// If the anchor does not exist in this extension set an error will be returned.
+ Result<FunctionRecord> DecodeFunction(uint32_t anchor) const;
+
+ /// \brief Lookup the anchor for a given function
+ ///
+ /// This operation is used when converting an Arrow execution plan to a Substrait plan.
+ /// If the function has been previously encoded then the same anchor value will be
+ /// returned.
+ ///
+ /// If the function has not been previously encoded then a new anchor value will be
+ /// created.
+ ///
+ /// If the function name is not in the extension id registry then an error will be
+ /// returned.
+ ///
+ /// \return An anchor that can be used to refer to the function within a plan
+ Result<uint32_t> EncodeFunction(util::string_view function_name);
+
+ /// \brief Returns the number of custom functions in this extension set
+ ///
+ /// Note: the functions are currently stored as a sparse vector, so this may return a
+ /// value larger than the actual number of functions. This behavior may change in the
+ /// future; see ARROW-15583.
+ std::size_t num_functions() const { return functions_.size(); }
+
+ private:
+ ExtensionIdRegistry* registry_;
+ /// The subset of extension registry URIs referenced by this extension set
+ std::vector<util::string_view> uris_;
+ std::vector<TypeRecord> types_;
+
+ std::vector<FunctionRecord> functions_;
+
+ // pimpl pattern to hide lookup details
+ struct Impl;
+ std::unique_ptr<Impl, void (*)(Impl*)> impl_;
+};
+
+} // namespace engine
+} // namespace arrow
diff --git a/cpp/src/arrow/engine/substrait/extension_types.cc b/cpp/src/arrow/engine/substrait/extension_types.cc
new file mode 100644
index 0000000..b8fd191
--- /dev/null
+++ b/cpp/src/arrow/engine/substrait/extension_types.cc
@@ -0,0 +1,147 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/engine/substrait/extension_types.h"
+
+#include "arrow/engine/simple_extension_type_internal.h"
+#include "arrow/util/hashing.h"
+#include "arrow/util/string_view.h"
+
+namespace arrow {
+
+using internal::DataMember;
+using internal::MakeProperties;
+
+namespace engine {
+namespace {
+
+constexpr util::string_view kUuidExtensionName = "uuid";
+struct UuidExtensionParams {};
+std::shared_ptr<DataType> UuidGetStorage(const UuidExtensionParams&) {
+ return fixed_size_binary(16);
+}
+static auto kUuidExtensionParamsProperties = MakeProperties();
+
+using UuidType = SimpleExtensionType<kUuidExtensionName, UuidExtensionParams,
+ decltype(kUuidExtensionParamsProperties),
+ &kUuidExtensionParamsProperties, UuidGetStorage>;
+
+constexpr util::string_view kFixedCharExtensionName = "fixed_char";
+struct FixedCharExtensionParams {
+ int32_t length;
+};
+std::shared_ptr<DataType> FixedCharGetStorage(const FixedCharExtensionParams& params) {
+ return fixed_size_binary(params.length);
+}
+static auto kFixedCharExtensionParamsProperties =
+ MakeProperties(DataMember("length", &FixedCharExtensionParams::length));
+
+using FixedCharType =
+ SimpleExtensionType<kFixedCharExtensionName, FixedCharExtensionParams,
+ decltype(kFixedCharExtensionParamsProperties),
+ &kFixedCharExtensionParamsProperties, FixedCharGetStorage>;
+
+constexpr util::string_view kVarCharExtensionName = "varchar";
+struct VarCharExtensionParams {
+ int32_t length;
+};
+std::shared_ptr<DataType> VarCharGetStorage(const VarCharExtensionParams&) {
+ return utf8();
+}
+static auto kVarCharExtensionParamsProperties =
+ MakeProperties(DataMember("length", &VarCharExtensionParams::length));
+
+using VarCharType =
+ SimpleExtensionType<kVarCharExtensionName, VarCharExtensionParams,
+ decltype(kVarCharExtensionParamsProperties),
+ &kVarCharExtensionParamsProperties, VarCharGetStorage>;
+
+constexpr util::string_view kIntervalYearExtensionName = "interval_year";
+struct IntervalYearExtensionParams {};
+std::shared_ptr<DataType> IntervalYearGetStorage(const IntervalYearExtensionParams&) {
+ return fixed_size_list(int32(), 2);
+}
+static auto kIntervalYearExtensionParamsProperties = MakeProperties();
+
+using IntervalYearType =
+ SimpleExtensionType<kIntervalYearExtensionName, IntervalYearExtensionParams,
+ decltype(kIntervalYearExtensionParamsProperties),
+ &kIntervalYearExtensionParamsProperties, IntervalYearGetStorage>;
+
+constexpr util::string_view kIntervalDayExtensionName = "interval_day";
+struct IntervalDayExtensionParams {};
+std::shared_ptr<DataType> IntervalDayGetStorage(const IntervalDayExtensionParams&) {
+ return fixed_size_list(int32(), 2);
+}
+static auto kIntervalDayExtensionParamsProperties = MakeProperties();
+
+using IntervalDayType =
+ SimpleExtensionType<kIntervalDayExtensionName, IntervalDayExtensionParams,
+ decltype(kIntervalDayExtensionParamsProperties),
+ &kIntervalDayExtensionParamsProperties, IntervalDayGetStorage>;
+
+} // namespace
+
+std::shared_ptr<DataType> uuid() { return UuidType::Make({}); }
+
+std::shared_ptr<DataType> fixed_char(int32_t length) {
+ return FixedCharType::Make({length});
+}
+
+std::shared_ptr<DataType> varchar(int32_t length) { return VarCharType::Make({length}); }
+
+std::shared_ptr<DataType> interval_year() { return IntervalYearType::Make({}); }
+
+std::shared_ptr<DataType> interval_day() { return IntervalDayType::Make({}); }
+
+bool UnwrapUuid(const DataType& t) {
+ if (UuidType::GetIf(t)) {
+ return true;
+ }
+ return false;
+}
+
+util::optional<int32_t> UnwrapFixedChar(const DataType& t) {
+ if (auto params = FixedCharType::GetIf(t)) {
+ return params->length;
+ }
+ return util::nullopt;
+}
+
+util::optional<int32_t> UnwrapVarChar(const DataType& t) {
+ if (auto params = VarCharType::GetIf(t)) {
+ return params->length;
+ }
+ return util::nullopt;
+}
+
+bool UnwrapIntervalYear(const DataType& t) {
+ if (IntervalYearType::GetIf(t)) {
+ return true;
+ }
+ return false;
+}
+
+bool UnwrapIntervalDay(const DataType& t) {
+ if (IntervalDayType::GetIf(t)) {
+ return true;
+ }
+ return false;
+}
+
+} // namespace engine
+} // namespace arrow
diff --git a/cpp/src/arrow/engine/substrait/extension_types.h b/cpp/src/arrow/engine/substrait/extension_types.h
new file mode 100644
index 0000000..e689e94
--- /dev/null
+++ b/cpp/src/arrow/engine/substrait/extension_types.h
@@ -0,0 +1,82 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// This API is EXPERIMENTAL.
+
+#pragma once
+
+#include <vector>
+
+#include "arrow/buffer.h"
+#include "arrow/compute/function.h"
+#include "arrow/engine/visibility.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/optional.h"
+#include "arrow/util/string_view.h"
+
+namespace arrow {
+namespace engine {
+
+// arrow::ExtensionTypes are provided to wrap uuid, fixed_char, varchar, interval_year,
+// and interval_day which are first-class types in substrait but do not appear in
+// the arrow type system.
+//
+// Note that these are not automatically registered with arrow::RegisterExtensionType(),
+// which means among other things that serialization of these types to IPC would fail.
+
+/// fixed_size_binary(16) for storing Universally Unique IDentifiers
+ARROW_ENGINE_EXPORT
+std::shared_ptr<DataType> uuid();
+
+/// fixed_size_binary(length) constrained to contain only valid UTF-8
+ARROW_ENGINE_EXPORT
+std::shared_ptr<DataType> fixed_char(int32_t length);
+
+/// utf8() constrained to be shorter than `length`
+ARROW_ENGINE_EXPORT
+std::shared_ptr<DataType> varchar(int32_t length);
+
+/// fixed_size_list(int32(), 2) storing a number of [years, months]
+ARROW_ENGINE_EXPORT
+std::shared_ptr<DataType> interval_year();
+
+/// fixed_size_list(int32(), 2) storing a number of [days, seconds]
+ARROW_ENGINE_EXPORT
+std::shared_ptr<DataType> interval_day();
+
+/// Return true if t is Uuid, otherwise false
+ARROW_ENGINE_EXPORT
+bool UnwrapUuid(const DataType&);
+
+/// Return FixedChar length if t is FixedChar, otherwise nullopt
+ARROW_ENGINE_EXPORT
+util::optional<int32_t> UnwrapFixedChar(const DataType&);
+
+/// Return Varchar (max) length if t is VarChar, otherwise nullopt
+ARROW_ENGINE_EXPORT
+util::optional<int32_t> UnwrapVarChar(const DataType& t);
+
+/// Return true if t is IntervalYear, otherwise false
+ARROW_ENGINE_EXPORT
+bool UnwrapIntervalYear(const DataType&);
+
+/// Return true if t is IntervalDay, otherwise false
+ARROW_ENGINE_EXPORT
+bool UnwrapIntervalDay(const DataType&);
+
+} // namespace engine
+} // namespace arrow
diff --git a/cpp/src/arrow/engine/substrait/plan_internal.cc b/cpp/src/arrow/engine/substrait/plan_internal.cc
new file mode 100644
index 0000000..8ffbcc0
--- /dev/null
+++ b/cpp/src/arrow/engine/substrait/plan_internal.cc
@@ -0,0 +1,161 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/engine/substrait/plan_internal.h"
+
+#include "arrow/result.h"
+#include "arrow/util/hashing.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/make_unique.h"
+#include "arrow/util/unreachable.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace engine {
+
+namespace internal {
+using ::arrow::internal::make_unique;
+} // namespace internal
+
+Status AddExtensionSetToPlan(const ExtensionSet& ext_set, substrait::Plan* plan) {
+ plan->clear_extension_uris();
+
+ std::unordered_map<util::string_view, int, ::arrow::internal::StringViewHash> map;
+
+ auto uris = plan->mutable_extension_uris();
+ uris->Reserve(static_cast<int>(ext_set.uris().size()));
+ for (uint32_t anchor = 0; anchor < ext_set.uris().size(); ++anchor) {
+ auto uri = ext_set.uris()[anchor];
+ if (uri.empty()) continue;
+
+ auto ext_uri = internal::make_unique<substrait::extensions::SimpleExtensionURI>();
+ ext_uri->set_uri(uri.to_string());
+ ext_uri->set_extension_uri_anchor(anchor);
+ uris->AddAllocated(ext_uri.release());
+
+ map[uri] = anchor;
+ }
+
+ auto extensions = plan->mutable_extensions();
+ extensions->Reserve(static_cast<int>(ext_set.num_types() + ext_set.num_functions()));
+
+ using ExtDecl = substrait::extensions::SimpleExtensionDeclaration;
+
+ for (uint32_t anchor = 0; anchor < ext_set.num_types(); ++anchor) {
+ ARROW_ASSIGN_OR_RAISE(auto type_record, ext_set.DecodeType(anchor));
+ if (type_record.id.empty()) continue;
+
+ auto ext_decl = internal::make_unique<ExtDecl>();
+
+ if (type_record.is_variation) {
+ auto type_var = internal::make_unique<ExtDecl::ExtensionTypeVariation>();
+ type_var->set_extension_uri_reference(map[type_record.id.uri]);
+ type_var->set_type_variation_anchor(anchor);
+ type_var->set_name(type_record.id.name.to_string());
+ ext_decl->set_allocated_extension_type_variation(type_var.release());
+ } else {
+ auto type = internal::make_unique<ExtDecl::ExtensionType>();
+ type->set_extension_uri_reference(map[type_record.id.uri]);
+ type->set_type_anchor(anchor);
+ type->set_name(type_record.id.name.to_string());
+ ext_decl->set_allocated_extension_type(type.release());
+ }
+
+ extensions->AddAllocated(ext_decl.release());
+ }
+
+ for (uint32_t anchor = 0; anchor < ext_set.num_functions(); ++anchor) {
+ ARROW_ASSIGN_OR_RAISE(auto function_record, ext_set.DecodeFunction(anchor));
+ if (function_record.id.empty()) continue;
+
+ auto fn = internal::make_unique<ExtDecl::ExtensionFunction>();
+ fn->set_extension_uri_reference(map[function_record.id.uri]);
+ fn->set_function_anchor(anchor);
+ fn->set_name(function_record.id.name.to_string());
+
+ auto ext_decl = internal::make_unique<ExtDecl>();
+ ext_decl->set_allocated_extension_function(fn.release());
+ extensions->AddAllocated(ext_decl.release());
+ }
+
+ return Status::OK();
+}
+
+namespace {
+template <typename Element, typename T>
+void SetElement(size_t i, const Element& element, std::vector<T>* vector) {
+ DCHECK_LE(i, 1 << 20);
+ if (i >= vector->size()) {
+ vector->resize(i + 1);
+ }
+ (*vector)[i] = static_cast<T>(element);
+}
+} // namespace
+
+Result<ExtensionSet> GetExtensionSetFromPlan(const substrait::Plan& plan,
+ ExtensionIdRegistry* registry) {
+ std::vector<util::string_view> uris;
+ for (const auto& uri : plan.extension_uris()) {
+ SetElement(uri.extension_uri_anchor(), uri.uri(), &uris);
+ }
+
+ // NOTE: it's acceptable to use views to memory owned by plan; ExtensionSet::Make
+ // will only store views to memory owned by registry.
+
+ using Id = ExtensionSet::Id;
+
+ std::vector<Id> type_ids, function_ids;
+ std::vector<bool> type_is_variation;
+ for (const auto& ext : plan.extensions()) {
+ switch (ext.mapping_type_case()) {
+ case substrait::extensions::SimpleExtensionDeclaration::kExtensionTypeVariation: {
+ const auto& type_var = ext.extension_type_variation();
+ util::string_view uri = uris[type_var.extension_uri_reference()];
+ SetElement(type_var.type_variation_anchor(), Id{uri, type_var.name()}, &type_ids);
+ SetElement(type_var.type_variation_anchor(), true, &type_is_variation);
+ break;
+ }
+
+ case substrait::extensions::SimpleExtensionDeclaration::kExtensionType: {
+ const auto& type = ext.extension_type();
+ util::string_view uri = uris[type.extension_uri_reference()];
+ SetElement(type.type_anchor(), Id{uri, type.name()}, &type_ids);
+ SetElement(type.type_anchor(), false, &type_is_variation);
+ break;
+ }
+
+ case substrait::extensions::SimpleExtensionDeclaration::kExtensionFunction: {
+ const auto& fn = ext.extension_function();
+ util::string_view uri = uris[fn.extension_uri_reference()];
+ SetElement(fn.function_anchor(), Id{uri, fn.name()}, &function_ids);
+ break;
+ }
+
+ default:
+ Unreachable();
+ }
+ }
+
+ return ExtensionSet::Make(std::move(uris), std::move(type_ids),
+ std::move(type_is_variation), std::move(function_ids),
+ registry);
+}
+
+} // namespace engine
+} // namespace arrow
diff --git a/cpp/src/arrow/engine/substrait/plan_internal.h b/cpp/src/arrow/engine/substrait/plan_internal.h
new file mode 100644
index 0000000..0ab06ec
--- /dev/null
+++ b/cpp/src/arrow/engine/substrait/plan_internal.h
@@ -0,0 +1,55 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// This API is EXPERIMENTAL.
+
+#pragma once
+
+#include "arrow/engine/substrait/extension_set.h"
+#include "arrow/engine/visibility.h"
+#include "arrow/type_fwd.h"
+
+#include "substrait/plan.pb.h" // IWYU pragma: export
+
+namespace arrow {
+namespace engine {
+
+/// \brief Replaces the extension information of a Substrait Plan message with the given
+/// extension set, such that the anchors defined therein can be used in the rest of the
+/// plan.
+///
+/// \param[in] ext_set the extension set to copy the extension information from
+/// \param[in,out] plan the Substrait plan message that is to be updated
+/// \return success or failure
+ARROW_ENGINE_EXPORT
+Status AddExtensionSetToPlan(const ExtensionSet& ext_set, substrait::Plan* plan);
+
+/// \brief Interprets the extension information of a Substrait Plan message into an
+/// ExtensionSet.
+///
+/// Note that the extension registry is not currently mutated, but may be in the future.
+///
+/// \param[in] plan the plan message to take the information from
+/// \param[in,out] registry registry defining which Arrow types and compute functions
+/// correspond to Substrait's URI/name pairs
+ARROW_ENGINE_EXPORT
+Result<ExtensionSet> GetExtensionSetFromPlan(
+ const substrait::Plan& plan,
+ ExtensionIdRegistry* registry = default_extension_id_registry());
+
+} // namespace engine
+} // namespace arrow
diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc
new file mode 100644
index 0000000..ae2244c
--- /dev/null
+++ b/cpp/src/arrow/engine/substrait/relation_internal.cc
@@ -0,0 +1,193 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/engine/substrait/relation_internal.h"
+
+#include "arrow/compute/api_scalar.h"
+#include "arrow/compute/exec/options.h"
+#include "arrow/dataset/file_parquet.h"
+#include "arrow/dataset/plan.h"
+#include "arrow/dataset/scanner.h"
+#include "arrow/engine/substrait/expression_internal.h"
+#include "arrow/engine/substrait/type_internal.h"
+#include "arrow/filesystem/localfs.h"
+
+namespace arrow {
+namespace engine {
+
+template <typename RelMessage>
+Status CheckRelCommon(const RelMessage& rel) {
+ if (rel.has_common()) {
+ if (rel.common().has_emit()) {
+ return Status::NotImplemented("substrait::RelCommon::Emit");
+ }
+ if (rel.common().has_hint()) {
+ return Status::NotImplemented("substrait::RelCommon::Hint");
+ }
+ if (rel.common().has_advanced_extension()) {
+ return Status::NotImplemented("substrait::RelCommon::advanced_extension");
+ }
+ }
+ if (rel.has_advanced_extension()) {
+ return Status::NotImplemented("substrait AdvancedExtensions");
+ }
+ return Status::OK();
+}
+
+Result<compute::Declaration> FromProto(const substrait::Rel& rel,
+ const ExtensionSet& ext_set) {
+ static bool dataset_init = false;
+ if (!dataset_init) {
+ dataset_init = true;
+ dataset::internal::Initialize();
+ }
+
+ switch (rel.rel_type_case()) {
+ case substrait::Rel::RelTypeCase::kRead: {
+ const auto& read = rel.read();
+ RETURN_NOT_OK(CheckRelCommon(read));
+
+ ARROW_ASSIGN_OR_RAISE(auto base_schema, FromProto(read.base_schema(), ext_set));
+
+ auto scan_options = std::make_shared<dataset::ScanOptions>();
+
+ if (read.has_filter()) {
+ ARROW_ASSIGN_OR_RAISE(scan_options->filter, FromProto(read.filter(), ext_set));
+ }
+
+ if (read.has_projection()) {
+ // NOTE: scan_options->projection is not used by the scanner and thus can't be
+ // used for this
+ return Status::NotImplemented("substrait::ReadRel::projection");
+ }
+
+ if (!read.has_local_files()) {
+ return Status::NotImplemented(
+ "substrait::ReadRel with read_type other than LocalFiles");
+ }
+
+ if (read.local_files().has_advanced_extension()) {
+ return Status::NotImplemented(
+ "substrait::ReadRel::LocalFiles::advanced_extension");
+ }
+
+ auto format = std::make_shared<dataset::ParquetFileFormat>();
+ auto filesystem = std::make_shared<fs::LocalFileSystem>();
+ std::vector<std::shared_ptr<dataset::FileFragment>> fragments;
+
+ for (const auto& item : read.local_files().items()) {
+ if (!item.has_uri_file()) {
+ return Status::NotImplemented(
+ "substrait::ReadRel::LocalFiles::FileOrFiles with "
+ "path_type other than uri_file");
+ }
+
+ if (item.format() !=
+ substrait::ReadRel::LocalFiles::FileOrFiles::FILE_FORMAT_PARQUET) {
+ return Status::NotImplemented(
+ "substrait::ReadRel::LocalFiles::FileOrFiles::format "
+ "other than FILE_FORMAT_PARQUET");
+ }
+
+ if (!util::string_view{item.uri_file()}.starts_with("file:///")) {
+ return Status::NotImplemented(
+ "substrait::ReadRel::LocalFiles::FileOrFiles::uri_file "
+ "with other than local filesystem (file:///)");
+ }
+ auto path = item.uri_file().substr(7);
+
+ if (item.partition_index() != 0) {
+ return Status::NotImplemented(
+ "non-default substrait::ReadRel::LocalFiles::FileOrFiles::partition_index");
+ }
+
+ if (item.start() != 0) {
+ return Status::NotImplemented(
+ "non-default substrait::ReadRel::LocalFiles::FileOrFiles::start offset");
+ }
+
+ if (item.length() != 0) {
+ return Status::NotImplemented(
+ "non-default substrait::ReadRel::LocalFiles::FileOrFiles::length");
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto fragment, format->MakeFragment(dataset::FileSource{
+ std::move(path), filesystem}));
+ fragments.push_back(std::move(fragment));
+ }
+
+ ARROW_ASSIGN_OR_RAISE(
+ auto ds, dataset::FileSystemDataset::Make(
+ std::move(base_schema), /*root_partition=*/compute::literal(true),
+ std::move(format), std::move(filesystem), std::move(fragments)));
+
+ return compute::Declaration{
+ "scan", dataset::ScanNodeOptions{std::move(ds), std::move(scan_options)}};
+ }
+
+ case substrait::Rel::RelTypeCase::kFilter: {
+ const auto& filter = rel.filter();
+ RETURN_NOT_OK(CheckRelCommon(filter));
+
+ if (!filter.has_input()) {
+ return Status::Invalid("substrait::FilterRel with no input relation");
+ }
+ ARROW_ASSIGN_OR_RAISE(auto input, FromProto(filter.input(), ext_set));
+
+ if (!filter.has_condition()) {
+ return Status::Invalid("substrait::FilterRel with no condition expression");
+ }
+ ARROW_ASSIGN_OR_RAISE(auto condition, FromProto(filter.condition(), ext_set));
+
+ return compute::Declaration::Sequence({
+ std::move(input),
+ {"filter", compute::FilterNodeOptions{std::move(condition)}},
+ });
+ }
+
+ case substrait::Rel::RelTypeCase::kProject: {
+ const auto& project = rel.project();
+ RETURN_NOT_OK(CheckRelCommon(project));
+
+ if (!project.has_input()) {
+ return Status::Invalid("substrait::ProjectRel with no input relation");
+ }
+ ARROW_ASSIGN_OR_RAISE(auto input, FromProto(project.input(), ext_set));
+
+ std::vector<compute::Expression> expressions;
+ for (const auto& expr : project.expressions()) {
+ expressions.emplace_back();
+ ARROW_ASSIGN_OR_RAISE(expressions.back(), FromProto(expr, ext_set));
+ }
+
+ return compute::Declaration::Sequence({
+ std::move(input),
+ {"project", compute::ProjectNodeOptions{std::move(expressions)}},
+ });
+ }
+
+ default:
+ break;
+ }
+
+ return Status::NotImplemented(
+ "conversion to arrow::compute::Declaration from Substrait relation ",
+ rel.DebugString());
+}
+
+} // namespace engine
+} // namespace arrow
diff --git a/cpp/src/arrow/engine/substrait/relation_internal.h b/cpp/src/arrow/engine/substrait/relation_internal.h
new file mode 100644
index 0000000..d9b90f5
--- /dev/null
+++ b/cpp/src/arrow/engine/substrait/relation_internal.h
@@ -0,0 +1,37 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// This API is EXPERIMENTAL.
+
+#pragma once
+
+#include "arrow/compute/exec/exec_plan.h"
+#include "arrow/engine/substrait/extension_types.h"
+#include "arrow/engine/substrait/serde.h"
+#include "arrow/engine/visibility.h"
+#include "arrow/type_fwd.h"
+
+#include "substrait/relations.pb.h" // IWYU pragma: export
+
+namespace arrow {
+namespace engine {
+
+ARROW_ENGINE_EXPORT
+Result<compute::Declaration> FromProto(const substrait::Rel&, const ExtensionSet&);
+
+} // namespace engine
+} // namespace arrow
diff --git a/cpp/src/arrow/engine/substrait/serde.cc b/cpp/src/arrow/engine/substrait/serde.cc
new file mode 100644
index 0000000..ea916d8
--- /dev/null
+++ b/cpp/src/arrow/engine/substrait/serde.cc
@@ -0,0 +1,232 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/engine/substrait/serde.h"
+
+#include "arrow/engine/substrait/expression_internal.h"
+#include "arrow/engine/substrait/plan_internal.h"
+#include "arrow/engine/substrait/relation_internal.h"
+#include "arrow/engine/substrait/type_internal.h"
+#include "arrow/util/string_view.h"
+
+#include <google/protobuf/descriptor.h>
+#include <google/protobuf/io/zero_copy_stream_impl_lite.h>
+#include <google/protobuf/message.h>
+#include <google/protobuf/util/json_util.h>
+#include <google/protobuf/util/message_differencer.h>
+#include <google/protobuf/util/type_resolver_util.h>
+
+namespace arrow {
+namespace engine {
+
+Status ParseFromBufferImpl(const Buffer& buf, const std::string& full_name,
+ google::protobuf::Message* message) {
+ google::protobuf::io::ArrayInputStream buf_stream{buf.data(),
+ static_cast<int>(buf.size())};
+
+ if (message->ParseFromZeroCopyStream(&buf_stream)) {
+ return Status::OK();
+ }
+ return Status::IOError("ParseFromZeroCopyStream failed for ", full_name);
+}
+
+template <typename Message>
+Result<Message> ParseFromBuffer(const Buffer& buf) {
+ Message message;
+ ARROW_RETURN_NOT_OK(
+ ParseFromBufferImpl(buf, Message::descriptor()->full_name(), &message));
+ return message;
+}
+
+Result<compute::Declaration> DeserializeRelation(const Buffer& buf,
+ const ExtensionSet& ext_set) {
+ ARROW_ASSIGN_OR_RAISE(auto rel, ParseFromBuffer<substrait::Rel>(buf));
+ return FromProto(rel, ext_set);
+}
+
+Result<std::vector<compute::Declaration>> DeserializePlan(
+ const Buffer& buf, const ConsumerFactory& consumer_factory,
+ ExtensionSet* ext_set_out) {
+ ARROW_ASSIGN_OR_RAISE(auto plan, ParseFromBuffer<substrait::Plan>(buf));
+
+ ARROW_ASSIGN_OR_RAISE(auto ext_set, GetExtensionSetFromPlan(plan));
+
+ std::vector<compute::Declaration> sink_decls;
+ for (const substrait::PlanRel& plan_rel : plan.relations()) {
+ if (plan_rel.has_root()) {
+ return Status::NotImplemented("substrait::PlanRel with custom output field names");
+ }
+ ARROW_ASSIGN_OR_RAISE(auto decl, FromProto(plan_rel.rel(), ext_set));
+
+ // pipe each relation into a consuming_sink node
+ auto sink_decl = compute::Declaration::Sequence({
+ std::move(decl),
+ {"consuming_sink", compute::ConsumingSinkNodeOptions{consumer_factory()}},
+ });
+ sink_decls.push_back(std::move(sink_decl));
+ }
+
+ if (ext_set_out) {
+ *ext_set_out = std::move(ext_set);
+ }
+ return sink_decls;
+}
+
+Result<std::shared_ptr<Schema>> DeserializeSchema(const Buffer& buf,
+ const ExtensionSet& ext_set) {
+ ARROW_ASSIGN_OR_RAISE(auto named_struct, ParseFromBuffer<substrait::NamedStruct>(buf));
+ return FromProto(named_struct, ext_set);
+}
+
+Result<std::shared_ptr<Buffer>> SerializeSchema(const Schema& schema,
+ ExtensionSet* ext_set) {
+ ARROW_ASSIGN_OR_RAISE(auto named_struct, ToProto(schema, ext_set));
+ std::string serialized = named_struct->SerializeAsString();
+ return Buffer::FromString(std::move(serialized));
+}
+
+Result<std::shared_ptr<DataType>> DeserializeType(const Buffer& buf,
+ const ExtensionSet& ext_set) {
+ ARROW_ASSIGN_OR_RAISE(auto type, ParseFromBuffer<substrait::Type>(buf));
+ ARROW_ASSIGN_OR_RAISE(auto type_nullable, FromProto(type, ext_set));
+ return std::move(type_nullable.first);
+}
+
+Result<std::shared_ptr<Buffer>> SerializeType(const DataType& type,
+ ExtensionSet* ext_set) {
+ ARROW_ASSIGN_OR_RAISE(auto st_type, ToProto(type, /*nullable=*/true, ext_set));
+ std::string serialized = st_type->SerializeAsString();
+ return Buffer::FromString(std::move(serialized));
+}
+
+Result<compute::Expression> DeserializeExpression(const Buffer& buf,
+ const ExtensionSet& ext_set) {
+ ARROW_ASSIGN_OR_RAISE(auto expr, ParseFromBuffer<substrait::Expression>(buf));
+ return FromProto(expr, ext_set);
+}
+
+Result<std::shared_ptr<Buffer>> SerializeExpression(const compute::Expression& expr,
+ ExtensionSet* ext_set) {
+ ARROW_ASSIGN_OR_RAISE(auto st_expr, ToProto(expr, ext_set));
+ std::string serialized = st_expr->SerializeAsString();
+ return Buffer::FromString(std::move(serialized));
+}
+
+namespace internal {
+
+template <typename Message>
+static Status CheckMessagesEquivalent(const Buffer& l_buf, const Buffer& r_buf) {
+ ARROW_ASSIGN_OR_RAISE(auto l, ParseFromBuffer<Message>(l_buf));
+ ARROW_ASSIGN_OR_RAISE(auto r, ParseFromBuffer<Message>(r_buf));
+
+ using google::protobuf::util::MessageDifferencer;
+
+ std::string out;
+ google::protobuf::io::StringOutputStream out_stream{&out};
+ MessageDifferencer::StreamReporter reporter{&out_stream};
+
+ MessageDifferencer differencer;
+ differencer.set_message_field_comparison(MessageDifferencer::EQUIVALENT);
+ differencer.ReportDifferencesTo(&reporter);
+
+ if (differencer.Compare(l, r)) {
+ return Status::OK();
+ }
+ return Status::Invalid("Messages were not equivalent: ", out);
+}
+
+Status CheckMessagesEquivalent(util::string_view message_name, const Buffer& l_buf,
+ const Buffer& r_buf) {
+ if (message_name == "Type") {
+ return CheckMessagesEquivalent<substrait::Type>(l_buf, r_buf);
+ }
+
+ if (message_name == "NamedStruct") {
+ return CheckMessagesEquivalent<substrait::NamedStruct>(l_buf, r_buf);
+ }
+
+ if (message_name == "Schema") {
+ return Status::Invalid(
+ "There is no substrait message named Schema. The substrait message type which "
+ "corresponds to Schemas is NamedStruct");
+ }
+
+ if (message_name == "Expression") {
+ return CheckMessagesEquivalent<substrait::Expression>(l_buf, r_buf);
+ }
+
+ if (message_name == "Rel") {
+ return CheckMessagesEquivalent<substrait::Rel>(l_buf, r_buf);
+ }
+
+ if (message_name == "Relation") {
+ return Status::Invalid(
+ "There is no substrait message named Relation. You probably meant \"Rel\"");
+ }
+
+ return Status::Invalid("Unsupported message name ", message_name,
+ " for CheckMessagesEquivalent");
+}
+
+inline google::protobuf::util::TypeResolver* GetGeneratedTypeResolver() {
+ static std::unique_ptr<google::protobuf::util::TypeResolver> type_resolver;
+ if (!type_resolver) {
+ type_resolver.reset(google::protobuf::util::NewTypeResolverForDescriptorPool(
+ /*url_prefix=*/"", google::protobuf::DescriptorPool::generated_pool()));
+ }
+ return type_resolver.get();
+}
+
+Result<std::shared_ptr<Buffer>> SubstraitFromJSON(util::string_view type_name,
+ util::string_view json) {
+ std::string type_url = "/substrait." + type_name.to_string();
+
+ google::protobuf::io::ArrayInputStream json_stream{json.data(),
+ static_cast<int>(json.size())};
+
+ std::string out;
+ google::protobuf::io::StringOutputStream out_stream{&out};
+
+ auto status = google::protobuf::util::JsonToBinaryStream(
+ GetGeneratedTypeResolver(), type_url, &json_stream, &out_stream);
+
+ if (!status.ok()) {
+ return Status::Invalid("JsonToBinaryStream returned ", status);
+ }
+ return Buffer::FromString(std::move(out));
+}
+
+Result<std::string> SubstraitToJSON(util::string_view type_name, const Buffer& buf) {
+ std::string type_url = "/substrait." + type_name.to_string();
+
+ google::protobuf::io::ArrayInputStream buf_stream{buf.data(),
+ static_cast<int>(buf.size())};
+
+ std::string out;
+ google::protobuf::io::StringOutputStream out_stream{&out};
+
+ auto status = google::protobuf::util::BinaryToJsonStream(
+ GetGeneratedTypeResolver(), type_url, &buf_stream, &out_stream);
+ if (!status.ok()) {
+ return Status::Invalid("BinaryToJsonStream returned ", status);
+ }
+ return out;
+}
+
+} // namespace internal
+} // namespace engine
+} // namespace arrow
diff --git a/cpp/src/arrow/engine/substrait/serde.h b/cpp/src/arrow/engine/substrait/serde.h
new file mode 100644
index 0000000..9e63a1b
--- /dev/null
+++ b/cpp/src/arrow/engine/substrait/serde.h
@@ -0,0 +1,168 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// This API is EXPERIMENTAL.
+
+#pragma once
+
+#include <functional>
+#include <string>
+#include <vector>
+
+#include "arrow/buffer.h"
+#include "arrow/compute/exec/exec_plan.h"
+#include "arrow/compute/exec/options.h"
+#include "arrow/engine/substrait/extension_set.h"
+#include "arrow/engine/visibility.h"
+#include "arrow/result.h"
+#include "arrow/util/string_view.h"
+
+namespace arrow {
+namespace engine {
+
+/// Factory function type for generating the node that consumes the batches produced by
+/// each toplevel Substrait relation when deserializing a Substrait Plan.
+using ConsumerFactory = std::function<std::shared_ptr<compute::SinkNodeConsumer>()>;
+
+/// \brief Deserializes a Substrait Plan message to a list of ExecNode declarations
+///
+/// \param[in] buf a buffer containing the protobuf serialization of a Substrait Plan
+/// message
+/// \param[in] consumer_factory factory function for generating the node that consumes
+/// the batches produced by each toplevel Substrait relation
+/// \param[out] ext_set if non-null, the extension mapping used by the Substrait Plan is
+/// returned here.
+/// \return a vector of ExecNode declarations, one for each toplevel relation in the
+/// Substrait Plan
+ARROW_ENGINE_EXPORT Result<std::vector<compute::Declaration>> DeserializePlan(
+ const Buffer& buf, const ConsumerFactory& consumer_factory,
+ ExtensionSet* ext_set = NULLPTR);
+
+/// \brief Deserializes a Substrait Type message to the corresponding Arrow type
+///
+/// \param[in] buf a buffer containing the protobuf serialization of a Substrait Type
+/// message
+/// \param[in] ext_set the extension mapping to use, normally provided by the
+/// surrounding Plan message
+/// \return the corresponding Arrow data type
+ARROW_ENGINE_EXPORT
+Result<std::shared_ptr<DataType>> DeserializeType(const Buffer& buf,
+ const ExtensionSet& ext_set);
+
+/// \brief Serializes an Arrow type to a Substrait Type message
+///
+/// \param[in] type the Arrow data type to serialize
+/// \param[in,out] ext_set the extension mapping to use; may be updated to add a
+/// mapping for the given type
+/// \return a buffer containing the protobuf serialization of the corresponding Substrait
+/// Type message
+ARROW_ENGINE_EXPORT
+Result<std::shared_ptr<Buffer>> SerializeType(const DataType& type,
+ ExtensionSet* ext_set);
+
+/// \brief Deserializes a Substrait NamedStruct message to an Arrow schema
+///
+/// \param[in] buf a buffer containing the protobuf serialization of a Substrait
+/// NamedStruct message
+/// \param[in] ext_set the extension mapping to use, normally provided by the
+/// surrounding Plan message
+/// \return the corresponding Arrow schema
+ARROW_ENGINE_EXPORT
+Result<std::shared_ptr<Schema>> DeserializeSchema(const Buffer& buf,
+ const ExtensionSet& ext_set);
+
+/// \brief Serializes an Arrow schema to a Substrait NamedStruct message
+///
+/// \param[in] schema the Arrow schema to serialize
+/// \param[in,out] ext_set the extension mapping to use; may be updated to add
+/// mappings for the types used in the schema
+/// \return a buffer containing the protobuf serialization of the corresponding Substrait
+/// NamedStruct message
+ARROW_ENGINE_EXPORT
+Result<std::shared_ptr<Buffer>> SerializeSchema(const Schema& schema,
+ ExtensionSet* ext_set);
+
+/// \brief Deserializes a Substrait Expression message to a compute expression
+///
+/// \param[in] buf a buffer containing the protobuf serialization of a Substrait
+/// Expression message
+/// \param[in] ext_set the extension mapping to use, normally provided by the
+/// surrounding Plan message
+/// \return the corresponding Arrow compute expression
+ARROW_ENGINE_EXPORT
+Result<compute::Expression> DeserializeExpression(const Buffer& buf,
+ const ExtensionSet& ext_set);
+
+/// \brief Serializes an Arrow compute expression to a Substrait Expression message
+///
+/// \param[in] expr the Arrow compute expression to serialize
+/// \param[in,out] ext_set the extension mapping to use; may be updated to add
+/// mappings for the types used in the expression
+/// \return a buffer containing the protobuf serialization of the corresponding Substrait
+/// Expression message
+ARROW_ENGINE_EXPORT
+Result<std::shared_ptr<Buffer>> SerializeExpression(const compute::Expression& expr,
+ ExtensionSet* ext_set);
+
+/// \brief Deserializes a Substrait Rel (relation) message to an ExecNode declaration
+///
+/// \param[in] buf a buffer containing the protobuf serialization of a Substrait
+/// Rel message
+/// \param[in] ext_set the extension mapping to use, normally provided by the
+/// surrounding Plan message
+/// \return the corresponding ExecNode declaration
+ARROW_ENGINE_EXPORT Result<compute::Declaration> DeserializeRelation(
+ const Buffer& buf, const ExtensionSet& ext_set);
+
+namespace internal {
+
+/// \brief Checks whether two protobuf serializations of a particular Substrait message
+/// type are equivalent
+///
+/// Note that a binary comparison of the two buffers is insufficient. One reason for this
+/// is that the fields of a message can be specified in any order in the serialization.
+///
+/// \param[in] message_name the name of the Substrait message type to check
+/// \param[in] l_buf buffer containing the first protobuf serialization to compare
+/// \param[in] r_buf buffer containing the second protobuf serialization to compare
+/// \return success if equivalent, failure if not
+ARROW_ENGINE_EXPORT
+Status CheckMessagesEquivalent(util::string_view message_name, const Buffer& l_buf,
+ const Buffer& r_buf);
+
+/// \brief Utility function to convert a JSON serialization of a Substrait message to
+/// its binary serialization
+///
+/// \param[in] type_name the name of the Substrait message type to convert
+/// \param[in] json the JSON string to convert
+/// \return a buffer filled with the binary protobuf serialization of message
+ARROW_ENGINE_EXPORT
+Result<std::shared_ptr<Buffer>> SubstraitFromJSON(util::string_view type_name,
+ util::string_view json);
+
+/// \brief Utility function to convert a binary protobuf serialization of a Substrait
+/// message to JSON
+///
+/// \param[in] type_name the name of the Substrait message type to convert
+/// \param[in] buf the buffer containing the binary protobuf serialization of the message
+/// \return a JSON string representing the message
+ARROW_ENGINE_EXPORT
+Result<std::string> SubstraitToJSON(util::string_view type_name, const Buffer& buf);
+
+} // namespace internal
+} // namespace engine
+} // namespace arrow
diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc
new file mode 100644
index 0000000..6af5d71
--- /dev/null
+++ b/cpp/src/arrow/engine/substrait/serde_test.cc
@@ -0,0 +1,728 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/engine/substrait/serde.h"
+
+#include <google/protobuf/descriptor.h>
+#include <google/protobuf/util/json_util.h>
+#include <google/protobuf/util/type_resolver_util.h>
+#include <gtest/gtest.h>
+
+#include "arrow/compute/exec/expression_internal.h"
+#include "arrow/dataset/file_base.h"
+#include "arrow/dataset/scanner.h"
+#include "arrow/engine/substrait/extension_types.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/matchers.h"
+#include "arrow/util/key_value_metadata.h"
+
+using testing::ElementsAre;
+using testing::Eq;
+using testing::HasSubstr;
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace engine {
+
+const std::shared_ptr<Schema> kBoringSchema = schema({
+ field("bool", boolean()),
+ field("i8", int8()),
+ field("i32", int32()),
+ field("i32_req", int32(), /*nullable=*/false),
+ field("u32", uint32()),
+ field("i64", int64()),
+ field("f32", float32()),
+ field("f32_req", float32(), /*nullable=*/false),
+ field("f64", float64()),
+ field("date64", date64()),
+ field("str", utf8()),
+ field("list_i32", list(int32())),
+ field("struct", struct_({
+ field("i32", int32()),
+ field("str", utf8()),
+ field("struct_i32_str",
+ struct_({field("i32", int32()), field("str", utf8())})),
+ })),
+ field("list_struct", list(struct_({
+ field("i32", int32()),
+ field("str", utf8()),
+ field("struct_i32_str", struct_({field("i32", int32()),
+ field("str", utf8())})),
+ }))),
+ field("dict_str", dictionary(int32(), utf8())),
+ field("dict_i32", dictionary(int32(), int32())),
+ field("ts_ns", timestamp(TimeUnit::NANO)),
+});
+
+std::shared_ptr<DataType> StripFieldNames(std::shared_ptr<DataType> type) {
+ if (type->id() == Type::STRUCT) {
+ FieldVector fields(type->num_fields());
+ for (int i = 0; i < type->num_fields(); ++i) {
+ fields[i] = type->field(i)->WithName("");
+ }
+ return struct_(std::move(fields));
+ }
+
+ if (type->id() == Type::LIST) {
+ return list(type->field(0)->WithName(""));
+ }
+
+ return type;
+}
+
+inline compute::Expression UseBoringRefs(const compute::Expression& expr) {
+ if (expr.literal()) return expr;
+
+ if (auto ref = expr.field_ref()) {
+ return compute::field_ref(*ref->FindOne(*kBoringSchema));
+ }
+
+ auto modified_call = *CallNotNull(expr);
+ for (auto& arg : modified_call.arguments) {
+ arg = UseBoringRefs(arg);
+ }
+ return compute::Expression{std::move(modified_call)};
+}
+
+TEST(Substrait, SupportedTypes) {
+ auto ExpectEq = [](util::string_view json, std::shared_ptr<DataType> expected_type) {
+ ARROW_SCOPED_TRACE(json);
+
+ ExtensionSet empty;
+ ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Type", json));
+ ASSERT_OK_AND_ASSIGN(auto type, DeserializeType(*buf, empty));
+
+ EXPECT_EQ(*type, *expected_type);
+
+ ASSERT_OK_AND_ASSIGN(auto serialized, SerializeType(*type, &empty));
+ EXPECT_EQ(empty.num_types(), 0);
+
+ // FIXME chokes on NULLABILITY_UNSPECIFIED
+ // EXPECT_THAT(internal::CheckMessagesEquivalent("Type", *buf, *serialized), Ok());
+
+ ASSERT_OK_AND_ASSIGN(auto roundtripped, DeserializeType(*serialized, empty));
+
+ EXPECT_EQ(*roundtripped, *expected_type);
+ };
+
+ ExpectEq(R"({"bool": {}})", boolean());
+
+ ExpectEq(R"({"i8": {}})", int8());
+ ExpectEq(R"({"i16": {}})", int16());
+ ExpectEq(R"({"i32": {}})", int32());
+ ExpectEq(R"({"i64": {}})", int64());
+
+ ExpectEq(R"({"fp32": {}})", float32());
+ ExpectEq(R"({"fp64": {}})", float64());
+
+ ExpectEq(R"({"string": {}})", utf8());
+ ExpectEq(R"({"binary": {}})", binary());
+
+ ExpectEq(R"({"timestamp": {}})", timestamp(TimeUnit::MICRO));
+ ExpectEq(R"({"date": {}})", date32());
+ ExpectEq(R"({"time": {}})", time64(TimeUnit::MICRO));
+ ExpectEq(R"({"timestamp_tz": {}})", timestamp(TimeUnit::MICRO, "UTC"));
+ ExpectEq(R"({"interval_year": {}})", interval_year());
+ ExpectEq(R"({"interval_day": {}})", interval_day());
+
+ ExpectEq(R"({"uuid": {}})", uuid());
+
+ ExpectEq(R"({"fixed_char": {"length": 32}})", fixed_char(32));
+ ExpectEq(R"({"varchar": {"length": 1024}})", varchar(1024));
+ ExpectEq(R"({"fixed_binary": {"length": 32}})", fixed_size_binary(32));
+
+ ExpectEq(R"({"decimal": {"precision": 27, "scale": 5}})", decimal128(27, 5));
+
+ ExpectEq(R"({"struct": {
+ "types": [
+ {"i64": {}},
+ {"list": {"type": {"string":{}} }}
+ ]
+ }})",
+ struct_({
+ field("", int64()),
+ field("", list(utf8())),
+ }));
+
+ ExpectEq(R"({"map": {
+ "key": {"string":{"nullability": "NULLABILITY_REQUIRED"}},
+ "value": {"string":{}}
+ }})",
+ map(utf8(), field("", utf8()), false));
+}
+
+TEST(Substrait, SupportedExtensionTypes) {
+ ExtensionSet ext_set;
+
+ for (auto expected_type : {
+ null(),
+ uint8(),
+ uint16(),
+ uint32(),
+ uint64(),
+ }) {
+ auto anchor = ext_set.num_types();
+
+ EXPECT_THAT(ext_set.EncodeType(*expected_type), ResultWith(Eq(anchor)));
+ ASSERT_OK_AND_ASSIGN(
+ auto buf,
+ internal::SubstraitFromJSON(
+ "Type", "{\"user_defined_type_reference\": " + std::to_string(anchor) + "}"));
+
+ ASSERT_OK_AND_ASSIGN(auto type, DeserializeType(*buf, ext_set));
+ EXPECT_EQ(*type, *expected_type);
+
+ auto size = ext_set.num_types();
+ ASSERT_OK_AND_ASSIGN(auto serialized, SerializeType(*type, &ext_set));
+ EXPECT_EQ(ext_set.num_types(), size) << "was already added to the set above";
+
+ ASSERT_OK_AND_ASSIGN(auto roundtripped, DeserializeType(*serialized, ext_set));
+ EXPECT_EQ(*roundtripped, *expected_type);
+ }
+}
+
+TEST(Substrait, NamedStruct) {
+ ExtensionSet ext_set;
+
+ ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("NamedStruct", R"({
+ "struct": {
+ "types": [
+ {"i64": {}},
+ {"list": {"type": {"string":{}} }},
+ {"struct": {
+ "types": [
+ {"fp32": {"nullability": "NULLABILITY_REQUIRED"}},
+ {"string": {}}
+ ]
+ }},
+ {"list": {"type": {"string":{}} }},
+ ]
+ },
+ "names": ["a", "b", "c", "d", "e", "f"]
+ })"));
+ ASSERT_OK_AND_ASSIGN(auto schema, DeserializeSchema(*buf, ext_set));
+ Schema expected_schema({
+ field("a", int64()),
+ field("b", list(utf8())),
+ field("c", struct_({
+ field("d", float32(), /*nullable=*/false),
+ field("e", utf8()),
+ })),
+ field("f", list(utf8())),
+ });
+ EXPECT_EQ(*schema, expected_schema);
+
+ ASSERT_OK_AND_ASSIGN(auto serialized, SerializeSchema(*schema, &ext_set));
+ ASSERT_OK_AND_ASSIGN(auto roundtripped, DeserializeSchema(*serialized, ext_set));
+ EXPECT_EQ(*roundtripped, expected_schema);
+
+ // too few names
+ ASSERT_OK_AND_ASSIGN(buf, internal::SubstraitFromJSON("NamedStruct", R"({
+ "struct": {"types": [{"i32": {}}, {"i32": {}}, {"i32": {}}]},
+ "names": []
+ })"));
+ EXPECT_THAT(DeserializeSchema(*buf, ext_set), Raises(StatusCode::Invalid));
+
+ // too many names
+ ASSERT_OK_AND_ASSIGN(buf, internal::SubstraitFromJSON("NamedStruct", R"({
+ "struct": {"types": []},
+ "names": ["a", "b", "c"]
+ })"));
+ EXPECT_THAT(DeserializeSchema(*buf, ext_set), Raises(StatusCode::Invalid));
+
+ // no schema metadata allowed
+ EXPECT_THAT(SerializeSchema(Schema({}, key_value_metadata({{"ext", "yes"}})), &ext_set),
+ Raises(StatusCode::Invalid));
+
+ // no schema metadata allowed
+ EXPECT_THAT(
+ SerializeSchema(Schema({field("a", int32(), key_value_metadata({{"ext", "yes"}}))}),
+ &ext_set),
+ Raises(StatusCode::Invalid));
+}
+
+TEST(Substrait, NoEquivalentArrowType) {
+ ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON(
+ "Type", R"({"user_defined_type_reference": 99})"));
+ ExtensionSet empty;
+ ASSERT_THAT(
+ DeserializeType(*buf, empty),
+ Raises(StatusCode::Invalid, HasSubstr("did not have a corresponding anchor")));
+}
+
+TEST(Substrait, NoEquivalentSubstraitType) {
+ for (auto type : {
+ date64(),
+ timestamp(TimeUnit::SECOND),
+ timestamp(TimeUnit::NANO),
+ timestamp(TimeUnit::MICRO, "New York"),
+ time32(TimeUnit::SECOND),
+ time32(TimeUnit::MILLI),
+ time64(TimeUnit::NANO),
+
+ decimal256(76, 67),
+
+ sparse_union({field("i8", int8()), field("f32", float32())}),
+ dense_union({field("i8", int8()), field("f32", float32())}),
+ dictionary(int32(), utf8()),
+
+ fixed_size_list(float16(), 3),
+
+ duration(TimeUnit::MICRO),
+
+ large_utf8(),
+ large_binary(),
+ large_list(utf8()),
+ }) {
+ ARROW_SCOPED_TRACE(type->ToString());
+ ExtensionSet set;
+ EXPECT_THAT(SerializeType(*type, &set), Raises(StatusCode::NotImplemented));
+ }
+}
+
+TEST(Substrait, SupportedLiterals) {
+ auto ExpectEq = [](util::string_view json, Datum expected_value) {
+ ARROW_SCOPED_TRACE(json);
+
+ ASSERT_OK_AND_ASSIGN(
+ auto buf, internal::SubstraitFromJSON("Expression",
+ "{\"literal\":" + json.to_string() + "}"));
+ ExtensionSet ext_set;
+ ASSERT_OK_AND_ASSIGN(auto expr, DeserializeExpression(*buf, ext_set));
+
+ ASSERT_TRUE(expr.literal());
+ ASSERT_THAT(*expr.literal(), DataEq(expected_value));
+
+ ASSERT_OK_AND_ASSIGN(auto serialized, SerializeExpression(expr, &ext_set));
+ EXPECT_EQ(ext_set.num_functions(), 0); // shouldn't need extensions for core literals
+
+ ASSERT_OK_AND_ASSIGN(auto roundtripped, DeserializeExpression(*serialized, ext_set));
+
+ ASSERT_TRUE(roundtripped.literal());
+ ASSERT_THAT(*roundtripped.literal(), DataEq(expected_value));
+ };
+
+ ExpectEq(R"({"boolean": true})", Datum(true));
+
+ ExpectEq(R"({"i8": 34})", Datum(int8_t(34)));
+ ExpectEq(R"({"i16": 34})", Datum(int16_t(34)));
+ ExpectEq(R"({"i32": 34})", Datum(int32_t(34)));
+ ExpectEq(R"({"i64": "34"})", Datum(int64_t(34)));
+
+ ExpectEq(R"({"fp32": 3.5})", Datum(3.5F));
+ ExpectEq(R"({"fp64": 7.125})", Datum(7.125));
+
+ ExpectEq(R"({"string": "hello world"})", Datum("hello world"));
+
+ ExpectEq(R"({"binary": "enp6"})", BinaryScalar(Buffer::FromString("zzz")));
+
+ ExpectEq(R"({"timestamp": "579"})", TimestampScalar(579, TimeUnit::MICRO));
+
+ ExpectEq(R"({"date": "5"})", Date32Scalar(5));
+
+ ExpectEq(R"({"time": "64"})", Time64Scalar(64, TimeUnit::MICRO));
+
+ ExpectEq(R"({"interval_year_to_month": {"years": 34, "months": 3}})",
+ ExtensionScalar(FixedSizeListScalar(ArrayFromJSON(int32(), "[34, 3]")),
+ interval_year()));
+
+ ExpectEq(R"({"interval_day_to_second": {"days": 34, "seconds": 3}})",
+ ExtensionScalar(FixedSizeListScalar(ArrayFromJSON(int32(), "[34, 3]")),
+ interval_day()));
+
+ ExpectEq(R"({"fixed_char": "zzz"})",
+ ExtensionScalar(
+ FixedSizeBinaryScalar(Buffer::FromString("zzz"), fixed_size_binary(3)),
+ fixed_char(3)));
+
+ ExpectEq(R"({"var_char": {"value": "zzz", "length": 1024}})",
+ ExtensionScalar(StringScalar("zzz"), varchar(1024)));
+
+ ExpectEq(R"({"fixed_binary": "enp6"})",
+ FixedSizeBinaryScalar(Buffer::FromString("zzz"), fixed_size_binary(3)));
+
+ ExpectEq(
+ R"({"decimal": {"value": "0gKWSQAAAAAAAAAAAAAAAA==", "precision": 27, "scale": 5}})",
+ Decimal128Scalar(Decimal128("123456789.0"), decimal128(27, 5)));
+
+ ExpectEq(R"({"timestamp_tz": "579"})", TimestampScalar(579, TimeUnit::MICRO, "UTC"));
+
+ // special case for empty lists
+ ExpectEq(R"({"empty_list": {"type": {"i32": {}}}})",
+ ScalarFromJSON(list(int32()), "[]"));
+
+ ExpectEq(R"({"struct": {
+ "fields": [
+ {"i64": "32"},
+ {"list": {"values": [
+ {"string": "hello"},
+ {"string": "world"}
+ ]}}
+ ]
+ }})",
+ ScalarFromJSON(struct_({
+ field("", int64()),
+ field("", list(utf8())),
+ }),
+ R"([32, ["hello", "world"]])"));
+
+ // check null scalars:
+ for (auto type : {
+ boolean(),
+
+ int8(),
+ int64(),
+
+ timestamp(TimeUnit::MICRO),
+ interval_year(),
+
+ struct_({
+ field("", int64()),
+ field("", list(utf8())),
+ }),
+ }) {
+ ExtensionSet set;
+ ASSERT_OK_AND_ASSIGN(auto buf, SerializeType(*type, &set));
+ ASSERT_OK_AND_ASSIGN(auto json, internal::SubstraitToJSON("Type", *buf));
+ ExpectEq("{\"null\": " + json + "}", MakeNullScalar(type));
+ }
+}
+
+TEST(Substrait, CannotDeserializeLiteral) {
+ ExtensionSet ext_set;
+
+ // Invalid: missing List.element_type
+ ASSERT_OK_AND_ASSIGN(
+ auto buf, internal::SubstraitFromJSON("Expression",
+ R"({"literal": {"list": {"values": []}}})"));
+ EXPECT_THAT(DeserializeExpression(*buf, ext_set), Raises(StatusCode::Invalid));
+
+ // Invalid: required null literal
+ ASSERT_OK_AND_ASSIGN(
+ buf,
+ internal::SubstraitFromJSON(
+ "Expression",
+ R"({"literal": {"null": {"bool": {"nullability": "NULLABILITY_REQUIRED"}}}})"));
+ EXPECT_THAT(DeserializeExpression(*buf, ext_set), Raises(StatusCode::Invalid));
+
+ // no equivalent arrow scalar
+ // FIXME no way to specify scalars of user_defined_type_reference
+}
+
+TEST(Substrait, FieldRefRoundTrip) {
+ for (FieldRef ref : {
+ // by name
+ FieldRef("i32"),
+ FieldRef("ts_ns"),
+ FieldRef("struct"),
+
+ // by index
+ FieldRef(0),
+ FieldRef(1),
+ FieldRef(kBoringSchema->num_fields() - 1),
+ FieldRef(kBoringSchema->GetFieldIndex("struct")),
+
+ // nested
+ FieldRef("struct", "i32"),
+ FieldRef("struct", "struct_i32_str", "i32"),
+ FieldRef(kBoringSchema->GetFieldIndex("struct"), 1),
+ }) {
+ ARROW_SCOPED_TRACE(ref.ToString());
+ ASSERT_OK_AND_ASSIGN(auto expr, compute::field_ref(ref).Bind(*kBoringSchema));
+
+ ExtensionSet ext_set;
+ ASSERT_OK_AND_ASSIGN(auto serialized, SerializeExpression(expr, &ext_set));
+ EXPECT_EQ(ext_set.num_functions(),
+ 0); // shouldn't need extensions for core field references
+ ASSERT_OK_AND_ASSIGN(auto roundtripped, DeserializeExpression(*serialized, ext_set));
+ ASSERT_TRUE(roundtripped.field_ref());
+
+ ASSERT_OK_AND_ASSIGN(auto expected, ref.FindOne(*kBoringSchema));
+ ASSERT_OK_AND_ASSIGN(auto actual, roundtripped.field_ref()->FindOne(*kBoringSchema));
+ EXPECT_EQ(actual.indices(), expected.indices());
+ }
+}
+
+TEST(Substrait, RecursiveFieldRef) {
+ FieldRef ref("struct", "str");
+
+ ARROW_SCOPED_TRACE(ref.ToString());
+ ASSERT_OK_AND_ASSIGN(auto expr, compute::field_ref(ref).Bind(*kBoringSchema));
+ ExtensionSet ext_set;
+ ASSERT_OK_AND_ASSIGN(auto serialized, SerializeExpression(expr, &ext_set));
+ ASSERT_OK_AND_ASSIGN(auto expected, internal::SubstraitFromJSON("Expression", R"({
+ "selection": {
+ "directReference": {
+ "structField": {
+ "field": 12,
+ "child": {
+ "structField": {
+ "field": 1
+ }
+ }
+ }
+ },
+ "rootReference": {}
+ }
+ })"));
+ ASSERT_OK(internal::CheckMessagesEquivalent("Expression", *serialized, *expected));
+}
+
+TEST(Substrait, FieldRefsInExpressions) {
+ ASSERT_OK_AND_ASSIGN(auto expr,
+ compute::call("struct_field",
+ {compute::call("if_else",
+ {
+ compute::literal(true),
+ compute::field_ref("struct"),
+ compute::field_ref("struct"),
+ })},
+ compute::StructFieldOptions({0}))
+ .Bind(*kBoringSchema));
+
+ ExtensionSet ext_set;
+ ASSERT_OK_AND_ASSIGN(auto serialized, SerializeExpression(expr, &ext_set));
+ ASSERT_OK_AND_ASSIGN(auto expected, internal::SubstraitFromJSON("Expression", R"({
+ "selection": {
+ "directReference": {
+ "structField": {
+ "field": 0
+ }
+ },
+ "expression": {
+ "if_then": {
+ "ifs": [
+ {
+ "if": {"literal": {"boolean": true}},
+ "then": {"selection": {"directReference": {"structField": {"field": 12}}}}
+ }
+ ],
+ "else": {"selection": {"directReference": {"structField": {"field": 12}}}}
+ }
+ }
+ }
+ })"));
+ ASSERT_OK(internal::CheckMessagesEquivalent("Expression", *serialized, *expected));
+}
+
+TEST(Substrait, CallSpecialCaseRoundTrip) {
+ for (compute::Expression expr : {
+ compute::call("if_else",
+ {
+ compute::literal(true),
+ compute::field_ref({"struct", 1}),
+ compute::field_ref("str"),
+ }),
+
+ compute::call(
+ "case_when",
+ {
+ compute::call("make_struct",
+ {compute::literal(false), compute::literal(true)},
+ compute::MakeStructOptions({"cond1", "cond2"})),
+ compute::field_ref({"struct", "str"}),
+ compute::field_ref({"struct", "struct_i32_str", "str"}),
+ compute::field_ref("str"),
+ }),
+
+ compute::call("list_element",
+ {
+ compute::field_ref("list_i32"),
+ compute::literal(3),
+ }),
+
+ compute::call("struct_field",
+ {compute::call("list_element",
+ {
+ compute::field_ref("list_struct"),
+ compute::literal(42),
+ })},
+ arrow::compute::StructFieldOptions({1})),
+
+ compute::call("struct_field",
+ {compute::call("list_element",
+ {
+ compute::field_ref("list_struct"),
+ compute::literal(42),
+ })},
+ arrow::compute::StructFieldOptions({2, 0})),
+
+ compute::call("struct_field",
+ {compute::call("if_else",
+ {
+ compute::literal(true),
+ compute::field_ref("struct"),
+ compute::field_ref("struct"),
+ })},
+ compute::StructFieldOptions({0})),
+ }) {
+ ARROW_SCOPED_TRACE(expr.ToString());
+ ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*kBoringSchema));
+
+ ExtensionSet ext_set;
+ ASSERT_OK_AND_ASSIGN(auto serialized, SerializeExpression(expr, &ext_set));
+
+ // These are special cased as core expressions in substrait; shouldn't require any
+ // extensions.
+ EXPECT_EQ(ext_set.num_functions(), 0);
+
+ ASSERT_OK_AND_ASSIGN(auto roundtripped, DeserializeExpression(*serialized, ext_set));
+ ASSERT_OK_AND_ASSIGN(roundtripped, roundtripped.Bind(*kBoringSchema));
+ EXPECT_EQ(UseBoringRefs(roundtripped), UseBoringRefs(expr));
+ }
+}
+
+TEST(Substrait, CallExtensionFunction) {
+ for (compute::Expression expr : {
+ compute::call("add", {compute::literal(0), compute::literal(1)}),
+ }) {
+ ARROW_SCOPED_TRACE(expr.ToString());
+ ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*kBoringSchema));
+
+ ExtensionSet ext_set;
+ ASSERT_OK_AND_ASSIGN(auto serialized, SerializeExpression(expr, &ext_set));
+
+ // These require an extension, so we should have a single-element ext_set.
+ EXPECT_EQ(ext_set.num_functions(), 1);
+
+ ASSERT_OK_AND_ASSIGN(auto roundtripped, DeserializeExpression(*serialized, ext_set));
+ ASSERT_OK_AND_ASSIGN(roundtripped, roundtripped.Bind(*kBoringSchema));
+ EXPECT_EQ(UseBoringRefs(roundtripped), UseBoringRefs(expr));
+ }
+}
+
+TEST(Substrait, ReadRel) {
+ ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Rel", R"({
+ "read": {
+ "base_schema": {
+ "struct": {
+ "types": [ {"i64": {}}, {"bool": {}} ]
+ },
+ "names": ["i", "b"]
+ },
+ "filter": {
+ "selection": {
+ "directReference": {
+ "structField": {
+ "field": 1
+ }
+ }
+ }
+ },
+ "local_files": {
+ "items": [
+ {
+ "uri_file": "file:///tmp/dat1.parquet",
+ "format": "FILE_FORMAT_PARQUET"
+ },
+ {
+ "uri_file": "file:///tmp/dat2.parquet",
+ "format": "FILE_FORMAT_PARQUET"
+ }
+ ]
+ }
+ }
+ })"));
+ ExtensionSet ext_set;
+ ASSERT_OK_AND_ASSIGN(auto rel, DeserializeRelation(*buf, ext_set));
+
+ // converting a ReadRel produces a scan Declaration
+ ASSERT_EQ(rel.factory_name, "scan");
+ const auto& scan_node_options =
+ checked_cast<const dataset::ScanNodeOptions&>(*rel.options);
+
+ // filter on the boolean field (#1)
+ EXPECT_EQ(scan_node_options.scan_options->filter, compute::field_ref(1));
+
+ // dataset is a FileSystemDataset in parquet format with the specified schema
+ ASSERT_EQ(scan_node_options.dataset->type_name(), "filesystem");
+ const auto& dataset =
+ checked_cast<const dataset::FileSystemDataset&>(*scan_node_options.dataset);
+ EXPECT_THAT(dataset.files(), ElementsAre("/tmp/dat1.parquet", "/tmp/dat2.parquet"));
+ EXPECT_EQ(dataset.format()->type_name(), "parquet");
+ EXPECT_EQ(*dataset.schema(), Schema({field("i", int64()), field("b", boolean())}));
+}
+
+TEST(Substrait, ExtensionSetFromPlan) {
+ ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({
+ "relations": [
+ {"rel": {
+ "read": {
+ "base_schema": {
+ "struct": {
+ "types": [ {"i64": {}}, {"bool": {}} ]
+ },
+ "names": ["i", "b"]
+ },
+ "local_files": { "items": [] }
+ }
+ }}
+ ],
+ "extension_uris": [
+ {
+ "extension_uri_anchor": 7,
+ "uri": "https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml"
+ }
+ ],
+ "extensions": [
+ {"extension_type": {
+ "extension_uri_reference": 7,
+ "type_anchor": 42,
+ "name": "null"
+ }},
+ {"extension_type_variation": {
+ "extension_uri_reference": 7,
+ "type_variation_anchor": 23,
+ "name": "u8"
+ }},
+ {"extension_function": {
+ "extension_uri_reference": 7,
+ "function_anchor": 42,
+ "name": "add"
+ }}
+ ]
+ })"));
+
+ ExtensionSet ext_set;
+ ASSERT_OK_AND_ASSIGN(
+ auto sink_decls,
+ DeserializePlan(
+ *buf, [] { return std::shared_ptr<compute::SinkNodeConsumer>{nullptr}; },
+ &ext_set));
+
+ EXPECT_OK_AND_ASSIGN(auto decoded_null_type, ext_set.DecodeType(42));
+ EXPECT_EQ(decoded_null_type.id.uri, kArrowExtTypesUri);
+ EXPECT_EQ(decoded_null_type.id.name, "null");
+ EXPECT_EQ(*decoded_null_type.type, NullType());
+ EXPECT_FALSE(decoded_null_type.is_variation);
+
+ EXPECT_OK_AND_ASSIGN(auto decoded_uint8_type, ext_set.DecodeType(23));
+ EXPECT_EQ(decoded_uint8_type.id.uri, kArrowExtTypesUri);
+ EXPECT_EQ(decoded_uint8_type.id.name, "u8");
+ EXPECT_EQ(*decoded_uint8_type.type, UInt8Type());
+ EXPECT_TRUE(decoded_uint8_type.is_variation);
+
+ EXPECT_OK_AND_ASSIGN(auto decoded_add_func, ext_set.DecodeFunction(42));
+ EXPECT_EQ(decoded_add_func.id.uri, kArrowExtTypesUri);
+ EXPECT_EQ(decoded_add_func.id.name, "add");
+ EXPECT_EQ(decoded_add_func.name, "add");
+}
+
+} // namespace engine
+} // namespace arrow
diff --git a/cpp/src/arrow/engine/substrait/type_internal.cc b/cpp/src/arrow/engine/substrait/type_internal.cc
new file mode 100644
index 0000000..49ca1bb
--- /dev/null
+++ b/cpp/src/arrow/engine/substrait/type_internal.cc
@@ -0,0 +1,494 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/engine/substrait/type_internal.h"
+
+#include <string>
+#include <vector>
+
+#include "arrow/engine/substrait/extension_types.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/make_unique.h"
+#include "arrow/visit_type_inline.h"
+
+namespace arrow {
+namespace engine {
+
+namespace internal {
+using ::arrow::internal::make_unique;
+} // namespace internal
+
+namespace {
+
+template <typename TypeMessage>
+Status CheckVariation(const TypeMessage& type) {
+ if (type.type_variation_reference() == 0) return Status::OK();
+ return Status::NotImplemented("Type variations for ", type.DebugString());
+}
+
+template <typename TypeMessage>
+bool IsNullable(const TypeMessage& type) {
+ // FIXME what can we do with NULLABILITY_UNSPECIFIED
+ return type.nullability() != substrait::Type::NULLABILITY_REQUIRED;
+}
+
+template <typename ArrowType, typename TypeMessage, typename... A>
+Result<std::pair<std::shared_ptr<DataType>, bool>> FromProtoImpl(const TypeMessage& type,
+ A&&... args) {
+ RETURN_NOT_OK(CheckVariation(type));
+
+ return std::make_pair(std::static_pointer_cast<DataType>(
+ std::make_shared<ArrowType>(std::forward<A>(args)...)),
+ IsNullable(type));
+}
+
+template <typename TypeMessage, typename... A>
+Result<std::pair<std::shared_ptr<DataType>, bool>> FromProtoImpl(
+ const TypeMessage& type, std::shared_ptr<DataType> type_factory(A...), A&&... args) {
+ RETURN_NOT_OK(CheckVariation(type));
+
+ return std::make_pair(
+ std::static_pointer_cast<DataType>(type_factory(std::forward<A>(args)...)),
+ IsNullable(type));
+}
+
+template <typename Types, typename NextName>
+Result<FieldVector> FieldsFromProto(int size, const Types& types,
+ const NextName& next_name,
+ const ExtensionSet& ext_set) {
+ FieldVector fields(size);
+ for (int i = 0; i < size; ++i) {
+ std::string name = next_name();
+ std::shared_ptr<DataType> type;
+ bool nullable;
+
+ if (types[i].has_struct_()) {
+ const auto& struct_ = types[i].struct_();
+
+ ARROW_ASSIGN_OR_RAISE(
+ type, FieldsFromProto(struct_.types_size(), struct_.types(), next_name, ext_set)
+ .Map(arrow::struct_));
+
+ nullable = IsNullable(struct_);
+ } else {
+ ARROW_ASSIGN_OR_RAISE(std::tie(type, nullable), FromProto(types[i], ext_set));
+ }
+
+ fields[i] = field(std::move(name), std::move(type), nullable);
+ }
+ return fields;
+}
+
+} // namespace
+
+Result<std::pair<std::shared_ptr<DataType>, bool>> FromProto(
+ const substrait::Type& type, const ExtensionSet& ext_set) {
+ switch (type.kind_case()) {
+ case substrait::Type::kBool:
+ return FromProtoImpl<BooleanType>(type.bool_());
+
+ case substrait::Type::kI8:
+ return FromProtoImpl<Int8Type>(type.i8());
+ case substrait::Type::kI16:
+ return FromProtoImpl<Int16Type>(type.i16());
+ case substrait::Type::kI32:
+ return FromProtoImpl<Int32Type>(type.i32());
+ case substrait::Type::kI64:
+ return FromProtoImpl<Int64Type>(type.i64());
+
+ case substrait::Type::kFp32:
+ return FromProtoImpl<FloatType>(type.fp32());
+ case substrait::Type::kFp64:
+ return FromProtoImpl<DoubleType>(type.fp64());
+
+ case substrait::Type::kString:
+ return FromProtoImpl<StringType>(type.string());
+ case substrait::Type::kBinary:
+ return FromProtoImpl<BinaryType>(type.binary());
+
+ case substrait::Type::kTimestamp:
+ return FromProtoImpl<TimestampType>(type.timestamp(), TimeUnit::MICRO);
+ case substrait::Type::kTimestampTz:
+ return FromProtoImpl<TimestampType>(type.timestamp_tz(), TimeUnit::MICRO,
+ TimestampTzTimezoneString());
+ case substrait::Type::kDate:
+ return FromProtoImpl<Date32Type>(type.date());
+
+ case substrait::Type::kTime:
+ return FromProtoImpl<Time64Type>(type.time(), TimeUnit::MICRO);
+
+ case substrait::Type::kIntervalYear:
+ return FromProtoImpl(type.interval_year(), interval_year);
+
+ case substrait::Type::kIntervalDay:
+ return FromProtoImpl(type.interval_day(), interval_day);
+
+ case substrait::Type::kUuid:
+ return FromProtoImpl(type.uuid(), uuid);
+
+ case substrait::Type::kFixedChar:
+ return FromProtoImpl(type.fixed_char(), fixed_char, type.fixed_char().length());
+
+ case substrait::Type::kVarchar:
+ return FromProtoImpl(type.varchar(), varchar, type.varchar().length());
+
+ case substrait::Type::kFixedBinary:
+ return FromProtoImpl<FixedSizeBinaryType>(type.fixed_binary(),
+ type.fixed_binary().length());
+
+ case substrait::Type::kDecimal: {
+ const auto& decimal = type.decimal();
+ return FromProtoImpl<Decimal128Type>(decimal, decimal.precision(), decimal.scale());
+ }
+
+ case substrait::Type::kStruct: {
+ const auto& struct_ = type.struct_();
+
+ ARROW_ASSIGN_OR_RAISE(auto fields, FieldsFromProto(
+ struct_.types_size(), struct_.types(),
+ /*next_name=*/[] { return ""; }, ext_set));
+
+ return FromProtoImpl<StructType>(struct_, std::move(fields));
+ }
+
+ case substrait::Type::kList: {
+ const auto& list = type.list();
+
+ if (!list.has_type()) {
+ return Status::Invalid(
+ "While converting to ListType encountered a missing item type in ",
+ list.DebugString());
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto type_nullable, FromProto(list.type(), ext_set));
+ return FromProtoImpl<ListType>(
+ list, field("item", std::move(type_nullable.first), type_nullable.second));
+ }
+
+ case substrait::Type::kMap: {
+ const auto& map = type.map();
+
+ static const std::array<char const*, 4> kMissing = {"key and value", "value", "key",
+ nullptr};
+ if (auto missing = kMissing[map.has_key() + map.has_value() * 2]) {
+ return Status::Invalid("While converting to MapType encountered missing ",
+ missing, " type in ", map.DebugString());
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto key_nullable, FromProto(map.key(), ext_set));
+ ARROW_ASSIGN_OR_RAISE(auto value_nullable, FromProto(map.value(), ext_set));
+
+ if (key_nullable.second) {
+ return Status::Invalid(
+ "While converting to MapType encountered nullable key field in ",
+ map.DebugString());
+ }
+
+ return FromProtoImpl<MapType>(
+ map, std::move(key_nullable.first),
+ field("value", std::move(value_nullable.first), value_nullable.second));
+ }
+
+ case substrait::Type::kUserDefinedTypeReference: {
+ uint32_t anchor = type.user_defined_type_reference();
+ ARROW_ASSIGN_OR_RAISE(auto type_record, ext_set.DecodeType(anchor));
+ return std::make_pair(std::move(type_record.type), true);
+ }
+
+ default:
+ break;
+ }
+
+ return Status::NotImplemented("conversion to arrow::DataType from Substrait type ",
+ type.DebugString());
+}
+
+namespace {
+
+struct DataTypeToProtoImpl {
+ Status Visit(const NullType& t) { return EncodeUserDefined(t); }
+
+ Status Visit(const BooleanType& t) {
+ return SetWith(&substrait::Type::set_allocated_bool_);
+ }
+
+ Status Visit(const Int8Type& t) { return SetWith(&substrait::Type::set_allocated_i8); }
+ Status Visit(const Int16Type& t) {
+ return SetWith(&substrait::Type::set_allocated_i16);
+ }
+ Status Visit(const Int32Type& t) {
+ return SetWith(&substrait::Type::set_allocated_i32);
+ }
+ Status Visit(const Int64Type& t) {
+ return SetWith(&substrait::Type::set_allocated_i64);
+ }
+
+ Status Visit(const UInt8Type& t) { return EncodeUserDefined(t); }
+ Status Visit(const UInt16Type& t) { return EncodeUserDefined(t); }
+ Status Visit(const UInt32Type& t) { return EncodeUserDefined(t); }
+ Status Visit(const UInt64Type& t) { return EncodeUserDefined(t); }
+
+ Status Visit(const HalfFloatType& t) { return EncodeUserDefined(t); }
+ Status Visit(const FloatType& t) {
+ return SetWith(&substrait::Type::set_allocated_fp32);
+ }
+ Status Visit(const DoubleType& t) {
+ return SetWith(&substrait::Type::set_allocated_fp64);
+ }
+
+ Status Visit(const StringType& t) {
+ return SetWith(&substrait::Type::set_allocated_string);
+ }
+ Status Visit(const BinaryType& t) {
+ return SetWith(&substrait::Type::set_allocated_binary);
+ }
+
+ Status Visit(const FixedSizeBinaryType& t) {
+ SetWithThen(&substrait::Type::set_allocated_fixed_binary)->set_length(t.byte_width());
+ return Status::OK();
+ }
+
+ Status Visit(const Date32Type& t) {
+ return SetWith(&substrait::Type::set_allocated_date);
+ }
+ Status Visit(const Date64Type& t) { return NotImplemented(t); }
+
+ Status Visit(const TimestampType& t) {
+ if (t.unit() != TimeUnit::MICRO) return NotImplemented(t);
+
+ if (t.timezone() == "") {
+ return SetWith(&substrait::Type::set_allocated_timestamp);
+ }
+ if (t.timezone() == TimestampTzTimezoneString()) {
+ return SetWith(&substrait::Type::set_allocated_timestamp_tz);
+ }
+
+ return NotImplemented(t);
+ }
+
+ Status Visit(const Time32Type& t) { return NotImplemented(t); }
+ Status Visit(const Time64Type& t) {
+ if (t.unit() != TimeUnit::MICRO) return NotImplemented(t);
+ return SetWith(&substrait::Type::set_allocated_time);
+ }
+
+ Status Visit(const MonthIntervalType& t) { return EncodeUserDefined(t); }
+ Status Visit(const DayTimeIntervalType& t) { return EncodeUserDefined(t); }
+
+ Status Visit(const Decimal128Type& t) {
+ auto dec = SetWithThen(&substrait::Type::set_allocated_decimal);
+ dec->set_precision(t.precision());
+ dec->set_scale(t.scale());
+ return Status::OK();
+ }
+ Status Visit(const Decimal256Type& t) { return NotImplemented(t); }
+
+ Status Visit(const ListType& t) {
+ // FIXME assert default field name; custom ones won't roundtrip
+ ARROW_ASSIGN_OR_RAISE(
+ auto type, ToProto(*t.value_type(), t.value_field()->nullable(), ext_set_));
+ SetWithThen(&substrait::Type::set_allocated_list)->set_allocated_type(type.release());
+ return Status::OK();
+ }
+
+ Status Visit(const StructType& t) {
+ auto types = SetWithThen(&substrait::Type::set_allocated_struct_)->mutable_types();
+
+ types->Reserve(t.num_fields());
+
+ for (const auto& field : t.fields()) {
+ if (field->metadata() != nullptr) {
+ return Status::Invalid("substrait::Type::Struct does not support field metadata");
+ }
+ ARROW_ASSIGN_OR_RAISE(auto type,
+ ToProto(*field->type(), field->nullable(), ext_set_));
+ types->AddAllocated(type.release());
+ }
+ return Status::OK();
+ }
+
+ Status Visit(const SparseUnionType& t) { return NotImplemented(t); }
+ Status Visit(const DenseUnionType& t) { return NotImplemented(t); }
+ Status Visit(const DictionaryType& t) { return NotImplemented(t); }
+
+ Status Visit(const MapType& t) {
+ // FIXME assert default field names; custom ones won't roundtrip
+ auto map = SetWithThen(&substrait::Type::set_allocated_map);
+
+ ARROW_ASSIGN_OR_RAISE(auto key, ToProto(*t.key_type(), /*nullable=*/false, ext_set_));
+ map->set_allocated_key(key.release());
+
+ ARROW_ASSIGN_OR_RAISE(auto value,
+ ToProto(*t.item_type(), t.item_field()->nullable(), ext_set_));
+ map->set_allocated_value(value.release());
+
+ return Status::OK();
+ }
+
+ Status Visit(const ExtensionType& t) {
+ if (UnwrapUuid(t)) {
+ return SetWith(&substrait::Type::set_allocated_uuid);
+ }
+
+ if (auto length = UnwrapFixedChar(t)) {
+ SetWithThen(&substrait::Type::set_allocated_fixed_char)->set_length(*length);
+ return Status::OK();
+ }
+
+ if (auto length = UnwrapVarChar(t)) {
+ SetWithThen(&substrait::Type::set_allocated_varchar)->set_length(*length);
+ return Status::OK();
+ }
+
+ if (UnwrapIntervalYear(t)) {
+ return SetWith(&substrait::Type::set_allocated_interval_year);
+ }
+
+ if (UnwrapIntervalDay(t)) {
+ return SetWith(&substrait::Type::set_allocated_interval_day);
+ }
+
+ return NotImplemented(t);
+ }
+
+ Status Visit(const FixedSizeListType& t) { return NotImplemented(t); }
+ Status Visit(const DurationType& t) { return NotImplemented(t); }
+ Status Visit(const LargeStringType& t) { return NotImplemented(t); }
+ Status Visit(const LargeBinaryType& t) { return NotImplemented(t); }
+ Status Visit(const LargeListType& t) { return NotImplemented(t); }
+ Status Visit(const MonthDayNanoIntervalType& t) { return EncodeUserDefined(t); }
+
+ template <typename Sub>
+ Sub* SetWithThen(void (substrait::Type::*set_allocated_sub)(Sub*)) {
+ auto sub = internal::make_unique<Sub>();
+ sub->set_nullability(nullable_ ? substrait::Type::NULLABILITY_NULLABLE
+ : substrait::Type::NULLABILITY_REQUIRED);
+
+ auto out = sub.get();
+ (type_->*set_allocated_sub)(sub.release());
+ return out;
+ }
+
+ template <typename Sub>
+ Status SetWith(void (substrait::Type::*set_allocated_sub)(Sub*)) {
+ return SetWithThen(set_allocated_sub), Status::OK();
+ }
+
+ template <typename T>
+ Status EncodeUserDefined(const T& t) {
+ ARROW_ASSIGN_OR_RAISE(auto anchor, ext_set_->EncodeType(t));
+ type_->set_user_defined_type_reference(anchor);
+ return Status::OK();
+ }
+
+ Status NotImplemented(const DataType& t) {
+ return Status::NotImplemented("conversion to substrait::Type from ", t.ToString());
+ }
+
+ Status operator()(const DataType& type) { return VisitTypeInline(type, this); }
+
+ substrait::Type* type_;
+ bool nullable_;
+ ExtensionSet* ext_set_;
+};
+} // namespace
+
+Result<std::unique_ptr<substrait::Type>> ToProto(const DataType& type, bool nullable,
+ ExtensionSet* ext_set) {
+ auto out = internal::make_unique<substrait::Type>();
+ RETURN_NOT_OK((DataTypeToProtoImpl{out.get(), nullable, ext_set})(type));
+ return std::move(out);
+}
+
+Result<std::shared_ptr<Schema>> FromProto(const substrait::NamedStruct& named_struct,
+ const ExtensionSet& ext_set) {
+ if (!named_struct.has_struct_()) {
+ return Status::Invalid("While converting ", named_struct.DebugString(),
+ " no anonymous struct type was provided to which names "
+ "could be attached.");
+ }
+ const auto& struct_ = named_struct.struct_();
+ RETURN_NOT_OK(CheckVariation(struct_));
+
+ int requested_names_count = 0;
+ ARROW_ASSIGN_OR_RAISE(auto fields, FieldsFromProto(
+ struct_.types_size(), struct_.types(),
+ /*next_name=*/
+ [&] {
+ int i = requested_names_count++;
+ return i < named_struct.names_size()
+ ? named_struct.names().Get(i)
+ : "";
+ },
+ ext_set));
+
+ if (requested_names_count != named_struct.names_size()) {
+ return Status::Invalid("While converting ", named_struct.DebugString(), " received ",
+ named_struct.names_size(), " names but ",
+ requested_names_count, " struct fields");
+ }
+
+ return schema(std::move(fields));
+}
+
+namespace {
+void ToProtoGetDepthFirstNames(const FieldVector& fields,
+ google::protobuf::RepeatedPtrField<std::string>* names) {
+ for (const auto& field : fields) {
+ *names->Add() = field->name();
+
+ if (field->type()->id() == Type::STRUCT) {
+ ToProtoGetDepthFirstNames(field->type()->fields(), names);
+ }
+ }
+}
+} // namespace
+
+Result<std::unique_ptr<substrait::NamedStruct>> ToProto(const Schema& schema,
+ ExtensionSet* ext_set) {
+ if (schema.metadata()) {
+ return Status::Invalid("substrait::NamedStruct does not support schema metadata");
+ }
+
+ auto named_struct = internal::make_unique<substrait::NamedStruct>();
+
+ auto names = named_struct->mutable_names();
+ names->Reserve(schema.num_fields());
+ ToProtoGetDepthFirstNames(schema.fields(), names);
+
+ auto struct_ = internal::make_unique<substrait::Type::Struct>();
+ auto types = struct_->mutable_types();
+ types->Reserve(schema.num_fields());
+
+ for (const auto& field : schema.fields()) {
+ if (field->metadata() != nullptr) {
+ return Status::Invalid("substrait::NamedStruct does not support field metadata");
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto type, ToProto(*field->type(), field->nullable(), ext_set));
+ types->AddAllocated(type.release());
+ }
+
+ named_struct->set_allocated_struct_(struct_.release());
+ return std::move(named_struct);
+}
+
+} // namespace engine
+} // namespace arrow
diff --git a/cpp/src/arrow/engine/substrait/type_internal.h b/cpp/src/arrow/engine/substrait/type_internal.h
new file mode 100644
index 0000000..058019c
--- /dev/null
+++ b/cpp/src/arrow/engine/substrait/type_internal.h
@@ -0,0 +1,51 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// This API is EXPERIMENTAL.
+
+#pragma once
+
+#include <utility>
+
+#include "arrow/engine/substrait/extension_set.h"
+#include "arrow/engine/visibility.h"
+#include "arrow/type_fwd.h"
+
+#include "substrait/type.pb.h" // IWYU pragma: export
+
+namespace arrow {
+namespace engine {
+
+ARROW_ENGINE_EXPORT
+Result<std::pair<std::shared_ptr<DataType>, bool>> FromProto(const substrait::Type&,
+ const ExtensionSet&);
+
+ARROW_ENGINE_EXPORT
+Result<std::unique_ptr<substrait::Type>> ToProto(const DataType&, bool nullable,
+ ExtensionSet*);
+
+ARROW_ENGINE_EXPORT
+Result<std::shared_ptr<Schema>> FromProto(const substrait::NamedStruct&,
+ const ExtensionSet&);
+
+ARROW_ENGINE_EXPORT
+Result<std::unique_ptr<substrait::NamedStruct>> ToProto(const Schema&, ExtensionSet*);
+
+inline std::string TimestampTzTimezoneString() { return "UTC"; }
+
+} // namespace engine
+} // namespace arrow
diff --git a/cpp/src/arrow/engine/visibility.h b/cpp/src/arrow/engine/visibility.h
new file mode 100644
index 0000000..5b1651f
--- /dev/null
+++ b/cpp/src/arrow/engine/visibility.h
@@ -0,0 +1,50 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// This API is EXPERIMENTAL.
+
+#pragma once
+
+#if defined(_WIN32) || defined(__CYGWIN__)
+#if defined(_MSC_VER)
+#pragma warning(push)
+#pragma warning(disable : 4251)
+#else
+#pragma GCC diagnostic ignored "-Wattributes"
+#endif
+
+#ifdef ARROW_ENGINE_STATIC
+#define ARROW_ENGINE_EXPORT
+#elif defined(ARROW_ENGINE_EXPORTING)
+#define ARROW_ENGINE_EXPORT __declspec(dllexport)
+#else
+#define ARROW_ENGINE_EXPORT __declspec(dllimport)
+#endif
+
+#define ARROW_ENGINE_NO_EXPORT
+#else // Not Windows
+#ifndef ARROW_ENGINE_EXPORT
+#define ARROW_ENGINE_EXPORT __attribute__((visibility("default")))
+#endif
+#ifndef ARROW_ENGINE_NO_EXPORT
+#define ARROW_ENGINE_NO_EXPORT __attribute__((visibility("hidden")))
+#endif
+#endif // Non-Windows
+
+#if defined(_MSC_VER)
+#pragma warning(pop)
+#endif
diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt
index 1038497..5861d84 100644
--- a/cpp/src/arrow/flight/CMakeLists.txt
+++ b/cpp/src/arrow/flight/CMakeLists.txt
@@ -59,7 +59,7 @@ endif()
# TODO(wesm): Protobuf shared vs static linking
set(FLIGHT_PROTO_PATH "${ARROW_SOURCE_DIR}/../format")
-set(FLIGHT_PROTO ${ARROW_SOURCE_DIR}/../format/Flight.proto)
+set(FLIGHT_PROTO "${ARROW_SOURCE_DIR}/../format/Flight.proto")
set(FLIGHT_GENERATED_PROTO_FILES
"${CMAKE_CURRENT_BINARY_DIR}/Flight.pb.cc" "${CMAKE_CURRENT_BINARY_DIR}/Flight.pb.h"
@@ -163,9 +163,9 @@ endif()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS_BACKUP}")
# Note, we do not compile the generated Protobuf sources directly, instead
-# compiling then via protocol_internal.cc which contains some gRPC template
+# compiling them via protocol_internal.cc which contains some gRPC template
# overrides to enable Flight-specific optimizations. See comments in
-# protobuf-internal.cc
+# protocol_internal.cc
set(ARROW_FLIGHT_SRCS
client.cc
client_cookie_middleware.cc
diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc
index 23c07e6..23c463b 100644
--- a/cpp/src/arrow/scalar.cc
+++ b/cpp/src/arrow/scalar.cc
@@ -475,9 +475,15 @@ Status Scalar::ValidateFull() const {
return ScalarValidateImpl(/*full_validation=*/true).Validate(*this);
}
+BinaryScalar::BinaryScalar(std::string s)
+ : BinaryScalar(Buffer::FromString(std::move(s))) {}
+
StringScalar::StringScalar(std::string s)
: StringScalar(Buffer::FromString(std::move(s))) {}
+LargeBinaryScalar::LargeBinaryScalar(std::string s)
+ : LargeBinaryScalar(Buffer::FromString(std::move(s))) {}
+
LargeStringScalar::LargeStringScalar(std::string s)
: LargeStringScalar(Buffer::FromString(std::move(s))) {}
@@ -488,6 +494,12 @@ FixedSizeBinaryScalar::FixedSizeBinaryScalar(std::shared_ptr<Buffer> value,
this->value->size());
}
+FixedSizeBinaryScalar::FixedSizeBinaryScalar(const std::shared_ptr<Buffer>& value)
+ : BinaryScalar(value, fixed_size_binary(static_cast<int>(value->size()))) {}
+
+FixedSizeBinaryScalar::FixedSizeBinaryScalar(std::string s)
+ : FixedSizeBinaryScalar(Buffer::FromString(std::move(s))) {}
+
BaseListScalar::BaseListScalar(std::shared_ptr<Array> value,
std::shared_ptr<DataType> type)
: Scalar{std::move(type), true}, value(std::move(value)) {
diff --git a/cpp/src/arrow/scalar.h b/cpp/src/arrow/scalar.h
index 9df3e3c..943a642 100644
--- a/cpp/src/arrow/scalar.h
+++ b/cpp/src/arrow/scalar.h
@@ -250,6 +250,8 @@ struct ARROW_EXPORT BinaryScalar : public BaseBinaryScalar {
explicit BinaryScalar(std::shared_ptr<Buffer> value)
: BinaryScalar(std::move(value), binary()) {}
+ explicit BinaryScalar(std::string s);
+
BinaryScalar() : BinaryScalar(binary()) {}
};
@@ -275,6 +277,8 @@ struct ARROW_EXPORT LargeBinaryScalar : public BaseBinaryScalar {
explicit LargeBinaryScalar(std::shared_ptr<Buffer> value)
: LargeBinaryScalar(std::move(value), large_binary()) {}
+ explicit LargeBinaryScalar(std::string s);
+
LargeBinaryScalar() : LargeBinaryScalar(large_binary()) {}
};
@@ -295,7 +299,12 @@ struct ARROW_EXPORT FixedSizeBinaryScalar : public BinaryScalar {
FixedSizeBinaryScalar(std::shared_ptr<Buffer> value, std::shared_ptr<DataType> type);
- explicit FixedSizeBinaryScalar(std::shared_ptr<DataType> type) : BinaryScalar(type) {}
+ explicit FixedSizeBinaryScalar(const std::shared_ptr<Buffer>& value);
+
+ explicit FixedSizeBinaryScalar(std::string s);
+
+ explicit FixedSizeBinaryScalar(std::shared_ptr<DataType> type)
+ : BinaryScalar(std::move(type)) {}
};
template <typename T>
@@ -345,8 +354,8 @@ struct ARROW_EXPORT TimestampScalar : public TemporalScalar<TimestampType> {
using TemporalScalar<TimestampType>::TemporalScalar;
TimestampScalar(typename TemporalScalar<TimestampType>::ValueType value,
- TimeUnit::type unit)
- : TimestampScalar(std::move(value), timestamp(unit)) {}
+ TimeUnit::type unit, std::string tz = "")
+ : TimestampScalar(std::move(value), timestamp(unit, std::move(tz))) {}
};
template <typename T>
@@ -533,6 +542,11 @@ struct ARROW_EXPORT ExtensionScalar : public Scalar {
ExtensionScalar(std::shared_ptr<Scalar> storage, std::shared_ptr<DataType> type)
: Scalar(std::move(type), true), value(std::move(storage)) {}
+ template <typename Storage,
+ typename = enable_if_t<std::is_base_of<Scalar, Storage>::value>>
+ ExtensionScalar(Storage&& storage, std::shared_ptr<DataType> type)
+ : ExtensionScalar(std::make_shared<Storage>(std::move(storage)), std::move(type)) {}
+
std::shared_ptr<Scalar> value;
};
diff --git a/cpp/src/arrow/status_test.cc b/cpp/src/arrow/status_test.cc
index 10a79d9..a8e1d1c 100644
--- a/cpp/src/arrow/status_test.cc
+++ b/cpp/src/arrow/status_test.cc
@@ -179,20 +179,19 @@ TEST(StatusTest, MatcherExplanations) {
{
testing::StringMatchResultListener listener;
EXPECT_TRUE(matcher.MatchAndExplain(Status::Invalid("XXX"), &listener));
- EXPECT_THAT(listener.str(), testing::StrEq("whose value \"Invalid: XXX\" matches"));
+ EXPECT_THAT(listener.str(), testing::StrEq("whose error matches"));
}
{
testing::StringMatchResultListener listener;
EXPECT_FALSE(matcher.MatchAndExplain(Status::OK(), &listener));
- EXPECT_THAT(listener.str(), testing::StrEq("whose value \"OK\" doesn't match"));
+ EXPECT_THAT(listener.str(), testing::StrEq("whose non-error doesn't match"));
}
{
testing::StringMatchResultListener listener;
EXPECT_FALSE(matcher.MatchAndExplain(Status::TypeError("XXX"), &listener));
- EXPECT_THAT(listener.str(),
- testing::StrEq("whose value \"Type error: XXX\" doesn't match"));
+ EXPECT_THAT(listener.str(), testing::StrEq("whose error doesn't match"));
}
}
diff --git a/cpp/src/arrow/testing/matchers.h b/cpp/src/arrow/testing/matchers.h
index ddfe60f..be88c3f 100644
--- a/cpp/src/arrow/testing/matchers.h
+++ b/cpp/src/arrow/testing/matchers.h
@@ -24,9 +24,11 @@
#include "arrow/datum.h"
#include "arrow/result.h"
#include "arrow/status.h"
+#include "arrow/stl_iterator.h"
#include "arrow/testing/future_util.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/util/future.h"
+#include "arrow/util/unreachable.h"
namespace arrow {
@@ -196,8 +198,14 @@ class ErrorMatcher {
message_matcher_->MatchAndExplain(status.message(), &value_listener);
}
- *listener << "whose value " << testing::PrintToString(status.ToString())
- << (match ? " matches" : " doesn't match");
+ if (match) {
+ *listener << "whose error matches";
+ } else if (status.ok()) {
+ *listener << "whose non-error doesn't match";
+ } else {
+ *listener << "whose error doesn't match";
+ }
+
testing::internal::PrintIfNotEmpty(value_listener.str(), listener->stream());
return match;
}
@@ -228,8 +236,7 @@ class OkMatcher {
const Status& status = internal::GenericToStatus(maybe_value);
const bool match = status.ok();
- *listener << "whose value " << testing::PrintToString(status.ToString())
- << (match ? " matches" : " doesn't match");
+ *listener << "whose " << (match ? "non-error matches" : "error doesn't match");
return match;
}
};
@@ -268,6 +275,9 @@ ErrorMatcher Raises(StatusCode code, const MessageMatcher& message_matcher) {
class DataEqMatcher {
public:
+ // TODO(bkietz) support EqualOptions, ApproxEquals, etc
+ // Probably it's better to use something like config-through-key_value_metadata
+ // as with the random generators to decouple this from EqualOptions etc.
explicit DataEqMatcher(Datum expected) : expected_(std::move(expected)) {}
template <typename Data>
@@ -295,17 +305,34 @@ class DataEqMatcher {
return false;
}
- if (*boxed.type() != *expected_.type()) {
- *listener << "whose DataType " << boxed.type()->ToString() << " doesn't match "
- << expected_.type()->ToString();
- return false;
+ if (const auto& boxed_type = boxed.type()) {
+ if (*boxed_type != *expected_.type()) {
+ *listener << "whose DataType " << boxed_type->ToString() << " doesn't match "
+ << expected_.type()->ToString();
+ return false;
+ }
+ } else if (const auto& boxed_schema = boxed.schema()) {
+ if (*boxed_schema != *expected_.schema()) {
+ *listener << "whose Schema " << boxed_schema->ToString() << " doesn't match "
+ << expected_.schema()->ToString();
+ return false;
+ }
+ } else {
+ Unreachable();
}
- const bool match = boxed == expected_;
- *listener << "whose value ";
- PrintTo(boxed, listener->stream());
- *listener << (match ? " matches" : " doesn't match");
- return match;
+ if (boxed == expected_) {
+ *listener << "whose value matches";
+ return true;
+ }
+
+ if (listener->IsInterested() && boxed.kind() == Datum::ARRAY) {
+ *listener << "whose value differs from the expected value by "
+ << boxed.make_array()->Diff(*expected_.make_array());
+ } else {
+ *listener << "whose value doesn't match";
+ }
+ return false;
}
Datum expected_;
@@ -318,9 +345,66 @@ class DataEqMatcher {
Datum expected_;
};
+/// Constructs a datum against which arguments are matched
template <typename Data>
DataEqMatcher DataEq(Data&& dat) {
return DataEqMatcher(Datum(std::forward<Data>(dat)));
}
+/// Constructs an array with ArrayFromJSON against which arguments are matched
+inline DataEqMatcher DataEqArray(const std::shared_ptr<DataType>& type,
+ util::string_view json) {
+ return DataEq(ArrayFromJSON(type, json));
+}
+
+/// Constructs an array from a vector of optionals against which arguments are matched
+template <typename T, typename ArrayType = typename TypeTraits<T>::ArrayType,
+ typename BuilderType = typename TypeTraits<T>::BuilderType,
+ typename ValueType =
+ typename ::arrow::stl::detail::DefaultValueAccessor<ArrayType>::ValueType>
+DataEqMatcher DataEqArray(T type, const std::vector<util::optional<ValueType>>& values) {
+ // FIXME(bkietz) broken until DataType is move constructible
+ BuilderType builder(std::make_shared<T>(std::move(type)), default_memory_pool());
+ DCHECK_OK(builder.Reserve(static_cast<int64_t>(values.size())));
+
+ // pseudo constexpr:
+ static const bool need_safe_append = !is_fixed_width(T::type_id);
+
+ for (auto value : values) {
+ if (value) {
+ if (need_safe_append) {
+ builder.UnsafeAppend(*value);
+ } else {
+ DCHECK_OK(builder.Append(*value));
+ }
+ } else {
+ builder.UnsafeAppendNull();
+ }
+ }
+
+ return DataEq(builder.Finish().ValueOrDie());
+}
+
+/// Constructs a scalar with ScalarFromJSON against which arguments are matched
+inline DataEqMatcher DataEqScalar(const std::shared_ptr<DataType>& type,
+ util::string_view json) {
+ return DataEq(ScalarFromJSON(type, json));
+}
+
+/// Constructs a scalar against which arguments are matched
+template <typename T, typename ScalarType = typename TypeTraits<T>::ScalarType,
+ typename ValueType = typename ScalarType::ValueType>
+DataEqMatcher DataEqScalar(T type, util::optional<ValueType> value) {
+ ScalarType expected(std::make_shared<T>(std::move(type)));
+
+ if (value) {
+ expected.is_valid = true;
+ expected.value = std::move(*value);
+ }
+
+ return DataEq(std::move(expected));
+}
+
+// HasType, HasSchema matchers
+
} // namespace arrow
diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h
index 881e597..c05e20f 100644
--- a/cpp/src/arrow/type.h
+++ b/cpp/src/arrow/type.h
@@ -82,8 +82,8 @@ class ARROW_EXPORT Fingerprintable {
virtual std::string ComputeFingerprint() const = 0;
virtual std::string ComputeMetadataFingerprint() const = 0;
- mutable std::atomic<std::string*> fingerprint_;
- mutable std::atomic<std::string*> metadata_fingerprint_;
+ mutable std::atomic<std::string*> fingerprint_{NULLPTR};
+ mutable std::atomic<std::string*> metadata_fingerprint_{NULLPTR};
};
} // namespace detail
@@ -819,7 +819,7 @@ class ARROW_EXPORT Decimal256Type : public DecimalType {
class ARROW_EXPORT BaseListType : public NestedType {
public:
using NestedType::NestedType;
- std::shared_ptr<Field> value_field() const { return children_[0]; }
+ const std::shared_ptr<Field>& value_field() const { return children_[0]; }
std::shared_ptr<DataType> value_type() const { return children_[0]->type(); }
};
diff --git a/cpp/src/arrow/util/hashing.h b/cpp/src/arrow/util/hashing.h
index 328d7e7..d2c0178 100644
--- a/cpp/src/arrow/util/hashing.h
+++ b/cpp/src/arrow/util/hashing.h
@@ -882,5 +882,14 @@ static inline Status ComputeNullBitmap(MemoryPool* pool, const MemoTableType& me
return Status::OK();
}
+struct StringViewHash {
+ // std::hash compatible hasher for use with std::unordered_*
+ // (the std::hash specialization provided by nonstd constructs std::string
+ // temporaries then invokes std::hash<std::string> against those)
+ hash_t operator()(const util::string_view& value) const {
+ return ComputeStringHash<0>(value.data(), static_cast<int64_t>(value.size()));
+ }
+};
+
} // namespace internal
} // namespace arrow
diff --git a/dev/archery/archery/cli.py b/dev/archery/archery/cli.py
index 2a2e13a..f58d99f 100644
--- a/dev/archery/archery/cli.py
+++ b/dev/archery/archery/cli.py
@@ -117,6 +117,12 @@ def _apply_options(cmd, options):
@cpp_toolchain_options
@click.option("--build-type", default=None, type=build_type,
help="CMake's CMAKE_BUILD_TYPE")
+@click.option("--build-static", default=True, type=BOOL,
+ help="Build static libraries")
+@click.option("--build-shared", default=True, type=BOOL,
+ help="Build shared libraries")
+@click.option("--build-unity", default=True, type=BOOL,
+ help="Use CMAKE_UNITY_BUILD")
@click.option("--warn-level", default="production", type=warn_level_type,
help="Controls compiler warnings -W(no-)error.")
@click.option("--use-gold-linker", default=True, type=BOOL,
diff --git a/dev/archery/archery/lang/cpp.py b/dev/archery/archery/lang/cpp.py
index 251ad54..ac3b251 100644
--- a/dev/archery/archery/lang/cpp.py
+++ b/dev/archery/archery/lang/cpp.py
@@ -42,7 +42,7 @@ class CppConfiguration:
cc=None, cxx=None, cxx_flags=None,
build_type=None, warn_level=None,
cpp_package_prefix=None, install_prefix=None, use_conda=None,
- build_static=False, build_shared=True, build_unity=True,
+ build_static=True, build_shared=True, build_unity=True,
# tests & examples
with_tests=None, with_benchmarks=None, with_examples=None,
with_integration=None,
diff --git a/dev/release/rat_exclude_files.txt b/dev/release/rat_exclude_files.txt
index 0a96858..4f0deaf 100644
--- a/dev/release/rat_exclude_files.txt
+++ b/dev/release/rat_exclude_files.txt
@@ -31,6 +31,7 @@ cpp/src/generated/parquet_constants.cpp
cpp/src/generated/parquet_constants.h
cpp/src/generated/parquet_types.cpp
cpp/src/generated/parquet_types.h
+cpp/src/generated/substrait/*
cpp/src/plasma/thirdparty/ae/ae.c
cpp/src/plasma/thirdparty/ae/ae.h
cpp/src/plasma/thirdparty/ae/ae_epoll.c
diff --git a/format/substrait/extension_types.yaml b/format/substrait/extension_types.yaml
new file mode 100644
index 0000000..c905c8b
--- /dev/null
+++ b/format/substrait/extension_types.yaml
@@ -0,0 +1,87 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# substrait::{ExtensionTypeVariation, ExtensionType}s
+# for wrapping types which appear in the arrow type system but
+# are not first-class in substrait. These include:
+# - null
+# - unsigned integers
+# - half-precision floating point numbers
+# - 32-bit times and 64-bit dates
+# - timestamps with units other than microseconds
+# - timestamps with timezones other than UTC
+# - 256-bit decimals
+# - sparse and dense unions
+# - dictionary encoded types
+# - durations
+# - string and binary with 64 bit offsets
+# - list with 64-bit offsets
+# - interval<months: i32>
+# - interval<days: i32, millis: i32>
+# - interval<months: i32, days: i32, nanos: i64>
+# - arrow::ExtensionTypes
+#
+# Note that not all of these are currently implemented. In particular, these
+# extension types are currently not parameterizable in Substrait, which means
+# among other things that we can't declare dictionary type here at all since
+# we'd have to declare a different dictionary type for all encoded types
+# (but that is an infinite space). Similarly, we would have to declare a
+# timestamp variation for all possible timezone strings.
+#
+# Ultimately these declarations are a promise which needs to be backed by
+# equivalent serde in c++. This is handled by default_extension_id_registry(),
+# defined in cpp/src/arrow/engine/substrait/extension_set.cc. These files
+# currently need to be kept in sync manually; see ARROW-15535.
+
+type_variations:
+ - parent: i8
+ name: u8
+ description: an unsigned 8 bit integer
+ functions: SEPARATE
+ - parent: i16
+ name: u16
+ description: an unsigned 16 bit integer
+ functions: SEPARATE
+ - parent: i32
+ name: u32
+ description: an unsigned 32 bit integer
+ functions: SEPARATE
+ - parent: i64
+ name: u64
+ description: an unsigned 64 bit integer
+ functions: SEPARATE
+
+ - parent: i16
+ name: fp16
+ description: a 16 bit floating point number
+ functions: SEPARATE
+
+types:
+ - name: "null"
+ structure: {}
+ - name: interval_month
+ structure:
+ months: i32
+ - name: interval_day_milli
+ structure:
+ days: i32
+ millis: i32
+ - name: interval_month_day_nano
+ structure:
+ months: i32
+ days: i32
+ nanos: i64