This is an automated email from the ASF dual-hosted git repository.
kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new cb0b51038b0 [SPARK-39748][SQL][SS] Include the origin logical plan for
LogicalRDD if it comes from DataFrame
cb0b51038b0 is described below
commit cb0b51038b0ae17ba2a4a38082e322f5b6087e06
Author: Jungtaek Lim <[email protected]>
AuthorDate: Tue Jul 12 17:59:10 2022 +0900
[SPARK-39748][SQL][SS] Include the origin logical plan for LogicalRDD if it
comes from DataFrame
### What changes were proposed in this pull request?
This PR proposes to include the origin logical plan for LogicalRDD, if the
LogicalRDD is built from DataFrame's RDD. Once the origin logical plan is
available, LogicalRDD produces the stats from origin logical plan rather than
default one.
Also, this PR applies the change to ForeachBatchSink, which seems to be the
only case as of now in current codebase.
### Why are the changes needed?
The origin logical plan can be useful for several use cases, including:
1. wants to connect the two split logical plans into one (consider the case
of foreachBatch sink: origin logical plan represents the plan for streaming
query, and the logical plan for new Dataset represents the plan for batch query
in user function)
2. inherits plan stats from origin logical plan
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
New UT.
Closes #37161 from HeartSaVioR/SPARK-39748.
Authored-by: Jungtaek Lim <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
---
.../main/scala/org/apache/spark/sql/Dataset.scala | 1 +
.../apache/spark/sql/execution/ExistingRDD.scala | 23 ++++++++---
.../streaming/sources/ForeachBatchSink.scala | 22 ++++++++++-
.../streaming/sources/ForeachBatchSinkSuite.scala | 45 +++++++++++++++++++++-
4 files changed, 83 insertions(+), 8 deletions(-)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 39d33d80261..f45c27d3007 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -705,6 +705,7 @@ class Dataset[T] private[sql](
LogicalRDD(
logicalPlan.output,
internalRdd,
+ None,
outputPartitioning,
physicalPlan.outputOrdering,
isStreaming
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index 1ab183fe843..bf9ef6991e3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -83,10 +83,16 @@ case class ExternalRDDScanExec[T](
}
}
-/** Logical plan node for scanning data from an RDD of InternalRow. */
+/**
+ * Logical plan node for scanning data from an RDD of InternalRow.
+ *
+ * It is advised to set the field `originLogicalPlan` if the RDD is directly
built from DataFrame,
+ * as the stat can be inherited from `originLogicalPlan`.
+ */
case class LogicalRDD(
output: Seq[Attribute],
rdd: RDD[InternalRow],
+ originLogicalPlan: Option[LogicalPlan] = None,
outputPartitioning: Partitioning = UnknownPartitioning(0),
override val outputOrdering: Seq[SortOrder] = Nil,
override val isStreaming: Boolean = false)(session: SparkSession)
@@ -113,6 +119,7 @@ case class LogicalRDD(
LogicalRDD(
output.map(rewrite),
rdd,
+ originLogicalPlan,
rewrittenPartitioning,
rewrittenOrdering,
isStreaming
@@ -121,11 +128,15 @@ case class LogicalRDD(
override protected def stringArgs: Iterator[Any] = Iterator(output,
isStreaming)
- override def computeStats(): Statistics = Statistics(
- // TODO: Instead of returning a default value here, find a way to return a
meaningful size
- // estimate for RDDs. See PR 1238 for more discussions.
- sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes)
- )
+ override def computeStats(): Statistics = {
+ originLogicalPlan.map(_.stats).getOrElse {
+ Statistics(
+ // TODO: Instead of returning a default value here, find a way to
return a meaningful size
+ // estimate for RDDs. See PR 1238 for more discussions.
+ sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes)
+ )
+ }
+ }
}
/** Physical plan node for scanning data from an RDD of InternalRow. */
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala
index 0893875aff5..1c6bca241af 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.streaming.sources
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.execution.LogicalRDD
import org.apache.spark.sql.execution.streaming.Sink
import org.apache.spark.sql.streaming.DataStreamWriter
@@ -27,11 +29,29 @@ class ForeachBatchSink[T](batchWriter: (Dataset[T], Long)
=> Unit, encoder: Expr
override def addBatch(batchId: Long, data: DataFrame): Unit = {
val rdd = data.queryExecution.toRdd
+ val executedPlan = data.queryExecution.executedPlan
+ val node = LogicalRDD(
+ data.schema.toAttributes,
+ rdd,
+ Some(eliminateWriteMarkerNode(data.queryExecution.analyzed)),
+ executedPlan.outputPartitioning,
+ executedPlan.outputOrdering)(data.sparkSession)
implicit val enc = encoder
- val ds = data.sparkSession.internalCreateDataFrame(rdd, data.schema).as[T]
+ val ds = Dataset.ofRows(data.sparkSession, node).as[T]
batchWriter(ds, batchId)
}
+ /**
+ * ForEachBatchSink implementation reuses the logical plan of `data` which
breaks the contract
+ * of Sink.addBatch, which `data` should be just used to "collect" the
output data.
+ * We have to deal with eliminating marker node here which we do this in
streaming specific
+ * optimization rule.
+ */
+ private def eliminateWriteMarkerNode(plan: LogicalPlan): LogicalPlan = plan
match {
+ case node: WriteToMicroBatchDataSourceV1 => node.child
+ case node => node
+ }
+
override def toString(): String = "ForeachBatchSink"
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala
index ce98e2e6a5b..dbac4af90c0 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala
@@ -22,7 +22,8 @@ import scala.language.implicitConversions
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
-import org.apache.spark.sql.execution.SerializeFromObjectExec
+import org.apache.spark.sql.execution.{LogicalRDD, SerializeFromObjectExec}
+import
org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.functions._
import org.apache.spark.sql.streaming._
@@ -185,6 +186,48 @@ class ForeachBatchSinkSuite extends StreamTest {
assertPlan(mem2, dsUntyped)
}
+ test("Leaf node of Dataset in foreachBatch should carry over origin logical
plan") {
+ def assertPlan[T](stream: MemoryStream[Int], ds: Dataset[T]): Unit = {
+ var planAsserted = false
+
+ val writer: (Dataset[T], Long) => Unit = { case (df, _) =>
+ df.logicalPlan.collectLeaves().head match {
+ case l: LogicalRDD =>
+ assert(l.originLogicalPlan.nonEmpty, "Origin logical plan should
be available in " +
+ "LogicalRDD")
+ l.originLogicalPlan.get.collectLeaves().head match {
+ case _: StreamingDataSourceV2Relation => // pass
+ case p =>
+ fail("Expect StreamingDataSourceV2Relation in the leaf node of
origin " +
+ s"logical plan! Actual: $p")
+ }
+
+ case p =>
+ fail(s"Expect LogicalRDD in the leaf node of Dataset! Actual: $p")
+ }
+ planAsserted = true
+ }
+
+ stream.addData(1, 2, 3, 4, 5)
+
+ val query =
ds.writeStream.trigger(Trigger.Once()).foreachBatch(writer).start()
+ query.awaitTermination()
+
+ assert(planAsserted, "ForeachBatch writer should be called!")
+ }
+
+ // typed
+ val mem = MemoryStream[Int]
+ val ds = mem.toDS.map(_ + 1)
+ assertPlan(mem, ds)
+
+ // untyped
+ val mem2 = MemoryStream[Int]
+ val dsUntyped = mem2.toDF().selectExpr("value + 1 as value")
+ assertPlan(mem2, dsUntyped)
+ }
+
+
// ============== Helper classes and methods =================
private class ForeachBatchTester[T: Encoder](memoryStream:
MemoryStream[Int]) {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]