This is an automated email from the ASF dual-hosted git repository. sunchao pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 11abc64a731d [SPARK-47094][SQL] SPJ : Dynamically rebalance number of buckets when they are not equal 11abc64a731d is described below commit 11abc64a731d0e75d837994183396e6da9c45310 Author: Szehon Ho <szehon.apa...@gmail.com> AuthorDate: Fri Apr 5 20:11:54 2024 -0700 [SPARK-47094][SQL] SPJ : Dynamically rebalance number of buckets when they are not equal ### What changes were proposed in this pull request? -- Allow SPJ between 'compatible' bucket funtions -- Add a mechanism to define 'reducible' functions, one function whose output can be 'reduced' to another for all inputs. ### Why are the changes needed? -- SPJ currently applies only if the partition transform expressions on both sides are identifical. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added new tests in KeyGroupedPartitioningSuite ### Was this patch authored or co-authored using generative AI tooling? No Closes #45267 from szehon-ho/spj-uneven-buckets. Authored-by: Szehon Ho <szehon.apa...@gmail.com> Signed-off-by: Chao Sun <c...@openai.com> --- .../sql/connector/catalog/functions/Reducer.java | 42 ++ .../catalog/functions/ReducibleFunction.java | 106 +++++ .../catalyst/expressions/TransformExpression.scala | 57 ++- .../sql/catalyst/plans/physical/partitioning.scala | 50 ++- .../org/apache/spark/sql/internal/SQLConf.scala | 15 + .../execution/datasources/v2/BatchScanExec.scala | 20 +- .../execution/exchange/EnsureRequirements.scala | 50 ++- .../connector/KeyGroupedPartitioningSuite.scala | 474 +++++++++++++++++++++ .../catalog/functions/transformFunctions.scala | 22 +- 9 files changed, 821 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Reducer.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Reducer.java new file mode 100644 index 000000000000..561d66092d64 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Reducer.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connector.catalog.functions; + +import org.apache.spark.annotation.Evolving; + +/** + * A 'reducer' for output of user-defined functions. + * + * @see ReducibleFunction + * + * A user defined function f_source(x) is 'reducible' on another user_defined function + * f_target(x) if + * <ul> + * <li> There exists a reducer function r(x) such that r(f_source(x)) = f_target(x) for + * all input x, or </li> + * <li> More generally, there exists reducer functions r1(x) and r2(x) such that + * r1(f_source(x)) = r2(f_target(x)) for all input x. </li> + * </ul> + * + * @param <I> reducer input type + * @param <O> reducer output type + * @since 4.0.0 + */ +@Evolving +public interface Reducer<I, O> { + O reduce(I arg); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java new file mode 100644 index 000000000000..ef1a14e50cda --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connector.catalog.functions; + +import org.apache.spark.annotation.Evolving; + +/** + * Base class for user-defined functions that can be 'reduced' on another function. + * + * A function f_source(x) is 'reducible' on another function f_target(x) if + * <ul> + * <li> There exists a reducer function r(x) such that r(f_source(x)) = f_target(x) + * for all input x, or </li> + * <li> More generally, there exists reducer functions r1(x) and r2(x) such that + * r1(f_source(x)) = r2(f_target(x)) for all input x. </li> + * </ul> + * <p> + * Examples: + * <ul> + * <li>Bucket functions where one side has reducer + * <ul> + * <li>f_source(x) = bucket(4, x)</li> + * <li>f_target(x) = bucket(2, x)</li> + * <li>r(x) = x % 2</li> + * </ul> + * + * <li>Bucket functions where both sides have reducer + * <ul> + * <li>f_source(x) = bucket(16, x)</li> + * <li>f_target(x) = bucket(12, x)</li> + * <li>r1(x) = x % 4</li> + * <li>r2(x) = x % 4</li> + * </ul> + * + * <li>Date functions + * <ul> + * <li>f_source(x) = days(x)</li> + * <li>f_target(x) = hours(x)</li> + * <li>r(x) = x / 24</li> + * </ul> + * </ul> + * @param <I> reducer function input type + * @param <O> reducer function output type + * @since 4.0.0 + */ +@Evolving +public interface ReducibleFunction<I, O> { + + /** + * This method is for the bucket function. + * + * If this bucket function is 'reducible' on another bucket function, + * return the {@link Reducer} function. + * <p> + * For example, to return reducer for reducing f_source = bucket(4, x) on f_target = bucket(2, x) + * <ul> + * <li>thisBucketFunction = bucket</li> + * <li>thisNumBuckets = 4</li> + * <li>otherBucketFunction = bucket</li> + * <li>otherNumBuckets = 2</li> + * </ul> + * + * @param thisNumBuckets parameter for this function + * @param otherBucketFunction the other parameterized function + * @param otherNumBuckets parameter for the other function + * @return a reduction function if it is reducible, null if not + */ + default Reducer<I, O> reducer( + int thisNumBuckets, + ReducibleFunction<?, ?> otherBucketFunction, + int otherNumBuckets) { + throw new UnsupportedOperationException(); + } + + /** + * This method is for all other functions. + * + * If this function is 'reducible' on another function, return the {@link Reducer} function. + * <p> + * Example of reducing f_source = days(x) on f_target = hours(x) + * <ul> + * <li>thisFunction = days</li> + * <li>otherFunction = hours</li> + * </ul> + * + * @param otherFunction the other function + * @return a reduction function if it is reducible, null if not. + */ + default Reducer<I, O> reducer(ReducibleFunction<?, ?> otherFunction) { + throw new UnsupportedOperationException(); + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala index 8412de554b71..d37c9d9f6452 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.connector.catalog.functions.BoundFunction +import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, Reducer, ReducibleFunction} import org.apache.spark.sql.types.DataType /** @@ -54,6 +54,61 @@ case class TransformExpression( false } + /** + * Whether this [[TransformExpression]]'s function is compatible with the `other` + * [[TransformExpression]]'s function. + * + * This is true if both are instances of [[ReducibleFunction]] and there exists a [[Reducer]] r(x) + * such that r(t1(x)) = t2(x), or r(t2(x)) = t1(x), for all input x. + * + * @param other the transform expression to compare to + * @return true if compatible, false if not + */ + def isCompatible(other: TransformExpression): Boolean = { + if (isSameFunction(other)) { + true + } else { + (function, other.function) match { + case (f: ReducibleFunction[_, _], o: ReducibleFunction[_, _]) => + val thisReducer = reducer(f, numBucketsOpt, o, other.numBucketsOpt) + val otherReducer = reducer(o, other.numBucketsOpt, f, numBucketsOpt) + thisReducer.isDefined || otherReducer.isDefined + case _ => false + } + } + } + + /** + * Return a [[Reducer]] for this transform expression on another + * on the transform expression. + * <p> + * A [[Reducer]] exists for a transform expression function if it is + * 'reducible' on the other expression function. + * <p> + * @return reducer function or None if not reducible on the other transform expression + */ + def reducers(other: TransformExpression): Option[Reducer[_, _]] = { + (function, other.function) match { + case(e1: ReducibleFunction[_, _], e2: ReducibleFunction[_, _]) => + reducer(e1, numBucketsOpt, e2, other.numBucketsOpt) + case _ => None + } + } + + // Return a Reducer for a reducible function on another reducible function + private def reducer( + thisFunction: ReducibleFunction[_, _], + thisNumBucketsOpt: Option[Int], + otherFunction: ReducibleFunction[_, _], + otherNumBucketsOpt: Option[Int]): Option[Reducer[_, _]] = { + val res = (thisNumBucketsOpt, otherNumBucketsOpt) match { + case (Some(numBuckets), Some(otherNumBuckets)) => + thisFunction.reducer(numBuckets, otherFunction, otherNumBuckets) + case _ => thisFunction.reducer(otherFunction) + } + Option(res) + } + override def dataType: DataType = function.resultType() override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index c98a2a92a3ab..2364130f79e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -24,6 +24,7 @@ import org.apache.spark.{SparkException, SparkUnsupportedOperationException} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper +import org.apache.spark.sql.connector.catalog.functions.Reducer import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, IntegerType} @@ -833,10 +834,42 @@ case class KeyGroupedShuffleSpec( (left, right) match { case (_: LeafExpression, _: LeafExpression) => true case (left: TransformExpression, right: TransformExpression) => - left.isSameFunction(right) + if (SQLConf.get.v2BucketingPushPartValuesEnabled && + !SQLConf.get.v2BucketingPartiallyClusteredDistributionEnabled && + SQLConf.get.v2BucketingAllowCompatibleTransforms) { + left.isCompatible(right) + } else { + left.isSameFunction(right) + } case _ => false } + /** + * Return a set of [[Reducer]] for the partition expressions of this shuffle spec, + * on the partition expressions of another shuffle spec. + * <p> + * A [[Reducer]] exists for a partition expression function of this shuffle spec if it is + * 'reducible' on the corresponding partition expression function of the other shuffle spec. + * <p> + * If a value is returned, there must be one [[Reducer]] per partition expression. + * A None value in the set indicates that the particular partition expression is not reducible + * on the corresponding expression on the other shuffle spec. + * <p> + * Returning none also indicates that none of the partition expressions can be reduced on the + * corresponding expression on the other shuffle spec. + * + * @param other other key-grouped shuffle spec + */ + def reducers(other: KeyGroupedShuffleSpec): Option[Seq[Option[Reducer[_, _]]]] = { + val results = partitioning.expressions.zip(other.partitioning.expressions).map { + case (e1: TransformExpression, e2: TransformExpression) => e1.reducers(e2) + case (_, _) => None + } + + // optimize to not return a value, if none of the partition expressions are reducible + if (results.forall(p => p.isEmpty)) None else Some(results) + } + override def canCreatePartitioning: Boolean = SQLConf.get.v2BucketingShuffleEnabled && // Only support partition expressions are AttributeReference for now partitioning.expressions.forall(_.isInstanceOf[AttributeReference]) @@ -846,6 +879,21 @@ case class KeyGroupedShuffleSpec( } } +object KeyGroupedShuffleSpec { + def reducePartitionValue( + row: InternalRow, + expressions: Seq[Expression], + reducers: Seq[Option[Reducer[_, _]]]): + InternalRowComparableWrapper = { + val partitionVals = row.toSeq(expressions.map(_.dataType)) + val reducedRow = partitionVals.zip(reducers).map{ + case (v, Some(reducer: Reducer[Any, Any])) => reducer.reduce(v) + case (v, _) => v + }.toArray + InternalRowComparableWrapper(new GenericInternalRow(reducedRow), expressions) + } +} + case class ShuffleSpecCollection(specs: Seq[ShuffleSpec]) extends ShuffleSpec { override def isCompatibleWith(other: ShuffleSpec): Boolean = { specs.exists(_.isCompatibleWith(other)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 9f07722528e8..73cb4fba8637 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1558,6 +1558,18 @@ object SQLConf { .booleanConf .createWithDefault(false) + val V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS = + buildConf("spark.sql.sources.v2.bucketing.allowCompatibleTransforms.enabled") + .doc("Whether to allow storage-partition join in the case where the partition transforms " + + "are compatible but not identical. This config requires both " + + s"${V2_BUCKETING_ENABLED.key} and ${V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key} to be " + + s"enabled and ${V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key} " + + "to be disabled." + ) + .version("4.0.0") + .booleanConf + .createWithDefault(false) + val BUCKETING_MAX_BUCKETS = buildConf("spark.sql.sources.bucketing.maxBuckets") .doc("The maximum number of buckets allowed.") .version("2.4.0") @@ -5323,6 +5335,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def v2BucketingAllowJoinKeysSubsetOfPartitionKeys: Boolean = getConf(SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS) + def v2BucketingAllowCompatibleTransforms: Boolean = + getConf(SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS) + def dataFrameSelfJoinAutoResolveAmbiguity: Boolean = getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index 7cce59904018..f949dbf71a37 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -24,9 +24,10 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, Partitioning, SinglePartition} +import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, KeyGroupedShuffleSpec, Partitioning, SinglePartition} import org.apache.spark.sql.catalyst.util.{truncatedString, InternalRowComparableWrapper} import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.connector.catalog.functions.Reducer import org.apache.spark.sql.connector.read._ import org.apache.spark.util.ArrayImplicits._ @@ -164,6 +165,18 @@ case class BatchScanExec( (groupedParts, expressions) } + // Also re-group the partitions if we are reducing compatible partition expressions + val finalGroupedPartitions = spjParams.reducers match { + case Some(reducers) => + val result = groupedPartitions.groupBy { case (row, _) => + KeyGroupedShuffleSpec.reducePartitionValue(row, partExpressions, reducers) + }.map { case (wrapper, splits) => (wrapper.row, splits.flatMap(_._2)) }.toSeq + val rowOrdering = RowOrdering.createNaturalAscendingOrdering( + partExpressions.map(_.dataType)) + result.sorted(rowOrdering.on((t: (InternalRow, _)) => t._1)) + case _ => groupedPartitions + } + // When partially clustered, the input partitions are not grouped by partition // values. Here we'll need to check `commonPartitionValues` and decide how to group // and replicate splits within a partition. @@ -174,7 +187,7 @@ case class BatchScanExec( .get .map(t => (InternalRowComparableWrapper(t._1, partExpressions), t._2)) .toMap - val nestGroupedPartitions = groupedPartitions.map { case (partValue, splits) => + val nestGroupedPartitions = finalGroupedPartitions.map { case (partValue, splits) => // `commonPartValuesMap` should contain the part value since it's the super set. val numSplits = commonPartValuesMap .get(InternalRowComparableWrapper(partValue, partExpressions)) @@ -207,7 +220,7 @@ case class BatchScanExec( } else { // either `commonPartitionValues` is not defined, or it is defined but // `applyPartialClustering` is false. - val partitionMapping = groupedPartitions.map { case (partValue, splits) => + val partitionMapping = finalGroupedPartitions.map { case (partValue, splits) => InternalRowComparableWrapper(partValue, partExpressions) -> splits }.toMap @@ -259,6 +272,7 @@ case class StoragePartitionJoinParams( keyGroupedPartitioning: Option[Seq[Expression]] = None, joinKeyPositions: Option[Seq[Int]] = None, commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None, + reducers: Option[Seq[Option[Reducer[_, _]]]] = None, applyPartialClustering: Boolean = false, replicatePartitions: Boolean = false) { override def equals(other: Any): Boolean = other match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 2a7c1206bb41..a0f74ef6c3d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper +import org.apache.spark.sql.connector.catalog.functions.Reducer import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec} @@ -505,11 +506,28 @@ case class EnsureRequirements( } } - // Now we need to push-down the common partition key to the scan in each child - newLeft = populatePartitionValues(left, mergedPartValues, leftSpec.joinKeyPositions, - applyPartialClustering, replicateLeftSide) - newRight = populatePartitionValues(right, mergedPartValues, rightSpec.joinKeyPositions, - applyPartialClustering, replicateRightSide) + // in case of compatible but not identical partition expressions, we apply 'reduce' + // transforms to group one side's partitions as well as the common partition values + val leftReducers = leftSpec.reducers(rightSpec) + val rightReducers = rightSpec.reducers(leftSpec) + + if (leftReducers.isDefined || rightReducers.isDefined) { + mergedPartValues = reduceCommonPartValues(mergedPartValues, + leftSpec.partitioning.expressions, + leftReducers) + mergedPartValues = reduceCommonPartValues(mergedPartValues, + rightSpec.partitioning.expressions, + rightReducers) + val rowOrdering = RowOrdering + .createNaturalAscendingOrdering(partitionExprs.map(_.dataType)) + mergedPartValues = mergedPartValues.sorted(rowOrdering.on((t: (InternalRow, _)) => t._1)) + } + + // Now we need to push-down the common partition information to the scan in each child + newLeft = populateCommonPartitionInfo(left, mergedPartValues, leftSpec.joinKeyPositions, + leftReducers, applyPartialClustering, replicateLeftSide) + newRight = populateCommonPartitionInfo(right, mergedPartValues, rightSpec.joinKeyPositions, + rightReducers, applyPartialClustering, replicateRightSide) } } @@ -527,11 +545,12 @@ case class EnsureRequirements( joinType == LeftAnti || joinType == LeftOuter } - // Populate the common partition values down to the scan nodes - private def populatePartitionValues( + // Populate the common partition information down to the scan nodes + private def populateCommonPartitionInfo( plan: SparkPlan, values: Seq[(InternalRow, Int)], joinKeyPositions: Option[Seq[Int]], + reducers: Option[Seq[Option[Reducer[_, _]]]], applyPartialClustering: Boolean, replicatePartitions: Boolean): SparkPlan = plan match { case scan: BatchScanExec => @@ -539,13 +558,26 @@ case class EnsureRequirements( spjParams = scan.spjParams.copy( commonPartitionValues = Some(values), joinKeyPositions = joinKeyPositions, + reducers = reducers, applyPartialClustering = applyPartialClustering, replicatePartitions = replicatePartitions ) ) case node => - node.mapChildren(child => populatePartitionValues( - child, values, joinKeyPositions, applyPartialClustering, replicatePartitions)) + node.mapChildren(child => populateCommonPartitionInfo( + child, values, joinKeyPositions, reducers, applyPartialClustering, replicatePartitions)) + } + + private def reduceCommonPartValues( + commonPartValues: Seq[(InternalRow, Int)], + expressions: Seq[Expression], + reducers: Option[Seq[Option[Reducer[_, _]]]]) = { + reducers match { + case Some(reducers) => commonPartValues.groupBy { case (row, _) => + KeyGroupedShuffleSpec.reducePartitionValue(row, expressions, reducers) + }.map{ case(wrapper, splits) => (wrapper.row, splits.map(_._2).sum) }.toSeq + case _ => commonPartValues + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 7fdc703007c2..ec275fe101fd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -63,11 +63,17 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { Collections.emptyMap[String, String] } private val table: String = "tbl" + private val columns: Array[Column] = Array( Column.create("id", IntegerType), Column.create("data", StringType), Column.create("ts", TimestampType)) + private val columns2: Array[Column] = Array( + Column.create("store_id", IntegerType), + Column.create("dept_id", IntegerType), + Column.create("data", StringType)) + test("clustered distribution: output partitioning should be KeyGroupedPartitioning") { val partitions: Array[Transform] = Array(Expressions.years("ts")) @@ -1309,6 +1315,474 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } } + test("SPARK-47094: Support compatible buckets") { + val table1 = "tab1e1" + val table2 = "table2" + + Seq( + ((2, 4), (4, 2)), + ((4, 2), (2, 4)), + ((2, 2), (4, 6)), + ((6, 2), (2, 2))).foreach { + case ((table1buckets1, table1buckets2), (table2buckets1, table2buckets2)) => + catalog.clearTables() + + val partition1 = Array(bucket(table1buckets1, "store_id"), + bucket(table1buckets2, "dept_id")) + val partition2 = Array(bucket(table2buckets1, "store_id"), + bucket(table2buckets2, "dept_id")) + + Seq((table1, partition1), (table2, partition2)).foreach { case (tab, part) => + createTable(tab, columns2, part) + val insertStr = s"INSERT INTO testcat.ns.$tab VALUES " + + "(0, 0, 'aa'), " + + "(0, 0, 'ab'), " + // duplicate partition key + "(0, 1, 'ac'), " + + "(0, 2, 'ad'), " + + "(0, 3, 'ae'), " + + "(0, 4, 'af'), " + + "(0, 5, 'ag'), " + + "(1, 0, 'ah'), " + + "(1, 0, 'ai'), " + // duplicate partition key + "(1, 1, 'aj'), " + + "(1, 2, 'ak'), " + + "(1, 3, 'al'), " + + "(1, 4, 'am'), " + + "(1, 5, 'an'), " + + "(2, 0, 'ao'), " + + "(2, 0, 'ap'), " + // duplicate partition key + "(2, 1, 'aq'), " + + "(2, 2, 'ar'), " + + "(2, 3, 'as'), " + + "(2, 4, 'at'), " + + "(2, 5, 'au'), " + + "(3, 0, 'av'), " + + "(3, 0, 'aw'), " + // duplicate partition key + "(3, 1, 'ax'), " + + "(3, 2, 'ay'), " + + "(3, 3, 'az'), " + + "(3, 4, 'ba'), " + + "(3, 5, 'bb'), " + + "(4, 0, 'bc'), " + + "(4, 0, 'bd'), " + // duplicate partition key + "(4, 1, 'be'), " + + "(4, 2, 'bf'), " + + "(4, 3, 'bg'), " + + "(4, 4, 'bh'), " + + "(4, 5, 'bi'), " + + "(5, 0, 'bj'), " + + "(5, 0, 'bk'), " + // duplicate partition key + "(5, 1, 'bl'), " + + "(5, 2, 'bm'), " + + "(5, 3, 'bn'), " + + "(5, 4, 'bo'), " + + "(5, 5, 'bp')" + + // additional unmatched partitions to test push down + val finalStr = if (tab == table1) { + insertStr ++ ", (8, 0, 'xa'), (8, 8, 'xx')" + } else { + insertStr ++ ", (9, 0, 'ya'), (9, 9, 'yy')" + } + + sql(finalStr) + } + + Seq(true, false).foreach { allowJoinKeysSubsetOfPartitionKeys => + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "false", + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> + allowJoinKeysSubsetOfPartitionKeys.toString, + SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") { + val df = sql( + s""" + |${selectWithMergeJoinHint("t1", "t2")} + |t1.store_id, t1.dept_id, t1.data, t2.data + |FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2 + |ON t1.store_id = t2.store_id AND t1.dept_id = t2.dept_id + |ORDER BY t1.store_id, t1.dept_id, t1.data, t2.data + |""".stripMargin) + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, "SPJ should be triggered") + + val scans = collectScans(df.queryExecution.executedPlan).map(_.inputRDD. + partitions.length) + val expectedBuckets = Math.min(table1buckets1, table2buckets1) * + Math.min(table1buckets2, table2buckets2) + assert(scans == Seq(expectedBuckets, expectedBuckets)) + + checkAnswer(df, Seq( + Row(0, 0, "aa", "aa"), + Row(0, 0, "aa", "ab"), + Row(0, 0, "ab", "aa"), + Row(0, 0, "ab", "ab"), + Row(0, 1, "ac", "ac"), + Row(0, 2, "ad", "ad"), + Row(0, 3, "ae", "ae"), + Row(0, 4, "af", "af"), + Row(0, 5, "ag", "ag"), + Row(1, 0, "ah", "ah"), + Row(1, 0, "ah", "ai"), + Row(1, 0, "ai", "ah"), + Row(1, 0, "ai", "ai"), + Row(1, 1, "aj", "aj"), + Row(1, 2, "ak", "ak"), + Row(1, 3, "al", "al"), + Row(1, 4, "am", "am"), + Row(1, 5, "an", "an"), + Row(2, 0, "ao", "ao"), + Row(2, 0, "ao", "ap"), + Row(2, 0, "ap", "ao"), + Row(2, 0, "ap", "ap"), + Row(2, 1, "aq", "aq"), + Row(2, 2, "ar", "ar"), + Row(2, 3, "as", "as"), + Row(2, 4, "at", "at"), + Row(2, 5, "au", "au"), + Row(3, 0, "av", "av"), + Row(3, 0, "av", "aw"), + Row(3, 0, "aw", "av"), + Row(3, 0, "aw", "aw"), + Row(3, 1, "ax", "ax"), + Row(3, 2, "ay", "ay"), + Row(3, 3, "az", "az"), + Row(3, 4, "ba", "ba"), + Row(3, 5, "bb", "bb"), + Row(4, 0, "bc", "bc"), + Row(4, 0, "bc", "bd"), + Row(4, 0, "bd", "bc"), + Row(4, 0, "bd", "bd"), + Row(4, 1, "be", "be"), + Row(4, 2, "bf", "bf"), + Row(4, 3, "bg", "bg"), + Row(4, 4, "bh", "bh"), + Row(4, 5, "bi", "bi"), + Row(5, 0, "bj", "bj"), + Row(5, 0, "bj", "bk"), + Row(5, 0, "bk", "bj"), + Row(5, 0, "bk", "bk"), + Row(5, 1, "bl", "bl"), + Row(5, 2, "bm", "bm"), + Row(5, 3, "bn", "bn"), + Row(5, 4, "bo", "bo"), + Row(5, 5, "bp", "bp") + )) + } + } + } + } + + test("SPARK-47094: Support compatible buckets with common divisor") { + val table1 = "tab1e1" + val table2 = "table2" + + Seq( + ((6, 4), (4, 6)), + ((6, 6), (4, 4)), + ((4, 4), (6, 6)), + ((4, 6), (6, 4))).foreach { + case ((table1buckets1, table1buckets2), (table2buckets1, table2buckets2)) => + catalog.clearTables() + + val partition1 = Array(bucket(table1buckets1, "store_id"), + bucket(table1buckets2, "dept_id")) + val partition2 = Array(bucket(table2buckets1, "store_id"), + bucket(table2buckets2, "dept_id")) + + Seq((table1, partition1), (table2, partition2)).foreach { case (tab, part) => + createTable(tab, columns2, part) + val insertStr = s"INSERT INTO testcat.ns.$tab VALUES " + + "(0, 0, 'aa'), " + + "(0, 0, 'ab'), " + // duplicate partition key + "(0, 1, 'ac'), " + + "(0, 2, 'ad'), " + + "(0, 3, 'ae'), " + + "(0, 4, 'af'), " + + "(0, 5, 'ag'), " + + "(1, 0, 'ah'), " + + "(1, 0, 'ai'), " + // duplicate partition key + "(1, 1, 'aj'), " + + "(1, 2, 'ak'), " + + "(1, 3, 'al'), " + + "(1, 4, 'am'), " + + "(1, 5, 'an'), " + + "(2, 0, 'ao'), " + + "(2, 0, 'ap'), " + // duplicate partition key + "(2, 1, 'aq'), " + + "(2, 2, 'ar'), " + + "(2, 3, 'as'), " + + "(2, 4, 'at'), " + + "(2, 5, 'au'), " + + "(3, 0, 'av'), " + + "(3, 0, 'aw'), " + // duplicate partition key + "(3, 1, 'ax'), " + + "(3, 2, 'ay'), " + + "(3, 3, 'az'), " + + "(3, 4, 'ba'), " + + "(3, 5, 'bb'), " + + "(4, 0, 'bc'), " + + "(4, 0, 'bd'), " + // duplicate partition key + "(4, 1, 'be'), " + + "(4, 2, 'bf'), " + + "(4, 3, 'bg'), " + + "(4, 4, 'bh'), " + + "(4, 5, 'bi'), " + + "(5, 0, 'bj'), " + + "(5, 0, 'bk'), " + // duplicate partition key + "(5, 1, 'bl'), " + + "(5, 2, 'bm'), " + + "(5, 3, 'bn'), " + + "(5, 4, 'bo'), " + + "(5, 5, 'bp')" + + // additional unmatched partitions to test push down + val finalStr = if (tab == table1) { + insertStr ++ ", (8, 0, 'xa'), (8, 8, 'xx')" + } else { + insertStr ++ ", (9, 0, 'ya'), (9, 9, 'yy')" + } + + sql(finalStr) + } + + Seq(true, false).foreach { allowJoinKeysSubsetOfPartitionKeys => + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "false", + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> + allowJoinKeysSubsetOfPartitionKeys.toString, + SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") { + val df = sql( + s""" + |${selectWithMergeJoinHint("t1", "t2")} + |t1.store_id, t1.dept_id, t1.data, t2.data + |FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2 + |ON t1.store_id = t2.store_id AND t1.dept_id = t2.dept_id + |ORDER BY t1.store_id, t1.dept_id, t1.data, t2.data + |""".stripMargin) + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, "SPJ should be triggered") + + val scans = collectScans(df.queryExecution.executedPlan).map(_.inputRDD. + partitions.length) + + def gcd(a: Int, b: Int): Int = BigInt(a).gcd(BigInt(b)).toInt + val expectedBuckets = gcd(table1buckets1, table2buckets1) * + gcd(table1buckets2, table2buckets2) + assert(scans == Seq(expectedBuckets, expectedBuckets)) + + checkAnswer(df, Seq( + Row(0, 0, "aa", "aa"), + Row(0, 0, "aa", "ab"), + Row(0, 0, "ab", "aa"), + Row(0, 0, "ab", "ab"), + Row(0, 1, "ac", "ac"), + Row(0, 2, "ad", "ad"), + Row(0, 3, "ae", "ae"), + Row(0, 4, "af", "af"), + Row(0, 5, "ag", "ag"), + Row(1, 0, "ah", "ah"), + Row(1, 0, "ah", "ai"), + Row(1, 0, "ai", "ah"), + Row(1, 0, "ai", "ai"), + Row(1, 1, "aj", "aj"), + Row(1, 2, "ak", "ak"), + Row(1, 3, "al", "al"), + Row(1, 4, "am", "am"), + Row(1, 5, "an", "an"), + Row(2, 0, "ao", "ao"), + Row(2, 0, "ao", "ap"), + Row(2, 0, "ap", "ao"), + Row(2, 0, "ap", "ap"), + Row(2, 1, "aq", "aq"), + Row(2, 2, "ar", "ar"), + Row(2, 3, "as", "as"), + Row(2, 4, "at", "at"), + Row(2, 5, "au", "au"), + Row(3, 0, "av", "av"), + Row(3, 0, "av", "aw"), + Row(3, 0, "aw", "av"), + Row(3, 0, "aw", "aw"), + Row(3, 1, "ax", "ax"), + Row(3, 2, "ay", "ay"), + Row(3, 3, "az", "az"), + Row(3, 4, "ba", "ba"), + Row(3, 5, "bb", "bb"), + Row(4, 0, "bc", "bc"), + Row(4, 0, "bc", "bd"), + Row(4, 0, "bd", "bc"), + Row(4, 0, "bd", "bd"), + Row(4, 1, "be", "be"), + Row(4, 2, "bf", "bf"), + Row(4, 3, "bg", "bg"), + Row(4, 4, "bh", "bh"), + Row(4, 5, "bi", "bi"), + Row(5, 0, "bj", "bj"), + Row(5, 0, "bj", "bk"), + Row(5, 0, "bk", "bj"), + Row(5, 0, "bk", "bk"), + Row(5, 1, "bl", "bl"), + Row(5, 2, "bm", "bm"), + Row(5, 3, "bn", "bn"), + Row(5, 4, "bo", "bo"), + Row(5, 5, "bp", "bp") + )) + } + } + } + } + + test("SPARK-47094: Support compatible buckets with less join keys than partition keys") { + val table1 = "tab1e1" + val table2 = "table2" + + Seq((2, 4), (4, 2), (2, 6), (6, 2)).foreach { + case (table1buckets, table2buckets) => + catalog.clearTables() + + val partition1 = Array(identity("data"), + bucket(table1buckets, "dept_id")) + val partition2 = Array(bucket(3, "store_id"), + bucket(table2buckets, "dept_id")) + + createTable(table1, columns2, partition1) + sql(s"INSERT INTO testcat.ns.$table1 VALUES " + + "(0, 0, 'aa'), " + + "(1, 0, 'ab'), " + + "(2, 1, 'ac'), " + + "(3, 2, 'ad'), " + + "(4, 3, 'ae'), " + + "(5, 4, 'af'), " + + "(6, 5, 'ag'), " + + + // value without other side match + "(6, 6, 'xx')" + ) + + createTable(table2, columns2, partition2) + sql(s"INSERT INTO testcat.ns.$table2 VALUES " + + "(6, 0, '01'), " + + "(5, 1, '02'), " + // duplicate partition key + "(5, 1, '03'), " + + "(4, 2, '04'), " + + "(3, 3, '05'), " + + "(2, 4, '06'), " + + "(1, 5, '07'), " + + + // value without other side match + "(7, 7, '99')" + ) + + + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "false", + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true", + SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") { + val df = sql( + s""" + |${selectWithMergeJoinHint("t1", "t2")} + |t1.store_id, t2.store_id, t1.dept_id, t2.dept_id, t1.data, t2.data + |FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2 + |ON t1.dept_id = t2.dept_id + |ORDER BY t1.store_id, t1.dept_id, t1.data, t2.data + |""".stripMargin) + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, "SPJ should be triggered") + + val scans = collectScans(df.queryExecution.executedPlan).map(_.inputRDD. + partitions.length) + + val expectedBuckets = Math.min(table1buckets, table2buckets) + + assert(scans == Seq(expectedBuckets, expectedBuckets)) + + checkAnswer(df, Seq( + Row(0, 6, 0, 0, "aa", "01"), + Row(1, 6, 0, 0, "ab", "01"), + Row(2, 5, 1, 1, "ac", "02"), + Row(2, 5, 1, 1, "ac", "03"), + Row(3, 4, 2, 2, "ad", "04"), + Row(4, 3, 3, 3, "ae", "05"), + Row(5, 2, 4, 4, "af", "06"), + Row(6, 1, 5, 5, "ag", "07") + )) + } + } + } + + test("SPARK-47094: Compatible buckets does not support SPJ with " + + "push-down values or partially-clustered") { + val table1 = "tab1e1" + val table2 = "table2" + + val partition1 = Array(bucket(4, "store_id"), + bucket(2, "dept_id")) + val partition2 = Array(bucket(2, "store_id"), + bucket(2, "dept_id")) + + createTable(table1, columns2, partition1) + sql(s"INSERT INTO testcat.ns.$table1 VALUES " + + "(0, 0, 'aa'), " + + "(1, 1, 'bb'), " + + "(2, 2, 'cc')" + ) + + createTable(table2, columns2, partition2) + sql(s"INSERT INTO testcat.ns.$table2 VALUES " + + "(0, 0, 'aa'), " + + "(1, 1, 'bb'), " + + "(2, 2, 'cc')" + ) + + Seq(true, false).foreach{ allowPushDown => + Seq(true, false).foreach{ partiallyClustered => + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> allowPushDown.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> + partiallyClustered.toString, + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true", + SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") { + val df = sql( + s""" + |${selectWithMergeJoinHint("t1", "t2")} + |t1.store_id, t1.store_id, t1.dept_id, t2.dept_id, t1.data, t2.data + |FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2 + |ON t1.store_id = t2.store_id AND t1.dept_id = t2.dept_id + |ORDER BY t1.store_id, t1.dept_id, t1.data, t2.data + |""".stripMargin) + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + val scans = collectScans(df.queryExecution.executedPlan).map(_.inputRDD. + partitions.length) + + (allowPushDown, partiallyClustered) match { + case (true, false) => + assert(shuffles.isEmpty, "SPJ should be triggered") + assert(scans == Seq(2, 2)) + case (_, _) => + assert(shuffles.nonEmpty, "SPJ should not be triggered") + assert(scans == Seq(3, 2)) + } + + checkAnswer(df, Seq( + Row(0, 0, 0, 0, "aa", "aa"), + Row(1, 1, 1, 1, "bb", "bb"), + Row(2, 2, 2, 2, "cc", "cc") + )) + } + } + } + } + test("SPARK-44647: test join key is the second cluster key") { val table1 = "tab1e1" val table2 = "table2" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala index 61895d49c4a2..5cdb90090105 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala @@ -76,7 +76,7 @@ object UnboundBucketFunction extends UnboundFunction { override def name(): String = "bucket" } -object BucketFunction extends ScalarFunction[Int] { +object BucketFunction extends ScalarFunction[Int] with ReducibleFunction[Int, Int] { override def inputTypes(): Array[DataType] = Array(IntegerType, LongType) override def resultType(): DataType = IntegerType override def name(): String = "bucket" @@ -85,6 +85,26 @@ object BucketFunction extends ScalarFunction[Int] { override def produceResult(input: InternalRow): Int = { (input.getLong(1) % input.getInt(0)).toInt } + + override def reducer( + thisNumBuckets: Int, + otherFunc: ReducibleFunction[_, _], + otherNumBuckets: Int): Reducer[Int, Int] = { + + if (otherFunc == BucketFunction) { + val gcd = this.gcd(thisNumBuckets, otherNumBuckets) + if (gcd != thisNumBuckets) { + return BucketReducer(thisNumBuckets, gcd) + } + } + null + } + + private def gcd(a: Int, b: Int): Int = BigInt(a).gcd(BigInt(b)).toInt +} + +case class BucketReducer(thisNumBuckets: Int, divisor: Int) extends Reducer[Int, Int] { + override def reduce(bucket: Int): Int = bucket % divisor } object UnboundStringSelfFunction extends UnboundFunction { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org