This is an automated email from the ASF dual-hosted git repository.
allisonwang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 5be9587a3ae5 [SPARK-48730][SQL] Implement CreateSQLFunctionCommand for
SQL Scalar and Table Functions
5be9587a3ae5 is described below
commit 5be9587a3ae587678680359f88f84d8554a70a66
Author: Allison Wang <[email protected]>
AuthorDate: Tue Jan 7 10:14:43 2025 +0800
[SPARK-48730][SQL] Implement CreateSQLFunctionCommand for SQL Scalar and
Table Functions
<!--
Thanks for sending a pull request! Here are some tips for you:
1. If this is your first time, please read our contributor guidelines:
https://spark.apache.org/contributing.html
2. Ensure you have added or run the appropriate tests for your PR:
https://spark.apache.org/developer-tools.html
3. If the PR is unfinished, add '[WIP]' in your PR title, e.g.,
'[WIP][SPARK-XXXX] Your PR title ...'.
4. Be sure to keep the PR description updated to reflect all changes.
5. Please write your PR title to summarize what this PR proposes.
6. If possible, provide a concise example to reproduce the issue for a
faster review.
7. If you want to add a new configuration, please read the guideline
first for naming configurations in
'core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala'.
8. If you want to add or modify an error type or message, please read the
guideline first in
'common/utils/src/main/resources/error/README.md'.
-->
### What changes were proposed in this pull request?
<!--
Please clarify what changes you are proposing. The purpose of this section
is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR. See the examples below.
1. If you refactor some codes with changing classes, showing the class
hierarchy will help reviewers.
2. If you fix some SQL features, you can provide some references of other
DBMSes.
3. If there is design documentation, please add the link.
4. If there is a discussion in the mailing list, please add the link.
-->
This PR implements the `CreateSQLFunctionCommand` to support the creation
of SQL scalar and table functions.
Note that the logic for storing and resolving SQL UDFs will be implemented
in subsequent PRs. And more SQL tests will be added once it can work E2E.
### Why are the changes needed?
<!--
Please clarify why the changes are needed. For instance,
1. If you propose a new API, clarify the use case for a new API.
2. If you fix a bug, you can clarify why it is a bug.
-->
To support SQL UDFs.
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such as
the documentation fix.
If yes, please clarify the previous behavior and the change this PR
proposes - provide the console output, description and/or an example to show
the behavior difference if possible.
If possible, please also clarify if this is a user-facing change compared
to the released Spark versions or within the unreleased branches such as master.
If no, write 'No'.
-->
Yes. After this PR, users can create persistent SQL UDFs.
### How was this patch tested?
<!--
If tests were added, say they were added here. Please make sure to add some
test cases that check the changes thoroughly including negative and positive
cases if possible.
If it was tested in a way different from regular unit tests, please clarify
how you tested step by step, ideally copy and paste-able, so that other
reviewers can test and check, and descendants can verify in the future.
If tests were not added, please describe why they were not added and/or why
it was difficult to add.
If benchmark tests were added, please run the benchmarks in GitHub Actions
for the consistent environment, and the instructions could accord to:
https://spark.apache.org/developer-tools.html#github-workflow-benchmarks.
-->
New UTs.
### Was this patch authored or co-authored using generative AI tooling?
<!--
If generative AI tooling has been used in the process of authoring this
patch, please include the
phrase: 'Generated-by: ' followed by the name of the tool and its version.
If no, write 'No'.
Please refer to the [ASF Generative Tooling
Guidance](https://www.apache.org/legal/generative-tooling.html) for details.
-->
No
Closes #49126 from allisonwang-db/spark-48730-create-sql-udf.
Authored-by: Allison Wang <[email protected]>
Signed-off-by: Allison Wang <[email protected]>
---
.../src/main/resources/error/error-conditions.json | 66 +++++
.../org/apache/spark/sql/types/StructField.scala | 12 +
.../spark/sql/catalyst/analysis/unresolved.scala | 3 +-
.../catalog/UserDefinedFunctionErrors.scala | 68 ++++-
.../sql/catalyst/plans/logical/v2Commands.scala | 22 +-
.../spark/sql/errors/QueryCompilationErrors.scala | 9 +
.../spark/sql/catalyst/analysis/AnalysisTest.scala | 15 +
.../catalyst/analysis/ResolveSessionCatalog.scala | 21 ++
.../spark/sql/catalyst/catalog/SQLFunction.scala | 84 ++++++
.../sql/catalyst/catalog/UserDefinedFunction.scala | 2 +
.../catalyst/plans/logical/SQLFunctionNode.scala | 45 +++
.../spark/sql/execution/SparkSqlParser.scala | 15 +-
.../command/CreateSQLFunctionCommand.scala | 328 ++++++++++++++++++++-
.../command/CreateUserDefinedFunctionCommand.scala | 110 +++++++
.../apache/spark/sql/execution/command/views.scala | 15 +-
.../command/CreateSQLFunctionParserSuite.scala | 203 +++++++++++++
.../sql/execution/command/DDLParserSuite.scala | 44 ---
17 files changed, 1004 insertions(+), 58 deletions(-)
diff --git a/common/utils/src/main/resources/error/error-conditions.json
b/common/utils/src/main/resources/error/error-conditions.json
index 573e7f3a6a38..52c0315bd073 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -1229,6 +1229,18 @@
},
"sqlState" : "4274K"
},
+ "DUPLICATE_ROUTINE_PARAMETER_NAMES" : {
+ "message" : [
+ "Found duplicate name(s) in the parameter list of the user-defined
routine <routineName>: <names>."
+ ],
+ "sqlState" : "42734"
+ },
+ "DUPLICATE_ROUTINE_RETURNS_COLUMNS" : {
+ "message" : [
+ "Found duplicate column(s) in the RETURNS clause column list of the
user-defined routine <routineName>: <columns>."
+ ],
+ "sqlState" : "42711"
+ },
"EMITTING_ROWS_OLDER_THAN_WATERMARK_NOT_ALLOWED" : {
"message" : [
"Previous node emitted a row with eventTime=<emittedRowEventTime> which
is older than current_watermark_value=<currentWatermark>",
@@ -4695,6 +4707,12 @@
],
"sqlState" : "42P01"
},
+ "TABLE_VALUED_ARGUMENTS_NOT_YET_IMPLEMENTED_FOR_SQL_FUNCTIONS" : {
+ "message" : [
+ "Cannot <action> SQL user-defined function <functionName> with TABLE
arguments because this functionality is not yet implemented."
+ ],
+ "sqlState" : "0A000"
+ },
"TABLE_VALUED_FUNCTION_FAILED_TO_ANALYZE_IN_PYTHON" : {
"message" : [
"Failed to analyze the Python user defined table function: <msg>"
@@ -5827,6 +5845,54 @@
],
"sqlState" : "42K0E"
},
+ "USER_DEFINED_FUNCTIONS" : {
+ "message" : [
+ "User defined function is invalid:"
+ ],
+ "subClass" : {
+ "CANNOT_CONTAIN_COMPLEX_FUNCTIONS" : {
+ "message" : [
+ "SQL scalar function cannot contain aggregate/window/generate
functions: <queryText>"
+ ]
+ },
+ "CANNOT_REPLACE_NON_SQL_UDF_WITH_SQL_UDF" : {
+ "message" : [
+ "Cannot replace the non-SQL function <name> with a SQL function."
+ ]
+ },
+ "NOT_A_VALID_DEFAULT_EXPRESSION" : {
+ "message" : [
+ "The DEFAULT expression of `<functionName>`.`<parameterName>` is not
supported because it contains a subquery."
+ ]
+ },
+ "NOT_A_VALID_DEFAULT_PARAMETER_POSITION" : {
+ "message" : [
+ "In routine `<functionName>` parameter `<parameterName>` with
DEFAULT must not be followed by parameter `<nextParameterName>` without
DEFAULT."
+ ]
+ },
+ "NOT_NULL_ON_FUNCTION_PARAMETERS" : {
+ "message" : [
+ "Cannot specify NOT NULL on function parameters: <input>"
+ ]
+ },
+ "RETURN_COLUMN_COUNT_MISMATCH" : {
+ "message" : [
+ "The number of columns produced by the RETURN clause (num:
`<outputSize>`) does not match the number of column names specified by the
RETURNS clause (num: `<returnParamSize>`) of <name>."
+ ]
+ },
+ "SQL_TABLE_UDF_BODY_MUST_BE_A_QUERY" : {
+ "message" : [
+ "SQL table function <name> body must be a query."
+ ]
+ },
+ "SQL_TABLE_UDF_MISSING_COLUMN_NAMES" : {
+ "message" : [
+ "The relation returned by the query in the CREATE FUNCTION statement
for <functionName> with RETURNS TABLE clause lacks explicit names for one or
more output columns; please rewrite the function body to provide explicit
column names or add column names to the RETURNS TABLE clause, and re-run the
command."
+ ]
+ }
+ },
+ "sqlState" : "42601"
+ },
"USER_RAISED_EXCEPTION" : {
"message" : [
"<errorMessage>"
diff --git
a/sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala
b/sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala
index d4e590629921..f33a49e686a5 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala
@@ -147,6 +147,18 @@ case class StructField(
if (metadata.contains("comment")) Option(metadata.getString("comment"))
else None
}
+ /**
+ * Return the default value of this StructField. This is used for storing
the default value of a
+ * function parameter.
+ */
+ private[sql] def getDefault(): Option[String] = {
+ if (metadata.contains("default")) {
+ Option(metadata.getString("default"))
+ } else {
+ None
+ }
+ }
+
/**
* Updates the StructField with a new current default value.
*/
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index 87a5e94d9f63..b47af90c651a 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -206,7 +206,8 @@ case class ResolvedInlineTable(rows: Seq[Seq[Expression]],
output: Seq[Attribute
*/
case class UnresolvedTableValuedFunction(
name: Seq[String],
- functionArgs: Seq[Expression])
+ functionArgs: Seq[Expression],
+ override val isStreaming: Boolean = false)
extends UnresolvedLeafNode {
final override val nodePatterns: Seq[TreePattern] =
Seq(UNRESOLVED_TABLE_VALUED_FUNCTION)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/UserDefinedFunctionErrors.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/UserDefinedFunctionErrors.scala
index a5381669caea..e8cfa8d74e83 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/UserDefinedFunctionErrors.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/UserDefinedFunctionErrors.scala
@@ -18,10 +18,12 @@
package org.apache.spark.sql.catalyst.catalog
import org.apache.spark.SparkException
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.errors.QueryErrorsBase
/**
- * Errors during registering and executing [[UserDefinedFunction]]s.
+ * Errors during registering and executing
+ * [[org.apache.spark.sql.expressions.UserDefinedFunction]]s.
*/
object UserDefinedFunctionErrors extends QueryErrorsBase {
def unsupportedUserDefinedFunction(language: RoutineLanguage): Throwable = {
@@ -31,4 +33,68 @@ object UserDefinedFunctionErrors extends QueryErrorsBase {
def unsupportedUserDefinedFunction(language: String): Throwable = {
SparkException.internalError(s"Unsupported user defined function type:
$language")
}
+
+ def duplicateParameterNames(routineName: String, names: String): Throwable =
{
+ new AnalysisException(
+ errorClass = "DUPLICATE_ROUTINE_PARAMETER_NAMES",
+ messageParameters = Map("routineName" -> routineName, "names" -> names))
+ }
+
+ def duplicateReturnsColumns(routineName: String, columns: String): Throwable
= {
+ new AnalysisException(
+ errorClass = "DUPLICATE_ROUTINE_RETURNS_COLUMNS",
+ messageParameters = Map("routineName" -> routineName, "columns" ->
columns))
+ }
+
+ def cannotSpecifyNotNullOnFunctionParameters(input: String): Throwable = {
+ new AnalysisException(
+ errorClass = "USER_DEFINED_FUNCTIONS.NOT_NULL_ON_FUNCTION_PARAMETERS",
+ messageParameters = Map("input" -> input))
+ }
+
+ def bodyIsNotAQueryForSqlTableUdf(functionName: String): Throwable = {
+ new AnalysisException(
+ errorClass = "USER_DEFINED_FUNCTIONS.SQL_TABLE_UDF_BODY_MUST_BE_A_QUERY",
+ messageParameters = Map("name" -> functionName))
+ }
+
+ def missingColumnNamesForSqlTableUdf(functionName: String): Throwable = {
+ new AnalysisException(
+ errorClass = "USER_DEFINED_FUNCTIONS.SQL_TABLE_UDF_MISSING_COLUMN_NAMES",
+ messageParameters = Map("functionName" -> toSQLId(functionName)))
+ }
+
+ def invalidTempViewReference(routineName: Seq[String], tempViewName:
Seq[String]): Throwable = {
+ new AnalysisException(
+ errorClass = "INVALID_TEMP_OBJ_REFERENCE",
+ messageParameters = Map(
+ "obj" -> "FUNCTION",
+ "objName" -> toSQLId(routineName),
+ "tempObj" -> "VIEW",
+ "tempObjName" -> toSQLId(tempViewName)
+ )
+ )
+ }
+
+ def invalidTempFuncReference(routineName: Seq[String], tempFuncName:
String): Throwable = {
+ new AnalysisException(
+ errorClass = "INVALID_TEMP_OBJ_REFERENCE",
+ messageParameters = Map(
+ "obj" -> "FUNCTION",
+ "objName" -> toSQLId(routineName),
+ "tempObj" -> "FUNCTION",
+ "tempObjName" -> toSQLId(tempFuncName)
+ )
+ )
+ }
+
+ def invalidTempVarReference(routineName: Seq[String], varName: Seq[String]):
Throwable = {
+ new AnalysisException(
+ errorClass = "INVALID_TEMP_OBJ_REFERENCE",
+ messageParameters = Map(
+ "obj" -> "FUNCTION",
+ "objName" -> toSQLId(routineName),
+ "tempObj" -> "VARIABLE",
+ "tempObjName" -> toSQLId(varName)))
+ }
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
index 85b5e8379d3d..58c62a90225a 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
@@ -21,8 +21,8 @@ import org.apache.spark.{SparkIllegalArgumentException,
SparkUnsupportedOperatio
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.{AnalysisContext,
AssignmentUtils, EliminateSubqueryAliases, FieldName, NamedRelation,
PartitionSpec, ResolvedIdentifier, ResolvedProcedure, TypeCheckResult,
UnresolvedException, UnresolvedProcedure, ViewSchemaMode}
import
org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch,
TypeCheckSuccess}
+import org.apache.spark.sql.catalyst.catalog.{FunctionResource,
RoutineLanguage}
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
-import org.apache.spark.sql.catalyst.catalog.FunctionResource
import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, AttributeSet, Expression, MetadataAttribute,
NamedExpression, UnaryExpression, Unevaluable, V2ExpressionUtils}
import org.apache.spark.sql.catalyst.plans.DescribeCommandSchema
import org.apache.spark.sql.catalyst.trees.BinaryLike
@@ -1072,6 +1072,26 @@ case class CreateFunction(
copy(child = newChild)
}
+/**
+ * The logical plan of the CREATE FUNCTION command for SQL Functions.
+ */
+case class CreateUserDefinedFunction(
+ child: LogicalPlan,
+ inputParamText: Option[String],
+ returnTypeText: String,
+ exprText: Option[String],
+ queryText: Option[String],
+ comment: Option[String],
+ isDeterministic: Option[Boolean],
+ containsSQL: Option[Boolean],
+ language: RoutineLanguage,
+ isTableFunc: Boolean,
+ ignoreIfExists: Boolean,
+ replace: Boolean) extends UnaryCommand {
+ override protected def withNewChildInternal(newChild: LogicalPlan):
CreateUserDefinedFunction =
+ copy(child = newChild)
+}
+
/**
* The logical plan of the DROP FUNCTION command.
*/
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
index d38c7a01e1c4..65ae8da3c4da 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
@@ -2172,6 +2172,15 @@ private[sql] object QueryCompilationErrors extends
QueryErrorsBase with Compilat
"ability" -> ability))
}
+ def tableValuedArgumentsNotYetImplementedForSqlFunctions(
+ action: String, functionName: String): Throwable = {
+ new AnalysisException(
+ errorClass =
"TABLE_VALUED_ARGUMENTS_NOT_YET_IMPLEMENTED_FOR_SQL_FUNCTIONS",
+ messageParameters = Map(
+ "action" -> action,
+ "functionName" -> functionName))
+ }
+
def tableValuedFunctionTooManyTableArgumentsError(num: Int): Throwable = {
new AnalysisException(
errorClass = "TABLE_VALUED_FUNCTION_TOO_MANY_TABLE_ARGUMENTS",
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
index 71744f4d1510..58e6cd7fe169 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
@@ -198,6 +198,21 @@ trait AnalysisTest extends PlanTest {
}
}
+ protected def assertParseErrorClass(
+ parser: String => Any,
+ sqlCommand: String,
+ errorClass: String,
+ parameters: Map[String, String],
+ queryContext: Array[ExpectedContext] = Array.empty): Unit = {
+ val e = parseException(parser)(sqlCommand)
+ checkError(
+ exception = e,
+ condition = errorClass,
+ parameters = parameters,
+ queryContext = queryContext
+ )
+ }
+
protected def interceptParseException(parser: String => Any)(
sqlCommand: String, messages: String*)(condition: Option[String] = None):
Unit = {
val e = parseException(parser)(sqlCommand)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala
b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala
index 87ea3071f490..6a388a7849f7 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala
@@ -497,6 +497,27 @@ class ResolveSessionCatalog(val catalogManager:
CatalogManager)
case CreateFunction(ResolvedIdentifier(catalog, _), _, _, _, _) =>
throw QueryCompilationErrors.missingCatalogAbilityError(catalog, "CREATE
FUNCTION")
+
+ case c @ CreateUserDefinedFunction(
+ ResolvedIdentifierInSessionCatalog(ident), _, _, _, _, _, _, _, _, _,
_, _) =>
+ CreateUserDefinedFunctionCommand(
+ FunctionIdentifier(ident.table, ident.database, ident.catalog),
+ c.inputParamText,
+ c.returnTypeText,
+ c.exprText,
+ c.queryText,
+ c.comment,
+ c.isDeterministic,
+ c.containsSQL,
+ c.language,
+ c.isTableFunc,
+ isTemp = false,
+ c.ignoreIfExists,
+ c.replace)
+
+ case CreateUserDefinedFunction(
+ ResolvedIdentifier(catalog, _), _, _, _, _, _, _, _, _, _, _, _) =>
+ throw QueryCompilationErrors.missingCatalogAbilityError(catalog, "CREATE
FUNCTION")
}
private def constructV1TableCmd(
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/catalog/SQLFunction.scala
b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/catalog/SQLFunction.scala
index 8ae0341e5646..c0bd4ac80f5e 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/catalog/SQLFunction.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/catalog/SQLFunction.scala
@@ -17,9 +17,16 @@
package org.apache.spark.sql.catalyst.catalog
+import scala.collection.mutable
+
+import org.json4s.JsonAST.{JArray, JString}
+import org.json4s.jackson.JsonMethods.{compact, render}
+
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.catalog.UserDefinedFunction._
+import org.apache.spark.sql.catalyst.expressions.{Expression, ScalarSubquery}
import org.apache.spark.sql.catalyst.parser.ParserInterface
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan,
OneRowRelation, Project}
import org.apache.spark.sql.types.{DataType, StructType}
/**
@@ -56,10 +63,48 @@ case class SQLFunction(
assert((isTableFunc && returnType.isRight) || (!isTableFunc &&
returnType.isLeft))
override val language: RoutineLanguage = LanguageSQL
+
+ /**
+ * Optionally get the function body as an expression or query using the
given parser.
+ */
+ def getExpressionAndQuery(
+ parser: ParserInterface,
+ isTableFunc: Boolean): (Option[Expression], Option[LogicalPlan]) = {
+ // The RETURN clause of the CREATE FUNCTION statement looks like this in
the parser:
+ // RETURN (query | expression)
+ // If the 'query' matches and parses as a SELECT clause of one item with
no FROM clause, and
+ // this is a scalar function, we skip a level of subquery expression
wrapping by using the
+ // referenced expression directly.
+ val parsedExpression = exprText.map(parser.parseExpression)
+ val parsedQuery = queryText.map(parser.parsePlan)
+ (parsedExpression, parsedQuery) match {
+ case (None, Some(Project(expr :: Nil, _: OneRowRelation)))
+ if !isTableFunc =>
+ (Some(expr), None)
+ case (Some(ScalarSubquery(Project(expr :: Nil, _: OneRowRelation), _, _,
_, _, _, _)), None)
+ if !isTableFunc =>
+ (Some(expr), None)
+ case (_, _) =>
+ (parsedExpression, parsedQuery)
+ }
+ }
}
object SQLFunction {
+ private val SQL_FUNCTION_PREFIX = "sqlFunction."
+
+ private val FUNCTION_CATALOG_AND_NAMESPACE = "catalogAndNamespace.numParts"
+ private val FUNCTION_CATALOG_AND_NAMESPACE_PART_PREFIX =
"catalogAndNamespace.part."
+
+ private val FUNCTION_REFERRED_TEMP_VIEW_NAMES = "referredTempViewNames"
+ private val FUNCTION_REFERRED_TEMP_FUNCTION_NAMES =
"referredTempFunctionsNames"
+ private val FUNCTION_REFERRED_TEMP_VARIABLE_NAMES =
"referredTempVariableNames"
+
+ def parseDefault(text: String, parser: ParserInterface): Expression = {
+ parser.parseExpression(text)
+ }
+
/**
* This method returns an optional DataType indicating, when present, either
the return type for
* scalar user-defined functions, or a StructType indicating the names and
types of the columns in
@@ -92,4 +137,43 @@ object SQLFunction {
}
}
}
+
+ def isSQLFunction(className: String): Boolean = className ==
SQL_FUNCTION_PREFIX
+
+ /**
+ * Convert the current catalog and namespace to properties.
+ */
+ def catalogAndNamespaceToProps(
+ currentCatalog: String,
+ currentNamespace: Seq[String]): Map[String, String] = {
+ val props = new mutable.HashMap[String, String]
+ val parts = currentCatalog +: currentNamespace
+ if (parts.nonEmpty) {
+ props.put(FUNCTION_CATALOG_AND_NAMESPACE, parts.length.toString)
+ parts.zipWithIndex.foreach { case (name, index) =>
+ props.put(s"$FUNCTION_CATALOG_AND_NAMESPACE_PART_PREFIX$index", name)
+ }
+ }
+ props.toMap
+ }
+
+ /**
+ * Convert the temporary object names to properties.
+ */
+ def referredTempNamesToProps(
+ viewNames: Seq[Seq[String]],
+ functionsNames: Seq[String],
+ variableNames: Seq[Seq[String]]): Map[String, String] = {
+ val viewNamesJson =
+ JArray(viewNames.map(nameParts =>
JArray(nameParts.map(JString).toList)).toList)
+ val functionsNamesJson = JArray(functionsNames.map(JString).toList)
+ val variableNamesJson =
+ JArray(variableNames.map(nameParts =>
JArray(nameParts.map(JString).toList)).toList)
+
+ val props = new mutable.HashMap[String, String]
+ props.put(FUNCTION_REFERRED_TEMP_VIEW_NAMES,
compact(render(viewNamesJson)))
+ props.put(FUNCTION_REFERRED_TEMP_FUNCTION_NAMES,
compact(render(functionsNamesJson)))
+ props.put(FUNCTION_REFERRED_TEMP_VARIABLE_NAMES,
compact(render(variableNamesJson)))
+ props.toMap
+ }
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/catalog/UserDefinedFunction.scala
b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/catalog/UserDefinedFunction.scala
index 1473f19cb71b..6567062841de 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/catalog/UserDefinedFunction.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/catalog/UserDefinedFunction.scala
@@ -56,6 +56,8 @@ trait UserDefinedFunction {
}
object UserDefinedFunction {
+ val SQL_CONFIG_PREFIX = "sqlConfig."
+
def parseTableSchema(text: String, parser: ParserInterface): StructType = {
val parsed = parser.parseTableSchema(text)
CharVarcharUtils.failIfHasCharVarchar(parsed).asInstanceOf[StructType]
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SQLFunctionNode.scala
b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SQLFunctionNode.scala
new file mode 100644
index 000000000000..0a3274af33b5
--- /dev/null
+++
b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SQLFunctionNode.scala
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.plans.logical
+
+import org.apache.spark.sql.catalyst.catalog.SQLFunction
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import
org.apache.spark.sql.catalyst.trees.TreePattern.FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION
+import org.apache.spark.sql.errors.DataTypeErrors.toSQLId
+import org.apache.spark.sql.errors.QueryCompilationErrors
+
+/**
+ * A container for holding a SQL function query plan and its function
identifier.
+ *
+ * @param function: the SQL function that this node represents.
+ * @param child: the SQL function body.
+ */
+case class SQLFunctionNode(
+ function: SQLFunction,
+ child: LogicalPlan) extends UnaryNode {
+ override def output: Seq[Attribute] = child.output
+ override def stringArgs: Iterator[Any] = Iterator(function.name, child)
+ override protected def withNewChildInternal(newChild: LogicalPlan):
SQLFunctionNode =
+ copy(child = newChild)
+
+ // Throw a reasonable error message when trying to call a SQL UDF with TABLE
argument(s).
+ if (child.containsPattern(FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION)) {
+ throw QueryCompilationErrors
+ .tableValuedArgumentsNotYetImplementedForSqlFunctions("call",
toSQLId(function.name.funcName))
+ }
+}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
index 8d5ddb2d85c4..744ab03d5d03 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
@@ -723,8 +723,19 @@ class SparkSqlAstBuilder extends AstBuilder {
withIdentClause(ctx.identifierReference(), functionIdentifier => {
if (ctx.TEMPORARY == null) {
- // TODO: support creating persistent UDFs.
- operationNotAllowed(s"creating persistent SQL functions is not
supported", ctx)
+ CreateUserDefinedFunction(
+ UnresolvedIdentifier(functionIdentifier),
+ inputParamText,
+ returnTypeText,
+ exprText,
+ queryText,
+ comment,
+ deterministic,
+ containsSQL,
+ language,
+ isTableFunc,
+ ctx.EXISTS != null,
+ ctx.REPLACE != null)
} else {
// Disallow to define a temporary function with `IF NOT EXISTS`
if (ctx.EXISTS != null) {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CreateSQLFunctionCommand.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CreateSQLFunctionCommand.scala
index d2aaa93fcca0..25598a12af22 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CreateSQLFunctionCommand.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CreateSQLFunctionCommand.scala
@@ -17,9 +17,19 @@
package org.apache.spark.sql.execution.command
-import org.apache.spark.sql.{Row, SparkSession}
+import org.apache.spark.SparkException
+import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalyst.FunctionIdentifier
-import org.apache.spark.sql.catalyst.catalog.SQLFunction
+import org.apache.spark.sql.catalyst.analysis.{Analyzer, UnresolvedAlias,
UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation}
+import org.apache.spark.sql.catalyst.catalog.{SessionCatalog, SQLFunction,
UserDefinedFunctionErrors}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, Generator,
LateralSubquery, Literal, ScalarSubquery, SubqueryExpression, WindowExpression}
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
+import org.apache.spark.sql.catalyst.plans.Inner
+import org.apache.spark.sql.catalyst.plans.logical.{LateralJoin, LogicalPlan,
OneRowRelation, Project, SQLFunctionNode, UnresolvedWith}
+import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_ATTRIBUTE
+import org.apache.spark.sql.errors.QueryCompilationErrors
+import
org.apache.spark.sql.execution.command.CreateUserDefinedFunctionCommand._
+import org.apache.spark.sql.types.{DataType, StructField, StructType}
/**
* The DDL command that creates a SQL function.
@@ -52,10 +62,13 @@ case class CreateSQLFunctionCommand(
replace: Boolean)
extends CreateUserDefinedFunctionCommand {
- override def run(sparkSession: SparkSession): Seq[Row] = {
- import SQLFunction._
+ import SQLFunction._
+ override def run(sparkSession: SparkSession): Seq[Row] = {
val parser = sparkSession.sessionState.sqlParser
+ val analyzer = sparkSession.sessionState.analyzer
+ val catalog = sparkSession.sessionState.catalog
+ val conf = sparkSession.sessionState.conf
val inputParam = inputParamText.map(parser.parseTableSchema)
val returnType = parseReturnTypeText(returnTypeText, isTableFunc, parser)
@@ -72,8 +85,313 @@ case class CreateSQLFunctionCommand(
isTableFunc,
Map.empty)
- // TODO: Implement the rest of the method.
+ val newFunction = {
+ val (expression, query) = function.getExpressionAndQuery(parser,
isTableFunc)
+ assert(query.nonEmpty || expression.nonEmpty)
+
+ // Check if the function can be replaced.
+ if (replace && catalog.functionExists(name)) {
+ checkFunctionSignatures(catalog, name)
+ }
+
+ // Build function input.
+ val inputPlan = if (inputParam.isDefined) {
+ val param = inputParam.get
+ checkParameterNotNull(param, inputParamText.get)
+ checkParameterNameDuplication(param, conf, name)
+ checkDefaultsTrailing(param, name)
+
+ // Qualify the input parameters with the function name so that
attributes referencing
+ // the function input parameters can be resolved correctly.
+ val qualifier = Seq(name.funcName)
+ val input = param.map(p => Alias(
+ {
+ val defaultExpr = p.getDefault()
+ if (defaultExpr.isEmpty) {
+ Literal.create(null, p.dataType)
+ } else {
+ val defaultPlan = parseDefault(defaultExpr.get, parser)
+ if (SubqueryExpression.hasSubquery(defaultPlan)) {
+ throw new AnalysisException(
+ errorClass =
"USER_DEFINED_FUNCTIONS.NOT_A_VALID_DEFAULT_EXPRESSION",
+ messageParameters =
+ Map("functionName" -> name.funcName, "parameterName" ->
p.name))
+ } else if (defaultPlan.containsPattern(UNRESOLVED_ATTRIBUTE)) {
+ // TODO(SPARK-50698): use parsed expression instead of
expression string.
+ defaultPlan.collect {
+ case a: UnresolvedAttribute =>
+ throw QueryCompilationErrors.unresolvedAttributeError(
+ "UNRESOLVED_COLUMN", a.sql, Seq.empty, a.origin)
+ }
+ }
+ Cast(defaultPlan, p.dataType)
+ }
+ }, p.name)(qualifier = qualifier))
+ Project(input, OneRowRelation())
+ } else {
+ OneRowRelation()
+ }
+
+ // Build the function body and check if the function body can be
analyzed successfully.
+ val (unresolvedPlan, analyzedPlan, inferredReturnType) = if
(!isTableFunc) {
+ // Build SQL scalar function plan.
+ val outputExpr = if (query.isDefined) ScalarSubquery(query.get) else
expression.get
+ val plan: LogicalPlan = returnType.map { t =>
+ val retType: DataType = t match {
+ case Left(t) => t
+ case _ => throw SparkException.internalError(
+ "Unexpected return type for a scalar SQL UDF.")
+ }
+ val outputCast = Seq(Alias(Cast(outputExpr, retType),
name.funcName)())
+ Project(outputCast, inputPlan)
+ }.getOrElse {
+ // If no explicit RETURNS clause is present, infer the result type
from the function body.
+ val outputAlias = Seq(Alias(outputExpr, name.funcName)())
+ Project(outputAlias, inputPlan)
+ }
+
+ // Check the function body can be analyzed correctly.
+ val analyzed = analyzer.execute(plan)
+ val (resolved, resolvedReturnType) = analyzed match {
+ case p @ Project(expr :: Nil, _) if expr.resolved =>
+ (p, Left(expr.dataType))
+ case other =>
+ (other, function.returnType)
+ }
+
+ // Check if the SQL function body contains aggregate/window functions.
+ // This check needs to be performed before checkAnalysis to provide
better error messages.
+ checkAggOrWindowOrGeneratorExpr(resolved)
+
+ // Check if the SQL function body can be analyzed.
+ checkFunctionBodyAnalysis(analyzer, function, resolved)
+
+ (plan, resolved, resolvedReturnType)
+ } else {
+ // Build SQL table function plan.
+ if (query.isEmpty) {
+ throw
UserDefinedFunctionErrors.bodyIsNotAQueryForSqlTableUdf(name.funcName)
+ }
+
+ // Construct a lateral join to analyze the function body.
+ val plan = LateralJoin(inputPlan, LateralSubquery(query.get), Inner,
None)
+ val analyzed = analyzer.execute(plan)
+ val newPlan = analyzed match {
+ case Project(_, j: LateralJoin) => j
+ case j: LateralJoin => j
+ case _ => throw SparkException.internalError("Unexpected plan
returned when " +
+ s"creating a SQL TVF: ${analyzed.getClass.getSimpleName}.")
+ }
+ val maybeResolved = newPlan.asInstanceOf[LateralJoin].right.plan
+
+ // Check if the function body can be analyzed.
+ checkFunctionBodyAnalysis(analyzer, function, maybeResolved)
+
+ // Get the function's return schema.
+ val returnParam: StructType = returnType.map {
+ case Right(t) => t
+ case Left(_) => throw SparkException.internalError(
+ "Unexpected return schema for a SQL table function.")
+ }.getOrElse {
+ // If no explicit RETURNS clause is present, infer the result type
from the function body.
+ // To detect this, we search for instances of the UnresolvedAlias
expression. Examples:
+ // CREATE TABLE t USING PARQUET AS VALUES (0, 1), (1, 2) AS tab(c1,
c2);
+ // SELECT c1 FROM t --> UnresolvedAttribute: 'c1
+ // SELECT c1 + 1 FROM t --> UnresolvedAlias:
unresolvedalias(('c1 + 1), None)
+ // SELECT c1 + 1 AS a FROM t --> Alias: ('c1 + 1) AS a#2
+ query.get match {
+ case Project(projectList, _) if
projectList.exists(_.isInstanceOf[UnresolvedAlias]) =>
+ throw
UserDefinedFunctionErrors.missingColumnNamesForSqlTableUdf(name.funcName)
+ case _ =>
+
StructType(analyzed.asInstanceOf[LateralJoin].right.plan.output.map { col =>
+ StructField(col.name, col.dataType)
+ })
+ }
+ }
+
+ // Check the return columns cannot have NOT NULL specified.
+ checkParameterNotNull(returnParam, returnTypeText)
+
+ // Check duplicated return column names.
+ checkReturnsColumnDuplication(returnParam, conf, name)
+
+ // Check if the actual output size equals to the number of return
parameters.
+ val outputSize = maybeResolved.output.size
+ if (outputSize != returnParam.size) {
+ throw new AnalysisException(
+ errorClass = "USER_DEFINED_FUNCTIONS.RETURN_COLUMN_COUNT_MISMATCH",
+ messageParameters = Map(
+ "outputSize" -> s"$outputSize",
+ "returnParamSize" -> s"${returnParam.size}",
+ "name" -> s"$name"
+ )
+ )
+ }
+
+ (plan, analyzed, Right(returnParam))
+ }
+
+ // A permanent function is not allowed to reference temporary objects.
+ // This should be called after `qe.assertAnalyzed()` (i.e., `plan` can
be resolved)
+ verifyTemporaryObjectsNotExists(catalog, isTemp, name, unresolvedPlan,
analyzedPlan)
+
+ // Generate function properties.
+ val properties = generateFunctionProperties(sparkSession,
unresolvedPlan, analyzedPlan)
+
+ // Derive determinism of the SQL function.
+ val deterministic = analyzedPlan.deterministic
+
+ function.copy(
+ // Assign the return type, inferring from the function body if needed.
+ returnType = inferredReturnType,
+ deterministic = Some(function.deterministic.getOrElse(deterministic)),
+ properties = properties
+ )
+ }
+
+ // TODO: create/register sql functions in catalog
Seq.empty
}
+
+ /**
+ * Check if the function body can be analyzed.
+ */
+ private def checkFunctionBodyAnalysis(
+ analyzer: Analyzer,
+ function: SQLFunction,
+ body: LogicalPlan): Unit = {
+ analyzer.checkAnalysis(SQLFunctionNode(function, body))
+ }
+
+ /** Check whether the new function is replacing an existing SQL function. */
+ private def checkFunctionSignatures(catalog: SessionCatalog, name:
FunctionIdentifier): Unit = {
+ val info = catalog.lookupFunctionInfo(name)
+ if (!isSQLFunction(info.getClassName)) {
+ throw new AnalysisException(
+ errorClass =
"USER_DEFINED_FUNCTIONS.CANNOT_REPLACE_NON_SQL_UDF_WITH_SQL_UDF",
+ messageParameters = Map("name" -> s"$name")
+ )
+ }
+ }
+
+ /**
+ * Collect all temporary views and functions and return the identifiers
separately
+ * This func traverses the unresolved plan `child`. Below are the reasons:
+ * 1) Analyzer replaces unresolved temporary views by a SubqueryAlias with
the corresponding
+ * logical plan. After replacement, it is impossible to detect whether the
SubqueryAlias is
+ * added/generated from a temporary view.
+ * 2) The temp functions are represented by multiple classes. Most are
inaccessible from this
+ * package (e.g., HiveGenericUDF).
+ * 3) Temporary SQL functions, once resolved, cannot be identified as temp
functions.
+ */
+ private def collectTemporaryObjectsInUnresolvedPlan(
+ catalog: SessionCatalog,
+ child: LogicalPlan): (Seq[Seq[String]], Seq[String]) = {
+ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
+ def collectTempViews(child: LogicalPlan): Seq[Seq[String]] = {
+ child.flatMap {
+ case UnresolvedRelation(nameParts, _, _) if
catalog.isTempView(nameParts) =>
+ Seq(nameParts)
+ case w: UnresolvedWith if !w.resolved =>
w.innerChildren.flatMap(collectTempViews)
+ case plan if !plan.resolved => plan.expressions.flatMap(_.flatMap {
+ case e: SubqueryExpression => collectTempViews(e.plan)
+ case _ => Seq.empty
+ })
+ case _ => Seq.empty
+ }.distinct
+ }
+
+ def collectTempFunctions(child: LogicalPlan): Seq[String] = {
+ child.flatMap {
+ case w: UnresolvedWith if !w.resolved =>
w.innerChildren.flatMap(collectTempFunctions)
+ case plan if !plan.resolved =>
+ plan.expressions.flatMap(_.flatMap {
+ case e: SubqueryExpression => collectTempFunctions(e.plan)
+ case e: UnresolvedFunction
+ if catalog.isTemporaryFunction(e.nameParts.asFunctionIdentifier)
=>
+ Seq(e.nameParts.asFunctionIdentifier.funcName)
+ case _ => Seq.empty
+ })
+ case _ => Seq.empty
+ }.distinct
+ }
+ (collectTempViews(child), collectTempFunctions(child))
+ }
+
+ /**
+ * Permanent functions are not allowed to reference temp objects, including
temp functions
+ * and temp views.
+ */
+ private def verifyTemporaryObjectsNotExists(
+ catalog: SessionCatalog,
+ isTemporary: Boolean,
+ name: FunctionIdentifier,
+ child: LogicalPlan,
+ analyzed: LogicalPlan): Unit = {
+ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
+ if (!isTemporary) {
+ val (tempViews, tempFunctions) =
collectTemporaryObjectsInUnresolvedPlan(catalog, child)
+ tempViews.foreach { nameParts =>
+ throw UserDefinedFunctionErrors.invalidTempViewReference(
+ routineName = name.asMultipart, tempViewName = nameParts)
+ }
+ tempFunctions.foreach { funcName =>
+ throw UserDefinedFunctionErrors.invalidTempFuncReference(
+ routineName = name.asMultipart, tempFuncName = funcName)
+ }
+ val tempVars = ViewHelper.collectTemporaryVariables(analyzed)
+ tempVars.foreach { varName =>
+ throw UserDefinedFunctionErrors.invalidTempVarReference(
+ routineName = name.asMultipart, varName = varName)
+ }
+ }
+ }
+
+ /**
+ * Check if the SQL function body contains aggregate/window/generate
functions.
+ * Note subqueries inside the SQL function body can contain
aggregate/window/generate functions.
+ */
+ private def checkAggOrWindowOrGeneratorExpr(plan: LogicalPlan): Unit = {
+ if (plan.resolved) {
+ plan.transformAllExpressions {
+ case e if e.isInstanceOf[WindowExpression] ||
e.isInstanceOf[Generator] ||
+ e.isInstanceOf[AggregateExpression] =>
+ throw new AnalysisException(
+ errorClass =
"USER_DEFINED_FUNCTIONS.CANNOT_CONTAIN_COMPLEX_FUNCTIONS",
+ messageParameters = Map("queryText" ->
s"${exprText.orElse(queryText).get}")
+ )
+ }
+ }
+ }
+
+ /**
+ * Generate the function properties, including:
+ * 1. the SQL configs when creating the function.
+ * 2. the catalog and database name when creating the function. This will be
used to provide
+ * context during nested function resolution.
+ * 3. referred temporary object names if the function is a temp function.
+ */
+ private def generateFunctionProperties(
+ session: SparkSession,
+ plan: LogicalPlan,
+ analyzed: LogicalPlan): Map[String, String] = {
+ val catalog = session.sessionState.catalog
+ val conf = session.sessionState.conf
+ val manager = session.sessionState.catalogManager
+
+ // Only collect temporary object names when the function is a temp
function.
+ val (tempViews, tempFunctions) = if (isTemp) {
+ collectTemporaryObjectsInUnresolvedPlan(catalog, plan)
+ } else {
+ (Nil, Nil)
+ }
+ val tempVars = ViewHelper.collectTemporaryVariables(analyzed)
+
+ sqlConfigsToProps(conf) ++
+ catalogAndNamespaceToProps(
+ manager.currentCatalog.name,
+ manager.currentNamespace.toIndexedSeq) ++
+ referredTempNamesToProps(tempViews, tempFunctions, tempVars)
+ }
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CreateUserDefinedFunctionCommand.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CreateUserDefinedFunctionCommand.scala
index bebb0f5cf6c3..1ee3c8a4c388 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CreateUserDefinedFunctionCommand.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CreateUserDefinedFunctionCommand.scala
@@ -17,9 +17,15 @@
package org.apache.spark.sql.execution.command
+import java.util.Locale
+
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.catalog.{LanguageSQL, RoutineLanguage,
UserDefinedFunctionErrors}
+import org.apache.spark.sql.catalyst.catalog.UserDefinedFunction._
import org.apache.spark.sql.catalyst.plans.logical.IgnoreCachedData
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.StructType
/**
* The base class for CreateUserDefinedFunctionCommand
@@ -74,4 +80,108 @@ object CreateUserDefinedFunctionCommand {
throw UserDefinedFunctionErrors.unsupportedUserDefinedFunction(other)
}
}
+
+ /**
+ * Convert SQL configs to properties by prefixing all configs with a key.
+ * When converting a function to
[[org.apache.spark.sql.catalyst.catalog.CatalogFunction]] or
+ * [[org.apache.spark.sql.catalyst.expressions.ExpressionInfo]], all SQL
configs and other
+ * function properties (such as the function parameters and the function
return type)
+ * are saved together in a property map.
+ */
+ def sqlConfigsToProps(conf: SQLConf): Map[String, String] = {
+ val modifiedConfs = ViewHelper.getModifiedConf(conf)
+ modifiedConfs.map { case (key, value) => s"$SQL_CONFIG_PREFIX$key" ->
value }
+ }
+
+ /**
+ * Check whether the function parameters contain duplicated column names.
+ * It takes the function input parameter struct as input and verifies that
there is no duplicates
+ * in the parameter column names.
+ * If any duplicates are found, it throws an exception with helpful
information for users to
+ * fix the wrong function parameters.
+ *
+ * Perform this check while registering the function to fail early.
+ * This check does not need to run the function itself.
+ */
+ def checkParameterNameDuplication(
+ param: StructType,
+ conf: SQLConf,
+ name: FunctionIdentifier): Unit = {
+ val names = if (conf.caseSensitiveAnalysis) {
+ param.fields.map(_.name)
+ } else {
+ param.fields.map(_.name.toLowerCase(Locale.ROOT))
+ }
+ if (names.distinct.length != names.length) {
+ val duplicateColumns = names.groupBy(identity).collect {
+ case (x, ys) if ys.length > 1 => s"`$x`"
+ }
+ throw UserDefinedFunctionErrors.duplicateParameterNames(
+ routineName = name.funcName,
+ names = duplicateColumns.toSeq.sorted.mkString(", "))
+ }
+ }
+
+ /**
+ * Check whether the function has duplicate column names in the RETURNS
clause.
+ */
+ def checkReturnsColumnDuplication(
+ columns: StructType,
+ conf: SQLConf,
+ name: FunctionIdentifier): Unit = {
+ val names = if (conf.caseSensitiveAnalysis) {
+ columns.fields.map(_.name)
+ } else {
+ columns.fields.map(_.name.toLowerCase(Locale.ROOT))
+ }
+ if (names.distinct.length != names.length) {
+ val duplicateColumns = names.groupBy(identity).collect {
+ case (x, ys) if ys.length > 1 => s"`$x`"
+ }
+ throw UserDefinedFunctionErrors.duplicateReturnsColumns(
+ routineName = name.funcName,
+ columns = duplicateColumns.toSeq.sorted.mkString(", "))
+ }
+ }
+
+ /**
+ * Check whether the function parameters contain non trailing defaults.
+ * For languages that support default values for input parameters,
+ * this check ensures once a default value is given to a parameter,
+ * all subsequent parameters must also have a default value. It throws error
if otherwise.
+ *
+ * Perform this check on function input parameters while registering the
function to fail early.
+ * This check does not need to run the function itself.
+ */
+ def checkDefaultsTrailing(param: StructType, name: FunctionIdentifier): Unit
= {
+ var defaultFound = false
+ var previousParamName = "";
+ param.fields.foreach { field =>
+ if (field.getDefault().isEmpty && defaultFound) {
+ throw new AnalysisException(
+ errorClass =
"USER_DEFINED_FUNCTIONS.NOT_A_VALID_DEFAULT_PARAMETER_POSITION",
+ messageParameters = Map(
+ "functionName" -> name.funcName,
+ "parameterName" -> previousParamName,
+ "nextParameterName" -> field.name))
+ }
+ defaultFound |= field.getDefault().isDefined
+ previousParamName = field.name
+ }
+ }
+
+ /**
+ * Check whether the function input or return columns (for TABLE Return
type) have NOT NULL
+ * specified. Throw exception if NOT NULL is found.
+ *
+ * Perform this check on function input and return parameters while
registering the function
+ * to fail early. This check does not need to run the function itself.
+ */
+ def checkParameterNotNull(param: StructType, input: String): Unit = {
+ param.fields.foreach { field =>
+ if (!field.nullable) {
+ throw
UserDefinedFunctionErrors.cannotSpecifyNotNullOnFunctionParameters(input)
+ }
+ }
+ }
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala
index d5a72fd6c441..f654c846c8a5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala
@@ -464,12 +464,19 @@ object ViewHelper extends SQLConfHelper with Logging {
}
/**
- * Convert the view SQL configs to `properties`.
+ * Get all configurations that are modifiable and should be captured.
*/
- private def sqlConfigsToProps(conf: SQLConf): Map[String, String] = {
- val modifiedConfs = conf.getAllConfs.filter { case (k, _) =>
+ def getModifiedConf(conf: SQLConf): Map[String, String] = {
+ conf.getAllConfs.filter { case (k, _) =>
conf.isModifiable(k) && shouldCaptureConfig(k)
}
+ }
+
+ /**
+ * Convert the view SQL configs to `properties`.
+ */
+ private def sqlConfigsToProps(conf: SQLConf): Map[String, String] = {
+ val modifiedConfs = getModifiedConf(conf)
// Some configs have dynamic default values, such as
SESSION_LOCAL_TIMEZONE whose
// default value relies on the JVM system timezone. We need to always
capture them to
// to make sure we apply the same configs when reading the view.
@@ -690,7 +697,7 @@ object ViewHelper extends SQLConfHelper with Logging {
/**
* Collect all temporary SQL variables and return the identifiers separately.
*/
- private def collectTemporaryVariables(child: LogicalPlan): Seq[Seq[String]]
= {
+ def collectTemporaryVariables(child: LogicalPlan): Seq[Seq[String]] = {
def collectTempVars(child: LogicalPlan): Seq[Seq[String]] = {
child.flatMap { plan =>
plan.expressions.flatMap(_.flatMap {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateSQLFunctionParserSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateSQLFunctionParserSuite.scala
new file mode 100644
index 000000000000..75b42c644071
--- /dev/null
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateSQLFunctionParserSuite.scala
@@ -0,0 +1,203 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.command
+
+import org.apache.spark.sql.catalyst.FunctionIdentifier
+import org.apache.spark.sql.catalyst.analysis.{AnalysisTest,
UnresolvedIdentifier}
+import org.apache.spark.sql.catalyst.catalog.LanguageSQL
+import org.apache.spark.sql.catalyst.plans.logical.CreateUserDefinedFunction
+import org.apache.spark.sql.execution.SparkSqlParser
+
+class CreateSQLFunctionParserSuite extends AnalysisTest {
+ private lazy val parser = new SparkSqlParser()
+
+ private def intercept(sqlCommand: String, messages: String*): Unit =
+ interceptParseException(parser.parsePlan)(sqlCommand, messages: _*)()
+
+ private def checkParseError(
+ sqlCommand: String,
+ errorClass: String,
+ parameters: Map[String, String],
+ queryContext: Array[ExpectedContext] = Array.empty): Unit =
+ assertParseErrorClass(parser.parsePlan, sqlCommand, errorClass,
parameters, queryContext)
+
+ // scalastyle:off argcount
+ private def createSQLFunction(
+ nameParts: Seq[String],
+ inputParamText: Option[String] = None,
+ returnTypeText: String = "INT",
+ exprText: Option[String] = None,
+ queryText: Option[String] = None,
+ comment: Option[String] = None,
+ isDeterministic: Option[Boolean] = None,
+ containsSQL: Option[Boolean] = None,
+ isTableFunc: Boolean = false,
+ ignoreIfExists: Boolean = false,
+ replace: Boolean = false): CreateUserDefinedFunction = {
+ // scalastyle:on argcount
+ CreateUserDefinedFunction(
+ UnresolvedIdentifier(nameParts),
+ inputParamText = inputParamText,
+ returnTypeText = returnTypeText,
+ exprText = exprText,
+ queryText = queryText,
+ comment = comment,
+ isDeterministic = isDeterministic,
+ containsSQL = containsSQL,
+ language = LanguageSQL,
+ isTableFunc = isTableFunc,
+ ignoreIfExists = ignoreIfExists,
+ replace = replace)
+ }
+
+ // scalastyle:off argcount
+ private def createSQLFunctionCommand(
+ name: String,
+ inputParamText: Option[String] = None,
+ returnTypeText: String = "INT",
+ exprText: Option[String] = None,
+ queryText: Option[String] = None,
+ comment: Option[String] = None,
+ isDeterministic: Option[Boolean] = None,
+ containsSQL: Option[Boolean] = None,
+ isTableFunc: Boolean = false,
+ ignoreIfExists: Boolean = false,
+ replace: Boolean = false): CreateSQLFunctionCommand = {
+ // scalastyle:on argcount
+ CreateSQLFunctionCommand(
+ FunctionIdentifier(name),
+ inputParamText = inputParamText,
+ returnTypeText = returnTypeText,
+ exprText = exprText,
+ queryText = queryText,
+ comment = comment,
+ isDeterministic = isDeterministic,
+ containsSQL = containsSQL,
+ isTableFunc = isTableFunc,
+ isTemp = true,
+ ignoreIfExists = ignoreIfExists,
+ replace = replace)
+ }
+
+ test("create temporary SQL functions") {
+ comparePlans(
+ parser.parsePlan("CREATE TEMPORARY FUNCTION a() RETURNS INT RETURN 1"),
+ createSQLFunctionCommand("a", exprText = Some("1")))
+
+ comparePlans(
+ parser.parsePlan(
+ "CREATE TEMPORARY FUNCTION a(x INT) RETURNS TABLE (a INT) RETURN
SELECT x"),
+ createSQLFunctionCommand(
+ name = "a",
+ inputParamText = Some("x INT"),
+ returnTypeText = "a INT",
+ queryText = Some("SELECT x"),
+ isTableFunc = true))
+
+ comparePlans(
+ parser.parsePlan("CREATE OR REPLACE TEMPORARY FUNCTION a() RETURNS INT
RETURN 1"),
+ createSQLFunctionCommand("a", exprText = Some("1"), replace = true))
+
+ checkParseError(
+ "CREATE TEMPORARY FUNCTION a.b() RETURNS INT RETURN 1",
+ errorClass = "INVALID_SQL_SYNTAX.CREATE_TEMP_FUNC_WITH_DATABASE",
+ parameters = Map("database" -> "`a`"),
+ queryContext = Array(
+ ExpectedContext("CREATE TEMPORARY FUNCTION a.b() RETURNS INT RETURN
1", 0, 51)
+ )
+ )
+
+ checkParseError(
+ "CREATE TEMPORARY FUNCTION a.b.c() RETURNS INT RETURN 1",
+ errorClass = "INVALID_SQL_SYNTAX.MULTI_PART_NAME",
+ parameters = Map(
+ "statement" -> "CREATE TEMPORARY FUNCTION",
+ "name" -> "`a`.`b`.`c`"),
+ queryContext = Array(
+ ExpectedContext("CREATE TEMPORARY FUNCTION a.b.c() RETURNS INT RETURN
1", 0, 53)
+ )
+ )
+
+ checkParseError(
+ "CREATE TEMPORARY FUNCTION IF NOT EXISTS a() RETURNS INT RETURN 1",
+ errorClass = "INVALID_SQL_SYNTAX.CREATE_TEMP_FUNC_WITH_IF_NOT_EXISTS",
+ parameters = Map.empty,
+ queryContext = Array(
+ ExpectedContext("CREATE TEMPORARY FUNCTION IF NOT EXISTS a() RETURNS
INT RETURN 1", 0, 63)
+ )
+ )
+ }
+
+ test("create persistent SQL functions") {
+ comparePlans(
+ parser.parsePlan("CREATE FUNCTION a() RETURNS INT RETURN 1"),
+ createSQLFunction(Seq("a"), exprText = Some("1")))
+
+ comparePlans(
+ parser.parsePlan("CREATE FUNCTION a.b(x INT) RETURNS INT RETURN x"),
+ createSQLFunction(Seq("a", "b"), Some("x INT"), exprText = Some("x")))
+
+ comparePlans(parser.parsePlan(
+ "CREATE FUNCTION a.b.c(x INT) RETURNS TABLE (a INT) RETURN SELECT x"),
+ createSQLFunction(Seq("a", "b", "c"), Some("x INT"), returnTypeText = "a
INT", None,
+ Some("SELECT x"), isTableFunc = true))
+
+ comparePlans(parser.parsePlan("CREATE FUNCTION IF NOT EXISTS a() RETURNS
INT RETURN 1"),
+ createSQLFunction(Seq("a"), exprText = Some("1"), ignoreIfExists = true)
+ )
+
+ comparePlans(parser.parsePlan("CREATE OR REPLACE FUNCTION a() RETURNS INT
RETURN 1"),
+ createSQLFunction(Seq("a"), exprText = Some("1"), replace = true))
+
+ comparePlans(
+ parser.parsePlan(
+ """
+ |CREATE FUNCTION a(x INT COMMENT 'x') RETURNS INT
+ |LANGUAGE SQL DETERMINISTIC CONTAINS SQL
+ |COMMENT 'function'
+ |RETURN x
+ |""".stripMargin),
+ createSQLFunction(Seq("a"), inputParamText = Some("x INT COMMENT 'x'"),
+ exprText = Some("x"), isDeterministic = Some(true), containsSQL =
Some(true),
+ comment = Some("function"))
+ )
+
+ intercept("CREATE OR REPLACE FUNCTION IF NOT EXISTS a() RETURNS INT RETURN
1",
+ "Cannot create a routine with both IF NOT EXISTS and REPLACE specified")
+ }
+
+ test("create SQL functions with unsupported routine characteristics") {
+ intercept("CREATE FUNCTION foo() RETURNS INT LANGUAGE blah RETURN 1",
+ "Operation not allowed: Unsupported language for user defined functions:
blah")
+
+ intercept("CREATE FUNCTION foo() RETURNS INT SPECIFIC foo1 RETURN 1",
+ "Operation not allowed: SQL function with SPECIFIC name is not
supported")
+
+ intercept("CREATE FUNCTION foo() RETURNS INT NO SQL RETURN 1",
+ "Operation not allowed: SQL function with NO SQL is not supported")
+
+ intercept("CREATE FUNCTION foo() RETURNS INT NO SQL CONTAINS SQL RETURN 1",
+ "Found duplicate clauses: SQL DATA ACCESS")
+
+ intercept("CREATE FUNCTION foo() RETURNS INT RETURNS NULL ON NULL INPUT
RETURN 1",
+ "Operation not allowed: SQL function with RETURNS NULL ON NULL INPUT is
not supported")
+
+ intercept("CREATE FUNCTION foo() RETURNS INT SQL SECURITY INVOKER RETURN
1",
+ "Operation not allowed: SQL function with SQL SECURITY INVOKER is not
supported")
+ }
+}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala
index d38708ab3745..3dea8593b428 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala
@@ -18,7 +18,6 @@
package org.apache.spark.sql.execution.command
import org.apache.spark.SparkThrowable
-import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, GlobalTempView,
LocalTempView, SchemaCompensation, UnresolvedAttribute, UnresolvedFunctionName,
UnresolvedIdentifier}
import org.apache.spark.sql.catalyst.catalog.{ArchiveResource, FileResource,
FunctionResource, JarResource}
import org.apache.spark.sql.catalyst.dsl.expressions._
@@ -37,9 +36,6 @@ class DDLParserSuite extends AnalysisTest with
SharedSparkSession {
super.parseException(parser.parsePlan)(sqlText)
}
- private def intercept(sqlCommand: String, messages: String*): Unit =
- interceptParseException(parser.parsePlan)(sqlCommand, messages: _*)()
-
private def compareTransformQuery(sql: String, expected: LogicalPlan): Unit
= {
val plan =
parser.parsePlan(sql).asInstanceOf[ScriptTransformation].copy(ioschema = null)
comparePlans(plan, expected, checkAnalysis = false)
@@ -827,44 +823,4 @@ class DDLParserSuite extends AnalysisTest with
SharedSparkSession {
parser.parsePlan("SHOW CATALOGS LIKE 'defau*'"),
ShowCatalogsCommand(Some("defau*")))
}
-
- test("Create SQL functions") {
- comparePlans(
- parser.parsePlan("CREATE TEMP FUNCTION foo() RETURNS INT RETURN 1"),
- CreateSQLFunctionCommand(
- FunctionIdentifier("foo"),
- inputParamText = None,
- returnTypeText = "INT",
- exprText = Some("1"),
- queryText = None,
- comment = None,
- isDeterministic = None,
- containsSQL = None,
- isTableFunc = false,
- isTemp = true,
- ignoreIfExists = false,
- replace = false))
- intercept("CREATE FUNCTION foo() RETURNS INT RETURN 1",
- "Operation not allowed: creating persistent SQL functions is not
supported")
- }
-
- test("create SQL functions with unsupported routine characteristics") {
- intercept("CREATE FUNCTION foo() RETURNS INT LANGUAGE blah RETURN 1",
- "Operation not allowed: Unsupported language for user defined functions:
blah")
-
- intercept("CREATE FUNCTION foo() RETURNS INT SPECIFIC foo1 RETURN 1",
- "Operation not allowed: SQL function with SPECIFIC name is not
supported")
-
- intercept("CREATE FUNCTION foo() RETURNS INT NO SQL RETURN 1",
- "Operation not allowed: SQL function with NO SQL is not supported")
-
- intercept("CREATE FUNCTION foo() RETURNS INT NO SQL CONTAINS SQL RETURN 1",
- "Found duplicate clauses: SQL DATA ACCESS")
-
- intercept("CREATE FUNCTION foo() RETURNS INT RETURNS NULL ON NULL INPUT
RETURN 1",
- "Operation not allowed: SQL function with RETURNS NULL ON NULL INPUT is
not supported")
-
- intercept("CREATE FUNCTION foo() RETURNS INT SQL SECURITY INVOKER RETURN
1",
- "Operation not allowed: SQL function with SQL SECURITY INVOKER is not
supported")
- }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]