HeartSaVioR commented on a change in pull request #33081:
URL: https://github.com/apache/spark/pull/33081#discussion_r670895590



##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
##########
@@ -511,6 +513,298 @@ case class StateStoreSaveExec(
     copy(child = newChild)
 }
 
+/**
+ * This class sorts input rows and existing sessions in state and provides 
output rows as
+ * sorted by "group keys + start time of session window".
+ *
+ * Refer [[MergingSortWithSessionWindowStateIterator]] for more details.
+ */
+case class SessionWindowStateStoreRestoreExec(
+    keyWithoutSessionExpressions: Seq[Attribute],
+    sessionExpression: Attribute,
+    stateInfo: Option[StatefulOperatorStateInfo],
+    eventTimeWatermark: Option[Long],
+    stateFormatVersion: Int,
+    child: SparkPlan)
+  extends UnaryExecNode with StateStoreReader with WatermarkSupport {
+
+  override def keyExpressions: Seq[Attribute] = keyWithoutSessionExpressions
+
+  private val stateManager = 
StreamingSessionWindowStateManager.createStateManager(
+    keyWithoutSessionExpressions, sessionExpression, child.output, 
stateFormatVersion)
+
+  override protected def doExecute(): RDD[InternalRow] = {
+    val numOutputRows = longMetric("numOutputRows")
+    assert(keyExpressions.nonEmpty, "Grouping key must be specified when using 
sessionWindow")
+
+    child.execute().mapPartitionsWithReadStateStore(
+      getStateInfo,
+      stateManager.getStateKeySchema,
+      stateManager.getStateValueSchema,
+      numColsPrefixKey = stateManager.getNumColsForPrefixKey,
+      session.sessionState,
+      Some(session.streams.stateStoreCoordinator)) { case (store, iter) =>
+
+      // We need to filter out outdated inputs
+      val filteredIterator = watermarkPredicateForData match {
+        case Some(predicate) => iter.filter((row: InternalRow) => 
!predicate.eval(row))
+        case None => iter
+      }
+
+      new MergingSortWithSessionWindowStateIterator(
+        filteredIterator,
+        stateManager,
+        store,
+        keyWithoutSessionExpressions,
+        sessionExpression,
+        child.output).map { row =>
+        numOutputRows += 1
+        row
+      }
+    }
+  }
+
+  override def output: Seq[Attribute] = child.output
+
+  override def outputPartitioning: Partitioning = child.outputPartitioning
+
+  override def outputOrdering: Seq[SortOrder] = {
+    (keyWithoutSessionExpressions ++ Seq(sessionExpression)).map(SortOrder(_, 
Ascending))
+  }
+
+  override def requiredChildDistribution: Seq[Distribution] = {
+    if (keyWithoutSessionExpressions.isEmpty) {
+      AllTuples :: Nil
+    } else {
+      ClusteredDistribution(keyWithoutSessionExpressions, 
stateInfo.map(_.numPartitions)) :: Nil
+    }
+  }
+
+  override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
+    Seq((keyWithoutSessionExpressions ++ 
Seq(sessionExpression)).map(SortOrder(_, Ascending)))
+  }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
+    copy(child = newChild)
+}
+
+/**
+ * For each input tuple, the key is calculated and the tuple is `put` into the 
[[StateStore]].
+ */
+case class SessionWindowStateStoreSaveExec(
+    keyExpressions: Seq[Attribute],
+    sessionExpression: Attribute,
+    stateInfo: Option[StatefulOperatorStateInfo] = None,
+    outputMode: Option[OutputMode] = None,
+    eventTimeWatermark: Option[Long] = None,
+    stateFormatVersion: Int,
+    child: SparkPlan)
+  extends UnaryExecNode with StateStoreWriter with WatermarkSupport {
+
+  private val keyWithoutSessionExpressions = keyExpressions.filterNot { p =>
+    p.semanticEquals(sessionExpression)
+  }
+
+  private val stateManager = 
StreamingSessionWindowStateManager.createStateManager(
+    keyWithoutSessionExpressions, sessionExpression, child.output, 
stateFormatVersion)
+
+  override protected def doExecute(): RDD[InternalRow] = {
+    metrics // force lazy init at driver
+    assert(outputMode.nonEmpty,
+      "Incorrect planning in IncrementalExecution, outputMode has not been 
set")
+    assert(keyExpressions.nonEmpty,
+      "Grouping key must be specified when using sessionWindow")
+
+    child.execute().mapPartitionsWithStateStore(
+      getStateInfo,
+      stateManager.getStateKeySchema,
+      stateManager.getStateValueSchema,
+      numColsPrefixKey = stateManager.getNumColsForPrefixKey,
+      session.sessionState,
+      Some(session.streams.stateStoreCoordinator)) { case (store, iter) =>
+
+      val numOutputRows = longMetric("numOutputRows")
+      val allUpdatesTimeMs = longMetric("allUpdatesTimeMs")
+      val allRemovalsTimeMs = longMetric("allRemovalsTimeMs")
+      val commitTimeMs = longMetric("commitTimeMs")
+
+      outputMode match {
+        // Update and output all rows in the StateStore.
+        case Some(Complete) =>
+          allUpdatesTimeMs += timeTakenMs {
+            putToStore(iter, store, false)
+          }
+          allRemovalsTimeMs += 0
+          commitTimeMs += timeTakenMs {
+            stateManager.commit(store)
+          }
+          setStoreMetrics(store)
+          stateManager.iterator(store).map { row =>
+            numOutputRows += 1
+            row
+          }
+
+        // Update and output only rows being evicted from the StateStore
+        // Assumption: watermark predicates must be non-empty if append mode 
is allowed
+        case Some(Append) =>
+          allUpdatesTimeMs += timeTakenMs {
+            putToStore(iter, store, true)
+          }
+
+          val removalStartTimeNs = System.nanoTime
+          new NextIterator[InternalRow] {
+            private val removedIter = stateManager.removeByValueCondition(
+              store, watermarkPredicateForData.get.eval)
+
+            override protected def getNext(): InternalRow = {
+              if (!removedIter.hasNext) {
+                finished = true
+                null
+              } else {
+                numOutputRows += 1
+                removedIter.next()
+              }
+            }
+
+            override protected def close(): Unit = {
+              allRemovalsTimeMs += NANOSECONDS.toMillis(System.nanoTime - 
removalStartTimeNs)
+              commitTimeMs += timeTakenMs { store.commit() }
+              setStoreMetrics(store)
+            }
+          }
+
+        case Some(Update) =>
+          val iterPutToStore = iteratorPutToStore(iter, store, true, true)
+          new NextIterator[InternalRow] {
+            private val updatesStartTimeNs = System.nanoTime
+
+            override protected def getNext(): InternalRow = {
+              if (iterPutToStore.hasNext) {
+                iterPutToStore.next()
+              } else {
+                finished = true
+                null
+              }
+            }
+
+            override protected def close(): Unit = {
+              allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - 
updatesStartTimeNs)
+
+              allRemovalsTimeMs += timeTakenMs {
+                if (watermarkPredicateForData.nonEmpty) {
+                  val removedIter = stateManager.removeByValueCondition(
+                    store, watermarkPredicateForData.get.eval)
+                  while (removedIter.hasNext) {
+                    removedIter.next()
+                  }
+                }
+              }
+              commitTimeMs += timeTakenMs { store.commit() }
+              setStoreMetrics(store)
+            }
+          }
+
+        case _ => throw new UnsupportedOperationException(s"Invalid output 
mode: $outputMode")
+      }
+    }
+  }
+
+  override def output: Seq[Attribute] = child.output
+
+  override def outputPartitioning: Partitioning = child.outputPartitioning
+
+  override def requiredChildDistribution: Seq[Distribution] = {
+    if (keyExpressions.isEmpty) {
+      AllTuples :: Nil
+    } else {
+      ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: 
Nil
+    }
+  }
+
+  override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean 
= {
+    (outputMode.contains(Append) || outputMode.contains(Update)) &&
+      eventTimeWatermark.isDefined &&
+      newMetadata.batchWatermarkMs > eventTimeWatermark.get
+  }
+
+  private def iteratorPutToStore(
+      baseIter: Iterator[InternalRow],
+      store: StateStore,
+      needFilter: Boolean,

Review comment:
       Yeah this was in old code and after refactoring this is no longer needed 
to be passed as parameter. Nice finding. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to