You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by li...@apache.org on 2022/06/08 11:36:52 UTC

[arrow] branch master updated: ARROW-14185: [C++] HashJoinNode should validate HashJoinNodeOptions (#13051)

This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new e87ac4f05a ARROW-14185: [C++] HashJoinNode should validate HashJoinNodeOptions (#13051)
e87ac4f05a is described below

commit e87ac4f05a8a39b09a50fe9e68e5e09fa4fbfeac
Author: Jabari Booker <o....@gmail.com>
AuthorDate: Wed Jun 8 07:36:48 2022 -0400

    ARROW-14185: [C++] HashJoinNode should validate HashJoinNodeOptions (#13051)
    
    Adding validation for HashJoinNodeOptions
    
    Authored-by: JabariBooker <o....@gmail.com>
    Signed-off-by: David Li <li...@gmail.com>
---
 cpp/src/arrow/compute/exec/hash_join_node.cc      | 15 +++++
 cpp/src/arrow/compute/exec/hash_join_node_test.cc | 67 ++++++++++++++++++++++-
 2 files changed, 81 insertions(+), 1 deletion(-)

diff --git a/cpp/src/arrow/compute/exec/hash_join_node.cc b/cpp/src/arrow/compute/exec/hash_join_node.cc
index e47d609554..c9232c6e43 100644
--- a/cpp/src/arrow/compute/exec/hash_join_node.cc
+++ b/cpp/src/arrow/compute/exec/hash_join_node.cc
@@ -454,6 +454,20 @@ Status HashJoinSchema::CollectFilterColumns(std::vector<FieldRef>& left_filter,
   return Status::OK();
 }
 
+Status ValidateHashJoinNodeOptions(const HashJoinNodeOptions& join_options) {
+  if (join_options.key_cmp.empty() || join_options.left_keys.empty() ||
+      join_options.right_keys.empty()) {
+    return Status::Invalid("key_cmp and keys cannot be empty");
+  }
+
+  if ((join_options.key_cmp.size() != join_options.left_keys.size()) ||
+      (join_options.key_cmp.size() != join_options.right_keys.size())) {
+    return Status::Invalid("key_cmp and keys must have the same size");
+  }
+
+  return Status::OK();
+}
+
 class HashJoinNode : public ExecNode {
  public:
   HashJoinNode(ExecPlan* plan, NodeVector inputs, const HashJoinNodeOptions& join_options,
@@ -481,6 +495,7 @@ class HashJoinNode : public ExecNode {
         ::arrow::internal::make_unique<HashJoinSchema>();
 
     const auto& join_options = checked_cast<const HashJoinNodeOptions&>(options);
+    RETURN_NOT_OK(ValidateHashJoinNodeOptions(join_options));
 
     const auto& left_schema = *(inputs[0]->output_schema());
     const auto& right_schema = *(inputs[1]->output_schema());
diff --git a/cpp/src/arrow/compute/exec/hash_join_node_test.cc b/cpp/src/arrow/compute/exec/hash_join_node_test.cc
index c4eccd68d3..e752870486 100644
--- a/cpp/src/arrow/compute/exec/hash_join_node_test.cc
+++ b/cpp/src/arrow/compute/exec/hash_join_node_test.cc
@@ -977,7 +977,7 @@ TEST(HashJoin, Suffix) {
                        MakeExecNode("source", plan.get(), {},
                                     SourceNodeOptions{input_right.schema,
                                                       input_right.gen(/*parallel=*/false,
-                                                                      /*slow=*/false)}))
+                                                                      /*slow=*/false)}));
 
   HashJoinNodeOptions join_opts{JoinType::INNER,
                                 /*left_keys=*/{"lkey"},
@@ -1783,6 +1783,71 @@ TEST(HashJoin, UnsupportedTypes) {
   }
 }
 
+TEST(HashJoin, CheckHashJoinNodeOptionsValidation) {
+  auto exec_ctx =
+      arrow::internal::make_unique<ExecContext>(default_memory_pool(), nullptr);
+  ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get()));
+
+  BatchesWithSchema input_left;
+  input_left.batches = {ExecBatchFromJSON({int32(), int32(), int32()}, R"([
+                   [1, 4, 7],
+                   [2, 5, 8],
+                   [3, 6, 9]
+                 ])")};
+  input_left.schema = schema(
+      {field("lkey", int32()), field("shared", int32()), field("ldistinct", int32())});
+
+  BatchesWithSchema input_right;
+  input_right.batches = {ExecBatchFromJSON({int32(), int32(), int32()}, R"([
+                   [1, 10, 13],
+                   [2, 11, 14],
+                   [3, 12, 15]
+                 ])")};
+  input_right.schema = schema(
+      {field("rkey", int32()), field("shared", int32()), field("rdistinct", int32())});
+
+  ExecNode* l_source;
+  ExecNode* r_source;
+  ASSERT_OK_AND_ASSIGN(
+      l_source,
+      MakeExecNode("source", plan.get(), {},
+                   SourceNodeOptions{input_left.schema, input_left.gen(/*parallel=*/false,
+                                                                       /*slow=*/false)}));
+
+  ASSERT_OK_AND_ASSIGN(r_source,
+                       MakeExecNode("source", plan.get(), {},
+                                    SourceNodeOptions{input_right.schema,
+                                                      input_right.gen(/*parallel=*/false,
+                                                                      /*slow=*/false)}))
+
+  std::vector<std::vector<FieldRef>> l_keys = {
+      {},
+      {FieldRef("lkey")},
+      {FieldRef("lkey"), FieldRef("shared"), FieldRef("ldistinct")}};
+  std::vector<std::vector<FieldRef>> r_keys = {
+      {},
+      {FieldRef("rkey")},
+      {FieldRef("rkey"), FieldRef("shared"), FieldRef("rdistinct")}};
+  std::vector<std::vector<JoinKeyCmp>> key_cmps = {
+      {}, {JoinKeyCmp::EQ}, {JoinKeyCmp::EQ, JoinKeyCmp::EQ, JoinKeyCmp::EQ}};
+
+  for (int i = 0; i < 3; ++i) {
+    for (int j = 0; j < 3; ++j) {
+      for (int k = 0; k < 3; ++k) {
+        if (i == j && j == k && i != 0) {
+          continue;
+        }
+
+        HashJoinNodeOptions options{JoinType::INNER, l_keys[j], r_keys[k], {}, {},
+                                    key_cmps[i]};
+        EXPECT_RAISES_WITH_MESSAGE_THAT(
+            Invalid, ::testing::HasSubstr("key_cmp and keys"),
+            MakeExecNode("hashjoin", plan.get(), {l_source, r_source}, options));
+      }
+    }
+  }
+}
+
 TEST(HashJoin, ResidualFilter) {
   for (bool parallel : {false, true}) {
     SCOPED_TRACE(parallel ? "parallel/merged" : "serial");