You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by np...@apache.org on 2020/08/10 17:42:13 UTC
[arrow] branch master updated: ARROW-9606: [C++][Dataset] Support
`"a"_.In(<>).Assume()`
This is an automated email from the ASF dual-hosted git repository.
npr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 858059f ARROW-9606: [C++][Dataset] Support `"a"_.In(<>).Assume(<compound>)`
858059f is described below
commit 858059fd1cfbae755341c2df5f94ef5d42b329da
Author: Benjamin Kietzman <be...@gmail.com>
AuthorDate: Mon Aug 10 10:41:37 2020 -0700
ARROW-9606: [C++][Dataset] Support `"a"_.In(<>).Assume(<compound>)`
This enables predicate pushdown of `%in%` filters in the presence of compound partition information
@mpjdem
Closes #7911 from bkietz/9606-simplify-isin-query-nested-partitions
Authored-by: Benjamin Kietzman <be...@gmail.com>
Signed-off-by: Neal Richardson <ne...@gmail.com>
---
cpp/src/arrow/dataset/filter.cc | 73 ++++++++++++++++++++++++++++--------
cpp/src/arrow/dataset/filter.h | 10 +++++
cpp/src/arrow/dataset/filter_test.cc | 13 ++++++-
cpp/src/arrow/dataset/partition.cc | 39 +++++++++----------
r/tests/testthat/test-dataset.R | 10 +++++
5 files changed, 106 insertions(+), 39 deletions(-)
diff --git a/cpp/src/arrow/dataset/filter.cc b/cpp/src/arrow/dataset/filter.cc
index ddfee24..25a3a0a 100644
--- a/cpp/src/arrow/dataset/filter.cc
+++ b/cpp/src/arrow/dataset/filter.cc
@@ -261,22 +261,36 @@ std::shared_ptr<Expression> Invert(const Expression& expr) {
}
std::shared_ptr<Expression> Expression::Assume(const Expression& given) const {
- if (given.type() == ExpressionType::COMPARISON) {
+ std::shared_ptr<Expression> out;
+
+ DCHECK_OK(VisitConjunctionMembers(given, [&](const Expression& given) {
+ if (out != nullptr) {
+ return Status::OK();
+ }
+
+ if (given.type() != ExpressionType::COMPARISON) {
+ return Status::OK();
+ }
+
const auto& given_cmp = checked_cast<const ComparisonExpression&>(given);
- if (given_cmp.op() == CompareOperator::EQUAL) {
- if (this->Equals(given_cmp.left_operand()) &&
- given_cmp.right_operand()->type() == ExpressionType::SCALAR) {
- return given_cmp.right_operand();
- }
+ if (given_cmp.op() != CompareOperator::EQUAL) {
+ return Status::OK();
+ }
- if (this->Equals(given_cmp.right_operand()) &&
- given_cmp.left_operand()->type() == ExpressionType::SCALAR) {
- return given_cmp.left_operand();
- }
+ if (this->Equals(given_cmp.left_operand())) {
+ out = given_cmp.right_operand();
+ return Status::OK();
}
- }
- return Copy();
+ if (this->Equals(given_cmp.right_operand())) {
+ out = given_cmp.left_operand();
+ return Status::OK();
+ }
+
+ return Status::OK();
+ }));
+
+ return out ? out : Copy();
}
std::shared_ptr<Expression> ComparisonExpression::Assume(const Expression& given) const {
@@ -571,15 +585,30 @@ std::shared_ptr<Expression> InExpression::Assume(const Expression& given) const
return scalar(set_->null_count() > 0);
}
- const auto& value = checked_cast<const ScalarExpression&>(*operand).value();
+ Datum set, value;
+ if (set_->type_id() == Type::DICTIONARY) {
+ const auto& dict_set = checked_cast<const DictionaryArray&>(*set_);
+ auto maybe_decoded = compute::Take(dict_set.dictionary(), dict_set.indices());
+ auto maybe_value = checked_cast<const DictionaryScalar&>(
+ *checked_cast<const ScalarExpression&>(*operand).value())
+ .GetEncodedValue();
+ if (!maybe_decoded.ok() || !maybe_value.ok()) {
+ return std::make_shared<InExpression>(std::move(operand), set_);
+ }
+ set = *maybe_decoded;
+ value = *maybe_value;
+ } else {
+ set = set_;
+ value = checked_cast<const ScalarExpression&>(*operand).value();
+ }
compute::CompareOptions eq(CompareOperator::EQUAL);
- Result<Datum> out_result = compute::Compare(set_, value, eq);
- if (!out_result.ok()) {
+ Result<Datum> maybe_out = compute::Compare(set, value, eq);
+ if (!maybe_out.ok()) {
return std::make_shared<InExpression>(std::move(operand), set_);
}
- Datum out = out_result.ValueOrDie();
+ Datum out = maybe_out.ValueOrDie();
DCHECK(out.is_array());
DCHECK_EQ(out.type()->id(), Type::BOOL);
@@ -1046,6 +1075,18 @@ Result<std::shared_ptr<Expression>> InsertImplicitCasts(const Expression& expr,
return VisitExpression(expr, InsertImplicitCastsImpl{schema});
}
+Status VisitConjunctionMembers(const Expression& expr,
+ const std::function<Status(const Expression&)>& visitor) {
+ if (expr.type() == ExpressionType::AND) {
+ const auto& and_ = checked_cast<const AndExpression&>(expr);
+ RETURN_NOT_OK(VisitConjunctionMembers(*and_.left_operand(), visitor));
+ RETURN_NOT_OK(VisitConjunctionMembers(*and_.right_operand(), visitor));
+ return Status::OK();
+ }
+
+ return visitor(expr);
+}
+
std::vector<std::string> FieldsInExpression(const Expression& expr) {
struct {
void operator()(const FieldExpression& expr) { fields.push_back(expr.name()); }
diff --git a/cpp/src/arrow/dataset/filter.h b/cpp/src/arrow/dataset/filter.h
index ebf58cc..d4cdcd9 100644
--- a/cpp/src/arrow/dataset/filter.h
+++ b/cpp/src/arrow/dataset/filter.h
@@ -575,6 +575,16 @@ auto VisitExpression(const Expression& expr, Visitor&& visitor)
return visitor(internal::checked_cast<const CustomExpression&>(expr));
}
+/// \brief Visit each subexpression of an arbitrarily nested conjunction.
+///
+/// | given | visit |
+/// |--------------------------------|---------------------------------------------|
+/// | a and b | visit(a), visit(b) |
+/// | c | visit(c) |
+/// | (a and b) and ((c or d) and e) | visit(a), visit(b), visit(c or d), visit(e) |
+ARROW_DS_EXPORT Status VisitConjunctionMembers(
+ const Expression& expr, const std::function<Status(const Expression&)>& visitor);
+
/// \brief Insert CastExpressions where necessary to make a valid expression.
ARROW_DS_EXPORT Result<std::shared_ptr<Expression>> InsertImplicitCasts(
const Expression& expr, const Schema& schema);
diff --git a/cpp/src/arrow/dataset/filter_test.cc b/cpp/src/arrow/dataset/filter_test.cc
index 8e16208..7723912 100644
--- a/cpp/src/arrow/dataset/filter_test.cc
+++ b/cpp/src/arrow/dataset/filter_test.cc
@@ -72,7 +72,8 @@ class ExpressionsTest : public ::testing::Test {
std::shared_ptr<DataType> ns = timestamp(TimeUnit::NANO);
std::shared_ptr<Schema> schema_ =
schema({field("a", int32()), field("b", int32()), field("f", float64()),
- field("s", utf8()), field("ts", ns)});
+ field("s", utf8()), field("ts", ns),
+ field("dict_b", dictionary(int32(), int32()))});
std::shared_ptr<Expression> always = scalar(true);
std::shared_ptr<Expression> never = scalar(false);
};
@@ -131,6 +132,16 @@ TEST_F(ExpressionsTest, SimplificationAgainstCompoundCondition) {
AssertSimplifiesTo("b"_ > 5, "b"_ == 3 or "b"_ == 6, "b"_ > 5);
AssertSimplifiesTo("b"_ > 7, "b"_ == 3 or "b"_ == 6, *never);
AssertSimplifiesTo("b"_ > 5 and "b"_ < 10, "b"_ > 6 and "b"_ < 13, "b"_ < 10);
+
+ auto set_123 = ArrayFromJSON(int32(), R"([1, 2, 3])");
+ AssertSimplifiesTo("b"_.In(set_123), "a"_ == 3 and "b"_ == 3, *always);
+ AssertSimplifiesTo("b"_.In(set_123), "a"_ == 3 and "b"_ == 5, *never);
+
+ auto dict_set_123 =
+ DictArrayFromJSON(dictionary(int32(), int32()), R"([1,2,0])", R"([1,2,3])");
+ ASSERT_OK_AND_ASSIGN(auto b_dict, dict_set_123->GetScalar(0));
+ AssertSimplifiesTo("b_dict"_.In(dict_set_123), "a"_ == 3 and "b_dict"_ == b_dict,
+ *always);
}
TEST_F(ExpressionsTest, SimplificationToNull) {
diff --git a/cpp/src/arrow/dataset/partition.cc b/cpp/src/arrow/dataset/partition.cc
index a0ea91d..d43cf7a 100644
--- a/cpp/src/arrow/dataset/partition.cc
+++ b/cpp/src/arrow/dataset/partition.cc
@@ -80,32 +80,27 @@ Status KeyValuePartitioning::VisitKeys(
const Expression& expr,
const std::function<Status(const std::string& name,
const std::shared_ptr<Scalar>& value)>& visitor) {
- if (expr.type() == ExpressionType::AND) {
- const auto& and_ = checked_cast<const AndExpression&>(expr);
- RETURN_NOT_OK(VisitKeys(*and_.left_operand(), visitor));
- RETURN_NOT_OK(VisitKeys(*and_.right_operand(), visitor));
- return Status::OK();
- }
-
- if (expr.type() != ExpressionType::COMPARISON) {
- return Status::OK();
- }
+ return VisitConjunctionMembers(expr, [visitor](const Expression& expr) {
+ if (expr.type() != ExpressionType::COMPARISON) {
+ return Status::OK();
+ }
- const auto& cmp = checked_cast<const ComparisonExpression&>(expr);
- if (cmp.op() != compute::CompareOperator::EQUAL) {
- return Status::OK();
- }
+ const auto& cmp = checked_cast<const ComparisonExpression&>(expr);
+ if (cmp.op() != compute::CompareOperator::EQUAL) {
+ return Status::OK();
+ }
- auto lhs = cmp.left_operand().get();
- auto rhs = cmp.right_operand().get();
- if (lhs->type() != ExpressionType::FIELD) std::swap(lhs, rhs);
+ auto lhs = cmp.left_operand().get();
+ auto rhs = cmp.right_operand().get();
+ if (lhs->type() != ExpressionType::FIELD) std::swap(lhs, rhs);
- if (lhs->type() != ExpressionType::FIELD || rhs->type() != ExpressionType::SCALAR) {
- return Status::OK();
- }
+ if (lhs->type() != ExpressionType::FIELD || rhs->type() != ExpressionType::SCALAR) {
+ return Status::OK();
+ }
- return visitor(checked_cast<const FieldExpression*>(lhs)->name(),
- checked_cast<const ScalarExpression*>(rhs)->value());
+ return visitor(checked_cast<const FieldExpression*>(lhs)->name(),
+ checked_cast<const ScalarExpression*>(rhs)->value());
+ });
}
Result<std::unordered_map<std::string, std::shared_ptr<Scalar>>>
diff --git a/r/tests/testthat/test-dataset.R b/r/tests/testthat/test-dataset.R
index 1e93d41..cc17b19 100644
--- a/r/tests/testthat/test-dataset.R
+++ b/r/tests/testthat/test-dataset.R
@@ -391,6 +391,16 @@ test_that("filter() with %in%", {
collect(),
tibble(int = df1$int[c(3, 4, 6)], part = 1)
)
+
+# ARROW-9606: bug in %in% filter on partition column with >1 partition columns
+ ds <- open_dataset(hive_dir)
+ expect_equivalent(
+ ds %>%
+ filter(group %in% 2) %>%
+ select(names(df2)) %>%
+ collect(),
+ df2
+ )
})
test_that("filter() on timestamp columns", {