This is an automated email from the ASF dual-hosted git repository. dongjoon pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 9e2aafb1373 [SPARK-45036][SQL] SPJ: Simplify the logic to handle partially clustered distribution 9e2aafb1373 is described below commit 9e2aafb13739f9c07f8218cd325c5532063b1a51 Author: Chao Sun <sunc...@apple.com> AuthorDate: Mon Sep 4 14:05:14 2023 -0700 [SPARK-45036][SQL] SPJ: Simplify the logic to handle partially clustered distribution ### What changes were proposed in this pull request? In SPJ, currently the logic to handle partially clustered distribution is a bit complicated. For instance, when the feature is eanbled (by enabling both `conf.v2BucketingPushPartValuesEnabled` and `conf.v2BucketingPartiallyClusteredDistributionEnabled`), Spark should postpone the combining of input splits until it is about to create an input RDD in `BatchScanExec`. To implement this, `groupPartitions` in `DataSourceV2ScanExecBase` currently takes the flag as input and has two differen [...] This PR introduces a new field in `KeyGroupedPartitioning`, named `originalPartitionValues`, that is used to store the original partition values from input before splits combining has been applied. The field is used when partially clustered distribution is enabled. With this, `groupPartitions` becomes easier to understand. In addition, this also simplifies `BatchScanExec.inputRDD` by combining two branches where partially clustered distribution is not enabled. ### Why are the changes needed? To simplify the current logic in the SPJ w.r.t partially clustered distribution. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? Closes #42757 from sunchao/SPARK-45036. Authored-by: Chao Sun <sunc...@apple.com> Signed-off-by: Dongjoon Hyun <dh...@apple.com> --- .../sql/catalyst/plans/physical/partitioning.scala | 35 +++--- .../execution/datasources/v2/BatchScanExec.scala | 117 +++++++++------------ .../datasources/v2/DataSourceV2ScanExecBase.scala | 65 +++++++----- .../execution/exchange/EnsureRequirements.scala | 9 +- .../execution/exchange/ShuffleExchangeExec.scala | 4 +- .../DistributionAndOrderingSuiteBase.scala | 6 +- .../connector/KeyGroupedPartitioningSuite.scala | 2 +- .../exchange/EnsureRequirementsSuite.scala | 2 +- 8 files changed, 122 insertions(+), 118 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index ce557422a08..0be4a61f275 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -312,26 +312,37 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) * Represents a partitioning where rows are split across partitions based on transforms defined * by `expressions`. `partitionValues`, if defined, should contain value of partition key(s) in * ascending order, after evaluated by the transforms in `expressions`, for each input partition. - * In addition, its length must be the same as the number of input partitions (and thus is a 1-1 - * mapping). The `partitionValues` may contain duplicated partition values. + * In addition, its length must be the same as the number of Spark partitions (and thus is a 1-1 + * mapping), and each row in `partitionValues` must be unique. * - * For example, if `expressions` is `[years(ts_col)]`, then a valid value of `partitionValues` is - * `[0, 1, 2]`, which represents 3 input partitions with distinct partition values. All rows - * in each partition have the same value for column `ts_col` (which is of timestamp type), after - * being applied by the `years` transform. + * The `originalPartitionValues`, on the other hand, are partition values from the original input + * splits returned by data sources. It may contain duplicated values. * - * On the other hand, `[0, 0, 1]` is not a valid value for `partitionValues` since `0` is - * duplicated twice. + * For example, if a data source reports partition transform expressions `[years(ts_col)]` with 4 + * input splits whose corresponding partition values are `[0, 1, 2, 2]`, then the `expressions` + * in this case is `[years(ts_col)]`, while `partitionValues` is `[0, 1, 2]`, which + * represents 3 input partitions with distinct partition values. All rows in each partition have + * the same value for column `ts_col` (which is of timestamp type), after being applied by the + * `years` transform. This is generated after combining the two splits with partition value `2` + * into a single Spark partition. + * + * On the other hand, in this example `[0, 1, 2, 2]` is the value of `originalPartitionValues` + * which is calculated from the original input splits. * * @param expressions partition expressions for the partitioning. * @param numPartitions the number of partitions - * @param partitionValues the values for the cluster keys of the distribution, must be - * in ascending order. + * @param partitionValues the values for the final cluster keys (that is, after applying grouping + * on the input splits according to `expressions`) of the distribution, + * must be in ascending order, and must NOT contain duplicated values. + * @param originalPartitionValues the original input partition values before any grouping has been + * applied, must be in ascending order, and may contain duplicated + * values */ case class KeyGroupedPartitioning( expressions: Seq[Expression], numPartitions: Int, - partitionValues: Seq[InternalRow] = Seq.empty) extends Partitioning { + partitionValues: Seq[InternalRow] = Seq.empty, + originalPartitionValues: Seq[InternalRow] = Seq.empty) extends Partitioning { override def satisfies0(required: Distribution): Boolean = { super.satisfies0(required) || { @@ -368,7 +379,7 @@ object KeyGroupedPartitioning { def apply( expressions: Seq[Expression], partitionValues: Seq[InternalRow]): KeyGroupedPartitioning = { - KeyGroupedPartitioning(expressions, partitionValues.size, partitionValues) + KeyGroupedPartitioning(expressions, partitionValues.size, partitionValues, partitionValues) } def supportsExpressions(expressions: Seq[Expression]): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index cc674961f8e..932ac0f5a1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, Par import org.apache.spark.sql.catalyst.util.{truncatedString, InternalRowComparableWrapper} import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.read._ -import org.apache.spark.sql.internal.SQLConf /** * Physical plan node for scanning a batch of data from a data source v2. @@ -101,7 +100,7 @@ case class BatchScanExec( "partition values that are not present in the original partitioning.") } - groupPartitions(newPartitions).get.map(_._2) + groupPartitions(newPartitions).get.groupedParts.map(_.parts) case _ => // no validation is needed as the data source did not report any specific partitioning @@ -137,81 +136,63 @@ case class BatchScanExec( outputPartitioning match { case p: KeyGroupedPartitioning => - if (conf.v2BucketingPushPartValuesEnabled && - conf.v2BucketingPartiallyClusteredDistributionEnabled) { - assert(filteredPartitions.forall(_.size == 1), - "Expect partitions to be not grouped when " + - s"${SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key} " + - "is enabled") - - val groupedPartitions = groupPartitions(finalPartitions.map(_.head), - groupSplits = true).get - - // This means the input partitions are not grouped by partition values. We'll need to - // check `groupByPartitionValues` and decide whether to group and replicate splits - // within a partition. - if (spjParams.commonPartitionValues.isDefined && - spjParams.applyPartialClustering) { - // A mapping from the common partition values to how many splits the partition - // should contain. - val commonPartValuesMap = spjParams.commonPartitionValues + val groupedPartitions = filteredPartitions.map(splits => { + assert(splits.nonEmpty && splits.head.isInstanceOf[HasPartitionKey]) + (splits.head.asInstanceOf[HasPartitionKey].partitionKey(), splits) + }) + + // When partially clustered, the input partitions are not grouped by partition + // values. Here we'll need to check `commonPartitionValues` and decide how to group + // and replicate splits within a partition. + if (spjParams.commonPartitionValues.isDefined && spjParams.applyPartialClustering) { + // A mapping from the common partition values to how many splits the partition + // should contain. + val commonPartValuesMap = spjParams.commonPartitionValues .get .map(t => (InternalRowComparableWrapper(t._1, p.expressions), t._2)) .toMap - val nestGroupedPartitions = groupedPartitions.map { - case (partValue, splits) => - // `commonPartValuesMap` should contain the part value since it's the super set. - val numSplits = commonPartValuesMap - .get(InternalRowComparableWrapper(partValue, p.expressions)) - assert(numSplits.isDefined, s"Partition value $partValue does not exist in " + - "common partition values from Spark plan") - - val newSplits = if (spjParams.replicatePartitions) { - // We need to also replicate partitions according to the other side of join - Seq.fill(numSplits.get)(splits) - } else { - // Not grouping by partition values: this could be the side with partially - // clustered distribution. Because of dynamic filtering, we'll need to check if - // the final number of splits of a partition is smaller than the original - // number, and fill with empty splits if so. This is necessary so that both - // sides of a join will have the same number of partitions & splits. - splits.map(Seq(_)).padTo(numSplits.get, Seq.empty) - } - (InternalRowComparableWrapper(partValue, p.expressions), newSplits) + val nestGroupedPartitions = groupedPartitions.map { case (partValue, splits) => + // `commonPartValuesMap` should contain the part value since it's the super set. + val numSplits = commonPartValuesMap + .get(InternalRowComparableWrapper(partValue, p.expressions)) + assert(numSplits.isDefined, s"Partition value $partValue does not exist in " + + "common partition values from Spark plan") + + val newSplits = if (spjParams.replicatePartitions) { + // We need to also replicate partitions according to the other side of join + Seq.fill(numSplits.get)(splits) + } else { + // Not grouping by partition values: this could be the side with partially + // clustered distribution. Because of dynamic filtering, we'll need to check if + // the final number of splits of a partition is smaller than the original + // number, and fill with empty splits if so. This is necessary so that both + // sides of a join will have the same number of partitions & splits. + splits.map(Seq(_)).padTo(numSplits.get, Seq.empty) } + (InternalRowComparableWrapper(partValue, p.expressions), newSplits) + } - // Now fill missing partition keys with empty partitions - val partitionMapping = nestGroupedPartitions.toMap - finalPartitions = spjParams.commonPartitionValues.get.flatMap { - case (partValue, numSplits) => - // Use empty partition for those partition values that are not present. - partitionMapping.getOrElse( - InternalRowComparableWrapper(partValue, p.expressions), - Seq.fill(numSplits)(Seq.empty)) - } - } else { - // either `commonPartitionValues` is not defined, or it is defined but - // `applyPartialClustering` is false. - val partitionMapping = groupedPartitions.map { case (row, parts) => - InternalRowComparableWrapper(row, p.expressions) -> parts - }.toMap - - // In case `commonPartitionValues` is not defined (e.g., SPJ is not used), there - // could exist duplicated partition values, as partition grouping is not done - // at the beginning and postponed to this method. It is important to use unique - // partition values here so that grouped partitions won't get duplicated. - finalPartitions = p.uniquePartitionValues.map { partValue => - // Use empty partition for those partition values that are not present + // Now fill missing partition keys with empty partitions + val partitionMapping = nestGroupedPartitions.toMap + finalPartitions = spjParams.commonPartitionValues.get.flatMap { + case (partValue, numSplits) => + // Use empty partition for those partition values that are not present. partitionMapping.getOrElse( - InternalRowComparableWrapper(partValue, p.expressions), Seq.empty) - } + InternalRowComparableWrapper(partValue, p.expressions), + Seq.fill(numSplits)(Seq.empty)) } } else { - val partitionMapping = finalPartitions.map { parts => - val row = parts.head.asInstanceOf[HasPartitionKey].partitionKey() - InternalRowComparableWrapper(row, p.expressions) -> parts + // either `commonPartitionValues` is not defined, or it is defined but + // `applyPartialClustering` is false. + val partitionMapping = groupedPartitions.map { case (partValue, splits) => + InternalRowComparableWrapper(partValue, p.expressions) -> splits }.toMap - finalPartitions = p.partitionValues.map { partValue => + + // In case `commonPartitionValues` is not defined (e.g., SPJ is not used), there + // could exist duplicated partition values, as partition grouping is not done + // at the beginning and postponed to this method. It is important to use unique + // partition values here so that grouped partitions won't get duplicated. + finalPartitions = p.uniquePartitionValues.map { partValue => // Use empty partition for those partition values that are not present partitionMapping.getOrElse( InternalRowComparableWrapper(partValue, p.expressions), Seq.empty) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala index f688d3514d9..94667fbd00c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala @@ -62,8 +62,9 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { redact(result) } - def partitions: Seq[Seq[InputPartition]] = - groupedPartitions.map(_.map(_._2)).getOrElse(inputPartitions.map(Seq(_))) + def partitions: Seq[Seq[InputPartition]] = { + groupedPartitions.map(_.groupedParts.map(_.parts)).getOrElse(inputPartitions.map(Seq(_))) + } /** * Shorthand for calling redact() without specifying redacting rules @@ -94,8 +95,10 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { keyGroupedPartitioning match { case Some(exprs) if KeyGroupedPartitioning.supportsExpressions(exprs) => groupedPartitions - .map { partitionValues => - KeyGroupedPartitioning(exprs, partitionValues.size, partitionValues.map(_._1)) + .map { keyGroupedPartsInfo => + val keyGroupedParts = keyGroupedPartsInfo.groupedParts + KeyGroupedPartitioning(exprs, keyGroupedParts.size, keyGroupedParts.map(_.value), + keyGroupedPartsInfo.originalParts.map(_.partitionKey())) } .getOrElse(super.outputPartitioning) case _ => @@ -103,7 +106,7 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { } } - @transient lazy val groupedPartitions: Option[Seq[(InternalRow, Seq[InputPartition])]] = { + @transient lazy val groupedPartitions: Option[KeyGroupedPartitionInfo] = { // Early check if we actually need to materialize the input partitions. keyGroupedPartitioning match { case Some(_) => groupPartitions(inputPartitions) @@ -117,24 +120,21 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { * - all input partitions implement [[HasPartitionKey]] * - `keyGroupedPartitioning` is set * - * The result, if defined, is a list of tuples where the first element is a partition value, - * and the second element is a list of input partitions that share the same partition value. + * The result, if defined, is a [[KeyGroupedPartitionInfo]] which contains a list of + * [[KeyGroupedPartition]], as well as a list of partition values from the original input splits, + * sorted according to the partition keys in ascending order. * * A non-empty result means each partition is clustered on a single key and therefore eligible * for further optimizations to eliminate shuffling in some operations such as join and aggregate. */ - def groupPartitions( - inputPartitions: Seq[InputPartition], - groupSplits: Boolean = !conf.v2BucketingPushPartValuesEnabled || - !conf.v2BucketingPartiallyClusteredDistributionEnabled): - Option[Seq[(InternalRow, Seq[InputPartition])]] = { - + def groupPartitions(inputPartitions: Seq[InputPartition]): Option[KeyGroupedPartitionInfo] = { if (!SQLConf.get.v2BucketingEnabled) return None + keyGroupedPartitioning.flatMap { expressions => val results = inputPartitions.takeWhile { case _: HasPartitionKey => true case _ => false - }.map(p => (p.asInstanceOf[HasPartitionKey].partitionKey(), p)) + }.map(p => (p.asInstanceOf[HasPartitionKey].partitionKey(), p.asInstanceOf[HasPartitionKey])) if (results.length != inputPartitions.length || inputPartitions.isEmpty) { // Not all of the `InputPartitions` implements `HasPartitionKey`, therefore skip here. @@ -143,32 +143,25 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { // also sort the input partitions according to their partition key order. This ensures // a canonical order from both sides of a bucketed join, for example. val partitionDataTypes = expressions.map(_.dataType) - val partitionOrdering: Ordering[(InternalRow, Seq[InputPartition])] = { + val partitionOrdering: Ordering[(InternalRow, InputPartition)] = { RowOrdering.createNaturalAscendingOrdering(partitionDataTypes).on(_._1) } - - val partitions = if (groupSplits) { - // Group the splits by their partition value - results + val sortedKeyToPartitions = results.sorted(partitionOrdering) + val groupedPartitions = sortedKeyToPartitions .map(t => (InternalRowComparableWrapper(t._1, expressions), t._2)) .groupBy(_._1) .toSeq - .map { - case (key, s) => (key.row, s.map(_._2)) - } - } else { - // No splits grouping, each split will become a separate Spark partition - results.map(t => (t._1, Seq(t._2))) - } + .map { case (key, s) => KeyGroupedPartition(key.row, s.map(_._2)) } - Some(partitions.sorted(partitionOrdering)) + Some(KeyGroupedPartitionInfo(groupedPartitions, sortedKeyToPartitions.map(_._2))) } } } override def outputOrdering: Seq[SortOrder] = { // when multiple partitions are grouped together, ordering inside partitions is not preserved - val partitioningPreservesOrdering = groupedPartitions.forall(_.forall(_._2.length <= 1)) + val partitioningPreservesOrdering = groupedPartitions + .forall(_.groupedParts.forall(_.parts.length <= 1)) ordering.filter(_ => partitioningPreservesOrdering).getOrElse(super.outputOrdering) } @@ -217,3 +210,19 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { } } } + +/** + * A key-grouped Spark partition, which could consist of multiple input splits + * + * @param value the partition value shared by all the input splits + * @param parts the input splits that are grouped into a single Spark partition + */ +private[v2] case class KeyGroupedPartition(value: InternalRow, parts: Seq[InputPartition]) + +/** + * Information about key-grouped partitions, which contains a list of grouped partitions as well + * as the original input partitions before the grouping. + */ +private[v2] case class KeyGroupedPartitionInfo( + groupedParts: Seq[KeyGroupedPartition], + originalParts: Seq[HasPartitionKey]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 42c880e7c62..f8e6fd1d016 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -288,12 +288,12 @@ case class EnsureRequirements( reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, rightExpressions, rightKeys) .orElse(reorderJoinKeysRecursively( leftKeys, rightKeys, leftPartitioning, None)) - case (Some(KeyGroupedPartitioning(clustering, _, _)), _) => + case (Some(KeyGroupedPartitioning(clustering, _, _, _)), _) => val leafExprs = clustering.flatMap(_.collectLeaves()) reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, leftKeys) .orElse(reorderJoinKeysRecursively( leftKeys, rightKeys, None, rightPartitioning)) - case (_, Some(KeyGroupedPartitioning(clustering, _, _))) => + case (_, Some(KeyGroupedPartitioning(clustering, _, _, _))) => val leafExprs = clustering.flatMap(_.collectLeaves()) reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, rightKeys) .orElse(reorderJoinKeysRecursively( @@ -483,7 +483,10 @@ case class EnsureRequirements( s"'$joinType'. Skipping partially clustered distribution.") replicateRightSide = false } else { - val partValues = if (replicateLeftSide) rightPartValues else leftPartValues + // In partially clustered distribution, we should use un-grouped partition values + val spec = if (replicateLeftSide) rightSpec else leftSpec + val partValues = spec.partitioning.originalPartitionValues + val numExpectedPartitions = partValues .map(InternalRowComparableWrapper(_, partitionExprs)) .groupBy(identity) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 750b96dc83d..509f1e6a1e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -301,7 +301,7 @@ object ShuffleExchangeExec { ascending = true, samplePointsPerPartitionHint = SQLConf.get.rangeExchangeSampleSizePerPartition) case SinglePartition => new ConstantPartitioner - case k @ KeyGroupedPartitioning(expressions, n, _) => + case k @ KeyGroupedPartitioning(expressions, n, _, _) => val valueMap = k.uniquePartitionValues.zipWithIndex.map { case (partition, index) => (partition.toSeq(expressions.map(_.dataType)), index) }.toMap @@ -332,7 +332,7 @@ object ShuffleExchangeExec { val projection = UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes) row => projection(row) case SinglePartition => identity - case KeyGroupedPartitioning(expressions, _, _) => + case KeyGroupedPartitioning(expressions, _, _, _) => row => bindReferences(expressions, outputAttributes).map(_.eval(row)) case _ => throw new IllegalStateException(s"Exchange not implemented for $newPartitioning") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala index f4317e63276..1a0efa7c4aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala @@ -51,9 +51,9 @@ abstract class DistributionAndOrderingSuiteBase plan: QueryPlan[T]): Partitioning = partitioning match { case HashPartitioning(exprs, numPartitions) => HashPartitioning(exprs.map(resolveAttrs(_, plan)), numPartitions) - case KeyGroupedPartitioning(clustering, numPartitions, partitionValues) => - KeyGroupedPartitioning(clustering.map(resolveAttrs(_, plan)), numPartitions, - partitionValues) + case KeyGroupedPartitioning(clustering, numPartitions, partValues, originalPartValues) => + KeyGroupedPartitioning(clustering.map(resolveAttrs(_, plan)), numPartitions, partValues, + originalPartValues) case PartitioningCollection(partitionings) => PartitioningCollection(partitionings.map(resolvePartitioning(_, plan))) case RangePartitioning(ordering, numPartitions) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 5b5e4021173..b22aba61aab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -131,7 +131,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { // Has exactly one partition. val partitionValues = Seq(31).map(v => InternalRow.fromSeq(Seq(v))) checkQueryPlan(df, distribution, - physical.KeyGroupedPartitioning(distribution.clustering, 1, partitionValues)) + physical.KeyGroupedPartitioning(distribution.clustering, 1, partitionValues, partitionValues)) } test("non-clustered distribution: no V2 catalog") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala index 3c9b92e5f66..3b0bb088a10 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala @@ -1127,7 +1127,7 @@ class EnsureRequirementsSuite extends SharedSparkSession { EnsureRequirements.apply(smjExec) match { case ShuffledHashJoinExec(_, _, _, _, _, DummySparkPlan(_, _, left: KeyGroupedPartitioning, _, _), - ShuffleExchangeExec(KeyGroupedPartitioning(attrs, 4, pv), + ShuffleExchangeExec(KeyGroupedPartitioning(attrs, 4, pv, _), DummySparkPlan(_, _, SinglePartition, _, _), _, _), _) => assert(left.expressions == a1 :: Nil) assert(attrs == a1 :: Nil) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org