yihua commented on code in PR #18432:
URL: https://github.com/apache/hudi/pull/18432#discussion_r3036068066


##########
hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieSparkBaseAnalysis.scala:
##########
@@ -310,6 +327,61 @@ case class ResolveReferences(spark: SparkSession) extends 
Rule[LogicalPlan]
       sparkAdapter.getCatalystPlanUtils.unapplyMergeIntoTable(plan)
   }
 
+  /**
+   * Resolves a table reference to a DataFrame. Accepts either a table 
identifier
+   * (including multi-part identifiers like catalog.db.table) or a file path.
+   */
+  private def resolveTableToDf(table: String): DataFrame = {
+    try {
+      if (table.contains(StoragePath.SEPARATOR)) {
+        spark.read.format("hudi").load(table)
+      } else {
+        spark.table(table)
+      }
+    } catch {
+      case e: Exception => throw new HoodieAnalysisException(
+        s"hudi_vector_search: unable to resolve table '$table': 
${e.getMessage}")
+    }
+  }
+
+  private def evaluateQueryVector(expr: Expression): Array[Double] = {
+    if (!expr.foldable) {
+      throw new HoodieAnalysisException(
+        s"Function '${HoodieVectorSearchTableValuedFunction.FUNC_NAME}': " +
+          "query vector must be a constant expression (e.g., ARRAY(1.0, 2.0, 
3.0))")
+    }
+    val value = expr.eval(null)
+    if (value == null) {
+      throw new HoodieAnalysisException(
+        s"Function '${HoodieVectorSearchTableValuedFunction.FUNC_NAME}': query 
vector cannot be null")
+    }
+
+    val arrayData = value.asInstanceOf[ArrayData]
+    val numElements = arrayData.numElements()
+    val elementType = expr.dataType.asInstanceOf[ArrayType].elementType
+
+    // Resolve element extractor once, before the loop.
+    // Spark SQL infers untyped decimal literals (e.g. ARRAY(1.0, 0.5)) as 
DecimalType,
+    // not DoubleType, so DecimalType is accepted and converted.
+    val getElement: Int => Double = elementType match {
+      case DoubleType     => i => arrayData.getDouble(i)
+      case FloatType      => i => arrayData.getFloat(i).toDouble
+      case IntegerType    => i => arrayData.getInt(i).toDouble
+      case LongType       => i => arrayData.getLong(i).toDouble
+      case d: DecimalType => i => arrayData.getDecimal(i, d.precision, 
d.scale).toDouble
+      case other => throw new HoodieAnalysisException(
+        s"Function '${HoodieVectorSearchTableValuedFunction.FUNC_NAME}': " +
+          s"query vector element type $other not supported, expected numeric 
array")
+    }
+
+    (0 until numElements).map { i =>
+      if (arrayData.isNullAt(i)) throw new HoodieAnalysisException(
+        s"Function '${HoodieVectorSearchTableValuedFunction.FUNC_NAME}': " +
+          s"query vector element at index $i is null")
+      getElement(i)
+    }.toArray
+  }

Review Comment:
   _⚠️ Potential issue_ | _🟡 Minor_
   
   **Add validation that `expr.dataType` is `ArrayType` before casting.**
   
   If a user mistakenly passes a non-array expression as the query vector, 
lines 362 and 364 will throw `ClassCastException` instead of a descriptive 
`HoodieAnalysisException`. Consider adding an early check:
   
   
   <details>
   <summary>🛡️ Proposed fix to add type validation</summary>
   
   ```diff
      private def evaluateQueryVector(expr: Expression): Array[Double] = {
   +    expr.dataType match {
   +      case _: ArrayType => // valid
   +      case other => throw new HoodieAnalysisException(
   +        s"Function '${HoodieVectorSearchTableValuedFunction.FUNC_NAME}': " +
   +          s"query vector must be an array type, got: $other")
   +    }
        if (!expr.foldable) {
          throw new HoodieAnalysisException(
   ```
   </details>
   
   <details>
   <summary>🤖 Prompt for AI Agents</summary>
   
   ```
   Verify each finding against the current code and only fix it if needed.
   
   In
   
`@hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieSparkBaseAnalysis.scala`
   around lines 350 - 386, In evaluateQueryVector, add an early check that
   expr.dataType is an ArrayType before casting: if it's not an ArrayType throw 
a
   HoodieAnalysisException with a clear message that the query vector must be an
   array of numeric types; then safely cast 
expr.dataType.asInstanceOf[ArrayType]
   (used for elementType) and proceed with existing element handling (preserving
   the existing numeric element type matching and null checks). This ensures
   non-array inputs to evaluateQueryVector raise a descriptive
   HoodieAnalysisException instead of a ClassCastException.
   ```
   
   </details>
   
   <!-- 
fingerprinting:phantom:medusa:ocelot:98f10c42-1c16-496f-ac35-5098e46fe12e -->
   
   <!-- This is an auto-generated comment by CodeRabbit -->
   
   — *CodeRabbit* 
([original](https://github.com/yihua/hudi/pull/10#discussion_r3036004780)) 
(source:comment#3036004780)



##########
hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestHoodieVectorSearchFunction.scala:
##########
@@ -0,0 +1,1359 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hudi.functional
+
+import org.apache.hudi.DataSourceWriteOptions._
+import org.apache.hudi.common.schema.HoodieSchema
+import org.apache.hudi.testutils.HoodieSparkClientTestBase
+
+import org.apache.spark.sql.{Row, SaveMode, SparkSession}
+import org.apache.spark.sql.types._
+import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
+import org.junit.jupiter.api.Assertions._
+
+/**
+ * End-to-end tests for the hudi_vector_search table-valued function.
+ * Tests both single-query and batch-query modes with Spark SQL and DataFrame 
API.
+ */
+class TestHoodieVectorSearchFunction extends HoodieSparkClientTestBase {
+
+  var spark: SparkSession = null
+  private val corpusPath = "corpus"
+  private val corpusViewName = "corpus_view"
+
+  // Test corpus: 5 unit-ish vectors in 3D for easy manual verification
+  // doc_1: [1, 0, 0] - x-axis
+  // doc_2: [0, 1, 0] - y-axis
+  // doc_3: [0, 0, 1] - z-axis
+  // doc_4: [0.707, 0.707, 0] - 45 degrees in xy-plane (normalized)
+  // doc_5: [0.577, 0.577, 0.577] - equal in all 3 dims (normalized)
+  private val corpusData = Seq(
+    ("doc_1", Seq(1.0f, 0.0f, 0.0f), "x-axis"),
+    ("doc_2", Seq(0.0f, 1.0f, 0.0f), "y-axis"),
+    ("doc_3", Seq(0.0f, 0.0f, 1.0f), "z-axis"),
+    ("doc_4", Seq(0.70710678f, 0.70710678f, 0.0f), "xy-diagonal"),
+    ("doc_5", Seq(0.57735027f, 0.57735027f, 0.57735027f), "xyz-diagonal")
+  )
+
+  @BeforeEach override def setUp(): Unit = {
+    initPath()
+    initSparkContexts()
+    spark = sqlContext.sparkSession
+    initTestDataGenerator()
+    initHoodieStorage()
+    createCorpusTable()
+  }
+
+  @AfterEach override def tearDown(): Unit = {
+    spark.catalog.dropTempView(corpusViewName)
+    cleanupSparkContexts()
+    cleanupTestDataGenerator()
+    cleanupFileSystem()
+  }
+
+  private def createCorpusTable(): Unit = {
+    val metadata = new MetadataBuilder()
+      .putString(HoodieSchema.TYPE_METADATA_FIELD, "VECTOR(3)")
+      .build()
+
+    val schema = StructType(Seq(
+      StructField("id", StringType, nullable = false),
+      StructField("embedding", ArrayType(FloatType, containsNull = false),
+        nullable = false, metadata),
+      StructField("label", StringType, nullable = true)
+    ))
+
+    val rows = corpusData.map { case (id, emb, label) =>
+      Row(id, emb, label)
+    }
+
+    val df = spark.createDataFrame(
+      spark.sparkContext.parallelize(rows),
+      schema
+    )
+
+    df.write.format("hudi")
+      .option(RECORDKEY_FIELD.key, "id")
+      .option(PRECOMBINE_FIELD.key, "id")
+      .option(TABLE_NAME.key, "vector_search_corpus")
+      .option(TABLE_TYPE.key, "COPY_ON_WRITE")
+      .mode(SaveMode.Overwrite)
+      .save(basePath + "/" + corpusPath)
+
+    spark.read.format("hudi").load(basePath + "/" + corpusPath)
+      .createOrReplaceTempView(corpusViewName)
+  }
+
+  /**
+   * Creates an in-memory Float corpus temp view (no Hudi write).
+   * Schema: id (String), embedding (Array[Float]).
+   */
+  private def createFloatInMemoryView(viewName: String, data: Seq[(String, 
Seq[Float])]): Unit = {
+    val schema = StructType(Seq(
+      StructField("id", StringType, nullable = false),
+      StructField("embedding", ArrayType(FloatType, containsNull = false), 
nullable = false)
+    ))
+    val rows = data.map { case (id, emb) => Row(id, emb) }
+    spark.createDataFrame(spark.sparkContext.parallelize(rows), schema)
+      .createOrReplaceTempView(viewName)
+  }
+
+  /**
+   * Creates an in-memory Byte corpus temp view (no Hudi write).
+   * Schema: id (String), embedding (Array[Byte]).
+   */
+  private def createByteCorpusView(viewName: String, data: Seq[(String, 
Seq[Byte])]): Unit = {
+    val schema = StructType(Seq(
+      StructField("id", StringType, nullable = false),
+      StructField("embedding", ArrayType(ByteType, containsNull = false), 
nullable = false)
+    ))
+    val rows = data.map { case (id, emb) => Row(id, emb) }
+    spark.createDataFrame(spark.sparkContext.parallelize(rows), schema)
+      .createOrReplaceTempView(viewName)
+  }
+
+  /**
+   * Creates a Float query temp view with configurable id and vector column 
names.
+   * Used by batch-query tests to avoid repeating StructType + createDataFrame 
boilerplate.
+   */
+  private def createFloatQueryView(viewName: String, idCol: String, vecCol: 
String,
+                                    data: Seq[(String, Seq[Float])]): Unit = {
+    val schema = StructType(Seq(
+      StructField(idCol, StringType, nullable = false),
+      StructField(vecCol, ArrayType(FloatType, containsNull = false), nullable 
= false)
+    ))
+    val rows = data.map { case (id, vec) => Row(id, vec) }
+    spark.createDataFrame(spark.sparkContext.parallelize(rows), schema)
+      .createOrReplaceTempView(viewName)
+  }
+
+  /**
+   * Writes rows to a Hudi table and registers the result as a Spark temp view.
+   * The supplied schema must include an "id" column used as the record key.
+   */
+  private def writeHudiAndCreateView(schema: StructType, data: Seq[Row], 
tableName: String,
+                                      subPath: String, viewName: String,
+                                      tableType: String = "COPY_ON_WRITE",
+                                      precombineField: String = "id"): Unit = {
+    spark.createDataFrame(spark.sparkContext.parallelize(data), schema)
+      .write.format("hudi")
+      .option(RECORDKEY_FIELD.key, "id")
+      .option(PRECOMBINE_FIELD.key, precombineField)
+      .option(TABLE_NAME.key, tableName)
+      .option(TABLE_TYPE.key, tableType)
+      .mode(SaveMode.Overwrite)
+      .save(basePath + "/" + subPath)
+    spark.read.format("hudi").load(basePath + "/" + subPath)
+      .createOrReplaceTempView(viewName)
+  }
+
+  @Test
+  def testSingleQueryDistanceMetrics(): Unit = {
+    // Verify all three distance metrics with query [1,0,0] against the shared 
corpus.
+
+    // --- Cosine ---
+    val cosine = spark.sql(
+      s"""
+         |SELECT id, _hudi_distance
+         |FROM hudi_vector_search('$corpusViewName', 'embedding', ARRAY(1.0, 
0.0, 0.0), 3, 'cosine')
+         |ORDER BY _hudi_distance
+         |""".stripMargin
+    ).collect()
+    assertEquals(3, cosine.length)
+    assertEquals("doc_1", cosine(0).getAs[String]("id"))
+    assertEquals(0.0, cosine(0).getAs[Double]("_hudi_distance"), 1e-5)
+    assertEquals("doc_4", cosine(1).getAs[String]("id"))
+    assertEquals(1.0 - 0.70710678, cosine(1).getAs[Double]("_hudi_distance"), 
1e-4)
+    assertEquals("doc_5", cosine(2).getAs[String]("id"))
+    assertEquals(1.0 - 0.57735027, cosine(2).getAs[Double]("_hudi_distance"), 
1e-4)
+
+    // --- L2 ---
+    val l2 = spark.sql(
+      s"""
+         |SELECT id, _hudi_distance
+         |FROM hudi_vector_search('$corpusViewName', 'embedding', ARRAY(1.0, 
0.0, 0.0), 3, 'l2')
+         |ORDER BY _hudi_distance
+         |""".stripMargin
+    ).collect()
+    assertEquals(3, l2.length)
+    assertEquals("doc_1", l2(0).getAs[String]("id"))
+    assertEquals(0.0, l2(0).getAs[Double]("_hudi_distance"), 1e-5)
+    assertEquals("doc_4", l2(1).getAs[String]("id"))
+    val expectedL2Doc4 = math.sqrt(math.pow(1.0 - 0.70710678, 2) + 
math.pow(0.70710678, 2))
+    assertEquals(expectedL2Doc4, l2(1).getAs[Double]("_hudi_distance"), 1e-4)
+
+    // --- Dot product ---
+    val dot = spark.sql(
+      s"""
+         |SELECT id, _hudi_distance
+         |FROM hudi_vector_search('$corpusViewName', 'embedding', ARRAY(1.0, 
0.0, 0.0), 3, 'dot_product')
+         |ORDER BY _hudi_distance
+         |""".stripMargin
+    ).collect()
+    assertEquals(3, dot.length)
+    assertEquals("doc_1", dot(0).getAs[String]("id"))
+    assertEquals(-1.0, dot(0).getAs[Double]("_hudi_distance"), 1e-5)
+    assertEquals("doc_4", dot(1).getAs[String]("id"))
+    assertEquals(-0.70710678, dot(1).getAs[Double]("_hudi_distance"), 1e-4)
+  }
+
+  @Test
+  def testSingleQueryDefaultMetric(): Unit = {
+    // Omit metric arg, should default to cosine
+    val result = spark.sql(
+      s"""
+         |SELECT id, _hudi_distance
+         |FROM hudi_vector_search(
+         |  '$corpusViewName',
+         |  'embedding',
+         |  ARRAY(1.0, 0.0, 0.0),
+         |  3
+         |)
+         |ORDER BY _hudi_distance
+         |""".stripMargin
+    ).collect()
+
+    assertEquals(3, result.length)
+    // Should match cosine: doc_1 first with distance ~0
+    assertEquals("doc_1", result(0).getAs[String]("id"))
+    assertEquals(0.0, result(0).getAs[Double]("_hudi_distance"), 1e-5)
+  }
+
+  @Test
+  def testSingleQueryReturnsAllCorpusColumns(): Unit = {
+    val result = spark.sql(
+      s"""
+         |SELECT *
+         |FROM hudi_vector_search(
+         |  '$corpusViewName',
+         |  'embedding',
+         |  ARRAY(1.0, 0.0, 0.0),
+         |  2
+         |)
+         |""".stripMargin
+    )
+
+    // Should have the _hudi_distance column plus original corpus columns 
(embedding is dropped)
+    assertTrue(result.columns.contains("_hudi_distance"))
+    assertTrue(result.columns.contains("id"))
+    assertTrue(result.columns.contains("label"))
+    assertFalse(result.columns.contains("embedding"))
+    assertEquals(2, result.count())
+  }
+
+  @Test
+  def testKGreaterThanCorpus(): Unit = {
+    // k=100, corpus has 5 rows -> should return all 5
+    val result = spark.sql(
+      s"""
+         |SELECT id, _hudi_distance
+         |FROM hudi_vector_search(
+         |  '$corpusViewName',
+         |  'embedding',
+         |  ARRAY(1.0, 0.0, 0.0),
+         |  100
+         |)
+         |""".stripMargin
+    ).collect()
+
+    assertEquals(5, result.length)
+  }
+
+  @Test
+  def testVectorSearchWithWhereClause(): Unit = {
+    val result = spark.sql(
+      s"""
+         |SELECT id, _hudi_distance
+         |FROM hudi_vector_search(
+         |  '$corpusViewName',
+         |  'embedding',
+         |  ARRAY(1.0, 0.0, 0.0),
+         |  5,
+         |  'cosine'
+         |)
+         |WHERE _hudi_distance < 0.5
+         |ORDER BY _hudi_distance
+         |""".stripMargin
+    ).collect()
+
+    // doc_1 (distance ~0), doc_4 (distance ~0.29), doc_5 (~0.42) should pass
+    // doc_2 and doc_3 have distance = 1.0 and should be filtered out
+    assertEquals(3, result.length)
+    assertTrue(result.forall(_.getAs[Double]("_hudi_distance") < 0.5))
+  }
+
+  @Test
+  def testVectorSearchAsSubquery(): Unit = {
+    val result = spark.sql(
+      s"""
+         |SELECT sub.id, sub.label, sub._hudi_distance
+         |FROM (
+         |  SELECT *
+         |  FROM hudi_vector_search(
+         |    '$corpusViewName',
+         |    'embedding',
+         |    ARRAY(0.0, 1.0, 0.0),
+         |    3
+         |  )
+         |) sub
+         |WHERE sub.label != 'y-axis'
+         |ORDER BY sub._hudi_distance
+         |""".stripMargin
+    ).collect()
+
+    // doc_2 (y-axis) is filtered out
+    assertTrue(result.forall(_.getAs[String]("id") != "doc_2"))
+  }
+
+  @Test
+  def testBatchQueryResultsPerQuery(): Unit = {
+    createFloatQueryView("batch_queries", "qid", "qvec", Seq(
+      ("q1", Seq(1.0f, 0.0f, 0.0f)),
+      ("q2", Seq(0.0f, 0.0f, 1.0f))
+    ))
+
+    val resultDf = spark.sql(
+      s"""
+         |SELECT *
+         |FROM hudi_vector_search_batch(
+         |  '$corpusViewName',
+         |  'embedding',
+         |  'batch_queries',
+         |  'qvec',
+         |  2,
+         |  'cosine'
+         |)
+         |""".stripMargin
+    )
+
+    // Verify output columns
+    val columns = resultDf.columns
+    assertTrue(columns.contains("_hudi_distance"))
+    assertTrue(columns.contains("_hudi_query_index"))
+
+    // Each query should get exactly 2 results
+    val resultsByQuery = 
resultDf.groupBy("_hudi_query_index").count().collect()
+    assertEquals(2, resultsByQuery.length)
+    resultsByQuery.foreach { row =>
+      assertEquals(2, row.getLong(1))
+    }
+
+    // Validate that _hudi_query_index has two distinct values
+    val queryIndexValues = 
resultDf.select("_hudi_query_index").distinct().collect()
+      .map(_.getLong(0)).sorted
+    assertEquals(2, queryIndexValues.length)
+    assertTrue(queryIndexValues(0) != queryIndexValues(1))
+
+    // Verify DataFrame operations work on the result (merged from 
testBatchQueryViaDataFrameApi)
+    val topResults = resultDf.filter("_hudi_distance < 0.5").select("id", 
"_hudi_distance", "_hudi_query_index")
+    assertTrue(topResults.count() > 0)
+
+    spark.catalog.dropTempView("batch_queries")
+  }
+
+  @Test
+  def testBatchQuerySameEmbeddingColumnName(): Unit = {
+    // Both corpus and query use the column name "embedding" — previously 
caused ambiguity error
+    createFloatQueryView("same_col_queries", "query_name", "embedding", Seq(
+      ("q_x", Seq(1.0f, 0.0f, 0.0f)),
+      ("q_y", Seq(0.0f, 1.0f, 0.0f))
+    ))
+
+    val result = spark.sql(
+      s"""
+         |SELECT *
+         |FROM hudi_vector_search_batch(
+         |  '$corpusViewName',
+         |  'embedding',
+         |  'same_col_queries',
+         |  'embedding',
+         |  2,
+         |  'cosine'
+         |)
+         |""".stripMargin
+    ).collect()
+
+    // 2 queries x 2 results each = 4 rows; should not throw AnalysisException
+    assertEquals(4, result.length)
+    assertTrue(result.head.schema.fieldNames.contains("_hudi_distance"))
+
+    spark.catalog.dropTempView("same_col_queries")
+  }
+
+  @Test
+  def testTableByPath(): Unit = {
+    val tablePath = basePath + "/" + corpusPath
+    val result = spark.sql(
+      s"""
+         |SELECT id, _hudi_distance
+         |FROM hudi_vector_search(
+         |  '$tablePath',
+         |  'embedding',
+         |  ARRAY(1.0, 0.0, 0.0),
+         |  2
+         |)
+         |ORDER BY _hudi_distance
+         |""".stripMargin
+    ).collect()
+
+    assertEquals(2, result.length)
+    assertEquals("doc_1", result(0).getAs[String]("id"))
+  }
+
+  @Test
+  def testDoubleVectorEmbeddings(): Unit = {
+    val metadata = new MetadataBuilder()
+      .putString(HoodieSchema.TYPE_METADATA_FIELD, "VECTOR(3, DOUBLE)")
+      .build()
+
+    val schema = StructType(Seq(
+      StructField("id", StringType, nullable = false),
+      StructField("embedding", ArrayType(DoubleType, containsNull = false),
+        nullable = false, metadata)
+    ))
+
+    writeHudiAndCreateView(schema, Seq(
+      Row("d1", Seq(1.0, 0.0, 0.0)),
+      Row("d2", Seq(0.0, 1.0, 0.0)),
+      Row("d3", Seq(0.0, 0.0, 1.0))
+    ), "double_vec_search", "double_search", "double_corpus")
+
+    val result = spark.sql(
+      """
+        |SELECT id, _hudi_distance
+        |FROM hudi_vector_search(
+        |  'double_corpus',
+        |  'embedding',
+        |  ARRAY(1.0, 0.0, 0.0),
+        |  2
+        |)
+        |ORDER BY _hudi_distance
+        |""".stripMargin
+    ).collect()
+
+    assertEquals(2, result.length)
+    assertEquals("d1", result(0).getAs[String]("id"))
+    assertEquals(0.0, result(0).getAs[Double]("_hudi_distance"), 1e-10)
+
+    spark.catalog.dropTempView("double_corpus")
+  }
+
+  @Test
+  def testInvalidEmbeddingColumn(): Unit = {
+    val ex = assertThrows(classOf[Exception], () => {
+      spark.sql(
+        s"""
+           |SELECT *
+           |FROM hudi_vector_search(
+           |  '$corpusViewName',
+           |  'nonexistent_col',
+           |  ARRAY(1.0, 0.0, 0.0),
+           |  3
+           |)
+           |""".stripMargin
+      ).collect()
+    })
+    assertTrue(ex.getMessage.contains("nonexistent_col") ||
+      ex.getCause.getMessage.contains("nonexistent_col"))

Review Comment:
   _⚠️ Potential issue_ | _🟡 Minor_
   
   **Stop dereferencing `getCause` blindly in these negative tests.**
   
   If Spark throws a top-level exception with no cause, these assertions NPE 
and hide the real regression. You already use the safer pattern on Lines 
1404-1405; hoisting that into a shared helper would make these checks 
deterministic.
    
   
   <details>
   <summary>Suggested cleanup</summary>
   
   ```diff
   +  private def rootMessage(t: Throwable): String =
   +    
Option(t.getCause).map(rootMessage).getOrElse(Option(t.getMessage).getOrElse(""))
   ```
   
   ```diff
   -    assertTrue(ex.getMessage.contains("nonexistent_col") ||
   -      ex.getCause.getMessage.contains("nonexistent_col"))
   +    assertTrue(rootMessage(ex).contains("nonexistent_col"))
   ```
   </details>
   
   
   Also applies to: 565-566, 582-583
   
   <details>
   <summary>🤖 Prompt for AI Agents</summary>
   
   ```
   Verify each finding against the current code and only fix it if needed.
   
   In
   
`@hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestHoodieVectorSearchFunction.scala`
   around lines 545 - 546, The test currently dereferences ex.getCause blindly 
(in
   TestHoodieVectorSearchFunction) causing NPEs; change the negative assertions 
to
   safely check both the top-level message and the cause only if present (e.g.,
   check ex.getMessage contains "nonexistent_col" OR (ex.getCause != null &&
   ex.getCause.getMessage contains "nonexistent_col")), and extract that logic 
into
   a small shared helper (e.g., assertExceptionContains(Throwable ex, String
   substr)) used by the failing tests instead of repeating the pattern so lines
   545-546 (and the similar checks at the other locations) become deterministic 
and
   null-safe.
   ```
   
   </details>
   
   <!-- 
fingerprinting:phantom:medusa:grasshopper:2ff4ba03-c784-4adc-8b35-55431566bbf1 
-->
   
   <!-- This is an auto-generated comment by CodeRabbit -->
   
   ✅ Addressed in commits 41bcb03 to f43a6ea
   
   — *CodeRabbit* 
([original](https://github.com/yihua/hudi/pull/10#discussion_r3036004782)) 
(source:comment#3036004782)



##########
hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieSparkBaseAnalysis.scala:
##########
@@ -310,6 +327,61 @@ case class ResolveReferences(spark: SparkSession) extends 
Rule[LogicalPlan]
       sparkAdapter.getCatalystPlanUtils.unapplyMergeIntoTable(plan)
   }
 
+  /**
+   * Resolves a table reference to a DataFrame. Accepts either a table 
identifier
+   * (including multi-part identifiers like catalog.db.table) or a file path.
+   */
+  private def resolveTableToDf(table: String): DataFrame = {
+    try {
+      if (table.contains(StoragePath.SEPARATOR)) {
+        spark.read.format("hudi").load(table)
+      } else {
+        spark.table(table)
+      }
+    } catch {
+      case e: Exception => throw new HoodieAnalysisException(
+        s"hudi_vector_search: unable to resolve table '$table': 
${e.getMessage}")
+    }
+  }
+
+  private def evaluateQueryVector(expr: Expression): Array[Double] = {
+    if (!expr.foldable) {
+      throw new HoodieAnalysisException(
+        s"Function '${HoodieVectorSearchTableValuedFunction.FUNC_NAME}': " +
+          "query vector must be a constant expression (e.g., ARRAY(1.0, 2.0, 
3.0))")
+    }
+    val value = expr.eval(null)
+    if (value == null) {
+      throw new HoodieAnalysisException(
+        s"Function '${HoodieVectorSearchTableValuedFunction.FUNC_NAME}': query 
vector cannot be null")
+    }
+
+    val arrayData = value.asInstanceOf[ArrayData]
+    val numElements = arrayData.numElements()
+    val elementType = expr.dataType.asInstanceOf[ArrayType].elementType

Review Comment:
   <a href="#"><img alt="P1" 
src="https://greptile-static-assets.s3.amazonaws.com/badges/p1.svg?v=7"; 
align="top"></a> **Unchecked cast to `ArrayData`/`ArrayType` yields 
`ClassCastException` on non-array input**
   
   If a user passes any non-array foldable expression as the `query_vector` 
argument — e.g. `hudi_vector_search('t', 'emb', 1.0, 5)` or 
`hudi_vector_search('t', 'emb', 'text', 5)` — both 
`value.asInstanceOf[ArrayData]` and `expr.dataType.asInstanceOf[ArrayType]` 
throw an unhandled `ClassCastException` rather than a 
`HoodieAnalysisException`. The user sees a raw JVM stack trace with no hint 
about how to fix the call.
   
   A type guard on `expr.dataType` should precede both casts:
   
   ```suggestion
       if (!expr.dataType.isInstanceOf[ArrayType]) {
         throw new HoodieAnalysisException(
           s"Function '${HoodieVectorSearchTableValuedFunction.FUNC_NAME}': " +
             s"query vector must be an array type, got 
${expr.dataType.simpleString}")
       }
       val arrayData = value.asInstanceOf[ArrayData]
       val numElements = arrayData.numElements()
       val elementType = expr.dataType.asInstanceOf[ArrayType].elementType
   ```
   
   — *Greptile* 
([original](https://github.com/yihua/hudi/pull/10#discussion_r3036042593)) 
(source:comment#3036042593)



##########
hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieSparkBaseAnalysis.scala:
##########
@@ -310,6 +327,61 @@ case class ResolveReferences(spark: SparkSession) extends 
Rule[LogicalPlan]
       sparkAdapter.getCatalystPlanUtils.unapplyMergeIntoTable(plan)
   }
 
+  /**
+   * Resolves a table reference to a DataFrame. Accepts either a table 
identifier
+   * (including multi-part identifiers like catalog.db.table) or a file path.
+   */
+  private def resolveTableToDf(table: String): DataFrame = {
+    try {
+      if (table.contains(StoragePath.SEPARATOR)) {
+        spark.read.format("hudi").load(table)
+      } else {
+        spark.table(table)
+      }
+    } catch {
+      case e: Exception => throw new HoodieAnalysisException(
+        s"hudi_vector_search: unable to resolve table '$table': 
${e.getMessage}")
+    }

Review Comment:
   _⚠️ Potential issue_ | _🟡 Minor_
   
   **Error message hardcodes function name, but method is shared by both TVFs.**
   
   When `resolveTableToDf` is invoked from the batch query handler, the error 
message will incorrectly reference `hudi_vector_search` instead of 
`hudi_vector_search_batch`.
   
   
   
   <details>
   <summary>🛠️ Proposed fix to parameterize the function name</summary>
   
   ```diff
   -  private def resolveTableToDf(table: String): DataFrame = {
   +  private def resolveTableToDf(table: String, funcName: String = 
"hudi_vector_search"): DataFrame = {
        try {
          if (table.contains(StoragePath.SEPARATOR)) {
            spark.read.format("hudi").load(table)
          } else {
            spark.table(table)
          }
        } catch {
          case e: Exception => throw new HoodieAnalysisException(
   -        s"hudi_vector_search: unable to resolve table '$table': 
${e.getMessage}")
   +        s"$funcName: unable to resolve table '$table': ${e.getMessage}")
        }
      }
   ```
   
   Then update the call sites:
   ```diff
        case HoodieVectorSearchTableValuedFunction(args) =>
          ...
   -      val corpusDf = resolveTableToDf(a.table)
   +      val corpusDf = resolveTableToDf(a.table, 
HoodieVectorSearchTableValuedFunction.FUNC_NAME)
          ...
    
        case HoodieVectorSearchBatchTableValuedFunction(args) =>
          ...
   -      val corpusDf = resolveTableToDf(a.corpusTable)
   -      val queryDf = resolveTableToDf(a.queryTable)
   +      val corpusDf = resolveTableToDf(a.corpusTable, 
HoodieVectorSearchBatchTableValuedFunction.FUNC_NAME)
   +      val queryDf = resolveTableToDf(a.queryTable, 
HoodieVectorSearchBatchTableValuedFunction.FUNC_NAME)
   ```
   </details>
   
   <details>
   <summary>🤖 Prompt for AI Agents</summary>
   
   ```
   Verify each finding against the current code and only fix it if needed.
   
   In
   
`@hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieSparkBaseAnalysis.scala`
   around lines 341 - 344, The catch block in resolveTableToDf throws a
   HoodieAnalysisException with a hardcoded function name "hudi_vector_search";
   change resolveTableToDf to accept a functionName (String) parameter and use 
that
   parameter in the exception message instead of the hardcoded string, then 
update
   all call sites (e.g., callers from the batch query handler and the other TVF
   caller) to pass the correct function name such as "hudi_vector_search" or
   "hudi_vector_search_batch" so the error message reflects the invoking TVF.
   ```
   
   </details>
   
   <!-- 
fingerprinting:phantom:medusa:ocelot:33dab172-b32b-4262-92a2-bf7322d34c2e -->
   
   <!-- This is an auto-generated comment by CodeRabbit -->
   
   — *CodeRabbit* 
([original](https://github.com/yihua/hudi/pull/10#discussion_r3036041541)) 
(source:comment#3036041541)



##########
hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/HoodieVectorSearchTableValuedFunction.scala:
##########
@@ -0,0 +1,199 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.plans.logical
+
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, 
Literal}
+import org.apache.spark.sql.hudi.command.exception.HoodieAnalysisException
+import org.apache.spark.sql.types.StringType
+
+object HoodieVectorSearchTableValuedFunction {
+
+  val FUNC_NAME = "hudi_vector_search"
+
+  object DistanceMetric extends Enumeration {
+    val COSINE, L2, DOT_PRODUCT = Value
+
+    def fromString(s: String): Value = 
Option(s).map(_.toLowerCase).getOrElse("") match {
+      case "cosine" => COSINE
+      case "l2" | "euclidean" => L2
+      case "dot_product" | "inner_product" => DOT_PRODUCT
+      case other => throw new HoodieAnalysisException(
+        s"Unsupported distance metric: '$other'. Supported: cosine, l2, 
dot_product")
+    }
+  }
+
+  object SearchAlgorithm extends Enumeration {
+    val BRUTE_FORCE = Value
+
+    def fromString(s: String): Value = 
Option(s).map(_.toLowerCase).getOrElse("") match {
+      case "brute_force" => BRUTE_FORCE
+      case other => throw new HoodieAnalysisException(
+        s"Unsupported search algorithm: '$other'. Supported: brute_force")
+    }
+  }
+
+  case class ParsedArgs(
+    table: String,
+    embeddingCol: String,
+    queryVectorExpr: Expression,
+    k: Int,
+    metric: DistanceMetric.Value,
+    algorithm: SearchAlgorithm.Value
+  )
+
+  /**
+   * Parse arguments for the hudi_vector_search TVF (single-query mode).
+   *
+   * Signature (4–6 args):
+   *   hudi_vector_search('table', 'embedding_col', ARRAY(1.0, 2.0, ...), k [, 
'metric'] [, 'algorithm'])
+   *   metric defaults to 'cosine'; algorithm defaults to 'brute_force'.
+   */
+  def parseArgs(exprs: Seq[Expression]): ParsedArgs = {
+    if (exprs.size < 4 || exprs.size > 6) {
+      throw new HoodieAnalysisException(
+        s"Function '$FUNC_NAME' expects 4-6 arguments: " +
+          "(table, embedding_col, query_vector, k [, metric] [, algorithm]).")
+    }
+
+    def requireStringLiteral(expr: Expression, argName: String): String = expr 
match {
+      case Literal(v, StringType) if v != null => v.toString
+      case _ => throw new HoodieAnalysisException(
+        s"Function '$FUNC_NAME': argument '$argName' must be a string literal, 
got: ${expr.sql}")
+    }
+
+    val table = requireStringLiteral(exprs.head, "table")
+    val embeddingCol = requireStringLiteral(exprs(1), "embedding_col")
+    val queryVectorExpr = exprs(2)
+    val k = parseK(FUNC_NAME, exprs(3))
+    val metric = if (exprs.size >= 5) 
DistanceMetric.fromString(requireStringLiteral(exprs(4), "metric"))
+    else DistanceMetric.COSINE
+    val algorithm = if (exprs.size >= 6) 
SearchAlgorithm.fromString(requireStringLiteral(exprs(5), "algorithm"))
+    else SearchAlgorithm.BRUTE_FORCE
+    ParsedArgs(table, embeddingCol, queryVectorExpr, k, metric, algorithm)
+  }
+
+  private[logical] def parseK(funcName: String, expr: Expression): Int = {
+    val rawValue = expr.eval()
+    val kValue = try {
+      rawValue.toString.toInt
+    } catch {
+      case _: NumberFormatException =>
+        throw new HoodieAnalysisException(
+          s"Function '$funcName': k must be a positive integer, got 
'$rawValue'")
+    }
+    if (kValue <= 0) {
+      throw new HoodieAnalysisException(
+        s"Function '$funcName': k must be a positive integer, got $kValue")
+    }
+    kValue
+  }
+}
+
+/**
+ * Unresolved logical plan node for the {@code hudi_vector_search} 
table-valued function
+ * (single-query mode). Replaced during analysis by the resolved search plan.
+ *
+ * <p><b>Usage (SQL):</b>
+ * {{{
+ *   SELECT * FROM hudi_vector_search('table_name', 'embedding_col', 
ARRAY(1.0, 2.0, 3.0), 10)
+ *   SELECT * FROM hudi_vector_search('table_name', 'embedding_col', 
ARRAY(1.0, 2.0, 3.0), 10, 'cosine', 'brute_force')
+ * }}}
+ *
+ * <p><b>Output columns:</b> all corpus columns (minus embedding) + {@code 
_hudi_distance: Double}.
+ * Results are ordered by distance ascending (lower = more similar), limited 
to top-k.
+ *
+ * <p><b>Type matching:</b> the corpus embedding column and query vector must 
have compatible
+ * element types (e.g. both float or both double). Mismatched types produce an 
error.
+ */
+case class HoodieVectorSearchTableValuedFunction(args: Seq[Expression]) 
extends LeafNode {
+
+  override def output: Seq[Attribute] = Nil
+

Review Comment:
   <a href="#"><img alt="P1" 
src="https://greptile-static-assets.s3.amazonaws.com/badges/p1.svg?v=7"; 
align="top"></a> **NullPointerException when `k` is a null literal**
   
   `expr.eval()` returns `null` whenever the `k` argument is a SQL `NULL` 
expression (e.g. `hudi_vector_search('t', 'emb', ARRAY(1.0), NULL)`). The 
subsequent `rawValue.toString.toInt` call then throws `NullPointerException`, 
which is _not_ a `NumberFormatException` and therefore escapes the catch block 
entirely. The user sees a cryptic NPE stack trace instead of the intended 
validation message.
   
   A null check must be added before calling `.toString`:
   
   ```suggestion
     private[logical] def parseK(funcName: String, expr: Expression): Int = {
       val rawValue = expr.eval()
       if (rawValue == null) {
         throw new HoodieAnalysisException(
           s"Function '$funcName': k must be a positive integer, got null")
       }
       val kValue = try {
         rawValue.toString.toInt
       } catch {
         case _: NumberFormatException =>
           throw new HoodieAnalysisException(
             s"Function '$funcName': k must be a positive integer, got 
'$rawValue'")
       }
       if (kValue <= 0) {
         throw new HoodieAnalysisException(
           s"Function '$funcName': k must be a positive integer, got $kValue")
       }
       kValue
     }
   ```
   
   — *Greptile* 
([original](https://github.com/yihua/hudi/pull/10#discussion_r3035998341)) 
(source:comment#3035998341)



##########
hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieVectorSearchPlanBuilder.scala:
##########
@@ -0,0 +1,325 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hudi.analysis
+
+import org.apache.hudi.common.schema.HoodieSchema
+
+import org.apache.spark.sql.{DataFrame, SparkSession}
+import 
org.apache.spark.sql.catalyst.plans.logical.HoodieVectorSearchTableValuedFunction.{DistanceMetric,
 SearchAlgorithm}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.expressions.Window
+import org.apache.spark.sql.functions.{broadcast, col, 
monotonically_increasing_id, row_number}
+import org.apache.spark.sql.hudi.command.exception.HoodieAnalysisException
+import org.apache.spark.sql.types.{ArrayType, ByteType, DataType, DoubleType, 
FloatType}
+
+import scala.util.{Failure, Success, Try}
+
+/**
+ * Extension point for vector search algorithms. Each implementation provides
+ * the Spark logical plan for single-query and batch-query KNN search.
+ *
+ * To add a new algorithm (e.g. RowMatrix, HNSW):
+ *  1. Create an object extending this trait
+ *  2. Add a value to [[SearchAlgorithm]]
+ *  3. Register the mapping in 
[[HoodieVectorSearchPlanBuilder.resolveAlgorithm]]
+ *
+ * Implementations can use the shared validation helpers on
+ * [[HoodieVectorSearchPlanBuilder]] (validateEmbeddingColumn, 
validateBatchDimensions, etc.)
+ * and the raw distance functions on [[VectorDistanceUtils]].
+ *
+ * The output schema contract:
+ *  - Single-query: all corpus columns (minus the embedding column) + 
`_hudi_distance: Double`
+ *  - Batch-query: all corpus columns (minus the embedding column) + clashing 
query columns
+ *    (prefixed with `_hudi_query_`) + `_hudi_distance: Double` + 
`_hudi_query_index: Long`
+ *  - Results are ordered by `_hudi_distance` ascending (lower = more similar)
+ *  - `_hudi_query_index` is an opaque grouping identifier (not a sequential 
index). Values may be
+ *    large non-contiguous numbers because they are generated by 
`monotonically_increasing_id()`.
+ */
+trait VectorSearchAlgorithm {
+
+  /** Human-readable name for error messages and logging. */
+  def name: String
+
+  /**
+   * Build a plan that finds the k nearest corpus rows to a single query 
vector.
+   *
+   * @param spark          active SparkSession
+   * @param corpusDf       resolved corpus DataFrame (may be Hudi, Parquet, or 
temp view)
+   * @param embeddingCol   name of the array-typed embedding column in corpusDf
+   * @param queryVector    the query vector, normalized to Array[Double]
+   * @param k              number of nearest neighbors to return
+   * @param metric         distance metric (COSINE, L2, DOT_PRODUCT)
+   * @return an analyzed LogicalPlan whose output matches the single-query 
schema contract
+   */
+  def buildSingleQueryPlan(
+      spark: SparkSession,
+      corpusDf: DataFrame,
+      embeddingCol: String,
+      queryVector: Array[Double],
+      k: Int,
+      metric: DistanceMetric.Value): LogicalPlan
+
+  /**
+   * Build a plan that finds the k nearest corpus rows for each row in the 
query table.
+   *
+   * @param spark              active SparkSession
+   * @param corpusDf           resolved corpus DataFrame
+   * @param corpusEmbeddingCol name of the embedding column in corpusDf
+   * @param queryDf            resolved query DataFrame
+   * @param queryEmbeddingCol  name of the embedding column in queryDf
+   * @param k                  number of nearest neighbors per query
+   * @param metric             distance metric (COSINE, L2, DOT_PRODUCT)
+   * @return an analyzed LogicalPlan whose output matches the batch-query 
schema contract
+   * @note Batch mode broadcasts the query table to all executors via a 
cross-join.
+   *       This is designed for small-to-medium query sets (tens to low 
hundreds of rows).
+   *       For large query tables, memory pressure on executors may occur.
+   */
+  def buildBatchQueryPlan(
+      spark: SparkSession,
+      corpusDf: DataFrame,
+      corpusEmbeddingCol: String,
+      queryDf: DataFrame,
+      queryEmbeddingCol: String,
+      k: Int,
+      metric: DistanceMetric.Value): LogicalPlan
+}
+
+/**
+ * Resolves [[SearchAlgorithm]] values to [[VectorSearchAlgorithm]] 
implementations
+ * and provides shared validation helpers used across algorithms.
+ */
+object HoodieVectorSearchPlanBuilder {
+
+  val DISTANCE_COL = "_hudi_distance"
+  private[analysis] val QUERY_ID_COL = "_hudi_query_index"
+  private[analysis] val QUERY_EMB_ALIAS = "_hudi_query_emb"
+  private[analysis] val RANK_COL = "_hudi_rank"
+  private[analysis] val QUERY_COL_PREFIX = "_hudi_query_"
+
+  /** Resolve a [[SearchAlgorithm]] enum value to its implementation. */
+  def resolveAlgorithm(algorithm: SearchAlgorithm.Value): 
VectorSearchAlgorithm = algorithm match {
+    case SearchAlgorithm.BRUTE_FORCE => BruteForceSearchAlgorithm
+    case other => throw new HoodieAnalysisException(
+      s"Unsupported search algorithm: $other")
+  }
+
+  private[analysis] def validateEmbeddingColumn(df: DataFrame, colName: 
String): Unit = {
+    val fieldOpt = df.schema.fields.find(_.name == colName)
+    val field = fieldOpt.getOrElse(

Review Comment:
   <a href="#"><img alt="P1" 
src="https://greptile-static-assets.s3.amazonaws.com/badges/p1.svg?v=7"; 
align="top"></a> **Case-sensitive column lookup is inconsistent with Spark SQL 
defaults**
   
   `find(_.name == colName)` is a byte-exact comparison. Since Spark SQL is 
case-insensitive by default (`spark.sql.caseSensitive = false`), a table with 
column `"Embedding"` will fail with `"Embedding column 'embedding' not found"` 
if the user passes `"embedding"`. Other Hudi analysis helpers use the Spark 
session resolver for these comparisons.
   
   ```suggestion
       val fieldOpt = df.schema.fields.find(_.name.equalsIgnoreCase(colName))
   ```
   
   — *Greptile* 
([original](https://github.com/yihua/hudi/pull/10#discussion_r3036042603)) 
(source:comment#3036042603)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to