This is an automated email from the ASF dual-hosted git repository.

danny0405 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/hudi.git


The following commit(s) were added to refs/heads/master by this push:
     new edc45df2905 [HUDI-7277] Fix `hoodie.bulkinsert.shuffle.parallelism` 
not activated with no-partitioned table (#10532)
edc45df2905 is described below

commit edc45df290590a8bfade3991873a2a1ff316dbd3
Author: KnightChess <[email protected]>
AuthorDate: Sat Jan 20 10:33:02 2024 +0800

    [HUDI-7277] Fix `hoodie.bulkinsert.shuffle.parallelism` not activated with 
no-partitioned table (#10532)
    
    Signed-off-by: wulingqi <[email protected]>
---
 .../hudi/HoodieDatasetBulkInsertHelper.scala       | 29 ++++++------
 .../TestHoodieDatasetBulkInsertHelper.java         | 53 ++++++++++++++++++++++
 2 files changed, 67 insertions(+), 15 deletions(-)

diff --git 
a/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/HoodieDatasetBulkInsertHelper.scala
 
b/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/HoodieDatasetBulkInsertHelper.scala
index 75ec069946d..0214b0a1030 100644
--- 
a/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/HoodieDatasetBulkInsertHelper.scala
+++ 
b/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/HoodieDatasetBulkInsertHelper.scala
@@ -76,6 +76,9 @@ object HoodieDatasetBulkInsertHelper
 
     val updatedSchema = StructType(metaFields ++ schema.fields)
 
+    val targetParallelism =
+      deduceShuffleParallelism(df, config.getBulkInsertShuffleParallelism)
+
     val updatedDF = if (populateMetaFields) {
       val keyGeneratorClassName = 
config.getStringOrThrow(HoodieWriteConfig.KEYGENERATOR_CLASS_NAME,
         "Key-generator class name is required")
@@ -110,7 +113,7 @@ object HoodieDatasetBulkInsertHelper
         }
 
       val dedupedRdd = if (config.shouldCombineBeforeInsert) {
-        dedupeRows(prependedRdd, updatedSchema, config.getPreCombineField, 
SparkHoodieIndexFactory.isGlobalIndex(config))
+        dedupeRows(prependedRdd, updatedSchema, config.getPreCombineField, 
SparkHoodieIndexFactory.isGlobalIndex(config), targetParallelism)
       } else {
         prependedRdd
       }
@@ -127,9 +130,6 @@ object HoodieDatasetBulkInsertHelper
       HoodieUnsafeUtils.createDataFrameFrom(df.sparkSession, prependedQuery)
     }
 
-    val targetParallelism =
-      deduceShuffleParallelism(updatedDF, 
config.getBulkInsertShuffleParallelism)
-
     partitioner.repartitionRecords(updatedDF, targetParallelism)
   }
 
@@ -193,7 +193,7 @@ object HoodieDatasetBulkInsertHelper
     table.getContext.parallelize(writeStatuses.toList.asJava)
   }
 
-  private def dedupeRows(rdd: RDD[InternalRow], schema: StructType, 
preCombineFieldRef: String, isGlobalIndex: Boolean): RDD[InternalRow] = {
+  private def dedupeRows(rdd: RDD[InternalRow], schema: StructType, 
preCombineFieldRef: String, isGlobalIndex: Boolean, targetParallelism: Int): 
RDD[InternalRow] = {
     val recordKeyMetaFieldOrd = 
schema.fieldIndex(HoodieRecord.RECORD_KEY_METADATA_FIELD)
     val partitionPathMetaFieldOrd = 
schema.fieldIndex(HoodieRecord.PARTITION_PATH_METADATA_FIELD)
     // NOTE: Pre-combine field could be a nested field
@@ -212,16 +212,15 @@ object HoodieDatasetBulkInsertHelper
         //       since Spark might be providing us with a mutable copy 
(updated during the iteration)
         (rowKey, row.copy())
       }
-      .reduceByKey {
-        (oneRow, otherRow) =>
-          val onePreCombineVal = getNestedInternalRowValue(oneRow, 
preCombineFieldPath).asInstanceOf[Comparable[AnyRef]]
-          val otherPreCombineVal = getNestedInternalRowValue(otherRow, 
preCombineFieldPath).asInstanceOf[Comparable[AnyRef]]
-          if 
(onePreCombineVal.compareTo(otherPreCombineVal.asInstanceOf[AnyRef]) >= 0) {
-            oneRow
-          } else {
-            otherRow
-          }
-      }
+      .reduceByKey ((oneRow, otherRow) => {
+        val onePreCombineVal = getNestedInternalRowValue(oneRow, 
preCombineFieldPath).asInstanceOf[Comparable[AnyRef]]
+        val otherPreCombineVal = getNestedInternalRowValue(otherRow, 
preCombineFieldPath).asInstanceOf[Comparable[AnyRef]]
+        if 
(onePreCombineVal.compareTo(otherPreCombineVal.asInstanceOf[AnyRef]) >= 0) {
+          oneRow
+        } else {
+          otherRow
+        }
+      }, targetParallelism)
       .values
   }
 
diff --git 
a/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/functional/TestHoodieDatasetBulkInsertHelper.java
 
b/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/functional/TestHoodieDatasetBulkInsertHelper.java
index 50ec641c182..bb24ee0e52a 100644
--- 
a/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/functional/TestHoodieDatasetBulkInsertHelper.java
+++ 
b/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/functional/TestHoodieDatasetBulkInsertHelper.java
@@ -37,8 +37,11 @@ import org.apache.hudi.testutils.HoodieSparkClientTestBase;
 import org.apache.avro.Schema;
 import org.apache.spark.api.java.function.MapFunction;
 import org.apache.spark.api.java.function.ReduceFunction;
+import org.apache.spark.scheduler.SparkListener;
+import org.apache.spark.scheduler.SparkListenerStageSubmitted;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Encoders;
+import org.apache.spark.sql.HoodieUnsafeUtils;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
 import org.apache.spark.sql.types.StructType;
@@ -59,6 +62,7 @@ import java.util.stream.Stream;
 import scala.Tuple2;
 
 import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotEquals;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 import static org.junit.jupiter.api.Assertions.fail;
 
@@ -348,4 +352,53 @@ public class TestHoodieDatasetBulkInsertHelper extends 
HoodieSparkClientTestBase
   private ExpressionEncoder getEncoder(StructType schema) {
     return 
SparkAdapterSupport$.MODULE$.sparkAdapter().getCatalystExpressionUtils().getEncoder(schema);
   }
+
+  @Test
+  public void testBulkInsertParallelismParam() {
+    HoodieWriteConfig config = 
getConfigBuilder(schemaStr).withProps(getPropsAllSet("_row_key"))
+        .combineInput(true, true)
+        .withPreCombineField("ts").build();
+    int checkParallelism = 7;
+    config.setValue("hoodie.bulkinsert.shuffle.parallelism", 
String.valueOf(checkParallelism));
+    StageCheckBulkParallelismListener stageCheckBulkParallelismListener =
+        new 
StageCheckBulkParallelismListener("org.apache.hudi.HoodieDatasetBulkInsertHelper$.dedupeRows");
+    
sqlContext.sparkContext().addSparkListener(stageCheckBulkParallelismListener);
+    List<Row> inserts = DataSourceTestUtils.generateRandomRows(10);
+    Dataset<Row> dataset = sqlContext.createDataFrame(inserts, 
structType).repartition(3);
+    assertNotEquals(checkParallelism, 
HoodieUnsafeUtils.getNumPartitions(dataset));
+    assertNotEquals(checkParallelism, 
sqlContext.sparkContext().defaultParallelism());
+    Dataset<Row> result = 
HoodieDatasetBulkInsertHelper.prepareForBulkInsert(dataset, config,
+        new NonSortPartitionerWithRows(), "000001111");
+    // trigger job
+    result.count();
+    assertEquals(checkParallelism, 
stageCheckBulkParallelismListener.getParallelism());
+    
sqlContext.sparkContext().removeSparkListener(stageCheckBulkParallelismListener);
+  }
+
+  class StageCheckBulkParallelismListener extends SparkListener {
+
+    private boolean checkFlag = false;
+    private String checkMessage;
+    private int parallelism;
+
+    StageCheckBulkParallelismListener(String checkMessage) {
+      this.checkMessage = checkMessage;
+    }
+
+    @Override
+    public void onStageSubmitted(SparkListenerStageSubmitted stageSubmitted) {
+      if (checkFlag) {
+        // dedup next stage is reduce task
+        this.parallelism = stageSubmitted.stageInfo().numTasks();
+        checkFlag = false;
+      }
+      if (stageSubmitted.stageInfo().details().contains(checkMessage)) {
+        checkFlag = true;
+      }
+    }
+
+    public int getParallelism() {
+      return parallelism;
+    }
+  }
 }

Reply via email to