jingz-db commented on code in PR #48005:
URL: https://github.com/apache/spark/pull/48005#discussion_r1828246246


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala:
##########
@@ -168,26 +168,53 @@ case class FlatMapGroupsInPandasWithState(
  * methods for new rows in each trigger and the user's state/state variables 
will be stored
  * persistently across invocations.
  * @param functionExpr function called on each group
- * @param groupingAttributes used to group the data
+ * @param groupingAttributesLen length of the seq of grouping attributes for 
input dataframe
  * @param outputAttrs used to define the output rows
  * @param outputMode defines the output mode for the statefulProcessor
  * @param timeMode the time mode semantics of the stateful processor for 
timers and TTL.
  * @param child logical plan of the underlying data
+ * @param initialState logical plan of initial state
+ * @param initGroupingAttrsLen length of the seq of grouping attributes for 
initial state dataframe
  */
 case class TransformWithStateInPandas(
     functionExpr: Expression,
-    groupingAttributes: Seq[Attribute],
+    groupingAttributesLen: Int,
     outputAttrs: Seq[Attribute],
     outputMode: OutputMode,
     timeMode: TimeMode,
-    child: LogicalPlan) extends UnaryNode {
+    child: LogicalPlan,
+    hasInitialState: Boolean,
+    initialState: LogicalPlan,
+    initGroupingAttrsLen: Int,
+    initialStateSchema: StructType) extends BinaryNode {
+  override def left: LogicalPlan = child
+
+  override def right: LogicalPlan = initialState
 
   override def output: Seq[Attribute] = outputAttrs
 
   override def producedAttributes: AttributeSet = AttributeSet(outputAttrs)
 
-  override protected def withNewChildInternal(
-      newChild: LogicalPlan): TransformWithStateInPandas = copy(child = 
newChild)
+  override lazy val references: AttributeSet =
+    AttributeSet(leftAttributes ++ rightAttributes ++ functionExpr.references) 
-- producedAttributes
+
+  override protected def withNewChildrenInternal(
+      newLeft: LogicalPlan, newRight: LogicalPlan): TransformWithStateInPandas 
=
+    copy(child = newLeft, initialState = newRight)
+
+  // We call the following attributes from `SparkStrategies` because attribute 
were
+  // resolved when we got there; if we directly pass Seq of attributes before 
it
+  // is fully analyzed, we will have conflicting attributes when resolving 
(initial state
+  // and child have the same names on grouping attributes)
+  def leftAttributes: Seq[Attribute] = left.output.take(groupingAttributesLen)

Review Comment:
   Nice suggestion, this makes much more sense than leaving the explanation in 
the comments. I removed the comments and added the assertion in the code.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to