This is an automated email from the ASF dual-hosted git repository.
JingsongLi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git
The following commit(s) were added to refs/heads/master by this push:
new 0609e8670a [spark] support return score on vector search (#8068)
0609e8670a is described below
commit 0609e8670a87fc6289e195617507c70da5173544
Author: Stefanietry <[email protected]>
AuthorDate: Tue Jun 2 21:35:47 2026 +0800
[spark] support return score on vector search (#8068)
---
.../paimon/spark/PaimonRecordReaderIterator.scala | 27 ++++++++++++++++++----
.../apache/paimon/spark/PaimonSparkTableBase.scala | 3 +++
.../paimon/spark/schema/PaimonMetadataColumn.scala | 9 ++++++--
.../apache/paimon/spark/SparkMultimodalITCase.java | 12 +++++++++-
4 files changed, 44 insertions(+), 7 deletions(-)
diff --git
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonRecordReaderIterator.scala
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonRecordReaderIterator.scala
index 444b6d6c64..0d67001f0e 100644
---
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonRecordReaderIterator.scala
+++
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonRecordReaderIterator.scala
@@ -20,11 +20,11 @@ package org.apache.paimon.spark
import org.apache.paimon.data.{BinaryString, GenericRow, InternalRow =>
PaimonInternalRow, JoinedRow}
import org.apache.paimon.fs.Path
-import org.apache.paimon.reader.{FileRecordIterator, RecordReader}
+import org.apache.paimon.reader.{FileRecordIterator, RecordReader,
ScoreRecordIterator}
import org.apache.paimon.spark.schema.PaimonMetadataColumn
-import
org.apache.paimon.spark.schema.PaimonMetadataColumn.{PARTITION_AND_BUCKET_META_COLUMNS,
PATH_AND_INDEX_META_COLUMNS}
+import
org.apache.paimon.spark.schema.PaimonMetadataColumn.{PARTITION_AND_BUCKET_META_COLUMNS,
PATH_AND_INDEX_META_COLUMNS, VECTOR_SEARCH_SCORE_COLUMN}
import org.apache.paimon.table.source.{DataSplit, Split}
-import org.apache.paimon.utils.CloseableIterator
+import org.apache.paimon.utils.{CloseableIterator, Preconditions}
import org.apache.spark.sql.PaimonUtils
@@ -48,6 +48,10 @@ case class PaimonRecordReaderIterator(
private val needMetadata = metadataColumns.nonEmpty
private val needPathAndIndexMetadata =
metadataColumns.exists(c => PATH_AND_INDEX_META_COLUMNS.contains(c.name))
+ private val needScoreMetadata = {
+ metadataColumns.exists(_.name == VECTOR_SEARCH_SCORE_COLUMN)
+ }
+ Preconditions.checkArgument(!needScoreMetadata || metadataColumns.size == 1)
private val metadataRow: GenericRow =
GenericRow.of(Array.fill(metadataColumns.size)(null.asInstanceOf[AnyRef]):
_*)
@@ -122,7 +126,11 @@ case class PaimonRecordReaderIterator(
while (!stop) {
val dataRow = currentIterator.next()
if (dataRow != null) {
- if (needMetadata) {
+ if (needScoreMetadata) {
+ updateScoreMetadata(
+
currentIterator.asInstanceOf[ScoreRecordIterator[PaimonInternalRow]])
+ currentResult = joinedRow.replace(dataRow, metadataRow)
+ } else if (needMetadata) {
updateMetadataRow(currentIterator.asInstanceOf[FileRecordIterator[PaimonInternalRow]])
currentResult = joinedRow.replace(dataRow, metadataRow)
} else {
@@ -165,4 +173,15 @@ case class PaimonRecordReaderIterator(
}
}
}
+
+ private def updateScoreMetadata(
+ fileRecordIterator: ScoreRecordIterator[PaimonInternalRow]): Unit = {
+ metadataColumns.zipWithIndex.foreach {
+ case (metadataColumn, index) =>
+ metadataColumn.name match {
+ case PaimonMetadataColumn.VECTOR_SEARCH_SCORE_COLUMN =>
+ metadataRow.setField(index, fileRecordIterator.returnedScore())
+ }
+ }
+ }
}
diff --git
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonSparkTableBase.scala
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonSparkTableBase.scala
index 0fc4bd9eb5..94b9128444 100644
---
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonSparkTableBase.scala
+++
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonSparkTableBase.scala
@@ -118,6 +118,9 @@ abstract class PaimonSparkTableBase(val table: Table)
_metadataColumns.append(PaimonMetadataColumn.ROW_ID)
_metadataColumns.append(PaimonMetadataColumn.SEQUENCE_NUMBER)
}
+ if (table.isInstanceOf[VectorSearchTable]) {
+ _metadataColumns.append(PaimonMetadataColumn.VECTOR_SEARCH_SCORE)
+ }
_metadataColumns.appendAll(
Seq(
diff --git
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/schema/PaimonMetadataColumn.scala
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/schema/PaimonMetadataColumn.scala
index 4b8ede097c..34343f7380 100644
---
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/schema/PaimonMetadataColumn.scala
+++
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/schema/PaimonMetadataColumn.scala
@@ -25,7 +25,7 @@ import org.apache.paimon.types.DataField
import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.connector.catalog.MetadataColumn
-import org.apache.spark.sql.types.{DataType, IntegerType, LongType,
StringType, StructField, StructType}
+import org.apache.spark.sql.types.{DataType, FloatType, IntegerType, LongType,
StringType, StructField, StructType}
case class PaimonMetadataColumn(id: Int, override val name: String, override
val dataType: DataType)
extends MetadataColumn {
@@ -51,6 +51,7 @@ object PaimonMetadataColumn {
val BUCKET_COLUMN = "__paimon_bucket"
val ROW_ID_COLUMN: String = SpecialFields.ROW_ID.name()
val SEQUENCE_NUMBER_COLUMN: String = SpecialFields.SEQUENCE_NUMBER.name()
+ val VECTOR_SEARCH_SCORE_COLUMN: String = "__paimon_vector_search_score"
val PATH_AND_INDEX_META_COLUMNS: Seq[String] = Seq(FILE_PATH_COLUMN,
ROW_INDEX_COLUMN)
val PARTITION_AND_BUCKET_META_COLUMNS: Seq[String] = Seq(PARTITION_COLUMN,
BUCKET_COLUMN)
@@ -62,7 +63,8 @@ object PaimonMetadataColumn {
PARTITION_COLUMN,
BUCKET_COLUMN,
ROW_ID_COLUMN,
- SEQUENCE_NUMBER_COLUMN
+ SEQUENCE_NUMBER_COLUMN,
+ VECTOR_SEARCH_SCORE_COLUMN
)
val ROW_INDEX: PaimonMetadataColumn =
@@ -78,6 +80,8 @@ object PaimonMetadataColumn {
PaimonMetadataColumn(Int.MaxValue - 104, ROW_ID_COLUMN, LongType)
val SEQUENCE_NUMBER: PaimonMetadataColumn =
PaimonMetadataColumn(Int.MaxValue - 105, SEQUENCE_NUMBER_COLUMN, LongType)
+ val VECTOR_SEARCH_SCORE: PaimonMetadataColumn =
+ PaimonMetadataColumn(Integer.MAX_VALUE - 106, VECTOR_SEARCH_SCORE_COLUMN,
FloatType)
def dvMetaCols: Seq[PaimonMetadataColumn] = Seq(FILE_PATH, ROW_INDEX)
@@ -89,6 +93,7 @@ object PaimonMetadataColumn {
case BUCKET_COLUMN => BUCKET
case ROW_ID_COLUMN => ROW_ID
case SEQUENCE_NUMBER_COLUMN => SEQUENCE_NUMBER
+ case VECTOR_SEARCH_SCORE_COLUMN => VECTOR_SEARCH_SCORE
case _ =>
throw new IllegalArgumentException(s"$metadataColumn metadata column
is not supported.")
}
diff --git
a/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkMultimodalITCase.java
b/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkMultimodalITCase.java
index c3c15d1b76..4100f54f61 100644
---
a/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkMultimodalITCase.java
+++
b/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkMultimodalITCase.java
@@ -21,6 +21,7 @@ package org.apache.paimon.spark;
import org.apache.paimon.fs.Path;
import org.apache.paimon.hive.TestHiveMetastore;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.junit.jupiter.api.AfterAll;
@@ -77,7 +78,6 @@ public class SparkMultimodalITCase {
spark.sql("CREATE DATABASE IF NOT EXISTS my_db1");
spark.sql("USE spark_catalog.my_db1");
- /** Create table */
spark.sql(
"\n"
+ "CREATE TABLE my_db1.vector_test (gid BIGINT, sid
STRING, embs ARRAY<FLOAT>)"
@@ -128,6 +128,16 @@ public class SparkMultimodalITCase {
"select gid, sid, embs from
vector_search('my_db1.vector_test', 'embs', array(1.0f, 2.0f, 3.0f, 4.0f), 5)
where date = '20260420'")
.collectAsList();
assertThat(rows).hasSize(5);
+ Dataset<Row> df =
+ spark.sql(
+ "select gid, sid, embs, __paimon_vector_search_score
from vector_search('my_db1.vector_test', 'embs', array(1.0f, 2.0f, 3.0f, 4.0f),
5) where date = '20260420'");
+ assertThat(df.columns()).hasSize(4);
+ rows = df.collectAsList();
+ assertThat(rows).hasSize(5);
+ spark.close();
+
+ spark = builder.getOrCreate();
+ spark.sql("DROP TABLE IF EXISTS `my_db1`.`vector_test`;");
spark.close();
}
}