This is an automated email from the ASF dual-hosted git repository.
zanmato pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 6f62c2adf2 GH-48268: [C++][Acero] Enhance the type checking for hash
join residual filter (#48272)
6f62c2adf2 is described below
commit 6f62c2adf27b7b5517c4010aefee055fd7d2f859
Author: Rossi Sun <[email protected]>
AuthorDate: Sat Nov 29 06:14:46 2025 -0800
GH-48268: [C++][Acero] Enhance the type checking for hash join residual
filter (#48272)
### Rationale for this change
Type checking for hash join filter isn't enforced for some corner cases
(literal filter expression). Some invalid tests are introduced.
### What changes are included in this PR?
Enforce the type checking for all cases. Also fix the problematic test
cases. Also refined the trivial residual filter handling in swiss join.
### Are these changes tested?
Test included.
### Are there any user-facing changes?
None.
* GitHub Issue: #48268
Authored-by: Rossi Sun <[email protected]>
Signed-off-by: Rossi Sun <[email protected]>
---
cpp/src/arrow/acero/hash_join_node.cc | 20 +++++++++------
cpp/src/arrow/acero/hash_join_node_test.cc | 39 ++++++++++++++++++++++++++++--
cpp/src/arrow/acero/swiss_join.cc | 26 +++++++++++---------
cpp/src/arrow/acero/swiss_join_internal.h | 2 ++
4 files changed, 67 insertions(+), 20 deletions(-)
diff --git a/cpp/src/arrow/acero/hash_join_node.cc
b/cpp/src/arrow/acero/hash_join_node.cc
index 28e3eb0e04..8311a76165 100644
--- a/cpp/src/arrow/acero/hash_join_node.cc
+++ b/cpp/src/arrow/acero/hash_join_node.cc
@@ -370,9 +370,19 @@ Result<Expression> HashJoinSchema::BindFilter(Expression
filter,
const Schema& left_schema,
const Schema& right_schema,
ExecContext* exec_context) {
- if (filter.IsBound() || filter == literal(true)) {
+ auto ValidateFilterTypeAndReturn = [](Expression filter) ->
Result<Expression> {
+ if (filter.type()->id() != Type::BOOL) {
+ return Status::TypeError("Filter expression must evaluate to bool, but ",
+ filter.ToString(), " evaluates to ",
+ filter.type()->ToString());
+ }
return filter;
+ };
+
+ if (filter.IsBound()) {
+ return ValidateFilterTypeAndReturn(std::move(filter));
}
+
// Step 1: Construct filter schema
FieldVector fields;
auto left_f_to_i =
@@ -401,12 +411,8 @@ Result<Expression> HashJoinSchema::BindFilter(Expression
filter,
// Step 3: Bind
ARROW_ASSIGN_OR_RAISE(filter, filter.Bind(filter_schema, exec_context));
- if (filter.type()->id() != Type::BOOL) {
- return Status::TypeError("Filter expression must evaluate to bool, but ",
- filter.ToString(), " evaluates to ",
- filter.type()->ToString());
- }
- return filter;
+
+ return ValidateFilterTypeAndReturn(std::move(filter));
}
Expression HashJoinSchema::RewriteFilterToUseFilterSchema(
diff --git a/cpp/src/arrow/acero/hash_join_node_test.cc
b/cpp/src/arrow/acero/hash_join_node_test.cc
index 654fd59c45..73355af8c4 100644
--- a/cpp/src/arrow/acero/hash_join_node_test.cc
+++ b/cpp/src/arrow/acero/hash_join_node_test.cc
@@ -1902,6 +1902,41 @@ TEST(HashJoin, CheckHashJoinNodeOptionsValidation) {
}
}
+TEST(HashJoin, CheckResidualFilterType) {
+ BatchesWithSchema input_left;
+ input_left.schema = schema({field("lkey", int32()), field("lpayload",
int32())});
+
+ BatchesWithSchema input_right;
+ input_right.schema = schema({field("rkey", int32()), field("rpayload",
int32())});
+
+ Declaration left{"source",
+ SourceNodeOptions{input_left.schema,
input_left.gen(/*parallel=*/false,
+
/*slow=*/false)}};
+ Declaration right{
+ "source", SourceNodeOptions{input_right.schema,
input_right.gen(/*parallel=*/false,
+
/*slow=*/false)}};
+
+ for (const auto& filter :
+ {literal(MakeNullScalar(boolean())), literal(true), literal(false),
+ equal(field_ref("lpayload"), field_ref("rpayload"))}) {
+ HashJoinNodeOptions options{
+ JoinType::INNER, {FieldRef("lkey")}, {FieldRef("rkey")}, filter};
+ Declaration join{"hashjoin", {left, right}, options};
+ ASSERT_OK(DeclarationToStatus(std::move(join)));
+ }
+
+ for (const auto& filter :
+ {literal(NullScalar()), literal(42),
+ call("add", {field_ref("lpayload"), field_ref("rpayload")})}) {
+ HashJoinNodeOptions options{
+ JoinType::INNER, {FieldRef("lkey")}, {FieldRef("rkey")}, filter};
+ Declaration join{"hashjoin", {left, right}, options};
+ EXPECT_RAISES_WITH_MESSAGE_THAT(TypeError,
+ ::testing::HasSubstr("must evaluate to
bool"),
+ DeclarationToStatus(std::move(join)));
+ }
+}
+
class ResidualFilterCaseRunner {
public:
ResidualFilterCaseRunner(BatchesWithSchema left_input, BatchesWithSchema
right_input)
@@ -2369,8 +2404,8 @@ TEST(HashJoin, FineGrainedResidualFilter) {
{
// Literal false, null, and scalar false, null.
for (Expression filter :
- {literal(false), literal(NullScalar()), equal(literal(0), literal(1)),
- equal(literal(1), literal(NullScalar()))}) {
+ {literal(false), literal(MakeNullScalar(boolean())),
+ equal(literal(0), literal(1)), equal(literal(1),
literal(NullScalar()))}) {
std::vector<FieldRef> left_keys{"l_key", "l_filter"},
right_keys{"r_key", "r_filter"};
{
diff --git a/cpp/src/arrow/acero/swiss_join.cc
b/cpp/src/arrow/acero/swiss_join.cc
index 9bdfaaae2f..97632e0ca0 100644
--- a/cpp/src/arrow/acero/swiss_join.cc
+++ b/cpp/src/arrow/acero/swiss_join.cc
@@ -1845,6 +1845,11 @@ void JoinResidualFilter::Init(Expression filter,
QueryContext* ctx, MemoryPool*
const HashJoinProjectionMaps* build_schemas,
SwissTableForJoin* hash_table) {
filter_ = std::move(filter);
+ if (auto lit = filter_.literal(); lit) {
+ const auto& scalar = lit->scalar_as<BooleanScalar>();
+ is_trivial_ = true;
+ is_literal_true_ = scalar.is_valid && scalar.value;
+ }
ctx_ = ctx;
pool_ = pool;
hardware_flags_ = hardware_flags;
@@ -1918,14 +1923,14 @@ Status JoinResidualFilter::FilterLeftSemi(const
ExecBatch& keypayload_batch,
arrow::util::TempVectorStack*
temp_stack,
int* num_passing_ids,
uint16_t* passing_batch_row_ids)
const {
- if (filter_ == literal(true)) {
+ if (is_literal_true_) {
CollectPassingBatchIds(1, hardware_flags_, batch_start_row, num_batch_rows,
match_bitvector, num_passing_ids,
passing_batch_row_ids);
return Status::OK();
}
*num_passing_ids = 0;
- if (filter_.IsNullLiteral() || filter_ == literal(false)) {
+ if (is_trivial_ && !is_literal_true_) {
return Status::OK();
}
@@ -1993,7 +1998,7 @@ Status JoinResidualFilter::FilterLeftAnti(const
ExecBatch& keypayload_batch,
arrow::util::TempVectorStack*
temp_stack,
int* num_passing_ids,
uint16_t* passing_batch_row_ids)
const {
- if (filter_ == literal(true)) {
+ if (is_literal_true_) {
CollectPassingBatchIds(0, hardware_flags_, batch_start_row, num_batch_rows,
match_bitvector, num_passing_ids,
passing_batch_row_ids);
return Status::OK();
@@ -2032,12 +2037,12 @@ Status JoinResidualFilter::FilterRightSemiAnti(
int64_t thread_id, const ExecBatch& keypayload_batch, int batch_start_row,
int num_batch_rows, const uint8_t* match_bitvector, const uint32_t*
key_ids,
bool no_duplicate_keys, arrow::util::TempVectorStack* temp_stack) const {
- if (filter_.IsNullLiteral() || filter_ == literal(false)) {
+ if (is_trivial_ && !is_literal_true_) {
return Status::OK();
}
int num_matching_ids = 0;
- if (filter_ == literal(true)) {
+ if (is_literal_true_) {
auto match_relative_batch_ids_buf =
arrow::util::TempVectorHolder<uint16_t>(temp_stack, num_batch_rows);
auto match_key_ids_buf =
@@ -2091,13 +2096,13 @@ Status JoinResidualFilter::FilterInner(
const ExecBatch& keypayload_batch, int num_batch_rows, uint16_t*
batch_row_ids,
uint32_t* key_ids, uint32_t* payload_ids_maybe_null, bool
output_payload_ids,
arrow::util::TempVectorStack* temp_stack, int* num_passing_rows) const {
- if (filter_ == literal(true)) {
+ if (is_literal_true_) {
*num_passing_rows = num_batch_rows;
return Status::OK();
}
*num_passing_rows = 0;
- if (filter_.IsNullLiteral() || filter_ == literal(false)) {
+ if (is_trivial_ && !is_literal_true_) {
return Status::OK();
}
@@ -2114,8 +2119,7 @@ Status JoinResidualFilter::FilterOneBatch(const
ExecBatch& keypayload_batch,
arrow::util::TempVectorStack*
temp_stack,
int* num_passing_rows) const {
// Caller must do shortcuts for trivial filter.
- ARROW_DCHECK(!filter_.IsNullLiteral() && filter_ != literal(true) &&
- filter_ != literal(false));
+ ARROW_DCHECK(!is_trivial_);
ARROW_DCHECK(!output_key_ids || key_ids_maybe_null);
ARROW_DCHECK(!output_payload_ids || payload_ids_maybe_null);
@@ -2128,6 +2132,7 @@ Status JoinResidualFilter::FilterOneBatch(const
ExecBatch& keypayload_batch,
ARROW_ASSIGN_OR_RAISE(Datum mask,
EvalFilter(keypayload_batch, num_batch_rows,
batch_row_ids,
key_ids_maybe_null,
payload_ids_maybe_null));
+ DCHECK_EQ(mask.type()->id(), Type::BOOL);
if (mask.is_scalar()) {
const auto& mask_scalar = mask.scalar_as<BooleanScalar>();
if (mask_scalar.is_valid && mask_scalar.value) {
@@ -2162,8 +2167,7 @@ Status JoinResidualFilter::FilterOneBatch(const
ExecBatch& keypayload_batch,
Result<Datum> JoinResidualFilter::EvalFilter(
const ExecBatch& keypayload_batch, int num_batch_rows, const uint16_t*
batch_row_ids,
const uint32_t* key_ids_maybe_null, const uint32_t*
payload_ids_maybe_null) const {
- ARROW_DCHECK(!filter_.IsNullLiteral() && filter_ != literal(true) &&
- filter_ != literal(false));
+ ARROW_DCHECK(!is_trivial_);
ARROW_ASSIGN_OR_RAISE(
ExecBatch input,
diff --git a/cpp/src/arrow/acero/swiss_join_internal.h
b/cpp/src/arrow/acero/swiss_join_internal.h
index 2512f9a752..47f1b36149 100644
--- a/cpp/src/arrow/acero/swiss_join_internal.h
+++ b/cpp/src/arrow/acero/swiss_join_internal.h
@@ -980,6 +980,8 @@ class JoinResidualFilter {
private:
Expression filter_;
+ bool is_trivial_ = false;
+ bool is_literal_true_ = false;
QueryContext* ctx_;
MemoryPool* pool_;