sadhen commented on a change in pull request #31735:
URL: https://github.com/apache/spark/pull/31735#discussion_r605705270
##########
File path: python/pyspark/sql/pandas/serializers.py
##########
@@ -153,14 +157,15 @@ def _create_batch(self, series):
from pyspark.sql.pandas.types import
_check_series_convert_timestamps_internal, \
_convert_dict_to_map_items
from pandas.api.types import is_categorical_dtype
- # Make input conform to [(series1, type1), (series2, type2), ...]
- if not isinstance(series, (list, tuple)) or \
- (len(series) == 2 and isinstance(series[1], pa.DataType)):
- series = [series]
- series = ((s, None) if not isinstance(s, (list, tuple)) else s for s
in series)
- def create_array(s, t):
- mask = s.isnull()
+ def create_array(s, t: pa.DataType, dt: Optional[DataType] = None):
+ if dt is not None:
+ if isinstance(dt, UserDefinedType):
+ s = s.apply(dt.serialize)
+ elif isinstance(dt, ArrayType) and isinstance(dt.elementType,
UserDefinedType):
+ udt = dt.elementType
+ s = s.apply(lambda x: [udt.serialize(f) for f in x])
Review comment:
This part is extracted, see
https://github.com/apache/spark/pull/32026/files
##########
File path: python/pyspark/sql/pandas/serializers.py
##########
@@ -183,30 +193,59 @@ def create_array(s, t):
raise e
return array
- arrs = []
- for s, t in series:
- if t is not None and pa.types.is_struct(t):
- if not isinstance(s, pd.DataFrame):
- raise ValueError("A field of type StructType expects a
pandas.DataFrame, "
- "but got: %s" % str(type(s)))
+ def create_arrs_names(s, t: pa.DataType, dt: Optional[DataType] =
None):
+ # If input s is empty with zero columns, return empty Arrays with
struct
+ if len(s) == 0 and len(s.columns) == 0:
+ return [(pa.array([], type=field.type), field.name) for field
in t]
- # Input partition and result pandas.DataFrame empty, make
empty Arrays with struct
- if len(s) == 0 and len(s.columns) == 0:
- arrs_names = [(pa.array([], type=field.type), field.name)
for field in t]
+ if self._assign_cols_by_name and any(isinstance(name, str) for
name in s.columns):
# Assign result columns by schema name if user labeled with
strings
- elif self._assign_cols_by_name and any(isinstance(name, str)
- for name in s.columns):
- arrs_names = [(create_array(s[field.name], field.type),
field.name)
- for field in t]
+ by_field_name = True
+ else:
# Assign result columns by position
+ by_field_name = False
+
+ if dt is None:
+ if by_field_name:
+ return [(create_array(s[field.name], field.type),
field.name) for field in t]
+ else:
+ return [
+ (create_array(s[s.columns[i]], field.type), field.name)
+ for i, field in enumerate(t)
+ ]
+ else:
+ if by_field_name:
+ return [
+ (create_array(s[field.name], field.type,
struct_field.dataType), field.name)
+ for field, struct_field in zip(t, dt.fields)
+ ]
else:
- arrs_names = [(create_array(s[s.columns[i]], field.type),
field.name)
- for i, field in enumerate(t)]
+ return [
+ (create_array(s[s.columns[i]], field.type,
struct_field.dataType), field.name)
+ for i, (field, struct_field) in enumerate(zip(t,
dt.fields))
+ ]
+
+ # Make input conform to [(series1, type1), (series2, type2), ...]
+ if not isinstance(series, (list, tuple)) or \
+ (len(series) == 2 and isinstance(series[1], (pa.DataType,
DataType))):
+ series = [series]
+ series = ((s, None) if not isinstance(s, (list, tuple)) else s for s
in series)
+ arrs = []
Review comment:
See https://github.com/apache/spark/pull/32026/files
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]