This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new a2c10380314 [SPARK-39579][SQL][PYTHON][R] Make ListFunctions/getFunction/functionExists compatible with 3 layer namespace a2c10380314 is described below commit a2c10380314392ada357a2e235f6a9b64244ae25 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Tue Jul 5 11:16:56 2022 +0800 [SPARK-39579][SQL][PYTHON][R] Make ListFunctions/getFunction/functionExists compatible with 3 layer namespace ### What changes were proposed in this pull request? Make ListFunctions/getFunction/functionExists compatible with 3 layer namespace ### Why are the changes needed? to support 3 layer namespace ### Does this PR introduce _any_ user-facing change? yes ### How was this patch tested? added UT Closes #36977 from zhengruifeng/sql_3L_catalog_list_functions. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- R/pkg/tests/fulltests/test_sparkSQL.R | 6 +- python/pyspark/sql/catalog.py | 73 +++++++++++++- python/pyspark/sql/tests/test_catalog.py | 28 ++++-- .../org/apache/spark/sql/catalog/interface.scala | 19 +++- .../apache/spark/sql/internal/CatalogImpl.scala | 109 +++++++++++++++++++-- .../apache/spark/sql/internal/CatalogSuite.scala | 61 +++++++++++- 6 files changed, 266 insertions(+), 30 deletions(-) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 0f984d0022a..b3218abb133 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -4050,12 +4050,12 @@ test_that("catalog APIs, listTables, listColumns, listFunctions", { f <- listFunctions() expect_true(nrow(f) >= 200) # 250 expect_equal(colnames(f), - c("name", "database", "description", "className", "isTemporary")) + c("name", "catalog", "namespace", "description", "className", "isTemporary")) expect_equal(take(orderBy(f, "className"), 1)$className, "org.apache.spark.sql.catalyst.expressions.Abs") expect_error(listFunctions("zxwtyswklpf_db"), - paste("Error in listFunctions : analysis error - Database", - "'zxwtyswklpf_db' does not exist")) + paste("Error in listFunctions : no such database - Database", + "'zxwtyswklpf_db' not found")) # recoverPartitions does not work with temporary view expect_error(recoverPartitions("cars"), diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index 624e3877db0..42c040c284b 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -68,6 +68,8 @@ class Column(NamedTuple): class Function(NamedTuple): name: str + catalog: Optional[str] + namespace: Optional[List[str]] description: Optional[str] className: str isTemporary: bool @@ -287,6 +289,9 @@ class Catalog: If no database is specified, the current database is used. This includes all temporary functions. + + .. versionchanged:: 3.4 + Allowed ``dbName`` to be qualified with catalog name. """ if dbName is None: dbName = self.currentDatabase() @@ -294,9 +299,17 @@ class Catalog: functions = [] while iter.hasNext(): jfunction = iter.next() + jnamespace = jfunction.namespace() + if jnamespace is not None: + namespace = [jnamespace[i] for i in range(0, len(jnamespace))] + else: + namespace = None + functions.append( Function( name=jfunction.name(), + catalog=jfunction.catalog(), + namespace=namespace, description=jfunction.description(), className=jfunction.className(), isTemporary=jfunction.isTemporary(), @@ -318,19 +331,75 @@ class Catalog: name of the database to check function existence in. If no database is specified, the current database is used + .. deprecated:: 3.4.0 + + Returns ------- bool Indicating whether the function exists + .. versionchanged:: 3.4 + Allowed ``functionName`` to be qualified with catalog name + Examples -------- >>> spark.catalog.functionExists("unexisting_function") False + >>> spark.catalog.functionExists("default.unexisting_function") + False + >>> spark.catalog.functionExists("spark_catalog.default.unexisting_function") + False """ if dbName is None: - dbName = self.currentDatabase() - return self._jcatalog.functionExists(dbName, functionName) + return self._jcatalog.functionExists(functionName) + else: + warnings.warn( + "`dbName` has been deprecated since Spark 3.4 and might be removed in " + "a future version. Use functionExists(`dbName.tableName`) instead.", + FutureWarning, + ) + return self._jcatalog.functionExists(self.currentDatabase(), functionName) + + def getFunction(self, functionName: str) -> Function: + """Get the function with the specified name. This function can be a temporary function or a + function. This throws an AnalysisException when the function cannot be found. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + tableName : str + name of the function to check existence. + + Examples + -------- + >>> func = spark.sql("CREATE FUNCTION my_func1 AS 'test.org.apache.spark.sql.MyDoubleAvg'") + >>> spark.catalog.getFunction("my_func1") + Function(name='my_func1', catalog=None, namespace=['default'], ... + >>> spark.catalog.getFunction("default.my_func1") + Function(name='my_func1', catalog=None, namespace=['default'], ... + >>> spark.catalog.getFunction("spark_catalog.default.my_func1") + Function(name='my_func1', catalog='spark_catalog', namespace=['default'], ... + >>> spark.catalog.getFunction("my_func2") + Traceback (most recent call last): + ... + pyspark.sql.utils.AnalysisException: ... + """ + jfunction = self._jcatalog.getFunction(functionName) + jnamespace = jfunction.namespace() + if jnamespace is not None: + namespace = [jnamespace[i] for i in range(0, len(jnamespace))] + else: + namespace = None + return Function( + name=jfunction.name(), + catalog=jfunction.catalog(), + namespace=namespace, + description=jfunction.description(), + className=jfunction.className(), + isTemporary=jfunction.isTemporary(), + ) def listColumns(self, tableName: str, dbName: Optional[str] = None) -> List[Column]: """Returns a list of columns for the given table/view in the specified database. diff --git a/python/pyspark/sql/tests/test_catalog.py b/python/pyspark/sql/tests/test_catalog.py index 49d96a9b7aa..7d81234bce2 100644 --- a/python/pyspark/sql/tests/test_catalog.py +++ b/python/pyspark/sql/tests/test_catalog.py @@ -184,8 +184,6 @@ class CatalogTests(ReusedSQLTestCase): ) def test_list_functions(self): - from pyspark.sql.catalog import Function - spark = self.spark with self.database("some_db"): spark.sql("CREATE DATABASE some_db") @@ -199,15 +197,12 @@ class CatalogTests(ReusedSQLTestCase): self.assertTrue("to_timestamp" in functions) self.assertTrue("to_unix_timestamp" in functions) self.assertTrue("current_database" in functions) + self.assertEqual(functions["+"].name, "+") + self.assertEqual(functions["+"].description, None) self.assertEqual( - functions["+"], - Function( - name="+", - description=None, - className="org.apache.spark.sql.catalyst.expressions.Add", - isTemporary=True, - ), + functions["+"].className, "org.apache.spark.sql.catalyst.expressions.Add" ) + self.assertTrue(functions["+"].isTemporary) self.assertEqual(functions, functionsDefault) with self.function("func1", "some_db.func2"): @@ -237,11 +232,26 @@ class CatalogTests(ReusedSQLTestCase): spark = self.spark with self.function("func1"): self.assertFalse(spark.catalog.functionExists("func1")) + self.assertFalse(spark.catalog.functionExists("default.func1")) + self.assertFalse(spark.catalog.functionExists("spark_catalog.default.func1")) self.assertFalse(spark.catalog.functionExists("func1", "default")) spark.sql("CREATE FUNCTION func1 AS 'org.apache.spark.data.bricks'") self.assertTrue(spark.catalog.functionExists("func1")) + self.assertTrue(spark.catalog.functionExists("default.func1")) + self.assertTrue(spark.catalog.functionExists("spark_catalog.default.func1")) self.assertTrue(spark.catalog.functionExists("func1", "default")) + def test_get_function(self): + spark = self.spark + with self.function("func1"): + spark.sql("CREATE FUNCTION func1 AS 'org.apache.spark.data.bricks'") + func1 = spark.catalog.getFunction("spark_catalog.default.func1") + self.assertTrue(func1.name == "func1") + self.assertTrue(func1.namespace == ["default"]) + self.assertTrue(func1.catalog == "spark_catalog") + self.assertTrue(func1.className == "org.apache.spark.data.bricks") + self.assertFalse(func1.isTemporary) + def test_list_columns(self): from pyspark.sql.catalog import Column diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala index 1f6cb678f1c..59f8099cbee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala @@ -168,7 +168,8 @@ class Column( * A user-defined function in Spark, as returned by `listFunctions` method in [[Catalog]]. * * @param name name of the function. - * @param database name of the database the function belongs to. + * @param catalog name of the catalog that the table belongs to. + * @param namespace the namespace that the table belongs to. * @param description description of the function; description can be null. * @param className the fully qualified class name of the function. * @param isTemporary whether the function is a temporary function or not. @@ -177,12 +178,26 @@ class Column( @Stable class Function( val name: String, - @Nullable val database: String, + @Nullable val catalog: String, + @Nullable val namespace: Array[String], @Nullable val description: String, val className: String, val isTemporary: Boolean) extends DefinedByConstructorParams { + def this( + name: String, + database: String, + description: String, + className: String, + isTemporary: Boolean) = { + this(name, null, Array(database), description, className, isTemporary) + } + + def database: String = { + if (namespace != null && namespace.length == 1) namespace(0) else null + } + override def toString: String = { "Function[" + s"name='$name', " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index c276fbb677c..880c084ab6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -23,16 +23,17 @@ import scala.util.control.NonFatal import org.apache.spark.sql._ import org.apache.spark.sql.catalog.{Catalog, CatalogMetadata, Column, Database, Function, Table} import org.apache.spark.sql.catalyst.{DefinedByConstructorParams, FunctionIdentifier, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.{ResolvedNamespace, ResolvedTable, ResolvedView, UnresolvedDBObjectName, UnresolvedNamespace, UnresolvedTable, UnresolvedTableOrView} +import org.apache.spark.sql.catalyst.analysis.{ResolvedNamespace, ResolvedNonPersistentFunc, ResolvedPersistentFunc, ResolvedTable, ResolvedView, UnresolvedDBObjectName, UnresolvedFunc, UnresolvedNamespace, UnresolvedTable, UnresolvedTableOrView} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.plans.logical.{CreateTable, LocalRelation, RecoverPartitions, ShowNamespaces, ShowTables, SubqueryAlias, TableSpec, View} +import org.apache.spark.sql.catalyst.plans.logical.{CreateTable, LocalRelation, RecoverPartitions, ShowFunctions, ShowNamespaces, ShowTables, SubqueryAlias, TableSpec, View} import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, Identifier, SupportsNamespaces, TableCatalog} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.{CatalogHelper, IdentifierHelper, MultipartIdentifierHelper, TransformHelper} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.{DataSource, LogicalRelation} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.internal.connector.V1Function import org.apache.spark.sql.types.StructType import org.apache.spark.storage.StorageLevel @@ -194,11 +195,40 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { */ @throws[AnalysisException]("database does not exist") override def listFunctions(dbName: String): Dataset[Function] = { - requireDatabaseExists(dbName) - val functions = sessionCatalog.listFunctions(dbName).map { case (functIdent, _) => - makeFunction(functIdent) + // `dbName` could be either a single database name (behavior in Spark 3.3 and prior) or + // a qualified namespace with catalog name. We assume it's a single database name + // and check if we can find the dbName in sessionCatalog. If so we listFunctions under + // that database. Otherwise we try 3-part name parsing and locate the database. + if (sessionCatalog.databaseExists(dbName)) { + val functions = sessionCatalog.listFunctions(dbName) + .map { case (functIdent, _) => makeFunction(functIdent) } + CatalogImpl.makeDataset(functions, sparkSession) + } else { + val ident = sparkSession.sessionState.sqlParser.parseMultipartIdentifier(dbName) + val functions = collection.mutable.ArrayBuilder.make[Function] + + // built-in functions + val plan0 = ShowFunctions(UnresolvedNamespace(ident), + userScope = false, systemScope = true, None) + sparkSession.sessionState.executePlan(plan0).toRdd.collect().foreach { row => + // `lookupBuiltinOrTempFunction` and `lookupBuiltinOrTempTableFunction` in Analyzer + // require the input identifier only contains the function name, otherwise, built-in + // functions will be skipped. + val name = row.getString(0) + functions += makeFunction(Seq(name)) + } + + // user functions + val plan1 = ShowFunctions(UnresolvedNamespace(ident), + userScope = true, systemScope = false, None) + sparkSession.sessionState.executePlan(plan1).toRdd.collect().foreach { row => + // `row.getString(0)` may contain dbName like `db.function`, so extract the function name. + val name = row.getString(0).split("\\.").last + functions += makeFunction(ident :+ name) + } + + CatalogImpl.makeDataset(functions.result(), sparkSession) } - CatalogImpl.makeDataset(functions, sparkSession) } private def makeFunction(funcIdent: FunctionIdentifier): Function = { @@ -211,6 +241,39 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { isTemporary = metadata.getDb == null) } + private def makeFunction(ident: Seq[String]): Function = { + val plan = UnresolvedFunc(ident, "Catalog.makeFunction", false, None) + sparkSession.sessionState.executePlan(plan).analyzed match { + case f: ResolvedPersistentFunc => + val className = f.func match { + case f: V1Function => f.info.getClassName + case f => f.getClass.getName + } + new Function( + name = f.identifier.name(), + catalog = f.catalog.name(), + namespace = f.identifier.namespace(), + description = f.func.description(), + className = className, + isTemporary = false) + + case f: ResolvedNonPersistentFunc => + val className = f.func match { + case f: V1Function => f.info.getClassName + case f => f.getClass.getName + } + new Function( + name = f.name, + catalog = null, + namespace = null, + description = f.func.description(), + className = className, + isTemporary = true) + + case _ => throw QueryCompilationErrors.noSuchFunctionError(ident, plan) + } + } + /** * Returns a list of columns for the given table/view or temporary view. */ @@ -380,8 +443,19 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * function. This throws an `AnalysisException` when no `Function` can be found. */ override def getFunction(functionName: String): Function = { - val functionIdent = sparkSession.sessionState.sqlParser.parseFunctionIdentifier(functionName) - getFunction(functionIdent.database.orNull, functionIdent.funcName) + // calling `sqlParser.parseFunctionIdentifier` to parse functionName. If it contains only + // function name and optionally contains a database name(thus a FunctionIdentifier), then + // we look up the function in sessionCatalog. + // Otherwise we try `sqlParser.parseMultipartIdentifier` to have a sequence of string as + // the qualified identifier and resolve the function through SQL analyzer. + try { + val ident = sparkSession.sessionState.sqlParser.parseFunctionIdentifier(functionName) + getFunction(ident.database.orNull, ident.funcName) + } catch { + case e: org.apache.spark.sql.catalyst.parser.ParseException => + val ident = sparkSession.sessionState.sqlParser.parseMultipartIdentifier(functionName) + makeFunction(ident) + } } /** @@ -443,8 +517,23 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * or a function. */ override def functionExists(functionName: String): Boolean = { - val functionIdent = sparkSession.sessionState.sqlParser.parseFunctionIdentifier(functionName) - functionExists(functionIdent.database.orNull, functionIdent.funcName) + try { + val ident = sparkSession.sessionState.sqlParser.parseFunctionIdentifier(functionName) + functionExists(ident.database.orNull, ident.funcName) + } catch { + case e: org.apache.spark.sql.catalyst.parser.ParseException => + try { + val ident = sparkSession.sessionState.sqlParser.parseMultipartIdentifier(functionName) + val plan = UnresolvedFunc(ident, "Catalog.functionExists", false, None) + sparkSession.sessionState.executePlan(plan).analyzed match { + case _: ResolvedPersistentFunc => true + case _: ResolvedNonPersistentFunc => true + case _ => false + } + } catch { + case _: org.apache.spark.sql.AnalysisException => false + } + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala index 7e4933b3407..f3133026836 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala @@ -32,8 +32,9 @@ import org.apache.spark.sql.catalyst.plans.logical.Range import org.apache.spark.sql.connector.FakeV2Provider import org.apache.spark.sql.connector.catalog.{CatalogManager, Identifier, InMemoryCatalog} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.CatalogHelper +import org.apache.spark.sql.connector.catalog.functions._ import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel @@ -366,7 +367,7 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf val db = new Database("nama", "cataloa", "descripta", "locata") val table = new Table("nama", "cataloa", Array("databasa"), "descripta", "typa", isTemporary = false) - val function = new Function("nama", "databasa", "descripta", "classa", isTemporary = false) + val function = new Function("nama", "cataloa", Array("databasa"), "descripta", "classa", false) val column = new Column( "nama", "descripta", "typa", nullable = false, isPartition = true, isBucket = true) val dbFields = ScalaReflection.getConstructorParameterValues(db) @@ -377,7 +378,9 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf assert(Seq(tableFields(0), tableFields(1), tableFields(3), tableFields(4), tableFields(5)) == Seq("nama", "cataloa", "descripta", "typa", false)) assert(tableFields(2).asInstanceOf[Array[String]].sameElements(Array("databasa"))) - assert(functionFields == Seq("nama", "databasa", "descripta", "classa", false)) + assert((functionFields(0), functionFields(1), functionFields(3), functionFields(4), + functionFields(5)) == ("nama", "cataloa", "descripta", "classa", false)) + assert(functionFields(2).asInstanceOf[Array[String]].sameElements(Array("databasa"))) assert(columnFields == Seq("nama", "descripta", "typa", false, true, true)) val dbString = CatalogImpl.makeDataset(Seq(db), spark).showString(10) val tableString = CatalogImpl.makeDataset(Seq(table), spark).showString(10) @@ -386,7 +389,8 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf dbFields.foreach { f => assert(dbString.contains(f.toString)) } tableFields.foreach { f => assert(tableString.contains(f.toString) || tableString.contains(f.asInstanceOf[Array[String]].mkString(""))) } - functionFields.foreach { f => assert(functionString.contains(f.toString)) } + functionFields.foreach { f => assert(functionString.contains(f.toString) || + functionString.contains(f.asInstanceOf[Array[String]].mkString(""))) } columnFields.foreach { f => assert(columnString.contains(f.toString)) } } @@ -895,4 +899,53 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf }.getMessage assert(e3.contains("unknown_db")) } + + test("SPARK-39579: Three layer namespace compatibility - " + + "listFunctions, getFunction, functionExists") { + createDatabase("my_db1") + createFunction("my_func1", Some("my_db1")) + + val functions1a = spark.catalog.listFunctions("my_db1").collect().map(_.name) + val functions1b = spark.catalog.listFunctions("spark_catalog.my_db1").collect().map(_.name) + assert(functions1a.length > 200 && functions1a.contains("my_func1")) + assert(functions1b.length > 200 && functions1b.contains("my_func1")) + // functions1b contains 5 more functions: [<>, ||, !=, case, between] + assert(functions1a.intersect(functions1b) === functions1a) + + assert(spark.catalog.functionExists("my_db1.my_func1")) + assert(spark.catalog.functionExists("spark_catalog.my_db1.my_func1")) + + val func1a = spark.catalog.getFunction("my_db1.my_func1") + val func1b = spark.catalog.getFunction("spark_catalog.my_db1.my_func1") + assert(func1a.name === func1b.name && func1a.namespace === func1b.namespace && + func1a.className === func1b.className && func1a.isTemporary === func1b.isTemporary) + assert(func1a.catalog === null && func1b.catalog === "spark_catalog") + assert(func1a.description === null && func1b.description === "N/A.") + + val function: UnboundFunction = new UnboundFunction { + override def bind(inputType: StructType): BoundFunction = new ScalarFunction[Int] { + override def inputTypes(): Array[DataType] = Array(IntegerType) + override def resultType(): DataType = IntegerType + override def name(): String = "my_bound_function" + } + override def description(): String = "hello" + override def name(): String = "my_function" + } + + val testCatalog: InMemoryCatalog = + spark.sessionState.catalogManager.catalog("testcat").asInstanceOf[InMemoryCatalog] + testCatalog.createFunction(Identifier.of(Array("my_db2"), "my_func2"), function) + + val functions2 = spark.catalog.listFunctions("testcat.my_db2").collect().map(_.name) + assert(functions2.length > 200 && functions2.contains("my_func2")) + + assert(spark.catalog.functionExists("testcat.my_db2.my_func2")) + assert(!spark.catalog.functionExists("testcat.my_db2.my_func3")) + + val func2 = spark.catalog.getFunction("testcat.my_db2.my_func2") + assert(func2.name === "my_func2" && func2.namespace === Array("my_db2") && + func2.catalog === "testcat" && func2.description === "hello" && + func2.isTemporary === false && + func2.className.startsWith("org.apache.spark.sql.internal.CatalogSuite")) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org