felipecrv commented on code in PR #43256:
URL: https://github.com/apache/arrow/pull/43256#discussion_r1697456034
##########
cpp/src/arrow/compute/expression.cc:
##########
@@ -1258,6 +1373,92 @@ struct Inequality {
return call->function_name == "is_valid" ? literal(true) :
literal(false);
}
+ if (call->function_name == "is_in") {
+ // Null-matching behavior is complex and reduces the chances of reduction
+ // of `is_in` calls to a single literal for every possible input, so we
+ // abort the simplification if nulls are possible in the input.
+ if (guarantee.nullable) return expr;
+
+ if (!guarantee.bound.is_scalar()) {
+ return Status::Invalid("Cannot simplify inequality on a non-scalar
bound");
+ }
+
+ const auto& lhs =
Comparison::StripOrderPreservingCasts(call->arguments[0]);
+ if (!lhs.field_ref()) return expr;
+ if (*lhs.field_ref() != guarantee.target) return expr;
+
+ auto options = checked_pointer_cast<SetLookupOptions>(call->options);
+ auto unsimplified_value_set = options->value_set.make_array();
+ if (unsimplified_value_set->length() == 0) return literal(false);
+
+ Type::type value_set_type = unsimplified_value_set->type_id();
+ // Simplification for `is_in` requires that comparison kernels exist, so
+ // we skip simplification for non-primitive and non-base-binary types.
+ if (!is_integer(value_set_type) && !is_temporal(value_set_type)
+ && !is_base_binary_like(value_set_type)) {
+ return expr;
+ }
Review Comment:
You need to figure out a way of making the `switch` you have below work as
both dispatching and validation. Trying to keep these two in sync is too
error-prone and inefficient/inelegant.
##########
cpp/src/arrow/compute/expression.cc:
##########
@@ -1258,6 +1373,92 @@ struct Inequality {
return call->function_name == "is_valid" ? literal(true) :
literal(false);
}
+ if (call->function_name == "is_in") {
+ // Null-matching behavior is complex and reduces the chances of reduction
+ // of `is_in` calls to a single literal for every possible input, so we
+ // abort the simplification if nulls are possible in the input.
+ if (guarantee.nullable) return expr;
+
+ if (!guarantee.bound.is_scalar()) {
+ return Status::Invalid("Cannot simplify inequality on a non-scalar
bound");
+ }
+
+ const auto& lhs =
Comparison::StripOrderPreservingCasts(call->arguments[0]);
+ if (!lhs.field_ref()) return expr;
+ if (*lhs.field_ref() != guarantee.target) return expr;
+
+ auto options = checked_pointer_cast<SetLookupOptions>(call->options);
+ auto unsimplified_value_set = options->value_set.make_array();
+ if (unsimplified_value_set->length() == 0) return literal(false);
+
+ Type::type value_set_type = unsimplified_value_set->type_id();
+ // Simplification for `is_in` requires that comparison kernels exist, so
+ // we skip simplification for non-primitive and non-base-binary types.
+ if (!is_integer(value_set_type) && !is_temporal(value_set_type)
+ && !is_base_binary_like(value_set_type)) {
+ return expr;
+ }
+ // For now, we abort simplification if the guarantee bound's type does
not
+ // exactly match the value set's type.
+ if (guarantee.bound.type()->id() != value_set_type) return expr;
+
+ std::shared_ptr<Array>& value_set = context.is_in_value_sets[call];
+ if (!value_set) {
+ // Simplification for `is_in` requires that the value set is
preprocessed.
+ // We store the prepared value set in the kernel state to avoid
repeated
+ // preprocessing across calls to `SimplifyWithGuarantee`.
+ auto state = checked_pointer_cast<internal::SetLookupStateBase>(
+ call->kernel_state);
+ if (!state->sorted_and_unique_value_set) {
+ ARROW_ASSIGN_OR_RAISE(state->sorted_and_unique_value_set,
+ PrepareIsInValueSet(unsimplified_value_set));
+ }
+ if (state->sorted_and_unique_value_set->length() == 0) {
+ context.is_in_value_sets.erase(call);
+ return literal(false);
+ }
+ value_set = state->sorted_and_unique_value_set;
+ }
+
+#define CASE(TYPE_CLASS)
\
+ case TYPE_CLASS##Type::type_id:
\
+ result = SimplifyIsInValueSet<TYPE_CLASS##Type>(guarantee, value_set);
\
+ break;
+
+ std::variant<std::shared_ptr<Array>, bool> result;
+ switch (value_set_type) {
+ CASE(UInt8)
+ CASE(Int8)
+ CASE(UInt16)
+ CASE(Int16)
+ CASE(UInt32)
+ CASE(Int32)
+ CASE(UInt64)
+ CASE(Int64)
+ CASE(Date32)
+ CASE(Date64)
+ CASE(Time32)
+ CASE(Time64)
+ CASE(Timestamp)
+ CASE(Binary)
+ CASE(String)
+ CASE(LargeBinary)
+ CASE(LargeString)
+ default:
+ DCHECK(false);
+ return expr;
+ }
+
+#undef CASE
+
+ if (std::holds_alternative<bool>(result)) {
+ context.is_in_value_sets.erase(call);
+ return literal(std::get<bool>(result));
+ }
+ value_set = std::get<std::shared_ptr<Array>>(result);
+ return expr;
Review Comment:
These 70 lines can become a separate function -- `SimplifyIsIn` -- because
we shouldn't fit all the simplifications (for all simplifiable functions) in
here.
##########
cpp/src/arrow/compute/expression.cc:
##########
@@ -1242,8 +1273,92 @@ struct Inequality {
/*insert_implicit_casts=*/false, &exec_context);
}
+ /// Simplify an `is_in` value set against an inequality guarantee.
+ ///
+ /// Simplifying an `is_in` predicate involves filtering out any values from
+ /// the value set that cannot possibly be found given the guarantee. For
+ /// example, if we have the predicate 'x is_in [1, 2, 3, 4]' and the
guarantee
+ /// 'x > 2', then the simplified predicate 'x is_in [3, 4]' is equivalent.
+ /// This can be done efficiently if the value set is sorted and unique by
+ /// binary searching the inequality gound and slicing the value set
+ /// accordingly.
+ ///
+ /// \pre `guarantee` is non-nullable
+ /// \pre `guarantee.bound` is a scalar
+ /// \pre `guarantee.bound.type()->id() == value_set->type_id()`
+ /// \pre `value_set` is non-empty
+ /// \return a simplified value set, or a bool if the simplification results
in
+ /// a boolean literal predicate.
+ template <typename ArrowType>
+ static std::variant<std::shared_ptr<Array>, bool> SimplifyIsInValueSet(
+ const Inequality& guarantee, std::shared_ptr<Array> value_set) {
+ using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
+ using ScalarType = typename TypeTraits<ArrowType>::ScalarType;
+ using CType =
decltype(checked_pointer_cast<ArrayType>(value_set)->Value(0));
Review Comment:
You can get the c type from the TypeTraits. If it's failing for some type,
maybe you shouldn't be calling this with that type.
##########
cpp/src/arrow/compute/expression.cc:
##########
@@ -1242,8 +1273,92 @@ struct Inequality {
/*insert_implicit_casts=*/false, &exec_context);
}
+ /// Simplify an `is_in` value set against an inequality guarantee.
+ ///
+ /// Simplifying an `is_in` predicate involves filtering out any values from
+ /// the value set that cannot possibly be found given the guarantee. For
+ /// example, if we have the predicate 'x is_in [1, 2, 3, 4]' and the
guarantee
+ /// 'x > 2', then the simplified predicate 'x is_in [3, 4]' is equivalent.
+ /// This can be done efficiently if the value set is sorted and unique by
+ /// binary searching the inequality gound and slicing the value set
+ /// accordingly.
+ ///
+ /// \pre `guarantee` is non-nullable
+ /// \pre `guarantee.bound` is a scalar
+ /// \pre `guarantee.bound.type()->id() == value_set->type_id()`
+ /// \pre `value_set` is non-empty
+ /// \return a simplified value set, or a bool if the simplification results
in
+ /// a boolean literal predicate.
Review Comment:
```suggestion
/// \return a simplified value set, or a bool if the simplification of the
value set
/// means the whole is_in expr can become a boolean literal.
```
##########
cpp/src/arrow/compute/expression.cc:
##########
@@ -1242,8 +1273,92 @@ struct Inequality {
/*insert_implicit_casts=*/false, &exec_context);
}
+ /// Simplify an `is_in` value set against an inequality guarantee.
+ ///
+ /// Simplifying an `is_in` predicate involves filtering out any values from
+ /// the value set that cannot possibly be found given the guarantee. For
+ /// example, if we have the predicate 'x is_in [1, 2, 3, 4]' and the
guarantee
+ /// 'x > 2', then the simplified predicate 'x is_in [3, 4]' is equivalent.
+ /// This can be done efficiently if the value set is sorted and unique by
+ /// binary searching the inequality gound and slicing the value set
+ /// accordingly.
+ ///
+ /// \pre `guarantee` is non-nullable
+ /// \pre `guarantee.bound` is a scalar
+ /// \pre `guarantee.bound.type()->id() == value_set->type_id()`
+ /// \pre `value_set` is non-empty
+ /// \return a simplified value set, or a bool if the simplification results
in
+ /// a boolean literal predicate.
+ template <typename ArrowType>
+ static std::variant<std::shared_ptr<Array>, bool> SimplifyIsInValueSet(
Review Comment:
Totally overkill. The binary search should be extracted as a function here
because it really is the only part that needs ~17 template specializations.
That will require moving the `switch` into this function.
##########
cpp/src/arrow/compute/expression.cc:
##########
@@ -1242,8 +1273,92 @@ struct Inequality {
/*insert_implicit_casts=*/false, &exec_context);
}
+ /// Simplify an `is_in` value set against an inequality guarantee.
+ ///
+ /// Simplifying an `is_in` predicate involves filtering out any values from
+ /// the value set that cannot possibly be found given the guarantee. For
+ /// example, if we have the predicate 'x is_in [1, 2, 3, 4]' and the
guarantee
+ /// 'x > 2', then the simplified predicate 'x is_in [3, 4]' is equivalent.
+ /// This can be done efficiently if the value set is sorted and unique by
+ /// binary searching the inequality gound and slicing the value set
+ /// accordingly.
+ ///
+ /// \pre `guarantee` is non-nullable
+ /// \pre `guarantee.bound` is a scalar
+ /// \pre `guarantee.bound.type()->id() == value_set->type_id()`
+ /// \pre `value_set` is non-empty
+ /// \return a simplified value set, or a bool if the simplification results
in
+ /// a boolean literal predicate.
+ template <typename ArrowType>
+ static std::variant<std::shared_ptr<Array>, bool> SimplifyIsInValueSet(
+ const Inequality& guarantee, std::shared_ptr<Array> value_set) {
+ using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
+ using ScalarType = typename TypeTraits<ArrowType>::ScalarType;
+ using CType =
decltype(checked_pointer_cast<ArrayType>(value_set)->Value(0));
+
+ DCHECK(guarantee.bound.is_scalar());
+ DCHECK_EQ(guarantee.bound.type()->id(), value_set->type_id());
+ DCHECK_GT(value_set->length(), 0);
+
+ CType bound;
+ if constexpr (std::is_same_v<std::shared_ptr<Buffer>,
+ typename ScalarType::ValueType>) {
+ bound =
static_cast<CType>(*guarantee.bound.scalar_as<ScalarType>().value);
+ } else {
+ bound = guarantee.bound.scalar_as<ScalarType>().value;
+ }
+
+ auto compare = [&bound, &value_set](size_t i) -> Comparison::type {
+ DCHECK(value_set->IsValid(i));
+ auto value = checked_pointer_cast<ArrayType>(value_set)->Value(i);
+ return value == bound ? Comparison::EQUAL
+ : value < bound ? Comparison::LESS
+ : Comparison::GREATER;
+ };
+
+ size_t lo = 0;
+ size_t hi = value_set->length();
+ while (lo + 1 < hi) {
+ size_t mid = (lo + hi) / 2;
+ Comparison::type cmp = compare(mid);
+ if (cmp & Comparison::LESS_EQUAL) {
+ lo = mid;
+ } else {
+ hi = mid;
+ }
Review Comment:
I suspect this binary-search doesn't necessarily work for all possible
comparison operators. I don't have time to think carefully about this yet
though.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]