Copilot commented on code in PR #6950:
URL: https://github.com/apache/paimon/pull/6950#discussion_r2663348303
##########
paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/plans/logical/PaimonTableValuedFunctions.scala:
##########
@@ -85,17 +89,44 @@ object PaimonTableValuedFunctions {
val sparkCatalog =
catalogManager.catalog(catalogName).asInstanceOf[TableCatalog]
val ident: Identifier = Identifier.of(Array(dbName), tableName)
val sparkTable = sparkCatalog.loadTable(ident)
- val options = tvf.parseArgs(args.tail)
- usingSparkIncrementQuery(tvf, sparkTable, options) match {
- case Some(snapshotIdPair: (Long, Long)) =>
- sparkIncrementQuery(spark, sparkTable, sparkCatalog, ident, options,
snapshotIdPair)
+ // Handle vector_search specially
+ tvf match {
+ case vsq: VectorSearchQuery =>
+ resolveVectorSearchQuery(sparkTable, sparkCatalog, ident, vsq,
args.tail)
case _ =>
+ val options = tvf.parseArgs(args.tail)
+ usingSparkIncrementQuery(tvf, sparkTable, options) match {
+ case Some(snapshotIdPair: (Long, Long)) =>
+ sparkIncrementQuery(spark, sparkTable, sparkCatalog, ident,
options, snapshotIdPair)
+ case _ =>
+ DataSourceV2Relation.create(
+ sparkTable,
+ Some(sparkCatalog),
+ Some(ident),
+ new CaseInsensitiveStringMap(options.asJava))
+ }
+ }
+ }
+
+ private def resolveVectorSearchQuery(
+ sparkTable: Table,
+ sparkCatalog: TableCatalog,
+ ident: Identifier,
+ vsq: VectorSearchQuery,
+ argsWithoutTable: Seq[Expression]): LogicalPlan = {
+ sparkTable match {
+ case st @ SparkTable(innerTable: InnerTable) =>
+ val vectorSearch = vsq.createVectorSearch(argsWithoutTable)
+ val vectorSearchTable = VectorSearchTable.create(innerTable,
vectorSearch)
DataSourceV2Relation.create(
- sparkTable,
+ st.copy(table = vectorSearchTable),
Some(sparkCatalog),
Some(ident),
- new CaseInsensitiveStringMap(options.asJava))
+ CaseInsensitiveStringMap.empty())
+ case _ =>
+ throw new RuntimeException(
+ s"vector_search only supports Paimon tables, got
${sparkTable.getClass.getName}")
Review Comment:
The error message could be more helpful by specifying what type of table is
actually supported. Consider changing the message to explicitly state that only
InnerTable instances are supported, or provide guidance on what the user should
do instead.
```suggestion
s"vector_search only supports Paimon SparkTable backed by
InnerTable, " +
s"but got table implementation: ${sparkTable.getClass.getName}")
```
##########
paimon-spark/paimon-spark-4.0/src/test/scala/org/apache/paimon/spark/sql/VectorSearchPushDownTest.scala:
##########
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.spark.sql
+
+import org.apache.paimon.spark.PaimonScan
+
+/** Tests for vector search table-valued function with global vector index. */
+class VectorSearchPushDownTest extends BaseVectorSearchPushDownTest {
+ test("vector search with global index") {
+ withTable("T") {
+ spark.sql("""
+ |CREATE TABLE T (id INT, v ARRAY<FLOAT>)
+ |TBLPROPERTIES (
+ | 'bucket' = '-1',
+ | 'global-index.row-count-per-shard' = '10000',
+ | 'row-tracking.enabled' = 'true',
+ | 'data-evolution.enabled' = 'true')
+ |""".stripMargin)
+
+ // Insert 100 rows with predictable vectors
+ val values = (0 until 100)
+ .map(
+ i => s"($i, array(cast($i as float), cast(${i + 1} as float),
cast(${i + 2} as float)))")
+ .mkString(",")
+ spark.sql(s"INSERT INTO T VALUES $values")
+
+ // Create vector index
+ val output = spark
+ .sql("CALL sys.create_global_index(table => 'test.T', index_column =>
'v', index_type => 'lucene-vector-knn', options => 'vector.dim=3')")
+ .collect()
+ .head
+ assert(output.getBoolean(0))
+
+ // Test vector search with table-valued function syntax
+ val result = spark
+ .sql("""
+ |SELECT * FROM vector_search('T', 'v', array(50.0f, 51.0f,
52.0f), 5)
+ |""".stripMargin)
+ .collect()
+
+ // The result should contain 5 rows
+ assert(result.length == 5)
+
+ // Vector (50, 51, 52) should be most similar to the row with id=50
+ assert(result.map(_.getInt(0)).contains(50))
+ }
+ }
+
+ test("vector search pushdown is applied in plan") {
+ withTable("T") {
+ spark.sql("""
+ |CREATE TABLE T (id INT, v ARRAY<FLOAT>)
+ |TBLPROPERTIES (
+ | 'bucket' = '-1',
+ | 'global-index.row-count-per-shard' = '10000',
+ | 'row-tracking.enabled' = 'true',
+ | 'data-evolution.enabled' = 'true')
+ |""".stripMargin)
+
+ val values = (0 until 10)
+ .map(
+ i => s"($i, array(cast($i as float), cast(${i + 1} as float),
cast(${i + 2} as float)))")
+ .mkString(",")
+ spark.sql(s"INSERT INTO T VALUES $values")
+
+ // Create vector index
+ spark
+ .sql("CALL sys.create_global_index(table => 'test.T', index_column =>
'v', index_type => 'lucene-vector-knn', options => 'vector.dim=3')")
+ .collect()
+
+ // Check that vector search is pushed down with table function syntax
+ val df = spark.sql("""
+ |SELECT * FROM vector_search('T', 'v', array(50.0f,
51.0f, 52.0f), 5)
+ |""".stripMargin)
+
+ // Get the scan from the executed plan (physical plan)
+ val executedPlan = df.queryExecution.executedPlan
+ val batchScans = executedPlan.collect {
+ case scan: org.apache.spark.sql.execution.datasources.v2.BatchScanExec
=> scan
+ }
+
+ assert(batchScans.nonEmpty, "Should have a BatchScanExec in executed
plan")
+ val paimonScans = batchScans.filter(_.scan.isInstanceOf[PaimonScan])
+ assert(paimonScans.nonEmpty, "Should have a PaimonScan in executed plan")
+
+ val paimonScan = paimonScans.head.scan.asInstanceOf[PaimonScan]
+ assert(paimonScan.pushedVectorSearch.isDefined, "Vector search should be
pushed down")
+ assert(paimonScan.pushedVectorSearch.get.fieldName() == "v", "Field name
should be 'v'")
+ assert(paimonScan.pushedVectorSearch.get.limit() == 5, "Limit should be
5")
+ }
+ }
+
+ test("vector search topk returns correct results") {
+ withTable("T") {
+ spark.sql("""
+ |CREATE TABLE T (id INT, v ARRAY<FLOAT>)
+ |TBLPROPERTIES (
+ | 'bucket' = '-1',
+ | 'global-index.row-count-per-shard' = '10000',
+ | 'row-tracking.enabled' = 'true',
+ | 'data-evolution.enabled' = 'true')
+ |""".stripMargin)
+
+ // Insert rows with distinct vectors
+ val values = (1 to 100)
+ .map {
+ i =>
+ val v = math.sqrt(3.0 * i * i)
+ val normalized = i.toFloat / v.toFloat
+ s"($i, array($normalized, $normalized, $normalized))"
+ }
+ .mkString(",")
+ spark.sql(s"INSERT INTO T VALUES $values")
+
+ // Create vector index
+ spark.sql(
+ "CALL sys.create_global_index(table => 'test.T', index_column => 'v',
index_type => 'lucene-vector-knn', options => 'vector.dim=3')")
+
+ // Query for top 10 similar to (1, 1, 1) normalized
+ val result = spark
+ .sql("""
+ |SELECT * FROM vector_search('T', 'v', array(0.577f, 0.577f,
0.577f), 10)
+ |""".stripMargin)
+ .collect()
+
+ assert(result.length == 10)
+ }
+ }
+}
Review Comment:
The test suite lacks negative test cases for the vector_search function.
Consider adding tests for error scenarios such as: invalid number of
parameters, non-existent table names, invalid column names, non-array query
vectors, invalid limit values (e.g., negative numbers or zero), and mismatched
vector dimensions.
##########
paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/plans/logical/PaimonTableValuedFunctions.scala:
##########
@@ -207,3 +238,59 @@ case class IncrementalToAutoTag(override val args:
Seq[Expression])
Map(CoreOptions.INCREMENTAL_TO_AUTO_TAG.key -> endTagName)
}
}
+
+/**
+ * Plan for the [[VECTOR_SEARCH]] table-valued function.
+ *
+ * Usage: vector_search(table_name, column_name, query_vector, limit)
+ * - table_name: the Paimon table to search
+ * - column_name: the vector column name
+ * - query_vector: array of floats representing the query vector
+ * - limit: the number of top results to return
+ *
+ * Example: SELECT * FROM vector_search('T', 'v', array(50.0f, 51.0f, 52.0f),
5)
+ */
+case class VectorSearchQuery(override val args: Seq[Expression])
+ extends PaimonTableValueFunction(VECTOR_SEARCH) {
+
+ override def parseArgs(args: Seq[Expression]): Map[String, String] = {
+ // This method is not used for VectorSearchQuery as we handle it specially
+ Map.empty
+ }
+
+ def createVectorSearch(argsWithoutTable: Seq[Expression]): VectorSearch = {
+ assert(
+ argsWithoutTable.size == 3,
+ s"$VECTOR_SEARCH needs four parameters: table_name, column_name,
query_vector, limit. " +
+ s"Got ${argsWithoutTable.size + 1} parameters."
+ )
+
+ val columnName = argsWithoutTable.head.eval().toString
+ val queryVector = extractQueryVector(argsWithoutTable(1))
+ val limit = argsWithoutTable(2).eval() match {
+ case i: Int => i
+ case l: Long => l.toInt
+ case other => throw new RuntimeException(s"Invalid limit type:
${other.getClass.getName}")
+ }
+
Review Comment:
Missing input validation for the limit parameter. The code should validate
that the limit is a positive integer. Negative or zero values for the limit
parameter could lead to unexpected behavior or errors downstream.
```suggestion
if (limit <= 0) {
throw new IllegalArgumentException(
s"Limit must be a positive integer, but got: $limit"
)
}
```
##########
paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/plans/logical/PaimonTableValuedFunctions.scala:
##########
@@ -207,3 +238,59 @@ case class IncrementalToAutoTag(override val args:
Seq[Expression])
Map(CoreOptions.INCREMENTAL_TO_AUTO_TAG.key -> endTagName)
}
}
+
+/**
+ * Plan for the [[VECTOR_SEARCH]] table-valued function.
+ *
+ * Usage: vector_search(table_name, column_name, query_vector, limit)
+ * - table_name: the Paimon table to search
+ * - column_name: the vector column name
+ * - query_vector: array of floats representing the query vector
+ * - limit: the number of top results to return
+ *
+ * Example: SELECT * FROM vector_search('T', 'v', array(50.0f, 51.0f, 52.0f),
5)
+ */
+case class VectorSearchQuery(override val args: Seq[Expression])
+ extends PaimonTableValueFunction(VECTOR_SEARCH) {
+
+ override def parseArgs(args: Seq[Expression]): Map[String, String] = {
+ // This method is not used for VectorSearchQuery as we handle it specially
+ Map.empty
+ }
+
+ def createVectorSearch(argsWithoutTable: Seq[Expression]): VectorSearch = {
+ assert(
+ argsWithoutTable.size == 3,
+ s"$VECTOR_SEARCH needs four parameters: table_name, column_name,
query_vector, limit. " +
+ s"Got ${argsWithoutTable.size + 1} parameters."
Review Comment:
The error message is inconsistent with the validation logic. The assertion
checks for 3 parameters in argsWithoutTable (column_name, query_vector, limit),
but the error message mentions "four parameters" which would be correct when
counting the table_name. However, the message should clarify this is the total
count including the table name, or adjust to say "three additional parameters
after table_name".
```suggestion
s"$VECTOR_SEARCH needs three parameters after table_name: column_name,
query_vector, limit. " +
s"Got ${argsWithoutTable.size} parameters after table_name."
```
##########
paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/plans/logical/PaimonTableValuedFunctions.scala:
##########
@@ -207,3 +238,59 @@ case class IncrementalToAutoTag(override val args:
Seq[Expression])
Map(CoreOptions.INCREMENTAL_TO_AUTO_TAG.key -> endTagName)
}
}
+
+/**
+ * Plan for the [[VECTOR_SEARCH]] table-valued function.
+ *
+ * Usage: vector_search(table_name, column_name, query_vector, limit)
+ * - table_name: the Paimon table to search
+ * - column_name: the vector column name
+ * - query_vector: array of floats representing the query vector
+ * - limit: the number of top results to return
+ *
+ * Example: SELECT * FROM vector_search('T', 'v', array(50.0f, 51.0f, 52.0f),
5)
+ */
+case class VectorSearchQuery(override val args: Seq[Expression])
+ extends PaimonTableValueFunction(VECTOR_SEARCH) {
+
+ override def parseArgs(args: Seq[Expression]): Map[String, String] = {
+ // This method is not used for VectorSearchQuery as we handle it specially
+ Map.empty
+ }
+
+ def createVectorSearch(argsWithoutTable: Seq[Expression]): VectorSearch = {
+ assert(
+ argsWithoutTable.size == 3,
+ s"$VECTOR_SEARCH needs four parameters: table_name, column_name,
query_vector, limit. " +
+ s"Got ${argsWithoutTable.size + 1} parameters."
+ )
+
+ val columnName = argsWithoutTable.head.eval().toString
+ val queryVector = extractQueryVector(argsWithoutTable(1))
+ val limit = argsWithoutTable(2).eval() match {
+ case i: Int => i
+ case l: Long => l.toInt
+ case other => throw new RuntimeException(s"Invalid limit type:
${other.getClass.getName}")
+ }
+
+ new VectorSearch(queryVector, limit, columnName)
+ }
+
+ private def extractQueryVector(expr: Expression): Array[Float] = {
+ expr match {
+ case Literal(arrayData, _) if arrayData != null =>
+ val arr =
arrayData.asInstanceOf[org.apache.spark.sql.catalyst.util.ArrayData]
+ arr.toFloatArray()
+ case CreateArray(elements, _) =>
+ elements.map {
+ case Literal(v: Float, _) => v
+ case Literal(v: Double, _) => v.toFloat
+ case Literal(v: java.lang.Float, _) if v != null => v.floatValue()
+ case Literal(v: java.lang.Double, _) if v != null => v.floatValue()
+ case other => throw new RuntimeException(s"Cannot extract float
from: $other")
+ }.toArray
+ case _ =>
+ throw new RuntimeException(s"Cannot extract query vector from
expression: $expr")
+ }
+ }
Review Comment:
Missing validation for empty query vectors. The extractQueryVector method
should validate that the extracted array is not empty, as an empty query vector
would not make sense for vector search operations and could cause errors in the
underlying search implementation.
##########
paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala:
##########
@@ -128,13 +128,20 @@ class PaimonScanBuilder(val table: InnerTable)
localScan match {
case Some(scan) => scan
case None =>
+ val (actualTable, vectorSearch) = table match {
+ case vst: org.apache.paimon.table.VectorSearchTable =>
+ (vst.origin(), Some(vst.vectorSearch()))
Review Comment:
There's a potential logical issue in handling VectorSearchTable. If the
table is already a VectorSearchTable, this code extracts the vectorSearch from
it and ignores any pushedVectorSearch. However, this assumes that
pushedVectorSearch and vst.vectorSearch() are never both set at the same time.
Consider adding validation or documentation about this assumption, or
prioritize one over the other explicitly.
```suggestion
val tableVectorSearch = Option(vst.vectorSearch())
val effectiveVectorSearch = (tableVectorSearch,
pushedVectorSearch) match {
// If both are defined, prioritize the configuration embedded
in the table.
case (Some(_), Some(_)) => tableVectorSearch
case (Some(_), None) => tableVectorSearch
case (None, Some(_)) => pushedVectorSearch
case (None, None) => None
}
(vst.origin(), effectiveVectorSearch)
```
##########
paimon-core/src/main/java/org/apache/paimon/table/VectorSearchTable.java:
##########
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.table;
+
+import org.apache.paimon.fs.FileIO;
+import org.apache.paimon.predicate.VectorSearch;
+import org.apache.paimon.table.source.InnerTableRead;
+import org.apache.paimon.table.source.InnerTableScan;
+import org.apache.paimon.types.RowType;
+
+import javax.annotation.Nullable;
+
+import java.util.List;
+import java.util.Map;
+
+/**
+ * A table wrapper to hold vector search information. This is used by Spark
engine to pass vector
+ * search pushdown information from logical plan optimization to physical plan
execution.
+ */
+public class VectorSearchTable implements ReadonlyTable {
+
+ private final InnerTable origin;
+ private final VectorSearch vectorSearch;
+
+ VectorSearchTable(InnerTable origin, VectorSearch vectorSearch) {
+ this.origin = origin;
+ this.vectorSearch = vectorSearch;
+ }
+
+ public static VectorSearchTable create(InnerTable origin, VectorSearch
vectorSearch) {
+ return new VectorSearchTable(origin, vectorSearch);
+ }
+
+ @Nullable
+ public VectorSearch vectorSearch() {
+ return vectorSearch;
Review Comment:
The @Nullable annotation on vectorSearch() method is misleading. The
vectorSearch field is set in the constructor and cannot be null based on the
class design. This method should not be annotated with @Nullable since it
always returns a non-null VectorSearch instance.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]