This is an automated email from the ASF dual-hosted git repository.

richox pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/auron.git


The following commit(s) were added to refs/heads/master by this push:
     new b36354e3 [AURON #1693] join operation should flush in time on 
duplicated keys (#1701)
b36354e3 is described below

commit b36354e3dd5ffd7320f8780cfb00986dbd8d2797
Author: bkhan <[email protected]>
AuthorDate: Sat Jan 17 14:49:20 2026 +0800

    [AURON #1693] join operation should flush in time on duplicated keys (#1701)
    
    <!--
    Thanks for sending a pull request! Please keep the following tips in
    mind:
    - Start the PR title with the related issue ID, e.g. '[AURON #XXXX]
    Short summary...'.
    - Make your PR title clear and descriptive, summarizing what this PR
    changes.
      - Provide a concise example to reproduce the issue, if possible.
      - Keep the PR description up to date with all changes.
    -->
    
    # Which issue does this PR close?
    
    <!--
    We generally require a GitHub issue to be filed for all bug fixes and
    enhancements and this helps us generate change logs for our releases.
    You can link an issue to this PR using the GitHub syntax. For example
    `Closes #123` indicates that this PR will close issue #123.
    -->
    
    Closes #1693.
    
     # Rationale for this change
    <!--
    Why are you proposing this change? If this is already explained clearly
    in the issue then this section is not needed.
    Explaining clearly why changes are proposed helps reviewers understand
    your changes and offer better suggestions for fixes.
    -->
    As discussed previously in #1693 and #1694, the join operation should
    check batch size and trigger flushing in a timely manner, to prevent
    extreme large batch size.
    
    
    # What changes are included in this PR?
    <!--
    There is no need to duplicate the description in the issue here but it
    is sometimes worth providing a summary of the individual changes in this
    PR.
    -->
    
    # Are there any user-facing changes?
    <!--
    If there are user-facing changes then we may require documentation to be
    updated before approving the PR.
    -->
    
    <!--
    If there are any breaking changes to public APIs, please add the `api
    change` label.
    -->
    
    # How was this patch tested?
    <!--
    If tests were added, say they were added here. Please make sure to add
    some test cases that check the changes thoroughly including negative and
    positive cases if possible.
    If it was tested in a way different from regular unit tests, please
    clarify how you tested step by step, ideally copy and paste-able, so
    that other reviewers can test and check, and descendants can verify in
    the future.
    If tests were not added, please describe why they were not added and/or
    why it was difficult to add.
    -->
    
    ---------
    
    Co-authored-by: Copilot <[email protected]>
---
 .../src/joins/smj/full_join.rs                     |  27 ++-
 .../datafusion-ext-plans/src/joins/test.rs         | 199 ++++++++++++++++++++-
 2 files changed, 222 insertions(+), 4 deletions(-)

diff --git a/native-engine/datafusion-ext-plans/src/joins/smj/full_join.rs 
b/native-engine/datafusion-ext-plans/src/joins/smj/full_join.rs
index a810b1b2..5a6bcf6d 100644
--- a/native-engine/datafusion-ext-plans/src/joins/smj/full_join.rs
+++ b/native-engine/datafusion-ext-plans/src/joins/smj/full_join.rs
@@ -56,6 +56,10 @@ impl<const L_OUTER: bool, const R_OUTER: bool> 
FullJoiner<L_OUTER, R_OUTER> {
         self.lindices.len() >= self.join_params.batch_size
     }
 
+    fn has_enough_room(&self, new_size: usize) -> bool {
+        self.lindices.len() + new_size <= self.join_params.batch_size
+    }
+
     async fn flush(
         mut self: Pin<&mut Self>,
         cur1: &mut StreamCursor,
@@ -158,9 +162,26 @@ impl<const L_OUTER: bool, const R_OUTER: bool> Joiner for 
FullJoiner<L_OUTER, R_
                         continue;
                     }
 
-                    for (&lidx, &ridx) in 
equal_lindices.iter().cartesian_product(&equal_rindices) {
-                        self.lindices.push(lidx);
-                        self.rindices.push(ridx);
+                    let new_size = equal_lindices.len() * equal_rindices.len();
+                    if self.has_enough_room(new_size) {
+                        // old cartesian_product way
+                        for (&lidx, &ridx) in
+                            
equal_lindices.iter().cartesian_product(&equal_rindices)
+                        {
+                            self.lindices.push(lidx);
+                            self.rindices.push(ridx);
+                        }
+                    } else {
+                        // do more aggressive flush
+                        for &lidx in &equal_lindices {
+                            for &ridx in &equal_rindices {
+                                self.lindices.push(lidx);
+                                self.rindices.push(ridx);
+                                if self.should_flush() {
+                                    self.as_mut().flush(cur1, cur2).await?;
+                                }
+                            }
+                        }
                     }
 
                     if r_equal {
diff --git a/native-engine/datafusion-ext-plans/src/joins/test.rs 
b/native-engine/datafusion-ext-plans/src/joins/test.rs
index 8427b6ed..cb48750c 100644
--- a/native-engine/datafusion-ext-plans/src/joins/test.rs
+++ b/native-engine/datafusion-ext-plans/src/joins/test.rs
@@ -31,7 +31,7 @@ mod tests {
         common::{JoinSide, Result},
         physical_expr::expressions::Column,
         physical_plan::{ExecutionPlan, common, joins::utils::*, 
test::TestMemoryExec},
-        prelude::SessionContext,
+        prelude::{SessionConfig, SessionContext},
     };
 
     use crate::{
@@ -283,6 +283,91 @@ mod tests {
         Ok((columns, batches))
     }
 
+    async fn join_collect_with_batch_size(
+        test_type: TestType,
+        left: Arc<dyn ExecutionPlan>,
+        right: Arc<dyn ExecutionPlan>,
+        on: JoinOn,
+        join_type: JoinType,
+        batch_size: usize,
+    ) -> Result<(Vec<String>, Vec<RecordBatch>)> {
+        MemManager::init(1000000);
+        let session_config = SessionConfig::new().with_batch_size(batch_size);
+        let session_ctx = SessionContext::new_with_config(session_config);
+        let task_ctx = session_ctx.task_ctx();
+        let schema = build_join_schema_for_test(&left.schema(), 
&right.schema(), join_type)?;
+
+        let join: Arc<dyn ExecutionPlan> = match test_type {
+            SMJ => {
+                let sort_options = vec![SortOptions::default(); on.len()];
+                Arc::new(SortMergeJoinExec::try_new(
+                    schema,
+                    left,
+                    right,
+                    on,
+                    join_type,
+                    sort_options,
+                )?)
+            }
+            BHJLeftProbed => {
+                let right = Arc::new(BroadcastJoinBuildHashMapExec::new(
+                    right,
+                    on.iter().map(|(_, right_key)| 
right_key.clone()).collect(),
+                ));
+                Arc::new(BroadcastJoinExec::try_new(
+                    schema,
+                    left,
+                    right,
+                    on,
+                    join_type,
+                    JoinSide::Right,
+                    true,
+                    None,
+                )?)
+            }
+            BHJRightProbed => {
+                let left = Arc::new(BroadcastJoinBuildHashMapExec::new(
+                    left,
+                    on.iter().map(|(left_key, _)| left_key.clone()).collect(),
+                ));
+                Arc::new(BroadcastJoinExec::try_new(
+                    schema,
+                    left,
+                    right,
+                    on,
+                    join_type,
+                    JoinSide::Left,
+                    true,
+                    None,
+                )?)
+            }
+            SHJLeftProbed => Arc::new(BroadcastJoinExec::try_new(
+                schema,
+                left,
+                right,
+                on,
+                join_type,
+                JoinSide::Right,
+                false,
+                None,
+            )?),
+            SHJRightProbed => Arc::new(BroadcastJoinExec::try_new(
+                schema,
+                left,
+                right,
+                on,
+                join_type,
+                JoinSide::Left,
+                false,
+                None,
+            )?),
+        };
+        let columns = columns(&join.schema());
+        let stream = join.execute(0, task_ctx)?;
+        let batches = common::collect(stream).await?;
+        Ok((columns, batches))
+    }
+
     const ALL_TEST_TYPE: [TestType; 5] = [
         SMJ,
         BHJLeftProbed,
@@ -447,6 +532,118 @@ mod tests {
         Ok(())
     }
 
+    #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
+    async fn join_inner_batchsize() -> Result<()> {
+        for test_type in ALL_TEST_TYPE {
+            let left = build_table(
+                ("a1", &vec![1, 1, 1, 1, 1]),
+                ("b1", &vec![1, 2, 3, 4, 5]),
+                ("c1", &vec![1, 2, 3, 4, 5]),
+            );
+            let right = build_table(
+                ("a2", &vec![1, 1, 1, 1, 1, 1, 1]),
+                ("b2", &vec![1, 2, 3, 4, 5, 6, 7]),
+                ("c2", &vec![1, 2, 3, 4, 5, 6, 7]),
+            );
+            let on: JoinOn = vec![(
+                Arc::new(Column::new_with_schema("a1", &left.schema())?),
+                Arc::new(Column::new_with_schema("a2", &right.schema())?),
+            )];
+            let expected = vec![
+                "+----+----+----+----+----+----+",
+                "| a1 | b1 | c1 | a2 | b2 | c2 |",
+                "+----+----+----+----+----+----+",
+                "| 1  | 1  | 1  | 1  | 1  | 1  |",
+                "| 1  | 1  | 1  | 1  | 2  | 2  |",
+                "| 1  | 1  | 1  | 1  | 3  | 3  |",
+                "| 1  | 1  | 1  | 1  | 4  | 4  |",
+                "| 1  | 1  | 1  | 1  | 5  | 5  |",
+                "| 1  | 1  | 1  | 1  | 6  | 6  |",
+                "| 1  | 1  | 1  | 1  | 7  | 7  |",
+                "| 1  | 2  | 2  | 1  | 1  | 1  |",
+                "| 1  | 2  | 2  | 1  | 2  | 2  |",
+                "| 1  | 2  | 2  | 1  | 3  | 3  |",
+                "| 1  | 2  | 2  | 1  | 4  | 4  |",
+                "| 1  | 2  | 2  | 1  | 5  | 5  |",
+                "| 1  | 2  | 2  | 1  | 6  | 6  |",
+                "| 1  | 2  | 2  | 1  | 7  | 7  |",
+                "| 1  | 3  | 3  | 1  | 1  | 1  |",
+                "| 1  | 3  | 3  | 1  | 2  | 2  |",
+                "| 1  | 3  | 3  | 1  | 3  | 3  |",
+                "| 1  | 3  | 3  | 1  | 4  | 4  |",
+                "| 1  | 3  | 3  | 1  | 5  | 5  |",
+                "| 1  | 3  | 3  | 1  | 6  | 6  |",
+                "| 1  | 3  | 3  | 1  | 7  | 7  |",
+                "| 1  | 4  | 4  | 1  | 1  | 1  |",
+                "| 1  | 4  | 4  | 1  | 2  | 2  |",
+                "| 1  | 4  | 4  | 1  | 3  | 3  |",
+                "| 1  | 4  | 4  | 1  | 4  | 4  |",
+                "| 1  | 4  | 4  | 1  | 5  | 5  |",
+                "| 1  | 4  | 4  | 1  | 6  | 6  |",
+                "| 1  | 4  | 4  | 1  | 7  | 7  |",
+                "| 1  | 5  | 5  | 1  | 1  | 1  |",
+                "| 1  | 5  | 5  | 1  | 2  | 2  |",
+                "| 1  | 5  | 5  | 1  | 3  | 3  |",
+                "| 1  | 5  | 5  | 1  | 4  | 4  |",
+                "| 1  | 5  | 5  | 1  | 5  | 5  |",
+                "| 1  | 5  | 5  | 1  | 6  | 6  |",
+                "| 1  | 5  | 5  | 1  | 7  | 7  |",
+                "+----+----+----+----+----+----+",
+            ];
+            let (_, batches) = join_collect_with_batch_size(
+                test_type,
+                left.clone(),
+                right.clone(),
+                on.clone(),
+                Inner,
+                2,
+            )
+            .await?;
+            assert_batches_sorted_eq!(expected, &batches);
+            let (_, batches) = join_collect_with_batch_size(
+                test_type,
+                left.clone(),
+                right.clone(),
+                on.clone(),
+                Inner,
+                3,
+            )
+            .await?;
+            assert_batches_sorted_eq!(expected, &batches);
+            let (_, batches) = join_collect_with_batch_size(
+                test_type,
+                left.clone(),
+                right.clone(),
+                on.clone(),
+                Inner,
+                4,
+            )
+            .await?;
+            assert_batches_sorted_eq!(expected, &batches);
+            let (_, batches) = join_collect_with_batch_size(
+                test_type,
+                left.clone(),
+                right.clone(),
+                on.clone(),
+                Inner,
+                5,
+            )
+            .await?;
+            assert_batches_sorted_eq!(expected, &batches);
+            let (_, batches) = join_collect_with_batch_size(
+                test_type,
+                left.clone(),
+                right.clone(),
+                on.clone(),
+                Inner,
+                7,
+            )
+            .await?;
+            assert_batches_sorted_eq!(expected, &batches);
+        }
+        Ok(())
+    }
+
     #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
     async fn join_left_one() -> Result<()> {
         for test_type in ALL_TEST_TYPE {

Reply via email to