This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new cb0b51038b0 [SPARK-39748][SQL][SS] Include the origin logical plan for 
LogicalRDD if it comes from DataFrame
cb0b51038b0 is described below

commit cb0b51038b0ae17ba2a4a38082e322f5b6087e06
Author: Jungtaek Lim <[email protected]>
AuthorDate: Tue Jul 12 17:59:10 2022 +0900

    [SPARK-39748][SQL][SS] Include the origin logical plan for LogicalRDD if it 
comes from DataFrame
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to include the origin logical plan for LogicalRDD, if the 
LogicalRDD is built from DataFrame's RDD. Once the origin logical plan is 
available, LogicalRDD produces the stats from origin logical plan rather than 
default one.
    
    Also, this PR applies the change to ForeachBatchSink, which seems to be the 
only case as of now in current codebase.
    
    ### Why are the changes needed?
    
    The origin logical plan can be useful for several use cases, including:
    
    1. wants to connect the two split logical plans into one (consider the case 
of foreachBatch sink: origin logical plan represents the plan for streaming 
query, and the logical plan for new Dataset represents the plan for batch query 
in user function)
    2. inherits plan stats from origin logical plan
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    New UT.
    
    Closes #37161 from HeartSaVioR/SPARK-39748.
    
    Authored-by: Jungtaek Lim <[email protected]>
    Signed-off-by: Jungtaek Lim <[email protected]>
---
 .../main/scala/org/apache/spark/sql/Dataset.scala  |  1 +
 .../apache/spark/sql/execution/ExistingRDD.scala   | 23 ++++++++---
 .../streaming/sources/ForeachBatchSink.scala       | 22 ++++++++++-
 .../streaming/sources/ForeachBatchSinkSuite.scala  | 45 +++++++++++++++++++++-
 4 files changed, 83 insertions(+), 8 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 39d33d80261..f45c27d3007 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -705,6 +705,7 @@ class Dataset[T] private[sql](
         LogicalRDD(
           logicalPlan.output,
           internalRdd,
+          None,
           outputPartitioning,
           physicalPlan.outputOrdering,
           isStreaming
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index 1ab183fe843..bf9ef6991e3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -83,10 +83,16 @@ case class ExternalRDDScanExec[T](
   }
 }
 
-/** Logical plan node for scanning data from an RDD of InternalRow. */
+/**
+ * Logical plan node for scanning data from an RDD of InternalRow.
+ *
+ * It is advised to set the field `originLogicalPlan` if the RDD is directly 
built from DataFrame,
+ * as the stat can be inherited from `originLogicalPlan`.
+ */
 case class LogicalRDD(
     output: Seq[Attribute],
     rdd: RDD[InternalRow],
+    originLogicalPlan: Option[LogicalPlan] = None,
     outputPartitioning: Partitioning = UnknownPartitioning(0),
     override val outputOrdering: Seq[SortOrder] = Nil,
     override val isStreaming: Boolean = false)(session: SparkSession)
@@ -113,6 +119,7 @@ case class LogicalRDD(
     LogicalRDD(
       output.map(rewrite),
       rdd,
+      originLogicalPlan,
       rewrittenPartitioning,
       rewrittenOrdering,
       isStreaming
@@ -121,11 +128,15 @@ case class LogicalRDD(
 
   override protected def stringArgs: Iterator[Any] = Iterator(output, 
isStreaming)
 
-  override def computeStats(): Statistics = Statistics(
-    // TODO: Instead of returning a default value here, find a way to return a 
meaningful size
-    // estimate for RDDs. See PR 1238 for more discussions.
-    sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes)
-  )
+  override def computeStats(): Statistics = {
+    originLogicalPlan.map(_.stats).getOrElse {
+      Statistics(
+        // TODO: Instead of returning a default value here, find a way to 
return a meaningful size
+        // estimate for RDDs. See PR 1238 for more discussions.
+        sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes)
+      )
+    }
+  }
 }
 
 /** Physical plan node for scanning data from an RDD of InternalRow. */
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala
index 0893875aff5..1c6bca241af 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.streaming.sources
 
 import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.execution.LogicalRDD
 import org.apache.spark.sql.execution.streaming.Sink
 import org.apache.spark.sql.streaming.DataStreamWriter
 
@@ -27,11 +29,29 @@ class ForeachBatchSink[T](batchWriter: (Dataset[T], Long) 
=> Unit, encoder: Expr
 
   override def addBatch(batchId: Long, data: DataFrame): Unit = {
     val rdd = data.queryExecution.toRdd
+    val executedPlan = data.queryExecution.executedPlan
+    val node = LogicalRDD(
+      data.schema.toAttributes,
+      rdd,
+      Some(eliminateWriteMarkerNode(data.queryExecution.analyzed)),
+      executedPlan.outputPartitioning,
+      executedPlan.outputOrdering)(data.sparkSession)
     implicit val enc = encoder
-    val ds = data.sparkSession.internalCreateDataFrame(rdd, data.schema).as[T]
+    val ds = Dataset.ofRows(data.sparkSession, node).as[T]
     batchWriter(ds, batchId)
   }
 
+  /**
+   * ForEachBatchSink implementation reuses the logical plan of `data` which 
breaks the contract
+   * of Sink.addBatch, which `data` should be just used to "collect" the 
output data.
+   * We have to deal with eliminating marker node here which we do this in 
streaming specific
+   * optimization rule.
+   */
+  private def eliminateWriteMarkerNode(plan: LogicalPlan): LogicalPlan = plan 
match {
+    case node: WriteToMicroBatchDataSourceV1 => node.child
+    case node => node
+  }
+
   override def toString(): String = "ForeachBatchSink"
 }
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala
index ce98e2e6a5b..dbac4af90c0 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala
@@ -22,7 +22,8 @@ import scala.language.implicitConversions
 
 import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.encoders.RowEncoder
-import org.apache.spark.sql.execution.SerializeFromObjectExec
+import org.apache.spark.sql.execution.{LogicalRDD, SerializeFromObjectExec}
+import 
org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation
 import org.apache.spark.sql.execution.streaming.MemoryStream
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.streaming._
@@ -185,6 +186,48 @@ class ForeachBatchSinkSuite extends StreamTest {
     assertPlan(mem2, dsUntyped)
   }
 
+  test("Leaf node of Dataset in foreachBatch should carry over origin logical 
plan") {
+    def assertPlan[T](stream: MemoryStream[Int], ds: Dataset[T]): Unit = {
+      var planAsserted = false
+
+      val writer: (Dataset[T], Long) => Unit = { case (df, _) =>
+        df.logicalPlan.collectLeaves().head match {
+          case l: LogicalRDD =>
+            assert(l.originLogicalPlan.nonEmpty, "Origin logical plan should 
be available in " +
+              "LogicalRDD")
+            l.originLogicalPlan.get.collectLeaves().head match {
+              case _: StreamingDataSourceV2Relation => // pass
+              case p =>
+                fail("Expect StreamingDataSourceV2Relation in the leaf node of 
origin " +
+                  s"logical plan! Actual: $p")
+            }
+
+          case p =>
+            fail(s"Expect LogicalRDD in the leaf node of Dataset! Actual: $p")
+        }
+        planAsserted = true
+      }
+
+      stream.addData(1, 2, 3, 4, 5)
+
+      val query = 
ds.writeStream.trigger(Trigger.Once()).foreachBatch(writer).start()
+      query.awaitTermination()
+
+      assert(planAsserted, "ForeachBatch writer should be called!")
+    }
+
+    // typed
+    val mem = MemoryStream[Int]
+    val ds = mem.toDS.map(_ + 1)
+    assertPlan(mem, ds)
+
+    // untyped
+    val mem2 = MemoryStream[Int]
+    val dsUntyped = mem2.toDF().selectExpr("value + 1 as value")
+    assertPlan(mem2, dsUntyped)
+  }
+
+
   // ============== Helper classes and methods =================
 
   private class ForeachBatchTester[T: Encoder](memoryStream: 
MemoryStream[Int]) {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to