You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by GitBox <gi...@apache.org> on 2021/12/13 19:54:18 UTC

[GitHub] [arrow] lidavidm commented on a change in pull request #11716: ARROW-14725: [C++][Compute] Extract Expression simplification pass registry

lidavidm commented on a change in pull request #11716:
URL: https://github.com/apache/arrow/pull/11716#discussion_r768060513



##########
File path: cpp/src/arrow/compute/exec/expression.h
##########
@@ -184,13 +184,51 @@ Result<KnownFieldValues> ExtractKnownFieldValues(
 /// guarantee on a field value, an Expression must be a call to "equal" with field_ref LHS
 /// and literal RHS. Flipping the arguments, "is_in" with a one-long value_set, ... or
 /// other semantically identical Expressions will not be recognized.
+///
+/// For any simplification, if no changes could be made the identical expression will be
+/// returned (`IsIdentical(old, new)` will be true).
 
 /// Weak canonicalization which establishes guarantees for subsequent passes. Even
 /// equivalent Expressions may result in different canonicalized expressions.
 /// TODO this could be a strong canonicalization
 ARROW_EXPORT
 Result<Expression> Canonicalize(Expression, ExecContext* = NULLPTR);
 
+/// An extensible registry for simplification passes over Expressions.
+class ARROW_EXPORT ExpressionSimplificationPassRegistry {
+ public:
+  /// A pass which can operate on a bound Expression independently.
+  /// Independent passes need not recurse into Call::arguments; all independent
+  /// passes will be applied to each argument before any is applied to the call.
+  /// Expressions will be canonicalized before each pass is run.
+  using IndependentPass = std::function<Result<Expression>(Expression, ExecContext*)>;
+
+  /// A pass which utilizes a guaranteed true predicate.
+  /// Guarantee passes are allowed to invalidate independent passes;
+  /// all independent passes will be applied when any guarantee pass makes a change.
+  /// Guarantee passes need not decompose conjunctions; they will be run for
+  /// each member of a guarantee conjunction.
+  /// Guarantee passes need not recurse into Call::arguments; all guarantee
+  /// passes will be applied to each argument before any is applied to the call.
+  /// Expressions will be canonicalized before each pass is run.
+  using GuaranteePass =
+      std::function<Result<Expression>(Expression, const Expression&, ExecContext*)>;
+
+  virtual ~ExpressionSimplificationPassRegistry() = default;
+
+  virtual void Add(IndependentPass) = 0;
+  virtual void Add(GuaranteePass) = 0;
+
+  virtual Result<Expression> RunIndependentPasses(Expression, ExecContext*) = 0;
+  virtual Result<Expression> RunAllPasses(Expression,
+                                          const Expression& guaranteed_true_predicate,
+                                          ExecContext*) = 0;
+};
+
+/// The default registry, which includes built-in simplification passes.
+ARROW_EXPORT
+ExpressionSimplificationPassRegistry* default_expression_simplification_registry();

Review comment:
       nit: DefaultExpressionSimplificationRegistry?

##########
File path: cpp/src/arrow/compute/exec/expression.cc
##########
@@ -1191,5 +1097,209 @@ Expression or_(const std::vector<Expression>& operands) {
 
 Expression not_(Expression operand) { return call("invert", {std::move(operand)}); }
 
+ExpressionSimplificationPassRegistry* default_expression_simplification_registry() {
+  class DefaultRegistry : public ExpressionSimplificationPassRegistry {
+   public:
+    DefaultRegistry() {

Review comment:
       Looks like FoldConstants should be removed above?

##########
File path: cpp/src/arrow/compute/exec/expression.h
##########
@@ -184,13 +184,51 @@ Result<KnownFieldValues> ExtractKnownFieldValues(
 /// guarantee on a field value, an Expression must be a call to "equal" with field_ref LHS
 /// and literal RHS. Flipping the arguments, "is_in" with a one-long value_set, ... or
 /// other semantically identical Expressions will not be recognized.
+///
+/// For any simplification, if no changes could be made the identical expression will be
+/// returned (`IsIdentical(old, new)` will be true).
 
 /// Weak canonicalization which establishes guarantees for subsequent passes. Even
 /// equivalent Expressions may result in different canonicalized expressions.
 /// TODO this could be a strong canonicalization
 ARROW_EXPORT
 Result<Expression> Canonicalize(Expression, ExecContext* = NULLPTR);
 
+/// An extensible registry for simplification passes over Expressions.
+class ARROW_EXPORT ExpressionSimplificationPassRegistry {
+ public:
+  /// A pass which can operate on a bound Expression independently.
+  /// Independent passes need not recurse into Call::arguments; all independent
+  /// passes will be applied to each argument before any is applied to the call.
+  /// Expressions will be canonicalized before each pass is run.
+  using IndependentPass = std::function<Result<Expression>(Expression, ExecContext*)>;
+
+  /// A pass which utilizes a guaranteed true predicate.
+  /// Guarantee passes are allowed to invalidate independent passes;
+  /// all independent passes will be applied when any guarantee pass makes a change.
+  /// Guarantee passes need not decompose conjunctions; they will be run for
+  /// each member of a guarantee conjunction.
+  /// Guarantee passes need not recurse into Call::arguments; all guarantee
+  /// passes will be applied to each argument before any is applied to the call.
+  /// Expressions will be canonicalized before each pass is run.
+  using GuaranteePass =
+      std::function<Result<Expression>(Expression, const Expression&, ExecContext*)>;
+
+  virtual ~ExpressionSimplificationPassRegistry() = default;
+
+  virtual void Add(IndependentPass) = 0;
+  virtual void Add(GuaranteePass) = 0;
+
+  virtual Result<Expression> RunIndependentPasses(Expression, ExecContext*) = 0;
+  virtual Result<Expression> RunAllPasses(Expression,
+                                          const Expression& guaranteed_true_predicate,
+                                          ExecContext*) = 0;
+};
+
+/// The default registry, which includes built-in simplification passes.
+ARROW_EXPORT
+ExpressionSimplificationPassRegistry* default_expression_simplification_registry();

Review comment:
       Though I see the exec node registry and function registry use different naming schemes (default_exec_factory_registry, GetFunctionRegistry)

##########
File path: cpp/src/arrow/compute/exec/expression.cc
##########
@@ -1191,5 +1097,209 @@ Expression or_(const std::vector<Expression>& operands) {
 
 Expression not_(Expression operand) { return call("invert", {std::move(operand)}); }
 
+ExpressionSimplificationPassRegistry* default_expression_simplification_registry() {
+  class DefaultRegistry : public ExpressionSimplificationPassRegistry {
+   public:
+    DefaultRegistry() {
+      Add([](Expression expr, ExecContext* ctx) -> Result<Expression> {
+        // if all arguments to a call are literal, we can evaluate this call *now*
+        auto call = CallNotNull(expr);
+        if (std::all_of(call->arguments.begin(), call->arguments.end(),
+                        [](const Expression& argument) { return argument.literal(); })) {
+          static const ExecBatch ignored_input = ExecBatch{};
+          ARROW_ASSIGN_OR_RAISE(Datum constant,
+                                ExecuteScalarExpression(expr, ignored_input));
+
+          return literal(std::move(constant));
+        }
+        return expr;
+      });
+
+      Add([](Expression expr, ExecContext* ctx) -> Result<Expression> {
+        // kernels which always produce intersected validity can be resolved
+        // to null *now* if any of their inputs is a null literal
+        auto call = CallNotNull(expr);
+        if (GetNullHandling(*call) == compute::NullHandling::INTERSECTION) {
+          for (const auto& argument : call->arguments) {
+            if (argument.IsNullLiteral()) {
+              return argument;
+            }
+          }
+        }
+        return expr;
+      });
+
+      Add([](Expression expr, ExecContext* ctx) -> Result<Expression> {
+        auto call = CallNotNull(expr);
+        if (call->function_name == "and_kleene") {
+          // false and x == false
+          if (call->arguments[0] == literal(false)) return literal(false);
+          if (call->arguments[1] == literal(false)) return literal(false);
+
+          // true and x == x
+          if (call->arguments[0] == literal(true)) return call->arguments[1];
+          if (call->arguments[1] == literal(true)) return call->arguments[0];
+
+          // x and x == x
+          if (call->arguments[0] == call->arguments[1]) return call->arguments[0];
+        }
+        return expr;
+      });
+
+      Add([](Expression expr, ExecContext* ctx) -> Result<Expression> {
+        auto call = CallNotNull(expr);
+        if (call->function_name == "or_kleene") {
+          // true or x == true
+          if (call->arguments[0] == literal(true)) return literal(true);
+          if (call->arguments[1] == literal(true)) return literal(true);
+
+          // false or x == x
+          if (call->arguments[0] == literal(false)) return call->arguments[1];
+          if (call->arguments[1] == literal(false)) return call->arguments[0];
+
+          // x or x == x
+          if (call->arguments[0] == call->arguments[1]) return call->arguments[0];
+        }
+        return expr;
+      });
+
+      Add([](Expression expr, const Expression& guarantee_expr,
+             ExecContext* ctx) -> Result<Expression> {
+        // Ensure both calls are comparisons with equal LHS and scalar RHS
+        auto cmp = Comparison::Get(expr);
+        auto cmp_guarantee = Comparison::Get(guarantee_expr);
+
+        if (!cmp) return expr;
+        if (!cmp_guarantee) return expr;
+
+        const auto& args = CallNotNull(expr)->arguments;
+        const auto& guarantee_args = CallNotNull(guarantee_expr)->arguments;
+
+        const auto& lhs = Comparison::StripOrderPreservingCasts(args[0]);
+        const auto& guarantee_lhs = guarantee_args[0];
+        if (lhs != guarantee_lhs) return expr;
+
+        auto rhs = args[1].literal();
+        auto guarantee_rhs = guarantee_args[1].literal();
+
+        if (!rhs) return expr;
+        if (!rhs->is_scalar()) return expr;
+
+        if (!guarantee_rhs) return expr;
+        if (!guarantee_rhs->is_scalar()) return expr;
+
+        ARROW_ASSIGN_OR_RAISE(auto cmp_rhs_guarantee_rhs,
+                              Comparison::Execute(*rhs, *guarantee_rhs));
+        DCHECK_NE(cmp_rhs_guarantee_rhs, Comparison::NA);
+
+        if (cmp_rhs_guarantee_rhs == Comparison::EQUAL) {
+          // RHS of filter is equal to RHS of guarantee
+
+          if ((*cmp & *cmp_guarantee) == *cmp_guarantee) {
+            // guarantee is a subset of filter, so all data will be included
+            // x > 1, x >= 1, x != 1 guaranteed by x > 1
+            return literal(true);
+          }
+
+          if ((*cmp & *cmp_guarantee) == 0) {
+            // guarantee disjoint with filter, so all data will be excluded
+            // x > 1, x >= 1, x != 1 unsatisfiable if x == 1
+            return literal(false);
+          }
+
+          return expr;
+        }
+
+        if (*cmp_guarantee & cmp_rhs_guarantee_rhs) {
+          // x > 1, x >= 1, x != 1 cannot use guarantee x >= 3
+          return expr;
+        }
+
+        if (*cmp & Comparison::GetFlipped(cmp_rhs_guarantee_rhs)) {
+          // x > 1, x >= 1, x != 1 guaranteed by x >= 3
+          return literal(true);
+        } else {
+          // x < 1, x <= 1, x == 1 unsatisfiable if x >= 3
+          return literal(false);
+        }
+      });
+    }
+
+    void Add(IndependentPass p) override { independent_passes_.push_back(std::move(p)); }
+
+    void Add(GuaranteePass p) override { guarantee_passes_.push_back(std::move(p)); }
+
+    Result<Expression> RunIndependentPasses(Expression expr, ExecContext* ctx) override {
+      ARROW_ASSIGN_OR_RAISE(auto canonicalized, Canonicalize(expr, ctx));
+
+      ARROW_ASSIGN_OR_RAISE(
+          auto simplified,
+          Modify(
+              canonicalized, [](Expression expr) { return expr; },
+              [&](Expression expr, ...) -> Result<Expression> {
+                for (const auto& pass : independent_passes_) {
+                  ARROW_ASSIGN_OR_RAISE(auto simplified, pass(expr, ctx));
+                  if (Identical(simplified, expr)) continue;
+
+                  ARROW_ASSIGN_OR_RAISE(expr, Canonicalize(std::move(simplified), ctx));
+                  if (!expr.call()) return expr;
+                }
+                return expr;
+              }));
+
+      if (Identical(simplified, canonicalized)) return expr;

Review comment:
       It also seems we don't need to canonicalize below before calling this.

##########
File path: cpp/src/arrow/compute/exec/expression.cc
##########
@@ -1191,5 +1097,209 @@ Expression or_(const std::vector<Expression>& operands) {
 
 Expression not_(Expression operand) { return call("invert", {std::move(operand)}); }
 
+ExpressionSimplificationPassRegistry* default_expression_simplification_registry() {
+  class DefaultRegistry : public ExpressionSimplificationPassRegistry {
+   public:
+    DefaultRegistry() {
+      Add([](Expression expr, ExecContext* ctx) -> Result<Expression> {
+        // if all arguments to a call are literal, we can evaluate this call *now*
+        auto call = CallNotNull(expr);
+        if (std::all_of(call->arguments.begin(), call->arguments.end(),
+                        [](const Expression& argument) { return argument.literal(); })) {
+          static const ExecBatch ignored_input = ExecBatch{};
+          ARROW_ASSIGN_OR_RAISE(Datum constant,
+                                ExecuteScalarExpression(expr, ignored_input));
+
+          return literal(std::move(constant));
+        }
+        return expr;
+      });
+
+      Add([](Expression expr, ExecContext* ctx) -> Result<Expression> {
+        // kernels which always produce intersected validity can be resolved
+        // to null *now* if any of their inputs is a null literal
+        auto call = CallNotNull(expr);
+        if (GetNullHandling(*call) == compute::NullHandling::INTERSECTION) {
+          for (const auto& argument : call->arguments) {
+            if (argument.IsNullLiteral()) {
+              return argument;
+            }
+          }
+        }
+        return expr;
+      });
+
+      Add([](Expression expr, ExecContext* ctx) -> Result<Expression> {
+        auto call = CallNotNull(expr);
+        if (call->function_name == "and_kleene") {
+          // false and x == false
+          if (call->arguments[0] == literal(false)) return literal(false);
+          if (call->arguments[1] == literal(false)) return literal(false);
+
+          // true and x == x
+          if (call->arguments[0] == literal(true)) return call->arguments[1];
+          if (call->arguments[1] == literal(true)) return call->arguments[0];
+
+          // x and x == x
+          if (call->arguments[0] == call->arguments[1]) return call->arguments[0];
+        }
+        return expr;
+      });
+
+      Add([](Expression expr, ExecContext* ctx) -> Result<Expression> {
+        auto call = CallNotNull(expr);
+        if (call->function_name == "or_kleene") {
+          // true or x == true
+          if (call->arguments[0] == literal(true)) return literal(true);
+          if (call->arguments[1] == literal(true)) return literal(true);
+
+          // false or x == x
+          if (call->arguments[0] == literal(false)) return call->arguments[1];
+          if (call->arguments[1] == literal(false)) return call->arguments[0];
+
+          // x or x == x
+          if (call->arguments[0] == call->arguments[1]) return call->arguments[0];
+        }
+        return expr;
+      });
+
+      Add([](Expression expr, const Expression& guarantee_expr,
+             ExecContext* ctx) -> Result<Expression> {
+        // Ensure both calls are comparisons with equal LHS and scalar RHS
+        auto cmp = Comparison::Get(expr);
+        auto cmp_guarantee = Comparison::Get(guarantee_expr);
+
+        if (!cmp) return expr;
+        if (!cmp_guarantee) return expr;
+
+        const auto& args = CallNotNull(expr)->arguments;
+        const auto& guarantee_args = CallNotNull(guarantee_expr)->arguments;
+
+        const auto& lhs = Comparison::StripOrderPreservingCasts(args[0]);
+        const auto& guarantee_lhs = guarantee_args[0];
+        if (lhs != guarantee_lhs) return expr;
+
+        auto rhs = args[1].literal();
+        auto guarantee_rhs = guarantee_args[1].literal();
+
+        if (!rhs) return expr;
+        if (!rhs->is_scalar()) return expr;
+
+        if (!guarantee_rhs) return expr;
+        if (!guarantee_rhs->is_scalar()) return expr;
+
+        ARROW_ASSIGN_OR_RAISE(auto cmp_rhs_guarantee_rhs,
+                              Comparison::Execute(*rhs, *guarantee_rhs));
+        DCHECK_NE(cmp_rhs_guarantee_rhs, Comparison::NA);
+
+        if (cmp_rhs_guarantee_rhs == Comparison::EQUAL) {
+          // RHS of filter is equal to RHS of guarantee
+
+          if ((*cmp & *cmp_guarantee) == *cmp_guarantee) {
+            // guarantee is a subset of filter, so all data will be included
+            // x > 1, x >= 1, x != 1 guaranteed by x > 1
+            return literal(true);
+          }
+
+          if ((*cmp & *cmp_guarantee) == 0) {
+            // guarantee disjoint with filter, so all data will be excluded
+            // x > 1, x >= 1, x != 1 unsatisfiable if x == 1
+            return literal(false);
+          }
+
+          return expr;
+        }
+
+        if (*cmp_guarantee & cmp_rhs_guarantee_rhs) {
+          // x > 1, x >= 1, x != 1 cannot use guarantee x >= 3
+          return expr;
+        }
+
+        if (*cmp & Comparison::GetFlipped(cmp_rhs_guarantee_rhs)) {
+          // x > 1, x >= 1, x != 1 guaranteed by x >= 3
+          return literal(true);
+        } else {
+          // x < 1, x <= 1, x == 1 unsatisfiable if x >= 3
+          return literal(false);
+        }
+      });
+    }
+
+    void Add(IndependentPass p) override { independent_passes_.push_back(std::move(p)); }
+
+    void Add(GuaranteePass p) override { guarantee_passes_.push_back(std::move(p)); }
+
+    Result<Expression> RunIndependentPasses(Expression expr, ExecContext* ctx) override {
+      ARROW_ASSIGN_OR_RAISE(auto canonicalized, Canonicalize(expr, ctx));
+
+      ARROW_ASSIGN_OR_RAISE(
+          auto simplified,
+          Modify(
+              canonicalized, [](Expression expr) { return expr; },
+              [&](Expression expr, ...) -> Result<Expression> {
+                for (const auto& pass : independent_passes_) {
+                  ARROW_ASSIGN_OR_RAISE(auto simplified, pass(expr, ctx));
+                  if (Identical(simplified, expr)) continue;
+
+                  ARROW_ASSIGN_OR_RAISE(expr, Canonicalize(std::move(simplified), ctx));
+                  if (!expr.call()) return expr;
+                }
+                return expr;
+              }));
+
+      if (Identical(simplified, canonicalized)) return expr;

Review comment:
       Should this return `canonicalized`?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org