This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 428f75942ecf [SPARK-50144][SS] Address the limitation of metrics 
calculation with DSv1 streaming sources
428f75942ecf is described below

commit 428f75942ecfccdcd87c9a0fdc50be5bafab6123
Author: Jungtaek Lim <[email protected]>
AuthorDate: Thu Oct 31 06:35:23 2024 +0900

    [SPARK-50144][SS] Address the limitation of metrics calculation with DSv1 
streaming sources
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to make the majority of DSv1 streaming sources to carry 
over stream information from logical plan to physical plan, so that Spark can 
use the same approach of collecting metrics for DSv2 streaming sources.
    
    To achieve this, this PR introduces two new traits which are matching with 
logical node and physical node respectively:
    
    * Logical plan: StreamSourceAwareLogicalPlan
    * Physical plan: StreamSourceAwareSparkPlan
    
    All streaming DSv1 sources are expected to produce the logical plan in 
getBatch(), which has leaf node(s) implementing StreamSourceAwareLogicalPlan. 
For built-in DSv1 streaming sources (and external sources using 
`SparkSession/SQLContext.internalCreateDataFrame`), they are mostly using one 
or multiple of nodes:
    
    * LogicalRDD
    * LocalRelation
    * LogicalRelation
    
    Physical planning with LogicalRelation will be covered via either 1) 
RowDataSourceScanExec or 2) FileSourceScanExec. It may not cover all possible 
types of relations, but it's uneasy for both data source developers and users 
to extend the Spark planner to handle the additional case in physical planning. 
Furthermore, DSv2 is the standard interface for streaming sources for external 
sources. For others they are covered with static nodes, LogicalRDD -> 
RDDScanExec, LocalRelation -> Local [...]
    
    This PR updates the progress reporter to collect the metrics from the nodes 
implementing StreamSourceAwareSparkPlan, which can be used across DSv1 and 
DSv2. To avoid regression, the progress reporter will check whether the 
executed plan has all streams we want to capture; if the executed plan misses 
some stream(s), the progress reporter will fall back to the old way.
    
    ### Why are the changes needed?
    
    For DSv2 data sources, the source nodes in the executed plan are always 
MicroBatchScanExec, and these nodes contain the stream information.
    
    But for DSv1 data sources, the logical node and the physical node 
representing the scan of the source are technically arbitrary (any logical node 
and any physical node), hence Spark makes an assumption that the leaf nodes for 
initial logical plan <=> logical plan for batch N <=> physical plan for batch N 
are the same so that we can associate these nodes. This is fragile and we have 
non-trivial number of reports of broken metric.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Modified tests - some tests were based on limitation and these tests are 
fixed with this PR.
    
    e.g. org.apache.spark.sql.streaming.TriggerAvailableNowSuite
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #48676 from HeartSaVioR/SPARK-50144.
    
    Authored-by: Jungtaek Lim <[email protected]>
    Signed-off-by: Jungtaek Lim <[email protected]>
---
 .../spark/sql/catalyst/optimizer/Optimizer.scala   | 13 ++++----
 .../sql/catalyst/plans/logical/LocalRelation.scala | 16 +++++++--
 .../logical/StreamSourceAwareLogicalPlan.scala     | 35 ++++++++++++++++++++
 .../analysis/ResolveInlineTablesSuite.scala        |  5 +--
 .../optimizer/RewriteDistinctAggregatesSuite.scala |  2 +-
 .../execution/SparkConnectPlanExecution.scala      |  2 +-
 .../spark/sql/execution/DataSourceScanExec.scala   | 17 ++++++++--
 .../apache/spark/sql/execution/ExistingRDD.scala   | 34 ++++++++++++++++---
 .../spark/sql/execution/LocalTableScanExec.scala   | 14 +++++++-
 .../spark/sql/execution/SparkStrategies.scala      |  9 ++---
 .../sql/execution/StreamSourceAwareSparkPlan.scala | 32 ++++++++++++++++++
 .../execution/datasources/DataSourceStrategy.scala |  3 ++
 .../execution/datasources/FileSourceStrategy.scala |  1 +
 .../execution/datasources/LogicalRelation.scala    | 19 ++++++++---
 .../datasources/SaveIntoDataSourceCommand.scala    |  3 +-
 .../datasources/v2/DataSourceV2Strategy.scala      |  5 +--
 .../datasources/v2/MicroBatchScanExec.scala        |  9 +++--
 .../execution/streaming/MicroBatchExecution.scala  | 17 ++++++----
 .../sql/execution/streaming/ProgressReporter.scala | 38 ++++++++++++++++++++--
 .../org/apache/spark/sql/DataFrameJoinSuite.scala  |  2 +-
 .../org/apache/spark/sql/DataFrameSuite.scala      |  2 +-
 .../scala/org/apache/spark/sql/SubquerySuite.scala |  2 +-
 .../execution/OptimizeMetadataOnlyQuerySuite.scala |  4 +--
 .../spark/sql/execution/SparkPlanSuite.scala       |  6 ++--
 .../spark/sql/execution/SparkPlannerSuite.scala    |  4 +--
 .../bucketing/CoalesceBucketsInJoinSuite.scala     |  3 +-
 .../StreamingSymmetricHashJoinHelperSuite.scala    |  4 +--
 .../sql/streaming/TriggerAvailableNowSuite.scala   | 37 +++++++++++++--------
 28 files changed, 267 insertions(+), 71 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index fb234c7bda4c..dcfb64ae51fb 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -2228,20 +2228,21 @@ object DecimalAggregates extends Rule[LogicalPlan] {
 object ConvertToLocalRelation extends Rule[LogicalPlan] {
   def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
     _.containsPattern(LOCAL_RELATION), ruleId) {
-    case Project(projectList, LocalRelation(output, data, isStreaming))
+    case Project(projectList, LocalRelation(output, data, isStreaming, stream))
         if !projectList.exists(hasUnevaluableExpr) =>
       val projection = new InterpretedMutableProjection(projectList, output)
       projection.initialize(0)
-      LocalRelation(projectList.map(_.toAttribute), 
data.map(projection(_).copy()), isStreaming)
+      LocalRelation(projectList.map(_.toAttribute), 
data.map(projection(_).copy()),
+        isStreaming, stream)
 
-    case Limit(IntegerLiteral(limit), LocalRelation(output, data, 
isStreaming)) =>
-      LocalRelation(output, data.take(limit), isStreaming)
+    case Limit(IntegerLiteral(limit), LocalRelation(output, data, isStreaming, 
stream)) =>
+      LocalRelation(output, data.take(limit), isStreaming, stream)
 
-    case Filter(condition, LocalRelation(output, data, isStreaming))
+    case Filter(condition, LocalRelation(output, data, isStreaming, stream))
         if !hasUnevaluableExpr(condition) =>
       val predicate = Predicate.create(condition, output)
       predicate.initialize(0)
-      LocalRelation(output, data.filter(row => predicate.eval(row)), 
isStreaming)
+      LocalRelation(output, data.filter(row => predicate.eval(row)), 
isStreaming, stream)
   }
 
   private def hasUnevaluableExpr(expr: Expression): Boolean = {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala
index 8b9d8c91815f..f52c38a64ab3 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, 
Literal}
 import 
org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils
 import org.apache.spark.sql.catalyst.trees.TreePattern.{LOCAL_RELATION, 
TreePattern}
 import org.apache.spark.sql.catalyst.types.DataTypeUtils
+import org.apache.spark.sql.connector.read.streaming.SparkDataStream
 import org.apache.spark.sql.types.{StructField, StructType}
 import org.apache.spark.util.Utils
 
@@ -61,19 +62,28 @@ case class LocalRelation(
     output: Seq[Attribute],
     data: Seq[InternalRow] = Nil,
     // Indicates whether this relation has data from a streaming source.
-    override val isStreaming: Boolean = false)
-  extends LeafNode with analysis.MultiInstanceRelation {
+    override val isStreaming: Boolean = false,
+    @transient stream: Option[SparkDataStream] = None)
+  extends LeafNode
+  with StreamSourceAwareLogicalPlan
+  with analysis.MultiInstanceRelation {
 
   // A local relation must have resolved output.
   require(output.forall(_.resolved), "Unresolved attributes found when 
constructing LocalRelation.")
 
+  override def withStream(stream: SparkDataStream): LocalRelation = {
+    copy(stream = Some(stream))
+  }
+
+  override def getStream: Option[SparkDataStream] = stream
+
   /**
    * Returns an identical copy of this relation with new exprIds for all 
attributes.  Different
    * attributes are required when a relation is going to be included multiple 
times in the same
    * query.
    */
   override final def newInstance(): this.type = {
-    LocalRelation(output.map(_.newInstance()), data, 
isStreaming).asInstanceOf[this.type]
+    LocalRelation(output.map(_.newInstance()), data, isStreaming, 
stream).asInstanceOf[this.type]
   }
 
   override protected def stringArgs: Iterator[Any] = {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/StreamSourceAwareLogicalPlan.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/StreamSourceAwareLogicalPlan.scala
new file mode 100644
index 000000000000..fd73a19fbf98
--- /dev/null
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/StreamSourceAwareLogicalPlan.scala
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.plans.logical
+
+import org.apache.spark.sql.connector.read.streaming.SparkDataStream
+
+/**
+ * This trait is a mixin for source logical nodes to represent the stream. 
This is required to the
+ * logical nodes which can be used in the leaf node of Source.getBatch().
+ */
+trait StreamSourceAwareLogicalPlan extends LogicalPlan {
+  /**
+   * Set the stream associated with this node.
+   * Spark will use this method to set the stream, and the implementation 
should copy the node with
+   * setting up the stream.
+   */
+  def withStream(stream: SparkDataStream): LogicalPlan
+
+  /** Get the stream associated with this node. */
+  def getStream: Option[SparkDataStream]
+}
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala
index 80cd56a8007a..f231164d5c25 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala
@@ -105,7 +105,7 @@ class ResolveInlineTablesSuite extends AnalysisTest with 
BeforeAndAfter {
     assert(resolved.isInstanceOf[ResolvedInlineTable])
 
     EvalInlineTables(ComputeCurrentTime(resolved)) match {
-      case LocalRelation(output, data, _) =>
+      case LocalRelation(output, data, _, _) =>
         assert(output.map(_.dataType) == Seq(TimestampType))
         assert(data.size == 2)
         // Make sure that both CURRENT_TIMESTAMP expressions are evaluated to 
the same value.
@@ -117,7 +117,8 @@ class ResolveInlineTablesSuite extends AnalysisTest with 
BeforeAndAfter {
     val table = UnresolvedInlineTable(Seq("c1"),
       Seq(Seq(Cast(lit("1991-12-06 00:00:00.0"), TimestampType))))
     val withTimeZone = ResolveTimeZone.apply(table)
-    val LocalRelation(output, data, _) = 
EvalInlineTables(ResolveInlineTables.apply(withTimeZone))
+    val LocalRelation(output, data, _, _) =
+      EvalInlineTables(ResolveInlineTables.apply(withTimeZone))
     val correct = Cast(lit("1991-12-06 00:00:00.0"), TimestampType)
       .withTimeZone(conf.sessionLocalTimeZone).eval().asInstanceOf[Long]
     assert(output.map(_.dataType) == Seq(TimestampType))
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala
index 4d31999ded65..9cb5ee46e0f3 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala
@@ -87,7 +87,7 @@ class RewriteDistinctAggregatesSuite extends PlanTest {
 
     val rewrite = RewriteDistinctAggregates(input)
     rewrite match {
-      case Aggregate(_, _, LocalRelation(_, _, _)) =>
+      case Aggregate(_, _, _: LocalRelation) =>
       case _ => fail(s"Plan is not as expected:\n$rewrite")
     }
   }
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
index 660951f22984..c0fd00b2eeaa 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
@@ -137,7 +137,7 @@ private[execution] class 
SparkConnectPlanExecution(executeHolder: ExecuteHolder)
     }
 
     dataframe.queryExecution.executedPlan match {
-      case LocalTableScanExec(_, rows) =>
+      case LocalTableScanExec(_, rows, _) =>
         executePlan.eventsManager.postFinished(Some(rows.length))
         var offset = 0L
         converter(rows.iterator).foreach { case (bytes, count) =>
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
index 2ebbb9664f67..226debc97642 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
@@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.QueryPlan
 import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, 
Partitioning, UnknownPartitioning}
 import org.apache.spark.sql.catalyst.util.{truncatedString, CaseInsensitiveMap}
+import org.apache.spark.sql.connector.read.streaming.SparkDataStream
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.execution
 import org.apache.spark.sql.execution.datasources._
@@ -46,7 +47,7 @@ import org.apache.spark.util.ArrayImplicits._
 import org.apache.spark.util.Utils
 import org.apache.spark.util.collection.BitSet
 
-trait DataSourceScanExec extends LeafExecNode {
+trait DataSourceScanExec extends LeafExecNode with StreamSourceAwareSparkPlan {
   def relation: BaseRelation
   def tableIdentifier: Option[TableIdentifier]
 
@@ -114,6 +115,7 @@ case class RowDataSourceScanExec(
     pushedDownOperators: PushedDownOperators,
     rdd: RDD[InternalRow],
     @transient relation: BaseRelation,
+    @transient stream: Option[SparkDataStream],
     tableIdentifier: Option[TableIdentifier])
   extends DataSourceScanExec with InputRDDCodegen {
 
@@ -201,12 +203,15 @@ case class RowDataSourceScanExec(
       )
   }
 
-  // Don't care about `rdd` and `tableIdentifier` when canonicalizing.
+  // Don't care about `rdd` and `tableIdentifier`, and `stream` when 
canonicalizing.
   override def doCanonicalize(): SparkPlan =
     copy(
       output.map(QueryPlan.normalizeExpressions(_, output)),
       rdd = null,
-      tableIdentifier = None)
+      tableIdentifier = None,
+      stream = None)
+
+  override def getStream: Option[SparkDataStream] = stream
 }
 
 /**
@@ -599,6 +604,7 @@ trait FileSourceScanLike extends DataSourceScanExec {
  */
 case class FileSourceScanExec(
     @transient override val relation: HadoopFsRelation,
+    @transient stream: Option[SparkDataStream],
     override val output: Seq[Attribute],
     override val requiredSchema: StructType,
     override val partitionFilters: Seq[Expression],
@@ -817,6 +823,9 @@ case class FileSourceScanExec(
   override def doCanonicalize(): FileSourceScanExec = {
     FileSourceScanExec(
       relation,
+      // remove stream on canonicalization; this is needed for reused shuffle 
to be effective in
+      // self-join
+      None,
       output.map(QueryPlan.normalizeExpressions(_, output)),
       requiredSchema,
       QueryPlan.normalizePredicates(
@@ -827,4 +836,6 @@ case class FileSourceScanExec(
       None,
       disableBucketedScan)
   }
+
+  override def getStream: Option[SparkDataStream] = stream
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index 8c7ed7b88d45..fd8f0b85edd2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, 
PartitioningCollection, UnknownPartitioning}
 import org.apache.spark.sql.catalyst.util.truncatedString
+import org.apache.spark.sql.connector.read.streaming.SparkDataStream
 import org.apache.spark.sql.execution.metric.SQLMetrics
 import org.apache.spark.util.collection.Utils
 
@@ -97,13 +98,16 @@ case class LogicalRDD(
     rdd: RDD[InternalRow],
     outputPartitioning: Partitioning = UnknownPartitioning(0),
     override val outputOrdering: Seq[SortOrder] = Nil,
-    override val isStreaming: Boolean = false)(
+    override val isStreaming: Boolean = false,
+    @transient stream: Option[SparkDataStream] = None)(
     session: SparkSession,
     // originStats and originConstraints are intentionally placed to "second" 
parameter list,
     // to prevent catalyst rules to mistakenly transform and rewrite them. Do 
not change this.
     originStats: Option[Statistics] = None,
     originConstraints: Option[ExpressionSet] = None)
-  extends LeafNode with MultiInstanceRelation {
+  extends LeafNode
+  with StreamSourceAwareLogicalPlan
+  with MultiInstanceRelation {
 
   import LogicalRDD._
 
@@ -134,7 +138,8 @@ case class LogicalRDD(
       rdd,
       rewrittenPartitioning,
       rewrittenOrdering,
-      isStreaming
+      isStreaming,
+      stream
     )(session, rewrittenStatistics, 
rewrittenConstraints).asInstanceOf[this.type]
   }
 
@@ -158,6 +163,13 @@ case class LogicalRDD(
     // Therefore we assume that all subqueries are non-deterministic, and we 
do not expose any
     // constraints that contain a subquery.
     .filterNot(SubqueryExpression.hasSubquery)
+
+  override def withStream(stream: SparkDataStream): LogicalRDD = {
+    copy(stream = Some(stream))(session, originStats, originConstraints)
+  }
+
+  override def getStream: Option[SparkDataStream] = stream
+
 }
 
 object LogicalRDD extends Logging {
@@ -191,7 +203,8 @@ object LogicalRDD extends Logging {
       rdd,
       firstLeafPartitioning(executedPlan.outputPartitioning),
       executedPlan.outputOrdering,
-      isStreaming
+      isStreaming,
+      None
     )(originDataset.sparkSession, stats, constraints)
   }
 
@@ -264,7 +277,11 @@ case class RDDScanExec(
     rdd: RDD[InternalRow],
     name: String,
     override val outputPartitioning: Partitioning = UnknownPartitioning(0),
-    override val outputOrdering: Seq[SortOrder] = Nil) extends LeafExecNode 
with InputRDDCodegen {
+    override val outputOrdering: Seq[SortOrder] = Nil,
+    @transient stream: Option[SparkDataStream] = None)
+  extends LeafExecNode
+  with StreamSourceAwareSparkPlan
+  with InputRDDCodegen {
 
   private def rddName: String = Option(rdd.name).map(n => s" $n").getOrElse("")
 
@@ -293,4 +310,11 @@ case class RDDScanExec(
   override protected val createUnsafeProjection: Boolean = true
 
   override def inputRDD: RDD[InternalRow] = rdd
+
+  // Don't care about `stream` when canonicalizing.
+  override protected def doCanonicalize(): SparkPlan = {
+    super.doCanonicalize().asInstanceOf[RDDScanExec].copy(stream = None)
+  }
+
+  override def getStream: Option[SparkDataStream] = stream
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala
index 9ac79aab36f6..2d5dbf819959 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection}
+import org.apache.spark.sql.connector.read.streaming.SparkDataStream
 import org.apache.spark.sql.execution.metric.SQLMetrics
 import org.apache.spark.util.ArrayImplicits._
 
@@ -32,7 +33,11 @@ import org.apache.spark.util.ArrayImplicits._
  */
 case class LocalTableScanExec(
     output: Seq[Attribute],
-    @transient rows: Seq[InternalRow]) extends LeafExecNode with 
InputRDDCodegen {
+    @transient rows: Seq[InternalRow],
+    @transient stream: Option[SparkDataStream])
+  extends LeafExecNode
+  with StreamSourceAwareSparkPlan
+  with InputRDDCodegen {
 
   override lazy val metrics = Map(
     "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output 
rows"))
@@ -97,8 +102,15 @@ case class LocalTableScanExec(
 
   override def inputRDD: RDD[InternalRow] = rdd
 
+  override def getStream: Option[SparkDataStream] = stream
+
   private def sendDriverMetrics(): Unit = {
     val executionId = 
sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
     SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, 
metrics.values.toSeq)
   }
+
+  // Don't care about `stream` when canonicalizing.
+  override protected def doCanonicalize(): SparkPlan = {
+    super.doCanonicalize().asInstanceOf[LocalTableScanExec].copy(stream = None)
+  }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 30b395d0c136..134a69500e57 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -865,7 +865,7 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
       case MemoryPlan(sink, output) =>
         val encoder = ExpressionEncoder(DataTypeUtils.fromAttributes(output))
         val toRow = encoder.createSerializer()
-        LocalTableScanExec(output, sink.allData.map(r => toRow(r).copy())) :: 
Nil
+        LocalTableScanExec(output, sink.allData.map(r => toRow(r).copy()), 
None) :: Nil
 
       case logical.Distinct(child) =>
         throw SparkException.internalError(
@@ -985,8 +985,8 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
         execution.ExpandExec(e.projections, e.output, planLater(child)) :: Nil
       case logical.Sample(lb, ub, withReplacement, seed, child) =>
         execution.SampleExec(lb, ub, withReplacement, seed, planLater(child)) 
:: Nil
-      case logical.LocalRelation(output, data, _) =>
-        LocalTableScanExec(output, data) :: Nil
+      case logical.LocalRelation(output, data, _, stream) =>
+        LocalTableScanExec(output, data, stream) :: Nil
       case logical.EmptyRelation(l) => EmptyRelationExec(l) :: Nil
       case CommandResult(output, _, plan, data) => CommandResultExec(output, 
plan, data) :: Nil
       // We should match the combination of limit and offset first, to get the 
optimal physical
@@ -1036,7 +1036,8 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
           shuffleOrigin, r.optAdvisoryPartitionSize) :: Nil
       case ExternalRDD(outputObjAttr, rdd) => 
ExternalRDDScanExec(outputObjAttr, rdd) :: Nil
       case r: LogicalRDD =>
-        RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, 
r.outputOrdering) :: Nil
+        RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, 
r.outputOrdering,
+          r.stream) :: Nil
       case _: UpdateTable =>
         throw QueryExecutionErrors.ddlUnsupportedTemporarilyError("UPDATE 
TABLE")
       case _: MergeIntoTable =>
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/StreamSourceAwareSparkPlan.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/StreamSourceAwareSparkPlan.scala
new file mode 100644
index 000000000000..cd50b78d203f
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/StreamSourceAwareSparkPlan.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution
+
+import org.apache.spark.sql.connector.read.streaming.SparkDataStream
+
+/**
+ * This trait is a mixin for source physical nodes to represent the stream. 
This is required to the
+ * physical nodes which is transformed from source logical nodes inheriting
+ * 
[[org.apache.spark.sql.catalyst.plans.logical.StreamSourceAwareLogicalPlan]].
+ *
+ * The node implementing this trait should expose the number of output rows 
via "numOutputRows"
+ * in `metrics`.
+ */
+trait StreamSourceAwareSparkPlan extends SparkPlan {
+  /** Get the stream associated with this node. */
+  def getStream: Option[SparkDataStream]
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index cacb1ef928e8..95746218e879 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -370,6 +370,7 @@ object DataSourceStrategy
         PushedDownOperators(None, None, None, None, Seq.empty, Seq.empty),
         toCatalystRDD(l, baseRelation.buildScan()),
         baseRelation,
+        l.stream,
         None) :: Nil
 
     case _ => Nil
@@ -444,6 +445,7 @@ object DataSourceStrategy
         PushedDownOperators(None, None, None, None, Seq.empty, Seq.empty),
         scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
         relation.relation,
+        relation.stream,
         relation.catalogTable.map(_.identifier))
       filterCondition.map(execution.FilterExec(_, scan)).getOrElse(scan)
     } else {
@@ -467,6 +469,7 @@ object DataSourceStrategy
         PushedDownOperators(None, None, None, None, Seq.empty, Seq.empty),
         scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
         relation.relation,
+        relation.stream,
         relation.catalogTable.map(_.identifier))
       execution.ProjectExec(
         projects, filterCondition.map(execution.FilterExec(_, 
scan)).getOrElse(scan))
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
index 27cf9702b155..02235ffb1976 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
@@ -321,6 +321,7 @@ object FileSourceStrategy extends Strategy with 
PredicateHelper with Logging {
       val scan =
         FileSourceScanExec(
           fsRelation,
+          l.stream,
           outputAttributes,
           outputDataSchema,
           partitionKeyFilters.toSeq,
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala
index 09502c8ecb3a..725b4a233257 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala
@@ -20,9 +20,10 @@ import 
org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
 import org.apache.spark.sql.catalyst.catalog.CatalogTable
 import org.apache.spark.sql.catalyst.expressions.{AttributeMap, 
AttributeReference}
 import org.apache.spark.sql.catalyst.plans.QueryPlan
-import org.apache.spark.sql.catalyst.plans.logical.{ExposesMetadataColumns, 
LeafNode, LogicalPlan, Statistics}
+import org.apache.spark.sql.catalyst.plans.logical.{ExposesMetadataColumns, 
LeafNode, LogicalPlan, Statistics, StreamSourceAwareLogicalPlan}
 import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
 import org.apache.spark.sql.catalyst.util.{truncatedString, CharVarcharUtils}
+import org.apache.spark.sql.connector.read.streaming.SparkDataStream
 import org.apache.spark.sql.sources.BaseRelation
 
 /**
@@ -39,8 +40,12 @@ case class LogicalRelation(
     relation: BaseRelation,
     output: Seq[AttributeReference],
     catalogTable: Option[CatalogTable],
-    override val isStreaming: Boolean)
-  extends LeafNode with MultiInstanceRelation with ExposesMetadataColumns {
+    override val isStreaming: Boolean,
+    @transient stream: Option[SparkDataStream])
+  extends LeafNode
+  with StreamSourceAwareLogicalPlan
+  with MultiInstanceRelation
+  with ExposesMetadataColumns {
 
   // Only care about relation when canonicalizing.
   override def doCanonicalize(): LogicalPlan = copy(
@@ -92,6 +97,10 @@ case class LogicalRelation(
       this
     }
   }
+
+  override def withStream(stream: SparkDataStream): LogicalRelation = 
copy(stream = Some(stream))
+
+  override def getStream: Option[SparkDataStream] = stream
 }
 
 object LogicalRelation {
@@ -99,14 +108,14 @@ object LogicalRelation {
     // The v1 source may return schema containing char/varchar type. We 
replace char/varchar
     // with "annotated" string type here as the query engine doesn't support 
char/varchar yet.
     val schema = 
CharVarcharUtils.replaceCharVarcharWithStringInSchema(relation.schema)
-    LogicalRelation(relation, toAttributes(schema), None, isStreaming)
+    LogicalRelation(relation, toAttributes(schema), None, isStreaming, None)
   }
 
   def apply(relation: BaseRelation, table: CatalogTable): LogicalRelation = {
     // The v1 source may return schema containing char/varchar type. We 
replace char/varchar
     // with "annotated" string type here as the query engine doesn't support 
char/varchar yet.
     val schema = 
CharVarcharUtils.replaceCharVarcharWithStringInSchema(relation.schema)
-    LogicalRelation(relation, toAttributes(schema), Some(table), false)
+    LogicalRelation(relation, toAttributes(schema), Some(table), false, None)
   }
 }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala
index e44f1d35e9cd..51fed315439e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala
@@ -68,7 +68,8 @@ case class SaveIntoDataSourceCommand(
     }
 
     try {
-      val logicalRelation = LogicalRelation(relation, 
toAttributes(relation.schema), None, false)
+      val logicalRelation = LogicalRelation(relation, 
toAttributes(relation.schema), None,
+        false, None)
       sparkSession.sharedState.cacheManager.recacheByPlan(sparkSession, 
logicalRelation)
     } catch {
       case NonFatal(_) =>
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
index 553b9447d644..1c265650e02a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
@@ -129,13 +129,14 @@ class DataSourceV2Strategy(session: SparkSession) extends 
Strategy with Predicat
         pushedDownOperators,
         unsafeRowRDD,
         v1Relation,
+        None,
         tableIdentifier)
       DataSourceV2Strategy.withProjectAndFilter(
         project, filters, dsScan, needsUnsafeConversion = false) :: Nil
 
     case PhysicalOperation(project, filters,
         DataSourceV2ScanRelation(_, scan: LocalScan, output, _, _)) =>
-      val localScanExec = LocalTableScanExec(output, 
scan.rows().toImmutableArraySeq)
+      val localScanExec = LocalTableScanExec(output, 
scan.rows().toImmutableArraySeq, None)
       DataSourceV2Strategy.withProjectAndFilter(
         project, filters, localScanExec, needsUnsafeConversion = false) :: Nil
 
@@ -362,7 +363,7 @@ class DataSourceV2Strategy(session: SparkSession) extends 
Strategy with Predicat
       DropTableExec(r.catalog.asTableCatalog, r.identifier, ifExists, purge, 
invalidateFunc) :: Nil
 
     case _: NoopCommand =>
-      LocalTableScanExec(Nil, Nil) :: Nil
+      LocalTableScanExec(Nil, Nil, None) :: Nil
 
     case RenameTable(r @ ResolvedTable(catalog, oldIdent, _, _), newIdent, 
isView) =>
       if (isView) {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MicroBatchScanExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MicroBatchScanExec.scala
index 07958987fa08..f81ca001fbe2 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MicroBatchScanExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MicroBatchScanExec.scala
@@ -21,7 +21,8 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, 
SortOrder}
 import org.apache.spark.sql.connector.read.{InputPartition, 
PartitionReaderFactory, Scan}
-import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, Offset}
+import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, 
Offset, SparkDataStream}
+import org.apache.spark.sql.execution.StreamSourceAwareSparkPlan
 import org.apache.spark.util.ArrayImplicits._
 
 /**
@@ -34,7 +35,9 @@ case class MicroBatchScanExec(
     @transient start: Offset,
     @transient end: Offset,
     keyGroupedPartitioning: Option[Seq[Expression]] = None,
-    ordering: Option[Seq[SortOrder]] = None) extends DataSourceV2ScanExecBase {
+    ordering: Option[Seq[SortOrder]] = None)
+  extends DataSourceV2ScanExecBase
+  with StreamSourceAwareSparkPlan {
 
   // TODO: unify the equal/hashCode implementation for all data source v2 
query plans.
   override def equals(other: Any): Boolean = other match {
@@ -55,4 +58,6 @@ case class MicroBatchScanExec(
     postDriverMetrics()
     inputRDD
   }
+
+  override def getStream: Option[SparkDataStream] = Some(stream)
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
index 5ce9e13eb8fa..40d58e5a402a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
@@ -24,7 +24,7 @@ import org.apache.spark.internal.{LogKeys, MDC}
 import org.apache.spark.sql.{Dataset, SparkSession}
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, 
CurrentBatchTimestamp, CurrentDate, CurrentTimestamp, 
FileSourceMetadataAttribute, LocalTimestamp}
-import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, 
LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, 
LogicalPlan, Project, StreamSourceAwareLogicalPlan}
 import org.apache.spark.sql.catalyst.streaming.{StreamingRelationV2, 
WriteToStream}
 import org.apache.spark.sql.catalyst.trees.TreePattern.CURRENT_LIKE
 import org.apache.spark.sql.catalyst.util.truncatedString
@@ -769,21 +769,25 @@ class MicroBatchExecution(
               }
               newRelation
           }
+          val finalDataPlanWithStream = finalDataPlan transformUp {
+            case l: StreamSourceAwareLogicalPlan => l.withStream(source)
+          }
           // SPARK-40460: overwrite the entry with the new logicalPlan
           // because it might contain the _metadata column. It is a necessary 
change,
           // in the ProgressReporter, we use the following mapping to get 
correct streaming metrics:
           // streaming logical plan (with sources) <==> trigger's logical plan 
<==> executed plan
-          mutableNewData.put(source, finalDataPlan)
+          mutableNewData.put(source, finalDataPlanWithStream)
           val maxFields = SQLConf.get.maxToStringFields
-          assert(output.size == finalDataPlan.output.size,
+          assert(output.size == finalDataPlanWithStream.output.size,
             s"Invalid batch: ${truncatedString(output, ",", maxFields)} != " +
-              s"${truncatedString(finalDataPlan.output, ",", maxFields)}")
+              s"${truncatedString(finalDataPlanWithStream.output, ",", 
maxFields)}")
 
-          val aliases = output.zip(finalDataPlan.output).map { case (to, from) 
=>
+          val aliases = output.zip(finalDataPlanWithStream.output).map { case 
(to, from) =>
             Alias(from, to.name)(exprId = to.exprId, explicitMetadata = 
Some(from.metadata))
           }
-          Project(aliases, finalDataPlan)
+          Project(aliases, finalDataPlanWithStream)
         }.getOrElse {
+          // Don't track the source node which is known to produce zero rows.
           LocalRelation(output, isStreaming = true)
         }
 
@@ -793,6 +797,7 @@ class MicroBatchExecution(
           case OffsetHolder(start, end) =>
             r.copy(startOffset = Some(start), endOffset = Some(end))
         }.getOrElse {
+          // Don't track the source node which is known to produce zero rows.
           LocalRelation(r.output, isStreaming = true)
         }
     }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
index c440ec451b72..fdb4f2813dba 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
@@ -33,7 +33,7 @@ import 
org.apache.spark.sql.catalyst.util.DateTimeConstants.MILLIS_PER_SECOND
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.connector.catalog.Table
 import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, 
ReportsSinkMetrics, ReportsSourceMetrics, SparkDataStream}
-import org.apache.spark.sql.execution.QueryExecution
+import org.apache.spark.sql.execution.{QueryExecution, 
StreamSourceAwareSparkPlan}
 import org.apache.spark.sql.execution.datasources.v2.{MicroBatchScanExec, 
StreamingDataSourceV2ScanRelation, StreamWriterCommitProgress}
 import org.apache.spark.sql.streaming._
 import org.apache.spark.sql.streaming.StreamingQueryListener.{QueryIdleEvent, 
QueryProgressEvent}
@@ -401,7 +401,6 @@ abstract class ProgressContext(
     }
   }
 
-  /** Extract number of input sources for each streaming source in plan */
   private def extractSourceToNumInputRows(
       lastExecution: IncrementalExecution): Map[SparkDataStream, Long] = {
 
@@ -409,6 +408,41 @@ abstract class ProgressContext(
       tuples.groupBy(_._1).transform((_, v) => v.map(_._2).sum) // sum up rows 
for each source
     }
 
+    val sources = newData.keys.toSet
+
+    val sourceToInputRowsTuples = lastExecution.executedPlan
+      .collect {
+        case node: StreamSourceAwareSparkPlan if node.getStream.isDefined =>
+          val numRows = 
node.metrics.get("numOutputRows").map(_.value).getOrElse(0L)
+          node.getStream.get -> numRows
+      }
+
+    val capturedSources = sourceToInputRowsTuples.map(_._1).toSet
+
+    if (sources == capturedSources) {
+      logDebug("Source -> # input rows\n\t" + 
sourceToInputRowsTuples.mkString("\n\t"))
+      sumRows(sourceToInputRowsTuples)
+    } else {
+      // Falling back to the legacy approach to avoid any regression.
+      val inputRows = legacyExtractSourceToNumInputRows(lastExecution)
+      // If the legacy approach fails to extract the input rows, we just pick 
the new approach
+      // as it is more likely that the source nodes have been pruned in valid 
reason.
+      if (inputRows.isEmpty) {
+        sumRows(sourceToInputRowsTuples)
+      } else {
+        inputRows
+      }
+    }
+  }
+
+  /** Extract number of input sources for each streaming source in plan */
+  private def legacyExtractSourceToNumInputRows(
+      lastExecution: IncrementalExecution): Map[SparkDataStream, Long] = {
+
+    def sumRows(tuples: Seq[(SparkDataStream, Long)]): Map[SparkDataStream, 
Long] = {
+      tuples.groupBy(_._1).transform((_, v) => v.map(_._2).sum) // sum up rows 
for each source
+    }
+
     def unrollCTE(plan: LogicalPlan): LogicalPlan = {
       val containsCTE = plan.exists {
         case _: WithCTE => true
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
index f6fd6b501d79..ed182322aec9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
@@ -448,7 +448,7 @@ class DataFrameJoinSuite extends QueryTest
             }
             assert(broadcastExchanges.size == 1)
             val tables = broadcastExchanges.head.collect {
-              case FileSourceScanExec(_, _, _, _, _, _, _, Some(tableIdent), 
_) => tableIdent
+              case FileSourceScanExec(_, _, _, _, _, _, _, _, 
Some(tableIdent), _) => tableIdent
             }
             assert(tables.size == 1)
             assert(tables.head ===
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index e784169dd59c..ff251ddbbfb5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -2074,7 +2074,7 @@ class DataFrameSuite extends QueryTest
     val emptyDf = spark.emptyDataFrame.withColumn("id", lit(1L))
     val joined = spark.range(10).join(emptyDf, "id")
     joined.queryExecution.optimizedPlan match {
-      case LocalRelation(Seq(id), Nil, _) =>
+      case LocalRelation(Seq(id), Nil, _, _) =>
         assert(id.name == "id")
       case _ =>
         fail("emptyDataFrame should be foldable")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
index f8f7fd246832..9e97c224736d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
@@ -1524,7 +1524,7 @@ class SubquerySuite extends QueryTest
       // need to execute the query before we can examine fs.inputRDDs()
       assert(stripAQEPlan(df.queryExecution.executedPlan) match {
         case WholeStageCodegenExec(ColumnarToRowExec(InputAdapter(
-            fs @ FileSourceScanExec(_, _, _, partitionFilters, _, _, _, _, 
_)))) =>
+            fs @ FileSourceScanExec(_, _, _, _, partitionFilters, _, _, _, _, 
_)))) =>
           partitionFilters.exists(ExecSubqueryExpression.hasSubquery) &&
             fs.inputRDDs().forall(
               _.asInstanceOf[FileScanRDD].filePartitions.forall(
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuerySuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuerySuite.scala
index cbc565974cd6..3828ad410c89 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuerySuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuerySuite.scala
@@ -47,14 +47,14 @@ class OptimizeMetadataOnlyQuerySuite extends QueryTest with 
SharedSparkSession {
 
   private def assertMetadataOnlyQuery(df: DataFrame): Unit = {
     val localRelations = df.queryExecution.optimizedPlan.collect {
-      case l @ LocalRelation(_, _, _) => l
+      case l: LocalRelation => l
     }
     assert(localRelations.size == 1)
   }
 
   private def assertNotMetadataOnlyQuery(df: DataFrame): Unit = {
     val localRelations = df.queryExecution.optimizedPlan.collect {
-      case l @ LocalRelation(_, _, _) => l
+      case l: LocalRelation => l
     }
     assert(localRelations.size == 0)
   }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala
index 8dc07e2df99f..aed11badb710 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala
@@ -92,7 +92,7 @@ class SparkPlanSuite extends QueryTest with 
SharedSparkSession {
   }
 
   test("SPARK-30780 empty LocalTableScan should use RDD without partitions") {
-    assert(LocalTableScanExec(Nil, Nil).execute().getNumPartitions == 0)
+    assert(LocalTableScanExec(Nil, Nil, None).execute().getNumPartitions == 0)
   }
 
   test("SPARK-33617: change default parallelism of LocalTableScan") {
@@ -119,11 +119,11 @@ class SparkPlanSuite extends QueryTest with 
SharedSparkSession {
   }
 
   test("SPARK-37221: The collect-like API in SparkPlan should support columnar 
output") {
-    val emptyResults = ColumnarOp(LocalTableScanExec(Nil, 
Nil)).toRowBased.executeCollect()
+    val emptyResults = ColumnarOp(LocalTableScanExec(Nil, Nil, 
None)).toRowBased.executeCollect()
     assert(emptyResults.isEmpty)
 
     val relation = LocalTableScanExec(
-      Seq(AttributeReference("val", IntegerType)()), Seq(InternalRow(1)))
+      Seq(AttributeReference("val", IntegerType)()), Seq(InternalRow(1)), None)
     val nonEmpty = ColumnarOp(relation).toRowBased.executeCollect()
     assert(nonEmpty === relation.executeCollect())
   }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlannerSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlannerSuite.scala
index b4cb7e3fce3c..d5c8cabe5003 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlannerSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlannerSuite.scala
@@ -40,9 +40,9 @@ class SparkPlannerSuite extends SharedSparkSession {
         case u: Union =>
           planned += 1
           UnionExec(u.children.map(planLater)) :: planLater(NeverPlanned) :: 
Nil
-        case LocalRelation(output, data, _) =>
+        case LocalRelation(output, data, _, stream) =>
           planned += 1
-          LocalTableScanExec(output, data) :: planLater(NeverPlanned) :: Nil
+          LocalTableScanExec(output, data, stream) :: planLater(NeverPlanned) 
:: Nil
         case NeverPlanned =>
           fail("QueryPlanner should not go down to this branch.")
         case _ => Nil
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/bucketing/CoalesceBucketsInJoinSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/bucketing/CoalesceBucketsInJoinSuite.scala
index b065c9a27a45..b44e899c1a4f 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/bucketing/CoalesceBucketsInJoinSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/bucketing/CoalesceBucketsInJoinSuite.scala
@@ -73,7 +73,8 @@ class CoalesceBucketsInJoinSuite extends SQLTestUtils with 
SharedSparkSession {
       bucketSpec = Some(BucketSpec(setting.numBuckets, 
setting.cols.map(_.name), Nil)),
       fileFormat = new ParquetFileFormat(),
       options = Map.empty)(spark)
-    FileSourceScanExec(relation, setting.cols, relation.dataSchema, Nil, None, 
None, Nil, None)
+    FileSourceScanExec(relation, None, setting.cols, relation.dataSchema, Nil, 
None, None, Nil,
+      None)
   }
 
   private def run(setting: JoinSetting): Unit = {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSymmetricHashJoinHelperSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSymmetricHashJoinHelperSuite.scala
index 26039b9185db..b38d5d7dbce9 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSymmetricHashJoinHelperSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSymmetricHashJoinHelperSuite.scala
@@ -29,8 +29,8 @@ class StreamingSymmetricHashJoinHelperSuite extends 
StreamTest {
   val rightAttributeC = AttributeReference("c", IntegerType)()
   val rightAttributeD = AttributeReference("d", IntegerType)()
 
-  val left = new LocalTableScanExec(Seq(leftAttributeA, leftAttributeB), Seq())
-  val right = new LocalTableScanExec(Seq(rightAttributeC, rightAttributeD), 
Seq())
+  val left = new LocalTableScanExec(Seq(leftAttributeA, leftAttributeB), 
Seq(), None)
+  val right = new LocalTableScanExec(Seq(rightAttributeC, rightAttributeD), 
Seq(), None)
 
   test("empty") {
     val split = JoinConditionSplitPredicates(None, left, right)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TriggerAvailableNowSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TriggerAvailableNowSuite.scala
index a47c2f839692..3de6273ffb7b 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TriggerAvailableNowSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TriggerAvailableNowSuite.scala
@@ -57,8 +57,10 @@ class TriggerAvailableNowSuite extends FileStreamSourceTest {
       if (currentOffset == 0) currentOffset = getOffsetValue(end)
       val plan = Range(
         start.map(getOffsetValue).getOrElse(0L) + 1L, getOffsetValue(end) + 
1L, 1, None,
-        isStreaming = true)
-      Dataset.ofRows(spark, plan)
+        // Intentionally set isStreaming to false; we only use RDD plan in 
below.
+        isStreaming = false)
+      sqlContext.internalCreateDataFrame(
+        plan.queryExecution.toRdd, plan.schema, isStreaming = true)
     }
 
     override def incrementAvailableOffset(numNewRows: Int): Unit = {
@@ -115,23 +117,24 @@ class TriggerAvailableNowSuite extends 
FileStreamSourceTest {
     }
   }
 
-  def testWithConfigMatrix(testName: String)(testFun: => Any): Unit = {
+  def testWithConfigMatrix(testName: String)(testFun: Boolean => Any): Unit = {
     Seq(true, false).foreach { useWrapper =>
       test(testName + s" (using wrapper: $useWrapper)") {
         withSQLConf(
           SQLConf.STREAMING_TRIGGER_AVAILABLE_NOW_WRAPPER_ENABLED.key -> 
useWrapper.toString) {
-          testFun
+          testFun(useWrapper)
         }
       }
     }
   }
 
   Seq(
-    new TestSource,
-    new TestSourceWithAdmissionControl,
-    new TestMicroBatchStream
-  ).foreach { testSource =>
-    testWithConfigMatrix(s"TriggerAvailableNow for multiple sources with 
${testSource.getClass}") {
+    (new TestSource, false),
+    (new TestSourceWithAdmissionControl, false),
+    (new TestMicroBatchStream, true)
+  ).foreach { case (testSource, supportTriggerAvailableNow) =>
+    testWithConfigMatrix(s"TriggerAvailableNow for multiple sources with " +
+      s"${testSource.getClass}") { useWrapper =>
       testSource.reset()
 
       withTempDirs { (src, target) =>
@@ -170,10 +173,16 @@ class TriggerAvailableNowSuite extends 
FileStreamSourceTest {
 
         val q = startQuery()
 
+        val expectedNumBatches = if (!useWrapper && 
!supportTriggerAvailableNow) {
+          // Spark will decide to fall back to SingleBatchExecutor.
+          1
+        } else {
+          3
+        }
+
         try {
           assert(q.awaitTermination(streamingTimeout.toMillis))
-          // only one batch has data in both sources, thus counted, see 
SPARK-24050
-          assert(q.recentProgress.count(_.numInputRows != 0) == 1)
+          assert(q.recentProgress.count(_.numInputRows != 0) == 
expectedNumBatches)
           q.recentProgress.foreach { p =>
             
assert(p.sources.exists(_.description.startsWith(testSource.sourceName)))
           }
@@ -193,8 +202,7 @@ class TriggerAvailableNowSuite extends FileStreamSourceTest 
{
         val q2 = startQuery()
         try {
           assert(q2.awaitTermination(streamingTimeout.toMillis))
-          // only one batch has data in both sources, thus counted, see 
SPARK-24050
-          assert(q2.recentProgress.count(_.numInputRows != 0) == 1)
+          assert(q2.recentProgress.count(_.numInputRows != 0) == 
expectedNumBatches)
           q2.recentProgress.foreach { p =>
             
assert(p.sources.exists(_.description.startsWith(testSource.sourceName)))
           }
@@ -212,7 +220,8 @@ class TriggerAvailableNowSuite extends FileStreamSourceTest 
{
     new TestSourceWithAdmissionControl,
     new TestMicroBatchStream
   ).foreach { testSource =>
-    testWithConfigMatrix(s"TriggerAvailableNow for single source with 
${testSource.getClass}") {
+    testWithConfigMatrix(s"TriggerAvailableNow for single source with " +
+      s"${testSource.getClass}") { _ =>
       testSource.reset()
 
       val tableName = "trigger_available_now_test_table"


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to