This is an automated email from the ASF dual-hosted git repository.

vbalaji pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/hudi.git


The following commit(s) were added to refs/heads/master by this push:
     new 4c7ac6112da [MINOR] Fix UT error in HUDI-6941 with stage task numbers 
(#10554)
4c7ac6112da is described below

commit 4c7ac6112daab349ebcdd1fbb2216d9d1138ca14
Author: xuzifu666 <[email protected]>
AuthorDate: Sat Jan 27 11:59:53 2024 +0800

    [MINOR] Fix UT error in HUDI-6941 with stage task numbers (#10554)
    
    * [MINOR] Fix UT error in HUDI-6941 with stage task numbers
---
 .../apache/spark/sql/hudi/TestInsertTable.scala    | 30 ++++++++++++++++++++++
 1 file changed, 30 insertions(+)

diff --git 
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestInsertTable.scala
 
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestInsertTable.scala
index eb6e20ee931..21369ea34e0 100644
--- 
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestInsertTable.scala
+++ 
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestInsertTable.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.hudi
 
 import org.apache.hudi.DataSourceWriteOptions._
+import org.apache.hudi.client.common.HoodieSparkEngineContext
 import org.apache.hudi.common.model.HoodieRecord.HoodieRecordType
 import org.apache.hudi.common.model.{HoodieRecord, WriteOperationType}
 import org.apache.hudi.common.table.timeline.HoodieInstant
@@ -28,12 +29,14 @@ import 
org.apache.hudi.exception.{HoodieDuplicateKeyException, HoodieException}
 import org.apache.hudi.execution.bulkinsert.BulkInsertSortMode
 import org.apache.hudi.index.HoodieIndex.IndexType
 import org.apache.hudi.{DataSourceWriteOptions, HoodieCLIUtils, 
HoodieSparkUtils}
+import org.apache.spark.scheduler.{SparkListener, SparkListenerStageSubmitted}
 import org.apache.spark.sql.SaveMode
 import org.apache.spark.sql.hudi.HoodieSparkSqlTestBase.getLastCommitMetadata
 import 
org.apache.spark.sql.hudi.command.HoodieSparkValidateDuplicateKeyRecordMerger
 import org.junit.jupiter.api.Assertions.assertEquals
 
 import java.io.File
+import java.util.concurrent.CountDownLatch
 
 class TestInsertTable extends HoodieSparkSqlTestBase {
 
@@ -2081,10 +2084,27 @@ class TestInsertTable extends HoodieSparkSqlTestBase {
     })
   }
 
+  var listenerCallCount: Int = 0
+  var countDownLatch: CountDownLatch = _
+
+  // add a listener for stages for parallelism checking with stage name
+  class StageParallelismListener(var stageName: String) extends SparkListener {
+    override def onStageSubmitted(stageSubmitted: 
SparkListenerStageSubmitted): Unit = {
+      if (stageSubmitted.stageInfo.name.contains(stageName)) {
+        assertResult(1)(stageSubmitted.stageInfo.numTasks)
+        listenerCallCount = listenerCallCount + 1
+        countDownLatch.countDown
+      }
+    }
+
+  }
+
   test("Test multiple partition fields pruning") {
 
     withRecordType()(withTempDir { tmp =>
       val targetTable = generateTableName
+      countDownLatch = new CountDownLatch(1)
+      listenerCallCount = 0
       spark.sql(
         s"""
            |create table ${targetTable} (
@@ -2114,6 +2134,8 @@ class TestInsertTable extends HoodieSparkSqlTestBase {
            |union
            |select '1' as id, 'aa' as name, 123 as dt, '2023-10-12' as `day`, 
12 as `hour`
            |""".stripMargin)
+      val stageClassName = classOf[HoodieSparkEngineContext].getSimpleName
+      spark.sparkContext.addSparkListener(new 
StageParallelismListener(stageName = stageClassName))
       val df = spark.sql(
         s"""
            |select * from ${targetTable} where day='2023-10-12' and hour=11
@@ -2124,12 +2146,16 @@ class TestInsertTable extends HoodieSparkSqlTestBase {
         rddHead = rddHead.firstParent
       }
       assertResult(1)(rddHead.partitions.size)
+      countDownLatch.await
+      assert(listenerCallCount >= 1)
     })
   }
 
   test("Test single partiton field pruning") {
 
     withRecordType()(withTempDir { tmp =>
+      countDownLatch = new CountDownLatch(1)
+      listenerCallCount = 0
       val targetTable = generateTableName
       spark.sql(
         s"""
@@ -2160,6 +2186,8 @@ class TestInsertTable extends HoodieSparkSqlTestBase {
            |union
            |select '1' as id, 'aa' as name, 123 as dt, '2023-10-12' as `day`, 
12 as `hour`
            |""".stripMargin)
+      val stageClassName = classOf[HoodieSparkEngineContext].getSimpleName
+      spark.sparkContext.addSparkListener(new 
StageParallelismListener(stageName = stageClassName))
       val df = spark.sql(
         s"""
            |select * from ${targetTable} where day='2023-10-12' and hour=11
@@ -2170,6 +2198,8 @@ class TestInsertTable extends HoodieSparkSqlTestBase {
         rddHead = rddHead.firstParent
       }
       assertResult(1)(rddHead.partitions.size)
+      countDownLatch.await
+      assert(listenerCallCount >= 1)
     })
   }
 

Reply via email to