tdas commented on a change in pull request #33336:
URL: https://github.com/apache/spark/pull/33336#discussion_r669983410
##########
File path: sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
##########
@@ -389,6 +389,102 @@ case class AppendColumnsWithObjectExec(
copy(child = newChild)
}
+/**
+ * Groups the input rows together and calls the function with each group and
an iterator containing
+ * all elements in the group. The result of this function is flattened before
being output. This
+ * version of the Physical operator takes a user provided initial state.
+ */
+case class MapGroupsWithInitialStateExec(
+ func: (Any, Iterator[Any], LogicalGroupState[Any]) => TraversableOnce[Any],
+ keyDeserializer: Expression,
+ valueDeserializer: Expression,
+ initialStateDeserializer: Expression,
+ groupingAttributes: Seq[Attribute],
+ initialStateGroupingAttributes: Seq[Attribute],
+ dataAttributes: Seq[Attribute],
+ initialStateDataAttrs: Seq[Attribute],
+ outputObjAttr: Attribute,
+ initialState: SparkPlan,
+ timeoutConf: GroupStateTimeout,
+ child: SparkPlan) extends BinaryExecNode with ObjectProducerExec {
+
+ override def left: SparkPlan = child
+
+ override def right: SparkPlan = initialState
+
+ override def outputPartitioning: Partitioning = child.outputPartitioning
+
+ private val watermarkPresent = child.output.exists {
+ case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) =>
true
+ case _ => false
+ }
+
+ /**
+ * Distribute by grouping attributes - We need the underlying data and the
initial state data
+ * to have the same grouping so that the data are co-lacated on the same
task.
+ */
+ override def requiredChildDistribution: Seq[Distribution] = {
+ ClusteredDistribution(groupingAttributes) ::
+ ClusteredDistribution(initialStateGroupingAttributes) :: Nil
+ }
+
+ /**
+ * Ordering needed for using GroupingIterator.
+ * We need the initial state to also use the ordering as the data so that we
can co-locate the
+ * keys from the underlying data and the initial state.
+ */
+ override def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq(
+ groupingAttributes.map(SortOrder(_, Ascending)),
+ initialStateGroupingAttributes.map(SortOrder(_, Ascending)))
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ child.execute().zipPartitions(
+ initialState.execute()
+ ) { case (dataIter, initialStateIter) =>
+ val groupedChildDataIter = GroupedIterator(dataIter, groupingAttributes,
child.output)
+ val groupedInitialStateIter = GroupedIterator(
+ initialStateIter, initialStateGroupingAttributes, initialState.output)
+ val getKey = ObjectOperator.deserializeRowToObject(keyDeserializer,
groupingAttributes)
+ val getValue = ObjectOperator.deserializeRowToObject(valueDeserializer,
dataAttributes)
+ val outputObject = ObjectOperator.wrapObjectToRow(outputObjectType)
+ val getStateObj =
+ ObjectOperator.deserializeRowToObject(initialStateDeserializer,
initialStateDataAttrs)
+
+ new CoGroupedIterator(
+ groupedChildDataIter, groupedInitialStateIter,
groupingAttributes).flatMap {
+ case (keyRow, valueRowIter, initialStateRowIter) =>
+ var foundInitialStateForKey = false
+ val optionalState = initialStateRowIter.map { initialStateRow =>
Review comment:
add docs here explaining the logic with foundInitialStateForKey
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]