You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by jo...@apache.org on 2022/04/21 18:55:24 UTC
[arrow] branch master updated: ARROW-12659: [C++] Support is_valid as a guarantee
This is an automated email from the ASF dual-hosted git repository.
jonkeane pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 0e03af446c ARROW-12659: [C++] Support is_valid as a guarantee
0e03af446c is described below
commit 0e03af446c328d0ef963510c3292cb14e092b917
Author: David Li <li...@gmail.com>
AuthorDate: Thu Apr 21 13:55:15 2022 -0500
ARROW-12659: [C++] Support is_valid as a guarantee
This rebases #10253 and fixes it up to also address ARROW-15312, including a regression test.
This refactors how inequalities, is_valid, and is_null are treated in expression simplification, and updates the guarantees that the Parquet/Datasets emits for row groups to properly reflect nullability.
Closes #12891 from lidavidm/arrow-12659
Lead-authored-by: David Li <li...@gmail.com>
Co-authored-by: Benjamin Kietzman <be...@gmail.com>
Co-authored-by: Antoine Pitrou <pi...@free.fr>
Signed-off-by: Jonathan Keane <jk...@gmail.com>
---
cpp/src/arrow/compute/exec/expression.cc | 405 +++++++++++++++------
cpp/src/arrow/compute/exec/expression.h | 7 +-
cpp/src/arrow/compute/exec/expression_test.cc | 137 ++++++-
cpp/src/arrow/compute/kernels/scalar_validity.cc | 72 +++-
.../arrow/compute/kernels/scalar_validity_test.cc | 19 +
cpp/src/arrow/dataset/file_csv_test.cc | 1 +
cpp/src/arrow/dataset/file_ipc_test.cc | 1 +
cpp/src/arrow/dataset/file_orc_test.cc | 1 +
cpp/src/arrow/dataset/file_parquet.cc | 26 +-
cpp/src/arrow/dataset/file_parquet_test.cc | 21 +-
cpp/src/arrow/dataset/test_util.h | 26 +-
cpp/src/arrow/type.h | 2 +-
cpp/src/arrow/util/stl_util_test.cc | 7 +
cpp/src/arrow/util/vector.h | 4 +-
docs/source/cpp/compute.rst | 9 +-
docs/source/python/api/compute.rst | 1 +
16 files changed, 570 insertions(+), 169 deletions(-)
diff --git a/cpp/src/arrow/compute/exec/expression.cc b/cpp/src/arrow/compute/exec/expression.cc
index 1ef5c6e7b9..8f7a9a1c8c 100644
--- a/cpp/src/arrow/compute/exec/expression.cc
+++ b/cpp/src/arrow/compute/exec/expression.cc
@@ -34,6 +34,7 @@
#include "arrow/util/optional.h"
#include "arrow/util/string.h"
#include "arrow/util/value_parsing.h"
+#include "arrow/util/vector.h"
namespace arrow {
@@ -110,7 +111,7 @@ namespace {
std::string PrintDatum(const Datum& datum) {
if (datum.is_scalar()) {
- if (!datum.scalar()->is_valid) return "null";
+ if (!datum.scalar()->is_valid) return "null[" + datum.type()->ToString() + "]";
switch (datum.type()->id()) {
case Type::STRING:
@@ -129,6 +130,8 @@ std::string PrintDatum(const Datum& datum) {
}
return datum.scalar()->ToString();
+ } else if (datum.is_array()) {
+ return "Array[" + datum.type()->ToString() + "]";
}
return datum.ToString();
}
@@ -305,19 +308,49 @@ bool Expression::IsNullLiteral() const {
return false;
}
-bool Expression::IsSatisfiable() const {
- if (type() && type()->id() == Type::NA) {
- return false;
+namespace {
+util::optional<compute::NullHandling::type> GetNullHandling(
+ const Expression::Call& call) {
+ DCHECK_NE(call.function, nullptr);
+ if (call.function->kind() == compute::Function::SCALAR) {
+ return static_cast<const compute::ScalarKernel*>(call.kernel)->null_handling;
}
+ return util::nullopt;
+}
+} // namespace
+
+bool Expression::IsSatisfiable() const {
+ if (!type()) return true;
+ if (type()->id() != Type::BOOL) return true;
if (auto lit = literal()) {
if (lit->null_count() == lit->length()) {
return false;
}
- if (lit->is_scalar() && lit->type()->id() == Type::BOOL) {
+ if (lit->is_scalar()) {
return lit->scalar_as<BooleanScalar>().value;
}
+
+ return true;
+ }
+
+ if (field_ref()) return true;
+
+ auto call = CallNotNull(*this);
+
+ // invert(true_unless_null(x)) is always false or null by definition
+ // true_unless_null arises in simplification of inequalities below
+ if (call->function_name == "invert") {
+ if (auto nested_call = call->arguments[0].call()) {
+ if (nested_call->function_name == "true_unless_null") return false;
+ }
+ }
+
+ if (call->function_name == "and_kleene" || call->function_name == "and") {
+ for (const Expression& arg : call->arguments) {
+ if (!arg.IsSatisfiable()) return false;
+ }
}
return true;
@@ -370,9 +403,11 @@ Result<Expression> BindNonRecursive(Expression::Call call, bool insert_implicit_
compute::KernelContext kernel_context(exec_context);
if (call.kernel->init) {
+ const FunctionOptions* options =
+ call.options ? call.options.get() : call.function->default_options();
ARROW_ASSIGN_OR_RAISE(
call.kernel_state,
- call.kernel->init(&kernel_context, {call.kernel, descrs, call.options.get()}));
+ call.kernel->init(&kernel_context, {call.kernel, descrs, options}));
kernel_context.SetState(call.kernel_state.get());
}
@@ -575,14 +610,6 @@ util::optional<Out> FoldLeft(It begin, It end, const BinOp& bin_op) {
return folded;
}
-util::optional<compute::NullHandling::type> GetNullHandling(
- const Expression::Call& call) {
- if (call.function && call.function->kind() == compute::Function::SCALAR) {
- return static_cast<const compute::ScalarKernel*>(call.kernel)->null_handling;
- }
- return util::nullopt;
-}
-
} // namespace
std::vector<FieldRef> FieldsInExpression(const Expression& expr) {
@@ -632,9 +659,17 @@ Result<Expression> FoldConstants(Expression expr) {
if (GetNullHandling(*call) == compute::NullHandling::INTERSECTION) {
// kernels which always produce intersected validity can be resolved
// to null *now* if any of their inputs is a null literal
+ if (!call->descr.type) {
+ return Status::Invalid("Cannot fold constants for unbound expression ",
+ expr.ToString());
+ }
for (const auto& argument : call->arguments) {
if (argument.IsNullLiteral()) {
- return argument;
+ if (argument.type()->Equals(*call->descr.type)) {
+ return argument;
+ } else {
+ return literal(MakeNullScalar(call->descr.type));
+ }
}
}
}
@@ -682,46 +717,52 @@ std::vector<Expression> GuaranteeConjunctionMembers(
return FlattenedAssociativeChain(guaranteed_true_predicate).fringe;
}
-// Conjunction members which are represented in known_values are erased from
-// conjunction_members
-Status ExtractKnownFieldValuesImpl(
- std::vector<Expression>* conjunction_members,
- std::unordered_map<FieldRef, Datum, FieldRef::Hash>* known_values) {
- auto unconsumed_end =
- std::partition(conjunction_members->begin(), conjunction_members->end(),
- [](const Expression& expr) {
- // search for an equality conditions between a field and a literal
- auto call = expr.call();
- if (!call) return true;
-
- if (call->function_name == "equal") {
- auto ref = call->arguments[0].field_ref();
- auto lit = call->arguments[1].literal();
- return !(ref && lit);
- }
-
- if (call->function_name == "is_null") {
- auto ref = call->arguments[0].field_ref();
- return !ref;
- }
-
- return true;
- });
-
- for (auto it = unconsumed_end; it != conjunction_members->end(); ++it) {
- auto call = CallNotNull(*it);
-
- if (call->function_name == "equal") {
- auto ref = call->arguments[0].field_ref();
- auto lit = call->arguments[1].literal();
- known_values->emplace(*ref, *lit);
- } else if (call->function_name == "is_null") {
- auto ref = call->arguments[0].field_ref();
- known_values->emplace(*ref, Datum(std::make_shared<NullScalar>()));
- }
+/// \brief Extract an equality from an expression.
+///
+/// Recognizes expressions of the form:
+/// equal(a, 2)
+/// is_null(a)
+util::optional<std::pair<FieldRef, Datum>> ExtractOneFieldValue(
+ const Expression& guarantee) {
+ auto call = guarantee.call();
+ if (!call) return util::nullopt;
+
+ // search for an equality conditions between a field and a literal
+ if (call->function_name == "equal") {
+ auto ref = call->arguments[0].field_ref();
+ if (!ref) return util::nullopt;
+
+ auto lit = call->arguments[1].literal();
+ if (!lit) return util::nullopt;
+
+ return std::make_pair(*ref, *lit);
+ }
+
+ // ... or a known null field
+ if (call->function_name == "is_null") {
+ auto ref = call->arguments[0].field_ref();
+ if (!ref) return util::nullopt;
+
+ return std::make_pair(*ref, Datum(std::make_shared<NullScalar>()));
}
- conjunction_members->erase(unconsumed_end, conjunction_members->end());
+ return util::nullopt;
+}
+
+// Conjunction members which are represented in known_values are erased from
+// conjunction_members
+Status ExtractKnownFieldValues(std::vector<Expression>* conjunction_members,
+ KnownFieldValues* known_values) {
+ // filter out consumed conjunction members, leaving only unconsumed
+ *conjunction_members = arrow::internal::FilterVector(
+ std::move(*conjunction_members),
+ [known_values](const Expression& guarantee) -> bool {
+ if (auto known_value = ExtractOneFieldValue(guarantee)) {
+ known_values->map.insert(std::move(*known_value));
+ return false;
+ }
+ return true;
+ });
return Status::OK();
}
@@ -730,9 +771,9 @@ Status ExtractKnownFieldValuesImpl(
Result<KnownFieldValues> ExtractKnownFieldValues(
const Expression& guaranteed_true_predicate) {
- auto conjunction_members = GuaranteeConjunctionMembers(guaranteed_true_predicate);
KnownFieldValues known_values;
- RETURN_NOT_OK(ExtractKnownFieldValuesImpl(&conjunction_members, &known_values.map));
+ auto conjunction_members = GuaranteeConjunctionMembers(guaranteed_true_predicate);
+ RETURN_NOT_OK(ExtractKnownFieldValues(&conjunction_members, &known_values));
return known_values;
}
@@ -879,68 +920,199 @@ Result<Expression> Canonicalize(Expression expr, compute::ExecContext* exec_cont
namespace {
-Result<Expression> DirectComparisonSimplification(Expression expr,
- const Expression::Call& guarantee) {
- return Modify(
- std::move(expr), [](Expression expr) { return expr; },
- [&guarantee](Expression expr, ...) -> Result<Expression> {
- auto call = expr.call();
- if (!call) return expr;
+// An inequality comparison which a target Expression is known to satisfy. If nullable,
+// the target may evaluate to null in addition to values satisfying the comparison.
+struct Inequality {
+ // The inequality type
+ Comparison::type cmp;
+ // The LHS of the inequality
+ const FieldRef& target;
+ // The RHS of the inequality
+ const Datum& bound;
+ // Whether target can be null
+ bool nullable;
+
+ // Extract an Inequality if possible, derived from "less",
+ // "greater", "less_equal", and "greater_equal" expressions,
+ // possibly disjuncted with an "is_null" Expression.
+ // cmp(a, 2)
+ // cmp(a, 2) or is_null(a)
+ static util::optional<Inequality> ExtractOne(const Expression& guarantee) {
+ auto call = guarantee.call();
+ if (!call) return util::nullopt;
+
+ if (call->function_name == "or_kleene") {
+ // expect the LHS to be a usable field inequality
+ auto out = ExtractOneFromComparison(call->arguments[0]);
+ if (!out) return util::nullopt;
+
+ // expect the RHS to be an is_null expression
+ auto call_rhs = call->arguments[1].call();
+ if (!call_rhs) return util::nullopt;
+ if (call_rhs->function_name != "is_null") return util::nullopt;
+
+ // ... and that it references the same target
+ auto target = call_rhs->arguments[0].field_ref();
+ if (!target) return util::nullopt;
+ if (*target != out->target) return util::nullopt;
+
+ out->nullable = true;
+ return out;
+ }
- // Ensure both calls are comparisons with equal LHS and scalar RHS
- auto cmp = Comparison::Get(expr);
- auto cmp_guarantee = Comparison::Get(guarantee.function_name);
+ // fall back to a simple comparison with no "is_null"
+ return ExtractOneFromComparison(guarantee);
+ }
- if (!cmp) return expr;
- if (!cmp_guarantee) return expr;
+ static util::optional<Inequality> ExtractOneFromComparison(
+ const Expression& guarantee) {
+ auto call = guarantee.call();
+ if (!call) return util::nullopt;
- const auto& lhs = Comparison::StripOrderPreservingCasts(call->arguments[0]);
- const auto& guarantee_lhs = guarantee.arguments[0];
- if (lhs != guarantee_lhs) return expr;
+ if (auto cmp = Comparison::Get(call->function_name)) {
+ // not_equal comparisons are not very usable as guarantees
+ if (*cmp == Comparison::NOT_EQUAL) return util::nullopt;
- auto rhs = call->arguments[1].literal();
- auto guarantee_rhs = guarantee.arguments[1].literal();
+ auto target = call->arguments[0].field_ref();
+ if (!target) return util::nullopt;
- if (!rhs) return expr;
- if (!rhs->is_scalar()) return expr;
+ auto bound = call->arguments[1].literal();
+ if (!bound) return util::nullopt;
+ if (!bound->is_scalar()) return util::nullopt;
- if (!guarantee_rhs) return expr;
- if (!guarantee_rhs->is_scalar()) return expr;
+ return Inequality{*cmp, /*target=*/*target, *bound, /*nullable=*/false};
+ }
- ARROW_ASSIGN_OR_RAISE(auto cmp_rhs_guarantee_rhs,
- Comparison::Execute(*rhs, *guarantee_rhs));
- DCHECK_NE(cmp_rhs_guarantee_rhs, Comparison::NA);
+ return util::nullopt;
+ }
- if (cmp_rhs_guarantee_rhs == Comparison::EQUAL) {
- // RHS of filter is equal to RHS of guarantee
+ /// The given expression simplifies to `value` if the inequality
+ /// target is not nullable. Otherwise, it simplifies to either a
+ /// call to true_unless_null or !true_unless_null.
+ Result<Expression> simplified_to(const Expression& bound_target, bool value) const {
+ if (!nullable) return literal(value);
+
+ ExecContext exec_context;
+
+ // Data may be null, so comparison will yield `value` - or null IFF the data was null
+ //
+ // true_unless_null is cheap; it purely reuses the validity bitmap for the values
+ // buffer. Inversion is less cheap but we expect that term never to be evaluated
+ // since invert(true_unless_null(x)) is not satisfiable.
+ Expression::Call call;
+ call.function_name = "true_unless_null";
+ call.arguments = {bound_target};
+ ARROW_ASSIGN_OR_RAISE(
+ auto true_unless_null,
+ BindNonRecursive(std::move(call),
+ /*insert_implicit_casts=*/false, &exec_context));
+ if (value) return true_unless_null;
+
+ Expression::Call invert;
+ invert.function_name = "invert";
+ invert.arguments = {std::move(true_unless_null)};
+ return BindNonRecursive(std::move(invert),
+ /*insert_implicit_casts=*/false, &exec_context);
+ }
- if ((*cmp & *cmp_guarantee) == *cmp_guarantee) {
- // guarantee is a subset of filter, so all data will be included
- // x > 1, x >= 1, x != 1 guaranteed by x > 1
- return literal(true);
- }
+ /// \brief Simplify the given expression given this inequality as a guarantee.
+ Result<Expression> Simplify(Expression expr) {
+ const auto& guarantee = *this;
- if ((*cmp & *cmp_guarantee) == 0) {
- // guarantee disjoint with filter, so all data will be excluded
- // x > 1, x >= 1, x != 1 unsatisfiable if x == 1
- return literal(false);
- }
+ auto call = expr.call();
+ if (!call) return expr;
- return expr;
- }
+ if (call->function_name == "is_valid" || call->function_name == "is_null") {
+ if (guarantee.nullable) return expr;
+ const auto& lhs = Comparison::StripOrderPreservingCasts(call->arguments[0]);
+ if (!lhs.field_ref()) return expr;
+ if (*lhs.field_ref() != guarantee.target) return expr;
- if (*cmp_guarantee & cmp_rhs_guarantee_rhs) {
- // x > 1, x >= 1, x != 1 cannot use guarantee x >= 3
- return expr;
- }
+ return call->function_name == "is_valid" ? literal(true) : literal(false);
+ }
- if (*cmp & Comparison::GetFlipped(cmp_rhs_guarantee_rhs)) {
- // x > 1, x >= 1, x != 1 guaranteed by x >= 3
- return literal(true);
- } else {
- // x < 1, x <= 1, x == 1 unsatisfiable if x >= 3
- return literal(false);
- }
+ auto cmp = Comparison::Get(expr);
+ if (!cmp) return expr;
+
+ auto rhs = call->arguments[1].literal();
+ if (!rhs) return expr;
+ if (!rhs->is_scalar()) return expr;
+
+ const auto& lhs = Comparison::StripOrderPreservingCasts(call->arguments[0]);
+ if (!lhs.field_ref()) return expr;
+ if (*lhs.field_ref() != guarantee.target) return expr;
+
+ // Whether the RHS of the expression is EQUAL, LESS, or GREATER than the
+ // RHS of the guarantee. N.B. Comparison::type is a bitmask
+ ARROW_ASSIGN_OR_RAISE(const Comparison::type cmp_rhs_bound,
+ Comparison::Execute(*rhs, guarantee.bound));
+ DCHECK_NE(cmp_rhs_bound, Comparison::NA);
+
+ if (cmp_rhs_bound == Comparison::EQUAL) {
+ // RHS of filter is equal to RHS of guarantee
+
+ if ((*cmp & guarantee.cmp) == guarantee.cmp) {
+ // guarantee is a subset of filter, so all data will be included
+ // x > 1, x >= 1, x != 1 guaranteed by x > 1
+ return simplified_to(lhs, true);
+ }
+
+ if ((*cmp & guarantee.cmp) == 0) {
+ // guarantee disjoint with filter, so all data will be excluded
+ // x > 1, x >= 1 unsatisfiable if x == 1
+ return simplified_to(lhs, false);
+ }
+
+ return expr;
+ }
+
+ if (guarantee.cmp & cmp_rhs_bound) {
+ // We guarantee (x (?) N) and are trying to simplify (x (?) M). We know
+ // either M < N or M > N (i.e. cmp_rhs_bound is either LESS or GREATER).
+
+ // If M > N, then if the guarantee is (x > N), (x >= N), or (x != N)
+ // (i.e. guarantee.cmp & cmp_rhs_bound), we cannot do anything with the
+ // guarantee, and bail out here.
+
+ // For example, take M = 5, N = 3. Then cmp_rhs_bound = GREATER.
+ // x > 3, x >= 3, x != 3 implies nothing about x < 5, x <= 5, x > 5,
+ // x >= 5, x != 5 and we bail out here.
+ // x < 3, x <= 3 could simplify (some of) those expressions.
+ return expr;
+ }
+
+ if (*cmp & Comparison::GetFlipped(cmp_rhs_bound)) {
+ // x > 1, x >= 1, x != 1 guaranteed by x >= 3
+ // (where `guarantee.cmp` is GREATER_EQUAL, `cmp_rhs_bound` is LESS)
+ return simplified_to(lhs, true);
+ } else {
+ // x < 1, x <= 1, x == 1 unsatisfiable if x >= 3
+ return simplified_to(lhs, false);
+ }
+ }
+};
+
+/// \brief Simplify an expression given a guarantee, if the guarantee
+/// is is_valid().
+Result<Expression> SimplifyIsValidGuarantee(Expression expr,
+ const Expression::Call& guarantee) {
+ if (guarantee.function_name != "is_valid") return expr;
+
+ return Modify(
+ std::move(expr), [](Expression expr) { return expr; },
+ [&](Expression expr, ...) -> Result<Expression> {
+ auto call = expr.call();
+ if (!call) return expr;
+
+ if (call->arguments[0] != guarantee.arguments[0]) return expr;
+
+ if (call->function_name == "is_valid") return literal(true);
+
+ if (call->function_name == "true_unless_null") return literal(true);
+
+ if (call->function_name == "is_null") return literal(false);
+
+ return expr;
});
}
@@ -948,10 +1120,10 @@ Result<Expression> DirectComparisonSimplification(Expression expr,
Result<Expression> SimplifyWithGuarantee(Expression expr,
const Expression& guaranteed_true_predicate) {
+ KnownFieldValues known_values;
auto conjunction_members = GuaranteeConjunctionMembers(guaranteed_true_predicate);
- KnownFieldValues known_values;
- RETURN_NOT_OK(ExtractKnownFieldValuesImpl(&conjunction_members, &known_values.map));
+ RETURN_NOT_OK(ExtractKnownFieldValues(&conjunction_members, &known_values));
ARROW_ASSIGN_OR_RAISE(expr,
ReplaceFieldsWithKnownValues(known_values, std::move(expr)));
@@ -964,9 +1136,26 @@ Result<Expression> SimplifyWithGuarantee(Expression expr,
RETURN_NOT_OK(CanonicalizeAndFoldConstants());
for (const auto& guarantee : conjunction_members) {
- if (Comparison::Get(guarantee) && guarantee.call()->arguments[1].literal()) {
+ if (!guarantee.call()) continue;
+
+ if (auto inequality = Inequality::ExtractOne(guarantee)) {
+ ARROW_ASSIGN_OR_RAISE(auto simplified,
+ Modify(
+ std::move(expr), [](Expression expr) { return expr; },
+ [&](Expression expr, ...) -> Result<Expression> {
+ return inequality->Simplify(std::move(expr));
+ }));
+
+ if (Identical(simplified, expr)) continue;
+
+ expr = std::move(simplified);
+ RETURN_NOT_OK(CanonicalizeAndFoldConstants());
+ }
+
+ if (guarantee.call()->function_name == "is_valid") {
ARROW_ASSIGN_OR_RAISE(
- auto simplified, DirectComparisonSimplification(expr, *CallNotNull(guarantee)));
+ auto simplified,
+ SimplifyIsValidGuarantee(std::move(expr), *CallNotNull(guarantee)));
if (Identical(simplified, expr)) continue;
diff --git a/cpp/src/arrow/compute/exec/expression.h b/cpp/src/arrow/compute/exec/expression.h
index dbc8da7bbb..38d8075966 100644
--- a/cpp/src/arrow/compute/exec/expression.h
+++ b/cpp/src/arrow/compute/exec/expression.h
@@ -93,7 +93,8 @@ class ARROW_EXPORT Expression {
/// Return true if this expression is literal and entirely null.
bool IsNullLiteral() const;
- /// Return true if this expression could evaluate to true.
+ /// Return true if this expression could evaluate to true. Will return true for any
+ /// unbound, non-boolean, or unsimplified Expressions
bool IsSatisfiable() const;
// XXX someday
@@ -171,8 +172,10 @@ std::vector<FieldRef> FieldsInExpression(const Expression&);
ARROW_EXPORT
bool ExpressionHasFieldRefs(const Expression&);
-/// Assemble a mapping from field references to known values.
struct ARROW_EXPORT KnownFieldValues;
+
+/// Assemble a mapping from field references to known values. This derives known values
+/// from "equal" and "is_null" Expressions referencing a field and a literal.
ARROW_EXPORT
Result<KnownFieldValues> ExtractKnownFieldValues(
const Expression& guaranteed_true_predicate);
diff --git a/cpp/src/arrow/compute/exec/expression_test.cc b/cpp/src/arrow/compute/exec/expression_test.cc
index f916bc2a1c..95adb1652e 100644
--- a/cpp/src/arrow/compute/exec/expression_test.cc
+++ b/cpp/src/arrow/compute/exec/expression_test.cc
@@ -66,6 +66,10 @@ Expression cast(Expression argument, std::shared_ptr<DataType> to_type) {
compute::CastOptions::Safe(std::move(to_type)));
}
+Expression true_unless_null(Expression argument) {
+ return call("true_unless_null", {std::move(argument)});
+}
+
template <typename Actual, typename Expected>
void ExpectResultsEqual(Actual&& actual, Expected&& expected) {
using MaybeActual = typename EnsureResult<typename std::decay<Actual>::type>::type;
@@ -250,8 +254,8 @@ TEST(Expression, ToString) {
EXPECT_EQ(literal(3).ToString(), "3");
EXPECT_EQ(literal("a").ToString(), "\"a\"");
EXPECT_EQ(literal("a\nb").ToString(), "\"a\\nb\"");
- EXPECT_EQ(literal(std::make_shared<BooleanScalar>()).ToString(), "null");
- EXPECT_EQ(literal(std::make_shared<Int64Scalar>()).ToString(), "null");
+ EXPECT_EQ(literal(std::make_shared<BooleanScalar>()).ToString(), "null[bool]");
+ EXPECT_EQ(literal(std::make_shared<Int64Scalar>()).ToString(), "null[int64]");
EXPECT_EQ(literal(std::make_shared<BinaryScalar>(Buffer::FromString("az"))).ToString(),
"\"617A\"");
@@ -388,29 +392,49 @@ TEST(Expression, IsScalarExpression) {
}
TEST(Expression, IsSatisfiable) {
+ auto Bind = [](Expression expr) { return expr.Bind(*kBoringSchema).ValueOrDie(); };
+
EXPECT_TRUE(literal(true).IsSatisfiable());
EXPECT_FALSE(literal(false).IsSatisfiable());
auto null = std::make_shared<BooleanScalar>();
EXPECT_FALSE(literal(null).IsSatisfiable());
- EXPECT_TRUE(field_ref("a").IsSatisfiable());
+ // NB: no implicit conversion to bool
+ EXPECT_TRUE(literal(0).IsSatisfiable());
+
+ EXPECT_TRUE(field_ref("i32").IsSatisfiable());
+ EXPECT_TRUE(Bind(field_ref("i32")).IsSatisfiable());
- EXPECT_TRUE(equal(field_ref("a"), literal(1)).IsSatisfiable());
+ EXPECT_TRUE(equal(field_ref("i32"), literal(1)).IsSatisfiable());
+ EXPECT_TRUE(Bind(equal(field_ref("i32"), literal(1))).IsSatisfiable());
// NB: no constant folding here
- EXPECT_TRUE(equal(literal(0), literal(1)).IsSatisfiable());
-
- // When a top level conjunction contains an Expression which is certain to evaluate to
- // null, it can only evaluate to null or false.
- auto never_true = and_(literal(null), field_ref("a"));
- // This may appear in satisfiable filters if coalesced (for example, wrapped in fill_na)
- EXPECT_TRUE(call("is_null", {never_true}).IsSatisfiable());
- // ... but at the top level it is not satisfiable.
+ EXPECT_TRUE(Bind(equal(literal(0), literal(1))).IsSatisfiable());
+
+ // Special case invert(true_unless_null(x)): arises in simplification against a
+ // guarantee with a nullable caveat.
+ EXPECT_FALSE(Bind(not_(true_unless_null(field_ref("i32")))).IsSatisfiable());
+ // NB: no effort to examine unbound expressions
+ EXPECT_TRUE(not_(true_unless_null(field_ref("i32"))).IsSatisfiable());
+
+ // When a top level conjunction contains an Expression which is not satisfiable
+ // (guaranteed to evaluate to null or false), it can only evaluate to null or false.
// This special case arises when (for example) an absent column has made
- // one member of the conjunction always-null. This is fairly common and
- // would be a worthwhile optimization to support.
- // EXPECT_FALSE(null_or_false).IsSatisfiable());
+ // one member of the conjunction always-null.
+ for (const auto& never_true : {
+ // N.B. this is "and_kleene"
+ and_(literal(false), field_ref("bool")),
+ and_(literal(null), field_ref("bool")),
+ call("and", {literal(false), field_ref("bool")}),
+ call("and", {literal(null), field_ref("bool")}),
+ }) {
+ ARROW_SCOPED_TRACE(never_true.ToString());
+ EXPECT_FALSE(Bind(never_true).IsSatisfiable());
+ // ... but it may appear in satisfiable filters if coalesced (for example, wrapped in
+ // fill_na)
+ EXPECT_TRUE(Bind(call("is_null", {never_true})).IsSatisfiable());
+ }
}
TEST(Expression, FieldsInExpression) {
@@ -846,6 +870,10 @@ TEST(Expression, FoldConstants) {
}),
literal(4));
+ // INTERSECTION null handling and null input -> null output
+ ExpectFoldsTo(call("equal", {field_ref("i32"), null_literal(int32())}),
+ null_literal(boolean()));
+
// nested call against literals with one field_ref
// (i32 - (2 * 3)) + 2 == (i32 - 6) + 2
// NB this could be improved further by using associativity of addition; another pass
@@ -1066,8 +1094,7 @@ TEST(Expression, CanonicalizeAnd) {
and_(and_(and_(and_(null_, null_), true_), b), c));
// catches and_kleene even when it's a subexpression
- ExpectCanonicalizesTo(call("is_valid", {and_(b, true_)}),
- call("is_valid", {and_(true_, b)}));
+ ExpectCanonicalizesTo(is_valid(and_(b, true_)), is_valid(and_(true_, b)));
}
TEST(Expression, CanonicalizeComparison) {
@@ -1279,13 +1306,89 @@ TEST(Expression, SimplifyWithGuarantee) {
.WithGuarantee(not_(equal(field_ref("i32"), literal(7))))
.Expect(equal(field_ref("i32"), literal(7)));
+ // In the absence of is_null(i32) we assume i32 is valid
+ Simplify{
+ is_null(field_ref("i32")),
+ }
+ .WithGuarantee(greater_equal(field_ref("i32"), literal(1)))
+ .Expect(false);
+
+ Simplify{
+ is_null(field_ref("i32")),
+ }
+ .WithGuarantee(
+ or_(greater_equal(field_ref("i32"), literal(1)), is_null(field_ref("i32"))))
+ .Expect(is_null(field_ref("i32")));
+
+ Simplify{
+ is_null(field_ref("i32")),
+ }
+ .WithGuarantee(
+ and_(greater_equal(field_ref("i32"), literal(1)), is_valid(field_ref("i32"))))
+ .Expect(false);
+
+ Simplify{
+ is_valid(field_ref("i32")),
+ }
+ .WithGuarantee(greater_equal(field_ref("i32"), literal(1)))
+ .Expect(true);
+
+ Simplify{
+ is_valid(field_ref("i32")),
+ }
+ .WithGuarantee(
+ or_(greater_equal(field_ref("i32"), literal(1)), is_null(field_ref("i32"))))
+ .Expect(is_valid(field_ref("i32")));
+
+ Simplify{
+ is_valid(field_ref("i32")),
+ }
+ .WithGuarantee(
+ and_(greater_equal(field_ref("i32"), literal(1)), is_valid(field_ref("i32"))))
+ .Expect(true);
+}
+
+TEST(Expression, SimplifyWithValidityGuarantee) {
Simplify{is_null(field_ref("i32"))}
.WithGuarantee(is_null(field_ref("i32")))
.Expect(literal(true));
+ Simplify{is_valid(field_ref("i32"))}
+ .WithGuarantee(is_null(field_ref("i32")))
+ .Expect(literal(false));
+
Simplify{is_valid(field_ref("i32"))}
.WithGuarantee(is_valid(field_ref("i32")))
+ .Expect(literal(true));
+
+ Simplify{is_valid(field_ref("i32"))}
+ .WithGuarantee(is_valid(field_ref("dict_i32"))) // different field
.Expect(is_valid(field_ref("i32")));
+
+ Simplify{is_null(field_ref("i32"))}
+ .WithGuarantee(is_valid(field_ref("i32")))
+ .Expect(literal(false));
+
+ Simplify{true_unless_null(field_ref("i32"))}
+ .WithGuarantee(is_valid(field_ref("i32")))
+ .Expect(literal(true));
+}
+
+TEST(Expression, SimplifyWithComparisonAndNullableCaveat) {
+ auto i32_is_2_or_null =
+ or_(equal(field_ref("i32"), literal(2)), is_null(field_ref("i32")));
+
+ Simplify{equal(field_ref("i32"), literal(2))}
+ .WithGuarantee(i32_is_2_or_null)
+ .Expect(true_unless_null(field_ref("i32")));
+
+ // XXX: needs a rule for 'true_unless_null(x) || is_null(x)'
+ // Simplify{i32_is_2_or_null}.WithGuarantee(i32_is_2_or_null).Expect(literal(true));
+
+ Simplify{equal(field_ref("i32"), literal(3))}
+ .WithGuarantee(i32_is_2_or_null)
+ .Expect(not_(
+ true_unless_null(field_ref("i32")))); // not satisfiable, will drop row group
}
TEST(Expression, SimplifyThenExecute) {
diff --git a/cpp/src/arrow/compute/kernels/scalar_validity.cc b/cpp/src/arrow/compute/kernels/scalar_validity.cc
index e32f73ce56..ff16e9d935 100644
--- a/cpp/src/arrow/compute/kernels/scalar_validity.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_validity.cc
@@ -39,6 +39,15 @@ struct IsValidOperator {
}
static Status Call(KernelContext* ctx, const ArrayData& arr, ArrayData* out) {
+ if (arr.type->id() == Type::NA) {
+ // Input is all nulls => output is entirely false.
+ ARROW_ASSIGN_OR_RAISE(out->buffers[1],
+ ctx->AllocateBitmap(out->length + out->offset));
+ bit_util::SetBitsTo(out->buffers[1]->mutable_data(), out->offset, out->length,
+ false);
+ return Status::OK();
+ }
+
DCHECK_EQ(out->offset, 0);
DCHECK_LE(out->length, arr.length);
if (arr.MayHaveNulls()) {
@@ -146,6 +155,29 @@ struct IsNullOperator {
}
};
+struct TrueUnlessNullOperator {
+ static Status Call(KernelContext* ctx, const Scalar& in, Scalar* out) {
+ checked_cast<BooleanScalar*>(out)->is_valid = in.is_valid;
+ checked_cast<BooleanScalar*>(out)->value = true;
+ return Status::OK();
+ }
+
+ static Status Call(KernelContext* ctx, const ArrayData& arr, ArrayData* out) {
+ // NullHandling::INTERSECTION with a single input means the execution engine
+ // has already reused or allocated a null_bitmap which can be reused as the values
+ // buffer.
+ if (out->buffers[0]) {
+ out->buffers[1] = out->buffers[0];
+ } else {
+ // But for all-valid inputs, the engine will skip allocating a
+ // buffer; we have to allocate one ourselves
+ ARROW_ASSIGN_OR_RAISE(out->buffers[1], ctx->AllocateBitmap(arr.length));
+ std::memset(out->buffers[1]->mutable_data(), 0xFF, out->buffers[1]->size());
+ }
+ return Status::OK();
+ }
+};
+
struct IsNanOperator {
template <typename OutType, typename InType>
static constexpr OutType Call(KernelContext*, const InType& value, Status*) {
@@ -156,14 +188,15 @@ struct IsNanOperator {
void MakeFunction(std::string name, const FunctionDoc* doc,
std::vector<InputType> in_types, OutputType out_type,
ArrayKernelExec exec, FunctionRegistry* registry,
- MemAllocation::type mem_allocation, bool can_write_into_slices,
+ MemAllocation::type mem_allocation, NullHandling::type null_handling,
+ bool can_write_into_slices,
const FunctionOptions* default_options = NULLPTR,
KernelInit init = NULLPTR) {
Arity arity{static_cast<int>(in_types.size())};
auto func = std::make_shared<ScalarFunction>(name, arity, doc, default_options);
ScalarKernel kernel(std::move(in_types), out_type, exec, init);
- kernel.null_handling = NullHandling::OUTPUT_NOT_NULL;
+ kernel.null_handling = null_handling;
kernel.can_write_into_slices = can_write_into_slices;
kernel.mem_allocation = mem_allocation;
@@ -247,21 +280,7 @@ std::shared_ptr<ScalarFunction> MakeIsNanFunction(std::string name,
}
Status IsValidExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
- const Datum& arg0 = batch[0];
- if (arg0.type()->id() == Type::NA) {
- auto false_value = std::make_shared<BooleanScalar>(false);
- if (arg0.kind() == Datum::SCALAR) {
- out->value = false_value;
- } else {
- std::shared_ptr<Array> false_values;
- RETURN_NOT_OK(MakeArrayFromScalar(*false_value, out->length(), ctx->memory_pool())
- .Value(&false_values));
- out->value = false_values->data();
- }
- return Status::OK();
- } else {
- return applicator::SimpleUnary<IsValidOperator>(ctx, batch, out);
- }
+ return applicator::SimpleUnary<IsValidOperator>(ctx, batch, out);
}
Status IsNullExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
@@ -281,6 +300,10 @@ Status IsNullExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
}
}
+Status TrueUnlessNullExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ return applicator::SimpleUnary<TrueUnlessNullOperator>(ctx, batch, out);
+}
+
const FunctionDoc is_valid_doc(
"Return true if non-null",
("For each input value, emit true iff the value is valid (i.e. non-null)."),
@@ -303,6 +326,11 @@ const FunctionDoc is_null_doc(
"True may also be emitted for NaN values by setting the `nan_is_null` flag."),
{"values"}, "NullOptions");
+const FunctionDoc true_unless_null_doc("Return true if non-null, else return null",
+ ("For each input value, emit true iff the value\n"
+ "is valid (non-null), otherwise emit null."),
+ {"values"});
+
const FunctionDoc is_nan_doc("Return true if NaN",
("For each input value, emit true iff the value is NaN."),
{"values"});
@@ -312,12 +340,18 @@ const FunctionDoc is_nan_doc("Return true if NaN",
void RegisterScalarValidity(FunctionRegistry* registry) {
static auto kNullOptions = NullOptions::Defaults();
MakeFunction("is_valid", &is_valid_doc, {ValueDescr::ANY}, boolean(), IsValidExec,
- registry, MemAllocation::NO_PREALLOCATE, /*can_write_into_slices=*/false);
+ registry, MemAllocation::NO_PREALLOCATE, NullHandling::OUTPUT_NOT_NULL,
+ /*can_write_into_slices=*/false);
MakeFunction("is_null", &is_null_doc, {ValueDescr::ANY}, boolean(), IsNullExec,
- registry, MemAllocation::PREALLOCATE,
+ registry, MemAllocation::PREALLOCATE, NullHandling::OUTPUT_NOT_NULL,
/*can_write_into_slices=*/true, &kNullOptions, NanOptionsState::Init);
+ MakeFunction("true_unless_null", &true_unless_null_doc, {ValueDescr::ANY}, boolean(),
+ TrueUnlessNullExec, registry, MemAllocation::NO_PREALLOCATE,
+ NullHandling::INTERSECTION,
+ /*can_write_into_slices=*/false);
+
DCHECK_OK(registry->AddFunction(MakeIsFiniteFunction("is_finite", &is_finite_doc)));
DCHECK_OK(registry->AddFunction(MakeIsInfFunction("is_inf", &is_inf_doc)));
DCHECK_OK(registry->AddFunction(MakeIsNanFunction("is_nan", &is_nan_doc)));
diff --git a/cpp/src/arrow/compute/kernels/scalar_validity_test.cc b/cpp/src/arrow/compute/kernels/scalar_validity_test.cc
index fb9358b143..df7ccc2909 100644
--- a/cpp/src/arrow/compute/kernels/scalar_validity_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_validity_test.cc
@@ -48,6 +48,25 @@ TEST_F(TestBooleanValidityKernels, ArrayIsValid) {
"[false, true, true, false]");
}
+TEST_F(TestBooleanValidityKernels, TrueUnlessNull) {
+ CheckScalarUnary("true_unless_null", type_singleton(), "[]", type_singleton(), "[]");
+ CheckScalarUnary("true_unless_null", type_singleton(), "[null]", type_singleton(),
+ "[null]");
+ CheckScalarUnary("true_unless_null", type_singleton(), "[0, 1]", type_singleton(),
+ "[true, true]");
+ CheckScalarUnary("true_unless_null", type_singleton(), "[null, 1, 0, null]",
+ type_singleton(), "[null, true, true, null]");
+}
+
+TEST_F(TestBooleanValidityKernels, IsValidIsNullNullType) {
+ CheckScalarUnary("is_null", std::make_shared<NullArray>(5),
+ ArrayFromJSON(boolean(), "[true, true, true, true, true]"));
+ CheckScalarUnary("is_valid", std::make_shared<NullArray>(5),
+ ArrayFromJSON(boolean(), "[false, false, false, false, false]"));
+ CheckScalarUnary("true_unless_null", std::make_shared<NullArray>(5),
+ ArrayFromJSON(boolean(), "[null, null, null, null, null]"));
+}
+
TEST_F(TestBooleanValidityKernels, ArrayIsValidBufferPassthruOptimization) {
Datum arg = ArrayFromJSON(boolean(), "[null, 1, 0, null]");
ASSERT_OK_AND_ASSIGN(auto validity, arrow::compute::IsValid(arg));
diff --git a/cpp/src/arrow/dataset/file_csv_test.cc b/cpp/src/arrow/dataset/file_csv_test.cc
index 00644b46eb..2064c58148 100644
--- a/cpp/src/arrow/dataset/file_csv_test.cc
+++ b/cpp/src/arrow/dataset/file_csv_test.cc
@@ -397,6 +397,7 @@ TEST_P(TestCsvFileFormatScan, ScanRecordBatchReaderWithVirtualColumn) {
TEST_P(TestCsvFileFormatScan, ScanRecordBatchReaderWithDuplicateColumnError) {
TestScanWithDuplicateColumnError();
}
+TEST_P(TestCsvFileFormatScan, ScanWithPushdownNulls) { TestScanWithPushdownNulls(); }
INSTANTIATE_TEST_SUITE_P(TestScan, TestCsvFileFormatScan,
::testing::ValuesIn(TestFormatParams::Values()),
diff --git a/cpp/src/arrow/dataset/file_ipc_test.cc b/cpp/src/arrow/dataset/file_ipc_test.cc
index b085ad6a1d..35a2ef273f 100644
--- a/cpp/src/arrow/dataset/file_ipc_test.cc
+++ b/cpp/src/arrow/dataset/file_ipc_test.cc
@@ -150,6 +150,7 @@ TEST_P(TestIpcFileFormatScan, ScanRecordBatchReaderWithDuplicateColumn) {
TEST_P(TestIpcFileFormatScan, ScanRecordBatchReaderWithDuplicateColumnError) {
TestScanWithDuplicateColumnError();
}
+TEST_P(TestIpcFileFormatScan, ScanWithPushdownNulls) { TestScanWithPushdownNulls(); }
TEST_P(TestIpcFileFormatScan, FragmentScanOptions) {
auto reader = GetRecordBatchReader(
// ARROW-12077: on Windows/mimalloc/release, nullable list column leads to crash
diff --git a/cpp/src/arrow/dataset/file_orc_test.cc b/cpp/src/arrow/dataset/file_orc_test.cc
index 0e5dfa0176..aaa3aeff94 100644
--- a/cpp/src/arrow/dataset/file_orc_test.cc
+++ b/cpp/src/arrow/dataset/file_orc_test.cc
@@ -85,6 +85,7 @@ TEST_P(TestOrcFileFormatScan, ScanRecordBatchReaderWithDuplicateColumn) {
TEST_P(TestOrcFileFormatScan, ScanRecordBatchReaderWithDuplicateColumnError) {
TestScanWithDuplicateColumnError();
}
+TEST_P(TestOrcFileFormatScan, ScanWithPushdownNulls) { TestScanWithPushdownNulls(); }
INSTANTIATE_TEST_SUITE_P(TestScan, TestOrcFileFormatScan,
::testing::ValuesIn(TestFormatParams::Values()),
TestFormatParams::ToTestNameString);
diff --git a/cpp/src/arrow/dataset/file_parquet.cc b/cpp/src/arrow/dataset/file_parquet.cc
index 4a8d409312..7be226e765 100644
--- a/cpp/src/arrow/dataset/file_parquet.cc
+++ b/cpp/src/arrow/dataset/file_parquet.cc
@@ -128,17 +128,27 @@ util::optional<compute::Expression> ColumnChunkStatisticsAsExpression(
auto maybe_min = min->CastTo(field->type());
auto maybe_max = max->CastTo(field->type());
if (maybe_min.ok() && maybe_max.ok()) {
- auto col_min = maybe_min.MoveValueUnsafe();
- auto col_max = maybe_max.MoveValueUnsafe();
- if (col_min->Equals(col_max)) {
- return compute::equal(std::move(field_expr), compute::literal(std::move(col_min)));
+ min = maybe_min.MoveValueUnsafe();
+ max = maybe_max.MoveValueUnsafe();
+
+ if (min->Equals(max)) {
+ auto single_value = compute::equal(field_expr, compute::literal(std::move(min)));
+
+ if (statistics->null_count() == 0) {
+ return single_value;
+ }
+ return compute::or_(std::move(single_value), is_null(std::move(field_expr)));
}
auto lower_bound =
- compute::greater_equal(field_expr, compute::literal(std::move(col_min)));
- auto upper_bound =
- compute::less_equal(std::move(field_expr), compute::literal(std::move(col_max)));
- return compute::and_(std::move(lower_bound), std::move(upper_bound));
+ compute::greater_equal(field_expr, compute::literal(std::move(min)));
+ auto upper_bound = compute::less_equal(field_expr, compute::literal(std::move(max)));
+
+ auto in_range = compute::and_(std::move(lower_bound), std::move(upper_bound));
+ if (statistics->null_count() != 0) {
+ return compute::or_(std::move(in_range), compute::is_null(field_expr));
+ }
+ return in_range;
}
return util::nullopt;
diff --git a/cpp/src/arrow/dataset/file_parquet_test.cc b/cpp/src/arrow/dataset/file_parquet_test.cc
index d5c7a0b985..2c7cad8a7e 100644
--- a/cpp/src/arrow/dataset/file_parquet_test.cc
+++ b/cpp/src/arrow/dataset/file_parquet_test.cc
@@ -265,17 +265,27 @@ TEST_F(TestParquetFileFormat, CountRowsPredicatePushdown) {
[1],
[2]
])");
- ASSERT_OK_AND_ASSIGN(auto reader,
- RecordBatchReader::Make({null_batch, batch}, dataset_schema));
+ auto batch2 = RecordBatchFromJSON(dataset_schema, R"([
+[4],
+[4]
+])");
+ ASSERT_OK_AND_ASSIGN(auto reader, RecordBatchReader::Make({null_batch, batch, batch2},
+ dataset_schema));
auto source = GetFileSource(reader.get());
auto fragment = MakeFragment(*source);
ASSERT_OK_AND_ASSIGN(
auto predicate,
greater_equal(field_ref("i64"), literal(1)).Bind(*dataset_schema));
- ASSERT_FINISHES_OK_AND_EQ(util::make_optional<int64_t>(2),
+ ASSERT_FINISHES_OK_AND_EQ(util::make_optional<int64_t>(4),
+ fragment->CountRows(predicate, options));
+
+ ASSERT_OK_AND_ASSIGN(predicate, is_null(field_ref("i64")).Bind(*dataset_schema));
+ ASSERT_FINISHES_OK_AND_EQ(util::make_optional<int64_t>(3),
+ fragment->CountRows(predicate, options));
+
+ ASSERT_OK_AND_ASSIGN(predicate, is_valid(field_ref("i64")).Bind(*dataset_schema));
+ ASSERT_FINISHES_OK_AND_EQ(util::make_optional<int64_t>(4),
fragment->CountRows(predicate, options));
- // TODO(ARROW-12659): SimplifyWithGuarantee can't handle
- // not(is_null) so trying to count with is_null doesn't work
}
}
@@ -393,6 +403,7 @@ TEST_P(TestParquetFileFormatScan, ScanRecordBatchReaderWithDuplicateColumn) {
TEST_P(TestParquetFileFormatScan, ScanRecordBatchReaderWithDuplicateColumnError) {
TestScanWithDuplicateColumnError();
}
+TEST_P(TestParquetFileFormatScan, ScanWithPushdownNulls) { TestScanWithPushdownNulls(); }
TEST_P(TestParquetFileFormatScan, ScanRecordBatchReaderDictEncoded) {
auto reader = GetRecordBatchReader(schema({field("utf8", utf8())}));
auto source = GetFileSource(reader.get());
diff --git a/cpp/src/arrow/dataset/test_util.h b/cpp/src/arrow/dataset/test_util.h
index 3f826fa09c..9ec0a59860 100644
--- a/cpp/src/arrow/dataset/test_util.h
+++ b/cpp/src/arrow/dataset/test_util.h
@@ -853,6 +853,27 @@ class FileFormatScanMixin : public FileFormatFixtureMixin<FormatHelper>,
ASSERT_RAISES(Invalid,
ProjectionDescr::FromNames({"i32"}, *this->opts_->dataset_schema));
}
+ void TestScanWithPushdownNulls() {
+ // Regression test for ARROW-15312
+ auto i64 = field("i64", int64());
+ this->SetSchema({i64});
+ this->SetFilter(is_null(field_ref("i64")));
+
+ auto rb = RecordBatchFromJSON(schema({i64}), R"([
+ [null],
+ [32]
+ ])");
+ ASSERT_OK_AND_ASSIGN(auto reader, RecordBatchReader::Make({rb}));
+ auto source = this->GetFileSource(reader.get());
+
+ auto fragment = this->MakeFragment(*source);
+ int64_t row_count = 0;
+ for (auto maybe_batch : Batches(fragment)) {
+ ASSERT_OK_AND_ASSIGN(auto batch, maybe_batch);
+ row_count += batch->num_rows();
+ }
+ ASSERT_EQ(row_count, 1);
+ }
protected:
using FileFormatFixtureMixin<FormatHelper>::opts_;
@@ -1002,13 +1023,11 @@ struct MakeFileSystemDatasetMixin {
continue;
}
- ASSERT_OK_AND_ASSIGN(partitions[i], partitions[i].Bind(*s));
ASSERT_OK_AND_ASSIGN(auto fragment,
format->MakeFragment({info, fs_}, partitions[i]));
fragments.push_back(std::move(fragment));
}
- ASSERT_OK_AND_ASSIGN(root_partition, root_partition.Bind(*s));
ASSERT_OK_AND_ASSIGN(dataset_, FileSystemDataset::Make(s, root_partition, format, fs_,
std::move(fragments)));
}
@@ -1059,9 +1078,6 @@ static std::vector<compute::Expression> PartitionExpressionsOf(
void AssertFragmentsHavePartitionExpressions(std::shared_ptr<Dataset> dataset,
std::vector<compute::Expression> expected) {
ASSERT_OK_AND_ASSIGN(auto fragment_it, dataset->GetFragments());
- for (auto& expr : expected) {
- ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*dataset->schema()));
- }
// Ordering is not guaranteed.
EXPECT_THAT(PartitionExpressionsOf(IteratorToVector(std::move(fragment_it))),
testing::UnorderedElementsAreArray(expected));
diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h
index 440b95ce59..030c81a9b8 100644
--- a/cpp/src/arrow/type.h
+++ b/cpp/src/arrow/type.h
@@ -1652,7 +1652,7 @@ class ARROW_EXPORT FieldRef {
bool Equals(const FieldRef& other) const { return impl_ == other.impl_; }
bool operator==(const FieldRef& other) const { return Equals(other); }
- bool operator!=(const FieldRef& other) const { return !(*this == other); }
+ bool operator!=(const FieldRef& other) const { return !Equals(other); }
std::string ToString() const;
diff --git a/cpp/src/arrow/util/stl_util_test.cc b/cpp/src/arrow/util/stl_util_test.cc
index 2a8784e13a..3f16051f1d 100644
--- a/cpp/src/arrow/util/stl_util_test.cc
+++ b/cpp/src/arrow/util/stl_util_test.cc
@@ -103,6 +103,13 @@ TEST(StlUtilTest, VectorFlatten) {
ASSERT_EQ(expected, actual);
}
+TEST(StlUtilTest, VectorFilter) {
+ std::vector<int> input{1, 2, 3, 4, 5, 6, 7, 8, 9};
+ auto filtered = FilterVector(input, [](int i) { return i % 3 == 0; });
+
+ EXPECT_THAT(filtered, ::testing::ElementsAre(3, 6, 9));
+}
+
static std::string int_to_str(int val) { return std::to_string(val); }
TEST(StlUtilTest, VectorMap) {
diff --git a/cpp/src/arrow/util/vector.h b/cpp/src/arrow/util/vector.h
index 041bdb424a..fb15e71f98 100644
--- a/cpp/src/arrow/util/vector.h
+++ b/cpp/src/arrow/util/vector.h
@@ -78,8 +78,8 @@ std::vector<T> ReplaceVectorElement(const std::vector<T>& values, size_t index,
template <typename T, typename Predicate>
std::vector<T> FilterVector(std::vector<T> values, Predicate&& predicate) {
- auto new_end =
- std::remove_if(values.begin(), values.end(), std::forward<Predicate>(predicate));
+ auto new_end = std::stable_partition(values.begin(), values.end(),
+ std::forward<Predicate>(predicate));
values.erase(new_end, values.end());
return values;
}
diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst
index cb1fde7a31..50977b750c 100644
--- a/docs/source/cpp/compute.rst
+++ b/docs/source/cpp/compute.rst
@@ -380,7 +380,7 @@ equivalents above and reflects how they are implemented internally.
* \(6) ``hash_one`` returns one arbitrary value from the input for each
group. The function is biased towards non-null values: if there is at least
one non-null value for a certain group, that value is returned, and only if
- all the values are ``null`` for the group will the function return ``null``.
+ all the values are ``null`` for the group will the function return ``null``.
* \(7) Output is Int64, UInt64, Float64, or Decimal128/256, depending on the
input type.
@@ -1176,6 +1176,8 @@ Categorizations
+-------------------+------------+-------------------------+---------------------+------------------------+---------+
| is_valid | Unary | Any | Boolean | | \(5) |
+-------------------+------------+-------------------------+---------------------+------------------------+---------+
+| true_unless_null | Unary | Any | Boolean | | \(6) |
++-------------------+------------+-------------------------+---------------------+------------------------+---------+
* \(1) Output is true iff the corresponding input element is finite (neither Infinity,
-Infinity, nor NaN). Hence, for Decimal and integer inputs this always returns true.
@@ -1189,7 +1191,10 @@ Categorizations
* \(4) Output is true iff the corresponding input element is null. NaN values
can also be considered null by setting :member:`NullOptions::nan_is_null`.
-* \(5) Output is true iff the corresponding input element is non-null.
+* \(5) Output is true iff the corresponding input element is non-null, else false.
+
+* \(6) Output is true iff the corresponding input element is non-null, else null.
+ Mostly intended for expression simplification/guarantees.
.. _cpp-compute-scalar-selections:
diff --git a/docs/source/python/api/compute.rst b/docs/source/python/api/compute.rst
index 579c4dad80..3d52f48ed8 100644
--- a/docs/source/python/api/compute.rst
+++ b/docs/source/python/api/compute.rst
@@ -359,6 +359,7 @@ Categorizations
is_nan
is_null
is_valid
+ true_unless_null
Selecting / Multiplexing
------------------------