felipecrv commented on code in PR #43256:
URL: https://github.com/apache/arrow/pull/43256#discussion_r1697456034


##########
cpp/src/arrow/compute/expression.cc:
##########
@@ -1258,6 +1373,92 @@ struct Inequality {
       return call->function_name == "is_valid" ? literal(true) : 
literal(false);
     }
 
+    if (call->function_name == "is_in") {
+      // Null-matching behavior is complex and reduces the chances of reduction
+      // of `is_in` calls to a single literal for every possible input, so we
+      // abort the simplification if nulls are possible in the input.
+      if (guarantee.nullable) return expr;
+
+      if (!guarantee.bound.is_scalar()) {
+        return Status::Invalid("Cannot simplify inequality on a non-scalar 
bound");
+      }
+
+      const auto& lhs = 
Comparison::StripOrderPreservingCasts(call->arguments[0]);
+      if (!lhs.field_ref()) return expr;
+      if (*lhs.field_ref() != guarantee.target) return expr;
+
+      auto options = checked_pointer_cast<SetLookupOptions>(call->options);
+      auto unsimplified_value_set = options->value_set.make_array();
+      if (unsimplified_value_set->length() == 0) return literal(false);
+
+      Type::type value_set_type = unsimplified_value_set->type_id();
+      // Simplification for `is_in` requires that comparison kernels exist, so
+      // we skip simplification for non-primitive and non-base-binary types.
+      if (!is_integer(value_set_type) && !is_temporal(value_set_type)
+          && !is_base_binary_like(value_set_type)) {
+        return expr;
+      }

Review Comment:
   You need to figure out a way of making the `switch` you have below work as 
both dispatching and validation. Trying to keep these two in sync is too 
error-prone and inefficient/inelegant.



##########
cpp/src/arrow/compute/expression.cc:
##########
@@ -1258,6 +1373,92 @@ struct Inequality {
       return call->function_name == "is_valid" ? literal(true) : 
literal(false);
     }
 
+    if (call->function_name == "is_in") {
+      // Null-matching behavior is complex and reduces the chances of reduction
+      // of `is_in` calls to a single literal for every possible input, so we
+      // abort the simplification if nulls are possible in the input.
+      if (guarantee.nullable) return expr;
+
+      if (!guarantee.bound.is_scalar()) {
+        return Status::Invalid("Cannot simplify inequality on a non-scalar 
bound");
+      }
+
+      const auto& lhs = 
Comparison::StripOrderPreservingCasts(call->arguments[0]);
+      if (!lhs.field_ref()) return expr;
+      if (*lhs.field_ref() != guarantee.target) return expr;
+
+      auto options = checked_pointer_cast<SetLookupOptions>(call->options);
+      auto unsimplified_value_set = options->value_set.make_array();
+      if (unsimplified_value_set->length() == 0) return literal(false);
+
+      Type::type value_set_type = unsimplified_value_set->type_id();
+      // Simplification for `is_in` requires that comparison kernels exist, so
+      // we skip simplification for non-primitive and non-base-binary types.
+      if (!is_integer(value_set_type) && !is_temporal(value_set_type)
+          && !is_base_binary_like(value_set_type)) {
+        return expr;
+      }
+      // For now, we abort simplification if the guarantee bound's type does 
not
+      // exactly match the value set's type.
+      if (guarantee.bound.type()->id() != value_set_type) return expr;
+
+      std::shared_ptr<Array>& value_set = context.is_in_value_sets[call];
+      if (!value_set) {
+        // Simplification for `is_in` requires that the value set is 
preprocessed.
+        // We store the prepared value set in the kernel state to avoid 
repeated
+        // preprocessing across calls to `SimplifyWithGuarantee`.
+        auto state = checked_pointer_cast<internal::SetLookupStateBase>(
+            call->kernel_state);
+        if (!state->sorted_and_unique_value_set) {
+          ARROW_ASSIGN_OR_RAISE(state->sorted_and_unique_value_set,
+                                PrepareIsInValueSet(unsimplified_value_set));
+        }
+        if (state->sorted_and_unique_value_set->length() == 0) {
+          context.is_in_value_sets.erase(call);
+          return literal(false);
+        }
+        value_set = state->sorted_and_unique_value_set;
+      }
+
+#define CASE(TYPE_CLASS)                                                       
\
+  case TYPE_CLASS##Type::type_id:                                              
\
+    result = SimplifyIsInValueSet<TYPE_CLASS##Type>(guarantee, value_set);     
\
+    break;
+
+      std::variant<std::shared_ptr<Array>, bool> result;
+      switch (value_set_type) {
+        CASE(UInt8)
+        CASE(Int8)
+        CASE(UInt16)
+        CASE(Int16)
+        CASE(UInt32)
+        CASE(Int32)
+        CASE(UInt64)
+        CASE(Int64)
+        CASE(Date32)
+        CASE(Date64)
+        CASE(Time32)
+        CASE(Time64)
+        CASE(Timestamp)
+        CASE(Binary)
+        CASE(String)
+        CASE(LargeBinary)
+        CASE(LargeString)
+        default:
+          DCHECK(false);
+          return expr;
+      }
+
+#undef CASE
+
+      if (std::holds_alternative<bool>(result)) {
+        context.is_in_value_sets.erase(call);
+        return literal(std::get<bool>(result));
+      }
+      value_set = std::get<std::shared_ptr<Array>>(result);
+      return expr;

Review Comment:
   These 70 lines can become a separate function -- `SimplifyIsIn` -- because 
we shouldn't fit all the simplifications (for all simplifiable functions) in 
here.



##########
cpp/src/arrow/compute/expression.cc:
##########
@@ -1242,8 +1273,92 @@ struct Inequality {
                             /*insert_implicit_casts=*/false, &exec_context);
   }
 
+  /// Simplify an `is_in` value set against an inequality guarantee.
+  ///
+  /// Simplifying an `is_in` predicate involves filtering out any values from
+  /// the value set that cannot possibly be found given the guarantee. For
+  /// example, if we have the predicate 'x is_in [1, 2, 3, 4]' and the 
guarantee
+  /// 'x > 2', then the simplified predicate 'x is_in [3, 4]' is equivalent.
+  /// This can be done efficiently if the value set is sorted and unique by
+  /// binary searching the inequality gound and slicing the value set
+  /// accordingly.
+  ///
+  /// \pre `guarantee` is non-nullable
+  /// \pre `guarantee.bound` is a scalar
+  /// \pre `guarantee.bound.type()->id() == value_set->type_id()`
+  /// \pre `value_set` is non-empty
+  /// \return a simplified value set, or a bool if the simplification results 
in
+  ///   a boolean literal predicate.
+  template <typename ArrowType>
+  static std::variant<std::shared_ptr<Array>, bool> SimplifyIsInValueSet(
+      const Inequality& guarantee, std::shared_ptr<Array> value_set) {
+    using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
+    using ScalarType = typename TypeTraits<ArrowType>::ScalarType;
+    using CType = 
decltype(checked_pointer_cast<ArrayType>(value_set)->Value(0));

Review Comment:
   You can get the c type from the TypeTraits. If it's failing for some type, 
maybe you shouldn't be calling this with that type.



##########
cpp/src/arrow/compute/expression.cc:
##########
@@ -1242,8 +1273,92 @@ struct Inequality {
                             /*insert_implicit_casts=*/false, &exec_context);
   }
 
+  /// Simplify an `is_in` value set against an inequality guarantee.
+  ///
+  /// Simplifying an `is_in` predicate involves filtering out any values from
+  /// the value set that cannot possibly be found given the guarantee. For
+  /// example, if we have the predicate 'x is_in [1, 2, 3, 4]' and the 
guarantee
+  /// 'x > 2', then the simplified predicate 'x is_in [3, 4]' is equivalent.
+  /// This can be done efficiently if the value set is sorted and unique by
+  /// binary searching the inequality gound and slicing the value set
+  /// accordingly.
+  ///
+  /// \pre `guarantee` is non-nullable
+  /// \pre `guarantee.bound` is a scalar
+  /// \pre `guarantee.bound.type()->id() == value_set->type_id()`
+  /// \pre `value_set` is non-empty
+  /// \return a simplified value set, or a bool if the simplification results 
in
+  ///   a boolean literal predicate.

Review Comment:
   ```suggestion
     /// \return a simplified value set, or a bool if the simplification of the 
value set
     ///         means the whole is_in expr can become a boolean literal.
   ```



##########
cpp/src/arrow/compute/expression.cc:
##########
@@ -1242,8 +1273,92 @@ struct Inequality {
                             /*insert_implicit_casts=*/false, &exec_context);
   }
 
+  /// Simplify an `is_in` value set against an inequality guarantee.
+  ///
+  /// Simplifying an `is_in` predicate involves filtering out any values from
+  /// the value set that cannot possibly be found given the guarantee. For
+  /// example, if we have the predicate 'x is_in [1, 2, 3, 4]' and the 
guarantee
+  /// 'x > 2', then the simplified predicate 'x is_in [3, 4]' is equivalent.
+  /// This can be done efficiently if the value set is sorted and unique by
+  /// binary searching the inequality gound and slicing the value set
+  /// accordingly.
+  ///
+  /// \pre `guarantee` is non-nullable
+  /// \pre `guarantee.bound` is a scalar
+  /// \pre `guarantee.bound.type()->id() == value_set->type_id()`
+  /// \pre `value_set` is non-empty
+  /// \return a simplified value set, or a bool if the simplification results 
in
+  ///   a boolean literal predicate.
+  template <typename ArrowType>
+  static std::variant<std::shared_ptr<Array>, bool> SimplifyIsInValueSet(

Review Comment:
   Totally overkill. The binary search should be extracted as a function here 
because it really is the only part that needs ~17 template specializations. 
That will require moving the `switch` into this function.



##########
cpp/src/arrow/compute/expression.cc:
##########
@@ -1242,8 +1273,92 @@ struct Inequality {
                             /*insert_implicit_casts=*/false, &exec_context);
   }
 
+  /// Simplify an `is_in` value set against an inequality guarantee.
+  ///
+  /// Simplifying an `is_in` predicate involves filtering out any values from
+  /// the value set that cannot possibly be found given the guarantee. For
+  /// example, if we have the predicate 'x is_in [1, 2, 3, 4]' and the 
guarantee
+  /// 'x > 2', then the simplified predicate 'x is_in [3, 4]' is equivalent.
+  /// This can be done efficiently if the value set is sorted and unique by
+  /// binary searching the inequality gound and slicing the value set
+  /// accordingly.
+  ///
+  /// \pre `guarantee` is non-nullable
+  /// \pre `guarantee.bound` is a scalar
+  /// \pre `guarantee.bound.type()->id() == value_set->type_id()`
+  /// \pre `value_set` is non-empty
+  /// \return a simplified value set, or a bool if the simplification results 
in
+  ///   a boolean literal predicate.
+  template <typename ArrowType>
+  static std::variant<std::shared_ptr<Array>, bool> SimplifyIsInValueSet(
+      const Inequality& guarantee, std::shared_ptr<Array> value_set) {
+    using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
+    using ScalarType = typename TypeTraits<ArrowType>::ScalarType;
+    using CType = 
decltype(checked_pointer_cast<ArrayType>(value_set)->Value(0));
+
+    DCHECK(guarantee.bound.is_scalar());
+    DCHECK_EQ(guarantee.bound.type()->id(), value_set->type_id());
+    DCHECK_GT(value_set->length(), 0);
+
+    CType bound;
+    if constexpr (std::is_same_v<std::shared_ptr<Buffer>,
+                                 typename ScalarType::ValueType>) {
+        bound = 
static_cast<CType>(*guarantee.bound.scalar_as<ScalarType>().value);
+    } else {
+        bound = guarantee.bound.scalar_as<ScalarType>().value;
+    }
+
+    auto compare = [&bound, &value_set](size_t i) -> Comparison::type {
+      DCHECK(value_set->IsValid(i));
+      auto value = checked_pointer_cast<ArrayType>(value_set)->Value(i);
+      return value == bound ? Comparison::EQUAL
+                            : value < bound ? Comparison::LESS
+                                            : Comparison::GREATER;
+    };
+
+    size_t lo = 0;
+    size_t hi = value_set->length();
+    while (lo + 1 < hi) {
+      size_t mid = (lo + hi) / 2;
+      Comparison::type cmp = compare(mid);
+      if (cmp & Comparison::LESS_EQUAL) {
+        lo = mid;
+      } else {
+        hi = mid;
+      }

Review Comment:
   I suspect this binary-search doesn't necessarily work for all possible 
comparison operators. I don't have time to think carefully about this yet 
though.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to