This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 438233db4b2 [SPARK-39607][SQL][DSV2] Distribution and ordering support 
V2 function in writing
438233db4b2 is described below

commit 438233db4b2177a884e5f970970da999c491d1ec
Author: Cheng Pan <cheng...@apache.org>
AuthorDate: Mon Aug 29 14:30:54 2022 +0800

    [SPARK-39607][SQL][DSV2] Distribution and ordering support V2 function in 
writing
    
    ### What changes were proposed in this pull request?
    
    Add new feature, make distribution and ordering support V2 function in 
writing.
    
    Currently, the rule `V2Writes` support converting `ApplyTransform` to 
`TransformExpression` (unevaluable), this PR makes `V2Writes` supports 
converting `TransformExpression` to 
`ApplyFunctionExpression`/`Invoke`/`StaticInvoke` (evaluable).
    
    ### Why are the changes needed?
    
    SPARK-33779 introduced API for DSv2 writer to claim distributions and 
orderings of data before writing, w/ SPARK-34026, Spark can translate 
`IdentityTransform` to catalyst expression in distributions and orderings 
expressions.
    
    But for some databases like ClickHouse, which allows table partition 
defined by an expression, e.g. `PARTITIONED BY num % 10` it's useful to support 
translating `ApplyTransform` so Spark can organize the data to fit the target 
storage requirement before writing.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, user can use V2 function as partition transform in DSv2 connector.
    
    ### How was this patch tested?
    
    UT added.
    
    Closes #36995 from pan3793/SPARK-39607.
    
    Authored-by: Cheng Pan <cheng...@apache.org>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../spark/sql/catalyst/analysis/Analyzer.scala     |  49 +--------
 .../catalyst/expressions/V2ExpressionUtils.scala   |  62 ++++++++++-
 .../datasources/v2/DataSourceV2Relation.scala      |   6 +-
 .../v2/DistributionAndOrderingUtils.scala          |  44 ++++++--
 .../v2/V2ScanPartitioningAndOrdering.scala         |   8 +-
 .../sql/execution/datasources/v2/V2Writes.scala    |  11 +-
 .../WriteDistributionAndOrderingSuite.scala        | 115 ++++++++++++++++++++-
 .../catalog/functions/transformFunctions.scala     |  27 ++++-
 8 files changed, 245 insertions(+), 77 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 820202ef9c5..ae177efa05e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -17,7 +17,6 @@
 
 package org.apache.spark.sql.catalyst.analysis
 
-import java.lang.reflect.{Method, Modifier}
 import java.util
 import java.util.Locale
 import java.util.concurrent.atomic.AtomicBoolean
@@ -47,8 +46,7 @@ import 
org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._
 import org.apache.spark.sql.connector.catalog._
 import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
 import org.apache.spark.sql.connector.catalog.TableChange.{After, 
ColumnPosition}
-import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction => 
V2AggregateFunction, BoundFunction, ScalarFunction, UnboundFunction}
-import 
org.apache.spark.sql.connector.catalog.functions.ScalarFunction.MAGIC_METHOD_NAME
+import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction => 
V2AggregateFunction, ScalarFunction, UnboundFunction}
 import org.apache.spark.sql.connector.expressions.{FieldReference, 
IdentityTransform, Transform}
 import org.apache.spark.sql.errors.{QueryCompilationErrors, 
QueryExecutionErrors}
 import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
@@ -2388,33 +2386,7 @@ class Analyzer(override val catalogManager: 
CatalogManager)
         throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
           scalarFunc.name(), "IGNORE NULLS")
       } else {
-        val declaredInputTypes = scalarFunc.inputTypes().toSeq
-        val argClasses = 
declaredInputTypes.map(ScalaReflection.dataTypeJavaClass)
-        findMethod(scalarFunc, MAGIC_METHOD_NAME, argClasses) match {
-          case Some(m) if Modifier.isStatic(m.getModifiers) =>
-            StaticInvoke(scalarFunc.getClass, scalarFunc.resultType(),
-              MAGIC_METHOD_NAME, arguments, inputTypes = declaredInputTypes,
-                propagateNull = false, returnNullable = 
scalarFunc.isResultNullable,
-                isDeterministic = scalarFunc.isDeterministic)
-          case Some(_) =>
-            val caller = Literal.create(scalarFunc, 
ObjectType(scalarFunc.getClass))
-            Invoke(caller, MAGIC_METHOD_NAME, scalarFunc.resultType(),
-              arguments, methodInputTypes = declaredInputTypes, propagateNull 
= false,
-              returnNullable = scalarFunc.isResultNullable,
-              isDeterministic = scalarFunc.isDeterministic)
-          case _ =>
-            // TODO: handle functions defined in Scala too - in Scala, even if 
a
-            //  subclass do not override the default method in parent interface
-            //  defined in Java, the method can still be found from
-            //  `getDeclaredMethod`.
-            findMethod(scalarFunc, "produceResult", Seq(classOf[InternalRow])) 
match {
-              case Some(_) =>
-                ApplyFunctionExpression(scalarFunc, arguments)
-              case _ =>
-                failAnalysis(s"ScalarFunction '${scalarFunc.name()}' neither 
implement" +
-                  s" magic method nor override 'produceResult'")
-            }
-        }
+        V2ExpressionUtils.resolveScalarFunction(scalarFunc, arguments)
       }
     }
 
@@ -2429,23 +2401,6 @@ class Analyzer(override val catalogManager: 
CatalogManager)
       val aggregator = V2Aggregator(aggFunc, arguments)
       aggregator.toAggregateExpression(u.isDistinct, u.filter)
     }
-
-    /**
-     * Check if the input `fn` implements the given `methodName` with 
parameter types specified
-     * via `argClasses`.
-     */
-    private def findMethod(
-        fn: BoundFunction,
-        methodName: String,
-        argClasses: Seq[Class[_]]): Option[Method] = {
-      val cls = fn.getClass
-      try {
-        Some(cls.getDeclaredMethod(methodName, argClasses: _*))
-      } catch {
-        case _: NoSuchMethodException =>
-          None
-      }
-    }
   }
 
   /**
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala
index c252ea5ccfe..64eb307bb9f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala
@@ -17,13 +17,17 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
+import java.lang.reflect.{Method, Modifier}
+
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.SQLConfHelper
+import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection, 
SQLConfHelper}
 import org.apache.spark.sql.catalyst.analysis.NoSuchFunctionException
+import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke}
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.connector.catalog.{FunctionCatalog, Identifier}
 import org.apache.spark.sql.connector.catalog.functions._
+import 
org.apache.spark.sql.connector.catalog.functions.ScalarFunction.MAGIC_METHOD_NAME
 import org.apache.spark.sql.connector.expressions.{BucketTransform, Expression 
=> V2Expression, FieldReference, IdentityTransform, NamedReference, 
NamedTransform, NullOrdering => V2NullOrdering, SortDirection => 
V2SortDirection, SortOrder => V2SortOrder, SortValue, Transform}
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.types._
@@ -52,8 +56,11 @@ object V2ExpressionUtils extends SQLConfHelper with Logging {
   /**
    * Converts the array of input V2 [[V2SortOrder]] into their counterparts in 
catalyst.
    */
-  def toCatalystOrdering(ordering: Array[V2SortOrder], query: LogicalPlan): 
Seq[SortOrder] = {
-    ordering.map(toCatalyst(_, query).asInstanceOf[SortOrder])
+  def toCatalystOrdering(
+      ordering: Array[V2SortOrder],
+      query: LogicalPlan,
+      funCatalogOpt: Option[FunctionCatalog] = None): Seq[SortOrder] = {
+    ordering.map(toCatalyst(_, query, funCatalogOpt).asInstanceOf[SortOrder])
   }
 
   def toCatalyst(
@@ -143,4 +150,53 @@ object V2ExpressionUtils extends SQLConfHelper with 
Logging {
     case V2NullOrdering.NULLS_FIRST => NullsFirst
     case V2NullOrdering.NULLS_LAST => NullsLast
   }
+
+  def resolveScalarFunction(
+      scalarFunc: ScalarFunction[_],
+      arguments: Seq[Expression]): Expression = {
+    val declaredInputTypes = scalarFunc.inputTypes().toSeq
+    val argClasses = declaredInputTypes.map(ScalaReflection.dataTypeJavaClass)
+    findMethod(scalarFunc, MAGIC_METHOD_NAME, argClasses) match {
+      case Some(m) if Modifier.isStatic(m.getModifiers) =>
+        StaticInvoke(scalarFunc.getClass, scalarFunc.resultType(),
+          MAGIC_METHOD_NAME, arguments, inputTypes = declaredInputTypes,
+          propagateNull = false, returnNullable = scalarFunc.isResultNullable,
+          isDeterministic = scalarFunc.isDeterministic)
+      case Some(_) =>
+        val caller = Literal.create(scalarFunc, 
ObjectType(scalarFunc.getClass))
+        Invoke(caller, MAGIC_METHOD_NAME, scalarFunc.resultType(),
+          arguments, methodInputTypes = declaredInputTypes, propagateNull = 
false,
+          returnNullable = scalarFunc.isResultNullable,
+          isDeterministic = scalarFunc.isDeterministic)
+      case _ =>
+        // TODO: handle functions defined in Scala too - in Scala, even if a
+        //  subclass do not override the default method in parent interface
+        //  defined in Java, the method can still be found from
+        //  `getDeclaredMethod`.
+        findMethod(scalarFunc, "produceResult", Seq(classOf[InternalRow])) 
match {
+          case Some(_) =>
+            ApplyFunctionExpression(scalarFunc, arguments)
+          case _ =>
+            throw new AnalysisException(s"ScalarFunction 
'${scalarFunc.name()}'" +
+              s" neither implement magic method nor override 'produceResult'")
+        }
+    }
+  }
+
+  /**
+   * Check if the input `fn` implements the given `methodName` with parameter 
types specified
+   * via `argClasses`.
+   */
+  private def findMethod(
+      fn: BoundFunction,
+      methodName: String,
+      argClasses: Seq[Class[_]]): Option[Method] = {
+    val cls = fn.getClass
+    try {
+      Some(cls.getDeclaredMethod(methodName, argClasses: _*))
+    } catch {
+      case _: NoSuchMethodException =>
+        None
+    }
+  }
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala
index 2045c599337..4fe01ac7607 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala
@@ -21,7 +21,7 @@ import 
org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelat
 import org.apache.spark.sql.catalyst.expressions.{Attribute, 
AttributeReference, Expression, SortOrder}
 import org.apache.spark.sql.catalyst.plans.logical.{ExposesMetadataColumns, 
LeafNode, LogicalPlan, Statistics}
 import org.apache.spark.sql.catalyst.util.{truncatedString, CharVarcharUtils}
-import org.apache.spark.sql.connector.catalog.{CatalogPlugin, Identifier, 
MetadataColumn, SupportsMetadataColumns, Table, TableCapability}
+import org.apache.spark.sql.connector.catalog.{CatalogPlugin, FunctionCatalog, 
Identifier, MetadataColumn, SupportsMetadataColumns, Table, TableCapability}
 import org.apache.spark.sql.connector.read.{Scan, Statistics => V2Statistics, 
SupportsReportStatistics}
 import org.apache.spark.sql.connector.read.streaming.{Offset, SparkDataStream}
 import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -48,6 +48,10 @@ case class DataSourceV2Relation(
 
   import DataSourceV2Implicits._
 
+  lazy val funCatalog: Option[FunctionCatalog] = catalog.collect {
+    case c: FunctionCatalog => c
+  }
+
   override lazy val metadataOutput: Seq[AttributeReference] = table match {
     case hasMeta: SupportsMetadataColumns =>
       val resolve = conf.resolver
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DistributionAndOrderingUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DistributionAndOrderingUtils.scala
index 07ede819880..b0b0d7bbc2d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DistributionAndOrderingUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DistributionAndOrderingUtils.scala
@@ -17,22 +17,33 @@
 
 package org.apache.spark.sql.execution.datasources.v2
 
-import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.analysis.{AnsiTypeCoercion, TypeCoercion}
+import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, 
SortOrder, TransformExpression, V2ExpressionUtils}
 import org.apache.spark.sql.catalyst.expressions.V2ExpressionUtils._
 import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, 
RebalancePartitions, RepartitionByExpression, Sort}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.connector.catalog.FunctionCatalog
+import org.apache.spark.sql.connector.catalog.functions.ScalarFunction
 import org.apache.spark.sql.connector.distributions._
 import org.apache.spark.sql.connector.write.{RequiresDistributionAndOrdering, 
Write}
 import org.apache.spark.sql.errors.QueryCompilationErrors
 
 object DistributionAndOrderingUtils {
 
-  def prepareQuery(write: Write, query: LogicalPlan): LogicalPlan = write 
match {
+  def prepareQuery(
+      write: Write,
+      query: LogicalPlan,
+      funCatalogOpt: Option[FunctionCatalog]): LogicalPlan = write match {
     case write: RequiresDistributionAndOrdering =>
       val numPartitions = write.requiredNumPartitions()
 
       val distribution = write.requiredDistribution match {
-        case d: OrderedDistribution => toCatalystOrdering(d.ordering(), query)
-        case d: ClusteredDistribution => d.clustering.map(e => toCatalyst(e, 
query)).toSeq
+        case d: OrderedDistribution =>
+          toCatalystOrdering(d.ordering(), query, funCatalogOpt)
+            .map(e => resolveTransformExpression(e).asInstanceOf[SortOrder])
+        case d: ClusteredDistribution =>
+          d.clustering.map(e => toCatalyst(e, query, funCatalogOpt))
+            .map(e => resolveTransformExpression(e)).toSeq
         case _: UnspecifiedDistribution => Seq.empty[Expression]
       }
 
@@ -53,16 +64,33 @@ object DistributionAndOrderingUtils {
         query
       }
 
-      val ordering = toCatalystOrdering(write.requiredOrdering, query)
+      val ordering = toCatalystOrdering(write.requiredOrdering, query, 
funCatalogOpt)
       val queryWithDistributionAndOrdering = if (ordering.nonEmpty) {
-        Sort(ordering, global = false, queryWithDistribution)
+        Sort(
+          ordering.map(e => 
resolveTransformExpression(e).asInstanceOf[SortOrder]),
+          global = false,
+          queryWithDistribution)
       } else {
         queryWithDistribution
       }
 
-      queryWithDistributionAndOrdering
-
+      // Apply typeCoercionRules since the converted expression from 
TransformExpression
+      // implemented ImplicitCastInputTypes
+      typeCoercionRules.foldLeft(queryWithDistributionAndOrdering)((plan, 
rule) => rule(plan))
     case _ =>
       query
   }
+
+  private def resolveTransformExpression(expr: Expression): Expression = 
expr.transform {
+    case TransformExpression(scalarFunc: ScalarFunction[_], arguments, 
Some(numBuckets)) =>
+      V2ExpressionUtils.resolveScalarFunction(scalarFunc, 
Seq(Literal(numBuckets)) ++ arguments)
+    case TransformExpression(scalarFunc: ScalarFunction[_], arguments, None) =>
+      V2ExpressionUtils.resolveScalarFunction(scalarFunc, arguments)
+  }
+
+  private def typeCoercionRules: List[Rule[LogicalPlan]] = if 
(conf.ansiEnabled) {
+    AnsiTypeCoercion.typeCoercionRules
+  } else {
+    TypeCoercion.typeCoercionRules
+  }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanPartitioningAndOrdering.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanPartitioningAndOrdering.scala
index 7ea1ca8c244..8ab0dc70726 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanPartitioningAndOrdering.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanPartitioningAndOrdering.scala
@@ -20,7 +20,6 @@ import org.apache.spark.sql.catalyst.SQLConfHelper
 import org.apache.spark.sql.catalyst.expressions.V2ExpressionUtils
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.connector.catalog.FunctionCatalog
 import org.apache.spark.sql.connector.read.{SupportsReportOrdering, 
SupportsReportPartitioning}
 import 
org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, 
UnknownPartitioning}
 import org.apache.spark.util.collection.Utils.sequenceToOption
@@ -41,14 +40,9 @@ object V2ScanPartitioningAndOrdering extends 
Rule[LogicalPlan] with SQLConfHelpe
 
   private def partitioning(plan: LogicalPlan) = plan.transformDown {
     case d @ DataSourceV2ScanRelation(relation, scan: 
SupportsReportPartitioning, _, None, _) =>
-      val funCatalogOpt = relation.catalog.flatMap {
-        case c: FunctionCatalog => Some(c)
-        case _ => None
-      }
-
       val catalystPartitioning = scan.outputPartitioning() match {
         case kgp: KeyGroupedPartitioning => sequenceToOption(kgp.keys().map(
-          V2ExpressionUtils.toCatalystOpt(_, relation, funCatalogOpt)))
+          V2ExpressionUtils.toCatalystOpt(_, relation, relation.funCatalog)))
         case _: UnknownPartitioning => None
         case p => throw new IllegalArgumentException("Unsupported data source 
V2 partitioning " +
             "type: " + p.getClass.getSimpleName)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala
index 2d47d94ff1d..afdcf2c870d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala
@@ -43,7 +43,7 @@ object V2Writes extends Rule[LogicalPlan] with 
PredicateHelper {
     case a @ AppendData(r: DataSourceV2Relation, query, options, _, None) =>
       val writeBuilder = newWriteBuilder(r.table, options, query.schema)
       val write = writeBuilder.build()
-      val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query)
+      val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, 
r.funCatalog)
       a.copy(write = Some(write), query = newQuery)
 
     case o @ OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query, 
options, _, None) =>
@@ -67,7 +67,7 @@ object V2Writes extends Rule[LogicalPlan] with 
PredicateHelper {
           throw 
QueryExecutionErrors.overwriteTableByUnsupportedExpressionError(table)
       }
 
-      val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query)
+      val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, 
r.funCatalog)
       o.copy(write = Some(write), query = newQuery)
 
     case o @ OverwritePartitionsDynamic(r: DataSourceV2Relation, query, 
options, _, None) =>
@@ -79,7 +79,7 @@ object V2Writes extends Rule[LogicalPlan] with 
PredicateHelper {
         case _ =>
           throw 
QueryExecutionErrors.dynamicPartitionOverwriteUnsupportedByTableError(table)
       }
-      val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query)
+      val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, 
r.funCatalog)
       o.copy(write = Some(write), query = newQuery)
 
     case WriteToMicroBatchDataSource(
@@ -89,14 +89,15 @@ object V2Writes extends Rule[LogicalPlan] with 
PredicateHelper {
       val write = buildWriteForMicroBatch(table, writeBuilder, outputMode)
       val microBatchWrite = new MicroBatchWrite(batchId, write.toStreaming)
       val customMetrics = write.supportedCustomMetrics.toSeq
-      val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query)
+      val funCatalogOpt = relation.flatMap(_.funCatalog)
+      val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, 
funCatalogOpt)
       WriteToDataSourceV2(relation, microBatchWrite, newQuery, customMetrics)
 
     case rd @ ReplaceData(r: DataSourceV2Relation, _, query, _, None) =>
       val rowSchema = StructType.fromAttributes(rd.dataInput)
       val writeBuilder = newWriteBuilder(r.table, Map.empty, rowSchema)
       val write = writeBuilder.build()
-      val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query)
+      val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, 
r.funCatalog)
       // project away any metadata columns that could be used for distribution 
and ordering
       rd.copy(write = Some(write), query = Project(rd.dataInput, newQuery))
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala
index 26baec90f3c..7966add7738 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala
@@ -20,11 +20,13 @@ package org.apache.spark.sql.connector
 import java.util.Collections
 
 import org.apache.spark.sql.{catalyst, AnalysisException, DataFrame, Row}
+import org.apache.spark.sql.catalyst.expressions.{ApplyFunctionExpression, 
Cast, Literal}
 import org.apache.spark.sql.catalyst.plans.physical
 import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, 
RangePartitioning, UnknownPartitioning}
 import org.apache.spark.sql.connector.catalog.Identifier
+import org.apache.spark.sql.connector.catalog.functions.{BucketFunction, 
StringSelfFunction, UnboundBucketFunction, UnboundStringSelfFunction}
 import org.apache.spark.sql.connector.distributions.{Distribution, 
Distributions}
-import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, 
NullOrdering, SortDirection, SortOrder}
+import org.apache.spark.sql.connector.expressions._
 import org.apache.spark.sql.connector.expressions.LogicalExpressions._
 import org.apache.spark.sql.execution.{QueryExecution, SortExec, SparkPlan}
 import org.apache.spark.sql.execution.adaptive.AQEShuffleReadExec
@@ -36,13 +38,21 @@ import org.apache.spark.sql.functions.lit
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming.{StreamingQueryException, Trigger}
 import org.apache.spark.sql.test.SQLTestData.TestData
-import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
+import org.apache.spark.sql.types.{IntegerType, LongType, StringType, 
StructType}
 import org.apache.spark.sql.util.QueryExecutionListener
 
 class WriteDistributionAndOrderingSuite extends 
DistributionAndOrderingSuiteBase {
   import testImplicits._
 
+  before {
+    Seq(UnboundBucketFunction, UnboundStringSelfFunction).foreach { f =>
+      catalog.createFunction(Identifier.of(Array.empty, f.name()), f)
+    }
+  }
+
   after {
+    catalog.clearTables()
+    catalog.clearFunctions()
     spark.sessionState.catalogManager.reset()
   }
 
@@ -987,6 +997,95 @@ class WriteDistributionAndOrderingSuite extends 
DistributionAndOrderingSuiteBase
     }
   }
 
+  test("clustered distribution and local sort contains v2 function: append") {
+    
checkClusteredDistributionAndLocalSortContainsV2FunctionInVariousCases("append")
+  }
+
+  test("clustered distribution and local sort contains v2 function: 
overwrite") {
+    
checkClusteredDistributionAndLocalSortContainsV2FunctionInVariousCases("overwrite")
+  }
+
+  test("clustered distribution and local sort contains v2 function: 
overwriteDynamic") {
+    
checkClusteredDistributionAndLocalSortContainsV2FunctionInVariousCases("overwriteDynamic")
+  }
+
+  test("clustered distribution and local sort contains v2 function with 
numPartitions: append") {
+    checkClusteredDistributionAndLocalSortContainsV2Function("append", 
Some(10))
+  }
+
+  test("clustered distribution and local sort contains v2 function with 
numPartitions: " +
+    "overwrite") {
+    checkClusteredDistributionAndLocalSortContainsV2Function("overwrite", 
Some(10))
+  }
+
+  test("clustered distribution and local sort contains v2 function with 
numPartitions: " +
+    "overwriteDynamic") {
+    
checkClusteredDistributionAndLocalSortContainsV2Function("overwriteDynamic", 
Some(10))
+  }
+
+  private def 
checkClusteredDistributionAndLocalSortContainsV2FunctionInVariousCases(
+    cmd: String): Unit = {
+    Seq(true, false).foreach { distributionStrictlyRequired =>
+      Seq(true, false).foreach { dataSkewed =>
+        Seq(true, false).foreach { coalesce =>
+          checkClusteredDistributionAndLocalSortContainsV2Function(
+            cmd, None, distributionStrictlyRequired, dataSkewed, coalesce)
+        }
+      }
+    }
+  }
+
+  private def checkClusteredDistributionAndLocalSortContainsV2Function(
+      command: String,
+      targetNumPartitions: Option[Int] = None,
+      distributionStrictlyRequired: Boolean = true,
+      dataSkewed: Boolean = false,
+      coalesce: Boolean = false): Unit = {
+    val tableOrdering = Array[SortOrder](
+      sort(FieldReference("data"), SortDirection.DESCENDING, 
NullOrdering.NULLS_FIRST),
+      sort(
+        BucketTransform(LiteralValue(10, IntegerType), 
Seq(FieldReference("id"))),
+        SortDirection.DESCENDING,
+        NullOrdering.NULLS_FIRST)
+    )
+    val tableDistribution = Distributions.clustered(Array(
+      ApplyTransform("string_self", Seq(FieldReference("data")))))
+
+    val writeOrdering = Seq(
+      catalyst.expressions.SortOrder(
+        attr("data"),
+        catalyst.expressions.Descending,
+        catalyst.expressions.NullsFirst,
+        Seq.empty
+      ),
+      catalyst.expressions.SortOrder(
+        ApplyFunctionExpression(BucketFunction, Seq(Literal(10), 
Cast(attr("id"), LongType))),
+        catalyst.expressions.Descending,
+        catalyst.expressions.NullsFirst,
+        Seq.empty
+      )
+    )
+
+    val writePartitioningExprs = Seq(
+      ApplyFunctionExpression(StringSelfFunction, Seq(attr("data"))))
+    val writePartitioning = if (!coalesce) {
+      clusteredWritePartitioning(writePartitioningExprs, targetNumPartitions)
+    } else {
+      clusteredWritePartitioning(writePartitioningExprs, Some(1))
+    }
+
+    checkWriteRequirements(
+      tableDistribution,
+      tableOrdering,
+      targetNumPartitions,
+      expectedWritePartitioning = writePartitioning,
+      expectedWriteOrdering = writeOrdering,
+      writeCommand = command,
+      distributionStrictlyRequired = distributionStrictlyRequired,
+      dataSkewed = dataSkewed,
+      coalesce = coalesce)
+  }
+
   // scalastyle:off argcount
   private def checkWriteRequirements(
       tableDistribution: Distribution,
@@ -1209,12 +1308,20 @@ class WriteDistributionAndOrderingSuite extends 
DistributionAndOrderingSuiteBase
     if (skewSplit) {
       assert(actualPartitioning.numPartitions > conf.numShufflePartitions)
     } else {
-      assert(actualPartitioning == expectedPartitioning, "partitioning must 
match")
+      (actualPartitioning, expectedPartitioning) match {
+        case (actual: catalyst.expressions.Expression, expected: 
catalyst.expressions.Expression) =>
+          assert(actual semanticEquals expected, "partitioning must match")
+        case (actual, expected) =>
+          assert(actual == expected, "partitioning must match")
+      }
     }
 
     val actualOrdering = plan.outputOrdering
     val expectedOrdering = ordering.map(resolveAttrs(_, plan))
-    assert(actualOrdering == expectedOrdering, "ordering must match")
+    assert(actualOrdering.length == expectedOrdering.length)
+    (actualOrdering zip expectedOrdering).foreach { case (actual, expected) =>
+      assert(actual semanticEquals expected, "ordering must match")
+    }
   }
 
   // executes a write operation and keeps the executed physical plan
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala
index 1994874d328..9277e8d059f 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala
@@ -16,7 +16,9 @@
  */
 package org.apache.spark.sql.connector.catalog.functions
 
+import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
 
 object UnboundYearsFunction extends UnboundFunction {
   override def bind(inputType: StructType): BoundFunction = {
@@ -70,9 +72,30 @@ object UnboundBucketFunction extends UnboundFunction {
   override def name(): String = "bucket"
 }
 
-object BucketFunction extends BoundFunction {
-  override def inputTypes(): Array[DataType] = Array(IntegerType, IntegerType)
+object BucketFunction extends ScalarFunction[Int] {
+  override def inputTypes(): Array[DataType] = Array(IntegerType, LongType)
   override def resultType(): DataType = IntegerType
   override def name(): String = "bucket"
   override def canonicalName(): String = name()
+  override def toString: String = name()
+  override def produceResult(input: InternalRow): Int = {
+    (input.getLong(1) % input.getInt(0)).toInt
+  }
+}
+
+object UnboundStringSelfFunction extends UnboundFunction {
+  override def bind(inputType: StructType): BoundFunction = StringSelfFunction
+  override def description(): String = name()
+  override def name(): String = "string_self"
+}
+
+object StringSelfFunction extends ScalarFunction[UTF8String] {
+  override def inputTypes(): Array[DataType] = Array(StringType)
+  override def resultType(): DataType = StringType
+  override def name(): String = "string_self"
+  override def canonicalName(): String = name()
+  override def toString: String = name()
+  override def produceResult(input: InternalRow): UTF8String = {
+    input.getUTF8String(0)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to