sunchao commented on a change in pull request #32082:
URL: https://github.com/apache/spark/pull/32082#discussion_r618670253



##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
##########
@@ -2021,88 +2032,218 @@ class Analyzer(override val catalogManager: 
CatalogManager)
                   name, other.getClass.getCanonicalName)
               }
             }
-          case u @ UnresolvedFunction(funcId, arguments, isDistinct, filter, 
ignoreNulls) =>
+
+          case u @ UnresolvedFunction(AsFunctionIdentifier(ident), arguments,
+          isDistinct, filter, ignoreNulls) =>
             withPosition(u) {
-              v1SessionCatalog.lookupFunction(funcId, arguments) match {
-                // AggregateWindowFunctions are AggregateFunctions that can 
only be evaluated within
-                // the context of a Window clause. They do not need to be 
wrapped in an
-                // AggregateExpression.
-                case wf: AggregateWindowFunction =>
-                  if (isDistinct) {
-                    throw 
QueryCompilationErrors.functionWithUnsupportedSyntaxError(
-                      wf.prettyName, "DISTINCT")
-                  } else if (filter.isDefined) {
-                    throw 
QueryCompilationErrors.functionWithUnsupportedSyntaxError(
-                      wf.prettyName, "FILTER clause")
-                  } else if (ignoreNulls) {
-                    wf match {
-                      case nthValue: NthValue =>
-                        nthValue.copy(ignoreNulls = ignoreNulls)
-                      case _ =>
-                        throw 
QueryCompilationErrors.functionWithUnsupportedSyntaxError(
-                          wf.prettyName, "IGNORE NULLS")
-                    }
-                  } else {
-                    wf
-                  }
-                case owf: FrameLessOffsetWindowFunction =>
-                  if (isDistinct) {
-                    throw 
QueryCompilationErrors.functionWithUnsupportedSyntaxError(
-                      owf.prettyName, "DISTINCT")
-                  } else if (filter.isDefined) {
-                    throw 
QueryCompilationErrors.functionWithUnsupportedSyntaxError(
-                      owf.prettyName, "FILTER clause")
-                  } else if (ignoreNulls) {
-                    owf match {
-                      case lead: Lead =>
-                        lead.copy(ignoreNulls = ignoreNulls)
-                      case lag: Lag =>
-                        lag.copy(ignoreNulls = ignoreNulls)
-                    }
-                  } else {
-                    owf
-                  }
-                // We get an aggregate function, we need to wrap it in an 
AggregateExpression.
-                case agg: AggregateFunction =>
-                  if (filter.isDefined && !filter.get.deterministic) {
-                    throw 
QueryCompilationErrors.nonDeterministicFilterInAggregateError
-                  }
-                  if (ignoreNulls) {
-                    val aggFunc = agg match {
-                      case first: First => first.copy(ignoreNulls = 
ignoreNulls)
-                      case last: Last => last.copy(ignoreNulls = ignoreNulls)
-                      case _ =>
-                        throw 
QueryCompilationErrors.functionWithUnsupportedSyntaxError(
-                          agg.prettyName, "IGNORE NULLS")
-                    }
-                    AggregateExpression(aggFunc, Complete, isDistinct, filter)
-                  } else {
-                    AggregateExpression(agg, Complete, isDistinct, filter)
-                  }
-                // This function is not an aggregate function, just return the 
resolved one.
-                case other if isDistinct =>
-                  throw 
QueryCompilationErrors.functionWithUnsupportedSyntaxError(
-                    other.prettyName, "DISTINCT")
-                case other if filter.isDefined =>
-                  throw 
QueryCompilationErrors.functionWithUnsupportedSyntaxError(
-                    other.prettyName, "FILTER clause")
-                case other if ignoreNulls =>
-                  throw 
QueryCompilationErrors.functionWithUnsupportedSyntaxError(
-                    other.prettyName, "IGNORE NULLS")
-                case e: String2TrimExpression if arguments.size == 2 =>
-                  if (trimWarningEnabled.get) {
-                    log.warn("Two-parameter TRIM/LTRIM/RTRIM function 
signatures are deprecated." +
-                      " Use SQL syntax `TRIM((BOTH | LEADING | TRAILING)? 
trimStr FROM str)`" +
-                      " instead.")
-                    trimWarningEnabled.set(false)
-                  }
-                  e
-                case other =>
-                  other
+              processFunctionExpr(v1SessionCatalog.lookupFunction(ident, 
arguments),
+                arguments, isDistinct, filter, ignoreNulls)
+            }
+
+          case u @ UnresolvedFunction(parts, arguments, isDistinct, filter, 
ignoreNulls) =>
+            withPosition(u) {
+              // resolve built-in or temporary functions with v2 catalog
+              val resultExpression = if (parts.length == 1) {
+                v1SessionCatalog.lookupBuiltinOrTempFunction(parts.head, 
arguments).map(
+                  processFunctionExpr(_, arguments, isDistinct, filter, 
ignoreNulls)
+                )
+              } else {
+                None
               }
+
+              resultExpression.getOrElse(
+                expandIdentifier(parts) match {
+                  case NonSessionCatalogAndIdentifier(catalog: 
FunctionCatalog, ident) =>
+                    lookupV2Function(catalog, ident, arguments, isDistinct, 
filter, ignoreNulls)
+                  case _ => u
+                }
+              )
             }
         }
     }
+
+    /**
+     * Check if the input `fn` implements the given `methodName`. If 
`inputType` is set, it also
+     * tries to match it against the declared parameter types.
+     */
+    private def findMethod(
+        fn: BoundFunction,
+        methodName: String,
+        inputTypeOpt: Option[Seq[DataType]] = None): Option[Method] = {
+      val cls = fn.getClass
+      inputTypeOpt match {
+        case Some(inputType) =>
+          try {
+            val argClasses = inputType.map(ScalaReflection.dataTypeJavaClass)
+            Some(cls.getDeclaredMethod(methodName, argClasses: _*))
+          } catch {
+            case _: NoSuchMethodException =>
+              None
+          }
+        case None =>
+          cls.getDeclaredMethods.find(_.getName == methodName)
+      }
+    }
+
+    private def processFunctionExpr(
+        expr: Expression,
+        arguments: Seq[Expression],
+        isDistinct: Boolean,
+        filter: Option[Expression],
+        ignoreNulls: Boolean): Expression = expr match {
+      // AggregateWindowFunctions are AggregateFunctions that can only be 
evaluated within
+      // the context of a Window clause. They do not need to be wrapped in an
+      // AggregateExpression.
+      case wf: AggregateWindowFunction =>
+        if (isDistinct) {
+          throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
+            wf.prettyName, "DISTINCT")
+        } else if (filter.isDefined) {
+          throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
+            wf.prettyName, "FILTER clause")
+        } else if (ignoreNulls) {
+          wf match {
+            case nthValue: NthValue =>
+              nthValue.copy(ignoreNulls = ignoreNulls)
+            case _ =>
+              throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
+                wf.prettyName, "IGNORE NULLS")
+          }
+        } else {
+          wf
+        }
+      case owf: FrameLessOffsetWindowFunction =>
+        if (isDistinct) {
+          throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
+            owf.prettyName, "DISTINCT")
+        } else if (filter.isDefined) {
+          throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
+            owf.prettyName, "FILTER clause")
+        } else if (ignoreNulls) {
+          owf match {
+            case lead: Lead =>
+              lead.copy(ignoreNulls = ignoreNulls)
+            case lag: Lag =>
+              lag.copy(ignoreNulls = ignoreNulls)
+          }
+        } else {
+          owf
+        }
+      // We get an aggregate function, we need to wrap it in an 
AggregateExpression.
+      case agg: AggregateFunction =>
+        if (filter.isDefined && !filter.get.deterministic) {
+          throw QueryCompilationErrors.nonDeterministicFilterInAggregateError
+        }
+        if (ignoreNulls) {
+          val aggFunc = agg match {
+            case first: First => first.copy(ignoreNulls = ignoreNulls)
+            case last: Last => last.copy(ignoreNulls = ignoreNulls)
+            case _ =>
+              throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
+                agg.prettyName, "IGNORE NULLS")
+          }
+          AggregateExpression(aggFunc, Complete, isDistinct, filter)
+        } else {
+          AggregateExpression(agg, Complete, isDistinct, filter)
+        }
+      // This function is not an aggregate function, just return the resolved 
one.
+      case other if isDistinct =>
+        throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
+          other.prettyName, "DISTINCT")
+      case other if filter.isDefined =>
+        throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
+          other.prettyName, "FILTER clause")
+      case other if ignoreNulls =>
+        throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
+          other.prettyName, "IGNORE NULLS")
+      case e: String2TrimExpression if arguments.size == 2 =>
+        if (trimWarningEnabled.get) {
+          log.warn("Two-parameter TRIM/LTRIM/RTRIM function signatures are 
deprecated." +
+            " Use SQL syntax `TRIM((BOTH | LEADING | TRAILING)? trimStr FROM 
str)`" +
+            " instead.")
+          trimWarningEnabled.set(false)
+        }
+        e
+      case other =>
+        other
+    }
+
+    private def lookupV2Function(
+        catalog: FunctionCatalog,
+        ident: Identifier,
+        arguments: Seq[Expression],
+        isDistinct: Boolean,
+        filter: Option[Expression],
+        ignoreNulls: Boolean): Expression = {
+      val unbound = catalog.loadFunction(ident)
+      val inputType = StructType(arguments.zipWithIndex.map {
+        case (exp, pos) => StructField(s"_$pos", exp.dataType, exp.nullable)
+      })
+      val bound = try {
+        unbound.bind(inputType)
+      } catch {
+        case unsupported: UnsupportedOperationException =>
+          failAnalysis(s"Function '${unbound.name}' cannot process input: " +
+            s"(${arguments.map(_.dataType.simpleString).mkString(", ")}): " +
+            unsupported.getMessage)
+      }
+
+      bound match {
+        case scalarFunc: ScalarFunction[_] =>
+          if (isDistinct) {
+            throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
+              scalarFunc.name(), "DISTINCT")
+          } else if (filter.isDefined) {
+            throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
+              scalarFunc.name(), "FILTER clause")
+          } else if (ignoreNulls) {
+            throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
+              scalarFunc.name(), "IGNORE NULLS")
+          } else {
+            // TODO: implement type coercion by looking at input type from the 
UDF. We may
+            //  also want to check if the parameter types from the magic 
method match the
+            //  input type through `BoundFunction.inputTypes`.
+            val argClasses = inputType.fields.map(_.dataType)
+            findMethod(scalarFunc, MAGIC_METHOD_NAME, Some(argClasses)) match {
+              case Some(_) =>
+                val caller = Literal.create(scalarFunc, 
ObjectType(scalarFunc.getClass))
+                Invoke(caller, MAGIC_METHOD_NAME, scalarFunc.resultType(),
+                  arguments, returnNullable = scalarFunc.isResultNullable)
+              case _ =>
+                // TODO: handle functions defined in Scala too - in Scala, 
even if a
+                //  subclass do not override the default method in parent 
interface
+                //  defined in Java, the method can still be found from
+                //  `getDeclaredMethod`.
+                findMethod(scalarFunc, "produceResult", Some(Seq(inputType))) 
match {

Review comment:
       Good point. Added some comments for this.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to