felipecrv commented on code in PR #43256:
URL: https://github.com/apache/arrow/pull/43256#discussion_r1685758875


##########
cpp/src/arrow/compute/api_scalar.cc:
##########
@@ -370,7 +370,8 @@ static auto kRoundToMultipleOptionsType = 
GetFunctionOptionsType<RoundToMultiple
 static auto kSetLookupOptionsType = GetFunctionOptionsType<SetLookupOptions>(
     DataMember("value_set", &SetLookupOptions::value_set),
     CoercedDataMember("null_matching_behavior", 
&SetLookupOptions::null_matching_behavior,
-                      &SetLookupOptions::GetNullMatchingBehavior));
+                      &SetLookupOptions::GetNullMatchingBehavior),
+    DataMember("sorted_and_deduped", &SetLookupOptions::sorted_and_deduped));

Review Comment:
   A name more in line with common usage in databases would be 
`sorted_and_unique`.



##########
cpp/src/arrow/compute/expression.cc:
##########
@@ -1242,33 +1242,104 @@ struct Inequality {
                             /*insert_implicit_casts=*/false, &exec_context);
   }
 
+  /// Simplify an is_in predicate against this inequality as a guarantee.
+  Result<Expression> SimplifyIsIn(Expression expr) {
+    const auto& guarantee = *this;
+    auto call = expr.call();
+    auto options = checked_pointer_cast<SetLookupOptions>(call->options);
+
+    auto value_set = options->value_set.make_array();
+    if (!value_set) return expr;
+    if (value_set->length() == 0) return literal(false);
+
+    if (!options->sorted_and_deduped) return expr;
+
+    // For now, only simplify when the guarantee is non-nullable.
+    if (guarantee.nullable) return expr;
+
+    auto compare = [&value_set, &guarantee](size_t i) -> 
Result<Comparison::type> {
+      ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar, 
value_set->GetScalar(i));
+      // Nulls compare greater than any non-null value.
+      if (!scalar->is_valid) {
+        return Comparison::GREATER;
+      }
+      ARROW_ASSIGN_OR_RAISE(Comparison::type cmp,
+                            Comparison::Execute(scalar, guarantee.bound));
+      return cmp;
+    };
+
+    size_t lo = 0;
+    size_t hi = value_set->length();
+    while (lo + 1 < hi) {
+      size_t mid = (lo + hi) / 2;
+      ARROW_ASSIGN_OR_RAISE(Comparison::type cmp, compare(mid));
+      if (cmp & Comparison::LESS_EQUAL) {
+        lo = mid;
+      } else {
+        hi = mid;
+      }
+    }
+
+    ARROW_ASSIGN_OR_RAISE(Comparison::type cmp, compare(lo));
+    size_t pivot = lo + (cmp == Comparison::LESS ? 1 : 0);
+    bool found = cmp == Comparison::EQUAL;
+
+    std::shared_ptr<Array> simplified_value_set;
+    if (guarantee.cmp == Comparison::EQUAL) {
+      return literal(found);
+    }
+    if (guarantee.cmp == Comparison::LESS) {
+      simplified_value_set = value_set->Slice(0, pivot);
+    } else if (guarantee.cmp == Comparison::LESS_EQUAL) {
+      simplified_value_set = value_set->Slice(0, pivot + (found ? 1 : 0));
+    } else if (guarantee.cmp == Comparison::GREATER) {
+      simplified_value_set = value_set->Slice(pivot + (found ? 1 : 0));
+    } else if (guarantee.cmp == Comparison::GREATER_EQUAL) {
+      simplified_value_set = value_set->Slice(pivot);
+    } else {
+      // We should never reach here.
+      return expr;
+    }
+
+    if (simplified_value_set->length() == 0) return literal(false);
+    if (simplified_value_set->length() == value_set->length()) return expr;
+
+    Expression::Call simplified_call;
+    simplified_call.function_name = "is_in";
+    simplified_call.arguments = call->arguments;
+    simplified_call.options = std::make_shared<SetLookupOptions>(
+        std::move(simplified_value_set), options->null_matching_behavior,
+        /*sorted_and_deduped=*/true);
+    ExecContext exec_context;
+    return BindNonRecursive(std::move(simplified_call),
+                            /*insert_implicit_casts=*/false, &exec_context);
+  }
+
   /// \brief Simplify the given expression given this inequality as a 
guarantee.
   Result<Expression> Simplify(Expression expr) {
     const auto& guarantee = *this;
 
     auto call = expr.call();
     if (!call) return expr;
 
+    const auto& lhs = 
Comparison::StripOrderPreservingCasts(call->arguments[0]);
+    if (!lhs.field_ref()) return expr;
+    if (*lhs.field_ref() != guarantee.target) return expr;
+
     if (call->function_name == "is_valid" || call->function_name == "is_null") 
{
       if (guarantee.nullable) return expr;
-      const auto& lhs = 
Comparison::StripOrderPreservingCasts(call->arguments[0]);
-      if (!lhs.field_ref()) return expr;
-      if (*lhs.field_ref() != guarantee.target) return expr;
-
       return call->function_name == "is_valid" ? literal(true) : 
literal(false);
     }
 
+    if (call->function_name == "is_in") return SimplifyIsIn(expr);

Review Comment:
   Now that you pattern-matched the expr variant and identified that it is a 
call, also pass the `const Expression::Call *` to `SimplifyIsIn`.



##########
cpp/src/arrow/compute/expression.cc:
##########
@@ -1242,33 +1242,104 @@ struct Inequality {
                             /*insert_implicit_casts=*/false, &exec_context);
   }
 
+  /// Simplify an is_in predicate against this inequality as a guarantee.
+  Result<Expression> SimplifyIsIn(Expression expr) {

Review Comment:
   `expr` should probably be named `is_in_call_expr` to make it clear that this 
function assumes pre-conditions that should be ensured by the caller and it 
can't be called with any generic `Expression`.



##########
cpp/src/arrow/compute/expression.cc:
##########
@@ -1242,33 +1242,104 @@ struct Inequality {
                             /*insert_implicit_casts=*/false, &exec_context);
   }
 
+  /// Simplify an is_in predicate against this inequality as a guarantee.
+  Result<Expression> SimplifyIsIn(Expression expr) {
+    const auto& guarantee = *this;
+    auto call = expr.call();
+    auto options = checked_pointer_cast<SetLookupOptions>(call->options);
+
+    auto value_set = options->value_set.make_array();
+    if (!value_set) return expr;

Review Comment:
   ```suggestion
       // If `value_set.make_array()` fails, we abort the simplification
       // and let the kernel handle the error when the expression is
       // executed.
       if (!value_set) return expr;
   ```



##########
cpp/src/arrow/compute/api_scalar.cc:
##########
@@ -370,7 +370,8 @@ static auto kRoundToMultipleOptionsType = 
GetFunctionOptionsType<RoundToMultiple
 static auto kSetLookupOptionsType = GetFunctionOptionsType<SetLookupOptions>(
     DataMember("value_set", &SetLookupOptions::value_set),
     CoercedDataMember("null_matching_behavior", 
&SetLookupOptions::null_matching_behavior,
-                      &SetLookupOptions::GetNullMatchingBehavior));
+                      &SetLookupOptions::GetNullMatchingBehavior),
+    DataMember("sorted_and_deduped", &SetLookupOptions::sorted_and_deduped));

Review Comment:
   If the kernel is not using this option (only the simplification). Wouldn't 
it be fine to sort the `value_set` in the function that simplifies the value 
set?



##########
cpp/src/arrow/compute/expression.cc:
##########
@@ -1242,33 +1242,104 @@ struct Inequality {
                             /*insert_implicit_casts=*/false, &exec_context);
   }
 
+  /// Simplify an is_in predicate against this inequality as a guarantee.
+  Result<Expression> SimplifyIsIn(Expression expr) {
+    const auto& guarantee = *this;
+    auto call = expr.call();
+    auto options = checked_pointer_cast<SetLookupOptions>(call->options);
+
+    auto value_set = options->value_set.make_array();
+    if (!value_set) return expr;
+    if (value_set->length() == 0) return literal(false);

Review Comment:
   You have to be extremely careful with expression simplifications. The values 
you return must be the correct value for every possible options/input 
combination. If `SetLookupOptions::null_matching_behavior` is `INCONCLUSIVE`, 
nulls should map to `null` in the output even when `value_set` is empty. [1]
   
   [1] If that's currently not the case, then there is a bug in the kernel 
implementation



##########
cpp/src/arrow/compute/expression.cc:
##########
@@ -1242,33 +1242,104 @@ struct Inequality {
                             /*insert_implicit_casts=*/false, &exec_context);
   }
 
+  /// Simplify an is_in predicate against this inequality as a guarantee.
+  Result<Expression> SimplifyIsIn(Expression expr) {
+    const auto& guarantee = *this;
+    auto call = expr.call();
+    auto options = checked_pointer_cast<SetLookupOptions>(call->options);
+
+    auto value_set = options->value_set.make_array();
+    if (!value_set) return expr;
+    if (value_set->length() == 0) return literal(false);
+
+    if (!options->sorted_and_deduped) return expr;
+
+    // For now, only simplify when the guarantee is non-nullable.
+    if (guarantee.nullable) return expr;
+
+    auto compare = [&value_set, &guarantee](size_t i) -> 
Result<Comparison::type> {
+      ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar, 
value_set->GetScalar(i));
+      // Nulls compare greater than any non-null value.
+      if (!scalar->is_valid) {
+        return Comparison::GREATER;
+      }
+      ARROW_ASSIGN_OR_RAISE(Comparison::type cmp,
+                            Comparison::Execute(scalar, guarantee.bound));
+      return cmp;
+    };
+
+    size_t lo = 0;
+    size_t hi = value_set->length();
+    while (lo + 1 < hi) {
+      size_t mid = (lo + hi) / 2;
+      ARROW_ASSIGN_OR_RAISE(Comparison::type cmp, compare(mid));
+      if (cmp & Comparison::LESS_EQUAL) {
+        lo = mid;
+      } else {
+        hi = mid;
+      }
+    }
+
+    ARROW_ASSIGN_OR_RAISE(Comparison::type cmp, compare(lo));
+    size_t pivot = lo + (cmp == Comparison::LESS ? 1 : 0);
+    bool found = cmp == Comparison::EQUAL;
+
+    std::shared_ptr<Array> simplified_value_set;
+    if (guarantee.cmp == Comparison::EQUAL) {
+      return literal(found);
+    }
+    if (guarantee.cmp == Comparison::LESS) {
+      simplified_value_set = value_set->Slice(0, pivot);
+    } else if (guarantee.cmp == Comparison::LESS_EQUAL) {
+      simplified_value_set = value_set->Slice(0, pivot + (found ? 1 : 0));
+    } else if (guarantee.cmp == Comparison::GREATER) {
+      simplified_value_set = value_set->Slice(pivot + (found ? 1 : 0));
+    } else if (guarantee.cmp == Comparison::GREATER_EQUAL) {
+      simplified_value_set = value_set->Slice(pivot);
+    } else {
+      // We should never reach here.
+      return expr;
+    }

Review Comment:
   Can you extract a function from here and call it 
`SimplifiedValueSetForIsIn(guarantee, options)`? Document all the 
pre-conditions:
   
   ```cpp
   /// \pre !guarantee.nullable
   /// \pre !options.value_set.empty()
   /// \pre !options.value_set is sorted and contains only unique values
   ```
   
   Even better if you can keep this outside the class.



##########
cpp/src/arrow/compute/expression.cc:
##########
@@ -1242,33 +1242,104 @@ struct Inequality {
                             /*insert_implicit_casts=*/false, &exec_context);
   }
 
+  /// Simplify an is_in predicate against this inequality as a guarantee.
+  Result<Expression> SimplifyIsIn(Expression expr) {
+    const auto& guarantee = *this;
+    auto call = expr.call();
+    auto options = checked_pointer_cast<SetLookupOptions>(call->options);
+
+    auto value_set = options->value_set.make_array();
+    if (!value_set) return expr;
+    if (value_set->length() == 0) return literal(false);
+
+    if (!options->sorted_and_deduped) return expr;
+
+    // For now, only simplify when the guarantee is non-nullable.
+    if (guarantee.nullable) return expr;
+
+    auto compare = [&value_set, &guarantee](size_t i) -> 
Result<Comparison::type> {
+      ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar, 
value_set->GetScalar(i));
+      // Nulls compare greater than any non-null value.
+      if (!scalar->is_valid) {
+        return Comparison::GREATER;
+      }
+      ARROW_ASSIGN_OR_RAISE(Comparison::type cmp,
+                            Comparison::Execute(scalar, guarantee.bound));
+      return cmp;
+    };
+
+    size_t lo = 0;
+    size_t hi = value_set->length();
+    while (lo + 1 < hi) {
+      size_t mid = (lo + hi) / 2;
+      ARROW_ASSIGN_OR_RAISE(Comparison::type cmp, compare(mid));
+      if (cmp & Comparison::LESS_EQUAL) {
+        lo = mid;
+      } else {
+        hi = mid;
+      }
+    }
+
+    ARROW_ASSIGN_OR_RAISE(Comparison::type cmp, compare(lo));
+    size_t pivot = lo + (cmp == Comparison::LESS ? 1 : 0);
+    bool found = cmp == Comparison::EQUAL;
+
+    std::shared_ptr<Array> simplified_value_set;
+    if (guarantee.cmp == Comparison::EQUAL) {
+      return literal(found);
+    }
+    if (guarantee.cmp == Comparison::LESS) {
+      simplified_value_set = value_set->Slice(0, pivot);
+    } else if (guarantee.cmp == Comparison::LESS_EQUAL) {
+      simplified_value_set = value_set->Slice(0, pivot + (found ? 1 : 0));
+    } else if (guarantee.cmp == Comparison::GREATER) {
+      simplified_value_set = value_set->Slice(pivot + (found ? 1 : 0));
+    } else if (guarantee.cmp == Comparison::GREATER_EQUAL) {
+      simplified_value_set = value_set->Slice(pivot);
+    } else {
+      // We should never reach here.
+      return expr;
+    }
+
+    if (simplified_value_set->length() == 0) return literal(false);
+    if (simplified_value_set->length() == value_set->length()) return expr;
+
+    Expression::Call simplified_call;
+    simplified_call.function_name = "is_in";
+    simplified_call.arguments = call->arguments;
+    simplified_call.options = std::make_shared<SetLookupOptions>(
+        std::move(simplified_value_set), options->null_matching_behavior,
+        /*sorted_and_deduped=*/true);
+    ExecContext exec_context;
+    return BindNonRecursive(std::move(simplified_call),
+                            /*insert_implicit_casts=*/false, &exec_context);
+  }
+
   /// \brief Simplify the given expression given this inequality as a 
guarantee.
   Result<Expression> Simplify(Expression expr) {
     const auto& guarantee = *this;
 
     auto call = expr.call();
     if (!call) return expr;
 
+    const auto& lhs = 
Comparison::StripOrderPreservingCasts(call->arguments[0]);
+    if (!lhs.field_ref()) return expr;
+    if (*lhs.field_ref() != guarantee.target) return expr;
+

Review Comment:
   Please keep this code guarded by function checks as it was. For two reasons:
   1) performance of the simplification: we call `StripOrderPreservingCasts` 
only when we know it's beneficial for the function call we are simplifying.
   2) safety and debugability: we don't transform arguments for every function 
in the registry.



##########
cpp/src/arrow/compute/expression.cc:
##########
@@ -1242,33 +1242,104 @@ struct Inequality {
                             /*insert_implicit_casts=*/false, &exec_context);
   }
 
+  /// Simplify an is_in predicate against this inequality as a guarantee.
+  Result<Expression> SimplifyIsIn(Expression expr) {
+    const auto& guarantee = *this;
+    auto call = expr.call();
+    auto options = checked_pointer_cast<SetLookupOptions>(call->options);
+
+    auto value_set = options->value_set.make_array();
+    if (!value_set) return expr;
+    if (value_set->length() == 0) return literal(false);
+
+    if (!options->sorted_and_deduped) return expr;
+

Review Comment:
   The PR description says "When this option is set, SimplifyWithGuarantee will 
attempt to simplify any is_in predicates that it finds in the expression."
   
   I don't understand yet what "simplify any is_in predicates that it find in 
the expression" means. Can you put a comment here in the code explaining what 
this simplification is?



##########
cpp/src/arrow/compute/expression.cc:
##########
@@ -1242,33 +1242,104 @@ struct Inequality {
                             /*insert_implicit_casts=*/false, &exec_context);
   }
 
+  /// Simplify an is_in predicate against this inequality as a guarantee.
+  Result<Expression> SimplifyIsIn(Expression expr) {
+    const auto& guarantee = *this;
+    auto call = expr.call();
+    auto options = checked_pointer_cast<SetLookupOptions>(call->options);
+
+    auto value_set = options->value_set.make_array();
+    if (!value_set) return expr;
+    if (value_set->length() == 0) return literal(false);
+
+    if (!options->sorted_and_deduped) return expr;
+
+    // For now, only simplify when the guarantee is non-nullable.
+    if (guarantee.nullable) return expr;
+
+    auto compare = [&value_set, &guarantee](size_t i) -> 
Result<Comparison::type> {
+      ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar, 
value_set->GetScalar(i));
+      // Nulls compare greater than any non-null value.
+      if (!scalar->is_valid) {
+        return Comparison::GREATER;
+      }
+      ARROW_ASSIGN_OR_RAISE(Comparison::type cmp,
+                            Comparison::Execute(scalar, guarantee.bound));
+      return cmp;
+    };
+
+    size_t lo = 0;
+    size_t hi = value_set->length();
+    while (lo + 1 < hi) {
+      size_t mid = (lo + hi) / 2;
+      ARROW_ASSIGN_OR_RAISE(Comparison::type cmp, compare(mid));
+      if (cmp & Comparison::LESS_EQUAL) {
+        lo = mid;
+      } else {
+        hi = mid;
+      }
+    }
+
+    ARROW_ASSIGN_OR_RAISE(Comparison::type cmp, compare(lo));
+    size_t pivot = lo + (cmp == Comparison::LESS ? 1 : 0);
+    bool found = cmp == Comparison::EQUAL;
+
+    std::shared_ptr<Array> simplified_value_set;
+    if (guarantee.cmp == Comparison::EQUAL) {
+      return literal(found);
+    }
+    if (guarantee.cmp == Comparison::LESS) {
+      simplified_value_set = value_set->Slice(0, pivot);
+    } else if (guarantee.cmp == Comparison::LESS_EQUAL) {
+      simplified_value_set = value_set->Slice(0, pivot + (found ? 1 : 0));
+    } else if (guarantee.cmp == Comparison::GREATER) {
+      simplified_value_set = value_set->Slice(pivot + (found ? 1 : 0));
+    } else if (guarantee.cmp == Comparison::GREATER_EQUAL) {
+      simplified_value_set = value_set->Slice(pivot);
+    } else {
+      // We should never reach here.

Review Comment:
   You can use a `switch` so the compiler guarantees all enum values are 
handled.



##########
cpp/src/arrow/compute/expression.cc:
##########
@@ -1242,33 +1242,104 @@ struct Inequality {
                             /*insert_implicit_casts=*/false, &exec_context);
   }
 
+  /// Simplify an is_in predicate against this inequality as a guarantee.
+  Result<Expression> SimplifyIsIn(Expression expr) {
+    const auto& guarantee = *this;
+    auto call = expr.call();
+    auto options = checked_pointer_cast<SetLookupOptions>(call->options);
+
+    auto value_set = options->value_set.make_array();
+    if (!value_set) return expr;
+    if (value_set->length() == 0) return literal(false);
+
+    if (!options->sorted_and_deduped) return expr;
+
+    // For now, only simplify when the guarantee is non-nullable.
+    if (guarantee.nullable) return expr;
+
+    auto compare = [&value_set, &guarantee](size_t i) -> 
Result<Comparison::type> {
+      ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar, 
value_set->GetScalar(i));
+      // Nulls compare greater than any non-null value.
+      if (!scalar->is_valid) {
+        return Comparison::GREATER;
+      }
+      ARROW_ASSIGN_OR_RAISE(Comparison::type cmp,
+                            Comparison::Execute(scalar, guarantee.bound));
+      return cmp;
+    };
+
+    size_t lo = 0;
+    size_t hi = value_set->length();
+    while (lo + 1 < hi) {
+      size_t mid = (lo + hi) / 2;
+      ARROW_ASSIGN_OR_RAISE(Comparison::type cmp, compare(mid));
+      if (cmp & Comparison::LESS_EQUAL) {
+        lo = mid;
+      } else {
+        hi = mid;
+      }
+    }
+
+    ARROW_ASSIGN_OR_RAISE(Comparison::type cmp, compare(lo));
+    size_t pivot = lo + (cmp == Comparison::LESS ? 1 : 0);
+    bool found = cmp == Comparison::EQUAL;
+
+    std::shared_ptr<Array> simplified_value_set;
+    if (guarantee.cmp == Comparison::EQUAL) {
+      return literal(found);
+    }
+    if (guarantee.cmp == Comparison::LESS) {
+      simplified_value_set = value_set->Slice(0, pivot);
+    } else if (guarantee.cmp == Comparison::LESS_EQUAL) {
+      simplified_value_set = value_set->Slice(0, pivot + (found ? 1 : 0));
+    } else if (guarantee.cmp == Comparison::GREATER) {
+      simplified_value_set = value_set->Slice(pivot + (found ? 1 : 0));
+    } else if (guarantee.cmp == Comparison::GREATER_EQUAL) {
+      simplified_value_set = value_set->Slice(pivot);
+    } else {
+      // We should never reach here.
+      return expr;
+    }
+
+    if (simplified_value_set->length() == 0) return literal(false);
+    if (simplified_value_set->length() == value_set->length()) return expr;
+
+    Expression::Call simplified_call;
+    simplified_call.function_name = "is_in";
+    simplified_call.arguments = call->arguments;
+    simplified_call.options = std::make_shared<SetLookupOptions>(
+        std::move(simplified_value_set), options->null_matching_behavior,
+        /*sorted_and_deduped=*/true);
+    ExecContext exec_context;
+    return BindNonRecursive(std::move(simplified_call),
+                            /*insert_implicit_casts=*/false, &exec_context);
+  }
+
   /// \brief Simplify the given expression given this inequality as a 
guarantee.
   Result<Expression> Simplify(Expression expr) {
     const auto& guarantee = *this;
 
     auto call = expr.call();
     if (!call) return expr;
 
+    const auto& lhs = 
Comparison::StripOrderPreservingCasts(call->arguments[0]);
+    if (!lhs.field_ref()) return expr;
+    if (*lhs.field_ref() != guarantee.target) return expr;
+
     if (call->function_name == "is_valid" || call->function_name == "is_null") 
{
       if (guarantee.nullable) return expr;
-      const auto& lhs = 
Comparison::StripOrderPreservingCasts(call->arguments[0]);
-      if (!lhs.field_ref()) return expr;
-      if (*lhs.field_ref() != guarantee.target) return expr;
-
       return call->function_name == "is_valid" ? literal(true) : 
literal(false);
     }
 
+    if (call->function_name == "is_in") return SimplifyIsIn(expr);
+
     auto cmp = Comparison::Get(expr);
     if (!cmp) return expr;
 
     auto rhs = call->arguments[1].literal();
     if (!rhs) return expr;
     if (!rhs->is_scalar()) return expr;
 
-    const auto& lhs = 
Comparison::StripOrderPreservingCasts(call->arguments[0]);
-    if (!lhs.field_ref()) return expr;
-    if (*lhs.field_ref() != guarantee.target) return expr;
-

Review Comment:
   Keep this one here.



##########
cpp/src/arrow/compute/expression.cc:
##########
@@ -1242,33 +1242,103 @@ struct Inequality {
                             /*insert_implicit_casts=*/false, &exec_context);
   }
 
+  /// Simplify an is_in predicate against this inequality as a guarantee.
+  Result<Expression> SimplifyIsIn(Expression expr) {
+    const auto& guarantee = *this;
+    auto call = expr.call();
+    auto options = checked_pointer_cast<SetLookupOptions>(call->options);
+
+    auto value_set = options->value_set.make_array();
+    if (!value_set) return expr;
+    if (value_set->length() == 0) return literal(false);
+
+    if (!options->sorted_and_deduped) return expr;
+
+    // For now, only simplify when the guarantee is non-nullable.

Review Comment:
   ```suggestion
       // Null-matching behavior is complex and reduces the chances of
       // reduction of `is_in` calls to a single literal for every possible
       // input, so we abort the simplification if nulls are possible in the
       // input.
   ```



##########
cpp/src/arrow/compute/expression.cc:
##########
@@ -1242,33 +1242,104 @@ struct Inequality {
                             /*insert_implicit_casts=*/false, &exec_context);
   }
 
+  /// Simplify an is_in predicate against this inequality as a guarantee.
+  Result<Expression> SimplifyIsIn(Expression expr) {
+    const auto& guarantee = *this;
+    auto call = expr.call();
+    auto options = checked_pointer_cast<SetLookupOptions>(call->options);
+
+    auto value_set = options->value_set.make_array();
+    if (!value_set) return expr;
+    if (value_set->length() == 0) return literal(false);

Review Comment:
   Or you can guard this with the `guarantee.nullable` check and write a 
comment with a short proof of correctness:
   
   > Since `!guarantee.nullable`, we don't have to consider all possible 
null-matching behaviors in the simplification and can guarantee that `false` is 
the return value when the `value_set` is empty.



##########
cpp/src/arrow/compute/expression.cc:
##########
@@ -1242,33 +1242,104 @@ struct Inequality {
                             /*insert_implicit_casts=*/false, &exec_context);
   }
 
+  /// Simplify an is_in predicate against this inequality as a guarantee.
+  Result<Expression> SimplifyIsIn(Expression expr) {
+    const auto& guarantee = *this;
+    auto call = expr.call();
+    auto options = checked_pointer_cast<SetLookupOptions>(call->options);
+
+    auto value_set = options->value_set.make_array();
+    if (!value_set) return expr;
+    if (value_set->length() == 0) return literal(false);
+
+    if (!options->sorted_and_deduped) return expr;
+
+    // For now, only simplify when the guarantee is non-nullable.
+    if (guarantee.nullable) return expr;
+
+    auto compare = [&value_set, &guarantee](size_t i) -> 
Result<Comparison::type> {
+      ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar, 
value_set->GetScalar(i));
+      // Nulls compare greater than any non-null value.
+      if (!scalar->is_valid) {
+        return Comparison::GREATER;
+      }
+      ARROW_ASSIGN_OR_RAISE(Comparison::type cmp,
+                            Comparison::Execute(scalar, guarantee.bound));
+      return cmp;
+    };
+
+    size_t lo = 0;
+    size_t hi = value_set->length();
+    while (lo + 1 < hi) {
+      size_t mid = (lo + hi) / 2;
+      ARROW_ASSIGN_OR_RAISE(Comparison::type cmp, compare(mid));
+      if (cmp & Comparison::LESS_EQUAL) {
+        lo = mid;
+      } else {
+        hi = mid;
+      }
+    }
+
+    ARROW_ASSIGN_OR_RAISE(Comparison::type cmp, compare(lo));
+    size_t pivot = lo + (cmp == Comparison::LESS ? 1 : 0);
+    bool found = cmp == Comparison::EQUAL;
+
+    std::shared_ptr<Array> simplified_value_set;
+    if (guarantee.cmp == Comparison::EQUAL) {
+      return literal(found);
+    }
+    if (guarantee.cmp == Comparison::LESS) {
+      simplified_value_set = value_set->Slice(0, pivot);
+    } else if (guarantee.cmp == Comparison::LESS_EQUAL) {
+      simplified_value_set = value_set->Slice(0, pivot + (found ? 1 : 0));
+    } else if (guarantee.cmp == Comparison::GREATER) {
+      simplified_value_set = value_set->Slice(pivot + (found ? 1 : 0));
+    } else if (guarantee.cmp == Comparison::GREATER_EQUAL) {
+      simplified_value_set = value_set->Slice(pivot);
+    } else {
+      // We should never reach here.
+      return expr;
+    }
+
+    if (simplified_value_set->length() == 0) return literal(false);
+    if (simplified_value_set->length() == value_set->length()) return expr;
+
+    Expression::Call simplified_call;
+    simplified_call.function_name = "is_in";
+    simplified_call.arguments = call->arguments;
+    simplified_call.options = std::make_shared<SetLookupOptions>(
+        std::move(simplified_value_set), options->null_matching_behavior,
+        /*sorted_and_deduped=*/true);
+    ExecContext exec_context;
+    return BindNonRecursive(std::move(simplified_call),
+                            /*insert_implicit_casts=*/false, &exec_context);
+  }
+
   /// \brief Simplify the given expression given this inequality as a 
guarantee.
   Result<Expression> Simplify(Expression expr) {
     const auto& guarantee = *this;
 
     auto call = expr.call();
     if (!call) return expr;
 
+    const auto& lhs = 
Comparison::StripOrderPreservingCasts(call->arguments[0]);
+    if (!lhs.field_ref()) return expr;
+    if (*lhs.field_ref() != guarantee.target) return expr;
+

Review Comment:
   You should have a copy of these 3 lines inside `SimplifyIsIn` instead of 
moving it here. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to