You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ic...@apache.org on 2023/05/01 12:45:16 UTC

[arrow] branch main updated: GH-35363: [C++] Fix Substrait schema names and for segmented aggregation (#35364)

This is an automated email from the ASF dual-hosted git repository.

icexelloss pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new 3b48834989 GH-35363: [C++] Fix Substrait schema names and for segmented aggregation (#35364)
3b48834989 is described below

commit 3b48834989fdeb9cc561f11585da0f78f4ced48e
Author: rtpsw <rt...@hotmail.com>
AuthorDate: Mon May 1 15:45:05 2023 +0300

    GH-35363: [C++] Fix Substrait schema names and for segmented aggregation (#35364)
    
    See https://github.com/apache/arrow/issues/35363
    * Closes: #35363
    
    Authored-by: Yaron Gvili <rt...@hotmail.com>
    Signed-off-by: Li Jin <ic...@gmail.com>
---
 cpp/src/arrow/acero/aggregate_node.cc        | 2 +-
 cpp/src/arrow/engine/substrait/options.cc    | 6 +++---
 cpp/src/arrow/engine/substrait/serde.cc      | 7 +++++++
 cpp/src/arrow/engine/substrait/serde_test.cc | 2 +-
 cpp/src/arrow/engine/substrait/type_fwd.h    | 1 -
 5 files changed, 12 insertions(+), 6 deletions(-)

diff --git a/cpp/src/arrow/acero/aggregate_node.cc b/cpp/src/arrow/acero/aggregate_node.cc
index 7adbb8c565..0de59899b4 100644
--- a/cpp/src/arrow/acero/aggregate_node.cc
+++ b/cpp/src/arrow/acero/aggregate_node.cc
@@ -857,7 +857,7 @@ class GroupByNode : public ExecNode, public TracedNode {
 
     // Segment keys come first
     PlaceFields(out_data, 0, segmenter_values_);
-    // Followed by segment keys
+    // Followed by keys
     ARROW_ASSIGN_OR_RAISE(ExecBatch out_keys, state->grouper->GetUniques());
     std::move(out_keys.values.begin(), out_keys.values.end(),
               out_data.values.begin() + segment_key_field_ids_.size());
diff --git a/cpp/src/arrow/engine/substrait/options.cc b/cpp/src/arrow/engine/substrait/options.cc
index 67fc4f329d..4813750767 100644
--- a/cpp/src/arrow/engine/substrait/options.cc
+++ b/cpp/src/arrow/engine/substrait/options.cc
@@ -208,9 +208,9 @@ class DefaultExtensionProvider : public BaseExtensionProvider {
       aggregates.push_back(std::move(aggregate));
     }
 
-    ARROW_ASSIGN_OR_RAISE(auto aggregate_schema,
-                          acero::aggregate::MakeOutputSchema(
-                              input_schema, keys, /*segment_keys=*/{}, aggregates));
+    ARROW_ASSIGN_OR_RAISE(
+        auto aggregate_schema,
+        acero::aggregate::MakeOutputSchema(input_schema, keys, segment_keys, aggregates));
 
     return internal::MakeAggregateDeclaration(
         std::move(inputs[0].declaration), std::move(aggregate_schema),
diff --git a/cpp/src/arrow/engine/substrait/serde.cc b/cpp/src/arrow/engine/substrait/serde.cc
index 78f0ea8892..a81e488bf9 100644
--- a/cpp/src/arrow/engine/substrait/serde.cc
+++ b/cpp/src/arrow/engine/substrait/serde.cc
@@ -154,6 +154,13 @@ Result<std::vector<acero::Declaration>> DeserializePlans(
     if (plan_rel.has_root()) {
       names.assign(plan_rel.root().names().begin(), plan_rel.root().names().end());
     }
+    if (names.size() > 0) {
+      if (decl_info.output_schema->num_fields() != plan_rel.root().names_size()) {
+        return Status::Invalid("Substrait plan has ", plan_rel.root().names_size(),
+                               " names that cannot be applied to extension schema:\n",
+                               decl_info.output_schema->ToString(false));
+      }
+    }
 
     // pipe each relation
     ARROW_ASSIGN_OR_RAISE(
diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc
index d89924b807..fb211707f0 100644
--- a/cpp/src/arrow/engine/substrait/serde_test.cc
+++ b/cpp/src/arrow/engine/substrait/serde_test.cc
@@ -3937,7 +3937,7 @@ TEST(Substrait, ProjectWithMultiFieldExpressions) {
             }]
           }
         },
-        "names": ["A", "B", "C", "D"]
+        "names": ["A", "B", "C"]
       }
     }]
   })";
diff --git a/cpp/src/arrow/engine/substrait/type_fwd.h b/cpp/src/arrow/engine/substrait/type_fwd.h
index bbeb4ef689..6089d3f747 100644
--- a/cpp/src/arrow/engine/substrait/type_fwd.h
+++ b/cpp/src/arrow/engine/substrait/type_fwd.h
@@ -27,7 +27,6 @@ class ExtensionSet;
 
 struct ConversionOptions;
 struct DeclarationInfo;
-struct RelationInfo;
 
 }  // namespace engine
 }  // namespace arrow