bkietz commented on a change in pull request #11579:
URL: https://github.com/apache/arrow/pull/11579#discussion_r745745837



##########
File path: cpp/src/arrow/compute/exec/hash_join_node.cc
##########
@@ -43,32 +43,49 @@ bool HashJoinSchema::IsTypeSupported(const DataType& type) {
   return is_fixed_width(id) || is_binary_like(id) || is_large_binary_like(id);
 }
 
-Result<std::vector<FieldRef>> HashJoinSchema::VectorDiff(const Schema& schema,
-                                                         const 
std::vector<FieldRef>& a,
-                                                         const 
std::vector<FieldRef>& b) {
-  std::unordered_set<int> b_paths;
-  for (size_t i = 0; i < b.size(); ++i) {
-    ARROW_ASSIGN_OR_RAISE(auto match, b[i].FindOne(schema));
-    b_paths.insert(match[0]);
+Result<std::vector<FieldRef>> HashJoinSchema::ComputePayload(
+    const Schema& schema, const std::vector<FieldRef>& output,
+    const std::vector<FieldRef>& filter, const std::vector<FieldRef>& keys) {
+  // payload = (output + filter) - keys, with no duplicates

Review comment:
       IR requires that fields be referenced by index, so it would be up to a 
producer to decide how to navigate duplicate field names.
   
   

##########
File path: cpp/src/arrow/compute/exec/hash_join_node.cc
##########
@@ -274,17 +303,83 @@ std::shared_ptr<Schema> HashJoinSchema::MakeOutputSchema(
   return std::make_shared<Schema>(std::move(fields));
 }
 
+Result<Expression> HashJoinSchema::BindFilter(Expression filter,
+                                              const Schema& left_schema,
+                                              const Schema& right_schema) {
+  if (filter.IsBound()) {
+    return std::move(filter);
+  }
+  if (filter != literal(true)) {
+    FieldVector fields;
+    auto left = proj_maps[0].map(HashJoinProjection::FILTER, 
HashJoinProjection::INPUT);
+    auto right = proj_maps[1].map(HashJoinProjection::FILTER, 
HashJoinProjection::INPUT);
+
+    auto AppendFieldsInMap = [&fields](const SchemaProjectionMap& map,
+                                       const Schema& schema) {
+      for (int i = 0; i < map.num_cols; i++) {
+        int input_idx = map.get(i);
+        fields.push_back(schema.fields()[input_idx]);
+      }
+    };
+    AppendFieldsInMap(left, left_schema);
+    AppendFieldsInMap(right, right_schema);
+    Schema filter_schema(fields);
+    ARROW_ASSIGN_OR_RAISE(filter, filter.Bind(filter_schema));
+    if (filter.type()->id() != Type::BOOL) {
+      return Status::TypeError("Filter expression must evaluate to bool, but ",
+                               filter.ToString(), " evaluates to ",
+                               filter.type()->ToString());
+    }
+    return std::move(filter);
+  }
+  return literal(true);
+}
+
+Result<std::vector<FieldRef>> HashJoinSchema::CollectFilterColumns(
+    const Expression& filter, const Schema& schema) {
+  std::vector<FieldRef> nonunique_refs;
+  RETURN_NOT_OK(TraverseExpression(nonunique_refs, filter, schema));
+
+  std::vector<FieldRef> result;
+  std::unordered_set<int> seen_paths;
+  for (auto ref : nonunique_refs) {
+    ARROW_ASSIGN_OR_RAISE(auto match, ref.FindOne(schema));
+    if (seen_paths.find(match[0]) == seen_paths.end()) {
+      seen_paths.insert(match[0]);
+      result.push_back(ref);
+    }
+  }
+  return result;
+}
+
+Status HashJoinSchema::TraverseExpression(std::vector<FieldRef>& refs,
+                                          const Expression& filter,
+                                          const Schema& schema) {
+  if (filter == literal(true)) return Status::OK();
+  if (auto* call = filter.call()) {
+    for (const Expression& arg : call->arguments)
+      RETURN_NOT_OK(TraverseExpression(refs, arg, schema));
+  } else if (auto* param = filter.parameter()) {
+    if (!param->ref.IsName())
+      return Status::Invalid("Filter parameters to join must be by name");
+    ARROW_ASSIGN_OR_RAISE(auto match, param->ref.FindOneOrNone(schema));
+    if (match != FieldPath()) refs.push_back(param->ref);
+  }
+  return Status::OK();
+}
+

Review comment:
       It is: `FieldsInExpression`




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to