allisonwang-db commented on code in PR #41867:
URL: https://github.com/apache/spark/pull/41867#discussion_r1261867838


##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -454,6 +474,74 @@ def __repr__(self):
         return "ArrowStreamPandasUDFSerializer"
 
 
+class ArrowStreamPandasUDTFSerializer(ArrowStreamPandasUDFSerializer):
+    """
+    Serializer used by Python worker to evaluate Arrow-optimized Python UDTFs.
+    """
+
+    def __init__(self, timezone, safecheck, assign_cols_by_name):
+        super(ArrowStreamPandasUDTFSerializer, self).__init__(
+            timezone=timezone,
+            safecheck=safecheck,
+            assign_cols_by_name=assign_cols_by_name,
+            # Set to 'False' to avoid converting struct type inputs into a 
pandas DataFrame.
+            df_for_struct=False,
+            # Defines how struct type inputs are converted. If set to "row", 
struct type inputs
+            # are converted into Rows. Without this setting, a struct type 
input would be treated
+            # as a dictionary. For example, for named_struct('name', 'Alice', 
'age', 1),
+            # if struct_in_pandas="dict", it becomes {"name": "Alice", "age": 
1}
+            # if struct_in_pandas="row", it becomes Row(name="Alice", age=1)
+            struct_in_pandas="row",
+            # When dealing with array type inputs, Arrow converts them into 
numpy.ndarrays.
+            # To ensure consistency across regular and arrow-optimized UDTFs, 
we further
+            # convert these numpy.ndarrays into Python lists.
+            ndarray_as_list=True,
+            # Enables explicit casting for mismatched return types of Arrow 
Python UDTFs.
+            arrow_cast=True,
+        )
+
+    def _create_batch(self, series):
+        """
+        Create an Arrow record batch from the given pandas.Series 
pandas.DataFrame
+        or list of Series or DataFrame, with optional type.
+
+        Parameters
+        ----------
+        series : pandas.Series or pandas.DataFrame or list
+            A single series or dataframe, list of series or dataframe,
+            or list of (series or dataframe, arrow_type)
+
+        Returns
+        -------
+        pyarrow.RecordBatch
+            Arrow RecordBatch
+        """
+        import pandas as pd
+        import pyarrow as pa
+
+        # Make input conform to [(series1, type1), (series2, type2), ...]
+        if not isinstance(series, (list, tuple)) or (
+            len(series) == 2 and isinstance(series[1], pa.DataType)
+        ):
+            series = [series]
+        series = ((s, None) if not isinstance(s, (list, tuple)) else s for s 
in series)
+
+        arrs = []
+        for s, t in series:
+            if not isinstance(s, pd.DataFrame):
+                raise PySparkValueError(
+                    "Output of an arrow-optimized Python UDTFs expects "
+                    f"a pandas.DataFrame but got: {type(s)}"
+                )
+
+            arrs.append(self._create_struct_array(s, t))
+
+        return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in 
range(len(arrs))])

Review Comment:
   Great idea! I will make a follow-up PR for this.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to