This is an automated email from the ASF dual-hosted git repository.

shenghang pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/seatunnel.git


The following commit(s) were added to refs/heads/dev by this push:
     new 75bc71beb8 [Fix[Connector-V2][Hbase] Avoid duplicate split assignment 
on restore (#10310)
75bc71beb8 is described below

commit 75bc71beb8be0a7135c5c781c4fe9402428de681
Author: yzeng1618 <[email protected]>
AuthorDate: Sun Jan 11 20:48:14 2026 +0800

    [Fix[Connector-V2][Hbase] Avoid duplicate split assignment on restore 
(#10310)
    
    Co-authored-by: zengyi <[email protected]>
---
 .../hbase/source/HbaseSourceSplitEnumerator.java   | 52 +++++++++++--
 .../source/HbaseSourceSplitEnumeratorTest.java     | 86 ++++++++++++++++++++++
 2 files changed, 133 insertions(+), 5 deletions(-)

diff --git 
a/seatunnel-connectors-v2/connector-hbase/src/main/java/org/apache/seatunnel/connectors/seatunnel/hbase/source/HbaseSourceSplitEnumerator.java
 
b/seatunnel-connectors-v2/connector-hbase/src/main/java/org/apache/seatunnel/connectors/seatunnel/hbase/source/HbaseSourceSplitEnumerator.java
index 73b1d6862a..54306ef6ec 100644
--- 
a/seatunnel-connectors-v2/connector-hbase/src/main/java/org/apache/seatunnel/connectors/seatunnel/hbase/source/HbaseSourceSplitEnumerator.java
+++ 
b/seatunnel-connectors-v2/connector-hbase/src/main/java/org/apache/seatunnel/connectors/seatunnel/hbase/source/HbaseSourceSplitEnumerator.java
@@ -50,6 +50,9 @@ public class HbaseSourceSplitEnumerator
     /** The splits that have not assigned */
     private Set<HbaseSourceSplit> pendingSplit;
 
+    /** Whether the pending splits have been initialized */
+    private boolean initialized = false;
+
     private HbaseParameters hbaseParameters;
 
     private HbaseClient hbaseClient;
@@ -71,23 +74,40 @@ public class HbaseSourceSplitEnumerator
             Context<HbaseSourceSplit> context,
             HbaseParameters hbaseParameters,
             HbaseClient hbaseClient) {
-        this(context, hbaseParameters, new HashSet<>());
-        this.hbaseClient = hbaseClient;
+        this(context, hbaseParameters, new HashSet<>(), hbaseClient);
+    }
+
+    @VisibleForTesting
+    public HbaseSourceSplitEnumerator(
+            Context<HbaseSourceSplit> context,
+            HbaseParameters hbaseParameters,
+            HbaseSourceState sourceState,
+            HbaseClient hbaseClient) {
+        this(context, hbaseParameters, sourceState.getAssignedSplits(), 
hbaseClient);
     }
 
     private HbaseSourceSplitEnumerator(
             Context<HbaseSourceSplit> context,
             HbaseParameters hbaseParameters,
             Set<HbaseSourceSplit> assignedSplit) {
+        this(context, hbaseParameters, assignedSplit, 
HbaseClient.createInstance(hbaseParameters));
+    }
+
+    private HbaseSourceSplitEnumerator(
+            Context<HbaseSourceSplit> context,
+            HbaseParameters hbaseParameters,
+            Set<HbaseSourceSplit> assignedSplit,
+            HbaseClient hbaseClient) {
         this.context = context;
         this.hbaseParameters = hbaseParameters;
         this.assignedSplit = assignedSplit;
-        this.hbaseClient = HbaseClient.createInstance(hbaseParameters);
+        this.hbaseClient = hbaseClient;
     }
 
     @Override
     public void open() {
         this.pendingSplit = new HashSet<>();
+        this.initialized = false;
     }
 
     @Override
@@ -110,7 +130,9 @@ public class HbaseSourceSplitEnumerator
     public void addSplitsBack(List<HbaseSourceSplit> splits, int subtaskId) {
         if (!splits.isEmpty()) {
             pendingSplit.addAll(splits);
-            assignSplit(subtaskId);
+            if (context.registeredReaders().contains(subtaskId)) {
+                assignSplit(subtaskId);
+            }
         }
     }
 
@@ -121,10 +143,30 @@ public class HbaseSourceSplitEnumerator
 
     @Override
     public void registerReader(int subtaskId) {
-        pendingSplit = getTableSplits();
+        initializePendingSplits();
         assignSplit(subtaskId);
     }
 
+    private void initializePendingSplits() {
+        if (initialized) {
+            return;
+        }
+        Set<HbaseSourceSplit> tableSplits = getTableSplits();
+        Set<String> existedSplitIds =
+                
pendingSplit.stream().map(HbaseSourceSplit::splitId).collect(Collectors.toSet());
+        if (!assignedSplit.isEmpty()) {
+            existedSplitIds.addAll(
+                    assignedSplit.stream()
+                            .map(HbaseSourceSplit::splitId)
+                            .collect(Collectors.toSet()));
+        }
+        pendingSplit.addAll(
+                tableSplits.stream()
+                        .filter(split -> 
!existedSplitIds.contains(split.splitId()))
+                        .collect(Collectors.toSet()));
+        initialized = true;
+    }
+
     @Override
     public HbaseSourceState snapshotState(long checkpointId) throws Exception {
         return new HbaseSourceState(assignedSplit);
diff --git 
a/seatunnel-connectors-v2/connector-hbase/src/test/java/org/apache/seatunnel/connectors/seatunnel/hbase/source/HbaseSourceSplitEnumeratorTest.java
 
b/seatunnel-connectors-v2/connector-hbase/src/test/java/org/apache/seatunnel/connectors/seatunnel/hbase/source/HbaseSourceSplitEnumeratorTest.java
index fd5eb0cceb..0fffeec0cc 100644
--- 
a/seatunnel-connectors-v2/connector-hbase/src/test/java/org/apache/seatunnel/connectors/seatunnel/hbase/source/HbaseSourceSplitEnumeratorTest.java
+++ 
b/seatunnel-connectors-v2/connector-hbase/src/test/java/org/apache/seatunnel/connectors/seatunnel/hbase/source/HbaseSourceSplitEnumeratorTest.java
@@ -27,16 +27,26 @@ import org.apache.hadoop.hbase.util.Bytes;
 
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
+import org.mockito.ArgumentCaptor;
 import org.mockito.Mock;
 import org.mockito.MockitoAnnotations;
 
 import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
 import java.util.Set;
+import java.util.stream.Collectors;
 
 import static org.junit.jupiter.api.Assertions.assertArrayEquals;
 import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
 import static org.junit.jupiter.api.Assertions.assertNotNull;
 import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
 public class HbaseSourceSplitEnumeratorTest {
@@ -376,4 +386,80 @@ public class HbaseSourceSplitEnumeratorTest {
         assertArrayEquals(Bytes.toBytes("row100"), split.getStartRow());
         assertArrayEquals(Bytes.toBytes("row200"), split.getEndRow());
     }
+
+    @Test
+    void testRestoreOnlyAssignReturnedSplits() throws Exception {
+        when(context.currentParallelism()).thenReturn(1);
+        when(context.registeredReaders()).thenReturn(Collections.emptySet());
+
+        byte[][] startKeys = {
+            HConstants.EMPTY_BYTE_ARRAY, Bytes.toBytes("row100"), 
Bytes.toBytes("row200")
+        };
+        byte[][] endKeys = {
+            Bytes.toBytes("row100"), Bytes.toBytes("row200"), 
HConstants.EMPTY_BYTE_ARRAY
+        };
+        when(regionLocator.getStartKeys()).thenReturn(startKeys);
+        when(regionLocator.getEndKeys()).thenReturn(endKeys);
+
+        Set<HbaseSourceSplit> assignedSplits = new HashSet<>();
+        assignedSplits.add(new HbaseSourceSplit(0, startKeys[0], endKeys[0]));
+        assignedSplits.add(new HbaseSourceSplit(1, startKeys[1], endKeys[1]));
+        assignedSplits.add(new HbaseSourceSplit(2, startKeys[2], endKeys[2]));
+
+        HbaseSourceSplitEnumerator restoredEnumerator =
+                new HbaseSourceSplitEnumerator(
+                        context,
+                        hbaseParameters,
+                        new HbaseSourceState(assignedSplits),
+                        hbaseClient);
+
+        restoredEnumerator.open();
+
+        List<HbaseSourceSplit> returnedSplits =
+                Arrays.asList(
+                        new HbaseSourceSplit(1, startKeys[1], endKeys[1]),
+                        new HbaseSourceSplit(2, startKeys[2], endKeys[2]));
+        restoredEnumerator.addSplitsBack(returnedSplits, 0);
+
+        ArgumentCaptor<List<HbaseSourceSplit>> assignedCaptor = 
ArgumentCaptor.forClass(List.class);
+        restoredEnumerator.registerReader(0);
+
+        verify(context, times(1)).assignSplit(eq(0), assignedCaptor.capture());
+        Set<String> assignedSplitIds =
+                assignedCaptor.getValue().stream()
+                        .map(HbaseSourceSplit::splitId)
+                        .collect(Collectors.toSet());
+        assertEquals(2, assignedSplitIds.size());
+        assertTrue(assignedSplitIds.contains("hbase_source_split_1"));
+        assertTrue(assignedSplitIds.contains("hbase_source_split_2"));
+        assertFalse(assignedSplitIds.contains("hbase_source_split_0"));
+    }
+
+    @Test
+    void 
testRegisterReaderInitializePendingSplitOnlyOnceWhenParallelismMoreThanOne()
+            throws Exception {
+        when(context.currentParallelism()).thenReturn(2);
+
+        byte[][] startKeys = {
+            HConstants.EMPTY_BYTE_ARRAY,
+            Bytes.toBytes("row100"),
+            Bytes.toBytes("row200"),
+            Bytes.toBytes("row300")
+        };
+        byte[][] endKeys = {
+            Bytes.toBytes("row100"),
+            Bytes.toBytes("row200"),
+            Bytes.toBytes("row300"),
+            HConstants.EMPTY_BYTE_ARRAY
+        };
+        when(regionLocator.getStartKeys()).thenReturn(startKeys);
+        when(regionLocator.getEndKeys()).thenReturn(endKeys);
+
+        enumerator.open();
+        enumerator.registerReader(0);
+        enumerator.registerReader(1);
+
+        verify(hbaseClient, times(1)).getRegionLocator("test_table");
+        assertEquals(0, enumerator.currentUnassignedSplitSize());
+    }
 }

Reply via email to