This is an automated email from the ASF dual-hosted git repository.

lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new b2781c80a [spark] Support cross partition write (#3976)
b2781c80a is described below

commit b2781c80a8ce32d5c2d89505fbea4def12f3e624
Author: Zouxxyy <[email protected]>
AuthorDate: Sun Aug 18 22:27:55 2024 +0800

    [spark] Support cross partition write (#3976)
---
 .../paimon/crosspartition}/KeyPartOrRow.java       |   2 +-
 .../KeyPartPartitionKeyExtractor.java              |   2 +-
 .../flink/sink/index/GlobalDynamicBucketSink.java  |   1 +
 .../sink/index/GlobalIndexAssignerOperator.java    |   1 +
 .../flink/sink/index/IndexBootstrapOperator.java   |   1 +
 .../sink/index/KeyPartRowChannelComputer.java      |   1 +
 .../flink/sink/index/KeyWithRowSerializer.java     |   3 +-
 .../paimon/spark/commands/BucketProcessor.scala    | 133 ++++++++++++++--
 .../paimon/spark/commands/PaimonSparkWriter.scala  | 177 +++++++++++++++------
 .../paimon/spark/sql/DynamicBucketTableTest.scala  | 106 ++++++++++--
 10 files changed, 349 insertions(+), 78 deletions(-)

diff --git 
a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/index/KeyPartOrRow.java
 b/paimon-core/src/main/java/org/apache/paimon/crosspartition/KeyPartOrRow.java
similarity index 97%
rename from 
paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/index/KeyPartOrRow.java
rename to 
paimon-core/src/main/java/org/apache/paimon/crosspartition/KeyPartOrRow.java
index 10c80c8d3..64d8661ea 100644
--- 
a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/index/KeyPartOrRow.java
+++ 
b/paimon-core/src/main/java/org/apache/paimon/crosspartition/KeyPartOrRow.java
@@ -16,7 +16,7 @@
  * limitations under the License.
  */
 
-package org.apache.paimon.flink.sink.index;
+package org.apache.paimon.crosspartition;
 
 /** Type of record, key or full row. */
 public enum KeyPartOrRow {
diff --git 
a/paimon-core/src/main/java/org/apache/paimon/crosspartition/KeyPartPartitionKeyExtractor.java
 
b/paimon-core/src/main/java/org/apache/paimon/crosspartition/KeyPartPartitionKeyExtractor.java
index ea86358ac..5abfbfffb 100644
--- 
a/paimon-core/src/main/java/org/apache/paimon/crosspartition/KeyPartPartitionKeyExtractor.java
+++ 
b/paimon-core/src/main/java/org/apache/paimon/crosspartition/KeyPartPartitionKeyExtractor.java
@@ -30,7 +30,7 @@ import java.util.List;
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
 
-/** A {@link PartitionKeyExtractor} to {@link InternalRow} with only key and 
partiton fields. */
+/** A {@link PartitionKeyExtractor} to {@link InternalRow} with only key and 
partition fields. */
 public class KeyPartPartitionKeyExtractor implements 
PartitionKeyExtractor<InternalRow> {
 
     private final Projection partitionProjection;
diff --git 
a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/index/GlobalDynamicBucketSink.java
 
b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/index/GlobalDynamicBucketSink.java
index 3ce562d89..f4da37072 100644
--- 
a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/index/GlobalDynamicBucketSink.java
+++ 
b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/index/GlobalDynamicBucketSink.java
@@ -20,6 +20,7 @@ package org.apache.paimon.flink.sink.index;
 
 import org.apache.paimon.CoreOptions;
 import org.apache.paimon.crosspartition.IndexBootstrap;
+import org.apache.paimon.crosspartition.KeyPartOrRow;
 import org.apache.paimon.data.InternalRow;
 import org.apache.paimon.flink.sink.Committable;
 import org.apache.paimon.flink.sink.DynamicBucketRowWriteOperator;
diff --git 
a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/index/GlobalIndexAssignerOperator.java
 
b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/index/GlobalIndexAssignerOperator.java
index bec047762..7fee3f45f 100644
--- 
a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/index/GlobalIndexAssignerOperator.java
+++ 
b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/index/GlobalIndexAssignerOperator.java
@@ -19,6 +19,7 @@
 package org.apache.paimon.flink.sink.index;
 
 import org.apache.paimon.crosspartition.GlobalIndexAssigner;
+import org.apache.paimon.crosspartition.KeyPartOrRow;
 import org.apache.paimon.data.InternalRow;
 import org.apache.paimon.disk.IOManager;
 import org.apache.paimon.table.Table;
diff --git 
a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/index/IndexBootstrapOperator.java
 
b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/index/IndexBootstrapOperator.java
index a75274be9..501e35dff 100644
--- 
a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/index/IndexBootstrapOperator.java
+++ 
b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/index/IndexBootstrapOperator.java
@@ -19,6 +19,7 @@
 package org.apache.paimon.flink.sink.index;
 
 import org.apache.paimon.crosspartition.IndexBootstrap;
+import org.apache.paimon.crosspartition.KeyPartOrRow;
 import org.apache.paimon.data.InternalRow;
 import org.apache.paimon.utils.SerializableFunction;
 
diff --git 
a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/index/KeyPartRowChannelComputer.java
 
b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/index/KeyPartRowChannelComputer.java
index dedb07c95..adb234158 100644
--- 
a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/index/KeyPartRowChannelComputer.java
+++ 
b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/index/KeyPartRowChannelComputer.java
@@ -20,6 +20,7 @@ package org.apache.paimon.flink.sink.index;
 
 import org.apache.paimon.codegen.CodeGenUtils;
 import org.apache.paimon.codegen.Projection;
+import org.apache.paimon.crosspartition.KeyPartOrRow;
 import org.apache.paimon.data.BinaryRow;
 import org.apache.paimon.data.InternalRow;
 import org.apache.paimon.table.sink.ChannelComputer;
diff --git 
a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/index/KeyWithRowSerializer.java
 
b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/index/KeyWithRowSerializer.java
index 876aa296c..fbb16f7da 100644
--- 
a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/index/KeyWithRowSerializer.java
+++ 
b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sink/index/KeyWithRowSerializer.java
@@ -18,6 +18,7 @@
 
 package org.apache.paimon.flink.sink.index;
 
+import org.apache.paimon.crosspartition.KeyPartOrRow;
 import org.apache.paimon.flink.utils.InternalTypeSerializer;
 
 import org.apache.flink.api.common.typeutils.TypeSerializer;
@@ -28,7 +29,7 @@ import org.apache.flink.core.memory.DataOutputView;
 import java.io.IOException;
 import java.util.Objects;
 
-import static org.apache.paimon.flink.sink.index.KeyPartOrRow.KEY_PART;
+import static org.apache.paimon.crosspartition.KeyPartOrRow.KEY_PART;
 
 /** A {@link InternalTypeSerializer} to serialize KeyPartOrRow with T. */
 public class KeyWithRowSerializer<T> extends 
InternalTypeSerializer<Tuple2<KeyPartOrRow, T>> {
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/BucketProcessor.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/BucketProcessor.scala
index 4a3393497..f252b3bb1 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/BucketProcessor.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/BucketProcessor.scala
@@ -18,22 +18,29 @@
 
 package org.apache.paimon.spark.commands
 
-import org.apache.paimon.data.{InternalRow => PaimonInternalRow}
+import org.apache.paimon.crosspartition.{GlobalIndexAssigner, KeyPartOrRow}
+import org.apache.paimon.data.{BinaryRow, GenericRow, InternalRow => 
PaimonInternalRow, JoinedRow}
+import org.apache.paimon.disk.IOManager
 import org.apache.paimon.index.HashBucketAssigner
-import org.apache.paimon.spark.SparkRow
+import org.apache.paimon.spark.{SparkInternalRow, SparkRow}
+import org.apache.paimon.spark.SparkUtils.createIOManager
 import org.apache.paimon.spark.util.EncoderUtils
 import org.apache.paimon.table.FileStoreTable
 import org.apache.paimon.table.sink.RowPartitionKeyExtractor
+import org.apache.paimon.types.RowType
+import org.apache.paimon.utils.{CloseableIterator, SerializationUtils}
 
 import org.apache.spark.TaskContext
 import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.{InternalRow => SparkInternalRow}
+import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.{Deserializer, 
Serializer}
 import org.apache.spark.sql.types.StructType
 
 import java.util.UUID
 
+import scala.collection.mutable
+
 case class EncoderSerDeGroup(schema: StructType) {
 
   val encoder: ExpressionEncoder[Row] = 
EncoderUtils.encode(schema).resolveAndBind()
@@ -42,24 +49,24 @@ case class EncoderSerDeGroup(schema: StructType) {
 
   private val deserializer: Deserializer[Row] = encoder.createDeserializer()
 
-  def rowToInternal(row: Row): SparkInternalRow = {
+  def rowToInternal(row: Row): InternalRow = {
     serializer(row)
   }
 
-  def internalToRow(internalRow: SparkInternalRow): Row = {
+  def internalToRow(internalRow: InternalRow): Row = {
     deserializer(internalRow)
   }
 }
 
-sealed trait BucketProcessor {
-  def processPartition(rowIterator: Iterator[Row]): Iterator[Row]
+sealed trait BucketProcessor[In] {
+  def processPartition(rowIterator: Iterator[In]): Iterator[Row]
 }
 
 case class CommonBucketProcessor(
     table: FileStoreTable,
     bucketColIndex: Int,
     encoderGroup: EncoderSerDeGroup)
-  extends BucketProcessor {
+  extends BucketProcessor[Row] {
 
   def processPartition(rowIterator: Iterator[Row]): Iterator[Row] = {
     val rowType = table.rowType()
@@ -89,7 +96,7 @@ case class DynamicBucketProcessor(
     numSparkPartitions: Int,
     numAssigners: Int,
     encoderGroup: EncoderSerDeGroup
-) extends BucketProcessor {
+) extends BucketProcessor[Row] {
 
   private val targetBucketRowNumber = 
fileStoreTable.coreOptions.dynamicBucketTargetRowNum
   private val rowType = fileStoreTable.rowType
@@ -123,3 +130,111 @@ case class DynamicBucketProcessor(
     }
   }
 }
+
+case class GlobalDynamicBucketProcessor(
+    fileStoreTable: FileStoreTable,
+    rowType: RowType,
+    numAssigners: Integer,
+    encoderGroup: EncoderSerDeGroup)
+  extends BucketProcessor[(KeyPartOrRow, Array[Byte])] {
+
+  override def processPartition(
+      rowIterator: Iterator[(KeyPartOrRow, Array[Byte])]): Iterator[Row] = {
+    new GlobalIndexAssignerIterator(
+      rowIterator,
+      fileStoreTable,
+      rowType,
+      numAssigners,
+      encoderGroup)
+  }
+}
+
+class GlobalIndexAssignerIterator(
+    rowIterator: Iterator[(KeyPartOrRow, Array[Byte])],
+    fileStoreTable: FileStoreTable,
+    rowType: RowType,
+    numAssigners: Integer,
+    encoderGroup: EncoderSerDeGroup)
+  extends Iterator[Row]
+  with AutoCloseable {
+
+  private val queue = mutable.Queue[Row]()
+
+  val ioManager: IOManager = createIOManager
+
+  var currentResult: Row = _
+
+  var advanced = false
+
+  val assigner: GlobalIndexAssigner = {
+    val _assigner = new GlobalIndexAssigner(fileStoreTable)
+    _assigner.open(
+      0,
+      ioManager,
+      numAssigners,
+      TaskContext.getPartitionId(),
+      (row, bucket) => {
+        val extraRow: GenericRow = new GenericRow(2)
+        extraRow.setField(0, row.getRowKind.toByteValue)
+        extraRow.setField(1, bucket)
+        queue.enqueue(
+          encoderGroup.internalToRow(
+            SparkInternalRow.fromPaimon(new JoinedRow(row, extraRow), 
rowType)))
+      }
+    )
+    rowIterator.foreach {
+      row =>
+        {
+          val internalRow = SerializationUtils.deserializeBinaryRow(row._2)
+          row._1 match {
+            case KeyPartOrRow.KEY_PART => _assigner.bootstrapKey(internalRow)
+            case KeyPartOrRow.ROW => _assigner.processInput(internalRow)
+            case _ =>
+              throw new UnsupportedOperationException(s"unknown kind 
${row._1}")
+          }
+        }
+    }
+    _assigner
+  }
+
+  private val emitIterator: CloseableIterator[BinaryRow] = 
assigner.endBoostrapWithoutEmit(true)
+
+  override def hasNext: Boolean = {
+    advanceIfNeeded()
+    currentResult != null
+  }
+
+  override def next(): Row = {
+    if (!hasNext) {
+      throw new NoSuchElementException
+    }
+    advanced = false
+    currentResult
+  }
+
+  def advanceIfNeeded(): Unit = {
+    if (!advanced) {
+      advanced = true
+      currentResult = null
+      var stop = false
+      while (!stop) {
+        if (queue.nonEmpty) {
+          currentResult = queue.dequeue()
+          stop = true
+        } else if (emitIterator.hasNext) {
+          assigner.processInput(emitIterator.next())
+        } else {
+          stop = true
+        }
+      }
+    }
+  }
+
+  override def close(): Unit = {
+    emitIterator.close()
+    assigner.close()
+    if (ioManager != null) {
+      ioManager.close()
+    }
+  }
+}
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonSparkWriter.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonSparkWriter.scala
index 7bdd0ce60..d0b2e86ea 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonSparkWriter.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonSparkWriter.scala
@@ -19,20 +19,25 @@
 package org.apache.paimon.spark.commands
 
 import org.apache.paimon.CoreOptions.WRITE_ONLY
+import org.apache.paimon.codegen.CodeGenUtils
+import org.apache.paimon.crosspartition.{IndexBootstrap, KeyPartOrRow}
+import org.apache.paimon.data.serializer.InternalSerializers
 import org.apache.paimon.deletionvectors.DeletionVector
-import org.apache.paimon.deletionvectors.append.{AppendDeletionFileMaintainer, 
UnawareAppendDeletionFileMaintainer}
+import org.apache.paimon.deletionvectors.append.AppendDeletionFileMaintainer
 import org.apache.paimon.index.{BucketAssigner, SimpleHashBucketAssigner}
 import org.apache.paimon.io.{CompactIncrement, DataIncrement, IndexIncrement}
 import org.apache.paimon.manifest.{FileKind, IndexManifestEntry}
-import org.apache.paimon.spark.{SparkRow, SparkTableWrite}
+import org.apache.paimon.spark.{SparkRow, SparkTableWrite, SparkTypeUtils}
 import org.apache.paimon.spark.schema.SparkSystemColumns.{BUCKET_COL, 
ROW_KIND_COL}
 import org.apache.paimon.spark.util.SparkRowUtils
-import org.apache.paimon.table.{BucketMode, FileStoreTable}
-import org.apache.paimon.table.BucketMode.BUCKET_UNAWARE
-import org.apache.paimon.table.sink.{BatchWriteBuilder, CommitMessage, 
CommitMessageImpl, CommitMessageSerializer, RowPartitionKeyExtractor}
+import org.apache.paimon.table.BucketMode._
+import org.apache.paimon.table.FileStoreTable
+import org.apache.paimon.table.sink._
+import org.apache.paimon.types.{RowKind, RowType}
 import org.apache.paimon.utils.SerializationUtils
 
 import org.apache.spark.{Partitioner, TaskContext}
+import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
 import org.apache.spark.sql.functions._
 
@@ -49,8 +54,6 @@ case class PaimonSparkWriter(table: FileStoreTable) {
 
   private lazy val bucketMode = table.bucketMode
 
-  private lazy val primaryKeyCols = tableSchema.trimmedPrimaryKeys().asScala
-
   @transient private lazy val serializer = new CommitMessageSerializer
 
   val writeBuilder: BatchWriteBuilder = table.newBatchWriteBuilder()
@@ -59,24 +62,31 @@ case class PaimonSparkWriter(table: FileStoreTable) {
     PaimonSparkWriter(table.copy(singletonMap(WRITE_ONLY.key(), "true")))
   }
 
-  def write(data: Dataset[Row]): Seq[CommitMessage] = {
+  def write(data: DataFrame): Seq[CommitMessage] = {
     val sparkSession = data.sparkSession
     import sparkSession.implicits._
 
-    val rowKindColIdx = SparkRowUtils.getFieldIndex(data.schema, ROW_KIND_COL)
-    assert(
-      rowKindColIdx == -1 || rowKindColIdx == data.schema.length - 1,
-      "Row kind column should be the last field.")
-
-    // append _bucket_ column as placeholder
-    val withInitBucketCol = data.withColumn(BUCKET_COL, lit(-1))
-    val bucketColIdx = withInitBucketCol.schema.size - 1
+    val withInitBucketCol = bucketMode match {
+      case CROSS_PARTITION if !data.schema.fieldNames.contains(ROW_KIND_COL) =>
+        data
+          .withColumn(ROW_KIND_COL, lit(RowKind.INSERT.toByteValue))
+          .withColumn(BUCKET_COL, lit(-1))
+      case _ => data.withColumn(BUCKET_COL, lit(-1))
+    }
+    val rowKindColIdx = SparkRowUtils.getFieldIndex(withInitBucketCol.schema, 
ROW_KIND_COL)
+    val bucketColIdx = SparkRowUtils.getFieldIndex(withInitBucketCol.schema, 
BUCKET_COL)
     val encoderGroupWithBucketCol = EncoderSerDeGroup(withInitBucketCol.schema)
 
     def newWrite(): SparkTableWrite = new SparkTableWrite(writeBuilder, 
rowType, rowKindColIdx)
 
-    def writeWithoutBucket(): Dataset[Array[Byte]] = {
-      data.mapPartitions {
+    def sparkParallelism = {
+      val defaultParallelism = sparkSession.sparkContext.defaultParallelism
+      val numShufflePartitions = 
sparkSession.sessionState.conf.numShufflePartitions
+      Math.max(defaultParallelism, numShufflePartitions)
+    }
+
+    def writeWithoutBucket(dataFrame: DataFrame): Dataset[Array[Byte]] = {
+      dataFrame.mapPartitions {
         iter =>
           {
             val write = newWrite()
@@ -90,12 +100,8 @@ case class PaimonSparkWriter(table: FileStoreTable) {
       }
     }
 
-    def writeWithBucketProcessor(
-        dataFrame: DataFrame,
-        processor: BucketProcessor): Dataset[Array[Byte]] = {
-      val repartitioned = repartitionByPartitionsAndBucket(
-        
dataFrame.mapPartitions(processor.processPartition)(encoderGroupWithBucketCol.encoder))
-      repartitioned.mapPartitions {
+    def writeWithBucket(dataFrame: DataFrame): Dataset[Array[Byte]] = {
+      dataFrame.mapPartitions {
         iter =>
           {
             val write = newWrite()
@@ -109,6 +115,16 @@ case class PaimonSparkWriter(table: FileStoreTable) {
       }
     }
 
+    def writeWithBucketProcessor(
+        dataFrame: DataFrame,
+        processor: BucketProcessor[Row]): Dataset[Array[Byte]] = {
+      val repartitioned = repartitionByPartitionsAndBucket(
+        dataFrame
+          
.mapPartitions(processor.processPartition)(encoderGroupWithBucketCol.encoder)
+          .toDF())
+      writeWithBucket(repartitioned)
+    }
+
     def writeWithBucketAssigner(
         dataFrame: DataFrame,
         funcFactory: () => Row => Int): Dataset[Array[Byte]] = {
@@ -128,25 +144,44 @@ case class PaimonSparkWriter(table: FileStoreTable) {
     }
 
     val written: Dataset[Array[Byte]] = bucketMode match {
-      case BucketMode.HASH_DYNAMIC =>
-        assert(primaryKeyCols.nonEmpty, "Only primary-key table can support 
dynamic bucket.")
-
-        val numParallelism = 
Option(table.coreOptions.dynamicBucketAssignerParallelism)
+      case CROSS_PARTITION =>
+        // Topology: input -> bootstrap -> shuffle by key hash -> 
bucket-assigner -> shuffle by partition & bucket
+        val rowType = 
SparkTypeUtils.toPaimonType(withInitBucketCol.schema).asInstanceOf[RowType]
+        val assignerParallelism = 
Option(table.coreOptions.dynamicBucketAssignerParallelism)
           .map(_.toInt)
-          .getOrElse {
-            val defaultParallelism = 
sparkSession.sparkContext.defaultParallelism
-            val numShufflePartitions = 
sparkSession.sessionState.conf.numShufflePartitions
-            Math.max(defaultParallelism, numShufflePartitions)
-          }
+          .getOrElse(sparkParallelism)
+        val bootstrapped = bootstrapAndRepartitionByKeyHash(
+          withInitBucketCol,
+          assignerParallelism,
+          rowKindColIdx,
+          rowType)
+
+        val globalDynamicBucketProcessor =
+          GlobalDynamicBucketProcessor(
+            table,
+            rowType,
+            assignerParallelism,
+            encoderGroupWithBucketCol)
+        val repartitioned = repartitionByPartitionsAndBucket(
+          sparkSession.createDataFrame(
+            
bootstrapped.mapPartitions(globalDynamicBucketProcessor.processPartition),
+            withInitBucketCol.schema))
+
+        writeWithBucket(repartitioned)
+
+      case HASH_DYNAMIC =>
+        val assignerParallelism = 
Option(table.coreOptions.dynamicBucketAssignerParallelism)
+          .map(_.toInt)
+          .getOrElse(sparkParallelism)
         val numAssigners = 
Option(table.coreOptions.dynamicBucketInitialBuckets)
-          .map(initialBuckets => Math.min(initialBuckets.toInt, 
numParallelism))
-          .getOrElse(numParallelism)
+          .map(initialBuckets => Math.min(initialBuckets.toInt, 
assignerParallelism))
+          .getOrElse(assignerParallelism)
 
         def partitionByKey(): DataFrame = {
           repartitionByKeyPartitionHash(
             sparkSession,
             withInitBucketCol,
-            numParallelism,
+            assignerParallelism,
             numAssigners)
         }
 
@@ -177,18 +212,22 @@ case class PaimonSparkWriter(table: FileStoreTable) {
             DynamicBucketProcessor(
               table,
               bucketColIdx,
-              numParallelism,
+              assignerParallelism,
               numAssigners,
-              encoderGroupWithBucketCol))
+              encoderGroupWithBucketCol)
+          )
         }
+
       case BUCKET_UNAWARE =>
         // Topology: input ->
-        writeWithoutBucket()
-      case BucketMode.HASH_FIXED =>
+        writeWithoutBucket(data)
+
+      case HASH_FIXED =>
         // Topology: input -> bucket-assigner -> shuffle by partition & bucket
         writeWithBucketProcessor(
           withInitBucketCol,
           CommonBucketProcessor(table, bucketColIdx, 
encoderGroupWithBucketCol))
+
       case _ =>
         throw new UnsupportedOperationException(s"Spark doesn't support 
$bucketMode mode.")
     }
@@ -286,13 +325,53 @@ case class PaimonSparkWriter(table: FileStoreTable) {
     }
   }
 
-  /** Compute bucket id in dynamic bucket mode. */
+  /** Boostrap and repartition for cross partition mode. */
+  private def bootstrapAndRepartitionByKeyHash(
+      data: DataFrame,
+      parallelism: Int,
+      rowKindColIdx: Int,
+      rowType: RowType): RDD[(KeyPartOrRow, Array[Byte])] = {
+    val numSparkPartitions = data.rdd.getNumPartitions
+    val primaryKeys = table.schema().primaryKeys()
+    val bootstrapType = IndexBootstrap.bootstrapType(table.schema())
+    data.rdd
+      .mapPartitions {
+        iter =>
+          {
+            val sparkPartitionId = TaskContext.getPartitionId()
+            val keyPartProject = CodeGenUtils.newProjection(bootstrapType, 
primaryKeys)
+            val rowProject = CodeGenUtils.newProjection(rowType, primaryKeys)
+            val bootstrapSer = InternalSerializers.create(bootstrapType)
+            val rowSer = InternalSerializers.create(rowType)
+            new IndexBootstrap(table)
+              .bootstrap(numSparkPartitions, sparkPartitionId)
+              .toCloseableIterator
+              .asScala
+              .map(
+                row => {
+                  val bytes: Array[Byte] =
+                    
SerializationUtils.serializeBinaryRow(bootstrapSer.toBinaryRow(row))
+                  (Math.abs(keyPartProject(row).hashCode()), 
(KeyPartOrRow.KEY_PART, bytes))
+                }) ++ iter.map(
+              r => {
+                val sparkRow =
+                  new SparkRow(rowType, r, SparkRowUtils.getRowKind(r, 
rowKindColIdx))
+                val bytes: Array[Byte] =
+                  
SerializationUtils.serializeBinaryRow(rowSer.toBinaryRow(sparkRow))
+                (Math.abs(rowProject(sparkRow).hashCode()), (KeyPartOrRow.ROW, 
bytes))
+              })
+          }
+      }
+      .partitionBy(ModPartitioner(parallelism))
+      .map(_._2)
+  }
+
+  /** Repartition for dynamic bucket mode. */
   private def repartitionByKeyPartitionHash(
       sparkSession: SparkSession,
       data: DataFrame,
-      numParallelism: Int,
+      parallelism: Int,
       numAssigners: Int): DataFrame = {
-
     sparkSession.createDataFrame(
       data.rdd
         .mapPartitions(
@@ -305,21 +384,19 @@ case class PaimonSparkWriter(table: FileStoreTable) {
                 val keyHash = 
rowPartitionKeyExtractor.trimmedPrimaryKey(sparkRow).hashCode
                 (
                   BucketAssigner
-                    .computeHashKey(partitionHash, keyHash, numParallelism, 
numAssigners),
+                    .computeHashKey(partitionHash, keyHash, parallelism, 
numAssigners),
                   row)
               })
-          },
-          preservesPartitioning = true
-        )
-        .partitionBy(ModPartitioner(numParallelism))
+          })
+        .partitionBy(ModPartitioner(parallelism))
         .map(_._2),
       data.schema
     )
   }
 
-  private def repartitionByPartitionsAndBucket(ds: Dataset[Row]): Dataset[Row] 
= {
+  private def repartitionByPartitionsAndBucket(df: DataFrame): DataFrame = {
     val partitionCols = tableSchema.partitionKeys().asScala.map(col)
-    ds.toDF().repartition(partitionCols ++ Seq(col(BUCKET_COL)): _*)
+    df.repartition(partitionCols ++ Seq(col(BUCKET_COL)): _*)
   }
 
   private def deserializeCommitMessage(
diff --git 
a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DynamicBucketTableTest.scala
 
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DynamicBucketTableTest.scala
index 0ba51ff28..023ab1664 100644
--- 
a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DynamicBucketTableTest.scala
+++ 
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DynamicBucketTableTest.scala
@@ -86,22 +86,96 @@ class DynamicBucketTableTest extends PaimonSparkTestBase {
       Row(0) :: Row(1) :: Row(2) :: Nil)
   }
 
-  test(s"Paimon dynamic bucket table: write with global dynamic bucket") {
-    spark.sql(s"""
-                 |CREATE TABLE T (
-                 |  pk STRING,
-                 |  v STRING,
-                 |  pt STRING)
-                 |TBLPROPERTIES (
-                 |  'primary-key' = 'pk',
-                 |  'bucket' = '-1'
-                 |)
-                 |PARTITIONED BY (pt)
-                 |""".stripMargin)
+  test(s"Paimon cross partition table: write with partition change") {
+    sql(s"""
+           |CREATE TABLE T (
+           |  pt INT,
+           |  pk INT,
+           |  v INT)
+           |TBLPROPERTIES (
+           |  'primary-key' = 'pk',
+           |  'bucket' = '-1',
+           |  'dynamic-bucket.target-row-num'='3',
+           |  'dynamic-bucket.assigner-parallelism'='1'
+           |)
+           |PARTITIONED BY (pt)
+           |""".stripMargin)
+
+    sql("INSERT INTO T VALUES (1, 1, 1), (1, 2, 2), (1, 3, 3), (1, 4, 4), (1, 
5, 5)")
+    checkAnswer(
+      sql("SELECT * FROM T ORDER BY pk"),
+      Seq(Row(1, 1, 1), Row(1, 2, 2), Row(1, 3, 3), Row(1, 4, 4), Row(1, 5, 
5)))
+
+    sql("INSERT INTO T VALUES (1, 3, 33), (1, 1, 11)")
+    checkAnswer(
+      sql("SELECT * FROM T ORDER BY pk"),
+      Seq(Row(1, 1, 11), Row(1, 2, 2), Row(1, 3, 33), Row(1, 4, 4), Row(1, 5, 
5)))
+
+    checkAnswer(sql("SELECT DISTINCT bucket FROM `T$FILES`"), Seq(Row(0), 
Row(1)))
+
+    // change partition
+    sql("INSERT INTO T VALUES (2, 1, 2), (2, 2, 3)")
+    checkAnswer(
+      sql("SELECT * FROM T ORDER BY pk"),
+      Seq(Row(2, 1, 2), Row(2, 2, 3), Row(1, 3, 33), Row(1, 4, 4), Row(1, 5, 
5)))
+  }
+
+  test(s"Paimon cross partition table: write with delete") {
+    sql(s"""
+           |CREATE TABLE T (
+           |  pt INT,
+           |  pk INT,
+           |  v INT)
+           |TBLPROPERTIES (
+           |  'primary-key' = 'pk',
+           |  'bucket' = '-1',
+           |  'dynamic-bucket.target-row-num'='3'
+           |)
+           |PARTITIONED BY (pt)
+           |""".stripMargin)
+
+    sql("INSERT INTO T VALUES (1, 1, 1), (1, 2, 2), (1, 3, 3), (1, 4, 4), (1, 
5, 5)")
+    checkAnswer(
+      sql("SELECT * FROM T ORDER BY pk"),
+      Seq(Row(1, 1, 1), Row(1, 2, 2), Row(1, 3, 3), Row(1, 4, 4), Row(1, 5, 
5)))
+
+    sql("DELETE FROM T WHERE pk = 1")
+    checkAnswer(
+      sql("SELECT * FROM T ORDER BY pk"),
+      Seq(Row(1, 2, 2), Row(1, 3, 3), Row(1, 4, 4), Row(1, 5, 5)))
+
+    // change partition
+    sql("INSERT INTO T VALUES (2, 1, 2), (2, 2, 3)")
+    checkAnswer(
+      sql("SELECT * FROM T ORDER BY pk"),
+      Seq(Row(2, 1, 2), Row(2, 2, 3), Row(1, 3, 3), Row(1, 4, 4), Row(1, 5, 
5)))
+
+    sql("DELETE FROM T WHERE pk = 2")
+    checkAnswer(
+      sql("SELECT * FROM T ORDER BY pk"),
+      Seq(Row(2, 1, 2), Row(1, 3, 3), Row(1, 4, 4), Row(1, 5, 5)))
+  }
+
+  test(s"Paimon cross partition table: user define assigner parallelism") {
+    sql(s"""
+           |CREATE TABLE T (
+           |  pt INT,
+           |  pk INT,
+           |  v INT)
+           |TBLPROPERTIES (
+           |  'primary-key' = 'pk',
+           |  'bucket' = '-1',
+           |  'dynamic-bucket.target-row-num'='3',
+           |  'dynamic-bucket.assigner-parallelism'='3'
+           |)
+           |PARTITIONED BY (pt)
+           |""".stripMargin)
+
+    sql("INSERT INTO T VALUES (1, 1, 1), (1, 2, 2), (1, 3, 3), (1, 4, 4), (1, 
5, 5)")
+    checkAnswer(
+      sql("SELECT * FROM T ORDER BY pk"),
+      Seq(Row(1, 1, 1), Row(1, 2, 2), Row(1, 3, 3), Row(1, 4, 4), Row(1, 5, 
5)))
 
-    val error = intercept[UnsupportedOperationException] {
-      spark.sql("INSERT INTO T VALUES ('1', 'a', 'p')")
-    }.getMessage
-    assert(error.contains("Spark doesn't support CROSS_PARTITION mode"))
+    checkAnswer(sql("SELECT DISTINCT bucket FROM `T$FILES`"), Seq(Row(0), 
Row(1), Row(2)))
   }
 }

Reply via email to