You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by we...@apache.org on 2022/08/10 19:27:07 UTC
[arrow] branch master updated: ARROW-15582: [C++] Add support for registering standard Substrait functions (#13613)
This is an automated email from the ASF dual-hosted git repository.
westonpace pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new cdb5b2019f ARROW-15582: [C++] Add support for registering standard Substrait functions (#13613)
cdb5b2019f is described below
commit cdb5b2019f6723cb37127487c91daccbf9d238d4
Author: Weston Pace <we...@gmail.com>
AuthorDate: Wed Aug 10 12:26:59 2022 -0700
ARROW-15582: [C++] Add support for registering standard Substrait functions (#13613)
This picks up where #13285 has left off. It mostly focuses on the Substrait->Arrow direction at the moment. In addition, basic support is added for named tables. This makes it possible to create unit tests that read from in-memory tables instead of requiring unit tests to do a scan.
The PR creates some utilities in `test_plan_builder.h` which allow for the construction of simple Substrait plans programmatically. This is used to create unit tests for the function mapping.
The PR extracts id "ownership" out of the `ExtensionIdRegistry` and into its own `IdStorage` class.
The PR gets rid of `NestedExtensionIdRegistryImpl` and instead makes `ExtensionIdRegistryImpl` nested if `parent_ != nullptr`.
Authored-by: Weston Pace <we...@gmail.com>
Signed-off-by: Weston Pace <we...@gmail.com>
---
cpp/src/arrow/compute/exec/options.cc | 2 +
cpp/src/arrow/compute/exec/options.h | 4 +-
cpp/src/arrow/compute/exec/sink_node.cc | 41 --
cpp/src/arrow/compute/exec/util.cc | 20 +
cpp/src/arrow/compute/exec/util.h | 19 +
cpp/src/arrow/engine/CMakeLists.txt | 4 +-
.../arrow/engine/substrait/expression_internal.cc | 165 ++++-
.../arrow/engine/substrait/expression_internal.h | 4 +
cpp/src/arrow/engine/substrait/ext_test.cc | 82 ++-
cpp/src/arrow/engine/substrait/extension_set.cc | 786 ++++++++++++++++-----
cpp/src/arrow/engine/substrait/extension_set.h | 262 +++++--
cpp/src/arrow/engine/substrait/function_test.cc | 495 +++++++++++++
cpp/src/arrow/engine/substrait/options.h | 9 +
cpp/src/arrow/engine/substrait/plan_internal.cc | 9 +-
.../arrow/engine/substrait/relation_internal.cc | 70 +-
cpp/src/arrow/engine/substrait/serde.cc | 24 +-
cpp/src/arrow/engine/substrait/serde.h | 2 +-
cpp/src/arrow/engine/substrait/serde_test.cc | 175 +++--
.../arrow/engine/substrait/test_plan_builder.cc | 216 ++++++
cpp/src/arrow/engine/substrait/test_plan_builder.h | 72 ++
cpp/src/arrow/engine/substrait/util.cc | 10 -
cpp/src/arrow/engine/substrait/util.h | 18 -
docs/source/cpp/streaming_execution.rst | 129 +++-
python/pyarrow/_substrait.pyx | 25 +
python/pyarrow/includes/libarrow_substrait.pxd | 12 +-
python/pyarrow/substrait.py | 1 +
python/pyarrow/tests/test_substrait.py | 20 +
27 files changed, 2196 insertions(+), 480 deletions(-)
diff --git a/cpp/src/arrow/compute/exec/options.cc b/cpp/src/arrow/compute/exec/options.cc
index c09ab1c1b6..ef1a0c7e2e 100644
--- a/cpp/src/arrow/compute/exec/options.cc
+++ b/cpp/src/arrow/compute/exec/options.cc
@@ -25,6 +25,8 @@
namespace arrow {
namespace compute {
+constexpr int64_t TableSourceNodeOptions::kDefaultMaxBatchSize;
+
std::string ToString(JoinType t) {
switch (t) {
case JoinType::LEFT_SEMI:
diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h
index 4a0cd602ef..a8e8c1ee23 100644
--- a/cpp/src/arrow/compute/exec/options.h
+++ b/cpp/src/arrow/compute/exec/options.h
@@ -64,7 +64,9 @@ class ARROW_EXPORT SourceNodeOptions : public ExecNodeOptions {
/// \brief An extended Source node which accepts a table
class ARROW_EXPORT TableSourceNodeOptions : public ExecNodeOptions {
public:
- TableSourceNodeOptions(std::shared_ptr<Table> table, int64_t max_batch_size)
+ static constexpr int64_t kDefaultMaxBatchSize = 1 << 20;
+ TableSourceNodeOptions(std::shared_ptr<Table> table,
+ int64_t max_batch_size = kDefaultMaxBatchSize)
: table(table), max_batch_size(max_batch_size) {}
// arrow table which acts as the data source
diff --git a/cpp/src/arrow/compute/exec/sink_node.cc b/cpp/src/arrow/compute/exec/sink_node.cc
index a1426265cf..8af4e8e996 100644
--- a/cpp/src/arrow/compute/exec/sink_node.cc
+++ b/cpp/src/arrow/compute/exec/sink_node.cc
@@ -388,47 +388,6 @@ class ConsumingSinkNode : public ExecNode, public BackpressureControl {
std::vector<std::string> names_;
int32_t backpressure_counter_ = 0;
};
-
-/**
- * @brief This node is an extension on ConsumingSinkNode
- * to facilitate to get the output from an execution plan
- * as a table. We define a custom SinkNodeConsumer to
- * enable this functionality.
- */
-
-struct TableSinkNodeConsumer : public SinkNodeConsumer {
- public:
- TableSinkNodeConsumer(std::shared_ptr<Table>* out, MemoryPool* pool)
- : out_(out), pool_(pool) {}
-
- Status Init(const std::shared_ptr<Schema>& schema,
- BackpressureControl* backpressure_control) override {
- // If the user is collecting into a table then backpressure is meaningless
- ARROW_UNUSED(backpressure_control);
- schema_ = schema;
- return Status::OK();
- }
-
- Status Consume(ExecBatch batch) override {
- std::lock_guard<std::mutex> guard(consume_mutex_);
- ARROW_ASSIGN_OR_RAISE(auto rb, batch.ToRecordBatch(schema_, pool_));
- batches_.push_back(rb);
- return Status::OK();
- }
-
- Future<> Finish() override {
- ARROW_ASSIGN_OR_RAISE(*out_, Table::FromRecordBatches(batches_));
- return Status::OK();
- }
-
- private:
- std::shared_ptr<Table>* out_;
- MemoryPool* pool_;
- std::shared_ptr<Schema> schema_;
- std::vector<std::shared_ptr<RecordBatch>> batches_;
- std::mutex consume_mutex_;
-};
-
static Result<ExecNode*> MakeTableConsumingSinkNode(
compute::ExecPlan* plan, std::vector<compute::ExecNode*> inputs,
const compute::ExecNodeOptions& options) {
diff --git a/cpp/src/arrow/compute/exec/util.cc b/cpp/src/arrow/compute/exec/util.cc
index ae70cfcd46..a34a9c6271 100644
--- a/cpp/src/arrow/compute/exec/util.cc
+++ b/cpp/src/arrow/compute/exec/util.cc
@@ -383,5 +383,25 @@ size_t ThreadIndexer::Check(size_t thread_index) {
return thread_index;
}
+Status TableSinkNodeConsumer::Init(const std::shared_ptr<Schema>& schema,
+ BackpressureControl* backpressure_control) {
+ // If the user is collecting into a table then backpressure is meaningless
+ ARROW_UNUSED(backpressure_control);
+ schema_ = schema;
+ return Status::OK();
+}
+
+Status TableSinkNodeConsumer::Consume(ExecBatch batch) {
+ auto guard = consume_mutex_.Lock();
+ ARROW_ASSIGN_OR_RAISE(auto rb, batch.ToRecordBatch(schema_, pool_));
+ batches_.push_back(std::move(rb));
+ return Status::OK();
+}
+
+Future<> TableSinkNodeConsumer::Finish() {
+ ARROW_ASSIGN_OR_RAISE(*out_, Table::FromRecordBatches(batches_));
+ return Status::OK();
+}
+
} // namespace compute
} // namespace arrow
diff --git a/cpp/src/arrow/compute/exec/util.h b/cpp/src/arrow/compute/exec/util.h
index 30526cb835..7e716808fa 100644
--- a/cpp/src/arrow/compute/exec/util.h
+++ b/cpp/src/arrow/compute/exec/util.h
@@ -24,6 +24,7 @@
#include <vector>
#include "arrow/buffer.h"
+#include "arrow/compute/exec/options.h"
#include "arrow/compute/type_fwd.h"
#include "arrow/memory_pool.h"
#include "arrow/result.h"
@@ -342,5 +343,23 @@ class TailSkipForSIMD {
}
};
+/// \brief A consumer that collects results into an in-memory table
+struct ARROW_EXPORT TableSinkNodeConsumer : public SinkNodeConsumer {
+ public:
+ TableSinkNodeConsumer(std::shared_ptr<Table>* out, MemoryPool* pool)
+ : out_(out), pool_(pool) {}
+ Status Init(const std::shared_ptr<Schema>& schema,
+ BackpressureControl* backpressure_control) override;
+ Status Consume(ExecBatch batch) override;
+ Future<> Finish() override;
+
+ private:
+ std::shared_ptr<Table>* out_;
+ MemoryPool* pool_;
+ std::shared_ptr<Schema> schema_;
+ std::vector<std::shared_ptr<RecordBatch>> batches_;
+ util::Mutex consume_mutex_;
+};
+
} // namespace compute
} // namespace arrow
diff --git a/cpp/src/arrow/engine/CMakeLists.txt b/cpp/src/arrow/engine/CMakeLists.txt
index 8edd22900e..4109b7b3bc 100644
--- a/cpp/src/arrow/engine/CMakeLists.txt
+++ b/cpp/src/arrow/engine/CMakeLists.txt
@@ -23,9 +23,10 @@ set(ARROW_SUBSTRAIT_SRCS
substrait/expression_internal.cc
substrait/extension_set.cc
substrait/extension_types.cc
- substrait/serde.cc
substrait/plan_internal.cc
substrait/relation_internal.cc
+ substrait/serde.cc
+ substrait/test_plan_builder.cc
substrait/type_internal.cc
substrait/util.cc)
@@ -67,6 +68,7 @@ endif()
add_arrow_test(substrait_test
SOURCES
substrait/ext_test.cc
+ substrait/function_test.cc
substrait/serde_test.cc
EXTRA_LINK_LIBS
${ARROW_SUBSTRAIT_TEST_LINK_LIBS}
diff --git a/cpp/src/arrow/engine/substrait/expression_internal.cc b/cpp/src/arrow/engine/substrait/expression_internal.cc
index 07c222bc4c..589d7e6ac6 100644
--- a/cpp/src/arrow/engine/substrait/expression_internal.cc
+++ b/cpp/src/arrow/engine/substrait/expression_internal.cc
@@ -41,6 +41,84 @@ namespace internal {
using ::arrow::internal::make_unique;
} // namespace internal
+Status DecodeArg(const substrait::FunctionArgument& arg, uint32_t idx,
+ SubstraitCall* call, const ExtensionSet& ext_set,
+ const ConversionOptions& conversion_options) {
+ if (arg.has_enum_()) {
+ const substrait::FunctionArgument::Enum& enum_val = arg.enum_();
+ if (enum_val.has_specified()) {
+ call->SetEnumArg(idx, enum_val.specified());
+ } else {
+ call->SetEnumArg(idx, util::nullopt);
+ }
+ } else if (arg.has_value()) {
+ ARROW_ASSIGN_OR_RAISE(compute::Expression expr,
+ FromProto(arg.value(), ext_set, conversion_options));
+ call->SetValueArg(idx, std::move(expr));
+ } else if (arg.has_type()) {
+ return Status::NotImplemented("Type arguments not currently supported");
+ } else {
+ return Status::NotImplemented("Unrecognized function argument class");
+ }
+ return Status::OK();
+}
+
+Result<SubstraitCall> DecodeScalarFunction(
+ Id id, const substrait::Expression::ScalarFunction& scalar_fn,
+ const ExtensionSet& ext_set, const ConversionOptions& conversion_options) {
+ ARROW_ASSIGN_OR_RAISE(auto output_type_and_nullable,
+ FromProto(scalar_fn.output_type(), ext_set, conversion_options));
+ SubstraitCall call(id, output_type_and_nullable.first, output_type_and_nullable.second);
+ for (int i = 0; i < scalar_fn.arguments_size(); i++) {
+ ARROW_RETURN_NOT_OK(DecodeArg(scalar_fn.arguments(i), static_cast<uint32_t>(i), &call,
+ ext_set, conversion_options));
+ }
+ return std::move(call);
+}
+
+std::string EnumToString(int value, const google::protobuf::EnumDescriptor* descriptor) {
+ const google::protobuf::EnumValueDescriptor* value_desc =
+ descriptor->FindValueByNumber(value);
+ if (value_desc == nullptr) {
+ return "unknown";
+ }
+ return value_desc->name();
+}
+
+Result<SubstraitCall> FromProto(const substrait::AggregateFunction& func, bool is_hash,
+ const ExtensionSet& ext_set,
+ const ConversionOptions& conversion_options) {
+ if (func.phase() != substrait::AggregationPhase::AGGREGATION_PHASE_INITIAL_TO_RESULT) {
+ return Status::NotImplemented(
+ "Unsupported aggregation phase '",
+ EnumToString(func.phase(), substrait::AggregationPhase_descriptor()),
+ "'. Only INITIAL_TO_RESULT is supported");
+ }
+ if (func.invocation() !=
+ substrait::AggregateFunction::AggregationInvocation::
+ AggregateFunction_AggregationInvocation_AGGREGATION_INVOCATION_ALL) {
+ return Status::NotImplemented(
+ "Unsupported aggregation invocation '",
+ EnumToString(func.invocation(),
+ substrait::AggregateFunction::AggregationInvocation_descriptor()),
+ "'. Only AGGREGATION_INVOCATION_ALL is "
+ "supported");
+ }
+ if (func.sorts_size() > 0) {
+ return Status::NotImplemented("Aggregation sorts are not supported");
+ }
+ ARROW_ASSIGN_OR_RAISE(auto output_type_and_nullable,
+ FromProto(func.output_type(), ext_set, conversion_options));
+ ARROW_ASSIGN_OR_RAISE(Id id, ext_set.DecodeFunction(func.function_reference()));
+ SubstraitCall call(id, output_type_and_nullable.first, output_type_and_nullable.second,
+ is_hash);
+ for (int i = 0; i < func.arguments_size(); i++) {
+ ARROW_RETURN_NOT_OK(DecodeArg(func.arguments(i), static_cast<uint32_t>(i), &call,
+ ext_set, conversion_options));
+ }
+ return std::move(call);
+}
+
Result<compute::Expression> FromProto(const substrait::Expression& expr,
const ExtensionSet& ext_set,
const ConversionOptions& conversion_options) {
@@ -166,34 +244,14 @@ Result<compute::Expression> FromProto(const substrait::Expression& expr,
case substrait::Expression::kScalarFunction: {
const auto& scalar_fn = expr.scalar_function();
- ARROW_ASSIGN_OR_RAISE(auto decoded_function,
+ ARROW_ASSIGN_OR_RAISE(Id function_id,
ext_set.DecodeFunction(scalar_fn.function_reference()));
-
- std::vector<compute::Expression> arguments(scalar_fn.arguments_size());
- for (int i = 0; i < scalar_fn.arguments_size(); ++i) {
- const auto& argument = scalar_fn.arguments(i);
- switch (argument.arg_type_case()) {
- case substrait::FunctionArgument::kValue: {
- ARROW_ASSIGN_OR_RAISE(
- arguments[i], FromProto(argument.value(), ext_set, conversion_options));
- break;
- }
- default:
- return Status::NotImplemented(
- "only value arguments are currently supported for functions");
- }
- }
-
- auto func_name = decoded_function.name.to_string();
- if (func_name != "cast") {
- return compute::call(func_name, std::move(arguments));
- } else {
- ARROW_ASSIGN_OR_RAISE(
- auto output_type_desc,
- FromProto(scalar_fn.output_type(), ext_set, conversion_options));
- auto cast_options = compute::CastOptions::Safe(std::move(output_type_desc.first));
- return compute::call(func_name, std::move(arguments), std::move(cast_options));
- }
+ ARROW_ASSIGN_OR_RAISE(ExtensionIdRegistry::SubstraitCallToArrow function_converter,
+ ext_set.registry()->GetSubstraitCallToArrow(function_id));
+ ARROW_ASSIGN_OR_RAISE(
+ SubstraitCall substrait_call,
+ DecodeScalarFunction(function_id, scalar_fn, ext_set, conversion_options));
+ return function_converter(substrait_call);
}
default:
@@ -827,6 +885,42 @@ static Result<std::unique_ptr<substrait::Expression>> MakeListElementReference(
return MakeDirectReference(std::move(expr), std::move(ref_segment));
}
+Result<std::unique_ptr<substrait::Expression::ScalarFunction>> EncodeSubstraitCall(
+ const SubstraitCall& call, ExtensionSet* ext_set,
+ const ConversionOptions& conversion_options) {
+ ARROW_ASSIGN_OR_RAISE(uint32_t anchor, ext_set->EncodeFunction(call.id()));
+ auto scalar_fn = internal::make_unique<substrait::Expression::ScalarFunction>();
+ scalar_fn->set_function_reference(anchor);
+ ARROW_ASSIGN_OR_RAISE(
+ std::unique_ptr<substrait::Type> output_type,
+ ToProto(*call.output_type(), call.output_nullable(), ext_set, conversion_options));
+ scalar_fn->set_allocated_output_type(output_type.release());
+
+ for (uint32_t i = 0; i < call.size(); i++) {
+ substrait::FunctionArgument* arg = scalar_fn->add_arguments();
+ if (call.HasEnumArg(i)) {
+ auto enum_val = internal::make_unique<substrait::FunctionArgument::Enum>();
+ ARROW_ASSIGN_OR_RAISE(util::optional<util::string_view> enum_arg,
+ call.GetEnumArg(i));
+ if (enum_arg) {
+ enum_val->set_specified(enum_arg->to_string());
+ } else {
+ enum_val->set_allocated_unspecified(new google::protobuf::Empty());
+ }
+ arg->set_allocated_enum_(enum_val.release());
+ } else if (call.HasValueArg(i)) {
+ ARROW_ASSIGN_OR_RAISE(compute::Expression value_arg, call.GetValueArg(i));
+ ARROW_ASSIGN_OR_RAISE(std::unique_ptr<substrait::Expression> value_expr,
+ ToProto(value_arg, ext_set, conversion_options));
+ arg->set_allocated_value(value_expr.release());
+ } else {
+ return Status::Invalid("Call reported having ", call.size(),
+ " arguments but no argument could be found at index ", i);
+ }
+ }
+ return std::move(scalar_fn);
+}
+
Result<std::unique_ptr<substrait::Expression>> ToProto(
const compute::Expression& expr, ExtensionSet* ext_set,
const ConversionOptions& conversion_options) {
@@ -933,17 +1027,12 @@ Result<std::unique_ptr<substrait::Expression>> ToProto(
}
// other expression types dive into extensions immediately
- ARROW_ASSIGN_OR_RAISE(auto anchor, ext_set->EncodeFunction(call->function_name));
-
- auto scalar_fn = internal::make_unique<substrait::Expression::ScalarFunction>();
- scalar_fn->set_function_reference(anchor);
- scalar_fn->mutable_arguments()->Reserve(static_cast<int>(arguments.size()));
- for (auto& arg : arguments) {
- auto argument = internal::make_unique<substrait::FunctionArgument>();
- argument->set_allocated_value(arg.release());
- scalar_fn->mutable_arguments()->AddAllocated(argument.release());
- }
-
+ ARROW_ASSIGN_OR_RAISE(
+ ExtensionIdRegistry::ArrowToSubstraitCall converter,
+ ext_set->registry()->GetArrowToSubstraitCall(call->function_name));
+ ARROW_ASSIGN_OR_RAISE(SubstraitCall substrait_call, converter(*call));
+ ARROW_ASSIGN_OR_RAISE(std::unique_ptr<substrait::Expression::ScalarFunction> scalar_fn,
+ EncodeSubstraitCall(substrait_call, ext_set, conversion_options));
out->set_allocated_scalar_function(scalar_fn.release());
return std::move(out);
}
diff --git a/cpp/src/arrow/engine/substrait/expression_internal.h b/cpp/src/arrow/engine/substrait/expression_internal.h
index 2b4dec2a00..f132afc0c1 100644
--- a/cpp/src/arrow/engine/substrait/expression_internal.h
+++ b/cpp/src/arrow/engine/substrait/expression_internal.h
@@ -50,5 +50,9 @@ Result<std::unique_ptr<substrait::Expression::Literal>> ToProto(const Datum&,
ExtensionSet*,
const ConversionOptions&);
+ARROW_ENGINE_EXPORT
+Result<SubstraitCall> FromProto(const substrait::AggregateFunction&, bool is_hash,
+ const ExtensionSet&, const ConversionOptions&);
+
} // namespace engine
} // namespace arrow
diff --git a/cpp/src/arrow/engine/substrait/ext_test.cc b/cpp/src/arrow/engine/substrait/ext_test.cc
index 8e41cb7c98..4b37aa8fcd 100644
--- a/cpp/src/arrow/engine/substrait/ext_test.cc
+++ b/cpp/src/arrow/engine/substrait/ext_test.cc
@@ -56,12 +56,10 @@ struct DefaultExtensionIdRegistryProvider : public ExtensionIdRegistryProvider {
struct NestedExtensionIdRegistryProvider : public ExtensionIdRegistryProvider {
virtual ~NestedExtensionIdRegistryProvider() {}
- std::shared_ptr<ExtensionIdRegistry> registry_ = substrait::MakeExtensionIdRegistry();
+ std::shared_ptr<ExtensionIdRegistry> registry_ = MakeExtensionIdRegistry();
ExtensionIdRegistry* get() const override { return &*registry_; }
};
-using Id = ExtensionIdRegistry::Id;
-
bool operator==(const Id& id1, const Id& id2) {
return id1.uri == id2.uri && id1.name == id2.name;
}
@@ -85,8 +83,8 @@ static const std::vector<TypeName> kTypeNames = {
TypeName{month_day_nano_interval(), "interval_month_day_nano"},
};
-static const std::vector<util::string_view> kFunctionNames = {
- "add",
+static const std::vector<Id> kFunctionIds = {
+ {kSubstraitArithmeticFunctionsUri, "add"},
};
static const std::vector<util::string_view> kTempFunctionNames = {
@@ -141,15 +139,12 @@ TEST_P(ExtensionIdRegistryTest, GetFunctions) {
auto provider = std::get<0>(GetParam());
auto registry = provider->get();
- for (util::string_view name : kFunctionNames) {
- auto id = Id{kArrowExtTypesUri, name};
- for (auto funcrec_opt : {registry->GetFunction(id), registry->GetFunction(name)}) {
- ASSERT_TRUE(funcrec_opt);
- auto funcrec = funcrec_opt.value();
- ASSERT_EQ(id, funcrec.id);
- ASSERT_EQ(name, funcrec.function_name);
- }
+ for (Id func_id : kFunctionIds) {
+ ASSERT_OK_AND_ASSIGN(ExtensionIdRegistry::SubstraitCallToArrow converter,
+ registry->GetSubstraitCallToArrow(func_id));
+ ASSERT_TRUE(converter);
}
+ ASSERT_RAISES(NotImplemented, registry->GetSubstraitCallToArrow(kNonExistentId));
ASSERT_FALSE(registry->GetType(kNonExistentId));
ASSERT_FALSE(registry->GetType(*kNonExistentTypeName.type));
}
@@ -158,10 +153,10 @@ TEST_P(ExtensionIdRegistryTest, ReregisterFunctions) {
auto provider = std::get<0>(GetParam());
auto registry = provider->get();
- for (util::string_view name : kFunctionNames) {
- auto id = Id{kArrowExtTypesUri, name};
- ASSERT_RAISES(Invalid, registry->CanRegisterFunction(id, name.to_string()));
- ASSERT_RAISES(Invalid, registry->RegisterFunction(id, name.to_string()));
+ for (Id function_id : kFunctionIds) {
+ ASSERT_RAISES(Invalid, registry->CanAddSubstraitCallToArrow(function_id));
+ ASSERT_RAISES(Invalid, registry->AddSubstraitCallToArrow(
+ function_id, function_id.name.to_string()));
}
}
@@ -173,11 +168,26 @@ INSTANTIATE_TEST_SUITE_P(
std::make_tuple(std::make_shared<NestedExtensionIdRegistryProvider>(),
"nested")));
+TEST(ExtensionIdRegistryTest, GetSupportedSubstraitFunctions) {
+ ExtensionIdRegistry* default_registry = default_extension_id_registry();
+ std::vector<std::string> supported_functions =
+ default_registry->GetSupportedSubstraitFunctions();
+ std::size_t num_functions = supported_functions.size();
+ ASSERT_GT(num_functions, 0);
+
+ std::shared_ptr<ExtensionIdRegistry> nested =
+ nested_extension_id_registry(default_registry);
+ ASSERT_OK(nested->AddSubstraitCallToArrow(kNonExistentId, "some_function"));
+
+ std::size_t num_nested_functions = nested->GetSupportedSubstraitFunctions().size();
+ ASSERT_EQ(num_functions + 1, num_nested_functions);
+}
+
TEST(ExtensionIdRegistryTest, RegisterTempTypes) {
auto default_registry = default_extension_id_registry();
constexpr int rounds = 3;
for (int i = 0; i < rounds; i++) {
- auto registry = substrait::MakeExtensionIdRegistry();
+ auto registry = MakeExtensionIdRegistry();
for (TypeName e : kTempTypeNames) {
auto id = Id{kArrowExtTypesUri, e.name};
@@ -194,15 +204,15 @@ TEST(ExtensionIdRegistryTest, RegisterTempFunctions) {
auto default_registry = default_extension_id_registry();
constexpr int rounds = 3;
for (int i = 0; i < rounds; i++) {
- auto registry = substrait::MakeExtensionIdRegistry();
+ auto registry = MakeExtensionIdRegistry();
for (util::string_view name : kTempFunctionNames) {
auto id = Id{kArrowExtTypesUri, name};
- ASSERT_OK(registry->CanRegisterFunction(id, name.to_string()));
- ASSERT_OK(registry->RegisterFunction(id, name.to_string()));
- ASSERT_RAISES(Invalid, registry->CanRegisterFunction(id, name.to_string()));
- ASSERT_RAISES(Invalid, registry->RegisterFunction(id, name.to_string()));
- ASSERT_OK(default_registry->CanRegisterFunction(id, name.to_string()));
+ ASSERT_OK(registry->CanAddSubstraitCallToArrow(id));
+ ASSERT_OK(registry->AddSubstraitCallToArrow(id, name.to_string()));
+ ASSERT_RAISES(Invalid, registry->CanAddSubstraitCallToArrow(id));
+ ASSERT_RAISES(Invalid, registry->AddSubstraitCallToArrow(id, name.to_string()));
+ ASSERT_OK(default_registry->CanAddSubstraitCallToArrow(id));
}
}
}
@@ -246,24 +256,24 @@ TEST(ExtensionIdRegistryTest, RegisterNestedFunctions) {
auto default_registry = default_extension_id_registry();
constexpr int rounds = 3;
for (int i = 0; i < rounds; i++) {
- auto registry1 = substrait::MakeExtensionIdRegistry();
+ auto registry1 = MakeExtensionIdRegistry();
- ASSERT_OK(registry1->CanRegisterFunction(id1, name1.to_string()));
- ASSERT_OK(registry1->RegisterFunction(id1, name1.to_string()));
+ ASSERT_OK(registry1->CanAddSubstraitCallToArrow(id1));
+ ASSERT_OK(registry1->AddSubstraitCallToArrow(id1, name1.to_string()));
for (int j = 0; j < rounds; j++) {
- auto registry2 = substrait::MakeExtensionIdRegistry();
+ auto registry2 = MakeExtensionIdRegistry();
- ASSERT_OK(registry2->CanRegisterFunction(id2, name2.to_string()));
- ASSERT_OK(registry2->RegisterFunction(id2, name2.to_string()));
- ASSERT_RAISES(Invalid, registry2->CanRegisterFunction(id2, name2.to_string()));
- ASSERT_RAISES(Invalid, registry2->RegisterFunction(id2, name2.to_string()));
- ASSERT_OK(default_registry->CanRegisterFunction(id2, name2.to_string()));
+ ASSERT_OK(registry2->CanAddSubstraitCallToArrow(id2));
+ ASSERT_OK(registry2->AddSubstraitCallToArrow(id2, name2.to_string()));
+ ASSERT_RAISES(Invalid, registry2->CanAddSubstraitCallToArrow(id2));
+ ASSERT_RAISES(Invalid, registry2->AddSubstraitCallToArrow(id2, name2.to_string()));
+ ASSERT_OK(default_registry->CanAddSubstraitCallToArrow(id2));
}
- ASSERT_RAISES(Invalid, registry1->CanRegisterFunction(id1, name1.to_string()));
- ASSERT_RAISES(Invalid, registry1->RegisterFunction(id1, name1.to_string()));
- ASSERT_OK(default_registry->CanRegisterFunction(id1, name1.to_string()));
+ ASSERT_RAISES(Invalid, registry1->CanAddSubstraitCallToArrow(id1));
+ ASSERT_RAISES(Invalid, registry1->AddSubstraitCallToArrow(id1, name1.to_string()));
+ ASSERT_OK(default_registry->CanAddSubstraitCallToArrow(id1));
}
}
diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc
index 08eb6acc9c..493d576e83 100644
--- a/cpp/src/arrow/engine/substrait/extension_set.cc
+++ b/cpp/src/arrow/engine/substrait/extension_set.cc
@@ -17,9 +17,9 @@
#include "arrow/engine/substrait/extension_set.h"
-#include <unordered_map>
-#include <unordered_set>
+#include <sstream>
+#include "arrow/engine/substrait/expression_internal.h"
#include "arrow/util/hash_util.h"
#include "arrow/util/hashing.h"
#include "arrow/util/string_view.h"
@@ -28,6 +28,9 @@ namespace arrow {
namespace engine {
namespace {
+// TODO(ARROW-16988): replace this with EXACT_ROUNDTRIP mode
+constexpr bool kExactRoundTrip = true;
+
struct TypePtrHashEq {
template <typename Ptr>
size_t operator()(const Ptr& type) const {
@@ -42,16 +45,115 @@ struct TypePtrHashEq {
} // namespace
-size_t ExtensionIdRegistry::IdHashEq::operator()(ExtensionIdRegistry::Id id) const {
+std::string Id::ToString() const {
+ std::stringstream sstream;
+ sstream << uri;
+ sstream << '#';
+ sstream << name;
+ return sstream.str();
+}
+
+size_t IdHashEq::operator()(Id id) const {
constexpr ::arrow::internal::StringViewHash hash = {};
auto out = static_cast<size_t>(hash(id.uri));
::arrow::internal::hash_combine(out, hash(id.name));
return out;
}
-bool ExtensionIdRegistry::IdHashEq::operator()(ExtensionIdRegistry::Id l,
- ExtensionIdRegistry::Id r) const {
- return l.uri == r.uri && l.name == r.name;
+bool IdHashEq::operator()(Id l, Id r) const { return l.uri == r.uri && l.name == r.name; }
+
+Id IdStorage::Emplace(Id id) {
+ util::string_view owned_uri = EmplaceUri(id.uri);
+
+ util::string_view owned_name;
+ auto name_itr = names_.find(id.name);
+ if (name_itr == names_.end()) {
+ owned_names_.emplace_back(id.name);
+ owned_name = owned_names_.back();
+ names_.insert(owned_name);
+ } else {
+ owned_name = *name_itr;
+ }
+
+ return {owned_uri, owned_name};
+}
+
+util::optional<Id> IdStorage::Find(Id id) const {
+ util::optional<util::string_view> maybe_owned_uri = FindUri(id.uri);
+ if (!maybe_owned_uri) {
+ return util::nullopt;
+ }
+
+ auto name_itr = names_.find(id.name);
+ if (name_itr == names_.end()) {
+ return util::nullopt;
+ } else {
+ return Id{*maybe_owned_uri, *name_itr};
+ }
+}
+
+util::optional<util::string_view> IdStorage::FindUri(util::string_view uri) const {
+ auto uri_itr = uris_.find(uri);
+ if (uri_itr == uris_.end()) {
+ return util::nullopt;
+ }
+ return *uri_itr;
+}
+
+util::string_view IdStorage::EmplaceUri(util::string_view uri) {
+ auto uri_itr = uris_.find(uri);
+ if (uri_itr == uris_.end()) {
+ owned_uris_.emplace_back(uri);
+ util::string_view owned_uri = owned_uris_.back();
+ uris_.insert(owned_uri);
+ return owned_uri;
+ }
+ return *uri_itr;
+}
+
+Result<util::optional<util::string_view>> SubstraitCall::GetEnumArg(
+ uint32_t index) const {
+ if (index >= size_) {
+ return Status::Invalid("Expected Substrait call to have an enum argument at index ",
+ index, " but it did not have enough arguments");
+ }
+ auto enum_arg_it = enum_args_.find(index);
+ if (enum_arg_it == enum_args_.end()) {
+ return Status::Invalid("Expected Substrait call to have an enum argument at index ",
+ index, " but the argument was not an enum.");
+ }
+ return enum_arg_it->second;
+}
+
+bool SubstraitCall::HasEnumArg(uint32_t index) const {
+ return enum_args_.find(index) != enum_args_.end();
+}
+
+void SubstraitCall::SetEnumArg(uint32_t index, util::optional<std::string> enum_arg) {
+ size_ = std::max(size_, index + 1);
+ enum_args_[index] = std::move(enum_arg);
+}
+
+Result<compute::Expression> SubstraitCall::GetValueArg(uint32_t index) const {
+ if (index >= size_) {
+ return Status::Invalid("Expected Substrait call to have a value argument at index ",
+ index, " but it did not have enough arguments");
+ }
+ auto value_arg_it = value_args_.find(index);
+ if (value_arg_it == value_args_.end()) {
+ return Status::Invalid("Expected Substrait call to have a value argument at index ",
+ index, " but the argument was not a value");
+ }
+ return value_arg_it->second;
+}
+
+bool SubstraitCall::HasValueArg(uint32_t index) const {
+ return value_args_.find(index) != value_args_.end();
+}
+
+void SubstraitCall::SetValueArg(uint32_t index, compute::Expression value_arg) {
+ size_ = std::max(size_, index + 1);
+ value_args_[index] = std::move(value_arg);
}
// A builder used when creating a Substrait plan from an Arrow execution plan. In
@@ -97,54 +199,54 @@ Result<ExtensionSet> ExtensionSet::Make(
std::unordered_map<uint32_t, util::string_view> uris,
std::unordered_map<uint32_t, Id> type_ids,
std::unordered_map<uint32_t, Id> function_ids, const ExtensionIdRegistry* registry) {
- ExtensionSet set;
+ ExtensionSet set(default_extension_id_registry());
set.registry_ = registry;
- // TODO(bkietz) move this into the registry as registry->OwnUris(&uris) or so
- std::unordered_set<util::string_view, ::arrow::internal::StringViewHash>
- uris_owned_by_registry;
- for (util::string_view uri : registry->Uris()) {
- uris_owned_by_registry.insert(uri);
- }
-
for (auto& uri : uris) {
- auto it = uris_owned_by_registry.find(uri.second);
- if (it == uris_owned_by_registry.end()) {
- return Status::KeyError("Uri '", uri.second, "' not found in registry");
+ util::optional<util::string_view> maybe_uri_internal = registry->FindUri(uri.second);
+ if (maybe_uri_internal) {
+ set.uris_[uri.first] = *maybe_uri_internal;
+ } else {
+ if (kExactRoundTrip) {
+ return Status::Invalid(
+ "Plan contained a URI that the extension registry is unaware of: ",
+ uri.second);
+ }
+ set.uris_[uri.first] = set.plan_specific_ids_.EmplaceUri(uri.second);
}
- uri.second = *it; // Ensure uris point into the registry's memory
- set.AddUri(uri);
}
set.types_.reserve(type_ids.size());
+ for (const auto& type_id : type_ids) {
+ if (type_id.second.empty()) continue;
+ RETURN_NOT_OK(set.CheckHasUri(type_id.second.uri));
- for (unsigned int i = 0; i < static_cast<unsigned int>(type_ids.size()); ++i) {
- if (type_ids[i].empty()) continue;
- RETURN_NOT_OK(set.CheckHasUri(type_ids[i].uri));
-
- if (auto rec = registry->GetType(type_ids[i])) {
- set.types_[i] = {rec->id, rec->type};
+ if (auto rec = registry->GetType(type_id.second)) {
+ set.types_[type_id.first] = {rec->id, rec->type};
continue;
}
- return Status::Invalid("Type ", type_ids[i].uri, "#", type_ids[i].name, " not found");
+ return Status::Invalid("Type ", type_id.second.uri, "#", type_id.second.name,
+ " not found");
}
set.functions_.reserve(function_ids.size());
-
- for (unsigned int i = 0; i < static_cast<unsigned int>(function_ids.size()); ++i) {
- if (function_ids[i].empty()) continue;
- RETURN_NOT_OK(set.CheckHasUri(function_ids[i].uri));
-
- if (auto rec = registry->GetFunction(function_ids[i])) {
- set.functions_[i] = {rec->id, rec->function_name};
- continue;
+ for (const auto& function_id : function_ids) {
+ if (function_id.second.empty()) continue;
+ RETURN_NOT_OK(set.CheckHasUri(function_id.second.uri));
+ util::optional<Id> maybe_id_internal = registry->FindId(function_id.second);
+ if (maybe_id_internal) {
+ set.functions_[function_id.first] = *maybe_id_internal;
+ } else {
+ if (kExactRoundTrip) {
+ return Status::Invalid(
+ "Plan contained a function id that the extension registry is unaware of: ",
+ function_id.second.uri, "#", function_id.second.name);
+ }
+ set.functions_[function_id.first] =
+ set.plan_specific_ids_.Emplace(function_id.second);
}
- return Status::Invalid("Function ", function_ids[i].uri, "#", function_ids[i].name,
- " not found");
}
- set.uris_ = std::move(uris);
-
return std::move(set);
}
@@ -162,39 +264,34 @@ Result<uint32_t> ExtensionSet::EncodeType(const DataType& type) {
auto it_success =
types_map_.emplace(rec->id, static_cast<uint32_t>(types_map_.size()));
if (it_success.second) {
- DCHECK_EQ(types_.find(static_cast<unsigned int>(types_.size())), types_.end())
+ DCHECK_EQ(types_.find(static_cast<uint32_t>(types_.size())), types_.end())
<< "Type existed in types_ but not types_map_. ExtensionSet is inconsistent";
- types_[static_cast<unsigned int>(types_.size())] = {rec->id, rec->type};
+ types_[static_cast<uint32_t>(types_.size())] = {rec->id, rec->type};
}
return it_success.first->second;
}
return Status::KeyError("type ", type.ToString(), " not found in the registry");
}
-Result<ExtensionSet::FunctionRecord> ExtensionSet::DecodeFunction(uint32_t anchor) const {
- if (functions_.find(anchor) == functions_.end() || functions_.at(anchor).id.empty()) {
+Result<Id> ExtensionSet::DecodeFunction(uint32_t anchor) const {
+ if (functions_.find(anchor) == functions_.end() || functions_.at(anchor).empty()) {
return Status::Invalid("User defined function reference ", anchor,
" did not have a corresponding anchor in the extension set");
}
return functions_.at(anchor);
}
-Result<uint32_t> ExtensionSet::EncodeFunction(util::string_view function_name) {
- if (auto rec = registry_->GetFunction(function_name)) {
- RETURN_NOT_OK(this->AddUri(rec->id));
- auto it_success =
- functions_map_.emplace(rec->id, static_cast<uint32_t>(functions_map_.size()));
- if (it_success.second) {
- DCHECK_EQ(functions_.find(static_cast<unsigned int>(functions_.size())),
- functions_.end())
- << "Function existed in functions_ but not functions_map_. ExtensionSet is "
- "inconsistent";
- functions_[static_cast<unsigned int>(functions_.size())] = {rec->id,
- rec->function_name};
- }
- return it_success.first->second;
+Result<uint32_t> ExtensionSet::EncodeFunction(Id function_id) {
+ RETURN_NOT_OK(this->AddUri(function_id));
+ auto it_success =
+ functions_map_.emplace(function_id, static_cast<uint32_t>(functions_map_.size()));
+ if (it_success.second) {
+ DCHECK_EQ(functions_.find(static_cast<uint32_t>(functions_.size())), functions_.end())
+ << "Function existed in functions_ but not functions_map_. ExtensionSet is "
+ "inconsistent";
+ functions_[static_cast<uint32_t>(functions_.size())] = function_id;
}
- return Status::KeyError("function ", function_name, " not found in the registry");
+ return it_success.first->second;
}
template <typename KeyToIndex, typename Key>
@@ -207,16 +304,38 @@ const int* GetIndex(const KeyToIndex& key_to_index, const Key& key) {
namespace {
struct ExtensionIdRegistryImpl : ExtensionIdRegistry {
+ ExtensionIdRegistryImpl() : parent_(nullptr) {}
+ explicit ExtensionIdRegistryImpl(const ExtensionIdRegistry* parent) : parent_(parent) {}
+
virtual ~ExtensionIdRegistryImpl() {}
- std::vector<util::string_view> Uris() const override {
- return {uris_.begin(), uris_.end()};
+ util::optional<util::string_view> FindUri(util::string_view uri) const override {
+ if (parent_) {
+ util::optional<util::string_view> parent_uri = parent_->FindUri(uri);
+ if (parent_uri) {
+ return parent_uri;
+ }
+ }
+ return ids_.FindUri(uri);
+ }
+
+ util::optional<Id> FindId(Id id) const override {
+ if (parent_) {
+ util::optional<Id> parent_id = parent_->FindId(id);
+ if (parent_id) {
+ return parent_id;
+ }
+ }
+ return ids_.Find(id);
}
util::optional<TypeRecord> GetType(const DataType& type) const override {
if (auto index = GetIndex(type_to_index_, &type)) {
return TypeRecord{type_ids_[*index], types_[*index]};
}
+ if (parent_) {
+ return parent_->GetType(type);
+ }
return {};
}
@@ -224,6 +343,9 @@ struct ExtensionIdRegistryImpl : ExtensionIdRegistry {
if (auto index = GetIndex(id_to_index_, id)) {
return TypeRecord{type_ids_[*index], types_[*index]};
}
+ if (parent_) {
+ return parent_->GetType(id);
+ }
return {};
}
@@ -234,14 +356,20 @@ struct ExtensionIdRegistryImpl : ExtensionIdRegistry {
if (type_to_index_.find(&*type) != type_to_index_.end()) {
return Status::Invalid("Type was already registered");
}
+ if (parent_) {
+ return parent_->CanRegisterType(id, type);
+ }
return Status::OK();
}
Status RegisterType(Id id, std::shared_ptr<DataType> type) override {
DCHECK_EQ(type_ids_.size(), types_.size());
- Id copied_id{*uris_.emplace(id.uri.to_string()).first,
- *names_.emplace(id.name.to_string()).first};
+ if (parent_) {
+ ARROW_RETURN_NOT_OK(parent_->CanRegisterType(id, type));
+ }
+
+ Id copied_id = ids_.Emplace(id);
auto index = static_cast<int>(type_ids_.size());
@@ -261,155 +389,394 @@ struct ExtensionIdRegistryImpl : ExtensionIdRegistry {
return Status::OK();
}
- util::optional<FunctionRecord> GetFunction(
- util::string_view arrow_function_name) const override {
- if (auto index = GetIndex(function_name_to_index_, arrow_function_name)) {
- return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]};
+ Status CanAddSubstraitCallToArrow(Id substrait_function_id) const override {
+ if (substrait_to_arrow_.find(substrait_function_id) != substrait_to_arrow_.end()) {
+ return Status::Invalid("Cannot register function converter for Substrait id ",
+ substrait_function_id.ToString(),
+ " because a converter already exists");
}
- return {};
+ if (parent_) {
+ return parent_->CanAddSubstraitCallToArrow(substrait_function_id);
+ }
+ return Status::OK();
}
- util::optional<FunctionRecord> GetFunction(Id id) const override {
- if (auto index = GetIndex(function_id_to_index_, id)) {
- return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]};
+ Status CanAddSubstraitAggregateToArrow(Id substrait_function_id) const override {
+ if (substrait_to_arrow_agg_.find(substrait_function_id) !=
+ substrait_to_arrow_agg_.end()) {
+ return Status::Invalid(
+ "Cannot register aggregate function converter for Substrait id ",
+ substrait_function_id.ToString(),
+ " because an aggregate converter already exists");
}
- return {};
+ if (parent_) {
+ return parent_->CanAddSubstraitAggregateToArrow(substrait_function_id);
+ }
+ return Status::OK();
+ }
+
+ template <typename ConverterType>
+ Status AddSubstraitToArrowFunc(
+ Id substrait_id, ConverterType conversion_func,
+ std::unordered_map<Id, ConverterType, IdHashEq, IdHashEq>* dest) {
+ // Convert id to view into registry-owned memory
+ Id copied_id = ids_.Emplace(substrait_id);
+
+ auto add_result = dest->emplace(copied_id, std::move(conversion_func));
+ if (!add_result.second) {
+ return Status::Invalid(
+ "Failed to register Substrait to Arrow function converter because a converter "
+ "already existed for Substrait id ",
+ substrait_id.ToString());
+ }
+
+ return Status::OK();
+ }
+
+ Status AddSubstraitCallToArrow(Id substrait_function_id,
+ SubstraitCallToArrow conversion_func) override {
+ if (parent_) {
+ ARROW_RETURN_NOT_OK(parent_->CanAddSubstraitCallToArrow(substrait_function_id));
+ }
+ return AddSubstraitToArrowFunc<SubstraitCallToArrow>(
+ substrait_function_id, std::move(conversion_func), &substrait_to_arrow_);
}
- Status CanRegisterFunction(Id id,
- const std::string& arrow_function_name) const override {
- if (function_id_to_index_.find(id) != function_id_to_index_.end()) {
- return Status::Invalid("Function id was already registered");
+ Status AddSubstraitAggregateToArrow(
+ Id substrait_function_id, SubstraitAggregateToArrow conversion_func) override {
+ if (parent_) {
+ ARROW_RETURN_NOT_OK(
+ parent_->CanAddSubstraitAggregateToArrow(substrait_function_id));
}
- if (function_name_to_index_.find(arrow_function_name) !=
- function_name_to_index_.end()) {
- return Status::Invalid("Function name was already registered");
+ return AddSubstraitToArrowFunc<SubstraitAggregateToArrow>(
+ substrait_function_id, std::move(conversion_func), &substrait_to_arrow_agg_);
+ }
+
+ template <typename ConverterType>
+ Status AddArrowToSubstraitFunc(std::string arrow_function_name, ConverterType converter,
+ std::unordered_map<std::string, ConverterType>* dest) {
+ auto add_result = dest->emplace(std::move(arrow_function_name), std::move(converter));
+ if (!add_result.second) {
+ return Status::Invalid(
+ "Failed to register Arrow to Substrait function converter for Arrow function ",
+ arrow_function_name, " because a converter already existed");
}
return Status::OK();
}
- Status RegisterFunction(Id id, std::string arrow_function_name) override {
- DCHECK_EQ(function_ids_.size(), function_name_ptrs_.size());
+ Status AddArrowToSubstraitCall(std::string arrow_function_name,
+ ArrowToSubstraitCall converter) override {
+ if (parent_) {
+ ARROW_RETURN_NOT_OK(parent_->CanAddArrowToSubstraitCall(arrow_function_name));
+ }
+ return AddArrowToSubstraitFunc(std::move(arrow_function_name), converter,
+ &arrow_to_substrait_);
+ }
- Id copied_id{*uris_.emplace(id.uri.to_string()).first,
- *names_.emplace(id.name.to_string()).first};
+ Status AddArrowToSubstraitAggregate(std::string arrow_function_name,
+ ArrowToSubstraitAggregate converter) override {
+ if (parent_) {
+ ARROW_RETURN_NOT_OK(parent_->CanAddArrowToSubstraitAggregate(arrow_function_name));
+ }
+ return AddArrowToSubstraitFunc(std::move(arrow_function_name), converter,
+ &arrow_to_substrait_agg_);
+ }
- const std::string& copied_function_name{
- *function_names_.emplace(std::move(arrow_function_name)).first};
+ Status CanAddArrowToSubstraitCall(const std::string& function_name) const override {
+ if (arrow_to_substrait_.find(function_name) != arrow_to_substrait_.end()) {
+ return Status::Invalid(
+ "Cannot register function converter because a converter already exists");
+ }
+ if (parent_) {
+ return parent_->CanAddArrowToSubstraitCall(function_name);
+ }
+ return Status::OK();
+ }
- auto index = static_cast<int>(function_ids_.size());
+ Status CanAddArrowToSubstraitAggregate(
+ const std::string& function_name) const override {
+ if (arrow_to_substrait_agg_.find(function_name) != arrow_to_substrait_agg_.end()) {
+ return Status::Invalid(
+ "Cannot register function converter because a converter already exists");
+ }
+ if (parent_) {
+ return parent_->CanAddArrowToSubstraitAggregate(function_name);
+ }
+ return Status::OK();
+ }
- auto it_success = function_id_to_index_.emplace(copied_id, index);
+ Result<SubstraitCallToArrow> GetSubstraitCallToArrow(
+ Id substrait_function_id) const override {
+ auto maybe_converter = substrait_to_arrow_.find(substrait_function_id);
+ if (maybe_converter == substrait_to_arrow_.end()) {
+ if (parent_) {
+ return parent_->GetSubstraitCallToArrow(substrait_function_id);
+ }
+ return Status::NotImplemented(
+ "No conversion function exists to convert the Substrait function ",
+ substrait_function_id.uri, "#", substrait_function_id.name,
+ " to an Arrow call expression");
+ }
+ return maybe_converter->second;
+ }
- if (!it_success.second) {
- return Status::Invalid("Function id was already registered");
+ Result<SubstraitAggregateToArrow> GetSubstraitAggregateToArrow(
+ Id substrait_function_id) const override {
+ auto maybe_converter = substrait_to_arrow_agg_.find(substrait_function_id);
+ if (maybe_converter == substrait_to_arrow_agg_.end()) {
+ if (parent_) {
+ return parent_->GetSubstraitAggregateToArrow(substrait_function_id);
+ }
+ return Status::NotImplemented(
+ "No conversion function exists to convert the Substrait aggregate function ",
+ substrait_function_id.uri, "#", substrait_function_id.name,
+ " to an Arrow aggregate");
}
+ return maybe_converter->second;
+ }
- if (!function_name_to_index_.emplace(copied_function_name, index).second) {
- function_id_to_index_.erase(it_success.first);
- return Status::Invalid("Function name was already registered");
+ Result<ArrowToSubstraitCall> GetArrowToSubstraitCall(
+ const std::string& arrow_function_name) const override {
+ auto maybe_converter = arrow_to_substrait_.find(arrow_function_name);
+ if (maybe_converter == arrow_to_substrait_.end()) {
+ if (parent_) {
+ return parent_->GetArrowToSubstraitCall(arrow_function_name);
+ }
+ return Status::NotImplemented(
+ "No conversion function exists to convert the Arrow function ",
+ arrow_function_name, " to a Substrait call");
}
+ return maybe_converter->second;
+ }
- function_name_ptrs_.push_back(&copied_function_name);
- function_ids_.push_back(copied_id);
- return Status::OK();
+ Result<ArrowToSubstraitAggregate> GetArrowToSubstraitAggregate(
+ const std::string& arrow_function_name) const override {
+ auto maybe_converter = arrow_to_substrait_agg_.find(arrow_function_name);
+ if (maybe_converter == arrow_to_substrait_agg_.end()) {
+ if (parent_) {
+ return parent_->GetArrowToSubstraitAggregate(arrow_function_name);
+ }
+ return Status::NotImplemented(
+ "No conversion function exists to convert the Arrow aggregate ",
+ arrow_function_name, " to a Substrait aggregate");
+ }
+ return maybe_converter->second;
}
- Status RegisterFunction(std::string uri, std::string name,
- std::string arrow_function_name) override {
- return RegisterFunction({uri, name}, arrow_function_name);
+ std::vector<std::string> GetSupportedSubstraitFunctions() const override {
+ std::vector<std::string> encoded_ids;
+ for (const auto& entry : substrait_to_arrow_) {
+ encoded_ids.push_back(entry.first.ToString());
+ }
+ for (const auto& entry : substrait_to_arrow_agg_) {
+ encoded_ids.push_back(entry.first.ToString());
+ }
+ if (parent_) {
+ std::vector<std::string> parent_ids = parent_->GetSupportedSubstraitFunctions();
+ encoded_ids.insert(encoded_ids.end(), make_move_iterator(parent_ids.begin()),
+ make_move_iterator(parent_ids.end()));
+ }
+ std::sort(encoded_ids.begin(), encoded_ids.end());
+ return encoded_ids;
}
- // owning storage of uris, names, (arrow::)function_names, types
- // note that storing strings like this is safe since references into an
- // unordered_set are not invalidated on insertion
- std::unordered_set<std::string> uris_, names_, function_names_;
+ // Defined below since it depends on some helper functions defined below
+ Status AddSubstraitCallToArrow(Id substrait_function_id,
+ std::string arrow_function_name) override;
+
+ // Parent registry, null for the root, non-null for nested
+ const ExtensionIdRegistry* parent_;
+
+ // owning storage of ids & types
+ IdStorage ids_;
DataTypeVector types_;
+ // There should only be one entry per Arrow function so there is no need
+ // to separate ownership and lookup
+ std::unordered_map<std::string, ArrowToSubstraitCall> arrow_to_substrait_;
+ std::unordered_map<std::string, ArrowToSubstraitAggregate> arrow_to_substrait_agg_;
// non-owning lookup helpers
- std::vector<Id> type_ids_, function_ids_;
+ std::vector<Id> type_ids_;
std::unordered_map<Id, int, IdHashEq, IdHashEq> id_to_index_;
std::unordered_map<const DataType*, int, TypePtrHashEq, TypePtrHashEq> type_to_index_;
-
- std::vector<const std::string*> function_name_ptrs_;
- std::unordered_map<Id, int, IdHashEq, IdHashEq> function_id_to_index_;
- std::unordered_map<util::string_view, int, ::arrow::internal::StringViewHash>
- function_name_to_index_;
+ std::unordered_map<Id, SubstraitCallToArrow, IdHashEq, IdHashEq> substrait_to_arrow_;
+ std::unordered_map<Id, SubstraitAggregateToArrow, IdHashEq, IdHashEq>
+ substrait_to_arrow_agg_;
};
-struct NestedExtensionIdRegistryImpl : ExtensionIdRegistryImpl {
- explicit NestedExtensionIdRegistryImpl(const ExtensionIdRegistry* parent)
- : parent_(parent) {}
-
- virtual ~NestedExtensionIdRegistryImpl() {}
+template <typename Enum>
+using EnumParser = std::function<Result<Enum>(util::optional<util::string_view>)>;
- std::vector<util::string_view> Uris() const override {
- std::vector<util::string_view> uris = parent_->Uris();
- std::unordered_set<util::string_view> uri_set;
- uri_set.insert(uris.begin(), uris.end());
- uri_set.insert(uris_.begin(), uris_.end());
- return std::vector<util::string_view>(uris);
+template <typename Enum>
+EnumParser<Enum> GetEnumParser(const std::vector<std::string>& options) {
+ std::unordered_map<std::string, Enum> parse_map;
+ for (std::size_t i = 0; i < options.size(); i++) {
+ parse_map[options[i]] = static_cast<Enum>(i + 1);
}
-
- util::optional<TypeRecord> GetType(const DataType& type) const override {
- auto type_opt = ExtensionIdRegistryImpl::GetType(type);
- if (type_opt) {
- return type_opt;
+ return [parse_map](util::optional<util::string_view> enum_val) -> Result<Enum> {
+ if (!enum_val) {
+ // Assumes 0 is always kUnspecified in Enum
+ return static_cast<Enum>(0);
}
- return parent_->GetType(type);
- }
-
- util::optional<TypeRecord> GetType(Id id) const override {
- auto type_opt = ExtensionIdRegistryImpl::GetType(id);
- if (type_opt) {
- return type_opt;
+ auto maybe_parsed = parse_map.find(enum_val->to_string());
+ if (maybe_parsed == parse_map.end()) {
+ return Status::Invalid("The value ", *enum_val, " is not an expected enum value");
}
- return parent_->GetType(id);
- }
+ return maybe_parsed->second;
+ };
+}
- Status CanRegisterType(Id id, const std::shared_ptr<DataType>& type) const override {
- return parent_->CanRegisterType(id, type) &
- ExtensionIdRegistryImpl::CanRegisterType(id, type);
- }
+enum class TemporalComponent { kUnspecified = 0, kYear, kMonth, kDay, kSecond };
+static std::vector<std::string> kTemporalComponentOptions = {"YEAR", "MONTH", "DAY",
+ "SECOND"};
+static EnumParser<TemporalComponent> kTemporalComponentParser =
+ GetEnumParser<TemporalComponent>(kTemporalComponentOptions);
+
+enum class OverflowBehavior { kUnspecified = 0, kSilent, kSaturate, kError };
+static std::vector<std::string> kOverflowOptions = {"SILENT", "SATURATE", "ERROR"};
+static EnumParser<OverflowBehavior> kOverflowParser =
+ GetEnumParser<OverflowBehavior>(kOverflowOptions);
+
+template <typename Enum>
+Result<Enum> ParseEnumArg(const SubstraitCall& call, uint32_t arg_index,
+ const EnumParser<Enum>& parser) {
+ ARROW_ASSIGN_OR_RAISE(util::optional<util::string_view> enum_arg,
+ call.GetEnumArg(arg_index));
+ return parser(enum_arg);
+}
- Status RegisterType(Id id, std::shared_ptr<DataType> type) override {
- return parent_->CanRegisterType(id, type) &
- ExtensionIdRegistryImpl::RegisterType(id, type);
+Result<std::vector<compute::Expression>> GetValueArgs(const SubstraitCall& call,
+ int start_index) {
+ std::vector<compute::Expression> expressions;
+ for (uint32_t index = start_index; index < call.size(); index++) {
+ ARROW_ASSIGN_OR_RAISE(compute::Expression arg, call.GetValueArg(index));
+ expressions.push_back(arg);
}
+ return std::move(expressions);
+}
- util::optional<FunctionRecord> GetFunction(
- util::string_view arrow_function_name) const override {
- auto func_opt = ExtensionIdRegistryImpl::GetFunction(arrow_function_name);
- if (func_opt) {
- return func_opt;
+ExtensionIdRegistry::SubstraitCallToArrow DecodeOptionlessOverflowableArithmetic(
+ const std::string& function_name) {
+ return [function_name](const SubstraitCall& call) -> Result<compute::Expression> {
+ ARROW_ASSIGN_OR_RAISE(OverflowBehavior overflow_behavior,
+ ParseEnumArg(call, 0, kOverflowParser));
+ ARROW_ASSIGN_OR_RAISE(std::vector<compute::Expression> value_args,
+ GetValueArgs(call, 1));
+ if (overflow_behavior == OverflowBehavior::kUnspecified) {
+ overflow_behavior = OverflowBehavior::kSilent;
}
- return parent_->GetFunction(arrow_function_name);
- }
+ if (overflow_behavior == OverflowBehavior::kSilent) {
+ return arrow::compute::call(function_name, std::move(value_args));
+ } else if (overflow_behavior == OverflowBehavior::kError) {
+ return arrow::compute::call(function_name + "_checked", std::move(value_args));
+ } else {
+ return Status::NotImplemented(
+ "Only SILENT and ERROR arithmetic kernels are currently implemented but ",
+ kOverflowOptions[static_cast<int>(overflow_behavior) - 1], " was requested");
+ }
+ };
+}
- util::optional<FunctionRecord> GetFunction(Id id) const override {
- auto func_opt = ExtensionIdRegistryImpl::GetFunction(id);
- if (func_opt) {
- return func_opt;
+template <bool kChecked>
+ExtensionIdRegistry::ArrowToSubstraitCall EncodeOptionlessOverflowableArithmetic(
+ Id substrait_fn_id) {
+ return
+ [substrait_fn_id](const compute::Expression::Call& call) -> Result<SubstraitCall> {
+ // nullable=true isn't quite correct but we don't know the nullability of
+ // the inputs
+ SubstraitCall substrait_call(substrait_fn_id, call.type.GetSharedPtr(),
+ /*nullable=*/true);
+ if (kChecked) {
+ substrait_call.SetEnumArg(0, "ERROR");
+ } else {
+ substrait_call.SetEnumArg(0, "SILENT");
+ }
+ for (std::size_t i = 0; i < call.arguments.size(); i++) {
+ substrait_call.SetValueArg(static_cast<uint32_t>(i + 1), call.arguments[i]);
+ }
+ return std::move(substrait_call);
+ };
+}
+
+ExtensionIdRegistry::SubstraitCallToArrow DecodeOptionlessBasicMapping(
+ const std::string& function_name, uint32_t max_args) {
+ return [function_name,
+ max_args](const SubstraitCall& call) -> Result<compute::Expression> {
+ if (call.size() > max_args) {
+ return Status::NotImplemented("Acero does not have a kernel for ", function_name,
+ " that receives ", call.size(), " arguments");
}
- return parent_->GetFunction(id);
- }
+ ARROW_ASSIGN_OR_RAISE(std::vector<compute::Expression> value_args,
+ GetValueArgs(call, 0));
+ return arrow::compute::call(function_name, std::move(value_args));
+ };
+}
- Status CanRegisterFunction(Id id,
- const std::string& arrow_function_name) const override {
- return parent_->CanRegisterFunction(id, arrow_function_name) &
- ExtensionIdRegistryImpl::CanRegisterFunction(id, arrow_function_name);
- }
+ExtensionIdRegistry::SubstraitCallToArrow DecodeTemporalExtractionMapping() {
+ return [](const SubstraitCall& call) -> Result<compute::Expression> {
+ ARROW_ASSIGN_OR_RAISE(TemporalComponent temporal_component,
+ ParseEnumArg(call, 0, kTemporalComponentParser));
+ if (temporal_component == TemporalComponent::kUnspecified) {
+ return Status::Invalid(
+ "The temporal component enum is a require option for the extract function "
+ "and is not specified");
+ }
+ ARROW_ASSIGN_OR_RAISE(std::vector<compute::Expression> value_args,
+ GetValueArgs(call, 1));
+ std::string func_name;
+ switch (temporal_component) {
+ case TemporalComponent::kYear:
+ func_name = "year";
+ break;
+ case TemporalComponent::kMonth:
+ func_name = "month";
+ break;
+ case TemporalComponent::kDay:
+ func_name = "day";
+ break;
+ case TemporalComponent::kSecond:
+ func_name = "second";
+ break;
+ default:
+ return Status::Invalid("Unexpected value for temporal component in extract call");
+ }
+ return compute::call(func_name, std::move(value_args));
+ };
+}
- Status RegisterFunction(Id id, std::string arrow_function_name) override {
- return parent_->CanRegisterFunction(id, arrow_function_name) &
- ExtensionIdRegistryImpl::RegisterFunction(id, arrow_function_name);
- }
+ExtensionIdRegistry::SubstraitCallToArrow DecodeConcatMapping() {
+ return [](const SubstraitCall& call) -> Result<compute::Expression> {
+ ARROW_ASSIGN_OR_RAISE(std::vector<compute::Expression> value_args,
+ GetValueArgs(call, 0));
+ value_args.push_back(compute::literal(""));
+ return compute::call("binary_join_element_wise", std::move(value_args));
+ };
+}
- const ExtensionIdRegistry* parent_;
-};
+ExtensionIdRegistry::SubstraitAggregateToArrow DecodeBasicAggregate(
+ const std::string& arrow_function_name) {
+ return [arrow_function_name](const SubstraitCall& call) -> Result<compute::Aggregate> {
+ if (call.size() != 1) {
+ return Status::NotImplemented(
+ "Only unary aggregate functions are currently supported");
+ }
+ ARROW_ASSIGN_OR_RAISE(compute::Expression arg, call.GetValueArg(0));
+ const FieldRef* arg_ref = arg.field_ref();
+ if (!arg_ref) {
+ return Status::Invalid("Expected an aggregate call ", call.id().uri, "#",
+ call.id().name, " to have a direct reference");
+ }
+ std::string fixed_arrow_func = arrow_function_name;
+ if (call.is_hash()) {
+ fixed_arrow_func = "hash_" + arrow_function_name;
+ }
+ return compute::Aggregate{std::move(fixed_arrow_func), nullptr, *arg_ref, ""};
+ };
+}
struct DefaultExtensionIdRegistry : ExtensionIdRegistryImpl {
DefaultExtensionIdRegistry() {
+ // ----------- Extension Types ----------------------------
struct TypeName {
std::shared_ptr<DataType> type;
util::string_view name;
@@ -428,32 +795,91 @@ struct DefaultExtensionIdRegistry : ExtensionIdRegistryImpl {
DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type)));
}
- for (TypeName e : {
- TypeName{null(), "null"},
- TypeName{month_interval(), "interval_month"},
- TypeName{day_time_interval(), "interval_day_milli"},
- TypeName{month_day_nano_interval(), "interval_month_day_nano"},
- }) {
+ for (TypeName e :
+ {TypeName{null(), "null"}, TypeName{month_interval(), "interval_month"},
+ TypeName{day_time_interval(), "interval_day_milli"},
+ TypeName{month_day_nano_interval(), "interval_month_day_nano"}}) {
DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type)));
}
- // TODO: this is just a placeholder right now. We'll need a YAML file for
- // all functions (and prototypes) that Arrow provides that are relevant
- // for Substrait, and include mappings for all of them here. See
- // ARROW-15535.
- for (util::string_view name : {
- "add",
- "equal",
- "is_not_distinct_from",
- "hash_count",
- }) {
- DCHECK_OK(RegisterFunction({kArrowExtTypesUri, name}, name.to_string()));
+ // -------------- Substrait -> Arrow Functions -----------------
+ // Mappings with a _checked variant
+ for (const auto& function_name : {"add", "subtract", "multiply", "divide"}) {
+ DCHECK_OK(
+ AddSubstraitCallToArrow({kSubstraitArithmeticFunctionsUri, function_name},
+ DecodeOptionlessOverflowableArithmetic(function_name)));
+ }
+ // Basic mappings that need _kleene appended to them
+ for (const auto& function_name : {"or", "and"}) {
+ DCHECK_OK(AddSubstraitCallToArrow(
+ {kSubstraitBooleanFunctionsUri, function_name},
+ DecodeOptionlessBasicMapping(std::string(function_name) + "_kleene",
+ /*max_args=*/2)));
+ }
+ // Basic binary mappings
+ for (const auto& function_name :
+ std::vector<std::pair<util::string_view, util::string_view>>{
+ {kSubstraitBooleanFunctionsUri, "xor"},
+ {kSubstraitComparisonFunctionsUri, "equal"},
+ {kSubstraitComparisonFunctionsUri, "not_equal"}}) {
+ DCHECK_OK(
+ AddSubstraitCallToArrow({function_name.first, function_name.second},
+ DecodeOptionlessBasicMapping(
+ function_name.second.to_string(), /*max_args=*/2)));
+ }
+ for (const auto& uri :
+ {kSubstraitComparisonFunctionsUri, kSubstraitDatetimeFunctionsUri}) {
+ DCHECK_OK(AddSubstraitCallToArrow(
+ {uri, "lt"}, DecodeOptionlessBasicMapping("less", /*max_args=*/2)));
+ DCHECK_OK(AddSubstraitCallToArrow(
+ {uri, "lte"}, DecodeOptionlessBasicMapping("less_equal", /*max_args=*/2)));
+ DCHECK_OK(AddSubstraitCallToArrow(
+ {uri, "gt"}, DecodeOptionlessBasicMapping("greater", /*max_args=*/2)));
+ DCHECK_OK(AddSubstraitCallToArrow(
+ {uri, "gte"}, DecodeOptionlessBasicMapping("greater_equal", /*max_args=*/2)));
+ }
+ // One-off mappings
+ DCHECK_OK(
+ AddSubstraitCallToArrow({kSubstraitBooleanFunctionsUri, "not"},
+ DecodeOptionlessBasicMapping("invert", /*max_args=*/1)));
+ DCHECK_OK(AddSubstraitCallToArrow({kSubstraitDatetimeFunctionsUri, "extract"},
+ DecodeTemporalExtractionMapping()));
+ DCHECK_OK(AddSubstraitCallToArrow({kSubstraitStringFunctionsUri, "concat"},
+ DecodeConcatMapping()));
+
+ // --------------- Substrait -> Arrow Aggregates --------------
+ for (const auto& fn_name : {"sum", "min", "max"}) {
+ DCHECK_OK(AddSubstraitAggregateToArrow({kSubstraitArithmeticFunctionsUri, fn_name},
+ DecodeBasicAggregate(fn_name)));
+ }
+ DCHECK_OK(AddSubstraitAggregateToArrow({kSubstraitArithmeticFunctionsUri, "avg"},
+ DecodeBasicAggregate("mean")));
+
+ // --------------- Arrow -> Substrait Functions ---------------
+ for (const auto& fn_name : {"add", "subtract", "multiply", "divide"}) {
+ Id fn_id{kSubstraitArithmeticFunctionsUri, fn_name};
+ DCHECK_OK(AddArrowToSubstraitCall(
+ fn_name, EncodeOptionlessOverflowableArithmetic<false>(fn_id)));
+ DCHECK_OK(
+ AddArrowToSubstraitCall(std::string(fn_name) + "_checked",
+ EncodeOptionlessOverflowableArithmetic<true>(fn_id)));
}
}
};
} // namespace
+Status ExtensionIdRegistryImpl::AddSubstraitCallToArrow(Id substrait_function_id,
+ std::string arrow_function_name) {
+ return AddSubstraitCallToArrow(
+ substrait_function_id,
+ [arrow_function_name](const SubstraitCall& call) -> Result<compute::Expression> {
+ ARROW_ASSIGN_OR_RAISE(std::vector<compute::Expression> value_args,
+ GetValueArgs(call, 0));
+ return compute::call(arrow_function_name, std::move(value_args));
+ });
+}
+
ExtensionIdRegistry* default_extension_id_registry() {
static DefaultExtensionIdRegistry impl_;
return &impl_;
@@ -461,7 +887,7 @@ ExtensionIdRegistry* default_extension_id_registry() {
std::shared_ptr<ExtensionIdRegistry> nested_extension_id_registry(
const ExtensionIdRegistry* parent) {
- return std::make_shared<NestedExtensionIdRegistryImpl>(parent);
+ return std::make_shared<ExtensionIdRegistryImpl>(parent);
}
} // namespace engine
diff --git a/cpp/src/arrow/engine/substrait/extension_set.h b/cpp/src/arrow/engine/substrait/extension_set.h
index 04e4586a9f..9cb42f6613 100644
--- a/cpp/src/arrow/engine/substrait/extension_set.h
+++ b/cpp/src/arrow/engine/substrait/extension_set.h
@@ -19,26 +19,130 @@
#pragma once
+#include <list>
#include <unordered_map>
+#include <unordered_set>
#include <vector>
+#include "arrow/compute/exec/exec_plan.h"
+#include "arrow/compute/exec/expression.h"
#include "arrow/engine/substrait/visibility.h"
+#include "arrow/result.h"
#include "arrow/type_fwd.h"
+#include "arrow/util/hash_util.h"
+#include "arrow/util/hashing.h"
#include "arrow/util/optional.h"
#include "arrow/util/string_view.h"
-#include "arrow/util/hash_util.h"
-
namespace arrow {
namespace engine {
+constexpr const char* kSubstraitArithmeticFunctionsUri =
+ "https://github.com/substrait-io/substrait/blob/main/extensions/"
+ "functions_arithmetic.yaml";
+constexpr const char* kSubstraitBooleanFunctionsUri =
+ "https://github.com/substrait-io/substrait/blob/main/extensions/"
+ "functions_boolean.yaml";
+constexpr const char* kSubstraitComparisonFunctionsUri =
+ "https://github.com/substrait-io/substrait/blob/main/extensions/"
+ "functions_comparison.yaml";
+constexpr const char* kSubstraitDatetimeFunctionsUri =
+ "https://github.com/substrait-io/substrait/blob/main/extensions/"
+ "functions_datetime.yaml";
+constexpr const char* kSubstraitStringFunctionsUri =
+ "https://github.com/substrait-io/substrait/blob/main/extensions/"
+ "functions_string.yaml";
+
+struct Id {
+ util::string_view uri, name;
+ bool empty() const { return uri.empty() && name.empty(); }
+ std::string ToString() const;
+};
+struct IdHashEq {
+ size_t operator()(Id id) const;
+ bool operator()(Id l, Id r) const;
+};
+
+/// \brief Owning storage for ids
+///
+/// Substrait plans may reuse URIs and names in many places. For convenience
+/// and performance Substarit ids are typically passed around as views. As we
+/// convert a plan from Substrait to Arrow we need to copy these strings out of
+/// the Substrait buffer and into owned storage. This class serves as that owned
+/// storage.
+class IdStorage {
+ public:
+ /// \brief Get an equivalent id pointing into this storage
+ ///
+ /// This operation will copy the ids into storage if they do not already exist
+ Id Emplace(Id id);
+ /// \brief Get an equivalent view pointing into this storage for a URI
+ ///
+ /// If no URI is found then the uri will be copied into storage
+ util::string_view EmplaceUri(util::string_view uri);
+ /// \brief Get an equivalent id pointing into this storage
+ ///
+ /// If no id is found then nullopt will be returned
+ util::optional<Id> Find(Id id) const;
+ /// \brief Get an equivalent view pointing into this storage for a URI
+ ///
+ /// If no URI is found then nullopt will be returned
+ util::optional<util::string_view> FindUri(util::string_view uri) const;
+
+ private:
+ std::unordered_set<util::string_view, ::arrow::internal::StringViewHash> uris_;
+ std::unordered_set<util::string_view, ::arrow::internal::StringViewHash> names_;
+ std::list<std::string> owned_uris_;
+ std::list<std::string> owned_names_;
+};
+
+/// \brief Describes a Substrait call
+///
+/// Substrait call expressions contain a list of arguments which can either
+/// be enum arguments (which are serialized as strings), value arguments (which)
+/// are Arrow expressions, or type arguments (not yet implemented)
+class SubstraitCall {
+ public:
+ SubstraitCall(Id id, std::shared_ptr<DataType> output_type, bool output_nullable,
+ bool is_hash = false)
+ : id_(id),
+ output_type_(std::move(output_type)),
+ output_nullable_(output_nullable),
+ is_hash_(is_hash) {}
+
+ const Id& id() const { return id_; }
+ const std::shared_ptr<DataType>& output_type() const { return output_type_; }
+ bool output_nullable() const { return output_nullable_; }
+ bool is_hash() const { return is_hash_; }
+
+ bool HasEnumArg(uint32_t index) const;
+ Result<util::optional<util::string_view>> GetEnumArg(uint32_t index) const;
+ void SetEnumArg(uint32_t index, util::optional<std::string> enum_arg);
+ Result<compute::Expression> GetValueArg(uint32_t index) const;
+ bool HasValueArg(uint32_t index) const;
+ void SetValueArg(uint32_t index, compute::Expression value_arg);
+ uint32_t size() const { return size_; }
+
+ private:
+ Id id_;
+ std::shared_ptr<DataType> output_type_;
+ bool output_nullable_;
+ // Only needed when converting from Substrait -> Arrow aggregates. The
+ // Arrow function name depends on whether or not there are any groups
+ bool is_hash_;
+ std::unordered_map<uint32_t, util::optional<std::string>> enum_args_;
+ std::unordered_map<uint32_t, compute::Expression> value_args_;
+ uint32_t size_ = 0;
+};
+
/// Substrait identifies functions and custom data types using a (uri, name) pair.
///
-/// This registry is a bidirectional mapping between Substrait IDs and their corresponding
-/// Arrow counterparts (arrow::DataType and function names in a function registry)
+/// This registry is a bidirectional mapping between Substrait IDs and their
+/// corresponding Arrow counterparts (arrow::DataType and function names in a function
+/// registry)
///
-/// Substrait extension types and variations must be registered with their corresponding
-/// arrow::DataType before they can be used!
+/// Substrait extension types and variations must be registered with their
+/// corresponding arrow::DataType before they can be used!
///
/// Conceptually this can be thought of as two pairs of `unordered_map`s. One pair to
/// go back and forth between Substrait ID and arrow::DataType and another pair to go
@@ -49,56 +153,103 @@ namespace engine {
/// instance).
class ARROW_ENGINE_EXPORT ExtensionIdRegistry {
public:
- /// All uris registered in this ExtensionIdRegistry
- virtual std::vector<util::string_view> Uris() const = 0;
-
- struct Id {
- util::string_view uri, name;
-
- bool empty() const { return uri.empty() && name.empty(); }
- };
-
- struct IdHashEq {
- size_t operator()(Id id) const;
- bool operator()(Id l, Id r) const;
- };
+ using ArrowToSubstraitCall =
+ std::function<Result<SubstraitCall>(const arrow::compute::Expression::Call&)>;
+ using SubstraitCallToArrow =
+ std::function<Result<arrow::compute::Expression>(const SubstraitCall&)>;
+ using ArrowToSubstraitAggregate =
+ std::function<Result<SubstraitCall>(const arrow::compute::Aggregate&)>;
+ using SubstraitAggregateToArrow =
+ std::function<Result<arrow::compute::Aggregate>(const SubstraitCall&)>;
/// \brief A mapping between a Substrait ID and an arrow::DataType
struct TypeRecord {
Id id;
const std::shared_ptr<DataType>& type;
};
+
+ /// \brief Return a uri view owned by this registry
+ ///
+ /// If the URI has never been emplaced it will return nullopt
+ virtual util::optional<util::string_view> FindUri(util::string_view uri) const = 0;
+ /// \brief Return a id view owned by this registry
+ ///
+ /// If the id has never been emplaced it will return nullopt
+ virtual util::optional<Id> FindId(Id id) const = 0;
virtual util::optional<TypeRecord> GetType(const DataType&) const = 0;
virtual util::optional<TypeRecord> GetType(Id) const = 0;
virtual Status CanRegisterType(Id, const std::shared_ptr<DataType>& type) const = 0;
virtual Status RegisterType(Id, std::shared_ptr<DataType>) = 0;
+ /// \brief Register a converter that converts an Arrow call to a Substrait call
+ ///
+ /// Note that there may not be 1:1 parity between ArrowToSubstraitCall and
+ /// SubstraitCallToArrow because some standard functions (e.g. add) may map to
+ /// multiple Arrow functions (e.g. add, add_checked)
+ virtual Status AddArrowToSubstraitCall(std::string arrow_function_name,
+ ArrowToSubstraitCall conversion_func) = 0;
+ /// \brief Check to see if a converter can be registered
+ ///
+ /// \return Status::OK if there are no conflicts, otherwise an error is returned
+ virtual Status CanAddArrowToSubstraitCall(
+ const std::string& arrow_function_name) const = 0;
- /// \brief A mapping between a Substrait ID and an Arrow function
+ /// \brief Register a converter that converts an Arrow aggregate to a Substrait
+ /// aggregate
+ virtual Status AddArrowToSubstraitAggregate(
+ std::string arrow_function_name, ArrowToSubstraitAggregate conversion_func) = 0;
+ /// \brief Check to see if a converter can be registered
///
- /// Note: At the moment we identify functions solely by the name
- /// of the function in the function registry.
+ /// \return Status::OK if there are no conflicts, otherwise an error is returned
+ virtual Status CanAddArrowToSubstraitAggregate(
+ const std::string& arrow_function_name) const = 0;
+
+ /// \brief Register a converter that converts a Substrait call to an Arrow call
+ virtual Status AddSubstraitCallToArrow(Id substrait_function_id,
+ SubstraitCallToArrow conversion_func) = 0;
+ /// \brief Check to see if a converter can be registered
///
- /// TODO(ARROW-15582) some functions will not be simple enough to convert without access
- /// to their arguments/options. For example is_in embeds the set in options rather than
- /// using an argument:
- /// is_in(x, SetLookupOptions(set)) <-> (k...Uri, "is_in")(x, set)
+ /// \return Status::OK if there are no conflicts, otherwise an error is returned
+ virtual Status CanAddSubstraitCallToArrow(Id substrait_function_id) const = 0;
+ /// \brief Register a simple mapping function
///
- /// ... for another example, depending on the value of the first argument to
- /// substrait::add it either corresponds to arrow::add or arrow::add_checked
- struct FunctionRecord {
- Id id;
- const std::string& function_name;
- };
- virtual util::optional<FunctionRecord> GetFunction(Id) const = 0;
- virtual util::optional<FunctionRecord> GetFunction(
- util::string_view arrow_function_name) const = 0;
- virtual Status CanRegisterFunction(Id,
- const std::string& arrow_function_name) const = 0;
- // registers a function without taking ownership of uri and name within Id
- virtual Status RegisterFunction(Id, std::string arrow_function_name) = 0;
- // registers a function while taking ownership of uri and name
- virtual Status RegisterFunction(std::string uri, std::string name,
- std::string arrow_function_name) = 0;
+ /// All calls to the function must pass only value arguments. The arguments
+ /// will be converted to expressions and passed to the Arrow function
+ virtual Status AddSubstraitCallToArrow(Id substrait_function_id,
+ std::string arrow_function_name) = 0;
+
+ /// \brief Register a converter that converts a Substrait aggregate to an Arrow
+ /// aggregate
+ virtual Status AddSubstraitAggregateToArrow(
+ Id substrait_function_id, SubstraitAggregateToArrow conversion_func) = 0;
+ /// \brief Check to see if a converter can be registered
+ ///
+ /// \return Status::OK if there are no conflicts, otherwise an error is returned
+ virtual Status CanAddSubstraitAggregateToArrow(Id substrait_function_id) const = 0;
+
+ /// \brief Return a list of Substrait functions that have a converter
+ ///
+ /// The function ids are encoded as strings using the pattern {uri}#{name}
+ virtual std::vector<std::string> GetSupportedSubstraitFunctions() const = 0;
+
+ /// \brief Find a converter to map Arrow calls to Substrait calls
+ /// \return A converter function or an invalid status if no converter is registered
+ virtual Result<ArrowToSubstraitCall> GetArrowToSubstraitCall(
+ const std::string& arrow_function_name) const = 0;
+
+ /// \brief Find a converter to map Arrow aggregates to Substrait aggregates
+ /// \return A converter function or an invalid status if no converter is registered
+ virtual Result<ArrowToSubstraitAggregate> GetArrowToSubstraitAggregate(
+ const std::string& arrow_function_name) const = 0;
+
+ /// \brief Find a converter to map a Substrait aggregate to an Arrow aggregate
+ /// \return A converter function or an invalid status if no converter is registered
+ virtual Result<SubstraitAggregateToArrow> GetSubstraitAggregateToArrow(
+ Id substrait_function_id) const = 0;
+
+ /// \brief Find a converter to map a Substrait call to an Arrow call
+ /// \return A converter function or an invalid status if no converter is registered
+ virtual Result<SubstraitCallToArrow> GetSubstraitCallToArrow(
+ Id substrait_function_id) const = 0;
};
constexpr util::string_view kArrowExtTypesUri =
@@ -153,9 +304,6 @@ ARROW_ENGINE_EXPORT std::shared_ptr<ExtensionIdRegistry> nested_extension_id_reg
/// ExtensionIdRegistry.
class ARROW_ENGINE_EXPORT ExtensionSet {
public:
- using Id = ExtensionIdRegistry::Id;
- using IdHashEq = ExtensionIdRegistry::IdHashEq;
-
struct FunctionRecord {
Id id;
util::string_view name;
@@ -219,12 +367,12 @@ class ARROW_ENGINE_EXPORT ExtensionSet {
/// \return An anchor that can be used to refer to the type within a plan
Result<uint32_t> EncodeType(const DataType& type);
- /// \brief Returns a function given an anchor
+ /// \brief Return a function id given an anchor
///
/// This is used when converting a Substrait plan to an Arrow execution plan.
///
/// If the anchor does not exist in this extension set an error will be returned.
- Result<FunctionRecord> DecodeFunction(uint32_t anchor) const;
+ Result<Id> DecodeFunction(uint32_t anchor) const;
/// \brief Lookup the anchor for a given function
///
@@ -239,26 +387,30 @@ class ARROW_ENGINE_EXPORT ExtensionSet {
/// returned.
///
/// \return An anchor that can be used to refer to the function within a plan
- Result<uint32_t> EncodeFunction(util::string_view function_name);
+ Result<uint32_t> EncodeFunction(Id function_id);
- /// \brief Returns the number of custom functions in this extension set
- ///
- /// Note: the functions are currently stored as a sparse vector, so this may return a
- /// value larger than the actual number of functions. This behavior may change in the
- /// future; see ARROW-15583.
+ /// \brief Return the number of custom functions in this extension set
std::size_t num_functions() const { return functions_.size(); }
+ const ExtensionIdRegistry* registry() const { return registry_; }
+
private:
const ExtensionIdRegistry* registry_;
+ // If the registry is not aware of an id then we probably can't do anything
+ // with it. However, in some cases, these may represent extensions or features
+ // that we can safely ignore. For example, we can usually safely ignore
+ // extension type variations if we assume the plan is valid. These ignorable
+ // ids are stored here.
+ IdStorage plan_specific_ids_;
// Map from anchor values to URI values referenced by this extension set
std::unordered_map<uint32_t, util::string_view> uris_;
// Map from anchor values to type definitions, used during Substrait->Arrow
// and populated from the Substrait extension set
std::unordered_map<uint32_t, TypeRecord> types_;
- // Map from anchor values to function definitions, used during Substrait->Arrow
+ // Map from anchor values to function ids, used during Substrait->Arrow
// and populated from the Substrait extension set
- std::unordered_map<uint32_t, FunctionRecord> functions_;
+ std::unordered_map<uint32_t, Id> functions_;
// Map from type names to anchor values. Used during Arrow->Substrait
// and built as the plan is created.
std::unordered_map<Id, uint32_t, IdHashEq, IdHashEq> types_map_;
diff --git a/cpp/src/arrow/engine/substrait/function_test.cc b/cpp/src/arrow/engine/substrait/function_test.cc
new file mode 100644
index 0000000000..225bc56d13
--- /dev/null
+++ b/cpp/src/arrow/engine/substrait/function_test.cc
@@ -0,0 +1,495 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/array.h"
+#include "arrow/array/builder_binary.h"
+#include "arrow/compute/cast.h"
+#include "arrow/compute/exec/options.h"
+#include "arrow/compute/exec/util.h"
+#include "arrow/engine/substrait/extension_set.h"
+#include "arrow/engine/substrait/plan_internal.h"
+#include "arrow/engine/substrait/serde.h"
+#include "arrow/engine/substrait/test_plan_builder.h"
+#include "arrow/engine/substrait/type_internal.h"
+#include "arrow/record_batch.h"
+#include "arrow/table.h"
+#include "arrow/testing/future_util.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/type.h"
+
+namespace arrow {
+
+namespace engine {
+struct FunctionTestCase {
+ Id function_id;
+ std::vector<std::string> arguments;
+ std::vector<std::shared_ptr<DataType>> data_types;
+ // For a test case that should fail just use the empty string
+ std::string expected_output;
+ std::shared_ptr<DataType> expected_output_type;
+};
+
+Result<std::shared_ptr<Array>> GetArray(const std::string& value,
+ const std::shared_ptr<DataType>& data_type) {
+ StringBuilder str_builder;
+ if (value.empty()) {
+ ARROW_EXPECT_OK(str_builder.AppendNull());
+ } else {
+ ARROW_EXPECT_OK(str_builder.Append(value));
+ }
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Array> value_str, str_builder.Finish());
+ ARROW_ASSIGN_OR_RAISE(Datum value_datum, compute::Cast(value_str, data_type));
+ return value_datum.make_array();
+}
+
+Result<std::shared_ptr<Table>> GetInputTable(
+ const std::vector<std::string>& arguments,
+ const std::vector<std::shared_ptr<DataType>>& data_types) {
+ std::vector<std::shared_ptr<Array>> columns;
+ std::vector<std::shared_ptr<Field>> fields;
+ EXPECT_EQ(arguments.size(), data_types.size());
+ for (std::size_t i = 0; i < arguments.size(); i++) {
+ if (data_types[i]) {
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Array> arg_array,
+ GetArray(arguments[i], data_types[i]));
+ columns.push_back(std::move(arg_array));
+ fields.push_back(field("arg_" + std::to_string(i), data_types[i]));
+ }
+ }
+ std::shared_ptr<RecordBatch> batch =
+ RecordBatch::Make(schema(std::move(fields)), 1, columns);
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Table> table, Table::FromRecordBatches({batch}));
+ return table;
+}
+
+Result<std::shared_ptr<Table>> GetOutputTable(
+ const std::string& output_value, const std::shared_ptr<DataType>& output_type) {
+ std::vector<std::shared_ptr<Array>> columns(1);
+ std::vector<std::shared_ptr<Field>> fields(1);
+ ARROW_ASSIGN_OR_RAISE(columns[0], GetArray(output_value, output_type));
+ fields[0] = field("output", output_type);
+ std::shared_ptr<RecordBatch> batch =
+ RecordBatch::Make(schema(std::move(fields)), 1, columns);
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Table> table, Table::FromRecordBatches({batch}));
+ return table;
+}
+
+Result<std::shared_ptr<compute::ExecPlan>> PlanFromTestCase(
+ const FunctionTestCase& test_case, std::shared_ptr<Table>* output_table) {
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Table> input_table,
+ GetInputTable(test_case.arguments, test_case.data_types));
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> substrait,
+ internal::CreateScanProjectSubstrait(
+ test_case.function_id, input_table, test_case.arguments,
+ test_case.data_types, *test_case.expected_output_type));
+ std::shared_ptr<compute::SinkNodeConsumer> consumer =
+ std::make_shared<compute::TableSinkNodeConsumer>(output_table,
+ default_memory_pool());
+
+ // Mock table provider that ignores the table name and returns input_table
+ NamedTableProvider table_provider = [input_table](const std::vector<std::string>&) {
+ std::shared_ptr<compute::ExecNodeOptions> options =
+ std::make_shared<compute::TableSourceNodeOptions>(input_table);
+ return compute::Declaration("table_source", {}, options, "mock_source");
+ };
+
+ ConversionOptions conversion_options;
+ conversion_options.named_table_provider = std::move(table_provider);
+
+ ARROW_ASSIGN_OR_RAISE(
+ std::shared_ptr<compute::ExecPlan> plan,
+ DeserializePlan(*substrait, std::move(consumer), default_extension_id_registry(),
+ /*ext_set_out=*/nullptr, conversion_options));
+ return plan;
+}
+
+void CheckValidTestCases(const std::vector<FunctionTestCase>& valid_cases) {
+ for (const FunctionTestCase& test_case : valid_cases) {
+ std::shared_ptr<Table> output_table;
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<compute::ExecPlan> plan,
+ PlanFromTestCase(test_case, &output_table));
+ ASSERT_OK(plan->StartProducing());
+ ASSERT_FINISHES_OK(plan->finished());
+
+ // Could also modify the Substrait plan with an emit to drop the leading columns
+ ASSERT_OK_AND_ASSIGN(output_table,
+ output_table->SelectColumns({output_table->num_columns() - 1}));
+
+ ASSERT_OK_AND_ASSIGN(
+ std::shared_ptr<Table> expected_output,
+ GetOutputTable(test_case.expected_output, test_case.expected_output_type));
+ AssertTablesEqual(*expected_output, *output_table, /*same_chunk_layout=*/false);
+ }
+}
+
+void CheckErrorTestCases(const std::vector<FunctionTestCase>& error_cases) {
+ for (const FunctionTestCase& test_case : error_cases) {
+ std::shared_ptr<Table> output_table;
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<compute::ExecPlan> plan,
+ PlanFromTestCase(test_case, &output_table));
+ ASSERT_OK(plan->StartProducing());
+ ASSERT_FINISHES_AND_RAISES(Invalid, plan->finished());
+ }
+}
+
+// These are not meant to be an exhaustive test of Substrait
+// conformance. Instead, we should test just enough to ensure
+// we are mapping to the correct function
+TEST(FunctionMapping, ValidCases) {
+ const std::vector<FunctionTestCase> valid_test_cases = {
+ {{kSubstraitArithmeticFunctionsUri, "add"},
+ {"SILENT", "127", "10"},
+ {nullptr, int8(), int8()},
+ "-119",
+ int8()},
+ {{kSubstraitArithmeticFunctionsUri, "subtract"},
+ {"SILENT", "-119", "10"},
+ {nullptr, int8(), int8()},
+ "127",
+ int8()},
+ {{kSubstraitArithmeticFunctionsUri, "multiply"},
+ {"SILENT", "10", "13"},
+ {nullptr, int8(), int8()},
+ "-126",
+ int8()},
+ {{kSubstraitArithmeticFunctionsUri, "divide"},
+ {"SILENT", "-128", "-1"},
+ {nullptr, int8(), int8()},
+ "0",
+ int8()},
+ {{kSubstraitBooleanFunctionsUri, "or"},
+ {"1", ""},
+ {boolean(), boolean()},
+ "1",
+ boolean()},
+ {{kSubstraitBooleanFunctionsUri, "and"},
+ {"1", ""},
+ {boolean(), boolean()},
+ "",
+ boolean()},
+ {{kSubstraitBooleanFunctionsUri, "xor"},
+ {"1", "1"},
+ {boolean(), boolean()},
+ "0",
+ boolean()},
+ {{kSubstraitBooleanFunctionsUri, "not"}, {"1"}, {boolean()}, "0", boolean()},
+ {{kSubstraitComparisonFunctionsUri, "equal"},
+ {"57", "57"},
+ {int8(), int8()},
+ "1",
+ boolean()},
+ {{kSubstraitComparisonFunctionsUri, "not_equal"},
+ {"57", "57"},
+ {int8(), int8()},
+ "0",
+ boolean()},
+ {{kSubstraitComparisonFunctionsUri, "lt"},
+ {"57", "80"},
+ {int8(), int8()},
+ "1",
+ boolean()},
+ {{kSubstraitComparisonFunctionsUri, "lt"},
+ {"57", "57"},
+ {int8(), int8()},
+ "0",
+ boolean()},
+ {{kSubstraitComparisonFunctionsUri, "gt"},
+ {"57", "30"},
+ {int8(), int8()},
+ "1",
+ boolean()},
+ {{kSubstraitComparisonFunctionsUri, "gt"},
+ {"57", "57"},
+ {int8(), int8()},
+ "0",
+ boolean()},
+ {{kSubstraitComparisonFunctionsUri, "lte"},
+ {"57", "57"},
+ {int8(), int8()},
+ "1",
+ boolean()},
+ {{kSubstraitComparisonFunctionsUri, "lte"},
+ {"50", "57"},
+ {int8(), int8()},
+ "1",
+ boolean()},
+ {{kSubstraitComparisonFunctionsUri, "gte"},
+ {"57", "57"},
+ {int8(), int8()},
+ "1",
+ boolean()},
+ {{kSubstraitComparisonFunctionsUri, "gte"},
+ {"60", "57"},
+ {int8(), int8()},
+ "1",
+ boolean()},
+ {{kSubstraitDatetimeFunctionsUri, "extract"},
+ {"YEAR", "2022-07-15T14:33:14"},
+ {nullptr, timestamp(TimeUnit::MICRO)},
+ "2022",
+ int64()},
+ {{kSubstraitDatetimeFunctionsUri, "extract"},
+ {"MONTH", "2022-07-15T14:33:14"},
+ {nullptr, timestamp(TimeUnit::MICRO)},
+ "7",
+ int64()},
+ {{kSubstraitDatetimeFunctionsUri, "extract"},
+ {"DAY", "2022-07-15T14:33:14"},
+ {nullptr, timestamp(TimeUnit::MICRO)},
+ "15",
+ int64()},
+ {{kSubstraitDatetimeFunctionsUri, "extract"},
+ {"SECOND", "2022-07-15T14:33:14"},
+ {nullptr, timestamp(TimeUnit::MICRO)},
+ "14",
+ int64()},
+ {{kSubstraitDatetimeFunctionsUri, "extract"},
+ {"YEAR", "2022-07-15T14:33:14Z"},
+ {nullptr, timestamp(TimeUnit::MICRO, "UTC")},
+ "2022",
+ int64()},
+ {{kSubstraitDatetimeFunctionsUri, "extract"},
+ {"MONTH", "2022-07-15T14:33:14Z"},
+ {nullptr, timestamp(TimeUnit::MICRO, "UTC")},
+ "7",
+ int64()},
+ {{kSubstraitDatetimeFunctionsUri, "extract"},
+ {"DAY", "2022-07-15T14:33:14Z"},
+ {nullptr, timestamp(TimeUnit::MICRO, "UTC")},
+ "15",
+ int64()},
+ {{kSubstraitDatetimeFunctionsUri, "extract"},
+ {"SECOND", "2022-07-15T14:33:14Z"},
+ {nullptr, timestamp(TimeUnit::MICRO, "UTC")},
+ "14",
+ int64()},
+ {{kSubstraitDatetimeFunctionsUri, "lt"},
+ {"2022-07-15T14:33:14", "2022-07-15T14:33:20"},
+ {timestamp(TimeUnit::MICRO), timestamp(TimeUnit::MICRO)},
+ "1",
+ boolean()},
+ {{kSubstraitDatetimeFunctionsUri, "lte"},
+ {"2022-07-15T14:33:14", "2022-07-15T14:33:14"},
+ {timestamp(TimeUnit::MICRO), timestamp(TimeUnit::MICRO)},
+ "1",
+ boolean()},
+ {{kSubstraitDatetimeFunctionsUri, "gt"},
+ {"2022-07-15T14:33:30", "2022-07-15T14:33:14"},
+ {timestamp(TimeUnit::MICRO), timestamp(TimeUnit::MICRO)},
+ "1",
+ boolean()},
+ {{kSubstraitDatetimeFunctionsUri, "gte"},
+ {"2022-07-15T14:33:14", "2022-07-15T14:33:14"},
+ {timestamp(TimeUnit::MICRO), timestamp(TimeUnit::MICRO)},
+ "1",
+ boolean()},
+ {{kSubstraitStringFunctionsUri, "concat"},
+ {"abc", "def"},
+ {utf8(), utf8()},
+ "abcdef",
+ utf8()}};
+ CheckValidTestCases(valid_test_cases);
+}
+
+TEST(FunctionMapping, ErrorCases) {
+ const std::vector<FunctionTestCase> error_test_cases = {
+ {{kSubstraitArithmeticFunctionsUri, "add"},
+ {"ERROR", "127", "10"},
+ {nullptr, int8(), int8()},
+ "",
+ int8()},
+ {{kSubstraitArithmeticFunctionsUri, "subtract"},
+ {"ERROR", "-119", "10"},
+ {nullptr, int8(), int8()},
+ "",
+ int8()},
+ {{kSubstraitArithmeticFunctionsUri, "multiply"},
+ {"ERROR", "10", "13"},
+ {nullptr, int8(), int8()},
+ "",
+ int8()},
+ {{kSubstraitArithmeticFunctionsUri, "divide"},
+ {"ERROR", "-128", "-1"},
+ {nullptr, int8(), int8()},
+ "",
+ int8()}};
+ CheckErrorTestCases(error_test_cases);
+}
+
+// For each aggregate test case we take in three values. We compute the
+// aggregate both on the entire set (all three values) and on groups. The
+// first two rows will be in the first group and the last row will be in the
+// second group. It's important to test both for coverage since the arrow
+// function used actually changes when group ids are present
+struct AggregateTestCase {
+ // The substrait function id
+ Id function_id;
+ // The three values, as a JSON string
+ std::string arguments;
+ // The data type of the three values
+ std::shared_ptr<DataType> data_type;
+ // The result of the aggregate on all three
+ std::string combined_output;
+ // The result of the aggregate on each group (i.e. the first two rows
+ // and the last row). Should be a json-encoded array of size 2
+ std::string group_outputs;
+ // The data type of the outputs
+ std::shared_ptr<DataType> output_type;
+};
+
+std::shared_ptr<Table> GetInputTableForAggregateCase(const AggregateTestCase& test_case) {
+ std::vector<std::shared_ptr<Array>> columns(2);
+ std::vector<std::shared_ptr<Field>> fields(2);
+ columns[0] = ArrayFromJSON(int8(), "[1, 1, 2]");
+ columns[1] = ArrayFromJSON(test_case.data_type, test_case.arguments);
+ fields[0] = field("key", int8());
+ fields[1] = field("value", test_case.data_type);
+ std::shared_ptr<RecordBatch> batch =
+ RecordBatch::Make(schema(std::move(fields)), /*num_rows=*/3, std::move(columns));
+ EXPECT_OK_AND_ASSIGN(std::shared_ptr<Table> table, Table::FromRecordBatches({batch}));
+ return table;
+}
+
+std::shared_ptr<Table> GetOutputTableForAggregateCase(
+ const std::shared_ptr<DataType>& output_type, const std::string& json_data) {
+ std::shared_ptr<Array> out_arr = ArrayFromJSON(output_type, json_data);
+ std::shared_ptr<RecordBatch> batch =
+ RecordBatch::Make(schema({field("", output_type)}), 1, {out_arr});
+ EXPECT_OK_AND_ASSIGN(std::shared_ptr<Table> table, Table::FromRecordBatches({batch}));
+ return table;
+}
+
+std::shared_ptr<compute::ExecPlan> PlanFromAggregateCase(
+ const AggregateTestCase& test_case, std::shared_ptr<Table>* output_table,
+ bool with_keys) {
+ std::shared_ptr<Table> input_table = GetInputTableForAggregateCase(test_case);
+ std::vector<int> key_idxs = {};
+ if (with_keys) {
+ key_idxs = {0};
+ }
+ EXPECT_OK_AND_ASSIGN(
+ std::shared_ptr<Buffer> substrait,
+ internal::CreateScanAggSubstrait(test_case.function_id, input_table, key_idxs,
+ /*arg_idx=*/1, *test_case.output_type));
+ std::shared_ptr<compute::SinkNodeConsumer> consumer =
+ std::make_shared<compute::TableSinkNodeConsumer>(output_table,
+ default_memory_pool());
+
+ // Mock table provider that ignores the table name and returns input_table
+ NamedTableProvider table_provider = [input_table](const std::vector<std::string>&) {
+ std::shared_ptr<compute::ExecNodeOptions> options =
+ std::make_shared<compute::TableSourceNodeOptions>(input_table);
+ return compute::Declaration("table_source", {}, options, "mock_source");
+ };
+
+ ConversionOptions conversion_options;
+ conversion_options.named_table_provider = std::move(table_provider);
+
+ EXPECT_OK_AND_ASSIGN(
+ std::shared_ptr<compute::ExecPlan> plan,
+ DeserializePlan(*substrait, std::move(consumer), default_extension_id_registry(),
+ /*ext_set_out=*/nullptr, conversion_options));
+ return plan;
+}
+
+void CheckWholeAggregateCase(const AggregateTestCase& test_case) {
+ std::shared_ptr<Table> output_table;
+ std::shared_ptr<compute::ExecPlan> plan =
+ PlanFromAggregateCase(test_case, &output_table, /*with_keys=*/false);
+
+ ASSERT_OK(plan->StartProducing());
+ ASSERT_FINISHES_OK(plan->finished());
+
+ ASSERT_OK_AND_ASSIGN(output_table,
+ output_table->SelectColumns({output_table->num_columns() - 1}));
+
+ std::shared_ptr<Table> expected_output =
+ GetOutputTableForAggregateCase(test_case.output_type, test_case.combined_output);
+ AssertTablesEqual(*expected_output, *output_table, /*same_chunk_layout=*/false);
+}
+
+void CheckGroupedAggregateCase(const AggregateTestCase& test_case) {
+ std::shared_ptr<Table> output_table;
+ std::shared_ptr<compute::ExecPlan> plan =
+ PlanFromAggregateCase(test_case, &output_table, /*with_keys=*/true);
+
+ ASSERT_OK(plan->StartProducing());
+ ASSERT_FINISHES_OK(plan->finished());
+
+ // The aggregate node's output is unpredictable so we sort by the key column
+ ASSERT_OK_AND_ASSIGN(
+ std::shared_ptr<Array> sort_indices,
+ compute::SortIndices(output_table, compute::SortOptions({compute::SortKey(
+ output_table->num_columns() - 1,
+ compute::SortOrder::Ascending)})));
+ ASSERT_OK_AND_ASSIGN(Datum sorted_table_datum,
+ compute::Take(output_table, sort_indices));
+ output_table = sorted_table_datum.table();
+ // TODO(ARROW-17245) We should be selecting N-1 here but Acero
+ // currently emits things in reverse order
+ ASSERT_OK_AND_ASSIGN(output_table, output_table->SelectColumns({0}));
+
+ std::shared_ptr<Table> expected_output =
+ GetOutputTableForAggregateCase(test_case.output_type, test_case.group_outputs);
+
+ AssertTablesEqual(*expected_output, *output_table, /*same_chunk_layout=*/false);
+}
+
+void CheckAggregateCases(const std::vector<AggregateTestCase>& test_cases) {
+ for (const AggregateTestCase& test_case : test_cases) {
+ CheckWholeAggregateCase(test_case);
+ CheckGroupedAggregateCase(test_case);
+ }
+}
+
+TEST(FunctionMapping, AggregateCases) {
+ const std::vector<AggregateTestCase> test_cases = {
+ {{kSubstraitArithmeticFunctionsUri, "sum"},
+ "[1, 2, 3]",
+ int8(),
+ "[6]",
+ "[3, 3]",
+ int64()},
+ {{kSubstraitArithmeticFunctionsUri, "min"},
+ "[1, 2, 3]",
+ int8(),
+ "[1]",
+ "[1, 3]",
+ int8()},
+ {{kSubstraitArithmeticFunctionsUri, "max"},
+ "[1, 2, 3]",
+ int8(),
+ "[3]",
+ "[2, 3]",
+ int8()},
+ {{kSubstraitArithmeticFunctionsUri, "avg"},
+ "[1, 2, 3]",
+ float64(),
+ "[2]",
+ "[1.5, 3]",
+ float64()}};
+ CheckAggregateCases(test_cases);
+}
+
+} // namespace engine
+} // namespace arrow
diff --git a/cpp/src/arrow/engine/substrait/options.h b/cpp/src/arrow/engine/substrait/options.h
index dcb2088416..eace200f0a 100644
--- a/cpp/src/arrow/engine/substrait/options.h
+++ b/cpp/src/arrow/engine/substrait/options.h
@@ -54,11 +54,20 @@ enum class ConversionStrictness {
BEST_EFFORT,
};
+using NamedTableProvider =
+ std::function<Result<compute::Declaration>(const std::vector<std::string>&)>;
+static NamedTableProvider kDefaultNamedTableProvider;
+
/// Options that control the conversion between Substrait and Acero representations of a
/// plan.
struct ConversionOptions {
/// \brief How strictly the converter should adhere to the structure of the input.
ConversionStrictness strictness = ConversionStrictness::BEST_EFFORT;
+ /// \brief A custom strategy to be used for providing named tables
+ ///
+ /// The default behavior will return an invalid status if the plan has any
+ /// named table relations.
+ NamedTableProvider named_table_provider = kDefaultNamedTableProvider;
};
} // namespace engine
diff --git a/cpp/src/arrow/engine/substrait/plan_internal.cc b/cpp/src/arrow/engine/substrait/plan_internal.cc
index 2da037000c..b0fdb9bdc2 100644
--- a/cpp/src/arrow/engine/substrait/plan_internal.cc
+++ b/cpp/src/arrow/engine/substrait/plan_internal.cc
@@ -74,13 +74,12 @@ Status AddExtensionSetToPlan(const ExtensionSet& ext_set, substrait::Plan* plan)
}
for (uint32_t anchor = 0; anchor < ext_set.num_functions(); ++anchor) {
- ARROW_ASSIGN_OR_RAISE(auto function_record, ext_set.DecodeFunction(anchor));
- if (function_record.id.empty()) continue;
+ ARROW_ASSIGN_OR_RAISE(Id function_id, ext_set.DecodeFunction(anchor));
auto fn = internal::make_unique<ExtDecl::ExtensionFunction>();
- fn->set_extension_uri_reference(map[function_record.id.uri]);
+ fn->set_extension_uri_reference(map[function_id.uri]);
fn->set_function_anchor(anchor);
- fn->set_name(function_record.id.name.to_string());
+ fn->set_name(function_id.name.to_string());
auto ext_decl = internal::make_unique<ExtDecl>();
ext_decl->set_allocated_extension_function(fn.release());
@@ -104,8 +103,6 @@ Result<ExtensionSet> GetExtensionSetFromPlan(const substrait::Plan& plan,
// NOTE: it's acceptable to use views to memory owned by plan; ExtensionSet::Make
// will only store views to memory owned by registry.
- using Id = ExtensionSet::Id;
-
std::unordered_map<uint32_t, Id> type_ids, function_ids;
for (const auto& ext : plan.extensions()) {
switch (ext.mapping_type_case()) {
diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc
index 8cc1da4d90..c5c02f5155 100644
--- a/cpp/src/arrow/engine/substrait/relation_internal.cc
+++ b/cpp/src/arrow/engine/substrait/relation_internal.cc
@@ -67,6 +67,7 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& rel, const ExtensionSet&
ARROW_ASSIGN_OR_RAISE(auto base_schema,
FromProto(read.base_schema(), ext_set, conversion_options));
+ auto num_columns = static_cast<int>(base_schema->fields().size());
auto scan_options = std::make_shared<dataset::ScanOptions>();
scan_options->use_threads = true;
@@ -82,6 +83,22 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& rel, const ExtensionSet&
return Status::NotImplemented("substrait::ReadRel::projection");
}
+ if (read.has_named_table()) {
+ if (!conversion_options.named_table_provider) {
+ return Status::Invalid(
+ "plan contained a named table but a NamedTableProvider has not been "
+ "configured");
+ }
+ const NamedTableProvider& named_table_provider =
+ conversion_options.named_table_provider;
+ const substrait::ReadRel::NamedTable& named_table = read.named_table();
+ std::vector<std::string> table_names(named_table.names().begin(),
+ named_table.names().end());
+ ARROW_ASSIGN_OR_RAISE(compute::Declaration source_decl,
+ named_table_provider(table_names));
+ return DeclarationInfo{std::move(source_decl), num_columns};
+ }
+
if (!read.has_local_files()) {
return Status::NotImplemented(
"substrait::ReadRel with read_type other than LocalFiles");
@@ -182,7 +199,6 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& rel, const ExtensionSet&
std::move(filesystem), std::move(files),
std::move(format), {}));
- auto num_columns = static_cast<int>(base_schema->fields().size());
ARROW_ASSIGN_OR_RAISE(auto ds, ds_factory->Finish(std::move(base_schema)));
return DeclarationInfo{
@@ -349,17 +365,20 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& rel, const ExtensionSet&
"than one item");
}
std::vector<FieldRef> keys;
- auto group = aggregate.groupings(0);
- keys.reserve(group.grouping_expressions_size());
- for (int exp_id = 0; exp_id < group.grouping_expressions_size(); exp_id++) {
- ARROW_ASSIGN_OR_RAISE(auto expr, FromProto(group.grouping_expressions(exp_id),
- ext_set, conversion_options));
- const auto* field_ref = expr.field_ref();
- if (field_ref) {
- keys.emplace_back(std::move(*field_ref));
- } else {
- return Status::Invalid(
- "The grouping expression for an aggregate must be a direct reference.");
+ if (aggregate.groupings_size() > 0) {
+ const substrait::AggregateRel::Grouping& group = aggregate.groupings(0);
+ keys.reserve(group.grouping_expressions_size());
+ for (int exp_id = 0; exp_id < group.grouping_expressions_size(); exp_id++) {
+ ARROW_ASSIGN_OR_RAISE(
+ compute::Expression expr,
+ FromProto(group.grouping_expressions(exp_id), ext_set, conversion_options));
+ const FieldRef* field_ref = expr.field_ref();
+ if (field_ref) {
+ keys.emplace_back(std::move(*field_ref));
+ } else {
+ return Status::Invalid(
+ "The grouping expression for an aggregate must be a direct reference.");
+ }
}
}
@@ -373,25 +392,14 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& rel, const ExtensionSet&
return Status::NotImplemented("Aggregate filters are not supported.");
}
const auto& agg_func = agg_measure.measure();
- if (agg_func.arguments_size() != 1) {
- return Status::NotImplemented("Aggregate function must be a unary function.");
- }
- int func_reference = agg_func.function_reference();
- ARROW_ASSIGN_OR_RAISE(auto func_record, ext_set.DecodeFunction(func_reference));
- // aggreagte function name
- auto func_name = std::string(func_record.id.name);
- // aggregate target
- auto subs_func_args = agg_func.arguments(0);
- ARROW_ASSIGN_OR_RAISE(auto field_expr, FromProto(subs_func_args.value(),
- ext_set, conversion_options));
- auto target = field_expr.field_ref();
- if (!target) {
- return Status::Invalid(
- "The input expression to an aggregate function must be a direct "
- "reference.");
- }
- aggregates.emplace_back(compute::Aggregate{std::move(func_name), NULLPTR,
- std::move(*target), std::move("")});
+ ARROW_ASSIGN_OR_RAISE(
+ SubstraitCall aggregate_call,
+ FromProto(agg_func, !keys.empty(), ext_set, conversion_options));
+ ARROW_ASSIGN_OR_RAISE(
+ ExtensionIdRegistry::SubstraitAggregateToArrow converter,
+ ext_set.registry()->GetSubstraitAggregateToArrow(aggregate_call.id()));
+ ARROW_ASSIGN_OR_RAISE(compute::Aggregate arrow_agg, converter(aggregate_call));
+ aggregates.push_back(std::move(arrow_agg));
} else {
return Status::Invalid("substrait::AggregateFunction not provided");
}
diff --git a/cpp/src/arrow/engine/substrait/serde.cc b/cpp/src/arrow/engine/substrait/serde.cc
index 87ad88dccb..9f7d979e2f 100644
--- a/cpp/src/arrow/engine/substrait/serde.cc
+++ b/cpp/src/arrow/engine/substrait/serde.cc
@@ -172,7 +172,7 @@ Result<std::shared_ptr<compute::ExecPlan>> MakeSingleDeclarationPlan(
} else {
ARROW_ASSIGN_OR_RAISE(auto plan, compute::ExecPlan::Make());
ARROW_RETURN_NOT_OK(declarations[0].AddToPlan(plan.get()));
- return plan;
+ return std::move(plan);
}
}
@@ -182,17 +182,21 @@ Result<std::shared_ptr<compute::ExecPlan>> DeserializePlan(
const Buffer& buf, const std::shared_ptr<compute::SinkNodeConsumer>& consumer,
const ExtensionIdRegistry* registry, ExtensionSet* ext_set_out,
const ConversionOptions& conversion_options) {
- bool factory_done = false;
- auto single_consumer = [&factory_done, &consumer] {
- if (factory_done) {
- return std::shared_ptr<compute::SinkNodeConsumer>{};
+ struct SingleConsumer {
+ std::shared_ptr<compute::SinkNodeConsumer> operator()() {
+ if (factory_done) {
+ Status::Invalid("SingleConsumer invoked more than once").Warn();
+ return std::shared_ptr<compute::SinkNodeConsumer>{};
+ }
+ factory_done = true;
+ return consumer;
}
- factory_done = true;
- return consumer;
+ bool factory_done;
+ std::shared_ptr<compute::SinkNodeConsumer> consumer;
};
- ARROW_ASSIGN_OR_RAISE(
- auto declarations,
- DeserializePlans(buf, single_consumer, registry, ext_set_out, conversion_options));
+ ARROW_ASSIGN_OR_RAISE(auto declarations,
+ DeserializePlans(buf, SingleConsumer{false, consumer}, registry,
+ ext_set_out, conversion_options));
return MakeSingleDeclarationPlan(declarations);
}
diff --git a/cpp/src/arrow/engine/substrait/serde.h b/cpp/src/arrow/engine/substrait/serde.h
index 5214606e1c..6c2083fb56 100644
--- a/cpp/src/arrow/engine/substrait/serde.h
+++ b/cpp/src/arrow/engine/substrait/serde.h
@@ -75,7 +75,7 @@ ARROW_ENGINE_EXPORT Result<std::vector<compute::Declaration>> DeserializePlans(
/// Plan is returned here.
/// \return an ExecNode corresponding to the single toplevel relation in the Substrait
/// Plan
-Result<std::shared_ptr<compute::ExecPlan>> DeserializePlan(
+ARROW_ENGINE_EXPORT Result<std::shared_ptr<compute::ExecPlan>> DeserializePlan(
const Buffer& buf, const std::shared_ptr<compute::SinkNodeConsumer>& consumer,
const ExtensionIdRegistry* registry = NULLPTR, ExtensionSet* ext_set_out = NULLPTR,
const ConversionOptions& conversion_options = {});
diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc
index 3bb4de4e92..04405b3168 100644
--- a/cpp/src/arrow/engine/substrait/serde_test.cc
+++ b/cpp/src/arrow/engine/substrait/serde_test.cc
@@ -701,7 +701,12 @@ TEST(Substrait, ExtensionSetFromPlan) {
"extension_uris": [
{
"extension_uri_anchor": 7,
- "uri": ")" + substrait::default_extension_types_uri() +
+ "uri": ")" + default_extension_types_uri() +
+ R"("
+ },
+ {
+ "extension_uri_anchor": 18,
+ "uri": ")" + kSubstraitArithmeticFunctionsUri +
R"("
}
],
@@ -712,15 +717,15 @@ TEST(Substrait, ExtensionSetFromPlan) {
"name": "null"
}},
{"extension_function": {
- "extension_uri_reference": 7,
+ "extension_uri_reference": 18,
"function_anchor": 42,
"name": "add"
}}
]
- })";
+})";
ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json));
for (auto sp_ext_id_reg :
- {std::shared_ptr<ExtensionIdRegistry>(), substrait::MakeExtensionIdRegistry()}) {
+ {std::shared_ptr<ExtensionIdRegistry>(), MakeExtensionIdRegistry()}) {
ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get();
ExtensionSet ext_set(ext_id_reg);
ASSERT_OK_AND_ASSIGN(auto sink_decls,
@@ -732,10 +737,9 @@ TEST(Substrait, ExtensionSetFromPlan) {
EXPECT_EQ(decoded_null_type.id.name, "null");
EXPECT_EQ(*decoded_null_type.type, NullType());
- EXPECT_OK_AND_ASSIGN(auto decoded_add_func, ext_set.DecodeFunction(42));
- EXPECT_EQ(decoded_add_func.id.uri, kArrowExtTypesUri);
- EXPECT_EQ(decoded_add_func.id.name, "add");
- EXPECT_EQ(decoded_add_func.name, "add");
+ EXPECT_OK_AND_ASSIGN(Id decoded_add_func_id, ext_set.DecodeFunction(42));
+ EXPECT_EQ(decoded_add_func_id.uri, kSubstraitArithmeticFunctionsUri);
+ EXPECT_EQ(decoded_add_func_id.name, "add");
}
}
@@ -745,7 +749,7 @@ TEST(Substrait, ExtensionSetFromPlanMissingFunc) {
"extension_uris": [
{
"extension_uri_anchor": 7,
- "uri": ")" + substrait::default_extension_types_uri() +
+ "uri": ")" + default_extension_types_uri() +
R"("
}
],
@@ -760,7 +764,7 @@ TEST(Substrait, ExtensionSetFromPlanMissingFunc) {
ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json));
for (auto sp_ext_id_reg :
- {std::shared_ptr<ExtensionIdRegistry>(), substrait::MakeExtensionIdRegistry()}) {
+ {std::shared_ptr<ExtensionIdRegistry>(), MakeExtensionIdRegistry()}) {
ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get();
ExtensionSet ext_set(ext_id_reg);
ASSERT_RAISES(Invalid, DeserializePlans(
@@ -786,7 +790,7 @@ TEST(Substrait, ExtensionSetFromPlanExhaustedFactory) {
"extension_uris": [
{
"extension_uri_anchor": 7,
- "uri": ")" + substrait::default_extension_types_uri() +
+ "uri": ")" + default_extension_types_uri() +
R"("
}
],
@@ -801,7 +805,7 @@ TEST(Substrait, ExtensionSetFromPlanExhaustedFactory) {
ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json));
for (auto sp_ext_id_reg :
- {std::shared_ptr<ExtensionIdRegistry>(), substrait::MakeExtensionIdRegistry()}) {
+ {std::shared_ptr<ExtensionIdRegistry>(), MakeExtensionIdRegistry()}) {
ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get();
ExtensionSet ext_set(ext_id_reg);
ASSERT_RAISES(
@@ -823,7 +827,7 @@ TEST(Substrait, ExtensionSetFromPlanRegisterFunc) {
"extension_uris": [
{
"extension_uri_anchor": 7,
- "uri": ")" + substrait::default_extension_types_uri() +
+ "uri": ")" + default_extension_types_uri() +
R"("
}
],
@@ -837,24 +841,23 @@ TEST(Substrait, ExtensionSetFromPlanRegisterFunc) {
})";
ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json));
- auto sp_ext_id_reg = substrait::MakeExtensionIdRegistry();
+ auto sp_ext_id_reg = MakeExtensionIdRegistry();
ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get();
// invalid before registration
ExtensionSet ext_set_invalid(ext_id_reg);
ASSERT_RAISES(Invalid,
DeserializePlans(
*buf, [] { return kNullConsumer; }, ext_id_reg, &ext_set_invalid));
- ASSERT_OK(substrait::RegisterFunction(
- *ext_id_reg, substrait::default_extension_types_uri(), "new_func", "multiply"));
+ ASSERT_OK(ext_id_reg->AddSubstraitCallToArrow(
+ {default_extension_types_uri(), "new_func"}, "multiply"));
// valid after registration
ExtensionSet ext_set_valid(ext_id_reg);
ASSERT_OK_AND_ASSIGN(auto sink_decls, DeserializePlans(
*buf, [] { return kNullConsumer; },
ext_id_reg, &ext_set_valid));
- EXPECT_OK_AND_ASSIGN(auto decoded_add_func, ext_set_valid.DecodeFunction(42));
- EXPECT_EQ(decoded_add_func.id.uri, kArrowExtTypesUri);
- EXPECT_EQ(decoded_add_func.id.name, "new_func");
- EXPECT_EQ(decoded_add_func.name, "multiply");
+ EXPECT_OK_AND_ASSIGN(Id decoded_add_func_id, ext_set_valid.DecodeFunction(42));
+ EXPECT_EQ(decoded_add_func_id.uri, kArrowExtTypesUri);
+ EXPECT_EQ(decoded_add_func_id.name, "new_func");
}
Result<std::string> GetSubstraitJSON() {
@@ -900,7 +903,7 @@ TEST(Substrait, DeserializeWithConsumerFactory) {
GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows";
#else
ASSERT_OK_AND_ASSIGN(std::string substrait_json, GetSubstraitJSON());
- ASSERT_OK_AND_ASSIGN(auto buf, substrait::SerializeJsonPlan(substrait_json));
+ ASSERT_OK_AND_ASSIGN(auto buf, SerializeJsonPlan(substrait_json));
ASSERT_OK_AND_ASSIGN(auto declarations,
DeserializePlans(*buf, NullSinkNodeConsumer::Make));
ASSERT_EQ(declarations.size(), 1);
@@ -923,7 +926,7 @@ TEST(Substrait, DeserializeSinglePlanWithConsumerFactory) {
GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows";
#else
ASSERT_OK_AND_ASSIGN(std::string substrait_json, GetSubstraitJSON());
- ASSERT_OK_AND_ASSIGN(auto buf, substrait::SerializeJsonPlan(substrait_json));
+ ASSERT_OK_AND_ASSIGN(auto buf, SerializeJsonPlan(substrait_json));
ASSERT_OK_AND_ASSIGN(std::shared_ptr<compute::ExecPlan> plan,
DeserializePlan(*buf, NullSinkNodeConsumer::Make()));
ASSERT_EQ(1, plan->sinks().size());
@@ -960,7 +963,7 @@ TEST(Substrait, DeserializeWithWriteOptionsFactory) {
return std::make_shared<dataset::WriteNodeOptions>(options);
};
ASSERT_OK_AND_ASSIGN(std::string substrait_json, GetSubstraitJSON());
- ASSERT_OK_AND_ASSIGN(auto buf, substrait::SerializeJsonPlan(substrait_json));
+ ASSERT_OK_AND_ASSIGN(auto buf, SerializeJsonPlan(substrait_json));
ASSERT_OK_AND_ASSIGN(auto declarations, DeserializePlans(*buf, write_options_factory));
ASSERT_EQ(declarations.size(), 1);
compute::Declaration* decl = &declarations[0];
@@ -984,7 +987,7 @@ TEST(Substrait, DeserializeWithWriteOptionsFactory) {
static void test_with_registries(
std::function<void(ExtensionIdRegistry*, compute::FunctionRegistry*)> test) {
auto default_func_reg = compute::GetFunctionRegistry();
- auto nested_ext_id_reg = substrait::MakeExtensionIdRegistry();
+ auto nested_ext_id_reg = MakeExtensionIdRegistry();
auto nested_func_reg = compute::FunctionRegistry::Make(default_func_reg);
test(nullptr, default_func_reg);
test(nullptr, nested_func_reg.get());
@@ -999,8 +1002,8 @@ TEST(Substrait, GetRecordBatchReader) {
ASSERT_OK_AND_ASSIGN(std::string substrait_json, GetSubstraitJSON());
test_with_registries([&substrait_json](ExtensionIdRegistry* ext_id_reg,
compute::FunctionRegistry* func_registry) {
- ASSERT_OK_AND_ASSIGN(auto buf, substrait::SerializeJsonPlan(substrait_json));
- ASSERT_OK_AND_ASSIGN(auto reader, substrait::ExecuteSerializedPlan(*buf));
+ ASSERT_OK_AND_ASSIGN(auto buf, SerializeJsonPlan(substrait_json));
+ ASSERT_OK_AND_ASSIGN(auto reader, ExecuteSerializedPlan(*buf));
ASSERT_OK_AND_ASSIGN(auto table, Table::FromRecordBatchReader(reader.get()));
// Note: assuming the binary.parquet file contains fixed amount of records
// in case of a test failure, re-evalaute the content in the file
@@ -1016,8 +1019,8 @@ TEST(Substrait, InvalidPlan) {
})";
test_with_registries([&substrait_json](ExtensionIdRegistry* ext_id_reg,
compute::FunctionRegistry* func_registry) {
- ASSERT_OK_AND_ASSIGN(auto buf, substrait::SerializeJsonPlan(substrait_json));
- ASSERT_RAISES(Invalid, substrait::ExecuteSerializedPlan(*buf));
+ ASSERT_OK_AND_ASSIGN(auto buf, SerializeJsonPlan(substrait_json));
+ ASSERT_RAISES(Invalid, ExecuteSerializedPlan(*buf));
});
}
@@ -1101,7 +1104,10 @@ TEST(Substrait, JoinPlanBasic) {
}
}
}
- }]
+ }],
+ "output_type": {
+ "bool": {}
+ }
}
},
"type": "JOIN_TYPE_INNER"
@@ -1111,7 +1117,7 @@ TEST(Substrait, JoinPlanBasic) {
"extension_uris": [
{
"extension_uri_anchor": 0,
- "uri": ")" + substrait::default_extension_types_uri() +
+ "uri": ")" + std::string(kSubstraitComparisonFunctionsUri) +
R"("
}
],
@@ -1125,7 +1131,7 @@ TEST(Substrait, JoinPlanBasic) {
})";
ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json));
for (auto sp_ext_id_reg :
- {std::shared_ptr<ExtensionIdRegistry>(), substrait::MakeExtensionIdRegistry()}) {
+ {std::shared_ptr<ExtensionIdRegistry>(), MakeExtensionIdRegistry()}) {
ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get();
ExtensionSet ext_set(ext_id_reg);
ASSERT_OK_AND_ASSIGN(auto sink_decls,
@@ -1241,7 +1247,10 @@ TEST(Substrait, JoinPlanInvalidKeyCmp) {
}
}
}
- }]
+ }],
+ "output_type": {
+ "bool": {}
+ }
}
},
"type": "JOIN_TYPE_INNER"
@@ -1251,7 +1260,7 @@ TEST(Substrait, JoinPlanInvalidKeyCmp) {
"extension_uris": [
{
"extension_uri_anchor": 0,
- "uri": ")" + substrait::default_extension_types_uri() +
+ "uri": ")" + std::string(kSubstraitArithmeticFunctionsUri) +
R"("
}
],
@@ -1265,7 +1274,7 @@ TEST(Substrait, JoinPlanInvalidKeyCmp) {
})";
ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json));
for (auto sp_ext_id_reg :
- {std::shared_ptr<ExtensionIdRegistry>(), substrait::MakeExtensionIdRegistry()}) {
+ {std::shared_ptr<ExtensionIdRegistry>(), MakeExtensionIdRegistry()}) {
ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get();
ExtensionSet ext_set(ext_id_reg);
ASSERT_RAISES(Invalid, DeserializePlans(
@@ -1333,7 +1342,7 @@ TEST(Substrait, JoinPlanInvalidExpression) {
}]
})"));
for (auto sp_ext_id_reg :
- {std::shared_ptr<ExtensionIdRegistry>(), substrait::MakeExtensionIdRegistry()}) {
+ {std::shared_ptr<ExtensionIdRegistry>(), MakeExtensionIdRegistry()}) {
ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get();
ExtensionSet ext_set(ext_id_reg);
ASSERT_RAISES(Invalid, DeserializePlans(
@@ -1406,7 +1415,7 @@ TEST(Substrait, JoinPlanInvalidKeys) {
}]
})"));
for (auto sp_ext_id_reg :
- {std::shared_ptr<ExtensionIdRegistry>(), substrait::MakeExtensionIdRegistry()}) {
+ {std::shared_ptr<ExtensionIdRegistry>(), MakeExtensionIdRegistry()}) {
ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get();
ExtensionSet ext_set(ext_id_reg);
ASSERT_RAISES(Invalid, DeserializePlans(
@@ -1470,6 +1479,7 @@ TEST(Substrait, AggregateBasic) {
}],
"sorts": [],
"phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT",
+ "invocation": "AGGREGATION_INVOCATION_ALL",
"outputType": {
"i64": {}
}
@@ -1480,18 +1490,18 @@ TEST(Substrait, AggregateBasic) {
}],
"extensionUris": [{
"extension_uri_anchor": 0,
- "uri": "https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml"
+ "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml"
}],
"extensions": [{
"extension_function": {
"extension_uri_reference": 0,
"function_anchor": 0,
- "name": "hash_count"
+ "name": "sum"
}
}],
})"));
- auto sp_ext_id_reg = substrait::MakeExtensionIdRegistry();
+ auto sp_ext_id_reg = MakeExtensionIdRegistry();
ASSERT_OK_AND_ASSIGN(auto sink_decls,
DeserializePlans(*buf, [] { return kNullConsumer; }));
auto agg_decl = sink_decls[0].inputs[0];
@@ -1503,7 +1513,7 @@ TEST(Substrait, AggregateBasic) {
EXPECT_EQ(agg_rel->factory_name, "aggregate");
EXPECT_EQ(agg_options.aggregates[0].name, "");
- EXPECT_EQ(agg_options.aggregates[0].function, "hash_count");
+ EXPECT_EQ(agg_options.aggregates[0].function, "hash_sum");
}
TEST(Substrait, AggregateInvalidRel) {
@@ -1516,13 +1526,13 @@ TEST(Substrait, AggregateInvalidRel) {
}],
"extensionUris": [{
"extension_uri_anchor": 0,
- "uri": "https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml"
+ "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml"
}],
"extensions": [{
"extension_function": {
"extension_uri_reference": 0,
"function_anchor": 0,
- "name": "hash_count"
+ "name": "sum"
}
}],
})"));
@@ -1577,13 +1587,13 @@ TEST(Substrait, AggregateInvalidFunction) {
}],
"extensionUris": [{
"extension_uri_anchor": 0,
- "uri": "https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml"
+ "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml"
}],
"extensions": [{
"extension_function": {
"extension_uri_reference": 0,
"function_anchor": 0,
- "name": "hash_count"
+ "name": "sum"
}
}],
})"));
@@ -1637,6 +1647,7 @@ TEST(Substrait, AggregateInvalidAggFuncArgs) {
"args": [],
"sorts": [],
"phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT",
+ "invocation": "AGGREGATION_INVOCATION_ALL",
"outputType": {
"i64": {}
}
@@ -1647,13 +1658,13 @@ TEST(Substrait, AggregateInvalidAggFuncArgs) {
}],
"extensionUris": [{
"extension_uri_anchor": 0,
- "uri": "https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml"
+ "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml"
}],
"extensions": [{
"extension_function": {
"extension_uri_reference": 0,
"function_anchor": 0,
- "name": "hash_count"
+ "name": "sum"
}
}],
})"));
@@ -1707,6 +1718,78 @@ TEST(Substrait, AggregateWithFilter) {
"args": [],
"sorts": [],
"phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT",
+ "invocation": "AGGREGATION_INVOCATION_ALL",
+ "outputType": {
+ "i64": {}
+ }
+ }
+ }]
+ }
+ }
+ }],
+ "extensionUris": [{
+ "extension_uri_anchor": 0,
+ "uri": "https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml"
+ }],
+ "extensions": [{
+ "extension_function": {
+ "extension_uri_reference": 0,
+ "function_anchor": 0,
+ "name": "equal"
+ }
+ }],
+ })"));
+
+ ASSERT_RAISES(NotImplemented, DeserializePlans(*buf, [] { return kNullConsumer; }));
+}
+
+TEST(Substrait, AggregateBadPhase) {
+ ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({
+ "relations": [{
+ "rel": {
+ "aggregate": {
+ "input": {
+ "read": {
+ "base_schema": {
+ "names": ["A", "B", "C"],
+ "struct": {
+ "types": [{
+ "i32": {}
+ }, {
+ "i32": {}
+ }, {
+ "i32": {}
+ }]
+ }
+ },
+ "local_files": {
+ "items": [
+ {
+ "uri_file": "file:///tmp/dat.parquet",
+ "parquet": {}
+ }
+ ]
+ }
+ }
+ },
+ "groupings": [{
+ "groupingExpressions": [{
+ "selection": {
+ "directReference": {
+ "structField": {
+ "field": 0
+ }
+ }
+ }
+ }]
+ }],
+ "measures": [{
+ "measure": {
+ "functionReference": 0,
+ "args": [],
+ "sorts": [],
+ "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT",
+ "invocation": "AGGREGATION_INVOCATION_DISTINCT",
"outputType": {
"i64": {}
}
diff --git a/cpp/src/arrow/engine/substrait/test_plan_builder.cc b/cpp/src/arrow/engine/substrait/test_plan_builder.cc
new file mode 100644
index 0000000000..3bd373ae5f
--- /dev/null
+++ b/cpp/src/arrow/engine/substrait/test_plan_builder.cc
@@ -0,0 +1,216 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/engine/substrait/test_plan_builder.h"
+
+#include <cstdint>
+
+#include "arrow/compute/exec/exec_plan.h"
+#include "arrow/engine/substrait/plan_internal.h"
+#include "arrow/engine/substrait/type_internal.h"
+#include "arrow/util/macros.h"
+#include "arrow/util/make_unique.h"
+
+#include "substrait/algebra.pb.h"
+#include "substrait/plan.pb.h"
+#include "substrait/type.pb.h"
+
+namespace arrow {
+
+using internal::make_unique;
+
+namespace engine {
+namespace internal {
+
+static const ConversionOptions kPlanBuilderConversionOptions;
+
+Result<std::unique_ptr<substrait::ReadRel>> CreateRead(const Table& table,
+ ExtensionSet* ext_set) {
+ auto read = make_unique<substrait::ReadRel>();
+
+ ARROW_ASSIGN_OR_RAISE(std::unique_ptr<substrait::NamedStruct> schema,
+ ToProto(*table.schema(), ext_set, kPlanBuilderConversionOptions));
+ read->set_allocated_base_schema(schema.release());
+
+ auto named_table = make_unique<substrait::ReadRel::NamedTable>();
+ named_table->add_names("test");
+ read->set_allocated_named_table(named_table.release());
+
+ return read;
+}
+
+void CreateDirectReference(int32_t index, substrait::Expression* expr) {
+ auto reference = make_unique<substrait::Expression::FieldReference>();
+ auto reference_segment = make_unique<substrait::Expression::ReferenceSegment>();
+ auto struct_field = make_unique<substrait::Expression::ReferenceSegment::StructField>();
+ struct_field->set_field(index);
+ reference_segment->set_allocated_struct_field(struct_field.release());
+ reference->set_allocated_direct_reference(reference_segment.release());
+
+ auto root_reference =
+ make_unique<substrait::Expression::FieldReference::RootReference>();
+ reference->set_allocated_root_reference(root_reference.release());
+ expr->set_allocated_selection(reference.release());
+}
+
+Result<std::unique_ptr<substrait::ProjectRel>> CreateProject(
+ Id function_id, const std::vector<std::string>& arguments,
+ const std::vector<std::shared_ptr<DataType>>& arg_types, const DataType& output_type,
+ ExtensionSet* ext_set) {
+ auto project = make_unique<substrait::ProjectRel>();
+
+ auto call = make_unique<substrait::Expression::ScalarFunction>();
+ ARROW_ASSIGN_OR_RAISE(uint32_t function_anchor, ext_set->EncodeFunction(function_id));
+ call->set_function_reference(function_anchor);
+
+ std::size_t arg_index = 0;
+ std::size_t table_arg_index = 0;
+ for (const std::shared_ptr<DataType>& arg_type : arg_types) {
+ substrait::FunctionArgument* argument = call->add_arguments();
+ if (arg_type) {
+ // If it has a type then it's a reference to the input table
+ auto expression = make_unique<substrait::Expression>();
+ CreateDirectReference(static_cast<int32_t>(table_arg_index++), expression.get());
+ argument->set_allocated_value(expression.release());
+ } else {
+ // If it doesn't have a type then it's an enum
+ const std::string& enum_value = arguments[arg_index];
+ auto enum_ = make_unique<substrait::FunctionArgument::Enum>();
+ if (enum_value.size() > 0) {
+ enum_->set_specified(enum_value);
+ } else {
+ auto unspecified = make_unique<google::protobuf::Empty>();
+ enum_->set_allocated_unspecified(unspecified.release());
+ }
+ argument->set_allocated_enum_(enum_.release());
+ }
+ arg_index++;
+ }
+
+ ARROW_ASSIGN_OR_RAISE(
+ std::unique_ptr<substrait::Type> output_type_substrait,
+ ToProto(output_type, /*nullable=*/true, ext_set, kPlanBuilderConversionOptions));
+ call->set_allocated_output_type(output_type_substrait.release());
+
+ substrait::Expression* call_expression = project->add_expressions();
+ call_expression->set_allocated_scalar_function(call.release());
+
+ return project;
+}
+
+Result<std::unique_ptr<substrait::AggregateRel>> CreateAgg(Id function_id,
+ const std::vector<int>& keys,
+ int arg_idx,
+ const DataType& output_type,
+ ExtensionSet* ext_set) {
+ auto agg = make_unique<substrait::AggregateRel>();
+
+ if (!keys.empty()) {
+ substrait::AggregateRel::Grouping* grouping = agg->add_groupings();
+ for (int key : keys) {
+ substrait::Expression* key_expr = grouping->add_grouping_expressions();
+ CreateDirectReference(key, key_expr);
+ }
+ }
+
+ substrait::AggregateRel::Measure* measure_wrapper = agg->add_measures();
+ auto agg_func = make_unique<substrait::AggregateFunction>();
+ ARROW_ASSIGN_OR_RAISE(uint32_t function_anchor, ext_set->EncodeFunction(function_id));
+
+ agg_func->set_function_reference(function_anchor);
+
+ substrait::FunctionArgument* arg = agg_func->add_arguments();
+ auto arg_expr = make_unique<substrait::Expression>();
+ CreateDirectReference(arg_idx, arg_expr.get());
+ arg->set_allocated_value(arg_expr.release());
+
+ agg_func->set_phase(substrait::AggregationPhase::AGGREGATION_PHASE_INITIAL_TO_RESULT);
+ agg_func->set_invocation(
+ substrait::AggregateFunction::AggregationInvocation::
+ AggregateFunction_AggregationInvocation_AGGREGATION_INVOCATION_ALL);
+
+ ARROW_ASSIGN_OR_RAISE(
+ std::unique_ptr<substrait::Type> output_type_substrait,
+ ToProto(output_type, /*nullable=*/true, ext_set, kPlanBuilderConversionOptions));
+ agg_func->set_allocated_output_type(output_type_substrait.release());
+ measure_wrapper->set_allocated_measure(agg_func.release());
+
+ return agg;
+}
+
+Result<std::unique_ptr<substrait::Plan>> CreatePlan(std::unique_ptr<substrait::Rel> root,
+ ExtensionSet* ext_set) {
+ auto plan = make_unique<substrait::Plan>();
+
+ substrait::PlanRel* plan_rel = plan->add_relations();
+ auto rel_root = make_unique<substrait::RelRoot>();
+ rel_root->set_allocated_input(root.release());
+ plan_rel->set_allocated_root(rel_root.release());
+
+ ARROW_RETURN_NOT_OK(AddExtensionSetToPlan(*ext_set, plan.get()));
+ return plan;
+}
+
+Result<std::shared_ptr<Buffer>> CreateScanProjectSubstrait(
+ Id function_id, const std::shared_ptr<Table>& input_table,
+ const std::vector<std::string>& arguments,
+ const std::vector<std::shared_ptr<DataType>>& data_types,
+ const DataType& output_type) {
+ ExtensionSet ext_set;
+ ARROW_ASSIGN_OR_RAISE(std::unique_ptr<substrait::ReadRel> read,
+ CreateRead(*input_table, &ext_set));
+ ARROW_ASSIGN_OR_RAISE(
+ std::unique_ptr<substrait::ProjectRel> project,
+ CreateProject(function_id, arguments, data_types, output_type, &ext_set));
+
+ auto read_rel = make_unique<substrait::Rel>();
+ read_rel->set_allocated_read(read.release());
+ project->set_allocated_input(read_rel.release());
+
+ auto project_rel = make_unique<substrait::Rel>();
+ project_rel->set_allocated_project(project.release());
+
+ ARROW_ASSIGN_OR_RAISE(std::unique_ptr<substrait::Plan> plan,
+ CreatePlan(std::move(project_rel), &ext_set));
+ return Buffer::FromString(plan->SerializeAsString());
+}
+
+Result<std::shared_ptr<Buffer>> CreateScanAggSubstrait(
+ Id function_id, const std::shared_ptr<Table>& input_table,
+ const std::vector<int>& key_idxs, int arg_idx, const DataType& output_type) {
+ ExtensionSet ext_set;
+
+ ARROW_ASSIGN_OR_RAISE(std::unique_ptr<substrait::ReadRel> read,
+ CreateRead(*input_table, &ext_set));
+ ARROW_ASSIGN_OR_RAISE(std::unique_ptr<substrait::AggregateRel> agg,
+ CreateAgg(function_id, key_idxs, arg_idx, output_type, &ext_set));
+
+ auto read_rel = make_unique<substrait::Rel>();
+ read_rel->set_allocated_read(read.release());
+ agg->set_allocated_input(read_rel.release());
+
+ auto agg_rel = make_unique<substrait::Rel>();
+ agg_rel->set_allocated_aggregate(agg.release());
+
+ ARROW_ASSIGN_OR_RAISE(std::unique_ptr<substrait::Plan> plan,
+ CreatePlan(std::move(agg_rel), &ext_set));
+ return Buffer::FromString(plan->SerializeAsString());
+}
+
+} // namespace internal
+} // namespace engine
+} // namespace arrow
diff --git a/cpp/src/arrow/engine/substrait/test_plan_builder.h b/cpp/src/arrow/engine/substrait/test_plan_builder.h
new file mode 100644
index 0000000000..9d2d97a8cc
--- /dev/null
+++ b/cpp/src/arrow/engine/substrait/test_plan_builder.h
@@ -0,0 +1,72 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// These utilities are for internal / unit test use only.
+// They allow for the construction of simple Substrait plans
+// programmatically without first requiring the construction
+// of an ExecPlan
+
+// These utilities have to be here, and not in a test_util.cc
+// file (or in a unit test) because only one .so is allowed
+// to include each .pb.h file or else protobuf will encounter
+// global namespace conflicts.
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/buffer.h"
+#include "arrow/engine/substrait/extension_set.h"
+#include "arrow/result.h"
+#include "arrow/table.h"
+#include "arrow/type.h"
+
+namespace arrow {
+namespace engine {
+namespace internal {
+
+/// \brief Create a scan->project->sink plan for tests
+///
+/// The plan will project one additional column using the function
+/// defined by `function_id`, `arguments`, and data_types. `arguments`
+/// and `data_types` should have the same length but only one of each
+/// should be defined at each index.
+///
+/// If `data_types` is defined at an index then the plan will create a
+/// direct reference (starting at index 0 and increasing by 1 for each
+/// argument of this type).
+///
+/// If `arguments` is defined at an index then the plan will create an
+/// enum argument with that value.
+ARROW_ENGINE_EXPORT Result<std::shared_ptr<Buffer>> CreateScanProjectSubstrait(
+ Id function_id, const std::shared_ptr<Table>& input_table,
+ const std::vector<std::string>& arguments,
+ const std::vector<std::shared_ptr<DataType>>& data_types,
+ const DataType& output_type);
+
+/// \brief Create a scan->aggregate->sink plan for tests
+///
+/// The plan will create an aggregate with one grouping set (defined by
+/// key_idxs) and one measure. The measure will be a unary function
+/// defined by `function_id` and a direct reference to `arg_idx`.
+ARROW_ENGINE_EXPORT Result<std::shared_ptr<Buffer>> CreateScanAggSubstrait(
+ Id function_id, const std::shared_ptr<Table>& input_table,
+ const std::vector<int>& key_idxs, int arg_idx, const DataType& output_type);
+
+} // namespace internal
+} // namespace engine
+} // namespace arrow
diff --git a/cpp/src/arrow/engine/substrait/util.cc b/cpp/src/arrow/engine/substrait/util.cc
index 36240d4682..936bde5c65 100644
--- a/cpp/src/arrow/engine/substrait/util.cc
+++ b/cpp/src/arrow/engine/substrait/util.cc
@@ -23,8 +23,6 @@ namespace arrow {
namespace engine {
-namespace substrait {
-
namespace {
/// \brief A SinkNodeConsumer specialized to output ExecBatches via PushGenerator
@@ -136,19 +134,11 @@ std::shared_ptr<ExtensionIdRegistry> MakeExtensionIdRegistry() {
return nested_extension_id_registry(default_extension_id_registry());
}
-Status RegisterFunction(ExtensionIdRegistry& registry, const std::string& id_uri,
- const std::string& id_name,
- const std::string& arrow_function_name) {
- return registry.RegisterFunction(id_uri, id_name, arrow_function_name);
-}
-
const std::string& default_extension_types_uri() {
static std::string uri = engine::kArrowExtTypesUri.to_string();
return uri;
}
-} // namespace substrait
-
} // namespace engine
} // namespace arrow
diff --git a/cpp/src/arrow/engine/substrait/util.h b/cpp/src/arrow/engine/substrait/util.h
index 134d633bb3..3ac9320e1d 100644
--- a/cpp/src/arrow/engine/substrait/util.h
+++ b/cpp/src/arrow/engine/substrait/util.h
@@ -27,8 +27,6 @@ namespace arrow {
namespace engine {
-namespace substrait {
-
/// \brief Retrieve a RecordBatchReader from a Substrait plan.
ARROW_ENGINE_EXPORT Result<std::shared_ptr<RecordBatchReader>> ExecuteSerializedPlan(
const Buffer& substrait_buffer, const ExtensionIdRegistry* registry = NULLPTR,
@@ -43,24 +41,8 @@ ARROW_ENGINE_EXPORT Result<std::shared_ptr<Buffer>> SerializeJsonPlan(
/// See arrow::engine::nested_extension_id_registry for details.
ARROW_ENGINE_EXPORT std::shared_ptr<ExtensionIdRegistry> MakeExtensionIdRegistry();
-/// \brief Register a function manually.
-///
-/// Register an arrow function name by an ID, defined by a URI and a name, on a given
-/// extension-id-registry.
-///
-/// \param[in] registry an extension-id-registry to use
-/// \param[in] id_uri a URI of the ID to register by
-/// \param[in] id_name a name of the ID to register by
-/// \param[in] arrow_function_name name of arrow function to register
-ARROW_ENGINE_EXPORT Status RegisterFunction(ExtensionIdRegistry& registry,
- const std::string& id_uri,
- const std::string& id_name,
- const std::string& arrow_function_name);
-
ARROW_ENGINE_EXPORT const std::string& default_extension_types_uri();
-} // namespace substrait
-
} // namespace engine
} // namespace arrow
diff --git a/docs/source/cpp/streaming_execution.rst b/docs/source/cpp/streaming_execution.rst
index e49225637d..daa5f4be2f 100644
--- a/docs/source/cpp/streaming_execution.rst
+++ b/docs/source/cpp/streaming_execution.rst
@@ -144,6 +144,17 @@ Join Relations
join key is supported.
* The ``post_join_filter`` property is not supported and will be ignored.
+Aggregate Relations
+^^^^^^^^^^^^^^^^^^^
+
+ * At most one grouping set is supported.
+ * Each grouping expression must be a direct reference.
+ * Each measure's arguments must be direct references.
+ * A measure may not have a filter
+ * A measure may not have sorts
+ * A measure's invocation must be AGGREGATION_INVOCATION_ALL
+ * A measure's phase must be AGGREGATION_PHASE_INITIAL_TO_RESULT
+
Expressions (general)
^^^^^^^^^^^^^^^^^^^^^
@@ -152,20 +163,128 @@ Expressions (general)
grouping set. Acero typically expects these expressions to be direct references.
Planners should extract the implicit projection into a formal project relation
before delivering the plan to Acero.
+ * Older versions of Isthmus would omit optional arguments instead of including them
+ as unspecified enums. Acero will not support these plans.
Literals
^^^^^^^^
* A literal with non-default nullability will cause a plan to be rejected.
+Types
+^^^^^
+
+ * Acero does not have full support for non-nullable types and may allow input
+ to have nulls without rejecting it.
+ * The table below shows the mapping between Arrow types and Substrait type
+ classes that are currently supported
+
+.. list-table:: Substrait / Arrow Type Mapping
+ :widths: 25 25
+ :header-rows: 1
+
+ * - Substrait Type
+ - Arrow Type
+ - Caveat
+ * - boolean
+ - boolean
+ -
+ * - i8
+ - int8
+ -
+ * - i16
+ - int16
+ -
+ * - i16
+ - int16
+ -
+ * - i32
+ - int32
+ -
+ * - i64
+ - int64
+ -
+ * - fp32
+ - float32
+ -
+ * - fp64
+ - float64
+ -
+ * - string
+ - string
+ -
+ * - binary
+ - binary
+ -
+ * - timestamp
+ - timestamp<MICRO,"">
+ -
+ * - timestamp_tz
+ - timestamp<MICRO,"UTC">
+ -
+ * - date
+ - date32<DAY>
+ -
+ * - time
+ - time64<MICRO>
+ -
+ * - interval_year
+ -
+ - Not currently supported
+ * - interval_day
+ -
+ - Not currently supported
+ * - uuid
+ -
+ - Not currently supported
+ * - FIXEDCHAR<L>
+ -
+ - Not currently supported
+ * - VARCHAR<L>
+ -
+ - Not currently supported
+ * - FIXEDBINARY<L>
+ - fixed_size_binary<L>
+ -
+ * - DECIMAL<P,S>
+ - decimal128<P,S>
+ -
+ * - STRUCT<T1...TN>
+ - struct<T1...TN>
+ - Arrow struct fields will have no name (empty string)
+ * - NSTRUCT<N:T1...N:Tn>
+ -
+ - Not currently supported
+ * - LIST<T>
+ - list<T>
+ -
+ * - MAP<K,V>
+ - map<K,V>
+ - K must not be nullable
+
Functions
^^^^^^^^^
- * The only functions currently supported by Acero are:
-
- * add
- * equal
- * is_not_distinct_from
+ * Acero does not support the legacy ``args`` style of declaring arguments
+ * The following functions have caveats or are not supported at all. Note that
+ this is not a comprehensive list. Functions are being added to Substrait at
+ a rapid pace and new functions may be missing.
+
+ * Acero does not support the SATURATE option for overflow
+ * Acero does not support kernels that take more than two arguments
+ for the functions ``and``, ``or``, ``xor``
+ * Acero does not support temporal arithmetic
+ * Acero does not support the following standard functions:
+
+ * ``is_not_distinct_from``
+ * ``like``
+ * ``substring``
+ * ``starts_with``
+ * ``ends_with``
+ * ``contains``
+ * ``count``
+ * ``count_distinct``
+ * ``approx_count_distinct``
* The functions above must be referenced using the URI
``https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml``
diff --git a/python/pyarrow/_substrait.pyx b/python/pyarrow/_substrait.pyx
index 7f079fb717..05794a95a2 100644
--- a/python/pyarrow/_substrait.pyx
+++ b/python/pyarrow/_substrait.pyx
@@ -19,6 +19,7 @@
from cython.operator cimport dereference as deref
from pyarrow import Buffer
+from pyarrow.lib import frombytes
from pyarrow.lib cimport *
from pyarrow.includes.libarrow cimport *
from pyarrow.includes.libarrow_substrait cimport *
@@ -77,3 +78,27 @@ def _parse_json_plan(plan):
with nogil:
c_buf_plan = GetResultValue(c_res_buffer)
return pyarrow_wrap_buffer(c_buf_plan)
+
+
+def get_supported_functions():
+ """
+ Get a list of Substrait functions that the underlying
+ engine currently supports.
+
+ Returns
+ -------
+ list[str]
+ A list of function ids encoded as '{uri}#{name}'
+ """
+
+ cdef:
+ ExtensionIdRegistry* c_id_registry
+ std_vector[c_string] c_ids
+
+ c_id_registry = default_extension_id_registry()
+ c_ids = c_id_registry.GetSupportedSubstraitFunctions()
+
+ functions_list = []
+ for c_id in c_ids:
+ functions_list.append(frombytes(c_id))
+ return functions_list
diff --git a/python/pyarrow/includes/libarrow_substrait.pxd b/python/pyarrow/includes/libarrow_substrait.pxd
index 2e1a17b06b..0b3ace75d9 100644
--- a/python/pyarrow/includes/libarrow_substrait.pxd
+++ b/python/pyarrow/includes/libarrow_substrait.pxd
@@ -17,10 +17,20 @@
# distutils: language = c++
+from libcpp.vector cimport vector as std_vector
+
from pyarrow.includes.common cimport *
from pyarrow.includes.libarrow cimport *
-cdef extern from "arrow/engine/substrait/util.h" namespace "arrow::engine::substrait" nogil:
+cdef extern from "arrow/engine/substrait/util.h" namespace "arrow::engine" nogil:
CResult[shared_ptr[CRecordBatchReader]] ExecuteSerializedPlan(const CBuffer& substrait_buffer)
CResult[shared_ptr[CBuffer]] SerializeJsonPlan(const c_string& substrait_json)
+
+cdef extern from "arrow/engine/substrait/extension_set.h" \
+ namespace "arrow::engine" nogil:
+
+ cdef cppclass ExtensionIdRegistry:
+ std_vector[c_string] GetSupportedSubstraitFunctions()
+
+ ExtensionIdRegistry* default_extension_id_registry()
diff --git a/python/pyarrow/substrait.py b/python/pyarrow/substrait.py
index e3ff28f4eb..590d03521f 100644
--- a/python/pyarrow/substrait.py
+++ b/python/pyarrow/substrait.py
@@ -16,5 +16,6 @@
# under the License.
from pyarrow._substrait import ( # noqa
+ get_supported_functions,
run_query,
)
diff --git a/python/pyarrow/tests/test_substrait.py b/python/pyarrow/tests/test_substrait.py
index f05d68a95a..c8fa6afcb9 100644
--- a/python/pyarrow/tests/test_substrait.py
+++ b/python/pyarrow/tests/test_substrait.py
@@ -145,3 +145,23 @@ def test_binary_conversion_with_json_options(tmpdir):
res_tb = reader.read_all()
assert table.select(["bar"]) == res_tb.select(["bar"])
+
+
+# Substrait has not finalized what the URI should be for standard functions
+# In the meantime, lets just check the suffix
+def has_function(fns, ext_file, fn_name):
+ suffix = f'{ext_file}#{fn_name}'
+ for fn in fns:
+ if fn.endswith(suffix):
+ return True
+ return False
+
+
+def test_get_supported_functions():
+ supported_functions = pa._substrait.get_supported_functions()
+ # It probably doesn't make sense to exhaustively verfiy this list but
+ # we can check a sample aggregate and a sample non-aggregate entry
+ assert has_function(supported_functions,
+ 'functions_arithmetic.yaml', 'add')
+ assert has_function(supported_functions,
+ 'functions_arithmetic.yaml', 'sum')