This is an automated email from the ASF dual-hosted git repository.

chl-wxp pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/seatunnel.git


The following commit(s) were added to refs/heads/dev by this push:
     new d991e6bbb2 [Improve][connector-milvus] Improved milvus source 
enumerator splits allocation algorithm for subtasks (#10868)
d991e6bbb2 is described below

commit d991e6bbb21518b50873af436939172cc52b3bc6
Author: JeremyXin <[email protected]>
AuthorDate: Thu May 21 11:34:31 2026 +0800

    [Improve][connector-milvus] Improved milvus source enumerator splits 
allocation algorithm for subtasks (#10868)
    
    good job
---
 .../milvus/source/MilvusSourceSplitEnumerator.java |  28 ++-
 .../seatunnel/milvus/source/MilvusSourceState.java |   1 +
 .../source/MilvusSourceSplitEnumeratorTest.java    | 213 +++++++++++++++++++++
 3 files changed, 235 insertions(+), 7 deletions(-)

diff --git 
a/seatunnel-connectors-v2/connector-milvus/src/main/java/org/apache/seatunnel/connectors/seatunnel/milvus/source/MilvusSourceSplitEnumerator.java
 
b/seatunnel-connectors-v2/connector-milvus/src/main/java/org/apache/seatunnel/connectors/seatunnel/milvus/source/MilvusSourceSplitEnumerator.java
index b1c242d682..0bed7df9d0 100644
--- 
a/seatunnel-connectors-v2/connector-milvus/src/main/java/org/apache/seatunnel/connectors/seatunnel/milvus/source/MilvusSourceSplitEnumerator.java
+++ 
b/seatunnel-connectors-v2/connector-milvus/src/main/java/org/apache/seatunnel/connectors/seatunnel/milvus/source/MilvusSourceSplitEnumerator.java
@@ -41,11 +41,14 @@ import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collections;
+import java.util.Comparator;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.stream.Collectors;
 
 @Slf4j
 public class MilvusSourceSplitEnumerator
@@ -57,6 +60,8 @@ public class MilvusSourceSplitEnumerator
     private final Map<Integer, List<MilvusSourceSplit>> pendingSplits;
     private final Object stateLock = new Object();
     private MilvusClient client = null;
+    // Keeps round-robin assignment global across collections and restore.
+    private final AtomicInteger assignCount = new AtomicInteger(0);
 
     private final ReadonlyConfig config;
 
@@ -74,6 +79,7 @@ public class MilvusSourceSplitEnumerator
         } else {
             this.pendingTables = new 
ConcurrentLinkedQueue<>(sourceState.getPendingTables());
             this.pendingSplits = new HashMap<>(sourceState.getPendingSplits());
+            this.assignCount.set(sourceState.getAssignCount());
         }
     }
 
@@ -160,22 +166,28 @@ public class MilvusSourceSplitEnumerator
         return milvusSourceSplits;
     }
 
-    protected String createSplitId(TablePath tablePath, String index) {
+    private String createSplitId(TablePath tablePath, String index) {
         return String.format("%s-%s", tablePath, index);
     }
 
     private void addPendingSplit(Collection<MilvusSourceSplit> splits) {
         int readerCount = context.currentParallelism();
-        for (MilvusSourceSplit split : splits) {
-            int ownerReader = getSplitOwner(split.splitId(), readerCount);
+
+        List<MilvusSourceSplit> sortedSplits =
+                splits.stream()
+                        
.sorted(Comparator.comparing(MilvusSourceSplit::getSplitId))
+                        .collect(Collectors.toList());
+
+        for (MilvusSourceSplit split : sortedSplits) {
+            int ownerReader = getSplitOwner(assignCount.getAndIncrement(), 
readerCount);
             log.info("Assigning {} to {} reader.", split, ownerReader);
 
             pendingSplits.computeIfAbsent(ownerReader, r -> new 
ArrayList<>()).add(split);
         }
     }
 
-    private static int getSplitOwner(String tp, int numReaders) {
-        return (tp.hashCode() & Integer.MAX_VALUE) % numReaders;
+    private static int getSplitOwner(int assignCount, int numReaders) {
+        return assignCount % numReaders;
     }
 
     private void assignSplit(Collection<Integer> readers) {
@@ -210,7 +222,7 @@ public class MilvusSourceSplitEnumerator
                         splits);
             }
         }
-        log.info("Add back splits {} to JdbcSourceSplitEnumerator.", 
splits.size());
+        log.info("Add back splits {} to MilvusSourceSplitEnumerator.", 
splits.size());
     }
 
     private void addPendingSplit(Collection<MilvusSourceSplit> splits, int 
ownerReader) {
@@ -241,7 +253,9 @@ public class MilvusSourceSplitEnumerator
     public MilvusSourceState snapshotState(long checkpointId) throws Exception 
{
         synchronized (stateLock) {
             return new MilvusSourceState(
-                    new ArrayList(pendingTables), new 
HashMap<>(pendingSplits));
+                    new ArrayList<>(pendingTables),
+                    new HashMap<>(pendingSplits),
+                    assignCount.get());
         }
     }
 
diff --git 
a/seatunnel-connectors-v2/connector-milvus/src/main/java/org/apache/seatunnel/connectors/seatunnel/milvus/source/MilvusSourceState.java
 
b/seatunnel-connectors-v2/connector-milvus/src/main/java/org/apache/seatunnel/connectors/seatunnel/milvus/source/MilvusSourceState.java
index 8af6bc41d1..c85be6e2c0 100644
--- 
a/seatunnel-connectors-v2/connector-milvus/src/main/java/org/apache/seatunnel/connectors/seatunnel/milvus/source/MilvusSourceState.java
+++ 
b/seatunnel-connectors-v2/connector-milvus/src/main/java/org/apache/seatunnel/connectors/seatunnel/milvus/source/MilvusSourceState.java
@@ -32,4 +32,5 @@ public class MilvusSourceState implements Serializable {
     private static final long serialVersionUID = 1718378968826165653L;
     private List<TablePath> pendingTables;
     private Map<Integer, List<MilvusSourceSplit>> pendingSplits;
+    private int assignCount;
 }
diff --git 
a/seatunnel-connectors-v2/connector-milvus/src/test/java/org/apache/seatunnel/connectors/seatunnel/milvus/source/MilvusSourceSplitEnumeratorTest.java
 
b/seatunnel-connectors-v2/connector-milvus/src/test/java/org/apache/seatunnel/connectors/seatunnel/milvus/source/MilvusSourceSplitEnumeratorTest.java
new file mode 100644
index 0000000000..96a625f7fe
--- /dev/null
+++ 
b/seatunnel-connectors-v2/connector-milvus/src/test/java/org/apache/seatunnel/connectors/seatunnel/milvus/source/MilvusSourceSplitEnumeratorTest.java
@@ -0,0 +1,213 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.seatunnel.connectors.seatunnel.milvus.source;
+
+import org.apache.seatunnel.api.common.metrics.MetricsContext;
+import org.apache.seatunnel.api.event.EventListener;
+import org.apache.seatunnel.api.source.SourceEvent;
+import org.apache.seatunnel.api.source.SourceSplitEnumerator;
+import org.apache.seatunnel.api.table.catalog.CatalogTable;
+import org.apache.seatunnel.api.table.catalog.TablePath;
+
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+
+import io.milvus.client.MilvusClient;
+import io.milvus.grpc.CollectionSchema;
+import io.milvus.grpc.DataType;
+import io.milvus.grpc.DescribeCollectionResponse;
+import io.milvus.grpc.FieldSchema;
+import io.milvus.param.R;
+
+import java.lang.reflect.Field;
+import java.lang.reflect.Method;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.LinkedHashMap;
+import java.util.LinkedHashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+public class MilvusSourceSplitEnumeratorTest {
+
+    @Test
+    public void shouldBalanceSplitsEvenlyAcrossReaders() throws Exception {
+        TestingContext context = new TestingContext(4);
+
+        MilvusSourceSplitEnumerator enumerator =
+                new MilvusSourceSplitEnumerator(context, null, 
Collections.emptyMap(), null);
+
+        Method addPendingSplit =
+                MilvusSourceSplitEnumerator.class.getDeclaredMethod(
+                        "addPendingSplit", java.util.Collection.class);
+        addPendingSplit.setAccessible(true);
+        addPendingSplit.invoke(enumerator, buildSplits(10));
+
+        MilvusSourceState state = enumerator.snapshotState(1L);
+        Map<Integer, List<MilvusSourceSplit>> pendingSplits = 
state.getPendingSplits();
+
+        Assertions.assertEquals(4, pendingSplits.size());
+        Assertions.assertEquals(3, pendingSplits.get(0).size());
+        Assertions.assertEquals(3, pendingSplits.get(1).size());
+        Assertions.assertEquals(2, pendingSplits.get(2).size());
+        Assertions.assertEquals(2, pendingSplits.get(3).size());
+    }
+
+    @Test
+    public void shouldBalanceSingleSplitCollectionsAcrossReadersInRun() throws 
Exception {
+        TestingContext context = new TestingContext(3);
+        Map<TablePath, CatalogTable> tables = new LinkedHashMap<>();
+        tables.put(
+                TablePath.of("db", null, "collection_0"),
+                createCatalogTable(TablePath.of("db", null, "collection_0")));
+        tables.put(
+                TablePath.of("db", null, "collection_1"),
+                createCatalogTable(TablePath.of("db", null, "collection_1")));
+        tables.put(
+                TablePath.of("db", null, "collection_2"),
+                createCatalogTable(TablePath.of("db", null, "collection_2")));
+
+        MilvusSourceSplitEnumerator enumerator =
+                new MilvusSourceSplitEnumerator(context, null, tables, null);
+        setClient(enumerator, mockSingleSplitMilvusClient());
+
+        enumerator.run();
+
+        Assertions.assertEquals(1, context.getAssignmentSize(0));
+        Assertions.assertEquals(1, context.getAssignmentSize(1));
+        Assertions.assertEquals(1, context.getAssignmentSize(2));
+    }
+
+    @Test
+    public void shouldContinueRoundRobinAfterRestore() throws Exception {
+        TestingContext context = new TestingContext(3);
+        TablePath remainingTable = TablePath.of("db", null, 
"collection_after_restore");
+        Map<TablePath, CatalogTable> tables = new LinkedHashMap<>();
+        tables.put(remainingTable, createCatalogTable(remainingTable));
+
+        MilvusSourceState restoredState =
+                new MilvusSourceState(
+                        Collections.singletonList(remainingTable), new 
HashMap<>(), 1);
+        MilvusSourceSplitEnumerator enumerator =
+                new MilvusSourceSplitEnumerator(context, null, tables, 
restoredState);
+        setClient(enumerator, mockSingleSplitMilvusClient());
+
+        enumerator.run();
+
+        Assertions.assertEquals(0, context.getAssignmentSize(0));
+        Assertions.assertEquals(1, context.getAssignmentSize(1));
+        Assertions.assertEquals(0, context.getAssignmentSize(2));
+    }
+
+    private List<MilvusSourceSplit> buildSplits(int size) {
+        List<MilvusSourceSplit> splits = new ArrayList<>();
+        for (int i = 0; i < size; i++) {
+            splits.add(MilvusSourceSplit.builder().splitId("split-" + 
i).build());
+        }
+        return splits;
+    }
+
+    private CatalogTable createCatalogTable(TablePath tablePath) {
+        CatalogTable catalogTable = mock(CatalogTable.class);
+        when(catalogTable.getTablePath()).thenReturn(tablePath);
+        return catalogTable;
+    }
+
+    private MilvusClient mockSingleSplitMilvusClient() {
+        MilvusClient client = mock(MilvusClient.class);
+        FieldSchema partitionKeyField =
+                FieldSchema.newBuilder()
+                        .setName("partition_key")
+                        .setDataType(DataType.VarChar)
+                        .setIsPartitionKey(true)
+                        .build();
+        CollectionSchema schema =
+                
CollectionSchema.newBuilder().addFields(partitionKeyField).build();
+        DescribeCollectionResponse response =
+                
DescribeCollectionResponse.newBuilder().setSchema(schema).build();
+
+        @SuppressWarnings("unchecked")
+        R<DescribeCollectionResponse> describeResponse = mock(R.class);
+        when(describeResponse.getData()).thenReturn(response);
+        when(client.describeCollection(any())).thenReturn(describeResponse);
+        return client;
+    }
+
+    private void setClient(MilvusSourceSplitEnumerator enumerator, 
MilvusClient client)
+            throws Exception {
+        Field clientField = 
MilvusSourceSplitEnumerator.class.getDeclaredField("client");
+        clientField.setAccessible(true);
+        clientField.set(enumerator, client);
+    }
+
+    private static class TestingContext
+            implements SourceSplitEnumerator.Context<MilvusSourceSplit> {
+        private final int parallelism;
+        private final Map<Integer, List<MilvusSourceSplit>> assignments = new 
HashMap<>();
+        private final Set<Integer> readers;
+
+        private TestingContext(int parallelism) {
+            this.parallelism = parallelism;
+            this.readers = new LinkedHashSet<>();
+            for (int i = 0; i < parallelism; i++) {
+                readers.add(i);
+            }
+        }
+
+        @Override
+        public int currentParallelism() {
+            return parallelism;
+        }
+
+        @Override
+        public Set<Integer> registeredReaders() {
+            return readers;
+        }
+
+        @Override
+        public void assignSplit(int subtaskId, List<MilvusSourceSplit> splits) 
{
+            assignments.computeIfAbsent(subtaskId, ignored -> new 
ArrayList<>()).addAll(splits);
+        }
+
+        @Override
+        public void signalNoMoreSplits(int subtask) {}
+
+        @Override
+        public void sendEventToSourceReader(int subtaskId, SourceEvent event) 
{}
+
+        @Override
+        public MetricsContext getMetricsContext() {
+            return null;
+        }
+
+        @Override
+        public EventListener getEventListener() {
+            return null;
+        }
+
+        private int getAssignmentSize(int subtaskId) {
+            return assignments.getOrDefault(subtaskId, 
Collections.emptyList()).size();
+        }
+    }
+}

Reply via email to