felipecrv commented on code in PR #43256:
URL: https://github.com/apache/arrow/pull/43256#discussion_r1693513214


##########
cpp/src/arrow/compute/expression.cc:
##########
@@ -1148,6 +1149,34 @@ Result<Expression> Canonicalize(Expression expr, 
compute::ExecContext* exec_cont
 
 namespace {
 
+/// Preprocess an `is_in` predicate value set for simplification.
+/// \pre `value_set` is non-empty
+/// \return the value set sorted with duplicate and null values removed
+Result<std::shared_ptr<Array>> PrepareIsInValueSet(std::shared_ptr<Array> 
value_set) {
+  DCHECK_GT(value_set->length(), 0);
+  ARROW_ASSIGN_OR_RAISE(value_set, Unique(value_set));

Review Comment:
   ```suggestion
     ARROW_ASSIGN_OR_RAISE(value_set, Unique(std::move(value_set)));
   ```
   
   To move it into the `Datum` and skip some refcount mutations.



##########
cpp/src/arrow/compute/expression.cc:
##########
@@ -1242,8 +1271,74 @@ struct Inequality {
                             /*insert_implicit_casts=*/false, &exec_context);
   }
 
+  /// Simplify an `is_in` value set against an inequality guarantee.
+  ///
+  /// Simplifying an `is_in` predicate involves filtering out any values from
+  /// the value set that cannot possibly be found given the guarantee. For
+  /// example, if we have the predicate 'x is_in [1, 2, 3, 4]' and the 
guarantee
+  /// 'x > 2', then the simplified predicate 'x is_in [3, 4]' is equivalent.
+  /// This can be done efficiently if the value set is sorted and unique by
+  /// binary searching the inequality gound and slicing the value set
+  /// accordingly.
+  ///
+  /// \pre `guarantee` is non-nullable
+  /// \pre `value_set` is non-empty
+  /// \return a simplified value set, or a bool if the simplification results 
in
+  ///   a boolean literal predicate.
+  static Result<std::variant<std::shared_ptr<Array>, bool>> 
SimplifyIsInValueSet(
+      const Inequality& guarantee, std::shared_ptr<Array> value_set) {
+    DCHECK_GT(value_set->length(), 0);
+
+    auto compare = [&guarantee, &value_set](size_t i) -> 
Result<Comparison::type> {
+      ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar, 
value_set->GetScalar(i));
+      DCHECK(scalar->is_valid);
+      ARROW_ASSIGN_OR_RAISE(Comparison::type cmp,
+                            Comparison::Execute(scalar, guarantee.bound));
+      return cmp;
+    };
+
+    size_t lo = 0;
+    size_t hi = value_set->length();
+    while (lo + 1 < hi) {
+      size_t mid = (lo + hi) / 2;
+      ARROW_ASSIGN_OR_RAISE(Comparison::type cmp, compare(mid));
+      if (cmp & Comparison::LESS_EQUAL) {
+        lo = mid;
+      } else {
+        hi = mid;
+      }
+    }
+
+    ARROW_ASSIGN_OR_RAISE(Comparison::type cmp, compare(lo));
+    size_t pivot = lo + (cmp == Comparison::LESS ? 1 : 0);
+    bool found = cmp == Comparison::EQUAL;
+
+    switch (guarantee.cmp) {
+      case Comparison::EQUAL:
+        return found;
+      case Comparison::LESS:
+        value_set = value_set->Slice(0, pivot);
+        break;
+      case Comparison::LESS_EQUAL:
+        value_set = value_set->Slice(0, pivot + (found ? 1 : 0));
+        break;
+      case Comparison::GREATER:
+        value_set = value_set->Slice(pivot + (found ? 1 : 0));
+        break;
+      case Comparison::GREATER_EQUAL:
+        value_set = value_set->Slice(pivot);
+        break;
+      default:
+        DCHECK(false);
+        break;

Review Comment:
   List `NA` explicitly instead of having the `default` so the compiler can 
warn in case another enum entry is added.



##########
cpp/src/arrow/compute/expression.cc:
##########
@@ -1148,6 +1149,34 @@ Result<Expression> Canonicalize(Expression expr, 
compute::ExecContext* exec_cont
 
 namespace {
 
+/// Preprocess an `is_in` predicate value set for simplification.
+/// \pre `value_set` is non-empty
+/// \return the value set sorted with duplicate and null values removed
+Result<std::shared_ptr<Array>> PrepareIsInValueSet(std::shared_ptr<Array> 
value_set) {
+  DCHECK_GT(value_set->length(), 0);
+  ARROW_ASSIGN_OR_RAISE(value_set, Unique(value_set));
+  ARROW_ASSIGN_OR_RAISE(
+      std::shared_ptr<Array> sort_indices,
+      SortIndices(value_set, SortOptions({}, NullPlacement::AtEnd)));
+  ARROW_ASSIGN_OR_RAISE(
+      value_set,
+      Take(*value_set, *sort_indices, TakeOptions(/*bounds_check=*/false)));
+  if (value_set->IsNull(value_set->length() - 1)) {
+    value_set = value_set->Slice(0, value_set->length() - 1);
+  }

Review Comment:
   ```suggestion
     if (value_set->IsNull(value_set->length() - 1)) {
       // If the last one is null we know it's the only
       // one because of the call to `Unique` above.
       value_set = value_set->Slice(0, value_set->length() - 1);
     }
   ```



##########
cpp/src/arrow/compute/expression.cc:
##########
@@ -1148,6 +1149,34 @@ Result<Expression> Canonicalize(Expression expr, 
compute::ExecContext* exec_cont
 
 namespace {
 
+/// Preprocess an `is_in` predicate value set for simplification.
+/// \pre `value_set` is non-empty
+/// \return the value set sorted with duplicate and null values removed
+Result<std::shared_ptr<Array>> PrepareIsInValueSet(std::shared_ptr<Array> 
value_set) {
+  DCHECK_GT(value_set->length(), 0);
+  ARROW_ASSIGN_OR_RAISE(value_set, Unique(value_set));
+  ARROW_ASSIGN_OR_RAISE(
+      std::shared_ptr<Array> sort_indices,
+      SortIndices(value_set, SortOptions({}, NullPlacement::AtEnd)));
+  ARROW_ASSIGN_OR_RAISE(
+      value_set,
+      Take(*value_set, *sort_indices, TakeOptions(/*bounds_check=*/false)));
+  if (value_set->IsNull(value_set->length() - 1)) {
+    value_set = value_set->Slice(0, value_set->length() - 1);
+  }
+  return value_set;
+}
+
+/// Context for expression simplification.
+struct SimplificationContext {
+  /// Mapping from `is_in` calls to simplified value sets.
+  ///
+  /// `is_in` predicates with large value sets can be expensive to bind, so we
+  /// accumulate simplifications in the context and defer binding until all
+  /// inequalities have been processed.
+  std::unordered_map<const Expression::Call*, std::shared_ptr<Array>> 
is_in_value_sets;

Review Comment:
   And using pointers as hash table keys is not a good idea.



##########
cpp/src/arrow/compute/expression.cc:
##########
@@ -1148,6 +1149,34 @@ Result<Expression> Canonicalize(Expression expr, 
compute::ExecContext* exec_cont
 
 namespace {
 
+/// Preprocess an `is_in` predicate value set for simplification.
+/// \pre `value_set` is non-empty
+/// \return the value set sorted with duplicate and null values removed
+Result<std::shared_ptr<Array>> PrepareIsInValueSet(std::shared_ptr<Array> 
value_set) {
+  DCHECK_GT(value_set->length(), 0);
+  ARROW_ASSIGN_OR_RAISE(value_set, Unique(value_set));
+  ARROW_ASSIGN_OR_RAISE(
+      std::shared_ptr<Array> sort_indices,
+      SortIndices(value_set, SortOptions({}, NullPlacement::AtEnd)));
+  ARROW_ASSIGN_OR_RAISE(
+      value_set,
+      Take(*value_set, *sort_indices, TakeOptions(/*bounds_check=*/false)));
+  if (value_set->IsNull(value_set->length() - 1)) {
+    value_set = value_set->Slice(0, value_set->length() - 1);
+  }
+  return value_set;
+}
+
+/// Context for expression simplification.
+struct SimplificationContext {
+  /// Mapping from `is_in` calls to simplified value sets.
+  ///
+  /// `is_in` predicates with large value sets can be expensive to bind, so we
+  /// accumulate simplifications in the context and defer binding until all
+  /// inequalities have been processed.
+  std::unordered_map<const Expression::Call*, std::shared_ptr<Array>> 
is_in_value_sets;

Review Comment:
   Might be better to add a new nullable field to `Expression::Call` called 
`value_set_for_is_in`.



##########
cpp/src/arrow/compute/expression.cc:
##########
@@ -1242,8 +1271,74 @@ struct Inequality {
                             /*insert_implicit_casts=*/false, &exec_context);
   }
 
+  /// Simplify an `is_in` value set against an inequality guarantee.
+  ///
+  /// Simplifying an `is_in` predicate involves filtering out any values from
+  /// the value set that cannot possibly be found given the guarantee. For
+  /// example, if we have the predicate 'x is_in [1, 2, 3, 4]' and the 
guarantee
+  /// 'x > 2', then the simplified predicate 'x is_in [3, 4]' is equivalent.
+  /// This can be done efficiently if the value set is sorted and unique by
+  /// binary searching the inequality gound and slicing the value set
+  /// accordingly.
+  ///
+  /// \pre `guarantee` is non-nullable
+  /// \pre `value_set` is non-empty
+  /// \return a simplified value set, or a bool if the simplification results 
in
+  ///   a boolean literal predicate.
+  static Result<std::variant<std::shared_ptr<Array>, bool>> 
SimplifyIsInValueSet(

Review Comment:
   The `bool` is never `true`. You can use `nullptr` to represent the `false` 
case and get rid of the `variant`. Can call the function `Try...` and document 
that it may return `nullptr`. No less safe than dealing with a `std::variant`.



##########
cpp/src/arrow/compute/expression.cc:
##########
@@ -1258,6 +1353,47 @@ struct Inequality {
       return call->function_name == "is_valid" ? literal(true) : 
literal(false);
     }
 
+    if (call->function_name == "is_in") {
+      // Null-matching behavior is complex and reduces the chances of reduction
+      // of `is_in` calls to a single literal for every possible input, so we
+      // abort the simplification if nulls are possible in the input.
+      if (guarantee.nullable) return expr;
+
+      const auto& lhs = 
Comparison::StripOrderPreservingCasts(call->arguments[0]);
+      if (!lhs.field_ref()) return expr;
+      if (*lhs.field_ref() != guarantee.target) return expr;
+
+      auto options = checked_pointer_cast<SetLookupOptions>(call->options);
+      auto unsimplified_value_set = options->value_set.make_array();
+      if (unsimplified_value_set->length() == 0) return literal(false);
+
+      std::shared_ptr<Array>& value_set = context.is_in_value_sets[call];
+      if (!value_set) {
+        // Simplification for `is_in` requires that the value set is 
preprocessed.
+        // We store the prepared value set in the kernel state to avoid 
repeated
+        // preprocessing across calls to `SimplifyWithGuarantee`.
+        auto state = checked_pointer_cast<internal::SetLookupStateBase>(
+            call->kernel_state);
+        if (!state->sorted_and_unique_value_set) {
+          ARROW_ASSIGN_OR_RAISE(state->sorted_and_unique_value_set,
+                                PrepareIsInValueSet(unsimplified_value_set));
+        }
+        if (state->sorted_and_unique_value_set->length() == 0) {
+          context.is_in_value_sets.erase(call);
+          return literal(false);
+        }
+        value_set = state->sorted_and_unique_value_set;

Review Comment:
   Why can't you keep the `value_set` in the 
`state->sorted_and_unique_value_set`? It's already associated with the `Call`. 
You don't need the unordered map.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to