This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 57503b67d939 [SPARK-55528][SQL] Add default collation support for SQL 
UDFs
57503b67d939 is described below

commit 57503b67d93971f22839c11086ad79191d750eee
Author: ilicmarkodb <[email protected]>
AuthorDate: Wed Mar 4 20:50:59 2026 +0800

    [SPARK-55528][SQL] Add default collation support for SQL UDFs
    
    ### What changes were proposed in this pull request?
    
    This PR adds default collation support for SQL user-defined functions, 
enabling UDFs to inherit schema-level collations and specify explicit default 
collations via the `DEFAULT COLLATION` clause.
    
    **How DEFAULT COLLATION is applied:**
    - **STRING parameters**: Parameters declared as `STRING` without explicit 
collation (e.g., `p1 STRING`) receive the default collation
    - **STRING return type**: When `RETURNS STRING` is specified without 
explicit collation, the default collation is applied
    - **Free string literals in body**: String literals in the UDF body receive 
the default collation
    - **Default string producing built-in functions in body**: String-producing 
built-in functions (e.g., `current_database()`) in the UDF body use the default 
collation for their string outputs
    
    Note: Explicit collations always take precedence. For example, `p1 STRING 
COLLATE UTF8_BINARY` preserves `UTF8_BINARY` regardless of the default 
collation.
    
    ### Why are the changes needed?
    
    Currently, SQL UDFs in Spark don't support collation specifications. This 
PR enables:
    - UDFs to specify `DEFAULT COLLATION` clause in `CREATE FUNCTION` statements
    - UDFs to automatically inherit the schema's default collation when not 
explicitly specified
    - Proper handling of explicit collations (e.g., `STRING COLLATE 
UTF8_BINARY`) without override
    - Collation support for table function return columns
    
    ### Does this PR introduce any user-facing change?
    
    Yes. Users can now:
    - Use `DEFAULT COLLATION <collation_name>` in `CREATE FUNCTION` statements
    - Have UDFs automatically inherit the schema's default collation
    
    Example:
    ```sql
    -- UDF with explicit default collation
    CREATE FUNCTION my_func(p1 STRING)
    RETURNS STRING
    DEFAULT COLLATION UTF8_LCASE
    RETURN SELECT upper(p1);
    
    -- String literals and return type get UTF8_LCASE
    -- p1 parameter gets UTF8_LCASE (no explicit collation specified)
    ```
    
    ```sql
    -- Explicit collation overrides default
    CREATE FUNCTION my_func2(p1 STRING COLLATE UTF8_BINARY)
    RETURNS STRING COLLATE de
    DEFAULT COLLATION UTF8_LCASE
    RETURN SELECT p1 || 'suffix';
    
    -- p1 keeps UTF8_BINARY (explicit collation specified)
    -- return type is 'de' (explicit collation specified)
    -- 'suffix' literal gets UTF8_LCASE (default applies)
    ```
    
    ### How was this patch tested?
    
    New tests in `DefaultCollationTestSuite`.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Yes, co-authored with Claude Sonnet 4.5
    
    Closes #54324 from ilicmarkodb/udf-default-collation.
    
    Authored-by: ilicmarkodb <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../spark/sql/catalyst/parser/SqlBaseParser.g4     |   1 +
 .../org/apache/spark/sql/types/StringType.scala    |  37 ++
 .../sql/connector/catalog/FunctionCatalog.java     |   5 +
 .../spark/sql/catalyst/analysis/Analyzer.scala     |  20 +-
 .../catalyst/analysis/ApplyDefaultCollation.scala  |  13 +-
 .../catalyst/analysis/CollationTypeCoercion.scala  |   3 +
 .../sql/catalyst/analysis/ResolveCatalogs.scala    |   2 +-
 .../catalyst/analysis/SQLFunctionExpression.scala  |   7 +
 .../spark/sql/catalyst/catalog/SQLFunction.scala   |  27 +-
 .../sql/catalyst/catalog/UserDefinedFunction.scala |  29 +-
 .../sql/catalyst/plans/logical/v2Commands.scala    |   1 +
 .../sql/catalyst/catalog/SessionCatalogSuite.scala |   2 +
 .../catalyst/analysis/ResolveSessionCatalog.scala  |   5 +-
 .../spark/sql/execution/SparkSqlParser.scala       |  10 +-
 .../command/CreateSQLFunctionCommand.scala         |  10 +-
 .../command/CreateUserDefinedFunctionCommand.scala |   2 +
 .../sql/collation/DefaultCollationTestSuite.scala  | 521 ++++++++++++++++++++-
 .../command/CreateSQLFunctionParserSuite.scala     |   2 +
 18 files changed, 664 insertions(+), 33 deletions(-)

diff --git 
a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 
b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
index d61dd137ec5a..24a6fb7e6d98 100644
--- 
a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
+++ 
b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
@@ -1571,6 +1571,7 @@ routineCharacteristics
     | sqlDataAccess
     | nullCall
     | commentSpec
+    | collationSpec
     | rightsClause)*
     ;
 
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala
index 9f52f647a57a..34467c258d6c 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala
@@ -190,3 +190,40 @@ case object NoConstraint extends StringConstraint
 case class FixedLength(length: Int) extends StringConstraint
 
 case class MaxLength(length: Int) extends StringConstraint
+
+/**
+ * Used in the context of UDFs when resolving parameters/return types.
+ *
+ * For example, if a UDF parameter is defined as `p1 STRING COLLATE 
UTF8_BINARY`, calling
+ * [[typeName]] will return just `STRING`, omitting the collation information. 
This causes the
+ * parameter to be parsed into the companion object [[StringType]]. If the UDF 
has a default
+ * collation specified, it will be applied to the companion object 
[[StringType]], potentially
+ * resulting in the construction of a [[StringType]] with an invalid collation.
+ */
+object ExplicitUTF8BinaryStringType
+    extends StringType(CollationFactory.UTF8_BINARY_COLLATION_ID, 
NoConstraint) {
+  override def typeName: String = s"string collate $collationName"
+  override def toString: String = s"StringType($collationName)"
+
+  /**
+   * Transforms the given `dataType` by replacing each [[StringType]] that has 
an explicit
+   * `UTF8_BINARY` collation with `ExplicitUTF8BinaryStringType`.
+   */
+  def transform(dataType: DataType): DataType = {
+    dataType.transformRecursively {
+      case st: StringType if st.isUTF8BinaryCollation && !st.eq(StringType) =>
+        ExplicitUTF8BinaryStringType
+    }
+  }
+
+  /**
+   * Transforms the given `dataType` by replacing each companion object 
[[StringType]] with
+   * explicit `UTF8_BINARY` [[StringType]].
+   */
+  def transformDefaultStringType(dataType: DataType): DataType = {
+    dataType.transformRecursively {
+      case st: StringType if st.eq(StringType) =>
+        StringType(CollationFactory.UTF8_BINARY_COLLATION_ID)
+    }
+  }
+}
diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/FunctionCatalog.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/FunctionCatalog.java
index de4559011942..09878509da9d 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/FunctionCatalog.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/FunctionCatalog.java
@@ -30,6 +30,11 @@ import 
org.apache.spark.sql.connector.catalog.functions.UnboundFunction;
 @Evolving
 public interface FunctionCatalog extends CatalogPlugin {
 
+  /**
+   * A reserved property to specify the collation of the function.
+   */
+  String PROP_COLLATION = "collation";
+
   /**
    * List the functions in a namespace from the catalog.
    * <p>
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index dd86c6c52cb9..12fc0f0a09fa 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -216,6 +216,12 @@ object AnalysisContext {
     try f finally { set(originContext) }
   }
 
+  def withAnalysisContext[A](function: SQLFunction)(f: => A): A = {
+    val originContext = value.get()
+    val context = originContext.copy(collation = function.collation)
+    set(context)
+    try f finally { set(originContext) }
+  }
 
   def withNewAnalysisContext[A](f: => A): A = {
     val originContext = value.get()
@@ -2340,8 +2346,10 @@ class Analyzer(
         e: SubqueryExpression,
         outer: LogicalPlan)(
         f: (LogicalPlan, Seq[Expression]) => SubqueryExpression): 
SubqueryExpression = {
-      val newSubqueryPlan = AnalysisContext.withOuterPlan(outer) {
-        executeSameContext(e.plan)
+      val newSubqueryPlan = SQLFunctionContext.withNewContext {
+        AnalysisContext.withOuterPlan(outer) {
+          executeSameContext(e.plan)
+        }
       }
 
       // If the subquery plan is fully resolved, pull the outer references and 
record
@@ -2486,7 +2494,9 @@ class Analyzer(
           Analyzer.retainResolutionConfigsForAnalysis(newConf = newConf, 
existingConf = conf)
         }
         SQLConf.withExistingConf(newConf) {
-          executeSameContext(plan)
+          AnalysisContext.withAnalysisContext(f.function) {
+            executeSameContext(plan)
+          }
         }
       }
       // Fail the analysis eagerly if a SQL function cannot be resolved using 
its input.
@@ -2785,7 +2795,9 @@ class Analyzer(
         val resolved = SQLConf.withExistingConf(newConf) {
           val plan = v1SessionCatalog.makeSQLTableFunctionPlan(name, function, 
inputs, output)
           SQLFunctionContext.withSQLFunction {
-            executeSameContext(plan)
+            AnalysisContext.withAnalysisContext(function) {
+              executeSameContext(plan)
+            }
           }
         }
         // Remove unnecessary lateral joins that are used to resolve the SQL 
function.
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ApplyDefaultCollation.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ApplyDefaultCollation.scala
index 3141e71ecadb..67d5b70b30a3 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ApplyDefaultCollation.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ApplyDefaultCollation.scala
@@ -21,7 +21,7 @@ import scala.util.control.NonFatal
 
 import org.apache.spark.SparkException
 import org.apache.spark.sql.catalyst.expressions.{Cast, 
DefaultStringProducingExpression, Expression, Literal, SubqueryExpression}
-import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumns, 
AlterColumnSpec, AlterViewAs, ColumnDefinition, CreateTable, 
CreateTableAsSelect, CreateTempView, CreateView, LogicalPlan, QualifiedColType, 
ReplaceColumns, ReplaceTable, ReplaceTableAsSelect, TableSpec, 
V2CreateTablePlan}
+import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumns, 
AlterColumnSpec, AlterViewAs, ColumnDefinition, CreateTable, 
CreateTableAsSelect, CreateTempView, CreateUserDefinedFunction, CreateView, 
LogicalPlan, QualifiedColType, ReplaceColumns, ReplaceTable, 
ReplaceTableAsSelect, TableSpec, V2CreateTablePlan}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.trees.CurrentOrigin
 import org.apache.spark.sql.catalyst.types.DataTypeUtils.{areSameBaseType, 
isDefaultStringCharOrVarcharType, replaceDefaultStringCharAndVarcharTypes}
@@ -220,6 +220,17 @@ object ApplyDefaultCollation extends Rule[LogicalPlan] {
           newAlterViewAs.copyTagsFrom(alterViewAs)
           newAlterViewAs
 
+        case createUserDefinedFunction@CreateUserDefinedFunction(
+        ResolvedIdentifier(catalog: SupportsNamespaces, identifier),
+        _, _, _, _, _, collation, _, _, _, _, _, _) if collation.isEmpty =>
+          val newCreateUserDefinedFunction =
+            CurrentOrigin.withOrigin(createUserDefinedFunction.origin) {
+              createUserDefinedFunction.copy(
+                collation = getCollationFromSchemaMetadata(catalog, 
identifier.namespace()))
+            }
+          newCreateUserDefinedFunction.copyTagsFrom(createUserDefinedFunction)
+          newCreateUserDefinedFunction
+
         case other =>
           other
       }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala
index 8d5b8c590fa4..75619c9c5ce3 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala
@@ -355,6 +355,9 @@ object CollationTypeCoercion extends SQLConfHelper {
     case expr @ (_: NamedExpression | _: SubqueryExpression | _: 
VariableReference) =>
       Some(addContextToStringType(expr.dataType, Implicit))
 
+    case f: SQLFunctionExpression =>
+      Some(addContextToStringType(f.dataType, Implicit))
+
     case lit: Literal =>
       Some(addContextToStringType(lit.dataType, Default))
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala
index d0a05e0495dc..5fa8ffefc012 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala
@@ -91,7 +91,7 @@ class ResolveCatalogs(val catalogManager: CatalogManager)
         "CREATE", nameParts.last)
 
     case CreateUserDefinedFunction(UnresolvedIdentifier(nameParts, _),
-        _, _, _, _, _, _, _, _, _, _, _)
+        _, _, _, _, _, _, _, _, _, _, _, _)
         if isSystemBuiltinName(nameParts) =>
       throw QueryCompilationErrors.operationNotAllowedOnBuiltinFunctionError(
         "CREATE", nameParts.last)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SQLFunctionExpression.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SQLFunctionExpression.scala
index 37981f47287d..e7bdc8ec0248 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SQLFunctionExpression.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SQLFunctionExpression.scala
@@ -87,4 +87,11 @@ object SQLFunctionContext {
     set(context)
     try f finally { set(originContext) }
   }
+
+  def withNewContext[A](f: => A): A = {
+    val originContext = value.get()
+    val context = SQLFunctionContext()
+    set(context)
+    try f finally { set(originContext) }
+  }
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SQLFunction.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SQLFunction.scala
index 84d87fab8b06..5724ce29742d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SQLFunction.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SQLFunction.scala
@@ -29,7 +29,7 @@ import 
org.apache.spark.sql.catalyst.catalog.UserDefinedFunction._
 import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, 
ScalarSubquery}
 import org.apache.spark.sql.catalyst.parser.ParserInterface
 import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, 
OneRowRelation, Project}
-import org.apache.spark.sql.types.{DataType, StructType}
+import org.apache.spark.sql.types.{DataType, ExplicitUTF8BinaryStringType, 
StructType}
 
 /**
  * Represent a SQL function.
@@ -40,6 +40,7 @@ import org.apache.spark.sql.types.{DataType, StructType}
  * @param exprText function body as an expression
  * @param queryText function body as a query
  * @param comment function comment
+ * @param collation function default collation
  * @param deterministic whether the function is deterministic
  * @param containsSQL whether the function has data access routine to be 
CONTAINS SQL
  * @param isTableFunc whether the function is a table function
@@ -54,6 +55,7 @@ case class SQLFunction(
     exprText: Option[String],
     queryText: Option[String],
     comment: Option[String],
+    collation: Option[String],
     deterministic: Option[Boolean],
     containsSQL: Option[Boolean],
     isTableFunc: Boolean,
@@ -152,16 +154,19 @@ case class SQLFunction(
    */
   private def sqlFunctionToProps: Map[String, String] = {
     val props = new mutable.HashMap[String, String]
-    val inputParamText = inputParam.map(_.fields.map(_.toDDL).mkString(", "))
+    val inputParamText = 
inputParam.map(ExplicitUTF8BinaryStringType.transform(_)
+      .asInstanceOf[StructType].fields.map(_.toDDL).mkString(", "))
     inputParamText.foreach(props.put(INPUT_PARAM, _))
     val returnTypeText = returnType match {
-      case Left(dataType) => dataType.sql
-      case Right(columns) => columns.toDDL
+      case Left(dataType) => 
ExplicitUTF8BinaryStringType.transform(dataType).sql
+      case Right(columns) =>
+        
ExplicitUTF8BinaryStringType.transform(columns).asInstanceOf[StructType].toDDL
     }
     props.put(RETURN_TYPE, returnTypeText)
     exprText.foreach(props.put(EXPRESSION, _))
     queryText.foreach(props.put(QUERY, _))
     comment.foreach(props.put(COMMENT, _))
+    collation.foreach(props.put(COLLATION, _))
     deterministic.foreach(d => props.put(DETERMINISTIC, d.toString))
     containsSQL.foreach(x => props.put(CONTAINS_SQL, x.toString))
     props.put(IS_TABLE_FUNC, isTableFunc.toString)
@@ -185,6 +190,7 @@ object SQLFunction {
   private val EXPRESSION: String = SQL_FUNCTION_PREFIX + "expression"
   private val QUERY: String = SQL_FUNCTION_PREFIX + "query"
   private val COMMENT: String = SQL_FUNCTION_PREFIX + "comment"
+  private val COLLATION: String = SQL_FUNCTION_PREFIX + "collation"
   private val DETERMINISTIC: String = SQL_FUNCTION_PREFIX + "deterministic"
   private val CONTAINS_SQL: String = SQL_FUNCTION_PREFIX + "containsSQL"
   private val IS_TABLE_FUNC: String = SQL_FUNCTION_PREFIX + "isTableFunc"
@@ -211,14 +217,16 @@ object SQLFunction {
       val blob = parts.sortBy(_._1).map(_._2).mkString
       val props = mapper.readValue(blob, classOf[Map[String, String]])
       val isTableFunc = props(IS_TABLE_FUNC).toBoolean
-      val returnType = parseReturnTypeText(props(RETURN_TYPE), isTableFunc, 
parser)
+      val collation = props.get(COLLATION)
+      val returnType = parseReturnTypeText(props(RETURN_TYPE), isTableFunc, 
parser, collation)
       SQLFunction(
         name = function.identifier,
-        inputParam = props.get(INPUT_PARAM).map(parseRoutineParam(_, parser)),
+        inputParam = props.get(INPUT_PARAM).map(parseRoutineParam(_, parser, 
collation)),
         returnType = returnType.get,
         exprText = props.get(EXPRESSION),
         queryText = props.get(QUERY),
         comment = props.get(COMMENT),
+        collation = collation,
         deterministic = props.get(DETERMINISTIC).map(_.toBoolean),
         containsSQL = props.get(CONTAINS_SQL).map(_.toBoolean),
         isTableFunc = isTableFunc,
@@ -249,7 +257,8 @@ object SQLFunction {
   def parseReturnTypeText(
       text: String,
       isTableFunc: Boolean,
-      parser: ParserInterface): Option[Either[DataType, StructType]] = {
+      parser: ParserInterface,
+      collation: Option[String]): Option[Either[DataType, StructType]] = {
     if (!isTableFunc) {
       // This is a scalar user-defined function.
       if (text.isEmpty) {
@@ -257,7 +266,7 @@ object SQLFunction {
         Option.empty[Either[DataType, StructType]]
       } else {
         // The CREATE FUNCTION statement included a RETURNS clause with an 
explicit return type.
-        Some(Left(parseDataType(text, parser)))
+        Some(Left(parseDataType(text, parser, collation)))
       }
     } else {
       // This is a table function.
@@ -266,7 +275,7 @@ object SQLFunction {
         Option.empty[Either[DataType, StructType]]
       } else {
         // The CREATE FUNCTION statement included a RETURNS TABLE clause with 
an explicit schema.
-        Some(Right(parseTableSchema(text, parser)))
+        Some(Right(parseTableSchema(text, parser, collation)))
       }
     }
   }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/UserDefinedFunction.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/UserDefinedFunction.scala
index 3365b11b0742..4887830c4279 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/UserDefinedFunction.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/UserDefinedFunction.scala
@@ -25,6 +25,7 @@ import org.apache.spark.SparkException
 import org.apache.spark.sql.catalyst.FunctionIdentifier
 import org.apache.spark.sql.catalyst.expressions.ExpressionInfo
 import org.apache.spark.sql.catalyst.parser.ParserInterface
+import 
org.apache.spark.sql.catalyst.types.DataTypeUtils.replaceDefaultStringCharAndVarcharTypes
 import org.apache.spark.sql.catalyst.util.CharVarcharUtils
 import org.apache.spark.sql.types.{DataType, StructType}
 
@@ -86,21 +87,37 @@ object UserDefinedFunction {
   // The default Hive Metastore SQL schema length for function resource uri.
   private val HIVE_FUNCTION_RESOURCE_URI_LENGTH_THRESHOLD: Int = 4000
 
-  def parseRoutineParam(text: String, parser: ParserInterface): StructType = {
-    val parsed = parser.parseRoutineParam(text)
+  def parseRoutineParam(text: String, parser: ParserInterface, collation: 
Option[String])
+      : StructType = {
+    val parsed = StructType(parser.parseRoutineParam(text)
+      .map(field => field.copy(dataType = resolveReturnType(field.dataType, 
collation))))
     CharVarcharUtils.failIfHasCharVarchar(parsed).asInstanceOf[StructType]
   }
 
-  def parseTableSchema(text: String, parser: ParserInterface): StructType = {
-    val parsed = parser.parseTableSchema(text)
+  def parseTableSchema(text: String, parser: ParserInterface, collation: 
Option[String])
+      : StructType = {
+    val parsed = StructType(parser.parseTableSchema(text)
+      .map(field => field.copy(dataType = resolveReturnType(field.dataType, 
collation))))
     CharVarcharUtils.failIfHasCharVarchar(parsed).asInstanceOf[StructType]
   }
 
-  def parseDataType(text: String, parser: ParserInterface): DataType = {
-    val dataType = parser.parseDataType(text)
+  def parseDataType(text: String, parser: ParserInterface, collation: 
Option[String]): DataType = {
+    val dataType = resolveReturnType(parser.parseDataType(text), collation)
     CharVarcharUtils.failIfHasCharVarchar(dataType)
   }
 
+  /**
+   * Resolve the return type by applying the default collation to non-collated 
string, char and
+   * varchar types.
+   *
+   * @param returnType The return type is taken from the RETURNS clause,
+   *                   or inferred from the function's return value if the 
clause is not specified.
+   * @param collation The default collation, if specified; otherwise, None.
+   */
+  def resolveReturnType(returnType: DataType, collation: Option[String]): 
DataType = {
+    collation.map(replaceDefaultStringCharAndVarcharTypes(returnType, 
_)).getOrElse(returnType)
+  }
+
   private val _mapper: ObjectMapper = getObjectMapper
 
   /**
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
index e22b55f625be..06a4d85a856c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
@@ -1377,6 +1377,7 @@ case class CreateUserDefinedFunction(
     exprText: Option[String],
     queryText: Option[String],
     comment: Option[String],
+    collation: Option[String],
     isDeterministic: Option[Boolean],
     containsSQL: Option[Boolean],
     language: RoutineLanguage,
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
index 47e0321bdfef..be7b4530e99e 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
@@ -2158,6 +2158,7 @@ abstract class SessionCatalogSuite extends AnalysisTest 
with Eventually {
         exprText = None,
         queryText = None,
         comment = None,
+        collation = None,
         deterministic = Some(true),
         containsSQL = Some(false),
         isTableFunc = false,
@@ -2181,6 +2182,7 @@ abstract class SessionCatalogSuite extends AnalysisTest 
with Eventually {
         exprText = Some("SELECT 1"),
         queryText = None,
         comment = None,
+        collation = None,
         deterministic = Some(true),
         containsSQL = Some(true),
         isTableFunc = true,  // But marked as table function
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala
index 92d208813cb3..7efd2e111317 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala
@@ -588,7 +588,7 @@ class ResolveSessionCatalog(val catalogManager: 
CatalogManager)
       throw 
QueryCompilationErrors.missingCatalogCreateFunctionAbilityError(catalog)
 
     case c @ CreateUserDefinedFunction(
-        CreateFunctionInSessionCatalog(ident), _, _, _, _, _, _, _, _, _, _, 
_) =>
+        CreateFunctionInSessionCatalog(ident), _, _, _, _, _, _, _, _, _, _, 
_, _) =>
       CreateUserDefinedFunctionCommand(
         FunctionIdentifier(ident.table, ident.database, ident.catalog),
         c.inputParamText,
@@ -596,6 +596,7 @@ class ResolveSessionCatalog(val catalogManager: 
CatalogManager)
         c.exprText,
         c.queryText,
         c.comment,
+        c.collation,
         c.isDeterministic,
         c.containsSQL,
         c.language,
@@ -605,7 +606,7 @@ class ResolveSessionCatalog(val catalogManager: 
CatalogManager)
         c.replace)
 
     case CreateUserDefinedFunction(
-        ResolvedIdentifier(catalog, _), _, _, _, _, _, _, _, _, _, _, _) =>
+        ResolvedIdentifier(catalog, _), _, _, _, _, _, _, _, _, _, _, _, _) =>
       throw 
QueryCompilationErrors.missingCatalogCreateFunctionAbilityError(catalog)
   }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
index 834dc5035196..4c6df5dbe6cf 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
@@ -919,7 +919,7 @@ class SparkSqlAstBuilder extends AstBuilder {
       val exprText = Option(ctx.expression()).map(source)
       val queryText = Option(ctx.query()).map(source)
 
-      val (containsSQL, deterministic, comment, optionalLanguage) =
+      val (containsSQL, deterministic, comment, collation, optionalLanguage) =
         visitRoutineCharacteristics(ctx.routineCharacteristics())
       val language: RoutineLanguage = optionalLanguage.getOrElse(LanguageSQL)
       val isTableFunc = ctx.TABLE() != null || 
returnTypeText.equalsIgnoreCase("table")
@@ -933,6 +933,7 @@ class SparkSqlAstBuilder extends AstBuilder {
             exprText,
             queryText,
             comment,
+            collation,
             deterministic,
             containsSQL,
             language,
@@ -954,6 +955,7 @@ class SparkSqlAstBuilder extends AstBuilder {
             exprText,
             queryText,
             comment,
+            collation,
             deterministic,
             containsSQL,
             language,
@@ -979,7 +981,7 @@ class SparkSqlAstBuilder extends AstBuilder {
    * rights: [SQL SECURITY INVOKER | SQL SECURITY DEFINER]
    */
   override def visitRoutineCharacteristics(ctx: RoutineCharacteristicsContext)
-  : (Option[Boolean], Option[Boolean], Option[String], 
Option[RoutineLanguage]) =
+  : (Option[Boolean], Option[Boolean], Option[String], Option[String], 
Option[RoutineLanguage]) =
     withOrigin(ctx) {
       checkDuplicateClauses(ctx.routineLanguage(), "LANGUAGE", ctx)
       checkDuplicateClauses(ctx.specificName(), "SPECIFIC", ctx)
@@ -987,6 +989,7 @@ class SparkSqlAstBuilder extends AstBuilder {
       checkDuplicateClauses(ctx.nullCall(), "NULL CALL", ctx)
       checkDuplicateClauses(ctx.deterministic(), "DETERMINISTIC", ctx)
       checkDuplicateClauses(ctx.commentSpec(), "COMMENT", ctx)
+      checkDuplicateClauses(ctx.collationSpec(), "DEFAULT COLLATION", ctx)
       checkDuplicateClauses(ctx.rightsClause(), "SQL SECURITY RIGHTS", ctx)
 
       val language: Option[RoutineLanguage] = ctx
@@ -1004,13 +1007,14 @@ class SparkSqlAstBuilder extends AstBuilder {
 
       val deterministic = 
ctx.deterministic().asScala.headOption.map(visitDeterminism)
       val comment = visitCommentSpecList(ctx.commentSpec())
+      val collation = 
ctx.collationSpec().asScala.headOption.map(visitCollationSpec)
 
       ctx.specificName().asScala.headOption.foreach(checkSpecificName)
       ctx.nullCall().asScala.headOption.foreach(checkNullCall)
       ctx.rightsClause().asScala.headOption.foreach(checkRightsClause)
       val containsSQL: Option[Boolean] =
         ctx.sqlDataAccess().asScala.headOption.map(visitDataAccess)
-      (containsSQL, deterministic, comment, language)
+      (containsSQL, deterministic, comment, collation, language)
     }
 
   /**
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CreateSQLFunctionCommand.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CreateSQLFunctionCommand.scala
index eb860089b0c8..730c3030428b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CreateSQLFunctionCommand.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CreateSQLFunctionCommand.scala
@@ -56,6 +56,7 @@ case class CreateSQLFunctionCommand(
     exprText: Option[String],
     queryText: Option[String],
     comment: Option[String],
+    collation: Option[String],
     isDeterministic: Option[Boolean],
     containsSQL: Option[Boolean],
     isTableFunc: Boolean,
@@ -72,8 +73,8 @@ case class CreateSQLFunctionCommand(
     val catalog = sparkSession.sessionState.catalog
     val conf = sparkSession.sessionState.conf
 
-    val inputParam = 
inputParamText.map(UserDefinedFunction.parseRoutineParam(_, parser))
-    val returnType = parseReturnTypeText(returnTypeText, isTableFunc, parser)
+    val inputParam = 
inputParamText.map(UserDefinedFunction.parseRoutineParam(_, parser, collation))
+    val returnType = parseReturnTypeText(returnTypeText, isTableFunc, parser, 
collation)
 
     val function = SQLFunction(
       name,
@@ -82,6 +83,7 @@ case class CreateSQLFunctionCommand(
       exprText,
       queryText,
       comment,
+      collation,
       isDeterministic,
       containsSQL,
       isTableFunc,
@@ -159,7 +161,7 @@ case class CreateSQLFunctionCommand(
         val analyzed = analyzer.execute(plan)
         val (resolved, resolvedReturnType) = analyzed match {
           case p @ Project(expr :: Nil, _) if expr.resolved =>
-            (p, Left(expr.dataType))
+            (p, Left(resolveReturnType(expr.dataType, collation)))
           case other =>
             (other, function.returnType)
         }
@@ -211,7 +213,7 @@ case class CreateSQLFunctionCommand(
               throw 
UserDefinedFunctionErrors.missingColumnNamesForSqlTableUdf(name.funcName)
             case _ =>
               
StructType(analyzed.asInstanceOf[LateralJoin].right.plan.output.map { col =>
-                StructField(col.name, col.dataType)
+                StructField(col.name, resolveReturnType(col.dataType, 
collation))
               })
           }
         }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CreateUserDefinedFunctionCommand.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CreateUserDefinedFunctionCommand.scala
index a3780a8bff19..f65c7c91251a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CreateUserDefinedFunctionCommand.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CreateUserDefinedFunctionCommand.scala
@@ -46,6 +46,7 @@ object CreateUserDefinedFunctionCommand {
       exprText: Option[String],
       queryText: Option[String],
       comment: Option[String],
+      collation: Option[String],
       isDeterministic: Option[Boolean],
       containsSQL: Option[Boolean],
       language: RoutineLanguage,
@@ -67,6 +68,7 @@ object CreateUserDefinedFunctionCommand {
           exprText,
           queryText,
           comment,
+          collation,
           isDeterministic,
           containsSQL,
           isTableFunc,
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala
index 88be0e79e4e6..bfa4dd982087 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.connector.DatasourceV2SQLBase
 import 
org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSparkSession
-import org.apache.spark.sql.types.{StringType, StructType}
+import org.apache.spark.sql.types.{BooleanType, StringType, StructType}
 
 abstract class DefaultCollationTestSuite extends QueryTest with 
SharedSparkSession {
 
@@ -759,16 +759,40 @@ abstract class DefaultCollationTestSuite extends 
QueryTest with SharedSparkSessi
   }
 }
 
+
 abstract class DefaultCollationTestSuiteV1 extends DefaultCollationTestSuite {
 
+  // This is used for tests that don't depend on explicitly specifying the 
data type
+  // (these tests still test the string type), or ones that are not applicable 
to char/varchar
+  // types. E.g., UDFs don't support char/varchar as input parameters/return 
types.
   protected def stringTestNamesV1: Seq[String] = Seq(
     "Check AttributeReference dataType from View with default collation",
     "CTAS with DEFAULT COLLATION and VIEW",
     "default string producing expressions in view definition",
+    "Test UDTF with default collation",
+    "Test UDF with default collation",
+    "Test UDTF with default collation and without columns in RETURNS TABLE",
+    "Test UDF with default collation and collation applied to return type",
+    "Test explicit UTF8_BINARY collation for UDF params/return type",
+    "ALTER SCHEMA DEFAULT COLLATION doesn't affect UDF/UDTF collation",
+    "Test applying collation to UDF params",
+    "Test UDF collation behavior with default and mixed collation settings",
+    "Test replacing UDF with default collation",
+    "Nested UDFs with default collation",
     "View with UTF8_LCASE default collation from schema level"
-  )
+  ) ++ schemaAndObjectCollationPairs.flatMap {
+    case (schemaDefaultCollation, udfDefaultCollation) => Seq(
+      s"""CREATE UDF/UDTF with schema level collation
+         | (schema default collation = $schemaDefaultCollation,
+         | view default collation = $udfDefaultCollation)""".stripMargin,
+      s"""CREATE OR UDF/UDTF with schema level collation
+         | (schema default collation = $schemaDefaultCollation,
+         | view default collation = $udfDefaultCollation)""".stripMargin
+    )
+  }
 
-    testString("Check AttributeReference dataType from View with default 
collation") {
+
+  testString("Check AttributeReference dataType from View with default 
collation") {
       _ =>
     withView(testView) {
       sql(s"CREATE VIEW $testView DEFAULT COLLATION UTF8_LCASE AS SELECT 'a' 
AS c1")
@@ -1070,6 +1094,404 @@ abstract class DefaultCollationTestSuiteV1 extends 
DefaultCollationTestSuite {
       }
     }
   }
+  def emptyCreateTable()(f: => Unit): Unit = {
+    f
+  }
+
+  def createTable(dataType: String)(f: => Unit): Unit = {
+    withTable(testTable1) {
+      sql(
+        s"""CREATE TABLE $testTable1
+           | (c1 $dataType COLLATE UNICODE, c2 $dataType COLLATE SR_AI, c3 INT)
+           |""".stripMargin)
+      // scalastyle:off
+      sql(s"INSERT INTO $testTable1 VALUES ('a', 'a', 1)")
+      // scalastyle:on
+      f
+    }
+  }
+
+  def testUDF()(
+      createAndCheckUDF: (String, String, Boolean, String, String) => Unit): 
Unit = {
+    val functionName = "f"
+    val prefix = s"${CollationFactory.CATALOG}.${CollationFactory.SCHEMA}"
+    Seq(
+      ("", "", false),
+      ("", "TEMPORARY", true),
+      ("OR REPLACE", "", false),
+      ("OR REPLACE", "TEMPORARY", true)
+    ).foreach {
+      case (replace, temporary, isTemporary) =>
+        createAndCheckUDF(replace, temporary, isTemporary, functionName, 
prefix)
+    }
+  }
+
+  testString("Test UDTF with default collation") {
+      dataType =>
+    testUDF() {
+      (replace, temporary, isTemporary, functionName, prefix) =>
+        createTable(dataType) {
+          withUserDefinedFunction((functionName, isTemporary)) {
+            // Table function
+            sql(
+              s"""CREATE $replace $temporary FUNCTION $functionName()
+                | RETURNS TABLE
+                | (c1 $dataType COLLATE UTF8_LCASE, c2 $dataType, c3 INT, c4 
$dataType)
+                | DEFAULT COLLATION UNICODE_CI
+                | RETURN
+                |  SELECT *, 'w' AS c4
+                |  FROM $testTable1
+                |  WHERE 'a' = 'A'
+                |""".stripMargin)
+
+            checkAnswer(sql(s"SELECT COUNT(*) FROM $functionName()"), Row(1))
+            checkAnswer(sql(s"SELECT COLLATION(c1) FROM $functionName()"),
+              Row(s"$prefix.UTF8_LCASE"))
+            checkAnswer(sql(s"SELECT COLLATION(c2) FROM $functionName()"),
+              Row(s"$prefix.UNICODE_CI"))
+            checkAnswer(sql(s"SELECT COLLATION(c4) FROM $functionName()"),
+              Row(s"$prefix.UNICODE_CI"))
+            checkAnswer(sql(s"SELECT c1 = 'A' FROM $functionName()"), 
Row(true))
+            checkAnswer(sql(s"SELECT c2 = 'A' FROM $functionName()"), 
Row(true))
+            checkAnswer(sql(s"SELECT c4 = 'W' FROM $functionName()"), 
Row(true))
+          }
+        }
+    }
+  }
+
+  testString("Test UDTF with default collation and without columns in RETURNS 
TABLE") { _ =>
+    testUDF() {
+      (replace, temporary, isTemporary, functionName, prefix) =>
+        withUserDefinedFunction((functionName, isTemporary)) {
+          sql(
+            s"""CREATE $replace $temporary FUNCTION $functionName()
+               | RETURNS TABLE
+               | DEFAULT COLLATION UTF8_LCASE
+               | RETURN
+               |  SELECT 'a' AS c1, 'b' COLLATE UTF8_BINARY AS c2, 'c' COLLATE 
UNICODE AS c3
+               |  WHERE 'a' = 'A'
+               |""".stripMargin)
+
+          checkAnswer(sql(s"SELECT * FROM $functionName()"), Row("a", "b", 
"c"))
+          checkAnswer(sql(s"SELECT COLLATION(c1) FROM $functionName()"),
+            Row(s"$prefix.UTF8_LCASE"))
+          checkAnswer(sql(s"SELECT COLLATION(c2) FROM $functionName()"),
+            Row(s"$prefix.UTF8_BINARY"))
+          checkAnswer(sql(s"SELECT COLLATION(c3) FROM $functionName()"),
+            Row(s"$prefix.UNICODE"))
+        }
+    }
+  }
+
+  testString("Test UDF with default collation") { dataType =>
+    testUDF() {
+      (replace, temporary, isTemporary, functionName, prefix) =>
+        createTable(dataType) {
+          withUserDefinedFunction((functionName, isTemporary)) {
+            sql(
+              s"""CREATE $replace $temporary FUNCTION $functionName()
+                | RETURNS $dataType COLLATE UTF8_LCASE
+                | DEFAULT COLLATION UNICODE_CI
+                | RETURN
+                |  SELECT c1
+                |  FROM $testTable1
+                |  WHERE 'a' = 'A'
+                |""".stripMargin)
+
+            checkAnswer(sql(s"SELECT COUNT($functionName())"), Row(1))
+            checkAnswer(sql(s"SELECT COLLATION($functionName())"),
+              Row(s"$prefix.UTF8_LCASE"))
+            checkAnswer(sql(s"SELECT $functionName() = 'A'"), Row(true))
+          }
+        }
+    }
+  }
+
+  testString("Test UDF with default collation and collation applied to return 
type") {
+      dataType =>
+    testUDF() {
+      (replace, temporary, isTemporary, functionName, prefix) =>
+        createTable(dataType) {
+          withUserDefinedFunction((functionName, isTemporary)) {
+            sql(
+              s"""CREATE $replace $temporary FUNCTION $functionName()
+                | RETURNS $dataType
+                | DEFAULT COLLATION UNICODE
+                | RETURN
+                |  SELECT c1
+                |  FROM $testTable1
+                |""".stripMargin)
+
+            checkAnswer(sql(s"SELECT COUNT($functionName())"), Row(1))
+            checkAnswer(sql(s"SELECT COLLATION($functionName())"),
+              Row(s"$prefix.UNICODE"))
+            checkAnswer(sql(s"SELECT $functionName() = 'A'"), Row(false))
+          }
+        }
+    }
+  }
+
+  testString("Test explicit UTF8_BINARY collation for UDF params/return type") 
{
+      dataType =>
+    testUDF() {
+      (replace, temporary, isTemporary, functionName, prefix) =>
+        emptyCreateTable() {
+          withUserDefinedFunction((functionName, isTemporary)) {
+            sql(
+              s"""CREATE $replace $temporary FUNCTION $functionName
+                 | (p1 $dataType COLLATE UTF8_BINARY, p2 $dataType)
+                 | RETURNS $dataType COLLATE UTF8_BINARY
+                 | DEFAULT COLLATION UTF8_LCASE
+                 | RETURN
+                 |  SELECT CASE WHEN p1 != 'A' AND p2 = 'B' THEN 'C' ELSE 'D' 
END
+                 |""".stripMargin)
+
+            checkAnswer(sql(s"SELECT $functionName('a', 'b') = 'C'"), 
Row(true))
+            checkAnswer(sql(s"SELECT $functionName('a', 'b') = 'c'"), 
Row(false))
+            checkAnswer(sql(s"SELECT DISTINCT COLLATION($functionName('b', 
'c'))"),
+              Row(s"$prefix.UTF8_BINARY"))
+          }
+        }
+    }
+
+    // Table UDF
+    testUDF() {
+      (replace, temporary, isTemporary, functionName, prefix) =>
+        emptyCreateTable() {
+          withUserDefinedFunction((functionName, isTemporary)) {
+            sql(
+              s"""CREATE $replace $temporary FUNCTION $functionName
+                 | (p1 $dataType COLLATE UTF8_BINARY, p2 $dataType)
+                 | RETURNS TABLE
+                 | (c1 $dataType COLLATE UTF8_BINARY, c2 $dataType)
+                 | DEFAULT COLLATION UTF8_LCASE
+                 | RETURN
+                 |  SELECT CASE WHEN p1 != 'A' AND p2 = 'B' THEN 'C' ELSE 'D' 
END, 'E'
+                 |""".stripMargin)
+
+            checkAnswer(sql(s"SELECT c1 = 'C', c2 = 'E' FROM 
$functionName('a', 'b')"),
+              Row(true, true))
+            checkAnswer(sql(s"SELECT c1 ='c' FROM $functionName('a', 'b')"), 
Row(false))
+            checkAnswer(sql(s"SELECT c2 ='e' FROM $functionName('a', 'b')"), 
Row(true))
+            checkAnswer(sql(s"SELECT DISTINCT COLLATION(c1) FROM 
$functionName('a', 'b')"),
+              Row(s"$prefix.UTF8_BINARY"))
+            checkAnswer(sql(s"SELECT DISTINCT COLLATION(c2) FROM 
$functionName('a', 'b')"),
+              Row(s"$prefix.UTF8_LCASE"))
+          }
+        }
+    }
+  }
+
+  // UDF with schema level collation tests
+  schemaAndObjectCollationPairs.foreach {
+    case (schemaDefaultCollation, udfDefaultCollation) =>
+      testString(
+        s"""CREATE UDF/UDTF with schema level collation
+          | (schema default collation = $schemaDefaultCollation,
+          | view default collation = $udfDefaultCollation)""".stripMargin) { 
dataType =>
+        testCreateUDFWithSchemaLevelCollation(dataType, 
schemaDefaultCollation, udfDefaultCollation)
+      }
+
+      testString(
+        s"""CREATE OR UDF/UDTF with schema level collation
+          | (schema default collation = $schemaDefaultCollation,
+          | view default collation = $udfDefaultCollation)""".stripMargin) { 
dataType =>
+        testCreateUDFWithSchemaLevelCollation(dataType, 
schemaDefaultCollation, udfDefaultCollation)
+      }
+  }
+
+  testString("ALTER SCHEMA DEFAULT COLLATION doesn't affect UDF/UDTF 
collation") {
+      dataType =>
+    val functionName = "f"
+    val prefix = "SYSTEM.BUILTIN"
+
+    withDatabase(testSchema) {
+      sql(s"CREATE SCHEMA $testSchema DEFAULT COLLATION UTF8_LCASE")
+      sql(s"USE $testSchema")
+
+      withUserDefinedFunction((functionName, false)) {
+        sql(s"CREATE FUNCTION $functionName() RETURN SELECT 'a' WHERE 'b' = 
'B'")
+
+        checkAnswer(sql(s"SELECT $functionName()"), Row("a"))
+        checkAnswer(sql(s"SELECT COLLATION($functionName())"), 
Row(s"$prefix.UTF8_LCASE"))
+
+        // ALTER SCHEMA DEFAULT COLLATION
+        sql(s"ALTER SCHEMA $testSchema DEFAULT COLLATION UNICODE")
+
+        checkAnswer(sql(s"SELECT $functionName()"), Row("a"))
+        checkAnswer(sql(s"SELECT COLLATION($functionName())"), 
Row(s"$prefix.UTF8_LCASE"))
+      }
+    }
+
+    withDatabase(testSchema) {
+      sql(s"CREATE SCHEMA $testSchema DEFAULT COLLATION UTF8_LCASE")
+      sql(s"USE $testSchema")
+
+      withUserDefinedFunction((functionName, false)) {
+        sql(
+          s"""CREATE FUNCTION $functionName()
+             |RETURNS TABLE (c1 $dataType, c2 $dataType COLLATE UTF8_BINARY,
+             |c3 $dataType COLLATE UNICODE)
+             |RETURN
+             |SELECT 'a', 'b', 'c' WHERE 'd' = 'D'
+             |""".stripMargin)
+
+        checkAnswer(sql(s"SELECT * FROM $functionName()"), Row("a", "b", "c"))
+        checkAnswer(sql(s"SELECT COLLATION(c1) FROM $functionName()"), 
Row(s"$prefix.UTF8_LCASE"))
+        checkAnswer(sql(s"SELECT COLLATION(c2) FROM $functionName()"), 
Row(s"$prefix.UTF8_BINARY"))
+        checkAnswer(sql(s"SELECT COLLATION(c3) FROM $functionName()"), 
Row(s"$prefix.UNICODE"))
+
+        // ALTER SCHEMA DEFAULT COLLATION
+        sql(s"ALTER SCHEMA $testSchema DEFAULT COLLATION UNICODE")
+
+        checkAnswer(sql(s"SELECT * FROM $functionName()"), Row("a", "b", "c"))
+        checkAnswer(sql(s"SELECT COLLATION(c1) FROM $functionName()"), 
Row(s"$prefix.UTF8_LCASE"))
+        checkAnswer(sql(s"SELECT COLLATION(c2) FROM $functionName()"), 
Row(s"$prefix.UTF8_BINARY"))
+        checkAnswer(sql(s"SELECT COLLATION(c3) FROM $functionName()"), 
Row(s"$prefix.UNICODE"))
+      }
+    }
+  }
+
+  testString("Test applying collation to UDF params") { dataType =>
+    testUDF() {
+      (replace, temporary, isTemporary, functionName, prefix) =>
+        emptyCreateTable() {
+          withUserDefinedFunction((functionName, isTemporary)) {
+            sql(
+              s"""CREATE $replace $temporary FUNCTION $functionName
+                | (p1 $dataType, p2 $dataType COLLATE UNICODE)
+                | RETURNS TABLE
+                | (c1 BOOLEAN, c2 BOOLEAN, c3 $dataType, c4 $dataType COLLATE 
UNICODE,
+                | c5 $dataType COLLATE SR_AI)
+                | DEFAULT COLLATION UTF8_LCASE
+                | RETURN
+                |  SELECT p1 = 'A', p2 = 'A', p2, p2, p2
+                |  WHERE p1 = 'A'
+                |""".stripMargin)
+
+            val expected = Seq(
+              Row(true, false, "a", "a", "a")
+            )
+            val expectedSchema = new StructType()
+              .add("c1", BooleanType)
+              .add("c2", BooleanType)
+              .add("c3", StringType)
+              .add("c4", StringType)
+              .add("c5", StringType)
+            checkAnswer(sql(s"SELECT * FROM $functionName('a', 'a')"),
+              spark.createDataFrame(spark.sparkContext.parallelize(expected), 
expectedSchema))
+            checkAnswer(sql(s"SELECT COLLATION(c3) FROM $functionName('a', 
'a')"),
+              Row(s"$prefix.UTF8_LCASE"))
+            checkAnswer(sql(s"SELECT COLLATION(c4) FROM $functionName('a', 
'a')"),
+              Row(s"$prefix.UNICODE"))
+            checkAnswer(sql(s"SELECT COLLATION(c5) FROM $functionName('a', 
'a')"),
+              Row(s"$prefix.sr_AI"))
+            checkAnswer(sql(s"SELECT c3 = 'A' FROM $functionName('a', 'a')"),
+              Row(true))
+            checkAnswer(sql(s"SELECT c4 = 'A' FROM $functionName('a', 'a')"),
+              Row(false))
+            checkAnswer(sql(s"SELECT c5 = 'A' FROM $functionName('a', 'a')"),
+              Row(false))
+          }
+        }
+    }
+  }
+
+  testString("Test UDF collation behavior with default and mixed collation 
settings") {
+      dataType =>
+    testUDF() {
+      (replace, temporary, isTemporary, functionName, prefix) =>
+        emptyCreateTable() {
+          val fullFunctionName =
+            if (isTemporary) {
+              functionName
+            } else {
+              s"spark_catalog.default.$functionName"
+            }
+
+          Seq(
+            // (returnsClause, returnType, otherCollation, inputChar, 
compareChar)
+            ("", "UTF8_LCASE", "SR_AI", "w", "W"),
+            (s"RETURNS $dataType", "UTF8_LCASE", "SR_AI", "w", "W"),
+            // scalastyle:off
+            (s"RETURNS $dataType COLLATE SR_AI", "sr_AI", "UTF8_LCASE", "ć", 
"č")
+            // scalastyle:on
+          ).foreach {
+            case (returnsClause, returnTypeCollation, otherCollation, 
inputChar, equalChar) =>
+              withUserDefinedFunction((functionName, isTemporary)) {
+                sql(
+                  s"""CREATE $replace $temporary FUNCTION $functionName() 
$returnsClause
+                    | DEFAULT COLLATION UTF8_LCASE
+                    | RETURN
+                    |  SELECT '$inputChar' AS c1
+                    |  WHERE 'a' = 'A'""".stripMargin)
+
+                checkAnswer(sql(s"SELECT COUNT($functionName())"), Row(1))
+                checkAnswer(sql(s"SELECT DISTINCT COLLATION($functionName())"),
+                  Row(s"$prefix.$returnTypeCollation"))
+                checkAnswer(
+                  sql(s"SELECT $functionName() =" +
+                    s" (SELECT '$equalChar' COLLATE $returnTypeCollation)"),
+                  Row(true))
+
+                val exception = intercept[AnalysisException] {
+                  sql(s"SELECT $functionName() = (SELECT 'a' COLLATE 
$otherCollation)")
+                }
+                assert(exception.getMessage.contains("indeterminate 
collation"))
+              }
+          }
+        }
+    }
+  }
+
+  testString("Test replacing UDF with default collation") { _ =>
+    val functionName = "f"
+    val prefix = "SYSTEM.BUILTIN"
+
+    withUserDefinedFunction((functionName, false)) {
+      sql(
+        s"""CREATE FUNCTION $functionName()
+          | RETURN
+          |  SELECT 'a'
+          |""".stripMargin)
+      sql(
+        s"""CREATE OR REPLACE FUNCTION $functionName()
+          | DEFAULT COLLATION UTF8_LCASE
+          | RETURN
+          |  SELECT 'a' AS c1
+          |""".stripMargin)
+
+      checkAnswer(sql(s"SELECT DISTINCT COLLATION($functionName())"),
+        Row(s"$prefix.UTF8_LCASE"))
+      checkAnswer(sql(s"SELECT $functionName() = 'A'"), Row(true))
+    }
+  }
+
+  testString("Nested UDFs with default collation") {
+      dataType =>
+    val function1Name = "f1"
+    val function2Name = "f2"
+    withUserDefinedFunction((function1Name, false)) {
+      sql(
+        s"""CREATE FUNCTION $function1Name(s $dataType)
+          | DEFAULT COLLATION UTF8_LCASE
+          | RETURN
+          |  SELECT s
+          |""".stripMargin)
+      withUserDefinedFunction((function2Name, false)) {
+        // scalastyle:off
+        sql(
+          s"""CREATE FUNCTION $function2Name()
+            | DEFAULT COLLATION SR_AI
+            | RETURN
+            |  SELECT 'č'
+            |  WHERE $function1Name('a') = $function1Name('A')
+            |""".stripMargin)
+        // scalastyle:on
+        checkAnswer(sql(s"SELECT COUNT($function2Name())"), Row(1))
+      }
+    }
+  }
 
   // View with schema level collation tests
   schemaAndObjectCollationPairs.foreach {
@@ -1252,6 +1674,99 @@ abstract class DefaultCollationTestSuiteV1 extends 
DefaultCollationTestSuite {
       }
     }
   }
+  private def testCreateUDFWithSchemaLevelCollation(
+      dataType: String,
+      schemaDefaultCollation: String,
+      udfDefaultCollation: Option[String],
+      replaceUDF: Boolean = false): Unit = {
+    val prefix = "SYSTEM.BUILTIN"
+    val functionName = "f"
+
+    val (udfDefaultCollationClause, resolvedDefaultCollation) =
+      if (udfDefaultCollation.isDefined) {
+        (s"DEFAULT COLLATION ${udfDefaultCollation.get}", 
udfDefaultCollation.get)
+      } else {
+        ("", schemaDefaultCollation)
+      }
+    val replace = if (replaceUDF) "OR REPLACE" else ""
+
+    Seq(/* alterSchemaCollation */ false, true).foreach {
+      alterSchemaCollation =>
+        withDatabase(testSchema) {
+          if (!alterSchemaCollation) {
+            sql(s"CREATE SCHEMA $testSchema DEFAULT COLLATION 
$schemaDefaultCollation")
+          } else {
+            sql(s"CREATE SCHEMA $testSchema DEFAULT COLLATION EN")
+            sql(s"ALTER SCHEMA $testSchema DEFAULT COLLATION 
$schemaDefaultCollation")
+          }
+          sql(s"USE $testSchema")
+
+          Seq(
+            // (returnClause, outputCollation)
+            ("", resolvedDefaultCollation),
+            (s"RETURNS $dataType", resolvedDefaultCollation),
+            (s"RETURNS $dataType COLLATE FR", "fr")
+          ).foreach {
+            case (returnClause, outputCollation) =>
+              withUserDefinedFunction((functionName, false)) {
+                // scalastyle:off
+                sql(
+                  s"""CREATE $replace FUNCTION $functionName
+                     |(p1 $dataType, p2 $dataType COLLATE UTF8_BINARY, p3 
$dataType COLLATE SR_AI_CI)
+                     |$returnClause
+                     |$udfDefaultCollationClause
+                     |RETURN SELECT 'a' AS c1 WHERE p2 != 'A' AND p3 = 'Č'
+                     |""".stripMargin)
+
+                checkAnswer(sql(s"SELECT $functionName('x', 'a', 'ć')"), 
Row("a"))
+                checkAnswer(sql(s"SELECT DISTINCT COLLATION($functionName('x', 
'a', 'ć'))"),
+                  Row(s"$prefix.$outputCollation"))
+                // scalastyle:on
+              }
+          }
+
+          withUserDefinedFunction((functionName, false)) {
+            sql(
+              s"""CREATE $replace FUNCTION $functionName()
+                 |RETURNS TABLE
+                 |(c1 $dataType, c2 $dataType COLLATE UTF8_BINARY, c3 
$dataType COLLATE SR_AI_CI)
+                 |$udfDefaultCollationClause
+                 |RETURN
+                 |SELECT 'a', 'b', 'c'
+                 |""".stripMargin)
+
+            checkAnswer(sql(s"SELECT * FROM $functionName()"), Row("a", "b", 
"c"))
+            checkAnswer(sql(s"SELECT DISTINCT COLLATION(c1) FROM 
$functionName()"),
+              Row(s"$prefix.$resolvedDefaultCollation"))
+            checkAnswer(sql(s"SELECT DISTINCT COLLATION(c2) FROM 
$functionName()"),
+              Row(s"$prefix.UTF8_BINARY"))
+            checkAnswer(sql(s"SELECT DISTINCT COLLATION(c3) FROM 
$functionName()"),
+              Row(s"$prefix.sr_CI_AI"))
+          }
+
+          withUserDefinedFunction((functionName, false)) {
+            val pairs = defaultStringProducingExpressions.zipWithIndex.map {
+              case (expr, index) => (s"$expr AS c${index + 1}", s"c${index + 
1} $dataType")
+            }
+            val columns = pairs.map(_._1).mkString(", ")
+            val returnsClause = pairs.map(_._2).mkString(", ")
+
+            sql(
+              s"""CREATE $replace FUNCTION $functionName()
+                 |RETURNS TABLE
+                 |($returnsClause)
+                 |$udfDefaultCollationClause
+                 |RETURN SELECT $columns
+                 |""".stripMargin)
+
+            (1 to defaultStringProducingExpressions.length).foreach { index =>
+              checkAnswer(sql(s"SELECT COLLATION(c$index) FROM 
$functionName()"),
+                Row(s"$prefix.$resolvedDefaultCollation"))
+            }
+          }
+        }
+    }
+  }
 }
 
 abstract class DefaultCollationTestSuiteV2
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateSQLFunctionParserSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateSQLFunctionParserSuite.scala
index 25d8a74797ce..56316f43f8df 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateSQLFunctionParserSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateSQLFunctionParserSuite.scala
@@ -58,6 +58,7 @@ class CreateSQLFunctionParserSuite extends AnalysisTest {
       exprText = exprText,
       queryText = queryText,
       comment = comment,
+      collation = None,
       isDeterministic = isDeterministic,
       containsSQL = containsSQL,
       language = LanguageSQL,
@@ -87,6 +88,7 @@ class CreateSQLFunctionParserSuite extends AnalysisTest {
       exprText = exprText,
       queryText = queryText,
       comment = comment,
+      collation = None,
       isDeterministic = isDeterministic,
       containsSQL = containsSQL,
       isTableFunc = isTableFunc,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to