michalursa commented on a change in pull request #11150:
URL: https://github.com/apache/arrow/pull/11150#discussion_r717144846



##########
File path: cpp/src/arrow/compute/exec/schema_util.h
##########
@@ -0,0 +1,197 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/compute/exec/key_encode.h"  // for KeyColumnMetadata
+#include "arrow/type.h"                     // for DataType, FieldRef, Field 
and Schema
+#include "arrow/util/mutex.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace compute {
+
+// Helper class for managing related row schemas.
+// Used to efficiently map any column in one schema to a corresponding column 
in another
+// schema if such exists.
+// Materialized mappings are generated lazily at the time of the first access.
+// Thread-safe apart from initialization.
+//
+template <typename SchemaHandleType>
+class FieldMap {
+ public:
+  static constexpr int kMissingField = -1;
+
+  void RegisterSchema(SchemaHandleType handle, const Schema& schema) {
+    std::vector<FieldInfo> out_fields;
+    const FieldVector& in_fields = schema.fields();
+    out_fields.resize(in_fields.size());
+    for (size_t i = 0; i < in_fields.size(); ++i) {
+      const std::string& name = in_fields[i]->name();
+      const std::shared_ptr<DataType>& type = in_fields[i]->type();
+      out_fields[i].field_ref = FieldRef(name);
+      out_fields[i].data_type = type;
+      out_fields[i].column_metadata = ColumnMetadataFromDataType(type);
+    }
+    schemas_.push_back(std::make_pair(handle, out_fields));
+  }
+
+  Status RegisterProjectedSchema(SchemaHandleType handle,
+                                 const std::vector<FieldRef>& selected_fields,
+                                 const Schema& full_schema) {
+    std::vector<FieldInfo> out_fields;
+    const FieldVector& in_fields = full_schema.fields();
+    out_fields.resize(selected_fields.size());
+    for (size_t i = 0; i < selected_fields.size(); ++i) {
+      // All fields must be found in schema without ambiguity
+      ARROW_ASSIGN_OR_RAISE(auto match, 
selected_fields[i].FindOne(full_schema));
+      const std::string& name = in_fields[match[0]]->name();
+      const std::shared_ptr<DataType>& type = in_fields[match[0]]->type();
+      out_fields[i].field_ref = FieldRef(name);
+      out_fields[i].data_type = type;
+      out_fields[i].column_metadata = ColumnMetadataFromDataType(type);
+    }
+    schemas_.push_back(std::make_pair(handle, out_fields));
+    return Status::OK();
+  }
+
+  void RegisterEnd() {
+    size_t size = schemas_.size();
+    mapping_ptrs_.resize(size * size);
+    mapping_bufs_.resize(size * size);
+  }
+
+  int num_cols(SchemaHandleType schema_handle) const {
+    int id = schema_id(schema_handle);
+    return static_cast<int>(schemas_[id].second.size());
+  }
+
+  const KeyEncoder::KeyColumnMetadata& column_metadata(SchemaHandleType 
schema_handle,
+                                                       int field_id) const {
+    return field(schema_handle, field_id).column_metadata;
+  }
+
+  const FieldRef& field_ref(SchemaHandleType schema_handle, int field_id) 
const {
+    return field(schema_handle, field_id).field_ref;
+  }
+
+  const std::shared_ptr<DataType>& data_type(SchemaHandleType schema_handle,
+                                             int field_id) const {
+    return field(schema_handle, field_id).data_type;
+  }
+
+  const int* map(SchemaHandleType from, SchemaHandleType to) {
+    int id_from = schema_id(from);
+    int id_to = schema_id(to);
+    int num_schemas = static_cast<int>(schemas_.size());
+    int pos = id_from * num_schemas + id_to;
+    const int* ptr = mapping_ptrs_[pos];
+    if (!ptr) {
+      auto guard = mutex_.Lock();  // acquire the lock
+      if (!ptr) {
+        GenerateMap(id_from, id_to);
+      }
+      ptr = mapping_ptrs_[pos];
+    }
+    return ptr;
+  }
+
+ protected:
+  struct FieldInfo {
+    FieldRef field_ref;
+    std::shared_ptr<DataType> data_type;
+    KeyEncoder::KeyColumnMetadata column_metadata;
+  };
+
+  KeyEncoder::KeyColumnMetadata ColumnMetadataFromDataType(
+      const std::shared_ptr<DataType>& type) {
+    if (type->id() == Type::DICTIONARY) {
+      auto bit_width = checked_cast<const FixedWidthType&>(*type).bit_width();
+      ARROW_DCHECK(bit_width % 8 == 0);
+      return KeyEncoder::KeyColumnMetadata(true, bit_width / 8);
+    } else if (type->id() == Type::BOOL) {
+      return KeyEncoder::KeyColumnMetadata(true, 0);
+    } else if (is_fixed_width(type->id())) {
+      return KeyEncoder::KeyColumnMetadata(
+          true, checked_cast<const FixedWidthType&>(*type).bit_width() / 8);
+    } else if (is_binary_like(type->id())) {
+      return KeyEncoder::KeyColumnMetadata(false, sizeof(uint32_t));
+    } else {
+      ARROW_DCHECK(false);
+      return KeyEncoder::KeyColumnMetadata(true, 0);
+    }
+  }
+
+  int schema_id(SchemaHandleType schema_handle) const {
+    for (size_t i = 0; i < schemas_.size(); ++i) {
+      if (schemas_[i].first == schema_handle) {
+        return static_cast<int>(i);
+      }
+    }
+    // We should never get here
+    ARROW_DCHECK(false);
+    return -1;
+  }
+
+  const FieldInfo& field(SchemaHandleType schema_handle, int field_id) const {
+    int id = schema_id(schema_handle);
+    const std::vector<FieldInfo>& field_infos = schemas_[id].second;
+    return field_infos[field_id];
+  }
+
+  void GenerateMap(int id_from, int id_to) {
+    int num_schemas = static_cast<int>(schemas_.size());
+    int pos = id_from * num_schemas + id_to;
+
+    int num_cols_from = static_cast<int>(schemas_[id_from].second.size());
+    int num_cols_to = static_cast<int>(schemas_[id_to].second.size());
+    mapping_bufs_[pos].resize(num_cols_from);
+    const std::vector<FieldInfo>& fields_from = schemas_[id_from].second;
+    const std::vector<FieldInfo>& fields_to = schemas_[id_to].second;
+    for (int i = 0; i < num_cols_from; ++i) {
+      int field_id = kMissingField;
+      for (int j = 0; j < num_cols_to; ++j) {
+        if (fields_from[i].field_ref.Equals(fields_to[j].field_ref) &&
+            fields_from[i].data_type->Equals(*fields_to[j].data_type)) {
+          ARROW_DCHECK(fields_from[i].column_metadata.is_fixed_length ==
+                           fields_to[j].column_metadata.is_fixed_length &&
+                       fields_from[i].column_metadata.fixed_length ==
+                           fields_to[j].column_metadata.fixed_length);
+          field_id = j;
+          break;
+        }
+      }
+      mapping_bufs_[pos][i] = field_id;
+    }
+    mapping_ptrs_[pos] = mapping_bufs_[pos].data();
+  }
+
+  std::vector<int*> mapping_ptrs_;
+  std::vector<std::vector<int>> mapping_bufs_;
+  std::vector<std::pair<SchemaHandleType, std::vector<FieldInfo>>> schemas_;

Review comment:
       comment added

##########
File path: cpp/src/arrow/compute/exec/schema_util.h
##########
@@ -0,0 +1,197 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/compute/exec/key_encode.h"  // for KeyColumnMetadata
+#include "arrow/type.h"                     // for DataType, FieldRef, Field 
and Schema
+#include "arrow/util/mutex.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace compute {
+
+// Helper class for managing related row schemas.
+// Used to efficiently map any column in one schema to a corresponding column 
in another
+// schema if such exists.
+// Materialized mappings are generated lazily at the time of the first access.
+// Thread-safe apart from initialization.
+//
+template <typename SchemaHandleType>

Review comment:
       Changed to ProjectionIdEnum

##########
File path: cpp/src/arrow/compute/exec/schema_util.h
##########
@@ -0,0 +1,197 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "arrow/compute/exec/key_encode.h"  // for KeyColumnMetadata
+#include "arrow/type.h"                     // for DataType, FieldRef, Field 
and Schema
+#include "arrow/util/mutex.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace compute {
+
+// Helper class for managing related row schemas.
+// Used to efficiently map any column in one schema to a corresponding column 
in another
+// schema if such exists.
+// Materialized mappings are generated lazily at the time of the first access.
+// Thread-safe apart from initialization.
+//

Review comment:
       Changed the comment, after narrowing the functionality of this helper 
class and adjusting it accordingly.

##########
File path: cpp/src/arrow/compute/exec/hash_join_node.cc
##########
@@ -0,0 +1,448 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <set>
+
+#include "arrow/compute/exec/exec_plan.h"
+#include "arrow/compute/exec/hash_join.h"
+#include "arrow/compute/exec/options.h"
+#include "arrow/compute/exec/schema_util.h"
+#include "arrow/compute/exec/util.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/future.h"
+#include "arrow/util/make_unique.h"
+#include "arrow/util/thread_pool.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace compute {
+
+std::vector<FieldRef> HashJoinFieldMap::VectorDiff(const 
std::vector<FieldRef>& a,
+                                                   const 
std::vector<FieldRef>& b) {
+  std::vector<FieldRef> result;
+  for (size_t i = 0; i < a.size(); ++i) {
+    bool is_found = false;
+    for (size_t j = 0; j < b.size(); ++j) {
+      if (a[i].Equals(b[j])) {
+        is_found = true;
+        break;
+      }
+    }
+    if (!is_found) {
+      result.push_back(a[i]);
+    }
+  }
+  return result;
+}
+
+Status HashJoinFieldMap::Init(JoinType join_type, const Schema& left_schema,
+                              const std::vector<FieldRef>& left_keys,
+                              const Schema& right_schema,
+                              const std::vector<FieldRef>& right_keys,
+                              const std::string& left_field_name_prefix,
+                              const std::string& right_field_name_prefix) {
+  std::vector<FieldRef> left_output;
+  if (join_type != JoinType::RIGHT_SEMI && join_type != JoinType::RIGHT_ANTI) {
+    const FieldVector& left_fields = left_schema.fields();
+    left_output.resize(left_fields.size());
+    for (size_t i = 0; i < left_fields.size(); ++i) {
+      const std::string& name = left_fields[i]->name();
+      left_output[i] = FieldRef(name);
+    }
+  }
+  // Repeat the same for the right side
+  std::vector<FieldRef> right_output;
+  if (join_type != JoinType::LEFT_SEMI && join_type != JoinType::LEFT_ANTI) {
+    const FieldVector& right_fields = right_schema.fields();
+    right_output.resize(right_fields.size());
+    for (size_t i = 0; i < right_fields.size(); ++i) {
+      const std::string& name = right_fields[i]->name();
+      right_output[i] = FieldRef(name);
+    }
+  }
+  return Init(join_type, left_schema, left_keys, left_output, right_schema, 
right_keys,
+              right_output, left_field_name_prefix, right_field_name_prefix);
+}
+
+Status HashJoinFieldMap::Init(JoinType join_type, const Schema& left_schema,
+                              const std::vector<FieldRef>& left_keys,
+                              const std::vector<FieldRef>& left_output,
+                              const Schema& right_schema,
+                              const std::vector<FieldRef>& right_keys,
+                              const std::vector<FieldRef>& right_output,
+                              const std::string& left_field_name_prefix,
+                              const std::string& right_field_name_prefix) {
+  RETURN_NOT_OK(ValidateSchemas(join_type, left_schema, left_keys, left_output,
+                                right_schema, right_keys, right_output,
+                                left_field_name_prefix, 
right_field_name_prefix));
+
+  RegisterSchema(HashJoinSchemaHandle::FIRST_INPUT, left_schema);
+  RegisterSchema(HashJoinSchemaHandle::SECOND_INPUT, right_schema);
+  RETURN_NOT_OK(
+      RegisterProjectedSchema(HashJoinSchemaHandle::FIRST_KEY, left_keys, 
left_schema));
+  RETURN_NOT_OK(RegisterProjectedSchema(HashJoinSchemaHandle::SECOND_KEY, 
right_keys,
+                                        right_schema));
+  auto left_payload = VectorDiff(left_output, left_keys);
+  RETURN_NOT_OK(RegisterProjectedSchema(HashJoinSchemaHandle::FIRST_PAYLOAD, 
left_payload,
+                                        left_schema));
+  auto right_payload = VectorDiff(right_output, right_keys);
+  RETURN_NOT_OK(RegisterProjectedSchema(HashJoinSchemaHandle::SECOND_PAYLOAD,
+                                        right_payload, right_schema));
+  RETURN_NOT_OK(RegisterProjectedSchema(HashJoinSchemaHandle::FIRST_OUTPUT, 
left_output,
+                                        left_schema));
+  RETURN_NOT_OK(RegisterProjectedSchema(HashJoinSchemaHandle::SECOND_OUTPUT, 
right_output,
+                                        right_schema));
+  RegisterEnd();
+  return Status::OK();
+}
+
+Status HashJoinFieldMap::ValidateSchemas(JoinType join_type, const Schema& 
left_schema,
+                                         const std::vector<FieldRef>& 
left_keys,
+                                         const std::vector<FieldRef>& 
left_output,
+                                         const Schema& right_schema,
+                                         const std::vector<FieldRef>& 
right_keys,
+                                         const std::vector<FieldRef>& 
right_output,
+                                         const std::string& 
left_field_name_prefix,
+                                         const std::string& 
right_field_name_prefix) {
+  // Checks for key fields:
+  // 1. Key field refs must match exactly one input field
+  // 2. Same number of key fields on left and right
+  // 3. At least one key field
+  // 4. Equal data types for corresponding key fields
+  // 5. Dictionary type is not supported in a key field
+  // 6. Some other data types may not be allowed in a key field
+  //
+  if (left_keys.size() != right_keys.size()) {
+    return Status::Invalid("Different number of key fields on left (", 
left_keys.size(),
+                           ") and right (", right_keys.size(), ") side of the 
join");
+  }
+  if (left_keys.size() < 1) {
+    return Status::Invalid("Join key cannot be empty");
+  }
+  for (size_t i = 0; i < left_keys.size() + right_keys.size(); ++i) {
+    bool left_side = i < left_keys.size();
+    const FieldRef& field_ref =
+        left_side ? left_keys[i] : right_keys[i - left_keys.size()];
+    Result<FieldPath> result = field_ref.FindOne(left_side ? left_schema : 
right_schema);
+    if (!result.ok()) {
+      return Status::Invalid("No match or multiple matches for key field 
reference ",
+                             field_ref.ToString(), left_side ? " on left " : " 
on right ",
+                             "side of the join");
+    }
+    const FieldPath& match = result.ValueUnsafe();
+    const std::shared_ptr<DataType>& type =
+        (left_side ? left_schema.fields() : 
right_schema.fields())[match[0]]->type();
+    if (type->id() == Type::DICTIONARY) {
+      return Status::Invalid(
+          "Dictionary type support for join key is not yet implemented, key 
field "
+          "reference: ",
+          field_ref.ToString(), left_side ? " on left " : " on right ",
+          "side of the join");
+    }
+    if ((type->id() != Type::BOOL && !is_fixed_width(type->id()) &&
+         !is_binary_like(type->id())) ||
+        is_large_binary_like(type->id())) {
+      return Status::Invalid("Data type ", type->ToString(),
+                             " is not supported in join key field");
+    }
+  }
+  for (size_t i = 0; i < left_keys.size(); ++i) {
+    const FieldRef& left_ref = left_keys[i];
+    const FieldRef& right_ref = right_keys[i];
+    int left_id = left_ref.FindOne(left_schema).ValueUnsafe()[0];
+    int right_id = right_ref.FindOne(right_schema).ValueUnsafe()[0];
+    const std::shared_ptr<DataType>& left_type = 
left_schema.fields()[left_id]->type();
+    const std::shared_ptr<DataType>& right_type = 
right_schema.fields()[right_id]->type();
+    if (!left_type->Equals(right_type)) {
+      return Status::Invalid("Mismatched data types for corresponding join 
field keys: ",
+                             left_ref.ToString(), " of type ", 
left_type->ToString(),
+                             " and ", right_ref.ToString(), " of type ",
+                             right_type->ToString());
+    }
+  }
+
+  // Check for output fields:
+  // 1. Output field refs must match exactly one input field
+  // 2. At least one output field
+  // 3. Dictionary type is not supported in an output field
+  // 4. Left semi/anti join (right semi/anti join) must not output fields from 
right
+  // (left)
+  // 5. No name collisions in output fields after adding (potentially empty)
+  // prefixes to left and right output
+  //
+  if (left_output.empty() && right_output.empty()) {
+    return Status::Invalid("Join must output at least one field");
+  }
+  if (join_type == JoinType::LEFT_SEMI || join_type == JoinType::LEFT_ANTI) {
+    if (!right_output.empty()) {
+      return Status::Invalid(
+          join_type == JoinType::LEFT_SEMI ? "Left semi join " : "Left 
anti-semi join ",
+          "may not output fields from right side");
+    }
+  }
+  if (join_type == JoinType::RIGHT_SEMI || join_type == JoinType::RIGHT_ANTI) {
+    if (!left_output.empty()) {
+      return Status::Invalid(join_type == JoinType::RIGHT_SEMI ? "Right semi 
join "
+                                                               : "Right 
anti-semi join ",
+                             "may not output fields from left side");
+    }
+  }
+  std::set<std::string> output_field_names;
+  for (size_t i = 0; i < left_output.size() + right_output.size(); ++i) {
+    bool left_side = i < left_output.size();
+    const FieldRef& field_ref =
+        left_side ? left_output[i] : right_output[i - left_output.size()];
+    if (!field_ref.IsName()) {
+      return Status::Invalid("Join output field references must be by name, 
reference ",
+                             field_ref.ToString(), left_side ? " on left " : " 
on right ",
+                             "side of the join");
+    }
+    Result<FieldPath> result = field_ref.FindOne(left_side ? left_schema : 
right_schema);
+    if (!result.ok()) {
+      return Status::Invalid("No match or multiple matches for output field 
reference ",
+                             field_ref.ToString(), left_side ? " on left " : " 
on right ",
+                             "side of the join");
+    }
+    const FieldPath& match = result.ValueUnsafe();
+    const std::shared_ptr<DataType>& type =
+        (left_side ? left_schema.fields() : 
right_schema.fields())[match[0]]->type();
+    if (type->id() == Type::DICTIONARY) {
+      return Status::Invalid(
+          "Dictionary type support for join output field is not yet 
implemented, output "
+          "field reference: ",
+          field_ref.ToString(), left_side ? " on left " : " on right ",
+          "side of the join");
+    }
+    const Field& output_field =
+        *((left_side ? left_schema.fields() : 
right_schema.fields())[match[0]]);
+    std::string output_field_name =
+        (left_side ? left_field_name_prefix : right_field_name_prefix) +
+        output_field.name();
+    if (output_field_names.find(output_field_name) != 
output_field_names.end()) {
+      return Status::Invalid("Output field name collision in join, name: ",
+                             output_field_name);
+    }
+    output_field_names.insert(output_field_name);
+  }
+  return Status::OK();
+}
+
+std::shared_ptr<Schema> HashJoinFieldMap::MakeOutputSchema(
+    const std::string& left_field_name_prefix,
+    const std::string& right_field_name_prefix) {
+  std::vector<std::shared_ptr<Field>> fields;
+  int left_size = num_cols(HashJoinSchemaHandle::FIRST_OUTPUT);
+  int right_size = num_cols(HashJoinSchemaHandle::SECOND_OUTPUT);
+  fields.resize(left_size + right_size);
+
+  for (int i = 0; i < left_size + right_size; ++i) {
+    bool is_left = i < left_size;
+    HashJoinSchemaHandle schema_handle = is_left ? 
HashJoinSchemaHandle::FIRST_OUTPUT
+                                                 : 
HashJoinSchemaHandle::SECOND_OUTPUT;
+    int field_id = is_left ? i : i - left_size;
+    const FieldRef& out_field_ref = field_ref(schema_handle, field_id);
+    const std::shared_ptr<DataType>& out_data_type = data_type(schema_handle, 
field_id);
+
+    // TODO: do we need to support field refs that are not by name?
+    ARROW_DCHECK(out_field_ref.IsName());
+    std::string output_field_name =
+        (is_left ? left_field_name_prefix : right_field_name_prefix) +
+        *out_field_ref.name();
+
+    // all fields coming out of join are marked as nullable, TODO: do we need 
to change
+    // that?
+    fields[i] =
+        std::make_shared<Field>(output_field_name, out_data_type, true 
/*nullable*/);
+  }
+  return std::make_shared<Schema>(std::move(fields));
+}
+
+class HashJoinNode : public ExecNode {
+ public:
+  HashJoinNode(ExecPlan* plan, NodeVector inputs, const HashJoinNodeOptions& 
join_options,
+               std::shared_ptr<Schema> output_schema,
+               std::unique_ptr<HashJoinFieldMap> field_map,
+               std::unique_ptr<HashJoinImpl> impl)
+      : ExecNode(plan, inputs, {"left", "right"},
+                 /*output_schema=*/std::move(output_schema),
+                 /*num_outputs=*/1),
+        join_type_(join_options.join_type),
+        key_cmp_(join_options.key_cmp),
+        field_map_(std::move(field_map)),
+        impl_(std::move(impl)) {
+    complete_.store(false);
+  }
+
+  static Result<ExecNode*> Make(ExecPlan* plan, std::vector<ExecNode*> inputs,
+                                const ExecNodeOptions& options) {
+    // Number of input exec nodes must be 2
+    RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 2, "HashJoinNode"));
+
+    std::unique_ptr<HashJoinFieldMap> field_map =
+        ::arrow::internal::make_unique<HashJoinFieldMap>();
+
+    const auto& join_options = checked_cast<const 
HashJoinNodeOptions&>(options);
+
+    // This will also validate input schemas
+    if (join_options.output_all) {
+      RETURN_NOT_OK(field_map->Init(
+          join_options.join_type, *(inputs[0]->output_schema()), 
join_options.left_keys,
+          *(inputs[1]->output_schema()), join_options.right_keys,
+          join_options.output_prefix_for_left, 
join_options.output_prefix_for_right));
+    } else {
+      RETURN_NOT_OK(field_map->Init(
+          join_options.join_type, *(inputs[0]->output_schema()), 
join_options.left_keys,
+          join_options.left_output, *(inputs[1]->output_schema()),
+          join_options.right_keys, join_options.right_output,
+          join_options.output_prefix_for_left, 
join_options.output_prefix_for_right));
+    }
+
+    // Generate output schema
+    std::shared_ptr<Schema> output_schema = field_map->MakeOutputSchema(
+        join_options.output_prefix_for_left, 
join_options.output_prefix_for_right);
+
+    // Create hash join implementation object
+    ARROW_ASSIGN_OR_RAISE(std::unique_ptr<HashJoinImpl> impl, 
HashJoinImpl::MakeBasic());
+
+    return plan->EmplaceNode<HashJoinNode>(plan, inputs, join_options,
+                                           std::move(output_schema), 
std::move(field_map),
+                                           std::move(impl));
+  }
+
+  const char* kind_name() override { return "HashJoinNode"; }
+
+  void InputReceived(ExecNode* input, ExecBatch batch) override {
+    ARROW_DCHECK(std::find(inputs_.begin(), inputs_.end(), input) != 
inputs_.end());
+
+    if (finished_.is_finished()) {
+      return;
+    }
+
+    size_t thread_index = thread_indexer_();
+    int side = (input == inputs_[0]) ? 0 : 1;
+    ErrorIfNotOk(impl_->InputReceived(thread_index, side, std::move(batch)));

Review comment:
       I added returns after ErrorIfNotOk calls.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to