This is an automated email from the ASF dual-hosted git repository.

zanmato pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new 30adc91e74 GH-45611: [C++][Acero] Improve Swiss join build performance 
by partitioning batches ahead to reduce contention (#45612)
30adc91e74 is described below

commit 30adc91e740dfbfe1d018169c0f1eb4df0aad4c4
Author: Rossi Sun <[email protected]>
AuthorDate: Tue Mar 4 19:47:46 2025 +0800

    GH-45611: [C++][Acero] Improve Swiss join build performance by partitioning 
batches ahead to reduce contention (#45612)
    
    ### Rationale for this change
    
    High contention is observed in Swiss join build phase as showed in #45611 .
    
    A little background about the contention. To build the hash table in 
parallel, we first build `N` partitioned hash tables (the "build" stage), then 
merge them together into the final hash table (the "merge" stage, less 
interesting in this PR). In the build stage, each one of the exec batches from 
the build side table is distributed to one of the `M` threads. Each such thread 
processes each one of the assigned batches by:
    1. Partition the batch based on the hash of the join key into `N` 
partitions;
    2. Insert the rows of each of the `N` partitions into the corresponding one 
of the `N` partitioned hash tables.
    
    Because each batch contains arbitrary data, all `M` threads will write to 
all `N` partitioned hash tables simultaneously. So we use (spin) locks on these 
partitioned hash tables, thus the contention.
    
    ### What changes are included in this PR?
    
    Instead of all `M` threads writing to all `N` partitioned hash tables 
simultaneously, we can further split the build stage into two:
    1. Partition stage: `M` threads, each only partitions the batches and 
preserves the partition info of each batch;
    2. (New) Build stage: `N` threads, each builds one of the `N` partitioned 
hash tables. Every thread will iterate all the batches and only insert the 
belonging rows of the batch into its assigned hash table.
    
    #### Performance
    
    Take [this 
benchmark](https://github.com/apache/arrow/blob/31994b5c2069a768e70fba16d1f521e4de64139e/cpp/src/arrow/acero/hash_join_benchmark.cc#L301),
 which is dedicated for the performance of parallel build, the result shows by 
eliminating the contention, we can achieve up to **10x** (on Arm) and **5x** 
(on Intel) performance boost for Swiss join build. I picked `krows=64` and 
`krows=512` and made a chart.
    
    ![Arm 
(1)](https://github.com/user-attachments/assets/21e8f198-9e47-46c9-a04b-7f24105968a1)
    
    
![Intel](https://github.com/user-attachments/assets/b61fb614-8422-4adc-b57c-2f83b7a7637b)
    
    Note the single thread performance is actually down a little bit (reasons 
detailed later). But IMO this is quite trivial compared to the total win of 
multi-threaded cases.
    
    Detailed benchmark numbers (on Arm) follow.
    
    <details>
    <summary>Benchmark After (Click to expand)</summary>
    
    ```
    Run on (10 X 24.1216 MHz CPU s)
    CPU Caches:
      L1 Data 64 KiB
      L1 Instruction 128 KiB
      L2 Unified 4096 KiB (x10)
    Load Average: 3.47, 2.76, 2.54
    
-----------------------------------------------------------------------------------------------------------------------------------------
    Benchmark                                                                   
            Time             CPU   Iterations UserCounters...
    
-----------------------------------------------------------------------------------------------------------------------------------------
    BM_HashJoinBasic_BuildParallelism/Threads:1/HashTable krows:1/process_time  
        53315 ns        53284 ns        12295 rows/sec=19.2179M/s
    BM_HashJoinBasic_BuildParallelism/Threads:2/HashTable krows:1/process_time  
        73001 ns        80862 ns         8606 rows/sec=12.6636M/s
    BM_HashJoinBasic_BuildParallelism/Threads:3/HashTable krows:1/process_time  
        88003 ns        95127 ns         7429 rows/sec=10.7645M/s
    BM_HashJoinBasic_BuildParallelism/Threads:4/HashTable krows:1/process_time  
        93248 ns       120317 ns         5135 rows/sec=8.51085M/s
    BM_HashJoinBasic_BuildParallelism/Threads:5/HashTable krows:1/process_time  
       109931 ns       140384 ns         4527 rows/sec=7.29427M/s
    BM_HashJoinBasic_BuildParallelism/Threads:6/HashTable krows:1/process_time  
       127997 ns       180633 ns         3546 rows/sec=5.66897M/s
    BM_HashJoinBasic_BuildParallelism/Threads:7/HashTable krows:1/process_time  
       125138 ns       185416 ns         3267 rows/sec=5.52271M/s
    BM_HashJoinBasic_BuildParallelism/Threads:8/HashTable krows:1/process_time  
       142611 ns       236355 ns         3613 rows/sec=4.33247M/s
    BM_HashJoinBasic_BuildParallelism/Threads:9/HashTable krows:1/process_time  
       169663 ns       336376 ns         2158 rows/sec=3.04421M/s
    BM_HashJoinBasic_BuildParallelism/Threads:10/HashTable krows:1/process_time 
       174708 ns       362630 ns         1943 rows/sec=2.82381M/s
    BM_HashJoinBasic_BuildParallelism/Threads:11/HashTable krows:1/process_time 
       186939 ns       409803 ns         1693 rows/sec=2.49876M/s
    BM_HashJoinBasic_BuildParallelism/Threads:12/HashTable krows:1/process_time 
       196817 ns       451213 ns         1542 rows/sec=2.26944M/s
    BM_HashJoinBasic_BuildParallelism/Threads:13/HashTable krows:1/process_time 
       209194 ns       501488 ns         1407 rows/sec=2.04192M/s
    BM_HashJoinBasic_BuildParallelism/Threads:14/HashTable krows:1/process_time 
       218517 ns       544590 ns         1299 rows/sec=1.88031M/s
    BM_HashJoinBasic_BuildParallelism/Threads:15/HashTable krows:1/process_time 
       224407 ns       579947 ns         1206 rows/sec=1.76568M/s
    BM_HashJoinBasic_BuildParallelism/Threads:16/HashTable krows:1/process_time 
       236201 ns       630016 ns         1134 rows/sec=1.62536M/s
    BM_HashJoinBasic_BuildParallelism/Threads:1/HashTable krows:8/process_time  
       213061 ns       213082 ns         3276 rows/sec=38.4453M/s
    BM_HashJoinBasic_BuildParallelism/Threads:2/HashTable krows:8/process_time  
       260230 ns       374124 ns         1900 rows/sec=21.8965M/s
    BM_HashJoinBasic_BuildParallelism/Threads:3/HashTable krows:8/process_time  
       275723 ns       483754 ns         1331 rows/sec=16.9342M/s
    BM_HashJoinBasic_BuildParallelism/Threads:4/HashTable krows:8/process_time  
       326784 ns       711857 ns          974 rows/sec=11.5079M/s
    BM_HashJoinBasic_BuildParallelism/Threads:5/HashTable krows:8/process_time  
       351987 ns       861883 ns          798 rows/sec=9.50477M/s
    BM_HashJoinBasic_BuildParallelism/Threads:6/HashTable krows:8/process_time  
       370956 ns      1000389 ns          683 rows/sec=8.18881M/s
    BM_HashJoinBasic_BuildParallelism/Threads:7/HashTable krows:8/process_time  
       384963 ns      1064672 ns          646 rows/sec=7.69439M/s
    BM_HashJoinBasic_BuildParallelism/Threads:8/HashTable krows:8/process_time  
       406914 ns      1172464 ns          606 rows/sec=6.987M/s
    BM_HashJoinBasic_BuildParallelism/Threads:9/HashTable krows:8/process_time  
       425632 ns      1252871 ns          567 rows/sec=6.53858M/s
    BM_HashJoinBasic_BuildParallelism/Threads:10/HashTable krows:8/process_time 
       433262 ns      1287050 ns          524 rows/sec=6.36494M/s
    BM_HashJoinBasic_BuildParallelism/Threads:11/HashTable krows:8/process_time 
       443328 ns      1329822 ns          528 rows/sec=6.16022M/s
    BM_HashJoinBasic_BuildParallelism/Threads:12/HashTable krows:8/process_time 
       450736 ns      1383203 ns          508 rows/sec=5.92249M/s
    BM_HashJoinBasic_BuildParallelism/Threads:13/HashTable krows:8/process_time 
       465523 ns      1425956 ns          495 rows/sec=5.74492M/s
    BM_HashJoinBasic_BuildParallelism/Threads:14/HashTable krows:8/process_time 
       471723 ns      1462440 ns          475 rows/sec=5.6016M/s
    BM_HashJoinBasic_BuildParallelism/Threads:15/HashTable krows:8/process_time 
       484823 ns      1524638 ns          464 rows/sec=5.37308M/s
    BM_HashJoinBasic_BuildParallelism/Threads:16/HashTable krows:8/process_time 
       485260 ns      1541146 ns          453 rows/sec=5.31553M/s
    BM_HashJoinBasic_BuildParallelism/Threads:1/HashTable krows:64/process_time 
      1716517 ns      1716522 ns          404 rows/sec=38.1795M/s
    BM_HashJoinBasic_BuildParallelism/Threads:2/HashTable krows:64/process_time 
      1762125 ns      2982570 ns          235 rows/sec=21.973M/s
    BM_HashJoinBasic_BuildParallelism/Threads:3/HashTable krows:64/process_time 
      1826549 ns      4331683 ns          161 rows/sec=15.1295M/s
    BM_HashJoinBasic_BuildParallelism/Threads:4/HashTable krows:64/process_time 
      2032670 ns      6228081 ns          111 rows/sec=10.5227M/s
    BM_HashJoinBasic_BuildParallelism/Threads:5/HashTable krows:64/process_time 
      2008129 ns      7401860 ns           93 rows/sec=8.85399M/s
    BM_HashJoinBasic_BuildParallelism/Threads:6/HashTable krows:64/process_time 
      2022595 ns      8733805 ns           77 rows/sec=7.50372M/s
    BM_HashJoinBasic_BuildParallelism/Threads:7/HashTable krows:64/process_time 
      2084620 ns     10333721 ns           68 rows/sec=6.34196M/s
    BM_HashJoinBasic_BuildParallelism/Threads:8/HashTable krows:64/process_time 
      2186912 ns     12275696 ns           56 rows/sec=5.33868M/s
    BM_HashJoinBasic_BuildParallelism/Threads:9/HashTable krows:64/process_time 
      3061302 ns     20949833 ns           24 rows/sec=3.12823M/s
    BM_HashJoinBasic_BuildParallelism/Threads:10/HashTable 
krows:64/process_time      4241129 ns     34483810 ns           21 
rows/sec=1.90049M/s
    BM_HashJoinBasic_BuildParallelism/Threads:11/HashTable 
krows:64/process_time      4123000 ns     33438545 ns           22 
rows/sec=1.95989M/s
    BM_HashJoinBasic_BuildParallelism/Threads:12/HashTable 
krows:64/process_time      5282983 ns     44385773 ns           22 
rows/sec=1.47651M/s
    BM_HashJoinBasic_BuildParallelism/Threads:13/HashTable 
krows:64/process_time      4214940 ns     33978250 ns           16 
rows/sec=1.92876M/s
    BM_HashJoinBasic_BuildParallelism/Threads:14/HashTable 
krows:64/process_time      9775500 ns     85277400 ns           10 
rows/sec=768.504k/s
    BM_HashJoinBasic_BuildParallelism/Threads:15/HashTable 
krows:64/process_time      8448605 ns     40459190 ns           21 
rows/sec=1.61981M/s
    BM_HashJoinBasic_BuildParallelism/Threads:16/HashTable 
krows:64/process_time      8311054 ns     74384765 ns           17 
rows/sec=881.041k/s
    BM_HashJoinBasic_BuildParallelism/Threads:1/HashTable 
krows:512/process_time     15124972 ns     15124152 ns           46 
rows/sec=34.6656M/s
    BM_HashJoinBasic_BuildParallelism/Threads:2/HashTable 
krows:512/process_time      9977718 ns     19336583 ns           36 
rows/sec=27.1138M/s
    BM_HashJoinBasic_BuildParallelism/Threads:3/HashTable 
krows:512/process_time      8751039 ns     23240667 ns           30 
rows/sec=22.5591M/s
    BM_HashJoinBasic_BuildParallelism/Threads:4/HashTable 
krows:512/process_time      9839327 ns     33597150 ns           20 
rows/sec=15.6051M/s
    BM_HashJoinBasic_BuildParallelism/Threads:5/HashTable 
krows:512/process_time     10058853 ns     41758118 ns           17 
rows/sec=12.5554M/s
    BM_HashJoinBasic_BuildParallelism/Threads:6/HashTable 
krows:512/process_time     10139465 ns     49509846 ns           13 
rows/sec=10.5896M/s
    BM_HashJoinBasic_BuildParallelism/Threads:7/HashTable 
krows:512/process_time     10311708 ns     58393545 ns           11 
rows/sec=8.97853M/s
    BM_HashJoinBasic_BuildParallelism/Threads:8/HashTable 
krows:512/process_time     10327653 ns     65427667 ns            9 
rows/sec=8.01325M/s
    BM_HashJoinBasic_BuildParallelism/Threads:9/HashTable 
krows:512/process_time     13476536 ns     99947571 ns            7 
rows/sec=5.24563M/s
    BM_HashJoinBasic_BuildParallelism/Threads:10/HashTable 
krows:512/process_time    17290050 ns    143569000 ns            5 
rows/sec=3.65182M/s
    BM_HashJoinBasic_BuildParallelism/Threads:11/HashTable 
krows:512/process_time    20576010 ns    176557250 ns            4 
rows/sec=2.96951M/s
    BM_HashJoinBasic_BuildParallelism/Threads:12/HashTable 
krows:512/process_time    24393117 ns    205985600 ns            5 
rows/sec=2.54527M/s
    BM_HashJoinBasic_BuildParallelism/Threads:13/HashTable 
krows:512/process_time    21039639 ns    168724000 ns            3 
rows/sec=3.10737M/s
    BM_HashJoinBasic_BuildParallelism/Threads:14/HashTable 
krows:512/process_time    38604708 ns    333330667 ns            3 
rows/sec=1.57288M/s
    BM_HashJoinBasic_BuildParallelism/Threads:15/HashTable 
krows:512/process_time    63189833 ns    502763000 ns            1 
rows/sec=1042.81k/s
    BM_HashJoinBasic_BuildParallelism/Threads:16/HashTable 
krows:512/process_time    91289749 ns    731794000 ns            1 
rows/sec=716.442k/s
    BM_HashJoinBasic_BuildParallelism/Threads:1/HashTable 
krows:4096/process_time   164686385 ns    164197000 ns            4 
rows/sec=25.5443M/s
    BM_HashJoinBasic_BuildParallelism/Threads:2/HashTable 
krows:4096/process_time   112767458 ns    217052333 ns            3 
rows/sec=19.3239M/s
    BM_HashJoinBasic_BuildParallelism/Threads:3/HashTable 
krows:4096/process_time   100643792 ns    245290000 ns            3 
rows/sec=17.0994M/s
    BM_HashJoinBasic_BuildParallelism/Threads:4/HashTable 
krows:4096/process_time    74837889 ns    268070667 ns            3 
rows/sec=15.6463M/s
    BM_HashJoinBasic_BuildParallelism/Threads:5/HashTable 
krows:4096/process_time    63174056 ns    269879667 ns            3 
rows/sec=15.5414M/s
    BM_HashJoinBasic_BuildParallelism/Threads:6/HashTable 
krows:4096/process_time    59140353 ns    294662000 ns            2 
rows/sec=14.2343M/s
    BM_HashJoinBasic_BuildParallelism/Threads:7/HashTable 
krows:4096/process_time    64158124 ns    354435000 ns            2 
rows/sec=11.8338M/s
    BM_HashJoinBasic_BuildParallelism/Threads:8/HashTable 
krows:4096/process_time    70799208 ns    465744500 ns            2 
rows/sec=9.00559M/s
    BM_HashJoinBasic_BuildParallelism/Threads:9/HashTable 
krows:4096/process_time   118786833 ns    730395500 ns            2 
rows/sec=5.74251M/s
    BM_HashJoinBasic_BuildParallelism/Threads:10/HashTable 
krows:4096/process_time  158779374 ns   1254764000 ns            1 
rows/sec=3.3427M/s
    BM_HashJoinBasic_BuildParallelism/Threads:11/HashTable 
krows:4096/process_time  124160834 ns    985925000 ns            1 
rows/sec=4.25418M/s
    BM_HashJoinBasic_BuildParallelism/Threads:12/HashTable 
krows:4096/process_time  261909918 ns   1956600000 ns            1 
rows/sec=2.14367M/s
    BM_HashJoinBasic_BuildParallelism/Threads:13/HashTable 
krows:4096/process_time  437582374 ns   3326539000 ns            1 
rows/sec=1.26086M/s
    BM_HashJoinBasic_BuildParallelism/Threads:14/HashTable 
krows:4096/process_time  225402042 ns   1756542000 ns            1 
rows/sec=2.38782M/s
    BM_HashJoinBasic_BuildParallelism/Threads:15/HashTable 
krows:4096/process_time  284178668 ns   2485382000 ns            1 
rows/sec=1.68759M/s
    BM_HashJoinBasic_BuildParallelism/Threads:16/HashTable 
krows:4096/process_time  198744084 ns   1697137000 ns            1 
rows/sec=2.4714M/s
    ```
    
    </details>
    
    <details>
    <summary>Benchmark After (Click to expand)</summary>
    
    ```
    Running ./arrow-acero-hash-join-benchmark
    Run on (10 X 24.1886 MHz CPU s)
    CPU Caches:
      L1 Data 64 KiB
      L1 Instruction 128 KiB
      L2 Unified 4096 KiB (x10)
    Load Average: 3.72, 3.38, 3.20
    
-----------------------------------------------------------------------------------------------------------------------------------------
    Benchmark                                                                   
            Time             CPU   Iterations UserCounters...
    
-----------------------------------------------------------------------------------------------------------------------------------------
    BM_HashJoinBasic_BuildParallelism/Threads:1/HashTable krows:1/process_time  
        64162 ns        60216 ns        11306 rows/sec=17.0054M/s
    BM_HashJoinBasic_BuildParallelism/Threads:2/HashTable krows:1/process_time  
        73712 ns        85168 ns         8287 rows/sec=12.0233M/s
    BM_HashJoinBasic_BuildParallelism/Threads:3/HashTable krows:1/process_time  
        81532 ns       108468 ns         6563 rows/sec=9.44057M/s
    BM_HashJoinBasic_BuildParallelism/Threads:4/HashTable krows:1/process_time  
        90389 ns       125957 ns         5590 rows/sec=8.12979M/s
    BM_HashJoinBasic_BuildParallelism/Threads:5/HashTable krows:1/process_time  
        98131 ns       144575 ns         3912 rows/sec=7.08281M/s
    BM_HashJoinBasic_BuildParallelism/Threads:6/HashTable krows:1/process_time  
       112269 ns       171638 ns         3551 rows/sec=5.96605M/s
    BM_HashJoinBasic_BuildParallelism/Threads:7/HashTable krows:1/process_time  
       127481 ns       207426 ns         3053 rows/sec=4.93669M/s
    BM_HashJoinBasic_BuildParallelism/Threads:8/HashTable krows:1/process_time  
       135240 ns       221817 ns         3337 rows/sec=4.61641M/s
    BM_HashJoinBasic_BuildParallelism/Threads:9/HashTable krows:1/process_time  
       167247 ns       323541 ns         2152 rows/sec=3.16497M/s
    BM_HashJoinBasic_BuildParallelism/Threads:10/HashTable krows:1/process_time 
       173753 ns       363113 ns         1913 rows/sec=2.82006M/s
    BM_HashJoinBasic_BuildParallelism/Threads:11/HashTable krows:1/process_time 
       182739 ns       404210 ns         1717 rows/sec=2.53334M/s
    BM_HashJoinBasic_BuildParallelism/Threads:12/HashTable krows:1/process_time 
       194151 ns       451175 ns         1542 rows/sec=2.26963M/s
    BM_HashJoinBasic_BuildParallelism/Threads:13/HashTable krows:1/process_time 
       205538 ns       496195 ns         1423 rows/sec=2.06371M/s
    BM_HashJoinBasic_BuildParallelism/Threads:14/HashTable krows:1/process_time 
       217099 ns       540857 ns         1259 rows/sec=1.89329M/s
    BM_HashJoinBasic_BuildParallelism/Threads:15/HashTable krows:1/process_time 
       228487 ns       591203 ns         1274 rows/sec=1.73206M/s
    BM_HashJoinBasic_BuildParallelism/Threads:16/HashTable krows:1/process_time 
       240082 ns       642682 ns         1087 rows/sec=1.59332M/s
    BM_HashJoinBasic_BuildParallelism/Threads:1/HashTable krows:8/process_time  
       218917 ns       218912 ns         3219 rows/sec=37.4214M/s
    BM_HashJoinBasic_BuildParallelism/Threads:2/HashTable krows:8/process_time  
       239310 ns       338138 ns         2066 rows/sec=24.2268M/s
    BM_HashJoinBasic_BuildParallelism/Threads:3/HashTable krows:8/process_time  
       284833 ns       411252 ns         1570 rows/sec=19.9197M/s
    BM_HashJoinBasic_BuildParallelism/Threads:4/HashTable krows:8/process_time  
       315525 ns       496170 ns         1437 rows/sec=16.5105M/s
    BM_HashJoinBasic_BuildParallelism/Threads:5/HashTable krows:8/process_time  
       329116 ns       557150 ns         1246 rows/sec=14.7034M/s
    BM_HashJoinBasic_BuildParallelism/Threads:6/HashTable krows:8/process_time  
       339415 ns       612913 ns         1123 rows/sec=13.3657M/s
    BM_HashJoinBasic_BuildParallelism/Threads:7/HashTable krows:8/process_time  
       354355 ns       673437 ns         1040 rows/sec=12.1645M/s
    BM_HashJoinBasic_BuildParallelism/Threads:8/HashTable krows:8/process_time  
       371602 ns       736217 ns          948 rows/sec=11.1271M/s
    BM_HashJoinBasic_BuildParallelism/Threads:9/HashTable krows:8/process_time  
       388963 ns       788646 ns          870 rows/sec=10.3874M/s
    BM_HashJoinBasic_BuildParallelism/Threads:10/HashTable krows:8/process_time 
       398060 ns       838691 ns          850 rows/sec=9.76761M/s
    BM_HashJoinBasic_BuildParallelism/Threads:11/HashTable krows:8/process_time 
       403233 ns       875477 ns          789 rows/sec=9.35719M/s
    BM_HashJoinBasic_BuildParallelism/Threads:12/HashTable krows:8/process_time 
       410908 ns       917480 ns          748 rows/sec=8.92881M/s
    BM_HashJoinBasic_BuildParallelism/Threads:13/HashTable krows:8/process_time 
       425442 ns       971118 ns          702 rows/sec=8.43564M/s
    BM_HashJoinBasic_BuildParallelism/Threads:14/HashTable krows:8/process_time 
       427492 ns      1002726 ns          718 rows/sec=8.16973M/s
    BM_HashJoinBasic_BuildParallelism/Threads:15/HashTable krows:8/process_time 
       442728 ns      1057910 ns          653 rows/sec=7.74357M/s
    BM_HashJoinBasic_BuildParallelism/Threads:16/HashTable krows:8/process_time 
       455481 ns      1115695 ns          642 rows/sec=7.34251M/s
    BM_HashJoinBasic_BuildParallelism/Threads:1/HashTable krows:64/process_time 
      1731379 ns      1731375 ns          403 rows/sec=37.852M/s
    BM_HashJoinBasic_BuildParallelism/Threads:2/HashTable krows:64/process_time 
      1179658 ns      2152165 ns          328 rows/sec=30.4512M/s
    BM_HashJoinBasic_BuildParallelism/Threads:3/HashTable krows:64/process_time 
      1116942 ns      2232095 ns          316 rows/sec=29.3608M/s
    BM_HashJoinBasic_BuildParallelism/Threads:4/HashTable krows:64/process_time 
       814811 ns      2498054 ns          276 rows/sec=26.2348M/s
    BM_HashJoinBasic_BuildParallelism/Threads:5/HashTable krows:64/process_time 
       900296 ns      2959111 ns          235 rows/sec=22.1472M/s
    BM_HashJoinBasic_BuildParallelism/Threads:6/HashTable krows:64/process_time 
       917596 ns      3253949 ns          215 rows/sec=20.1405M/s
    BM_HashJoinBasic_BuildParallelism/Threads:7/HashTable krows:64/process_time 
       920826 ns      3526660 ns          197 rows/sec=18.583M/s
    BM_HashJoinBasic_BuildParallelism/Threads:8/HashTable krows:64/process_time 
       811062 ns      3789065 ns          184 rows/sec=17.2961M/s
    BM_HashJoinBasic_BuildParallelism/Threads:9/HashTable krows:64/process_time 
      1031480 ns      5637721 ns          122 rows/sec=11.6246M/s
    BM_HashJoinBasic_BuildParallelism/Threads:10/HashTable 
krows:64/process_time      1072212 ns      6040280 ns          118 
rows/sec=10.8498M/s
    BM_HashJoinBasic_BuildParallelism/Threads:11/HashTable 
krows:64/process_time      1088001 ns      6204862 ns          116 
rows/sec=10.562M/s
    BM_HashJoinBasic_BuildParallelism/Threads:12/HashTable 
krows:64/process_time      1119427 ns      6356310 ns          113 
rows/sec=10.3104M/s
    BM_HashJoinBasic_BuildParallelism/Threads:13/HashTable 
krows:64/process_time      1128651 ns      6542557 ns          115 
rows/sec=10.0169M/s
    BM_HashJoinBasic_BuildParallelism/Threads:14/HashTable 
krows:64/process_time      1152430 ns      6731112 ns          107 
rows/sec=9.73628M/s
    BM_HashJoinBasic_BuildParallelism/Threads:15/HashTable 
krows:64/process_time      1161581 ns      6772318 ns          107 
rows/sec=9.67704M/s
    BM_HashJoinBasic_BuildParallelism/Threads:16/HashTable 
krows:64/process_time      1171040 ns      6748462 ns          106 
rows/sec=9.71125M/s
    BM_HashJoinBasic_BuildParallelism/Threads:1/HashTable 
krows:512/process_time     16584785 ns     16419156 ns           45 
rows/sec=31.9315M/s
    BM_HashJoinBasic_BuildParallelism/Threads:2/HashTable 
krows:512/process_time      9782162 ns     18750500 ns           36 
rows/sec=27.9613M/s
    BM_HashJoinBasic_BuildParallelism/Threads:3/HashTable 
krows:512/process_time      9204909 ns     18933861 ns           36 
rows/sec=27.6905M/s
    BM_HashJoinBasic_BuildParallelism/Threads:4/HashTable 
krows:512/process_time      5665851 ns     20187600 ns           35 
rows/sec=25.9708M/s
    BM_HashJoinBasic_BuildParallelism/Threads:5/HashTable 
krows:512/process_time      6824165 ns     24445690 ns           29 
rows/sec=21.4471M/s
    BM_HashJoinBasic_BuildParallelism/Threads:6/HashTable 
krows:512/process_time      6476403 ns     25448704 ns           27 
rows/sec=20.6018M/s
    BM_HashJoinBasic_BuildParallelism/Threads:7/HashTable 
krows:512/process_time      6380011 ns     26670808 ns           26 
rows/sec=19.6577M/s
    BM_HashJoinBasic_BuildParallelism/Threads:8/HashTable 
krows:512/process_time      4994868 ns     29002792 ns           24 
rows/sec=18.0772M/s
    BM_HashJoinBasic_BuildParallelism/Threads:9/HashTable 
krows:512/process_time      6097037 ns     37510263 ns           19 
rows/sec=13.9772M/s
    BM_HashJoinBasic_BuildParallelism/Threads:10/HashTable 
krows:512/process_time     6024000 ns     40356889 ns           18 
rows/sec=12.9913M/s
    BM_HashJoinBasic_BuildParallelism/Threads:11/HashTable 
krows:512/process_time     6167103 ns     41287529 ns           17 
rows/sec=12.6985M/s
    BM_HashJoinBasic_BuildParallelism/Threads:12/HashTable 
krows:512/process_time     6087725 ns     40475722 ns           18 
rows/sec=12.9531M/s
    BM_HashJoinBasic_BuildParallelism/Threads:13/HashTable 
krows:512/process_time     6163463 ns     41720647 ns           17 
rows/sec=12.5666M/s
    BM_HashJoinBasic_BuildParallelism/Threads:14/HashTable 
krows:512/process_time     6056402 ns     40388529 ns           17 
rows/sec=12.9811M/s
    BM_HashJoinBasic_BuildParallelism/Threads:15/HashTable 
krows:512/process_time     5972958 ns     40973824 ns           17 
rows/sec=12.7957M/s
    BM_HashJoinBasic_BuildParallelism/Threads:16/HashTable 
krows:512/process_time     6593174 ns     40719647 ns           17 
rows/sec=12.8756M/s
    BM_HashJoinBasic_BuildParallelism/Threads:1/HashTable 
krows:4096/process_time   174475083 ns    174058000 ns            3 
rows/sec=24.0972M/s
    BM_HashJoinBasic_BuildParallelism/Threads:2/HashTable 
krows:4096/process_time   109935347 ns    200222667 ns            3 
rows/sec=20.9482M/s
    BM_HashJoinBasic_BuildParallelism/Threads:3/HashTable 
krows:4096/process_time    89852042 ns    187011000 ns            3 
rows/sec=22.4281M/s
    BM_HashJoinBasic_BuildParallelism/Threads:4/HashTable 
krows:4096/process_time    57974139 ns    202076667 ns            3 
rows/sec=20.756M/s
    BM_HashJoinBasic_BuildParallelism/Threads:5/HashTable 
krows:4096/process_time    57160194 ns    210744667 ns            3 
rows/sec=19.9023M/s
    BM_HashJoinBasic_BuildParallelism/Threads:6/HashTable 
krows:4096/process_time    56770167 ns    221233000 ns            3 
rows/sec=18.9588M/s
    BM_HashJoinBasic_BuildParallelism/Threads:7/HashTable 
krows:4096/process_time    59031097 ns    241927000 ns            3 
rows/sec=17.3371M/s
    BM_HashJoinBasic_BuildParallelism/Threads:8/HashTable 
krows:4096/process_time    46069291 ns    263787667 ns            3 
rows/sec=15.9003M/s
    BM_HashJoinBasic_BuildParallelism/Threads:9/HashTable 
krows:4096/process_time    51498374 ns    310020500 ns            2 
rows/sec=13.5291M/s
    BM_HashJoinBasic_BuildParallelism/Threads:10/HashTable 
krows:4096/process_time   52055417 ns    319261500 ns            2 
rows/sec=13.1375M/s
    BM_HashJoinBasic_BuildParallelism/Threads:11/HashTable 
krows:4096/process_time   49418250 ns    331526500 ns            2 
rows/sec=12.6515M/s
    BM_HashJoinBasic_BuildParallelism/Threads:12/HashTable 
krows:4096/process_time   53305833 ns    332126000 ns            2 
rows/sec=12.6287M/s
    BM_HashJoinBasic_BuildParallelism/Threads:13/HashTable 
krows:4096/process_time   48910062 ns    325631500 ns            2 
rows/sec=12.8805M/s
    BM_HashJoinBasic_BuildParallelism/Threads:14/HashTable 
krows:4096/process_time   52218458 ns    312798500 ns            2 
rows/sec=13.409M/s
    BM_HashJoinBasic_BuildParallelism/Threads:15/HashTable 
krows:4096/process_time   51131709 ns    344045500 ns            2 
rows/sec=12.1911M/s
    BM_HashJoinBasic_BuildParallelism/Threads:16/HashTable 
krows:4096/process_time   55233376 ns    338843500 ns            2 
rows/sec=12.3783M/s
    ```
    
    </details>
    
    #### Overhead
    
    This change introduces some overhead indeed. First, in the old 
implementation, the partition info is used right way after partitioning the 
batch, whereas the new implementation preserves the partition info and uses it 
in the next stage (potentially by other thread). This may be less cache 
friendly. Second, preserving the the partition info requires more memory: the 
increased allocation may hurt performance a bit, and worsen the memory profile 
by 6 bytes per row (4 bytes for hash and 2 [...]
    
    But as mentioned above, almost all multi-threaded cases are winning. Even 
nicer, the increased memory profile spans only a short period and doesn't 
really increase the peak memory: the peak moment always comes in the merge 
stage, and by that time, the preserved partition info for all batches are 
released already. This is verified by printing the memory pool stats when 
benchmarking in my local.
    
    ### Are these changes tested?
    
    Yes. Existing tests suffice.
    
    ### Are there any user-facing changes?
    
    None.
    
    **This PR includes breaking changes to public APIs.** (If there are any 
breaking changes to public APIs, please explain which changes are breaking. If 
not, you can remove this.)
    
    **This PR contains a "Critical Fix".** (If the changes fix either (a) a 
security vulnerability, (b) a bug that caused incorrect or invalid data to be 
produced, or (c) a bug that causes a crash (even when the API contract is 
upheld), please provide explanation. If not, you can remove this.)
    
    * GitHub Issue: #45611
    
    Lead-authored-by: Rossi Sun <[email protected]>
    Co-authored-by: Sutou Kouhei <[email protected]>
    Signed-off-by: Rossi Sun <[email protected]>
---
 cpp/src/arrow/acero/swiss_join.cc         | 201 +++++++++++++++++-------------
 cpp/src/arrow/acero/swiss_join_internal.h |  49 +++++---
 2 files changed, 148 insertions(+), 102 deletions(-)

diff --git a/cpp/src/arrow/acero/swiss_join.cc 
b/cpp/src/arrow/acero/swiss_join.cc
index a1faef4679..b4d89df290 100644
--- a/cpp/src/arrow/acero/swiss_join.cc
+++ b/cpp/src/arrow/acero/swiss_join.cc
@@ -1102,7 +1102,8 @@ uint32_t SwissTableForJoin::payload_id_to_key_id(uint32_t 
payload_id) const {
 }
 
 Status SwissTableForJoinBuild::Init(SwissTableForJoin* target, int dop, 
int64_t num_rows,
-                                    bool reject_duplicate_keys, bool 
no_payload,
+                                    int64_t num_batches, bool 
reject_duplicate_keys,
+                                    bool no_payload,
                                     const std::vector<KeyColumnMetadata>& 
key_types,
                                     const std::vector<KeyColumnMetadata>& 
payload_types,
                                     MemoryPool* pool, int64_t hardware_flags) {
@@ -1112,7 +1113,7 @@ Status SwissTableForJoinBuild::Init(SwissTableForJoin* 
target, int dop, int64_t
 
   // Make sure that we do not use many partitions if there are not enough rows.
   //
-  constexpr int64_t min_num_rows_per_prtn = 1 << 18;
+  constexpr int64_t min_num_rows_per_prtn = 1 << 12;
   log_num_prtns_ =
       std::min(bit_util::Log2(dop_),
                bit_util::Log2(bit_util::CeilDiv(num_rows, 
min_num_rows_per_prtn)));
@@ -1123,9 +1124,9 @@ Status SwissTableForJoinBuild::Init(SwissTableForJoin* 
target, int dop, int64_t
   pool_ = pool;
   hardware_flags_ = hardware_flags;
 
+  batch_states_.resize(num_batches);
   prtn_states_.resize(num_prtns_);
   thread_states_.resize(dop_);
-  prtn_locks_.Init(dop_, num_prtns_);
 
   RowTableMetadata key_row_metadata;
   key_row_metadata.FromColumnMetadataVector(key_types,
@@ -1154,91 +1155,74 @@ Status SwissTableForJoinBuild::Init(SwissTableForJoin* 
target, int dop, int64_t
   return Status::OK();
 }
 
-Status SwissTableForJoinBuild::PushNextBatch(int64_t thread_id,
-                                             const ExecBatch& key_batch,
-                                             const ExecBatch* 
payload_batch_maybe_null,
-                                             arrow::util::TempVectorStack* 
temp_stack) {
-  ARROW_DCHECK(thread_id < dop_);
+Status SwissTableForJoinBuild::PartitionBatch(size_t thread_id, int64_t 
batch_id,
+                                              const ExecBatch& key_batch,
+                                              arrow::util::TempVectorStack* 
temp_stack) {
+  DCHECK_LT(thread_id, thread_states_.size());
+  DCHECK_LT(batch_id, static_cast<int64_t>(batch_states_.size()));
   ThreadState& locals = thread_states_[thread_id];
+  BatchState& batch_state = batch_states_[batch_id];
+  uint16_t num_rows = static_cast<uint16_t>(key_batch.length);
 
   // Compute hash
   //
-  locals.batch_hashes.resize(key_batch.length);
-  RETURN_NOT_OK(Hashing32::HashBatch(
-      key_batch, locals.batch_hashes.data(), locals.temp_column_arrays, 
hardware_flags_,
-      temp_stack, /*start_row=*/0, static_cast<int>(key_batch.length)));
+  batch_state.hashes.resize(num_rows);
+  RETURN_NOT_OK(Hashing32::HashBatch(key_batch, batch_state.hashes.data(),
+                                     locals.temp_column_arrays, 
hardware_flags_,
+                                     temp_stack, /*start_row=*/0, num_rows));
 
   // Partition on hash
   //
-  locals.batch_prtn_row_ids.resize(locals.batch_hashes.size());
-  locals.batch_prtn_ranges.resize(num_prtns_ + 1);
-  int num_rows = static_cast<int>(locals.batch_hashes.size());
+  batch_state.prtn_ranges.resize(num_prtns_ + 1);
+  batch_state.prtn_row_ids.resize(num_rows);
   if (num_prtns_ == 1) {
     // We treat single partition case separately to avoid extra checks in row
     // partitioning implementation for general case.
     //
-    locals.batch_prtn_ranges[0] = 0;
-    locals.batch_prtn_ranges[1] = num_rows;
-    for (int i = 0; i < num_rows; ++i) {
-      locals.batch_prtn_row_ids[i] = i;
+    batch_state.prtn_ranges[0] = 0;
+    batch_state.prtn_ranges[1] = num_rows;
+    for (uint16_t i = 0; i < num_rows; ++i) {
+      batch_state.prtn_row_ids[i] = i;
     }
   } else {
     PartitionSort::Eval(
-        static_cast<int>(locals.batch_hashes.size()), num_prtns_,
-        locals.batch_prtn_ranges.data(),
-        [this, &locals](int64_t i) {
+        num_rows, num_prtns_, batch_state.prtn_ranges.data(),
+        [this, &batch_state](int64_t i) {
           // SwissTable uses the highest bits of the hash for block index.
           // We want each partition to correspond to a range of block indices,
           // so we also partition on the highest bits of the hash.
           //
-          return locals.batch_hashes[i] >> (SwissTable::bits_hash_ - 
log_num_prtns_);
+          return batch_state.hashes[i] >> (SwissTable::bits_hash_ - 
log_num_prtns_);
         },
-        [&locals](int64_t i, int pos) {
-          locals.batch_prtn_row_ids[pos] = static_cast<uint16_t>(i);
+        [&batch_state](int64_t i, int pos) {
+          batch_state.prtn_row_ids[pos] = static_cast<uint16_t>(i);
         });
-  }
 
-  // Update hashes, shifting left to get rid of the bits that were already used
-  // for partitioning.
-  //
-  for (size_t i = 0; i < locals.batch_hashes.size(); ++i) {
-    locals.batch_hashes[i] <<= log_num_prtns_;
+    // Update hashes, shifting left to get rid of the bits that were already 
used
+    // for partitioning.
+    //
+    for (size_t i = 0; i < batch_state.hashes.size(); ++i) {
+      batch_state.hashes[i] <<= log_num_prtns_;
+    }
   }
 
-  // For each partition:
-  // - map keys to unique integers using (this partition's) hash table
-  // - append payloads (if present) to (this partition's) row array
-  //
-  locals.temp_prtn_ids.resize(num_prtns_);
-
-  RETURN_NOT_OK(prtn_locks_.ForEachPartition(
-      thread_id, locals.temp_prtn_ids.data(),
-      /*is_prtn_empty_fn=*/
-      [&](int prtn_id) {
-        return locals.batch_prtn_ranges[prtn_id + 1] == 
locals.batch_prtn_ranges[prtn_id];
-      },
-      /*process_prtn_fn=*/
-      [&](int prtn_id) {
-        return ProcessPartition(thread_id, key_batch, payload_batch_maybe_null,
-                                temp_stack, prtn_id);
-      }));
-
   return Status::OK();
 }
 
-Status SwissTableForJoinBuild::ProcessPartition(int64_t thread_id,
-                                                const ExecBatch& key_batch,
-                                                const ExecBatch* 
payload_batch_maybe_null,
-                                                arrow::util::TempVectorStack* 
temp_stack,
-                                                int prtn_id) {
-  ARROW_DCHECK(thread_id < dop_);
+Status SwissTableForJoinBuild::ProcessPartition(
+    size_t thread_id, int64_t batch_id, int prtn_id, const ExecBatch& 
key_batch,
+    const ExecBatch* payload_batch_maybe_null, arrow::util::TempVectorStack* 
temp_stack) {
+  DCHECK_LT(thread_id, thread_states_.size());
+  DCHECK_LT(batch_id, static_cast<int64_t>(batch_states_.size()));
+  DCHECK_LT(static_cast<size_t>(prtn_id), prtn_states_.size());
   ThreadState& locals = thread_states_[thread_id];
+  BatchState& batch_state = batch_states_[batch_id];
+  PartitionState& prtn_state = prtn_states_[prtn_id];
 
   int num_rows_new =
-      locals.batch_prtn_ranges[prtn_id + 1] - 
locals.batch_prtn_ranges[prtn_id];
+      batch_state.prtn_ranges[prtn_id + 1] - batch_state.prtn_ranges[prtn_id];
   const uint16_t* row_ids =
-      locals.batch_prtn_row_ids.data() + locals.batch_prtn_ranges[prtn_id];
-  PartitionState& prtn_state = prtn_states_[prtn_id];
+      batch_state.prtn_row_ids.data() + batch_state.prtn_ranges[prtn_id];
   size_t num_rows_before = prtn_state.key_ids.size();
   // Insert new keys into hash table associated with the current partition
   // and map existing keys to integer ids.
@@ -1247,7 +1231,7 @@ Status SwissTableForJoinBuild::ProcessPartition(int64_t 
thread_id,
   SwissTableWithKeys::Input input(&key_batch, num_rows_new, row_ids, 
temp_stack,
                                   &locals.temp_column_arrays, 
&locals.temp_group_ids);
   RETURN_NOT_OK(prtn_state.keys.MapWithInserts(
-      &input, locals.batch_hashes.data(), prtn_state.key_ids.data() + 
num_rows_before));
+      &input, batch_state.hashes.data(), prtn_state.key_ids.data() + 
num_rows_before));
   // Append input batch rows from current partition to an array of payload
   // rows for this partition.
   //
@@ -2504,6 +2488,13 @@ class SwissJoin : public HashJoinImpl {
   }
 
   void InitTaskGroups() {
+    task_group_partition_ = register_task_group_callback_(
+        [this](size_t thread_index, int64_t task_id) -> Status {
+          return PartitionTask(thread_index, task_id);
+        },
+        [this](size_t thread_index) -> Status {
+          return PartitionFinished(thread_index);
+        });
     task_group_build_ = register_task_group_callback_(
         [this](size_t thread_index, int64_t task_id) -> Status {
           return BuildTask(thread_index, task_id);
@@ -2593,16 +2584,16 @@ class SwissJoin : public HashJoinImpl {
     hash_table_build_ = std::make_unique<SwissTableForJoinBuild>();
     RETURN_NOT_OK(CancelIfNotOK(hash_table_build_->Init(
         &hash_table_, num_threads_, build_side_batches_.row_count(),
-        reject_duplicate_keys, no_payload, key_types, payload_types, pool_,
-        hardware_flags_)));
+        build_side_batches_.batch_count(), reject_duplicate_keys, no_payload, 
key_types,
+        payload_types, pool_, hardware_flags_)));
 
     // Process all input batches
     //
-    return CancelIfNotOK(
-        start_task_group_callback_(task_group_build_, 
build_side_batches_.batch_count()));
+    return CancelIfNotOK(start_task_group_callback_(task_group_partition_,
+                                                    
build_side_batches_.batch_count()));
   }
 
-  Status BuildTask(size_t thread_id, int64_t batch_id) {
+  Status PartitionTask(size_t thread_id, int64_t batch_id) {
     if (IsCancelled()) {
       return Status::OK();
     }
@@ -2610,41 +2601,78 @@ class SwissJoin : public HashJoinImpl {
     DCHECK_GT(build_side_batches_[batch_id].length, 0);
 
     const HashJoinProjectionMaps* schema = schema_[1];
-    DCHECK_NE(hash_table_build_, nullptr);
-    bool no_payload = hash_table_build_->no_payload();
-
     ExecBatch input_batch;
     ARROW_ASSIGN_OR_RAISE(
         input_batch, KeyPayloadFromInput(/*side=*/1, 
&build_side_batches_[batch_id]));
 
-    // Split batch into key batch and optional payload batch
-    //
-    // Input batch is key-payload batch (key columns followed by payload
-    // columns). We split it into two separate batches.
-    //
-    // TODO: Change SwissTableForJoinBuild interface to use key-payload
-    // batch instead to avoid this operation, which involves increasing
-    // shared pointer ref counts.
-    //
     ExecBatch key_batch({}, input_batch.length);
     key_batch.values.resize(schema->num_cols(HashJoinProjection::KEY));
     for (size_t icol = 0; icol < key_batch.values.size(); ++icol) {
       key_batch.values[icol] = input_batch.values[icol];
     }
-    ExecBatch payload_batch({}, input_batch.length);
+    arrow::util::TempVectorStack* temp_stack = &local_states_[thread_id].stack;
 
+    DCHECK_NE(hash_table_build_, nullptr);
+    return hash_table_build_->PartitionBatch(static_cast<int64_t>(thread_id), 
batch_id,
+                                             key_batch, temp_stack);
+  }
+
+  Status PartitionFinished(size_t thread_id) {
+    RETURN_NOT_OK(status());
+
+    DCHECK_NE(hash_table_build_, nullptr);
+    return CancelIfNotOK(
+        start_task_group_callback_(task_group_build_, 
hash_table_build_->num_prtns()));
+  }
+
+  Status BuildTask(size_t thread_id, int64_t prtn_id) {
+    if (IsCancelled()) {
+      return Status::OK();
+    }
+
+    const HashJoinProjectionMaps* schema = schema_[1];
+    DCHECK_NE(hash_table_build_, nullptr);
+    bool no_payload = hash_table_build_->no_payload();
+    ExecBatch key_batch, payload_batch;
+    auto num_keys = schema->num_cols(HashJoinProjection::KEY);
+    auto num_payloads = schema->num_cols(HashJoinProjection::PAYLOAD);
+    key_batch.values.resize(num_keys);
     if (!no_payload) {
-      
payload_batch.values.resize(schema->num_cols(HashJoinProjection::PAYLOAD));
-      for (size_t icol = 0; icol < payload_batch.values.size(); ++icol) {
-        payload_batch.values[icol] =
-            input_batch.values[schema->num_cols(HashJoinProjection::KEY) + 
icol];
-      }
+      payload_batch.values.resize(num_payloads);
     }
     arrow::util::TempVectorStack* temp_stack = &local_states_[thread_id].stack;
-    DCHECK_NE(hash_table_build_, nullptr);
-    RETURN_NOT_OK(CancelIfNotOK(hash_table_build_->PushNextBatch(
-        static_cast<int64_t>(thread_id), key_batch, no_payload ? nullptr : 
&payload_batch,
-        temp_stack)));
+
+    for (int64_t batch_id = 0;
+         batch_id < static_cast<int64_t>(build_side_batches_.batch_count()); 
++batch_id) {
+      ExecBatch input_batch;
+      ARROW_ASSIGN_OR_RAISE(
+          input_batch, KeyPayloadFromInput(/*side=*/1, 
&build_side_batches_[batch_id]));
+
+      // Split batch into key batch and optional payload batch
+      //
+      // Input batch is key-payload batch (key columns followed by payload
+      // columns). We split it into two separate batches.
+      //
+      // TODO: Change SwissTableForJoinBuild interface to use key-payload
+      // batch instead to avoid this operation, which involves increasing
+      // shared pointer ref counts.
+      //
+      key_batch.length = input_batch.length;
+      for (size_t icol = 0; icol < key_batch.values.size(); ++icol) {
+        key_batch.values[icol] = input_batch.values[icol];
+      }
+
+      if (!no_payload) {
+        payload_batch.length = input_batch.length;
+        for (size_t icol = 0; icol < payload_batch.values.size(); ++icol) {
+          payload_batch.values[icol] = input_batch.values[num_keys + icol];
+        }
+      }
+
+      RETURN_NOT_OK(CancelIfNotOK(hash_table_build_->ProcessPartition(
+          thread_id, batch_id, static_cast<int>(prtn_id), key_batch,
+          no_payload ? nullptr : &payload_batch, temp_stack)));
+    }
 
     return Status::OK();
   }
@@ -2897,6 +2925,7 @@ class SwissJoin : public HashJoinImpl {
   const HashJoinProjectionMaps* schema_[2];
 
   // Task scheduling
+  int task_group_partition_;
   int task_group_build_;
   int task_group_merge_;
   int task_group_scan_;
diff --git a/cpp/src/arrow/acero/swiss_join_internal.h 
b/cpp/src/arrow/acero/swiss_join_internal.h
index c7af6517d7..365f2917d8 100644
--- a/cpp/src/arrow/acero/swiss_join_internal.h
+++ b/cpp/src/arrow/acero/swiss_join_internal.h
@@ -523,19 +523,27 @@ class SwissTableForJoin {
 //
 class SwissTableForJoinBuild {
  public:
-  Status Init(SwissTableForJoin* target, int dop, int64_t num_rows,
+  Status Init(SwissTableForJoin* target, int dop, int64_t num_rows, int64_t 
num_batches,
               bool reject_duplicate_keys, bool no_payload,
               const std::vector<KeyColumnMetadata>& key_types,
               const std::vector<KeyColumnMetadata>& payload_types, MemoryPool* 
pool,
               int64_t hardware_flags);
 
-  // In the first phase of parallel hash table build, threads pick unprocessed
-  // exec batches, partition the rows based on hash, and update all of the
-  // partitions with information related to that batch of rows.
+  // In the first phase of parallel hash table build, each thread picks 
unprocessed exec
+  // batches, hashes the batches and preserve the hashes, then partition the 
rows based on
+  // hashes.
   //
-  Status PushNextBatch(int64_t thread_id, const ExecBatch& key_batch,
-                       const ExecBatch* payload_batch_maybe_null,
-                       arrow::util::TempVectorStack* temp_stack);
+  Status PartitionBatch(size_t thread_id, int64_t batch_id, const ExecBatch& 
key_batch,
+                        arrow::util::TempVectorStack* temp_stack);
+
+  // In the second phase of parallel hash table build, each thread picks the 
given
+  // partition of all batches, and updates that particular partition with 
information
+  // related to that batch of rows.
+  //
+  Status ProcessPartition(size_t thread_id, int64_t batch_id, int prtn_id,
+                          const ExecBatch& key_batch,
+                          const ExecBatch* payload_batch_maybe_null,
+                          arrow::util::TempVectorStack* temp_stack);
 
   // Allocate memory and initialize counters required for parallel merging of
   // hash table partitions.
@@ -543,7 +551,7 @@ class SwissTableForJoinBuild {
   //
   Status PreparePrtnMerge();
 
-  // Second phase of parallel hash table build.
+  // Third phase of parallel hash table build.
   // Each partition can be processed by a different thread.
   // Parallel step.
   //
@@ -564,9 +572,6 @@ class SwissTableForJoinBuild {
 
  private:
   void InitRowArray();
-  Status ProcessPartition(int64_t thread_id, const ExecBatch& key_batch,
-                          const ExecBatch* payload_batch_maybe_null,
-                          arrow::util::TempVectorStack* temp_stack, int 
prtn_id);
 
   SwissTableForJoin* target_;
   // DOP stands for Degree Of Parallelism - the maximum number of participating
@@ -604,6 +609,22 @@ class SwissTableForJoinBuild {
   MemoryPool* pool_;
   int64_t hardware_flags_;
 
+  // One per batch.
+  //
+  // Informations like hashes and partitions of each batch gathered in the 
partition phase
+  // and used in the build phase.
+  //
+  struct BatchState {
+    // Hashes for the batch, preserved in the partition phase to avoid 
recomputation in
+    // the build phase. One element per row in the batch.
+    std::vector<uint32_t> hashes;
+    // Accumulative number of rows in each partition for the batch. 
`num_prtns_` + 1
+    // elements.
+    std::vector<uint16_t> prtn_ranges;
+    // Row ids after partition sorting the batch. One element per row in the 
batch.
+    std::vector<uint16_t> prtn_row_ids;
+  };
+
   // One per partition.
   //
   struct PartitionState {
@@ -620,17 +641,13 @@ class SwissTableForJoinBuild {
   // batches.
   //
   struct ThreadState {
-    std::vector<uint32_t> batch_hashes;
-    std::vector<uint16_t> batch_prtn_ranges;
-    std::vector<uint16_t> batch_prtn_row_ids;
-    std::vector<int> temp_prtn_ids;
     std::vector<uint32_t> temp_group_ids;
     std::vector<KeyColumnArray> temp_column_arrays;
   };
 
+  std::vector<BatchState> batch_states_;
   std::vector<PartitionState> prtn_states_;
   std::vector<ThreadState> thread_states_;
-  PartitionLocks prtn_locks_;
 
   std::vector<int64_t> partition_keys_first_row_id_;
   std::vector<int64_t> partition_payloads_first_row_id_;


Reply via email to