sadhen commented on a change in pull request #31735:
URL: https://github.com/apache/spark/pull/31735#discussion_r605705270



##########
File path: python/pyspark/sql/pandas/serializers.py
##########
@@ -153,14 +157,15 @@ def _create_batch(self, series):
         from pyspark.sql.pandas.types import 
_check_series_convert_timestamps_internal, \
             _convert_dict_to_map_items
         from pandas.api.types import is_categorical_dtype
-        # Make input conform to [(series1, type1), (series2, type2), ...]
-        if not isinstance(series, (list, tuple)) or \
-                (len(series) == 2 and isinstance(series[1], pa.DataType)):
-            series = [series]
-        series = ((s, None) if not isinstance(s, (list, tuple)) else s for s 
in series)
 
-        def create_array(s, t):
-            mask = s.isnull()
+        def create_array(s, t: pa.DataType, dt: Optional[DataType] = None):
+            if dt is not None:
+                if isinstance(dt, UserDefinedType):
+                    s = s.apply(dt.serialize)
+                elif isinstance(dt, ArrayType) and isinstance(dt.elementType, 
UserDefinedType):
+                    udt = dt.elementType
+                    s = s.apply(lambda x: [udt.serialize(f) for f in x])

Review comment:
       This part is extracted, see 
https://github.com/apache/spark/pull/32026/files

##########
File path: python/pyspark/sql/pandas/serializers.py
##########
@@ -183,30 +193,59 @@ def create_array(s, t):
                     raise e
             return array
 
-        arrs = []
-        for s, t in series:
-            if t is not None and pa.types.is_struct(t):
-                if not isinstance(s, pd.DataFrame):
-                    raise ValueError("A field of type StructType expects a 
pandas.DataFrame, "
-                                     "but got: %s" % str(type(s)))
+        def create_arrs_names(s, t: pa.DataType, dt: Optional[DataType] = 
None):
+            # If input s is empty with zero columns, return empty Arrays with 
struct
+            if len(s) == 0 and len(s.columns) == 0:
+                return [(pa.array([], type=field.type), field.name) for field 
in t]
 
-                # Input partition and result pandas.DataFrame empty, make 
empty Arrays with struct
-                if len(s) == 0 and len(s.columns) == 0:
-                    arrs_names = [(pa.array([], type=field.type), field.name) 
for field in t]
+            if self._assign_cols_by_name and any(isinstance(name, str) for 
name in s.columns):
                 # Assign result columns by schema name if user labeled with 
strings
-                elif self._assign_cols_by_name and any(isinstance(name, str)
-                                                       for name in s.columns):
-                    arrs_names = [(create_array(s[field.name], field.type), 
field.name)
-                                  for field in t]
+                by_field_name = True
+            else:
                 # Assign result columns by  position
+                by_field_name = False
+
+            if dt is None:
+                if by_field_name:
+                    return [(create_array(s[field.name], field.type), 
field.name) for field in t]
+                else: 
+                    return [
+                        (create_array(s[s.columns[i]], field.type), field.name)
+                        for i, field in enumerate(t)
+                    ]
+            else:
+                if by_field_name:
+                    return [
+                        (create_array(s[field.name], field.type, 
struct_field.dataType), field.name)
+                        for field, struct_field in zip(t, dt.fields)
+                    ]
                 else:
-                    arrs_names = [(create_array(s[s.columns[i]], field.type), 
field.name)
-                                  for i, field in enumerate(t)]
+                    return [
+                        (create_array(s[s.columns[i]], field.type, 
struct_field.dataType), field.name)
+                        for i, (field, struct_field) in enumerate(zip(t, 
dt.fields))
+                    ]
+
+        # Make input conform to [(series1, type1), (series2, type2), ...]
+        if not isinstance(series, (list, tuple)) or \
+                (len(series) == 2 and isinstance(series[1], (pa.DataType, 
DataType))):
+            series = [series]
+        series = ((s, None) if not isinstance(s, (list, tuple)) else s for s 
in series)
 
+        arrs = []

Review comment:
       See https://github.com/apache/spark/pull/32026/files




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to