rtpsw commented on code in PR #13880:
URL: https://github.com/apache/arrow/pull/13880#discussion_r954671378


##########
cpp/src/arrow/compute/exec/asof_join_node.cc:
##########
@@ -602,52 +828,155 @@ class AsofJoinNode : public ExecNode {
 
  public:
   AsofJoinNode(ExecPlan* plan, NodeVector inputs, std::vector<std::string> 
input_labels,
-               const AsofJoinNodeOptions& join_options,
+               const vec_col_index_t& indices_of_on_key,
+               const std::vector<vec_col_index_t>& indices_of_by_key, OnType 
tolerance,
                std::shared_ptr<Schema> output_schema);
 
+  Status InternalInit(bool must_hash, bool nullable_by_key,
+                      std::vector<std::unique_ptr<KeyHasher>> key_hashers) {
+    key_hashers_.swap(key_hashers);
+    auto inputs = this->inputs();
+    for (size_t i = 0; i < inputs.size(); ++i) {
+      state_.push_back(::arrow::internal::make_unique<InputState>(
+          must_hash, nullable_by_key, key_hashers_[i].get(), 
inputs[i]->output_schema(),
+          indices_of_on_key_[i], indices_of_by_key_[i]));
+    }
+
+    col_index_t dst_offset = 0;
+    for (auto& state : state_)
+      dst_offset = state->InitSrcToDstMapping(dst_offset, !!dst_offset);
+
+    return Status::OK();
+  }
+
   virtual ~AsofJoinNode() {
     process_.Push(false);  // poison pill
     process_thread_.join();
   }
 
+  const vec_col_index_t& indices_of_on_key() { return indices_of_on_key_; }
+  const std::vector<vec_col_index_t>& indices_of_by_key() { return 
indices_of_by_key_; }
+
   static arrow::Result<std::shared_ptr<Schema>> MakeOutputSchema(
-      const std::vector<ExecNode*>& inputs, const AsofJoinNodeOptions& 
options) {
+      const std::vector<ExecNode*>& inputs, const vec_col_index_t& 
indices_of_on_key,
+      const std::vector<vec_col_index_t>& indices_of_by_key) {
     std::vector<std::shared_ptr<arrow::Field>> fields;
 
-    const auto& on_field_name = *options.on_key.name();
-    const auto& by_field_name = *options.by_key.name();
-
+    size_t n_by = indices_of_by_key[0].size();
+    const DataType* on_key_type = NULLPTR;
+    std::vector<const DataType*> by_key_type(n_by, NULLPTR);
     // Take all non-key, non-time RHS fields
     for (size_t j = 0; j < inputs.size(); ++j) {
       const auto& input_schema = inputs[j]->output_schema();
-      const auto& on_field_ix = input_schema->GetFieldIndex(on_field_name);
-      const auto& by_field_ix = input_schema->GetFieldIndex(by_field_name);
+      const auto& on_field_ix = indices_of_on_key[j];
+      const auto& by_field_ix = indices_of_by_key[j];
 
-      if ((on_field_ix == -1) | (by_field_ix == -1)) {
+      if ((on_field_ix == -1) || std_has(by_field_ix, -1)) {
         return Status::Invalid("Missing join key on table ", j);
       }
 
+      const auto& on_field = input_schema->fields()[on_field_ix];
+      std::vector<const Field*> by_field(n_by);
+      for (size_t k = 0; k < n_by; k++) {
+        by_field[k] = input_schema->fields()[by_field_ix[k]].get();
+      }
+
+      if (on_key_type == NULLPTR) {
+        on_key_type = on_field->type().get();
+      } else if (*on_key_type != *on_field->type()) {
+        return Status::Invalid("Expected on-key type ", *on_key_type, " but 
got ",
+                               *on_field->type(), " for field ", 
on_field->name(),
+                               " in input ", j);
+      }
+      for (size_t k = 0; k < n_by; k++) {
+        if (by_key_type[k] == NULLPTR) {
+          by_key_type[k] = by_field[k]->type().get();
+        } else if (*by_key_type[k] != *by_field[k]->type()) {
+          return Status::Invalid("Expected on-key type ", *by_key_type[k], " 
but got ",
+                                 *by_field[k]->type(), " for field ", 
by_field[k]->name(),
+                                 " in input ", j);
+        }
+      }
+
       for (int i = 0; i < input_schema->num_fields(); ++i) {
         const auto field = input_schema->field(i);
-        if (field->name() == on_field_name) {
-          if (kSupportedOnTypes_.find(field->type()) == 
kSupportedOnTypes_.end()) {
-            return Status::Invalid("Unsupported type for on key: ", 
field->name());
+        if (i == on_field_ix) {

Review Comment:
   Will do.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to