gaogaotiantian commented on code in PR #54172:
URL: https://github.com/apache/spark/pull/54172#discussion_r2800581134


##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -129,33 +151,25 @@ def dump_stream(self, iterator, stream):
                 writer.close()
 
     def load_stream(self, stream):
-        import pyarrow as pa
+        """Load batches: plain stream if num_dfs=0, grouped otherwise."""
+        if self._num_dfs == 0:
+            import pyarrow as pa
 
-        reader = pa.ipc.open_stream(stream)
-        for batch in reader:
-            yield batch
+            reader = pa.ipc.open_stream(stream)
+            for batch in reader:
+                yield batch
+        elif self._num_dfs == 2:

Review Comment:
   It's a bit weird here to do `0, 2, else` check. Let's do `0, 1, 2, else 
assert False`



##########
python/pyspark/worker.py:
##########
@@ -2819,25 +2826,48 @@ def read_udfs(pickleSer, infile, eval_type, 
runner_conf, eval_conf):
         for i in range(num_udfs)
     ]
 
+    if eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF:
+        import pyarrow as pa
+
+        assert num_udfs == 1, "One MAP_ARROW_ITER UDF expected here."
+        udf_func: Callable[[Iterator[pa.RecordBatch]], 
Iterator[pa.RecordBatch]] = udfs[0]
+
+        def func(
+            split_index: int, batches: Iterator["pa.RecordBatch"]
+        ) -> Iterator["pa.RecordBatch"]:

Review Comment:
   nit: if you are already using `pa.RecordBatch` (`pa` is definitely imported 
here), you don't need to use the quote version.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to