This is an automated email from the ASF dual-hosted git repository.

zanmato pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new 6f62c2adf2 GH-48268: [C++][Acero] Enhance the type checking for hash 
join residual filter (#48272)
6f62c2adf2 is described below

commit 6f62c2adf27b7b5517c4010aefee055fd7d2f859
Author: Rossi Sun <[email protected]>
AuthorDate: Sat Nov 29 06:14:46 2025 -0800

    GH-48268: [C++][Acero] Enhance the type checking for hash join residual 
filter (#48272)
    
    ### Rationale for this change
    
    Type checking for hash join filter isn't enforced for some corner cases 
(literal filter expression). Some invalid tests are introduced.
    
    ### What changes are included in this PR?
    
    Enforce the type checking for all cases. Also fix the problematic test 
cases. Also refined the trivial residual filter handling in swiss join.
    
    ### Are these changes tested?
    
    Test included.
    
    ### Are there any user-facing changes?
    
    None.
    * GitHub Issue: #48268
    
    Authored-by: Rossi Sun <[email protected]>
    Signed-off-by: Rossi Sun <[email protected]>
---
 cpp/src/arrow/acero/hash_join_node.cc      | 20 +++++++++------
 cpp/src/arrow/acero/hash_join_node_test.cc | 39 ++++++++++++++++++++++++++++--
 cpp/src/arrow/acero/swiss_join.cc          | 26 +++++++++++---------
 cpp/src/arrow/acero/swiss_join_internal.h  |  2 ++
 4 files changed, 67 insertions(+), 20 deletions(-)

diff --git a/cpp/src/arrow/acero/hash_join_node.cc 
b/cpp/src/arrow/acero/hash_join_node.cc
index 28e3eb0e04..8311a76165 100644
--- a/cpp/src/arrow/acero/hash_join_node.cc
+++ b/cpp/src/arrow/acero/hash_join_node.cc
@@ -370,9 +370,19 @@ Result<Expression> HashJoinSchema::BindFilter(Expression 
filter,
                                               const Schema& left_schema,
                                               const Schema& right_schema,
                                               ExecContext* exec_context) {
-  if (filter.IsBound() || filter == literal(true)) {
+  auto ValidateFilterTypeAndReturn = [](Expression filter) -> 
Result<Expression> {
+    if (filter.type()->id() != Type::BOOL) {
+      return Status::TypeError("Filter expression must evaluate to bool, but ",
+                               filter.ToString(), " evaluates to ",
+                               filter.type()->ToString());
+    }
     return filter;
+  };
+
+  if (filter.IsBound()) {
+    return ValidateFilterTypeAndReturn(std::move(filter));
   }
+
   // Step 1: Construct filter schema
   FieldVector fields;
   auto left_f_to_i =
@@ -401,12 +411,8 @@ Result<Expression> HashJoinSchema::BindFilter(Expression 
filter,
 
   // Step 3: Bind
   ARROW_ASSIGN_OR_RAISE(filter, filter.Bind(filter_schema, exec_context));
-  if (filter.type()->id() != Type::BOOL) {
-    return Status::TypeError("Filter expression must evaluate to bool, but ",
-                             filter.ToString(), " evaluates to ",
-                             filter.type()->ToString());
-  }
-  return filter;
+
+  return ValidateFilterTypeAndReturn(std::move(filter));
 }
 
 Expression HashJoinSchema::RewriteFilterToUseFilterSchema(
diff --git a/cpp/src/arrow/acero/hash_join_node_test.cc 
b/cpp/src/arrow/acero/hash_join_node_test.cc
index 654fd59c45..73355af8c4 100644
--- a/cpp/src/arrow/acero/hash_join_node_test.cc
+++ b/cpp/src/arrow/acero/hash_join_node_test.cc
@@ -1902,6 +1902,41 @@ TEST(HashJoin, CheckHashJoinNodeOptionsValidation) {
   }
 }
 
+TEST(HashJoin, CheckResidualFilterType) {
+  BatchesWithSchema input_left;
+  input_left.schema = schema({field("lkey", int32()), field("lpayload", 
int32())});
+
+  BatchesWithSchema input_right;
+  input_right.schema = schema({field("rkey", int32()), field("rpayload", 
int32())});
+
+  Declaration left{"source",
+                   SourceNodeOptions{input_left.schema, 
input_left.gen(/*parallel=*/false,
+                                                                       
/*slow=*/false)}};
+  Declaration right{
+      "source", SourceNodeOptions{input_right.schema, 
input_right.gen(/*parallel=*/false,
+                                                                      
/*slow=*/false)}};
+
+  for (const auto& filter :
+       {literal(MakeNullScalar(boolean())), literal(true), literal(false),
+        equal(field_ref("lpayload"), field_ref("rpayload"))}) {
+    HashJoinNodeOptions options{
+        JoinType::INNER, {FieldRef("lkey")}, {FieldRef("rkey")}, filter};
+    Declaration join{"hashjoin", {left, right}, options};
+    ASSERT_OK(DeclarationToStatus(std::move(join)));
+  }
+
+  for (const auto& filter :
+       {literal(NullScalar()), literal(42),
+        call("add", {field_ref("lpayload"), field_ref("rpayload")})}) {
+    HashJoinNodeOptions options{
+        JoinType::INNER, {FieldRef("lkey")}, {FieldRef("rkey")}, filter};
+    Declaration join{"hashjoin", {left, right}, options};
+    EXPECT_RAISES_WITH_MESSAGE_THAT(TypeError,
+                                    ::testing::HasSubstr("must evaluate to 
bool"),
+                                    DeclarationToStatus(std::move(join)));
+  }
+}
+
 class ResidualFilterCaseRunner {
  public:
   ResidualFilterCaseRunner(BatchesWithSchema left_input, BatchesWithSchema 
right_input)
@@ -2369,8 +2404,8 @@ TEST(HashJoin, FineGrainedResidualFilter) {
   {
     // Literal false, null, and scalar false, null.
     for (Expression filter :
-         {literal(false), literal(NullScalar()), equal(literal(0), literal(1)),
-          equal(literal(1), literal(NullScalar()))}) {
+         {literal(false), literal(MakeNullScalar(boolean())),
+          equal(literal(0), literal(1)), equal(literal(1), 
literal(NullScalar()))}) {
       std::vector<FieldRef> left_keys{"l_key", "l_filter"},
           right_keys{"r_key", "r_filter"};
       {
diff --git a/cpp/src/arrow/acero/swiss_join.cc 
b/cpp/src/arrow/acero/swiss_join.cc
index 9bdfaaae2f..97632e0ca0 100644
--- a/cpp/src/arrow/acero/swiss_join.cc
+++ b/cpp/src/arrow/acero/swiss_join.cc
@@ -1845,6 +1845,11 @@ void JoinResidualFilter::Init(Expression filter, 
QueryContext* ctx, MemoryPool*
                               const HashJoinProjectionMaps* build_schemas,
                               SwissTableForJoin* hash_table) {
   filter_ = std::move(filter);
+  if (auto lit = filter_.literal(); lit) {
+    const auto& scalar = lit->scalar_as<BooleanScalar>();
+    is_trivial_ = true;
+    is_literal_true_ = scalar.is_valid && scalar.value;
+  }
   ctx_ = ctx;
   pool_ = pool;
   hardware_flags_ = hardware_flags;
@@ -1918,14 +1923,14 @@ Status JoinResidualFilter::FilterLeftSemi(const 
ExecBatch& keypayload_batch,
                                           arrow::util::TempVectorStack* 
temp_stack,
                                           int* num_passing_ids,
                                           uint16_t* passing_batch_row_ids) 
const {
-  if (filter_ == literal(true)) {
+  if (is_literal_true_) {
     CollectPassingBatchIds(1, hardware_flags_, batch_start_row, num_batch_rows,
                            match_bitvector, num_passing_ids, 
passing_batch_row_ids);
     return Status::OK();
   }
 
   *num_passing_ids = 0;
-  if (filter_.IsNullLiteral() || filter_ == literal(false)) {
+  if (is_trivial_ && !is_literal_true_) {
     return Status::OK();
   }
 
@@ -1993,7 +1998,7 @@ Status JoinResidualFilter::FilterLeftAnti(const 
ExecBatch& keypayload_batch,
                                           arrow::util::TempVectorStack* 
temp_stack,
                                           int* num_passing_ids,
                                           uint16_t* passing_batch_row_ids) 
const {
-  if (filter_ == literal(true)) {
+  if (is_literal_true_) {
     CollectPassingBatchIds(0, hardware_flags_, batch_start_row, num_batch_rows,
                            match_bitvector, num_passing_ids, 
passing_batch_row_ids);
     return Status::OK();
@@ -2032,12 +2037,12 @@ Status JoinResidualFilter::FilterRightSemiAnti(
     int64_t thread_id, const ExecBatch& keypayload_batch, int batch_start_row,
     int num_batch_rows, const uint8_t* match_bitvector, const uint32_t* 
key_ids,
     bool no_duplicate_keys, arrow::util::TempVectorStack* temp_stack) const {
-  if (filter_.IsNullLiteral() || filter_ == literal(false)) {
+  if (is_trivial_ && !is_literal_true_) {
     return Status::OK();
   }
 
   int num_matching_ids = 0;
-  if (filter_ == literal(true)) {
+  if (is_literal_true_) {
     auto match_relative_batch_ids_buf =
         arrow::util::TempVectorHolder<uint16_t>(temp_stack, num_batch_rows);
     auto match_key_ids_buf =
@@ -2091,13 +2096,13 @@ Status JoinResidualFilter::FilterInner(
     const ExecBatch& keypayload_batch, int num_batch_rows, uint16_t* 
batch_row_ids,
     uint32_t* key_ids, uint32_t* payload_ids_maybe_null, bool 
output_payload_ids,
     arrow::util::TempVectorStack* temp_stack, int* num_passing_rows) const {
-  if (filter_ == literal(true)) {
+  if (is_literal_true_) {
     *num_passing_rows = num_batch_rows;
     return Status::OK();
   }
 
   *num_passing_rows = 0;
-  if (filter_.IsNullLiteral() || filter_ == literal(false)) {
+  if (is_trivial_ && !is_literal_true_) {
     return Status::OK();
   }
 
@@ -2114,8 +2119,7 @@ Status JoinResidualFilter::FilterOneBatch(const 
ExecBatch& keypayload_batch,
                                           arrow::util::TempVectorStack* 
temp_stack,
                                           int* num_passing_rows) const {
   // Caller must do shortcuts for trivial filter.
-  ARROW_DCHECK(!filter_.IsNullLiteral() && filter_ != literal(true) &&
-               filter_ != literal(false));
+  ARROW_DCHECK(!is_trivial_);
   ARROW_DCHECK(!output_key_ids || key_ids_maybe_null);
   ARROW_DCHECK(!output_payload_ids || payload_ids_maybe_null);
 
@@ -2128,6 +2132,7 @@ Status JoinResidualFilter::FilterOneBatch(const 
ExecBatch& keypayload_batch,
   ARROW_ASSIGN_OR_RAISE(Datum mask,
                         EvalFilter(keypayload_batch, num_batch_rows, 
batch_row_ids,
                                    key_ids_maybe_null, 
payload_ids_maybe_null));
+  DCHECK_EQ(mask.type()->id(), Type::BOOL);
   if (mask.is_scalar()) {
     const auto& mask_scalar = mask.scalar_as<BooleanScalar>();
     if (mask_scalar.is_valid && mask_scalar.value) {
@@ -2162,8 +2167,7 @@ Status JoinResidualFilter::FilterOneBatch(const 
ExecBatch& keypayload_batch,
 Result<Datum> JoinResidualFilter::EvalFilter(
     const ExecBatch& keypayload_batch, int num_batch_rows, const uint16_t* 
batch_row_ids,
     const uint32_t* key_ids_maybe_null, const uint32_t* 
payload_ids_maybe_null) const {
-  ARROW_DCHECK(!filter_.IsNullLiteral() && filter_ != literal(true) &&
-               filter_ != literal(false));
+  ARROW_DCHECK(!is_trivial_);
 
   ARROW_ASSIGN_OR_RAISE(
       ExecBatch input,
diff --git a/cpp/src/arrow/acero/swiss_join_internal.h 
b/cpp/src/arrow/acero/swiss_join_internal.h
index 2512f9a752..47f1b36149 100644
--- a/cpp/src/arrow/acero/swiss_join_internal.h
+++ b/cpp/src/arrow/acero/swiss_join_internal.h
@@ -980,6 +980,8 @@ class JoinResidualFilter {
 
  private:
   Expression filter_;
+  bool is_trivial_ = false;
+  bool is_literal_true_ = false;
 
   QueryContext* ctx_;
   MemoryPool* pool_;

Reply via email to