zhengruifeng commented on code in PR #49005:
URL: https://github.com/apache/spark/pull/49005#discussion_r1863210090


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala:
##########
@@ -45,91 +36,24 @@ class CoGroupedArrowPythonRunner(
     leftSchema: StructType,
     rightSchema: StructType,
     timeZoneId: String,
+    largeVarTypes: Boolean,
+    arrowMaxRecordsPerBatch: Int,
     conf: Map[String, String],
     override val pythonMetrics: Map[String, SQLMetric],
     jobArtifactUUID: Option[String],
     profiler: Option[String])
-  extends BasePythonRunner[
-    (Iterator[InternalRow], Iterator[InternalRow]), ColumnarBatch](
-    funcs.map(_._1), evalType, argOffsets, jobArtifactUUID)
-  with BasicPythonArrowOutput {
-
-  override val pythonExec: String =
-    SQLConf.get.pysparkWorkerPythonExecutable.getOrElse(
-      funcs.head._1.funcs.head.pythonExec)
-
-  override val faultHandlerEnabled: Boolean = 
SQLConf.get.pythonUDFWorkerFaulthandlerEnabled
-
-  override val simplifiedTraceback: Boolean = 
SQLConf.get.pysparkSimplifiedTraceback
-
-  protected def newWriter(
-      env: SparkEnv,
-      worker: PythonWorker,
-      inputIterator: Iterator[(Iterator[InternalRow], Iterator[InternalRow])],
-      partitionIndex: Int,
-      context: TaskContext): Writer = {
-
-    new Writer(env, worker, inputIterator, partitionIndex, context) {
-
-      protected override def writeCommand(dataOut: DataOutputStream): Unit = {
-
-        // Write config for the worker as a number of key -> value pairs of 
strings
-        dataOut.writeInt(conf.size)
-        for ((k, v) <- conf) {
-          PythonRDD.writeUTF(k, dataOut)
-          PythonRDD.writeUTF(v, dataOut)
-        }
-
-        PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets, profiler)
-      }
-
-      override def writeNextInputToStream(dataOut: DataOutputStream): Boolean 
= {
-        // For each we first send the number of dataframes in each group then 
send
-        // first df, then send second df.  End of data is marked by sending 0.
-        if (inputIterator.hasNext) {
-          val startData = dataOut.size()
-          dataOut.writeInt(2)
-          val (nextLeft, nextRight) = inputIterator.next()
-          writeGroup(nextLeft, leftSchema, dataOut, "left")
-          writeGroup(nextRight, rightSchema, dataOut, "right")
-
-          val deltaData = dataOut.size() - startData
-          pythonMetrics("pythonDataSent") += deltaData
-          true
-        } else {
-          dataOut.writeInt(0)
-          false
-        }
-      }
-
-      private def writeGroup(
-          group: Iterator[InternalRow],
-          schema: StructType,
-          dataOut: DataOutputStream,
-          name: String): Unit = {
-        val arrowSchema =
-          ArrowUtils.toArrowSchema(schema, timeZoneId, 
errorOnDuplicatedFieldNames = true)
-        val allocator = ArrowUtils.rootAllocator.newChildAllocator(
-          s"stdout writer for $pythonExec ($name)", 0, Long.MaxValue)
-        val root = VectorSchemaRoot.create(arrowSchema, allocator)
-
-        Utils.tryWithSafeFinally {
-          val writer = new ArrowStreamWriter(root, null, dataOut)
-          val arrowWriter = ArrowWriter.create(root)
-          writer.start()
-
-          while (group.hasNext) {
-            arrowWriter.write(group.next())
-          }
-          arrowWriter.finish()
-          writer.writeBatch()
-          writer.end()
-        }{
-          root.close()
-          allocator.close()
-        }
-      }
-    }
+  extends BaseGroupedArrowPythonRunner[(Iterator[InternalRow], 
Iterator[InternalRow])](
+    funcs, evalType, argOffsets, timeZoneId, largeVarTypes, 
arrowMaxRecordsPerBatch, conf,
+    pythonMetrics, jobArtifactUUID, profiler) {
+
+  override protected def writeNextGroup(
+      group: (Iterator[InternalRow], Iterator[InternalRow]),
+      dataOut: DataOutputStream): Unit = {
+    val (leftGroup, rightGroup) = group
+
+    dataOut.writeInt(2)
+    writeSingleStream(leftGroup, leftSchema, dataOut, Some("left"))
+    writeSingleStream(rightGroup, rightSchema, dataOut, Some("right"))

Review Comment:
   does here means for a cogroup (key, group A, group B), write the whole group 
A in batches, and then write the whole group B in batches?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInBatchExec.scala:
##########
@@ -82,13 +83,14 @@ trait FlatMapGroupsInBatchExec extends SparkPlan with 
UnaryExecNode with PythonS
 
       val data = groupedData(iter, dedupAttributes)
 
-      val runner = new ArrowPythonRunner(
+      val runner = new GroupedArrowPythonRunner(

Review Comment:
   So for a UDF using existing `Table -> Table` signature, it will be split 
into batches?
   I guess this might cause breaking change:
   suppose a UDF doing aggregation, and a group larger than 
`arrowMaxRecordsPerBatch`, it outputs single row before, but will generate 
multiples rows after this change?



##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -728,8 +743,9 @@ def load_stream(self, stream):
             dataframes_in_group = read_int(stream)
 
             if dataframes_in_group == 2:
+                # We need to fully load the left batches, but we can lazily 
load the right batches

Review Comment:
   It still may cause memory issue, is it possible to do it in this way
   
   ```
   cogroup: (Iterator[InternalRow], Iterator[InternalRow])
   
   (left, right) = cogroup
   
   while (left.hasNext || right.hasNext) {
      if (left.hasNext) {
         writeBatch(... , "left")
      } else {
         writeBatch(A empty Batch , "left")
      }
   
      if (right.hasNext) {
         writeBatch(... , "right")
      } else {
         writeBatch(A empty Batch , "right")
      }
   }
   
   ```



##########
python/pyspark/sql/pandas/_typing/__init__.pyi:
##########
@@ -348,10 +349,22 @@ PandasCogroupedMapFunction = Union[
 ArrowGroupedMapFunction = Union[
     Callable[[pyarrow.Table], pyarrow.Table],
     Callable[[Tuple[pyarrow.Scalar, ...], pyarrow.Table], pyarrow.Table],
+    Callable[[Iterator[pyarrow.RecordBatch]], Iterator[pyarrow.RecordBatch]],

Review Comment:
   what about adding a new `ArrowGroupedMapIterFunction` for the new signature?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to