zhengruifeng commented on code in PR #49005:
URL: https://github.com/apache/spark/pull/49005#discussion_r1863210090
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala:
##########
@@ -45,91 +36,24 @@ class CoGroupedArrowPythonRunner(
leftSchema: StructType,
rightSchema: StructType,
timeZoneId: String,
+ largeVarTypes: Boolean,
+ arrowMaxRecordsPerBatch: Int,
conf: Map[String, String],
override val pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String],
profiler: Option[String])
- extends BasePythonRunner[
- (Iterator[InternalRow], Iterator[InternalRow]), ColumnarBatch](
- funcs.map(_._1), evalType, argOffsets, jobArtifactUUID)
- with BasicPythonArrowOutput {
-
- override val pythonExec: String =
- SQLConf.get.pysparkWorkerPythonExecutable.getOrElse(
- funcs.head._1.funcs.head.pythonExec)
-
- override val faultHandlerEnabled: Boolean =
SQLConf.get.pythonUDFWorkerFaulthandlerEnabled
-
- override val simplifiedTraceback: Boolean =
SQLConf.get.pysparkSimplifiedTraceback
-
- protected def newWriter(
- env: SparkEnv,
- worker: PythonWorker,
- inputIterator: Iterator[(Iterator[InternalRow], Iterator[InternalRow])],
- partitionIndex: Int,
- context: TaskContext): Writer = {
-
- new Writer(env, worker, inputIterator, partitionIndex, context) {
-
- protected override def writeCommand(dataOut: DataOutputStream): Unit = {
-
- // Write config for the worker as a number of key -> value pairs of
strings
- dataOut.writeInt(conf.size)
- for ((k, v) <- conf) {
- PythonRDD.writeUTF(k, dataOut)
- PythonRDD.writeUTF(v, dataOut)
- }
-
- PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets, profiler)
- }
-
- override def writeNextInputToStream(dataOut: DataOutputStream): Boolean
= {
- // For each we first send the number of dataframes in each group then
send
- // first df, then send second df. End of data is marked by sending 0.
- if (inputIterator.hasNext) {
- val startData = dataOut.size()
- dataOut.writeInt(2)
- val (nextLeft, nextRight) = inputIterator.next()
- writeGroup(nextLeft, leftSchema, dataOut, "left")
- writeGroup(nextRight, rightSchema, dataOut, "right")
-
- val deltaData = dataOut.size() - startData
- pythonMetrics("pythonDataSent") += deltaData
- true
- } else {
- dataOut.writeInt(0)
- false
- }
- }
-
- private def writeGroup(
- group: Iterator[InternalRow],
- schema: StructType,
- dataOut: DataOutputStream,
- name: String): Unit = {
- val arrowSchema =
- ArrowUtils.toArrowSchema(schema, timeZoneId,
errorOnDuplicatedFieldNames = true)
- val allocator = ArrowUtils.rootAllocator.newChildAllocator(
- s"stdout writer for $pythonExec ($name)", 0, Long.MaxValue)
- val root = VectorSchemaRoot.create(arrowSchema, allocator)
-
- Utils.tryWithSafeFinally {
- val writer = new ArrowStreamWriter(root, null, dataOut)
- val arrowWriter = ArrowWriter.create(root)
- writer.start()
-
- while (group.hasNext) {
- arrowWriter.write(group.next())
- }
- arrowWriter.finish()
- writer.writeBatch()
- writer.end()
- }{
- root.close()
- allocator.close()
- }
- }
- }
+ extends BaseGroupedArrowPythonRunner[(Iterator[InternalRow],
Iterator[InternalRow])](
+ funcs, evalType, argOffsets, timeZoneId, largeVarTypes,
arrowMaxRecordsPerBatch, conf,
+ pythonMetrics, jobArtifactUUID, profiler) {
+
+ override protected def writeNextGroup(
+ group: (Iterator[InternalRow], Iterator[InternalRow]),
+ dataOut: DataOutputStream): Unit = {
+ val (leftGroup, rightGroup) = group
+
+ dataOut.writeInt(2)
+ writeSingleStream(leftGroup, leftSchema, dataOut, Some("left"))
+ writeSingleStream(rightGroup, rightSchema, dataOut, Some("right"))
Review Comment:
does here means for a cogroup (key, group A, group B), write the whole group
A in batches, and then write the whole group B in batches?
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInBatchExec.scala:
##########
@@ -82,13 +83,14 @@ trait FlatMapGroupsInBatchExec extends SparkPlan with
UnaryExecNode with PythonS
val data = groupedData(iter, dedupAttributes)
- val runner = new ArrowPythonRunner(
+ val runner = new GroupedArrowPythonRunner(
Review Comment:
So for a UDF using existing `Table -> Table` signature, it will be split
into batches?
I guess this might cause breaking change:
suppose a UDF doing aggregation, and a group larger than
`arrowMaxRecordsPerBatch`, it outputs single row before, but will generate
multiples rows after this change?
##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -728,8 +743,9 @@ def load_stream(self, stream):
dataframes_in_group = read_int(stream)
if dataframes_in_group == 2:
+ # We need to fully load the left batches, but we can lazily
load the right batches
Review Comment:
It still may cause memory issue, is it possible to do it in this way
```
cogroup: (Iterator[InternalRow], Iterator[InternalRow])
(left, right) = cogroup
while (left.hasNext || right.hasNext) {
if (left.hasNext) {
writeBatch(... , "left")
} else {
writeBatch(A empty Batch , "left")
}
if (right.hasNext) {
writeBatch(... , "right")
} else {
writeBatch(A empty Batch , "right")
}
}
```
##########
python/pyspark/sql/pandas/_typing/__init__.pyi:
##########
@@ -348,10 +349,22 @@ PandasCogroupedMapFunction = Union[
ArrowGroupedMapFunction = Union[
Callable[[pyarrow.Table], pyarrow.Table],
Callable[[Tuple[pyarrow.Scalar, ...], pyarrow.Table], pyarrow.Table],
+ Callable[[Iterator[pyarrow.RecordBatch]], Iterator[pyarrow.RecordBatch]],
Review Comment:
what about adding a new `ArrowGroupedMapIterFunction` for the new signature?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]