This is an automated email from the ASF dual-hosted git repository.

hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new bb9b7ef034 [GLUTEN-7079][VL] Fix metrics for InputIteratorTransformer 
of broadcast exchange (#7167)
bb9b7ef034 is described below

commit bb9b7ef034f87d2a532f688b31e606398f66578a
Author: Tengfei Huang <[email protected]>
AuthorDate: Thu Nov 7 08:22:02 2024 +0800

    [GLUTEN-7079][VL] Fix metrics for InputIteratorTransformer of broadcast 
exchange (#7167)
    
    Closes #7079
    Closes #4672
---
 .../backendsapi/clickhouse/CHMetricsApi.scala      | 30 +++++++++++++---
 .../metrics/GlutenClickHouseTPCHMetricsSuite.scala | 39 ++++++++++++++++++++-
 .../gluten/backendsapi/velox/VeloxMetricsApi.scala | 36 ++++++++++++++-----
 .../metrics/InputIteratorMetricsUpdater.scala      | 21 +++++++-----
 .../gluten/execution/VeloxMetricsSuite.scala       | 40 ++++++++++++++++++++--
 .../org/apache/gluten/backendsapi/MetricsApi.scala |  9 +++--
 .../ColumnarCollapseTransformStages.scala          | 18 ++++++++--
 7 files changed, 164 insertions(+), 29 deletions(-)

diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHMetricsApi.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHMetricsApi.scala
index 6a4f0c9a6f..73b2d0f211 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHMetricsApi.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHMetricsApi.scala
@@ -23,7 +23,8 @@ import org.apache.gluten.substrait.{AggregationParams, 
JoinParams}
 
 import org.apache.spark.SparkContext
 import org.apache.spark.internal.Logging
-import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.{ColumnarInputAdapter, SparkPlan}
+import org.apache.spark.sql.execution.adaptive.QueryStageExec
 import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
 
 import java.lang.{Long => JLong}
@@ -39,21 +40,40 @@ class CHMetricsApi extends MetricsApi with Logging with 
LogLevelUtil {
   }
 
   override def genInputIteratorTransformerMetrics(
-      sparkContext: SparkContext): Map[String, SQLMetric] = {
+      child: SparkPlan,
+      sparkContext: SparkContext,
+      forBroadcast: Boolean): Map[String, SQLMetric] = {
+    def metricsPlan(plan: SparkPlan): SparkPlan = {
+      plan match {
+        case ColumnarInputAdapter(child) => metricsPlan(child)
+        case q: QueryStageExec => metricsPlan(q.plan)
+        case _ => plan
+      }
+    }
+
+    val outputMetrics = if (forBroadcast) {
+      metricsPlan(child).metrics
+        .filterKeys(key => key.equals("numOutputRows"))
+    } else {
+      Map(
+        "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of 
output rows")
+      )
+    }
+
     Map(
       "iterReadTime" -> SQLMetrics.createTimingMetric(
         sparkContext,
         "time of reading from iterator"),
       "numInputRows" -> SQLMetrics.createMetric(sparkContext, "number of input 
rows"),
-      "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of 
output rows"),
       "fillingRightJoinSideTime" -> SQLMetrics.createTimingMetric(
         sparkContext,
         "filling right join side time")
-    )
+    ) ++ outputMetrics
   }
 
   override def genInputIteratorTransformerMetricsUpdater(
-      metrics: Map[String, SQLMetric]): MetricsUpdater = {
+      metrics: Map[String, SQLMetric],
+      forBroadcast: Boolean): MetricsUpdater = {
     InputIteratorMetricsUpdater(metrics)
   }
 
diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/metrics/GlutenClickHouseTPCHMetricsSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/metrics/GlutenClickHouseTPCHMetricsSuite.scala
index 932686433e..3cfb8cc4fc 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/metrics/GlutenClickHouseTPCHMetricsSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/metrics/GlutenClickHouseTPCHMetricsSuite.scala
@@ -21,7 +21,9 @@ import org.apache.gluten.extension.GlutenPlan
 
 import org.apache.spark.SparkConf
 import org.apache.spark.sql.catalyst.expressions.Attribute
-import org.apache.spark.sql.execution.InputIteratorTransformer
+import org.apache.spark.sql.execution.{ColumnarInputAdapter, 
InputIteratorTransformer}
+import org.apache.spark.sql.execution.adaptive.BroadcastQueryStageExec
+import org.apache.spark.sql.execution.exchange.BroadcastExchangeLike
 import org.apache.spark.task.TaskResources
 
 import scala.collection.JavaConverters._
@@ -422,4 +424,39 @@ class GlutenClickHouseTPCHMetricsSuite extends 
GlutenClickHouseTPCHAbstractSuite
       
assert(nativeMetricsDataFinal.metricsDataList.get(2).getName.equals("kProject"))
     }
   }
+
+  test("Metrics for input iterator of broadcast exchange") {
+    createTPCHNotNullTables()
+    val partTableRecords = spark.sql("select * from part").count()
+
+    // Repartition to make sure we have multiple tasks executing the join.
+    spark
+      .sql("select * from lineitem")
+      .repartition(2)
+      .createOrReplaceTempView("lineitem")
+
+    Seq("true", "false").foreach {
+      adaptiveEnabled =>
+        withSQLConf("spark.sql.adaptive.enabled" -> adaptiveEnabled) {
+          val sqlStr =
+            """
+              |select /*+ BROADCAST(part) */ * from part join lineitem
+              |on l_partkey = p_partkey
+              |""".stripMargin
+
+          runQueryAndCompare(sqlStr) {
+            df =>
+              val inputIterator = find(df.queryExecution.executedPlan) {
+                case InputIteratorTransformer(ColumnarInputAdapter(child)) =>
+                  child.isInstanceOf[BroadcastQueryStageExec] || child
+                    .isInstanceOf[BroadcastExchangeLike]
+                case _ => false
+              }
+              assert(inputIterator.isDefined)
+              val metrics = inputIterator.get.metrics
+              assert(metrics("numOutputRows").value == partTableRecords)
+          }
+        }
+    }
+  }
 }
diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxMetricsApi.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxMetricsApi.scala
index 10b0c493c1..e70e1d13bd 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxMetricsApi.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxMetricsApi.scala
@@ -22,7 +22,8 @@ import org.apache.gluten.substrait.{AggregationParams, 
JoinParams}
 
 import org.apache.spark.SparkContext
 import org.apache.spark.internal.Logging
-import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.{ColumnarInputAdapter, SparkPlan}
+import org.apache.spark.sql.execution.adaptive.QueryStageExec
 import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
 
 import java.lang.{Long => JLong}
@@ -38,18 +39,37 @@ class VeloxMetricsApi extends MetricsApi with Logging {
   }
 
   override def genInputIteratorTransformerMetrics(
-      sparkContext: SparkContext): Map[String, SQLMetric] = {
+      child: SparkPlan,
+      sparkContext: SparkContext,
+      forBroadcast: Boolean): Map[String, SQLMetric] = {
+    def metricsPlan(plan: SparkPlan): SparkPlan = {
+      plan match {
+        case ColumnarInputAdapter(child) => metricsPlan(child)
+        case q: QueryStageExec => metricsPlan(q.plan)
+        case _ => plan
+      }
+    }
+
+    val outputMetrics = if (forBroadcast) {
+      metricsPlan(child).metrics
+        .filterKeys(key => key.equals("numOutputRows") || 
key.equals("outputVectors"))
+    } else {
+      Map(
+        "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of 
output rows"),
+        "outputVectors" -> SQLMetrics.createMetric(sparkContext, "number of 
output vectors")
+      )
+    }
+
     Map(
       "cpuCount" -> SQLMetrics.createMetric(sparkContext, "cpu wall time 
count"),
-      "wallNanos" -> SQLMetrics.createNanoTimingMetric(sparkContext, "time of 
input iterator"),
-      "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of 
output rows"),
-      "outputVectors" -> SQLMetrics.createMetric(sparkContext, "number of 
output vectors")
-    )
+      "wallNanos" -> SQLMetrics.createNanoTimingMetric(sparkContext, "time of 
input iterator")
+    ) ++ outputMetrics
   }
 
   override def genInputIteratorTransformerMetricsUpdater(
-      metrics: Map[String, SQLMetric]): MetricsUpdater = {
-    InputIteratorMetricsUpdater(metrics)
+      metrics: Map[String, SQLMetric],
+      forBroadcast: Boolean): MetricsUpdater = {
+    InputIteratorMetricsUpdater(metrics, forBroadcast)
   }
 
   override def genBatchScanTransformerMetrics(sparkContext: SparkContext): 
Map[String, SQLMetric] =
diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/metrics/InputIteratorMetricsUpdater.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/metrics/InputIteratorMetricsUpdater.scala
index a9067d069e..8002a44ae9 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/metrics/InputIteratorMetricsUpdater.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/metrics/InputIteratorMetricsUpdater.scala
@@ -17,20 +17,23 @@
 package org.apache.gluten.metrics
 import org.apache.spark.sql.execution.metric.SQLMetric
 
-case class InputIteratorMetricsUpdater(metrics: Map[String, SQLMetric]) 
extends MetricsUpdater {
+case class InputIteratorMetricsUpdater(metrics: Map[String, SQLMetric], 
forBroadcast: Boolean)
+  extends MetricsUpdater {
   override def updateNativeMetrics(opMetrics: IOperatorMetrics): Unit = {
     if (opMetrics != null) {
       val operatorMetrics = opMetrics.asInstanceOf[OperatorMetrics]
       metrics("cpuCount") += operatorMetrics.cpuCount
       metrics("wallNanos") += operatorMetrics.wallNanos
-      if (operatorMetrics.outputRows == 0 && operatorMetrics.outputVectors == 
0) {
-        // Sometimes, velox does not update metrics for intermediate operator,
-        // here we try to use the input metrics
-        metrics("numOutputRows") += operatorMetrics.inputRows
-        metrics("outputVectors") += operatorMetrics.inputVectors
-      } else {
-        metrics("numOutputRows") += operatorMetrics.outputRows
-        metrics("outputVectors") += operatorMetrics.outputVectors
+      if (!forBroadcast) {
+        if (operatorMetrics.outputRows == 0 && operatorMetrics.outputVectors 
== 0) {
+          // Sometimes, velox does not update metrics for intermediate 
operator,
+          // here we try to use the input metrics
+          metrics("numOutputRows") += operatorMetrics.inputRows
+          metrics("outputVectors") += operatorMetrics.inputVectors
+        } else {
+          metrics("numOutputRows") += operatorMetrics.outputRows
+          metrics("outputVectors") += operatorMetrics.outputVectors
+        }
       }
     }
   }
diff --git 
a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxMetricsSuite.scala
 
b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxMetricsSuite.scala
index 4a25144984..0b74824832 100644
--- 
a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxMetricsSuite.scala
+++ 
b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxMetricsSuite.scala
@@ -22,8 +22,9 @@ import org.apache.gluten.sql.shims.SparkShimLoader
 import org.apache.spark.SparkConf
 import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted}
 import org.apache.spark.sql.TestUtils
-import org.apache.spark.sql.execution.CommandResultExec
-import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
+import org.apache.spark.sql.execution.{ColumnarInputAdapter, 
CommandResultExec, InputIteratorTransformer}
+import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, 
BroadcastQueryStageExec}
+import org.apache.spark.sql.execution.exchange.BroadcastExchangeLike
 import org.apache.spark.sql.internal.SQLConf
 
 class VeloxMetricsSuite extends VeloxWholeStageTransformerSuite with 
AdaptiveSparkPlanHelper {
@@ -227,4 +228,39 @@ class VeloxMetricsSuite extends 
VeloxWholeStageTransformerSuite with AdaptiveSpa
 
     assert(inputRecords == (partTableRecords + itemTableRecords))
   }
+
+  test("Metrics for input iterator of broadcast exchange") {
+    createTPCHNotNullTables()
+    val partTableRecords = spark.sql("select * from part").count()
+
+    // Repartition to make sure we have multiple tasks executing the join.
+    spark
+      .sql("select * from lineitem")
+      .repartition(2)
+      .createOrReplaceTempView("lineitem")
+
+    Seq("true", "false").foreach {
+      adaptiveEnabled =>
+        withSQLConf("spark.sql.adaptive.enabled" -> adaptiveEnabled) {
+          val sqlStr =
+            """
+              |select /*+ BROADCAST(part) */ * from part join lineitem
+              |on l_partkey = p_partkey
+              |""".stripMargin
+
+          runQueryAndCompare(sqlStr) {
+            df =>
+              val inputIterator = find(df.queryExecution.executedPlan) {
+                case InputIteratorTransformer(ColumnarInputAdapter(child)) =>
+                  child.isInstanceOf[BroadcastQueryStageExec] || child
+                    .isInstanceOf[BroadcastExchangeLike]
+                case _ => false
+              }
+              assert(inputIterator.isDefined)
+              val metrics = inputIterator.get.metrics
+              assert(metrics("numOutputRows").value == partTableRecords)
+          }
+        }
+    }
+  }
 }
diff --git 
a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/MetricsApi.scala
 
b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/MetricsApi.scala
index 62008767f5..c67d4b5f88 100644
--- 
a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/MetricsApi.scala
+++ 
b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/MetricsApi.scala
@@ -33,9 +33,14 @@ trait MetricsApi extends Serializable {
       "pipelineTime" -> SQLMetrics
         .createTimingMetric(sparkContext, 
WholeStageCodegenExec.PIPELINE_DURATION_METRIC))
 
-  def genInputIteratorTransformerMetrics(sparkContext: SparkContext): 
Map[String, SQLMetric]
+  def genInputIteratorTransformerMetrics(
+      child: SparkPlan,
+      sparkContext: SparkContext,
+      forBroadcast: Boolean): Map[String, SQLMetric]
 
-  def genInputIteratorTransformerMetricsUpdater(metrics: Map[String, 
SQLMetric]): MetricsUpdater
+  def genInputIteratorTransformerMetricsUpdater(
+      metrics: Map[String, SQLMetric],
+      forBroadcast: Boolean): MetricsUpdater
 
   def metricsUpdatingFunction(
       child: SparkPlan,
diff --git 
a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarCollapseTransformStages.scala
 
b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarCollapseTransformStages.scala
index 32575e4f13..d222dcfef8 100644
--- 
a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarCollapseTransformStages.scala
+++ 
b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarCollapseTransformStages.scala
@@ -32,6 +32,8 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, 
SortOrder}
 import org.apache.spark.sql.catalyst.plans.physical.Partitioning
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.util.truncatedString
+import org.apache.spark.sql.execution.adaptive.BroadcastQueryStageExec
+import org.apache.spark.sql.execution.exchange.BroadcastExchangeLike
 import org.apache.spark.sql.execution.metric.SQLMetric
 import org.apache.spark.sql.vectorized.ColumnarBatch
 
@@ -49,14 +51,18 @@ case class InputIteratorTransformer(child: SparkPlan) 
extends UnaryTransformSupp
 
   @transient
   override lazy val metrics: Map[String, SQLMetric] =
-    
BackendsApiManager.getMetricsApiInstance.genInputIteratorTransformerMetrics(sparkContext)
+    
BackendsApiManager.getMetricsApiInstance.genInputIteratorTransformerMetrics(
+      child,
+      sparkContext,
+      forBroadcast())
 
   override def simpleString(maxFields: Int): String = {
     s"$nodeName${truncatedString(output, "[", ", ", "]", maxFields)}"
   }
 
   override def metricsUpdater(): MetricsUpdater =
-    
BackendsApiManager.getMetricsApiInstance.genInputIteratorTransformerMetricsUpdater(metrics)
+    BackendsApiManager.getMetricsApiInstance
+      .genInputIteratorTransformerMetricsUpdater(metrics, forBroadcast())
 
   override def output: Seq[Attribute] = child.output
   override def outputPartitioning: Partitioning = child.outputPartitioning
@@ -75,6 +81,14 @@ case class InputIteratorTransformer(child: SparkPlan) 
extends UnaryTransformSupp
   override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan 
= {
     copy(child = newChild)
   }
+
+  private def forBroadcast(): Boolean = {
+    child match {
+      case ColumnarInputAdapter(c) if c.isInstanceOf[BroadcastQueryStageExec] 
=> true
+      case ColumnarInputAdapter(c) if c.isInstanceOf[BroadcastExchangeLike] => 
true
+      case _ => false
+    }
+  }
 }
 
 /**


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to