felipecrv commented on code in PR #43256:
URL: https://github.com/apache/arrow/pull/43256#discussion_r1685772369
##########
cpp/src/arrow/compute/expression.cc:
##########
@@ -1242,33 +1242,104 @@ struct Inequality {
/*insert_implicit_casts=*/false, &exec_context);
}
+ /// Simplify an is_in predicate against this inequality as a guarantee.
+ Result<Expression> SimplifyIsIn(Expression expr) {
+ const auto& guarantee = *this;
+ auto call = expr.call();
+ auto options = checked_pointer_cast<SetLookupOptions>(call->options);
+
+ auto value_set = options->value_set.make_array();
+ if (!value_set) return expr;
+ if (value_set->length() == 0) return literal(false);
+
+ if (!options->sorted_and_deduped) return expr;
+
+ // For now, only simplify when the guarantee is non-nullable.
+ if (guarantee.nullable) return expr;
+
+ auto compare = [&value_set, &guarantee](size_t i) ->
Result<Comparison::type> {
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar,
value_set->GetScalar(i));
+ // Nulls compare greater than any non-null value.
+ if (!scalar->is_valid) {
+ return Comparison::GREATER;
+ }
+ ARROW_ASSIGN_OR_RAISE(Comparison::type cmp,
+ Comparison::Execute(scalar, guarantee.bound));
+ return cmp;
+ };
+
+ size_t lo = 0;
+ size_t hi = value_set->length();
+ while (lo + 1 < hi) {
+ size_t mid = (lo + hi) / 2;
+ ARROW_ASSIGN_OR_RAISE(Comparison::type cmp, compare(mid));
+ if (cmp & Comparison::LESS_EQUAL) {
+ lo = mid;
+ } else {
+ hi = mid;
+ }
+ }
+
+ ARROW_ASSIGN_OR_RAISE(Comparison::type cmp, compare(lo));
+ size_t pivot = lo + (cmp == Comparison::LESS ? 1 : 0);
+ bool found = cmp == Comparison::EQUAL;
+
+ std::shared_ptr<Array> simplified_value_set;
+ if (guarantee.cmp == Comparison::EQUAL) {
+ return literal(found);
+ }
+ if (guarantee.cmp == Comparison::LESS) {
+ simplified_value_set = value_set->Slice(0, pivot);
+ } else if (guarantee.cmp == Comparison::LESS_EQUAL) {
+ simplified_value_set = value_set->Slice(0, pivot + (found ? 1 : 0));
+ } else if (guarantee.cmp == Comparison::GREATER) {
+ simplified_value_set = value_set->Slice(pivot + (found ? 1 : 0));
+ } else if (guarantee.cmp == Comparison::GREATER_EQUAL) {
+ simplified_value_set = value_set->Slice(pivot);
+ } else {
+ // We should never reach here.
+ return expr;
+ }
Review Comment:
Can you extract a function from here and call it
`SimplifiedValueSetForIsIn(guarantee, options)`? Document all the
pre-conditions:
```cpp
/// \pre !guarantee.nullable (the values matched against value_set are never
null)
/// \pre options.value_set is not empty
/// \pre options.value_set only contains distinct values
/// \pre options.value_set is sorted
```
Even better if you can keep this outside the class.
##########
cpp/src/arrow/compute/expression.cc:
##########
@@ -1242,33 +1242,104 @@ struct Inequality {
/*insert_implicit_casts=*/false, &exec_context);
}
+ /// Simplify an is_in predicate against this inequality as a guarantee.
+ Result<Expression> SimplifyIsIn(Expression expr) {
+ const auto& guarantee = *this;
+ auto call = expr.call();
+ auto options = checked_pointer_cast<SetLookupOptions>(call->options);
+
+ auto value_set = options->value_set.make_array();
+ if (!value_set) return expr;
+ if (value_set->length() == 0) return literal(false);
+
+ if (!options->sorted_and_deduped) return expr;
+
+ // For now, only simplify when the guarantee is non-nullable.
+ if (guarantee.nullable) return expr;
+
+ auto compare = [&value_set, &guarantee](size_t i) ->
Result<Comparison::type> {
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar,
value_set->GetScalar(i));
+ // Nulls compare greater than any non-null value.
+ if (!scalar->is_valid) {
+ return Comparison::GREATER;
+ }
+ ARROW_ASSIGN_OR_RAISE(Comparison::type cmp,
+ Comparison::Execute(scalar, guarantee.bound));
+ return cmp;
+ };
+
+ size_t lo = 0;
+ size_t hi = value_set->length();
+ while (lo + 1 < hi) {
+ size_t mid = (lo + hi) / 2;
+ ARROW_ASSIGN_OR_RAISE(Comparison::type cmp, compare(mid));
+ if (cmp & Comparison::LESS_EQUAL) {
+ lo = mid;
+ } else {
+ hi = mid;
+ }
+ }
+
+ ARROW_ASSIGN_OR_RAISE(Comparison::type cmp, compare(lo));
+ size_t pivot = lo + (cmp == Comparison::LESS ? 1 : 0);
+ bool found = cmp == Comparison::EQUAL;
+
+ std::shared_ptr<Array> simplified_value_set;
+ if (guarantee.cmp == Comparison::EQUAL) {
+ return literal(found);
+ }
+ if (guarantee.cmp == Comparison::LESS) {
+ simplified_value_set = value_set->Slice(0, pivot);
+ } else if (guarantee.cmp == Comparison::LESS_EQUAL) {
+ simplified_value_set = value_set->Slice(0, pivot + (found ? 1 : 0));
+ } else if (guarantee.cmp == Comparison::GREATER) {
+ simplified_value_set = value_set->Slice(pivot + (found ? 1 : 0));
+ } else if (guarantee.cmp == Comparison::GREATER_EQUAL) {
+ simplified_value_set = value_set->Slice(pivot);
+ } else {
+ // We should never reach here.
+ return expr;
+ }
Review Comment:
Can you extract a function from here and call it
`SimplifiedValueSetForIsIn(guarantee, options)`? Document all the
pre-conditions:
```cpp
/// \pre !guarantee.nullable (i.e. the values matched against value_set are
never null)
/// \pre options.value_set is not empty
/// \pre options.value_set only contains distinct values
/// \pre options.value_set is sorted
```
Even better if you can keep this outside the class.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]