rahil-c commented on code in PR #18432: URL: https://github.com/apache/hudi/pull/18432#discussion_r3037486865
########## hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestHoodieVectorSearchFunction.scala: ########## @@ -0,0 +1,1359 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hudi.functional + +import org.apache.hudi.DataSourceWriteOptions._ +import org.apache.hudi.common.schema.HoodieSchema +import org.apache.hudi.testutils.HoodieSparkClientTestBase + +import org.apache.spark.sql.{Row, SaveMode, SparkSession} +import org.apache.spark.sql.types._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test} +import org.junit.jupiter.api.Assertions._ + +/** + * End-to-end tests for the hudi_vector_search table-valued function. + * Tests both single-query and batch-query modes with Spark SQL and DataFrame API. + */ +class TestHoodieVectorSearchFunction extends HoodieSparkClientTestBase { + + var spark: SparkSession = null + private val corpusPath = "corpus" + private val corpusViewName = "corpus_view" + + // Test corpus: 5 unit-ish vectors in 3D for easy manual verification + // doc_1: [1, 0, 0] - x-axis + // doc_2: [0, 1, 0] - y-axis + // doc_3: [0, 0, 1] - z-axis + // doc_4: [0.707, 0.707, 0] - 45 degrees in xy-plane (normalized) + // doc_5: [0.577, 0.577, 0.577] - equal in all 3 dims (normalized) + private val corpusData = Seq( + ("doc_1", Seq(1.0f, 0.0f, 0.0f), "x-axis"), + ("doc_2", Seq(0.0f, 1.0f, 0.0f), "y-axis"), + ("doc_3", Seq(0.0f, 0.0f, 1.0f), "z-axis"), + ("doc_4", Seq(0.70710678f, 0.70710678f, 0.0f), "xy-diagonal"), + ("doc_5", Seq(0.57735027f, 0.57735027f, 0.57735027f), "xyz-diagonal") + ) + + @BeforeEach override def setUp(): Unit = { + initPath() + initSparkContexts() + spark = sqlContext.sparkSession + initTestDataGenerator() + initHoodieStorage() + createCorpusTable() + } + + @AfterEach override def tearDown(): Unit = { + spark.catalog.dropTempView(corpusViewName) + cleanupSparkContexts() + cleanupTestDataGenerator() + cleanupFileSystem() + } + + private def createCorpusTable(): Unit = { + val metadata = new MetadataBuilder() + .putString(HoodieSchema.TYPE_METADATA_FIELD, "VECTOR(3)") + .build() + + val schema = StructType(Seq( + StructField("id", StringType, nullable = false), + StructField("embedding", ArrayType(FloatType, containsNull = false), + nullable = false, metadata), + StructField("label", StringType, nullable = true) + )) + + val rows = corpusData.map { case (id, emb, label) => + Row(id, emb, label) + } + + val df = spark.createDataFrame( + spark.sparkContext.parallelize(rows), + schema + ) + + df.write.format("hudi") + .option(RECORDKEY_FIELD.key, "id") + .option(PRECOMBINE_FIELD.key, "id") + .option(TABLE_NAME.key, "vector_search_corpus") + .option(TABLE_TYPE.key, "COPY_ON_WRITE") + .mode(SaveMode.Overwrite) + .save(basePath + "/" + corpusPath) + + spark.read.format("hudi").load(basePath + "/" + corpusPath) + .createOrReplaceTempView(corpusViewName) + } + + /** + * Creates an in-memory Float corpus temp view (no Hudi write). + * Schema: id (String), embedding (Array[Float]). + */ + private def createFloatInMemoryView(viewName: String, data: Seq[(String, Seq[Float])]): Unit = { + val schema = StructType(Seq( + StructField("id", StringType, nullable = false), + StructField("embedding", ArrayType(FloatType, containsNull = false), nullable = false) + )) + val rows = data.map { case (id, emb) => Row(id, emb) } + spark.createDataFrame(spark.sparkContext.parallelize(rows), schema) + .createOrReplaceTempView(viewName) + } + + /** + * Creates an in-memory Byte corpus temp view (no Hudi write). + * Schema: id (String), embedding (Array[Byte]). + */ + private def createByteCorpusView(viewName: String, data: Seq[(String, Seq[Byte])]): Unit = { + val schema = StructType(Seq( + StructField("id", StringType, nullable = false), + StructField("embedding", ArrayType(ByteType, containsNull = false), nullable = false) + )) + val rows = data.map { case (id, emb) => Row(id, emb) } + spark.createDataFrame(spark.sparkContext.parallelize(rows), schema) + .createOrReplaceTempView(viewName) + } + + /** + * Creates a Float query temp view with configurable id and vector column names. + * Used by batch-query tests to avoid repeating StructType + createDataFrame boilerplate. + */ + private def createFloatQueryView(viewName: String, idCol: String, vecCol: String, + data: Seq[(String, Seq[Float])]): Unit = { + val schema = StructType(Seq( + StructField(idCol, StringType, nullable = false), + StructField(vecCol, ArrayType(FloatType, containsNull = false), nullable = false) + )) + val rows = data.map { case (id, vec) => Row(id, vec) } + spark.createDataFrame(spark.sparkContext.parallelize(rows), schema) + .createOrReplaceTempView(viewName) + } + + /** + * Writes rows to a Hudi table and registers the result as a Spark temp view. + * The supplied schema must include an "id" column used as the record key. + */ + private def writeHudiAndCreateView(schema: StructType, data: Seq[Row], tableName: String, + subPath: String, viewName: String, + tableType: String = "COPY_ON_WRITE", + precombineField: String = "id"): Unit = { + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + .write.format("hudi") + .option(RECORDKEY_FIELD.key, "id") + .option(PRECOMBINE_FIELD.key, precombineField) + .option(TABLE_NAME.key, tableName) + .option(TABLE_TYPE.key, tableType) + .mode(SaveMode.Overwrite) + .save(basePath + "/" + subPath) + spark.read.format("hudi").load(basePath + "/" + subPath) + .createOrReplaceTempView(viewName) + } + + @Test + def testSingleQueryDistanceMetrics(): Unit = { + // Verify all three distance metrics with query [1,0,0] against the shared corpus. + + // --- Cosine --- + val cosine = spark.sql( + s""" + |SELECT id, _hudi_distance + |FROM hudi_vector_search('$corpusViewName', 'embedding', ARRAY(1.0, 0.0, 0.0), 3, 'cosine') + |ORDER BY _hudi_distance + |""".stripMargin + ).collect() + assertEquals(3, cosine.length) + assertEquals("doc_1", cosine(0).getAs[String]("id")) + assertEquals(0.0, cosine(0).getAs[Double]("_hudi_distance"), 1e-5) + assertEquals("doc_4", cosine(1).getAs[String]("id")) + assertEquals(1.0 - 0.70710678, cosine(1).getAs[Double]("_hudi_distance"), 1e-4) + assertEquals("doc_5", cosine(2).getAs[String]("id")) + assertEquals(1.0 - 0.57735027, cosine(2).getAs[Double]("_hudi_distance"), 1e-4) + + // --- L2 --- + val l2 = spark.sql( + s""" + |SELECT id, _hudi_distance + |FROM hudi_vector_search('$corpusViewName', 'embedding', ARRAY(1.0, 0.0, 0.0), 3, 'l2') + |ORDER BY _hudi_distance + |""".stripMargin + ).collect() + assertEquals(3, l2.length) + assertEquals("doc_1", l2(0).getAs[String]("id")) + assertEquals(0.0, l2(0).getAs[Double]("_hudi_distance"), 1e-5) + assertEquals("doc_4", l2(1).getAs[String]("id")) + val expectedL2Doc4 = math.sqrt(math.pow(1.0 - 0.70710678, 2) + math.pow(0.70710678, 2)) + assertEquals(expectedL2Doc4, l2(1).getAs[Double]("_hudi_distance"), 1e-4) + + // --- Dot product --- + val dot = spark.sql( + s""" + |SELECT id, _hudi_distance + |FROM hudi_vector_search('$corpusViewName', 'embedding', ARRAY(1.0, 0.0, 0.0), 3, 'dot_product') + |ORDER BY _hudi_distance + |""".stripMargin + ).collect() + assertEquals(3, dot.length) + assertEquals("doc_1", dot(0).getAs[String]("id")) + assertEquals(-1.0, dot(0).getAs[Double]("_hudi_distance"), 1e-5) + assertEquals("doc_4", dot(1).getAs[String]("id")) + assertEquals(-0.70710678, dot(1).getAs[Double]("_hudi_distance"), 1e-4) + } + + @Test + def testSingleQueryDefaultMetric(): Unit = { + // Omit metric arg, should default to cosine + val result = spark.sql( + s""" + |SELECT id, _hudi_distance + |FROM hudi_vector_search( + | '$corpusViewName', + | 'embedding', + | ARRAY(1.0, 0.0, 0.0), + | 3 + |) + |ORDER BY _hudi_distance + |""".stripMargin + ).collect() + + assertEquals(3, result.length) + // Should match cosine: doc_1 first with distance ~0 + assertEquals("doc_1", result(0).getAs[String]("id")) + assertEquals(0.0, result(0).getAs[Double]("_hudi_distance"), 1e-5) + } + + @Test + def testSingleQueryReturnsAllCorpusColumns(): Unit = { + val result = spark.sql( + s""" + |SELECT * + |FROM hudi_vector_search( + | '$corpusViewName', + | 'embedding', + | ARRAY(1.0, 0.0, 0.0), + | 2 + |) + |""".stripMargin + ) + + // Should have the _hudi_distance column plus original corpus columns (embedding is dropped) + assertTrue(result.columns.contains("_hudi_distance")) + assertTrue(result.columns.contains("id")) + assertTrue(result.columns.contains("label")) + assertFalse(result.columns.contains("embedding")) + assertEquals(2, result.count()) + } + + @Test + def testKGreaterThanCorpus(): Unit = { + // k=100, corpus has 5 rows -> should return all 5 + val result = spark.sql( + s""" + |SELECT id, _hudi_distance + |FROM hudi_vector_search( + | '$corpusViewName', + | 'embedding', + | ARRAY(1.0, 0.0, 0.0), + | 100 + |) + |""".stripMargin + ).collect() + + assertEquals(5, result.length) + } + + @Test + def testVectorSearchWithWhereClause(): Unit = { + val result = spark.sql( + s""" + |SELECT id, _hudi_distance + |FROM hudi_vector_search( + | '$corpusViewName', + | 'embedding', + | ARRAY(1.0, 0.0, 0.0), + | 5, + | 'cosine' + |) + |WHERE _hudi_distance < 0.5 + |ORDER BY _hudi_distance + |""".stripMargin + ).collect() + + // doc_1 (distance ~0), doc_4 (distance ~0.29), doc_5 (~0.42) should pass + // doc_2 and doc_3 have distance = 1.0 and should be filtered out + assertEquals(3, result.length) + assertTrue(result.forall(_.getAs[Double]("_hudi_distance") < 0.5)) + } + + @Test + def testVectorSearchAsSubquery(): Unit = { + val result = spark.sql( + s""" + |SELECT sub.id, sub.label, sub._hudi_distance + |FROM ( + | SELECT * + | FROM hudi_vector_search( + | '$corpusViewName', + | 'embedding', + | ARRAY(0.0, 1.0, 0.0), + | 3 + | ) + |) sub + |WHERE sub.label != 'y-axis' + |ORDER BY sub._hudi_distance + |""".stripMargin + ).collect() + + // doc_2 (y-axis) is filtered out + assertTrue(result.forall(_.getAs[String]("id") != "doc_2")) + } + + @Test + def testBatchQueryResultsPerQuery(): Unit = { + createFloatQueryView("batch_queries", "qid", "qvec", Seq( + ("q1", Seq(1.0f, 0.0f, 0.0f)), + ("q2", Seq(0.0f, 0.0f, 1.0f)) + )) + + val resultDf = spark.sql( + s""" + |SELECT * + |FROM hudi_vector_search_batch( + | '$corpusViewName', + | 'embedding', + | 'batch_queries', + | 'qvec', + | 2, + | 'cosine' + |) + |""".stripMargin + ) + + // Verify output columns + val columns = resultDf.columns + assertTrue(columns.contains("_hudi_distance")) + assertTrue(columns.contains("_hudi_query_index")) + + // Each query should get exactly 2 results + val resultsByQuery = resultDf.groupBy("_hudi_query_index").count().collect() + assertEquals(2, resultsByQuery.length) + resultsByQuery.foreach { row => + assertEquals(2, row.getLong(1)) + } + + // Validate that _hudi_query_index has two distinct values + val queryIndexValues = resultDf.select("_hudi_query_index").distinct().collect() + .map(_.getLong(0)).sorted + assertEquals(2, queryIndexValues.length) + assertTrue(queryIndexValues(0) != queryIndexValues(1)) + + // Verify DataFrame operations work on the result (merged from testBatchQueryViaDataFrameApi) + val topResults = resultDf.filter("_hudi_distance < 0.5").select("id", "_hudi_distance", "_hudi_query_index") + assertTrue(topResults.count() > 0) + + spark.catalog.dropTempView("batch_queries") + } + + @Test + def testBatchQuerySameEmbeddingColumnName(): Unit = { + // Both corpus and query use the column name "embedding" — previously caused ambiguity error + createFloatQueryView("same_col_queries", "query_name", "embedding", Seq( + ("q_x", Seq(1.0f, 0.0f, 0.0f)), + ("q_y", Seq(0.0f, 1.0f, 0.0f)) + )) + + val result = spark.sql( + s""" + |SELECT * + |FROM hudi_vector_search_batch( + | '$corpusViewName', + | 'embedding', + | 'same_col_queries', + | 'embedding', + | 2, + | 'cosine' + |) + |""".stripMargin + ).collect() + + // 2 queries x 2 results each = 4 rows; should not throw AnalysisException + assertEquals(4, result.length) + assertTrue(result.head.schema.fieldNames.contains("_hudi_distance")) + + spark.catalog.dropTempView("same_col_queries") + } + + @Test + def testTableByPath(): Unit = { + val tablePath = basePath + "/" + corpusPath + val result = spark.sql( + s""" + |SELECT id, _hudi_distance + |FROM hudi_vector_search( + | '$tablePath', + | 'embedding', + | ARRAY(1.0, 0.0, 0.0), + | 2 + |) + |ORDER BY _hudi_distance + |""".stripMargin + ).collect() + + assertEquals(2, result.length) + assertEquals("doc_1", result(0).getAs[String]("id")) + } + + @Test + def testDoubleVectorEmbeddings(): Unit = { + val metadata = new MetadataBuilder() + .putString(HoodieSchema.TYPE_METADATA_FIELD, "VECTOR(3, DOUBLE)") + .build() + + val schema = StructType(Seq( + StructField("id", StringType, nullable = false), + StructField("embedding", ArrayType(DoubleType, containsNull = false), + nullable = false, metadata) + )) + + writeHudiAndCreateView(schema, Seq( + Row("d1", Seq(1.0, 0.0, 0.0)), + Row("d2", Seq(0.0, 1.0, 0.0)), + Row("d3", Seq(0.0, 0.0, 1.0)) + ), "double_vec_search", "double_search", "double_corpus") + + val result = spark.sql( + """ + |SELECT id, _hudi_distance + |FROM hudi_vector_search( + | 'double_corpus', + | 'embedding', + | ARRAY(1.0, 0.0, 0.0), + | 2 + |) + |ORDER BY _hudi_distance + |""".stripMargin + ).collect() + + assertEquals(2, result.length) + assertEquals("d1", result(0).getAs[String]("id")) + assertEquals(0.0, result(0).getAs[Double]("_hudi_distance"), 1e-10) + + spark.catalog.dropTempView("double_corpus") + } + + @Test + def testInvalidEmbeddingColumn(): Unit = { + val ex = assertThrows(classOf[Exception], () => { + spark.sql( + s""" + |SELECT * + |FROM hudi_vector_search( + | '$corpusViewName', + | 'nonexistent_col', + | ARRAY(1.0, 0.0, 0.0), + | 3 + |) + |""".stripMargin + ).collect() + }) + assertTrue(ex.getMessage.contains("nonexistent_col") || + ex.getCause.getMessage.contains("nonexistent_col")) Review Comment: ack can address. ########## hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieSparkBaseAnalysis.scala: ########## @@ -310,6 +327,61 @@ case class ResolveReferences(spark: SparkSession) extends Rule[LogicalPlan] sparkAdapter.getCatalystPlanUtils.unapplyMergeIntoTable(plan) } + /** + * Resolves a table reference to a DataFrame. Accepts either a table identifier + * (including multi-part identifiers like catalog.db.table) or a file path. + */ + private def resolveTableToDf(table: String): DataFrame = { + try { + if (table.contains(StoragePath.SEPARATOR)) { + spark.read.format("hudi").load(table) + } else { + spark.table(table) + } + } catch { + case e: Exception => throw new HoodieAnalysisException( + s"hudi_vector_search: unable to resolve table '$table': ${e.getMessage}") + } Review Comment: will address -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
