This is an automated email from the ASF dual-hosted git repository.
JingsongLi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git
The following commit(s) were added to refs/heads/master by this push:
new d13301ccd0 [spark] support distributed execution of vector search on
spark (#8108)
d13301ccd0 is described below
commit d13301ccd0ad841e1fb4315e53a20b81f4ac2cec
Author: Stefanietry <[email protected]>
AuthorDate: Fri Jun 5 15:52:09 2026 +0800
[spark] support distributed execution of vector search on spark (#8108)
Purpose: Currently, vector search operation is executed on a single node
within the driver, which may lead to performance bottlenecks when
dealing with large amounts of data. This issue aims to implement a
distributed execution capability.
---
docs/generated/core_configuration.html | 6 +
.../main/java/org/apache/paimon/CoreOptions.java | 10 ++
.../globalindex/GlobalIndexResultSerializer.java | 17 ++
.../apache/paimon/utils/RoaringNavigableMap64.java | 5 +-
.../apache/paimon/table/source/VectorReadImpl.java | 17 +-
.../table/source/VectorSearchBuilderImpl.java | 12 +-
.../paimon/spark/read/SparkEngineContext.java | 63 +++++++
.../paimon/spark/read/SparkVectorReadImpl.java | 182 +++++++++++++++++++++
.../spark/read/SparkVectorSearchBuilderImpl.java | 41 +++++
.../org/apache/paimon/spark/PaimonBaseScan.scala | 12 +-
.../apache/paimon/spark/SparkMultimodalITCase.java | 34 +++-
11 files changed, 375 insertions(+), 24 deletions(-)
diff --git a/docs/generated/core_configuration.html
b/docs/generated/core_configuration.html
index 3094d2e9d0..7908ed1c61 100644
--- a/docs/generated/core_configuration.html
+++ b/docs/generated/core_configuration.html
@@ -1650,6 +1650,12 @@ If the data size allocated for the sorting task is
uneven,which may lead to perf
<td>String</td>
<td>Specifies column names that should be stored as vector type.
This is used when you want to treat a ARRAY column as a VECTOR.</td>
</tr>
+ <tr>
+ <td><h5>vector-search.distribute.enabled</h5></td>
+ <td style="word-wrap: break-word;">false</td>
+ <td>Boolean</td>
+ <td>Whether to process distributed vector search.</td>
+ </tr>
<tr>
<td><h5>vector.file.format</h5></td>
<td style="word-wrap: break-word;">(none)</td>
diff --git a/paimon-api/src/main/java/org/apache/paimon/CoreOptions.java
b/paimon-api/src/main/java/org/apache/paimon/CoreOptions.java
index 80f0a8cec0..118b029c03 100644
--- a/paimon-api/src/main/java/org/apache/paimon/CoreOptions.java
+++ b/paimon-api/src/main/java/org/apache/paimon/CoreOptions.java
@@ -2597,6 +2597,12 @@ public class CoreOptions implements Serializable {
+ " Default is the same as
TARGET_FILE_SIZE.")
.build());
+ public static final ConfigOption<Boolean> VECTOR_SEARCH_DISTRIBUTE_ENABLED
=
+ key("vector-search.distribute.enabled")
+ .booleanType()
+ .defaultValue(false)
+ .withDescription("Whether to process distributed vector
search.");
+
@Immutable
public static final ConfigOption<Boolean> PK_CLUSTERING_OVERRIDE =
key("pk-clustering-override")
@@ -4077,6 +4083,10 @@ public class CoreOptions implements Serializable {
.orElse(targetFileSize(false));
}
+ public boolean vectorSearchDistributeEnabled() {
+ return options.get(VECTOR_SEARCH_DISTRIBUTE_ENABLED);
+ }
+
/** Specifies the merge engine for table with primary key. */
public enum MergeEngine implements DescribedEnum {
DEDUPLICATE("deduplicate", "De-duplicate and keep the last row."),
diff --git
a/paimon-common/src/main/java/org/apache/paimon/globalindex/GlobalIndexResultSerializer.java
b/paimon-common/src/main/java/org/apache/paimon/globalindex/GlobalIndexResultSerializer.java
index 66a43b082d..5559c59c4f 100644
---
a/paimon-common/src/main/java/org/apache/paimon/globalindex/GlobalIndexResultSerializer.java
+++
b/paimon-common/src/main/java/org/apache/paimon/globalindex/GlobalIndexResultSerializer.java
@@ -23,6 +23,7 @@ import org.apache.paimon.io.DataInputDeserializer;
import org.apache.paimon.io.DataInputView;
import org.apache.paimon.io.DataOutputSerializer;
import org.apache.paimon.io.DataOutputView;
+import org.apache.paimon.utils.Preconditions;
import org.apache.paimon.utils.RoaringNavigableMap64;
import java.io.IOException;
@@ -116,4 +117,20 @@ public class GlobalIndexResultSerializer implements
Serializer<GlobalIndexResult
return ScoredGlobalIndexResult.create(roaringNavigableMap64,
scoreMap::get);
}
+
+ public byte[] serialize(GlobalIndexResult globalIndexResult) throws
IOException {
+ DataOutputSerializer dataOutputSerializer = new
DataOutputSerializer(1024);
+ serialize(globalIndexResult, dataOutputSerializer);
+ return dataOutputSerializer.getCopyOfBuffer();
+ }
+
+ public ScoredGlobalIndexResult deserialize(byte[] data) throws IOException
{
+ DataInputDeserializer dataInputDeserializer = new
DataInputDeserializer(data);
+ GlobalIndexResult globalIndexResult =
deserialize(dataInputDeserializer);
+ Preconditions.checkArgument(
+ globalIndexResult instanceof ScoredGlobalIndexResult,
+ "Expected ScoredGlobalIndexResult, but got %s",
+ globalIndexResult == null ? "null" :
globalIndexResult.getClass().getName());
+ return (ScoredGlobalIndexResult) globalIndexResult;
+ }
}
diff --git
a/paimon-common/src/main/java/org/apache/paimon/utils/RoaringNavigableMap64.java
b/paimon-common/src/main/java/org/apache/paimon/utils/RoaringNavigableMap64.java
index bec44f3fb0..c70623817c 100644
---
a/paimon-common/src/main/java/org/apache/paimon/utils/RoaringNavigableMap64.java
+++
b/paimon-common/src/main/java/org/apache/paimon/utils/RoaringNavigableMap64.java
@@ -25,12 +25,15 @@ import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
+import java.io.Serializable;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
/** A compressed bitmap for 64-bit integer aggregated by tree. */
-public class RoaringNavigableMap64 implements Iterable<Long> {
+public class RoaringNavigableMap64 implements Iterable<Long>, Serializable {
+
+ private static final long serialVersionUID = 1L;
private final Roaring64NavigableMap roaring64NavigableMap;
diff --git
a/paimon-core/src/main/java/org/apache/paimon/table/source/VectorReadImpl.java
b/paimon-core/src/main/java/org/apache/paimon/table/source/VectorReadImpl.java
index a3402c3f1d..2eae2d4877 100644
---
a/paimon-core/src/main/java/org/apache/paimon/table/source/VectorReadImpl.java
+++
b/paimon-core/src/main/java/org/apache/paimon/table/source/VectorReadImpl.java
@@ -42,6 +42,7 @@ import org.apache.paimon.utils.RoaringNavigableMap64;
import javax.annotation.Nullable;
import java.io.IOException;
+import java.io.Serializable;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
@@ -55,13 +56,15 @@ import static
org.apache.paimon.CoreOptions.GLOBAL_INDEX_THREAD_NUM;
import static org.apache.paimon.utils.Preconditions.checkNotNull;
/** Implementation for {@link VectorRead}. */
-public class VectorReadImpl implements VectorRead {
+public class VectorReadImpl implements VectorRead, Serializable {
- private final FileStoreTable table;
+ private static final long serialVersionUID = 1L;
+
+ protected final FileStoreTable table;
private final Predicate filter;
- private final int limit;
- private final DataField vectorColumn;
- private final float[] vector;
+ protected final int limit;
+ protected final DataField vectorColumn;
+ protected final float[] vector;
public VectorReadImpl(
FileStoreTable table,
@@ -120,7 +123,7 @@ public class VectorReadImpl implements VectorRead {
return result.topK(limit);
}
- private Optional<RoaringNavigableMap64> preFilter(List<VectorSearchSplit>
splits) {
+ protected Optional<RoaringNavigableMap64>
preFilter(List<VectorSearchSplit> splits) {
Set<IndexFileMeta> scalarIndexFiles =
new TreeSet<>(Comparator.comparing(IndexFileMeta::fileName));
for (VectorSearchSplit split : splits) {
@@ -139,7 +142,7 @@ public class VectorReadImpl implements VectorRead {
}
}
- private CompletableFuture<Optional<ScoredGlobalIndexResult>> eval(
+ protected CompletableFuture<Optional<ScoredGlobalIndexResult>> eval(
GlobalIndexer globalIndexer,
IndexPathFactory indexPathFactory,
long rowRangeStart,
diff --git
a/paimon-core/src/main/java/org/apache/paimon/table/source/VectorSearchBuilderImpl.java
b/paimon-core/src/main/java/org/apache/paimon/table/source/VectorSearchBuilderImpl.java
index beb7844e13..a0d11ff21f 100644
---
a/paimon-core/src/main/java/org/apache/paimon/table/source/VectorSearchBuilderImpl.java
+++
b/paimon-core/src/main/java/org/apache/paimon/table/source/VectorSearchBuilderImpl.java
@@ -32,13 +32,13 @@ public class VectorSearchBuilderImpl implements
VectorSearchBuilder {
private static final long serialVersionUID = 1L;
- private final FileStoreTable table;
+ protected final FileStoreTable table;
- private PartitionPredicate partitionFilter;
- private Predicate filter;
- private int limit;
- private DataField vectorColumn;
- private float[] vector;
+ protected PartitionPredicate partitionFilter;
+ protected Predicate filter;
+ protected int limit;
+ protected DataField vectorColumn;
+ protected float[] vector;
public VectorSearchBuilderImpl(InnerTable table) {
this.table = (FileStoreTable) table;
diff --git
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/read/SparkEngineContext.java
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/read/SparkEngineContext.java
new file mode 100644
index 0000000000..f5a1abd668
--- /dev/null
+++
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/read/SparkEngineContext.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.spark.read;
+
+import org.apache.paimon.utils.SerializableFunction;
+
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.broadcast.Broadcast;
+import org.apache.spark.sql.SparkSession;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.stream.Stream;
+
+/**
+ * Tiny wrapper around the active {@link SparkSession} that exposes RDD style
{@code map} / {@code
+ * flatMap} primitives over a Java {@link List}. Used by Paimon-on-Spark to
dispatch
+ * embarrassingly-parallel work (e.g. per-split vector search) to the cluster
without forcing the
+ * caller to depend on Spark types directly.
+ */
+public class SparkEngineContext {
+
+ private final JavaSparkContext jsc;
+
+ public SparkEngineContext() {
+ this.jsc =
JavaSparkContext.fromSparkContext(SparkSession.active().sparkContext());
+ }
+
+ public <T> Broadcast<T> broadcast(T value) {
+ return jsc.broadcast(value);
+ }
+
+ public <I, O> List<O> map(List<I> data, SerializableFunction<I, O> func,
int parallelism) {
+ if (data.isEmpty()) {
+ return Collections.emptyList();
+ }
+ return jsc.parallelize(data, parallelism).map(func::apply).collect();
+ }
+
+ public <I, O> List<O> flatMap(
+ List<I> data, SerializableFunction<I, Stream<O>> func, int
parallelism) {
+ if (data.isEmpty()) {
+ return Collections.emptyList();
+ }
+ return jsc.parallelize(data, parallelism).flatMap(x ->
func.apply(x).iterator()).collect();
+ }
+}
diff --git
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/read/SparkVectorReadImpl.java
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/read/SparkVectorReadImpl.java
new file mode 100644
index 0000000000..df60afb489
--- /dev/null
+++
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/read/SparkVectorReadImpl.java
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.spark.read;
+
+import org.apache.paimon.globalindex.GlobalIndexReadThreadPool;
+import org.apache.paimon.globalindex.GlobalIndexResult;
+import org.apache.paimon.globalindex.GlobalIndexResultSerializer;
+import org.apache.paimon.globalindex.GlobalIndexer;
+import org.apache.paimon.globalindex.GlobalIndexerFactoryUtils;
+import org.apache.paimon.globalindex.ScoredGlobalIndexResult;
+import org.apache.paimon.index.IndexPathFactory;
+import org.apache.paimon.predicate.Predicate;
+import org.apache.paimon.table.FileStoreTable;
+import org.apache.paimon.table.source.VectorReadImpl;
+import org.apache.paimon.table.source.VectorSearchSplit;
+import org.apache.paimon.types.DataField;
+import org.apache.paimon.utils.InstantiationUtil;
+import org.apache.paimon.utils.RoaringNavigableMap64;
+import org.apache.paimon.utils.SerializableFunction;
+
+import org.apache.spark.broadcast.Broadcast;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutorService;
+
+import static org.apache.paimon.CoreOptions.GLOBAL_INDEX_THREAD_NUM;
+
+/**
+ * Spark-aware {@link VectorReadImpl} that distributes grouped vector index
evaluation across the
+ * Spark cluster instead of evaluating them with the local thread pool.
+ */
+public class SparkVectorReadImpl extends VectorReadImpl {
+
+ private static final long serialVersionUID = 1L;
+
+ public SparkVectorReadImpl(
+ FileStoreTable table,
+ Predicate filter,
+ int limit,
+ DataField vectorColumn,
+ float[] vector) {
+ super(table, filter, limit, vectorColumn, vector);
+ }
+
+ @Override
+ public GlobalIndexResult read(List<VectorSearchSplit> splits) {
+ if (splits.isEmpty()) {
+ return GlobalIndexResult.createEmpty();
+ }
+
+ int parallelism =
+ Math.max(1,
table.coreOptions().toConfiguration().get(GLOBAL_INDEX_THREAD_NUM));
+ if (splits.size() < parallelism * 2) {
+ return super.read(splits);
+ }
+
+ RoaringNavigableMap64 preFilter = preFilter(splits).orElse(null);
+ String indexType = splits.get(0).vectorIndexFiles().get(0).indexType();
+ List<byte[]> splitBytes = new ArrayList<>(splits.size());
+ for (VectorSearchSplit split : splits) {
+ try {
+ splitBytes.add(InstantiationUtil.serializeObject(split));
+ } catch (IOException e) {
+ throw new RuntimeException("Failed to serialize
VectorSearchSplit", e);
+ }
+ }
+ List<List<byte[]>> splitGroups = splitGroups(splitBytes, parallelism);
+ SparkEngineContext engineContext = new SparkEngineContext();
+ Broadcast<RoaringNavigableMap64> preFilterBroadcast =
+ preFilter == null ? null : engineContext.broadcast(preFilter);
+
+ SerializableFunction<List<byte[]>, byte[]> task =
+ group -> {
+ GlobalIndexer globalIndexer =
+ GlobalIndexerFactoryUtils.load(indexType)
+ .create(vectorColumn,
table.coreOptions().toConfiguration());
+ IndexPathFactory indexPathFactory =
+
table.store().pathFactory().globalIndexFileFactory();
+
+ RoaringNavigableMap64 includeRowIds =
+ preFilterBroadcast == null ? null :
preFilterBroadcast.value();
+ ExecutorService executor =
+ GlobalIndexReadThreadPool.getExecutorService(
+ Math.min(parallelism, group.size()));
+ List<CompletableFuture<Optional<ScoredGlobalIndexResult>>>
futures =
+ new ArrayList<>(group.size());
+ for (byte[] bytes : group) {
+ VectorSearchSplit split = deserializeSplit(bytes);
+ futures.add(
+ eval(
+ globalIndexer,
+ indexPathFactory,
+ split.rowRangeStart(),
+ split.rowRangeEnd(),
+ split.vectorIndexFiles(),
+ includeRowIds,
+ executor));
+ }
+ CompletableFuture.allOf(futures.toArray(new
CompletableFuture[0])).join();
+ ScoredGlobalIndexResult result =
ScoredGlobalIndexResult.createEmpty();
+ for (CompletableFuture<Optional<ScoredGlobalIndexResult>>
f : futures) {
+ Optional<ScoredGlobalIndexResult> next = f.join();
+ if (next.isPresent()) {
+ result = result.or(next.get());
+ }
+ }
+ result = result.topK(limit);
+ if (result.results().isEmpty()) {
+ return null;
+ }
+ try {
+ return new
GlobalIndexResultSerializer().serialize(result);
+ } catch (IOException e) {
+ throw new RuntimeException(
+ "Failed to serialize ScoredGlobalIndexResult",
e);
+ }
+ };
+
+ List<byte[]> remoteResults;
+ try {
+ remoteResults = engineContext.map(splitGroups, task,
splitGroups.size());
+ } finally {
+ if (preFilterBroadcast != null) {
+ preFilterBroadcast.unpersist(false);
+ }
+ }
+
+ ScoredGlobalIndexResult result = ScoredGlobalIndexResult.createEmpty();
+ GlobalIndexResultSerializer serializer = new
GlobalIndexResultSerializer();
+ for (byte[] bytes : remoteResults) {
+ if (bytes != null) {
+ try {
+ result = result.or(serializer.deserialize(bytes));
+ } catch (IOException e) {
+ throw new RuntimeException("Failed to deserialize
ScoredGlobalIndexResult", e);
+ }
+ }
+ }
+ return result.topK(limit);
+ }
+
+ private VectorSearchSplit deserializeSplit(byte[] bytes) {
+ try {
+ return InstantiationUtil.deserializeObject(
+ bytes, Thread.currentThread().getContextClassLoader());
+ } catch (IOException | ClassNotFoundException e) {
+ throw new RuntimeException("Failed to deserialize
VectorSearchSplit", e);
+ }
+ }
+
+ private List<List<byte[]>> splitGroups(List<byte[]> splitBytes, int
parallelism) {
+ List<List<byte[]>> groups = new ArrayList<>(parallelism);
+ int groupSize = (splitBytes.size() + parallelism - 1) / parallelism;
+ for (int start = 0; start < splitBytes.size(); start += groupSize) {
+ groups.add(
+ new ArrayList<>(
+ splitBytes.subList(
+ start, Math.min(start + groupSize,
splitBytes.size()))));
+ }
+ return groups;
+ }
+}
diff --git
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/read/SparkVectorSearchBuilderImpl.java
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/read/SparkVectorSearchBuilderImpl.java
new file mode 100644
index 0000000000..0e4e347f64
--- /dev/null
+++
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/read/SparkVectorSearchBuilderImpl.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.spark.read;
+
+import org.apache.paimon.table.InnerTable;
+import org.apache.paimon.table.source.VectorRead;
+import org.apache.paimon.table.source.VectorSearchBuilderImpl;
+
+/**
+ * Spark-aware {@link VectorSearchBuilderImpl} which produces a {@link
SparkVectorReadImpl} so the
+ * per-split vector index evaluation is dispatched through Spark instead of
the local thread pool.
+ */
+public class SparkVectorSearchBuilderImpl extends VectorSearchBuilderImpl {
+
+ private static final long serialVersionUID = 1L;
+
+ public SparkVectorSearchBuilderImpl(InnerTable table) {
+ super(table);
+ }
+
+ @Override
+ public VectorRead newVectorRead() {
+ return new SparkVectorReadImpl(table, filter, limit, vectorColumn,
vector);
+ }
+}
diff --git
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScan.scala
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScan.scala
index ff1e0bde37..abff5ba3dc 100644
---
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScan.scala
+++
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScan.scala
@@ -18,11 +18,12 @@
package org.apache.paimon.spark
+import org.apache.paimon.CoreOptions
import org.apache.paimon.globalindex.GlobalIndexResult
import org.apache.paimon.partition.PartitionPredicate
import org.apache.paimon.predicate.PredicateBuilder
import org.apache.paimon.spark.metric.SparkMetricRegistry
-import org.apache.paimon.spark.read.{BaseScan, BatchReadTagCleanupListener,
PaimonSupportsRuntimeFiltering}
+import org.apache.paimon.spark.read.{BaseScan, BatchReadTagCleanupListener,
PaimonSupportsRuntimeFiltering, SparkVectorSearchBuilderImpl}
import org.apache.paimon.spark.sources.PaimonMicroBatchStream
import org.apache.paimon.spark.util.OptionUtils
import org.apache.paimon.table.{DataTable, FileStoreTable, InnerTable}
@@ -78,8 +79,13 @@ abstract class PaimonBaseScan(table: InnerTable)
private def evalVectorSearch(): GlobalIndexResult = {
val vectorSearch = pushedVectorSearch.get
- val vectorBuilder = table
- .newVectorSearchBuilder()
+ val vectorSearchBuilder =
+ if (CoreOptions.fromMap(table.options).vectorSearchDistributeEnabled()) {
+ new SparkVectorSearchBuilderImpl(table)
+ } else {
+ table.newVectorSearchBuilder()
+ }
+ val vectorBuilder = vectorSearchBuilder
.withVector(vectorSearch.vector())
.withVectorColumn(vectorSearch.fieldName())
.withLimit(vectorSearch.limit())
diff --git
a/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkMultimodalITCase.java
b/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkMultimodalITCase.java
index 4100f54f61..435aa6ce85 100644
---
a/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkMultimodalITCase.java
+++
b/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkMultimodalITCase.java
@@ -20,6 +20,7 @@ package org.apache.paimon.spark;
import org.apache.paimon.fs.Path;
import org.apache.paimon.hive.TestHiveMetastore;
+import org.apache.paimon.utils.Pair;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
@@ -31,6 +32,7 @@ import org.junit.jupiter.api.io.TempDir;
import java.io.IOException;
import java.util.List;
+import java.util.stream.Collectors;
import static org.assertj.core.api.Assertions.assertThat;
@@ -95,13 +97,17 @@ public class SparkMultimodalITCase {
spark = builder.getOrCreate();
spark.sql(
- "insert overwrite table my_db1.vector_test\n"
- + "VALUES (1, '1', array(cast(1.0 as float), cast(2.0
as float), cast(3.0 as float), cast(4.0 as float)), '20260420'),\n"
+ "insert overwrite table my_db1.vector_test VALUES \n"
+ + "(1, '1', array(cast(1.0 as float), cast(2.0 as
float), cast(3.0 as float), cast(4.0 as float)), '20260420'),\n"
+ "(2, '2', array(cast(2.0 as float), cast(3.0 as
float), cast(4.0 as float), cast(5.0 as float)), '20260420'),\n"
- + "(3, '3', array(cast(3.0 as float), cast(4.0 as
float), cast(5.0 as float), cast(6.0 as float)), '20260420'),\n"
+ + "(3, '3', array(cast(3.0 as float), cast(4.0 as
float), cast(5.0 as float), cast(6.0 as float)), '20260420');");
+ spark.sql(
+ "insert into table my_db1.vector_test VALUES\n"
+ "(4, '4', array(cast(4.0 as float), cast(5.0 as
float), cast(6.0 as float), cast(7.0 as float)), '20260420'),\n"
+ "(5, '5', array(cast(5.0 as float), cast(6.0 as
float), cast(7.0 as float), cast(8.0 as float)), '20260420'),\n"
- + "(6, '6', array(cast(6.0 as float), cast(7.0 as
float), cast(8.0 as float), cast(9.0 as float)), '20260420'),\n"
+ + "(6, '6', array(cast(6.0 as float), cast(7.0 as
float), cast(8.0 as float), cast(9.0 as float)), '20260420');");
+ spark.sql(
+ "insert into table my_db1.vector_test VALUES\n"
+ "(7, '7', array(cast(7.0 as float), cast(8.0 as
float), cast(9.0 as float), cast(10.0 as float)), '20260420'),\n"
+ "(8, '8', array(cast(8.0 as float), cast(9.0 as
float), cast(10.0 as float), cast(11.0 as float)), '20260420');");
spark.close();
@@ -128,12 +134,26 @@ public class SparkMultimodalITCase {
"select gid, sid, embs from
vector_search('my_db1.vector_test', 'embs', array(1.0f, 2.0f, 3.0f, 4.0f), 5)
where date = '20260420'")
.collectAsList();
assertThat(rows).hasSize(5);
- Dataset<Row> df =
- spark.sql(
- "select gid, sid, embs, __paimon_vector_search_score
from vector_search('my_db1.vector_test', 'embs', array(1.0f, 2.0f, 3.0f, 4.0f),
5) where date = '20260420'");
+ String vectorSearchSql =
+ "select gid, sid, embs, __paimon_vector_search_score "
+ + "from vector_search('my_db1.vector_test', 'embs',
array(1.0f, 2.0f, 3.0f, 4.0f), 5) "
+ + "where date = '20260420'";
+ Dataset<Row> df = spark.sql(vectorSearchSql);
assertThat(df.columns()).hasSize(4);
rows = df.collectAsList();
assertThat(rows).hasSize(5);
+ spark.sql("SET
`spark.paimon.vector-search.distribute.enabled`=`true`");
+ spark.sql("SET `spark.paimon.global-index.thread-num`=`1`");
+ List<Row> compareRows = spark.sql(vectorSearchSql).collectAsList();
+ assertThat(compareRows).hasSize(5);
+ assertThat(
+ compareRows.stream()
+ .map(row -> Pair.of(row.getLong(0),
row.getString(1)))
+ .collect(Collectors.toList()))
+ .containsExactlyInAnyOrderElementsOf(
+ rows.stream()
+ .map(row -> Pair.of(row.getLong(0),
row.getString(1)))
+ .collect(Collectors.toList()));
spark.close();
spark = builder.getOrCreate();