You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by np...@apache.org on 2020/08/10 17:42:13 UTC

[arrow] branch master updated: ARROW-9606: [C++][Dataset] Support `"a"_.In(<>).Assume()`

This is an automated email from the ASF dual-hosted git repository.

npr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 858059f  ARROW-9606: [C++][Dataset] Support `"a"_.In(<>).Assume(<compound>)`
858059f is described below

commit 858059fd1cfbae755341c2df5f94ef5d42b329da
Author: Benjamin Kietzman <be...@gmail.com>
AuthorDate: Mon Aug 10 10:41:37 2020 -0700

    ARROW-9606: [C++][Dataset] Support `"a"_.In(<>).Assume(<compound>)`
    
    This enables predicate pushdown of `%in%` filters in the presence of compound partition information
    
    @mpjdem
    
    Closes #7911 from bkietz/9606-simplify-isin-query-nested-partitions
    
    Authored-by: Benjamin Kietzman <be...@gmail.com>
    Signed-off-by: Neal Richardson <ne...@gmail.com>
---
 cpp/src/arrow/dataset/filter.cc      | 73 ++++++++++++++++++++++++++++--------
 cpp/src/arrow/dataset/filter.h       | 10 +++++
 cpp/src/arrow/dataset/filter_test.cc | 13 ++++++-
 cpp/src/arrow/dataset/partition.cc   | 39 +++++++++----------
 r/tests/testthat/test-dataset.R      | 10 +++++
 5 files changed, 106 insertions(+), 39 deletions(-)

diff --git a/cpp/src/arrow/dataset/filter.cc b/cpp/src/arrow/dataset/filter.cc
index ddfee24..25a3a0a 100644
--- a/cpp/src/arrow/dataset/filter.cc
+++ b/cpp/src/arrow/dataset/filter.cc
@@ -261,22 +261,36 @@ std::shared_ptr<Expression> Invert(const Expression& expr) {
 }
 
 std::shared_ptr<Expression> Expression::Assume(const Expression& given) const {
-  if (given.type() == ExpressionType::COMPARISON) {
+  std::shared_ptr<Expression> out;
+
+  DCHECK_OK(VisitConjunctionMembers(given, [&](const Expression& given) {
+    if (out != nullptr) {
+      return Status::OK();
+    }
+
+    if (given.type() != ExpressionType::COMPARISON) {
+      return Status::OK();
+    }
+
     const auto& given_cmp = checked_cast<const ComparisonExpression&>(given);
-    if (given_cmp.op() == CompareOperator::EQUAL) {
-      if (this->Equals(given_cmp.left_operand()) &&
-          given_cmp.right_operand()->type() == ExpressionType::SCALAR) {
-        return given_cmp.right_operand();
-      }
+    if (given_cmp.op() != CompareOperator::EQUAL) {
+      return Status::OK();
+    }
 
-      if (this->Equals(given_cmp.right_operand()) &&
-          given_cmp.left_operand()->type() == ExpressionType::SCALAR) {
-        return given_cmp.left_operand();
-      }
+    if (this->Equals(given_cmp.left_operand())) {
+      out = given_cmp.right_operand();
+      return Status::OK();
     }
-  }
 
-  return Copy();
+    if (this->Equals(given_cmp.right_operand())) {
+      out = given_cmp.left_operand();
+      return Status::OK();
+    }
+
+    return Status::OK();
+  }));
+
+  return out ? out : Copy();
 }
 
 std::shared_ptr<Expression> ComparisonExpression::Assume(const Expression& given) const {
@@ -571,15 +585,30 @@ std::shared_ptr<Expression> InExpression::Assume(const Expression& given) const
     return scalar(set_->null_count() > 0);
   }
 
-  const auto& value = checked_cast<const ScalarExpression&>(*operand).value();
+  Datum set, value;
+  if (set_->type_id() == Type::DICTIONARY) {
+    const auto& dict_set = checked_cast<const DictionaryArray&>(*set_);
+    auto maybe_decoded = compute::Take(dict_set.dictionary(), dict_set.indices());
+    auto maybe_value = checked_cast<const DictionaryScalar&>(
+                           *checked_cast<const ScalarExpression&>(*operand).value())
+                           .GetEncodedValue();
+    if (!maybe_decoded.ok() || !maybe_value.ok()) {
+      return std::make_shared<InExpression>(std::move(operand), set_);
+    }
+    set = *maybe_decoded;
+    value = *maybe_value;
+  } else {
+    set = set_;
+    value = checked_cast<const ScalarExpression&>(*operand).value();
+  }
 
   compute::CompareOptions eq(CompareOperator::EQUAL);
-  Result<Datum> out_result = compute::Compare(set_, value, eq);
-  if (!out_result.ok()) {
+  Result<Datum> maybe_out = compute::Compare(set, value, eq);
+  if (!maybe_out.ok()) {
     return std::make_shared<InExpression>(std::move(operand), set_);
   }
 
-  Datum out = out_result.ValueOrDie();
+  Datum out = maybe_out.ValueOrDie();
 
   DCHECK(out.is_array());
   DCHECK_EQ(out.type()->id(), Type::BOOL);
@@ -1046,6 +1075,18 @@ Result<std::shared_ptr<Expression>> InsertImplicitCasts(const Expression& expr,
   return VisitExpression(expr, InsertImplicitCastsImpl{schema});
 }
 
+Status VisitConjunctionMembers(const Expression& expr,
+                               const std::function<Status(const Expression&)>& visitor) {
+  if (expr.type() == ExpressionType::AND) {
+    const auto& and_ = checked_cast<const AndExpression&>(expr);
+    RETURN_NOT_OK(VisitConjunctionMembers(*and_.left_operand(), visitor));
+    RETURN_NOT_OK(VisitConjunctionMembers(*and_.right_operand(), visitor));
+    return Status::OK();
+  }
+
+  return visitor(expr);
+}
+
 std::vector<std::string> FieldsInExpression(const Expression& expr) {
   struct {
     void operator()(const FieldExpression& expr) { fields.push_back(expr.name()); }
diff --git a/cpp/src/arrow/dataset/filter.h b/cpp/src/arrow/dataset/filter.h
index ebf58cc..d4cdcd9 100644
--- a/cpp/src/arrow/dataset/filter.h
+++ b/cpp/src/arrow/dataset/filter.h
@@ -575,6 +575,16 @@ auto VisitExpression(const Expression& expr, Visitor&& visitor)
   return visitor(internal::checked_cast<const CustomExpression&>(expr));
 }
 
+/// \brief Visit each subexpression of an arbitrarily nested conjunction.
+///
+/// | given                          | visit                                       |
+/// |--------------------------------|---------------------------------------------|
+/// | a and b                        | visit(a), visit(b)                          |
+/// | c                              | visit(c)                                    |
+/// | (a and b) and ((c or d) and e) | visit(a), visit(b), visit(c or d), visit(e) |
+ARROW_DS_EXPORT Status VisitConjunctionMembers(
+    const Expression& expr, const std::function<Status(const Expression&)>& visitor);
+
 /// \brief Insert CastExpressions where necessary to make a valid expression.
 ARROW_DS_EXPORT Result<std::shared_ptr<Expression>> InsertImplicitCasts(
     const Expression& expr, const Schema& schema);
diff --git a/cpp/src/arrow/dataset/filter_test.cc b/cpp/src/arrow/dataset/filter_test.cc
index 8e16208..7723912 100644
--- a/cpp/src/arrow/dataset/filter_test.cc
+++ b/cpp/src/arrow/dataset/filter_test.cc
@@ -72,7 +72,8 @@ class ExpressionsTest : public ::testing::Test {
   std::shared_ptr<DataType> ns = timestamp(TimeUnit::NANO);
   std::shared_ptr<Schema> schema_ =
       schema({field("a", int32()), field("b", int32()), field("f", float64()),
-              field("s", utf8()), field("ts", ns)});
+              field("s", utf8()), field("ts", ns),
+              field("dict_b", dictionary(int32(), int32()))});
   std::shared_ptr<Expression> always = scalar(true);
   std::shared_ptr<Expression> never = scalar(false);
 };
@@ -131,6 +132,16 @@ TEST_F(ExpressionsTest, SimplificationAgainstCompoundCondition) {
   AssertSimplifiesTo("b"_ > 5, "b"_ == 3 or "b"_ == 6, "b"_ > 5);
   AssertSimplifiesTo("b"_ > 7, "b"_ == 3 or "b"_ == 6, *never);
   AssertSimplifiesTo("b"_ > 5 and "b"_ < 10, "b"_ > 6 and "b"_ < 13, "b"_ < 10);
+
+  auto set_123 = ArrayFromJSON(int32(), R"([1, 2, 3])");
+  AssertSimplifiesTo("b"_.In(set_123), "a"_ == 3 and "b"_ == 3, *always);
+  AssertSimplifiesTo("b"_.In(set_123), "a"_ == 3 and "b"_ == 5, *never);
+
+  auto dict_set_123 =
+      DictArrayFromJSON(dictionary(int32(), int32()), R"([1,2,0])", R"([1,2,3])");
+  ASSERT_OK_AND_ASSIGN(auto b_dict, dict_set_123->GetScalar(0));
+  AssertSimplifiesTo("b_dict"_.In(dict_set_123), "a"_ == 3 and "b_dict"_ == b_dict,
+                     *always);
 }
 
 TEST_F(ExpressionsTest, SimplificationToNull) {
diff --git a/cpp/src/arrow/dataset/partition.cc b/cpp/src/arrow/dataset/partition.cc
index a0ea91d..d43cf7a 100644
--- a/cpp/src/arrow/dataset/partition.cc
+++ b/cpp/src/arrow/dataset/partition.cc
@@ -80,32 +80,27 @@ Status KeyValuePartitioning::VisitKeys(
     const Expression& expr,
     const std::function<Status(const std::string& name,
                                const std::shared_ptr<Scalar>& value)>& visitor) {
-  if (expr.type() == ExpressionType::AND) {
-    const auto& and_ = checked_cast<const AndExpression&>(expr);
-    RETURN_NOT_OK(VisitKeys(*and_.left_operand(), visitor));
-    RETURN_NOT_OK(VisitKeys(*and_.right_operand(), visitor));
-    return Status::OK();
-  }
-
-  if (expr.type() != ExpressionType::COMPARISON) {
-    return Status::OK();
-  }
+  return VisitConjunctionMembers(expr, [visitor](const Expression& expr) {
+    if (expr.type() != ExpressionType::COMPARISON) {
+      return Status::OK();
+    }
 
-  const auto& cmp = checked_cast<const ComparisonExpression&>(expr);
-  if (cmp.op() != compute::CompareOperator::EQUAL) {
-    return Status::OK();
-  }
+    const auto& cmp = checked_cast<const ComparisonExpression&>(expr);
+    if (cmp.op() != compute::CompareOperator::EQUAL) {
+      return Status::OK();
+    }
 
-  auto lhs = cmp.left_operand().get();
-  auto rhs = cmp.right_operand().get();
-  if (lhs->type() != ExpressionType::FIELD) std::swap(lhs, rhs);
+    auto lhs = cmp.left_operand().get();
+    auto rhs = cmp.right_operand().get();
+    if (lhs->type() != ExpressionType::FIELD) std::swap(lhs, rhs);
 
-  if (lhs->type() != ExpressionType::FIELD || rhs->type() != ExpressionType::SCALAR) {
-    return Status::OK();
-  }
+    if (lhs->type() != ExpressionType::FIELD || rhs->type() != ExpressionType::SCALAR) {
+      return Status::OK();
+    }
 
-  return visitor(checked_cast<const FieldExpression*>(lhs)->name(),
-                 checked_cast<const ScalarExpression*>(rhs)->value());
+    return visitor(checked_cast<const FieldExpression*>(lhs)->name(),
+                   checked_cast<const ScalarExpression*>(rhs)->value());
+  });
 }
 
 Result<std::unordered_map<std::string, std::shared_ptr<Scalar>>>
diff --git a/r/tests/testthat/test-dataset.R b/r/tests/testthat/test-dataset.R
index 1e93d41..cc17b19 100644
--- a/r/tests/testthat/test-dataset.R
+++ b/r/tests/testthat/test-dataset.R
@@ -391,6 +391,16 @@ test_that("filter() with %in%", {
       collect(),
     tibble(int = df1$int[c(3, 4, 6)], part = 1)
   )
+
+# ARROW-9606: bug in %in% filter on partition column with >1 partition columns
+  ds <- open_dataset(hive_dir)
+  expect_equivalent(
+    ds %>%
+      filter(group %in% 2) %>%
+      select(names(df2)) %>%
+      collect(),
+    df2
+  )
 })
 
 test_that("filter() on timestamp columns", {