westonpace commented on a change in pull request #11579:
URL: https://github.com/apache/arrow/pull/11579#discussion_r745504468
##########
File path: cpp/src/arrow/compute/exec/hash_join.cc
##########
@@ -282,6 +279,125 @@ class HashJoinBasicImpl : public HashJoinImpl {
num_batches_produced_++;
}
+ Status ProbeBatch_ResidualFilter(ThreadLocalState& local_state,
+ std::vector<int32_t>& match,
+ std::vector<int32_t>& no_match,
+ std::vector<int32_t>& match_left,
+ std::vector<int32_t>& match_right) {
+ if (filter_ == literal(true)) {
+ return Status::OK();
+ }
+ ARROW_DCHECK_EQ(match_left.size(), match_right.size());
+
+ ExecBatch concatenated({}, match_left.size());
+
+ ARROW_ASSIGN_OR_RAISE(ExecBatch left_key,
local_state.exec_batch_keys.Decode(
+ match_left.size(),
match_left.data()));
+ ARROW_ASSIGN_OR_RAISE(
+ ExecBatch right_key,
+ hash_table_keys_.Decode(match_right.size(), match_right.data()));
+
+ ExecBatch left_payload;
+ if (!schema_mgr_->LeftPayloadIsEmpty()) {
+ ARROW_ASSIGN_OR_RAISE(left_payload,
local_state.exec_batch_payloads.Decode(
+ match_left.size(),
match_left.data()));
+ }
+
+ ExecBatch right_payload;
+ if (!schema_mgr_->RightPayloadIsEmpty()) {
+ ARROW_ASSIGN_OR_RAISE(right_payload, hash_table_payloads_.Decode(
+ match_right.size(),
match_right.data()));
+ }
+
+ auto AppendFields = [&concatenated](const SchemaProjectionMap& to_key,
+ const SchemaProjectionMap& to_pay,
+ const ExecBatch& key, const ExecBatch&
payload) {
+ ARROW_DCHECK(to_key.num_cols == to_pay.num_cols);
+ for (int i = 0; i < to_key.num_cols; i++) {
+ if (to_key.get(i) != SchemaProjectionMap::kMissingField) {
+ int key_idx = to_key.get(i);
+ concatenated.values.push_back(key.values[key_idx]);
+ } else if (to_pay.get(i) != SchemaProjectionMap::kMissingField) {
+ int pay_idx = to_pay.get(i);
+ concatenated.values.push_back(payload.values[pay_idx]);
+ }
+ }
+ };
+
+ SchemaProjectionMap left_to_key = schema_mgr_->proj_maps[0].map(
+ HashJoinProjection::FILTER, HashJoinProjection::KEY);
+ SchemaProjectionMap left_to_pay = schema_mgr_->proj_maps[0].map(
+ HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD);
+ SchemaProjectionMap right_to_key = schema_mgr_->proj_maps[1].map(
+ HashJoinProjection::FILTER, HashJoinProjection::KEY);
+ SchemaProjectionMap right_to_pay = schema_mgr_->proj_maps[1].map(
+ HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD);
+
+ AppendFields(left_to_key, left_to_pay, left_key, left_payload);
+ AppendFields(right_to_key, right_to_pay, right_key, right_payload);
+
+ ARROW_ASSIGN_OR_RAISE(Datum mask,
+ ExecuteScalarExpression(filter_, concatenated,
ctx_));
+
+ size_t num_probed_rows = match.size() + no_match.size();
+ if (mask.is_scalar()) {
+ const auto& mask_scalar = mask.scalar_as<BooleanScalar>();
+ if (mask_scalar.is_valid && mask_scalar.value) {
+ // All rows passed, nothing left to do
+ return Status::OK();
+ } else {
+ // Nothing passed, no_match becomes everything
+ no_match.resize(num_probed_rows);
+ std::iota(no_match.begin(), no_match.end(), 0);
+ match_left.clear();
+ match_right.clear();
+ match.clear();
+ return Status::OK();
+ }
+ }
+ ARROW_DCHECK(mask.array()->offset == 0);
+ ARROW_DCHECK(mask.array()->length ==
static_cast<int64_t>(match_left.size()));
Review comment:
Nit: ARROW_DCHECK_EQ
##########
File path: cpp/src/arrow/compute/exec/hash_join.cc
##########
@@ -282,6 +279,125 @@ class HashJoinBasicImpl : public HashJoinImpl {
num_batches_produced_++;
}
+ Status ProbeBatch_ResidualFilter(ThreadLocalState& local_state,
+ std::vector<int32_t>& match,
+ std::vector<int32_t>& no_match,
+ std::vector<int32_t>& match_left,
+ std::vector<int32_t>& match_right) {
+ if (filter_ == literal(true)) {
+ return Status::OK();
+ }
+ ARROW_DCHECK_EQ(match_left.size(), match_right.size());
+
+ ExecBatch concatenated({}, match_left.size());
+
+ ARROW_ASSIGN_OR_RAISE(ExecBatch left_key,
local_state.exec_batch_keys.Decode(
+ match_left.size(),
match_left.data()));
+ ARROW_ASSIGN_OR_RAISE(
+ ExecBatch right_key,
+ hash_table_keys_.Decode(match_right.size(), match_right.data()));
+
+ ExecBatch left_payload;
+ if (!schema_mgr_->LeftPayloadIsEmpty()) {
+ ARROW_ASSIGN_OR_RAISE(left_payload,
local_state.exec_batch_payloads.Decode(
+ match_left.size(),
match_left.data()));
+ }
+
+ ExecBatch right_payload;
+ if (!schema_mgr_->RightPayloadIsEmpty()) {
+ ARROW_ASSIGN_OR_RAISE(right_payload, hash_table_payloads_.Decode(
+ match_right.size(),
match_right.data()));
+ }
+
+ auto AppendFields = [&concatenated](const SchemaProjectionMap& to_key,
+ const SchemaProjectionMap& to_pay,
+ const ExecBatch& key, const ExecBatch&
payload) {
+ ARROW_DCHECK(to_key.num_cols == to_pay.num_cols);
+ for (int i = 0; i < to_key.num_cols; i++) {
+ if (to_key.get(i) != SchemaProjectionMap::kMissingField) {
+ int key_idx = to_key.get(i);
+ concatenated.values.push_back(key.values[key_idx]);
+ } else if (to_pay.get(i) != SchemaProjectionMap::kMissingField) {
+ int pay_idx = to_pay.get(i);
+ concatenated.values.push_back(payload.values[pay_idx]);
+ }
+ }
+ };
+
+ SchemaProjectionMap left_to_key = schema_mgr_->proj_maps[0].map(
+ HashJoinProjection::FILTER, HashJoinProjection::KEY);
+ SchemaProjectionMap left_to_pay = schema_mgr_->proj_maps[0].map(
+ HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD);
+ SchemaProjectionMap right_to_key = schema_mgr_->proj_maps[1].map(
+ HashJoinProjection::FILTER, HashJoinProjection::KEY);
+ SchemaProjectionMap right_to_pay = schema_mgr_->proj_maps[1].map(
+ HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD);
+
+ AppendFields(left_to_key, left_to_pay, left_key, left_payload);
+ AppendFields(right_to_key, right_to_pay, right_key, right_payload);
+
+ ARROW_ASSIGN_OR_RAISE(Datum mask,
+ ExecuteScalarExpression(filter_, concatenated,
ctx_));
+
+ size_t num_probed_rows = match.size() + no_match.size();
+ if (mask.is_scalar()) {
+ const auto& mask_scalar = mask.scalar_as<BooleanScalar>();
+ if (mask_scalar.is_valid && mask_scalar.value) {
+ // All rows passed, nothing left to do
+ return Status::OK();
+ } else {
+ // Nothing passed, no_match becomes everything
+ no_match.resize(num_probed_rows);
+ std::iota(no_match.begin(), no_match.end(), 0);
+ match_left.clear();
+ match_right.clear();
+ match.clear();
+ return Status::OK();
+ }
+ }
+ ARROW_DCHECK(mask.array()->offset == 0);
+ ARROW_DCHECK(mask.array()->length ==
static_cast<int64_t>(match_left.size()));
+ const uint8_t* nulls = mask.array()->buffers[0]->data();
+ const uint8_t* comparisons = mask.array()->buffers[1]->data();
+ size_t num_rows = match_left.size();
+
+ match.clear();
+ no_match.clear();
+
+ int32_t match_idx = 0; // current size of new match_left
+ int32_t irow = 0; // index into match_left
+ for (int32_t curr_left = 0; static_cast<size_t>(curr_left) <
num_probed_rows;
+ curr_left++) {
+ int32_t advance_to = static_cast<size_t>(irow) < num_rows
+ ? match_left[irow]
+ : static_cast<int32_t>(num_probed_rows);
+ while (curr_left < advance_to) {
+ no_match.push_back(curr_left++);
+ }
+ bool passed = false;
+ for (; static_cast<size_t>(irow) < num_rows && match_left[irow] ==
curr_left;
+ irow++) {
+ bool is_null = !BitUtil::GetBit(nulls, irow);
+ bool is_cmp_true = BitUtil::GetBit(comparisons, irow);
+ // We treat a null comparison result as false, like in SQL
+ if (!is_null && is_cmp_true) {
Review comment:
Super minor nit: `is_valid` might be clearer to read since you wouldn't
have a double negation.
##########
File path: cpp/src/arrow/compute/exec/hash_join_node.cc
##########
@@ -43,32 +43,49 @@ bool HashJoinSchema::IsTypeSupported(const DataType& type) {
return is_fixed_width(id) || is_binary_like(id) || is_large_binary_like(id);
}
-Result<std::vector<FieldRef>> HashJoinSchema::VectorDiff(const Schema& schema,
- const
std::vector<FieldRef>& a,
- const
std::vector<FieldRef>& b) {
- std::unordered_set<int> b_paths;
- for (size_t i = 0; i < b.size(); ++i) {
- ARROW_ASSIGN_OR_RAISE(auto match, b[i].FindOne(schema));
- b_paths.insert(match[0]);
+Result<std::vector<FieldRef>> HashJoinSchema::ComputePayload(
+ const Schema& schema, const std::vector<FieldRef>& output,
+ const std::vector<FieldRef>& filter, const std::vector<FieldRef>& keys) {
+ // payload = (output + filter) - keys, with no duplicates
Review comment:
So what happens if there are duplicate field names in the schema? Is
this a malformed query?
@bkietz Is this something the IR would detect and prevent in the first place?
##########
File path: cpp/src/arrow/compute/exec/hash_join.cc
##########
@@ -282,6 +279,125 @@ class HashJoinBasicImpl : public HashJoinImpl {
num_batches_produced_++;
}
+ Status ProbeBatch_ResidualFilter(ThreadLocalState& local_state,
+ std::vector<int32_t>& match,
+ std::vector<int32_t>& no_match,
+ std::vector<int32_t>& match_left,
+ std::vector<int32_t>& match_right) {
+ if (filter_ == literal(true)) {
+ return Status::OK();
+ }
+ ARROW_DCHECK_EQ(match_left.size(), match_right.size());
+
+ ExecBatch concatenated({}, match_left.size());
+
+ ARROW_ASSIGN_OR_RAISE(ExecBatch left_key,
local_state.exec_batch_keys.Decode(
+ match_left.size(),
match_left.data()));
+ ARROW_ASSIGN_OR_RAISE(
+ ExecBatch right_key,
+ hash_table_keys_.Decode(match_right.size(), match_right.data()));
+
+ ExecBatch left_payload;
+ if (!schema_mgr_->LeftPayloadIsEmpty()) {
+ ARROW_ASSIGN_OR_RAISE(left_payload,
local_state.exec_batch_payloads.Decode(
+ match_left.size(),
match_left.data()));
+ }
+
+ ExecBatch right_payload;
+ if (!schema_mgr_->RightPayloadIsEmpty()) {
+ ARROW_ASSIGN_OR_RAISE(right_payload, hash_table_payloads_.Decode(
+ match_right.size(),
match_right.data()));
+ }
Review comment:
Do your tests cover situations where the `left_payload` or
`right_payload` is empty?
##########
File path: cpp/src/arrow/compute/exec/hash_join.cc
##########
@@ -282,6 +279,125 @@ class HashJoinBasicImpl : public HashJoinImpl {
num_batches_produced_++;
}
+ Status ProbeBatch_ResidualFilter(ThreadLocalState& local_state,
+ std::vector<int32_t>& match,
+ std::vector<int32_t>& no_match,
+ std::vector<int32_t>& match_left,
+ std::vector<int32_t>& match_right) {
+ if (filter_ == literal(true)) {
+ return Status::OK();
+ }
+ ARROW_DCHECK_EQ(match_left.size(), match_right.size());
+
+ ExecBatch concatenated({}, match_left.size());
+
+ ARROW_ASSIGN_OR_RAISE(ExecBatch left_key,
local_state.exec_batch_keys.Decode(
+ match_left.size(),
match_left.data()));
+ ARROW_ASSIGN_OR_RAISE(
+ ExecBatch right_key,
+ hash_table_keys_.Decode(match_right.size(), match_right.data()));
+
+ ExecBatch left_payload;
+ if (!schema_mgr_->LeftPayloadIsEmpty()) {
+ ARROW_ASSIGN_OR_RAISE(left_payload,
local_state.exec_batch_payloads.Decode(
+ match_left.size(),
match_left.data()));
+ }
+
+ ExecBatch right_payload;
+ if (!schema_mgr_->RightPayloadIsEmpty()) {
+ ARROW_ASSIGN_OR_RAISE(right_payload, hash_table_payloads_.Decode(
+ match_right.size(),
match_right.data()));
+ }
+
+ auto AppendFields = [&concatenated](const SchemaProjectionMap& to_key,
+ const SchemaProjectionMap& to_pay,
+ const ExecBatch& key, const ExecBatch&
payload) {
+ ARROW_DCHECK(to_key.num_cols == to_pay.num_cols);
+ for (int i = 0; i < to_key.num_cols; i++) {
+ if (to_key.get(i) != SchemaProjectionMap::kMissingField) {
+ int key_idx = to_key.get(i);
+ concatenated.values.push_back(key.values[key_idx]);
+ } else if (to_pay.get(i) != SchemaProjectionMap::kMissingField) {
+ int pay_idx = to_pay.get(i);
+ concatenated.values.push_back(payload.values[pay_idx]);
+ }
+ }
+ };
+
+ SchemaProjectionMap left_to_key = schema_mgr_->proj_maps[0].map(
+ HashJoinProjection::FILTER, HashJoinProjection::KEY);
+ SchemaProjectionMap left_to_pay = schema_mgr_->proj_maps[0].map(
+ HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD);
+ SchemaProjectionMap right_to_key = schema_mgr_->proj_maps[1].map(
+ HashJoinProjection::FILTER, HashJoinProjection::KEY);
+ SchemaProjectionMap right_to_pay = schema_mgr_->proj_maps[1].map(
+ HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD);
+
+ AppendFields(left_to_key, left_to_pay, left_key, left_payload);
+ AppendFields(right_to_key, right_to_pay, right_key, right_payload);
+
+ ARROW_ASSIGN_OR_RAISE(Datum mask,
+ ExecuteScalarExpression(filter_, concatenated,
ctx_));
+
+ size_t num_probed_rows = match.size() + no_match.size();
+ if (mask.is_scalar()) {
+ const auto& mask_scalar = mask.scalar_as<BooleanScalar>();
+ if (mask_scalar.is_valid && mask_scalar.value) {
+ // All rows passed, nothing left to do
+ return Status::OK();
+ } else {
+ // Nothing passed, no_match becomes everything
+ no_match.resize(num_probed_rows);
+ std::iota(no_match.begin(), no_match.end(), 0);
+ match_left.clear();
+ match_right.clear();
+ match.clear();
+ return Status::OK();
+ }
+ }
+ ARROW_DCHECK(mask.array()->offset == 0);
+ ARROW_DCHECK(mask.array()->length ==
static_cast<int64_t>(match_left.size()));
+ const uint8_t* nulls = mask.array()->buffers[0]->data();
Review comment:
Is `buffers[0]` guaranteed to exist? In general, for arrays, the
validity map is optional (in which case `buffers[0] == nullptr`. Maybe
`ExecuteScalarExpression` always creates one today but that could change.
Also, naming nit, but maybe `validity` instead of `nulls`? The latter name
would imply that a set bit indicates a null value.
##########
File path: cpp/src/arrow/compute/exec/hash_join.cc
##########
@@ -282,6 +279,125 @@ class HashJoinBasicImpl : public HashJoinImpl {
num_batches_produced_++;
}
+ Status ProbeBatch_ResidualFilter(ThreadLocalState& local_state,
+ std::vector<int32_t>& match,
+ std::vector<int32_t>& no_match,
+ std::vector<int32_t>& match_left,
+ std::vector<int32_t>& match_right) {
+ if (filter_ == literal(true)) {
+ return Status::OK();
+ }
+ ARROW_DCHECK_EQ(match_left.size(), match_right.size());
+
+ ExecBatch concatenated({}, match_left.size());
+
+ ARROW_ASSIGN_OR_RAISE(ExecBatch left_key,
local_state.exec_batch_keys.Decode(
+ match_left.size(),
match_left.data()));
+ ARROW_ASSIGN_OR_RAISE(
+ ExecBatch right_key,
+ hash_table_keys_.Decode(match_right.size(), match_right.data()));
+
+ ExecBatch left_payload;
+ if (!schema_mgr_->LeftPayloadIsEmpty()) {
+ ARROW_ASSIGN_OR_RAISE(left_payload,
local_state.exec_batch_payloads.Decode(
+ match_left.size(),
match_left.data()));
+ }
+
+ ExecBatch right_payload;
+ if (!schema_mgr_->RightPayloadIsEmpty()) {
+ ARROW_ASSIGN_OR_RAISE(right_payload, hash_table_payloads_.Decode(
+ match_right.size(),
match_right.data()));
+ }
+
+ auto AppendFields = [&concatenated](const SchemaProjectionMap& to_key,
+ const SchemaProjectionMap& to_pay,
+ const ExecBatch& key, const ExecBatch&
payload) {
+ ARROW_DCHECK(to_key.num_cols == to_pay.num_cols);
+ for (int i = 0; i < to_key.num_cols; i++) {
+ if (to_key.get(i) != SchemaProjectionMap::kMissingField) {
+ int key_idx = to_key.get(i);
+ concatenated.values.push_back(key.values[key_idx]);
+ } else if (to_pay.get(i) != SchemaProjectionMap::kMissingField) {
+ int pay_idx = to_pay.get(i);
+ concatenated.values.push_back(payload.values[pay_idx]);
+ }
+ }
+ };
+
+ SchemaProjectionMap left_to_key = schema_mgr_->proj_maps[0].map(
+ HashJoinProjection::FILTER, HashJoinProjection::KEY);
+ SchemaProjectionMap left_to_pay = schema_mgr_->proj_maps[0].map(
+ HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD);
+ SchemaProjectionMap right_to_key = schema_mgr_->proj_maps[1].map(
+ HashJoinProjection::FILTER, HashJoinProjection::KEY);
+ SchemaProjectionMap right_to_pay = schema_mgr_->proj_maps[1].map(
+ HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD);
+
+ AppendFields(left_to_key, left_to_pay, left_key, left_payload);
+ AppendFields(right_to_key, right_to_pay, right_key, right_payload);
+
+ ARROW_ASSIGN_OR_RAISE(Datum mask,
+ ExecuteScalarExpression(filter_, concatenated,
ctx_));
+
+ size_t num_probed_rows = match.size() + no_match.size();
+ if (mask.is_scalar()) {
+ const auto& mask_scalar = mask.scalar_as<BooleanScalar>();
Review comment:
Is this branch covered by your testing?
##########
File path: cpp/src/arrow/compute/exec/hash_join_node_test.cc
##########
@@ -1061,17 +1066,17 @@ TEST(HashJoin, Random) {
// Print test case parameters
// print num_rows, batch_size, join_type, join_cmp
- std::cout << join_type_name << " " << key_cmp_str << " ";
+ std::cout << "Trial " << test_id << ":\n";
+ std::cout << " " << join_type_name << " " << key_cmp_str << " ";
Review comment:
I'm a little surprised we have print statements here. I think we
normally avoid this kind of thing for tests, favoring `ARROW_SCOPED_TRACE`
instead but since they were here already :shrug:
##########
File path: cpp/src/arrow/compute/exec/hash_join_node.cc
##########
@@ -43,32 +43,49 @@ bool HashJoinSchema::IsTypeSupported(const DataType& type) {
return is_fixed_width(id) || is_binary_like(id) || is_large_binary_like(id);
}
-Result<std::vector<FieldRef>> HashJoinSchema::VectorDiff(const Schema& schema,
- const
std::vector<FieldRef>& a,
- const
std::vector<FieldRef>& b) {
- std::unordered_set<int> b_paths;
- for (size_t i = 0; i < b.size(); ++i) {
- ARROW_ASSIGN_OR_RAISE(auto match, b[i].FindOne(schema));
- b_paths.insert(match[0]);
+Result<std::vector<FieldRef>> HashJoinSchema::ComputePayload(
+ const Schema& schema, const std::vector<FieldRef>& output,
+ const std::vector<FieldRef>& filter, const std::vector<FieldRef>& keys) {
+ // payload = (output + filter) - keys, with no duplicates
+ std::unordered_set<int> payload_fields;
+ for (auto ref : output) {
+ ARROW_ASSIGN_OR_RAISE(auto match, ref.FindOne(schema));
+ payload_fields.insert(match[0]);
}
- std::vector<FieldRef> result;
+ for (auto ref : filter) {
+ ARROW_ASSIGN_OR_RAISE(auto match, ref.FindOne(schema));
+ payload_fields.insert(match[0]);
+ }
- for (size_t i = 0; i < a.size(); ++i) {
- ARROW_ASSIGN_OR_RAISE(auto match, a[i].FindOne(schema));
- bool is_found = (b_paths.find(match[0]) != b_paths.end());
- if (!is_found) {
- result.push_back(a[i]);
- }
+ for (auto ref : keys) {
+ ARROW_ASSIGN_OR_RAISE(auto match, ref.FindOne(schema));
+ payload_fields.erase(match[0]);
}
+ std::vector<FieldRef> result;
Review comment:
Minor nit: I know it was this way before but in general it's nice if we
can avoid names like `result` and use something meaningful like `payload_refs`
##########
File path: cpp/src/arrow/compute/exec/hash_join_node.cc
##########
@@ -274,17 +303,83 @@ std::shared_ptr<Schema> HashJoinSchema::MakeOutputSchema(
return std::make_shared<Schema>(std::move(fields));
}
+Result<Expression> HashJoinSchema::BindFilter(Expression filter,
+ const Schema& left_schema,
+ const Schema& right_schema) {
+ if (filter.IsBound()) {
+ return std::move(filter);
+ }
Review comment:
Maybe `if (filter.IsBound() || filter == literal(true))` and then you
could get rid of the next if.
##########
File path: cpp/src/arrow/compute/exec/hash_join_node.cc
##########
@@ -300,30 +395,36 @@ class HashJoinNode : public ExecNode {
const auto& join_options = checked_cast<const
HashJoinNodeOptions&>(options);
+ const auto& left_schema = *(inputs[0]->output_schema());
+ const auto& right_schema = *(inputs[1]->output_schema());
// This will also validate input schemas
if (join_options.output_all) {
RETURN_NOT_OK(schema_mgr->Init(
- join_options.join_type, *(inputs[0]->output_schema()),
join_options.left_keys,
- *(inputs[1]->output_schema()), join_options.right_keys,
+ join_options.join_type, left_schema, join_options.left_keys,
right_schema,
+ join_options.right_keys, join_options.filter,
join_options.output_prefix_for_left,
join_options.output_prefix_for_right));
} else {
RETURN_NOT_OK(schema_mgr->Init(
- join_options.join_type, *(inputs[0]->output_schema()),
join_options.left_keys,
- join_options.left_output, *(inputs[1]->output_schema()),
- join_options.right_keys, join_options.right_output,
+ join_options.join_type, left_schema, join_options.left_keys,
+ join_options.left_output, right_schema, join_options.right_keys,
+ join_options.right_output, join_options.filter,
join_options.output_prefix_for_left,
join_options.output_prefix_for_right));
}
+ ARROW_ASSIGN_OR_RAISE(Expression filter,
+
schema_mgr->BindFilter(std::move(join_options.filter),
+ left_schema, right_schema));
+
Review comment:
I think you'd get away with it 9 times out of 10 but I'm not sure it is
safe to move `join_options.filter`
##########
File path: cpp/src/arrow/compute/exec/hash_join_node.cc
##########
@@ -274,17 +303,83 @@ std::shared_ptr<Schema> HashJoinSchema::MakeOutputSchema(
return std::make_shared<Schema>(std::move(fields));
}
+Result<Expression> HashJoinSchema::BindFilter(Expression filter,
+ const Schema& left_schema,
+ const Schema& right_schema) {
+ if (filter.IsBound()) {
+ return std::move(filter);
+ }
+ if (filter != literal(true)) {
+ FieldVector fields;
+ auto left = proj_maps[0].map(HashJoinProjection::FILTER,
HashJoinProjection::INPUT);
+ auto right = proj_maps[1].map(HashJoinProjection::FILTER,
HashJoinProjection::INPUT);
+
+ auto AppendFieldsInMap = [&fields](const SchemaProjectionMap& map,
+ const Schema& schema) {
+ for (int i = 0; i < map.num_cols; i++) {
+ int input_idx = map.get(i);
+ fields.push_back(schema.fields()[input_idx]);
+ }
+ };
+ AppendFieldsInMap(left, left_schema);
+ AppendFieldsInMap(right, right_schema);
+ Schema filter_schema(fields);
+ ARROW_ASSIGN_OR_RAISE(filter, filter.Bind(filter_schema));
+ if (filter.type()->id() != Type::BOOL) {
+ return Status::TypeError("Filter expression must evaluate to bool, but ",
+ filter.ToString(), " evaluates to ",
+ filter.type()->ToString());
+ }
+ return std::move(filter);
+ }
+ return literal(true);
+}
+
+Result<std::vector<FieldRef>> HashJoinSchema::CollectFilterColumns(
+ const Expression& filter, const Schema& schema) {
+ std::vector<FieldRef> nonunique_refs;
+ RETURN_NOT_OK(TraverseExpression(nonunique_refs, filter, schema));
+
+ std::vector<FieldRef> result;
+ std::unordered_set<int> seen_paths;
+ for (auto ref : nonunique_refs) {
+ ARROW_ASSIGN_OR_RAISE(auto match, ref.FindOne(schema));
+ if (seen_paths.find(match[0]) == seen_paths.end()) {
+ seen_paths.insert(match[0]);
+ result.push_back(ref);
+ }
+ }
+ return result;
+}
+
+Status HashJoinSchema::TraverseExpression(std::vector<FieldRef>& refs,
+ const Expression& filter,
+ const Schema& schema) {
+ if (filter == literal(true)) return Status::OK();
+ if (auto* call = filter.call()) {
+ for (const Expression& arg : call->arguments)
+ RETURN_NOT_OK(TraverseExpression(refs, arg, schema));
+ } else if (auto* param = filter.parameter()) {
+ if (!param->ref.IsName())
+ return Status::Invalid("Filter parameters to join must be by name");
+ ARROW_ASSIGN_OR_RAISE(auto match, param->ref.FindOneOrNone(schema));
+ if (match != FieldPath()) refs.push_back(param->ref);
+ }
+ return Status::OK();
+}
+
Review comment:
Seems like this could be some kind of utility method as part of
`Expression`.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]