tdas commented on a change in pull request #33093:
URL: https://github.com/apache/spark/pull/33093#discussion_r661555856



##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
##########
@@ -39,45 +42,63 @@ import org.apache.spark.util.CompletionIterator
  * @param dataAttributes used to read the data
  * @param outputObjAttr Defines the output object
  * @param stateEncoder used to serialize/deserialize state before calling 
`func`
+ * @param initStateEncoder encoder for the initial state used to deserialize 
init state.
  * @param outputMode the output mode of `func`
  * @param timeoutConf used to timeout groups that have not received data in a 
while
  * @param batchTimestampMs processing timestamp of the current batch.
+ * @param eventTimeWatermark event time watermark for the current batch
+ * @param initialState the user specified initial state
+ * @param hasInitialState indicates whether the initial state is provided or 
not
+ * @param child the physical plan for the underlying data
  */
 case class FlatMapGroupsWithStateExec(
     func: (Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any],
     keyDeserializer: Expression,
     valueDeserializer: Expression,
     groupingAttributes: Seq[Attribute],
+    initStateGroupAttrs: Seq[Attribute],
     dataAttributes: Seq[Attribute],
     outputObjAttr: Attribute,
     stateInfo: Option[StatefulOperatorStateInfo],
     stateEncoder: ExpressionEncoder[Any],
+    initStateEncoder: ExpressionEncoder[Any],
     stateFormatVersion: Int,
     outputMode: OutputMode,
     timeoutConf: GroupStateTimeout,
     batchTimestampMs: Option[Long],
     eventTimeWatermark: Option[Long],
+    initialState: SparkPlan,
+    hasInitialState: Boolean,
     child: SparkPlan
-  ) extends UnaryExecNode with ObjectProducerExec with StateStoreWriter with 
WatermarkSupport {
+  ) extends BinaryExecNode with ObjectProducerExec with StateStoreWriter with 
WatermarkSupport {
 
   import FlatMapGroupsWithStateExecHelper._
   import GroupStateImpl._
 
+  override def left: SparkPlan = child
+
+  override def right: SparkPlan = initialState
+
   private val isTimeoutEnabled = timeoutConf != NoTimeout
   private val watermarkPresent = child.output.exists {
     case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => 
true
     case _ => false
   }
+
   private[sql] val stateManager =
     createStateManager(stateEncoder, isTimeoutEnabled, stateFormatVersion)
 
   /** Distribute by grouping attributes */
-  override def requiredChildDistribution: Seq[Distribution] =
-    ClusteredDistribution(groupingAttributes, stateInfo.map(_.numPartitions)) 
:: Nil
+  override def requiredChildDistribution: Seq[Distribution] = {
+    HashClusteredDistribution(groupingAttributes, 
stateInfo.map(_.numPartitions)) ::
+    HashClusteredDistribution(initStateGroupAttrs, 
stateInfo.map(_.numPartitions)) ::
+      Nil
+  }
 
   /** Ordering needed for using GroupingIterator */
-  override def requiredChildOrdering: Seq[Seq[SortOrder]] =
-    Seq(groupingAttributes.map(SortOrder(_, Ascending)))
+  override def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq(
+      groupingAttributes.map(SortOrder(_, Ascending)),
+      initStateGroupAttrs.map(SortOrder(_, Ascending)))

Review comment:
       same as above, add more docs




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to