ulysses-you commented on code in PR #6793:
URL: https://github.com/apache/kyuubi/pull/6793#discussion_r1830702568


##########
extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/KyuubiSparkSQLExtensionTest.scala:
##########
@@ -99,23 +102,49 @@ trait KyuubiSparkSQLExtensionTest extends QueryTest
     withListener(sql(sqlString))(callback)
   }
 
-  def withListener(df: => DataFrame)(callback: DataWritingCommand => Unit): 
Unit = {
+  def withListener(df: => DataFrame)(
+      callback: DataWritingCommand => Unit,
+      failIfNotCallback: Boolean = true): Unit = {
+    val writes = Collections.synchronizedList(new 
java.util.ArrayList[DataWritingCommand]())
+
     val listener = new QueryExecutionListener {
       override def onFailure(f: String, qe: QueryExecution, e: Exception): 
Unit = {}
 
       override def onSuccess(funcName: String, qe: QueryExecution, duration: 
Long): Unit = {
-        qe.executedPlan match {
-          case write: DataWritingCommandExec => callback(write.cmd)
-          case _ =>
+        def collectWrite(plan: SparkPlan): Unit = {
+          plan match {
+            case write: DataWritingCommandExec =>
+              writes.add(write.cmd)
+            case a: AdaptiveSparkPlanExec => collectWrite(a.executedPlan)
+            case _ =>
+          }
         }
+        collectWrite(qe.executedPlan)
       }
     }
+    // Make sure the listener is registered after all previous events have 
been processed
+    sparkContext.listenerBus.waitUntilEmpty()
     spark.listenerManager.register(listener)
     try {
       df.collect()
       sparkContext.listenerBus.waitUntilEmpty()
     } finally {
       spark.listenerManager.unregister(listener)
     }
+    if (failIfNotCallback && writes.isEmpty) {
+      fail("No write command found")
+    }
+    writes.forEach(callback(_))
+  }
+
+  def collectRebalancePartitions(plan: LogicalPlan): Seq[RebalancePartitions] 
= {
+    def collect(p: LogicalPlan): Seq[RebalancePartitions] = {
+      p.flatMap {
+        case r: RebalancePartitions => Seq(r)
+        case s: LogicalQueryStage => collect(s.logicalPlan)

Review Comment:
   Is it a Spark issue ? IMO, `LogicalQueryStage` should not exist in query 
execution.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: notifications-unsubscr...@kyuubi.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: notifications-unsubscr...@kyuubi.apache.org
For additional commands, e-mail: notifications-h...@kyuubi.apache.org

Reply via email to