michalursa commented on a change in pull request #11150: URL: https://github.com/apache/arrow/pull/11150#discussion_r717145397
########## File path: cpp/src/arrow/compute/exec/hash_join_node.cc ########## @@ -0,0 +1,448 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <set> + +#include "arrow/compute/exec/exec_plan.h" +#include "arrow/compute/exec/hash_join.h" +#include "arrow/compute/exec/options.h" +#include "arrow/compute/exec/schema_util.h" +#include "arrow/compute/exec/util.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/future.h" +#include "arrow/util/make_unique.h" +#include "arrow/util/thread_pool.h" + +namespace arrow { + +using internal::checked_cast; + +namespace compute { + +std::vector<FieldRef> HashJoinFieldMap::VectorDiff(const std::vector<FieldRef>& a, + const std::vector<FieldRef>& b) { + std::vector<FieldRef> result; + for (size_t i = 0; i < a.size(); ++i) { + bool is_found = false; + for (size_t j = 0; j < b.size(); ++j) { + if (a[i].Equals(b[j])) { + is_found = true; + break; + } + } + if (!is_found) { + result.push_back(a[i]); + } + } + return result; +} + +Status HashJoinFieldMap::Init(JoinType join_type, const Schema& left_schema, + const std::vector<FieldRef>& left_keys, + const Schema& right_schema, + const std::vector<FieldRef>& right_keys, + const std::string& left_field_name_prefix, + const std::string& right_field_name_prefix) { + std::vector<FieldRef> left_output; + if (join_type != JoinType::RIGHT_SEMI && join_type != JoinType::RIGHT_ANTI) { + const FieldVector& left_fields = left_schema.fields(); + left_output.resize(left_fields.size()); + for (size_t i = 0; i < left_fields.size(); ++i) { + const std::string& name = left_fields[i]->name(); + left_output[i] = FieldRef(name); + } + } + // Repeat the same for the right side + std::vector<FieldRef> right_output; + if (join_type != JoinType::LEFT_SEMI && join_type != JoinType::LEFT_ANTI) { + const FieldVector& right_fields = right_schema.fields(); + right_output.resize(right_fields.size()); + for (size_t i = 0; i < right_fields.size(); ++i) { + const std::string& name = right_fields[i]->name(); + right_output[i] = FieldRef(name); + } + } + return Init(join_type, left_schema, left_keys, left_output, right_schema, right_keys, + right_output, left_field_name_prefix, right_field_name_prefix); +} + +Status HashJoinFieldMap::Init(JoinType join_type, const Schema& left_schema, + const std::vector<FieldRef>& left_keys, + const std::vector<FieldRef>& left_output, + const Schema& right_schema, + const std::vector<FieldRef>& right_keys, + const std::vector<FieldRef>& right_output, + const std::string& left_field_name_prefix, + const std::string& right_field_name_prefix) { + RETURN_NOT_OK(ValidateSchemas(join_type, left_schema, left_keys, left_output, + right_schema, right_keys, right_output, + left_field_name_prefix, right_field_name_prefix)); + + RegisterSchema(HashJoinSchemaHandle::FIRST_INPUT, left_schema); + RegisterSchema(HashJoinSchemaHandle::SECOND_INPUT, right_schema); + RETURN_NOT_OK( + RegisterProjectedSchema(HashJoinSchemaHandle::FIRST_KEY, left_keys, left_schema)); + RETURN_NOT_OK(RegisterProjectedSchema(HashJoinSchemaHandle::SECOND_KEY, right_keys, + right_schema)); + auto left_payload = VectorDiff(left_output, left_keys); + RETURN_NOT_OK(RegisterProjectedSchema(HashJoinSchemaHandle::FIRST_PAYLOAD, left_payload, + left_schema)); + auto right_payload = VectorDiff(right_output, right_keys); + RETURN_NOT_OK(RegisterProjectedSchema(HashJoinSchemaHandle::SECOND_PAYLOAD, + right_payload, right_schema)); + RETURN_NOT_OK(RegisterProjectedSchema(HashJoinSchemaHandle::FIRST_OUTPUT, left_output, + left_schema)); + RETURN_NOT_OK(RegisterProjectedSchema(HashJoinSchemaHandle::SECOND_OUTPUT, right_output, + right_schema)); + RegisterEnd(); + return Status::OK(); +} + +Status HashJoinFieldMap::ValidateSchemas(JoinType join_type, const Schema& left_schema, + const std::vector<FieldRef>& left_keys, + const std::vector<FieldRef>& left_output, + const Schema& right_schema, + const std::vector<FieldRef>& right_keys, + const std::vector<FieldRef>& right_output, + const std::string& left_field_name_prefix, + const std::string& right_field_name_prefix) { + // Checks for key fields: + // 1. Key field refs must match exactly one input field + // 2. Same number of key fields on left and right + // 3. At least one key field + // 4. Equal data types for corresponding key fields + // 5. Dictionary type is not supported in a key field + // 6. Some other data types may not be allowed in a key field + // + if (left_keys.size() != right_keys.size()) { + return Status::Invalid("Different number of key fields on left (", left_keys.size(), + ") and right (", right_keys.size(), ") side of the join"); + } + if (left_keys.size() < 1) { + return Status::Invalid("Join key cannot be empty"); + } + for (size_t i = 0; i < left_keys.size() + right_keys.size(); ++i) { + bool left_side = i < left_keys.size(); + const FieldRef& field_ref = + left_side ? left_keys[i] : right_keys[i - left_keys.size()]; + Result<FieldPath> result = field_ref.FindOne(left_side ? left_schema : right_schema); + if (!result.ok()) { + return Status::Invalid("No match or multiple matches for key field reference ", + field_ref.ToString(), left_side ? " on left " : " on right ", + "side of the join"); + } + const FieldPath& match = result.ValueUnsafe(); + const std::shared_ptr<DataType>& type = + (left_side ? left_schema.fields() : right_schema.fields())[match[0]]->type(); + if (type->id() == Type::DICTIONARY) { + return Status::Invalid( + "Dictionary type support for join key is not yet implemented, key field " + "reference: ", + field_ref.ToString(), left_side ? " on left " : " on right ", + "side of the join"); + } + if ((type->id() != Type::BOOL && !is_fixed_width(type->id()) && + !is_binary_like(type->id())) || + is_large_binary_like(type->id())) { + return Status::Invalid("Data type ", type->ToString(), + " is not supported in join key field"); + } + } + for (size_t i = 0; i < left_keys.size(); ++i) { + const FieldRef& left_ref = left_keys[i]; + const FieldRef& right_ref = right_keys[i]; + int left_id = left_ref.FindOne(left_schema).ValueUnsafe()[0]; + int right_id = right_ref.FindOne(right_schema).ValueUnsafe()[0]; + const std::shared_ptr<DataType>& left_type = left_schema.fields()[left_id]->type(); + const std::shared_ptr<DataType>& right_type = right_schema.fields()[right_id]->type(); + if (!left_type->Equals(right_type)) { + return Status::Invalid("Mismatched data types for corresponding join field keys: ", + left_ref.ToString(), " of type ", left_type->ToString(), + " and ", right_ref.ToString(), " of type ", + right_type->ToString()); + } + } + + // Check for output fields: + // 1. Output field refs must match exactly one input field + // 2. At least one output field + // 3. Dictionary type is not supported in an output field + // 4. Left semi/anti join (right semi/anti join) must not output fields from right + // (left) + // 5. No name collisions in output fields after adding (potentially empty) + // prefixes to left and right output + // + if (left_output.empty() && right_output.empty()) { + return Status::Invalid("Join must output at least one field"); + } + if (join_type == JoinType::LEFT_SEMI || join_type == JoinType::LEFT_ANTI) { + if (!right_output.empty()) { + return Status::Invalid( + join_type == JoinType::LEFT_SEMI ? "Left semi join " : "Left anti-semi join ", + "may not output fields from right side"); + } + } + if (join_type == JoinType::RIGHT_SEMI || join_type == JoinType::RIGHT_ANTI) { + if (!left_output.empty()) { + return Status::Invalid(join_type == JoinType::RIGHT_SEMI ? "Right semi join " + : "Right anti-semi join ", + "may not output fields from left side"); + } + } + std::set<std::string> output_field_names; + for (size_t i = 0; i < left_output.size() + right_output.size(); ++i) { + bool left_side = i < left_output.size(); + const FieldRef& field_ref = + left_side ? left_output[i] : right_output[i - left_output.size()]; + if (!field_ref.IsName()) { + return Status::Invalid("Join output field references must be by name, reference ", + field_ref.ToString(), left_side ? " on left " : " on right ", + "side of the join"); + } + Result<FieldPath> result = field_ref.FindOne(left_side ? left_schema : right_schema); + if (!result.ok()) { + return Status::Invalid("No match or multiple matches for output field reference ", + field_ref.ToString(), left_side ? " on left " : " on right ", + "side of the join"); + } + const FieldPath& match = result.ValueUnsafe(); + const std::shared_ptr<DataType>& type = + (left_side ? left_schema.fields() : right_schema.fields())[match[0]]->type(); + if (type->id() == Type::DICTIONARY) { + return Status::Invalid( + "Dictionary type support for join output field is not yet implemented, output " + "field reference: ", + field_ref.ToString(), left_side ? " on left " : " on right ", + "side of the join"); + } + const Field& output_field = + *((left_side ? left_schema.fields() : right_schema.fields())[match[0]]); + std::string output_field_name = + (left_side ? left_field_name_prefix : right_field_name_prefix) + + output_field.name(); + if (output_field_names.find(output_field_name) != output_field_names.end()) { + return Status::Invalid("Output field name collision in join, name: ", + output_field_name); + } + output_field_names.insert(output_field_name); + } + return Status::OK(); +} + +std::shared_ptr<Schema> HashJoinFieldMap::MakeOutputSchema( + const std::string& left_field_name_prefix, + const std::string& right_field_name_prefix) { + std::vector<std::shared_ptr<Field>> fields; + int left_size = num_cols(HashJoinSchemaHandle::FIRST_OUTPUT); + int right_size = num_cols(HashJoinSchemaHandle::SECOND_OUTPUT); + fields.resize(left_size + right_size); + + for (int i = 0; i < left_size + right_size; ++i) { + bool is_left = i < left_size; + HashJoinSchemaHandle schema_handle = is_left ? HashJoinSchemaHandle::FIRST_OUTPUT + : HashJoinSchemaHandle::SECOND_OUTPUT; + int field_id = is_left ? i : i - left_size; + const FieldRef& out_field_ref = field_ref(schema_handle, field_id); + const std::shared_ptr<DataType>& out_data_type = data_type(schema_handle, field_id); + + // TODO: do we need to support field refs that are not by name? + ARROW_DCHECK(out_field_ref.IsName()); + std::string output_field_name = + (is_left ? left_field_name_prefix : right_field_name_prefix) + + *out_field_ref.name(); + + // all fields coming out of join are marked as nullable, TODO: do we need to change + // that? + fields[i] = + std::make_shared<Field>(output_field_name, out_data_type, true /*nullable*/); + } + return std::make_shared<Schema>(std::move(fields)); +} + +class HashJoinNode : public ExecNode { + public: + HashJoinNode(ExecPlan* plan, NodeVector inputs, const HashJoinNodeOptions& join_options, + std::shared_ptr<Schema> output_schema, + std::unique_ptr<HashJoinFieldMap> field_map, + std::unique_ptr<HashJoinImpl> impl) + : ExecNode(plan, inputs, {"left", "right"}, + /*output_schema=*/std::move(output_schema), + /*num_outputs=*/1), + join_type_(join_options.join_type), + key_cmp_(join_options.key_cmp), + field_map_(std::move(field_map)), + impl_(std::move(impl)) { + complete_.store(false); + } + + static Result<ExecNode*> Make(ExecPlan* plan, std::vector<ExecNode*> inputs, + const ExecNodeOptions& options) { + // Number of input exec nodes must be 2 + RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 2, "HashJoinNode")); + + std::unique_ptr<HashJoinFieldMap> field_map = + ::arrow::internal::make_unique<HashJoinFieldMap>(); + + const auto& join_options = checked_cast<const HashJoinNodeOptions&>(options); + + // This will also validate input schemas + if (join_options.output_all) { + RETURN_NOT_OK(field_map->Init( + join_options.join_type, *(inputs[0]->output_schema()), join_options.left_keys, + *(inputs[1]->output_schema()), join_options.right_keys, + join_options.output_prefix_for_left, join_options.output_prefix_for_right)); + } else { + RETURN_NOT_OK(field_map->Init( + join_options.join_type, *(inputs[0]->output_schema()), join_options.left_keys, + join_options.left_output, *(inputs[1]->output_schema()), + join_options.right_keys, join_options.right_output, + join_options.output_prefix_for_left, join_options.output_prefix_for_right)); + } + + // Generate output schema + std::shared_ptr<Schema> output_schema = field_map->MakeOutputSchema( + join_options.output_prefix_for_left, join_options.output_prefix_for_right); + + // Create hash join implementation object + ARROW_ASSIGN_OR_RAISE(std::unique_ptr<HashJoinImpl> impl, HashJoinImpl::MakeBasic()); + + return plan->EmplaceNode<HashJoinNode>(plan, inputs, join_options, + std::move(output_schema), std::move(field_map), + std::move(impl)); + } + + const char* kind_name() override { return "HashJoinNode"; } + + void InputReceived(ExecNode* input, ExecBatch batch) override { + ARROW_DCHECK(std::find(inputs_.begin(), inputs_.end(), input) != inputs_.end()); + + if (finished_.is_finished()) { + return; + } + + size_t thread_index = thread_indexer_(); + int side = (input == inputs_[0]) ? 0 : 1; + ErrorIfNotOk(impl_->InputReceived(thread_index, side, std::move(batch))); Review comment: I added returns after ErrorIfNotOk calls. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
