sunchao commented on code in PR #42306:
URL: https://github.com/apache/spark/pull/42306#discussion_r1287768581
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala:
##########
@@ -344,7 +344,11 @@ case class KeyGroupedPartitioning(
} else {
// We'll need to find leaf attributes from the partition
expressions first.
val attributes = expressions.flatMap(_.collectLeaves())
- attributes.forall(x =>
requiredClustering.exists(_.semanticEquals(x)))
+
+ // Support only when all cluster key have an associated partition
expression key
+ requiredClustering.exists(x =>
attributes.exists(_.semanticEquals(x))) &&
+ // and if all partition expression contain only a single
partition key.
+ expressions.forall(_.collectLeaves().size == 1)
Review Comment:
hmm why this condition?
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala:
##########
@@ -344,7 +344,11 @@ case class KeyGroupedPartitioning(
} else {
// We'll need to find leaf attributes from the partition
expressions first.
val attributes = expressions.flatMap(_.collectLeaves())
- attributes.forall(x =>
requiredClustering.exists(_.semanticEquals(x)))
+
+ // Support only when all cluster key have an associated partition
expression key
Review Comment:
should we consider the new flag here and still keep the old behavior if it
is not enabled?
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala:
##########
@@ -701,41 +705,78 @@ case class KeyGroupedShuffleSpec(
case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning,
otherDistribution) =>
distribution.clustering.length == otherDistribution.clustering.length &&
numPartitions == other.numPartitions && areKeysCompatible(otherSpec) &&
-
partitioning.partitionValues.zip(otherPartitioning.partitionValues).forall {
- case (left, right) =>
- InternalRowComparableWrapper(left, partitioning.expressions)
- .equals(InternalRowComparableWrapper(right,
partitioning.expressions))
- }
+ isPartitioningCompatible(otherPartitioning)
case ShuffleSpecCollection(specs) =>
specs.exists(isCompatibleWith)
case _ => false
}
+ def isPartitioningCompatible(otherPartitioning: KeyGroupedPartitioning):
Boolean = {
+ val clusterKeySize = keyPositions.size
+ partitioning.partitionValues.zip(otherPartitioning.partitionValues)
+ .forall {
+ case (left, right) =>
+ val leftTypes = partitioning.expressions.map(_.dataType)
Review Comment:
nit: perhaps add some comments here since it's not that clear. Also we can
extract:
```scala
val rightTypes = partitioning.expressions.map(_.dataType)
val rightVals =
right.toSeq(rightTypes).take(clusterKeySize).toArray
val newRight = new GenericInternalRow(rightVals)
```
into a separate util method.
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala:
##########
@@ -701,41 +705,78 @@ case class KeyGroupedShuffleSpec(
case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning,
otherDistribution) =>
distribution.clustering.length == otherDistribution.clustering.length &&
numPartitions == other.numPartitions && areKeysCompatible(otherSpec) &&
-
partitioning.partitionValues.zip(otherPartitioning.partitionValues).forall {
- case (left, right) =>
- InternalRowComparableWrapper(left, partitioning.expressions)
- .equals(InternalRowComparableWrapper(right,
partitioning.expressions))
- }
+ isPartitioningCompatible(otherPartitioning)
case ShuffleSpecCollection(specs) =>
specs.exists(isCompatibleWith)
case _ => false
}
+ def isPartitioningCompatible(otherPartitioning: KeyGroupedPartitioning):
Boolean = {
+ val clusterKeySize = keyPositions.size
+ partitioning.partitionValues.zip(otherPartitioning.partitionValues)
+ .forall {
+ case (left, right) =>
+ val leftTypes = partitioning.expressions.map(_.dataType)
+ val leftVals = left.toSeq(leftTypes).take(clusterKeySize).toArray
+ val newLeft = new GenericInternalRow(leftVals)
+
+ val rightTypes = partitioning.expressions.map(_.dataType)
+ val rightVals = right.toSeq(rightTypes).take(clusterKeySize).toArray
+ val newRight = new GenericInternalRow(rightVals)
+
+ InternalRowComparableWrapper(newLeft,
partitioning.expressions.take(clusterKeySize))
+ .equals(InternalRowComparableWrapper(
+ newRight, partitioning.expressions.take(clusterKeySize)))
+ }
+ }
+
// Whether the partition keys (i.e., partition expressions) are compatible
between this and the
// `other` spec.
def areKeysCompatible(other: KeyGroupedShuffleSpec): Boolean = {
- val expressions = partitioning.expressions
- val otherExpressions = other.partitioning.expressions
-
- expressions.length == otherExpressions.length && {
- val otherKeyPositions = other.keyPositions
- keyPositions.zip(otherKeyPositions).forall { case (left, right) =>
- left.intersect(right).nonEmpty
+ partitionExpressionsCompatible(other) &&
+ KeyGroupedShuffleSpec.keyPositionsCompatible(
+ keyPositions, other.keyPositions
+ )
+ }
+
+ // Whether the partition keys (i.e., partition expressions) that also are in
the set of
+ // cluster keys are compatible between this and the 'other' spec.
Review Comment:
nit: 'other' -> `other`
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala:
##########
@@ -144,8 +159,25 @@ case class BatchScanExec(
s"${SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key} " +
"is enabled")
- val groupedPartitions =
groupPartitions(finalPartitions.map(_.head),
- groupSplits = true).get
+ // In the case where we replicate partitions, we have grouped
+ // the partitions by the join key if they differ
+ val groupByExpressions =
Review Comment:
Can we override `KeyGroupedPartitioning` method in this class, and wrap the
logic of handling join keys in the method? We can return a new
`KeyGroupedPartitioning` instance whose `expressions`, `partitionValues` are
"projected" on the join keys.
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala:
##########
@@ -1500,6 +1500,18 @@ object SQLConf {
.booleanConf
.createWithDefault(false)
+ val V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS =
+
buildConf("spark.sql.sources.v2.bucketing.allowJoinKeysSubsetOfPartitionKeys.enabled")
+ .doc("Whether to allow storage-partition join in the case where join
keys are" +
+ "a subset of the partition keys of the source tables. At planning
time, " +
+ "Spark will group the partitions by only those keys that are in the
join keys." +
+ "This is currently enabled only if
spark.sql.sources.v2.bucketing.pushPartValues.enabled " +
+ "is also enabled."
+ )
+ .version("3.5.0")
Review Comment:
4.0.0
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala:
##########
@@ -1500,6 +1500,18 @@ object SQLConf {
.booleanConf
.createWithDefault(false)
+ val V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS =
+
buildConf("spark.sql.sources.v2.bucketing.allowJoinKeysSubsetOfPartitionKeys.enabled")
+ .doc("Whether to allow storage-partition join in the case where join
keys are" +
+ "a subset of the partition keys of the source tables. At planning
time, " +
+ "Spark will group the partitions by only those keys that are in the
join keys." +
+ "This is currently enabled only if
spark.sql.sources.v2.bucketing.pushPartValues.enabled " +
+ "is also enabled."
+ )
+ .version("3.5.0")
+ .booleanConf
+ .createWithDefault(true)
Review Comment:
we should default to false
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala:
##########
@@ -523,23 +546,25 @@ case class EnsureRequirements(
joinType == LeftAnti || joinType == LeftOuter
}
- // Populate the common partition values down to the scan nodes
- private def populatePartitionValues(
+ // Populate the storage partition join params down to the scan nodes
+ private def populateStoragePartitionJoinParams(
plan: SparkPlan,
values: Seq[(InternalRow, Int)],
+ partitionGroupByPositions: Option[Seq[Boolean]],
applyPartialClustering: Boolean,
replicatePartitions: Boolean): SparkPlan = plan match {
case scan: BatchScanExec =>
scan.copy(
spjParams = scan.spjParams.copy(
commonPartitionValues = Some(values),
+ partitionGroupByPositions = partitionGroupByPositions,
applyPartialClustering = applyPartialClustering,
replicatePartitions = replicatePartitions
)
)
case node =>
- node.mapChildren(child => populatePartitionValues(
- child, values, applyPartialClustering, replicatePartitions))
+ node.mapChildren(child => populateStoragePartitionJoinParams(
+ child, values, partitionGroupByPositions, applyPartialClustering,
replicatePartitions))
Review Comment:
Instead of populating `partitionGroupByPositions`, can we populate
`StoragePartitionJoinParams.keyGroupedPartitioning` instead? which can be the
subset of expressions that participate in the join.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]