This is an automated email from the ASF dual-hosted git repository. hvanhovell pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
commit c8db813644f947c84ec64ca8440897e7b5756b8e Author: Herman van Hovell <[email protected]> AuthorDate: Thu Aug 1 16:32:12 2024 -0400 SPARK-49004 --- .../sql/connect/planner/SparkConnectPlanner.scala | 70 ++++++------------ .../spark/sql/catalyst/analysis/Analyzer.scala | 26 ++----- .../sql/catalyst/analysis/FunctionRegistry.scala | 29 +++++++- .../sql/catalyst/catalog/SessionCatalog.scala | 84 ++++++++++++---------- .../apache/spark/sql/catalyst/identifiers.scala | 11 +++ .../sql/connector/catalog/CatalogManager.scala | 1 + .../main/scala/org/apache/spark/sql/Column.scala | 21 +++--- .../scala/org/apache/spark/sql/functions.scala | 13 ++-- 8 files changed, 130 insertions(+), 125 deletions(-) diff --git a/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 7bfacf7cf064..6565c4abf477 100644 --- a/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -45,13 +45,12 @@ import org.apache.spark.internal.LogKeys.{DATAFRAME_ID, SESSION_ID} import org.apache.spark.ml.{functions => MLFunctions} import org.apache.spark.resource.{ExecutorResourceRequest, ResourceProfile, TaskResourceProfile, TaskResourceRequest} import org.apache.spark.sql.{withOrigin, Column, Dataset, Encoders, ForeachWriter, Observation, RelationalGroupedDataset, SparkSession} -import org.apache.spark.sql.avro.{AvroDataToCatalyst, CatalystDataToAvro} import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier, QueryPlanningTracker} import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedDeserializer, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundRowEncoder import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, BloomFilterAggregate} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils} import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin} import org.apache.spark.sql.catalyst.plans.logical @@ -65,7 +64,8 @@ import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_ import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry import org.apache.spark.sql.connect.service.{ExecuteHolder, SessionHolder, SparkConnectService} import org.apache.spark.sql.connect.utils.MetricGenerator -import org.apache.spark.sql.errors.{DataTypeErrors, QueryCompilationErrors} +import org.apache.spark.sql.connector.catalog.CatalogManager +import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.execution.arrow.ArrowConverters @@ -1614,12 +1614,28 @@ class SparkConnectPlanner( fun: proto.Expression.UnresolvedFunction): Expression = { if (fun.getIsUserDefinedFunction) { UnresolvedFunction( - parser.parseFunctionIdentifier(fun.getFunctionName), + parser.parseMultipartIdentifier(fun.getFunctionName), fun.getArgumentsList.asScala.map(transformExpression).toSeq, isDistinct = fun.getIsDistinct) } else { + // In order to retain backwards compatibility we allow functions registered in the + // `system`.`internal` namespace to looked by their name (instead of their FQN). + val builtInName = FunctionIdentifier(fun.getFunctionName) + val functionRegistry = session.sessionState.functionRegistry + val internalName = builtInName.copy( + database = Option(CatalogManager.INTERNAL_NAMESPACE), + catalog = Option(CatalogManager.SYSTEM_CATALOG_NAME)) + // We need to drop the global built-ins because we can't parse symbolic names + // (e.g. `+`, `-`, ...). + val names = if (functionRegistry.functionExists(builtInName)) { + builtInName.nameParts + } else if (functionRegistry.functionExists(internalName)) { + internalName.nameParts + } else { + parser.parseMultipartIdentifier(fun.getFunctionName) + } UnresolvedFunction( - FunctionIdentifier(fun.getFunctionName), + names, fun.getArgumentsList.asScala.map(transformExpression).toSeq, isDistinct = fun.getIsDistinct) } @@ -1832,18 +1848,6 @@ class SparkConnectPlanner( private def transformUnregisteredFunction( fun: proto.Expression.UnresolvedFunction): Option[Expression] = { fun.getFunctionName match { - case "product" if fun.getArgumentsCount == 1 => - Some( - aggregate - .Product(transformExpression(fun.getArgumentsList.asScala.head)) - .toAggregateExpression()) - - case "bloom_filter_agg" if fun.getArgumentsCount == 3 => - // [col, expectedNumItems: Long, numBits: Long] - val children = fun.getArgumentsList.asScala.map(transformExpression) - Some( - new BloomFilterAggregate(children(0), children(1), children(2)) - .toAggregateExpression()) case "timestampdiff" if fun.getArgumentsCount == 3 => val children = fun.getArgumentsList.asScala.map(transformExpression) @@ -1864,21 +1868,6 @@ class SparkConnectPlanner( throw InvalidPlanInput(s"numBuckets should be a literal integer, but got $other") } - case "years" if fun.getArgumentsCount == 1 => - Some(Years(transformExpression(fun.getArguments(0)))) - - case "months" if fun.getArgumentsCount == 1 => - Some(Months(transformExpression(fun.getArguments(0)))) - - case "days" if fun.getArgumentsCount == 1 => - Some(Days(transformExpression(fun.getArguments(0)))) - - case "hours" if fun.getArgumentsCount == 1 => - Some(Hours(transformExpression(fun.getArguments(0)))) - - case "unwrap_udt" if fun.getArgumentsCount == 1 => - Some(UnwrapUDT(transformExpression(fun.getArguments(0)))) - case "from_json" if Seq(2, 3).contains(fun.getArgumentsCount) => // JsonToStructs constructor doesn't accept JSON-formatted schema. extractDataTypeFromJSON(fun.getArguments(1)).map { dataType => @@ -1928,9 +1917,6 @@ class SparkConnectPlanner( Some(CatalystDataToAvro(children.head, jsonFormatSchema)) // PS(Pandas API on Spark)-specific functions - case "distributed_sequence_id" if fun.getArgumentsCount == 0 => - Some(DistributedSequenceID()) - case "pandas_product" if fun.getArgumentsCount == 2 => val children = fun.getArgumentsList.asScala.map(transformExpression) val dropna = extractBoolean(children(1), "dropna") @@ -1941,14 +1927,6 @@ class SparkConnectPlanner( val ddof = extractInteger(children(1), "ddof") Some(aggregate.PandasStddev(children(0), ddof).toAggregateExpression(false)) - case "pandas_skew" if fun.getArgumentsCount == 1 => - val children = fun.getArgumentsList.asScala.map(transformExpression) - Some(aggregate.PandasSkewness(children(0)).toAggregateExpression(false)) - - case "pandas_kurt" if fun.getArgumentsCount == 1 => - val children = fun.getArgumentsList.asScala.map(transformExpression) - Some(aggregate.PandasKurtosis(children(0)).toAggregateExpression(false)) - case "pandas_var" if fun.getArgumentsCount == 2 => val children = fun.getArgumentsList.asScala.map(transformExpression) val ddof = extractInteger(children(1), "ddof") @@ -1968,11 +1946,7 @@ class SparkConnectPlanner( val children = fun.getArgumentsList.asScala.map(transformExpression) val alpha = extractDouble(children(1), "alpha") val ignoreNA = extractBoolean(children(2), "ignoreNA") - Some(EWM(children(0), alpha, ignoreNA)) - - case "null_index" if fun.getArgumentsCount == 1 => - val children = fun.getArgumentsList.asScala.map(transformExpression) - Some(NullIndex(children(0))) + Some(new EWM(children(0), alpha, ignoreNA)) // ML-specific functions case "vector_to_array" if fun.getArgumentsCount == 2 => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 1b194da5ab0a..ac0b1ea601e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2320,42 +2320,26 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor } def lookupBuiltinOrTempFunction(name: Seq[String]): Option[ExpressionInfo] = { - if (name.length == 1) { - v1SessionCatalog.lookupBuiltinOrTempFunction(name.head) - } else { - None - } + v1SessionCatalog.lookupBuiltinOrTempFunction(FunctionIdentifier(name)) } def lookupBuiltinOrTempTableFunction(name: Seq[String]): Option[ExpressionInfo] = { - if (name.length == 1) { - v1SessionCatalog.lookupBuiltinOrTempTableFunction(name.head) - } else { - None - } + v1SessionCatalog.lookupBuiltinOrTempTableFunction(FunctionIdentifier(name)) } private def resolveBuiltinOrTempFunction( name: Seq[String], arguments: Seq[Expression], u: Option[UnresolvedFunction]): Option[Expression] = { - if (name.length == 1) { - v1SessionCatalog.resolveBuiltinOrTempFunction(name.head, arguments).map { func => - if (u.isDefined) validateFunction(func, arguments.length, u.get) else func - } - } else { - None + v1SessionCatalog.resolveBuiltinOrTempFunction(FunctionIdentifier(name), arguments).map { + func => if (u.isDefined) validateFunction(func, arguments.length, u.get) else func } } private def resolveBuiltinOrTempTableFunction( name: Seq[String], arguments: Seq[Expression]): Option[LogicalPlan] = { - if (name.length == 1) { - v1SessionCatalog.resolveBuiltinOrTempTableFunction(name.head, arguments) - } else { - None - } + v1SessionCatalog.resolveBuiltinOrTempTableFunction(FunctionIdentifier(name), arguments) } private def resolveV1Function( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 48123254a8fe..39df8ba4ee46 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.variant._ import org.apache.spark.sql.catalyst.expressions.xml._ import org.apache.spark.sql.catalyst.plans.logical.{FunctionBuilderBase, Generate, LogicalPlan, OneRowRelation, Range} import org.apache.spark.sql.catalyst.trees.TreeNodeTag +import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types._ import org.apache.spark.util.ArrayImplicits._ @@ -202,7 +203,7 @@ trait SimpleFunctionRegistryBase[T] extends FunctionRegistryBase[T] with Logging // Resolution of the function name is always case insensitive, but the database name // depends on the caller private def normalizeFuncName(name: FunctionIdentifier): FunctionIdentifier = { - FunctionIdentifier(name.funcName.toLowerCase(Locale.ROOT), name.database) + name.copy(funcName = name.funcName.toLowerCase(Locale.ROOT)) } override def registerFunction( @@ -883,6 +884,32 @@ object FunctionRegistry { val functionSet: Set[FunctionIdentifier] = builtin.listFunction().toSet + /** + * Expressions registered in the system.internal. + */ + registerInternalExpression[Product]("product") + registerInternalExpression[BloomFilterAggregate]("bloom_filter_agg") + registerInternalExpression[Years]("years") + registerInternalExpression[Months]("months") + registerInternalExpression[Days]("days") + registerInternalExpression[Hours]("hours") + registerInternalExpression[UnwrapUDT]("unwrap_udt") + registerInternalExpression[DistributedSequenceID]("distributed_sequence_id") + registerInternalExpression[PandasSkewness]("pandas_skew") + registerInternalExpression[PandasKurtosis]("pandas_kurt") + registerInternalExpression[NullIndex]("null_index") + + private def registerInternalExpression[T <: Expression : ClassTag](name: String): Unit = { + val (info, builder) = FunctionRegistryBase.build(name, None) + builtin.internalRegisterFunction( + FunctionIdentifier( + name, + Option(CatalogManager.INTERNAL_NAMESPACE), + Option(CatalogManager.SYSTEM_CATALOG_NAME)), + info, + builder) + } + private def makeExprInfoForVirtualOperator(name: String, usage: String): ExpressionInfo = { new ExpressionInfo( null, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 701c68684c34..91300ee6a7eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -1647,10 +1647,17 @@ class SessionCatalog( * Look up the `ExpressionInfo` of the given function by name if it's a built-in or temp function. * This only supports scalar functions. */ - def lookupBuiltinOrTempFunction(name: String): Option[ExpressionInfo] = { - FunctionRegistry.builtinOperators.get(name.toLowerCase(Locale.ROOT)).orElse { + def lookupBuiltinOrTempFunction(funcIdent: FunctionIdentifier): Option[ExpressionInfo] = { + val operator = funcIdent match { + case FunctionIdentifier(name, None, None) => + FunctionRegistry.builtinOperators.get(name.toLowerCase(Locale.ROOT)) + case _ => None + } + operator.orElse { synchronized(lookupTempFuncWithViewContext( - name, FunctionRegistry.builtin.functionExists, functionRegistry.lookupFunction)) + funcIdent, + FunctionRegistry.builtin.functionExists, + functionRegistry.lookupFunction)) } } @@ -1658,18 +1665,26 @@ class SessionCatalog( * Look up the `ExpressionInfo` of the given function by name if it's a built-in or * temp table function. */ - def lookupBuiltinOrTempTableFunction(name: String): Option[ExpressionInfo] = synchronized { - lookupTempFuncWithViewContext( - name, TableFunctionRegistry.builtin.functionExists, tableFunctionRegistry.lookupFunction) - } + def lookupBuiltinOrTempTableFunction(funcIdent: FunctionIdentifier): Option[ExpressionInfo] = + synchronized { + lookupTempFuncWithViewContext( + funcIdent, + TableFunctionRegistry.builtin.functionExists, + tableFunctionRegistry.lookupFunction) + } /** * Look up a built-in or temp scalar function by name and resolves it to an Expression if such * a function exists. */ - def resolveBuiltinOrTempFunction(name: String, arguments: Seq[Expression]): Option[Expression] = { + def resolveBuiltinOrTempFunction( + funcIdent: FunctionIdentifier, + arguments: Seq[Expression]): Option[Expression] = { resolveBuiltinOrTempFunctionInternal( - name, arguments, FunctionRegistry.builtin.functionExists, functionRegistry) + funcIdent, + arguments, + FunctionRegistry.builtin.functionExists, + functionRegistry) } /** @@ -1677,35 +1692,36 @@ class SessionCatalog( * a function exists. */ def resolveBuiltinOrTempTableFunction( - name: String, arguments: Seq[Expression]): Option[LogicalPlan] = { + funcIdent: FunctionIdentifier, + arguments: Seq[Expression]): Option[LogicalPlan] = { resolveBuiltinOrTempFunctionInternal( - name, arguments, TableFunctionRegistry.builtin.functionExists, tableFunctionRegistry) + funcIdent, arguments, TableFunctionRegistry.builtin.functionExists, tableFunctionRegistry) } private def resolveBuiltinOrTempFunctionInternal[T]( - name: String, + funcIdent: FunctionIdentifier, arguments: Seq[Expression], isBuiltin: FunctionIdentifier => Boolean, registry: FunctionRegistryBase[T]): Option[T] = synchronized { - val funcIdent = FunctionIdentifier(name) if (!registry.functionExists(funcIdent)) { None } else { lookupTempFuncWithViewContext( - name, isBuiltin, ident => Option(registry.lookupFunction(ident, arguments))) + funcIdent, isBuiltin, ident => Option(registry.lookupFunction(ident, arguments))) } } private def lookupTempFuncWithViewContext[T]( - name: String, + funcIdent: FunctionIdentifier, isBuiltin: FunctionIdentifier => Boolean, lookupFunc: FunctionIdentifier => Option[T]): Option[T] = { - val funcIdent = FunctionIdentifier(name) if (isBuiltin(funcIdent)) { lookupFunc(funcIdent) - } else { - val isResolvingView = AnalysisContext.get.catalogAndNamespace.nonEmpty - val referredTempFunctionNames = AnalysisContext.get.referredTempFunctionNames + } else if (funcIdent.catalog.isEmpty && funcIdent.database.isEmpty) { + val name = funcIdent.funcName + val context = AnalysisContext.get + val isResolvingView = context.catalogAndNamespace.nonEmpty + val referredTempFunctionNames = context.referredTempFunctionNames if (isResolvingView) { // When resolving a view, only return a temp function if it's referred by this view. if (referredTempFunctionNames.contains(name)) { @@ -1719,10 +1735,12 @@ class SessionCatalog( // We are not resolving a view and the function is a temp one, add it to // `AnalysisContext`, so during the view creation, we can save all referred temp // functions to view metadata. - AnalysisContext.get.referredTempFunctionNames.add(name) + context.referredTempFunctionNames.add(name) } result } + } else { + None } } @@ -1809,33 +1827,21 @@ class SessionCatalog( * Look up the [[ExpressionInfo]] associated with the specified function, assuming it exists. */ def lookupFunctionInfo(name: FunctionIdentifier): ExpressionInfo = synchronized { - if (name.database.isEmpty) { - lookupBuiltinOrTempFunction(name.funcName) - .orElse(lookupBuiltinOrTempTableFunction(name.funcName)) - .getOrElse(lookupPersistentFunction(name)) - } else { - lookupPersistentFunction(name) - } + lookupBuiltinOrTempFunction(name) + .orElse(lookupBuiltinOrTempTableFunction(name)) + .getOrElse(lookupPersistentFunction(name)) } // The actual function lookup logic looks up temp/built-in function first, then persistent // function from either v1 or v2 catalog. This method only look up v1 catalog. def lookupFunction(name: FunctionIdentifier, children: Seq[Expression]): Expression = { - if (name.database.isEmpty) { - resolveBuiltinOrTempFunction(name.funcName, children) - .getOrElse(resolvePersistentFunction(name, children)) - } else { - resolvePersistentFunction(name, children) - } + resolveBuiltinOrTempFunction(name, children) + .getOrElse(resolvePersistentFunction(name, children)) } def lookupTableFunction(name: FunctionIdentifier, children: Seq[Expression]): LogicalPlan = { - if (name.database.isEmpty) { - resolveBuiltinOrTempTableFunction(name.funcName, children) - .getOrElse(resolvePersistentTableFunction(name, children)) - } else { - resolvePersistentTableFunction(name, children) - } + resolveBuiltinOrTempTableFunction(name, children) + .getOrElse(resolvePersistentTableFunction(name, children)) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala index 2f818fecad93..66e78f12d5c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst +import org.apache.spark.sql.errors.QueryCompilationErrors + /** * An identifier that optionally specifies a database. * @@ -136,6 +138,15 @@ case class FunctionIdentifier(funcName: String, database: Option[String], catalo object FunctionIdentifier { def apply(funcName: String): FunctionIdentifier = new FunctionIdentifier(funcName) + def apply(funcName: String, database: Option[String]): FunctionIdentifier = new FunctionIdentifier(funcName, database) + + def apply(names: Seq[String]): FunctionIdentifier = names match { + case Seq() => throw QueryCompilationErrors.emptyMultipartIdentifierError() + case Seq(name) => new FunctionIdentifier(name) + case Seq(database, name) => FunctionIdentifier(name, Option(database)) + case Seq(catalog, database, name) => FunctionIdentifier(name, Option(database), Option(catalog)) + case _ => throw QueryCompilationErrors.identifierTooManyNamePartsError(names) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala index 16c387a82373..d29a4e4a3648 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala @@ -155,4 +155,5 @@ private[sql] object CatalogManager { val SESSION_CATALOG_NAME: String = "spark_catalog" val SYSTEM_CATALOG_NAME = "system" val SESSION_NAMESPACE = "session" + val INTERNAL_NAMESPACE = "internal" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 3108f1886c29..3f115fcb0a3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils} +import org.apache.spark.sql.connector.catalog.CatalogManager.{INTERNAL_NAMESPACE, SYSTEM_CATALOG_NAME} import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.lit @@ -61,21 +62,25 @@ private[sql] object Column { } private[sql] def fn(name: String, inputs: Column*): Column = { - fn(name, isDistinct = false, ignoreNulls = false, inputs: _*) + fn(name, isDistinct = false, inputs: _*) } private[sql] def fn(name: String, isDistinct: Boolean, inputs: Column*): Column = { - fn(name, isDistinct = isDistinct, ignoreNulls = false, inputs: _*) + fn(name :: Nil, isDistinct = isDistinct, inputs: _*) } - private[sql] def fn( - name: String, + private[sql] def internalFn(name: String, inputs: Column*): Column = { + fn( + SYSTEM_CATALOG_NAME :: INTERNAL_NAMESPACE :: name :: Nil, + isDistinct = false, + inputs: _*) + } + + private def fn( + names: Seq[String], isDistinct: Boolean, - ignoreNulls: Boolean, inputs: Column*): Column = withOrigin { - Column { - UnresolvedFunction(Seq(name), inputs.map(_.expr), isDistinct, ignoreNulls = ignoreNulls) - } + Column(UnresolvedFunction(names, inputs.map(_.expr), isDistinct)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 0e62e05900a5..705891a92eb7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -29,7 +29,6 @@ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, HintInfo, ResolvedHint} import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.errors.{DataTypeErrors, QueryCompilationErrors} @@ -8465,9 +8464,7 @@ object functions { * @group udf_funcs * @since 3.4.0 */ - def unwrap_udt(column: Column): Column = withExpr { - UnwrapUDT(column.expr) - } + def unwrap_udt(column: Column): Column = Column.internalFn("unwrap_udt", column) // scalastyle:off // TODO(SPARK-45970): Use @static annotation so Java can access to those @@ -8481,7 +8478,7 @@ object functions { * @group partition_transforms * @since 4.0.0 */ - def years(e: Column): Column = withExpr { Years(e.expr) } + def years(e: Column): Column = Column.internalFn("years", e) /** * (Scala-specific) A transform for timestamps and dates to partition data into months. @@ -8489,7 +8486,7 @@ object functions { * @group partition_transforms * @since 4.0.0 */ - def months(e: Column): Column = withExpr { Months(e.expr) } + def months(e: Column): Column = Column.internalFn("months", e) /** * (Scala-specific) A transform for timestamps and dates to partition data into days. @@ -8497,7 +8494,7 @@ object functions { * @group partition_transforms * @since 4.0.0 */ - def days(e: Column): Column = withExpr { Days(e.expr) } + def days(e: Column): Column = Column.internalFn("days", e) /** * (Scala-specific) A transform for timestamps to partition data into hours. @@ -8505,7 +8502,7 @@ object functions { * @group partition_transforms * @since 4.0.0 */ - def hours(e: Column): Column = withExpr { Hours(e.expr) } + def hours(e: Column): Column = Column.internalFn("hours", e) /** * (Scala-specific) A transform for any type that partitions by a hash of the input column. --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
