This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 438233db4b2 [SPARK-39607][SQL][DSV2] Distribution and ordering support V2 function in writing 438233db4b2 is described below commit 438233db4b2177a884e5f970970da999c491d1ec Author: Cheng Pan <cheng...@apache.org> AuthorDate: Mon Aug 29 14:30:54 2022 +0800 [SPARK-39607][SQL][DSV2] Distribution and ordering support V2 function in writing ### What changes were proposed in this pull request? Add new feature, make distribution and ordering support V2 function in writing. Currently, the rule `V2Writes` support converting `ApplyTransform` to `TransformExpression` (unevaluable), this PR makes `V2Writes` supports converting `TransformExpression` to `ApplyFunctionExpression`/`Invoke`/`StaticInvoke` (evaluable). ### Why are the changes needed? SPARK-33779 introduced API for DSv2 writer to claim distributions and orderings of data before writing, w/ SPARK-34026, Spark can translate `IdentityTransform` to catalyst expression in distributions and orderings expressions. But for some databases like ClickHouse, which allows table partition defined by an expression, e.g. `PARTITIONED BY num % 10` it's useful to support translating `ApplyTransform` so Spark can organize the data to fit the target storage requirement before writing. ### Does this PR introduce _any_ user-facing change? Yes, user can use V2 function as partition transform in DSv2 connector. ### How was this patch tested? UT added. Closes #36995 from pan3793/SPARK-39607. Authored-by: Cheng Pan <cheng...@apache.org> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../spark/sql/catalyst/analysis/Analyzer.scala | 49 +-------- .../catalyst/expressions/V2ExpressionUtils.scala | 62 ++++++++++- .../datasources/v2/DataSourceV2Relation.scala | 6 +- .../v2/DistributionAndOrderingUtils.scala | 44 ++++++-- .../v2/V2ScanPartitioningAndOrdering.scala | 8 +- .../sql/execution/datasources/v2/V2Writes.scala | 11 +- .../WriteDistributionAndOrderingSuite.scala | 115 ++++++++++++++++++++- .../catalog/functions/transformFunctions.scala | 27 ++++- 8 files changed, 245 insertions(+), 77 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 820202ef9c5..ae177efa05e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.analysis -import java.lang.reflect.{Method, Modifier} import java.util import java.util.Locale import java.util.concurrent.atomic.AtomicBoolean @@ -47,8 +46,7 @@ import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._ import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.connector.catalog.TableChange.{After, ColumnPosition} -import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction => V2AggregateFunction, BoundFunction, ScalarFunction, UnboundFunction} -import org.apache.spark.sql.connector.catalog.functions.ScalarFunction.MAGIC_METHOD_NAME +import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction => V2AggregateFunction, ScalarFunction, UnboundFunction} import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation @@ -2388,33 +2386,7 @@ class Analyzer(override val catalogManager: CatalogManager) throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( scalarFunc.name(), "IGNORE NULLS") } else { - val declaredInputTypes = scalarFunc.inputTypes().toSeq - val argClasses = declaredInputTypes.map(ScalaReflection.dataTypeJavaClass) - findMethod(scalarFunc, MAGIC_METHOD_NAME, argClasses) match { - case Some(m) if Modifier.isStatic(m.getModifiers) => - StaticInvoke(scalarFunc.getClass, scalarFunc.resultType(), - MAGIC_METHOD_NAME, arguments, inputTypes = declaredInputTypes, - propagateNull = false, returnNullable = scalarFunc.isResultNullable, - isDeterministic = scalarFunc.isDeterministic) - case Some(_) => - val caller = Literal.create(scalarFunc, ObjectType(scalarFunc.getClass)) - Invoke(caller, MAGIC_METHOD_NAME, scalarFunc.resultType(), - arguments, methodInputTypes = declaredInputTypes, propagateNull = false, - returnNullable = scalarFunc.isResultNullable, - isDeterministic = scalarFunc.isDeterministic) - case _ => - // TODO: handle functions defined in Scala too - in Scala, even if a - // subclass do not override the default method in parent interface - // defined in Java, the method can still be found from - // `getDeclaredMethod`. - findMethod(scalarFunc, "produceResult", Seq(classOf[InternalRow])) match { - case Some(_) => - ApplyFunctionExpression(scalarFunc, arguments) - case _ => - failAnalysis(s"ScalarFunction '${scalarFunc.name()}' neither implement" + - s" magic method nor override 'produceResult'") - } - } + V2ExpressionUtils.resolveScalarFunction(scalarFunc, arguments) } } @@ -2429,23 +2401,6 @@ class Analyzer(override val catalogManager: CatalogManager) val aggregator = V2Aggregator(aggFunc, arguments) aggregator.toAggregateExpression(u.isDistinct, u.filter) } - - /** - * Check if the input `fn` implements the given `methodName` with parameter types specified - * via `argClasses`. - */ - private def findMethod( - fn: BoundFunction, - methodName: String, - argClasses: Seq[Class[_]]): Option[Method] = { - val cls = fn.getClass - try { - Some(cls.getDeclaredMethod(methodName, argClasses: _*)) - } catch { - case _: NoSuchMethodException => - None - } - } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala index c252ea5ccfe..64eb307bb9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala @@ -17,13 +17,17 @@ package org.apache.spark.sql.catalyst.expressions +import java.lang.reflect.{Method, Modifier} + import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection, SQLConfHelper} import org.apache.spark.sql.catalyst.analysis.NoSuchFunctionException +import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.connector.catalog.{FunctionCatalog, Identifier} import org.apache.spark.sql.connector.catalog.functions._ +import org.apache.spark.sql.connector.catalog.functions.ScalarFunction.MAGIC_METHOD_NAME import org.apache.spark.sql.connector.expressions.{BucketTransform, Expression => V2Expression, FieldReference, IdentityTransform, NamedReference, NamedTransform, NullOrdering => V2NullOrdering, SortDirection => V2SortDirection, SortOrder => V2SortOrder, SortValue, Transform} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types._ @@ -52,8 +56,11 @@ object V2ExpressionUtils extends SQLConfHelper with Logging { /** * Converts the array of input V2 [[V2SortOrder]] into their counterparts in catalyst. */ - def toCatalystOrdering(ordering: Array[V2SortOrder], query: LogicalPlan): Seq[SortOrder] = { - ordering.map(toCatalyst(_, query).asInstanceOf[SortOrder]) + def toCatalystOrdering( + ordering: Array[V2SortOrder], + query: LogicalPlan, + funCatalogOpt: Option[FunctionCatalog] = None): Seq[SortOrder] = { + ordering.map(toCatalyst(_, query, funCatalogOpt).asInstanceOf[SortOrder]) } def toCatalyst( @@ -143,4 +150,53 @@ object V2ExpressionUtils extends SQLConfHelper with Logging { case V2NullOrdering.NULLS_FIRST => NullsFirst case V2NullOrdering.NULLS_LAST => NullsLast } + + def resolveScalarFunction( + scalarFunc: ScalarFunction[_], + arguments: Seq[Expression]): Expression = { + val declaredInputTypes = scalarFunc.inputTypes().toSeq + val argClasses = declaredInputTypes.map(ScalaReflection.dataTypeJavaClass) + findMethod(scalarFunc, MAGIC_METHOD_NAME, argClasses) match { + case Some(m) if Modifier.isStatic(m.getModifiers) => + StaticInvoke(scalarFunc.getClass, scalarFunc.resultType(), + MAGIC_METHOD_NAME, arguments, inputTypes = declaredInputTypes, + propagateNull = false, returnNullable = scalarFunc.isResultNullable, + isDeterministic = scalarFunc.isDeterministic) + case Some(_) => + val caller = Literal.create(scalarFunc, ObjectType(scalarFunc.getClass)) + Invoke(caller, MAGIC_METHOD_NAME, scalarFunc.resultType(), + arguments, methodInputTypes = declaredInputTypes, propagateNull = false, + returnNullable = scalarFunc.isResultNullable, + isDeterministic = scalarFunc.isDeterministic) + case _ => + // TODO: handle functions defined in Scala too - in Scala, even if a + // subclass do not override the default method in parent interface + // defined in Java, the method can still be found from + // `getDeclaredMethod`. + findMethod(scalarFunc, "produceResult", Seq(classOf[InternalRow])) match { + case Some(_) => + ApplyFunctionExpression(scalarFunc, arguments) + case _ => + throw new AnalysisException(s"ScalarFunction '${scalarFunc.name()}'" + + s" neither implement magic method nor override 'produceResult'") + } + } + } + + /** + * Check if the input `fn` implements the given `methodName` with parameter types specified + * via `argClasses`. + */ + private def findMethod( + fn: BoundFunction, + methodName: String, + argClasses: Seq[Class[_]]): Option[Method] = { + val cls = fn.getClass + try { + Some(cls.getDeclaredMethod(methodName, argClasses: _*)) + } catch { + case _: NoSuchMethodException => + None + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 2045c599337..4fe01ac7607 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelat import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.{ExposesMetadataColumns, LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.util.{truncatedString, CharVarcharUtils} -import org.apache.spark.sql.connector.catalog.{CatalogPlugin, Identifier, MetadataColumn, SupportsMetadataColumns, Table, TableCapability} +import org.apache.spark.sql.connector.catalog.{CatalogPlugin, FunctionCatalog, Identifier, MetadataColumn, SupportsMetadataColumns, Table, TableCapability} import org.apache.spark.sql.connector.read.{Scan, Statistics => V2Statistics, SupportsReportStatistics} import org.apache.spark.sql.connector.read.streaming.{Offset, SparkDataStream} import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -48,6 +48,10 @@ case class DataSourceV2Relation( import DataSourceV2Implicits._ + lazy val funCatalog: Option[FunctionCatalog] = catalog.collect { + case c: FunctionCatalog => c + } + override lazy val metadataOutput: Seq[AttributeReference] = table match { case hasMeta: SupportsMetadataColumns => val resolve = conf.resolver diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DistributionAndOrderingUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DistributionAndOrderingUtils.scala index 07ede819880..b0b0d7bbc2d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DistributionAndOrderingUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DistributionAndOrderingUtils.scala @@ -17,22 +17,33 @@ package org.apache.spark.sql.execution.datasources.v2 -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.analysis.{AnsiTypeCoercion, TypeCoercion} +import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, SortOrder, TransformExpression, V2ExpressionUtils} import org.apache.spark.sql.catalyst.expressions.V2ExpressionUtils._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, RebalancePartitions, RepartitionByExpression, Sort} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.catalog.FunctionCatalog +import org.apache.spark.sql.connector.catalog.functions.ScalarFunction import org.apache.spark.sql.connector.distributions._ import org.apache.spark.sql.connector.write.{RequiresDistributionAndOrdering, Write} import org.apache.spark.sql.errors.QueryCompilationErrors object DistributionAndOrderingUtils { - def prepareQuery(write: Write, query: LogicalPlan): LogicalPlan = write match { + def prepareQuery( + write: Write, + query: LogicalPlan, + funCatalogOpt: Option[FunctionCatalog]): LogicalPlan = write match { case write: RequiresDistributionAndOrdering => val numPartitions = write.requiredNumPartitions() val distribution = write.requiredDistribution match { - case d: OrderedDistribution => toCatalystOrdering(d.ordering(), query) - case d: ClusteredDistribution => d.clustering.map(e => toCatalyst(e, query)).toSeq + case d: OrderedDistribution => + toCatalystOrdering(d.ordering(), query, funCatalogOpt) + .map(e => resolveTransformExpression(e).asInstanceOf[SortOrder]) + case d: ClusteredDistribution => + d.clustering.map(e => toCatalyst(e, query, funCatalogOpt)) + .map(e => resolveTransformExpression(e)).toSeq case _: UnspecifiedDistribution => Seq.empty[Expression] } @@ -53,16 +64,33 @@ object DistributionAndOrderingUtils { query } - val ordering = toCatalystOrdering(write.requiredOrdering, query) + val ordering = toCatalystOrdering(write.requiredOrdering, query, funCatalogOpt) val queryWithDistributionAndOrdering = if (ordering.nonEmpty) { - Sort(ordering, global = false, queryWithDistribution) + Sort( + ordering.map(e => resolveTransformExpression(e).asInstanceOf[SortOrder]), + global = false, + queryWithDistribution) } else { queryWithDistribution } - queryWithDistributionAndOrdering - + // Apply typeCoercionRules since the converted expression from TransformExpression + // implemented ImplicitCastInputTypes + typeCoercionRules.foldLeft(queryWithDistributionAndOrdering)((plan, rule) => rule(plan)) case _ => query } + + private def resolveTransformExpression(expr: Expression): Expression = expr.transform { + case TransformExpression(scalarFunc: ScalarFunction[_], arguments, Some(numBuckets)) => + V2ExpressionUtils.resolveScalarFunction(scalarFunc, Seq(Literal(numBuckets)) ++ arguments) + case TransformExpression(scalarFunc: ScalarFunction[_], arguments, None) => + V2ExpressionUtils.resolveScalarFunction(scalarFunc, arguments) + } + + private def typeCoercionRules: List[Rule[LogicalPlan]] = if (conf.ansiEnabled) { + AnsiTypeCoercion.typeCoercionRules + } else { + TypeCoercion.typeCoercionRules + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanPartitioningAndOrdering.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanPartitioningAndOrdering.scala index 7ea1ca8c244..8ab0dc70726 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanPartitioningAndOrdering.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanPartitioningAndOrdering.scala @@ -20,7 +20,6 @@ import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.expressions.V2ExpressionUtils import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.connector.catalog.FunctionCatalog import org.apache.spark.sql.connector.read.{SupportsReportOrdering, SupportsReportPartitioning} import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, UnknownPartitioning} import org.apache.spark.util.collection.Utils.sequenceToOption @@ -41,14 +40,9 @@ object V2ScanPartitioningAndOrdering extends Rule[LogicalPlan] with SQLConfHelpe private def partitioning(plan: LogicalPlan) = plan.transformDown { case d @ DataSourceV2ScanRelation(relation, scan: SupportsReportPartitioning, _, None, _) => - val funCatalogOpt = relation.catalog.flatMap { - case c: FunctionCatalog => Some(c) - case _ => None - } - val catalystPartitioning = scan.outputPartitioning() match { case kgp: KeyGroupedPartitioning => sequenceToOption(kgp.keys().map( - V2ExpressionUtils.toCatalystOpt(_, relation, funCatalogOpt))) + V2ExpressionUtils.toCatalystOpt(_, relation, relation.funCatalog))) case _: UnknownPartitioning => None case p => throw new IllegalArgumentException("Unsupported data source V2 partitioning " + "type: " + p.getClass.getSimpleName) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala index 2d47d94ff1d..afdcf2c870d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala @@ -43,7 +43,7 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper { case a @ AppendData(r: DataSourceV2Relation, query, options, _, None) => val writeBuilder = newWriteBuilder(r.table, options, query.schema) val write = writeBuilder.build() - val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query) + val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, r.funCatalog) a.copy(write = Some(write), query = newQuery) case o @ OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query, options, _, None) => @@ -67,7 +67,7 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper { throw QueryExecutionErrors.overwriteTableByUnsupportedExpressionError(table) } - val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query) + val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, r.funCatalog) o.copy(write = Some(write), query = newQuery) case o @ OverwritePartitionsDynamic(r: DataSourceV2Relation, query, options, _, None) => @@ -79,7 +79,7 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper { case _ => throw QueryExecutionErrors.dynamicPartitionOverwriteUnsupportedByTableError(table) } - val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query) + val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, r.funCatalog) o.copy(write = Some(write), query = newQuery) case WriteToMicroBatchDataSource( @@ -89,14 +89,15 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper { val write = buildWriteForMicroBatch(table, writeBuilder, outputMode) val microBatchWrite = new MicroBatchWrite(batchId, write.toStreaming) val customMetrics = write.supportedCustomMetrics.toSeq - val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query) + val funCatalogOpt = relation.flatMap(_.funCatalog) + val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, funCatalogOpt) WriteToDataSourceV2(relation, microBatchWrite, newQuery, customMetrics) case rd @ ReplaceData(r: DataSourceV2Relation, _, query, _, None) => val rowSchema = StructType.fromAttributes(rd.dataInput) val writeBuilder = newWriteBuilder(r.table, Map.empty, rowSchema) val write = writeBuilder.build() - val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query) + val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, r.funCatalog) // project away any metadata columns that could be used for distribution and ordering rd.copy(write = Some(write), query = Project(rd.dataInput, newQuery)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala index 26baec90f3c..7966add7738 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala @@ -20,11 +20,13 @@ package org.apache.spark.sql.connector import java.util.Collections import org.apache.spark.sql.{catalyst, AnalysisException, DataFrame, Row} +import org.apache.spark.sql.catalyst.expressions.{ApplyFunctionExpression, Cast, Literal} import org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, RangePartitioning, UnknownPartitioning} import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.connector.catalog.functions.{BucketFunction, StringSelfFunction, UnboundBucketFunction, UnboundStringSelfFunction} import org.apache.spark.sql.connector.distributions.{Distribution, Distributions} -import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, NullOrdering, SortDirection, SortOrder} +import org.apache.spark.sql.connector.expressions._ import org.apache.spark.sql.connector.expressions.LogicalExpressions._ import org.apache.spark.sql.execution.{QueryExecution, SortExec, SparkPlan} import org.apache.spark.sql.execution.adaptive.AQEShuffleReadExec @@ -36,13 +38,21 @@ import org.apache.spark.sql.functions.lit import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.{StreamingQueryException, Trigger} import org.apache.spark.sql.test.SQLTestData.TestData -import org.apache.spark.sql.types.{IntegerType, StringType, StructType} +import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} import org.apache.spark.sql.util.QueryExecutionListener class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase { import testImplicits._ + before { + Seq(UnboundBucketFunction, UnboundStringSelfFunction).foreach { f => + catalog.createFunction(Identifier.of(Array.empty, f.name()), f) + } + } + after { + catalog.clearTables() + catalog.clearFunctions() spark.sessionState.catalogManager.reset() } @@ -987,6 +997,95 @@ class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase } } + test("clustered distribution and local sort contains v2 function: append") { + checkClusteredDistributionAndLocalSortContainsV2FunctionInVariousCases("append") + } + + test("clustered distribution and local sort contains v2 function: overwrite") { + checkClusteredDistributionAndLocalSortContainsV2FunctionInVariousCases("overwrite") + } + + test("clustered distribution and local sort contains v2 function: overwriteDynamic") { + checkClusteredDistributionAndLocalSortContainsV2FunctionInVariousCases("overwriteDynamic") + } + + test("clustered distribution and local sort contains v2 function with numPartitions: append") { + checkClusteredDistributionAndLocalSortContainsV2Function("append", Some(10)) + } + + test("clustered distribution and local sort contains v2 function with numPartitions: " + + "overwrite") { + checkClusteredDistributionAndLocalSortContainsV2Function("overwrite", Some(10)) + } + + test("clustered distribution and local sort contains v2 function with numPartitions: " + + "overwriteDynamic") { + checkClusteredDistributionAndLocalSortContainsV2Function("overwriteDynamic", Some(10)) + } + + private def checkClusteredDistributionAndLocalSortContainsV2FunctionInVariousCases( + cmd: String): Unit = { + Seq(true, false).foreach { distributionStrictlyRequired => + Seq(true, false).foreach { dataSkewed => + Seq(true, false).foreach { coalesce => + checkClusteredDistributionAndLocalSortContainsV2Function( + cmd, None, distributionStrictlyRequired, dataSkewed, coalesce) + } + } + } + } + + private def checkClusteredDistributionAndLocalSortContainsV2Function( + command: String, + targetNumPartitions: Option[Int] = None, + distributionStrictlyRequired: Boolean = true, + dataSkewed: Boolean = false, + coalesce: Boolean = false): Unit = { + val tableOrdering = Array[SortOrder]( + sort(FieldReference("data"), SortDirection.DESCENDING, NullOrdering.NULLS_FIRST), + sort( + BucketTransform(LiteralValue(10, IntegerType), Seq(FieldReference("id"))), + SortDirection.DESCENDING, + NullOrdering.NULLS_FIRST) + ) + val tableDistribution = Distributions.clustered(Array( + ApplyTransform("string_self", Seq(FieldReference("data"))))) + + val writeOrdering = Seq( + catalyst.expressions.SortOrder( + attr("data"), + catalyst.expressions.Descending, + catalyst.expressions.NullsFirst, + Seq.empty + ), + catalyst.expressions.SortOrder( + ApplyFunctionExpression(BucketFunction, Seq(Literal(10), Cast(attr("id"), LongType))), + catalyst.expressions.Descending, + catalyst.expressions.NullsFirst, + Seq.empty + ) + ) + + val writePartitioningExprs = Seq( + ApplyFunctionExpression(StringSelfFunction, Seq(attr("data")))) + val writePartitioning = if (!coalesce) { + clusteredWritePartitioning(writePartitioningExprs, targetNumPartitions) + } else { + clusteredWritePartitioning(writePartitioningExprs, Some(1)) + } + + checkWriteRequirements( + tableDistribution, + tableOrdering, + targetNumPartitions, + expectedWritePartitioning = writePartitioning, + expectedWriteOrdering = writeOrdering, + writeCommand = command, + distributionStrictlyRequired = distributionStrictlyRequired, + dataSkewed = dataSkewed, + coalesce = coalesce) + } + // scalastyle:off argcount private def checkWriteRequirements( tableDistribution: Distribution, @@ -1209,12 +1308,20 @@ class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase if (skewSplit) { assert(actualPartitioning.numPartitions > conf.numShufflePartitions) } else { - assert(actualPartitioning == expectedPartitioning, "partitioning must match") + (actualPartitioning, expectedPartitioning) match { + case (actual: catalyst.expressions.Expression, expected: catalyst.expressions.Expression) => + assert(actual semanticEquals expected, "partitioning must match") + case (actual, expected) => + assert(actual == expected, "partitioning must match") + } } val actualOrdering = plan.outputOrdering val expectedOrdering = ordering.map(resolveAttrs(_, plan)) - assert(actualOrdering == expectedOrdering, "ordering must match") + assert(actualOrdering.length == expectedOrdering.length) + (actualOrdering zip expectedOrdering).foreach { case (actual, expected) => + assert(actual semanticEquals expected, "ordering must match") + } } // executes a write operation and keeps the executed physical plan diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala index 1994874d328..9277e8d059f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala @@ -16,7 +16,9 @@ */ package org.apache.spark.sql.connector.catalog.functions +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String object UnboundYearsFunction extends UnboundFunction { override def bind(inputType: StructType): BoundFunction = { @@ -70,9 +72,30 @@ object UnboundBucketFunction extends UnboundFunction { override def name(): String = "bucket" } -object BucketFunction extends BoundFunction { - override def inputTypes(): Array[DataType] = Array(IntegerType, IntegerType) +object BucketFunction extends ScalarFunction[Int] { + override def inputTypes(): Array[DataType] = Array(IntegerType, LongType) override def resultType(): DataType = IntegerType override def name(): String = "bucket" override def canonicalName(): String = name() + override def toString: String = name() + override def produceResult(input: InternalRow): Int = { + (input.getLong(1) % input.getInt(0)).toInt + } +} + +object UnboundStringSelfFunction extends UnboundFunction { + override def bind(inputType: StructType): BoundFunction = StringSelfFunction + override def description(): String = name() + override def name(): String = "string_self" +} + +object StringSelfFunction extends ScalarFunction[UTF8String] { + override def inputTypes(): Array[DataType] = Array(StringType) + override def resultType(): DataType = StringType + override def name(): String = "string_self" + override def canonicalName(): String = name() + override def toString: String = name() + override def produceResult(input: InternalRow): UTF8String = { + input.getUTF8String(0) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org