This is an automated email from the ASF dual-hosted git repository.
chl-wxp pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/seatunnel.git
The following commit(s) were added to refs/heads/dev by this push:
new d991e6bbb2 [Improve][connector-milvus] Improved milvus source
enumerator splits allocation algorithm for subtasks (#10868)
d991e6bbb2 is described below
commit d991e6bbb21518b50873af436939172cc52b3bc6
Author: JeremyXin <[email protected]>
AuthorDate: Thu May 21 11:34:31 2026 +0800
[Improve][connector-milvus] Improved milvus source enumerator splits
allocation algorithm for subtasks (#10868)
good job
---
.../milvus/source/MilvusSourceSplitEnumerator.java | 28 ++-
.../seatunnel/milvus/source/MilvusSourceState.java | 1 +
.../source/MilvusSourceSplitEnumeratorTest.java | 213 +++++++++++++++++++++
3 files changed, 235 insertions(+), 7 deletions(-)
diff --git
a/seatunnel-connectors-v2/connector-milvus/src/main/java/org/apache/seatunnel/connectors/seatunnel/milvus/source/MilvusSourceSplitEnumerator.java
b/seatunnel-connectors-v2/connector-milvus/src/main/java/org/apache/seatunnel/connectors/seatunnel/milvus/source/MilvusSourceSplitEnumerator.java
index b1c242d682..0bed7df9d0 100644
---
a/seatunnel-connectors-v2/connector-milvus/src/main/java/org/apache/seatunnel/connectors/seatunnel/milvus/source/MilvusSourceSplitEnumerator.java
+++
b/seatunnel-connectors-v2/connector-milvus/src/main/java/org/apache/seatunnel/connectors/seatunnel/milvus/source/MilvusSourceSplitEnumerator.java
@@ -41,11 +41,14 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
+import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.stream.Collectors;
@Slf4j
public class MilvusSourceSplitEnumerator
@@ -57,6 +60,8 @@ public class MilvusSourceSplitEnumerator
private final Map<Integer, List<MilvusSourceSplit>> pendingSplits;
private final Object stateLock = new Object();
private MilvusClient client = null;
+ // Keeps round-robin assignment global across collections and restore.
+ private final AtomicInteger assignCount = new AtomicInteger(0);
private final ReadonlyConfig config;
@@ -74,6 +79,7 @@ public class MilvusSourceSplitEnumerator
} else {
this.pendingTables = new
ConcurrentLinkedQueue<>(sourceState.getPendingTables());
this.pendingSplits = new HashMap<>(sourceState.getPendingSplits());
+ this.assignCount.set(sourceState.getAssignCount());
}
}
@@ -160,22 +166,28 @@ public class MilvusSourceSplitEnumerator
return milvusSourceSplits;
}
- protected String createSplitId(TablePath tablePath, String index) {
+ private String createSplitId(TablePath tablePath, String index) {
return String.format("%s-%s", tablePath, index);
}
private void addPendingSplit(Collection<MilvusSourceSplit> splits) {
int readerCount = context.currentParallelism();
- for (MilvusSourceSplit split : splits) {
- int ownerReader = getSplitOwner(split.splitId(), readerCount);
+
+ List<MilvusSourceSplit> sortedSplits =
+ splits.stream()
+
.sorted(Comparator.comparing(MilvusSourceSplit::getSplitId))
+ .collect(Collectors.toList());
+
+ for (MilvusSourceSplit split : sortedSplits) {
+ int ownerReader = getSplitOwner(assignCount.getAndIncrement(),
readerCount);
log.info("Assigning {} to {} reader.", split, ownerReader);
pendingSplits.computeIfAbsent(ownerReader, r -> new
ArrayList<>()).add(split);
}
}
- private static int getSplitOwner(String tp, int numReaders) {
- return (tp.hashCode() & Integer.MAX_VALUE) % numReaders;
+ private static int getSplitOwner(int assignCount, int numReaders) {
+ return assignCount % numReaders;
}
private void assignSplit(Collection<Integer> readers) {
@@ -210,7 +222,7 @@ public class MilvusSourceSplitEnumerator
splits);
}
}
- log.info("Add back splits {} to JdbcSourceSplitEnumerator.",
splits.size());
+ log.info("Add back splits {} to MilvusSourceSplitEnumerator.",
splits.size());
}
private void addPendingSplit(Collection<MilvusSourceSplit> splits, int
ownerReader) {
@@ -241,7 +253,9 @@ public class MilvusSourceSplitEnumerator
public MilvusSourceState snapshotState(long checkpointId) throws Exception
{
synchronized (stateLock) {
return new MilvusSourceState(
- new ArrayList(pendingTables), new
HashMap<>(pendingSplits));
+ new ArrayList<>(pendingTables),
+ new HashMap<>(pendingSplits),
+ assignCount.get());
}
}
diff --git
a/seatunnel-connectors-v2/connector-milvus/src/main/java/org/apache/seatunnel/connectors/seatunnel/milvus/source/MilvusSourceState.java
b/seatunnel-connectors-v2/connector-milvus/src/main/java/org/apache/seatunnel/connectors/seatunnel/milvus/source/MilvusSourceState.java
index 8af6bc41d1..c85be6e2c0 100644
---
a/seatunnel-connectors-v2/connector-milvus/src/main/java/org/apache/seatunnel/connectors/seatunnel/milvus/source/MilvusSourceState.java
+++
b/seatunnel-connectors-v2/connector-milvus/src/main/java/org/apache/seatunnel/connectors/seatunnel/milvus/source/MilvusSourceState.java
@@ -32,4 +32,5 @@ public class MilvusSourceState implements Serializable {
private static final long serialVersionUID = 1718378968826165653L;
private List<TablePath> pendingTables;
private Map<Integer, List<MilvusSourceSplit>> pendingSplits;
+ private int assignCount;
}
diff --git
a/seatunnel-connectors-v2/connector-milvus/src/test/java/org/apache/seatunnel/connectors/seatunnel/milvus/source/MilvusSourceSplitEnumeratorTest.java
b/seatunnel-connectors-v2/connector-milvus/src/test/java/org/apache/seatunnel/connectors/seatunnel/milvus/source/MilvusSourceSplitEnumeratorTest.java
new file mode 100644
index 0000000000..96a625f7fe
--- /dev/null
+++
b/seatunnel-connectors-v2/connector-milvus/src/test/java/org/apache/seatunnel/connectors/seatunnel/milvus/source/MilvusSourceSplitEnumeratorTest.java
@@ -0,0 +1,213 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.seatunnel.connectors.seatunnel.milvus.source;
+
+import org.apache.seatunnel.api.common.metrics.MetricsContext;
+import org.apache.seatunnel.api.event.EventListener;
+import org.apache.seatunnel.api.source.SourceEvent;
+import org.apache.seatunnel.api.source.SourceSplitEnumerator;
+import org.apache.seatunnel.api.table.catalog.CatalogTable;
+import org.apache.seatunnel.api.table.catalog.TablePath;
+
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+
+import io.milvus.client.MilvusClient;
+import io.milvus.grpc.CollectionSchema;
+import io.milvus.grpc.DataType;
+import io.milvus.grpc.DescribeCollectionResponse;
+import io.milvus.grpc.FieldSchema;
+import io.milvus.param.R;
+
+import java.lang.reflect.Field;
+import java.lang.reflect.Method;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.LinkedHashMap;
+import java.util.LinkedHashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+public class MilvusSourceSplitEnumeratorTest {
+
+ @Test
+ public void shouldBalanceSplitsEvenlyAcrossReaders() throws Exception {
+ TestingContext context = new TestingContext(4);
+
+ MilvusSourceSplitEnumerator enumerator =
+ new MilvusSourceSplitEnumerator(context, null,
Collections.emptyMap(), null);
+
+ Method addPendingSplit =
+ MilvusSourceSplitEnumerator.class.getDeclaredMethod(
+ "addPendingSplit", java.util.Collection.class);
+ addPendingSplit.setAccessible(true);
+ addPendingSplit.invoke(enumerator, buildSplits(10));
+
+ MilvusSourceState state = enumerator.snapshotState(1L);
+ Map<Integer, List<MilvusSourceSplit>> pendingSplits =
state.getPendingSplits();
+
+ Assertions.assertEquals(4, pendingSplits.size());
+ Assertions.assertEquals(3, pendingSplits.get(0).size());
+ Assertions.assertEquals(3, pendingSplits.get(1).size());
+ Assertions.assertEquals(2, pendingSplits.get(2).size());
+ Assertions.assertEquals(2, pendingSplits.get(3).size());
+ }
+
+ @Test
+ public void shouldBalanceSingleSplitCollectionsAcrossReadersInRun() throws
Exception {
+ TestingContext context = new TestingContext(3);
+ Map<TablePath, CatalogTable> tables = new LinkedHashMap<>();
+ tables.put(
+ TablePath.of("db", null, "collection_0"),
+ createCatalogTable(TablePath.of("db", null, "collection_0")));
+ tables.put(
+ TablePath.of("db", null, "collection_1"),
+ createCatalogTable(TablePath.of("db", null, "collection_1")));
+ tables.put(
+ TablePath.of("db", null, "collection_2"),
+ createCatalogTable(TablePath.of("db", null, "collection_2")));
+
+ MilvusSourceSplitEnumerator enumerator =
+ new MilvusSourceSplitEnumerator(context, null, tables, null);
+ setClient(enumerator, mockSingleSplitMilvusClient());
+
+ enumerator.run();
+
+ Assertions.assertEquals(1, context.getAssignmentSize(0));
+ Assertions.assertEquals(1, context.getAssignmentSize(1));
+ Assertions.assertEquals(1, context.getAssignmentSize(2));
+ }
+
+ @Test
+ public void shouldContinueRoundRobinAfterRestore() throws Exception {
+ TestingContext context = new TestingContext(3);
+ TablePath remainingTable = TablePath.of("db", null,
"collection_after_restore");
+ Map<TablePath, CatalogTable> tables = new LinkedHashMap<>();
+ tables.put(remainingTable, createCatalogTable(remainingTable));
+
+ MilvusSourceState restoredState =
+ new MilvusSourceState(
+ Collections.singletonList(remainingTable), new
HashMap<>(), 1);
+ MilvusSourceSplitEnumerator enumerator =
+ new MilvusSourceSplitEnumerator(context, null, tables,
restoredState);
+ setClient(enumerator, mockSingleSplitMilvusClient());
+
+ enumerator.run();
+
+ Assertions.assertEquals(0, context.getAssignmentSize(0));
+ Assertions.assertEquals(1, context.getAssignmentSize(1));
+ Assertions.assertEquals(0, context.getAssignmentSize(2));
+ }
+
+ private List<MilvusSourceSplit> buildSplits(int size) {
+ List<MilvusSourceSplit> splits = new ArrayList<>();
+ for (int i = 0; i < size; i++) {
+ splits.add(MilvusSourceSplit.builder().splitId("split-" +
i).build());
+ }
+ return splits;
+ }
+
+ private CatalogTable createCatalogTable(TablePath tablePath) {
+ CatalogTable catalogTable = mock(CatalogTable.class);
+ when(catalogTable.getTablePath()).thenReturn(tablePath);
+ return catalogTable;
+ }
+
+ private MilvusClient mockSingleSplitMilvusClient() {
+ MilvusClient client = mock(MilvusClient.class);
+ FieldSchema partitionKeyField =
+ FieldSchema.newBuilder()
+ .setName("partition_key")
+ .setDataType(DataType.VarChar)
+ .setIsPartitionKey(true)
+ .build();
+ CollectionSchema schema =
+
CollectionSchema.newBuilder().addFields(partitionKeyField).build();
+ DescribeCollectionResponse response =
+
DescribeCollectionResponse.newBuilder().setSchema(schema).build();
+
+ @SuppressWarnings("unchecked")
+ R<DescribeCollectionResponse> describeResponse = mock(R.class);
+ when(describeResponse.getData()).thenReturn(response);
+ when(client.describeCollection(any())).thenReturn(describeResponse);
+ return client;
+ }
+
+ private void setClient(MilvusSourceSplitEnumerator enumerator,
MilvusClient client)
+ throws Exception {
+ Field clientField =
MilvusSourceSplitEnumerator.class.getDeclaredField("client");
+ clientField.setAccessible(true);
+ clientField.set(enumerator, client);
+ }
+
+ private static class TestingContext
+ implements SourceSplitEnumerator.Context<MilvusSourceSplit> {
+ private final int parallelism;
+ private final Map<Integer, List<MilvusSourceSplit>> assignments = new
HashMap<>();
+ private final Set<Integer> readers;
+
+ private TestingContext(int parallelism) {
+ this.parallelism = parallelism;
+ this.readers = new LinkedHashSet<>();
+ for (int i = 0; i < parallelism; i++) {
+ readers.add(i);
+ }
+ }
+
+ @Override
+ public int currentParallelism() {
+ return parallelism;
+ }
+
+ @Override
+ public Set<Integer> registeredReaders() {
+ return readers;
+ }
+
+ @Override
+ public void assignSplit(int subtaskId, List<MilvusSourceSplit> splits)
{
+ assignments.computeIfAbsent(subtaskId, ignored -> new
ArrayList<>()).addAll(splits);
+ }
+
+ @Override
+ public void signalNoMoreSplits(int subtask) {}
+
+ @Override
+ public void sendEventToSourceReader(int subtaskId, SourceEvent event)
{}
+
+ @Override
+ public MetricsContext getMetricsContext() {
+ return null;
+ }
+
+ @Override
+ public EventListener getEventListener() {
+ return null;
+ }
+
+ private int getAssignmentSize(int subtaskId) {
+ return assignments.getOrDefault(subtaskId,
Collections.emptyList()).size();
+ }
+ }
+}