ueshin commented on code in PR #41867:
URL: https://github.com/apache/spark/pull/41867#discussion_r1261739197


##########
python/pyspark/sql/tests/test_udtf.py:
##########
@@ -369,19 +381,115 @@ def eval(self, a: int):
         ):
             TestUDTF(rand(0) * 100).collect()
 
-    def test_udtf_no_eval(self):
-        @udtf(returnType="a: int, b: int")
+    def test_udtf_with_struct_input_type(self):
+        @udtf(returnType="x: string")
         class TestUDTF:
-            def run(self, a: int):
-                yield a, a + 1
+            def eval(self, person):
+                yield f"{person.name}: {person.age}",
+
+        self.spark.udtf.register("test_udtf", TestUDTF)
+        self.assertEqual(
+            self.spark.sql(
+                "select * from test_udtf(named_struct('name', 'Alice', 'age', 
1))"
+            ).collect(),
+            [Row(x="Alice: 1")],
+        )
 
+    def test_udtf_with_array_input_type(self):
+        @udtf(returnType="x: string")
+        class TestUDTF:
+            def eval(self, args):
+                yield str(args),
+
+        self.spark.udtf.register("test_udtf", TestUDTF)
+        self.assertEqual(
+            self.spark.sql("select * from test_udtf(array(1, 2, 
3))").collect(),
+            [Row(x="[1, 2, 3]")],
+        )
+
+    def test_udtf_with_map_input_type(self):
+        @udtf(returnType="x: string")
+        class TestUDTF:
+            def eval(self, m):
+                yield str(m),
+
+        self.spark.udtf.register("test_udtf", TestUDTF)
+        self.assertEqual(
+            self.spark.sql("select * from test_udtf(map('key', 
'value'))").collect(),
+            [Row(x="{'key': 'value'}")],
+        )
+
+    def test_udtf_with_struct_output_types(self):
+        @udtf(returnType="x: struct<a:int,b:int>")
+        class TestUDTF:
+            def eval(self, x: int):
+                yield {"a": x, "b": x + 1},
+
+        self.assertEqual(TestUDTF(lit(1)).collect(), [Row(x=Row(a=1, b=2))])
+
+    def test_udtf_with_array_output_types(self):
+        @udtf(returnType="x: array<int>")
+        class TestUDTF:
+            def eval(self, x: int):
+                yield [x, x + 1, x + 2],
+
+        self.assertEqual(TestUDTF(lit(1)).collect(), [Row(x=[1, 2, 3])])
+
+    def test_udtf_with_map_output_types(self):
+        @udtf(returnType="x: map<int,string>")
+        class TestUDTF:
+            def eval(self, x: int):
+                yield {x: str(x)},
+
+        self.assertEqual(TestUDTF(lit(1)).collect(), [Row(x={1: "1"})])
+
+    def test_udtf_with_pandas_input_type(self):
+        import pandas as pd
+
+        @udtf(returnType="corr: double")
+        class TestUDTF:
+            def eval(self, s1: pd.Series, s2: pd.Series):
+                yield s1.corr(s2)
+
+        self.spark.udtf.register("test_udtf", TestUDTF)
+        # TODO(SPARK-43968): check during compile time instead of runtime
         with self.assertRaisesRegex(
-            PythonException,
-            "Failed to execute the user defined table function because it has 
not "
-            "implemented the 'eval' method. Please add the 'eval' method and 
try the "
-            "query again.",
+            PythonException, "AttributeError: 'int' object has no attribute 
'corr'"
         ):
-            TestUDTF(lit(1)).collect()
+            self.spark.sql(
+                "select * from values (1, 2), (2, 3) t(a, b), " "lateral 
test_udtf(a, b)"

Review Comment:
   nit: I guess the formatter combined two lines. We don't need an extra blank 
between two strings.



##########
python/pyspark/worker.py:
##########
@@ -461,24 +462,53 @@ def assign_cols_by_name(runner_conf):
 # ensure the UDTF is valid. This function also prepares a mapper function for 
applying
 # the UDTF logic to input rows.
 def read_udtf(pickleSer, infile, eval_type):
+    if eval_type == PythonEvalType.SQL_ARROW_TABLE_UDF:
+        runner_conf = {}
+        # Load conf used for arrow evaluation.
+        num_conf = read_int(infile)
+        for i in range(num_conf):
+            k = utf8_deserializer.loads(infile)
+            v = utf8_deserializer.loads(infile)
+            runner_conf[k] = v
+
+        # NOTE: if timezone is set here, that implies respectSessionTimeZone 
is True
+        timezone = runner_conf.get("spark.sql.session.timeZone", None)
+        safecheck = (
+            
runner_conf.get("spark.sql.execution.pandas.convertToArrowArraySafely", 
"false").lower()
+            == "true"
+        )
+        ser = ArrowStreamPandasUDTFSerializer(
+            timezone,
+            safecheck,
+            assign_cols_by_name(runner_conf),

Review Comment:
   We shouldn't refer the config as the generated pandas DataFrame's columns 
have no names.



##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -454,6 +474,74 @@ def __repr__(self):
         return "ArrowStreamPandasUDFSerializer"
 
 
+class ArrowStreamPandasUDTFSerializer(ArrowStreamPandasUDFSerializer):
+    """
+    Serializer used by Python worker to evaluate Arrow-optimized Python UDTFs.
+    """
+
+    def __init__(self, timezone, safecheck, assign_cols_by_name):
+        super(ArrowStreamPandasUDTFSerializer, self).__init__(
+            timezone=timezone,
+            safecheck=safecheck,
+            assign_cols_by_name=assign_cols_by_name,
+            # Set to 'False' to avoid converting struct type inputs into a 
pandas DataFrame.
+            df_for_struct=False,
+            # Defines how struct type inputs are converted. If set to "row", 
struct type inputs
+            # are converted into Rows. Without this setting, a struct type 
input would be treated
+            # as a dictionary. For example, for named_struct('name', 'Alice', 
'age', 1),
+            # if struct_in_pandas="dict", it becomes {"name": "Alice", "age": 
1}
+            # if struct_in_pandas="row", it becomes Row(name="Alice", age=1)
+            struct_in_pandas="row",
+            # When dealing with array type inputs, Arrow converts them into 
numpy.ndarrays.
+            # To ensure consistency across regular and arrow-optimized UDTFs, 
we further
+            # convert these numpy.ndarrays into Python lists.
+            ndarray_as_list=True,
+            # Enables explicit casting for mismatched return types of Arrow 
Python UDTFs.
+            arrow_cast=True,
+        )
+
+    def _create_batch(self, series):
+        """
+        Create an Arrow record batch from the given pandas.Series 
pandas.DataFrame
+        or list of Series or DataFrame, with optional type.
+
+        Parameters
+        ----------
+        series : pandas.Series or pandas.DataFrame or list
+            A single series or dataframe, list of series or dataframe,
+            or list of (series or dataframe, arrow_type)
+
+        Returns
+        -------
+        pyarrow.RecordBatch
+            Arrow RecordBatch
+        """
+        import pandas as pd
+        import pyarrow as pa
+
+        # Make input conform to [(series1, type1), (series2, type2), ...]
+        if not isinstance(series, (list, tuple)) or (
+            len(series) == 2 and isinstance(series[1], pa.DataType)
+        ):
+            series = [series]
+        series = ((s, None) if not isinstance(s, (list, tuple)) else s for s 
in series)
+
+        arrs = []
+        for s, t in series:
+            if not isinstance(s, pd.DataFrame):
+                raise PySparkValueError(
+                    "Output of an arrow-optimized Python UDTFs expects "
+                    f"a pandas.DataFrame but got: {type(s)}"
+                )
+
+            arrs.append(self._create_struct_array(s, t))
+
+        return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in 
range(len(arrs))])

Review Comment:
   Just an optimization idea, not a strong suggestion:
   
   Now that we have a separate serialize, we can assume:
   - `len(series) == 1`
   - `isinstance(series[0][0], pd.DataFrame)`
   
   then we can flatten `series[0]` into a `new_series` and call 
`super()._create_batch(new_series)`? Also in that case, we don't need to 
flatten it in Java.
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to