This is an automated email from the ASF dual-hosted git repository.

lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new ed0b0cf79f [spark] Fix the bucket join may produce wrong result after 
bucket rescaled (#5669)
ed0b0cf79f is described below

commit ed0b0cf79fe2303c9fd8bf6db8932112975e0fcd
Author: WenjunMin <[email protected]>
AuthorDate: Thu May 29 10:01:11 2025 +0800

    [spark] Fix the bucket join may produce wrong result after bucket rescaled 
(#5669)
---
 .../scala/org/apache/paimon/spark/PaimonScan.scala | 34 +++++++++++++---
 .../org/apache/paimon/spark/PaimonBaseScan.scala   | 21 ++++++----
 .../scala/org/apache/paimon/spark/PaimonScan.scala | 34 +++++++++++++---
 .../paimon/spark/sql/BucketedTableQueryTest.scala  | 47 +++++++++++++++++++++-
 4 files changed, 117 insertions(+), 19 deletions(-)

diff --git 
a/paimon-spark/paimon-spark-3.3/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
 
b/paimon-spark/paimon-spark-3.3/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
index 4c62d58a81..ec589442e8 100644
--- 
a/paimon-spark/paimon-spark-3.3/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
+++ 
b/paimon-spark/paimon-spark-3.3/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
@@ -60,11 +60,16 @@ case class PaimonScan(
           // so we only support one bucket key case.
           assert(bucketSpec.getNumBuckets > 0)
           assert(bucketSpec.getBucketKeys.size() == 1)
-          val bucketKey = bucketSpec.getBucketKeys.get(0)
-          if (requiredSchema.exists(f => conf.resolver(f.name, bucketKey))) {
-            Some(Expressions.bucket(bucketSpec.getNumBuckets, bucketKey))
-          } else {
-            None
+          extractBucketNumber() match {
+            case Some(num) =>
+              val bucketKey = bucketSpec.getBucketKeys.get(0)
+              if (requiredSchema.exists(f => conf.resolver(f.name, 
bucketKey))) {
+                Some(Expressions.bucket(num, bucketKey))
+              } else {
+                None
+              }
+
+            case _ => None
           }
         }
 
@@ -72,6 +77,24 @@ case class PaimonScan(
     }
   }
 
+  /**
+   * Extract the bucket number from the splits only if all splits have the 
same totalBuckets number.
+   */
+  private def extractBucketNumber(): Option[Int] = {
+    val splits = getOriginSplits
+    if (splits.exists(!_.isInstanceOf[DataSplit])) {
+      None
+    } else {
+      val deduplicated =
+        splits.map(s => 
Option(s.asInstanceOf[DataSplit].totalBuckets())).toSeq.distinct
+
+      deduplicated match {
+        case Seq(Some(num)) => Some(num)
+        case _ => None
+      }
+    }
+  }
+
   private def shouldDoBucketedScan: Boolean = {
     !bucketedScanDisabled && conf.v2BucketingEnabled && 
extractBucketTransform.isDefined
   }
@@ -120,6 +143,7 @@ case class PaimonScan(
       readBuilder.withFilter(partitionFilter.head)
       // set inputPartitions null to trigger to get the new splits.
       inputPartitions = null
+      inputSplits = null
     }
   }
 }
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScan.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScan.scala
index 74741f5364..b0447c8830 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScan.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScan.scala
@@ -51,6 +51,8 @@ abstract class PaimonBaseScan(
 
   protected var inputPartitions: Seq[PaimonInputPartition] = _
 
+  protected var inputSplits: Array[Split] = _
+
   override val coreOptions: CoreOptions = CoreOptions.fromMap(table.options())
 
   lazy val statistics: Optional[stats.Statistics] = table.statistics()
@@ -65,14 +67,17 @@ abstract class PaimonBaseScan(
 
   @VisibleForTesting
   def getOriginSplits: Array[Split] = {
-    readBuilder
-      .newScan()
-      .asInstanceOf[InnerTableScan]
-      .withMetricRegistry(paimonMetricsRegistry)
-      .plan()
-      .splits()
-      .asScala
-      .toArray
+    if (inputSplits == null) {
+      inputSplits = readBuilder
+        .newScan()
+        .asInstanceOf[InnerTableScan]
+        .withMetricRegistry(paimonMetricsRegistry)
+        .plan()
+        .splits()
+        .asScala
+        .toArray
+    }
+    inputSplits
   }
 
   final def lazyInputPartitions: Seq[PaimonInputPartition] = {
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
index 20c1cfffad..616c660255 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
@@ -62,11 +62,16 @@ case class PaimonScan(
           // so we only support one bucket key case.
           assert(bucketSpec.getNumBuckets > 0)
           assert(bucketSpec.getBucketKeys.size() == 1)
-          val bucketKey = bucketSpec.getBucketKeys.get(0)
-          if (requiredSchema.exists(f => conf.resolver(f.name, bucketKey))) {
-            Some(Expressions.bucket(bucketSpec.getNumBuckets, bucketKey))
-          } else {
-            None
+          extractBucketNumber() match {
+            case Some(num) =>
+              val bucketKey = bucketSpec.getBucketKeys.get(0)
+              if (requiredSchema.exists(f => conf.resolver(f.name, 
bucketKey))) {
+                Some(Expressions.bucket(num, bucketKey))
+              } else {
+                None
+              }
+
+            case _ => None
           }
         }
 
@@ -74,6 +79,24 @@ case class PaimonScan(
     }
   }
 
+  /**
+   * Extract the bucket number from the splits only if all splits have the 
same totalBuckets number.
+   */
+  private def extractBucketNumber(): Option[Int] = {
+    val splits = getOriginSplits
+    if (splits.exists(!_.isInstanceOf[DataSplit])) {
+      None
+    } else {
+      val deduplicated =
+        splits.map(s => 
Option(s.asInstanceOf[DataSplit].totalBuckets())).toSeq.distinct
+
+      deduplicated match {
+        case Seq(Some(num)) => Some(num)
+        case _ => None
+      }
+    }
+  }
+
   private def shouldDoBucketedScan: Boolean = {
     !bucketedScanDisabled && conf.v2BucketingEnabled && 
extractBucketTransform.isDefined
   }
@@ -169,6 +192,7 @@ case class PaimonScan(
       readBuilder.withFilter(partitionFilter.toList.asJava)
       // set inputPartitions null to trigger to get the new splits.
       inputPartitions = null
+      inputSplits = null
     }
   }
 }
diff --git 
a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/BucketedTableQueryTest.scala
 
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/BucketedTableQueryTest.scala
index 35931924c4..3f87f8ec6f 100644
--- 
a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/BucketedTableQueryTest.scala
+++ 
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/BucketedTableQueryTest.scala
@@ -36,7 +36,9 @@ class BucketedTableQueryTest extends PaimonSparkTestBase with 
AdaptiveSparkPlanH
     }
     withSparkSQLConf(
       "spark.sql.sources.v2.bucketing.enabled" -> "true",
-      "spark.sql.autoBroadcastJoinThreshold" -> "-1") {
+      "spark.sql.requireAllClusterKeysForCoPartition" -> "false",
+      "spark.sql.autoBroadcastJoinThreshold" -> "-1"
+    ) {
       val df = spark.sql(query)
       checkAnswer(df, expectedResult.toSeq)
       assert(collect(df.queryExecution.executedPlan) {
@@ -50,6 +52,49 @@ class BucketedTableQueryTest extends PaimonSparkTestBase 
with AdaptiveSparkPlanH
     }
   }
 
+  test("Query on a rescaled bucket table") {
+    assume(gteqSpark3_3)
+
+    withTable("t1", "t2") {
+
+      spark.sql(
+        "CREATE TABLE t1 (id INT, c STRING, dt STRING) partitioned by (dt) 
TBLPROPERTIES ('bucket'='2', 'bucket-key' = 'id')")
+      spark.sql(
+        "CREATE TABLE t2 (id INT, c STRING, dt STRING) partitioned by (dt) 
TBLPROPERTIES ('bucket'='3', 'bucket-key' = 'id')")
+      spark.sql("INSERT INTO t1 VALUES (1, 'x1', '20250101'), (3, 'x2', 
'20250101')")
+      spark.sql("INSERT INTO t2 VALUES (1, 'x1', '20250101'), (4, 'x2', 
'20250101')")
+      checkAnswerAndShuffleSorts(
+        "SELECT * FROM t1 JOIN t2 on t1.id = t2.id and t1.dt = '20250101' and 
t2.dt = '20250101'",
+        2,
+        2)
+      spark.sql("ALTER TABLE t1 SET TBLPROPERTIES ('bucket' = '3')")
+      checkAnswerAndShuffleSorts(
+        "SELECT * FROM t1 JOIN t2 on t1.id = t2.id and t1.dt = t2.dt ",
+        2,
+        2)
+    }
+
+    withTable("t1", "t2") {
+
+      spark.sql(
+        "CREATE TABLE t1 (id INT, c STRING, dt STRING) partitioned by (dt) 
TBLPROPERTIES ('bucket'='2', 'bucket-key' = 'id')")
+      spark.sql(
+        "CREATE TABLE t2 (id INT, c STRING, dt STRING) partitioned by (dt) 
TBLPROPERTIES ('bucket'='2', 'bucket-key' = 'id')")
+      // TODO if the input partition is not aligned by bucket value, the 
bucket join will not be applied.
+      spark.sql("INSERT INTO t1 VALUES (1, 'x1', '20250101'), (2, 'x2', 
'20250101')")
+      spark.sql("INSERT INTO t2 VALUES (1, 'x1', '20250101'), (5, 'x2', 
'20250101')")
+      checkAnswerAndShuffleSorts(
+        "SELECT * FROM t1 JOIN t2 on t1.id = t2.id and t1.dt = '20250101' and 
t2.dt = '20250101'",
+        0,
+        2)
+      spark.sql("ALTER TABLE t1 SET TBLPROPERTIES ('bucket' = '3')")
+      checkAnswerAndShuffleSorts(
+        "SELECT * FROM t1 JOIN t2 on t1.id = t2.id and t1.dt = t2.dt ",
+        0,
+        2)
+    }
+  }
+
   test("Query on a bucketed table - join - positive case") {
     assume(gteqSpark3_3)
 

Reply via email to