You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by we...@apache.org on 2022/07/28 05:13:28 UTC

[arrow] branch master updated: ARROW-17230: [C++] Fix DeserializePlan, add additional option validation (#13728)

This is an automated email from the ASF dual-hosted git repository.

westonpace pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 60a1919527 ARROW-17230: [C++] Fix DeserializePlan, add additional option validation (#13728)
60a1919527 is described below

commit 60a1919527003a55a39785633bdbbbeef412c362
Author: David Li <li...@gmail.com>
AuthorDate: Thu Jul 28 01:13:22 2022 -0400

    ARROW-17230: [C++] Fix DeserializePlan, add additional option validation (#13728)
    
    - DeserializePlan's signature was a guaranteed use-after-free
    - Check that the write options are actually filled out
    - Don't call the consumer factory twice per declaration
    - Don't ignore plan construction errors
    - Actually attempt to run the plans we make
    
    Authored-by: David Li <li...@gmail.com>
    Signed-off-by: Weston Pace <we...@gmail.com>
---
 cpp/src/arrow/compute/exec/sink_node.cc      |  4 ++
 cpp/src/arrow/dataset/dataset_writer.cc      |  6 +++
 cpp/src/arrow/dataset/file_base.cc           |  4 ++
 cpp/src/arrow/engine/substrait/serde.cc      | 16 +++---
 cpp/src/arrow/engine/substrait/serde.h       |  4 +-
 cpp/src/arrow/engine/substrait/serde_test.cc | 75 ++++++++++++++++++++--------
 6 files changed, 77 insertions(+), 32 deletions(-)

diff --git a/cpp/src/arrow/compute/exec/sink_node.cc b/cpp/src/arrow/compute/exec/sink_node.cc
index 9118d5a50e..a1426265cf 100644
--- a/cpp/src/arrow/compute/exec/sink_node.cc
+++ b/cpp/src/arrow/compute/exec/sink_node.cc
@@ -271,6 +271,10 @@ class ConsumingSinkNode : public ExecNode, public BackpressureControl {
     RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "SinkNode"));
 
     const auto& sink_options = checked_cast<const ConsumingSinkNodeOptions&>(options);
+    if (!sink_options.consumer) {
+      return Status::Invalid("A SinkNodeConsumer is required");
+    }
+
     return plan->EmplaceNode<ConsumingSinkNode>(plan, std::move(inputs),
                                                 std::move(sink_options.consumer),
                                                 std::move(sink_options.names));
diff --git a/cpp/src/arrow/dataset/dataset_writer.cc b/cpp/src/arrow/dataset/dataset_writer.cc
index 4fc0d814f3..36305eac73 100644
--- a/cpp/src/arrow/dataset/dataset_writer.cc
+++ b/cpp/src/arrow/dataset/dataset_writer.cc
@@ -398,6 +398,12 @@ Status ValidateBasenameTemplate(util::string_view basename_template) {
 
 Status ValidateOptions(const FileSystemDatasetWriteOptions& options) {
   ARROW_RETURN_NOT_OK(ValidateBasenameTemplate(options.basename_template));
+  if (!options.file_write_options) {
+    return Status::Invalid("Must provide file_write_options");
+  }
+  if (!options.filesystem) {
+    return Status::Invalid("Must provide filesystem");
+  }
   if (options.max_rows_per_group <= 0) {
     return Status::Invalid("max_rows_per_group must be a positive number");
   }
diff --git a/cpp/src/arrow/dataset/file_base.cc b/cpp/src/arrow/dataset/file_base.cc
index f50bbe1e0c..b3f161e92d 100644
--- a/cpp/src/arrow/dataset/file_base.cc
+++ b/cpp/src/arrow/dataset/file_base.cc
@@ -436,6 +436,10 @@ Result<compute::ExecNode*> MakeWriteNode(compute::ExecPlan* plan,
       write_node_options.custom_metadata;
   const FileSystemDatasetWriteOptions& write_options = write_node_options.write_options;
 
+  if (!write_options.partitioning) {
+    return Status::Invalid("Must provide partitioning");
+  }
+
   ARROW_ASSIGN_OR_RAISE(auto dataset_writer,
                         internal::DatasetWriter::Make(write_options));
 
diff --git a/cpp/src/arrow/engine/substrait/serde.cc b/cpp/src/arrow/engine/substrait/serde.cc
index 00901b5e95..238008a714 100644
--- a/cpp/src/arrow/engine/substrait/serde.cc
+++ b/cpp/src/arrow/engine/substrait/serde.cc
@@ -70,12 +70,12 @@ DeclarationFactory MakeConsumingSinkDeclarationFactory(
              compute::Declaration input,
              std::vector<std::string> names) -> Result<compute::Declaration> {
     std::shared_ptr<compute::SinkNodeConsumer> consumer = consumer_factory();
-    if (consumer == NULLPTR) {
+    if (consumer == nullptr) {
       return Status::Invalid("consumer factory is exhausted");
     }
     std::shared_ptr<compute::ExecNodeOptions> options =
         std::make_shared<compute::ConsumingSinkNodeOptions>(
-            compute::ConsumingSinkNodeOptions{consumer_factory(), std::move(names)});
+            compute::ConsumingSinkNodeOptions{std::move(consumer), std::move(names)});
     return compute::Declaration::Sequence(
         {std::move(input), {"consuming_sink", options}});
   };
@@ -103,7 +103,7 @@ DeclarationFactory MakeWriteDeclarationFactory(
              compute::Declaration input,
              std::vector<std::string> names) -> Result<compute::Declaration> {
     std::shared_ptr<dataset::WriteNodeOptions> options = write_options_factory();
-    if (options == NULLPTR) {
+    if (options == nullptr) {
       return Status::Invalid("write options factory is exhausted");
     }
     compute::Declaration projected = ProjectByNamesDeclaration(input, names);
@@ -161,20 +161,20 @@ Result<std::vector<compute::Declaration>> DeserializePlans(
 
 namespace {
 
-Result<compute::ExecPlan> MakeSingleDeclarationPlan(
+Result<std::shared_ptr<compute::ExecPlan>> MakeSingleDeclarationPlan(
     std::vector<compute::Declaration> declarations) {
   if (declarations.size() > 1) {
     return Status::Invalid("DeserializePlan does not support multiple root relations");
   } else {
     ARROW_ASSIGN_OR_RAISE(auto plan, compute::ExecPlan::Make());
-    std::ignore = declarations[0].AddToPlan(plan.get());
-    return *std::move(plan);
+    ARROW_RETURN_NOT_OK(declarations[0].AddToPlan(plan.get()));
+    return plan;
   }
 }
 
 }  // namespace
 
-Result<compute::ExecPlan> DeserializePlan(
+Result<std::shared_ptr<compute::ExecPlan>> DeserializePlan(
     const Buffer& buf, const std::shared_ptr<compute::SinkNodeConsumer>& consumer,
     const ExtensionIdRegistry* registry, ExtensionSet* ext_set_out) {
   bool factory_done = false;
@@ -190,7 +190,7 @@ Result<compute::ExecPlan> DeserializePlan(
   return MakeSingleDeclarationPlan(declarations);
 }
 
-Result<compute::ExecPlan> DeserializePlan(
+Result<std::shared_ptr<compute::ExecPlan>> DeserializePlan(
     const Buffer& buf, const std::shared_ptr<dataset::WriteNodeOptions>& write_options,
     const ExtensionIdRegistry* registry, ExtensionSet* ext_set_out) {
   bool factory_done = false;
diff --git a/cpp/src/arrow/engine/substrait/serde.h b/cpp/src/arrow/engine/substrait/serde.h
index 545e7449fb..9005553d30 100644
--- a/cpp/src/arrow/engine/substrait/serde.h
+++ b/cpp/src/arrow/engine/substrait/serde.h
@@ -71,7 +71,7 @@ ARROW_ENGINE_EXPORT Result<std::vector<compute::Declaration>> DeserializePlans(
 /// Plan is returned here.
 /// \return an ExecNode corresponding to the single toplevel relation in the Substrait
 /// Plan
-Result<compute::ExecPlan> DeserializePlan(
+Result<std::shared_ptr<compute::ExecPlan>> DeserializePlan(
     const Buffer& buf, const std::shared_ptr<compute::SinkNodeConsumer>& consumer,
     const ExtensionIdRegistry* registry = NULLPTR, ExtensionSet* ext_set_out = NULLPTR);
 
@@ -111,7 +111,7 @@ ARROW_ENGINE_EXPORT Result<std::vector<compute::Declaration>> DeserializePlans(
 /// Plan is returned here.
 /// \return a vector of ExecNode declarations, one for each toplevel relation in the
 /// Substrait Plan
-ARROW_ENGINE_EXPORT Result<compute::ExecPlan> DeserializePlan(
+ARROW_ENGINE_EXPORT Result<std::shared_ptr<compute::ExecPlan>> DeserializePlan(
     const Buffer& buf, const std::shared_ptr<dataset::WriteNodeOptions>& write_options,
     const ExtensionIdRegistry* registry = NULLPTR, ExtensionSet* ext_set_out = NULLPTR);
 
diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc
index 8e5745d6df..3bb4de4e92 100644
--- a/cpp/src/arrow/engine/substrait/serde_test.cc
+++ b/cpp/src/arrow/engine/substrait/serde_test.cc
@@ -15,12 +15,6 @@
 // specific language governing permissions and limitations
 // under the License.
 
-#include "arrow/engine/substrait/serde.h"
-#include "arrow/dataset/plan.h"
-#include "arrow/engine/substrait/util.h"
-#include "arrow/filesystem/mockfs.h"
-#include "arrow/filesystem/test_util.h"
-
 #include <google/protobuf/descriptor.h>
 #include <google/protobuf/util/json_util.h>
 #include <google/protobuf/util/type_resolver_util.h>
@@ -28,8 +22,14 @@
 
 #include "arrow/compute/exec/expression_internal.h"
 #include "arrow/dataset/file_base.h"
+#include "arrow/dataset/file_ipc.h"
+#include "arrow/dataset/plan.h"
 #include "arrow/dataset/scanner.h"
 #include "arrow/engine/substrait/extension_types.h"
+#include "arrow/engine/substrait/serde.h"
+#include "arrow/engine/substrait/util.h"
+#include "arrow/filesystem/mockfs.h"
+#include "arrow/filesystem/test_util.h"
 #include "arrow/testing/gtest_util.h"
 #include "arrow/testing/matchers.h"
 #include "arrow/util/key_value_metadata.h"
@@ -807,12 +807,12 @@ TEST(Substrait, ExtensionSetFromPlanExhaustedFactory) {
     ASSERT_RAISES(
         Invalid,
         DeserializePlans(
-            *buf, []() -> std::shared_ptr<compute::SinkNodeConsumer> { return NULLPTR; },
+            *buf, []() -> std::shared_ptr<compute::SinkNodeConsumer> { return nullptr; },
             ext_id_reg, &ext_set));
     ASSERT_RAISES(
         Invalid,
         DeserializePlans(
-            *buf, []() -> std::shared_ptr<dataset::WriteNodeOptions> { return NULLPTR; },
+            *buf, []() -> std::shared_ptr<dataset::WriteNodeOptions> { return nullptr; },
             ext_id_reg, &ext_set));
   }
 }
@@ -905,13 +905,36 @@ TEST(Substrait, DeserializeWithConsumerFactory) {
                        DeserializePlans(*buf, NullSinkNodeConsumer::Make));
   ASSERT_EQ(declarations.size(), 1);
   compute::Declaration* decl = &declarations[0];
-  ASSERT_TRUE(decl->factory_name == std::string("consuming_sink"));
+  ASSERT_EQ(decl->factory_name, "consuming_sink");
   ASSERT_OK_AND_ASSIGN(auto plan, compute::ExecPlan::Make());
   ASSERT_OK_AND_ASSIGN(auto sink_node, declarations[0].AddToPlan(plan.get()));
-  ASSERT_TRUE(sink_node->kind_name() == std::string("ConsumingSinkNode"));
+  ASSERT_STREQ(sink_node->kind_name(), "ConsumingSinkNode");
+  ASSERT_EQ(sink_node->num_inputs(), 1);
+  auto& prev_node = sink_node->inputs()[0];
+  ASSERT_STREQ(prev_node->kind_name(), "SourceNode");
+
+  ASSERT_OK(plan->StartProducing());
+  ASSERT_FINISHES_OK(plan->finished());
+#endif
+}
+
+TEST(Substrait, DeserializeSinglePlanWithConsumerFactory) {
+#ifdef _WIN32
+  GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows";
+#else
+  ASSERT_OK_AND_ASSIGN(std::string substrait_json, GetSubstraitJSON());
+  ASSERT_OK_AND_ASSIGN(auto buf, substrait::SerializeJsonPlan(substrait_json));
+  ASSERT_OK_AND_ASSIGN(std::shared_ptr<compute::ExecPlan> plan,
+                       DeserializePlan(*buf, NullSinkNodeConsumer::Make()));
+  ASSERT_EQ(1, plan->sinks().size());
+  compute::ExecNode* sink_node = plan->sinks()[0];
+  ASSERT_STREQ(sink_node->kind_name(), "ConsumingSinkNode");
   ASSERT_EQ(sink_node->num_inputs(), 1);
   auto& prev_node = sink_node->inputs()[0];
-  ASSERT_TRUE(prev_node->kind_name() == std::string("SourceNode"));
+  ASSERT_STREQ(prev_node->kind_name(), "SourceNode");
+
+  ASSERT_OK(plan->StartProducing());
+  ASSERT_FINISHES_OK(plan->finished());
 #endif
 }
 
@@ -925,10 +948,15 @@ TEST(Substrait, DeserializeWithWriteOptionsFactory) {
   ASSERT_OK_AND_ASSIGN(std::shared_ptr<fs::FileSystem> fs,
                        fs::internal::MockFileSystem::Make(mock_now, {testdir}));
   auto write_options_factory = [&fs] {
+    std::shared_ptr<dataset::IpcFileFormat> format =
+        std::make_shared<dataset::IpcFileFormat>();
     dataset::FileSystemDatasetWriteOptions options;
+    options.file_write_options = format->DefaultWriteOptions();
     options.filesystem = fs;
     options.basename_template = "chunk-{i}.arrow";
     options.base_dir = "testdir";
+    options.partitioning =
+        std::make_shared<dataset::DirectoryPartitioning>(arrow::schema({}));
     return std::make_shared<dataset::WriteNodeOptions>(options);
   };
   ASSERT_OK_AND_ASSIGN(std::string substrait_json, GetSubstraitJSON());
@@ -936,17 +964,20 @@ TEST(Substrait, DeserializeWithWriteOptionsFactory) {
   ASSERT_OK_AND_ASSIGN(auto declarations, DeserializePlans(*buf, write_options_factory));
   ASSERT_EQ(declarations.size(), 1);
   compute::Declaration* decl = &declarations[0];
-  ASSERT_TRUE(decl->factory_name == std::string("write"));
+  ASSERT_EQ(decl->factory_name, "write");
   ASSERT_EQ(decl->inputs.size(), 1);
   decl = util::get_if<compute::Declaration>(&decl->inputs[0]);
-  ASSERT_TRUE(decl != NULLPTR);
-  ASSERT_TRUE(decl->factory_name == std::string("scan"));
+  ASSERT_NE(decl, nullptr);
+  ASSERT_EQ(decl->factory_name, "scan");
   ASSERT_OK_AND_ASSIGN(auto plan, compute::ExecPlan::Make());
   ASSERT_OK_AND_ASSIGN(auto sink_node, declarations[0].AddToPlan(plan.get()));
-  ASSERT_TRUE(sink_node->kind_name() == std::string("ConsumingSinkNode"));
+  ASSERT_STREQ(sink_node->kind_name(), "ConsumingSinkNode");
   ASSERT_EQ(sink_node->num_inputs(), 1);
   auto& prev_node = sink_node->inputs()[0];
-  ASSERT_TRUE(prev_node->kind_name() == std::string("SourceNode"));
+  ASSERT_STREQ(prev_node->kind_name(), "SourceNode");
+
+  ASSERT_OK(plan->StartProducing());
+  ASSERT_FINISHES_OK(plan->finished());
 #endif
 }
 
@@ -955,8 +986,8 @@ static void test_with_registries(
   auto default_func_reg = compute::GetFunctionRegistry();
   auto nested_ext_id_reg = substrait::MakeExtensionIdRegistry();
   auto nested_func_reg = compute::FunctionRegistry::Make(default_func_reg);
-  test(NULLPTR, default_func_reg);
-  test(NULLPTR, nested_func_reg.get());
+  test(nullptr, default_func_reg);
+  test(nullptr, nested_func_reg.get());
   test(nested_ext_id_reg.get(), default_func_reg);
   test(nested_ext_id_reg.get(), nested_func_reg.get());
 }
@@ -1402,7 +1433,7 @@ TEST(Substrait, AggregateBasic) {
                   }]
                 }
               },
-              "local_files": { 
+              "local_files": {
                 "items": [
                   {
                     "uri_file": "file:///tmp/dat.parquet",
@@ -1518,7 +1549,7 @@ TEST(Substrait, AggregateInvalidFunction) {
                   }]
                 }
               },
-              "local_files": { 
+              "local_files": {
                 "items": [
                   {
                     "uri_file": "file:///tmp/dat.parquet",
@@ -1579,7 +1610,7 @@ TEST(Substrait, AggregateInvalidAggFuncArgs) {
                   }]
                 }
               },
-              "local_files": { 
+              "local_files": {
                 "items": [
                   {
                     "uri_file": "file:///tmp/dat.parquet",
@@ -1649,7 +1680,7 @@ TEST(Substrait, AggregateWithFilter) {
                   }]
                 }
               },
-              "local_files": { 
+              "local_files": {
                 "items": [
                   {
                     "uri_file": "file:///tmp/dat.parquet",