This is an automated email from the ASF dual-hosted git repository.

raulcd pushed a commit to branch maint-15.0.x
in repository https://gitbox.apache.org/repos/asf/arrow.git

commit 0e9bd55b6584441fa078337728d703c9dc1c2049
Author: Jeremy Aguilon <[email protected]>
AuthorDate: Mon Feb 19 09:54:57 2024 -0500

    GH-39803: [C++][Acero] Fix AsOfJoin with differently ordered schemas than 
the output (#39804)
    
    ### Rationale for this change
    
    Issue is described visually in https://github.com/apache/arrow/issues/39803.
    
    The key hasher works by hashing every row of the input tables' key columns. 
An important step is inspecting the [column 
metadata](https://github.com/apache/arrow/blob/main/cpp/src/arrow/acero/asof_join_node.cc#L412)
 for the asof-join key fields. This returns whether columns are fixed width, 
among other things.
    
    The issue is we are passing the `output_schema`, rather than the input's 
schema.
    
    If an input looks like
    
    ```
    key_string_type,ts_int32_type,val
    ```
    
    But our expected output schema looks like:
    
    ```
    ts_int32,key_string_type,...
    ```
    Then the hasher will think that the `key_string_type`'s type is an int32. 
This completely throws off hashes. Tests currently get away with it since we 
just use ints across the board.
    
    ### What changes are included in this PR?
    
    One line fix and test with string types.
    
    ### Are these changes tested?
    
    Yes. Can see the test run before and after changes here: 
https://gist.github.com/JerAguilon/953d82ed288d58f9ce24d1a925def2cc
    
    Before the change, notice that inputs 0 and 1 have mismatched hashes:
    
    ```
    AsofjoinNode(0x16cf9e2d8): key hasher 1 got hashes [0, 9784892099856512926, 
1050982531982388796, 10763536662319179482, 2029627098739957112, 
11814237723602982167, 3080328155728858293, 12792882290360550483, 
4058972722486426609, 13771526852823217039]
    ...
    AsofjoinNode(0x16cf9dd18): key hasher 0 got hashes [17528465654998409509, 
12047706865972860560, 18017664240540048750, 12358837084497432044, 
8151160321586084686, 8691136767698756332, 15973065724125580046, 
9654919479117127288, 618127929167745505, 3403805303373270709]
    
    ```
    
    And after, they do match:
    
    ```
    AsofjoinNode(0x16f2ea2d8): key hasher 1 got hashes [17528465654998409509, 
12047706865972860560, 18017664240540048750, 12358837084497432044, 
8151160321586084686, 8691136767698756332, 15973065724125580046, 
9654919479117127288, 618127929167745505, 3403805303373270709]
    ...
    AsofjoinNode(0x16f2e9d18): key hasher 0 got hashes [17528465654998409509, 
12047706865972860560, 18017664240540048750, 12358837084497432044, 
8151160321586084686, 8691136767698756332, 15973065724125580046, 
9654919479117127288, 618127929167745505, 3403805303373270709]
    ```
    
    ...which is exactly what you want, since the `key` column for both tables 
looks like `["0", "1", ..."9"]`
    
    ### Are there any user-facing changes?
    
    * Closes: #39803
    
    Lead-authored-by: Jeremy Aguilon <[email protected]>
    Co-authored-by: Antoine Pitrou <[email protected]>
    Signed-off-by: Antoine Pitrou <[email protected]>
---
 cpp/src/arrow/acero/asof_join_node.cc      |  2 +-
 cpp/src/arrow/acero/asof_join_node_test.cc | 64 ++++++++++++++++++++++++++++++
 2 files changed, 65 insertions(+), 1 deletion(-)

diff --git a/cpp/src/arrow/acero/asof_join_node.cc 
b/cpp/src/arrow/acero/asof_join_node.cc
index 2609905a0b..e96d5ad44a 100644
--- a/cpp/src/arrow/acero/asof_join_node.cc
+++ b/cpp/src/arrow/acero/asof_join_node.cc
@@ -1098,7 +1098,7 @@ class AsofJoinNode : public ExecNode {
     auto inputs = this->inputs();
     for (size_t i = 0; i < inputs.size(); i++) {
       
RETURN_NOT_OK(key_hashers_[i]->Init(plan()->query_context()->exec_context(),
-                                          output_schema()));
+                                          inputs[i]->output_schema()));
       ARROW_ASSIGN_OR_RAISE(
           auto input_state,
           InputState::Make(i, tolerance_, must_hash_, may_rehash_, 
key_hashers_[i].get(),
diff --git a/cpp/src/arrow/acero/asof_join_node_test.cc 
b/cpp/src/arrow/acero/asof_join_node_test.cc
index e400cc0316..d95d2aaad3 100644
--- a/cpp/src/arrow/acero/asof_join_node_test.cc
+++ b/cpp/src/arrow/acero/asof_join_node_test.cc
@@ -1582,6 +1582,70 @@ TEST(AsofJoinTest, BatchSequencing) {
   return TestSequencing(MakeIntegerBatches, /*num_batches=*/32, 
/*batch_size=*/1);
 }
 
+template <typename BatchesMaker>
+void TestSchemaResolution(BatchesMaker maker, int num_batches, int batch_size) 
{
+  // GH-39803: The key hasher needs to resolve the types of key columns. All 
other
+  // tests use int32 for all columns, but this test converts the key columns to
+  // strings via a projection node to test that the column is correctly 
resolved
+  // to string.
+  auto l_schema =
+      schema({field("time", int32()), field("key", int32()), field("l_value", 
int32())});
+  auto r_schema =
+      schema({field("time", int32()), field("key", int32()), field("r0_value", 
int32())});
+
+  auto make_shift = [&maker, num_batches, batch_size](
+                        const std::shared_ptr<Schema>& schema, int shift) {
+    return maker({[](int row) -> int64_t { return row; },
+                  [num_batches](int row) -> int64_t { return row / 
num_batches; },
+                  [shift](int row) -> int64_t { return row * 10 + shift; }},
+                 schema, num_batches, batch_size);
+  };
+  ASSERT_OK_AND_ASSIGN(auto l_batches, make_shift(l_schema, 0));
+  ASSERT_OK_AND_ASSIGN(auto r_batches, make_shift(r_schema, 1));
+
+  Declaration l_src = {"source",
+                       SourceNodeOptions(l_schema, l_batches.gen(false, 
false))};
+  Declaration r_src = {"source",
+                       SourceNodeOptions(r_schema, r_batches.gen(false, 
false))};
+  Declaration l_project = {
+      "project",
+      {std::move(l_src)},
+      ProjectNodeOptions({compute::field_ref("time"),
+                          compute::call("cast", {compute::field_ref("key")},
+                                        compute::CastOptions::Safe(utf8())),
+                          compute::field_ref("l_value")},
+                         {"time", "key", "l_value"})};
+  Declaration r_project = {
+      "project",
+      {std::move(r_src)},
+      ProjectNodeOptions({compute::call("cast", {compute::field_ref("key")},
+                                        compute::CastOptions::Safe(utf8())),
+                          compute::field_ref("r0_value"), 
compute::field_ref("time")},
+                         {"key", "r0_value", "time"})};
+
+  Declaration asofjoin = {
+      "asofjoin", {l_project, r_project}, GetRepeatedOptions(2, "time", 
{"key"}, 1000)};
+
+  QueryOptions query_options;
+  query_options.use_threads = false;
+  ASSERT_OK_AND_ASSIGN(auto table, DeclarationToTable(asofjoin, 
query_options));
+
+  Int32Builder expected_r0_b;
+  for (int i = 1; i <= 91; i += 10) {
+    ASSERT_OK(expected_r0_b.Append(i));
+  }
+  ASSERT_OK_AND_ASSIGN(auto expected_r0, expected_r0_b.Finish());
+
+  auto actual_r0 = table->GetColumnByName("r0_value");
+  std::vector<std::shared_ptr<arrow::Array>> chunks = {expected_r0};
+  auto expected_r0_chunked = std::make_shared<arrow::ChunkedArray>(chunks);
+  ASSERT_TRUE(actual_r0->Equals(expected_r0_chunked));
+}
+
+TEST(AsofJoinTest, OutputSchemaResolution) {
+  return TestSchemaResolution(MakeIntegerBatches, /*num_batches=*/1, 
/*batch_size=*/10);
+}
+
 namespace {
 
 Result<AsyncGenerator<std::optional<ExecBatch>>> MakeIntegerBatchGenForTest(

Reply via email to