You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by li...@apache.org on 2022/06/08 11:36:52 UTC
[arrow] branch master updated: ARROW-14185: [C++] HashJoinNode should validate HashJoinNodeOptions (#13051)
This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new e87ac4f05a ARROW-14185: [C++] HashJoinNode should validate HashJoinNodeOptions (#13051)
e87ac4f05a is described below
commit e87ac4f05a8a39b09a50fe9e68e5e09fa4fbfeac
Author: Jabari Booker <o....@gmail.com>
AuthorDate: Wed Jun 8 07:36:48 2022 -0400
ARROW-14185: [C++] HashJoinNode should validate HashJoinNodeOptions (#13051)
Adding validation for HashJoinNodeOptions
Authored-by: JabariBooker <o....@gmail.com>
Signed-off-by: David Li <li...@gmail.com>
---
cpp/src/arrow/compute/exec/hash_join_node.cc | 15 +++++
cpp/src/arrow/compute/exec/hash_join_node_test.cc | 67 ++++++++++++++++++++++-
2 files changed, 81 insertions(+), 1 deletion(-)
diff --git a/cpp/src/arrow/compute/exec/hash_join_node.cc b/cpp/src/arrow/compute/exec/hash_join_node.cc
index e47d609554..c9232c6e43 100644
--- a/cpp/src/arrow/compute/exec/hash_join_node.cc
+++ b/cpp/src/arrow/compute/exec/hash_join_node.cc
@@ -454,6 +454,20 @@ Status HashJoinSchema::CollectFilterColumns(std::vector<FieldRef>& left_filter,
return Status::OK();
}
+Status ValidateHashJoinNodeOptions(const HashJoinNodeOptions& join_options) {
+ if (join_options.key_cmp.empty() || join_options.left_keys.empty() ||
+ join_options.right_keys.empty()) {
+ return Status::Invalid("key_cmp and keys cannot be empty");
+ }
+
+ if ((join_options.key_cmp.size() != join_options.left_keys.size()) ||
+ (join_options.key_cmp.size() != join_options.right_keys.size())) {
+ return Status::Invalid("key_cmp and keys must have the same size");
+ }
+
+ return Status::OK();
+}
+
class HashJoinNode : public ExecNode {
public:
HashJoinNode(ExecPlan* plan, NodeVector inputs, const HashJoinNodeOptions& join_options,
@@ -481,6 +495,7 @@ class HashJoinNode : public ExecNode {
::arrow::internal::make_unique<HashJoinSchema>();
const auto& join_options = checked_cast<const HashJoinNodeOptions&>(options);
+ RETURN_NOT_OK(ValidateHashJoinNodeOptions(join_options));
const auto& left_schema = *(inputs[0]->output_schema());
const auto& right_schema = *(inputs[1]->output_schema());
diff --git a/cpp/src/arrow/compute/exec/hash_join_node_test.cc b/cpp/src/arrow/compute/exec/hash_join_node_test.cc
index c4eccd68d3..e752870486 100644
--- a/cpp/src/arrow/compute/exec/hash_join_node_test.cc
+++ b/cpp/src/arrow/compute/exec/hash_join_node_test.cc
@@ -977,7 +977,7 @@ TEST(HashJoin, Suffix) {
MakeExecNode("source", plan.get(), {},
SourceNodeOptions{input_right.schema,
input_right.gen(/*parallel=*/false,
- /*slow=*/false)}))
+ /*slow=*/false)}));
HashJoinNodeOptions join_opts{JoinType::INNER,
/*left_keys=*/{"lkey"},
@@ -1783,6 +1783,71 @@ TEST(HashJoin, UnsupportedTypes) {
}
}
+TEST(HashJoin, CheckHashJoinNodeOptionsValidation) {
+ auto exec_ctx =
+ arrow::internal::make_unique<ExecContext>(default_memory_pool(), nullptr);
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get()));
+
+ BatchesWithSchema input_left;
+ input_left.batches = {ExecBatchFromJSON({int32(), int32(), int32()}, R"([
+ [1, 4, 7],
+ [2, 5, 8],
+ [3, 6, 9]
+ ])")};
+ input_left.schema = schema(
+ {field("lkey", int32()), field("shared", int32()), field("ldistinct", int32())});
+
+ BatchesWithSchema input_right;
+ input_right.batches = {ExecBatchFromJSON({int32(), int32(), int32()}, R"([
+ [1, 10, 13],
+ [2, 11, 14],
+ [3, 12, 15]
+ ])")};
+ input_right.schema = schema(
+ {field("rkey", int32()), field("shared", int32()), field("rdistinct", int32())});
+
+ ExecNode* l_source;
+ ExecNode* r_source;
+ ASSERT_OK_AND_ASSIGN(
+ l_source,
+ MakeExecNode("source", plan.get(), {},
+ SourceNodeOptions{input_left.schema, input_left.gen(/*parallel=*/false,
+ /*slow=*/false)}));
+
+ ASSERT_OK_AND_ASSIGN(r_source,
+ MakeExecNode("source", plan.get(), {},
+ SourceNodeOptions{input_right.schema,
+ input_right.gen(/*parallel=*/false,
+ /*slow=*/false)}))
+
+ std::vector<std::vector<FieldRef>> l_keys = {
+ {},
+ {FieldRef("lkey")},
+ {FieldRef("lkey"), FieldRef("shared"), FieldRef("ldistinct")}};
+ std::vector<std::vector<FieldRef>> r_keys = {
+ {},
+ {FieldRef("rkey")},
+ {FieldRef("rkey"), FieldRef("shared"), FieldRef("rdistinct")}};
+ std::vector<std::vector<JoinKeyCmp>> key_cmps = {
+ {}, {JoinKeyCmp::EQ}, {JoinKeyCmp::EQ, JoinKeyCmp::EQ, JoinKeyCmp::EQ}};
+
+ for (int i = 0; i < 3; ++i) {
+ for (int j = 0; j < 3; ++j) {
+ for (int k = 0; k < 3; ++k) {
+ if (i == j && j == k && i != 0) {
+ continue;
+ }
+
+ HashJoinNodeOptions options{JoinType::INNER, l_keys[j], r_keys[k], {}, {},
+ key_cmps[i]};
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, ::testing::HasSubstr("key_cmp and keys"),
+ MakeExecNode("hashjoin", plan.get(), {l_source, r_source}, options));
+ }
+ }
+ }
+}
+
TEST(HashJoin, ResidualFilter) {
for (bool parallel : {false, true}) {
SCOPED_TRACE(parallel ? "parallel/merged" : "serial");