You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by we...@apache.org on 2022/05/25 08:54:56 UTC
[arrow] branch master updated: ARROW-15583: [C++] The Substrait consumer could potentially use a massive amount of RAM if the producer uses large anchors
This is an automated email from the ASF dual-hosted git repository.
westonpace pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 8bbc695520 ARROW-15583: [C++] The Substrait consumer could potentially use a massive amount of RAM if the producer uses large anchors
8bbc695520 is described below
commit 8bbc6955202650ab1c3f9f564ec99e4e499a1f40
Author: Sanjiban Sengupta <sa...@gmail.com>
AuthorDate: Tue May 24 22:54:21 2022 -1000
ARROW-15583: [C++] The Substrait consumer could potentially use a massive amount of RAM if the producer uses large anchors
This PR modifies the ExtensionSet in C++ Consumer to use an `unordered_map` instead of a `vector` to store the `uri` anchors as the lookup table. This also modifies the usage of the `impl` struct as now the included functions are defined directly with the ExtensionSet implementation.
Closes #12852 from sanjibansg/substrait/uri_map
Authored-by: Sanjiban Sengupta <sa...@gmail.com>
Signed-off-by: Weston Pace <we...@gmail.com>
---
cpp/examples/arrow/engine_substrait_consumption.cc | 5 -
cpp/src/arrow/engine/substrait/extension_set.cc | 204 ++++++++++-----------
cpp/src/arrow/engine/substrait/extension_set.h | 51 ++++--
cpp/src/arrow/engine/substrait/plan_internal.cc | 54 ++----
cpp/src/arrow/engine/substrait/serde_test.cc | 13 --
cpp/src/arrow/engine/substrait/type_internal.cc | 11 --
6 files changed, 142 insertions(+), 196 deletions(-)
diff --git a/cpp/examples/arrow/engine_substrait_consumption.cc b/cpp/examples/arrow/engine_substrait_consumption.cc
index 779fc0bf55..8ee6bf7a59 100644
--- a/cpp/examples/arrow/engine_substrait_consumption.cc
+++ b/cpp/examples/arrow/engine_substrait_consumption.cc
@@ -106,11 +106,6 @@ arrow::Future<std::shared_ptr<arrow::Buffer>> GetSubstraitFromServer(
"type_anchor": 42,
"name": "null"
}},
- {"extension_type_variation": {
- "extension_uri_reference": 7,
- "type_variation_anchor": 23,
- "name": "u8"
- }},
{"extension_function": {
"extension_uri_reference": 7,
"function_anchor": 42,
diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc
index 80cdf59f49..b7d2f87b74 100644
--- a/cpp/src/arrow/engine/substrait/extension_set.cc
+++ b/cpp/src/arrow/engine/substrait/extension_set.cc
@@ -40,81 +40,63 @@ struct TypePtrHashEq {
}
};
-struct IdHashEq {
- using Id = ExtensionSet::Id;
-
- size_t operator()(Id id) const {
- constexpr ::arrow::internal::StringViewHash hash = {};
- auto out = static_cast<size_t>(hash(id.uri));
- ::arrow::internal::hash_combine(out, hash(id.name));
- return out;
- }
+} // namespace
- bool operator()(Id l, Id r) const { return l.uri == r.uri && l.name == r.name; }
-};
+size_t ExtensionIdRegistry::IdHashEq::operator()(ExtensionIdRegistry::Id id) const {
+ constexpr ::arrow::internal::StringViewHash hash = {};
+ auto out = static_cast<size_t>(hash(id.uri));
+ ::arrow::internal::hash_combine(out, hash(id.name));
+ return out;
+}
-} // namespace
+bool ExtensionIdRegistry::IdHashEq::operator()(ExtensionIdRegistry::Id l,
+ ExtensionIdRegistry::Id r) const {
+ return l.uri == r.uri && l.name == r.name;
+}
// A builder used when creating a Substrait plan from an Arrow execution plan. In
// that situation we do not have a set of anchor values already defined so we keep
// a map of what Ids we have seen.
-struct ExtensionSet::Impl {
- void AddUri(util::string_view uri, ExtensionSet* self) {
- if (uris_.find(uri) != uris_.end()) return;
-
- self->uris_.push_back(uri);
- uris_.insert(self->uris_.back()); // lookup helper's keys should reference memory
- // owned by this ExtensionSet
- }
-
- Status CheckHasUri(util::string_view uri) {
- if (uris_.find(uri) != uris_.end()) return Status::OK();
-
- return Status::Invalid(
- "Uri ", uri,
- " was referenced by an extension but was not declared in the ExtensionSet.");
- }
-
- uint32_t EncodeType(ExtensionIdRegistry::TypeRecord type_record, ExtensionSet* self) {
- // note: at this point we're guaranteed to have an Id which points to memory owned by
- // the set's registry.
- AddUri(type_record.id.uri, self);
- auto it_success =
- types_.emplace(type_record.id, static_cast<uint32_t>(types_.size()));
-
- if (it_success.second) {
- self->types_.push_back(
- {type_record.id, type_record.type, type_record.is_variation});
- }
-
- return it_success.first->second;
- }
-
- uint32_t EncodeFunction(Id id, util::string_view function_name, ExtensionSet* self) {
- // note: at this point we're guaranteed to have an Id which points to memory owned by
- // the set's registry.
- AddUri(id.uri, self);
- auto it_success = functions_.emplace(id, static_cast<uint32_t>(functions_.size()));
+ExtensionSet::ExtensionSet(ExtensionIdRegistry* registry) : registry_(registry) {}
+
+Status ExtensionSet::CheckHasUri(util::string_view uri) {
+ auto it =
+ std::find_if(uris_.begin(), uris_.end(),
+ [&uri](const std::pair<uint32_t, util::string_view>& anchor_uri_pair) {
+ return anchor_uri_pair.second == uri;
+ });
+ if (it != uris_.end()) return Status::OK();
+
+ return Status::Invalid(
+ "Uri ", uri,
+ " was referenced by an extension but was not declared in the ExtensionSet.");
+}
- if (it_success.second) {
- self->functions_.push_back({id, function_name});
- }
+void ExtensionSet::AddUri(std::pair<uint32_t, util::string_view> uri) {
+ auto it =
+ std::find_if(uris_.begin(), uris_.end(),
+ [&uri](const std::pair<uint32_t, util::string_view>& anchor_uri_pair) {
+ return anchor_uri_pair.second == uri.second;
+ });
+ if (it != uris_.end()) return;
+ uris_[uri.first] = uri.second;
+}
- return it_success.first->second;
+Status ExtensionSet::AddUri(Id id) {
+ auto uris_size = static_cast<unsigned int>(uris_.size());
+ if (uris_.find(uris_size) != uris_.end()) {
+ // Substrait plans shouldn't have repeated URIs in the extension set
+ return Status::Invalid("Key already exists in the uris map");
}
+ uris_[uris_size] = id.uri;
+ return Status::OK();
+}
- std::unordered_set<util::string_view, ::arrow::internal::StringViewHash> uris_;
- std::unordered_map<Id, uint32_t, IdHashEq, IdHashEq> types_, functions_;
-};
-
-ExtensionSet::ExtensionSet(ExtensionIdRegistry* registry)
- : registry_(registry), impl_(new Impl(), [](Impl* impl) { delete impl; }) {}
-
-Result<ExtensionSet> ExtensionSet::Make(std::vector<util::string_view> uris,
- std::vector<Id> type_ids,
- std::vector<bool> type_is_variation,
- std::vector<Id> function_ids,
- ExtensionIdRegistry* registry) {
+// Creates an extension set from the Substrait plan's top-level extensions block
+Result<ExtensionSet> ExtensionSet::Make(
+ std::unordered_map<uint32_t, util::string_view> uris,
+ std::unordered_map<uint32_t, Id> type_ids,
+ std::unordered_map<uint32_t, Id> function_ids, ExtensionIdRegistry* registry) {
ExtensionSet set;
set.registry_ = registry;
@@ -126,39 +108,32 @@ Result<ExtensionSet> ExtensionSet::Make(std::vector<util::string_view> uris,
}
for (auto& uri : uris) {
- if (uri.empty()) continue;
- auto it = uris_owned_by_registry.find(uri);
+ auto it = uris_owned_by_registry.find(uri.second);
if (it == uris_owned_by_registry.end()) {
- return Status::KeyError("Uri '", uri, "' not found in registry");
+ return Status::KeyError("Uri '", uri.second, "' not found in registry");
}
- uri = *it; // Ensure uris point into the registry's memory
- set.impl_->AddUri(*it, &set);
+ uri.second = *it; // Ensure uris point into the registry's memory
+ set.AddUri(uri);
}
- if (type_ids.size() != type_is_variation.size()) {
- return Status::Invalid("Received ", type_ids.size(), " type ids but a ",
- type_is_variation.size(), "-long is_variation vector");
- }
+ set.types_.reserve(type_ids.size());
- set.types_.resize(type_ids.size());
-
- for (size_t i = 0; i < type_ids.size(); ++i) {
+ for (unsigned int i = 0; i < static_cast<unsigned int>(type_ids.size()); ++i) {
if (type_ids[i].empty()) continue;
- RETURN_NOT_OK(set.impl_->CheckHasUri(type_ids[i].uri));
+ RETURN_NOT_OK(set.CheckHasUri(type_ids[i].uri));
- if (auto rec = registry->GetType(type_ids[i], type_is_variation[i])) {
- set.types_[i] = {rec->id, rec->type, rec->is_variation};
+ if (auto rec = registry->GetType(type_ids[i])) {
+ set.types_[i] = {rec->id, rec->type};
continue;
}
- return Status::Invalid("Type", (type_is_variation[i] ? " variation" : ""), " ",
- type_ids[i].uri, "#", type_ids[i].name, " not found");
+ return Status::Invalid("Type ", type_ids[i].uri, "#", type_ids[i].name, " not found");
}
- set.functions_.resize(function_ids.size());
+ set.functions_.reserve(function_ids.size());
- for (size_t i = 0; i < function_ids.size(); ++i) {
+ for (unsigned int i = 0; i < static_cast<unsigned int>(function_ids.size()); ++i) {
if (function_ids[i].empty()) continue;
- RETURN_NOT_OK(set.impl_->CheckHasUri(function_ids[i].uri));
+ RETURN_NOT_OK(set.CheckHasUri(function_ids[i].uri));
if (auto rec = registry->GetFunction(function_ids[i])) {
set.functions_[i] = {rec->id, rec->function_name};
@@ -174,31 +149,50 @@ Result<ExtensionSet> ExtensionSet::Make(std::vector<util::string_view> uris,
}
Result<ExtensionSet::TypeRecord> ExtensionSet::DecodeType(uint32_t anchor) const {
- if (anchor >= types_.size() || types_[anchor].id.empty()) {
+ if (types_.find(anchor) == types_.end() || types_.at(anchor).id.empty()) {
return Status::Invalid("User defined type reference ", anchor,
" did not have a corresponding anchor in the extension set");
}
- return types_[anchor];
+ return types_.at(anchor);
}
Result<uint32_t> ExtensionSet::EncodeType(const DataType& type) {
if (auto rec = registry_->GetType(type)) {
- return impl_->EncodeType(*rec, this);
+ RETURN_NOT_OK(this->AddUri(rec->id));
+ auto it_success =
+ types_map_.emplace(rec->id, static_cast<uint32_t>(types_map_.size()));
+ if (it_success.second) {
+ DCHECK_EQ(types_.find(static_cast<unsigned int>(types_.size())), types_.end())
+ << "Type existed in types_ but not types_map_. ExtensionSet is inconsistent";
+ types_[static_cast<unsigned int>(types_.size())] = {rec->id, rec->type};
+ }
+ return it_success.first->second;
}
return Status::KeyError("type ", type.ToString(), " not found in the registry");
}
Result<ExtensionSet::FunctionRecord> ExtensionSet::DecodeFunction(uint32_t anchor) const {
- if (anchor >= functions_.size() || functions_[anchor].id.empty()) {
+ if (functions_.find(anchor) == functions_.end() || functions_.at(anchor).id.empty()) {
return Status::Invalid("User defined function reference ", anchor,
" did not have a corresponding anchor in the extension set");
}
- return functions_[anchor];
+ return functions_.at(anchor);
}
Result<uint32_t> ExtensionSet::EncodeFunction(util::string_view function_name) {
if (auto rec = registry_->GetFunction(function_name)) {
- return impl_->EncodeFunction(rec->id, rec->function_name, this);
+ RETURN_NOT_OK(this->AddUri(rec->id));
+ auto it_success =
+ functions_map_.emplace(rec->id, static_cast<uint32_t>(functions_map_.size()));
+ if (it_success.second) {
+ DCHECK_EQ(functions_.find(static_cast<unsigned int>(functions_.size())),
+ functions_.end())
+ << "Function existed in functions_ but not functions_map_. ExtensionSet is "
+ "inconsistent";
+ functions_[static_cast<unsigned int>(functions_.size())] = {rec->id,
+ rec->function_name};
+ }
+ return it_success.first->second;
}
return Status::KeyError("function ", function_name, " not found in the registry");
}
@@ -228,8 +222,7 @@ ExtensionIdRegistry* default_extension_id_registry() {
TypeName{uint64(), "u64"},
TypeName{float16(), "fp16"},
}) {
- DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type),
- /*is_variation=*/true));
+ DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type)));
}
for (TypeName e : {
@@ -238,8 +231,7 @@ ExtensionIdRegistry* default_extension_id_registry() {
TypeName{day_time_interval(), "interval_day_milli"},
TypeName{month_day_nano_interval(), "interval_month_day_nano"},
}) {
- DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type),
- /*is_variation=*/false));
+ DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type)));
}
// TODO: this is just a placeholder right now. We'll need a YAML file for
@@ -259,44 +251,39 @@ ExtensionIdRegistry* default_extension_id_registry() {
util::optional<TypeRecord> GetType(const DataType& type) const override {
if (auto index = GetIndex(type_to_index_, &type)) {
- return TypeRecord{type_ids_[*index], types_[*index], type_is_variation_[*index]};
+ return TypeRecord{type_ids_[*index], types_[*index]};
}
return {};
}
- util::optional<TypeRecord> GetType(Id id, bool is_variation) const override {
- if (auto index =
- GetIndex(is_variation ? variation_id_to_index_ : id_to_index_, id)) {
- return TypeRecord{type_ids_[*index], types_[*index], type_is_variation_[*index]};
+ util::optional<TypeRecord> GetType(Id id) const override {
+ if (auto index = GetIndex(id_to_index_, id)) {
+ return TypeRecord{type_ids_[*index], types_[*index]};
}
return {};
}
- Status RegisterType(Id id, std::shared_ptr<DataType> type,
- bool is_variation) override {
+ Status RegisterType(Id id, std::shared_ptr<DataType> type) override {
DCHECK_EQ(type_ids_.size(), types_.size());
- DCHECK_EQ(type_ids_.size(), type_is_variation_.size());
Id copied_id{*uris_.emplace(id.uri.to_string()).first,
*names_.emplace(id.name.to_string()).first};
auto index = static_cast<int>(type_ids_.size());
- auto* id_to_index = is_variation ? &variation_id_to_index_ : &id_to_index_;
- auto it_success = id_to_index->emplace(copied_id, index);
+ auto it_success = id_to_index_.emplace(copied_id, index);
if (!it_success.second) {
return Status::Invalid("Type id was already registered");
}
if (!type_to_index_.emplace(type.get(), index).second) {
- id_to_index->erase(it_success.first);
+ id_to_index_.erase(it_success.first);
return Status::Invalid("Type was already registered");
}
type_ids_.push_back(copied_id);
types_.push_back(std::move(type));
- type_is_variation_.push_back(is_variation);
return Status::OK();
}
@@ -347,11 +334,10 @@ ExtensionIdRegistry* default_extension_id_registry() {
// unordered_set are not invalidated on insertion
std::unordered_set<std::string> uris_, names_, function_names_;
DataTypeVector types_;
- std::vector<bool> type_is_variation_;
// non-owning lookup helpers
std::vector<Id> type_ids_, function_ids_;
- std::unordered_map<Id, int, IdHashEq, IdHashEq> id_to_index_, variation_id_to_index_;
+ std::unordered_map<Id, int, IdHashEq, IdHashEq> id_to_index_;
std::unordered_map<const DataType*, int, TypePtrHashEq, TypePtrHashEq> type_to_index_;
std::vector<const std::string*> function_name_ptrs_;
diff --git a/cpp/src/arrow/engine/substrait/extension_set.h b/cpp/src/arrow/engine/substrait/extension_set.h
index 951f7ffa3a..55ea4d0232 100644
--- a/cpp/src/arrow/engine/substrait/extension_set.h
+++ b/cpp/src/arrow/engine/substrait/extension_set.h
@@ -19,6 +19,7 @@
#pragma once
+#include <unordered_map>
#include <vector>
#include "arrow/engine/substrait/visibility.h"
@@ -26,6 +27,8 @@
#include "arrow/util/optional.h"
#include "arrow/util/string_view.h"
+#include "arrow/util/hash_util.h"
+
namespace arrow {
namespace engine {
@@ -55,15 +58,19 @@ class ARROW_ENGINE_EXPORT ExtensionIdRegistry {
bool empty() const { return uri.empty() && name.empty(); }
};
+ struct IdHashEq {
+ size_t operator()(Id id) const;
+ bool operator()(Id l, Id r) const;
+ };
+
/// \brief A mapping between a Substrait ID and an arrow::DataType
struct TypeRecord {
Id id;
const std::shared_ptr<DataType>& type;
- bool is_variation;
};
virtual util::optional<TypeRecord> GetType(const DataType&) const = 0;
- virtual util::optional<TypeRecord> GetType(Id, bool is_variation) const = 0;
- virtual Status RegisterType(Id, std::shared_ptr<DataType>, bool is_variation) = 0;
+ virtual util::optional<TypeRecord> GetType(Id) const = 0;
+ virtual Status RegisterType(Id, std::shared_ptr<DataType>) = 0;
/// \brief A mapping between a Substrait ID and an Arrow function
///
@@ -127,6 +134,7 @@ ARROW_ENGINE_EXPORT ExtensionIdRegistry* default_extension_id_registry();
class ARROW_ENGINE_EXPORT ExtensionSet {
public:
using Id = ExtensionIdRegistry::Id;
+ using IdHashEq = ExtensionIdRegistry::IdHashEq;
struct FunctionRecord {
Id id;
@@ -136,7 +144,6 @@ class ARROW_ENGINE_EXPORT ExtensionSet {
struct TypeRecord {
Id id;
std::shared_ptr<DataType> type;
- bool is_variation;
};
/// Construct an empty ExtensionSet to be populated during serialization.
@@ -158,14 +165,12 @@ class ARROW_ENGINE_EXPORT ExtensionSet {
/// An extension set should instead be created using
/// arrow::engine::GetExtensionSetFromPlan
static Result<ExtensionSet> Make(
- std::vector<util::string_view> uris, std::vector<Id> type_ids,
- std::vector<bool> type_is_variation, std::vector<Id> function_ids,
+ std::unordered_map<uint32_t, util::string_view> uris,
+ std::unordered_map<uint32_t, Id> type_ids,
+ std::unordered_map<uint32_t, Id> function_ids,
ExtensionIdRegistry* = default_extension_id_registry());
- // index in these vectors == value of _anchor/_reference fields
- /// TODO(ARROW-15583) this assumes that _anchor/_references won't be huge, which is not
- /// guaranteed. Could it be?
- const std::vector<util::string_view>& uris() const { return uris_; }
+ const std::unordered_map<uint32_t, util::string_view>& uris() const { return uris_; }
/// \brief Returns a data type given an anchor
///
@@ -225,15 +230,25 @@ class ARROW_ENGINE_EXPORT ExtensionSet {
private:
ExtensionIdRegistry* registry_;
- /// The subset of extension registry URIs referenced by this extension set
- std::vector<util::string_view> uris_;
- std::vector<TypeRecord> types_;
-
- std::vector<FunctionRecord> functions_;
- // pimpl pattern to hide lookup details
- struct Impl;
- std::unique_ptr<Impl, void (*)(Impl*)> impl_;
+ // Map from anchor values to URI values referenced by this extension set
+ std::unordered_map<uint32_t, util::string_view> uris_;
+ // Map from anchor values to type definitions, used during Substrait->Arrow
+ // and populated from the Substrait extension set
+ std::unordered_map<uint32_t, TypeRecord> types_;
+ // Map from anchor values to function definitions, used during Substrait->Arrow
+ // and populated from the Substrait extension set
+ std::unordered_map<uint32_t, FunctionRecord> functions_;
+ // Map from type names to anchor values. Used during Arrow->Substrait
+ // and built as the plan is created.
+ std::unordered_map<Id, uint32_t, IdHashEq, IdHashEq> types_map_;
+ // Map from function names to anchor values. Used during Arrow->Substrait
+ // and built as the plan is created.
+ std::unordered_map<Id, uint32_t, IdHashEq, IdHashEq> functions_map_;
+
+ Status CheckHasUri(util::string_view uri);
+ void AddUri(std::pair<uint32_t, util::string_view> uri);
+ Status AddUri(Id id);
};
} // namespace engine
diff --git a/cpp/src/arrow/engine/substrait/plan_internal.cc b/cpp/src/arrow/engine/substrait/plan_internal.cc
index 5813dcde24..cc20de6da6 100644
--- a/cpp/src/arrow/engine/substrait/plan_internal.cc
+++ b/cpp/src/arrow/engine/substrait/plan_internal.cc
@@ -43,7 +43,7 @@ Status AddExtensionSetToPlan(const ExtensionSet& ext_set, substrait::Plan* plan)
auto uris = plan->mutable_extension_uris();
uris->Reserve(static_cast<int>(ext_set.uris().size()));
for (uint32_t anchor = 0; anchor < ext_set.uris().size(); ++anchor) {
- auto uri = ext_set.uris()[anchor];
+ auto uri = ext_set.uris().at(anchor);
if (uri.empty()) continue;
auto ext_uri = internal::make_unique<substrait::extensions::SimpleExtensionURI>();
@@ -65,20 +65,11 @@ Status AddExtensionSetToPlan(const ExtensionSet& ext_set, substrait::Plan* plan)
auto ext_decl = internal::make_unique<ExtDecl>();
- if (type_record.is_variation) {
- auto type_var = internal::make_unique<ExtDecl::ExtensionTypeVariation>();
- type_var->set_extension_uri_reference(map[type_record.id.uri]);
- type_var->set_type_variation_anchor(anchor);
- type_var->set_name(type_record.id.name.to_string());
- ext_decl->set_allocated_extension_type_variation(type_var.release());
- } else {
- auto type = internal::make_unique<ExtDecl::ExtensionType>();
- type->set_extension_uri_reference(map[type_record.id.uri]);
- type->set_type_anchor(anchor);
- type->set_name(type_record.id.name.to_string());
- ext_decl->set_allocated_extension_type(type.release());
- }
-
+ auto type = internal::make_unique<ExtDecl::ExtensionType>();
+ type->set_extension_uri_reference(map[type_record.id.uri]);
+ type->set_type_anchor(anchor);
+ type->set_name(type_record.id.name.to_string());
+ ext_decl->set_allocated_extension_type(type.release());
extensions->AddAllocated(ext_decl.release());
}
@@ -99,22 +90,12 @@ Status AddExtensionSetToPlan(const ExtensionSet& ext_set, substrait::Plan* plan)
return Status::OK();
}
-namespace {
-template <typename Element, typename T>
-void SetElement(size_t i, const Element& element, std::vector<T>* vector) {
- DCHECK_LE(i, 1 << 20);
- if (i >= vector->size()) {
- vector->resize(i + 1);
- }
- (*vector)[i] = static_cast<T>(element);
-}
-} // namespace
-
Result<ExtensionSet> GetExtensionSetFromPlan(const substrait::Plan& plan,
ExtensionIdRegistry* registry) {
- std::vector<util::string_view> uris;
+ std::unordered_map<uint32_t, util::string_view> uris;
+ uris.reserve(plan.extension_uris_size());
for (const auto& uri : plan.extension_uris()) {
- SetElement(uri.extension_uri_anchor(), uri.uri(), &uris);
+ uris[uri.extension_uri_anchor()] = uri.uri();
}
// NOTE: it's acceptable to use views to memory owned by plan; ExtensionSet::Make
@@ -122,30 +103,24 @@ Result<ExtensionSet> GetExtensionSetFromPlan(const substrait::Plan& plan,
using Id = ExtensionSet::Id;
- std::vector<Id> type_ids, function_ids;
- std::vector<bool> type_is_variation;
+ std::unordered_map<uint32_t, Id> type_ids, function_ids;
for (const auto& ext : plan.extensions()) {
switch (ext.mapping_type_case()) {
case substrait::extensions::SimpleExtensionDeclaration::kExtensionTypeVariation: {
- const auto& type_var = ext.extension_type_variation();
- util::string_view uri = uris[type_var.extension_uri_reference()];
- SetElement(type_var.type_variation_anchor(), Id{uri, type_var.name()}, &type_ids);
- SetElement(type_var.type_variation_anchor(), true, &type_is_variation);
- break;
+ return Status::NotImplemented("Type Variations are not yet implemented");
}
case substrait::extensions::SimpleExtensionDeclaration::kExtensionType: {
const auto& type = ext.extension_type();
util::string_view uri = uris[type.extension_uri_reference()];
- SetElement(type.type_anchor(), Id{uri, type.name()}, &type_ids);
- SetElement(type.type_anchor(), false, &type_is_variation);
+ type_ids[type.type_anchor()] = Id{uri, type.name()};
break;
}
case substrait::extensions::SimpleExtensionDeclaration::kExtensionFunction: {
const auto& fn = ext.extension_function();
util::string_view uri = uris[fn.extension_uri_reference()];
- SetElement(fn.function_anchor(), Id{uri, fn.name()}, &function_ids);
+ function_ids[fn.function_anchor()] = Id{uri, fn.name()};
break;
}
@@ -154,8 +129,7 @@ Result<ExtensionSet> GetExtensionSetFromPlan(const substrait::Plan& plan,
}
}
- return ExtensionSet::Make(std::move(uris), std::move(type_ids),
- std::move(type_is_variation), std::move(function_ids),
+ return ExtensionSet::Make(std::move(uris), std::move(type_ids), std::move(function_ids),
registry);
}
diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc
index deee2d1445..fae23f200d 100644
--- a/cpp/src/arrow/engine/substrait/serde_test.cc
+++ b/cpp/src/arrow/engine/substrait/serde_test.cc
@@ -689,11 +689,6 @@ TEST(Substrait, ExtensionSetFromPlan) {
"type_anchor": 42,
"name": "null"
}},
- {"extension_type_variation": {
- "extension_uri_reference": 7,
- "type_variation_anchor": 23,
- "name": "u8"
- }},
{"extension_function": {
"extension_uri_reference": 7,
"function_anchor": 42,
@@ -701,7 +696,6 @@ TEST(Substrait, ExtensionSetFromPlan) {
}}
]
})"));
-
ExtensionSet ext_set;
ASSERT_OK_AND_ASSIGN(
auto sink_decls,
@@ -713,13 +707,6 @@ TEST(Substrait, ExtensionSetFromPlan) {
EXPECT_EQ(decoded_null_type.id.uri, kArrowExtTypesUri);
EXPECT_EQ(decoded_null_type.id.name, "null");
EXPECT_EQ(*decoded_null_type.type, NullType());
- EXPECT_FALSE(decoded_null_type.is_variation);
-
- EXPECT_OK_AND_ASSIGN(auto decoded_uint8_type, ext_set.DecodeType(23));
- EXPECT_EQ(decoded_uint8_type.id.uri, kArrowExtTypesUri);
- EXPECT_EQ(decoded_uint8_type.id.name, "u8");
- EXPECT_EQ(*decoded_uint8_type.type, UInt8Type());
- EXPECT_TRUE(decoded_uint8_type.is_variation);
EXPECT_OK_AND_ASSIGN(auto decoded_add_func, ext_set.DecodeFunction(42));
EXPECT_EQ(decoded_add_func.id.uri, kArrowExtTypesUri);
diff --git a/cpp/src/arrow/engine/substrait/type_internal.cc b/cpp/src/arrow/engine/substrait/type_internal.cc
index c7b94b4104..178e6d25a5 100644
--- a/cpp/src/arrow/engine/substrait/type_internal.cc
+++ b/cpp/src/arrow/engine/substrait/type_internal.cc
@@ -37,12 +37,6 @@ using ::arrow::internal::make_unique;
namespace {
-template <typename TypeMessage>
-Status CheckVariation(const TypeMessage& type) {
- if (type.type_variation_reference() == 0) return Status::OK();
- return Status::NotImplemented("Type variations for ", type.DebugString());
-}
-
template <typename TypeMessage>
bool IsNullable(const TypeMessage& type) {
// FIXME what can we do with NULLABILITY_UNSPECIFIED
@@ -52,8 +46,6 @@ bool IsNullable(const TypeMessage& type) {
template <typename ArrowType, typename TypeMessage, typename... A>
Result<std::pair<std::shared_ptr<DataType>, bool>> FromProtoImpl(const TypeMessage& type,
A&&... args) {
- RETURN_NOT_OK(CheckVariation(type));
-
return std::make_pair(std::static_pointer_cast<DataType>(
std::make_shared<ArrowType>(std::forward<A>(args)...)),
IsNullable(type));
@@ -62,8 +54,6 @@ Result<std::pair<std::shared_ptr<DataType>, bool>> FromProtoImpl(const TypeMessa
template <typename TypeMessage, typename... A>
Result<std::pair<std::shared_ptr<DataType>, bool>> FromProtoImpl(
const TypeMessage& type, std::shared_ptr<DataType> type_factory(A...), A&&... args) {
- RETURN_NOT_OK(CheckVariation(type));
-
return std::make_pair(
std::static_pointer_cast<DataType>(type_factory(std::forward<A>(args)...)),
IsNullable(type));
@@ -430,7 +420,6 @@ Result<std::shared_ptr<Schema>> FromProto(const ::substrait::NamedStruct& named_
"could be attached.");
}
const auto& struct_ = named_struct.struct_();
- RETURN_NOT_OK(CheckVariation(struct_));
int requested_names_count = 0;
ARROW_ASSIGN_OR_RAISE(auto fields, FieldsFromProto(