This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 438233db4b2 [SPARK-39607][SQL][DSV2] Distribution and ordering support
V2 function in writing
438233db4b2 is described below
commit 438233db4b2177a884e5f970970da999c491d1ec
Author: Cheng Pan <[email protected]>
AuthorDate: Mon Aug 29 14:30:54 2022 +0800
[SPARK-39607][SQL][DSV2] Distribution and ordering support V2 function in
writing
### What changes were proposed in this pull request?
Add new feature, make distribution and ordering support V2 function in
writing.
Currently, the rule `V2Writes` support converting `ApplyTransform` to
`TransformExpression` (unevaluable), this PR makes `V2Writes` supports
converting `TransformExpression` to
`ApplyFunctionExpression`/`Invoke`/`StaticInvoke` (evaluable).
### Why are the changes needed?
SPARK-33779 introduced API for DSv2 writer to claim distributions and
orderings of data before writing, w/ SPARK-34026, Spark can translate
`IdentityTransform` to catalyst expression in distributions and orderings
expressions.
But for some databases like ClickHouse, which allows table partition
defined by an expression, e.g. `PARTITIONED BY num % 10` it's useful to support
translating `ApplyTransform` so Spark can organize the data to fit the target
storage requirement before writing.
### Does this PR introduce _any_ user-facing change?
Yes, user can use V2 function as partition transform in DSv2 connector.
### How was this patch tested?
UT added.
Closes #36995 from pan3793/SPARK-39607.
Authored-by: Cheng Pan <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../spark/sql/catalyst/analysis/Analyzer.scala | 49 +--------
.../catalyst/expressions/V2ExpressionUtils.scala | 62 ++++++++++-
.../datasources/v2/DataSourceV2Relation.scala | 6 +-
.../v2/DistributionAndOrderingUtils.scala | 44 ++++++--
.../v2/V2ScanPartitioningAndOrdering.scala | 8 +-
.../sql/execution/datasources/v2/V2Writes.scala | 11 +-
.../WriteDistributionAndOrderingSuite.scala | 115 ++++++++++++++++++++-
.../catalog/functions/transformFunctions.scala | 27 ++++-
8 files changed, 245 insertions(+), 77 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 820202ef9c5..ae177efa05e 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.catalyst.analysis
-import java.lang.reflect.{Method, Modifier}
import java.util
import java.util.Locale
import java.util.concurrent.atomic.AtomicBoolean
@@ -47,8 +46,7 @@ import
org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._
import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.catalog.TableChange.{After,
ColumnPosition}
-import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction =>
V2AggregateFunction, BoundFunction, ScalarFunction, UnboundFunction}
-import
org.apache.spark.sql.connector.catalog.functions.ScalarFunction.MAGIC_METHOD_NAME
+import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction =>
V2AggregateFunction, ScalarFunction, UnboundFunction}
import org.apache.spark.sql.connector.expressions.{FieldReference,
IdentityTransform, Transform}
import org.apache.spark.sql.errors.{QueryCompilationErrors,
QueryExecutionErrors}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
@@ -2388,33 +2386,7 @@ class Analyzer(override val catalogManager:
CatalogManager)
throw QueryCompilationErrors.functionWithUnsupportedSyntaxError(
scalarFunc.name(), "IGNORE NULLS")
} else {
- val declaredInputTypes = scalarFunc.inputTypes().toSeq
- val argClasses =
declaredInputTypes.map(ScalaReflection.dataTypeJavaClass)
- findMethod(scalarFunc, MAGIC_METHOD_NAME, argClasses) match {
- case Some(m) if Modifier.isStatic(m.getModifiers) =>
- StaticInvoke(scalarFunc.getClass, scalarFunc.resultType(),
- MAGIC_METHOD_NAME, arguments, inputTypes = declaredInputTypes,
- propagateNull = false, returnNullable =
scalarFunc.isResultNullable,
- isDeterministic = scalarFunc.isDeterministic)
- case Some(_) =>
- val caller = Literal.create(scalarFunc,
ObjectType(scalarFunc.getClass))
- Invoke(caller, MAGIC_METHOD_NAME, scalarFunc.resultType(),
- arguments, methodInputTypes = declaredInputTypes, propagateNull
= false,
- returnNullable = scalarFunc.isResultNullable,
- isDeterministic = scalarFunc.isDeterministic)
- case _ =>
- // TODO: handle functions defined in Scala too - in Scala, even if
a
- // subclass do not override the default method in parent interface
- // defined in Java, the method can still be found from
- // `getDeclaredMethod`.
- findMethod(scalarFunc, "produceResult", Seq(classOf[InternalRow]))
match {
- case Some(_) =>
- ApplyFunctionExpression(scalarFunc, arguments)
- case _ =>
- failAnalysis(s"ScalarFunction '${scalarFunc.name()}' neither
implement" +
- s" magic method nor override 'produceResult'")
- }
- }
+ V2ExpressionUtils.resolveScalarFunction(scalarFunc, arguments)
}
}
@@ -2429,23 +2401,6 @@ class Analyzer(override val catalogManager:
CatalogManager)
val aggregator = V2Aggregator(aggFunc, arguments)
aggregator.toAggregateExpression(u.isDistinct, u.filter)
}
-
- /**
- * Check if the input `fn` implements the given `methodName` with
parameter types specified
- * via `argClasses`.
- */
- private def findMethod(
- fn: BoundFunction,
- methodName: String,
- argClasses: Seq[Class[_]]): Option[Method] = {
- val cls = fn.getClass
- try {
- Some(cls.getDeclaredMethod(methodName, argClasses: _*))
- } catch {
- case _: NoSuchMethodException =>
- None
- }
- }
}
/**
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala
index c252ea5ccfe..64eb307bb9f 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala
@@ -17,13 +17,17 @@
package org.apache.spark.sql.catalyst.expressions
+import java.lang.reflect.{Method, Modifier}
+
import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.SQLConfHelper
+import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection,
SQLConfHelper}
import org.apache.spark.sql.catalyst.analysis.NoSuchFunctionException
+import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.connector.catalog.{FunctionCatalog, Identifier}
import org.apache.spark.sql.connector.catalog.functions._
+import
org.apache.spark.sql.connector.catalog.functions.ScalarFunction.MAGIC_METHOD_NAME
import org.apache.spark.sql.connector.expressions.{BucketTransform, Expression
=> V2Expression, FieldReference, IdentityTransform, NamedReference,
NamedTransform, NullOrdering => V2NullOrdering, SortDirection =>
V2SortDirection, SortOrder => V2SortOrder, SortValue, Transform}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types._
@@ -52,8 +56,11 @@ object V2ExpressionUtils extends SQLConfHelper with Logging {
/**
* Converts the array of input V2 [[V2SortOrder]] into their counterparts in
catalyst.
*/
- def toCatalystOrdering(ordering: Array[V2SortOrder], query: LogicalPlan):
Seq[SortOrder] = {
- ordering.map(toCatalyst(_, query).asInstanceOf[SortOrder])
+ def toCatalystOrdering(
+ ordering: Array[V2SortOrder],
+ query: LogicalPlan,
+ funCatalogOpt: Option[FunctionCatalog] = None): Seq[SortOrder] = {
+ ordering.map(toCatalyst(_, query, funCatalogOpt).asInstanceOf[SortOrder])
}
def toCatalyst(
@@ -143,4 +150,53 @@ object V2ExpressionUtils extends SQLConfHelper with
Logging {
case V2NullOrdering.NULLS_FIRST => NullsFirst
case V2NullOrdering.NULLS_LAST => NullsLast
}
+
+ def resolveScalarFunction(
+ scalarFunc: ScalarFunction[_],
+ arguments: Seq[Expression]): Expression = {
+ val declaredInputTypes = scalarFunc.inputTypes().toSeq
+ val argClasses = declaredInputTypes.map(ScalaReflection.dataTypeJavaClass)
+ findMethod(scalarFunc, MAGIC_METHOD_NAME, argClasses) match {
+ case Some(m) if Modifier.isStatic(m.getModifiers) =>
+ StaticInvoke(scalarFunc.getClass, scalarFunc.resultType(),
+ MAGIC_METHOD_NAME, arguments, inputTypes = declaredInputTypes,
+ propagateNull = false, returnNullable = scalarFunc.isResultNullable,
+ isDeterministic = scalarFunc.isDeterministic)
+ case Some(_) =>
+ val caller = Literal.create(scalarFunc,
ObjectType(scalarFunc.getClass))
+ Invoke(caller, MAGIC_METHOD_NAME, scalarFunc.resultType(),
+ arguments, methodInputTypes = declaredInputTypes, propagateNull =
false,
+ returnNullable = scalarFunc.isResultNullable,
+ isDeterministic = scalarFunc.isDeterministic)
+ case _ =>
+ // TODO: handle functions defined in Scala too - in Scala, even if a
+ // subclass do not override the default method in parent interface
+ // defined in Java, the method can still be found from
+ // `getDeclaredMethod`.
+ findMethod(scalarFunc, "produceResult", Seq(classOf[InternalRow]))
match {
+ case Some(_) =>
+ ApplyFunctionExpression(scalarFunc, arguments)
+ case _ =>
+ throw new AnalysisException(s"ScalarFunction
'${scalarFunc.name()}'" +
+ s" neither implement magic method nor override 'produceResult'")
+ }
+ }
+ }
+
+ /**
+ * Check if the input `fn` implements the given `methodName` with parameter
types specified
+ * via `argClasses`.
+ */
+ private def findMethod(
+ fn: BoundFunction,
+ methodName: String,
+ argClasses: Seq[Class[_]]): Option[Method] = {
+ val cls = fn.getClass
+ try {
+ Some(cls.getDeclaredMethod(methodName, argClasses: _*))
+ } catch {
+ case _: NoSuchMethodException =>
+ None
+ }
+ }
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala
index 2045c599337..4fe01ac7607 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala
@@ -21,7 +21,7 @@ import
org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelat
import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, Expression, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{ExposesMetadataColumns,
LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.catalyst.util.{truncatedString, CharVarcharUtils}
-import org.apache.spark.sql.connector.catalog.{CatalogPlugin, Identifier,
MetadataColumn, SupportsMetadataColumns, Table, TableCapability}
+import org.apache.spark.sql.connector.catalog.{CatalogPlugin, FunctionCatalog,
Identifier, MetadataColumn, SupportsMetadataColumns, Table, TableCapability}
import org.apache.spark.sql.connector.read.{Scan, Statistics => V2Statistics,
SupportsReportStatistics}
import org.apache.spark.sql.connector.read.streaming.{Offset, SparkDataStream}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -48,6 +48,10 @@ case class DataSourceV2Relation(
import DataSourceV2Implicits._
+ lazy val funCatalog: Option[FunctionCatalog] = catalog.collect {
+ case c: FunctionCatalog => c
+ }
+
override lazy val metadataOutput: Seq[AttributeReference] = table match {
case hasMeta: SupportsMetadataColumns =>
val resolve = conf.resolver
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DistributionAndOrderingUtils.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DistributionAndOrderingUtils.scala
index 07ede819880..b0b0d7bbc2d 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DistributionAndOrderingUtils.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DistributionAndOrderingUtils.scala
@@ -17,22 +17,33 @@
package org.apache.spark.sql.execution.datasources.v2
-import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.analysis.{AnsiTypeCoercion, TypeCoercion}
+import org.apache.spark.sql.catalyst.expressions.{Expression, Literal,
SortOrder, TransformExpression, V2ExpressionUtils}
import org.apache.spark.sql.catalyst.expressions.V2ExpressionUtils._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan,
RebalancePartitions, RepartitionByExpression, Sort}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.connector.catalog.FunctionCatalog
+import org.apache.spark.sql.connector.catalog.functions.ScalarFunction
import org.apache.spark.sql.connector.distributions._
import org.apache.spark.sql.connector.write.{RequiresDistributionAndOrdering,
Write}
import org.apache.spark.sql.errors.QueryCompilationErrors
object DistributionAndOrderingUtils {
- def prepareQuery(write: Write, query: LogicalPlan): LogicalPlan = write
match {
+ def prepareQuery(
+ write: Write,
+ query: LogicalPlan,
+ funCatalogOpt: Option[FunctionCatalog]): LogicalPlan = write match {
case write: RequiresDistributionAndOrdering =>
val numPartitions = write.requiredNumPartitions()
val distribution = write.requiredDistribution match {
- case d: OrderedDistribution => toCatalystOrdering(d.ordering(), query)
- case d: ClusteredDistribution => d.clustering.map(e => toCatalyst(e,
query)).toSeq
+ case d: OrderedDistribution =>
+ toCatalystOrdering(d.ordering(), query, funCatalogOpt)
+ .map(e => resolveTransformExpression(e).asInstanceOf[SortOrder])
+ case d: ClusteredDistribution =>
+ d.clustering.map(e => toCatalyst(e, query, funCatalogOpt))
+ .map(e => resolveTransformExpression(e)).toSeq
case _: UnspecifiedDistribution => Seq.empty[Expression]
}
@@ -53,16 +64,33 @@ object DistributionAndOrderingUtils {
query
}
- val ordering = toCatalystOrdering(write.requiredOrdering, query)
+ val ordering = toCatalystOrdering(write.requiredOrdering, query,
funCatalogOpt)
val queryWithDistributionAndOrdering = if (ordering.nonEmpty) {
- Sort(ordering, global = false, queryWithDistribution)
+ Sort(
+ ordering.map(e =>
resolveTransformExpression(e).asInstanceOf[SortOrder]),
+ global = false,
+ queryWithDistribution)
} else {
queryWithDistribution
}
- queryWithDistributionAndOrdering
-
+ // Apply typeCoercionRules since the converted expression from
TransformExpression
+ // implemented ImplicitCastInputTypes
+ typeCoercionRules.foldLeft(queryWithDistributionAndOrdering)((plan,
rule) => rule(plan))
case _ =>
query
}
+
+ private def resolveTransformExpression(expr: Expression): Expression =
expr.transform {
+ case TransformExpression(scalarFunc: ScalarFunction[_], arguments,
Some(numBuckets)) =>
+ V2ExpressionUtils.resolveScalarFunction(scalarFunc,
Seq(Literal(numBuckets)) ++ arguments)
+ case TransformExpression(scalarFunc: ScalarFunction[_], arguments, None) =>
+ V2ExpressionUtils.resolveScalarFunction(scalarFunc, arguments)
+ }
+
+ private def typeCoercionRules: List[Rule[LogicalPlan]] = if
(conf.ansiEnabled) {
+ AnsiTypeCoercion.typeCoercionRules
+ } else {
+ TypeCoercion.typeCoercionRules
+ }
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanPartitioningAndOrdering.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanPartitioningAndOrdering.scala
index 7ea1ca8c244..8ab0dc70726 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanPartitioningAndOrdering.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanPartitioningAndOrdering.scala
@@ -20,7 +20,6 @@ import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.expressions.V2ExpressionUtils
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.connector.catalog.FunctionCatalog
import org.apache.spark.sql.connector.read.{SupportsReportOrdering,
SupportsReportPartitioning}
import
org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning,
UnknownPartitioning}
import org.apache.spark.util.collection.Utils.sequenceToOption
@@ -41,14 +40,9 @@ object V2ScanPartitioningAndOrdering extends
Rule[LogicalPlan] with SQLConfHelpe
private def partitioning(plan: LogicalPlan) = plan.transformDown {
case d @ DataSourceV2ScanRelation(relation, scan:
SupportsReportPartitioning, _, None, _) =>
- val funCatalogOpt = relation.catalog.flatMap {
- case c: FunctionCatalog => Some(c)
- case _ => None
- }
-
val catalystPartitioning = scan.outputPartitioning() match {
case kgp: KeyGroupedPartitioning => sequenceToOption(kgp.keys().map(
- V2ExpressionUtils.toCatalystOpt(_, relation, funCatalogOpt)))
+ V2ExpressionUtils.toCatalystOpt(_, relation, relation.funCatalog)))
case _: UnknownPartitioning => None
case p => throw new IllegalArgumentException("Unsupported data source
V2 partitioning " +
"type: " + p.getClass.getSimpleName)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala
index 2d47d94ff1d..afdcf2c870d 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala
@@ -43,7 +43,7 @@ object V2Writes extends Rule[LogicalPlan] with
PredicateHelper {
case a @ AppendData(r: DataSourceV2Relation, query, options, _, None) =>
val writeBuilder = newWriteBuilder(r.table, options, query.schema)
val write = writeBuilder.build()
- val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query)
+ val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query,
r.funCatalog)
a.copy(write = Some(write), query = newQuery)
case o @ OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query,
options, _, None) =>
@@ -67,7 +67,7 @@ object V2Writes extends Rule[LogicalPlan] with
PredicateHelper {
throw
QueryExecutionErrors.overwriteTableByUnsupportedExpressionError(table)
}
- val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query)
+ val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query,
r.funCatalog)
o.copy(write = Some(write), query = newQuery)
case o @ OverwritePartitionsDynamic(r: DataSourceV2Relation, query,
options, _, None) =>
@@ -79,7 +79,7 @@ object V2Writes extends Rule[LogicalPlan] with
PredicateHelper {
case _ =>
throw
QueryExecutionErrors.dynamicPartitionOverwriteUnsupportedByTableError(table)
}
- val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query)
+ val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query,
r.funCatalog)
o.copy(write = Some(write), query = newQuery)
case WriteToMicroBatchDataSource(
@@ -89,14 +89,15 @@ object V2Writes extends Rule[LogicalPlan] with
PredicateHelper {
val write = buildWriteForMicroBatch(table, writeBuilder, outputMode)
val microBatchWrite = new MicroBatchWrite(batchId, write.toStreaming)
val customMetrics = write.supportedCustomMetrics.toSeq
- val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query)
+ val funCatalogOpt = relation.flatMap(_.funCatalog)
+ val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query,
funCatalogOpt)
WriteToDataSourceV2(relation, microBatchWrite, newQuery, customMetrics)
case rd @ ReplaceData(r: DataSourceV2Relation, _, query, _, None) =>
val rowSchema = StructType.fromAttributes(rd.dataInput)
val writeBuilder = newWriteBuilder(r.table, Map.empty, rowSchema)
val write = writeBuilder.build()
- val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query)
+ val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query,
r.funCatalog)
// project away any metadata columns that could be used for distribution
and ordering
rd.copy(write = Some(write), query = Project(rd.dataInput, newQuery))
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala
index 26baec90f3c..7966add7738 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala
@@ -20,11 +20,13 @@ package org.apache.spark.sql.connector
import java.util.Collections
import org.apache.spark.sql.{catalyst, AnalysisException, DataFrame, Row}
+import org.apache.spark.sql.catalyst.expressions.{ApplyFunctionExpression,
Cast, Literal}
import org.apache.spark.sql.catalyst.plans.physical
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning,
RangePartitioning, UnknownPartitioning}
import org.apache.spark.sql.connector.catalog.Identifier
+import org.apache.spark.sql.connector.catalog.functions.{BucketFunction,
StringSelfFunction, UnboundBucketFunction, UnboundStringSelfFunction}
import org.apache.spark.sql.connector.distributions.{Distribution,
Distributions}
-import org.apache.spark.sql.connector.expressions.{Expression, FieldReference,
NullOrdering, SortDirection, SortOrder}
+import org.apache.spark.sql.connector.expressions._
import org.apache.spark.sql.connector.expressions.LogicalExpressions._
import org.apache.spark.sql.execution.{QueryExecution, SortExec, SparkPlan}
import org.apache.spark.sql.execution.adaptive.AQEShuffleReadExec
@@ -36,13 +38,21 @@ import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.{StreamingQueryException, Trigger}
import org.apache.spark.sql.test.SQLTestData.TestData
-import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
+import org.apache.spark.sql.types.{IntegerType, LongType, StringType,
StructType}
import org.apache.spark.sql.util.QueryExecutionListener
class WriteDistributionAndOrderingSuite extends
DistributionAndOrderingSuiteBase {
import testImplicits._
+ before {
+ Seq(UnboundBucketFunction, UnboundStringSelfFunction).foreach { f =>
+ catalog.createFunction(Identifier.of(Array.empty, f.name()), f)
+ }
+ }
+
after {
+ catalog.clearTables()
+ catalog.clearFunctions()
spark.sessionState.catalogManager.reset()
}
@@ -987,6 +997,95 @@ class WriteDistributionAndOrderingSuite extends
DistributionAndOrderingSuiteBase
}
}
+ test("clustered distribution and local sort contains v2 function: append") {
+
checkClusteredDistributionAndLocalSortContainsV2FunctionInVariousCases("append")
+ }
+
+ test("clustered distribution and local sort contains v2 function:
overwrite") {
+
checkClusteredDistributionAndLocalSortContainsV2FunctionInVariousCases("overwrite")
+ }
+
+ test("clustered distribution and local sort contains v2 function:
overwriteDynamic") {
+
checkClusteredDistributionAndLocalSortContainsV2FunctionInVariousCases("overwriteDynamic")
+ }
+
+ test("clustered distribution and local sort contains v2 function with
numPartitions: append") {
+ checkClusteredDistributionAndLocalSortContainsV2Function("append",
Some(10))
+ }
+
+ test("clustered distribution and local sort contains v2 function with
numPartitions: " +
+ "overwrite") {
+ checkClusteredDistributionAndLocalSortContainsV2Function("overwrite",
Some(10))
+ }
+
+ test("clustered distribution and local sort contains v2 function with
numPartitions: " +
+ "overwriteDynamic") {
+
checkClusteredDistributionAndLocalSortContainsV2Function("overwriteDynamic",
Some(10))
+ }
+
+ private def
checkClusteredDistributionAndLocalSortContainsV2FunctionInVariousCases(
+ cmd: String): Unit = {
+ Seq(true, false).foreach { distributionStrictlyRequired =>
+ Seq(true, false).foreach { dataSkewed =>
+ Seq(true, false).foreach { coalesce =>
+ checkClusteredDistributionAndLocalSortContainsV2Function(
+ cmd, None, distributionStrictlyRequired, dataSkewed, coalesce)
+ }
+ }
+ }
+ }
+
+ private def checkClusteredDistributionAndLocalSortContainsV2Function(
+ command: String,
+ targetNumPartitions: Option[Int] = None,
+ distributionStrictlyRequired: Boolean = true,
+ dataSkewed: Boolean = false,
+ coalesce: Boolean = false): Unit = {
+ val tableOrdering = Array[SortOrder](
+ sort(FieldReference("data"), SortDirection.DESCENDING,
NullOrdering.NULLS_FIRST),
+ sort(
+ BucketTransform(LiteralValue(10, IntegerType),
Seq(FieldReference("id"))),
+ SortDirection.DESCENDING,
+ NullOrdering.NULLS_FIRST)
+ )
+ val tableDistribution = Distributions.clustered(Array(
+ ApplyTransform("string_self", Seq(FieldReference("data")))))
+
+ val writeOrdering = Seq(
+ catalyst.expressions.SortOrder(
+ attr("data"),
+ catalyst.expressions.Descending,
+ catalyst.expressions.NullsFirst,
+ Seq.empty
+ ),
+ catalyst.expressions.SortOrder(
+ ApplyFunctionExpression(BucketFunction, Seq(Literal(10),
Cast(attr("id"), LongType))),
+ catalyst.expressions.Descending,
+ catalyst.expressions.NullsFirst,
+ Seq.empty
+ )
+ )
+
+ val writePartitioningExprs = Seq(
+ ApplyFunctionExpression(StringSelfFunction, Seq(attr("data"))))
+ val writePartitioning = if (!coalesce) {
+ clusteredWritePartitioning(writePartitioningExprs, targetNumPartitions)
+ } else {
+ clusteredWritePartitioning(writePartitioningExprs, Some(1))
+ }
+
+ checkWriteRequirements(
+ tableDistribution,
+ tableOrdering,
+ targetNumPartitions,
+ expectedWritePartitioning = writePartitioning,
+ expectedWriteOrdering = writeOrdering,
+ writeCommand = command,
+ distributionStrictlyRequired = distributionStrictlyRequired,
+ dataSkewed = dataSkewed,
+ coalesce = coalesce)
+ }
+
// scalastyle:off argcount
private def checkWriteRequirements(
tableDistribution: Distribution,
@@ -1209,12 +1308,20 @@ class WriteDistributionAndOrderingSuite extends
DistributionAndOrderingSuiteBase
if (skewSplit) {
assert(actualPartitioning.numPartitions > conf.numShufflePartitions)
} else {
- assert(actualPartitioning == expectedPartitioning, "partitioning must
match")
+ (actualPartitioning, expectedPartitioning) match {
+ case (actual: catalyst.expressions.Expression, expected:
catalyst.expressions.Expression) =>
+ assert(actual semanticEquals expected, "partitioning must match")
+ case (actual, expected) =>
+ assert(actual == expected, "partitioning must match")
+ }
}
val actualOrdering = plan.outputOrdering
val expectedOrdering = ordering.map(resolveAttrs(_, plan))
- assert(actualOrdering == expectedOrdering, "ordering must match")
+ assert(actualOrdering.length == expectedOrdering.length)
+ (actualOrdering zip expectedOrdering).foreach { case (actual, expected) =>
+ assert(actual semanticEquals expected, "ordering must match")
+ }
}
// executes a write operation and keeps the executed physical plan
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala
index 1994874d328..9277e8d059f 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala
@@ -16,7 +16,9 @@
*/
package org.apache.spark.sql.connector.catalog.functions
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
object UnboundYearsFunction extends UnboundFunction {
override def bind(inputType: StructType): BoundFunction = {
@@ -70,9 +72,30 @@ object UnboundBucketFunction extends UnboundFunction {
override def name(): String = "bucket"
}
-object BucketFunction extends BoundFunction {
- override def inputTypes(): Array[DataType] = Array(IntegerType, IntegerType)
+object BucketFunction extends ScalarFunction[Int] {
+ override def inputTypes(): Array[DataType] = Array(IntegerType, LongType)
override def resultType(): DataType = IntegerType
override def name(): String = "bucket"
override def canonicalName(): String = name()
+ override def toString: String = name()
+ override def produceResult(input: InternalRow): Int = {
+ (input.getLong(1) % input.getInt(0)).toInt
+ }
+}
+
+object UnboundStringSelfFunction extends UnboundFunction {
+ override def bind(inputType: StructType): BoundFunction = StringSelfFunction
+ override def description(): String = name()
+ override def name(): String = "string_self"
+}
+
+object StringSelfFunction extends ScalarFunction[UTF8String] {
+ override def inputTypes(): Array[DataType] = Array(StringType)
+ override def resultType(): DataType = StringType
+ override def name(): String = "string_self"
+ override def canonicalName(): String = name()
+ override def toString: String = name()
+ override def produceResult(input: InternalRow): UTF8String = {
+ input.getUTF8String(0)
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]