allisonwang-db commented on code in PR #49126:
URL: https://github.com/apache/spark/pull/49126#discussion_r1903204479
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/command/CreateSQLFunctionCommand.scala:
##########
@@ -72,8 +85,312 @@ case class CreateSQLFunctionCommand(
isTableFunc,
Map.empty)
- // TODO: Implement the rest of the method.
+ val newFunction = {
+ val (expression, query) = function.getExpressionAndQuery(parser,
isTableFunc)
+ assert(query.nonEmpty || expression.nonEmpty)
+
+ // Check if the function can be replaced.
+ if (replace && catalog.functionExists(name)) {
+ checkFunctionSignatures(catalog, name)
+ }
+
+ // Build function input.
+ val inputPlan = if (inputParam.isDefined) {
+ val param = inputParam.get
+ checkParameterNotNull(param, inputParamText.get)
+ checkParameterNameDuplication(param, conf, name)
+ checkDefaultsTrailing(param, name)
+
+ // Qualify the input parameters with the function name so that
attributes referencing
+ // the function input parameters can be resolved correctly.
+ val qualifier = Seq(name.funcName)
+ val input = param.map(p => Alias(
+ {
+ val defaultExpr = p.getDefault()
+ if (defaultExpr.isEmpty) {
+ Literal.create(null, p.dataType)
+ } else {
+ val defaultPlan = parseDefault(defaultExpr.get, parser)
+ if (SubqueryExpression.hasSubquery(defaultPlan)) {
+ throw new AnalysisException(
+ errorClass =
"USER_DEFINED_FUNCTIONS.NOT_A_VALID_DEFAULT_EXPRESSION",
+ messageParameters =
+ Map("functionName" -> name.funcName, "parameterName" ->
p.name))
+ } else if (defaultPlan.containsPattern(UNRESOLVED_ATTRIBUTE)) {
+ defaultPlan.collect {
+ case a: UnresolvedAttribute =>
+ throw QueryCompilationErrors.unresolvedAttributeError(
+ "UNRESOLVED_COLUMN", a.sql, Seq.empty, a.origin)
+ }
+ }
+ Cast(defaultPlan, p.dataType)
+ }
+ }, p.name)(qualifier = qualifier))
+ Project(input, OneRowRelation())
+ } else {
+ OneRowRelation()
+ }
+
+ // Build the function body and check if the function body can be
analyzed successfully.
+ val (unresolvedPlan, analyzedPlan, inferredReturnType) = if
(!isTableFunc) {
+ // Build SQL scalar function plan.
+ val outputExpr = if (query.isDefined) ScalarSubquery(query.get) else
expression.get
+ val plan: LogicalPlan = returnType.map { t =>
+ val retType: DataType = t match {
+ case Left(t) => t
+ case _ => throw SparkException.internalError(
+ "Unexpected return type for a scalar SQL UDF.")
+ }
+ val outputCast = Seq(Alias(Cast(outputExpr, retType),
name.funcName)())
+ Project(outputCast, inputPlan)
+ }.getOrElse {
+ // If no explicit RETURNS clause is present, infer the result type
from the function body.
+ val outputAlias = Seq(Alias(outputExpr, name.funcName)())
+ Project(outputAlias, inputPlan)
+ }
+
+ // Check the function body can be analyzed correctly.
+ val analyzed = analyzer.execute(plan)
+ val (resolved, resolvedReturnType) = analyzed match {
+ case p @ Project(expr :: Nil, _) if expr.resolved =>
+ (p, Left(expr.dataType))
+ case other =>
+ (other, function.returnType)
+ }
+
+ // Check if the SQL function body contains aggregate/window functions.
+ // This check needs to be performed before checkAnalysis to provide
better error messages.
+ checkAggOrWindowOrGeneratorExpr(resolved)
+
+ // Check if the SQL function body can be analyzed.
+ checkFunctionBodyAnalysis(analyzer, function, resolved)
+
+ (plan, resolved, resolvedReturnType)
+ } else {
+ // Build SQL table function plan.
+ if (query.isEmpty) {
+ throw
UserDefinedFunctionErrors.bodyIsNotAQueryForSqlTableUdf(name.funcName)
+ }
+
+ // Construct a lateral join to analyze the function body.
+ val plan = LateralJoin(inputPlan, LateralSubquery(query.get), Inner,
None)
+ val analyzed = analyzer.execute(plan)
+ val newPlan = analyzed match {
+ case Project(_, j: LateralJoin) => j
+ case j: LateralJoin => j
+ case _ => throw SparkException.internalError("Unexpected plan
returned when " +
+ s"creating a SQL TVF: ${analyzed.getClass.getSimpleName}.")
+ }
+ val maybeResolved = newPlan.asInstanceOf[LateralJoin].right.plan
+
+ // Check if the function body can be analyzed.
+ checkFunctionBodyAnalysis(analyzer, function, maybeResolved)
+
+ // Get the function's return schema.
+ val returnParam: StructType = returnType.map {
+ case Right(t) => t
+ case Left(_) => throw SparkException.internalError(
+ "Unexpected return schema for a SQL table function.")
+ }.getOrElse {
+ // If no explicit RETURNS clause is present, infer the result type
from the function body.
+ // To detect this, we search for instances of the UnresolvedAlias
expression. Examples:
+ // CREATE TABLE t USING PARQUET AS VALUES (0, 1), (1, 2) AS tab(c1,
c2);
+ // SELECT c1 FROM t --> UnresolvedAttribute: 'c1
+ // SELECT c1 + 1 FROM t --> UnresolvedAlias:
unresolvedalias(('c1 + 1), None)
+ // SELECT c1 + 1 AS a FROM t --> Alias: ('c1 + 1) AS a#2
Review Comment:
Yea this is to support an optional return type for SQL functions. We don't
require users to specify this alias unless the return type is missing (inferred
from the function body).
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]