This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new bba8cf48d14f [SPARK-50734][SQL] Add catalog API for creating and
registering SQL UDFs
bba8cf48d14f is described below
commit bba8cf48d14f91109eea04e22fd19be188fce5fb
Author: Allison Wang <[email protected]>
AuthorDate: Tue Jan 7 21:18:44 2025 +0800
[SPARK-50734][SQL] Add catalog API for creating and registering SQL UDFs
### What changes were proposed in this pull request?
This PR adds catalog APIs to support the creation and registration of SQL
UDFs. It uses Hive Metastore to persist a SQL UDF by deserializing the function
information into a FunctionResource and storing it in Hive (toCatalogFunction).
During resolution, it retrieves the catalog function and deserializes it into a
SQLFunction.
This PR only adds the catalog API, and a subsequent PR will add the
analyzer logic to resolve SQL UDFs.
### Why are the changes needed?
To support SQL UDFs in Spark.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing tests. End to end tests will be added in the next PR once we
support SQL UDF resolution.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #49389 from allisonwang-db/spark-50734-sql-udf-catalog-api.
Authored-by: Allison Wang <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../src/main/resources/error/error-conditions.json | 11 ++
.../catalyst/analysis/SQLFunctionExpression.scala | 41 ++++
.../sql/catalyst/analysis/SQLFunctionNode.scala | 64 +++++++
.../spark/sql/catalyst/catalog/SQLFunction.scala | 128 ++++++++++++-
.../sql/catalyst/catalog/SessionCatalog.scala | 206 +++++++++++++++++++--
.../sql/catalyst/catalog/UserDefinedFunction.scala | 145 +++++++++++++++
.../catalog/UserDefinedFunctionErrors.scala | 18 ++
.../spark/sql/catalyst/catalog/interface.scala | 4 +-
.../spark/sql/catalyst/trees/TreePatterns.scala | 2 +
.../sql/catalyst/catalog/UserDefinedFunction.scala | 70 -------
.../command/CreateSQLFunctionCommand.scala | 25 ++-
11 files changed, 624 insertions(+), 90 deletions(-)
diff --git a/common/utils/src/main/resources/error/error-conditions.json
b/common/utils/src/main/resources/error/error-conditions.json
index 97c8f059bcde..e47387d59fe1 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -758,6 +758,12 @@
],
"sqlState" : "22018"
},
+ "CORRUPTED_CATALOG_FUNCTION" : {
+ "message" : [
+ "Cannot convert the catalog function '<identifier>' into a SQL function
due to corrupted function information in catalog. If the function is not a SQL
function, please make sure the class name '<className>' is loadable."
+ ],
+ "sqlState" : "0A000"
+ },
"CREATE_PERMANENT_VIEW_WITHOUT_ALIAS" : {
"message" : [
"Not allowed to create the permanent view <name> without explicitly
assigning an alias for the expression <attr>."
@@ -5892,6 +5898,11 @@
"The number of columns produced by the RETURN clause (num:
`<outputSize>`) does not match the number of column names specified by the
RETURNS clause (num: `<returnParamSize>`) of <name>."
]
},
+ "ROUTINE_PROPERTY_TOO_LARGE" : {
+ "message" : [
+ "Cannot convert user defined routine <name> to catalog function:
routine properties are too large."
+ ]
+ },
"SQL_TABLE_UDF_BODY_MUST_BE_A_QUERY" : {
"message" : [
"SQL table function <name> body must be a query."
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SQLFunctionExpression.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SQLFunctionExpression.scala
new file mode 100644
index 000000000000..fb6935d64d4c
--- /dev/null
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SQLFunctionExpression.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.catalyst.catalog.SQLFunction
+import org.apache.spark.sql.catalyst.expressions.{Expression, Unevaluable}
+import
org.apache.spark.sql.catalyst.trees.TreePattern.{SQL_FUNCTION_EXPRESSION,
TreePattern}
+import org.apache.spark.sql.types.DataType
+
+/**
+ * Represent a SQL function expression resolved from the catalog SQL function
builder.
+ */
+case class SQLFunctionExpression(
+ name: String,
+ function: SQLFunction,
+ inputs: Seq[Expression],
+ returnType: Option[DataType]) extends Expression with Unevaluable {
+ override def children: Seq[Expression] = inputs
+ override def dataType: DataType = returnType.get
+ override def nullable: Boolean = true
+ override def prettyName: String = name
+ override def toString: String = s"$name(${children.mkString(", ")})"
+ override protected def withNewChildrenInternal(
+ newChildren: IndexedSeq[Expression]): SQLFunctionExpression = copy(inputs
= newChildren)
+ final override val nodePatterns: Seq[TreePattern] =
Seq(SQL_FUNCTION_EXPRESSION)
+}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SQLFunctionNode.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SQLFunctionNode.scala
new file mode 100644
index 000000000000..38059d9810a7
--- /dev/null
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SQLFunctionNode.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.catalyst.catalog.SQLFunction
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
+import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan,
UnaryNode}
+import
org.apache.spark.sql.catalyst.trees.TreePattern.{FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION,
SQL_TABLE_FUNCTION, TreePattern}
+import org.apache.spark.sql.errors.DataTypeErrors.toSQLId
+import org.apache.spark.sql.errors.QueryCompilationErrors
+
+/**
+ * A container for holding a SQL function query plan and its function
identifier.
+ *
+ * @param function: the SQL function that this node represents.
+ * @param child: the SQL function body.
+ */
+case class SQLFunctionNode(
+ function: SQLFunction,
+ child: LogicalPlan) extends UnaryNode {
+ override def output: Seq[Attribute] = child.output
+ override def stringArgs: Iterator[Any] = Iterator(function.name, child)
+ override protected def withNewChildInternal(newChild: LogicalPlan):
SQLFunctionNode =
+ copy(child = newChild)
+
+ // Throw a reasonable error message when trying to call a SQL UDF with TABLE
argument(s).
+ if (child.containsPattern(FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION)) {
+ throw QueryCompilationErrors
+ .tableValuedArgumentsNotYetImplementedForSqlFunctions("call",
toSQLId(function.name.funcName))
+ }
+}
+
+/**
+ * Represent a SQL table function plan resolved from the catalog SQL table
function builder.
+ */
+case class SQLTableFunction(
+ name: String,
+ function: SQLFunction,
+ inputs: Seq[Expression],
+ override val output: Seq[Attribute]) extends LeafNode {
+ final override val nodePatterns: Seq[TreePattern] = Seq(SQL_TABLE_FUNCTION)
+
+ // Throw a reasonable error message when trying to call a SQL UDF with TABLE
argument(s) because
+ // this functionality is not implemented yet.
+ if
(inputs.exists(_.containsPattern(FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION)))
{
+ throw QueryCompilationErrors
+ .tableValuedArgumentsNotYetImplementedForSqlFunctions("call",
toSQLId(name))
+ }
+}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/catalog/SQLFunction.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SQLFunction.scala
similarity index 61%
rename from
sql/core/src/main/scala/org/apache/spark/sql/catalyst/catalog/SQLFunction.scala
rename to
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SQLFunction.scala
index c0bd4ac80f5e..923373c1856a 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/catalog/SQLFunction.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SQLFunction.scala
@@ -22,9 +22,11 @@ import scala.collection.mutable
import org.json4s.JsonAST.{JArray, JString}
import org.json4s.jackson.JsonMethods.{compact, render}
+import org.apache.spark.SparkException
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.catalog.UserDefinedFunction._
-import org.apache.spark.sql.catalyst.expressions.{Expression, ScalarSubquery}
+import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo,
ScalarSubquery}
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan,
OneRowRelation, Project}
import org.apache.spark.sql.types.{DataType, StructType}
@@ -62,6 +64,8 @@ case class SQLFunction(
assert(exprText.nonEmpty || queryText.nonEmpty)
assert((isTableFunc && returnType.isRight) || (!isTableFunc &&
returnType.isLeft))
+ import SQLFunction._
+
override val language: RoutineLanguage = LanguageSQL
/**
@@ -88,12 +92,98 @@ case class SQLFunction(
(parsedExpression, parsedQuery)
}
}
+
+ /** Get scalar function return data type. */
+ def getScalarFuncReturnType: DataType = returnType match {
+ case Left(dataType) => dataType
+ case Right(_) =>
+ throw SparkException.internalError(
+ "This function is a table function, not a scalar function.")
+ }
+
+ /** Get table function return columns. */
+ def getTableFuncReturnCols: StructType = returnType match {
+ case Left(_) =>
+ throw SparkException.internalError(
+ "This function is a scalar function, not a table function.")
+ case Right(columns) => columns
+ }
+
+ /**
+ * Convert the SQL function to a [[CatalogFunction]].
+ */
+ def toCatalogFunction: CatalogFunction = {
+ val props = sqlFunctionToProps ++ properties
+ CatalogFunction(
+ identifier = name,
+ className = SQL_FUNCTION_PREFIX,
+ resources = propertiesToFunctionResources(props, name))
+ }
+
+ /**
+ * Convert the SQL function to an [[ExpressionInfo]].
+ */
+ def toExpressionInfo: ExpressionInfo = {
+ val props = sqlFunctionToProps ++ functionMetadataToProps ++ properties
+ val usage = mapper.writeValueAsString(props)
+ new ExpressionInfo(
+ SQL_FUNCTION_PREFIX,
+ name.database.orNull,
+ name.funcName,
+ usage,
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "sql_udf")
+ }
+
+ /**
+ * Convert the SQL function fields into properties.
+ */
+ private def sqlFunctionToProps: Map[String, String] = {
+ val props = new mutable.HashMap[String, String]
+ val inputParamText = inputParam.map(_.fields.map(_.toDDL).mkString(", "))
+ inputParamText.foreach(props.put(INPUT_PARAM, _))
+ val returnTypeText = returnType match {
+ case Left(dataType) => dataType.sql
+ case Right(columns) => columns.toDDL
+ }
+ props.put(RETURN_TYPE, returnTypeText)
+ exprText.foreach(props.put(EXPRESSION, _))
+ queryText.foreach(props.put(QUERY, _))
+ comment.foreach(props.put(COMMENT, _))
+ deterministic.foreach(d => props.put(DETERMINISTIC, d.toString))
+ containsSQL.foreach(x => props.put(CONTAINS_SQL, x.toString))
+ props.put(IS_TABLE_FUNC, isTableFunc.toString)
+ props.toMap
+ }
+
+ private def functionMetadataToProps: Map[String, String] = {
+ val props = new mutable.HashMap[String, String]
+ owner.foreach(props.put(OWNER, _))
+ props.put(CREATE_TIME, createTimeMs.toString)
+ props.toMap
+ }
}
object SQLFunction {
private val SQL_FUNCTION_PREFIX = "sqlFunction."
+ private val INPUT_PARAM: String = SQL_FUNCTION_PREFIX + "inputParam"
+ private val RETURN_TYPE: String = SQL_FUNCTION_PREFIX + "returnType"
+ private val EXPRESSION: String = SQL_FUNCTION_PREFIX + "expression"
+ private val QUERY: String = SQL_FUNCTION_PREFIX + "query"
+ private val COMMENT: String = SQL_FUNCTION_PREFIX + "comment"
+ private val DETERMINISTIC: String = SQL_FUNCTION_PREFIX + "deterministic"
+ private val CONTAINS_SQL: String = SQL_FUNCTION_PREFIX + "containsSQL"
+ private val IS_TABLE_FUNC: String = SQL_FUNCTION_PREFIX + "isTableFunc"
+ private val OWNER: String = SQL_FUNCTION_PREFIX + "owner"
+ private val CREATE_TIME: String = SQL_FUNCTION_PREFIX + "createTime"
+
private val FUNCTION_CATALOG_AND_NAMESPACE = "catalogAndNamespace.numParts"
private val FUNCTION_CATALOG_AND_NAMESPACE_PART_PREFIX =
"catalogAndNamespace.part."
@@ -101,6 +191,42 @@ object SQLFunction {
private val FUNCTION_REFERRED_TEMP_FUNCTION_NAMES =
"referredTempFunctionsNames"
private val FUNCTION_REFERRED_TEMP_VARIABLE_NAMES =
"referredTempVariableNames"
+ /**
+ * Convert a [[CatalogFunction]] into a SQL function.
+ */
+ def fromCatalogFunction(function: CatalogFunction, parser: ParserInterface):
SQLFunction = {
+ try {
+ val parts = function.resources.collect { case
FunctionResource(FileResource, uri) =>
+ val index = uri.substring(0, INDEX_LENGTH).toInt
+ val body = uri.substring(INDEX_LENGTH)
+ index -> body
+ }
+ val blob = parts.sortBy(_._1).map(_._2).mkString
+ val props = mapper.readValue(blob, classOf[Map[String, String]])
+ val isTableFunc = props(IS_TABLE_FUNC).toBoolean
+ val returnType = parseReturnTypeText(props(RETURN_TYPE), isTableFunc,
parser)
+ SQLFunction(
+ name = function.identifier,
+ inputParam = props.get(INPUT_PARAM).map(parseTableSchema(_, parser)),
+ returnType = returnType.get,
+ exprText = props.get(EXPRESSION),
+ queryText = props.get(QUERY),
+ comment = props.get(COMMENT),
+ deterministic = props.get(DETERMINISTIC).map(_.toBoolean),
+ containsSQL = props.get(CONTAINS_SQL).map(_.toBoolean),
+ isTableFunc = isTableFunc,
+ props.filterNot(_._1.startsWith(SQL_FUNCTION_PREFIX)))
+ } catch {
+ case e: Exception =>
+ throw new AnalysisException(
+ errorClass = "CORRUPTED_CATALOG_FUNCTION",
+ messageParameters = Map(
+ "identifier" -> s"${function.identifier}",
+ "className" -> s"${function.className}"), cause = Some(e)
+ )
+ }
+ }
+
def parseDefault(text: String, parser: ParserInterface): Expression = {
parser.parseExpression(text)
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
index d87678ac3411..3c6dfe5ac844 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
@@ -32,12 +32,15 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.{SparkException, SparkThrowable}
import org.apache.spark.internal.Logging
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
-import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, Expression,
ExpressionInfo, NamedExpression, UpCast}
+import
org.apache.spark.sql.catalyst.analysis.TableFunctionRegistry.TableFunctionBuilder
+import org.apache.spark.sql.catalyst.catalog.SQLFunction.parseDefault
+import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference,
Cast, Expression, ExpressionInfo, NamedArgumentExpression, NamedExpression,
UpCast}
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser,
ParserInterface}
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project,
SubqueryAlias, View}
+import org.apache.spark.sql.catalyst.plans.logical.{FunctionSignature,
InputParameter, LogicalPlan, NamedParametersSupport, Project, SubqueryAlias,
View}
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, StringUtils}
import org.apache.spark.sql.connector.catalog.CatalogManager
@@ -1532,10 +1535,49 @@ class SessionCatalog(
}
}
+ /**
+ * Create a user defined function.
+ */
+ def createUserDefinedFunction(function: UserDefinedFunction, ignoreIfExists:
Boolean): Unit = {
+ createFunction(function.toCatalogFunction, ignoreIfExists)
+ }
+
// ----------------------------------------------------------------
// | Methods that interact with temporary and metastore functions |
// ----------------------------------------------------------------
+ /**
+ * Constructs a [[FunctionBuilder]] based on the provided class that
represents a function.
+ */
+ private def makeSQLFunctionBuilder(function: SQLFunction): FunctionBuilder =
{
+ if (function.isTableFunc) {
+ throw
UserDefinedFunctionErrors.notAScalarFunction(function.name.nameParts)
+ }
+ (input: Seq[Expression]) => {
+ val args = rearrangeArguments(function.inputParam, input,
function.name.toString)
+ val returnType = function.getScalarFuncReturnType
+ SQLFunctionExpression(
+ function.name.unquotedString, function, args, Some(returnType))
+ }
+ }
+
+ /**
+ * Constructs a [[TableFunctionBuilder]] based on the provided class that
represents a function.
+ */
+ private def makeSQLTableFunctionBuilder(function: SQLFunction):
TableFunctionBuilder = {
+ if (!function.isTableFunc) {
+ throw
UserDefinedFunctionErrors.notATableFunction(function.name.nameParts)
+ }
+ (input: Seq[Expression]) => {
+ val args = rearrangeArguments(function.inputParam, input,
function.name.toString)
+ val returnParam = function.getTableFuncReturnCols
+ val output = returnParam.fields.map { param =>
+ AttributeReference(param.name, param.dataType, param.nullable)()
+ }
+ SQLTableFunction(function.name.unquotedString, function, args,
output.toSeq)
+ }
+ }
+
/**
* Constructs a [[FunctionBuilder]] based on the provided function metadata.
*/
@@ -1550,6 +1592,24 @@ class SessionCatalog(
(input: Seq[Expression]) => functionExpressionBuilder.makeExpression(name,
clazz, input)
}
+ private def makeUserDefinedScalarFuncBuilder(func: UserDefinedFunction):
FunctionBuilder = {
+ func match {
+ case f: SQLFunction => makeSQLFunctionBuilder(f)
+ case _ =>
+ val clsName = func.getClass.getSimpleName
+ throw UserDefinedFunctionErrors.unsupportedUserDefinedFunction(clsName)
+ }
+ }
+
+ private def makeUserDefinedTableFuncBuilder(func: UserDefinedFunction):
TableFunctionBuilder = {
+ func match {
+ case f: SQLFunction => makeSQLTableFunctionBuilder(f)
+ case _ =>
+ val clsName = func.getClass.getSimpleName
+ throw UserDefinedFunctionErrors.unsupportedUserDefinedFunction(clsName)
+ }
+ }
+
/**
* Loads resources such as JARs and Files for a function. Every resource is
represented
* by a tuple (resource type, resource uri).
@@ -1597,6 +1657,81 @@ class SessionCatalog(
"hive")
}
+ /**
+ * Registers a temporary or persistent SQL scalar function into a
session-specific
+ * [[FunctionRegistry]].
+ */
+ def registerSQLScalarFunction(
+ function: SQLFunction,
+ overrideIfExists: Boolean): Unit = {
+ registerUserDefinedFunction[Expression](
+ function,
+ overrideIfExists,
+ functionRegistry,
+ makeSQLFunctionBuilder(function))
+ }
+
+ /**
+ * Registers a temporary or persistent SQL table function into a
session-specific
+ * [[TableFunctionRegistry]].
+ */
+ def registerSQLTableFunction(
+ function: SQLFunction,
+ overrideIfExists: Boolean): Unit = {
+ registerUserDefinedFunction[LogicalPlan](
+ function,
+ overrideIfExists,
+ tableFunctionRegistry,
+ makeSQLTableFunctionBuilder(function))
+ }
+
+ /**
+ * Rearranges the arguments of a UDF into positional order.
+ */
+ private def rearrangeArguments(
+ inputParams: Option[StructType],
+ expressions: Seq[Expression],
+ functionName: String) : Seq[Expression] = {
+ val firstNamedArgumentExpressionIdx =
+ expressions.indexWhere(_.isInstanceOf[NamedArgumentExpression])
+ if (firstNamedArgumentExpressionIdx == -1) {
+ return expressions
+ }
+
+ val paramNames: Seq[InputParameter] =
+ if (inputParams.isDefined) {
+ inputParams.get.map {
+ p => p.getDefault() match {
+ case Some(defaultExpr) =>
+ // This cast is needed to ensure the default value is of the
target data type.
+ InputParameter(p.name, Some(Cast(parseDefault(defaultExpr,
parser), p.dataType)))
+ case None =>
+ InputParameter(p.name)
+ }
+ }.toSeq
+ } else {
+ Seq()
+ }
+
+ NamedParametersSupport.defaultRearrange(
+ FunctionSignature(paramNames), expressions, functionName)
+ }
+
+ /**
+ * Registers a temporary or permanent SQL function into a session-specific
function registry.
+ */
+ private def registerUserDefinedFunction[T](
+ function: UserDefinedFunction,
+ overrideIfExists: Boolean,
+ registry: FunctionRegistryBase[T],
+ functionBuilder: Seq[Expression] => T): Unit = {
+ if (registry.functionExists(function.name) && !overrideIfExists) {
+ throw QueryCompilationErrors.functionAlreadyExistsError(function.name)
+ }
+ val info = function.toExpressionInfo
+ registry.registerFunction(function.name, info, functionBuilder)
+ }
+
/**
* Unregister a temporary or permanent function from a session-specific
[[FunctionRegistry]]
* or [[TableFunctionRegistry]]. Return true if function exists.
@@ -1753,7 +1888,11 @@ class SessionCatalog(
requireDbExists(db)
if (externalCatalog.functionExists(db, funcName)) {
val metadata = externalCatalog.getFunction(db, funcName)
- makeExprInfoForHiveFunction(metadata.copy(identifier =
qualifiedIdent))
+ if (metadata.isUserDefinedFunction) {
+ UserDefinedFunction.fromCatalogFunction(metadata,
parser).toExpressionInfo
+ } else {
+ makeExprInfoForHiveFunction(metadata.copy(identifier =
qualifiedIdent))
+ }
} else {
failFunctionLookup(name)
}
@@ -1765,7 +1904,26 @@ class SessionCatalog(
*/
def resolvePersistentFunction(
name: FunctionIdentifier, arguments: Seq[Expression]): Expression = {
- resolvePersistentFunctionInternal(name, arguments, functionRegistry,
makeFunctionBuilder)
+ resolvePersistentFunctionInternal[Expression](
+ name,
+ arguments,
+ functionRegistry,
+ registerHiveFunc = func =>
+ registerFunction(
+ func,
+ overrideIfExists = false,
+ registry = functionRegistry,
+ functionBuilder = makeFunctionBuilder(func)
+ ),
+ registerUserDefinedFunc = function => {
+ val builder = makeUserDefinedScalarFuncBuilder(function)
+ registerUserDefinedFunction[Expression](
+ function = function,
+ overrideIfExists = false,
+ registry = functionRegistry,
+ functionBuilder = builder)
+ }
+ )
}
/**
@@ -1774,16 +1932,29 @@ class SessionCatalog(
def resolvePersistentTableFunction(
name: FunctionIdentifier,
arguments: Seq[Expression]): LogicalPlan = {
- // We don't support persistent table functions yet.
- val builder = (func: CatalogFunction) => failFunctionLookup(name)
- resolvePersistentFunctionInternal(name, arguments, tableFunctionRegistry,
builder)
+ resolvePersistentFunctionInternal[LogicalPlan](
+ name,
+ arguments,
+ tableFunctionRegistry,
+ // We don't support persistent Hive table functions yet.
+ registerHiveFunc = (func: CatalogFunction) => failFunctionLookup(name),
+ registerUserDefinedFunc = function => {
+ val builder = makeUserDefinedTableFuncBuilder(function)
+ registerUserDefinedFunction[LogicalPlan](
+ function = function,
+ overrideIfExists = false,
+ registry = tableFunctionRegistry,
+ functionBuilder = builder)
+ }
+ )
}
private def resolvePersistentFunctionInternal[T](
name: FunctionIdentifier,
arguments: Seq[Expression],
registry: FunctionRegistryBase[T],
- createFunctionBuilder: CatalogFunction =>
FunctionRegistryBase[T]#FunctionBuilder): T = {
+ registerHiveFunc: CatalogFunction => Unit,
+ registerUserDefinedFunc: UserDefinedFunction => Unit): T = {
// `synchronized` is used to prevent multiple threads from concurrently
resolving the
// same function that has not yet been loaded into the function registry.
This is needed
// because calling `registerFunction` twice with `overrideIfExists =
false` can lead to
@@ -1799,19 +1970,24 @@ class SessionCatalog(
// The function has not been loaded to the function registry, which
means
// that the function is a persistent function (if it actually has been
registered
// in the metastore). We need to first put the function in the
function registry.
- val catalogFunction = externalCatalog.getFunction(db, funcName)
- loadFunctionResources(catalogFunction.resources)
+ val catalogFunction = try {
+ externalCatalog.getFunction(db, funcName)
+ } catch {
+ case _: AnalysisException => failFunctionLookup(qualifiedIdent)
+ }
// Please note that qualifiedName is provided by the user. However,
// catalogFunction.identifier.unquotedString is returned by the
underlying
// catalog. So, it is possible that qualifiedName is not exactly the
same as
// catalogFunction.identifier.unquotedString (difference is on
case-sensitivity).
// At here, we preserve the input from the user.
val funcMetadata = catalogFunction.copy(identifier = qualifiedIdent)
- registerFunction(
- funcMetadata,
- overrideIfExists = false,
- registry = registry,
- functionBuilder = createFunctionBuilder(funcMetadata))
+ if (!catalogFunction.isUserDefinedFunction) {
+ loadFunctionResources(catalogFunction.resources)
+ registerHiveFunc(funcMetadata)
+ } else {
+ val function = UserDefinedFunction.fromCatalogFunction(funcMetadata,
parser)
+ registerUserDefinedFunc(function)
+ }
// Now, we need to create the Expression.
registry.lookupFunction(qualifiedIdent, arguments)
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/UserDefinedFunction.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/UserDefinedFunction.scala
new file mode 100644
index 000000000000..fe00184e843a
--- /dev/null
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/UserDefinedFunction.scala
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.catalog
+
+import com.fasterxml.jackson.annotation.JsonInclude.Include
+import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper}
+import com.fasterxml.jackson.module.scala.{DefaultScalaModule,
ScalaObjectMapper}
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.catalyst.FunctionIdentifier
+import org.apache.spark.sql.catalyst.expressions.ExpressionInfo
+import org.apache.spark.sql.catalyst.parser.ParserInterface
+import org.apache.spark.sql.catalyst.util.CharVarcharUtils
+import org.apache.spark.sql.types.{DataType, StructType}
+
+/**
+ * The base class for all user defined functions registered via SQL.
+ */
+trait UserDefinedFunction {
+
+ /**
+ * Qualified name of the function
+ */
+ def name: FunctionIdentifier
+
+ /**
+ * Additional properties to be serialized for the function.
+ * Use this to preserve the runtime configuration that should be used during
the function
+ * execution, such as SQL configs etc. See [[SQLConf]] for more info.
+ */
+ def properties: Map[String, String]
+
+ /**
+ * Owner of the function
+ */
+ def owner: Option[String]
+
+ /**
+ * Function creation time in milliseconds since the linux epoch
+ */
+ def createTimeMs: Long
+
+ /**
+ * The language of the user defined function.
+ */
+ def language: RoutineLanguage
+
+ /**
+ * Convert the function to a [[CatalogFunction]].
+ */
+ def toCatalogFunction: CatalogFunction
+
+ /**
+ * Convert the SQL function to an [[ExpressionInfo]].
+ */
+ def toExpressionInfo: ExpressionInfo
+}
+
+object UserDefinedFunction {
+ val SQL_CONFIG_PREFIX = "sqlConfig."
+ val INDEX_LENGTH: Int = 3
+
+ // The default Hive Metastore SQL schema length for function resource uri.
+ private val HIVE_FUNCTION_RESOURCE_URI_LENGTH_THRESHOLD: Int = 4000
+
+ def parseTableSchema(text: String, parser: ParserInterface): StructType = {
+ val parsed = parser.parseTableSchema(text)
+ CharVarcharUtils.failIfHasCharVarchar(parsed).asInstanceOf[StructType]
+ }
+
+ def parseDataType(text: String, parser: ParserInterface): DataType = {
+ val dataType = parser.parseDataType(text)
+ CharVarcharUtils.failIfHasCharVarchar(dataType)
+ }
+
+ private val _mapper: ObjectMapper = getObjectMapper
+
+ /**
+ * A shared [[ObjectMapper]] for serializations.
+ */
+ def mapper: ObjectMapper = _mapper
+
+ /**
+ * Convert the given properties to a list of function resources.
+ */
+ def propertiesToFunctionResources(
+ props: Map[String, String],
+ name: FunctionIdentifier): Seq[FunctionResource] = {
+ val blob = mapper.writeValueAsString(props)
+ val threshold = HIVE_FUNCTION_RESOURCE_URI_LENGTH_THRESHOLD - INDEX_LENGTH
+ blob.grouped(threshold).zipWithIndex.map { case (part, i) =>
+ // Add a sequence number to the part and pad it to a given length.
+ // E.g. 1 will become "001" if the given length is 3.
+ val index = s"%0${INDEX_LENGTH}d".format(i)
+ if (index.length > INDEX_LENGTH) {
+ throw UserDefinedFunctionErrors.routinePropertyTooLarge(name.funcName)
+ }
+ FunctionResource(FileResource, index + part)
+ }.toSeq
+ }
+
+ /**
+ * Get a object mapper to serialize and deserialize function properties.
+ */
+ private def getObjectMapper: ObjectMapper = {
+ val mapper = new ObjectMapper with ScalaObjectMapper
+ mapper.setSerializationInclusion(Include.NON_ABSENT)
+ mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
+ mapper.registerModule(DefaultScalaModule)
+ mapper
+ }
+
+ /**
+ * Convert a [[CatalogFunction]] into a corresponding UDF.
+ */
+ def fromCatalogFunction(function: CatalogFunction, parser: ParserInterface)
+ : UserDefinedFunction = {
+ val className = function.className
+ if (SQLFunction.isSQLFunction(className)) {
+ SQLFunction.fromCatalogFunction(function, parser)
+ } else {
+ throw SparkException.internalError(s"Unsupported function type
$className")
+ }
+ }
+
+ /**
+ * Verify if the function is a [[UserDefinedFunction]].
+ */
+ def isUserDefinedFunction(className: String): Boolean =
SQLFunction.isSQLFunction(className)
+}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/UserDefinedFunctionErrors.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/UserDefinedFunctionErrors.scala
index e8cfa8d74e83..904a17bc8ce4 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/UserDefinedFunctionErrors.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/UserDefinedFunctionErrors.scala
@@ -97,4 +97,22 @@ object UserDefinedFunctionErrors extends QueryErrorsBase {
"tempObj" -> "VARIABLE",
"tempObjName" -> toSQLId(varName)))
}
+
+ def routinePropertyTooLarge(routineName: String): Throwable = {
+ new AnalysisException(
+ errorClass = "USER_DEFINED_FUNCTIONS.ROUTINE_PROPERTY_TOO_LARGE",
+ messageParameters = Map("name" -> toSQLId(routineName)))
+ }
+
+ def notAScalarFunction(functionName: Seq[String]): Throwable = {
+ new AnalysisException(
+ errorClass = "NOT_A_SCALAR_FUNCTION",
+ messageParameters = Map("functionName" -> toSQLId(functionName)))
+ }
+
+ def notATableFunction(functionName: Seq[String]): Throwable = {
+ new AnalysisException(
+ errorClass = "NOT_A_TABLE_FUNCTION",
+ messageParameters = Map("functionName" -> toSQLId(functionName)))
+ }
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
index 858e2cf25b6f..2ebfcf781b97 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
@@ -104,7 +104,9 @@ trait MetadataMapSupport {
case class CatalogFunction(
identifier: FunctionIdentifier,
className: String,
- resources: Seq[FunctionResource])
+ resources: Seq[FunctionResource]) {
+ val isUserDefinedFunction: Boolean =
UserDefinedFunction.isUserDefinedFunction(className)
+}
/**
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
index 1dfb0336ecf0..80531da4a0ab 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
@@ -90,6 +90,8 @@ object TreePattern extends Enumeration {
val SCALA_UDF: Value = Value
val SESSION_WINDOW: Value = Value
val SORT: Value = Value
+ val SQL_FUNCTION_EXPRESSION: Value = Value
+ val SQL_TABLE_FUNCTION: Value = Value
val SUBQUERY_ALIAS: Value = Value
val SUM: Value = Value
val TIME_WINDOW: Value = Value
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/catalog/UserDefinedFunction.scala
b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/catalog/UserDefinedFunction.scala
deleted file mode 100644
index 6567062841de..000000000000
---
a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/catalog/UserDefinedFunction.scala
+++ /dev/null
@@ -1,70 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.catalog
-
-import org.apache.spark.sql.catalyst.FunctionIdentifier
-import org.apache.spark.sql.catalyst.parser.ParserInterface
-import org.apache.spark.sql.catalyst.util.CharVarcharUtils
-import org.apache.spark.sql.types.{DataType, StructType}
-
-/**
- * The base class for all user defined functions registered via SQL queries.
- */
-trait UserDefinedFunction {
-
- /**
- * Qualified name of the function
- */
- def name: FunctionIdentifier
-
- /**
- * Additional properties to be serialized for the function.
- * Use this to preserve the runtime configuration that should be used during
the function
- * execution, such as SQL configs etc. See [[SQLConf]] for more info.
- */
- def properties: Map[String, String]
-
- /**
- * Owner of the function
- */
- def owner: Option[String]
-
- /**
- * Function creation time in milliseconds since the linux epoch
- */
- def createTimeMs: Long
-
- /**
- * The language of the user defined function.
- */
- def language: RoutineLanguage
-}
-
-object UserDefinedFunction {
- val SQL_CONFIG_PREFIX = "sqlConfig."
-
- def parseTableSchema(text: String, parser: ParserInterface): StructType = {
- val parsed = parser.parseTableSchema(text)
- CharVarcharUtils.failIfHasCharVarchar(parsed).asInstanceOf[StructType]
- }
-
- def parseDataType(text: String, parser: ParserInterface): DataType = {
- val dataType = parser.parseDataType(text)
- CharVarcharUtils.failIfHasCharVarchar(dataType)
- }
-}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CreateSQLFunctionCommand.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CreateSQLFunctionCommand.scala
index 25598a12af22..fe4e6f121f57 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CreateSQLFunctionCommand.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CreateSQLFunctionCommand.scala
@@ -20,12 +20,12 @@ package org.apache.spark.sql.execution.command
import org.apache.spark.SparkException
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalyst.FunctionIdentifier
-import org.apache.spark.sql.catalyst.analysis.{Analyzer, UnresolvedAlias,
UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation}
+import org.apache.spark.sql.catalyst.analysis.{Analyzer, SQLFunctionNode,
UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation}
import org.apache.spark.sql.catalyst.catalog.{SessionCatalog, SQLFunction,
UserDefinedFunctionErrors}
import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, Generator,
LateralSubquery, Literal, ScalarSubquery, SubqueryExpression, WindowExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.Inner
-import org.apache.spark.sql.catalyst.plans.logical.{LateralJoin, LogicalPlan,
OneRowRelation, Project, SQLFunctionNode, UnresolvedWith}
+import org.apache.spark.sql.catalyst.plans.logical.{LateralJoin, LogicalPlan,
OneRowRelation, Project, UnresolvedWith}
import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_ATTRIBUTE
import org.apache.spark.sql.errors.QueryCompilationErrors
import
org.apache.spark.sql.execution.command.CreateUserDefinedFunctionCommand._
@@ -249,7 +249,26 @@ case class CreateSQLFunctionCommand(
)
}
- // TODO: create/register sql functions in catalog
+ if (isTemp) {
+ if (isTableFunc) {
+ catalog.registerSQLTableFunction(newFunction, overrideIfExists =
replace)
+ } else {
+ catalog.registerSQLScalarFunction(newFunction, overrideIfExists =
replace)
+ }
+ } else {
+ if (replace && catalog.functionExists(name)) {
+ // Hive metastore alter function method does not alter function
resources
+ // so the existing function must be dropped first when replacing a SQL
function.
+ assert(!ignoreIfExists)
+ catalog.dropFunction(name, ignoreIfExists)
+ }
+ // For a persistent function, we will store the metadata into underlying
external catalog.
+ // This function will be loaded into the FunctionRegistry when a query
uses it.
+ // We do not load it into FunctionRegistry right now, to avoid loading
the resource
+ // immediately, as the Spark application to create the function may not
have
+ // access to the function.
+ catalog.createUserDefinedFunction(newFunction, ignoreIfExists)
+ }
Seq.empty
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]