lidavidm commented on code in PR #12891:
URL: https://github.com/apache/arrow/pull/12891#discussion_r853529547


##########
cpp/src/arrow/compute/exec/expression.cc:
##########
@@ -879,79 +918,183 @@ Result<Expression> Canonicalize(Expression expr, 
compute::ExecContext* exec_cont
 
 namespace {
 
-Result<Expression> DirectComparisonSimplification(Expression expr,
-                                                  const Expression::Call& 
guarantee) {
-  return Modify(
-      std::move(expr), [](Expression expr) { return expr; },
-      [&guarantee](Expression expr, ...) -> Result<Expression> {
-        auto call = expr.call();
-        if (!call) return expr;
+// An inequality comparison which a target Expression is known to satisfy. If 
nullable,
+// the target may evaluate to null in addition to values satisfying the 
comparison.
+struct Inequality {
+  Comparison::type cmp;
+  const FieldRef& target;
+  const Datum& bound;
+  bool nullable;
+
+  // Extract an Inequality if possible, derived from "less",
+  // "greater", "less_equal", and "greater_equal" expressions,
+  // possibly disjuncted with an "is_null" Expression.
+  // cmp(a, 2)
+  // cmp(a, 2) or is_null(a)
+  static util::optional<Inequality> ExtractOne(const Expression& guarantee) {
+    auto call = guarantee.call();
+    if (!call) return util::nullopt;
+
+    if (call->function_name == "or_kleene") {
+      // expect the LHS to be a usable field inequality
+      auto out = ExtractOneFromComparison(call->arguments[0]);
+      if (!out) return util::nullopt;
+
+      // expect the RHS to be an is_null expression
+      auto call_rhs = call->arguments[1].call();
+      if (!call_rhs) return util::nullopt;
+      if (call_rhs->function_name != "is_null") return util::nullopt;
+
+      // ... and that it references the same target
+      auto target = call_rhs->arguments[0].field_ref();
+      if (!target) return util::nullopt;
+      if (*target != out->target) return util::nullopt;
+
+      out->nullable = true;
+      return out;
+    }
 
-        // Ensure both calls are comparisons with equal LHS and scalar RHS
-        auto cmp = Comparison::Get(expr);
-        auto cmp_guarantee = Comparison::Get(guarantee.function_name);
+    // fall back to a simple comparison with no "is_null"
+    return ExtractOneFromComparison(guarantee);
+  }
 
-        if (!cmp) return expr;
-        if (!cmp_guarantee) return expr;
+  static util::optional<Inequality> ExtractOneFromComparison(
+      const Expression& guarantee) {
+    auto call = guarantee.call();
+    if (!call) return util::nullopt;
 
-        const auto& lhs = 
Comparison::StripOrderPreservingCasts(call->arguments[0]);
-        const auto& guarantee_lhs = guarantee.arguments[0];
-        if (lhs != guarantee_lhs) return expr;
+    if (auto cmp = Comparison::Get(call->function_name)) {
+      // not_equal comparisons are not very usable as guarantees
+      if (*cmp == Comparison::NOT_EQUAL) return util::nullopt;
 
-        auto rhs = call->arguments[1].literal();
-        auto guarantee_rhs = guarantee.arguments[1].literal();
+      auto target = call->arguments[0].field_ref();
+      if (!target) return util::nullopt;
 
-        if (!rhs) return expr;
-        if (!rhs->is_scalar()) return expr;
+      auto bound = call->arguments[1].literal();
+      if (!bound) return util::nullopt;
+      if (!bound->is_scalar()) return util::nullopt;
 
-        if (!guarantee_rhs) return expr;
-        if (!guarantee_rhs->is_scalar()) return expr;
+      return Inequality{*cmp, /*target=*/*target, *bound, /*nullable=*/false};
+    }
 
-        ARROW_ASSIGN_OR_RAISE(auto cmp_rhs_guarantee_rhs,
-                              Comparison::Execute(*rhs, *guarantee_rhs));
-        DCHECK_NE(cmp_rhs_guarantee_rhs, Comparison::NA);
+    return util::nullopt;
+  }
 
-        if (cmp_rhs_guarantee_rhs == Comparison::EQUAL) {
-          // RHS of filter is equal to RHS of guarantee
+  /// The given expression simplifies to `value` if the inequality
+  /// target is not nullable. Otherwise, it simplifies to either a
+  /// call to true_unless_null or !true_unless_null.
+  Result<Expression> simplified_to(const Expression& bound_target, bool value) 
const {
+    if (!nullable) return literal(value);
+
+    ExecContext exec_context;
+
+    // Data may be null, so comparison will yield `value` - or null IFF the 
data was null
+    //
+    // true_unless_null is cheap; it purely reuses the validity bitmap for the 
values
+    // buffer. Inversion is less cheap but we expect that term never to be 
evaluated
+    // since invert(true_unless_null(x)) is not satisfiable.
+    Expression::Call call;
+    call.function_name = "true_unless_null";
+    call.arguments = {bound_target};
+    ARROW_ASSIGN_OR_RAISE(
+        auto true_unless_null,
+        BindNonRecursive(std::move(call),
+                         /*insert_implicit_casts=*/false, &exec_context));
+    if (value) return true_unless_null;
+
+    Expression::Call invert;
+    invert.function_name = "invert";
+    invert.arguments = {std::move(true_unless_null)};
+    return BindNonRecursive(std::move(invert),
+                            /*insert_implicit_casts=*/false, &exec_context);
+  }
 
-          if ((*cmp & *cmp_guarantee) == *cmp_guarantee) {
-            // guarantee is a subset of filter, so all data will be included
-            // x > 1, x >= 1, x != 1 guaranteed by x > 1
-            return literal(true);
-          }
+  /// \brief Simplify the given expression given this inequality as a 
guarantee.
+  Result<Expression> Simplify(Expression expr) {
+    const auto& guarantee = *this;
 
-          if ((*cmp & *cmp_guarantee) == 0) {
-            // guarantee disjoint with filter, so all data will be excluded
-            // x > 1, x >= 1, x != 1 unsatisfiable if x == 1
-            return literal(false);
-          }
+    auto call = expr.call();
+    if (!call) return expr;
 
-          return expr;
-        }
+    auto cmp = Comparison::Get(expr);
+    if (!cmp) return expr;
 
-        if (*cmp_guarantee & cmp_rhs_guarantee_rhs) {
-          // x > 1, x >= 1, x != 1 cannot use guarantee x >= 3
-          return expr;
-        }
+    auto rhs = call->arguments[1].literal();
+    if (!rhs) return expr;
+    if (!rhs->is_scalar()) return expr;
 
-        if (*cmp & Comparison::GetFlipped(cmp_rhs_guarantee_rhs)) {
-          // x > 1, x >= 1, x != 1 guaranteed by x >= 3
-          return literal(true);
-        } else {
-          // x < 1, x <= 1, x == 1 unsatisfiable if x >= 3
-          return literal(false);
-        }
+    const auto& lhs = 
Comparison::StripOrderPreservingCasts(call->arguments[0]);
+    if (!lhs.field_ref()) return expr;
+    if (*lhs.field_ref() != guarantee.target) return expr;
+
+    ARROW_ASSIGN_OR_RAISE(auto cmp_rhs_bound, Comparison::Execute(*rhs, 
guarantee.bound));
+    DCHECK_NE(cmp_rhs_bound, Comparison::NA);
+
+    if (cmp_rhs_bound == Comparison::EQUAL) {
+      // RHS of filter is equal to RHS of guarantee
+
+      if ((*cmp & guarantee.cmp) == guarantee.cmp) {
+        // guarantee is a subset of filter, so all data will be included
+        // x > 1, x >= 1, x != 1 guaranteed by x > 1
+        return simplified_to(lhs, true);
+      }
+
+      if ((*cmp & guarantee.cmp) == 0) {
+        // guarantee disjoint with filter, so all data will be excluded
+        // x > 1, x >= 1, x != 1 unsatisfiable if x == 1
+        return simplified_to(lhs, false);
+      }
+
+      return expr;
+    }
+
+    if (guarantee.cmp & cmp_rhs_bound) {

Review Comment:
   Alright, added comments here and for the other feedback to try to clarify 
things. This conditional is surprisingly subtle…



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to