Github user icexelloss commented on a diff in the pull request:
https://github.com/apache/spark/pull/22305#discussion_r239587375
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala
---
@@ -144,24 +282,107 @@ case class WindowInPandasExec(
queue.close()
}
- val inputProj = UnsafeProjection.create(allInputs, child.output)
- val pythonInput = grouped.map { case (_, rows) =>
- rows.map { row =>
- queue.add(row.asInstanceOf[UnsafeRow])
- inputProj(row)
+ val stream = iter.map { row =>
+ queue.add(row.asInstanceOf[UnsafeRow])
+ row
+ }
+
+ val pythonInput = new Iterator[Iterator[UnsafeRow]] {
+
+ // Manage the stream and the grouping.
+ var nextRow: UnsafeRow = null
+ var nextGroup: UnsafeRow = null
+ var nextRowAvailable: Boolean = false
+ private[this] def fetchNextRow() {
+ nextRowAvailable = stream.hasNext
+ if (nextRowAvailable) {
+ nextRow = stream.next().asInstanceOf[UnsafeRow]
+ nextGroup = grouping(nextRow)
+ } else {
+ nextRow = null
+ nextGroup = null
+ }
+ }
+ fetchNextRow()
+
+ // Manage the current partition.
+ val buffer: ExternalAppendOnlyUnsafeRowArray =
+ new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold,
spillThreshold)
+ var bufferIterator: Iterator[UnsafeRow] = _
+
+ val indexRow = new
SpecificInternalRow(Array.fill(numBoundIndices)(IntegerType))
+
+ val frames = factories.map(_(indexRow))
+
+ private[this] def fetchNextPartition() {
+ // Collect all the rows in the current partition.
+ // Before we start to fetch new input rows, make a copy of
nextGroup.
+ val currentGroup = nextGroup.copy()
+
+ // clear last partition
+ buffer.clear()
+
+ while (nextRowAvailable && nextGroup == currentGroup) {
--- End diff --
Good catch. @ueshin Do you mind double checking what I am doing now is
correct?
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]