chaoqin-li1123 commented on code in PR #45432:
URL: https://github.com/apache/spark/pull/45432#discussion_r1522240457


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala:
##########
@@ -22,19 +22,38 @@ import scala.util.control.NonFatal
 import org.apache.spark.{SparkException, SparkThrowable}
 import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Deduplicate, 
DeduplicateWithinWatermark, Distinct, FlatMapGroupsInPandasWithState, 
FlatMapGroupsWithState, GlobalLimit, Join, LogicalPlan, TransformWithState}
 import org.apache.spark.sql.execution.LogicalRDD
 import org.apache.spark.sql.execution.streaming.Sink
 import org.apache.spark.sql.streaming.DataStreamWriter
 
 class ForeachBatchSink[T](batchWriter: (Dataset[T], Long) => Unit, encoder: 
ExpressionEncoder[T])
   extends Sink {
 
+  private def isQueryStateful(logicalPlan: LogicalPlan): Boolean = {
+    logicalPlan.collect {
+      case node @ (_: Aggregate | _: Distinct | _: FlatMapGroupsWithState
+                   | _: FlatMapGroupsInPandasWithState | _: TransformWithState 
| _: Deduplicate
+                   | _: DeduplicateWithinWatermark | _: GlobalLimit) if 
node.isStreaming => node
+      case node @ Join(left, right, _, _, _) if left.isStreaming && 
right.isStreaming => node
+    }.nonEmpty
+  }
+
   override def addBatch(batchId: Long, data: DataFrame): Unit = {
     val node = LogicalRDD.fromDataset(rdd = data.queryExecution.toRdd, 
originDataset = data,
       isStreaming = false)
     implicit val enc = encoder
     val ds = Dataset.ofRows(data.sparkSession, node).as[T]
-    callBatchWriter(ds, batchId)
+    // SPARK-47329 - persist the dataframe for stateful queries to prevent 
state stores
+    // from reloading state multiple times in each batch
+    val isStateful = isQueryStateful(data.logicalPlan)
+    if (isStateful) {
+      ds.persist()
+      callBatchWriter(ds, batchId)
+      ds.unpersist()

Review Comment:
   Shall we wrap this in try catch finally to avoid leaking the cache?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to